detr_head.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from ppdet.core.workspace import register
  21. import pycocotools.mask as mask_util
  22. from ..initializer import linear_init_, constant_
  23. from ..transformers.utils import inverse_sigmoid
  24. __all__ = ['DETRHead', 'DeformableDETRHead', 'DINOHead']
  25. class MLP(nn.Layer):
  26. """This code is based on
  27. https://github.com/facebookresearch/detr/blob/main/models/detr.py
  28. """
  29. def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
  30. super().__init__()
  31. self.num_layers = num_layers
  32. h = [hidden_dim] * (num_layers - 1)
  33. self.layers = nn.LayerList(
  34. nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
  35. self._reset_parameters()
  36. def _reset_parameters(self):
  37. for l in self.layers:
  38. linear_init_(l)
  39. def forward(self, x):
  40. for i, layer in enumerate(self.layers):
  41. x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
  42. return x
  43. class MultiHeadAttentionMap(nn.Layer):
  44. """This code is based on
  45. https://github.com/facebookresearch/detr/blob/main/models/segmentation.py
  46. This is a 2D attention module, which only returns the attention softmax (no multiplication by value)
  47. """
  48. def __init__(self, query_dim, hidden_dim, num_heads, dropout=0.0,
  49. bias=True):
  50. super().__init__()
  51. self.num_heads = num_heads
  52. self.hidden_dim = hidden_dim
  53. self.dropout = nn.Dropout(dropout)
  54. weight_attr = paddle.ParamAttr(
  55. initializer=paddle.nn.initializer.XavierUniform())
  56. bias_attr = paddle.framework.ParamAttr(
  57. initializer=paddle.nn.initializer.Constant()) if bias else False
  58. self.q_proj = nn.Linear(query_dim, hidden_dim, weight_attr, bias_attr)
  59. self.k_proj = nn.Conv2D(
  60. query_dim,
  61. hidden_dim,
  62. 1,
  63. weight_attr=weight_attr,
  64. bias_attr=bias_attr)
  65. self.normalize_fact = float(hidden_dim / self.num_heads)**-0.5
  66. def forward(self, q, k, mask=None):
  67. q = self.q_proj(q)
  68. k = self.k_proj(k)
  69. bs, num_queries, n, c, h, w = q.shape[0], q.shape[1], self.num_heads,\
  70. self.hidden_dim // self.num_heads, k.shape[-2], k.shape[-1]
  71. qh = q.reshape([bs, num_queries, n, c])
  72. kh = k.reshape([bs, n, c, h, w])
  73. # weights = paddle.einsum("bqnc,bnchw->bqnhw", qh * self.normalize_fact, kh)
  74. qh = qh.transpose([0, 2, 1, 3]).reshape([-1, num_queries, c])
  75. kh = kh.reshape([-1, c, h * w])
  76. weights = paddle.bmm(qh * self.normalize_fact, kh).reshape(
  77. [bs, n, num_queries, h, w]).transpose([0, 2, 1, 3, 4])
  78. if mask is not None:
  79. weights += mask
  80. # fix a potenial bug: https://github.com/facebookresearch/detr/issues/247
  81. weights = F.softmax(weights.flatten(3), axis=-1).reshape(weights.shape)
  82. weights = self.dropout(weights)
  83. return weights
  84. class MaskHeadFPNConv(nn.Layer):
  85. """This code is based on
  86. https://github.com/facebookresearch/detr/blob/main/models/segmentation.py
  87. Simple convolutional head, using group norm.
  88. Upsampling is done using a FPN approach
  89. """
  90. def __init__(self, input_dim, fpn_dims, context_dim, num_groups=8):
  91. super().__init__()
  92. inter_dims = [input_dim,
  93. ] + [context_dim // (2**i) for i in range(1, 5)]
  94. weight_attr = paddle.ParamAttr(
  95. initializer=paddle.nn.initializer.KaimingUniform())
  96. bias_attr = paddle.framework.ParamAttr(
  97. initializer=paddle.nn.initializer.Constant())
  98. self.conv0 = self._make_layers(input_dim, input_dim, 3, num_groups,
  99. weight_attr, bias_attr)
  100. self.conv_inter = nn.LayerList()
  101. for in_dims, out_dims in zip(inter_dims[:-1], inter_dims[1:]):
  102. self.conv_inter.append(
  103. self._make_layers(in_dims, out_dims, 3, num_groups, weight_attr,
  104. bias_attr))
  105. self.conv_out = nn.Conv2D(
  106. inter_dims[-1],
  107. 1,
  108. 3,
  109. padding=1,
  110. weight_attr=weight_attr,
  111. bias_attr=bias_attr)
  112. self.adapter = nn.LayerList()
  113. for i in range(len(fpn_dims)):
  114. self.adapter.append(
  115. nn.Conv2D(
  116. fpn_dims[i],
  117. inter_dims[i + 1],
  118. 1,
  119. weight_attr=weight_attr,
  120. bias_attr=bias_attr))
  121. def _make_layers(self,
  122. in_dims,
  123. out_dims,
  124. kernel_size,
  125. num_groups,
  126. weight_attr=None,
  127. bias_attr=None):
  128. return nn.Sequential(
  129. nn.Conv2D(
  130. in_dims,
  131. out_dims,
  132. kernel_size,
  133. padding=kernel_size // 2,
  134. weight_attr=weight_attr,
  135. bias_attr=bias_attr),
  136. nn.GroupNorm(num_groups, out_dims),
  137. nn.ReLU())
  138. def forward(self, x, bbox_attention_map, fpns):
  139. x = paddle.concat([
  140. x.tile([bbox_attention_map.shape[1], 1, 1, 1]),
  141. bbox_attention_map.flatten(0, 1)
  142. ], 1)
  143. x = self.conv0(x)
  144. for inter_layer, adapter_layer, feat in zip(self.conv_inter[:-1],
  145. self.adapter, fpns):
  146. feat = adapter_layer(feat).tile(
  147. [bbox_attention_map.shape[1], 1, 1, 1])
  148. x = inter_layer(x)
  149. x = feat + F.interpolate(x, size=feat.shape[-2:])
  150. x = self.conv_inter[-1](x)
  151. x = self.conv_out(x)
  152. return x
  153. @register
  154. class DETRHead(nn.Layer):
  155. __shared__ = ['num_classes', 'hidden_dim', 'use_focal_loss']
  156. __inject__ = ['loss']
  157. def __init__(self,
  158. num_classes=80,
  159. hidden_dim=256,
  160. nhead=8,
  161. num_mlp_layers=3,
  162. loss='DETRLoss',
  163. fpn_dims=[1024, 512, 256],
  164. with_mask_head=False,
  165. use_focal_loss=False):
  166. super(DETRHead, self).__init__()
  167. # add background class
  168. self.num_classes = num_classes if use_focal_loss else num_classes + 1
  169. self.hidden_dim = hidden_dim
  170. self.loss = loss
  171. self.with_mask_head = with_mask_head
  172. self.use_focal_loss = use_focal_loss
  173. self.score_head = nn.Linear(hidden_dim, self.num_classes)
  174. self.bbox_head = MLP(hidden_dim,
  175. hidden_dim,
  176. output_dim=4,
  177. num_layers=num_mlp_layers)
  178. if self.with_mask_head:
  179. self.bbox_attention = MultiHeadAttentionMap(hidden_dim, hidden_dim,
  180. nhead)
  181. self.mask_head = MaskHeadFPNConv(hidden_dim + nhead, fpn_dims,
  182. hidden_dim)
  183. self._reset_parameters()
  184. def _reset_parameters(self):
  185. linear_init_(self.score_head)
  186. @classmethod
  187. def from_config(cls, cfg, hidden_dim, nhead, input_shape):
  188. return {
  189. 'hidden_dim': hidden_dim,
  190. 'nhead': nhead,
  191. 'fpn_dims': [i.channels for i in input_shape[::-1]][1:]
  192. }
  193. @staticmethod
  194. def get_gt_mask_from_polygons(gt_poly, pad_mask):
  195. out_gt_mask = []
  196. for polygons, padding in zip(gt_poly, pad_mask):
  197. height, width = int(padding[:, 0].sum()), int(padding[0, :].sum())
  198. masks = []
  199. for obj_poly in polygons:
  200. rles = mask_util.frPyObjects(obj_poly, height, width)
  201. rle = mask_util.merge(rles)
  202. masks.append(
  203. paddle.to_tensor(mask_util.decode(rle)).astype('float32'))
  204. masks = paddle.stack(masks)
  205. masks_pad = paddle.zeros(
  206. [masks.shape[0], pad_mask.shape[1], pad_mask.shape[2]])
  207. masks_pad[:, :height, :width] = masks
  208. out_gt_mask.append(masks_pad)
  209. return out_gt_mask
  210. def forward(self, out_transformer, body_feats, inputs=None):
  211. r"""
  212. Args:
  213. out_transformer (Tuple): (feats: [num_levels, batch_size,
  214. num_queries, hidden_dim],
  215. memory: [batch_size, hidden_dim, h, w],
  216. src_proj: [batch_size, h*w, hidden_dim],
  217. src_mask: [batch_size, 1, 1, h, w])
  218. body_feats (List(Tensor)): list[[B, C, H, W]]
  219. inputs (dict): dict(inputs)
  220. """
  221. feats, memory, src_proj, src_mask = out_transformer
  222. outputs_logit = self.score_head(feats)
  223. outputs_bbox = F.sigmoid(self.bbox_head(feats))
  224. outputs_seg = None
  225. if self.with_mask_head:
  226. bbox_attention_map = self.bbox_attention(feats[-1], memory,
  227. src_mask)
  228. fpn_feats = [a for a in body_feats[::-1]][1:]
  229. outputs_seg = self.mask_head(src_proj, bbox_attention_map,
  230. fpn_feats)
  231. outputs_seg = outputs_seg.reshape([
  232. feats.shape[1], feats.shape[2], outputs_seg.shape[-2],
  233. outputs_seg.shape[-1]
  234. ])
  235. if self.training:
  236. assert inputs is not None
  237. assert 'gt_bbox' in inputs and 'gt_class' in inputs
  238. gt_mask = self.get_gt_mask_from_polygons(
  239. inputs['gt_poly'],
  240. inputs['pad_mask']) if 'gt_poly' in inputs else None
  241. return self.loss(
  242. outputs_bbox,
  243. outputs_logit,
  244. inputs['gt_bbox'],
  245. inputs['gt_class'],
  246. masks=outputs_seg,
  247. gt_mask=gt_mask)
  248. else:
  249. return (outputs_bbox[-1], outputs_logit[-1], outputs_seg)
  250. @register
  251. class DeformableDETRHead(nn.Layer):
  252. __shared__ = ['num_classes', 'hidden_dim']
  253. __inject__ = ['loss']
  254. def __init__(self,
  255. num_classes=80,
  256. hidden_dim=512,
  257. nhead=8,
  258. num_mlp_layers=3,
  259. loss='DETRLoss'):
  260. super(DeformableDETRHead, self).__init__()
  261. self.num_classes = num_classes
  262. self.hidden_dim = hidden_dim
  263. self.nhead = nhead
  264. self.loss = loss
  265. self.score_head = nn.Linear(hidden_dim, self.num_classes)
  266. self.bbox_head = MLP(hidden_dim,
  267. hidden_dim,
  268. output_dim=4,
  269. num_layers=num_mlp_layers)
  270. self._reset_parameters()
  271. def _reset_parameters(self):
  272. linear_init_(self.score_head)
  273. constant_(self.score_head.bias, -4.595)
  274. constant_(self.bbox_head.layers[-1].weight)
  275. with paddle.no_grad():
  276. bias = paddle.zeros_like(self.bbox_head.layers[-1].bias)
  277. bias[2:] = -2.0
  278. self.bbox_head.layers[-1].bias.set_value(bias)
  279. @classmethod
  280. def from_config(cls, cfg, hidden_dim, nhead, input_shape):
  281. return {'hidden_dim': hidden_dim, 'nhead': nhead}
  282. def forward(self, out_transformer, body_feats, inputs=None):
  283. r"""
  284. Args:
  285. out_transformer (Tuple): (feats: [num_levels, batch_size,
  286. num_queries, hidden_dim],
  287. memory: [batch_size,
  288. \sum_{l=0}^{L-1} H_l \cdot W_l, hidden_dim],
  289. reference_points: [batch_size, num_queries, 2])
  290. body_feats (List(Tensor)): list[[B, C, H, W]]
  291. inputs (dict): dict(inputs)
  292. """
  293. feats, memory, reference_points = out_transformer
  294. reference_points = inverse_sigmoid(reference_points.unsqueeze(0))
  295. outputs_bbox = self.bbox_head(feats)
  296. # It's equivalent to "outputs_bbox[:, :, :, :2] += reference_points",
  297. # but the gradient is wrong in paddle.
  298. outputs_bbox = paddle.concat(
  299. [
  300. outputs_bbox[:, :, :, :2] + reference_points,
  301. outputs_bbox[:, :, :, 2:]
  302. ],
  303. axis=-1)
  304. outputs_bbox = F.sigmoid(outputs_bbox)
  305. outputs_logit = self.score_head(feats)
  306. if self.training:
  307. assert inputs is not None
  308. assert 'gt_bbox' in inputs and 'gt_class' in inputs
  309. return self.loss(outputs_bbox, outputs_logit, inputs['gt_bbox'],
  310. inputs['gt_class'])
  311. else:
  312. return (outputs_bbox[-1], outputs_logit[-1], None)
  313. @register
  314. class DINOHead(nn.Layer):
  315. __inject__ = ['loss']
  316. def __init__(self, loss='DINOLoss'):
  317. super(DINOHead, self).__init__()
  318. self.loss = loss
  319. def forward(self, out_transformer, body_feats, inputs=None):
  320. (dec_out_bboxes, dec_out_logits, enc_topk_bboxes, enc_topk_logits,
  321. dn_meta) = out_transformer
  322. if self.training:
  323. assert inputs is not None
  324. assert 'gt_bbox' in inputs and 'gt_class' in inputs
  325. if dn_meta is not None:
  326. dn_out_bboxes, dec_out_bboxes = paddle.split(
  327. dec_out_bboxes, dn_meta['dn_num_split'], axis=2)
  328. dn_out_logits, dec_out_logits = paddle.split(
  329. dec_out_logits, dn_meta['dn_num_split'], axis=2)
  330. else:
  331. dn_out_bboxes, dn_out_logits = None, None
  332. out_bboxes = paddle.concat(
  333. [enc_topk_bboxes.unsqueeze(0), dec_out_bboxes])
  334. out_logits = paddle.concat(
  335. [enc_topk_logits.unsqueeze(0), dec_out_logits])
  336. return self.loss(
  337. out_bboxes,
  338. out_logits,
  339. inputs['gt_bbox'],
  340. inputs['gt_class'],
  341. dn_out_bboxes=dn_out_bboxes,
  342. dn_out_logits=dn_out_logits,
  343. dn_meta=dn_meta)
  344. else:
  345. return (dec_out_bboxes[-1], dec_out_logits[-1], None)