fcos_head.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import math
  18. import paddle
  19. import paddle.nn as nn
  20. import paddle.nn.functional as F
  21. from paddle import ParamAttr
  22. from paddle.nn.initializer import Normal, Constant
  23. from ppdet.core.workspace import register
  24. from ppdet.modeling.layers import ConvNormLayer, MultiClassNMS
  25. __all__ = ['FCOSFeat', 'FCOSHead']
  26. class ScaleReg(nn.Layer):
  27. """
  28. Parameter for scaling the regression outputs.
  29. """
  30. def __init__(self):
  31. super(ScaleReg, self).__init__()
  32. self.scale_reg = self.create_parameter(
  33. shape=[1],
  34. attr=ParamAttr(initializer=Constant(value=1.)),
  35. dtype="float32")
  36. def forward(self, inputs):
  37. out = inputs * self.scale_reg
  38. return out
  39. @register
  40. class FCOSFeat(nn.Layer):
  41. """
  42. FCOSFeat of FCOS
  43. Args:
  44. feat_in (int): The channel number of input Tensor.
  45. feat_out (int): The channel number of output Tensor.
  46. num_convs (int): The convolution number of the FCOSFeat.
  47. norm_type (str): Normalization type, 'bn'/'sync_bn'/'gn'.
  48. use_dcn (bool): Whether to use dcn in tower or not.
  49. """
  50. def __init__(self,
  51. feat_in=256,
  52. feat_out=256,
  53. num_convs=4,
  54. norm_type='bn',
  55. use_dcn=False):
  56. super(FCOSFeat, self).__init__()
  57. self.feat_in = feat_in
  58. self.feat_out = feat_out
  59. self.num_convs = num_convs
  60. self.norm_type = norm_type
  61. self.cls_subnet_convs = []
  62. self.reg_subnet_convs = []
  63. for i in range(self.num_convs):
  64. in_c = feat_in if i == 0 else feat_out
  65. cls_conv_name = 'fcos_head_cls_tower_conv_{}'.format(i)
  66. cls_conv = self.add_sublayer(
  67. cls_conv_name,
  68. ConvNormLayer(
  69. ch_in=in_c,
  70. ch_out=feat_out,
  71. filter_size=3,
  72. stride=1,
  73. norm_type=norm_type,
  74. use_dcn=use_dcn,
  75. bias_on=True,
  76. lr_scale=2.))
  77. self.cls_subnet_convs.append(cls_conv)
  78. reg_conv_name = 'fcos_head_reg_tower_conv_{}'.format(i)
  79. reg_conv = self.add_sublayer(
  80. reg_conv_name,
  81. ConvNormLayer(
  82. ch_in=in_c,
  83. ch_out=feat_out,
  84. filter_size=3,
  85. stride=1,
  86. norm_type=norm_type,
  87. use_dcn=use_dcn,
  88. bias_on=True,
  89. lr_scale=2.))
  90. self.reg_subnet_convs.append(reg_conv)
  91. def forward(self, fpn_feat):
  92. cls_feat = fpn_feat
  93. reg_feat = fpn_feat
  94. for i in range(self.num_convs):
  95. cls_feat = F.relu(self.cls_subnet_convs[i](cls_feat))
  96. reg_feat = F.relu(self.reg_subnet_convs[i](reg_feat))
  97. return cls_feat, reg_feat
  98. @register
  99. class FCOSHead(nn.Layer):
  100. """
  101. FCOSHead
  102. Args:
  103. num_classes (int): Number of classes
  104. fcos_feat (object): Instance of 'FCOSFeat'
  105. fpn_stride (list): The stride of each FPN Layer
  106. prior_prob (float): Used to set the bias init for the class prediction layer
  107. norm_reg_targets (bool): Normalization the regression target if true
  108. centerness_on_reg (bool): The prediction of centerness on regression or clssification branch
  109. num_shift (float): Relative offset between the center of the first shift and the top-left corner of img
  110. fcos_loss (object): Instance of 'FCOSLoss'
  111. nms (object): Instance of 'MultiClassNMS'
  112. trt (bool): Whether to use trt in nms of deploy
  113. """
  114. __inject__ = ['fcos_feat', 'fcos_loss', 'nms']
  115. __shared__ = ['num_classes', 'trt']
  116. def __init__(self,
  117. num_classes=80,
  118. fcos_feat='FCOSFeat',
  119. fpn_stride=[8, 16, 32, 64, 128],
  120. prior_prob=0.01,
  121. multiply_strides_reg_targets=False,
  122. norm_reg_targets=True,
  123. centerness_on_reg=True,
  124. num_shift=0.5,
  125. sqrt_score=False,
  126. fcos_loss='FCOSLoss',
  127. nms='MultiClassNMS',
  128. trt=False):
  129. super(FCOSHead, self).__init__()
  130. self.fcos_feat = fcos_feat
  131. self.num_classes = num_classes
  132. self.fpn_stride = fpn_stride
  133. self.prior_prob = prior_prob
  134. self.fcos_loss = fcos_loss
  135. self.norm_reg_targets = norm_reg_targets
  136. self.centerness_on_reg = centerness_on_reg
  137. self.multiply_strides_reg_targets = multiply_strides_reg_targets
  138. self.num_shift = num_shift
  139. self.nms = nms
  140. if isinstance(self.nms, MultiClassNMS) and trt:
  141. self.nms.trt = trt
  142. self.sqrt_score = sqrt_score
  143. self.is_teacher = False
  144. conv_cls_name = "fcos_head_cls"
  145. bias_init_value = -math.log((1 - self.prior_prob) / self.prior_prob)
  146. self.fcos_head_cls = self.add_sublayer(
  147. conv_cls_name,
  148. nn.Conv2D(
  149. in_channels=256,
  150. out_channels=self.num_classes,
  151. kernel_size=3,
  152. stride=1,
  153. padding=1,
  154. weight_attr=ParamAttr(initializer=Normal(
  155. mean=0., std=0.01)),
  156. bias_attr=ParamAttr(
  157. initializer=Constant(value=bias_init_value))))
  158. conv_reg_name = "fcos_head_reg"
  159. self.fcos_head_reg = self.add_sublayer(
  160. conv_reg_name,
  161. nn.Conv2D(
  162. in_channels=256,
  163. out_channels=4,
  164. kernel_size=3,
  165. stride=1,
  166. padding=1,
  167. weight_attr=ParamAttr(initializer=Normal(
  168. mean=0., std=0.01)),
  169. bias_attr=ParamAttr(initializer=Constant(value=0))))
  170. conv_centerness_name = "fcos_head_centerness"
  171. self.fcos_head_centerness = self.add_sublayer(
  172. conv_centerness_name,
  173. nn.Conv2D(
  174. in_channels=256,
  175. out_channels=1,
  176. kernel_size=3,
  177. stride=1,
  178. padding=1,
  179. weight_attr=ParamAttr(initializer=Normal(
  180. mean=0., std=0.01)),
  181. bias_attr=ParamAttr(initializer=Constant(value=0))))
  182. self.scales_regs = []
  183. for i in range(len(self.fpn_stride)):
  184. lvl = int(math.log(int(self.fpn_stride[i]), 2))
  185. feat_name = 'p{}_feat'.format(lvl)
  186. scale_reg = self.add_sublayer(feat_name, ScaleReg())
  187. self.scales_regs.append(scale_reg)
  188. def _compute_locations_by_level(self, fpn_stride, feature, num_shift=0.5):
  189. """
  190. Compute locations of anchor points of each FPN layer
  191. Args:
  192. fpn_stride (int): The stride of current FPN feature map
  193. feature (Tensor): Tensor of current FPN feature map
  194. Return:
  195. Anchor points locations of current FPN feature map
  196. """
  197. h, w = feature.shape[2], feature.shape[3]
  198. shift_x = paddle.arange(0, w * fpn_stride, fpn_stride)
  199. shift_y = paddle.arange(0, h * fpn_stride, fpn_stride)
  200. shift_x = paddle.unsqueeze(shift_x, axis=0)
  201. shift_y = paddle.unsqueeze(shift_y, axis=1)
  202. shift_x = paddle.expand(shift_x, shape=[h, w])
  203. shift_y = paddle.expand(shift_y, shape=[h, w])
  204. shift_x = paddle.reshape(shift_x, shape=[-1])
  205. shift_y = paddle.reshape(shift_y, shape=[-1])
  206. location = paddle.stack(
  207. [shift_x, shift_y], axis=-1) + float(fpn_stride * num_shift)
  208. return location
  209. def forward(self, fpn_feats, targets=None):
  210. assert len(fpn_feats) == len(
  211. self.fpn_stride
  212. ), "The size of fpn_feats is not equal to size of fpn_stride"
  213. cls_logits_list = []
  214. bboxes_reg_list = []
  215. centerness_list = []
  216. for scale_reg, fpn_stride, fpn_feat in zip(self.scales_regs,
  217. self.fpn_stride, fpn_feats):
  218. fcos_cls_feat, fcos_reg_feat = self.fcos_feat(fpn_feat)
  219. cls_logits = self.fcos_head_cls(fcos_cls_feat)
  220. bbox_reg = scale_reg(self.fcos_head_reg(fcos_reg_feat))
  221. if self.centerness_on_reg:
  222. centerness = self.fcos_head_centerness(fcos_reg_feat)
  223. else:
  224. centerness = self.fcos_head_centerness(fcos_cls_feat)
  225. if self.norm_reg_targets:
  226. bbox_reg = F.relu(bbox_reg)
  227. if self.multiply_strides_reg_targets:
  228. bbox_reg = bbox_reg * fpn_stride
  229. else:
  230. if not self.training or targets.get(
  231. 'get_data',
  232. False) or targets.get('is_teacher', False):
  233. bbox_reg = bbox_reg * fpn_stride
  234. else:
  235. bbox_reg = paddle.exp(bbox_reg)
  236. cls_logits_list.append(cls_logits)
  237. bboxes_reg_list.append(bbox_reg)
  238. centerness_list.append(centerness)
  239. if targets is not None:
  240. self.is_teacher = targets.get('is_teacher', False)
  241. if self.is_teacher:
  242. return [cls_logits_list, bboxes_reg_list, centerness_list]
  243. if self.training and targets is not None:
  244. get_data = targets.get('get_data', False)
  245. if get_data:
  246. return [cls_logits_list, bboxes_reg_list, centerness_list]
  247. losses = {}
  248. fcos_head_outs = [cls_logits_list, bboxes_reg_list, centerness_list]
  249. losses_fcos = self.get_loss(fcos_head_outs, targets)
  250. losses.update(losses_fcos)
  251. total_loss = paddle.add_n(list(losses.values()))
  252. losses.update({'loss': total_loss})
  253. return losses
  254. else:
  255. # eval or infer
  256. locations_list = []
  257. for fpn_stride, feature in zip(self.fpn_stride, fpn_feats):
  258. location = self._compute_locations_by_level(fpn_stride, feature,
  259. self.num_shift)
  260. locations_list.append(location)
  261. fcos_head_outs = [
  262. locations_list, cls_logits_list, bboxes_reg_list,
  263. centerness_list
  264. ]
  265. return fcos_head_outs
  266. def get_loss(self, fcos_head_outs, targets):
  267. cls_logits, bboxes_reg, centerness = fcos_head_outs
  268. # get labels,reg_target,centerness
  269. tag_labels, tag_bboxes, tag_centerness = [], [], []
  270. for i in range(len(self.fpn_stride)):
  271. k_lbl = 'labels{}'.format(i)
  272. if k_lbl in targets:
  273. tag_labels.append(targets[k_lbl])
  274. k_box = 'reg_target{}'.format(i)
  275. if k_box in targets:
  276. tag_bboxes.append(targets[k_box])
  277. k_ctn = 'centerness{}'.format(i)
  278. if k_ctn in targets:
  279. tag_centerness.append(targets[k_ctn])
  280. losses_fcos = self.fcos_loss(cls_logits, bboxes_reg, centerness,
  281. tag_labels, tag_bboxes, tag_centerness)
  282. return losses_fcos
  283. def _post_process_by_level(self,
  284. locations,
  285. box_cls,
  286. box_reg,
  287. box_ctn,
  288. sqrt_score=False):
  289. box_scores = F.sigmoid(box_cls).flatten(2).transpose([0, 2, 1])
  290. box_centerness = F.sigmoid(box_ctn).flatten(2).transpose([0, 2, 1])
  291. pred_scores = box_scores * box_centerness
  292. if sqrt_score:
  293. pred_scores = paddle.sqrt(pred_scores)
  294. box_reg_ch_last = box_reg.flatten(2).transpose([0, 2, 1])
  295. box_reg_decoding = paddle.stack(
  296. [
  297. locations[:, 0] - box_reg_ch_last[:, :, 0],
  298. locations[:, 1] - box_reg_ch_last[:, :, 1],
  299. locations[:, 0] + box_reg_ch_last[:, :, 2],
  300. locations[:, 1] + box_reg_ch_last[:, :, 3]
  301. ],
  302. axis=1)
  303. pred_boxes = box_reg_decoding.transpose([0, 2, 1])
  304. return pred_scores, pred_boxes
  305. def post_process(self, fcos_head_outs, scale_factor):
  306. locations, cls_logits, bboxes_reg, centerness = fcos_head_outs
  307. pred_bboxes, pred_scores = [], []
  308. for pts, cls, reg, ctn in zip(locations, cls_logits, bboxes_reg,
  309. centerness):
  310. scores, boxes = self._post_process_by_level(pts, cls, reg, ctn,
  311. self.sqrt_score)
  312. pred_scores.append(scores)
  313. pred_bboxes.append(boxes)
  314. pred_bboxes = paddle.concat(pred_bboxes, axis=1)
  315. pred_scores = paddle.concat(pred_scores, axis=1)
  316. # scale bbox to origin
  317. scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
  318. scale_factor = paddle.concat(
  319. [scale_x, scale_y, scale_x, scale_y], axis=-1).reshape([-1, 1, 4])
  320. pred_bboxes /= scale_factor
  321. pred_scores = pred_scores.transpose([0, 2, 1])
  322. bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
  323. return bbox_pred, bbox_num