bbox_head.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import numpy as np
  15. import paddle
  16. import paddle.nn as nn
  17. import paddle.nn.functional as F
  18. from paddle.nn.initializer import Normal, XavierUniform, KaimingNormal
  19. from paddle.regularizer import L2Decay
  20. from ppdet.core.workspace import register, create
  21. from .roi_extractor import RoIAlign
  22. from ..shape_spec import ShapeSpec
  23. from ..bbox_utils import bbox2delta
  24. from ..cls_utils import _get_class_default_kwargs
  25. from ppdet.modeling.layers import ConvNormLayer
  26. __all__ = ['TwoFCHead', 'XConvNormHead', 'BBoxHead']
  27. @register
  28. class TwoFCHead(nn.Layer):
  29. """
  30. RCNN bbox head with Two fc layers to extract feature
  31. Args:
  32. in_channel (int): Input channel which can be derived by from_config
  33. out_channel (int): Output channel
  34. resolution (int): Resolution of input feature map, default 7
  35. """
  36. def __init__(self, in_channel=256, out_channel=1024, resolution=7):
  37. super(TwoFCHead, self).__init__()
  38. self.in_channel = in_channel
  39. self.out_channel = out_channel
  40. fan = in_channel * resolution * resolution
  41. self.fc6 = nn.Linear(
  42. in_channel * resolution * resolution,
  43. out_channel,
  44. weight_attr=paddle.ParamAttr(
  45. initializer=XavierUniform(fan_out=fan)))
  46. self.fc6.skip_quant = True
  47. self.fc7 = nn.Linear(
  48. out_channel,
  49. out_channel,
  50. weight_attr=paddle.ParamAttr(initializer=XavierUniform()))
  51. self.fc7.skip_quant = True
  52. @classmethod
  53. def from_config(cls, cfg, input_shape):
  54. s = input_shape
  55. s = s[0] if isinstance(s, (list, tuple)) else s
  56. return {'in_channel': s.channels}
  57. @property
  58. def out_shape(self):
  59. return [ShapeSpec(channels=self.out_channel, )]
  60. def forward(self, rois_feat):
  61. rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
  62. fc6 = self.fc6(rois_feat)
  63. fc6 = F.relu(fc6)
  64. fc7 = self.fc7(fc6)
  65. fc7 = F.relu(fc7)
  66. return fc7
  67. @register
  68. class XConvNormHead(nn.Layer):
  69. __shared__ = ['norm_type', 'freeze_norm']
  70. """
  71. RCNN bbox head with serveral convolution layers
  72. Args:
  73. in_channel (int): Input channels which can be derived by from_config
  74. num_convs (int): The number of conv layers
  75. conv_dim (int): The number of channels for the conv layers
  76. out_channel (int): Output channels
  77. resolution (int): Resolution of input feature map
  78. norm_type (string): Norm type, bn, gn, sync_bn are available,
  79. default `gn`
  80. freeze_norm (bool): Whether to freeze the norm
  81. stage_name (string): Prefix name for conv layer, '' by default
  82. """
  83. def __init__(self,
  84. in_channel=256,
  85. num_convs=4,
  86. conv_dim=256,
  87. out_channel=1024,
  88. resolution=7,
  89. norm_type='gn',
  90. freeze_norm=False,
  91. stage_name=''):
  92. super(XConvNormHead, self).__init__()
  93. self.in_channel = in_channel
  94. self.num_convs = num_convs
  95. self.conv_dim = conv_dim
  96. self.out_channel = out_channel
  97. self.norm_type = norm_type
  98. self.freeze_norm = freeze_norm
  99. self.bbox_head_convs = []
  100. fan = conv_dim * 3 * 3
  101. initializer = KaimingNormal(fan_in=fan)
  102. for i in range(self.num_convs):
  103. in_c = in_channel if i == 0 else conv_dim
  104. head_conv_name = stage_name + 'bbox_head_conv{}'.format(i)
  105. head_conv = self.add_sublayer(
  106. head_conv_name,
  107. ConvNormLayer(
  108. ch_in=in_c,
  109. ch_out=conv_dim,
  110. filter_size=3,
  111. stride=1,
  112. norm_type=self.norm_type,
  113. freeze_norm=self.freeze_norm,
  114. initializer=initializer))
  115. self.bbox_head_convs.append(head_conv)
  116. fan = conv_dim * resolution * resolution
  117. self.fc6 = nn.Linear(
  118. conv_dim * resolution * resolution,
  119. out_channel,
  120. weight_attr=paddle.ParamAttr(
  121. initializer=XavierUniform(fan_out=fan)),
  122. bias_attr=paddle.ParamAttr(
  123. learning_rate=2., regularizer=L2Decay(0.)))
  124. @classmethod
  125. def from_config(cls, cfg, input_shape):
  126. s = input_shape
  127. s = s[0] if isinstance(s, (list, tuple)) else s
  128. return {'in_channel': s.channels}
  129. @property
  130. def out_shape(self):
  131. return [ShapeSpec(channels=self.out_channel, )]
  132. def forward(self, rois_feat):
  133. for i in range(self.num_convs):
  134. rois_feat = F.relu(self.bbox_head_convs[i](rois_feat))
  135. rois_feat = paddle.flatten(rois_feat, start_axis=1, stop_axis=-1)
  136. fc6 = F.relu(self.fc6(rois_feat))
  137. return fc6
  138. @register
  139. class BBoxHead(nn.Layer):
  140. __shared__ = ['num_classes', 'use_cot']
  141. __inject__ = ['bbox_assigner', 'bbox_loss', 'loss_cot']
  142. """
  143. RCNN bbox head
  144. Args:
  145. head (nn.Layer): Extract feature in bbox head
  146. in_channel (int): Input channel after RoI extractor
  147. roi_extractor (object): The module of RoI Extractor
  148. bbox_assigner (object): The module of Box Assigner, label and sample the
  149. box.
  150. with_pool (bool): Whether to use pooling for the RoI feature.
  151. num_classes (int): The number of classes
  152. bbox_weight (List[float]): The weight to get the decode box
  153. cot_classes (int): The number of base classes
  154. loss_cot (object): The module of Label-cotuning
  155. use_cot(bool): whether to use Label-cotuning
  156. """
  157. def __init__(self,
  158. head,
  159. in_channel,
  160. roi_extractor=_get_class_default_kwargs(RoIAlign),
  161. bbox_assigner='BboxAssigner',
  162. with_pool=False,
  163. num_classes=80,
  164. bbox_weight=[10., 10., 5., 5.],
  165. bbox_loss=None,
  166. loss_normalize_pos=False,
  167. cot_classes=None,
  168. loss_cot='COTLoss',
  169. use_cot=False):
  170. super(BBoxHead, self).__init__()
  171. self.head = head
  172. self.roi_extractor = roi_extractor
  173. if isinstance(roi_extractor, dict):
  174. self.roi_extractor = RoIAlign(**roi_extractor)
  175. self.bbox_assigner = bbox_assigner
  176. self.with_pool = with_pool
  177. self.num_classes = num_classes
  178. self.bbox_weight = bbox_weight
  179. self.bbox_loss = bbox_loss
  180. self.loss_normalize_pos = loss_normalize_pos
  181. self.loss_cot = loss_cot
  182. self.cot_relation = None
  183. self.cot_classes = cot_classes
  184. self.use_cot = use_cot
  185. if use_cot:
  186. self.cot_bbox_score = nn.Linear(
  187. in_channel,
  188. self.num_classes + 1,
  189. weight_attr=paddle.ParamAttr(initializer=Normal(
  190. mean=0.0, std=0.01)))
  191. self.bbox_score = nn.Linear(
  192. in_channel,
  193. self.cot_classes + 1,
  194. weight_attr=paddle.ParamAttr(initializer=Normal(
  195. mean=0.0, std=0.01)))
  196. self.cot_bbox_score.skip_quant = True
  197. else:
  198. self.bbox_score = nn.Linear(
  199. in_channel,
  200. self.num_classes + 1,
  201. weight_attr=paddle.ParamAttr(initializer=Normal(
  202. mean=0.0, std=0.01)))
  203. self.bbox_score.skip_quant = True
  204. self.bbox_delta = nn.Linear(
  205. in_channel,
  206. 4 * self.num_classes,
  207. weight_attr=paddle.ParamAttr(initializer=Normal(
  208. mean=0.0, std=0.001)))
  209. self.bbox_delta.skip_quant = True
  210. self.assigned_label = None
  211. self.assigned_rois = None
  212. def init_cot_head(self, relationship):
  213. self.cot_relation = relationship
  214. @classmethod
  215. def from_config(cls, cfg, input_shape):
  216. roi_pooler = cfg['roi_extractor']
  217. assert isinstance(roi_pooler, dict)
  218. kwargs = RoIAlign.from_config(cfg, input_shape)
  219. roi_pooler.update(kwargs)
  220. kwargs = {'input_shape': input_shape}
  221. head = create(cfg['head'], **kwargs)
  222. return {
  223. 'roi_extractor': roi_pooler,
  224. 'head': head,
  225. 'in_channel': head.out_shape[0].channels
  226. }
  227. def forward(self, body_feats=None, rois=None, rois_num=None, inputs=None, cot=False):
  228. """
  229. body_feats (list[Tensor]): Feature maps from backbone
  230. rois (list[Tensor]): RoIs generated from RPN module
  231. rois_num (Tensor): The number of RoIs in each image
  232. inputs (dict{Tensor}): The ground-truth of image
  233. """
  234. if self.training:
  235. rois, rois_num, targets = self.bbox_assigner(rois, rois_num, inputs)
  236. self.assigned_rois = (rois, rois_num)
  237. self.assigned_targets = targets
  238. rois_feat = self.roi_extractor(body_feats, rois, rois_num)
  239. bbox_feat = self.head(rois_feat)
  240. if self.with_pool:
  241. feat = F.adaptive_avg_pool2d(bbox_feat, output_size=1)
  242. feat = paddle.squeeze(feat, axis=[2, 3])
  243. else:
  244. feat = bbox_feat
  245. if self.use_cot:
  246. scores = self.cot_bbox_score(feat)
  247. cot_scores = self.bbox_score(feat)
  248. else:
  249. scores = self.bbox_score(feat)
  250. deltas = self.bbox_delta(feat)
  251. if self.training:
  252. loss = self.get_loss(
  253. scores,
  254. deltas,
  255. targets,
  256. rois,
  257. self.bbox_weight,
  258. loss_normalize_pos=self.loss_normalize_pos)
  259. if self.cot_relation is not None:
  260. loss_cot = self.loss_cot(cot_scores, targets, self.cot_relation)
  261. loss.update(loss_cot)
  262. return loss, bbox_feat
  263. else:
  264. if cot:
  265. pred = self.get_prediction(cot_scores, deltas)
  266. else:
  267. pred = self.get_prediction(scores, deltas)
  268. return pred, self.head
  269. def get_loss(self,
  270. scores,
  271. deltas,
  272. targets,
  273. rois,
  274. bbox_weight,
  275. loss_normalize_pos=False):
  276. """
  277. scores (Tensor): scores from bbox head outputs
  278. deltas (Tensor): deltas from bbox head outputs
  279. targets (list[List[Tensor]]): bbox targets containing tgt_labels, tgt_bboxes and tgt_gt_inds
  280. rois (List[Tensor]): RoIs generated in each batch
  281. """
  282. cls_name = 'loss_bbox_cls'
  283. reg_name = 'loss_bbox_reg'
  284. loss_bbox = {}
  285. # TODO: better pass args
  286. tgt_labels, tgt_bboxes, tgt_gt_inds = targets
  287. # bbox cls
  288. tgt_labels = paddle.concat(tgt_labels) if len(
  289. tgt_labels) > 1 else tgt_labels[0]
  290. valid_inds = paddle.nonzero(tgt_labels >= 0).flatten()
  291. if valid_inds.shape[0] == 0:
  292. loss_bbox[cls_name] = paddle.zeros([1], dtype='float32')
  293. else:
  294. tgt_labels = tgt_labels.cast('int64')
  295. tgt_labels.stop_gradient = True
  296. if not loss_normalize_pos:
  297. loss_bbox_cls = F.cross_entropy(
  298. input=scores, label=tgt_labels, reduction='mean')
  299. else:
  300. loss_bbox_cls = F.cross_entropy(
  301. input=scores, label=tgt_labels,
  302. reduction='none').sum() / (tgt_labels.shape[0] + 1e-7)
  303. loss_bbox[cls_name] = loss_bbox_cls
  304. # bbox reg
  305. cls_agnostic_bbox_reg = deltas.shape[1] == 4
  306. fg_inds = paddle.nonzero(
  307. paddle.logical_and(tgt_labels >= 0, tgt_labels <
  308. self.num_classes)).flatten()
  309. if fg_inds.numel() == 0:
  310. loss_bbox[reg_name] = paddle.zeros([1], dtype='float32')
  311. return loss_bbox
  312. if cls_agnostic_bbox_reg:
  313. reg_delta = paddle.gather(deltas, fg_inds)
  314. else:
  315. fg_gt_classes = paddle.gather(tgt_labels, fg_inds)
  316. reg_row_inds = paddle.arange(fg_gt_classes.shape[0]).unsqueeze(1)
  317. reg_row_inds = paddle.tile(reg_row_inds, [1, 4]).reshape([-1, 1])
  318. reg_col_inds = 4 * fg_gt_classes.unsqueeze(1) + paddle.arange(4)
  319. reg_col_inds = reg_col_inds.reshape([-1, 1])
  320. reg_inds = paddle.concat([reg_row_inds, reg_col_inds], axis=1)
  321. reg_delta = paddle.gather(deltas, fg_inds)
  322. reg_delta = paddle.gather_nd(reg_delta, reg_inds).reshape([-1, 4])
  323. rois = paddle.concat(rois) if len(rois) > 1 else rois[0]
  324. tgt_bboxes = paddle.concat(tgt_bboxes) if len(
  325. tgt_bboxes) > 1 else tgt_bboxes[0]
  326. reg_target = bbox2delta(rois, tgt_bboxes, bbox_weight)
  327. reg_target = paddle.gather(reg_target, fg_inds)
  328. reg_target.stop_gradient = True
  329. if self.bbox_loss is not None:
  330. reg_delta = self.bbox_transform(reg_delta)
  331. reg_target = self.bbox_transform(reg_target)
  332. if not loss_normalize_pos:
  333. loss_bbox_reg = self.bbox_loss(
  334. reg_delta, reg_target).sum() / tgt_labels.shape[0]
  335. loss_bbox_reg *= self.num_classes
  336. else:
  337. loss_bbox_reg = self.bbox_loss(
  338. reg_delta, reg_target).sum() / (tgt_labels.shape[0] + 1e-7)
  339. else:
  340. loss_bbox_reg = paddle.abs(reg_delta - reg_target).sum(
  341. ) / tgt_labels.shape[0]
  342. loss_bbox[reg_name] = loss_bbox_reg
  343. return loss_bbox
  344. def bbox_transform(self, deltas, weights=[0.1, 0.1, 0.2, 0.2]):
  345. wx, wy, ww, wh = weights
  346. deltas = paddle.reshape(deltas, shape=(0, -1, 4))
  347. dx = paddle.slice(deltas, axes=[2], starts=[0], ends=[1]) * wx
  348. dy = paddle.slice(deltas, axes=[2], starts=[1], ends=[2]) * wy
  349. dw = paddle.slice(deltas, axes=[2], starts=[2], ends=[3]) * ww
  350. dh = paddle.slice(deltas, axes=[2], starts=[3], ends=[4]) * wh
  351. dw = paddle.clip(dw, -1.e10, np.log(1000. / 16))
  352. dh = paddle.clip(dh, -1.e10, np.log(1000. / 16))
  353. pred_ctr_x = dx
  354. pred_ctr_y = dy
  355. pred_w = paddle.exp(dw)
  356. pred_h = paddle.exp(dh)
  357. x1 = pred_ctr_x - 0.5 * pred_w
  358. y1 = pred_ctr_y - 0.5 * pred_h
  359. x2 = pred_ctr_x + 0.5 * pred_w
  360. y2 = pred_ctr_y + 0.5 * pred_h
  361. x1 = paddle.reshape(x1, shape=(-1, ))
  362. y1 = paddle.reshape(y1, shape=(-1, ))
  363. x2 = paddle.reshape(x2, shape=(-1, ))
  364. y2 = paddle.reshape(y2, shape=(-1, ))
  365. return paddle.concat([x1, y1, x2, y2])
  366. def get_prediction(self, score, delta):
  367. bbox_prob = F.softmax(score)
  368. return delta, bbox_prob
  369. def get_head(self, ):
  370. return self.head
  371. def get_assigned_targets(self, ):
  372. return self.assigned_targets
  373. def get_assigned_rois(self, ):
  374. return self.assigned_rois