metrics.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import json
  20. import paddle
  21. import numpy as np
  22. import typing
  23. from collections import defaultdict
  24. from pathlib import Path
  25. from .map_utils import prune_zero_padding, DetectionMAP
  26. from .coco_utils import get_infer_results, cocoapi_eval
  27. from .widerface_utils import face_eval_run
  28. from ppdet.data.source.category import get_categories
  29. from ppdet.modeling.rbox_utils import poly2rbox_np
  30. from ppdet.utils.logger import setup_logger
  31. logger = setup_logger(__name__)
  32. __all__ = [
  33. 'Metric', 'COCOMetric', 'VOCMetric', 'WiderFaceMetric', 'get_infer_results',
  34. 'RBoxMetric', 'SNIPERCOCOMetric'
  35. ]
  36. COCO_SIGMAS = np.array([
  37. .26, .25, .25, .35, .35, .79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87,
  38. .89, .89
  39. ]) / 10.0
  40. CROWD_SIGMAS = np.array(
  41. [.79, .79, .72, .72, .62, .62, 1.07, 1.07, .87, .87, .89, .89, .79,
  42. .79]) / 10.0
  43. class Metric(paddle.metric.Metric):
  44. def name(self):
  45. return self.__class__.__name__
  46. def reset(self):
  47. pass
  48. def accumulate(self):
  49. pass
  50. # paddle.metric.Metric defined :metch:`update`, :meth:`accumulate`
  51. # :metch:`reset`, in ppdet, we also need following 2 methods:
  52. # abstract method for logging metric results
  53. def log(self):
  54. pass
  55. # abstract method for getting metric results
  56. def get_results(self):
  57. pass
  58. class COCOMetric(Metric):
  59. def __init__(self, anno_file, **kwargs):
  60. self.anno_file = anno_file
  61. self.clsid2catid = kwargs.get('clsid2catid', None)
  62. if self.clsid2catid is None:
  63. self.clsid2catid, _ = get_categories('COCO', anno_file)
  64. self.classwise = kwargs.get('classwise', False)
  65. self.output_eval = kwargs.get('output_eval', None)
  66. # TODO: bias should be unified
  67. self.bias = kwargs.get('bias', 0)
  68. self.save_prediction_only = kwargs.get('save_prediction_only', False)
  69. self.iou_type = kwargs.get('IouType', 'bbox')
  70. if not self.save_prediction_only:
  71. assert os.path.isfile(anno_file), \
  72. "anno_file {} not a file".format(anno_file)
  73. if self.output_eval is not None:
  74. Path(self.output_eval).mkdir(exist_ok=True)
  75. self.reset()
  76. def reset(self):
  77. # only bbox and mask evaluation support currently
  78. self.results = {'bbox': [], 'mask': [], 'segm': [], 'keypoint': []}
  79. self.eval_results = {}
  80. def update(self, inputs, outputs):
  81. outs = {}
  82. # outputs Tensor -> numpy.ndarray
  83. for k, v in outputs.items():
  84. outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v
  85. # multi-scale inputs: all inputs have same im_id
  86. if isinstance(inputs, typing.Sequence):
  87. im_id = inputs[0]['im_id']
  88. else:
  89. im_id = inputs['im_id']
  90. outs['im_id'] = im_id.numpy() if isinstance(im_id,
  91. paddle.Tensor) else im_id
  92. infer_results = get_infer_results(
  93. outs, self.clsid2catid, bias=self.bias)
  94. self.results['bbox'] += infer_results[
  95. 'bbox'] if 'bbox' in infer_results else []
  96. self.results['mask'] += infer_results[
  97. 'mask'] if 'mask' in infer_results else []
  98. self.results['segm'] += infer_results[
  99. 'segm'] if 'segm' in infer_results else []
  100. self.results['keypoint'] += infer_results[
  101. 'keypoint'] if 'keypoint' in infer_results else []
  102. def accumulate(self):
  103. if len(self.results['bbox']) > 0:
  104. output = "bbox.json"
  105. if self.output_eval:
  106. output = os.path.join(self.output_eval, output)
  107. with open(output, 'w') as f:
  108. json.dump(self.results['bbox'], f)
  109. logger.info('The bbox result is saved to bbox.json.')
  110. if self.save_prediction_only:
  111. logger.info('The bbox result is saved to {} and do not '
  112. 'evaluate the mAP.'.format(output))
  113. else:
  114. bbox_stats = cocoapi_eval(
  115. output,
  116. 'bbox',
  117. anno_file=self.anno_file,
  118. classwise=self.classwise)
  119. self.eval_results['bbox'] = bbox_stats
  120. sys.stdout.flush()
  121. if len(self.results['mask']) > 0:
  122. output = "mask.json"
  123. if self.output_eval:
  124. output = os.path.join(self.output_eval, output)
  125. with open(output, 'w') as f:
  126. json.dump(self.results['mask'], f)
  127. logger.info('The mask result is saved to mask.json.')
  128. if self.save_prediction_only:
  129. logger.info('The mask result is saved to {} and do not '
  130. 'evaluate the mAP.'.format(output))
  131. else:
  132. seg_stats = cocoapi_eval(
  133. output,
  134. 'segm',
  135. anno_file=self.anno_file,
  136. classwise=self.classwise)
  137. self.eval_results['mask'] = seg_stats
  138. sys.stdout.flush()
  139. if len(self.results['segm']) > 0:
  140. output = "segm.json"
  141. if self.output_eval:
  142. output = os.path.join(self.output_eval, output)
  143. with open(output, 'w') as f:
  144. json.dump(self.results['segm'], f)
  145. logger.info('The segm result is saved to segm.json.')
  146. if self.save_prediction_only:
  147. logger.info('The segm result is saved to {} and do not '
  148. 'evaluate the mAP.'.format(output))
  149. else:
  150. seg_stats = cocoapi_eval(
  151. output,
  152. 'segm',
  153. anno_file=self.anno_file,
  154. classwise=self.classwise)
  155. self.eval_results['mask'] = seg_stats
  156. sys.stdout.flush()
  157. if len(self.results['keypoint']) > 0:
  158. output = "keypoint.json"
  159. if self.output_eval:
  160. output = os.path.join(self.output_eval, output)
  161. with open(output, 'w') as f:
  162. json.dump(self.results['keypoint'], f)
  163. logger.info('The keypoint result is saved to keypoint.json.')
  164. if self.save_prediction_only:
  165. logger.info('The keypoint result is saved to {} and do not '
  166. 'evaluate the mAP.'.format(output))
  167. else:
  168. style = 'keypoints'
  169. use_area = True
  170. sigmas = COCO_SIGMAS
  171. if self.iou_type == 'keypoints_crowd':
  172. style = 'keypoints_crowd'
  173. use_area = False
  174. sigmas = CROWD_SIGMAS
  175. keypoint_stats = cocoapi_eval(
  176. output,
  177. style,
  178. anno_file=self.anno_file,
  179. classwise=self.classwise,
  180. sigmas=sigmas,
  181. use_area=use_area)
  182. self.eval_results['keypoint'] = keypoint_stats
  183. sys.stdout.flush()
  184. def log(self):
  185. pass
  186. def get_results(self):
  187. return self.eval_results
  188. class VOCMetric(Metric):
  189. def __init__(self,
  190. label_list,
  191. class_num=20,
  192. overlap_thresh=0.5,
  193. map_type='11point',
  194. is_bbox_normalized=False,
  195. evaluate_difficult=False,
  196. classwise=False,
  197. output_eval=None,
  198. save_prediction_only=False):
  199. assert os.path.isfile(label_list), \
  200. "label_list {} not a file".format(label_list)
  201. self.clsid2catid, self.catid2name = get_categories('VOC', label_list)
  202. self.overlap_thresh = overlap_thresh
  203. self.map_type = map_type
  204. self.evaluate_difficult = evaluate_difficult
  205. self.output_eval = output_eval
  206. self.save_prediction_only = save_prediction_only
  207. self.detection_map = DetectionMAP(
  208. class_num=class_num,
  209. overlap_thresh=overlap_thresh,
  210. map_type=map_type,
  211. is_bbox_normalized=is_bbox_normalized,
  212. evaluate_difficult=evaluate_difficult,
  213. catid2name=self.catid2name,
  214. classwise=classwise)
  215. self.reset()
  216. def reset(self):
  217. self.results = {'bbox': [], 'score': [], 'label': []}
  218. self.detection_map.reset()
  219. def update(self, inputs, outputs):
  220. bbox_np = outputs['bbox'].numpy() if isinstance(
  221. outputs['bbox'], paddle.Tensor) else outputs['bbox']
  222. bboxes = bbox_np[:, 2:]
  223. scores = bbox_np[:, 1]
  224. labels = bbox_np[:, 0]
  225. bbox_lengths = outputs['bbox_num'].numpy() if isinstance(
  226. outputs['bbox_num'], paddle.Tensor) else outputs['bbox_num']
  227. self.results['bbox'].append(bboxes.tolist())
  228. self.results['score'].append(scores.tolist())
  229. self.results['label'].append(labels.tolist())
  230. if bboxes.shape == (1, 1) or bboxes is None:
  231. return
  232. if self.save_prediction_only:
  233. return
  234. gt_boxes = inputs['gt_bbox']
  235. gt_labels = inputs['gt_class']
  236. difficults = inputs['difficult'] if not self.evaluate_difficult \
  237. else None
  238. if 'scale_factor' in inputs:
  239. scale_factor = inputs['scale_factor'].numpy() if isinstance(
  240. inputs['scale_factor'],
  241. paddle.Tensor) else inputs['scale_factor']
  242. else:
  243. scale_factor = np.ones((gt_boxes.shape[0], 2)).astype('float32')
  244. bbox_idx = 0
  245. for i in range(len(gt_boxes)):
  246. gt_box = gt_boxes[i].numpy() if isinstance(
  247. gt_boxes[i], paddle.Tensor) else gt_boxes[i]
  248. h, w = scale_factor[i]
  249. gt_box = gt_box / np.array([w, h, w, h])
  250. gt_label = gt_labels[i].numpy() if isinstance(
  251. gt_labels[i], paddle.Tensor) else gt_labels[i]
  252. if difficults is not None:
  253. difficult = difficults[i].numpy() if isinstance(
  254. difficults[i], paddle.Tensor) else difficults[i]
  255. else:
  256. difficult = None
  257. bbox_num = bbox_lengths[i]
  258. bbox = bboxes[bbox_idx:bbox_idx + bbox_num]
  259. score = scores[bbox_idx:bbox_idx + bbox_num]
  260. label = labels[bbox_idx:bbox_idx + bbox_num]
  261. gt_box, gt_label, difficult = prune_zero_padding(gt_box, gt_label,
  262. difficult)
  263. self.detection_map.update(bbox, score, label, gt_box, gt_label,
  264. difficult)
  265. bbox_idx += bbox_num
  266. def accumulate(self):
  267. output = "bbox.json"
  268. if self.output_eval:
  269. output = os.path.join(self.output_eval, output)
  270. with open(output, 'w') as f:
  271. json.dump(self.results, f)
  272. logger.info('The bbox result is saved to bbox.json.')
  273. if self.save_prediction_only:
  274. return
  275. logger.info("Accumulating evaluatation results...")
  276. self.detection_map.accumulate()
  277. def log(self):
  278. map_stat = 100. * self.detection_map.get_map()
  279. logger.info("mAP({:.2f}, {}) = {:.2f}%".format(self.overlap_thresh,
  280. self.map_type, map_stat))
  281. def get_results(self):
  282. return {'bbox': [self.detection_map.get_map()]}
  283. class WiderFaceMetric(Metric):
  284. def __init__(self, image_dir, anno_file, multi_scale=True):
  285. self.image_dir = image_dir
  286. self.anno_file = anno_file
  287. self.multi_scale = multi_scale
  288. self.clsid2catid, self.catid2name = get_categories('widerface')
  289. def update(self, model):
  290. face_eval_run(
  291. model,
  292. self.image_dir,
  293. self.anno_file,
  294. pred_dir='output/pred',
  295. eval_mode='widerface',
  296. multi_scale=self.multi_scale)
  297. class RBoxMetric(Metric):
  298. def __init__(self, anno_file, **kwargs):
  299. self.anno_file = anno_file
  300. self.clsid2catid, self.catid2name = get_categories('RBOX', anno_file)
  301. self.catid2clsid = {v: k for k, v in self.clsid2catid.items()}
  302. self.classwise = kwargs.get('classwise', False)
  303. self.output_eval = kwargs.get('output_eval', None)
  304. self.save_prediction_only = kwargs.get('save_prediction_only', False)
  305. self.overlap_thresh = kwargs.get('overlap_thresh', 0.5)
  306. self.map_type = kwargs.get('map_type', '11point')
  307. self.evaluate_difficult = kwargs.get('evaluate_difficult', False)
  308. self.imid2path = kwargs.get('imid2path', None)
  309. class_num = len(self.catid2name)
  310. self.detection_map = DetectionMAP(
  311. class_num=class_num,
  312. overlap_thresh=self.overlap_thresh,
  313. map_type=self.map_type,
  314. is_bbox_normalized=False,
  315. evaluate_difficult=self.evaluate_difficult,
  316. catid2name=self.catid2name,
  317. classwise=self.classwise)
  318. self.reset()
  319. def reset(self):
  320. self.results = []
  321. self.detection_map.reset()
  322. def update(self, inputs, outputs):
  323. outs = {}
  324. # outputs Tensor -> numpy.ndarray
  325. for k, v in outputs.items():
  326. outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v
  327. im_id = inputs['im_id']
  328. im_id = im_id.numpy() if isinstance(im_id, paddle.Tensor) else im_id
  329. outs['im_id'] = im_id
  330. infer_results = get_infer_results(outs, self.clsid2catid)
  331. infer_results = infer_results['bbox'] if 'bbox' in infer_results else []
  332. self.results += infer_results
  333. if self.save_prediction_only:
  334. return
  335. gt_boxes = inputs['gt_poly']
  336. gt_labels = inputs['gt_class']
  337. if 'scale_factor' in inputs:
  338. scale_factor = inputs['scale_factor'].numpy() if isinstance(
  339. inputs['scale_factor'],
  340. paddle.Tensor) else inputs['scale_factor']
  341. else:
  342. scale_factor = np.ones((gt_boxes.shape[0], 2)).astype('float32')
  343. for i in range(len(gt_boxes)):
  344. gt_box = gt_boxes[i].numpy() if isinstance(
  345. gt_boxes[i], paddle.Tensor) else gt_boxes[i]
  346. h, w = scale_factor[i]
  347. gt_box = gt_box / np.array([w, h, w, h, w, h, w, h])
  348. gt_label = gt_labels[i].numpy() if isinstance(
  349. gt_labels[i], paddle.Tensor) else gt_labels[i]
  350. gt_box, gt_label, _ = prune_zero_padding(gt_box, gt_label)
  351. bbox = [
  352. res['bbox'] for res in infer_results
  353. if int(res['image_id']) == int(im_id[i])
  354. ]
  355. score = [
  356. res['score'] for res in infer_results
  357. if int(res['image_id']) == int(im_id[i])
  358. ]
  359. label = [
  360. self.catid2clsid[int(res['category_id'])]
  361. for res in infer_results
  362. if int(res['image_id']) == int(im_id[i])
  363. ]
  364. self.detection_map.update(bbox, score, label, gt_box, gt_label)
  365. def save_results(self, results, output_dir, imid2path):
  366. if imid2path:
  367. data_dicts = defaultdict(list)
  368. for result in results:
  369. image_id = result['image_id']
  370. data_dicts[image_id].append(result)
  371. for image_id, image_path in imid2path.items():
  372. basename = os.path.splitext(os.path.split(image_path)[-1])[0]
  373. output = os.path.join(output_dir, "{}.txt".format(basename))
  374. dets = data_dicts.get(image_id, [])
  375. with open(output, 'w') as f:
  376. for det in dets:
  377. catid, bbox, score = det['category_id'], det[
  378. 'bbox'], det['score']
  379. bbox_pred = '{} {} '.format(self.catid2name[catid],
  380. score) + ' '.join(
  381. [str(e) for e in bbox])
  382. f.write(bbox_pred + '\n')
  383. logger.info('The bbox result is saved to {}.'.format(output_dir))
  384. else:
  385. output = os.path.join(output_dir, "bbox.json")
  386. with open(output, 'w') as f:
  387. json.dump(results, f)
  388. logger.info('The bbox result is saved to {}.'.format(output))
  389. def accumulate(self):
  390. if self.output_eval:
  391. self.save_results(self.results, self.output_eval, self.imid2path)
  392. if not self.save_prediction_only:
  393. logger.info("Accumulating evaluatation results...")
  394. self.detection_map.accumulate()
  395. def log(self):
  396. map_stat = 100. * self.detection_map.get_map()
  397. logger.info("mAP({:.2f}, {}) = {:.2f}%".format(self.overlap_thresh,
  398. self.map_type, map_stat))
  399. def get_results(self):
  400. return {'bbox': [self.detection_map.get_map()]}
  401. class SNIPERCOCOMetric(COCOMetric):
  402. def __init__(self, anno_file, **kwargs):
  403. super(SNIPERCOCOMetric, self).__init__(anno_file, **kwargs)
  404. self.dataset = kwargs["dataset"]
  405. self.chip_results = []
  406. def reset(self):
  407. # only bbox and mask evaluation support currently
  408. self.results = {'bbox': [], 'mask': [], 'segm': [], 'keypoint': []}
  409. self.eval_results = {}
  410. self.chip_results = []
  411. def update(self, inputs, outputs):
  412. outs = {}
  413. # outputs Tensor -> numpy.ndarray
  414. for k, v in outputs.items():
  415. outs[k] = v.numpy() if isinstance(v, paddle.Tensor) else v
  416. im_id = inputs['im_id']
  417. outs['im_id'] = im_id.numpy() if isinstance(im_id,
  418. paddle.Tensor) else im_id
  419. self.chip_results.append(outs)
  420. def accumulate(self):
  421. results = self.dataset.anno_cropper.aggregate_chips_detections(
  422. self.chip_results)
  423. for outs in results:
  424. infer_results = get_infer_results(
  425. outs, self.clsid2catid, bias=self.bias)
  426. self.results['bbox'] += infer_results[
  427. 'bbox'] if 'bbox' in infer_results else []
  428. super(SNIPERCOCOMetric, self).accumulate()