distill.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import paddle
  18. import paddle.nn as nn
  19. import paddle.nn.functional as F
  20. from paddle import ParamAttr
  21. from ppdet.core.workspace import register, create, load_config
  22. from ppdet.modeling import ops
  23. from ppdet.utils.checkpoint import load_pretrain_weight
  24. from ppdet.utils.logger import setup_logger
  25. logger = setup_logger(__name__)
  26. class DistillModel(nn.Layer):
  27. def __init__(self, cfg, slim_cfg):
  28. super(DistillModel, self).__init__()
  29. self.student_model = create(cfg.architecture)
  30. logger.debug('Load student model pretrain_weights:{}'.format(
  31. cfg.pretrain_weights))
  32. load_pretrain_weight(self.student_model, cfg.pretrain_weights)
  33. slim_cfg = load_config(slim_cfg)
  34. self.teacher_model = create(slim_cfg.architecture)
  35. self.distill_loss = create(slim_cfg.distill_loss)
  36. logger.debug('Load teacher model pretrain_weights:{}'.format(
  37. slim_cfg.pretrain_weights))
  38. load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)
  39. for param in self.teacher_model.parameters():
  40. param.trainable = False
  41. def parameters(self):
  42. return self.student_model.parameters()
  43. def forward(self, inputs):
  44. if self.training:
  45. teacher_loss = self.teacher_model(inputs)
  46. student_loss = self.student_model(inputs)
  47. loss = self.distill_loss(self.teacher_model, self.student_model)
  48. student_loss['distill_loss'] = loss
  49. student_loss['teacher_loss'] = teacher_loss['loss']
  50. student_loss['loss'] += student_loss['distill_loss']
  51. return student_loss
  52. else:
  53. return self.student_model(inputs)
  54. class FGDDistillModel(nn.Layer):
  55. """
  56. Build FGD distill model.
  57. Args:
  58. cfg: The student config.
  59. slim_cfg: The teacher and distill config.
  60. """
  61. def __init__(self, cfg, slim_cfg):
  62. super(FGDDistillModel, self).__init__()
  63. self.is_inherit = True
  64. # build student model before load slim config
  65. self.student_model = create(cfg.architecture)
  66. self.arch = cfg.architecture
  67. stu_pretrain = cfg['pretrain_weights']
  68. slim_cfg = load_config(slim_cfg)
  69. self.teacher_cfg = slim_cfg
  70. self.loss_cfg = slim_cfg
  71. tea_pretrain = cfg['pretrain_weights']
  72. self.teacher_model = create(self.teacher_cfg.architecture)
  73. self.teacher_model.eval()
  74. for param in self.teacher_model.parameters():
  75. param.trainable = False
  76. if 'pretrain_weights' in cfg and stu_pretrain:
  77. if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
  78. load_pretrain_weight(self.student_model,
  79. self.teacher_cfg.pretrain_weights)
  80. logger.debug(
  81. "Inheriting! loading teacher weights to student model!")
  82. load_pretrain_weight(self.student_model, stu_pretrain)
  83. if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
  84. load_pretrain_weight(self.teacher_model,
  85. self.teacher_cfg.pretrain_weights)
  86. self.fgd_loss_dic = self.build_loss(
  87. self.loss_cfg.distill_loss,
  88. name_list=self.loss_cfg['distill_loss_name'])
  89. def build_loss(self,
  90. cfg,
  91. name_list=[
  92. 'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1',
  93. 'neck_f_0'
  94. ]):
  95. loss_func = dict()
  96. for idx, k in enumerate(name_list):
  97. loss_func[k] = create(cfg)
  98. return loss_func
  99. def forward(self, inputs):
  100. if self.training:
  101. s_body_feats = self.student_model.backbone(inputs)
  102. s_neck_feats = self.student_model.neck(s_body_feats)
  103. with paddle.no_grad():
  104. t_body_feats = self.teacher_model.backbone(inputs)
  105. t_neck_feats = self.teacher_model.neck(t_body_feats)
  106. loss_dict = {}
  107. for idx, k in enumerate(self.fgd_loss_dic):
  108. loss_dict[k] = self.fgd_loss_dic[k](s_neck_feats[idx],
  109. t_neck_feats[idx], inputs)
  110. if self.arch == "RetinaNet":
  111. loss = self.student_model.head(s_neck_feats, inputs)
  112. elif self.arch == "PicoDet":
  113. head_outs = self.student_model.head(
  114. s_neck_feats, self.student_model.export_post_process)
  115. loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
  116. total_loss = paddle.add_n(list(loss_gfl.values()))
  117. loss = {}
  118. loss.update(loss_gfl)
  119. loss.update({'loss': total_loss})
  120. else:
  121. raise ValueError(f"Unsupported model {self.arch}")
  122. for k in loss_dict:
  123. loss['loss'] += loss_dict[k]
  124. loss[k] = loss_dict[k]
  125. return loss
  126. else:
  127. body_feats = self.student_model.backbone(inputs)
  128. neck_feats = self.student_model.neck(body_feats)
  129. head_outs = self.student_model.head(neck_feats)
  130. if self.arch == "RetinaNet":
  131. bbox, bbox_num = self.student_model.head.post_process(
  132. head_outs, inputs['im_shape'], inputs['scale_factor'])
  133. return {'bbox': bbox, 'bbox_num': bbox_num}
  134. elif self.arch == "PicoDet":
  135. head_outs = self.student_model.head(
  136. neck_feats, self.student_model.export_post_process)
  137. scale_factor = inputs['scale_factor']
  138. bboxes, bbox_num = self.student_model.head.post_process(
  139. head_outs,
  140. scale_factor,
  141. export_nms=self.student_model.export_nms)
  142. return {'bbox': bboxes, 'bbox_num': bbox_num}
  143. else:
  144. raise ValueError(f"Unsupported model {self.arch}")
  145. class CWDDistillModel(nn.Layer):
  146. """
  147. Build CWD distill model.
  148. Args:
  149. cfg: The student config.
  150. slim_cfg: The teacher and distill config.
  151. """
  152. def __init__(self, cfg, slim_cfg):
  153. super(CWDDistillModel, self).__init__()
  154. self.is_inherit = False
  155. # build student model before load slim config
  156. self.student_model = create(cfg.architecture)
  157. self.arch = cfg.architecture
  158. if self.arch not in ['GFL', 'RetinaNet']:
  159. raise ValueError(
  160. f"The arch can only be one of ['GFL', 'RetinaNet'], but received {self.arch}"
  161. )
  162. stu_pretrain = cfg['pretrain_weights']
  163. slim_cfg = load_config(slim_cfg)
  164. self.teacher_cfg = slim_cfg
  165. self.loss_cfg = slim_cfg
  166. tea_pretrain = cfg['pretrain_weights']
  167. self.teacher_model = create(self.teacher_cfg.architecture)
  168. self.teacher_model.eval()
  169. for param in self.teacher_model.parameters():
  170. param.trainable = False
  171. if 'pretrain_weights' in cfg and stu_pretrain:
  172. if self.is_inherit and 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
  173. load_pretrain_weight(self.student_model,
  174. self.teacher_cfg.pretrain_weights)
  175. logger.debug(
  176. "Inheriting! loading teacher weights to student model!")
  177. load_pretrain_weight(self.student_model, stu_pretrain)
  178. if 'pretrain_weights' in self.teacher_cfg and self.teacher_cfg.pretrain_weights:
  179. load_pretrain_weight(self.teacher_model,
  180. self.teacher_cfg.pretrain_weights)
  181. self.loss_dic = self.build_loss(
  182. self.loss_cfg.distill_loss,
  183. name_list=self.loss_cfg['distill_loss_name'])
  184. def build_loss(self,
  185. cfg,
  186. name_list=[
  187. 'neck_f_4', 'neck_f_3', 'neck_f_2', 'neck_f_1',
  188. 'neck_f_0'
  189. ]):
  190. loss_func = dict()
  191. for idx, k in enumerate(name_list):
  192. loss_func[k] = create(cfg)
  193. return loss_func
  194. def get_loss_retinanet(self, stu_fea_list, tea_fea_list, inputs):
  195. loss = self.student_model.head(stu_fea_list, inputs)
  196. distill_loss = {}
  197. # cwd kd loss
  198. for idx, k in enumerate(self.loss_dic):
  199. distill_loss[k] = self.loss_dic[k](stu_fea_list[idx],
  200. tea_fea_list[idx])
  201. loss['loss'] += distill_loss[k]
  202. loss[k] = distill_loss[k]
  203. return loss
  204. def get_loss_gfl(self, stu_fea_list, tea_fea_list, inputs):
  205. loss = {}
  206. head_outs = self.student_model.head(stu_fea_list)
  207. loss_gfl = self.student_model.head.get_loss(head_outs, inputs)
  208. loss.update(loss_gfl)
  209. total_loss = paddle.add_n(list(loss.values()))
  210. loss.update({'loss': total_loss})
  211. # cwd kd loss
  212. feat_loss = {}
  213. loss_dict = {}
  214. s_cls_feat, t_cls_feat = [], []
  215. for s_neck_f, t_neck_f in zip(stu_fea_list, tea_fea_list):
  216. conv_cls_feat, _ = self.student_model.head.conv_feat(s_neck_f)
  217. cls_score = self.student_model.head.gfl_head_cls(conv_cls_feat)
  218. t_conv_cls_feat, _ = self.teacher_model.head.conv_feat(t_neck_f)
  219. t_cls_score = self.teacher_model.head.gfl_head_cls(t_conv_cls_feat)
  220. s_cls_feat.append(cls_score)
  221. t_cls_feat.append(t_cls_score)
  222. for idx, k in enumerate(self.loss_dic):
  223. loss_dict[k] = self.loss_dic[k](s_cls_feat[idx], t_cls_feat[idx])
  224. feat_loss[f"neck_f_{idx}"] = self.loss_dic[k](stu_fea_list[idx],
  225. tea_fea_list[idx])
  226. for k in feat_loss:
  227. loss['loss'] += feat_loss[k]
  228. loss[k] = feat_loss[k]
  229. for k in loss_dict:
  230. loss['loss'] += loss_dict[k]
  231. loss[k] = loss_dict[k]
  232. return loss
  233. def forward(self, inputs):
  234. if self.training:
  235. s_body_feats = self.student_model.backbone(inputs)
  236. s_neck_feats = self.student_model.neck(s_body_feats)
  237. with paddle.no_grad():
  238. t_body_feats = self.teacher_model.backbone(inputs)
  239. t_neck_feats = self.teacher_model.neck(t_body_feats)
  240. if self.arch == "RetinaNet":
  241. loss = self.get_loss_retinanet(s_neck_feats, t_neck_feats,
  242. inputs)
  243. elif self.arch == "GFL":
  244. loss = self.get_loss_gfl(s_neck_feats, t_neck_feats, inputs)
  245. else:
  246. raise ValueError(f"unsupported arch {self.arch}")
  247. return loss
  248. else:
  249. body_feats = self.student_model.backbone(inputs)
  250. neck_feats = self.student_model.neck(body_feats)
  251. head_outs = self.student_model.head(neck_feats)
  252. if self.arch == "RetinaNet":
  253. bbox, bbox_num = self.student_model.head.post_process(
  254. head_outs, inputs['im_shape'], inputs['scale_factor'])
  255. return {'bbox': bbox, 'bbox_num': bbox_num}
  256. elif self.arch == "GFL":
  257. bbox_pred, bbox_num = head_outs
  258. output = {'bbox': bbox_pred, 'bbox_num': bbox_num}
  259. return output
  260. else:
  261. raise ValueError(f"unsupported arch {self.arch}")
  262. @register
  263. class ChannelWiseDivergence(nn.Layer):
  264. def __init__(self, student_channels, teacher_channels, tau=1.0, weight=1.0):
  265. super(ChannelWiseDivergence, self).__init__()
  266. self.tau = tau
  267. self.loss_weight = weight
  268. if student_channels != teacher_channels:
  269. self.align = nn.Conv2D(
  270. student_channels,
  271. teacher_channels,
  272. kernel_size=1,
  273. stride=1,
  274. padding=0)
  275. else:
  276. self.align = None
  277. def distill_softmax(self, x, t):
  278. _, _, w, h = paddle.shape(x)
  279. x = paddle.reshape(x, [-1, w * h])
  280. x /= t
  281. return F.softmax(x, axis=1)
  282. def forward(self, preds_s, preds_t):
  283. assert preds_s.shape[-2:] == preds_t.shape[
  284. -2:], 'the output dim of teacher and student differ'
  285. N, C, W, H = preds_s.shape
  286. eps = 1e-5
  287. if self.align is not None:
  288. preds_s = self.align(preds_s)
  289. softmax_pred_s = self.distill_softmax(preds_s, self.tau)
  290. softmax_pred_t = self.distill_softmax(preds_t, self.tau)
  291. loss = paddle.sum(-softmax_pred_t * paddle.log(eps + softmax_pred_s) +
  292. softmax_pred_t * paddle.log(eps + softmax_pred_t))
  293. return self.loss_weight * loss / (C * N)
  294. @register
  295. class DistillYOLOv3Loss(nn.Layer):
  296. def __init__(self, weight=1000):
  297. super(DistillYOLOv3Loss, self).__init__()
  298. self.weight = weight
  299. def obj_weighted_reg(self, sx, sy, sw, sh, tx, ty, tw, th, tobj):
  300. loss_x = ops.sigmoid_cross_entropy_with_logits(sx, F.sigmoid(tx))
  301. loss_y = ops.sigmoid_cross_entropy_with_logits(sy, F.sigmoid(ty))
  302. loss_w = paddle.abs(sw - tw)
  303. loss_h = paddle.abs(sh - th)
  304. loss = paddle.add_n([loss_x, loss_y, loss_w, loss_h])
  305. weighted_loss = paddle.mean(loss * F.sigmoid(tobj))
  306. return weighted_loss
  307. def obj_weighted_cls(self, scls, tcls, tobj):
  308. loss = ops.sigmoid_cross_entropy_with_logits(scls, F.sigmoid(tcls))
  309. weighted_loss = paddle.mean(paddle.multiply(loss, F.sigmoid(tobj)))
  310. return weighted_loss
  311. def obj_loss(self, sobj, tobj):
  312. obj_mask = paddle.cast(tobj > 0., dtype="float32")
  313. obj_mask.stop_gradient = True
  314. loss = paddle.mean(
  315. ops.sigmoid_cross_entropy_with_logits(sobj, obj_mask))
  316. return loss
  317. def forward(self, teacher_model, student_model):
  318. teacher_distill_pairs = teacher_model.yolo_head.loss.distill_pairs
  319. student_distill_pairs = student_model.yolo_head.loss.distill_pairs
  320. distill_reg_loss, distill_cls_loss, distill_obj_loss = [], [], []
  321. for s_pair, t_pair in zip(student_distill_pairs, teacher_distill_pairs):
  322. distill_reg_loss.append(
  323. self.obj_weighted_reg(s_pair[0], s_pair[1], s_pair[2], s_pair[
  324. 3], t_pair[0], t_pair[1], t_pair[2], t_pair[3], t_pair[4]))
  325. distill_cls_loss.append(
  326. self.obj_weighted_cls(s_pair[5], t_pair[5], t_pair[4]))
  327. distill_obj_loss.append(self.obj_loss(s_pair[4], t_pair[4]))
  328. distill_reg_loss = paddle.add_n(distill_reg_loss)
  329. distill_cls_loss = paddle.add_n(distill_cls_loss)
  330. distill_obj_loss = paddle.add_n(distill_obj_loss)
  331. loss = (distill_reg_loss + distill_cls_loss + distill_obj_loss
  332. ) * self.weight
  333. return loss
  334. def parameter_init(mode="kaiming", value=0.):
  335. if mode == "kaiming":
  336. weight_attr = paddle.nn.initializer.KaimingUniform()
  337. elif mode == "constant":
  338. weight_attr = paddle.nn.initializer.Constant(value=value)
  339. else:
  340. weight_attr = paddle.nn.initializer.KaimingUniform()
  341. weight_init = ParamAttr(initializer=weight_attr)
  342. return weight_init
  343. @register
  344. class FGDFeatureLoss(nn.Layer):
  345. """
  346. The code is reference from https://github.com/yzd-v/FGD/blob/master/mmdet/distillation/losses/fgd.py
  347. Paddle version of `Focal and Global Knowledge Distillation for Detectors`
  348. Args:
  349. student_channels(int): The number of channels in the student's FPN feature map. Default to 256.
  350. teacher_channels(int): The number of channels in the teacher's FPN feature map. Default to 256.
  351. temp (float, optional): The temperature coefficient. Defaults to 0.5.
  352. alpha_fgd (float, optional): The weight of fg_loss. Defaults to 0.001
  353. beta_fgd (float, optional): The weight of bg_loss. Defaults to 0.0005
  354. gamma_fgd (float, optional): The weight of mask_loss. Defaults to 0.001
  355. lambda_fgd (float, optional): The weight of relation_loss. Defaults to 0.000005
  356. """
  357. def __init__(self,
  358. student_channels=256,
  359. teacher_channels=256,
  360. temp=0.5,
  361. alpha_fgd=0.001,
  362. beta_fgd=0.0005,
  363. gamma_fgd=0.001,
  364. lambda_fgd=0.000005):
  365. super(FGDFeatureLoss, self).__init__()
  366. self.temp = temp
  367. self.alpha_fgd = alpha_fgd
  368. self.beta_fgd = beta_fgd
  369. self.gamma_fgd = gamma_fgd
  370. self.lambda_fgd = lambda_fgd
  371. kaiming_init = parameter_init("kaiming")
  372. zeros_init = parameter_init("constant", 0.0)
  373. if student_channels != teacher_channels:
  374. self.align = nn.Conv2D(
  375. student_channels,
  376. teacher_channels,
  377. kernel_size=1,
  378. stride=1,
  379. padding=0,
  380. weight_attr=kaiming_init)
  381. student_channels = teacher_channels
  382. else:
  383. self.align = None
  384. self.conv_mask_s = nn.Conv2D(
  385. student_channels, 1, kernel_size=1, weight_attr=kaiming_init)
  386. self.conv_mask_t = nn.Conv2D(
  387. teacher_channels, 1, kernel_size=1, weight_attr=kaiming_init)
  388. self.stu_conv_block = nn.Sequential(
  389. nn.Conv2D(
  390. student_channels,
  391. student_channels // 2,
  392. kernel_size=1,
  393. weight_attr=zeros_init),
  394. nn.LayerNorm([student_channels // 2, 1, 1]),
  395. nn.ReLU(),
  396. nn.Conv2D(
  397. student_channels // 2,
  398. student_channels,
  399. kernel_size=1,
  400. weight_attr=zeros_init))
  401. self.tea_conv_block = nn.Sequential(
  402. nn.Conv2D(
  403. teacher_channels,
  404. teacher_channels // 2,
  405. kernel_size=1,
  406. weight_attr=zeros_init),
  407. nn.LayerNorm([teacher_channels // 2, 1, 1]),
  408. nn.ReLU(),
  409. nn.Conv2D(
  410. teacher_channels // 2,
  411. teacher_channels,
  412. kernel_size=1,
  413. weight_attr=zeros_init))
  414. def spatial_channel_attention(self, x, t=0.5):
  415. shape = paddle.shape(x)
  416. N, C, H, W = shape
  417. _f = paddle.abs(x)
  418. spatial_map = paddle.reshape(
  419. paddle.mean(
  420. _f, axis=1, keepdim=True) / t, [N, -1])
  421. spatial_map = F.softmax(spatial_map, axis=1, dtype="float32") * H * W
  422. spatial_att = paddle.reshape(spatial_map, [N, H, W])
  423. channel_map = paddle.mean(
  424. paddle.mean(
  425. _f, axis=2, keepdim=False), axis=2, keepdim=False)
  426. channel_att = F.softmax(channel_map / t, axis=1, dtype="float32") * C
  427. return [spatial_att, channel_att]
  428. def spatial_pool(self, x, mode="teacher"):
  429. batch, channel, width, height = x.shape
  430. x_copy = x
  431. x_copy = paddle.reshape(x_copy, [batch, channel, height * width])
  432. x_copy = x_copy.unsqueeze(1)
  433. if mode.lower() == "student":
  434. context_mask = self.conv_mask_s(x)
  435. else:
  436. context_mask = self.conv_mask_t(x)
  437. context_mask = paddle.reshape(context_mask, [batch, 1, height * width])
  438. context_mask = F.softmax(context_mask, axis=2)
  439. context_mask = context_mask.unsqueeze(-1)
  440. context = paddle.matmul(x_copy, context_mask)
  441. context = paddle.reshape(context, [batch, channel, 1, 1])
  442. return context
  443. def mask_loss(self, stu_channel_att, tea_channel_att, stu_spatial_att,
  444. tea_spatial_att):
  445. def _func(a, b):
  446. return paddle.sum(paddle.abs(a - b)) / len(a)
  447. mask_loss = _func(stu_channel_att, tea_channel_att) + _func(
  448. stu_spatial_att, tea_spatial_att)
  449. return mask_loss
  450. def feature_loss(self, stu_feature, tea_feature, Mask_fg, Mask_bg,
  451. tea_channel_att, tea_spatial_att):
  452. Mask_fg = Mask_fg.unsqueeze(axis=1)
  453. Mask_bg = Mask_bg.unsqueeze(axis=1)
  454. tea_channel_att = tea_channel_att.unsqueeze(axis=-1)
  455. tea_channel_att = tea_channel_att.unsqueeze(axis=-1)
  456. tea_spatial_att = tea_spatial_att.unsqueeze(axis=1)
  457. fea_t = paddle.multiply(tea_feature, paddle.sqrt(tea_spatial_att))
  458. fea_t = paddle.multiply(fea_t, paddle.sqrt(tea_channel_att))
  459. fg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_fg))
  460. bg_fea_t = paddle.multiply(fea_t, paddle.sqrt(Mask_bg))
  461. fea_s = paddle.multiply(stu_feature, paddle.sqrt(tea_spatial_att))
  462. fea_s = paddle.multiply(fea_s, paddle.sqrt(tea_channel_att))
  463. fg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_fg))
  464. bg_fea_s = paddle.multiply(fea_s, paddle.sqrt(Mask_bg))
  465. fg_loss = F.mse_loss(fg_fea_s, fg_fea_t, reduction="sum") / len(Mask_fg)
  466. bg_loss = F.mse_loss(bg_fea_s, bg_fea_t, reduction="sum") / len(Mask_bg)
  467. return fg_loss, bg_loss
  468. def relation_loss(self, stu_feature, tea_feature):
  469. context_s = self.spatial_pool(stu_feature, "student")
  470. context_t = self.spatial_pool(tea_feature, "teacher")
  471. out_s = stu_feature + self.stu_conv_block(context_s)
  472. out_t = tea_feature + self.tea_conv_block(context_t)
  473. rela_loss = F.mse_loss(out_s, out_t, reduction="sum") / len(out_s)
  474. return rela_loss
  475. def mask_value(self, mask, xl, xr, yl, yr, value):
  476. mask[xl:xr, yl:yr] = paddle.maximum(mask[xl:xr, yl:yr], value)
  477. return mask
  478. def forward(self, stu_feature, tea_feature, inputs):
  479. """Forward function.
  480. Args:
  481. stu_feature(Tensor): Bs*C*H*W, student's feature map
  482. tea_feature(Tensor): Bs*C*H*W, teacher's feature map
  483. inputs: The inputs with gt bbox and input shape info.
  484. """
  485. assert stu_feature.shape[-2:] == stu_feature.shape[-2:], \
  486. f'The shape of Student feature {stu_feature.shape} and Teacher feature {tea_feature.shape} should be the same.'
  487. assert "gt_bbox" in inputs.keys() and "im_shape" in inputs.keys(
  488. ), "ERROR! FGDFeatureLoss need gt_bbox and im_shape as inputs."
  489. gt_bboxes = inputs['gt_bbox']
  490. ins_shape = [
  491. inputs['im_shape'][i] for i in range(inputs['im_shape'].shape[0])
  492. ]
  493. index_gt = []
  494. for i in range(len(gt_bboxes)):
  495. if gt_bboxes[i].size > 2:
  496. index_gt.append(i)
  497. # only distill feature with labeled GTbox
  498. if len(index_gt) != len(gt_bboxes):
  499. index_gt_t = paddle.to_tensor(index_gt)
  500. preds_S = paddle.index_select(preds_S, index_gt_t)
  501. preds_T = paddle.index_select(preds_T, index_gt_t)
  502. ins_shape = [ins_shape[c] for c in index_gt]
  503. gt_bboxes = [gt_bboxes[c] for c in index_gt]
  504. assert len(gt_bboxes) == preds_T.shape[
  505. 0], f"The number of selected GT box [{len(gt_bboxes)}] should be same with first dim of input tensor [{preds_T.shape[0]}]."
  506. if self.align is not None:
  507. stu_feature = self.align(stu_feature)
  508. N, C, H, W = stu_feature.shape
  509. tea_spatial_att, tea_channel_att = self.spatial_channel_attention(
  510. tea_feature, self.temp)
  511. stu_spatial_att, stu_channel_att = self.spatial_channel_attention(
  512. stu_feature, self.temp)
  513. Mask_fg = paddle.zeros(tea_spatial_att.shape)
  514. Mask_bg = paddle.ones_like(tea_spatial_att)
  515. one_tmp = paddle.ones([*tea_spatial_att.shape[1:]])
  516. zero_tmp = paddle.zeros([*tea_spatial_att.shape[1:]])
  517. Mask_fg.stop_gradient = True
  518. Mask_bg.stop_gradient = True
  519. one_tmp.stop_gradient = True
  520. zero_tmp.stop_gradient = True
  521. wmin, wmax, hmin, hmax, area = [], [], [], [], []
  522. for i in range(N):
  523. tmp_box = paddle.ones_like(gt_bboxes[i])
  524. tmp_box.stop_gradient = True
  525. tmp_box[:, 0] = gt_bboxes[i][:, 0] / ins_shape[i][1] * W
  526. tmp_box[:, 2] = gt_bboxes[i][:, 2] / ins_shape[i][1] * W
  527. tmp_box[:, 1] = gt_bboxes[i][:, 1] / ins_shape[i][0] * H
  528. tmp_box[:, 3] = gt_bboxes[i][:, 3] / ins_shape[i][0] * H
  529. zero = paddle.zeros_like(tmp_box[:, 0], dtype="int32")
  530. ones = paddle.ones_like(tmp_box[:, 2], dtype="int32")
  531. zero.stop_gradient = True
  532. ones.stop_gradient = True
  533. wmin.append(
  534. paddle.cast(paddle.floor(tmp_box[:, 0]), "int32").maximum(zero))
  535. wmax.append(paddle.cast(paddle.ceil(tmp_box[:, 2]), "int32"))
  536. hmin.append(
  537. paddle.cast(paddle.floor(tmp_box[:, 1]), "int32").maximum(zero))
  538. hmax.append(paddle.cast(paddle.ceil(tmp_box[:, 3]), "int32"))
  539. area_recip = 1.0 / (
  540. hmax[i].reshape([1, -1]) + 1 - hmin[i].reshape([1, -1])) / (
  541. wmax[i].reshape([1, -1]) + 1 - wmin[i].reshape([1, -1]))
  542. for j in range(len(gt_bboxes[i])):
  543. Mask_fg[i] = self.mask_value(Mask_fg[i], hmin[i][j],
  544. hmax[i][j] + 1, wmin[i][j],
  545. wmax[i][j] + 1, area_recip[0][j])
  546. Mask_bg[i] = paddle.where(Mask_fg[i] > zero_tmp, zero_tmp, one_tmp)
  547. if paddle.sum(Mask_bg[i]):
  548. Mask_bg[i] /= paddle.sum(Mask_bg[i])
  549. fg_loss, bg_loss = self.feature_loss(stu_feature, tea_feature, Mask_fg,
  550. Mask_bg, tea_channel_att,
  551. tea_spatial_att)
  552. mask_loss = self.mask_loss(stu_channel_att, tea_channel_att,
  553. stu_spatial_att, tea_spatial_att)
  554. rela_loss = self.relation_loss(stu_feature, tea_feature)
  555. loss = self.alpha_fgd * fg_loss + self.beta_fgd * bg_loss \
  556. + self.gamma_fgd * mask_loss + self.lambda_fgd * rela_loss
  557. return loss
  558. class LDDistillModel(nn.Layer):
  559. def __init__(self, cfg, slim_cfg):
  560. super(LDDistillModel, self).__init__()
  561. self.student_model = create(cfg.architecture)
  562. logger.debug('Load student model pretrain_weights:{}'.format(
  563. cfg.pretrain_weights))
  564. load_pretrain_weight(self.student_model, cfg.pretrain_weights)
  565. slim_cfg = load_config(slim_cfg) #rewrite student cfg
  566. self.teacher_model = create(slim_cfg.architecture)
  567. logger.debug('Load teacher model pretrain_weights:{}'.format(
  568. slim_cfg.pretrain_weights))
  569. load_pretrain_weight(self.teacher_model, slim_cfg.pretrain_weights)
  570. for param in self.teacher_model.parameters():
  571. param.trainable = False
  572. def parameters(self):
  573. return self.student_model.parameters()
  574. def forward(self, inputs):
  575. if self.training:
  576. with paddle.no_grad():
  577. t_body_feats = self.teacher_model.backbone(inputs)
  578. t_neck_feats = self.teacher_model.neck(t_body_feats)
  579. t_head_outs = self.teacher_model.head(t_neck_feats)
  580. #student_loss = self.student_model(inputs)
  581. s_body_feats = self.student_model.backbone(inputs)
  582. s_neck_feats = self.student_model.neck(s_body_feats)
  583. s_head_outs = self.student_model.head(s_neck_feats)
  584. soft_label_list = t_head_outs[0]
  585. soft_targets_list = t_head_outs[1]
  586. student_loss = self.student_model.head.get_loss(
  587. s_head_outs, inputs, soft_label_list, soft_targets_list)
  588. total_loss = paddle.add_n(list(student_loss.values()))
  589. student_loss['loss'] = total_loss
  590. return student_loss
  591. else:
  592. return self.student_model(inputs)
  593. @register
  594. class KnowledgeDistillationKLDivLoss(nn.Layer):
  595. """Loss function for knowledge distilling using KL divergence.
  596. Args:
  597. reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
  598. loss_weight (float): Loss weight of current loss.
  599. T (int): Temperature for distillation.
  600. """
  601. def __init__(self, reduction='mean', loss_weight=1.0, T=10):
  602. super(KnowledgeDistillationKLDivLoss, self).__init__()
  603. assert reduction in ('none', 'mean', 'sum')
  604. assert T >= 1
  605. self.reduction = reduction
  606. self.loss_weight = loss_weight
  607. self.T = T
  608. def knowledge_distillation_kl_div_loss(self,
  609. pred,
  610. soft_label,
  611. T,
  612. detach_target=True):
  613. r"""Loss function for knowledge distilling using KL divergence.
  614. Args:
  615. pred (Tensor): Predicted logits with shape (N, n + 1).
  616. soft_label (Tensor): Target logits with shape (N, N + 1).
  617. T (int): Temperature for distillation.
  618. detach_target (bool): Remove soft_label from automatic differentiation
  619. Returns:
  620. torch.Tensor: Loss tensor with shape (N,).
  621. """
  622. assert pred.shape == soft_label.shape
  623. target = F.softmax(soft_label / T, axis=1)
  624. if detach_target:
  625. target = target.detach()
  626. kd_loss = F.kl_div(
  627. F.log_softmax(
  628. pred / T, axis=1), target, reduction='none').mean(1) * (T * T)
  629. return kd_loss
  630. def forward(self,
  631. pred,
  632. soft_label,
  633. weight=None,
  634. avg_factor=None,
  635. reduction_override=None):
  636. """Forward function.
  637. Args:
  638. pred (Tensor): Predicted logits with shape (N, n + 1).
  639. soft_label (Tensor): Target logits with shape (N, N + 1).
  640. weight (Tensor, optional): The weight of loss for each
  641. prediction. Defaults to None.
  642. avg_factor (int, optional): Average factor that is used to average
  643. the loss. Defaults to None.
  644. reduction_override (str, optional): The reduction method used to
  645. override the original reduction method of the loss.
  646. Defaults to None.
  647. """
  648. assert reduction_override in (None, 'none', 'mean', 'sum')
  649. reduction = (reduction_override
  650. if reduction_override else self.reduction)
  651. loss_kd_out = self.knowledge_distillation_kl_div_loss(
  652. pred, soft_label, T=self.T)
  653. if weight is not None:
  654. loss_kd_out = weight * loss_kd_out
  655. if avg_factor is None:
  656. if reduction == 'none':
  657. loss = loss_kd_out
  658. elif reduction == 'mean':
  659. loss = loss_kd_out.mean()
  660. elif reduction == 'sum':
  661. loss = loss_kd_out.sum()
  662. else:
  663. # if reduction is mean, then average the loss by avg_factor
  664. if reduction == 'mean':
  665. loss = loss_kd_out.sum() / avg_factor
  666. # if reduction is 'none', then do nothing, otherwise raise an error
  667. elif reduction != 'none':
  668. raise ValueError(
  669. 'avg_factor can not be used with reduction="sum"')
  670. loss_kd = self.loss_weight * loss
  671. return loss_kd