generate_result.py 7.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import re
  16. import glob
  17. import numpy as np
  18. from multiprocessing import Pool
  19. from functools import partial
  20. from shapely.geometry import Polygon
  21. import argparse
  22. wordname_15 = [
  23. 'plane', 'baseball-diamond', 'bridge', 'ground-track-field',
  24. 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
  25. 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
  26. 'harbor', 'swimming-pool', 'helicopter'
  27. ]
  28. wordname_16 = wordname_15 + ['container-crane']
  29. wordname_18 = wordname_16 + ['airport', 'helipad']
  30. DATA_CLASSES = {
  31. 'dota10': wordname_15,
  32. 'dota15': wordname_16,
  33. 'dota20': wordname_18
  34. }
  35. def rbox_iou(g, p):
  36. """
  37. iou of rbox
  38. """
  39. g = np.array(g)
  40. p = np.array(p)
  41. g = Polygon(g[:8].reshape((4, 2)))
  42. p = Polygon(p[:8].reshape((4, 2)))
  43. g = g.buffer(0)
  44. p = p.buffer(0)
  45. if not g.is_valid or not p.is_valid:
  46. return 0
  47. inter = Polygon(g).intersection(Polygon(p)).area
  48. union = g.area + p.area - inter
  49. if union == 0:
  50. return 0
  51. else:
  52. return inter / union
  53. def py_cpu_nms_poly_fast(dets, thresh):
  54. """
  55. Args:
  56. dets: pred results
  57. thresh: nms threshold
  58. Returns: index of keep
  59. """
  60. obbs = dets[:, 0:-1]
  61. x1 = np.min(obbs[:, 0::2], axis=1)
  62. y1 = np.min(obbs[:, 1::2], axis=1)
  63. x2 = np.max(obbs[:, 0::2], axis=1)
  64. y2 = np.max(obbs[:, 1::2], axis=1)
  65. scores = dets[:, 8]
  66. areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  67. polys = []
  68. for i in range(len(dets)):
  69. tm_polygon = [
  70. dets[i][0], dets[i][1], dets[i][2], dets[i][3], dets[i][4],
  71. dets[i][5], dets[i][6], dets[i][7]
  72. ]
  73. polys.append(tm_polygon)
  74. polys = np.array(polys)
  75. order = scores.argsort()[::-1]
  76. keep = []
  77. while order.size > 0:
  78. ovr = []
  79. i = order[0]
  80. keep.append(i)
  81. xx1 = np.maximum(x1[i], x1[order[1:]])
  82. yy1 = np.maximum(y1[i], y1[order[1:]])
  83. xx2 = np.minimum(x2[i], x2[order[1:]])
  84. yy2 = np.minimum(y2[i], y2[order[1:]])
  85. w = np.maximum(0.0, xx2 - xx1)
  86. h = np.maximum(0.0, yy2 - yy1)
  87. hbb_inter = w * h
  88. hbb_ovr = hbb_inter / (areas[i] + areas[order[1:]] - hbb_inter)
  89. h_inds = np.where(hbb_ovr > 0)[0]
  90. tmp_order = order[h_inds + 1]
  91. for j in range(tmp_order.size):
  92. iou = rbox_iou(polys[i], polys[tmp_order[j]])
  93. hbb_ovr[h_inds[j]] = iou
  94. try:
  95. if math.isnan(ovr[0]):
  96. pdb.set_trace()
  97. except:
  98. pass
  99. inds = np.where(hbb_ovr <= thresh)[0]
  100. order = order[inds + 1]
  101. return keep
  102. def poly2origpoly(poly, x, y, rate):
  103. origpoly = []
  104. for i in range(int(len(poly) / 2)):
  105. tmp_x = float(poly[i * 2] + x) / float(rate)
  106. tmp_y = float(poly[i * 2 + 1] + y) / float(rate)
  107. origpoly.append(tmp_x)
  108. origpoly.append(tmp_y)
  109. return origpoly
  110. def nmsbynamedict(nameboxdict, nms, thresh):
  111. """
  112. Args:
  113. nameboxdict: nameboxdict
  114. nms: nms
  115. thresh: nms threshold
  116. Returns: nms result as dict
  117. """
  118. nameboxnmsdict = {x: [] for x in nameboxdict}
  119. for imgname in nameboxdict:
  120. keep = nms(np.array(nameboxdict[imgname]), thresh)
  121. outdets = []
  122. for index in keep:
  123. outdets.append(nameboxdict[imgname][index])
  124. nameboxnmsdict[imgname] = outdets
  125. return nameboxnmsdict
  126. def merge_single(output_dir, nms, nms_thresh, pred_class_lst):
  127. """
  128. Args:
  129. output_dir: output_dir
  130. nms: nms
  131. pred_class_lst: pred_class_lst
  132. class_name: class_name
  133. Returns:
  134. """
  135. class_name, pred_bbox_list = pred_class_lst
  136. nameboxdict = {}
  137. for line in pred_bbox_list:
  138. splitline = line.split(' ')
  139. subname = splitline[0]
  140. splitname = subname.split('__')
  141. oriname = splitname[0]
  142. pattern1 = re.compile(r'__\d+___\d+')
  143. x_y = re.findall(pattern1, subname)
  144. x_y_2 = re.findall(r'\d+', x_y[0])
  145. x, y = int(x_y_2[0]), int(x_y_2[1])
  146. pattern2 = re.compile(r'__([\d+\.]+)__\d+___')
  147. rate = re.findall(pattern2, subname)[0]
  148. confidence = splitline[1]
  149. poly = list(map(float, splitline[2:]))
  150. origpoly = poly2origpoly(poly, x, y, rate)
  151. det = origpoly
  152. det.append(confidence)
  153. det = list(map(float, det))
  154. if (oriname not in nameboxdict):
  155. nameboxdict[oriname] = []
  156. nameboxdict[oriname].append(det)
  157. nameboxnmsdict = nmsbynamedict(nameboxdict, nms, nms_thresh)
  158. # write result
  159. dstname = os.path.join(output_dir, class_name + '.txt')
  160. with open(dstname, 'w') as f_out:
  161. for imgname in nameboxnmsdict:
  162. for det in nameboxnmsdict[imgname]:
  163. confidence = det[-1]
  164. bbox = det[0:-1]
  165. outline = imgname + ' ' + str(confidence) + ' ' + ' '.join(
  166. map(str, bbox))
  167. f_out.write(outline + '\n')
  168. def generate_result(pred_txt_dir,
  169. output_dir='output',
  170. class_names=wordname_15,
  171. nms_thresh=0.1):
  172. """
  173. pred_txt_dir: dir of pred txt
  174. output_dir: dir of output
  175. class_names: class names of data
  176. """
  177. pred_txt_list = glob.glob("{}/*.txt".format(pred_txt_dir))
  178. # step1: summary pred bbox
  179. pred_classes = {}
  180. for class_name in class_names:
  181. pred_classes[class_name] = []
  182. for current_txt in pred_txt_list:
  183. img_id = os.path.split(current_txt)[1]
  184. img_id = img_id.split('.txt')[0]
  185. with open(current_txt) as f:
  186. res = f.readlines()
  187. for item in res:
  188. item = item.split(' ')
  189. pred_class = item[0]
  190. item[0] = img_id
  191. pred_bbox = ' '.join(item)
  192. pred_classes[pred_class].append(pred_bbox)
  193. pred_classes_lst = []
  194. for class_name in pred_classes.keys():
  195. print('class_name: {}, count: {}'.format(class_name,
  196. len(pred_classes[class_name])))
  197. pred_classes_lst.append((class_name, pred_classes[class_name]))
  198. # step2: merge
  199. pool = Pool(len(class_names))
  200. nms = py_cpu_nms_poly_fast
  201. mergesingle_fn = partial(merge_single, output_dir, nms, nms_thresh)
  202. pool.map(mergesingle_fn, pred_classes_lst)
  203. def parse_args():
  204. parser = argparse.ArgumentParser(description='generate test results')
  205. parser.add_argument('--pred_txt_dir', type=str, help='path of pred txt dir')
  206. parser.add_argument(
  207. '--output_dir', type=str, default='output', help='path of output dir')
  208. parser.add_argument(
  209. '--data_type', type=str, default='dota10', help='data type')
  210. parser.add_argument(
  211. '--nms_thresh',
  212. type=float,
  213. default=0.1,
  214. help='nms threshold whild merging results')
  215. return parser.parse_args()
  216. if __name__ == '__main__':
  217. args = parse_args()
  218. output_dir = args.output_dir
  219. if not os.path.exists(output_dir):
  220. os.makedirs(output_dir)
  221. class_names = DATA_CLASSES[args.data_type]
  222. generate_result(args.pred_txt_dir, output_dir, class_names)
  223. print('done!')