rpn_head.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn.initializer import Normal
  18. from ppdet.core.workspace import register
  19. from .anchor_generator import AnchorGenerator
  20. from .target_layer import RPNTargetAssign
  21. from .proposal_generator import ProposalGenerator
  22. from ..cls_utils import _get_class_default_kwargs
  23. class RPNFeat(nn.Layer):
  24. """
  25. Feature extraction in RPN head
  26. Args:
  27. in_channel (int): Input channel
  28. out_channel (int): Output channel
  29. """
  30. def __init__(self, in_channel=1024, out_channel=1024):
  31. super(RPNFeat, self).__init__()
  32. # rpn feat is shared with each level
  33. self.rpn_conv = nn.Conv2D(
  34. in_channels=in_channel,
  35. out_channels=out_channel,
  36. kernel_size=3,
  37. padding=1,
  38. weight_attr=paddle.ParamAttr(initializer=Normal(
  39. mean=0., std=0.01)))
  40. self.rpn_conv.skip_quant = True
  41. def forward(self, feats):
  42. rpn_feats = []
  43. for feat in feats:
  44. rpn_feats.append(F.relu(self.rpn_conv(feat)))
  45. return rpn_feats
  46. @register
  47. class RPNHead(nn.Layer):
  48. """
  49. Region Proposal Network
  50. Args:
  51. anchor_generator (dict): configure of anchor generation
  52. rpn_target_assign (dict): configure of rpn targets assignment
  53. train_proposal (dict): configure of proposals generation
  54. at the stage of training
  55. test_proposal (dict): configure of proposals generation
  56. at the stage of prediction
  57. in_channel (int): channel of input feature maps which can be
  58. derived by from_config
  59. """
  60. __shared__ = ['export_onnx']
  61. __inject__ = ['loss_rpn_bbox']
  62. def __init__(self,
  63. anchor_generator=_get_class_default_kwargs(AnchorGenerator),
  64. rpn_target_assign=_get_class_default_kwargs(RPNTargetAssign),
  65. train_proposal=_get_class_default_kwargs(ProposalGenerator,
  66. 12000, 2000),
  67. test_proposal=_get_class_default_kwargs(ProposalGenerator),
  68. in_channel=1024,
  69. export_onnx=False,
  70. loss_rpn_bbox=None):
  71. super(RPNHead, self).__init__()
  72. self.anchor_generator = anchor_generator
  73. self.rpn_target_assign = rpn_target_assign
  74. self.train_proposal = train_proposal
  75. self.test_proposal = test_proposal
  76. self.export_onnx = export_onnx
  77. if isinstance(anchor_generator, dict):
  78. self.anchor_generator = AnchorGenerator(**anchor_generator)
  79. if isinstance(rpn_target_assign, dict):
  80. self.rpn_target_assign = RPNTargetAssign(**rpn_target_assign)
  81. if isinstance(train_proposal, dict):
  82. self.train_proposal = ProposalGenerator(**train_proposal)
  83. if isinstance(test_proposal, dict):
  84. self.test_proposal = ProposalGenerator(**test_proposal)
  85. self.loss_rpn_bbox = loss_rpn_bbox
  86. num_anchors = self.anchor_generator.num_anchors
  87. self.rpn_feat = RPNFeat(in_channel, in_channel)
  88. # rpn head is shared with each level
  89. # rpn roi classification scores
  90. self.rpn_rois_score = nn.Conv2D(
  91. in_channels=in_channel,
  92. out_channels=num_anchors,
  93. kernel_size=1,
  94. padding=0,
  95. weight_attr=paddle.ParamAttr(initializer=Normal(
  96. mean=0., std=0.01)))
  97. self.rpn_rois_score.skip_quant = True
  98. # rpn roi bbox regression deltas
  99. self.rpn_rois_delta = nn.Conv2D(
  100. in_channels=in_channel,
  101. out_channels=4 * num_anchors,
  102. kernel_size=1,
  103. padding=0,
  104. weight_attr=paddle.ParamAttr(initializer=Normal(
  105. mean=0., std=0.01)))
  106. self.rpn_rois_delta.skip_quant = True
  107. @classmethod
  108. def from_config(cls, cfg, input_shape):
  109. # FPN share same rpn head
  110. if isinstance(input_shape, (list, tuple)):
  111. input_shape = input_shape[0]
  112. return {'in_channel': input_shape.channels}
  113. def forward(self, feats, inputs):
  114. rpn_feats = self.rpn_feat(feats)
  115. scores = []
  116. deltas = []
  117. for rpn_feat in rpn_feats:
  118. rrs = self.rpn_rois_score(rpn_feat)
  119. rrd = self.rpn_rois_delta(rpn_feat)
  120. scores.append(rrs)
  121. deltas.append(rrd)
  122. anchors = self.anchor_generator(rpn_feats)
  123. rois, rois_num = self._gen_proposal(scores, deltas, anchors, inputs)
  124. if self.training:
  125. loss = self.get_loss(scores, deltas, anchors, inputs)
  126. return rois, rois_num, loss
  127. else:
  128. return rois, rois_num, None
  129. def _gen_proposal(self, scores, bbox_deltas, anchors, inputs):
  130. """
  131. scores (list[Tensor]): Multi-level scores prediction
  132. bbox_deltas (list[Tensor]): Multi-level deltas prediction
  133. anchors (list[Tensor]): Multi-level anchors
  134. inputs (dict): ground truth info
  135. """
  136. prop_gen = self.train_proposal if self.training else self.test_proposal
  137. im_shape = inputs['im_shape']
  138. # Collect multi-level proposals for each batch
  139. # Get 'topk' of them as final output
  140. if self.export_onnx:
  141. # bs = 1 when exporting onnx
  142. onnx_rpn_rois_list = []
  143. onnx_rpn_prob_list = []
  144. onnx_rpn_rois_num_list = []
  145. for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
  146. anchors):
  147. onnx_rpn_rois, onnx_rpn_rois_prob, onnx_rpn_rois_num, onnx_post_nms_top_n = prop_gen(
  148. scores=rpn_score[0:1],
  149. bbox_deltas=rpn_delta[0:1],
  150. anchors=anchor,
  151. im_shape=im_shape[0:1])
  152. onnx_rpn_rois_list.append(onnx_rpn_rois)
  153. onnx_rpn_prob_list.append(onnx_rpn_rois_prob)
  154. onnx_rpn_rois_num_list.append(onnx_rpn_rois_num)
  155. onnx_rpn_rois = paddle.concat(onnx_rpn_rois_list)
  156. onnx_rpn_prob = paddle.concat(onnx_rpn_prob_list).flatten()
  157. onnx_top_n = paddle.to_tensor(onnx_post_nms_top_n).cast('int32')
  158. onnx_num_rois = paddle.shape(onnx_rpn_prob)[0].cast('int32')
  159. k = paddle.minimum(onnx_top_n, onnx_num_rois)
  160. onnx_topk_prob, onnx_topk_inds = paddle.topk(onnx_rpn_prob, k)
  161. onnx_topk_rois = paddle.gather(onnx_rpn_rois, onnx_topk_inds)
  162. # TODO(wangguanzhong): Now bs_rois_collect in export_onnx is moved outside conditional branch
  163. # due to problems in dy2static of paddle. Will fix it when updating paddle framework.
  164. # bs_rois_collect = [onnx_topk_rois]
  165. # bs_rois_num_collect = paddle.shape(onnx_topk_rois)[0]
  166. else:
  167. bs_rois_collect = []
  168. bs_rois_num_collect = []
  169. batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
  170. # Generate proposals for each level and each batch.
  171. # Discard batch-computing to avoid sorting bbox cross different batches.
  172. for i in range(batch_size):
  173. rpn_rois_list = []
  174. rpn_prob_list = []
  175. rpn_rois_num_list = []
  176. for rpn_score, rpn_delta, anchor in zip(scores, bbox_deltas,
  177. anchors):
  178. rpn_rois, rpn_rois_prob, rpn_rois_num, post_nms_top_n = prop_gen(
  179. scores=rpn_score[i:i + 1],
  180. bbox_deltas=rpn_delta[i:i + 1],
  181. anchors=anchor,
  182. im_shape=im_shape[i:i + 1])
  183. rpn_rois_list.append(rpn_rois)
  184. rpn_prob_list.append(rpn_rois_prob)
  185. rpn_rois_num_list.append(rpn_rois_num)
  186. if len(scores) > 1:
  187. rpn_rois = paddle.concat(rpn_rois_list)
  188. rpn_prob = paddle.concat(rpn_prob_list).flatten()
  189. num_rois = paddle.shape(rpn_prob)[0].cast('int32')
  190. if num_rois > post_nms_top_n:
  191. topk_prob, topk_inds = paddle.topk(rpn_prob,
  192. post_nms_top_n)
  193. topk_rois = paddle.gather(rpn_rois, topk_inds)
  194. else:
  195. topk_rois = rpn_rois
  196. topk_prob = rpn_prob
  197. else:
  198. topk_rois = rpn_rois_list[0]
  199. topk_prob = rpn_prob_list[0].flatten()
  200. bs_rois_collect.append(topk_rois)
  201. bs_rois_num_collect.append(paddle.shape(topk_rois)[0])
  202. bs_rois_num_collect = paddle.concat(bs_rois_num_collect)
  203. if self.export_onnx:
  204. output_rois = [onnx_topk_rois]
  205. output_rois_num = paddle.shape(onnx_topk_rois)[0]
  206. else:
  207. output_rois = bs_rois_collect
  208. output_rois_num = bs_rois_num_collect
  209. return output_rois, output_rois_num
  210. def get_loss(self, pred_scores, pred_deltas, anchors, inputs):
  211. """
  212. pred_scores (list[Tensor]): Multi-level scores prediction
  213. pred_deltas (list[Tensor]): Multi-level deltas prediction
  214. anchors (list[Tensor]): Multi-level anchors
  215. inputs (dict): ground truth info, including im, gt_bbox, gt_score
  216. """
  217. anchors = [paddle.reshape(a, shape=(-1, 4)) for a in anchors]
  218. anchors = paddle.concat(anchors)
  219. scores = [
  220. paddle.reshape(
  221. paddle.transpose(
  222. v, perm=[0, 2, 3, 1]),
  223. shape=(v.shape[0], -1, 1)) for v in pred_scores
  224. ]
  225. scores = paddle.concat(scores, axis=1)
  226. deltas = [
  227. paddle.reshape(
  228. paddle.transpose(
  229. v, perm=[0, 2, 3, 1]),
  230. shape=(v.shape[0], -1, 4)) for v in pred_deltas
  231. ]
  232. deltas = paddle.concat(deltas, axis=1)
  233. score_tgt, bbox_tgt, loc_tgt, norm = self.rpn_target_assign(inputs,
  234. anchors)
  235. scores = paddle.reshape(x=scores, shape=(-1, ))
  236. deltas = paddle.reshape(x=deltas, shape=(-1, 4))
  237. score_tgt = paddle.concat(score_tgt)
  238. score_tgt.stop_gradient = True
  239. pos_mask = score_tgt == 1
  240. pos_ind = paddle.nonzero(pos_mask)
  241. valid_mask = score_tgt >= 0
  242. valid_ind = paddle.nonzero(valid_mask)
  243. # cls loss
  244. if valid_ind.shape[0] == 0:
  245. loss_rpn_cls = paddle.zeros([1], dtype='float32')
  246. else:
  247. score_pred = paddle.gather(scores, valid_ind)
  248. score_label = paddle.gather(score_tgt, valid_ind).cast('float32')
  249. score_label.stop_gradient = True
  250. loss_rpn_cls = F.binary_cross_entropy_with_logits(
  251. logit=score_pred, label=score_label, reduction="sum")
  252. # reg loss
  253. if pos_ind.shape[0] == 0:
  254. loss_rpn_reg = paddle.zeros([1], dtype='float32')
  255. else:
  256. loc_pred = paddle.gather(deltas, pos_ind)
  257. loc_tgt = paddle.concat(loc_tgt)
  258. loc_tgt = paddle.gather(loc_tgt, pos_ind)
  259. loc_tgt.stop_gradient = True
  260. if self.loss_rpn_bbox is None:
  261. loss_rpn_reg = paddle.abs(loc_pred - loc_tgt).sum()
  262. else:
  263. loss_rpn_reg = self.loss_rpn_bbox(loc_pred, loc_tgt).sum()
  264. return {
  265. 'loss_rpn_cls': loss_rpn_cls / norm,
  266. 'loss_rpn_reg': loss_rpn_reg / norm
  267. }