centernet_head.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddle.nn.initializer import Constant, Uniform
  19. from ppdet.core.workspace import register
  20. from ppdet.modeling.losses import CTFocalLoss, GIoULoss
  21. class ConvLayer(nn.Layer):
  22. def __init__(self,
  23. ch_in,
  24. ch_out,
  25. kernel_size,
  26. stride=1,
  27. padding=0,
  28. dilation=1,
  29. groups=1,
  30. bias=False):
  31. super(ConvLayer, self).__init__()
  32. bias_attr = False
  33. fan_in = ch_in * kernel_size**2
  34. bound = 1 / math.sqrt(fan_in)
  35. param_attr = paddle.ParamAttr(initializer=Uniform(-bound, bound))
  36. if bias:
  37. bias_attr = paddle.ParamAttr(initializer=Constant(0.))
  38. self.conv = nn.Conv2D(
  39. in_channels=ch_in,
  40. out_channels=ch_out,
  41. kernel_size=kernel_size,
  42. stride=stride,
  43. padding=padding,
  44. dilation=dilation,
  45. groups=groups,
  46. weight_attr=param_attr,
  47. bias_attr=bias_attr)
  48. def forward(self, inputs):
  49. out = self.conv(inputs)
  50. return out
  51. @register
  52. class CenterNetHead(nn.Layer):
  53. """
  54. Args:
  55. in_channels (int): the channel number of input to CenterNetHead.
  56. num_classes (int): the number of classes, 80 (COCO dataset) by default.
  57. head_planes (int): the channel number in all head, 256 by default.
  58. prior_bias (float): prior bias in heatmap head, -2.19 by default, -4.6 in CenterTrack
  59. regress_ltrb (bool): whether to regress left/top/right/bottom or
  60. width/height for a box, True by default.
  61. size_loss (str): the type of size regression loss, 'L1' by default, can be 'giou'.
  62. loss_weight (dict): the weight of each loss.
  63. add_iou (bool): whether to add iou branch, False by default.
  64. """
  65. __shared__ = ['num_classes']
  66. def __init__(self,
  67. in_channels,
  68. num_classes=80,
  69. head_planes=256,
  70. prior_bias=-2.19,
  71. regress_ltrb=True,
  72. size_loss='L1',
  73. loss_weight={
  74. 'heatmap': 1.0,
  75. 'size': 0.1,
  76. 'offset': 1.0,
  77. 'iou': 0.0,
  78. },
  79. add_iou=False):
  80. super(CenterNetHead, self).__init__()
  81. self.regress_ltrb = regress_ltrb
  82. self.loss_weight = loss_weight
  83. self.add_iou = add_iou
  84. # heatmap head
  85. self.heatmap = nn.Sequential(
  86. ConvLayer(
  87. in_channels, head_planes, kernel_size=3, padding=1, bias=True),
  88. nn.ReLU(),
  89. ConvLayer(
  90. head_planes,
  91. num_classes,
  92. kernel_size=1,
  93. stride=1,
  94. padding=0,
  95. bias=True))
  96. with paddle.no_grad():
  97. self.heatmap[2].conv.bias[:] = prior_bias
  98. # size(ltrb or wh) head
  99. self.size = nn.Sequential(
  100. ConvLayer(
  101. in_channels, head_planes, kernel_size=3, padding=1, bias=True),
  102. nn.ReLU(),
  103. ConvLayer(
  104. head_planes,
  105. 4 if regress_ltrb else 2,
  106. kernel_size=1,
  107. stride=1,
  108. padding=0,
  109. bias=True))
  110. self.size_loss = size_loss
  111. # offset head
  112. self.offset = nn.Sequential(
  113. ConvLayer(
  114. in_channels, head_planes, kernel_size=3, padding=1, bias=True),
  115. nn.ReLU(),
  116. ConvLayer(
  117. head_planes, 2, kernel_size=1, stride=1, padding=0, bias=True))
  118. # iou head (optinal)
  119. if self.add_iou and 'iou' in self.loss_weight:
  120. self.iou = nn.Sequential(
  121. ConvLayer(
  122. in_channels,
  123. head_planes,
  124. kernel_size=3,
  125. padding=1,
  126. bias=True),
  127. nn.ReLU(),
  128. ConvLayer(
  129. head_planes,
  130. 4 if regress_ltrb else 2,
  131. kernel_size=1,
  132. stride=1,
  133. padding=0,
  134. bias=True))
  135. @classmethod
  136. def from_config(cls, cfg, input_shape):
  137. if isinstance(input_shape, (list, tuple)):
  138. input_shape = input_shape[0]
  139. return {'in_channels': input_shape.channels}
  140. def forward(self, feat, inputs):
  141. heatmap = F.sigmoid(self.heatmap(feat))
  142. size = self.size(feat)
  143. offset = self.offset(feat)
  144. head_outs = {'heatmap': heatmap, 'size': size, 'offset': offset}
  145. if self.add_iou and 'iou' in self.loss_weight:
  146. iou = self.iou(feat)
  147. head_outs.update({'iou': iou})
  148. if self.training:
  149. losses = self.get_loss(inputs, self.loss_weight, head_outs)
  150. return losses
  151. else:
  152. return head_outs
  153. def get_loss(self, inputs, weights, head_outs):
  154. # 1.heatmap(hm) head loss: CTFocalLoss
  155. heatmap = head_outs['heatmap']
  156. heatmap_target = inputs['heatmap']
  157. heatmap = paddle.clip(heatmap, 1e-4, 1 - 1e-4)
  158. ctfocal_loss = CTFocalLoss()
  159. heatmap_loss = ctfocal_loss(heatmap, heatmap_target)
  160. # 2.size(wh) head loss: L1 loss or GIoU loss
  161. size = head_outs['size']
  162. index = inputs['index']
  163. mask = inputs['index_mask']
  164. size = paddle.transpose(size, perm=[0, 2, 3, 1])
  165. size_n, _, _, size_c = size.shape
  166. size = paddle.reshape(size, shape=[size_n, -1, size_c])
  167. index = paddle.unsqueeze(index, 2)
  168. batch_inds = list()
  169. for i in range(size_n):
  170. batch_ind = paddle.full(
  171. shape=[1, index.shape[1], 1], fill_value=i, dtype='int64')
  172. batch_inds.append(batch_ind)
  173. batch_inds = paddle.concat(batch_inds, axis=0)
  174. index = paddle.concat(x=[batch_inds, index], axis=2)
  175. pos_size = paddle.gather_nd(size, index=index)
  176. mask = paddle.unsqueeze(mask, axis=2)
  177. size_mask = paddle.expand_as(mask, pos_size)
  178. size_mask = paddle.cast(size_mask, dtype=pos_size.dtype)
  179. pos_num = size_mask.sum()
  180. size_mask.stop_gradient = True
  181. if self.size_loss == 'L1':
  182. if self.regress_ltrb:
  183. size_target = inputs['size']
  184. # shape: [bs, max_per_img, 4]
  185. else:
  186. if inputs['size'].shape[-1] == 2:
  187. # inputs['size'] is wh, and regress as wh
  188. # shape: [bs, max_per_img, 2]
  189. size_target = inputs['size']
  190. else:
  191. # inputs['size'] is ltrb, but regress as wh
  192. # shape: [bs, max_per_img, 4]
  193. size_target = inputs['size'][:, :, 0:2] + inputs[
  194. 'size'][:, :, 2:]
  195. size_target.stop_gradient = True
  196. size_loss = F.l1_loss(
  197. pos_size * size_mask, size_target * size_mask, reduction='sum')
  198. size_loss = size_loss / (pos_num + 1e-4)
  199. elif self.size_loss == 'giou':
  200. size_target = inputs['bbox_xys']
  201. size_target.stop_gradient = True
  202. centers_x = (size_target[:, :, 0:1] + size_target[:, :, 2:3]) / 2.0
  203. centers_y = (size_target[:, :, 1:2] + size_target[:, :, 3:4]) / 2.0
  204. x1 = centers_x - pos_size[:, :, 0:1]
  205. y1 = centers_y - pos_size[:, :, 1:2]
  206. x2 = centers_x + pos_size[:, :, 2:3]
  207. y2 = centers_y + pos_size[:, :, 3:4]
  208. pred_boxes = paddle.concat([x1, y1, x2, y2], axis=-1)
  209. giou_loss = GIoULoss(reduction='sum')
  210. size_loss = giou_loss(
  211. pred_boxes * size_mask,
  212. size_target * size_mask,
  213. iou_weight=size_mask,
  214. loc_reweight=None)
  215. size_loss = size_loss / (pos_num + 1e-4)
  216. # 3.offset(reg) head loss: L1 loss
  217. offset = head_outs['offset']
  218. offset_target = inputs['offset']
  219. offset = paddle.transpose(offset, perm=[0, 2, 3, 1])
  220. offset_n, _, _, offset_c = offset.shape
  221. offset = paddle.reshape(offset, shape=[offset_n, -1, offset_c])
  222. pos_offset = paddle.gather_nd(offset, index=index)
  223. offset_mask = paddle.expand_as(mask, pos_offset)
  224. offset_mask = paddle.cast(offset_mask, dtype=pos_offset.dtype)
  225. pos_num = offset_mask.sum()
  226. offset_mask.stop_gradient = True
  227. offset_target.stop_gradient = True
  228. offset_loss = F.l1_loss(
  229. pos_offset * offset_mask,
  230. offset_target * offset_mask,
  231. reduction='sum')
  232. offset_loss = offset_loss / (pos_num + 1e-4)
  233. # 4.iou head loss: GIoU loss (optinal)
  234. if self.add_iou and 'iou' in self.loss_weight:
  235. iou = head_outs['iou']
  236. iou = paddle.transpose(iou, perm=[0, 2, 3, 1])
  237. iou_n, _, _, iou_c = iou.shape
  238. iou = paddle.reshape(iou, shape=[iou_n, -1, iou_c])
  239. pos_iou = paddle.gather_nd(iou, index=index)
  240. iou_mask = paddle.expand_as(mask, pos_iou)
  241. iou_mask = paddle.cast(iou_mask, dtype=pos_iou.dtype)
  242. pos_num = iou_mask.sum()
  243. iou_mask.stop_gradient = True
  244. gt_bbox_xys = inputs['bbox_xys']
  245. gt_bbox_xys.stop_gradient = True
  246. centers_x = (gt_bbox_xys[:, :, 0:1] + gt_bbox_xys[:, :, 2:3]) / 2.0
  247. centers_y = (gt_bbox_xys[:, :, 1:2] + gt_bbox_xys[:, :, 3:4]) / 2.0
  248. x1 = centers_x - pos_size[:, :, 0:1]
  249. y1 = centers_y - pos_size[:, :, 1:2]
  250. x2 = centers_x + pos_size[:, :, 2:3]
  251. y2 = centers_y + pos_size[:, :, 3:4]
  252. pred_boxes = paddle.concat([x1, y1, x2, y2], axis=-1)
  253. giou_loss = GIoULoss(reduction='sum')
  254. iou_loss = giou_loss(
  255. pred_boxes * iou_mask,
  256. gt_bbox_xys * iou_mask,
  257. iou_weight=iou_mask,
  258. loc_reweight=None)
  259. iou_loss = iou_loss / (pos_num + 1e-4)
  260. losses = {
  261. 'heatmap_loss': heatmap_loss,
  262. 'size_loss': size_loss,
  263. 'offset_loss': offset_loss,
  264. }
  265. det_loss = weights['heatmap'] * heatmap_loss + weights[
  266. 'size'] * size_loss + weights['offset'] * offset_loss
  267. if self.add_iou and 'iou' in self.loss_weight:
  268. losses.update({'iou_loss': iou_loss})
  269. det_loss += weights['iou'] * iou_loss
  270. losses.update({'det_loss': det_loss})
  271. return losses