# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import re import glob import numpy as np from multiprocessing import Pool from functools import partial from shapely.geometry import Polygon import argparse wordname_15 = [ 'plane', 'baseball-diamond', 'bridge', 'ground-track-field', 'small-vehicle', 'large-vehicle', 'ship', 'tennis-court', 'basketball-court', 'storage-tank', 'soccer-ball-field', 'roundabout', 'harbor', 'swimming-pool', 'helicopter' ] wordname_16 = wordname_15 + ['container-crane'] wordname_18 = wordname_16 + ['airport', 'helipad'] DATA_CLASSES = { 'dota10': wordname_15, 'dota15': wordname_16, 'dota20': wordname_18 } def rbox_iou(g, p): """ iou of rbox """ g = np.array(g) p = np.array(p) g = Polygon(g[:8].reshape((4, 2))) p = Polygon(p[:8].reshape((4, 2))) g = g.buffer(0) p = p.buffer(0) if not g.is_valid or not p.is_valid: return 0 inter = Polygon(g).intersection(Polygon(p)).area union = g.area + p.area - inter if union == 0: return 0 else: return inter / union def py_cpu_nms_poly_fast(dets, thresh): """ Args: dets: pred results thresh: nms threshold Returns: index of keep """ obbs = dets[:, 0:-1] x1 = np.min(obbs[:, 0::2], axis=1) y1 = np.min(obbs[:, 1::2], axis=1) x2 = np.max(obbs[:, 0::2], axis=1) y2 = np.max(obbs[:, 1::2], axis=1) scores = dets[:, 8] areas = (x2 - x1 + 1) * (y2 - y1 + 1) polys = [] for i in range(len(dets)): tm_polygon = [ dets[i][0], dets[i][1], dets[i][2], dets[i][3], dets[i][4], dets[i][5], dets[i][6], dets[i][7] ] polys.append(tm_polygon) polys = np.array(polys) order = scores.argsort()[::-1] keep = [] while order.size > 0: ovr = [] i = order[0] keep.append(i) xx1 = np.maximum(x1[i], x1[order[1:]]) yy1 = np.maximum(y1[i], y1[order[1:]]) xx2 = np.minimum(x2[i], x2[order[1:]]) yy2 = np.minimum(y2[i], y2[order[1:]]) w = np.maximum(0.0, xx2 - xx1) h = np.maximum(0.0, yy2 - yy1) hbb_inter = w * h hbb_ovr = hbb_inter / (areas[i] + areas[order[1:]] - hbb_inter) h_inds = np.where(hbb_ovr > 0)[0] tmp_order = order[h_inds + 1] for j in range(tmp_order.size): iou = rbox_iou(polys[i], polys[tmp_order[j]]) hbb_ovr[h_inds[j]] = iou try: if math.isnan(ovr[0]): pdb.set_trace() except: pass inds = np.where(hbb_ovr <= thresh)[0] order = order[inds + 1] return keep def poly2origpoly(poly, x, y, rate): origpoly = [] for i in range(int(len(poly) / 2)): tmp_x = float(poly[i * 2] + x) / float(rate) tmp_y = float(poly[i * 2 + 1] + y) / float(rate) origpoly.append(tmp_x) origpoly.append(tmp_y) return origpoly def nmsbynamedict(nameboxdict, nms, thresh): """ Args: nameboxdict: nameboxdict nms: nms thresh: nms threshold Returns: nms result as dict """ nameboxnmsdict = {x: [] for x in nameboxdict} for imgname in nameboxdict: keep = nms(np.array(nameboxdict[imgname]), thresh) outdets = [] for index in keep: outdets.append(nameboxdict[imgname][index]) nameboxnmsdict[imgname] = outdets return nameboxnmsdict def merge_single(output_dir, nms, nms_thresh, pred_class_lst): """ Args: output_dir: output_dir nms: nms pred_class_lst: pred_class_lst class_name: class_name Returns: """ class_name, pred_bbox_list = pred_class_lst nameboxdict = {} for line in pred_bbox_list: splitline = line.split(' ') subname = splitline[0] splitname = subname.split('__') oriname = splitname[0] pattern1 = re.compile(r'__\d+___\d+') x_y = re.findall(pattern1, subname) x_y_2 = re.findall(r'\d+', x_y[0]) x, y = int(x_y_2[0]), int(x_y_2[1]) pattern2 = re.compile(r'__([\d+\.]+)__\d+___') rate = re.findall(pattern2, subname)[0] confidence = splitline[1] poly = list(map(float, splitline[2:])) origpoly = poly2origpoly(poly, x, y, rate) det = origpoly det.append(confidence) det = list(map(float, det)) if (oriname not in nameboxdict): nameboxdict[oriname] = [] nameboxdict[oriname].append(det) nameboxnmsdict = nmsbynamedict(nameboxdict, nms, nms_thresh) # write result dstname = os.path.join(output_dir, class_name + '.txt') with open(dstname, 'w') as f_out: for imgname in nameboxnmsdict: for det in nameboxnmsdict[imgname]: confidence = det[-1] bbox = det[0:-1] outline = imgname + ' ' + str(confidence) + ' ' + ' '.join( map(str, bbox)) f_out.write(outline + '\n') def generate_result(pred_txt_dir, output_dir='output', class_names=wordname_15, nms_thresh=0.1): """ pred_txt_dir: dir of pred txt output_dir: dir of output class_names: class names of data """ pred_txt_list = glob.glob("{}/*.txt".format(pred_txt_dir)) # step1: summary pred bbox pred_classes = {} for class_name in class_names: pred_classes[class_name] = [] for current_txt in pred_txt_list: img_id = os.path.split(current_txt)[1] img_id = img_id.split('.txt')[0] with open(current_txt) as f: res = f.readlines() for item in res: item = item.split(' ') pred_class = item[0] item[0] = img_id pred_bbox = ' '.join(item) pred_classes[pred_class].append(pred_bbox) pred_classes_lst = [] for class_name in pred_classes.keys(): print('class_name: {}, count: {}'.format(class_name, len(pred_classes[class_name]))) pred_classes_lst.append((class_name, pred_classes[class_name])) # step2: merge pool = Pool(len(class_names)) nms = py_cpu_nms_poly_fast mergesingle_fn = partial(merge_single, output_dir, nms, nms_thresh) pool.map(mergesingle_fn, pred_classes_lst) def parse_args(): parser = argparse.ArgumentParser(description='generate test results') parser.add_argument('--pred_txt_dir', type=str, help='path of pred txt dir') parser.add_argument( '--output_dir', type=str, default='output', help='path of output dir') parser.add_argument( '--data_type', type=str, default='dota10', help='data type') parser.add_argument( '--nms_thresh', type=float, default=0.1, help='nms threshold whild merging results') return parser.parse_args() if __name__ == '__main__': args = parse_args() output_dir = args.output_dir if not os.path.exists(output_dir): os.makedirs(output_dir) class_names = DATA_CLASSES[args.data_type] generate_result(args.pred_txt_dir, output_dir, class_names) print('done!')