ppyoloe_r_head.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from ppdet.core.workspace import register
  18. from ..losses import ProbIoULoss
  19. from ..initializer import bias_init_with_prob, constant_, normal_, vector_
  20. from ppdet.modeling.backbones.cspresnet import ConvBNLayer
  21. from ppdet.modeling.ops import get_static_shape, get_act_fn, anchor_generator
  22. from ppdet.modeling.layers import MultiClassNMS
  23. __all__ = ['PPYOLOERHead']
  24. class ESEAttn(nn.Layer):
  25. def __init__(self, feat_channels, act='swish'):
  26. super(ESEAttn, self).__init__()
  27. self.fc = nn.Conv2D(feat_channels, feat_channels, 1)
  28. self.conv = ConvBNLayer(feat_channels, feat_channels, 1, act=act)
  29. self._init_weights()
  30. def _init_weights(self):
  31. normal_(self.fc.weight, std=0.01)
  32. def forward(self, feat, avg_feat):
  33. weight = F.sigmoid(self.fc(avg_feat))
  34. return self.conv(feat * weight)
  35. @register
  36. class PPYOLOERHead(nn.Layer):
  37. __shared__ = ['num_classes', 'trt', 'export_onnx']
  38. __inject__ = ['static_assigner', 'assigner', 'nms']
  39. def __init__(self,
  40. in_channels=[1024, 512, 256],
  41. num_classes=15,
  42. act='swish',
  43. fpn_strides=(32, 16, 8),
  44. grid_cell_offset=0.5,
  45. angle_max=90,
  46. use_varifocal_loss=True,
  47. static_assigner_epoch=4,
  48. trt=False,
  49. export_onnx=False,
  50. static_assigner='ATSSAssigner',
  51. assigner='TaskAlignedAssigner',
  52. nms='MultiClassNMS',
  53. loss_weight={'class': 1.0,
  54. 'iou': 2.5,
  55. 'dfl': 0.05}):
  56. super(PPYOLOERHead, self).__init__()
  57. assert len(in_channels) > 0, "len(in_channels) should > 0"
  58. self.in_channels = in_channels
  59. self.num_classes = num_classes
  60. self.fpn_strides = fpn_strides
  61. self.grid_cell_offset = grid_cell_offset
  62. self.angle_max = angle_max
  63. self.loss_weight = loss_weight
  64. self.use_varifocal_loss = use_varifocal_loss
  65. self.half_pi = paddle.to_tensor(
  66. [1.5707963267948966], dtype=paddle.float32)
  67. self.half_pi_bin = self.half_pi / angle_max
  68. self.iou_loss = ProbIoULoss()
  69. self.static_assigner_epoch = static_assigner_epoch
  70. self.static_assigner = static_assigner
  71. self.assigner = assigner
  72. self.nms = nms
  73. # stem
  74. self.stem_cls = nn.LayerList()
  75. self.stem_reg = nn.LayerList()
  76. self.stem_angle = nn.LayerList()
  77. trt = False if export_onnx else trt
  78. self.export_onnx = export_onnx
  79. act = get_act_fn(
  80. act, trt=trt) if act is None or isinstance(act,
  81. (str, dict)) else act
  82. self.trt = trt
  83. for in_c in self.in_channels:
  84. self.stem_cls.append(ESEAttn(in_c, act=act))
  85. self.stem_reg.append(ESEAttn(in_c, act=act))
  86. self.stem_angle.append(ESEAttn(in_c, act=act))
  87. # pred head
  88. self.pred_cls = nn.LayerList()
  89. self.pred_reg = nn.LayerList()
  90. self.pred_angle = nn.LayerList()
  91. for in_c in self.in_channels:
  92. self.pred_cls.append(
  93. nn.Conv2D(
  94. in_c, self.num_classes, 3, padding=1))
  95. self.pred_reg.append(nn.Conv2D(in_c, 4, 3, padding=1))
  96. self.pred_angle.append(
  97. nn.Conv2D(
  98. in_c, self.angle_max + 1, 3, padding=1))
  99. self.angle_proj_conv = nn.Conv2D(
  100. self.angle_max + 1, 1, 1, bias_attr=False)
  101. self._init_weights()
  102. @classmethod
  103. def from_config(cls, cfg, input_shape):
  104. return {'in_channels': [i.channels for i in input_shape], }
  105. def _init_weights(self):
  106. bias_cls = bias_init_with_prob(0.01)
  107. bias_angle = [10.] + [1.] * self.angle_max
  108. for cls_, reg_, angle_ in zip(self.pred_cls, self.pred_reg,
  109. self.pred_angle):
  110. normal_(cls_.weight, std=0.01)
  111. constant_(cls_.bias, bias_cls)
  112. normal_(reg_.weight, std=0.01)
  113. constant_(reg_.bias)
  114. constant_(angle_.weight)
  115. vector_(angle_.bias, bias_angle)
  116. angle_proj = paddle.linspace(0, self.angle_max, self.angle_max + 1)
  117. self.angle_proj = angle_proj * self.half_pi_bin
  118. self.angle_proj_conv.weight.set_value(
  119. self.angle_proj.reshape([1, self.angle_max + 1, 1, 1]))
  120. self.angle_proj_conv.weight.stop_gradient = True
  121. def _generate_anchors(self, feats):
  122. if self.trt:
  123. anchor_points = []
  124. for feat, stride in zip(feats, self.fpn_strides):
  125. _, _, h, w = paddle.shape(feat)
  126. anchor, _ = anchor_generator(
  127. feat,
  128. stride * 4,
  129. 1.0, [1.0, 1.0, 1.0, 1.0], [stride, stride],
  130. offset=0.5)
  131. x1, y1, x2, y2 = paddle.split(anchor, 4, axis=-1)
  132. xc = (x1 + x2 + 1) / 2
  133. yc = (y1 + y2 + 1) / 2
  134. anchor_point = paddle.concat(
  135. [xc, yc], axis=-1).reshape((1, h * w, 2))
  136. anchor_points.append(anchor_point)
  137. anchor_points = paddle.concat(anchor_points, axis=1)
  138. return anchor_points, None, None
  139. else:
  140. anchor_points = []
  141. stride_tensor = []
  142. num_anchors_list = []
  143. for feat, stride in zip(feats, self.fpn_strides):
  144. _, _, h, w = paddle.shape(feat)
  145. shift_x = (paddle.arange(end=w) + 0.5) * stride
  146. shift_y = (paddle.arange(end=h) + 0.5) * stride
  147. shift_y, shift_x = paddle.meshgrid(shift_y, shift_x)
  148. anchor_point = paddle.cast(
  149. paddle.stack(
  150. [shift_x, shift_y], axis=-1), dtype='float32')
  151. anchor_points.append(anchor_point.reshape([1, -1, 2]))
  152. stride_tensor.append(
  153. paddle.full(
  154. [1, h * w, 1], stride, dtype='float32'))
  155. num_anchors_list.append(h * w)
  156. anchor_points = paddle.concat(anchor_points, axis=1)
  157. stride_tensor = paddle.concat(stride_tensor, axis=1)
  158. return anchor_points, stride_tensor, num_anchors_list
  159. def forward(self, feats, targets=None):
  160. assert len(feats) == len(self.fpn_strides), \
  161. "The size of feats is not equal to size of fpn_strides"
  162. if self.training:
  163. return self.forward_train(feats, targets)
  164. else:
  165. return self.forward_eval(feats)
  166. def forward_train(self, feats, targets):
  167. anchor_points, stride_tensor, num_anchors_list = self._generate_anchors(
  168. feats)
  169. cls_score_list, reg_dist_list, reg_angle_list = [], [], []
  170. for i, feat in enumerate(feats):
  171. avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
  172. cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
  173. feat)
  174. reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
  175. reg_angle = self.pred_angle[i](self.stem_angle[i](feat, avg_feat))
  176. # cls and reg
  177. cls_score = F.sigmoid(cls_logit)
  178. cls_score_list.append(cls_score.flatten(2).transpose([0, 2, 1]))
  179. reg_dist_list.append(reg_dist.flatten(2).transpose([0, 2, 1]))
  180. reg_angle_list.append(reg_angle.flatten(2).transpose([0, 2, 1]))
  181. cls_score_list = paddle.concat(cls_score_list, axis=1)
  182. reg_dist_list = paddle.concat(reg_dist_list, axis=1)
  183. reg_angle_list = paddle.concat(reg_angle_list, axis=1)
  184. return self.get_loss([
  185. cls_score_list, reg_dist_list, reg_angle_list, anchor_points,
  186. num_anchors_list, stride_tensor
  187. ], targets)
  188. def forward_eval(self, feats):
  189. cls_score_list, reg_box_list = [], []
  190. anchor_points, _, _ = self._generate_anchors(feats)
  191. for i, (feat, stride) in enumerate(zip(feats, self.fpn_strides)):
  192. b, _, h, w = paddle.shape(feat)
  193. l = h * w
  194. # cls
  195. avg_feat = F.adaptive_avg_pool2d(feat, (1, 1))
  196. cls_logit = self.pred_cls[i](self.stem_cls[i](feat, avg_feat) +
  197. feat)
  198. # reg
  199. reg_dist = self.pred_reg[i](self.stem_reg[i](feat, avg_feat))
  200. reg_xy, reg_wh = paddle.split(reg_dist, 2, axis=1)
  201. reg_xy = reg_xy * stride
  202. reg_wh = (F.elu(reg_wh) + 1.) * stride
  203. reg_angle = self.pred_angle[i](self.stem_angle[i](feat, avg_feat))
  204. reg_angle = self.angle_proj_conv(F.softmax(reg_angle, axis=1))
  205. reg_box = paddle.concat([reg_xy, reg_wh, reg_angle], axis=1)
  206. # cls and reg
  207. cls_score = F.sigmoid(cls_logit)
  208. cls_score_list.append(cls_score.reshape([b, self.num_classes, l]))
  209. reg_box_list.append(reg_box.reshape([b, 5, l]))
  210. cls_score_list = paddle.concat(cls_score_list, axis=-1)
  211. reg_box_list = paddle.concat(reg_box_list, axis=-1).transpose([0, 2, 1])
  212. reg_xy, reg_wha = paddle.split(reg_box_list, [2, 3], axis=-1)
  213. reg_xy = reg_xy + anchor_points
  214. reg_box_list = paddle.concat([reg_xy, reg_wha], axis=-1)
  215. return cls_score_list, reg_box_list
  216. def _bbox_decode(self, points, pred_dist, pred_angle, stride_tensor):
  217. # predict vector to x, y, w, h, angle
  218. b, l = pred_angle.shape[:2]
  219. xy, wh = paddle.split(pred_dist, 2, axis=-1)
  220. xy = xy * stride_tensor + points
  221. wh = (F.elu(wh) + 1.) * stride_tensor
  222. angle = F.softmax(pred_angle.reshape([b, l, 1, self.angle_max + 1
  223. ])).matmul(self.angle_proj)
  224. return paddle.concat([xy, wh, angle], axis=-1)
  225. def get_loss(self, head_outs, gt_meta):
  226. pred_scores, pred_dist, pred_angle, \
  227. anchor_points, num_anchors_list, stride_tensor = head_outs
  228. # [B, N, 5] -> [B, N, 5]
  229. pred_bboxes = self._bbox_decode(anchor_points, pred_dist, pred_angle,
  230. stride_tensor)
  231. gt_labels = gt_meta['gt_class']
  232. # [B, N, 5]
  233. gt_bboxes = gt_meta['gt_rbox']
  234. pad_gt_mask = gt_meta['pad_gt_mask']
  235. # label assignment
  236. if gt_meta['epoch_id'] < self.static_assigner_epoch:
  237. assigned_labels, assigned_bboxes, assigned_scores = \
  238. self.static_assigner(
  239. anchor_points,
  240. stride_tensor,
  241. num_anchors_list,
  242. gt_labels,
  243. gt_meta['gt_bbox'],
  244. gt_bboxes,
  245. pad_gt_mask,
  246. self.num_classes,
  247. pred_bboxes.detach()
  248. )
  249. else:
  250. assigned_labels, assigned_bboxes, assigned_scores = \
  251. self.assigner(
  252. pred_scores.detach(),
  253. pred_bboxes.detach(),
  254. anchor_points,
  255. num_anchors_list,
  256. gt_labels,
  257. gt_bboxes,
  258. pad_gt_mask,
  259. bg_index=self.num_classes)
  260. alpha_l = -1
  261. # cls loss
  262. if self.use_varifocal_loss:
  263. one_hot_label = F.one_hot(assigned_labels,
  264. self.num_classes + 1)[..., :-1]
  265. loss_cls = self._varifocal_loss(pred_scores, assigned_scores,
  266. one_hot_label)
  267. else:
  268. loss_cls = self._focal_loss(pred_scores, assigned_scores, alpha_l)
  269. assigned_scores_sum = assigned_scores.sum()
  270. if paddle.distributed.get_world_size() > 1:
  271. paddle.distributed.all_reduce(assigned_scores_sum)
  272. assigned_scores_sum = paddle.clip(
  273. assigned_scores_sum / paddle.distributed.get_world_size(),
  274. min=1.)
  275. else:
  276. assigned_scores_sum = paddle.clip(assigned_scores_sum, min=1.)
  277. loss_cls /= assigned_scores_sum
  278. loss_iou, loss_dfl = self._bbox_loss(pred_angle, pred_bboxes,
  279. anchor_points, assigned_labels,
  280. assigned_bboxes, assigned_scores,
  281. assigned_scores_sum, stride_tensor)
  282. loss = self.loss_weight['class'] * loss_cls + \
  283. self.loss_weight['iou'] * loss_iou + \
  284. self.loss_weight['dfl'] * loss_dfl
  285. out_dict = {
  286. 'loss': loss,
  287. 'loss_cls': loss_cls,
  288. 'loss_iou': loss_iou,
  289. 'loss_dfl': loss_dfl
  290. }
  291. return out_dict
  292. @staticmethod
  293. def _focal_loss(score, label, alpha=0.25, gamma=2.0):
  294. weight = (score - label).pow(gamma)
  295. if alpha > 0:
  296. alpha_t = alpha * label + (1 - alpha) * (1 - label)
  297. weight *= alpha_t
  298. loss = F.binary_cross_entropy(
  299. score, label, weight=weight, reduction='sum')
  300. return loss
  301. @staticmethod
  302. def _varifocal_loss(pred_score, gt_score, label, alpha=0.75, gamma=2.0):
  303. weight = alpha * pred_score.pow(gamma) * (1 - label) + gt_score * label
  304. loss = F.binary_cross_entropy(
  305. pred_score, gt_score, weight=weight, reduction='sum')
  306. return loss
  307. @staticmethod
  308. def _df_loss(pred_dist, target):
  309. target_left = paddle.cast(target, 'int64')
  310. target_right = target_left + 1
  311. weight_left = target_right.astype('float32') - target
  312. weight_right = 1 - weight_left
  313. loss_left = F.cross_entropy(
  314. pred_dist, target_left, reduction='none') * weight_left
  315. loss_right = F.cross_entropy(
  316. pred_dist, target_right, reduction='none') * weight_right
  317. return (loss_left + loss_right).mean(-1, keepdim=True)
  318. def _bbox_loss(self, pred_angle, pred_bboxes, anchor_points,
  319. assigned_labels, assigned_bboxes, assigned_scores,
  320. assigned_scores_sum, stride_tensor):
  321. # select positive samples mask
  322. mask_positive = (assigned_labels != self.num_classes)
  323. num_pos = mask_positive.sum()
  324. # pos/neg loss
  325. if num_pos > 0:
  326. # iou
  327. bbox_mask = mask_positive.unsqueeze(-1).tile([1, 1, 5])
  328. pred_bboxes_pos = paddle.masked_select(pred_bboxes,
  329. bbox_mask).reshape([-1, 5])
  330. assigned_bboxes_pos = paddle.masked_select(
  331. assigned_bboxes, bbox_mask).reshape([-1, 5])
  332. bbox_weight = paddle.masked_select(
  333. assigned_scores.sum(-1), mask_positive).reshape([-1])
  334. loss_iou = self.iou_loss(pred_bboxes_pos,
  335. assigned_bboxes_pos) * bbox_weight
  336. loss_iou = loss_iou.sum() / assigned_scores_sum
  337. # dfl
  338. angle_mask = mask_positive.unsqueeze(-1).tile(
  339. [1, 1, self.angle_max + 1])
  340. pred_angle_pos = paddle.masked_select(
  341. pred_angle, angle_mask).reshape([-1, self.angle_max + 1])
  342. assigned_angle_pos = (
  343. assigned_bboxes_pos[:, 4] /
  344. self.half_pi_bin).clip(0, self.angle_max - 0.01)
  345. loss_dfl = self._df_loss(pred_angle_pos, assigned_angle_pos)
  346. else:
  347. loss_iou = pred_bboxes.sum() * 0.
  348. loss_dfl = paddle.zeros([1])
  349. return loss_iou, loss_dfl
  350. def _box2corners(self, pred_bboxes):
  351. """ convert (x, y, w, h, angle) to (x1, y1, x2, y2, x3, y3, x4, y4)
  352. Args:
  353. pred_bboxes (Tensor): [B, N, 5]
  354. Returns:
  355. polys (Tensor): [B, N, 8]
  356. """
  357. x, y, w, h, angle = paddle.split(pred_bboxes, 5, axis=-1)
  358. cos_a_half = paddle.cos(angle) * 0.5
  359. sin_a_half = paddle.sin(angle) * 0.5
  360. w_x = cos_a_half * w
  361. w_y = sin_a_half * w
  362. h_x = -sin_a_half * h
  363. h_y = cos_a_half * h
  364. return paddle.concat(
  365. [
  366. x + w_x + h_x, y + w_y + h_y, x - w_x + h_x, y - w_y + h_y,
  367. x - w_x - h_x, y - w_y - h_y, x + w_x - h_x, y + w_y - h_y
  368. ],
  369. axis=-1)
  370. def post_process(self, head_outs, scale_factor):
  371. pred_scores, pred_bboxes = head_outs
  372. # [B, N, 5] -> [B, N, 8]
  373. pred_bboxes = self._box2corners(pred_bboxes)
  374. # scale bbox to origin
  375. scale_y, scale_x = paddle.split(scale_factor, 2, axis=-1)
  376. scale_factor = paddle.concat(
  377. [
  378. scale_x, scale_y, scale_x, scale_y, scale_x, scale_y, scale_x,
  379. scale_y
  380. ],
  381. axis=-1).reshape([-1, 1, 8])
  382. pred_bboxes /= scale_factor
  383. if self.export_onnx:
  384. return pred_bboxes, pred_scores
  385. bbox_pred, bbox_num, _ = self.nms(pred_bboxes, pred_scores)
  386. return bbox_pred, bbox_num