target_layer.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import sys
  15. import paddle
  16. from ppdet.core.workspace import register, serializable
  17. from .target import rpn_anchor_target, generate_proposal_target, generate_mask_target, libra_generate_proposal_target
  18. import numpy as np
  19. @register
  20. @serializable
  21. class RPNTargetAssign(object):
  22. __shared__ = ['assign_on_cpu']
  23. """
  24. RPN targets assignment module
  25. The assignment consists of three steps:
  26. 1. Match anchor and ground-truth box, label the anchor with foreground
  27. or background sample
  28. 2. Sample anchors to keep the properly ratio between foreground and
  29. background
  30. 3. Generate the targets for classification and regression branch
  31. Args:
  32. batch_size_per_im (int): Total number of RPN samples per image.
  33. default 256
  34. fg_fraction (float): Fraction of anchors that is labeled
  35. foreground, default 0.5
  36. positive_overlap (float): Minimum overlap required between an anchor
  37. and ground-truth box for the (anchor, gt box) pair to be
  38. a foreground sample. default 0.7
  39. negative_overlap (float): Maximum overlap allowed between an anchor
  40. and ground-truth box for the (anchor, gt box) pair to be
  41. a background sample. default 0.3
  42. ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
  43. if the value is larger than zero.
  44. use_random (bool): Use random sampling to choose foreground and
  45. background boxes, default true.
  46. assign_on_cpu (bool): In case the number of gt box is too large,
  47. compute IoU on CPU, default false.
  48. """
  49. def __init__(self,
  50. batch_size_per_im=256,
  51. fg_fraction=0.5,
  52. positive_overlap=0.7,
  53. negative_overlap=0.3,
  54. ignore_thresh=-1.,
  55. use_random=True,
  56. assign_on_cpu=False):
  57. super(RPNTargetAssign, self).__init__()
  58. self.batch_size_per_im = batch_size_per_im
  59. self.fg_fraction = fg_fraction
  60. self.positive_overlap = positive_overlap
  61. self.negative_overlap = negative_overlap
  62. self.ignore_thresh = ignore_thresh
  63. self.use_random = use_random
  64. self.assign_on_cpu = assign_on_cpu
  65. def __call__(self, inputs, anchors):
  66. """
  67. inputs: ground-truth instances.
  68. anchor_box (Tensor): [num_anchors, 4], num_anchors are all anchors in all feature maps.
  69. """
  70. gt_boxes = inputs['gt_bbox']
  71. is_crowd = inputs.get('is_crowd', None)
  72. batch_size = len(gt_boxes)
  73. tgt_labels, tgt_bboxes, tgt_deltas = rpn_anchor_target(
  74. anchors,
  75. gt_boxes,
  76. self.batch_size_per_im,
  77. self.positive_overlap,
  78. self.negative_overlap,
  79. self.fg_fraction,
  80. self.use_random,
  81. batch_size,
  82. self.ignore_thresh,
  83. is_crowd,
  84. assign_on_cpu=self.assign_on_cpu)
  85. norm = self.batch_size_per_im * batch_size
  86. return tgt_labels, tgt_bboxes, tgt_deltas, norm
  87. @register
  88. class BBoxAssigner(object):
  89. __shared__ = ['num_classes', 'assign_on_cpu']
  90. """
  91. RCNN targets assignment module
  92. The assignment consists of three steps:
  93. 1. Match RoIs and ground-truth box, label the RoIs with foreground
  94. or background sample
  95. 2. Sample anchors to keep the properly ratio between foreground and
  96. background
  97. 3. Generate the targets for classification and regression branch
  98. Args:
  99. batch_size_per_im (int): Total number of RoIs per image.
  100. default 512
  101. fg_fraction (float): Fraction of RoIs that is labeled
  102. foreground, default 0.25
  103. fg_thresh (float): Minimum overlap required between a RoI
  104. and ground-truth box for the (roi, gt box) pair to be
  105. a foreground sample. default 0.5
  106. bg_thresh (float): Maximum overlap allowed between a RoI
  107. and ground-truth box for the (roi, gt box) pair to be
  108. a background sample. default 0.5
  109. ignore_thresh(float): Threshold for ignoring the is_crowd ground-truth
  110. if the value is larger than zero.
  111. use_random (bool): Use random sampling to choose foreground and
  112. background boxes, default true
  113. cascade_iou (list[iou]): The list of overlap to select foreground and
  114. background of each stage, which is only used In Cascade RCNN.
  115. num_classes (int): The number of class.
  116. assign_on_cpu (bool): In case the number of gt box is too large,
  117. compute IoU on CPU, default false.
  118. """
  119. def __init__(self,
  120. batch_size_per_im=512,
  121. fg_fraction=.25,
  122. fg_thresh=.5,
  123. bg_thresh=.5,
  124. ignore_thresh=-1.,
  125. use_random=True,
  126. cascade_iou=[0.5, 0.6, 0.7],
  127. num_classes=80,
  128. assign_on_cpu=False):
  129. super(BBoxAssigner, self).__init__()
  130. self.batch_size_per_im = batch_size_per_im
  131. self.fg_fraction = fg_fraction
  132. self.fg_thresh = fg_thresh
  133. self.bg_thresh = bg_thresh
  134. self.ignore_thresh = ignore_thresh
  135. self.use_random = use_random
  136. self.cascade_iou = cascade_iou
  137. self.num_classes = num_classes
  138. self.assign_on_cpu = assign_on_cpu
  139. def __call__(self,
  140. rpn_rois,
  141. rpn_rois_num,
  142. inputs,
  143. stage=0,
  144. is_cascade=False,
  145. add_gt_as_proposals=True):
  146. gt_classes = inputs['gt_class']
  147. gt_boxes = inputs['gt_bbox']
  148. is_crowd = inputs.get('is_crowd', None)
  149. # rois, tgt_labels, tgt_bboxes, tgt_gt_inds
  150. # new_rois_num
  151. outs = generate_proposal_target(
  152. rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
  153. self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
  154. self.ignore_thresh, is_crowd, self.use_random, is_cascade,
  155. self.cascade_iou[stage], self.assign_on_cpu, add_gt_as_proposals)
  156. rois = outs[0]
  157. rois_num = outs[-1]
  158. # tgt_labels, tgt_bboxes, tgt_gt_inds
  159. targets = outs[1:4]
  160. return rois, rois_num, targets
  161. @register
  162. class BBoxLibraAssigner(object):
  163. __shared__ = ['num_classes']
  164. """
  165. Libra-RCNN targets assignment module
  166. The assignment consists of three steps:
  167. 1. Match RoIs and ground-truth box, label the RoIs with foreground
  168. or background sample
  169. 2. Sample anchors to keep the properly ratio between foreground and
  170. background
  171. 3. Generate the targets for classification and regression branch
  172. Args:
  173. batch_size_per_im (int): Total number of RoIs per image.
  174. default 512
  175. fg_fraction (float): Fraction of RoIs that is labeled
  176. foreground, default 0.25
  177. fg_thresh (float): Minimum overlap required between a RoI
  178. and ground-truth box for the (roi, gt box) pair to be
  179. a foreground sample. default 0.5
  180. bg_thresh (float): Maximum overlap allowed between a RoI
  181. and ground-truth box for the (roi, gt box) pair to be
  182. a background sample. default 0.5
  183. use_random (bool): Use random sampling to choose foreground and
  184. background boxes, default true
  185. cascade_iou (list[iou]): The list of overlap to select foreground and
  186. background of each stage, which is only used In Cascade RCNN.
  187. num_classes (int): The number of class.
  188. num_bins (int): The number of libra_sample.
  189. """
  190. def __init__(self,
  191. batch_size_per_im=512,
  192. fg_fraction=.25,
  193. fg_thresh=.5,
  194. bg_thresh=.5,
  195. use_random=True,
  196. cascade_iou=[0.5, 0.6, 0.7],
  197. num_classes=80,
  198. num_bins=3):
  199. super(BBoxLibraAssigner, self).__init__()
  200. self.batch_size_per_im = batch_size_per_im
  201. self.fg_fraction = fg_fraction
  202. self.fg_thresh = fg_thresh
  203. self.bg_thresh = bg_thresh
  204. self.use_random = use_random
  205. self.cascade_iou = cascade_iou
  206. self.num_classes = num_classes
  207. self.num_bins = num_bins
  208. def __call__(self,
  209. rpn_rois,
  210. rpn_rois_num,
  211. inputs,
  212. stage=0,
  213. is_cascade=False):
  214. gt_classes = inputs['gt_class']
  215. gt_boxes = inputs['gt_bbox']
  216. # rois, tgt_labels, tgt_bboxes, tgt_gt_inds
  217. outs = libra_generate_proposal_target(
  218. rpn_rois, gt_classes, gt_boxes, self.batch_size_per_im,
  219. self.fg_fraction, self.fg_thresh, self.bg_thresh, self.num_classes,
  220. self.use_random, is_cascade, self.cascade_iou[stage], self.num_bins)
  221. rois = outs[0]
  222. rois_num = outs[-1]
  223. # tgt_labels, tgt_bboxes, tgt_gt_inds
  224. targets = outs[1:4]
  225. return rois, rois_num, targets
  226. @register
  227. @serializable
  228. class MaskAssigner(object):
  229. __shared__ = ['num_classes', 'mask_resolution']
  230. """
  231. Mask targets assignment module
  232. The assignment consists of three steps:
  233. 1. Select RoIs labels with foreground.
  234. 2. Encode the RoIs and corresponding gt polygons to generate
  235. mask target
  236. Args:
  237. num_classes (int): The number of class
  238. mask_resolution (int): The resolution of mask target, default 14
  239. """
  240. def __init__(self, num_classes=80, mask_resolution=14):
  241. super(MaskAssigner, self).__init__()
  242. self.num_classes = num_classes
  243. self.mask_resolution = mask_resolution
  244. def __call__(self, rois, tgt_labels, tgt_gt_inds, inputs):
  245. gt_segms = inputs['gt_poly']
  246. outs = generate_mask_target(gt_segms, rois, tgt_labels, tgt_gt_inds,
  247. self.num_classes, self.mask_resolution)
  248. # mask_rois, mask_rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights
  249. return outs
  250. @register
  251. class RBoxAssigner(object):
  252. """
  253. assigner of rbox
  254. Args:
  255. pos_iou_thr (float): threshold of pos samples
  256. neg_iou_thr (float): threshold of neg samples
  257. min_iou_thr (float): the min threshold of samples
  258. ignore_iof_thr (int): the ignored threshold
  259. """
  260. def __init__(self,
  261. pos_iou_thr=0.5,
  262. neg_iou_thr=0.4,
  263. min_iou_thr=0.0,
  264. ignore_iof_thr=-2):
  265. super(RBoxAssigner, self).__init__()
  266. self.pos_iou_thr = pos_iou_thr
  267. self.neg_iou_thr = neg_iou_thr
  268. self.min_iou_thr = min_iou_thr
  269. self.ignore_iof_thr = ignore_iof_thr
  270. def anchor_valid(self, anchors):
  271. """
  272. Args:
  273. anchor: M x 4
  274. Returns:
  275. """
  276. if anchors.ndim == 3:
  277. anchors = anchors.reshape(-1, anchors.shape[-1])
  278. assert anchors.ndim == 2
  279. anchor_num = anchors.shape[0]
  280. anchor_valid = np.ones((anchor_num), np.int32)
  281. anchor_inds = np.arange(anchor_num)
  282. return anchor_inds
  283. def rbox2delta(self,
  284. proposals,
  285. gt,
  286. means=[0, 0, 0, 0, 0],
  287. stds=[1, 1, 1, 1, 1]):
  288. """
  289. Args:
  290. proposals: tensor [N, 5]
  291. gt: gt [N, 5]
  292. means: means [5]
  293. stds: stds [5]
  294. Returns:
  295. """
  296. proposals = proposals.astype(np.float64)
  297. PI = np.pi
  298. gt_widths = gt[..., 2]
  299. gt_heights = gt[..., 3]
  300. gt_angle = gt[..., 4]
  301. proposals_widths = proposals[..., 2]
  302. proposals_heights = proposals[..., 3]
  303. proposals_angle = proposals[..., 4]
  304. coord = gt[..., 0:2] - proposals[..., 0:2]
  305. dx = (np.cos(proposals[..., 4]) * coord[..., 0] +
  306. np.sin(proposals[..., 4]) * coord[..., 1]) / proposals_widths
  307. dy = (-np.sin(proposals[..., 4]) * coord[..., 0] +
  308. np.cos(proposals[..., 4]) * coord[..., 1]) / proposals_heights
  309. dw = np.log(gt_widths / proposals_widths)
  310. dh = np.log(gt_heights / proposals_heights)
  311. da = (gt_angle - proposals_angle)
  312. da = (da + PI / 4) % PI - PI / 4
  313. da /= PI
  314. deltas = np.stack([dx, dy, dw, dh, da], axis=-1)
  315. means = np.array(means, dtype=deltas.dtype)
  316. stds = np.array(stds, dtype=deltas.dtype)
  317. deltas = (deltas - means) / stds
  318. deltas = deltas.astype(np.float32)
  319. return deltas
  320. def assign_anchor(self,
  321. anchors,
  322. gt_bboxes,
  323. gt_labels,
  324. pos_iou_thr,
  325. neg_iou_thr,
  326. min_iou_thr=0.0,
  327. ignore_iof_thr=-2):
  328. assert anchors.shape[1] == 4 or anchors.shape[1] == 5
  329. assert gt_bboxes.shape[1] == 4 or gt_bboxes.shape[1] == 5
  330. anchors_xc_yc = anchors
  331. gt_bboxes_xc_yc = gt_bboxes
  332. # calc rbox iou
  333. anchors_xc_yc = anchors_xc_yc.astype(np.float32)
  334. gt_bboxes_xc_yc = gt_bboxes_xc_yc.astype(np.float32)
  335. anchors_xc_yc = paddle.to_tensor(anchors_xc_yc)
  336. gt_bboxes_xc_yc = paddle.to_tensor(gt_bboxes_xc_yc)
  337. try:
  338. from ext_op import rbox_iou
  339. except Exception as e:
  340. print("import custom_ops error, try install ext_op " \
  341. "following ppdet/ext_op/README.md", e)
  342. sys.stdout.flush()
  343. sys.exit(-1)
  344. iou = rbox_iou(gt_bboxes_xc_yc, anchors_xc_yc)
  345. iou = iou.numpy()
  346. iou = iou.T
  347. # every gt's anchor's index
  348. gt_bbox_anchor_inds = iou.argmax(axis=0)
  349. gt_bbox_anchor_iou = iou[gt_bbox_anchor_inds, np.arange(iou.shape[1])]
  350. gt_bbox_anchor_iou_inds = np.where(iou == gt_bbox_anchor_iou)[0]
  351. # every anchor's gt bbox's index
  352. anchor_gt_bbox_inds = iou.argmax(axis=1)
  353. anchor_gt_bbox_iou = iou[np.arange(iou.shape[0]), anchor_gt_bbox_inds]
  354. # (1) set labels=-2 as default
  355. labels = np.ones((iou.shape[0], ), dtype=np.int32) * ignore_iof_thr
  356. # (2) assign ignore
  357. labels[anchor_gt_bbox_iou < min_iou_thr] = ignore_iof_thr
  358. # (3) assign neg_ids -1
  359. assign_neg_ids1 = anchor_gt_bbox_iou >= min_iou_thr
  360. assign_neg_ids2 = anchor_gt_bbox_iou < neg_iou_thr
  361. assign_neg_ids = np.logical_and(assign_neg_ids1, assign_neg_ids2)
  362. labels[assign_neg_ids] = -1
  363. # anchor_gt_bbox_iou_inds
  364. # (4) assign max_iou as pos_ids >=0
  365. anchor_gt_bbox_iou_inds = anchor_gt_bbox_inds[gt_bbox_anchor_iou_inds]
  366. # gt_bbox_anchor_iou_inds = np.logical_and(gt_bbox_anchor_iou_inds, anchor_gt_bbox_iou >= min_iou_thr)
  367. labels[gt_bbox_anchor_iou_inds] = gt_labels[anchor_gt_bbox_iou_inds]
  368. # (5) assign >= pos_iou_thr as pos_ids
  369. iou_pos_iou_thr_ids = anchor_gt_bbox_iou >= pos_iou_thr
  370. iou_pos_iou_thr_ids_box_inds = anchor_gt_bbox_inds[iou_pos_iou_thr_ids]
  371. labels[iou_pos_iou_thr_ids] = gt_labels[iou_pos_iou_thr_ids_box_inds]
  372. return anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels
  373. def __call__(self, anchors, gt_bboxes, gt_labels, is_crowd):
  374. assert anchors.ndim == 2
  375. assert anchors.shape[1] == 5
  376. assert gt_bboxes.ndim == 2
  377. assert gt_bboxes.shape[1] == 5
  378. pos_iou_thr = self.pos_iou_thr
  379. neg_iou_thr = self.neg_iou_thr
  380. min_iou_thr = self.min_iou_thr
  381. ignore_iof_thr = self.ignore_iof_thr
  382. anchor_num = anchors.shape[0]
  383. gt_bboxes = gt_bboxes
  384. is_crowd_slice = is_crowd
  385. not_crowd_inds = np.where(is_crowd_slice == 0)
  386. # Step1: match anchor and gt_bbox
  387. anchor_gt_bbox_inds, anchor_gt_bbox_iou, labels = self.assign_anchor(
  388. anchors, gt_bboxes,
  389. gt_labels.reshape(-1), pos_iou_thr, neg_iou_thr, min_iou_thr,
  390. ignore_iof_thr)
  391. # Step2: sample anchor
  392. pos_inds = np.where(labels >= 0)[0]
  393. neg_inds = np.where(labels == -1)[0]
  394. # Step3: make output
  395. anchors_num = anchors.shape[0]
  396. bbox_targets = np.zeros_like(anchors)
  397. bbox_weights = np.zeros_like(anchors)
  398. bbox_gt_bboxes = np.zeros_like(anchors)
  399. pos_labels = np.zeros(anchors_num, dtype=np.int32)
  400. pos_labels_weights = np.zeros(anchors_num, dtype=np.float32)
  401. pos_sampled_anchors = anchors[pos_inds]
  402. pos_sampled_gt_boxes = gt_bboxes[anchor_gt_bbox_inds[pos_inds]]
  403. if len(pos_inds) > 0:
  404. pos_bbox_targets = self.rbox2delta(pos_sampled_anchors,
  405. pos_sampled_gt_boxes)
  406. bbox_targets[pos_inds, :] = pos_bbox_targets
  407. bbox_gt_bboxes[pos_inds, :] = pos_sampled_gt_boxes
  408. bbox_weights[pos_inds, :] = 1.0
  409. pos_labels[pos_inds] = labels[pos_inds]
  410. pos_labels_weights[pos_inds] = 1.0
  411. if len(neg_inds) > 0:
  412. pos_labels_weights[neg_inds] = 1.0
  413. return (pos_labels, pos_labels_weights, bbox_targets, bbox_weights,
  414. bbox_gt_bboxes, pos_inds, neg_inds)