Deteval.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import json
  15. import numpy as np
  16. import scipy.io as io
  17. from ppocr.utils.utility import check_install
  18. from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
  19. def get_socre_A(gt_dir, pred_dict):
  20. allInputs = 1
  21. def input_reading_mod(pred_dict):
  22. """This helper reads input from txt files"""
  23. det = []
  24. n = len(pred_dict)
  25. for i in range(n):
  26. points = pred_dict[i]['points']
  27. text = pred_dict[i]['texts']
  28. point = ",".join(map(str, points.reshape(-1, )))
  29. det.append([point, text])
  30. return det
  31. def gt_reading_mod(gt_dict):
  32. """This helper reads groundtruths from mat files"""
  33. gt = []
  34. n = len(gt_dict)
  35. for i in range(n):
  36. points = gt_dict[i]['points'].tolist()
  37. h = len(points)
  38. text = gt_dict[i]['text']
  39. xx = [
  40. np.array(
  41. ['x:'], dtype='<U2'), 0, np.array(
  42. ['y:'], dtype='<U2'), 0, np.array(
  43. ['#'], dtype='<U1'), np.array(
  44. ['#'], dtype='<U1')
  45. ]
  46. t_x, t_y = [], []
  47. for j in range(h):
  48. t_x.append(points[j][0])
  49. t_y.append(points[j][1])
  50. xx[1] = np.array([t_x], dtype='int16')
  51. xx[3] = np.array([t_y], dtype='int16')
  52. if text != "":
  53. xx[4] = np.array([text], dtype='U{}'.format(len(text)))
  54. xx[5] = np.array(['c'], dtype='<U1')
  55. gt.append(xx)
  56. return gt
  57. def detection_filtering(detections, groundtruths, threshold=0.5):
  58. for gt_id, gt in enumerate(groundtruths):
  59. if (gt[5] == '#') and (gt[1].shape[1] > 1):
  60. gt_x = list(map(int, np.squeeze(gt[1])))
  61. gt_y = list(map(int, np.squeeze(gt[3])))
  62. for det_id, detection in enumerate(detections):
  63. detection_orig = detection
  64. detection = [float(x) for x in detection[0].split(',')]
  65. detection = list(map(int, detection))
  66. det_x = detection[0::2]
  67. det_y = detection[1::2]
  68. det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
  69. if det_gt_iou > threshold:
  70. detections[det_id] = []
  71. detections[:] = [item for item in detections if item != []]
  72. return detections
  73. def sigma_calculation(det_x, det_y, gt_x, gt_y):
  74. """
  75. sigma = inter_area / gt_area
  76. """
  77. return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
  78. area(gt_x, gt_y)), 2)
  79. def tau_calculation(det_x, det_y, gt_x, gt_y):
  80. if area(det_x, det_y) == 0.0:
  81. return 0
  82. return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
  83. area(det_x, det_y)), 2)
  84. ##############################Initialization###################################
  85. # global_sigma = []
  86. # global_tau = []
  87. # global_pred_str = []
  88. # global_gt_str = []
  89. ###############################################################################
  90. for input_id in range(allInputs):
  91. if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
  92. input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
  93. input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
  94. and (input_id != 'Deteval_result_non_curved.txt'):
  95. detections = input_reading_mod(pred_dict)
  96. groundtruths = gt_reading_mod(gt_dir)
  97. detections = detection_filtering(
  98. detections,
  99. groundtruths) # filters detections overlapping with DC area
  100. dc_id = []
  101. for i in range(len(groundtruths)):
  102. if groundtruths[i][5] == '#':
  103. dc_id.append(i)
  104. cnt = 0
  105. for a in dc_id:
  106. num = a - cnt
  107. del groundtruths[num]
  108. cnt += 1
  109. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  110. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  111. local_pred_str = {}
  112. local_gt_str = {}
  113. for gt_id, gt in enumerate(groundtruths):
  114. if len(detections) > 0:
  115. for det_id, detection in enumerate(detections):
  116. detection_orig = detection
  117. detection = [float(x) for x in detection[0].split(',')]
  118. detection = list(map(int, detection))
  119. pred_seq_str = detection_orig[1].strip()
  120. det_x = detection[0::2]
  121. det_y = detection[1::2]
  122. gt_x = list(map(int, np.squeeze(gt[1])))
  123. gt_y = list(map(int, np.squeeze(gt[3])))
  124. gt_seq_str = str(gt[4].tolist()[0])
  125. local_sigma_table[gt_id, det_id] = sigma_calculation(
  126. det_x, det_y, gt_x, gt_y)
  127. local_tau_table[gt_id, det_id] = tau_calculation(
  128. det_x, det_y, gt_x, gt_y)
  129. local_pred_str[det_id] = pred_seq_str
  130. local_gt_str[gt_id] = gt_seq_str
  131. global_sigma = local_sigma_table
  132. global_tau = local_tau_table
  133. global_pred_str = local_pred_str
  134. global_gt_str = local_gt_str
  135. single_data = {}
  136. single_data['sigma'] = global_sigma
  137. single_data['global_tau'] = global_tau
  138. single_data['global_pred_str'] = global_pred_str
  139. single_data['global_gt_str'] = global_gt_str
  140. return single_data
  141. def get_socre_B(gt_dir, img_id, pred_dict):
  142. allInputs = 1
  143. def input_reading_mod(pred_dict):
  144. """This helper reads input from txt files"""
  145. det = []
  146. n = len(pred_dict)
  147. for i in range(n):
  148. points = pred_dict[i]['points']
  149. text = pred_dict[i]['texts']
  150. point = ",".join(map(str, points.reshape(-1, )))
  151. det.append([point, text])
  152. return det
  153. def gt_reading_mod(gt_dir, gt_id):
  154. gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
  155. gt = gt['polygt']
  156. return gt
  157. def detection_filtering(detections, groundtruths, threshold=0.5):
  158. for gt_id, gt in enumerate(groundtruths):
  159. if (gt[5] == '#') and (gt[1].shape[1] > 1):
  160. gt_x = list(map(int, np.squeeze(gt[1])))
  161. gt_y = list(map(int, np.squeeze(gt[3])))
  162. for det_id, detection in enumerate(detections):
  163. detection_orig = detection
  164. detection = [float(x) for x in detection[0].split(',')]
  165. detection = list(map(int, detection))
  166. det_x = detection[0::2]
  167. det_y = detection[1::2]
  168. det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
  169. if det_gt_iou > threshold:
  170. detections[det_id] = []
  171. detections[:] = [item for item in detections if item != []]
  172. return detections
  173. def sigma_calculation(det_x, det_y, gt_x, gt_y):
  174. """
  175. sigma = inter_area / gt_area
  176. """
  177. return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
  178. area(gt_x, gt_y)), 2)
  179. def tau_calculation(det_x, det_y, gt_x, gt_y):
  180. if area(det_x, det_y) == 0.0:
  181. return 0
  182. return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
  183. area(det_x, det_y)), 2)
  184. ##############################Initialization###################################
  185. # global_sigma = []
  186. # global_tau = []
  187. # global_pred_str = []
  188. # global_gt_str = []
  189. ###############################################################################
  190. for input_id in range(allInputs):
  191. if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
  192. input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
  193. input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
  194. and (input_id != 'Deteval_result_non_curved.txt'):
  195. detections = input_reading_mod(pred_dict)
  196. groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
  197. detections = detection_filtering(
  198. detections,
  199. groundtruths) # filters detections overlapping with DC area
  200. dc_id = []
  201. for i in range(len(groundtruths)):
  202. if groundtruths[i][5] == '#':
  203. dc_id.append(i)
  204. cnt = 0
  205. for a in dc_id:
  206. num = a - cnt
  207. del groundtruths[num]
  208. cnt += 1
  209. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  210. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  211. local_pred_str = {}
  212. local_gt_str = {}
  213. for gt_id, gt in enumerate(groundtruths):
  214. if len(detections) > 0:
  215. for det_id, detection in enumerate(detections):
  216. detection_orig = detection
  217. detection = [float(x) for x in detection[0].split(',')]
  218. detection = list(map(int, detection))
  219. pred_seq_str = detection_orig[1].strip()
  220. det_x = detection[0::2]
  221. det_y = detection[1::2]
  222. gt_x = list(map(int, np.squeeze(gt[1])))
  223. gt_y = list(map(int, np.squeeze(gt[3])))
  224. gt_seq_str = str(gt[4].tolist()[0])
  225. local_sigma_table[gt_id, det_id] = sigma_calculation(
  226. det_x, det_y, gt_x, gt_y)
  227. local_tau_table[gt_id, det_id] = tau_calculation(
  228. det_x, det_y, gt_x, gt_y)
  229. local_pred_str[det_id] = pred_seq_str
  230. local_gt_str[gt_id] = gt_seq_str
  231. global_sigma = local_sigma_table
  232. global_tau = local_tau_table
  233. global_pred_str = local_pred_str
  234. global_gt_str = local_gt_str
  235. single_data = {}
  236. single_data['sigma'] = global_sigma
  237. single_data['global_tau'] = global_tau
  238. single_data['global_pred_str'] = global_pred_str
  239. single_data['global_gt_str'] = global_gt_str
  240. return single_data
  241. def get_score_C(gt_label, text, pred_bboxes):
  242. """
  243. get score for CentripetalText (CT) prediction.
  244. """
  245. check_install("Polygon", "Polygon3")
  246. import Polygon as plg
  247. def gt_reading_mod(gt_label, text):
  248. """This helper reads groundtruths from mat files"""
  249. groundtruths = []
  250. nbox = len(gt_label)
  251. for i in range(nbox):
  252. label = {"transcription": text[i][0], "points": gt_label[i].numpy()}
  253. groundtruths.append(label)
  254. return groundtruths
  255. def get_union(pD, pG):
  256. areaA = pD.area()
  257. areaB = pG.area()
  258. return areaA + areaB - get_intersection(pD, pG)
  259. def get_intersection(pD, pG):
  260. pInt = pD & pG
  261. if len(pInt) == 0:
  262. return 0
  263. return pInt.area()
  264. def detection_filtering(detections, groundtruths, threshold=0.5):
  265. for gt in groundtruths:
  266. point_num = gt['points'].shape[1] // 2
  267. if gt['transcription'] == '###' and (point_num > 1):
  268. gt_p = np.array(gt['points']).reshape(point_num,
  269. 2).astype('int32')
  270. gt_p = plg.Polygon(gt_p)
  271. for det_id, detection in enumerate(detections):
  272. det_y = detection[0::2]
  273. det_x = detection[1::2]
  274. det_p = np.concatenate((np.array(det_x), np.array(det_y)))
  275. det_p = det_p.reshape(2, -1).transpose()
  276. det_p = plg.Polygon(det_p)
  277. try:
  278. det_gt_iou = get_intersection(det_p,
  279. gt_p) / det_p.area()
  280. except:
  281. print(det_x, det_y, gt_p)
  282. if det_gt_iou > threshold:
  283. detections[det_id] = []
  284. detections[:] = [item for item in detections if item != []]
  285. return detections
  286. def sigma_calculation(det_p, gt_p):
  287. """
  288. sigma = inter_area / gt_area
  289. """
  290. if gt_p.area() == 0.:
  291. return 0
  292. return get_intersection(det_p, gt_p) / gt_p.area()
  293. def tau_calculation(det_p, gt_p):
  294. """
  295. tau = inter_area / det_area
  296. """
  297. if det_p.area() == 0.:
  298. return 0
  299. return get_intersection(det_p, gt_p) / det_p.area()
  300. detections = []
  301. for item in pred_bboxes:
  302. detections.append(item[:, ::-1].reshape(-1))
  303. groundtruths = gt_reading_mod(gt_label, text)
  304. detections = detection_filtering(
  305. detections, groundtruths) # filters detections overlapping with DC area
  306. for idx in range(len(groundtruths) - 1, -1, -1):
  307. #NOTE: source code use 'orin' to indicate '#', here we use 'anno',
  308. # which may cause slight drop in fscore, about 0.12
  309. if groundtruths[idx]['transcription'] == '###':
  310. groundtruths.pop(idx)
  311. local_sigma_table = np.zeros((len(groundtruths), len(detections)))
  312. local_tau_table = np.zeros((len(groundtruths), len(detections)))
  313. for gt_id, gt in enumerate(groundtruths):
  314. if len(detections) > 0:
  315. for det_id, detection in enumerate(detections):
  316. point_num = gt['points'].shape[1] // 2
  317. gt_p = np.array(gt['points']).reshape(point_num,
  318. 2).astype('int32')
  319. gt_p = plg.Polygon(gt_p)
  320. det_y = detection[0::2]
  321. det_x = detection[1::2]
  322. det_p = np.concatenate((np.array(det_x), np.array(det_y)))
  323. det_p = det_p.reshape(2, -1).transpose()
  324. det_p = plg.Polygon(det_p)
  325. local_sigma_table[gt_id, det_id] = sigma_calculation(det_p,
  326. gt_p)
  327. local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
  328. data = {}
  329. data['sigma'] = local_sigma_table
  330. data['global_tau'] = local_tau_table
  331. data['global_pred_str'] = ''
  332. data['global_gt_str'] = ''
  333. return data
  334. def combine_results(all_data, rec_flag=True):
  335. tr = 0.7
  336. tp = 0.6
  337. fsc_k = 0.8
  338. k = 2
  339. global_sigma = []
  340. global_tau = []
  341. global_pred_str = []
  342. global_gt_str = []
  343. for data in all_data:
  344. global_sigma.append(data['sigma'])
  345. global_tau.append(data['global_tau'])
  346. global_pred_str.append(data['global_pred_str'])
  347. global_gt_str.append(data['global_gt_str'])
  348. global_accumulative_recall = 0
  349. global_accumulative_precision = 0
  350. total_num_gt = 0
  351. total_num_det = 0
  352. hit_str_count = 0
  353. hit_count = 0
  354. def one_to_one(local_sigma_table, local_tau_table,
  355. local_accumulative_recall, local_accumulative_precision,
  356. global_accumulative_recall, global_accumulative_precision,
  357. gt_flag, det_flag, idy, rec_flag):
  358. hit_str_num = 0
  359. for gt_id in range(num_gt):
  360. gt_matching_qualified_sigma_candidates = np.where(
  361. local_sigma_table[gt_id, :] > tr)
  362. gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
  363. 0].shape[0]
  364. gt_matching_qualified_tau_candidates = np.where(
  365. local_tau_table[gt_id, :] > tp)
  366. gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
  367. 0].shape[0]
  368. det_matching_qualified_sigma_candidates = np.where(
  369. local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
  370. > tr)
  371. det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
  372. 0].shape[0]
  373. det_matching_qualified_tau_candidates = np.where(
  374. local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
  375. tp)
  376. det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
  377. 0].shape[0]
  378. if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
  379. (det_matching_num_qualified_sigma_candidates == 1) and (
  380. det_matching_num_qualified_tau_candidates == 1):
  381. global_accumulative_recall = global_accumulative_recall + 1.0
  382. global_accumulative_precision = global_accumulative_precision + 1.0
  383. local_accumulative_recall = local_accumulative_recall + 1.0
  384. local_accumulative_precision = local_accumulative_precision + 1.0
  385. gt_flag[0, gt_id] = 1
  386. matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
  387. # recg start
  388. if rec_flag:
  389. gt_str_cur = global_gt_str[idy][gt_id]
  390. pred_str_cur = global_pred_str[idy][matched_det_id[0]
  391. .tolist()[0]]
  392. if pred_str_cur == gt_str_cur:
  393. hit_str_num += 1
  394. else:
  395. if pred_str_cur.lower() == gt_str_cur.lower():
  396. hit_str_num += 1
  397. # recg end
  398. det_flag[0, matched_det_id] = 1
  399. return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
  400. def one_to_many(local_sigma_table, local_tau_table,
  401. local_accumulative_recall, local_accumulative_precision,
  402. global_accumulative_recall, global_accumulative_precision,
  403. gt_flag, det_flag, idy, rec_flag):
  404. hit_str_num = 0
  405. for gt_id in range(num_gt):
  406. # skip the following if the groundtruth was matched
  407. if gt_flag[0, gt_id] > 0:
  408. continue
  409. non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
  410. num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
  411. if num_non_zero_in_sigma >= k:
  412. ####search for all detections that overlaps with this groundtruth
  413. qualified_tau_candidates = np.where((local_tau_table[
  414. gt_id, :] >= tp) & (det_flag[0, :] == 0))
  415. num_qualified_tau_candidates = qualified_tau_candidates[
  416. 0].shape[0]
  417. if num_qualified_tau_candidates == 1:
  418. if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
  419. and
  420. (local_sigma_table[gt_id, qualified_tau_candidates] >=
  421. tr)):
  422. # became an one-to-one case
  423. global_accumulative_recall = global_accumulative_recall + 1.0
  424. global_accumulative_precision = global_accumulative_precision + 1.0
  425. local_accumulative_recall = local_accumulative_recall + 1.0
  426. local_accumulative_precision = local_accumulative_precision + 1.0
  427. gt_flag[0, gt_id] = 1
  428. det_flag[0, qualified_tau_candidates] = 1
  429. # recg start
  430. if rec_flag:
  431. gt_str_cur = global_gt_str[idy][gt_id]
  432. pred_str_cur = global_pred_str[idy][
  433. qualified_tau_candidates[0].tolist()[0]]
  434. if pred_str_cur == gt_str_cur:
  435. hit_str_num += 1
  436. else:
  437. if pred_str_cur.lower() == gt_str_cur.lower():
  438. hit_str_num += 1
  439. # recg end
  440. elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
  441. >= tr):
  442. gt_flag[0, gt_id] = 1
  443. det_flag[0, qualified_tau_candidates] = 1
  444. # recg start
  445. if rec_flag:
  446. gt_str_cur = global_gt_str[idy][gt_id]
  447. pred_str_cur = global_pred_str[idy][
  448. qualified_tau_candidates[0].tolist()[0]]
  449. if pred_str_cur == gt_str_cur:
  450. hit_str_num += 1
  451. else:
  452. if pred_str_cur.lower() == gt_str_cur.lower():
  453. hit_str_num += 1
  454. # recg end
  455. global_accumulative_recall = global_accumulative_recall + fsc_k
  456. global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
  457. local_accumulative_recall = local_accumulative_recall + fsc_k
  458. local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
  459. return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
  460. def many_to_one(local_sigma_table, local_tau_table,
  461. local_accumulative_recall, local_accumulative_precision,
  462. global_accumulative_recall, global_accumulative_precision,
  463. gt_flag, det_flag, idy, rec_flag):
  464. hit_str_num = 0
  465. for det_id in range(num_det):
  466. # skip the following if the detection was matched
  467. if det_flag[0, det_id] > 0:
  468. continue
  469. non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
  470. num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
  471. if num_non_zero_in_tau >= k:
  472. ####search for all detections that overlaps with this groundtruth
  473. qualified_sigma_candidates = np.where((
  474. local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
  475. num_qualified_sigma_candidates = qualified_sigma_candidates[
  476. 0].shape[0]
  477. if num_qualified_sigma_candidates == 1:
  478. if ((local_tau_table[qualified_sigma_candidates, det_id] >=
  479. tp) and
  480. (local_sigma_table[qualified_sigma_candidates, det_id]
  481. >= tr)):
  482. # became an one-to-one case
  483. global_accumulative_recall = global_accumulative_recall + 1.0
  484. global_accumulative_precision = global_accumulative_precision + 1.0
  485. local_accumulative_recall = local_accumulative_recall + 1.0
  486. local_accumulative_precision = local_accumulative_precision + 1.0
  487. gt_flag[0, qualified_sigma_candidates] = 1
  488. det_flag[0, det_id] = 1
  489. # recg start
  490. if rec_flag:
  491. pred_str_cur = global_pred_str[idy][det_id]
  492. gt_len = len(qualified_sigma_candidates[0])
  493. for idx in range(gt_len):
  494. ele_gt_id = qualified_sigma_candidates[
  495. 0].tolist()[idx]
  496. if ele_gt_id not in global_gt_str[idy]:
  497. continue
  498. gt_str_cur = global_gt_str[idy][ele_gt_id]
  499. if pred_str_cur == gt_str_cur:
  500. hit_str_num += 1
  501. break
  502. else:
  503. if pred_str_cur.lower() == gt_str_cur.lower(
  504. ):
  505. hit_str_num += 1
  506. break
  507. # recg end
  508. elif (np.sum(local_tau_table[qualified_sigma_candidates,
  509. det_id]) >= tp):
  510. det_flag[0, det_id] = 1
  511. gt_flag[0, qualified_sigma_candidates] = 1
  512. # recg start
  513. if rec_flag:
  514. pred_str_cur = global_pred_str[idy][det_id]
  515. gt_len = len(qualified_sigma_candidates[0])
  516. for idx in range(gt_len):
  517. ele_gt_id = qualified_sigma_candidates[0].tolist()[
  518. idx]
  519. if ele_gt_id not in global_gt_str[idy]:
  520. continue
  521. gt_str_cur = global_gt_str[idy][ele_gt_id]
  522. if pred_str_cur == gt_str_cur:
  523. hit_str_num += 1
  524. break
  525. else:
  526. if pred_str_cur.lower() == gt_str_cur.lower():
  527. hit_str_num += 1
  528. break
  529. # recg end
  530. global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
  531. global_accumulative_precision = global_accumulative_precision + fsc_k
  532. local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
  533. local_accumulative_precision = local_accumulative_precision + fsc_k
  534. return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
  535. for idx in range(len(global_sigma)):
  536. local_sigma_table = np.array(global_sigma[idx])
  537. local_tau_table = global_tau[idx]
  538. num_gt = local_sigma_table.shape[0]
  539. num_det = local_sigma_table.shape[1]
  540. total_num_gt = total_num_gt + num_gt
  541. total_num_det = total_num_det + num_det
  542. local_accumulative_recall = 0
  543. local_accumulative_precision = 0
  544. gt_flag = np.zeros((1, num_gt))
  545. det_flag = np.zeros((1, num_det))
  546. #######first check for one-to-one case##########
  547. local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
  548. gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
  549. local_accumulative_recall, local_accumulative_precision,
  550. global_accumulative_recall, global_accumulative_precision,
  551. gt_flag, det_flag, idx, rec_flag)
  552. hit_str_count += hit_str_num
  553. #######then check for one-to-many case##########
  554. local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
  555. gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
  556. local_accumulative_recall, local_accumulative_precision,
  557. global_accumulative_recall, global_accumulative_precision,
  558. gt_flag, det_flag, idx, rec_flag)
  559. hit_str_count += hit_str_num
  560. #######then check for many-to-one case##########
  561. local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
  562. gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
  563. local_accumulative_recall, local_accumulative_precision,
  564. global_accumulative_recall, global_accumulative_precision,
  565. gt_flag, det_flag, idx, rec_flag)
  566. hit_str_count += hit_str_num
  567. try:
  568. recall = global_accumulative_recall / total_num_gt
  569. except ZeroDivisionError:
  570. recall = 0
  571. try:
  572. precision = global_accumulative_precision / total_num_det
  573. except ZeroDivisionError:
  574. precision = 0
  575. try:
  576. f_score = 2 * precision * recall / (precision + recall)
  577. except ZeroDivisionError:
  578. f_score = 0
  579. try:
  580. seqerr = 1 - float(hit_str_count) / global_accumulative_recall
  581. except ZeroDivisionError:
  582. seqerr = 1
  583. try:
  584. recall_e2e = float(hit_str_count) / total_num_gt
  585. except ZeroDivisionError:
  586. recall_e2e = 0
  587. try:
  588. precision_e2e = float(hit_str_count) / total_num_det
  589. except ZeroDivisionError:
  590. precision_e2e = 0
  591. try:
  592. f_score_e2e = 2 * precision_e2e * recall_e2e / (
  593. precision_e2e + recall_e2e)
  594. except ZeroDivisionError:
  595. f_score_e2e = 0
  596. final = {
  597. 'total_num_gt': total_num_gt,
  598. 'total_num_det': total_num_det,
  599. 'global_accumulative_recall': global_accumulative_recall,
  600. 'hit_str_count': hit_str_count,
  601. 'recall': recall,
  602. 'precision': precision,
  603. 'f_score': f_score,
  604. 'seqerr': seqerr,
  605. 'recall_e2e': recall_e2e,
  606. 'precision_e2e': precision_e2e,
  607. 'f_score_e2e': f_score_e2e
  608. }
  609. return final