s2anet_head.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. #
  15. # The code is based on https://github.com/csuhan/s2anet/blob/master/mmdet/models/anchor_heads_rotated/s2anet_head.py
  16. import paddle
  17. from paddle import ParamAttr
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle.nn.initializer import Normal, Constant
  21. from ppdet.core.workspace import register
  22. from ppdet.modeling.proposal_generator.target_layer import RBoxAssigner
  23. from ppdet.modeling.proposal_generator.anchor_generator import S2ANetAnchorGenerator
  24. from ppdet.modeling.layers import AlignConv
  25. from ..cls_utils import _get_class_default_kwargs
  26. import numpy as np
  27. @register
  28. class S2ANetHead(nn.Layer):
  29. """
  30. S2Anet head
  31. Args:
  32. stacked_convs (int): number of stacked_convs
  33. feat_in (int): input channels of feat
  34. feat_out (int): output channels of feat
  35. num_classes (int): num_classes
  36. anchor_strides (list): stride of anchors
  37. anchor_scales (list): scale of anchors
  38. anchor_ratios (list): ratios of anchors
  39. target_means (list): target_means
  40. target_stds (list): target_stds
  41. align_conv_type (str): align_conv_type ['Conv', 'AlignConv']
  42. align_conv_size (int): kernel size of align_conv
  43. use_sigmoid_cls (bool): use sigmoid_cls or not
  44. reg_loss_weight (list): loss weight for regression
  45. """
  46. __shared__ = ['num_classes']
  47. __inject__ = ['anchor_assign', 'nms']
  48. def __init__(self,
  49. stacked_convs=2,
  50. feat_in=256,
  51. feat_out=256,
  52. num_classes=15,
  53. anchor_strides=[8, 16, 32, 64, 128],
  54. anchor_scales=[4],
  55. anchor_ratios=[1.0],
  56. target_means=0.0,
  57. target_stds=1.0,
  58. align_conv_type='AlignConv',
  59. align_conv_size=3,
  60. use_sigmoid_cls=True,
  61. anchor_assign=_get_class_default_kwargs(RBoxAssigner),
  62. reg_loss_weight=[1.0, 1.0, 1.0, 1.0, 1.1],
  63. cls_loss_weight=[1.1, 1.05],
  64. reg_loss_type='l1',
  65. nms_pre=2000,
  66. nms='MultiClassNMS'):
  67. super(S2ANetHead, self).__init__()
  68. self.stacked_convs = stacked_convs
  69. self.feat_in = feat_in
  70. self.feat_out = feat_out
  71. self.anchor_list = None
  72. self.anchor_scales = anchor_scales
  73. self.anchor_ratios = anchor_ratios
  74. self.anchor_strides = anchor_strides
  75. self.anchor_strides = paddle.to_tensor(anchor_strides)
  76. self.anchor_base_sizes = list(anchor_strides)
  77. self.means = paddle.ones(shape=[5]) * target_means
  78. self.stds = paddle.ones(shape=[5]) * target_stds
  79. assert align_conv_type in ['AlignConv', 'Conv', 'DCN']
  80. self.align_conv_type = align_conv_type
  81. self.align_conv_size = align_conv_size
  82. self.use_sigmoid_cls = use_sigmoid_cls
  83. self.cls_out_channels = num_classes if self.use_sigmoid_cls else num_classes + 1
  84. self.sampling = False
  85. self.anchor_assign = anchor_assign
  86. self.reg_loss_weight = reg_loss_weight
  87. self.cls_loss_weight = cls_loss_weight
  88. self.alpha = 1.0
  89. self.beta = 1.0
  90. self.reg_loss_type = reg_loss_type
  91. self.nms_pre = nms_pre
  92. self.nms = nms
  93. self.fake_bbox = paddle.to_tensor(
  94. np.array(
  95. [[-1, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]],
  96. dtype='float32'))
  97. self.fake_bbox_num = paddle.to_tensor(np.array([1], dtype='int32'))
  98. # anchor
  99. self.anchor_generators = []
  100. for anchor_base in self.anchor_base_sizes:
  101. self.anchor_generators.append(
  102. S2ANetAnchorGenerator(anchor_base, anchor_scales,
  103. anchor_ratios))
  104. self.anchor_generators = nn.LayerList(self.anchor_generators)
  105. self.fam_cls_convs = nn.Sequential()
  106. self.fam_reg_convs = nn.Sequential()
  107. for i in range(self.stacked_convs):
  108. chan_in = self.feat_in if i == 0 else self.feat_out
  109. self.fam_cls_convs.add_sublayer(
  110. 'fam_cls_conv_{}'.format(i),
  111. nn.Conv2D(
  112. in_channels=chan_in,
  113. out_channels=self.feat_out,
  114. kernel_size=3,
  115. padding=1,
  116. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  117. bias_attr=ParamAttr(initializer=Constant(0))))
  118. self.fam_cls_convs.add_sublayer('fam_cls_conv_{}_act'.format(i),
  119. nn.ReLU())
  120. self.fam_reg_convs.add_sublayer(
  121. 'fam_reg_conv_{}'.format(i),
  122. nn.Conv2D(
  123. in_channels=chan_in,
  124. out_channels=self.feat_out,
  125. kernel_size=3,
  126. padding=1,
  127. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  128. bias_attr=ParamAttr(initializer=Constant(0))))
  129. self.fam_reg_convs.add_sublayer('fam_reg_conv_{}_act'.format(i),
  130. nn.ReLU())
  131. self.fam_reg = nn.Conv2D(
  132. self.feat_out,
  133. 5,
  134. 1,
  135. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  136. bias_attr=ParamAttr(initializer=Constant(0)))
  137. prior_prob = 0.01
  138. bias_init = float(-np.log((1 - prior_prob) / prior_prob))
  139. self.fam_cls = nn.Conv2D(
  140. self.feat_out,
  141. self.cls_out_channels,
  142. 1,
  143. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  144. bias_attr=ParamAttr(initializer=Constant(bias_init)))
  145. if self.align_conv_type == "AlignConv":
  146. self.align_conv = AlignConv(self.feat_out, self.feat_out,
  147. self.align_conv_size)
  148. elif self.align_conv_type == "Conv":
  149. self.align_conv = nn.Conv2D(
  150. self.feat_out,
  151. self.feat_out,
  152. self.align_conv_size,
  153. padding=(self.align_conv_size - 1) // 2,
  154. bias_attr=ParamAttr(initializer=Constant(0)))
  155. elif self.align_conv_type == "DCN":
  156. self.align_conv_offset = nn.Conv2D(
  157. self.feat_out,
  158. 2 * self.align_conv_size**2,
  159. 1,
  160. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  161. bias_attr=ParamAttr(initializer=Constant(0)))
  162. self.align_conv = paddle.vision.ops.DeformConv2D(
  163. self.feat_out,
  164. self.feat_out,
  165. self.align_conv_size,
  166. padding=(self.align_conv_size - 1) // 2,
  167. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  168. bias_attr=False)
  169. self.or_conv = nn.Conv2D(
  170. self.feat_out,
  171. self.feat_out,
  172. kernel_size=3,
  173. padding=1,
  174. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  175. bias_attr=ParamAttr(initializer=Constant(0)))
  176. # ODM
  177. self.odm_cls_convs = nn.Sequential()
  178. self.odm_reg_convs = nn.Sequential()
  179. for i in range(self.stacked_convs):
  180. ch_in = self.feat_out
  181. # ch_in = int(self.feat_out / 8) if i == 0 else self.feat_out
  182. self.odm_cls_convs.add_sublayer(
  183. 'odm_cls_conv_{}'.format(i),
  184. nn.Conv2D(
  185. in_channels=ch_in,
  186. out_channels=self.feat_out,
  187. kernel_size=3,
  188. stride=1,
  189. padding=1,
  190. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  191. bias_attr=ParamAttr(initializer=Constant(0))))
  192. self.odm_cls_convs.add_sublayer('odm_cls_conv_{}_act'.format(i),
  193. nn.ReLU())
  194. self.odm_reg_convs.add_sublayer(
  195. 'odm_reg_conv_{}'.format(i),
  196. nn.Conv2D(
  197. in_channels=self.feat_out,
  198. out_channels=self.feat_out,
  199. kernel_size=3,
  200. stride=1,
  201. padding=1,
  202. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  203. bias_attr=ParamAttr(initializer=Constant(0))))
  204. self.odm_reg_convs.add_sublayer('odm_reg_conv_{}_act'.format(i),
  205. nn.ReLU())
  206. self.odm_cls = nn.Conv2D(
  207. self.feat_out,
  208. self.cls_out_channels,
  209. 3,
  210. padding=1,
  211. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  212. bias_attr=ParamAttr(initializer=Constant(bias_init)))
  213. self.odm_reg = nn.Conv2D(
  214. self.feat_out,
  215. 5,
  216. 3,
  217. padding=1,
  218. weight_attr=ParamAttr(initializer=Normal(0.0, 0.01)),
  219. bias_attr=ParamAttr(initializer=Constant(0)))
  220. def forward(self, feats, targets=None):
  221. fam_reg_list, fam_cls_list = [], []
  222. odm_reg_list, odm_cls_list = [], []
  223. num_anchors_list, base_anchors_list, refine_anchors_list = [], [], []
  224. for i, feat in enumerate(feats):
  225. # get shape
  226. B = feat.shape[0]
  227. H, W = paddle.shape(feat)[2], paddle.shape(feat)[3]
  228. NA = H * W
  229. num_anchors_list.append(NA)
  230. fam_cls_feat = self.fam_cls_convs(feat)
  231. fam_cls = self.fam_cls(fam_cls_feat)
  232. # [N, CLS, H, W] --> [N, H, W, CLS]
  233. fam_cls = fam_cls.transpose([0, 2, 3, 1]).reshape(
  234. [B, NA, self.cls_out_channels])
  235. fam_cls_list.append(fam_cls)
  236. fam_reg_feat = self.fam_reg_convs(feat)
  237. fam_reg = self.fam_reg(fam_reg_feat)
  238. # [N, 5, H, W] --> [N, H, W, 5]
  239. fam_reg = fam_reg.transpose([0, 2, 3, 1]).reshape([B, NA, 5])
  240. fam_reg_list.append(fam_reg)
  241. # prepare anchor
  242. init_anchors = self.anchor_generators[i]((H, W),
  243. self.anchor_strides[i])
  244. init_anchors = init_anchors.reshape([1, NA, 5])
  245. base_anchors_list.append(init_anchors.squeeze(0))
  246. if self.training:
  247. refine_anchor = self.bbox_decode(fam_reg.detach(), init_anchors)
  248. else:
  249. refine_anchor = self.bbox_decode(fam_reg, init_anchors)
  250. refine_anchors_list.append(refine_anchor)
  251. if self.align_conv_type == 'AlignConv':
  252. align_feat = self.align_conv(feat,
  253. refine_anchor.clone(), (H, W),
  254. self.anchor_strides[i])
  255. elif self.align_conv_type == 'DCN':
  256. align_offset = self.align_conv_offset(feat)
  257. align_feat = self.align_conv(feat, align_offset)
  258. elif self.align_conv_type == 'Conv':
  259. align_feat = self.align_conv(feat)
  260. or_feat = self.or_conv(align_feat)
  261. odm_reg_feat = or_feat
  262. odm_cls_feat = or_feat
  263. odm_reg_feat = self.odm_reg_convs(odm_reg_feat)
  264. odm_cls_feat = self.odm_cls_convs(odm_cls_feat)
  265. odm_cls = self.odm_cls(odm_cls_feat)
  266. # [N, CLS, H, W] --> [N, H, W, CLS]
  267. odm_cls = odm_cls.transpose([0, 2, 3, 1]).reshape(
  268. [B, NA, self.cls_out_channels])
  269. odm_cls_list.append(odm_cls)
  270. odm_reg = self.odm_reg(odm_reg_feat)
  271. # [N, 5, H, W] --> [N, H, W, 5]
  272. odm_reg = odm_reg.transpose([0, 2, 3, 1]).reshape([B, NA, 5])
  273. odm_reg_list.append(odm_reg)
  274. if self.training:
  275. return self.get_loss([
  276. fam_cls_list, fam_reg_list, odm_cls_list, odm_reg_list,
  277. num_anchors_list, base_anchors_list, refine_anchors_list
  278. ], targets)
  279. else:
  280. odm_bboxes_list = []
  281. for odm_reg, refine_anchor in zip(odm_reg_list,
  282. refine_anchors_list):
  283. odm_bboxes = self.bbox_decode(odm_reg, refine_anchor)
  284. odm_bboxes_list.append(odm_bboxes)
  285. return [odm_bboxes_list, odm_cls_list]
  286. def get_bboxes(self, head_outs):
  287. perd_bboxes_list, pred_scores_list = head_outs
  288. batch = paddle.shape(pred_scores_list[0])[0]
  289. bboxes, bbox_num = [], []
  290. for i in range(batch):
  291. pred_scores_per_image = [t[i] for t in pred_scores_list]
  292. pred_bboxes_per_image = [t[i] for t in perd_bboxes_list]
  293. bbox_per_image, bbox_num_per_image = self.get_bboxes_single(
  294. pred_scores_per_image, pred_bboxes_per_image)
  295. bboxes.append(bbox_per_image)
  296. bbox_num.append(bbox_num_per_image)
  297. bboxes = paddle.concat(bboxes)
  298. bbox_num = paddle.concat(bbox_num)
  299. return bboxes, bbox_num
  300. def get_pred(self, bboxes, bbox_num, im_shape, scale_factor):
  301. """
  302. Rescale, clip and filter the bbox from the output of NMS to
  303. get final prediction.
  304. Args:
  305. bboxes(Tensor): bboxes [N, 10]
  306. bbox_num(Tensor): bbox_num
  307. im_shape(Tensor): [1 2]
  308. scale_factor(Tensor): [1 2]
  309. Returns:
  310. bbox_pred(Tensor): The output is the prediction with shape [N, 8]
  311. including labels, scores and bboxes. The size of
  312. bboxes are corresponding to the original image.
  313. """
  314. origin_shape = paddle.floor(im_shape / scale_factor + 0.5)
  315. origin_shape_list = []
  316. scale_factor_list = []
  317. # scale_factor: scale_y, scale_x
  318. for i in range(bbox_num.shape[0]):
  319. expand_shape = paddle.expand(origin_shape[i:i + 1, :],
  320. [bbox_num[i], 2])
  321. scale_y, scale_x = scale_factor[i][0], scale_factor[i][1]
  322. scale = paddle.concat([
  323. scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x,
  324. scale_y
  325. ])
  326. expand_scale = paddle.expand(scale, [bbox_num[i], 8])
  327. origin_shape_list.append(expand_shape)
  328. scale_factor_list.append(expand_scale)
  329. origin_shape_list = paddle.concat(origin_shape_list)
  330. scale_factor_list = paddle.concat(scale_factor_list)
  331. # bboxes: [N, 10], label, score, bbox
  332. pred_label_score = bboxes[:, 0:2]
  333. pred_bbox = bboxes[:, 2:]
  334. # rescale bbox to original image
  335. pred_bbox = pred_bbox.reshape([-1, 8])
  336. scaled_bbox = pred_bbox / scale_factor_list
  337. origin_h = origin_shape_list[:, 0]
  338. origin_w = origin_shape_list[:, 1]
  339. bboxes = scaled_bbox
  340. zeros = paddle.zeros_like(origin_h)
  341. x1 = paddle.maximum(paddle.minimum(bboxes[:, 0], origin_w - 1), zeros)
  342. y1 = paddle.maximum(paddle.minimum(bboxes[:, 1], origin_h - 1), zeros)
  343. x2 = paddle.maximum(paddle.minimum(bboxes[:, 2], origin_w - 1), zeros)
  344. y2 = paddle.maximum(paddle.minimum(bboxes[:, 3], origin_h - 1), zeros)
  345. x3 = paddle.maximum(paddle.minimum(bboxes[:, 4], origin_w - 1), zeros)
  346. y3 = paddle.maximum(paddle.minimum(bboxes[:, 5], origin_h - 1), zeros)
  347. x4 = paddle.maximum(paddle.minimum(bboxes[:, 6], origin_w - 1), zeros)
  348. y4 = paddle.maximum(paddle.minimum(bboxes[:, 7], origin_h - 1), zeros)
  349. pred_bbox = paddle.stack([x1, y1, x2, y2, x3, y3, x4, y4], axis=-1)
  350. pred_result = paddle.concat([pred_label_score, pred_bbox], axis=1)
  351. return pred_result
  352. def get_bboxes_single(self, cls_score_list, bbox_pred_list):
  353. mlvl_bboxes = []
  354. mlvl_scores = []
  355. for cls_score, bbox_pred in zip(cls_score_list, bbox_pred_list):
  356. if self.use_sigmoid_cls:
  357. scores = F.sigmoid(cls_score)
  358. else:
  359. scores = F.softmax(cls_score, axis=-1)
  360. if scores.shape[0] > self.nms_pre:
  361. # Get maximum scores for foreground classes.
  362. if self.use_sigmoid_cls:
  363. max_scores = paddle.max(scores, axis=1)
  364. else:
  365. max_scores = paddle.max(scores[:, :-1], axis=1)
  366. topk_val, topk_inds = paddle.topk(max_scores, self.nms_pre)
  367. bbox_pred = paddle.gather(bbox_pred, topk_inds)
  368. scores = paddle.gather(scores, topk_inds)
  369. mlvl_bboxes.append(bbox_pred)
  370. mlvl_scores.append(scores)
  371. mlvl_bboxes = paddle.concat(mlvl_bboxes)
  372. mlvl_scores = paddle.concat(mlvl_scores)
  373. mlvl_polys = self.rbox2poly(mlvl_bboxes).unsqueeze(0)
  374. mlvl_scores = paddle.transpose(mlvl_scores, [1, 0]).unsqueeze(0)
  375. bbox, bbox_num, _ = self.nms(mlvl_polys, mlvl_scores)
  376. if bbox.shape[0] <= 0:
  377. bbox = self.fake_bbox
  378. bbox_num = self.fake_bbox_num
  379. return bbox, bbox_num
  380. def smooth_l1_loss(self, pred, label, delta=1.0 / 9.0):
  381. """
  382. Args:
  383. pred: pred score
  384. label: label
  385. delta: delta
  386. Returns: loss
  387. """
  388. assert pred.shape == label.shape and label.numel() > 0
  389. assert delta > 0
  390. diff = paddle.abs(pred - label)
  391. loss = paddle.where(diff < delta, 0.5 * diff * diff / delta,
  392. diff - 0.5 * delta)
  393. return loss
  394. def get_fam_loss(self, fam_target, s2anet_head_out, reg_loss_type='l1'):
  395. (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes,
  396. pos_inds, neg_inds) = fam_target
  397. fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list, num_anchors_list = s2anet_head_out
  398. fam_cls_losses = []
  399. fam_bbox_losses = []
  400. st_idx = 0
  401. num_total_samples = len(pos_inds) + len(
  402. neg_inds) if self.sampling else len(pos_inds)
  403. num_total_samples = max(1, num_total_samples)
  404. for idx, feat_anchor_num in enumerate(num_anchors_list):
  405. # step1: get data
  406. feat_labels = labels[st_idx:st_idx + feat_anchor_num]
  407. feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
  408. feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
  409. feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]
  410. # step2: calc cls loss
  411. feat_labels = feat_labels.reshape(-1)
  412. feat_label_weights = feat_label_weights.reshape(-1)
  413. fam_cls_score = fam_cls_branch_list[idx]
  414. fam_cls_score = paddle.squeeze(fam_cls_score, axis=0)
  415. fam_cls_score1 = fam_cls_score
  416. feat_labels = paddle.to_tensor(feat_labels)
  417. feat_labels_one_hot = paddle.nn.functional.one_hot(
  418. feat_labels, self.cls_out_channels + 1)
  419. feat_labels_one_hot = feat_labels_one_hot[:, 1:]
  420. feat_labels_one_hot.stop_gradient = True
  421. num_total_samples = paddle.to_tensor(
  422. num_total_samples, dtype='float32', stop_gradient=True)
  423. fam_cls = F.sigmoid_focal_loss(
  424. fam_cls_score1,
  425. feat_labels_one_hot,
  426. normalizer=num_total_samples,
  427. reduction='none')
  428. feat_label_weights = feat_label_weights.reshape(
  429. feat_label_weights.shape[0], 1)
  430. feat_label_weights = np.repeat(
  431. feat_label_weights, self.cls_out_channels, axis=1)
  432. feat_label_weights = paddle.to_tensor(
  433. feat_label_weights, stop_gradient=True)
  434. fam_cls = fam_cls * feat_label_weights
  435. fam_cls_total = paddle.sum(fam_cls)
  436. fam_cls_losses.append(fam_cls_total)
  437. # step3: regression loss
  438. feat_bbox_targets = paddle.to_tensor(
  439. feat_bbox_targets, dtype='float32', stop_gradient=True)
  440. feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
  441. fam_bbox_pred = fam_reg_branch_list[idx]
  442. fam_bbox_pred = paddle.squeeze(fam_bbox_pred, axis=0)
  443. fam_bbox_pred = paddle.reshape(fam_bbox_pred, [-1, 5])
  444. fam_bbox = self.smooth_l1_loss(fam_bbox_pred, feat_bbox_targets)
  445. loss_weight = paddle.to_tensor(
  446. self.reg_loss_weight, dtype='float32', stop_gradient=True)
  447. fam_bbox = paddle.multiply(fam_bbox, loss_weight)
  448. feat_bbox_weights = paddle.to_tensor(
  449. feat_bbox_weights, stop_gradient=True)
  450. fam_bbox = fam_bbox * feat_bbox_weights
  451. fam_bbox_total = paddle.sum(fam_bbox) / num_total_samples
  452. fam_bbox_losses.append(fam_bbox_total)
  453. st_idx += feat_anchor_num
  454. fam_cls_loss = paddle.add_n(fam_cls_losses)
  455. fam_cls_loss_weight = paddle.to_tensor(
  456. self.cls_loss_weight[0], dtype='float32', stop_gradient=True)
  457. fam_cls_loss = fam_cls_loss * fam_cls_loss_weight
  458. fam_reg_loss = paddle.add_n(fam_bbox_losses)
  459. return fam_cls_loss, fam_reg_loss
  460. def get_odm_loss(self, odm_target, s2anet_head_out, reg_loss_type='l1'):
  461. (labels, label_weights, bbox_targets, bbox_weights, bbox_gt_bboxes,
  462. pos_inds, neg_inds) = odm_target
  463. fam_cls_branch_list, fam_reg_branch_list, odm_cls_branch_list, odm_reg_branch_list, num_anchors_list = s2anet_head_out
  464. odm_cls_losses = []
  465. odm_bbox_losses = []
  466. st_idx = 0
  467. num_total_samples = len(pos_inds) + len(
  468. neg_inds) if self.sampling else len(pos_inds)
  469. num_total_samples = max(1, num_total_samples)
  470. for idx, feat_anchor_num in enumerate(num_anchors_list):
  471. # step1: get data
  472. feat_labels = labels[st_idx:st_idx + feat_anchor_num]
  473. feat_label_weights = label_weights[st_idx:st_idx + feat_anchor_num]
  474. feat_bbox_targets = bbox_targets[st_idx:st_idx + feat_anchor_num, :]
  475. feat_bbox_weights = bbox_weights[st_idx:st_idx + feat_anchor_num, :]
  476. # step2: calc cls loss
  477. feat_labels = feat_labels.reshape(-1)
  478. feat_label_weights = feat_label_weights.reshape(-1)
  479. odm_cls_score = odm_cls_branch_list[idx]
  480. odm_cls_score = paddle.squeeze(odm_cls_score, axis=0)
  481. odm_cls_score1 = odm_cls_score
  482. feat_labels = paddle.to_tensor(feat_labels)
  483. feat_labels_one_hot = paddle.nn.functional.one_hot(
  484. feat_labels, self.cls_out_channels + 1)
  485. feat_labels_one_hot = feat_labels_one_hot[:, 1:]
  486. feat_labels_one_hot.stop_gradient = True
  487. num_total_samples = paddle.to_tensor(
  488. num_total_samples, dtype='float32', stop_gradient=True)
  489. odm_cls = F.sigmoid_focal_loss(
  490. odm_cls_score1,
  491. feat_labels_one_hot,
  492. normalizer=num_total_samples,
  493. reduction='none')
  494. feat_label_weights = feat_label_weights.reshape(
  495. feat_label_weights.shape[0], 1)
  496. feat_label_weights = np.repeat(
  497. feat_label_weights, self.cls_out_channels, axis=1)
  498. feat_label_weights = paddle.to_tensor(feat_label_weights)
  499. feat_label_weights.stop_gradient = True
  500. odm_cls = odm_cls * feat_label_weights
  501. odm_cls_total = paddle.sum(odm_cls)
  502. odm_cls_losses.append(odm_cls_total)
  503. # # step3: regression loss
  504. feat_bbox_targets = paddle.to_tensor(
  505. feat_bbox_targets, dtype='float32')
  506. feat_bbox_targets = paddle.reshape(feat_bbox_targets, [-1, 5])
  507. feat_bbox_targets.stop_gradient = True
  508. odm_bbox_pred = odm_reg_branch_list[idx]
  509. odm_bbox_pred = paddle.squeeze(odm_bbox_pred, axis=0)
  510. odm_bbox_pred = paddle.reshape(odm_bbox_pred, [-1, 5])
  511. odm_bbox = self.smooth_l1_loss(odm_bbox_pred, feat_bbox_targets)
  512. loss_weight = paddle.to_tensor(
  513. self.reg_loss_weight, dtype='float32', stop_gradient=True)
  514. odm_bbox = paddle.multiply(odm_bbox, loss_weight)
  515. feat_bbox_weights = paddle.to_tensor(
  516. feat_bbox_weights, stop_gradient=True)
  517. odm_bbox = odm_bbox * feat_bbox_weights
  518. odm_bbox_total = paddle.sum(odm_bbox) / num_total_samples
  519. odm_bbox_losses.append(odm_bbox_total)
  520. st_idx += feat_anchor_num
  521. odm_cls_loss = paddle.add_n(odm_cls_losses)
  522. odm_cls_loss_weight = paddle.to_tensor(
  523. self.cls_loss_weight[1], dtype='float32', stop_gradient=True)
  524. odm_cls_loss = odm_cls_loss * odm_cls_loss_weight
  525. odm_reg_loss = paddle.add_n(odm_bbox_losses)
  526. return odm_cls_loss, odm_reg_loss
  527. def get_loss(self, head_outs, inputs):
  528. fam_cls_list, fam_reg_list, odm_cls_list, odm_reg_list, \
  529. num_anchors_list, base_anchors_list, refine_anchors_list = head_outs
  530. # compute loss
  531. fam_cls_loss_lst = []
  532. fam_reg_loss_lst = []
  533. odm_cls_loss_lst = []
  534. odm_reg_loss_lst = []
  535. batch = len(inputs['gt_rbox'])
  536. for i in range(batch):
  537. # data_format: (xc, yc, w, h, theta)
  538. gt_mask = inputs['pad_gt_mask'][i, :, 0]
  539. gt_idx = paddle.nonzero(gt_mask).squeeze(-1)
  540. gt_bboxes = paddle.gather(inputs['gt_rbox'][i], gt_idx).numpy()
  541. gt_labels = paddle.gather(inputs['gt_class'][i], gt_idx).numpy()
  542. is_crowd = paddle.gather(inputs['is_crowd'][i], gt_idx).numpy()
  543. gt_labels = gt_labels + 1
  544. anchors_per_image = np.concatenate(base_anchors_list)
  545. fam_cls_per_image = [t[i] for t in fam_cls_list]
  546. fam_reg_per_image = [t[i] for t in fam_reg_list]
  547. odm_cls_per_image = [t[i] for t in odm_cls_list]
  548. odm_reg_per_image = [t[i] for t in odm_reg_list]
  549. im_s2anet_head_out = (fam_cls_per_image, fam_reg_per_image,
  550. odm_cls_per_image, odm_reg_per_image,
  551. num_anchors_list)
  552. # FAM
  553. im_fam_target = self.anchor_assign(anchors_per_image, gt_bboxes,
  554. gt_labels, is_crowd)
  555. if im_fam_target is not None:
  556. im_fam_cls_loss, im_fam_reg_loss = self.get_fam_loss(
  557. im_fam_target, im_s2anet_head_out, self.reg_loss_type)
  558. fam_cls_loss_lst.append(im_fam_cls_loss)
  559. fam_reg_loss_lst.append(im_fam_reg_loss)
  560. # ODM
  561. refine_anchors_per_image = [t[i] for t in refine_anchors_list]
  562. refine_anchors_per_image = paddle.concat(
  563. refine_anchors_per_image).numpy()
  564. im_odm_target = self.anchor_assign(refine_anchors_per_image,
  565. gt_bboxes, gt_labels, is_crowd)
  566. if im_odm_target is not None:
  567. im_odm_cls_loss, im_odm_reg_loss = self.get_odm_loss(
  568. im_odm_target, im_s2anet_head_out, self.reg_loss_type)
  569. odm_cls_loss_lst.append(im_odm_cls_loss)
  570. odm_reg_loss_lst.append(im_odm_reg_loss)
  571. fam_cls_loss = paddle.add_n(fam_cls_loss_lst) / batch
  572. fam_reg_loss = paddle.add_n(fam_reg_loss_lst) / batch
  573. odm_cls_loss = paddle.add_n(odm_cls_loss_lst) / batch
  574. odm_reg_loss = paddle.add_n(odm_reg_loss_lst) / batch
  575. loss = fam_cls_loss + fam_reg_loss + odm_cls_loss + odm_reg_loss
  576. return {
  577. 'loss': loss,
  578. 'fam_cls_loss': fam_cls_loss,
  579. 'fam_reg_loss': fam_reg_loss,
  580. 'odm_cls_loss': odm_cls_loss,
  581. 'odm_reg_loss': odm_reg_loss
  582. }
  583. def bbox_decode(self, preds, anchors, wh_ratio_clip=1e-6):
  584. """decode bbox from deltas
  585. Args:
  586. preds: [B, L, 5]
  587. anchors: [1, L, 5]
  588. return:
  589. bboxes: [B, L, 5]
  590. """
  591. preds = paddle.add(paddle.multiply(preds, self.stds), self.means)
  592. dx, dy, dw, dh, dangle = paddle.split(preds, 5, axis=-1)
  593. max_ratio = np.abs(np.log(wh_ratio_clip))
  594. dw = paddle.clip(dw, min=-max_ratio, max=max_ratio)
  595. dh = paddle.clip(dh, min=-max_ratio, max=max_ratio)
  596. rroi_x, rroi_y, rroi_w, rroi_h, rroi_angle = paddle.split(
  597. anchors, 5, axis=-1)
  598. gx = dx * rroi_w * paddle.cos(rroi_angle) - dy * rroi_h * paddle.sin(
  599. rroi_angle) + rroi_x
  600. gy = dx * rroi_w * paddle.sin(rroi_angle) + dy * rroi_h * paddle.cos(
  601. rroi_angle) + rroi_y
  602. gw = rroi_w * dw.exp()
  603. gh = rroi_h * dh.exp()
  604. ga = np.pi * dangle + rroi_angle
  605. ga = (ga + np.pi / 4) % np.pi - np.pi / 4
  606. bboxes = paddle.concat([gx, gy, gw, gh, ga], axis=-1)
  607. return bboxes
  608. def rbox2poly(self, rboxes):
  609. """
  610. rboxes: [x_ctr,y_ctr,w,h,angle]
  611. to
  612. polys: [x0,y0,x1,y1,x2,y2,x3,y3]
  613. """
  614. N = paddle.shape(rboxes)[0]
  615. x_ctr = rboxes[:, 0]
  616. y_ctr = rboxes[:, 1]
  617. width = rboxes[:, 2]
  618. height = rboxes[:, 3]
  619. angle = rboxes[:, 4]
  620. tl_x, tl_y, br_x, br_y = -width * 0.5, -height * 0.5, width * 0.5, height * 0.5
  621. normal_rects = paddle.stack(
  622. [tl_x, br_x, br_x, tl_x, tl_y, tl_y, br_y, br_y], axis=0)
  623. normal_rects = paddle.reshape(normal_rects, [2, 4, N])
  624. normal_rects = paddle.transpose(normal_rects, [2, 0, 1])
  625. sin, cos = paddle.sin(angle), paddle.cos(angle)
  626. # M: [N,2,2]
  627. M = paddle.stack([cos, -sin, sin, cos], axis=0)
  628. M = paddle.reshape(M, [2, 2, N])
  629. M = paddle.transpose(M, [2, 0, 1])
  630. # polys: [N,8]
  631. polys = paddle.matmul(M, normal_rects)
  632. polys = paddle.transpose(polys, [2, 1, 0])
  633. polys = paddle.reshape(polys, [-1, N])
  634. polys = paddle.transpose(polys, [1, 0])
  635. tmp = paddle.stack(
  636. [x_ctr, y_ctr, x_ctr, y_ctr, x_ctr, y_ctr, x_ctr, y_ctr], axis=1)
  637. polys = polys + tmp
  638. return polys