123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import json
- import numpy as np
- import scipy.io as io
- from ppocr.utils.utility import check_install
- from ppocr.utils.e2e_metric.polygon_fast import iod, area_of_intersection, area
- def get_socre_A(gt_dir, pred_dict):
- allInputs = 1
- def input_reading_mod(pred_dict):
- """This helper reads input from txt files"""
- det = []
- n = len(pred_dict)
- for i in range(n):
- points = pred_dict[i]['points']
- text = pred_dict[i]['texts']
- point = ",".join(map(str, points.reshape(-1, )))
- det.append([point, text])
- return det
- def gt_reading_mod(gt_dict):
- """This helper reads groundtruths from mat files"""
- gt = []
- n = len(gt_dict)
- for i in range(n):
- points = gt_dict[i]['points'].tolist()
- h = len(points)
- text = gt_dict[i]['text']
- xx = [
- np.array(
- ['x:'], dtype='<U2'), 0, np.array(
- ['y:'], dtype='<U2'), 0, np.array(
- ['#'], dtype='<U1'), np.array(
- ['#'], dtype='<U1')
- ]
- t_x, t_y = [], []
- for j in range(h):
- t_x.append(points[j][0])
- t_y.append(points[j][1])
- xx[1] = np.array([t_x], dtype='int16')
- xx[3] = np.array([t_y], dtype='int16')
- if text != "":
- xx[4] = np.array([text], dtype='U{}'.format(len(text)))
- xx[5] = np.array(['c'], dtype='<U1')
- gt.append(xx)
- return gt
- def detection_filtering(detections, groundtruths, threshold=0.5):
- for gt_id, gt in enumerate(groundtruths):
- if (gt[5] == '#') and (gt[1].shape[1] > 1):
- gt_x = list(map(int, np.squeeze(gt[1])))
- gt_y = list(map(int, np.squeeze(gt[3])))
- for det_id, detection in enumerate(detections):
- detection_orig = detection
- detection = [float(x) for x in detection[0].split(',')]
- detection = list(map(int, detection))
- det_x = detection[0::2]
- det_y = detection[1::2]
- det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
- if det_gt_iou > threshold:
- detections[det_id] = []
- detections[:] = [item for item in detections if item != []]
- return detections
- def sigma_calculation(det_x, det_y, gt_x, gt_y):
- """
- sigma = inter_area / gt_area
- """
- return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
- area(gt_x, gt_y)), 2)
- def tau_calculation(det_x, det_y, gt_x, gt_y):
- if area(det_x, det_y) == 0.0:
- return 0
- return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
- area(det_x, det_y)), 2)
- ##############################Initialization###################################
- # global_sigma = []
- # global_tau = []
- # global_pred_str = []
- # global_gt_str = []
- ###############################################################################
- for input_id in range(allInputs):
- if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
- input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
- input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
- and (input_id != 'Deteval_result_non_curved.txt'):
- detections = input_reading_mod(pred_dict)
- groundtruths = gt_reading_mod(gt_dir)
- detections = detection_filtering(
- detections,
- groundtruths) # filters detections overlapping with DC area
- dc_id = []
- for i in range(len(groundtruths)):
- if groundtruths[i][5] == '#':
- dc_id.append(i)
- cnt = 0
- for a in dc_id:
- num = a - cnt
- del groundtruths[num]
- cnt += 1
- local_sigma_table = np.zeros((len(groundtruths), len(detections)))
- local_tau_table = np.zeros((len(groundtruths), len(detections)))
- local_pred_str = {}
- local_gt_str = {}
- for gt_id, gt in enumerate(groundtruths):
- if len(detections) > 0:
- for det_id, detection in enumerate(detections):
- detection_orig = detection
- detection = [float(x) for x in detection[0].split(',')]
- detection = list(map(int, detection))
- pred_seq_str = detection_orig[1].strip()
- det_x = detection[0::2]
- det_y = detection[1::2]
- gt_x = list(map(int, np.squeeze(gt[1])))
- gt_y = list(map(int, np.squeeze(gt[3])))
- gt_seq_str = str(gt[4].tolist()[0])
- local_sigma_table[gt_id, det_id] = sigma_calculation(
- det_x, det_y, gt_x, gt_y)
- local_tau_table[gt_id, det_id] = tau_calculation(
- det_x, det_y, gt_x, gt_y)
- local_pred_str[det_id] = pred_seq_str
- local_gt_str[gt_id] = gt_seq_str
- global_sigma = local_sigma_table
- global_tau = local_tau_table
- global_pred_str = local_pred_str
- global_gt_str = local_gt_str
- single_data = {}
- single_data['sigma'] = global_sigma
- single_data['global_tau'] = global_tau
- single_data['global_pred_str'] = global_pred_str
- single_data['global_gt_str'] = global_gt_str
- return single_data
- def get_socre_B(gt_dir, img_id, pred_dict):
- allInputs = 1
- def input_reading_mod(pred_dict):
- """This helper reads input from txt files"""
- det = []
- n = len(pred_dict)
- for i in range(n):
- points = pred_dict[i]['points']
- text = pred_dict[i]['texts']
- point = ",".join(map(str, points.reshape(-1, )))
- det.append([point, text])
- return det
- def gt_reading_mod(gt_dir, gt_id):
- gt = io.loadmat('%s/poly_gt_img%s.mat' % (gt_dir, gt_id))
- gt = gt['polygt']
- return gt
- def detection_filtering(detections, groundtruths, threshold=0.5):
- for gt_id, gt in enumerate(groundtruths):
- if (gt[5] == '#') and (gt[1].shape[1] > 1):
- gt_x = list(map(int, np.squeeze(gt[1])))
- gt_y = list(map(int, np.squeeze(gt[3])))
- for det_id, detection in enumerate(detections):
- detection_orig = detection
- detection = [float(x) for x in detection[0].split(',')]
- detection = list(map(int, detection))
- det_x = detection[0::2]
- det_y = detection[1::2]
- det_gt_iou = iod(det_x, det_y, gt_x, gt_y)
- if det_gt_iou > threshold:
- detections[det_id] = []
- detections[:] = [item for item in detections if item != []]
- return detections
- def sigma_calculation(det_x, det_y, gt_x, gt_y):
- """
- sigma = inter_area / gt_area
- """
- return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
- area(gt_x, gt_y)), 2)
- def tau_calculation(det_x, det_y, gt_x, gt_y):
- if area(det_x, det_y) == 0.0:
- return 0
- return np.round((area_of_intersection(det_x, det_y, gt_x, gt_y) /
- area(det_x, det_y)), 2)
- ##############################Initialization###################################
- # global_sigma = []
- # global_tau = []
- # global_pred_str = []
- # global_gt_str = []
- ###############################################################################
- for input_id in range(allInputs):
- if (input_id != '.DS_Store') and (input_id != 'Pascal_result.txt') and (
- input_id != 'Pascal_result_curved.txt') and (input_id != 'Pascal_result_non_curved.txt') and (
- input_id != 'Deteval_result.txt') and (input_id != 'Deteval_result_curved.txt') \
- and (input_id != 'Deteval_result_non_curved.txt'):
- detections = input_reading_mod(pred_dict)
- groundtruths = gt_reading_mod(gt_dir, img_id).tolist()
- detections = detection_filtering(
- detections,
- groundtruths) # filters detections overlapping with DC area
- dc_id = []
- for i in range(len(groundtruths)):
- if groundtruths[i][5] == '#':
- dc_id.append(i)
- cnt = 0
- for a in dc_id:
- num = a - cnt
- del groundtruths[num]
- cnt += 1
- local_sigma_table = np.zeros((len(groundtruths), len(detections)))
- local_tau_table = np.zeros((len(groundtruths), len(detections)))
- local_pred_str = {}
- local_gt_str = {}
- for gt_id, gt in enumerate(groundtruths):
- if len(detections) > 0:
- for det_id, detection in enumerate(detections):
- detection_orig = detection
- detection = [float(x) for x in detection[0].split(',')]
- detection = list(map(int, detection))
- pred_seq_str = detection_orig[1].strip()
- det_x = detection[0::2]
- det_y = detection[1::2]
- gt_x = list(map(int, np.squeeze(gt[1])))
- gt_y = list(map(int, np.squeeze(gt[3])))
- gt_seq_str = str(gt[4].tolist()[0])
- local_sigma_table[gt_id, det_id] = sigma_calculation(
- det_x, det_y, gt_x, gt_y)
- local_tau_table[gt_id, det_id] = tau_calculation(
- det_x, det_y, gt_x, gt_y)
- local_pred_str[det_id] = pred_seq_str
- local_gt_str[gt_id] = gt_seq_str
- global_sigma = local_sigma_table
- global_tau = local_tau_table
- global_pred_str = local_pred_str
- global_gt_str = local_gt_str
- single_data = {}
- single_data['sigma'] = global_sigma
- single_data['global_tau'] = global_tau
- single_data['global_pred_str'] = global_pred_str
- single_data['global_gt_str'] = global_gt_str
- return single_data
- def get_score_C(gt_label, text, pred_bboxes):
- """
- get score for CentripetalText (CT) prediction.
- """
- check_install("Polygon", "Polygon3")
- import Polygon as plg
- def gt_reading_mod(gt_label, text):
- """This helper reads groundtruths from mat files"""
- groundtruths = []
- nbox = len(gt_label)
- for i in range(nbox):
- label = {"transcription": text[i][0], "points": gt_label[i].numpy()}
- groundtruths.append(label)
- return groundtruths
- def get_union(pD, pG):
- areaA = pD.area()
- areaB = pG.area()
- return areaA + areaB - get_intersection(pD, pG)
- def get_intersection(pD, pG):
- pInt = pD & pG
- if len(pInt) == 0:
- return 0
- return pInt.area()
- def detection_filtering(detections, groundtruths, threshold=0.5):
- for gt in groundtruths:
- point_num = gt['points'].shape[1] // 2
- if gt['transcription'] == '###' and (point_num > 1):
- gt_p = np.array(gt['points']).reshape(point_num,
- 2).astype('int32')
- gt_p = plg.Polygon(gt_p)
- for det_id, detection in enumerate(detections):
- det_y = detection[0::2]
- det_x = detection[1::2]
- det_p = np.concatenate((np.array(det_x), np.array(det_y)))
- det_p = det_p.reshape(2, -1).transpose()
- det_p = plg.Polygon(det_p)
- try:
- det_gt_iou = get_intersection(det_p,
- gt_p) / det_p.area()
- except:
- print(det_x, det_y, gt_p)
- if det_gt_iou > threshold:
- detections[det_id] = []
- detections[:] = [item for item in detections if item != []]
- return detections
- def sigma_calculation(det_p, gt_p):
- """
- sigma = inter_area / gt_area
- """
- if gt_p.area() == 0.:
- return 0
- return get_intersection(det_p, gt_p) / gt_p.area()
- def tau_calculation(det_p, gt_p):
- """
- tau = inter_area / det_area
- """
- if det_p.area() == 0.:
- return 0
- return get_intersection(det_p, gt_p) / det_p.area()
- detections = []
- for item in pred_bboxes:
- detections.append(item[:, ::-1].reshape(-1))
- groundtruths = gt_reading_mod(gt_label, text)
- detections = detection_filtering(
- detections, groundtruths) # filters detections overlapping with DC area
- for idx in range(len(groundtruths) - 1, -1, -1):
- #NOTE: source code use 'orin' to indicate '#', here we use 'anno',
- # which may cause slight drop in fscore, about 0.12
- if groundtruths[idx]['transcription'] == '###':
- groundtruths.pop(idx)
- local_sigma_table = np.zeros((len(groundtruths), len(detections)))
- local_tau_table = np.zeros((len(groundtruths), len(detections)))
- for gt_id, gt in enumerate(groundtruths):
- if len(detections) > 0:
- for det_id, detection in enumerate(detections):
- point_num = gt['points'].shape[1] // 2
- gt_p = np.array(gt['points']).reshape(point_num,
- 2).astype('int32')
- gt_p = plg.Polygon(gt_p)
- det_y = detection[0::2]
- det_x = detection[1::2]
- det_p = np.concatenate((np.array(det_x), np.array(det_y)))
- det_p = det_p.reshape(2, -1).transpose()
- det_p = plg.Polygon(det_p)
- local_sigma_table[gt_id, det_id] = sigma_calculation(det_p,
- gt_p)
- local_tau_table[gt_id, det_id] = tau_calculation(det_p, gt_p)
- data = {}
- data['sigma'] = local_sigma_table
- data['global_tau'] = local_tau_table
- data['global_pred_str'] = ''
- data['global_gt_str'] = ''
- return data
- def combine_results(all_data, rec_flag=True):
- tr = 0.7
- tp = 0.6
- fsc_k = 0.8
- k = 2
- global_sigma = []
- global_tau = []
- global_pred_str = []
- global_gt_str = []
- for data in all_data:
- global_sigma.append(data['sigma'])
- global_tau.append(data['global_tau'])
- global_pred_str.append(data['global_pred_str'])
- global_gt_str.append(data['global_gt_str'])
- global_accumulative_recall = 0
- global_accumulative_precision = 0
- total_num_gt = 0
- total_num_det = 0
- hit_str_count = 0
- hit_count = 0
- def one_to_one(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idy, rec_flag):
- hit_str_num = 0
- for gt_id in range(num_gt):
- gt_matching_qualified_sigma_candidates = np.where(
- local_sigma_table[gt_id, :] > tr)
- gt_matching_num_qualified_sigma_candidates = gt_matching_qualified_sigma_candidates[
- 0].shape[0]
- gt_matching_qualified_tau_candidates = np.where(
- local_tau_table[gt_id, :] > tp)
- gt_matching_num_qualified_tau_candidates = gt_matching_qualified_tau_candidates[
- 0].shape[0]
- det_matching_qualified_sigma_candidates = np.where(
- local_sigma_table[:, gt_matching_qualified_sigma_candidates[0]]
- > tr)
- det_matching_num_qualified_sigma_candidates = det_matching_qualified_sigma_candidates[
- 0].shape[0]
- det_matching_qualified_tau_candidates = np.where(
- local_tau_table[:, gt_matching_qualified_tau_candidates[0]] >
- tp)
- det_matching_num_qualified_tau_candidates = det_matching_qualified_tau_candidates[
- 0].shape[0]
- if (gt_matching_num_qualified_sigma_candidates == 1) and (gt_matching_num_qualified_tau_candidates == 1) and \
- (det_matching_num_qualified_sigma_candidates == 1) and (
- det_matching_num_qualified_tau_candidates == 1):
- global_accumulative_recall = global_accumulative_recall + 1.0
- global_accumulative_precision = global_accumulative_precision + 1.0
- local_accumulative_recall = local_accumulative_recall + 1.0
- local_accumulative_precision = local_accumulative_precision + 1.0
- gt_flag[0, gt_id] = 1
- matched_det_id = np.where(local_sigma_table[gt_id, :] > tr)
- # recg start
- if rec_flag:
- gt_str_cur = global_gt_str[idy][gt_id]
- pred_str_cur = global_pred_str[idy][matched_det_id[0]
- .tolist()[0]]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
- hit_str_num += 1
- # recg end
- det_flag[0, matched_det_id] = 1
- return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
- def one_to_many(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idy, rec_flag):
- hit_str_num = 0
- for gt_id in range(num_gt):
- # skip the following if the groundtruth was matched
- if gt_flag[0, gt_id] > 0:
- continue
- non_zero_in_sigma = np.where(local_sigma_table[gt_id, :] > 0)
- num_non_zero_in_sigma = non_zero_in_sigma[0].shape[0]
- if num_non_zero_in_sigma >= k:
- ####search for all detections that overlaps with this groundtruth
- qualified_tau_candidates = np.where((local_tau_table[
- gt_id, :] >= tp) & (det_flag[0, :] == 0))
- num_qualified_tau_candidates = qualified_tau_candidates[
- 0].shape[0]
- if num_qualified_tau_candidates == 1:
- if ((local_tau_table[gt_id, qualified_tau_candidates] >= tp)
- and
- (local_sigma_table[gt_id, qualified_tau_candidates] >=
- tr)):
- # became an one-to-one case
- global_accumulative_recall = global_accumulative_recall + 1.0
- global_accumulative_precision = global_accumulative_precision + 1.0
- local_accumulative_recall = local_accumulative_recall + 1.0
- local_accumulative_precision = local_accumulative_precision + 1.0
- gt_flag[0, gt_id] = 1
- det_flag[0, qualified_tau_candidates] = 1
- # recg start
- if rec_flag:
- gt_str_cur = global_gt_str[idy][gt_id]
- pred_str_cur = global_pred_str[idy][
- qualified_tau_candidates[0].tolist()[0]]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
- hit_str_num += 1
- # recg end
- elif (np.sum(local_sigma_table[gt_id, qualified_tau_candidates])
- >= tr):
- gt_flag[0, gt_id] = 1
- det_flag[0, qualified_tau_candidates] = 1
- # recg start
- if rec_flag:
- gt_str_cur = global_gt_str[idy][gt_id]
- pred_str_cur = global_pred_str[idy][
- qualified_tau_candidates[0].tolist()[0]]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
- hit_str_num += 1
- # recg end
- global_accumulative_recall = global_accumulative_recall + fsc_k
- global_accumulative_precision = global_accumulative_precision + num_qualified_tau_candidates * fsc_k
- local_accumulative_recall = local_accumulative_recall + fsc_k
- local_accumulative_precision = local_accumulative_precision + num_qualified_tau_candidates * fsc_k
- return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
- def many_to_one(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idy, rec_flag):
- hit_str_num = 0
- for det_id in range(num_det):
- # skip the following if the detection was matched
- if det_flag[0, det_id] > 0:
- continue
- non_zero_in_tau = np.where(local_tau_table[:, det_id] > 0)
- num_non_zero_in_tau = non_zero_in_tau[0].shape[0]
- if num_non_zero_in_tau >= k:
- ####search for all detections that overlaps with this groundtruth
- qualified_sigma_candidates = np.where((
- local_sigma_table[:, det_id] >= tp) & (gt_flag[0, :] == 0))
- num_qualified_sigma_candidates = qualified_sigma_candidates[
- 0].shape[0]
- if num_qualified_sigma_candidates == 1:
- if ((local_tau_table[qualified_sigma_candidates, det_id] >=
- tp) and
- (local_sigma_table[qualified_sigma_candidates, det_id]
- >= tr)):
- # became an one-to-one case
- global_accumulative_recall = global_accumulative_recall + 1.0
- global_accumulative_precision = global_accumulative_precision + 1.0
- local_accumulative_recall = local_accumulative_recall + 1.0
- local_accumulative_precision = local_accumulative_precision + 1.0
- gt_flag[0, qualified_sigma_candidates] = 1
- det_flag[0, det_id] = 1
- # recg start
- if rec_flag:
- pred_str_cur = global_pred_str[idy][det_id]
- gt_len = len(qualified_sigma_candidates[0])
- for idx in range(gt_len):
- ele_gt_id = qualified_sigma_candidates[
- 0].tolist()[idx]
- if ele_gt_id not in global_gt_str[idy]:
- continue
- gt_str_cur = global_gt_str[idy][ele_gt_id]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- break
- else:
- if pred_str_cur.lower() == gt_str_cur.lower(
- ):
- hit_str_num += 1
- break
- # recg end
- elif (np.sum(local_tau_table[qualified_sigma_candidates,
- det_id]) >= tp):
- det_flag[0, det_id] = 1
- gt_flag[0, qualified_sigma_candidates] = 1
- # recg start
- if rec_flag:
- pred_str_cur = global_pred_str[idy][det_id]
- gt_len = len(qualified_sigma_candidates[0])
- for idx in range(gt_len):
- ele_gt_id = qualified_sigma_candidates[0].tolist()[
- idx]
- if ele_gt_id not in global_gt_str[idy]:
- continue
- gt_str_cur = global_gt_str[idy][ele_gt_id]
- if pred_str_cur == gt_str_cur:
- hit_str_num += 1
- break
- else:
- if pred_str_cur.lower() == gt_str_cur.lower():
- hit_str_num += 1
- break
- # recg end
- global_accumulative_recall = global_accumulative_recall + num_qualified_sigma_candidates * fsc_k
- global_accumulative_precision = global_accumulative_precision + fsc_k
- local_accumulative_recall = local_accumulative_recall + num_qualified_sigma_candidates * fsc_k
- local_accumulative_precision = local_accumulative_precision + fsc_k
- return local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, gt_flag, det_flag, hit_str_num
- for idx in range(len(global_sigma)):
- local_sigma_table = np.array(global_sigma[idx])
- local_tau_table = global_tau[idx]
- num_gt = local_sigma_table.shape[0]
- num_det = local_sigma_table.shape[1]
- total_num_gt = total_num_gt + num_gt
- total_num_det = total_num_det + num_det
- local_accumulative_recall = 0
- local_accumulative_precision = 0
- gt_flag = np.zeros((1, num_gt))
- det_flag = np.zeros((1, num_det))
- #######first check for one-to-one case##########
- local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
- gt_flag, det_flag, hit_str_num = one_to_one(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idx, rec_flag)
- hit_str_count += hit_str_num
- #######then check for one-to-many case##########
- local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
- gt_flag, det_flag, hit_str_num = one_to_many(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idx, rec_flag)
- hit_str_count += hit_str_num
- #######then check for many-to-one case##########
- local_accumulative_recall, local_accumulative_precision, global_accumulative_recall, global_accumulative_precision, \
- gt_flag, det_flag, hit_str_num = many_to_one(local_sigma_table, local_tau_table,
- local_accumulative_recall, local_accumulative_precision,
- global_accumulative_recall, global_accumulative_precision,
- gt_flag, det_flag, idx, rec_flag)
- hit_str_count += hit_str_num
- try:
- recall = global_accumulative_recall / total_num_gt
- except ZeroDivisionError:
- recall = 0
- try:
- precision = global_accumulative_precision / total_num_det
- except ZeroDivisionError:
- precision = 0
- try:
- f_score = 2 * precision * recall / (precision + recall)
- except ZeroDivisionError:
- f_score = 0
- try:
- seqerr = 1 - float(hit_str_count) / global_accumulative_recall
- except ZeroDivisionError:
- seqerr = 1
- try:
- recall_e2e = float(hit_str_count) / total_num_gt
- except ZeroDivisionError:
- recall_e2e = 0
- try:
- precision_e2e = float(hit_str_count) / total_num_det
- except ZeroDivisionError:
- precision_e2e = 0
- try:
- f_score_e2e = 2 * precision_e2e * recall_e2e / (
- precision_e2e + recall_e2e)
- except ZeroDivisionError:
- f_score_e2e = 0
- final = {
- 'total_num_gt': total_num_gt,
- 'total_num_det': total_num_det,
- 'global_accumulative_recall': global_accumulative_recall,
- 'hit_str_count': hit_str_count,
- 'recall': recall,
- 'precision': precision,
- 'f_score': f_score,
- 'seqerr': seqerr,
- 'recall_e2e': recall_e2e,
- 'precision_e2e': precision_e2e,
- 'f_score_e2e': f_score_e2e
- }
- return final
|