label_ops.py 52 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525
  1. # copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. from __future__ import unicode_literals
  18. import copy
  19. import numpy as np
  20. import string
  21. from shapely.geometry import LineString, Point, Polygon
  22. import json
  23. import copy
  24. from random import sample
  25. from ppocr.utils.logging import get_logger
  26. from ppocr.data.imaug.vqa.augment import order_by_tbyx
  27. class ClsLabelEncode(object):
  28. def __init__(self, label_list, **kwargs):
  29. self.label_list = label_list
  30. def __call__(self, data):
  31. label = data['label']
  32. if label not in self.label_list:
  33. return None
  34. label = self.label_list.index(label)
  35. data['label'] = label
  36. return data
  37. class DetLabelEncode(object):
  38. #def __init__(self, **kwargs):
  39. # pass
  40. def __init__(self, label_list, num_classes=1, **kwargs):
  41. self.num_classes = num_classes
  42. self.label_list = []
  43. if label_list is not None:
  44. if isinstance(label_list, str):
  45. with open(label_list, "r+", encoding="utf-8") as f:
  46. for line in f.readlines():
  47. self.label_list.append(line.replace("\n", ""))
  48. else:
  49. self.label_list = label_list
  50. if num_classes != len(self.label_list):
  51. assert "label_list长度与num_classes长度不符合"
  52. def __call__(self, data):
  53. label = data['label']
  54. label = json.loads(label)
  55. nBox = len(label)
  56. boxes, txts, txt_tags = [], [], []
  57. classes = []
  58. for bno in range(0, nBox):
  59. box = label[bno]['points']
  60. txt = label[bno]['transcription']
  61. boxes.append(box)
  62. txts.append(txt)
  63. if txt in ['*', '###']:
  64. txt_tags.append(True)
  65. if self.num_classes > 1:
  66. classes.append(-2)
  67. else:
  68. txt_tags.append(False)
  69. if self.num_classes > 1:
  70. classes.append(int(self.label_list.index(txt)))
  71. if len(boxes) == 0:
  72. return None
  73. boxes = self.expand_points_num(boxes)
  74. boxes = np.array(boxes, dtype=np.float32)
  75. txt_tags = np.array(txt_tags, dtype=np.bool)
  76. classes = classes
  77. data['polys'] = boxes
  78. data['texts'] = txts
  79. data['ignore_tags'] = txt_tags
  80. if self.num_classes > 1:
  81. data['classes'] = classes
  82. return data
  83. def order_points_clockwise(self, pts):
  84. rect = np.zeros((4, 2), dtype="float32")
  85. s = pts.sum(axis=1)
  86. rect[0] = pts[np.argmin(s)]
  87. rect[2] = pts[np.argmax(s)]
  88. tmp = np.delete(pts, (np.argmin(s), np.argmax(s)), axis=0)
  89. diff = np.diff(np.array(tmp), axis=1)
  90. rect[1] = tmp[np.argmin(diff)]
  91. rect[3] = tmp[np.argmax(diff)]
  92. return rect
  93. def expand_points_num(self, boxes):
  94. max_points_num = 0
  95. for box in boxes:
  96. if len(box) > max_points_num:
  97. max_points_num = len(box)
  98. ex_boxes = []
  99. for box in boxes:
  100. ex_box = box + [box[-1]] * (max_points_num - len(box))
  101. ex_boxes.append(ex_box)
  102. return ex_boxes
  103. class BaseRecLabelEncode(object):
  104. """ Convert between text-label and text-index """
  105. def __init__(self,
  106. max_text_length,
  107. character_dict_path=None,
  108. use_space_char=False,
  109. lower=False):
  110. self.max_text_len = max_text_length
  111. self.beg_str = "sos"
  112. self.end_str = "eos"
  113. self.lower = lower
  114. if character_dict_path is None:
  115. logger = get_logger()
  116. logger.warning(
  117. "The character_dict_path is None, model can only recognize number and lower letters"
  118. )
  119. self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
  120. dict_character = list(self.character_str)
  121. self.lower = True
  122. else:
  123. self.character_str = []
  124. with open(character_dict_path, "rb") as fin:
  125. lines = fin.readlines()
  126. for line in lines:
  127. line = line.decode('utf-8').strip("\n").strip("\r\n")
  128. self.character_str.append(line)
  129. if use_space_char:
  130. self.character_str.append(" ")
  131. dict_character = list(self.character_str)
  132. dict_character = self.add_special_char(dict_character)
  133. self.dict = {}
  134. for i, char in enumerate(dict_character):
  135. self.dict[char] = i
  136. self.character = dict_character
  137. def add_special_char(self, dict_character):
  138. return dict_character
  139. def encode(self, text):
  140. """convert text-label into text-index.
  141. input:
  142. text: text labels of each image. [batch_size]
  143. output:
  144. text: concatenated text index for CTCLoss.
  145. [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
  146. length: length of each text. [batch_size]
  147. """
  148. if len(text) == 0 or len(text) > self.max_text_len:
  149. return None
  150. if self.lower:
  151. text = text.lower()
  152. text_list = []
  153. for char in text:
  154. if char not in self.dict:
  155. # logger = get_logger()
  156. # logger.warning('{} is not in dict'.format(char))
  157. continue
  158. text_list.append(self.dict[char])
  159. if len(text_list) == 0:
  160. return None
  161. return text_list
  162. class CTCLabelEncode(BaseRecLabelEncode):
  163. """ Convert between text-label and text-index """
  164. def __init__(self,
  165. max_text_length,
  166. character_dict_path=None,
  167. use_space_char=False,
  168. **kwargs):
  169. super(CTCLabelEncode, self).__init__(
  170. max_text_length, character_dict_path, use_space_char)
  171. def __call__(self, data):
  172. text = data['label']
  173. text = self.encode(text)
  174. if text is None:
  175. return None
  176. data['length'] = np.array(len(text))
  177. text = text + [0] * (self.max_text_len - len(text))
  178. data['label'] = np.array(text)
  179. label = [0] * len(self.character)
  180. for x in text:
  181. label[x] += 1
  182. data['label_ace'] = np.array(label)
  183. return data
  184. def add_special_char(self, dict_character):
  185. dict_character = ['blank'] + dict_character
  186. return dict_character
  187. class E2ELabelEncodeTest(BaseRecLabelEncode):
  188. def __init__(self,
  189. max_text_length,
  190. character_dict_path=None,
  191. use_space_char=False,
  192. **kwargs):
  193. super(E2ELabelEncodeTest, self).__init__(
  194. max_text_length, character_dict_path, use_space_char)
  195. def __call__(self, data):
  196. import json
  197. padnum = len(self.dict)
  198. label = data['label']
  199. label = json.loads(label)
  200. nBox = len(label)
  201. boxes, txts, txt_tags = [], [], []
  202. for bno in range(0, nBox):
  203. box = label[bno]['points']
  204. txt = label[bno]['transcription']
  205. boxes.append(box)
  206. txts.append(txt)
  207. if txt in ['*', '###']:
  208. txt_tags.append(True)
  209. else:
  210. txt_tags.append(False)
  211. boxes = np.array(boxes, dtype=np.float32)
  212. txt_tags = np.array(txt_tags, dtype=np.bool_)
  213. data['polys'] = boxes
  214. data['ignore_tags'] = txt_tags
  215. temp_texts = []
  216. for text in txts:
  217. text = text.lower()
  218. text = self.encode(text)
  219. if text is None:
  220. return None
  221. text = text + [padnum] * (self.max_text_len - len(text)
  222. ) # use 36 to pad
  223. temp_texts.append(text)
  224. data['texts'] = np.array(temp_texts)
  225. return data
  226. class E2ELabelEncodeTrain(object):
  227. def __init__(self, **kwargs):
  228. pass
  229. def __call__(self, data):
  230. import json
  231. label = data['label']
  232. label = json.loads(label)
  233. nBox = len(label)
  234. boxes, txts, txt_tags = [], [], []
  235. for bno in range(0, nBox):
  236. box = label[bno]['points']
  237. txt = label[bno]['transcription']
  238. boxes.append(box)
  239. txts.append(txt)
  240. if txt in ['*', '###']:
  241. txt_tags.append(True)
  242. else:
  243. txt_tags.append(False)
  244. boxes = np.array(boxes, dtype=np.float32)
  245. txt_tags = np.array(txt_tags, dtype=np.bool_)
  246. data['polys'] = boxes
  247. data['texts'] = txts
  248. data['ignore_tags'] = txt_tags
  249. return data
  250. class KieLabelEncode(object):
  251. def __init__(self,
  252. character_dict_path,
  253. class_path,
  254. norm=10,
  255. directed=False,
  256. **kwargs):
  257. super(KieLabelEncode, self).__init__()
  258. self.dict = dict({'': 0})
  259. self.label2classid_map = dict()
  260. with open(character_dict_path, 'r', encoding='utf-8') as fr:
  261. idx = 1
  262. for line in fr:
  263. char = line.strip()
  264. self.dict[char] = idx
  265. idx += 1
  266. with open(class_path, "r") as fin:
  267. lines = fin.readlines()
  268. for idx, line in enumerate(lines):
  269. line = line.strip("\n")
  270. self.label2classid_map[line] = idx
  271. self.norm = norm
  272. self.directed = directed
  273. def compute_relation(self, boxes):
  274. """Compute relation between every two boxes."""
  275. x1s, y1s = boxes[:, 0:1], boxes[:, 1:2]
  276. x2s, y2s = boxes[:, 4:5], boxes[:, 5:6]
  277. ws, hs = x2s - x1s + 1, np.maximum(y2s - y1s + 1, 1)
  278. dxs = (x1s[:, 0][None] - x1s) / self.norm
  279. dys = (y1s[:, 0][None] - y1s) / self.norm
  280. xhhs, xwhs = hs[:, 0][None] / hs, ws[:, 0][None] / hs
  281. whs = ws / hs + np.zeros_like(xhhs)
  282. relations = np.stack([dxs, dys, whs, xhhs, xwhs], -1)
  283. bboxes = np.concatenate([x1s, y1s, x2s, y2s], -1).astype(np.float32)
  284. return relations, bboxes
  285. def pad_text_indices(self, text_inds):
  286. """Pad text index to same length."""
  287. max_len = 300
  288. recoder_len = max([len(text_ind) for text_ind in text_inds])
  289. padded_text_inds = -np.ones((len(text_inds), max_len), np.int32)
  290. for idx, text_ind in enumerate(text_inds):
  291. padded_text_inds[idx, :len(text_ind)] = np.array(text_ind)
  292. return padded_text_inds, recoder_len
  293. def list_to_numpy(self, ann_infos):
  294. """Convert bboxes, relations, texts and labels to ndarray."""
  295. boxes, text_inds = ann_infos['points'], ann_infos['text_inds']
  296. boxes = np.array(boxes, np.int32)
  297. relations, bboxes = self.compute_relation(boxes)
  298. labels = ann_infos.get('labels', None)
  299. if labels is not None:
  300. labels = np.array(labels, np.int32)
  301. edges = ann_infos.get('edges', None)
  302. if edges is not None:
  303. labels = labels[:, None]
  304. edges = np.array(edges)
  305. edges = (edges[:, None] == edges[None, :]).astype(np.int32)
  306. if self.directed:
  307. edges = (edges & labels == 1).astype(np.int32)
  308. np.fill_diagonal(edges, -1)
  309. labels = np.concatenate([labels, edges], -1)
  310. padded_text_inds, recoder_len = self.pad_text_indices(text_inds)
  311. max_num = 300
  312. temp_bboxes = np.zeros([max_num, 4])
  313. h, _ = bboxes.shape
  314. temp_bboxes[:h, :] = bboxes
  315. temp_relations = np.zeros([max_num, max_num, 5])
  316. temp_relations[:h, :h, :] = relations
  317. temp_padded_text_inds = np.zeros([max_num, max_num])
  318. temp_padded_text_inds[:h, :] = padded_text_inds
  319. temp_labels = np.zeros([max_num, max_num])
  320. temp_labels[:h, :h + 1] = labels
  321. tag = np.array([h, recoder_len])
  322. return dict(
  323. image=ann_infos['image'],
  324. points=temp_bboxes,
  325. relations=temp_relations,
  326. texts=temp_padded_text_inds,
  327. labels=temp_labels,
  328. tag=tag)
  329. def convert_canonical(self, points_x, points_y):
  330. assert len(points_x) == 4
  331. assert len(points_y) == 4
  332. points = [Point(points_x[i], points_y[i]) for i in range(4)]
  333. polygon = Polygon([(p.x, p.y) for p in points])
  334. min_x, min_y, _, _ = polygon.bounds
  335. points_to_lefttop = [
  336. LineString([points[i], Point(min_x, min_y)]) for i in range(4)
  337. ]
  338. distances = np.array([line.length for line in points_to_lefttop])
  339. sort_dist_idx = np.argsort(distances)
  340. lefttop_idx = sort_dist_idx[0]
  341. if lefttop_idx == 0:
  342. point_orders = [0, 1, 2, 3]
  343. elif lefttop_idx == 1:
  344. point_orders = [1, 2, 3, 0]
  345. elif lefttop_idx == 2:
  346. point_orders = [2, 3, 0, 1]
  347. else:
  348. point_orders = [3, 0, 1, 2]
  349. sorted_points_x = [points_x[i] for i in point_orders]
  350. sorted_points_y = [points_y[j] for j in point_orders]
  351. return sorted_points_x, sorted_points_y
  352. def sort_vertex(self, points_x, points_y):
  353. assert len(points_x) == 4
  354. assert len(points_y) == 4
  355. x = np.array(points_x)
  356. y = np.array(points_y)
  357. center_x = np.sum(x) * 0.25
  358. center_y = np.sum(y) * 0.25
  359. x_arr = np.array(x - center_x)
  360. y_arr = np.array(y - center_y)
  361. angle = np.arctan2(y_arr, x_arr) * 180.0 / np.pi
  362. sort_idx = np.argsort(angle)
  363. sorted_points_x, sorted_points_y = [], []
  364. for i in range(4):
  365. sorted_points_x.append(points_x[sort_idx[i]])
  366. sorted_points_y.append(points_y[sort_idx[i]])
  367. return self.convert_canonical(sorted_points_x, sorted_points_y)
  368. def __call__(self, data):
  369. import json
  370. label = data['label']
  371. annotations = json.loads(label)
  372. boxes, texts, text_inds, labels, edges = [], [], [], [], []
  373. for ann in annotations:
  374. box = ann['points']
  375. x_list = [box[i][0] for i in range(4)]
  376. y_list = [box[i][1] for i in range(4)]
  377. sorted_x_list, sorted_y_list = self.sort_vertex(x_list, y_list)
  378. sorted_box = []
  379. for x, y in zip(sorted_x_list, sorted_y_list):
  380. sorted_box.append(x)
  381. sorted_box.append(y)
  382. boxes.append(sorted_box)
  383. text = ann['transcription']
  384. texts.append(ann['transcription'])
  385. text_ind = [self.dict[c] for c in text if c in self.dict]
  386. text_inds.append(text_ind)
  387. if 'label' in ann.keys():
  388. labels.append(self.label2classid_map[ann['label']])
  389. elif 'key_cls' in ann.keys():
  390. labels.append(ann['key_cls'])
  391. else:
  392. raise ValueError(
  393. "Cannot found 'key_cls' in ann.keys(), please check your training annotation."
  394. )
  395. edges.append(ann.get('edge', 0))
  396. ann_infos = dict(
  397. image=data['image'],
  398. points=boxes,
  399. texts=texts,
  400. text_inds=text_inds,
  401. edges=edges,
  402. labels=labels)
  403. return self.list_to_numpy(ann_infos)
  404. class AttnLabelEncode(BaseRecLabelEncode):
  405. """ Convert between text-label and text-index """
  406. def __init__(self,
  407. max_text_length,
  408. character_dict_path=None,
  409. use_space_char=False,
  410. **kwargs):
  411. super(AttnLabelEncode, self).__init__(
  412. max_text_length, character_dict_path, use_space_char)
  413. def add_special_char(self, dict_character):
  414. self.beg_str = "sos"
  415. self.end_str = "eos"
  416. dict_character = [self.beg_str] + dict_character + [self.end_str]
  417. return dict_character
  418. def __call__(self, data):
  419. text = data['label']
  420. text = self.encode(text)
  421. if text is None:
  422. return None
  423. if len(text) >= self.max_text_len:
  424. return None
  425. data['length'] = np.array(len(text))
  426. text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
  427. - len(text) - 2)
  428. data['label'] = np.array(text)
  429. return data
  430. def get_ignored_tokens(self):
  431. beg_idx = self.get_beg_end_flag_idx("beg")
  432. end_idx = self.get_beg_end_flag_idx("end")
  433. return [beg_idx, end_idx]
  434. def get_beg_end_flag_idx(self, beg_or_end):
  435. if beg_or_end == "beg":
  436. idx = np.array(self.dict[self.beg_str])
  437. elif beg_or_end == "end":
  438. idx = np.array(self.dict[self.end_str])
  439. else:
  440. assert False, "Unsupport type %s in get_beg_end_flag_idx" \
  441. % beg_or_end
  442. return idx
  443. class RFLLabelEncode(BaseRecLabelEncode):
  444. """ Convert between text-label and text-index """
  445. def __init__(self,
  446. max_text_length,
  447. character_dict_path=None,
  448. use_space_char=False,
  449. **kwargs):
  450. super(RFLLabelEncode, self).__init__(
  451. max_text_length, character_dict_path, use_space_char)
  452. def add_special_char(self, dict_character):
  453. self.beg_str = "sos"
  454. self.end_str = "eos"
  455. dict_character = [self.beg_str] + dict_character + [self.end_str]
  456. return dict_character
  457. def encode_cnt(self, text):
  458. cnt_label = [0.0] * len(self.character)
  459. for char_ in text:
  460. cnt_label[char_] += 1
  461. return np.array(cnt_label)
  462. def __call__(self, data):
  463. text = data['label']
  464. text = self.encode(text)
  465. if text is None:
  466. return None
  467. if len(text) >= self.max_text_len:
  468. return None
  469. cnt_label = self.encode_cnt(text)
  470. data['length'] = np.array(len(text))
  471. text = [0] + text + [len(self.character) - 1] + [0] * (self.max_text_len
  472. - len(text) - 2)
  473. if len(text) != self.max_text_len:
  474. return None
  475. data['label'] = np.array(text)
  476. data['cnt_label'] = cnt_label
  477. return data
  478. def get_ignored_tokens(self):
  479. beg_idx = self.get_beg_end_flag_idx("beg")
  480. end_idx = self.get_beg_end_flag_idx("end")
  481. return [beg_idx, end_idx]
  482. def get_beg_end_flag_idx(self, beg_or_end):
  483. if beg_or_end == "beg":
  484. idx = np.array(self.dict[self.beg_str])
  485. elif beg_or_end == "end":
  486. idx = np.array(self.dict[self.end_str])
  487. else:
  488. assert False, "Unsupport type %s in get_beg_end_flag_idx" \
  489. % beg_or_end
  490. return idx
  491. class SEEDLabelEncode(BaseRecLabelEncode):
  492. """ Convert between text-label and text-index """
  493. def __init__(self,
  494. max_text_length,
  495. character_dict_path=None,
  496. use_space_char=False,
  497. **kwargs):
  498. super(SEEDLabelEncode, self).__init__(
  499. max_text_length, character_dict_path, use_space_char)
  500. def add_special_char(self, dict_character):
  501. self.padding = "padding"
  502. self.end_str = "eos"
  503. self.unknown = "unknown"
  504. dict_character = dict_character + [
  505. self.end_str, self.padding, self.unknown
  506. ]
  507. return dict_character
  508. def __call__(self, data):
  509. text = data['label']
  510. text = self.encode(text)
  511. if text is None:
  512. return None
  513. if len(text) >= self.max_text_len:
  514. return None
  515. data['length'] = np.array(len(text)) + 1 # conclude eos
  516. text = text + [len(self.character) - 3] + [len(self.character) - 2] * (
  517. self.max_text_len - len(text) - 1)
  518. data['label'] = np.array(text)
  519. return data
  520. class SRNLabelEncode(BaseRecLabelEncode):
  521. """ Convert between text-label and text-index """
  522. def __init__(self,
  523. max_text_length=25,
  524. character_dict_path=None,
  525. use_space_char=False,
  526. **kwargs):
  527. super(SRNLabelEncode, self).__init__(
  528. max_text_length, character_dict_path, use_space_char)
  529. def add_special_char(self, dict_character):
  530. dict_character = dict_character + [self.beg_str, self.end_str]
  531. return dict_character
  532. def __call__(self, data):
  533. text = data['label']
  534. text = self.encode(text)
  535. char_num = len(self.character)
  536. if text is None:
  537. return None
  538. if len(text) > self.max_text_len:
  539. return None
  540. data['length'] = np.array(len(text))
  541. text = text + [char_num - 1] * (self.max_text_len - len(text))
  542. data['label'] = np.array(text)
  543. return data
  544. def get_ignored_tokens(self):
  545. beg_idx = self.get_beg_end_flag_idx("beg")
  546. end_idx = self.get_beg_end_flag_idx("end")
  547. return [beg_idx, end_idx]
  548. def get_beg_end_flag_idx(self, beg_or_end):
  549. if beg_or_end == "beg":
  550. idx = np.array(self.dict[self.beg_str])
  551. elif beg_or_end == "end":
  552. idx = np.array(self.dict[self.end_str])
  553. else:
  554. assert False, "Unsupport type %s in get_beg_end_flag_idx" \
  555. % beg_or_end
  556. return idx
  557. class TableLabelEncode(AttnLabelEncode):
  558. """ Convert between text-label and text-index """
  559. def __init__(self,
  560. max_text_length,
  561. character_dict_path,
  562. replace_empty_cell_token=False,
  563. merge_no_span_structure=False,
  564. learn_empty_box=False,
  565. loc_reg_num=4,
  566. **kwargs):
  567. self.max_text_len = max_text_length
  568. self.lower = False
  569. self.learn_empty_box = learn_empty_box
  570. self.merge_no_span_structure = merge_no_span_structure
  571. self.replace_empty_cell_token = replace_empty_cell_token
  572. dict_character = []
  573. with open(character_dict_path, "rb") as fin:
  574. lines = fin.readlines()
  575. for line in lines:
  576. line = line.decode('utf-8').strip("\n").strip("\r\n")
  577. dict_character.append(line)
  578. if self.merge_no_span_structure:
  579. if "<td></td>" not in dict_character:
  580. dict_character.append("<td></td>")
  581. if "<td>" in dict_character:
  582. dict_character.remove("<td>")
  583. dict_character = self.add_special_char(dict_character)
  584. self.dict = {}
  585. for i, char in enumerate(dict_character):
  586. self.dict[char] = i
  587. self.idx2char = {v: k for k, v in self.dict.items()}
  588. self.character = dict_character
  589. self.loc_reg_num = loc_reg_num
  590. self.pad_idx = self.dict[self.beg_str]
  591. self.start_idx = self.dict[self.beg_str]
  592. self.end_idx = self.dict[self.end_str]
  593. self.td_token = ['<td>', '<td', '<eb></eb>', '<td></td>']
  594. self.empty_bbox_token_dict = {
  595. "[]": '<eb></eb>',
  596. "[' ']": '<eb1></eb1>',
  597. "['<b>', ' ', '</b>']": '<eb2></eb2>',
  598. "['\\u2028', '\\u2028']": '<eb3></eb3>',
  599. "['<sup>', ' ', '</sup>']": '<eb4></eb4>',
  600. "['<b>', '</b>']": '<eb5></eb5>',
  601. "['<i>', ' ', '</i>']": '<eb6></eb6>',
  602. "['<b>', '<i>', '</i>', '</b>']": '<eb7></eb7>',
  603. "['<b>', '<i>', ' ', '</i>', '</b>']": '<eb8></eb8>',
  604. "['<i>', '</i>']": '<eb9></eb9>',
  605. "['<b>', ' ', '\\u2028', ' ', '\\u2028', ' ', '</b>']":
  606. '<eb10></eb10>',
  607. }
  608. @property
  609. def _max_text_len(self):
  610. return self.max_text_len + 2
  611. def __call__(self, data):
  612. cells = data['cells']
  613. structure = data['structure']
  614. if self.merge_no_span_structure:
  615. structure = self._merge_no_span_structure(structure)
  616. if self.replace_empty_cell_token:
  617. structure = self._replace_empty_cell_token(structure, cells)
  618. # remove empty token and add " " to span token
  619. new_structure = []
  620. for token in structure:
  621. if token != '':
  622. if 'span' in token and token[0] != ' ':
  623. token = ' ' + token
  624. new_structure.append(token)
  625. # encode structure
  626. structure = self.encode(new_structure)
  627. if structure is None:
  628. return None
  629. structure = [self.start_idx] + structure + [self.end_idx
  630. ] # add sos abd eos
  631. structure = structure + [self.pad_idx] * (self._max_text_len -
  632. len(structure)) # pad
  633. structure = np.array(structure)
  634. data['structure'] = structure
  635. if len(structure) > self._max_text_len:
  636. return None
  637. # encode box
  638. bboxes = np.zeros(
  639. (self._max_text_len, self.loc_reg_num), dtype=np.float32)
  640. bbox_masks = np.zeros((self._max_text_len, 1), dtype=np.float32)
  641. bbox_idx = 0
  642. for i, token in enumerate(structure):
  643. if self.idx2char[token] in self.td_token:
  644. if 'bbox' in cells[bbox_idx] and len(cells[bbox_idx][
  645. 'tokens']) > 0:
  646. bbox = cells[bbox_idx]['bbox'].copy()
  647. bbox = np.array(bbox, dtype=np.float32).reshape(-1)
  648. bboxes[i] = bbox
  649. bbox_masks[i] = 1.0
  650. if self.learn_empty_box:
  651. bbox_masks[i] = 1.0
  652. bbox_idx += 1
  653. data['bboxes'] = bboxes
  654. data['bbox_masks'] = bbox_masks
  655. return data
  656. def _merge_no_span_structure(self, structure):
  657. """
  658. This code is refer from:
  659. https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
  660. """
  661. new_structure = []
  662. i = 0
  663. while i < len(structure):
  664. token = structure[i]
  665. if token == '<td>':
  666. token = '<td></td>'
  667. i += 1
  668. new_structure.append(token)
  669. i += 1
  670. return new_structure
  671. def _replace_empty_cell_token(self, token_list, cells):
  672. """
  673. This fun code is refer from:
  674. https://github.com/JiaquanYe/TableMASTER-mmocr/blob/master/table_recognition/data_preprocess.py
  675. """
  676. bbox_idx = 0
  677. add_empty_bbox_token_list = []
  678. for token in token_list:
  679. if token in ['<td></td>', '<td', '<td>']:
  680. if 'bbox' not in cells[bbox_idx].keys():
  681. content = str(cells[bbox_idx]['tokens'])
  682. token = self.empty_bbox_token_dict[content]
  683. add_empty_bbox_token_list.append(token)
  684. bbox_idx += 1
  685. else:
  686. add_empty_bbox_token_list.append(token)
  687. return add_empty_bbox_token_list
  688. class TableMasterLabelEncode(TableLabelEncode):
  689. """ Convert between text-label and text-index """
  690. def __init__(self,
  691. max_text_length,
  692. character_dict_path,
  693. replace_empty_cell_token=False,
  694. merge_no_span_structure=False,
  695. learn_empty_box=False,
  696. loc_reg_num=4,
  697. **kwargs):
  698. super(TableMasterLabelEncode, self).__init__(
  699. max_text_length, character_dict_path, replace_empty_cell_token,
  700. merge_no_span_structure, learn_empty_box, loc_reg_num, **kwargs)
  701. self.pad_idx = self.dict[self.pad_str]
  702. self.unknown_idx = self.dict[self.unknown_str]
  703. @property
  704. def _max_text_len(self):
  705. return self.max_text_len
  706. def add_special_char(self, dict_character):
  707. self.beg_str = '<SOS>'
  708. self.end_str = '<EOS>'
  709. self.unknown_str = '<UKN>'
  710. self.pad_str = '<PAD>'
  711. dict_character = dict_character
  712. dict_character = dict_character + [
  713. self.unknown_str, self.beg_str, self.end_str, self.pad_str
  714. ]
  715. return dict_character
  716. class TableBoxEncode(object):
  717. def __init__(self, in_box_format='xyxy', out_box_format='xyxy', **kwargs):
  718. assert out_box_format in ['xywh', 'xyxy', 'xyxyxyxy']
  719. self.in_box_format = in_box_format
  720. self.out_box_format = out_box_format
  721. def __call__(self, data):
  722. img_height, img_width = data['image'].shape[:2]
  723. bboxes = data['bboxes']
  724. if self.in_box_format != self.out_box_format:
  725. if self.out_box_format == 'xywh':
  726. if self.in_box_format == 'xyxyxyxy':
  727. bboxes = self.xyxyxyxy2xywh(bboxes)
  728. elif self.in_box_format == 'xyxy':
  729. bboxes = self.xyxy2xywh(bboxes)
  730. bboxes[:, 0::2] /= img_width
  731. bboxes[:, 1::2] /= img_height
  732. data['bboxes'] = bboxes
  733. return data
  734. def xyxyxyxy2xywh(self, boxes):
  735. new_bboxes = np.zeros([len(bboxes), 4])
  736. new_bboxes[:, 0] = bboxes[:, 0::2].min() # x1
  737. new_bboxes[:, 1] = bboxes[:, 1::2].min() # y1
  738. new_bboxes[:, 2] = bboxes[:, 0::2].max() - new_bboxes[:, 0] # w
  739. new_bboxes[:, 3] = bboxes[:, 1::2].max() - new_bboxes[:, 1] # h
  740. return new_bboxes
  741. def xyxy2xywh(self, bboxes):
  742. new_bboxes = np.empty_like(bboxes)
  743. new_bboxes[:, 0] = (bboxes[:, 0] + bboxes[:, 2]) / 2 # x center
  744. new_bboxes[:, 1] = (bboxes[:, 1] + bboxes[:, 3]) / 2 # y center
  745. new_bboxes[:, 2] = bboxes[:, 2] - bboxes[:, 0] # width
  746. new_bboxes[:, 3] = bboxes[:, 3] - bboxes[:, 1] # height
  747. return new_bboxes
  748. class SARLabelEncode(BaseRecLabelEncode):
  749. """ Convert between text-label and text-index """
  750. def __init__(self,
  751. max_text_length,
  752. character_dict_path=None,
  753. use_space_char=False,
  754. **kwargs):
  755. super(SARLabelEncode, self).__init__(
  756. max_text_length, character_dict_path, use_space_char)
  757. def add_special_char(self, dict_character):
  758. beg_end_str = "<BOS/EOS>"
  759. unknown_str = "<UKN>"
  760. padding_str = "<PAD>"
  761. dict_character = dict_character + [unknown_str]
  762. self.unknown_idx = len(dict_character) - 1
  763. dict_character = dict_character + [beg_end_str]
  764. self.start_idx = len(dict_character) - 1
  765. self.end_idx = len(dict_character) - 1
  766. dict_character = dict_character + [padding_str]
  767. self.padding_idx = len(dict_character) - 1
  768. return dict_character
  769. def __call__(self, data):
  770. text = data['label']
  771. text = self.encode(text)
  772. if text is None:
  773. return None
  774. if len(text) >= self.max_text_len - 1:
  775. return None
  776. data['length'] = np.array(len(text))
  777. target = [self.start_idx] + text + [self.end_idx]
  778. padded_text = [self.padding_idx for _ in range(self.max_text_len)]
  779. padded_text[:len(target)] = target
  780. data['label'] = np.array(padded_text)
  781. return data
  782. def get_ignored_tokens(self):
  783. return [self.padding_idx]
  784. class PRENLabelEncode(BaseRecLabelEncode):
  785. def __init__(self,
  786. max_text_length,
  787. character_dict_path,
  788. use_space_char=False,
  789. **kwargs):
  790. super(PRENLabelEncode, self).__init__(
  791. max_text_length, character_dict_path, use_space_char)
  792. def add_special_char(self, dict_character):
  793. padding_str = '<PAD>' # 0
  794. end_str = '<EOS>' # 1
  795. unknown_str = '<UNK>' # 2
  796. dict_character = [padding_str, end_str, unknown_str] + dict_character
  797. self.padding_idx = 0
  798. self.end_idx = 1
  799. self.unknown_idx = 2
  800. return dict_character
  801. def encode(self, text):
  802. if len(text) == 0 or len(text) >= self.max_text_len:
  803. return None
  804. if self.lower:
  805. text = text.lower()
  806. text_list = []
  807. for char in text:
  808. if char not in self.dict:
  809. text_list.append(self.unknown_idx)
  810. else:
  811. text_list.append(self.dict[char])
  812. text_list.append(self.end_idx)
  813. if len(text_list) < self.max_text_len:
  814. text_list += [self.padding_idx] * (
  815. self.max_text_len - len(text_list))
  816. return text_list
  817. def __call__(self, data):
  818. text = data['label']
  819. encoded_text = self.encode(text)
  820. if encoded_text is None:
  821. return None
  822. data['label'] = np.array(encoded_text)
  823. return data
  824. class VQATokenLabelEncode(object):
  825. """
  826. Label encode for NLP VQA methods
  827. """
  828. def __init__(self,
  829. class_path,
  830. contains_re=False,
  831. add_special_ids=False,
  832. algorithm='LayoutXLM',
  833. use_textline_bbox_info=True,
  834. order_method=None,
  835. infer_mode=False,
  836. ocr_engine=None,
  837. **kwargs):
  838. super(VQATokenLabelEncode, self).__init__()
  839. from paddlenlp.transformers import LayoutXLMTokenizer, LayoutLMTokenizer, LayoutLMv2Tokenizer
  840. from ppocr.utils.utility import load_vqa_bio_label_maps
  841. tokenizer_dict = {
  842. 'LayoutXLM': {
  843. 'class': LayoutXLMTokenizer,
  844. 'pretrained_model': 'layoutxlm-base-uncased'
  845. },
  846. 'LayoutLM': {
  847. 'class': LayoutLMTokenizer,
  848. 'pretrained_model': 'layoutlm-base-uncased'
  849. },
  850. 'LayoutLMv2': {
  851. 'class': LayoutLMv2Tokenizer,
  852. 'pretrained_model': 'layoutlmv2-base-uncased'
  853. }
  854. }
  855. self.contains_re = contains_re
  856. tokenizer_config = tokenizer_dict[algorithm]
  857. self.tokenizer = tokenizer_config['class'].from_pretrained(
  858. tokenizer_config['pretrained_model'])
  859. self.label2id_map, id2label_map = load_vqa_bio_label_maps(class_path)
  860. self.add_special_ids = add_special_ids
  861. self.infer_mode = infer_mode
  862. self.ocr_engine = ocr_engine
  863. self.use_textline_bbox_info = use_textline_bbox_info
  864. self.order_method = order_method
  865. assert self.order_method in [None, "tb-yx"]
  866. def split_bbox(self, bbox, text, tokenizer):
  867. words = text.split()
  868. token_bboxes = []
  869. curr_word_idx = 0
  870. x1, y1, x2, y2 = bbox
  871. unit_w = (x2 - x1) / len(text)
  872. for idx, word in enumerate(words):
  873. curr_w = len(word) * unit_w
  874. word_bbox = [x1, y1, x1 + curr_w, y2]
  875. token_bboxes.extend([word_bbox] * len(tokenizer.tokenize(word)))
  876. x1 += (len(word) + 1) * unit_w
  877. return token_bboxes
  878. def filter_empty_contents(self, ocr_info):
  879. """
  880. find out the empty texts and remove the links
  881. """
  882. new_ocr_info = []
  883. empty_index = []
  884. for idx, info in enumerate(ocr_info):
  885. if len(info["transcription"]) > 0:
  886. new_ocr_info.append(copy.deepcopy(info))
  887. else:
  888. empty_index.append(info["id"])
  889. for idx, info in enumerate(new_ocr_info):
  890. new_link = []
  891. for link in info["linking"]:
  892. if link[0] in empty_index or link[1] in empty_index:
  893. continue
  894. new_link.append(link)
  895. new_ocr_info[idx]["linking"] = new_link
  896. return new_ocr_info
  897. def __call__(self, data):
  898. # load bbox and label info
  899. ocr_info = self._load_ocr_info(data)
  900. for idx in range(len(ocr_info)):
  901. if "bbox" not in ocr_info[idx]:
  902. ocr_info[idx]["bbox"] = self.trans_poly_to_bbox(ocr_info[idx][
  903. "points"])
  904. if self.order_method == "tb-yx":
  905. ocr_info = order_by_tbyx(ocr_info)
  906. # for re
  907. train_re = self.contains_re and not self.infer_mode
  908. if train_re:
  909. ocr_info = self.filter_empty_contents(ocr_info)
  910. height, width, _ = data['image'].shape
  911. words_list = []
  912. bbox_list = []
  913. input_ids_list = []
  914. token_type_ids_list = []
  915. segment_offset_id = []
  916. gt_label_list = []
  917. entities = []
  918. if train_re:
  919. relations = []
  920. id2label = {}
  921. entity_id_to_index_map = {}
  922. empty_entity = set()
  923. data['ocr_info'] = copy.deepcopy(ocr_info)
  924. for info in ocr_info:
  925. text = info["transcription"]
  926. if len(text) <= 0:
  927. continue
  928. if train_re:
  929. # for re
  930. if len(text) == 0:
  931. empty_entity.add(info["id"])
  932. continue
  933. id2label[info["id"]] = info["label"]
  934. relations.extend([tuple(sorted(l)) for l in info["linking"]])
  935. # smooth_box
  936. info["bbox"] = self.trans_poly_to_bbox(info["points"])
  937. encode_res = self.tokenizer.encode(
  938. text,
  939. pad_to_max_seq_len=False,
  940. return_attention_mask=True,
  941. return_token_type_ids=True)
  942. if not self.add_special_ids:
  943. # TODO: use tok.all_special_ids to remove
  944. encode_res["input_ids"] = encode_res["input_ids"][1:-1]
  945. encode_res["token_type_ids"] = encode_res["token_type_ids"][1:
  946. -1]
  947. encode_res["attention_mask"] = encode_res["attention_mask"][1:
  948. -1]
  949. if self.use_textline_bbox_info:
  950. bbox = [info["bbox"]] * len(encode_res["input_ids"])
  951. else:
  952. bbox = self.split_bbox(info["bbox"], info["transcription"],
  953. self.tokenizer)
  954. if len(bbox) <= 0:
  955. continue
  956. bbox = self._smooth_box(bbox, height, width)
  957. if self.add_special_ids:
  958. bbox.insert(0, [0, 0, 0, 0])
  959. bbox.append([0, 0, 0, 0])
  960. # parse label
  961. if not self.infer_mode:
  962. label = info['label']
  963. gt_label = self._parse_label(label, encode_res)
  964. # construct entities for re
  965. if train_re:
  966. if gt_label[0] != self.label2id_map["O"]:
  967. entity_id_to_index_map[info["id"]] = len(entities)
  968. label = label.upper()
  969. entities.append({
  970. "start": len(input_ids_list),
  971. "end":
  972. len(input_ids_list) + len(encode_res["input_ids"]),
  973. "label": label.upper(),
  974. })
  975. else:
  976. entities.append({
  977. "start": len(input_ids_list),
  978. "end": len(input_ids_list) + len(encode_res["input_ids"]),
  979. "label": 'O',
  980. })
  981. input_ids_list.extend(encode_res["input_ids"])
  982. token_type_ids_list.extend(encode_res["token_type_ids"])
  983. bbox_list.extend(bbox)
  984. words_list.append(text)
  985. segment_offset_id.append(len(input_ids_list))
  986. if not self.infer_mode:
  987. gt_label_list.extend(gt_label)
  988. data['input_ids'] = input_ids_list
  989. data['token_type_ids'] = token_type_ids_list
  990. data['bbox'] = bbox_list
  991. data['attention_mask'] = [1] * len(input_ids_list)
  992. data['labels'] = gt_label_list
  993. data['segment_offset_id'] = segment_offset_id
  994. data['tokenizer_params'] = dict(
  995. padding_side=self.tokenizer.padding_side,
  996. pad_token_type_id=self.tokenizer.pad_token_type_id,
  997. pad_token_id=self.tokenizer.pad_token_id)
  998. data['entities'] = entities
  999. if train_re:
  1000. data['relations'] = relations
  1001. data['id2label'] = id2label
  1002. data['empty_entity'] = empty_entity
  1003. data['entity_id_to_index_map'] = entity_id_to_index_map
  1004. return data
  1005. def trans_poly_to_bbox(self, poly):
  1006. x1 = int(np.min([p[0] for p in poly]))
  1007. x2 = int(np.max([p[0] for p in poly]))
  1008. y1 = int(np.min([p[1] for p in poly]))
  1009. y2 = int(np.max([p[1] for p in poly]))
  1010. return [x1, y1, x2, y2]
  1011. def _load_ocr_info(self, data):
  1012. if self.infer_mode:
  1013. ocr_result = self.ocr_engine.ocr(data['image'], cls=False)[0]
  1014. ocr_info = []
  1015. for res in ocr_result:
  1016. ocr_info.append({
  1017. "transcription": res[1][0],
  1018. "bbox": self.trans_poly_to_bbox(res[0]),
  1019. "points": res[0],
  1020. })
  1021. return ocr_info
  1022. else:
  1023. info = data['label']
  1024. # read text info
  1025. info_dict = json.loads(info)
  1026. return info_dict
  1027. def _smooth_box(self, bboxes, height, width):
  1028. bboxes = np.array(bboxes)
  1029. bboxes[:, 0] = bboxes[:, 0] * 1000 / width
  1030. bboxes[:, 2] = bboxes[:, 2] * 1000 / width
  1031. bboxes[:, 1] = bboxes[:, 1] * 1000 / height
  1032. bboxes[:, 3] = bboxes[:, 3] * 1000 / height
  1033. bboxes = bboxes.astype("int64").tolist()
  1034. return bboxes
  1035. def _parse_label(self, label, encode_res):
  1036. gt_label = []
  1037. if label.lower() in ["other", "others", "ignore"]:
  1038. gt_label.extend([0] * len(encode_res["input_ids"]))
  1039. else:
  1040. gt_label.append(self.label2id_map[("b-" + label).upper()])
  1041. gt_label.extend([self.label2id_map[("i-" + label).upper()]] *
  1042. (len(encode_res["input_ids"]) - 1))
  1043. return gt_label
  1044. class MultiLabelEncode(BaseRecLabelEncode):
  1045. def __init__(self,
  1046. max_text_length,
  1047. character_dict_path=None,
  1048. use_space_char=False,
  1049. **kwargs):
  1050. super(MultiLabelEncode, self).__init__(
  1051. max_text_length, character_dict_path, use_space_char)
  1052. self.ctc_encode = CTCLabelEncode(max_text_length, character_dict_path,
  1053. use_space_char, **kwargs)
  1054. self.sar_encode = SARLabelEncode(max_text_length, character_dict_path,
  1055. use_space_char, **kwargs)
  1056. def __call__(self, data):
  1057. data_ctc = copy.deepcopy(data)
  1058. data_sar = copy.deepcopy(data)
  1059. data_out = dict()
  1060. data_out['img_path'] = data.get('img_path', None)
  1061. data_out['image'] = data['image']
  1062. ctc = self.ctc_encode.__call__(data_ctc)
  1063. sar = self.sar_encode.__call__(data_sar)
  1064. if ctc is None or sar is None:
  1065. return None
  1066. data_out['label_ctc'] = ctc['label']
  1067. data_out['label_sar'] = sar['label']
  1068. data_out['length'] = ctc['length']
  1069. return data_out
  1070. class NRTRLabelEncode(BaseRecLabelEncode):
  1071. """ Convert between text-label and text-index """
  1072. def __init__(self,
  1073. max_text_length,
  1074. character_dict_path=None,
  1075. use_space_char=False,
  1076. **kwargs):
  1077. super(NRTRLabelEncode, self).__init__(
  1078. max_text_length, character_dict_path, use_space_char)
  1079. def __call__(self, data):
  1080. text = data['label']
  1081. text = self.encode(text)
  1082. if text is None:
  1083. return None
  1084. if len(text) >= self.max_text_len - 1:
  1085. return None
  1086. data['length'] = np.array(len(text))
  1087. text.insert(0, 2)
  1088. text.append(3)
  1089. text = text + [0] * (self.max_text_len - len(text))
  1090. data['label'] = np.array(text)
  1091. return data
  1092. def add_special_char(self, dict_character):
  1093. dict_character = ['blank', '<unk>', '<s>', '</s>'] + dict_character
  1094. return dict_character
  1095. class ViTSTRLabelEncode(BaseRecLabelEncode):
  1096. """ Convert between text-label and text-index """
  1097. def __init__(self,
  1098. max_text_length,
  1099. character_dict_path=None,
  1100. use_space_char=False,
  1101. ignore_index=0,
  1102. **kwargs):
  1103. super(ViTSTRLabelEncode, self).__init__(
  1104. max_text_length, character_dict_path, use_space_char)
  1105. self.ignore_index = ignore_index
  1106. def __call__(self, data):
  1107. text = data['label']
  1108. text = self.encode(text)
  1109. if text is None:
  1110. return None
  1111. if len(text) >= self.max_text_len:
  1112. return None
  1113. data['length'] = np.array(len(text))
  1114. text.insert(0, self.ignore_index)
  1115. text.append(1)
  1116. text = text + [self.ignore_index] * (self.max_text_len + 2 - len(text))
  1117. data['label'] = np.array(text)
  1118. return data
  1119. def add_special_char(self, dict_character):
  1120. dict_character = ['<s>', '</s>'] + dict_character
  1121. return dict_character
  1122. class ABINetLabelEncode(BaseRecLabelEncode):
  1123. """ Convert between text-label and text-index """
  1124. def __init__(self,
  1125. max_text_length,
  1126. character_dict_path=None,
  1127. use_space_char=False,
  1128. ignore_index=100,
  1129. **kwargs):
  1130. super(ABINetLabelEncode, self).__init__(
  1131. max_text_length, character_dict_path, use_space_char)
  1132. self.ignore_index = ignore_index
  1133. def __call__(self, data):
  1134. text = data['label']
  1135. text = self.encode(text)
  1136. if text is None:
  1137. return None
  1138. if len(text) >= self.max_text_len:
  1139. return None
  1140. data['length'] = np.array(len(text))
  1141. text.append(0)
  1142. text = text + [self.ignore_index] * (self.max_text_len + 1 - len(text))
  1143. data['label'] = np.array(text)
  1144. return data
  1145. def add_special_char(self, dict_character):
  1146. dict_character = ['</s>'] + dict_character
  1147. return dict_character
  1148. class SRLabelEncode(BaseRecLabelEncode):
  1149. def __init__(self,
  1150. max_text_length,
  1151. character_dict_path=None,
  1152. use_space_char=False,
  1153. **kwargs):
  1154. super(SRLabelEncode, self).__init__(max_text_length,
  1155. character_dict_path, use_space_char)
  1156. self.dic = {}
  1157. with open(character_dict_path, 'r') as fin:
  1158. for line in fin.readlines():
  1159. line = line.strip()
  1160. character, sequence = line.split()
  1161. self.dic[character] = sequence
  1162. english_stroke_alphabet = '0123456789'
  1163. self.english_stroke_dict = {}
  1164. for index in range(len(english_stroke_alphabet)):
  1165. self.english_stroke_dict[english_stroke_alphabet[index]] = index
  1166. def encode(self, label):
  1167. stroke_sequence = ''
  1168. for character in label:
  1169. if character not in self.dic:
  1170. continue
  1171. else:
  1172. stroke_sequence += self.dic[character]
  1173. stroke_sequence += '0'
  1174. label = stroke_sequence
  1175. length = len(label)
  1176. input_tensor = np.zeros(self.max_text_len).astype("int64")
  1177. for j in range(length - 1):
  1178. input_tensor[j + 1] = self.english_stroke_dict[label[j]]
  1179. return length, input_tensor
  1180. def __call__(self, data):
  1181. text = data['label']
  1182. length, input_tensor = self.encode(text)
  1183. data["length"] = length
  1184. data["input_tensor"] = input_tensor
  1185. if text is None:
  1186. return None
  1187. return data
  1188. class SPINLabelEncode(AttnLabelEncode):
  1189. """ Convert between text-label and text-index """
  1190. def __init__(self,
  1191. max_text_length,
  1192. character_dict_path=None,
  1193. use_space_char=False,
  1194. lower=True,
  1195. **kwargs):
  1196. super(SPINLabelEncode, self).__init__(
  1197. max_text_length, character_dict_path, use_space_char)
  1198. self.lower = lower
  1199. def add_special_char(self, dict_character):
  1200. self.beg_str = "sos"
  1201. self.end_str = "eos"
  1202. dict_character = [self.beg_str] + [self.end_str] + dict_character
  1203. return dict_character
  1204. def __call__(self, data):
  1205. text = data['label']
  1206. text = self.encode(text)
  1207. if text is None:
  1208. return None
  1209. if len(text) > self.max_text_len:
  1210. return None
  1211. data['length'] = np.array(len(text))
  1212. target = [0] + text + [1]
  1213. padded_text = [0 for _ in range(self.max_text_len + 2)]
  1214. padded_text[:len(target)] = target
  1215. data['label'] = np.array(padded_text)
  1216. return data
  1217. class VLLabelEncode(BaseRecLabelEncode):
  1218. """ Convert between text-label and text-index """
  1219. def __init__(self,
  1220. max_text_length,
  1221. character_dict_path=None,
  1222. use_space_char=False,
  1223. **kwargs):
  1224. super(VLLabelEncode, self).__init__(max_text_length,
  1225. character_dict_path, use_space_char)
  1226. self.dict = {}
  1227. for i, char in enumerate(self.character):
  1228. self.dict[char] = i
  1229. def __call__(self, data):
  1230. text = data['label'] # original string
  1231. # generate occluded text
  1232. len_str = len(text)
  1233. if len_str <= 0:
  1234. return None
  1235. change_num = 1
  1236. order = list(range(len_str))
  1237. change_id = sample(order, change_num)[0]
  1238. label_sub = text[change_id]
  1239. if change_id == (len_str - 1):
  1240. label_res = text[:change_id]
  1241. elif change_id == 0:
  1242. label_res = text[1:]
  1243. else:
  1244. label_res = text[:change_id] + text[change_id + 1:]
  1245. data['label_res'] = label_res # remaining string
  1246. data['label_sub'] = label_sub # occluded character
  1247. data['label_id'] = change_id # character index
  1248. # encode label
  1249. text = self.encode(text)
  1250. if text is None:
  1251. return None
  1252. text = [i + 1 for i in text]
  1253. data['length'] = np.array(len(text))
  1254. text = text + [0] * (self.max_text_len - len(text))
  1255. data['label'] = np.array(text)
  1256. label_res = self.encode(label_res)
  1257. label_sub = self.encode(label_sub)
  1258. if label_res is None:
  1259. label_res = []
  1260. else:
  1261. label_res = [i + 1 for i in label_res]
  1262. if label_sub is None:
  1263. label_sub = []
  1264. else:
  1265. label_sub = [i + 1 for i in label_sub]
  1266. data['length_res'] = np.array(len(label_res))
  1267. data['length_sub'] = np.array(len(label_sub))
  1268. label_res = label_res + [0] * (self.max_text_len - len(label_res))
  1269. label_sub = label_sub + [0] * (self.max_text_len - len(label_sub))
  1270. data['label_res'] = np.array(label_res)
  1271. data['label_sub'] = np.array(label_sub)
  1272. return data
  1273. class CTLabelEncode(object):
  1274. def __init__(self, **kwargs):
  1275. pass
  1276. def __call__(self, data):
  1277. label = data['label']
  1278. label = json.loads(label)
  1279. nBox = len(label)
  1280. boxes, txts = [], []
  1281. for bno in range(0, nBox):
  1282. box = label[bno]['points']
  1283. box = np.array(box)
  1284. boxes.append(box)
  1285. txt = label[bno]['transcription']
  1286. txts.append(txt)
  1287. if len(boxes) == 0:
  1288. return None
  1289. data['polys'] = boxes
  1290. data['texts'] = txts
  1291. return data
  1292. class CANLabelEncode(BaseRecLabelEncode):
  1293. def __init__(self,
  1294. character_dict_path,
  1295. max_text_length=100,
  1296. use_space_char=False,
  1297. lower=True,
  1298. **kwargs):
  1299. super(CANLabelEncode, self).__init__(
  1300. max_text_length, character_dict_path, use_space_char, lower)
  1301. def encode(self, text_seq):
  1302. text_seq_encoded = []
  1303. for text in text_seq:
  1304. if text not in self.character:
  1305. continue
  1306. text_seq_encoded.append(self.dict.get(text))
  1307. if len(text_seq_encoded) == 0:
  1308. return None
  1309. return text_seq_encoded
  1310. def __call__(self, data):
  1311. label = data['label']
  1312. if isinstance(label, str):
  1313. label = label.strip().split()
  1314. label.append(self.end_str)
  1315. data['label'] = self.encode(label)
  1316. return data