cascade_head.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn.initializer import Normal
  18. from ppdet.core.workspace import register
  19. from .bbox_head import BBoxHead, TwoFCHead, XConvNormHead
  20. from .roi_extractor import RoIAlign
  21. from ..shape_spec import ShapeSpec
  22. from ..bbox_utils import delta2bbox, clip_bbox, nonempty_bbox
  23. from ..cls_utils import _get_class_default_kwargs
  24. __all__ = ['CascadeTwoFCHead', 'CascadeXConvNormHead', 'CascadeHead']
  25. @register
  26. class CascadeTwoFCHead(nn.Layer):
  27. __shared__ = ['num_cascade_stage']
  28. """
  29. Cascade RCNN bbox head with Two fc layers to extract feature
  30. Args:
  31. in_channel (int): Input channel which can be derived by from_config
  32. out_channel (int): Output channel
  33. resolution (int): Resolution of input feature map, default 7
  34. num_cascade_stage (int): The number of cascade stage, default 3
  35. """
  36. def __init__(self,
  37. in_channel=256,
  38. out_channel=1024,
  39. resolution=7,
  40. num_cascade_stage=3):
  41. super(CascadeTwoFCHead, self).__init__()
  42. self.in_channel = in_channel
  43. self.out_channel = out_channel
  44. self.head_list = []
  45. for stage in range(num_cascade_stage):
  46. head_per_stage = self.add_sublayer(
  47. str(stage), TwoFCHead(in_channel, out_channel, resolution))
  48. self.head_list.append(head_per_stage)
  49. @classmethod
  50. def from_config(cls, cfg, input_shape):
  51. s = input_shape
  52. s = s[0] if isinstance(s, (list, tuple)) else s
  53. return {'in_channel': s.channels}
  54. @property
  55. def out_shape(self):
  56. return [ShapeSpec(channels=self.out_channel, )]
  57. def forward(self, rois_feat, stage=0):
  58. out = self.head_list[stage](rois_feat)
  59. return out
  60. @register
  61. class CascadeXConvNormHead(nn.Layer):
  62. __shared__ = ['norm_type', 'freeze_norm', 'num_cascade_stage']
  63. """
  64. Cascade RCNN bbox head with serveral convolution layers
  65. Args:
  66. in_channel (int): Input channels which can be derived by from_config
  67. num_convs (int): The number of conv layers
  68. conv_dim (int): The number of channels for the conv layers
  69. out_channel (int): Output channels
  70. resolution (int): Resolution of input feature map
  71. norm_type (string): Norm type, bn, gn, sync_bn are available,
  72. default `gn`
  73. freeze_norm (bool): Whether to freeze the norm
  74. num_cascade_stage (int): The number of cascade stage, default 3
  75. """
  76. def __init__(self,
  77. in_channel=256,
  78. num_convs=4,
  79. conv_dim=256,
  80. out_channel=1024,
  81. resolution=7,
  82. norm_type='gn',
  83. freeze_norm=False,
  84. num_cascade_stage=3):
  85. super(CascadeXConvNormHead, self).__init__()
  86. self.in_channel = in_channel
  87. self.out_channel = out_channel
  88. self.head_list = []
  89. for stage in range(num_cascade_stage):
  90. head_per_stage = self.add_sublayer(
  91. str(stage),
  92. XConvNormHead(
  93. in_channel,
  94. num_convs,
  95. conv_dim,
  96. out_channel,
  97. resolution,
  98. norm_type,
  99. freeze_norm,
  100. stage_name='stage{}_'.format(stage)))
  101. self.head_list.append(head_per_stage)
  102. @classmethod
  103. def from_config(cls, cfg, input_shape):
  104. s = input_shape
  105. s = s[0] if isinstance(s, (list, tuple)) else s
  106. return {'in_channel': s.channels}
  107. @property
  108. def out_shape(self):
  109. return [ShapeSpec(channels=self.out_channel, )]
  110. def forward(self, rois_feat, stage=0):
  111. out = self.head_list[stage](rois_feat)
  112. return out
  113. @register
  114. class CascadeHead(BBoxHead):
  115. __shared__ = ['num_classes', 'num_cascade_stages']
  116. __inject__ = ['bbox_assigner', 'bbox_loss']
  117. """
  118. Cascade RCNN bbox head
  119. Args:
  120. head (nn.Layer): Extract feature in bbox head
  121. in_channel (int): Input channel after RoI extractor
  122. roi_extractor (object): The module of RoI Extractor
  123. bbox_assigner (object): The module of Box Assigner, label and sample the
  124. box.
  125. num_classes (int): The number of classes
  126. bbox_weight (List[List[float]]): The weight to get the decode box and the
  127. length of weight is the number of cascade stage
  128. num_cascade_stages (int): THe number of stage to refine the box
  129. """
  130. def __init__(self,
  131. head,
  132. in_channel,
  133. roi_extractor=_get_class_default_kwargs(RoIAlign),
  134. bbox_assigner='BboxAssigner',
  135. num_classes=80,
  136. bbox_weight=[[10., 10., 5., 5.], [20.0, 20.0, 10.0, 10.0],
  137. [30.0, 30.0, 15.0, 15.0]],
  138. num_cascade_stages=3,
  139. bbox_loss=None,
  140. reg_class_agnostic=True,
  141. stage_loss_weights=None,
  142. loss_normalize_pos=False,
  143. add_gt_as_proposals=[True, False, False]):
  144. nn.Layer.__init__(self, )
  145. self.head = head
  146. self.roi_extractor = roi_extractor
  147. if isinstance(roi_extractor, dict):
  148. self.roi_extractor = RoIAlign(**roi_extractor)
  149. self.bbox_assigner = bbox_assigner
  150. self.num_classes = num_classes
  151. self.bbox_weight = bbox_weight
  152. self.num_cascade_stages = num_cascade_stages
  153. self.bbox_loss = bbox_loss
  154. self.stage_loss_weights = [
  155. 1. / num_cascade_stages for _ in range(num_cascade_stages)
  156. ] if stage_loss_weights is None else stage_loss_weights
  157. self.add_gt_as_proposals = add_gt_as_proposals
  158. assert len(
  159. self.stage_loss_weights
  160. ) == num_cascade_stages, f'stage_loss_weights({len(self.stage_loss_weights)}) do not equal to num_cascade_stages({num_cascade_stages})'
  161. self.reg_class_agnostic = reg_class_agnostic
  162. num_bbox_delta = 4 if reg_class_agnostic else 4 * num_classes
  163. self.loss_normalize_pos = loss_normalize_pos
  164. self.bbox_score_list = []
  165. self.bbox_delta_list = []
  166. for i in range(num_cascade_stages):
  167. score_name = 'bbox_score_stage{}'.format(i)
  168. delta_name = 'bbox_delta_stage{}'.format(i)
  169. bbox_score = self.add_sublayer(
  170. score_name,
  171. nn.Linear(
  172. in_channel,
  173. self.num_classes + 1,
  174. weight_attr=paddle.ParamAttr(initializer=Normal(
  175. mean=0.0, std=0.01))))
  176. bbox_delta = self.add_sublayer(
  177. delta_name,
  178. nn.Linear(
  179. in_channel,
  180. num_bbox_delta,
  181. weight_attr=paddle.ParamAttr(initializer=Normal(
  182. mean=0.0, std=0.001))))
  183. self.bbox_score_list.append(bbox_score)
  184. self.bbox_delta_list.append(bbox_delta)
  185. self.assigned_label = None
  186. self.assigned_rois = None
  187. def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None):
  188. """
  189. body_feats (list[Tensor]): Feature maps from backbone
  190. rois (Tensor): RoIs generated from RPN module
  191. rois_num (Tensor): The number of RoIs in each image
  192. inputs (dict{Tensor}): The ground-truth of image
  193. """
  194. targets = []
  195. if self.training:
  196. rois, rois_num, targets = self.bbox_assigner(
  197. rois,
  198. rois_num,
  199. inputs,
  200. add_gt_as_proposals=self.add_gt_as_proposals[0])
  201. targets_list = [targets]
  202. self.assigned_rois = (rois, rois_num)
  203. self.assigned_targets = targets
  204. pred_bbox = None
  205. head_out_list = []
  206. for i in range(self.num_cascade_stages):
  207. if i > 0:
  208. rois, rois_num = self._get_rois_from_boxes(pred_bbox,
  209. inputs['im_shape'])
  210. if self.training:
  211. rois, rois_num, targets = self.bbox_assigner(
  212. rois,
  213. rois_num,
  214. inputs,
  215. i,
  216. is_cascade=True,
  217. add_gt_as_proposals=self.add_gt_as_proposals[i])
  218. targets_list.append(targets)
  219. rois_feat = self.roi_extractor(body_feats, rois, rois_num)
  220. bbox_feat = self.head(rois_feat, i)
  221. scores = self.bbox_score_list[i](bbox_feat)
  222. deltas = self.bbox_delta_list[i](bbox_feat)
  223. # TODO (lyuwenyu) Is it correct for only one class ?
  224. if not self.reg_class_agnostic and i < self.num_cascade_stages - 1:
  225. deltas = deltas.reshape([deltas.shape[0], self.num_classes, 4])
  226. labels = scores[:, :-1].argmax(axis=-1)
  227. if self.training:
  228. deltas = deltas[paddle.arange(deltas.shape[0]), labels]
  229. else:
  230. deltas = deltas[((deltas + 10000) * F.one_hot(
  231. labels, num_classes=self.num_classes).unsqueeze(-1) != 0
  232. ).nonzero(as_tuple=True)].reshape(
  233. [deltas.shape[0], 4])
  234. head_out_list.append([scores, deltas, rois])
  235. pred_bbox = self._get_pred_bbox(deltas, rois, self.bbox_weight[i])
  236. if self.training:
  237. loss = {}
  238. for stage, value in enumerate(zip(head_out_list, targets_list)):
  239. (scores, deltas, rois), targets = value
  240. loss_stage = self.get_loss(
  241. scores,
  242. deltas,
  243. targets,
  244. rois,
  245. self.bbox_weight[stage],
  246. loss_normalize_pos=self.loss_normalize_pos)
  247. for k, v in loss_stage.items():
  248. loss[k + "_stage{}".format(
  249. stage)] = v * self.stage_loss_weights[stage]
  250. return loss, bbox_feat
  251. else:
  252. scores, deltas, self.refined_rois = self.get_prediction(
  253. head_out_list)
  254. return (deltas, scores), self.head
  255. def _get_rois_from_boxes(self, boxes, im_shape):
  256. rois = []
  257. for i, boxes_per_image in enumerate(boxes):
  258. clip_box = clip_bbox(boxes_per_image, im_shape[i])
  259. if self.training:
  260. keep = nonempty_bbox(clip_box)
  261. if keep.shape[0] == 0:
  262. keep = paddle.zeros([1], dtype='int32')
  263. clip_box = paddle.gather(clip_box, keep)
  264. rois.append(clip_box)
  265. rois_num = paddle.concat([paddle.shape(r)[0] for r in rois])
  266. return rois, rois_num
  267. def _get_pred_bbox(self, deltas, proposals, weights):
  268. pred_proposals = paddle.concat(proposals) if len(
  269. proposals) > 1 else proposals[0]
  270. pred_bbox = delta2bbox(deltas, pred_proposals, weights)
  271. pred_bbox = paddle.reshape(pred_bbox, [-1, deltas.shape[-1]])
  272. num_prop = []
  273. for p in proposals:
  274. num_prop.append(p.shape[0])
  275. # NOTE(dev): num_prob will be tagged as LoDTensorArray because it
  276. # depends on batch_size under @to_static. However the argument
  277. # num_or_sections in paddle.split does not support LoDTensorArray,
  278. # so we use [-1] to replace it if num_prop is not list. The modification
  279. # This ensures the correctness of both dynamic and static graphs.
  280. if not isinstance(num_prop, list):
  281. num_prop = [-1]
  282. return pred_bbox.split(num_prop)
  283. def get_prediction(self, head_out_list):
  284. """
  285. head_out_list(List[Tensor]): scores, deltas, rois
  286. """
  287. pred_list = []
  288. scores_list = [F.softmax(head[0]) for head in head_out_list]
  289. scores = paddle.add_n(scores_list) / self.num_cascade_stages
  290. # Get deltas and rois from the last stage
  291. _, deltas, rois = head_out_list[-1]
  292. return scores, deltas, rois
  293. def get_refined_rois(self, ):
  294. return self.refined_rois