123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import os
- import re
- import glob
- import numpy as np
- from multiprocessing import Pool
- from functools import partial
- from shapely.geometry import Polygon
- import argparse
- wordname_15 = [
- 'plane', 'baseball-diamond', 'bridge', 'ground-track-field',
- 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court',
- 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout',
- 'harbor', 'swimming-pool', 'helicopter'
- ]
- wordname_16 = wordname_15 + ['container-crane']
- wordname_18 = wordname_16 + ['airport', 'helipad']
- DATA_CLASSES = {
- 'dota10': wordname_15,
- 'dota15': wordname_16,
- 'dota20': wordname_18
- }
- def rbox_iou(g, p):
- """
- iou of rbox
- """
- g = np.array(g)
- p = np.array(p)
- g = Polygon(g[:8].reshape((4, 2)))
- p = Polygon(p[:8].reshape((4, 2)))
- g = g.buffer(0)
- p = p.buffer(0)
- if not g.is_valid or not p.is_valid:
- return 0
- inter = Polygon(g).intersection(Polygon(p)).area
- union = g.area + p.area - inter
- if union == 0:
- return 0
- else:
- return inter / union
- def py_cpu_nms_poly_fast(dets, thresh):
- """
- Args:
- dets: pred results
- thresh: nms threshold
- Returns: index of keep
- """
- obbs = dets[:, 0:-1]
- x1 = np.min(obbs[:, 0::2], axis=1)
- y1 = np.min(obbs[:, 1::2], axis=1)
- x2 = np.max(obbs[:, 0::2], axis=1)
- y2 = np.max(obbs[:, 1::2], axis=1)
- scores = dets[:, 8]
- areas = (x2 - x1 + 1) * (y2 - y1 + 1)
- polys = []
- for i in range(len(dets)):
- tm_polygon = [
- dets[i][0], dets[i][1], dets[i][2], dets[i][3], dets[i][4],
- dets[i][5], dets[i][6], dets[i][7]
- ]
- polys.append(tm_polygon)
- polys = np.array(polys)
- order = scores.argsort()[::-1]
- keep = []
- while order.size > 0:
- ovr = []
- i = order[0]
- keep.append(i)
- xx1 = np.maximum(x1[i], x1[order[1:]])
- yy1 = np.maximum(y1[i], y1[order[1:]])
- xx2 = np.minimum(x2[i], x2[order[1:]])
- yy2 = np.minimum(y2[i], y2[order[1:]])
- w = np.maximum(0.0, xx2 - xx1)
- h = np.maximum(0.0, yy2 - yy1)
- hbb_inter = w * h
- hbb_ovr = hbb_inter / (areas[i] + areas[order[1:]] - hbb_inter)
- h_inds = np.where(hbb_ovr > 0)[0]
- tmp_order = order[h_inds + 1]
- for j in range(tmp_order.size):
- iou = rbox_iou(polys[i], polys[tmp_order[j]])
- hbb_ovr[h_inds[j]] = iou
- try:
- if math.isnan(ovr[0]):
- pdb.set_trace()
- except:
- pass
- inds = np.where(hbb_ovr <= thresh)[0]
- order = order[inds + 1]
- return keep
- def poly2origpoly(poly, x, y, rate):
- origpoly = []
- for i in range(int(len(poly) / 2)):
- tmp_x = float(poly[i * 2] + x) / float(rate)
- tmp_y = float(poly[i * 2 + 1] + y) / float(rate)
- origpoly.append(tmp_x)
- origpoly.append(tmp_y)
- return origpoly
- def nmsbynamedict(nameboxdict, nms, thresh):
- """
- Args:
- nameboxdict: nameboxdict
- nms: nms
- thresh: nms threshold
- Returns: nms result as dict
- """
- nameboxnmsdict = {x: [] for x in nameboxdict}
- for imgname in nameboxdict:
- keep = nms(np.array(nameboxdict[imgname]), thresh)
- outdets = []
- for index in keep:
- outdets.append(nameboxdict[imgname][index])
- nameboxnmsdict[imgname] = outdets
- return nameboxnmsdict
- def merge_single(output_dir, nms, nms_thresh, pred_class_lst):
- """
- Args:
- output_dir: output_dir
- nms: nms
- pred_class_lst: pred_class_lst
- class_name: class_name
- Returns:
- """
- class_name, pred_bbox_list = pred_class_lst
- nameboxdict = {}
- for line in pred_bbox_list:
- splitline = line.split(' ')
- subname = splitline[0]
- splitname = subname.split('__')
- oriname = splitname[0]
- pattern1 = re.compile(r'__\d+___\d+')
- x_y = re.findall(pattern1, subname)
- x_y_2 = re.findall(r'\d+', x_y[0])
- x, y = int(x_y_2[0]), int(x_y_2[1])
- pattern2 = re.compile(r'__([\d+\.]+)__\d+___')
- rate = re.findall(pattern2, subname)[0]
- confidence = splitline[1]
- poly = list(map(float, splitline[2:]))
- origpoly = poly2origpoly(poly, x, y, rate)
- det = origpoly
- det.append(confidence)
- det = list(map(float, det))
- if (oriname not in nameboxdict):
- nameboxdict[oriname] = []
- nameboxdict[oriname].append(det)
- nameboxnmsdict = nmsbynamedict(nameboxdict, nms, nms_thresh)
- # write result
- dstname = os.path.join(output_dir, class_name + '.txt')
- with open(dstname, 'w') as f_out:
- for imgname in nameboxnmsdict:
- for det in nameboxnmsdict[imgname]:
- confidence = det[-1]
- bbox = det[0:-1]
- outline = imgname + ' ' + str(confidence) + ' ' + ' '.join(
- map(str, bbox))
- f_out.write(outline + '\n')
- def generate_result(pred_txt_dir,
- output_dir='output',
- class_names=wordname_15,
- nms_thresh=0.1):
- """
- pred_txt_dir: dir of pred txt
- output_dir: dir of output
- class_names: class names of data
- """
- pred_txt_list = glob.glob("{}/*.txt".format(pred_txt_dir))
- # step1: summary pred bbox
- pred_classes = {}
- for class_name in class_names:
- pred_classes[class_name] = []
- for current_txt in pred_txt_list:
- img_id = os.path.split(current_txt)[1]
- img_id = img_id.split('.txt')[0]
- with open(current_txt) as f:
- res = f.readlines()
- for item in res:
- item = item.split(' ')
- pred_class = item[0]
- item[0] = img_id
- pred_bbox = ' '.join(item)
- pred_classes[pred_class].append(pred_bbox)
- pred_classes_lst = []
- for class_name in pred_classes.keys():
- print('class_name: {}, count: {}'.format(class_name,
- len(pred_classes[class_name])))
- pred_classes_lst.append((class_name, pred_classes[class_name]))
- # step2: merge
- pool = Pool(len(class_names))
- nms = py_cpu_nms_poly_fast
- mergesingle_fn = partial(merge_single, output_dir, nms, nms_thresh)
- pool.map(mergesingle_fn, pred_classes_lst)
- def parse_args():
- parser = argparse.ArgumentParser(description='generate test results')
- parser.add_argument('--pred_txt_dir', type=str, help='path of pred txt dir')
- parser.add_argument(
- '--output_dir', type=str, default='output', help='path of output dir')
- parser.add_argument(
- '--data_type', type=str, default='dota10', help='data type')
- parser.add_argument(
- '--nms_thresh',
- type=float,
- default=0.1,
- help='nms threshold whild merging results')
- return parser.parse_args()
- if __name__ == '__main__':
- args = parse_args()
- output_dir = args.output_dir
- if not os.path.exists(output_dir):
- os.makedirs(output_dir)
- class_names = DATA_CLASSES[args.data_type]
- generate_result(args.pred_txt_dir, output_dir, class_names)
- print('done!')
|