fcosr_head.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from ppdet.core.workspace import register
  18. from paddle import ParamAttr
  19. from paddle.regularizer import L2Decay
  20. from .fcos_head import ScaleReg
  21. from ..initializer import bias_init_with_prob, constant_, normal_
  22. from ..ops import get_act_fn, anchor_generator
  23. from ..rbox_utils import box2corners
  24. from ..losses import ProbIoULoss
  25. import numpy as np
  26. __all__ = ['FCOSRHead']
  27. def trunc_div(a, b):
  28. ipt = paddle.divide(a, b)
  29. sign_ipt = paddle.sign(ipt)
  30. abs_ipt = paddle.abs(ipt)
  31. abs_ipt = paddle.floor(abs_ipt)
  32. out = paddle.multiply(sign_ipt, abs_ipt)
  33. return out
  34. def fmod(a, b):
  35. return a - trunc_div(a, b) * b
  36. def fmod_eval(a, b):
  37. return a - a.divide(b).cast(paddle.int32).cast(paddle.float32) * b
  38. class ConvBNLayer(nn.Layer):
  39. def __init__(self,
  40. ch_in,
  41. ch_out,
  42. filter_size=3,
  43. stride=1,
  44. groups=1,
  45. padding=0,
  46. norm_cfg={'name': 'gn',
  47. 'num_groups': 32},
  48. act=None):
  49. super(ConvBNLayer, self).__init__()
  50. self.conv = nn.Conv2D(
  51. in_channels=ch_in,
  52. out_channels=ch_out,
  53. kernel_size=filter_size,
  54. stride=stride,
  55. padding=padding,
  56. groups=groups,
  57. bias_attr=False)
  58. norm_type = norm_cfg['name']
  59. if norm_type in ['sync_bn', 'bn']:
  60. self.norm = nn.BatchNorm2D(
  61. ch_out,
  62. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  63. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  64. else:
  65. groups = norm_cfg.get('num_groups', 1)
  66. self.norm = nn.GroupNorm(
  67. num_groups=groups,
  68. num_channels=ch_out,
  69. weight_attr=ParamAttr(regularizer=L2Decay(0.0)),
  70. bias_attr=ParamAttr(regularizer=L2Decay(0.0)))
  71. self.act = get_act_fn(act) if act is None or isinstance(act, (
  72. str, dict)) else act
  73. def forward(self, x):
  74. x = self.conv(x)
  75. x = self.norm(x)
  76. x = self.act(x)
  77. return x
  78. @register
  79. class FCOSRHead(nn.Layer):
  80. """ FCOSR Head, refer to https://arxiv.org/abs/2111.10780 for details """
  81. __shared__ = ['num_classes', 'trt']
  82. __inject__ = ['assigner', 'nms']
  83. def __init__(self,
  84. num_classes=15,
  85. in_channels=256,
  86. feat_channels=256,
  87. stacked_convs=4,
  88. act='relu',
  89. fpn_strides=[4, 8, 16, 32, 64],
  90. trt=False,
  91. loss_weight={'class': 1.0,
  92. 'probiou': 1.0},
  93. norm_cfg={'name': 'gn',
  94. 'num_groups': 32},
  95. assigner='FCOSRAssigner',
  96. nms='MultiClassNMS'):
  97. super(FCOSRHead, self).__init__()
  98. self.in_channels = in_channels
  99. self.num_classes = num_classes
  100. self.fpn_strides = fpn_strides
  101. self.stacked_convs = stacked_convs
  102. self.loss_weight = loss_weight
  103. self.half_pi = paddle.to_tensor(
  104. [1.5707963267948966], dtype=paddle.float32)
  105. self.probiou_loss = ProbIoULoss(mode='l1')
  106. act = get_act_fn(
  107. act, trt=trt) if act is None or isinstance(act,
  108. (str, dict)) else act
  109. self.trt = trt
  110. self.loss_weight = loss_weight
  111. self.assigner = assigner
  112. self.nms = nms
  113. # stem
  114. self.stem_cls = nn.LayerList()
  115. self.stem_reg = nn.LayerList()
  116. for i in range(self.stacked_convs):
  117. self.stem_cls.append(
  118. ConvBNLayer(
  119. self.in_channels[i],
  120. feat_channels,
  121. filter_size=3,
  122. stride=1,
  123. padding=1,
  124. norm_cfg=norm_cfg,
  125. act=act))
  126. self.stem_reg.append(
  127. ConvBNLayer(
  128. self.in_channels[i],
  129. feat_channels,
  130. filter_size=3,
  131. stride=1,
  132. padding=1,
  133. norm_cfg=norm_cfg,
  134. act=act))
  135. self.scales = nn.LayerList(
  136. [ScaleReg() for _ in range(len(fpn_strides))])
  137. # prediction
  138. self.pred_cls = nn.Conv2D(feat_channels, self.num_classes, 3, padding=1)
  139. self.pred_xy = nn.Conv2D(feat_channels, 2, 3, padding=1)
  140. self.pred_wh = nn.Conv2D(feat_channels, 2, 3, padding=1)
  141. self.pred_angle = nn.Conv2D(feat_channels, 1, 3, padding=1)
  142. self._init_weights()
  143. def _init_weights(self):
  144. for cls_, reg_ in zip(self.stem_cls, self.stem_reg):
  145. normal_(cls_.conv.weight, std=0.01)
  146. normal_(reg_.conv.weight, std=0.01)
  147. bias_cls = bias_init_with_prob(0.01)
  148. normal_(self.pred_cls.weight, std=0.01)
  149. constant_(self.pred_cls.bias, bias_cls)
  150. normal_(self.pred_xy.weight, std=0.01)
  151. normal_(self.pred_wh.weight, std=0.01)
  152. normal_(self.pred_angle.weight, std=0.01)
  153. @classmethod
  154. def from_config(cls, cfg, input_shape):
  155. return {'in_channels': [i.channels for i in input_shape], }
  156. def _generate_anchors(self, feats):
  157. if self.trt:
  158. anchor_points = []
  159. for feat, stride in zip(feats, self.fpn_strides):
  160. _, _, h, w = paddle.shape(feat)
  161. anchor, _ = anchor_generator(
  162. feat,
  163. stride * 4,
  164. 1.0, [1.0, 1.0, 1.0, 1.0], [stride, stride],
  165. offset=0.5)
  166. x1, y1, x2, y2 = paddle.split(anchor, 4, axis=-1)
  167. xc = (x1 + x2 + 1) / 2
  168. yc = (y1 + y2 + 1) / 2
  169. anchor_point = paddle.concat(
  170. [xc, yc], axis=-1).reshape((1, h * w, 2))
  171. anchor_points.append(anchor_point)
  172. anchor_points = paddle.concat(anchor_points, axis=1)
  173. return anchor_points, None, None
  174. else:
  175. anchor_points = []
  176. stride_tensor = []
  177. num_anchors_list = []
  178. for feat, stride in zip(feats, self.fpn_strides):
  179. _, _, h, w = paddle.shape(feat)
  180. shift_x = (paddle.arange(end=w) + 0.5) * stride
  181. shift_y = (paddle.arange(end=h) + 0.5) * stride
  182. shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
  183. anchor_point = paddle.cast(
  184. paddle.stack(
  185. [shift_x, shift_y], axis=-1), dtype='float32')
  186. anchor_points.append(anchor_point.reshape([1, -1, 2]))
  187. stride_tensor.append(
  188. paddle.full(
  189. [1, h * w, 1], stride, dtype='float32'))
  190. num_anchors_list.append(h * w)
  191. anchor_points = paddle.concat(anchor_points, axis=1)
  192. stride_tensor = paddle.concat(stride_tensor, axis=1)
  193. return anchor_points, stride_tensor, num_anchors_list
  194. def forward(self, feats, target=None):
  195. if self.training:
  196. return self.forward_train(feats, target)
  197. else:
  198. return self.forward_eval(feats, target)
  199. def forward_train(self, feats, target=None):
  200. anchor_points, stride_tensor, num_anchors_list = self._generate_anchors(
  201. feats)
  202. cls_pred_list, reg_pred_list = [], []
  203. for stride, feat, scale in zip(self.fpn_strides, feats, self.scales):
  204. # cls
  205. cls_feat = feat
  206. for cls_layer in self.stem_cls:
  207. cls_feat = cls_layer(cls_feat)
  208. cls_pred = F.sigmoid(self.pred_cls(cls_feat))
  209. cls_pred_list.append(cls_pred.flatten(2).transpose((0, 2, 1)))
  210. # reg
  211. reg_feat = feat
  212. for reg_layer in self.stem_reg:
  213. reg_feat = reg_layer(reg_feat)
  214. reg_xy = scale(self.pred_xy(reg_feat)) * stride
  215. reg_wh = F.elu(scale(self.pred_wh(reg_feat)) + 1.) * stride
  216. reg_angle = self.pred_angle(reg_feat)
  217. reg_angle = fmod(reg_angle, self.half_pi)
  218. reg_pred = paddle.concat([reg_xy, reg_wh, reg_angle], axis=1)
  219. reg_pred_list.append(reg_pred.flatten(2).transpose((0, 2, 1)))
  220. cls_pred_list = paddle.concat(cls_pred_list, axis=1)
  221. reg_pred_list = paddle.concat(reg_pred_list, axis=1)
  222. return self.get_loss([
  223. cls_pred_list, reg_pred_list, anchor_points, stride_tensor,
  224. num_anchors_list
  225. ], target)
  226. def forward_eval(self, feats, target=None):
  227. cls_pred_list, reg_pred_list = [], []
  228. anchor_points, _, _ = self._generate_anchors(feats)
  229. for stride, feat, scale in zip(self.fpn_strides, feats, self.scales):
  230. b, _, h, w = paddle.shape(feat)
  231. # cls
  232. cls_feat = feat
  233. for cls_layer in self.stem_cls:
  234. cls_feat = cls_layer(cls_feat)
  235. cls_pred = F.sigmoid(self.pred_cls(cls_feat))
  236. cls_pred_list.append(cls_pred.reshape([b, self.num_classes, h * w]))
  237. # reg
  238. reg_feat = feat
  239. for reg_layer in self.stem_reg:
  240. reg_feat = reg_layer(reg_feat)
  241. reg_xy = scale(self.pred_xy(reg_feat)) * stride
  242. reg_wh = F.elu(scale(self.pred_wh(reg_feat)) + 1.) * stride
  243. reg_angle = self.pred_angle(reg_feat)
  244. reg_angle = fmod_eval(reg_angle, self.half_pi)
  245. reg_pred = paddle.concat([reg_xy, reg_wh, reg_angle], axis=1)
  246. reg_pred = reg_pred.reshape([b, 5, h * w]).transpose((0, 2, 1))
  247. reg_pred_list.append(reg_pred)
  248. cls_pred_list = paddle.concat(cls_pred_list, axis=2)
  249. reg_pred_list = paddle.concat(reg_pred_list, axis=1)
  250. reg_pred_list = self._bbox_decode(anchor_points, reg_pred_list)
  251. return cls_pred_list, reg_pred_list
  252. def _bbox_decode(self, points, reg_pred_list):
  253. xy, wha = paddle.split(reg_pred_list, [2, 3], axis=-1)
  254. xy = xy + points
  255. return paddle.concat([xy, wha], axis=-1)
  256. def _box2corners(self, pred_bboxes):
  257. """ convert (x, y, w, h, angle) to (x1, y1, x2, y2, x3, y3, x4, y4)
  258. Args:
  259. pred_bboxes (Tensor): [B, N, 5]
  260. Returns:
  261. polys (Tensor): [B, N, 8]
  262. """
  263. x, y, w, h, angle = paddle.split(pred_bboxes, 5, axis=-1)
  264. cos_a_half = paddle.cos(angle) * 0.5
  265. sin_a_half = paddle.sin(angle) * 0.5
  266. w_x = cos_a_half * w
  267. w_y = sin_a_half * w
  268. h_x = -sin_a_half * h
  269. h_y = cos_a_half * h
  270. return paddle.concat(
  271. [
  272. x + w_x + h_x, y + w_y + h_y, x - w_x + h_x, y - w_y + h_y,
  273. x - w_x - h_x, y - w_y - h_y, x + w_x - h_x, y + w_y - h_y
  274. ],
  275. axis=-1)
  276. def get_loss(self, head_outs, gt_meta):
  277. cls_pred_list, reg_pred_list, anchor_points, stride_tensor, num_anchors_list = head_outs
  278. gt_labels = gt_meta['gt_class']
  279. gt_bboxes = gt_meta['gt_bbox']
  280. gt_rboxes = gt_meta['gt_rbox']
  281. pad_gt_mask = gt_meta['pad_gt_mask']
  282. # decode
  283. pred_rboxes = self._bbox_decode(anchor_points, reg_pred_list)
  284. # label assignment
  285. assigned_labels, assigned_rboxes, assigned_scores = \
  286. self.assigner(
  287. anchor_points,
  288. stride_tensor,
  289. num_anchors_list,
  290. gt_labels,
  291. gt_bboxes,
  292. gt_rboxes,
  293. pad_gt_mask,
  294. self.num_classes,
  295. pred_rboxes
  296. )
  297. # reg_loss
  298. mask_positive = (assigned_labels != self.num_classes)
  299. num_pos = mask_positive.sum().item()
  300. if num_pos > 0:
  301. bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 5])
  302. pred_rboxes_pos = paddle.masked_select(pred_rboxes,
  303. bbox_mask).reshape([-1, 5])
  304. assigned_rboxes_pos = paddle.masked_select(
  305. assigned_rboxes, bbox_mask).reshape([-1, 5])
  306. bbox_weight = paddle.masked_select(
  307. assigned_scores.sum(-1), mask_positive).reshape([-1])
  308. avg_factor = bbox_weight.sum()
  309. loss_probiou = self.probiou_loss(pred_rboxes_pos,
  310. assigned_rboxes_pos)
  311. loss_probiou = paddle.sum(loss_probiou * bbox_weight) / avg_factor
  312. else:
  313. loss_probiou = pred_rboxes.sum() * 0.
  314. avg_factor = max(num_pos, 1.0)
  315. # cls_loss
  316. loss_cls = self._qfocal_loss(
  317. cls_pred_list, assigned_scores, reduction='sum')
  318. loss_cls = loss_cls / avg_factor
  319. loss = self.loss_weight['class'] * loss_cls + \
  320. self.loss_weight['probiou'] * loss_probiou
  321. out_dict = {
  322. 'loss': loss,
  323. 'loss_probiou': loss_probiou,
  324. 'loss_cls': loss_cls
  325. }
  326. return out_dict
  327. @staticmethod
  328. def _qfocal_loss(score, label, gamma=2.0, reduction='sum'):
  329. weight = (score - label).pow(gamma)
  330. loss = F.binary_cross_entropy(
  331. score, label, weight=weight, reduction=reduction)
  332. return loss
  333. def post_process(self, head_outs, scale_factor):
  334. pred_scores, pred_rboxes = head_outs
  335. # [B, N, 5] -> [B, N, 4, 2] -> [B, N, 8]
  336. pred_rboxes = self._box2corners(pred_rboxes)
  337. # scale bbox to origin
  338. scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
  339. scale_factor = paddle.concat(
  340. [
  341. scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x,
  342. scale_y
  343. ],
  344. axis=-1).reshape([-1, 1, 8])
  345. pred_rboxes /= scale_factor
  346. bbox_pred, bbox_num, _ = self.nms(pred_rboxes, pred_scores)
  347. return bbox_pred, bbox_num