sparsercnn_head.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/PeizeSun/SparseR-CNN/blob/main/projects/SparseRCNN/sparsercnn/head.py
  16. Ths copyright of PeizeSun/SparseR-CNN is as follows:
  17. MIT License [see LICENSE for details]
  18. """
  19. from __future__ import absolute_import
  20. from __future__ import division
  21. from __future__ import print_function
  22. import math
  23. import copy
  24. import paddle
  25. import paddle.nn as nn
  26. from ppdet.core.workspace import register
  27. from ppdet.modeling.heads.roi_extractor import RoIAlign
  28. from ppdet.modeling.bbox_utils import delta2bbox
  29. from .. import initializer as init
  30. _DEFAULT_SCALE_CLAMP = math.log(100000. / 16)
  31. class DynamicConv(nn.Layer):
  32. def __init__(
  33. self,
  34. head_hidden_dim,
  35. head_dim_dynamic,
  36. head_num_dynamic, ):
  37. super().__init__()
  38. self.hidden_dim = head_hidden_dim
  39. self.dim_dynamic = head_dim_dynamic
  40. self.num_dynamic = head_num_dynamic
  41. self.num_params = self.hidden_dim * self.dim_dynamic
  42. self.dynamic_layer = nn.Linear(self.hidden_dim,
  43. self.num_dynamic * self.num_params)
  44. self.norm1 = nn.LayerNorm(self.dim_dynamic)
  45. self.norm2 = nn.LayerNorm(self.hidden_dim)
  46. self.activation = nn.ReLU()
  47. pooler_resolution = 7
  48. num_output = self.hidden_dim * pooler_resolution**2
  49. self.out_layer = nn.Linear(num_output, self.hidden_dim)
  50. self.norm3 = nn.LayerNorm(self.hidden_dim)
  51. def forward(self, pro_features, roi_features):
  52. '''
  53. pro_features: (1, N * nr_boxes, self.d_model)
  54. roi_features: (49, N * nr_boxes, self.d_model)
  55. '''
  56. features = roi_features.transpose(perm=[1, 0, 2])
  57. parameters = self.dynamic_layer(pro_features).transpose(perm=[1, 0, 2])
  58. param1 = parameters[:, :, :self.num_params].reshape(
  59. [-1, self.hidden_dim, self.dim_dynamic])
  60. param2 = parameters[:, :, self.num_params:].reshape(
  61. [-1, self.dim_dynamic, self.hidden_dim])
  62. features = paddle.bmm(features, param1)
  63. features = self.norm1(features)
  64. features = self.activation(features)
  65. features = paddle.bmm(features, param2)
  66. features = self.norm2(features)
  67. features = self.activation(features)
  68. features = features.flatten(1)
  69. features = self.out_layer(features)
  70. features = self.norm3(features)
  71. features = self.activation(features)
  72. return features
  73. class RCNNHead(nn.Layer):
  74. def __init__(
  75. self,
  76. d_model,
  77. num_classes,
  78. dim_feedforward,
  79. nhead,
  80. dropout,
  81. head_cls,
  82. head_reg,
  83. head_dim_dynamic,
  84. head_num_dynamic,
  85. scale_clamp: float=_DEFAULT_SCALE_CLAMP,
  86. bbox_weights=(2.0, 2.0, 1.0, 1.0), ):
  87. super().__init__()
  88. self.d_model = d_model
  89. # dynamic.
  90. self.self_attn = nn.MultiHeadAttention(d_model, nhead, dropout=dropout)
  91. self.inst_interact = DynamicConv(d_model, head_dim_dynamic,
  92. head_num_dynamic)
  93. self.linear1 = nn.Linear(d_model, dim_feedforward)
  94. self.dropout = nn.Dropout(dropout)
  95. self.linear2 = nn.Linear(dim_feedforward, d_model)
  96. self.norm1 = nn.LayerNorm(d_model)
  97. self.norm2 = nn.LayerNorm(d_model)
  98. self.norm3 = nn.LayerNorm(d_model)
  99. self.dropout1 = nn.Dropout(dropout)
  100. self.dropout2 = nn.Dropout(dropout)
  101. self.dropout3 = nn.Dropout(dropout)
  102. self.activation = nn.ReLU()
  103. # cls.
  104. num_cls = head_cls
  105. cls_module = list()
  106. for _ in range(num_cls):
  107. cls_module.append(nn.Linear(d_model, d_model, bias_attr=False))
  108. cls_module.append(nn.LayerNorm(d_model))
  109. cls_module.append(nn.ReLU())
  110. self.cls_module = nn.LayerList(cls_module)
  111. # reg.
  112. num_reg = head_reg
  113. reg_module = list()
  114. for _ in range(num_reg):
  115. reg_module.append(nn.Linear(d_model, d_model, bias_attr=False))
  116. reg_module.append(nn.LayerNorm(d_model))
  117. reg_module.append(nn.ReLU())
  118. self.reg_module = nn.LayerList(reg_module)
  119. # pred.
  120. self.class_logits = nn.Linear(d_model, num_classes)
  121. self.bboxes_delta = nn.Linear(d_model, 4)
  122. self.scale_clamp = scale_clamp
  123. self.bbox_weights = bbox_weights
  124. def forward(self, features, bboxes, pro_features, pooler):
  125. """
  126. :param bboxes: (N, nr_boxes, 4)
  127. :param pro_features: (N, nr_boxes, d_model)
  128. """
  129. N, nr_boxes = bboxes.shape[:2]
  130. proposal_boxes = list()
  131. for b in range(N):
  132. proposal_boxes.append(bboxes[b])
  133. roi_num = paddle.full([N], nr_boxes).astype("int32")
  134. roi_features = pooler(features, proposal_boxes, roi_num)
  135. roi_features = roi_features.reshape(
  136. [N * nr_boxes, self.d_model, -1]).transpose(perm=[2, 0, 1])
  137. # self_att.
  138. pro_features = pro_features.reshape([N, nr_boxes, self.d_model])
  139. pro_features2 = self.self_attn(
  140. pro_features, pro_features, value=pro_features)
  141. pro_features = pro_features.transpose(perm=[1, 0, 2]) + self.dropout1(
  142. pro_features2.transpose(perm=[1, 0, 2]))
  143. pro_features = self.norm1(pro_features)
  144. # inst_interact.
  145. pro_features = pro_features.reshape(
  146. [nr_boxes, N, self.d_model]).transpose(perm=[1, 0, 2]).reshape(
  147. [1, N * nr_boxes, self.d_model])
  148. pro_features2 = self.inst_interact(pro_features, roi_features)
  149. pro_features = pro_features + self.dropout2(pro_features2)
  150. obj_features = self.norm2(pro_features)
  151. # obj_feature.
  152. obj_features2 = self.linear2(
  153. self.dropout(self.activation(self.linear1(obj_features))))
  154. obj_features = obj_features + self.dropout3(obj_features2)
  155. obj_features = self.norm3(obj_features)
  156. fc_feature = obj_features.transpose(perm=[1, 0, 2]).reshape(
  157. [N * nr_boxes, -1])
  158. cls_feature = fc_feature.clone()
  159. reg_feature = fc_feature.clone()
  160. for cls_layer in self.cls_module:
  161. cls_feature = cls_layer(cls_feature)
  162. for reg_layer in self.reg_module:
  163. reg_feature = reg_layer(reg_feature)
  164. class_logits = self.class_logits(cls_feature)
  165. bboxes_deltas = self.bboxes_delta(reg_feature)
  166. pred_bboxes = delta2bbox(bboxes_deltas,
  167. bboxes.reshape([-1, 4]), self.bbox_weights)
  168. return class_logits.reshape([N, nr_boxes, -1]), pred_bboxes.reshape(
  169. [N, nr_boxes, -1]), obj_features
  170. @register
  171. class SparseRCNNHead(nn.Layer):
  172. '''
  173. SparsercnnHead
  174. Args:
  175. roi_input_shape (list[ShapeSpec]): The output shape of fpn
  176. num_classes (int): Number of classes,
  177. head_hidden_dim (int): The param of MultiHeadAttention,
  178. head_dim_feedforward (int): The param of MultiHeadAttention,
  179. nhead (int): The param of MultiHeadAttention,
  180. head_dropout (float): The p of dropout,
  181. head_cls (int): The number of class head,
  182. head_reg (int): The number of regressionhead,
  183. head_num_dynamic (int): The number of DynamicConv's param,
  184. head_num_heads (int): The number of RCNNHead,
  185. deep_supervision (int): wheather supervise the intermediate results,
  186. num_proposals (int): the number of proposals boxes and features
  187. '''
  188. __inject__ = ['loss_func']
  189. __shared__ = ['num_classes']
  190. def __init__(
  191. self,
  192. head_hidden_dim,
  193. head_dim_feedforward,
  194. nhead,
  195. head_dropout,
  196. head_cls,
  197. head_reg,
  198. head_dim_dynamic,
  199. head_num_dynamic,
  200. head_num_heads,
  201. deep_supervision,
  202. num_proposals,
  203. num_classes=80,
  204. loss_func="SparseRCNNLoss",
  205. roi_input_shape=None, ):
  206. super().__init__()
  207. assert head_num_heads > 0, \
  208. f'At least one RoI Head is required, but {head_num_heads}.'
  209. # Build RoI.
  210. box_pooler = self._init_box_pooler(roi_input_shape)
  211. self.box_pooler = box_pooler
  212. # Build heads.
  213. rcnn_head = RCNNHead(
  214. head_hidden_dim,
  215. num_classes,
  216. head_dim_feedforward,
  217. nhead,
  218. head_dropout,
  219. head_cls,
  220. head_reg,
  221. head_dim_dynamic,
  222. head_num_dynamic, )
  223. self.head_series = nn.LayerList(
  224. [copy.deepcopy(rcnn_head) for i in range(head_num_heads)])
  225. self.return_intermediate = deep_supervision
  226. self.num_classes = num_classes
  227. # build init proposal
  228. self.init_proposal_features = nn.Embedding(num_proposals,
  229. head_hidden_dim)
  230. self.init_proposal_boxes = nn.Embedding(num_proposals, 4)
  231. self.lossfunc = loss_func
  232. # Init parameters.
  233. init.reset_initialized_parameter(self)
  234. self._reset_parameters()
  235. def _reset_parameters(self):
  236. # init all parameters.
  237. prior_prob = 0.01
  238. bias_value = -math.log((1 - prior_prob) / prior_prob)
  239. for m in self.sublayers():
  240. if isinstance(m, nn.Linear):
  241. init.xavier_normal_(m.weight, reverse=True)
  242. elif not isinstance(m, nn.Embedding) and hasattr(
  243. m, "weight") and m.weight.dim() > 1:
  244. init.xavier_normal_(m.weight, reverse=False)
  245. if hasattr(m, "bias") and m.bias is not None and m.bias.shape[
  246. -1] == self.num_classes:
  247. init.constant_(m.bias, bias_value)
  248. init_bboxes = paddle.empty_like(self.init_proposal_boxes.weight)
  249. init_bboxes[:, :2] = 0.5
  250. init_bboxes[:, 2:] = 1.0
  251. self.init_proposal_boxes.weight.set_value(init_bboxes)
  252. @staticmethod
  253. def _init_box_pooler(input_shape):
  254. pooler_resolution = 7
  255. sampling_ratio = 2
  256. if input_shape is not None:
  257. pooler_scales = tuple(1.0 / input_shape[k].stride
  258. for k in range(len(input_shape)))
  259. in_channels = [
  260. input_shape[f].channels for f in range(len(input_shape))
  261. ]
  262. end_level = len(input_shape) - 1
  263. # Check all channel counts are equal
  264. assert len(set(in_channels)) == 1, in_channels
  265. else:
  266. pooler_scales = [1.0 / 4.0, 1.0 / 8.0, 1.0 / 16.0, 1.0 / 32.0]
  267. end_level = 3
  268. box_pooler = RoIAlign(
  269. resolution=pooler_resolution,
  270. spatial_scale=pooler_scales,
  271. sampling_ratio=sampling_ratio,
  272. end_level=end_level,
  273. aligned=True)
  274. return box_pooler
  275. def forward(self, features, input_whwh):
  276. bs = len(features[0])
  277. bboxes = box_cxcywh_to_xyxy(self.init_proposal_boxes.weight.clone(
  278. )).unsqueeze(0)
  279. bboxes = bboxes * input_whwh.unsqueeze(-2)
  280. init_features = self.init_proposal_features.weight.unsqueeze(0).tile(
  281. [1, bs, 1])
  282. proposal_features = init_features.clone()
  283. inter_class_logits = []
  284. inter_pred_bboxes = []
  285. for stage, rcnn_head in enumerate(self.head_series):
  286. class_logits, pred_bboxes, proposal_features = rcnn_head(
  287. features, bboxes, proposal_features, self.box_pooler)
  288. if self.return_intermediate or stage == len(self.head_series) - 1:
  289. inter_class_logits.append(class_logits)
  290. inter_pred_bboxes.append(pred_bboxes)
  291. bboxes = pred_bboxes.detach()
  292. output = {
  293. 'pred_logits': inter_class_logits[-1],
  294. 'pred_boxes': inter_pred_bboxes[-1]
  295. }
  296. if self.return_intermediate:
  297. output['aux_outputs'] = [{
  298. 'pred_logits': a,
  299. 'pred_boxes': b
  300. } for a, b in zip(inter_class_logits[:-1], inter_pred_bboxes[:-1])]
  301. return output
  302. def get_loss(self, outputs, targets):
  303. losses = self.lossfunc(outputs, targets)
  304. weight_dict = self.lossfunc.weight_dict
  305. for k in losses.keys():
  306. if k in weight_dict:
  307. losses[k] *= weight_dict[k]
  308. return losses
  309. def box_cxcywh_to_xyxy(x):
  310. x_c, y_c, w, h = x.unbind(-1)
  311. b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
  312. return paddle.stack(b, axis=-1)