mask_head.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn.initializer import KaimingNormal
  18. from ppdet.core.workspace import register, create
  19. from ppdet.modeling.layers import ConvNormLayer
  20. from .roi_extractor import RoIAlign
  21. from ..cls_utils import _get_class_default_kwargs
  22. @register
  23. class MaskFeat(nn.Layer):
  24. """
  25. Feature extraction in Mask head
  26. Args:
  27. in_channel (int): Input channels
  28. out_channel (int): Output channels
  29. num_convs (int): The number of conv layers, default 4
  30. norm_type (string | None): Norm type, bn, gn, sync_bn are available,
  31. default None
  32. """
  33. def __init__(self,
  34. in_channel=256,
  35. out_channel=256,
  36. num_convs=4,
  37. norm_type=None):
  38. super(MaskFeat, self).__init__()
  39. self.num_convs = num_convs
  40. self.in_channel = in_channel
  41. self.out_channel = out_channel
  42. self.norm_type = norm_type
  43. fan_conv = out_channel * 3 * 3
  44. fan_deconv = out_channel * 2 * 2
  45. mask_conv = nn.Sequential()
  46. if norm_type == 'gn':
  47. for i in range(self.num_convs):
  48. conv_name = 'mask_inter_feat_{}'.format(i + 1)
  49. mask_conv.add_sublayer(
  50. conv_name,
  51. ConvNormLayer(
  52. ch_in=in_channel if i == 0 else out_channel,
  53. ch_out=out_channel,
  54. filter_size=3,
  55. stride=1,
  56. norm_type=self.norm_type,
  57. initializer=KaimingNormal(fan_in=fan_conv),
  58. skip_quant=True))
  59. mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
  60. else:
  61. for i in range(self.num_convs):
  62. conv_name = 'mask_inter_feat_{}'.format(i + 1)
  63. conv = nn.Conv2D(
  64. in_channels=in_channel if i == 0 else out_channel,
  65. out_channels=out_channel,
  66. kernel_size=3,
  67. padding=1,
  68. weight_attr=paddle.ParamAttr(
  69. initializer=KaimingNormal(fan_in=fan_conv)))
  70. conv.skip_quant = True
  71. mask_conv.add_sublayer(conv_name, conv)
  72. mask_conv.add_sublayer(conv_name + 'act', nn.ReLU())
  73. mask_conv.add_sublayer(
  74. 'conv5_mask',
  75. nn.Conv2DTranspose(
  76. in_channels=self.out_channel if num_convs > 0 else self.in_channel,
  77. out_channels=self.out_channel,
  78. kernel_size=2,
  79. stride=2,
  80. weight_attr=paddle.ParamAttr(
  81. initializer=KaimingNormal(fan_in=fan_deconv))))
  82. mask_conv.add_sublayer('conv5_mask' + 'act', nn.ReLU())
  83. self.upsample = mask_conv
  84. @classmethod
  85. def from_config(cls, cfg, input_shape):
  86. if isinstance(input_shape, (list, tuple)):
  87. input_shape = input_shape[0]
  88. return {'in_channel': input_shape.channels, }
  89. def out_channels(self):
  90. return self.out_channel
  91. def forward(self, feats):
  92. return self.upsample(feats)
  93. @register
  94. class MaskHead(nn.Layer):
  95. __shared__ = ['num_classes', 'export_onnx']
  96. __inject__ = ['mask_assigner']
  97. """
  98. RCNN mask head
  99. Args:
  100. head (nn.Layer): Extract feature in mask head
  101. roi_extractor (object): The module of RoI Extractor
  102. mask_assigner (object): The module of Mask Assigner,
  103. label and sample the mask
  104. num_classes (int): The number of classes
  105. share_bbox_feat (bool): Whether to share the feature from bbox head,
  106. default false
  107. """
  108. def __init__(self,
  109. head,
  110. roi_extractor=_get_class_default_kwargs(RoIAlign),
  111. mask_assigner='MaskAssigner',
  112. num_classes=80,
  113. share_bbox_feat=False,
  114. export_onnx=False):
  115. super(MaskHead, self).__init__()
  116. self.num_classes = num_classes
  117. self.export_onnx = export_onnx
  118. self.roi_extractor = roi_extractor
  119. if isinstance(roi_extractor, dict):
  120. self.roi_extractor = RoIAlign(**roi_extractor)
  121. self.head = head
  122. self.in_channels = head.out_channels()
  123. self.mask_assigner = mask_assigner
  124. self.share_bbox_feat = share_bbox_feat
  125. self.bbox_head = None
  126. self.mask_fcn_logits = nn.Conv2D(
  127. in_channels=self.in_channels,
  128. out_channels=self.num_classes,
  129. kernel_size=1,
  130. weight_attr=paddle.ParamAttr(initializer=KaimingNormal(
  131. fan_in=self.num_classes)))
  132. self.mask_fcn_logits.skip_quant = True
  133. @classmethod
  134. def from_config(cls, cfg, input_shape):
  135. roi_pooler = cfg['roi_extractor']
  136. assert isinstance(roi_pooler, dict)
  137. kwargs = RoIAlign.from_config(cfg, input_shape)
  138. roi_pooler.update(kwargs)
  139. kwargs = {'input_shape': input_shape}
  140. head = create(cfg['head'], **kwargs)
  141. return {
  142. 'roi_extractor': roi_pooler,
  143. 'head': head,
  144. }
  145. def get_loss(self, mask_logits, mask_label, mask_target, mask_weight):
  146. mask_label = F.one_hot(mask_label, self.num_classes).unsqueeze([2, 3])
  147. mask_label = paddle.expand_as(mask_label, mask_logits)
  148. mask_label.stop_gradient = True
  149. mask_pred = paddle.gather_nd(mask_logits, paddle.nonzero(mask_label))
  150. shape = mask_logits.shape
  151. mask_pred = paddle.reshape(mask_pred, [shape[0], shape[2], shape[3]])
  152. mask_target = mask_target.cast('float32')
  153. mask_weight = mask_weight.unsqueeze([1, 2])
  154. loss_mask = F.binary_cross_entropy_with_logits(
  155. mask_pred, mask_target, weight=mask_weight, reduction="mean")
  156. return loss_mask
  157. def forward_train(self, body_feats, rois, rois_num, inputs, targets,
  158. bbox_feat):
  159. """
  160. body_feats (list[Tensor]): Multi-level backbone features
  161. rois (list[Tensor]): Proposals for each batch with shape [N, 4]
  162. rois_num (Tensor): The number of proposals for each batch
  163. inputs (dict): ground truth info
  164. """
  165. tgt_labels, _, tgt_gt_inds = targets
  166. rois, rois_num, tgt_classes, tgt_masks, mask_index, tgt_weights = self.mask_assigner(
  167. rois, tgt_labels, tgt_gt_inds, inputs)
  168. if self.share_bbox_feat:
  169. rois_feat = paddle.gather(bbox_feat, mask_index)
  170. else:
  171. rois_feat = self.roi_extractor(body_feats, rois, rois_num)
  172. mask_feat = self.head(rois_feat)
  173. mask_logits = self.mask_fcn_logits(mask_feat)
  174. loss_mask = self.get_loss(mask_logits, tgt_classes, tgt_masks,
  175. tgt_weights)
  176. return {'loss_mask': loss_mask}
  177. def forward_test(self,
  178. body_feats,
  179. rois,
  180. rois_num,
  181. scale_factor,
  182. feat_func=None):
  183. """
  184. body_feats (list[Tensor]): Multi-level backbone features
  185. rois (Tensor): Prediction from bbox head with shape [N, 6]
  186. rois_num (Tensor): The number of prediction for each batch
  187. scale_factor (Tensor): The scale factor from origin size to input size
  188. """
  189. if not self.export_onnx and rois.shape[0] == 0:
  190. mask_out = paddle.full([1, 1, 1], -1)
  191. else:
  192. bbox = [rois[:, 2:]]
  193. labels = rois[:, 0].cast('int32')
  194. rois_feat = self.roi_extractor(body_feats, bbox, rois_num)
  195. if self.share_bbox_feat:
  196. assert feat_func is not None
  197. rois_feat = feat_func(rois_feat)
  198. mask_feat = self.head(rois_feat)
  199. mask_logit = self.mask_fcn_logits(mask_feat)
  200. if self.num_classes == 1:
  201. mask_out = F.sigmoid(mask_logit)[:, 0, :, :]
  202. else:
  203. num_masks = paddle.shape(mask_logit)[0]
  204. index = paddle.arange(num_masks).cast('int32')
  205. mask_out = mask_logit[index, labels]
  206. mask_out_shape = paddle.shape(mask_out)
  207. mask_out = paddle.reshape(mask_out, [
  208. paddle.shape(index), mask_out_shape[-2], mask_out_shape[-1]
  209. ])
  210. mask_out = F.sigmoid(mask_out)
  211. return mask_out
  212. def forward(self,
  213. body_feats,
  214. rois,
  215. rois_num,
  216. inputs,
  217. targets=None,
  218. bbox_feat=None,
  219. feat_func=None):
  220. if self.training:
  221. return self.forward_train(body_feats, rois, rois_num, inputs,
  222. targets, bbox_feat)
  223. else:
  224. im_scale = inputs['scale_factor']
  225. return self.forward_test(body_feats, rois, rois_num, im_scale,
  226. feat_func)