trainer_ssod.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480
  1. # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import copy
  20. import time
  21. import typing
  22. import math
  23. import numpy as np
  24. import paddle
  25. import paddle.nn as nn
  26. import paddle.distributed as dist
  27. from paddle.distributed import fleet
  28. from ppdet.optimizer import ModelEMA, SimpleModelEMA
  29. from ppdet.core.workspace import create
  30. from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
  31. import ppdet.utils.stats as stats
  32. from ppdet.utils import profiler
  33. from ppdet.modeling.ssod_utils import align_weak_strong_shape
  34. from .trainer import Trainer
  35. from ppdet.utils.logger import setup_logger
  36. logger = setup_logger('ppdet.engine')
  37. __all__ = ['Trainer_DenseTeacher']
  38. class Trainer_DenseTeacher(Trainer):
  39. def __init__(self, cfg, mode='train'):
  40. self.cfg = cfg
  41. assert mode.lower() in ['train', 'eval', 'test'], \
  42. "mode should be 'train', 'eval' or 'test'"
  43. self.mode = mode.lower()
  44. self.optimizer = None
  45. self.is_loaded_weights = False
  46. self.use_amp = self.cfg.get('amp', False)
  47. self.amp_level = self.cfg.get('amp_level', 'O1')
  48. self.custom_white_list = self.cfg.get('custom_white_list', None)
  49. self.custom_black_list = self.cfg.get('custom_black_list', None)
  50. # build data loader
  51. capital_mode = self.mode.capitalize()
  52. self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create(
  53. '{}Dataset'.format(capital_mode))()
  54. if self.mode == 'train':
  55. self.dataset_unlabel = self.cfg['UnsupTrainDataset'] = create(
  56. 'UnsupTrainDataset')
  57. self.loader = create('SemiTrainReader')(
  58. self.dataset, self.dataset_unlabel, cfg.worker_num)
  59. # build model
  60. if 'model' not in self.cfg:
  61. self.model = create(cfg.architecture)
  62. else:
  63. self.model = self.cfg.model
  64. self.is_loaded_weights = True
  65. # EvalDataset build with BatchSampler to evaluate in single device
  66. # TODO: multi-device evaluate
  67. if self.mode == 'eval':
  68. self._eval_batch_sampler = paddle.io.BatchSampler(
  69. self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
  70. # If metric is VOC, need to be set collate_batch=False.
  71. if cfg.metric == 'VOC':
  72. cfg['EvalReader']['collate_batch'] = False
  73. self.loader = create('EvalReader')(self.dataset, cfg.worker_num,
  74. self._eval_batch_sampler)
  75. # TestDataset build after user set images, skip loader creation here
  76. # build optimizer in train mode
  77. if self.mode == 'train':
  78. steps_per_epoch = len(self.loader)
  79. if steps_per_epoch < 1:
  80. logger.warning(
  81. "Samples in dataset are less than batch_size, please set smaller batch_size in TrainReader."
  82. )
  83. self.lr = create('LearningRate')(steps_per_epoch)
  84. self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
  85. # Unstructured pruner is only enabled in the train mode.
  86. if self.cfg.get('unstructured_prune'):
  87. self.pruner = create('UnstructuredPruner')(self.model,
  88. steps_per_epoch)
  89. if self.use_amp and self.amp_level == 'O2':
  90. self.model, self.optimizer = paddle.amp.decorate(
  91. models=self.model,
  92. optimizers=self.optimizer,
  93. level=self.amp_level)
  94. self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
  95. if self.use_ema:
  96. ema_decay = self.cfg.get('ema_decay', 0.9998)
  97. ema_decay_type = self.cfg.get('ema_decay_type', 'threshold')
  98. cycle_epoch = self.cfg.get('cycle_epoch', -1)
  99. ema_black_list = self.cfg.get('ema_black_list', None)
  100. self.ema = ModelEMA(
  101. self.model,
  102. decay=ema_decay,
  103. ema_decay_type=ema_decay_type,
  104. cycle_epoch=cycle_epoch,
  105. ema_black_list=ema_black_list)
  106. self.ema_start_iters = self.cfg.get('ema_start_iters', 0)
  107. # simple_ema for SSOD
  108. self.use_simple_ema = ('use_simple_ema' in cfg and
  109. cfg['use_simple_ema'])
  110. if self.use_simple_ema:
  111. self.use_ema = True
  112. ema_decay = self.cfg.get('ema_decay', 0.9996)
  113. self.ema = SimpleModelEMA(self.model, decay=ema_decay)
  114. self.ema_start_iters = self.cfg.get('ema_start_iters', 0)
  115. self._nranks = dist.get_world_size()
  116. self._local_rank = dist.get_rank()
  117. self.status = {}
  118. self.start_epoch = 0
  119. self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
  120. # initial default callbacks
  121. self._init_callbacks()
  122. # initial default metrics
  123. self._init_metrics()
  124. self._reset_metrics()
  125. def load_weights(self, weights):
  126. if self.is_loaded_weights:
  127. return
  128. self.start_epoch = 0
  129. load_pretrain_weight(self.model, weights)
  130. load_pretrain_weight(self.ema.model, weights)
  131. logger.info("Load weights {} to start training for teacher and student".
  132. format(weights))
  133. def resume_weights(self, weights, exchange=True):
  134. # support Distill resume weights
  135. if hasattr(self.model, 'student_model'):
  136. self.start_epoch = load_weight(self.model.student_model, weights,
  137. self.optimizer, exchange)
  138. else:
  139. self.start_epoch = load_weight(self.model, weights, self.optimizer,
  140. self.ema
  141. if self.use_ema else None, exchange)
  142. logger.debug("Resume weights of epoch {}".format(self.start_epoch))
  143. def train(self, validate=False):
  144. self.semi_start_iters = self.cfg.get('semi_start_iters', 5000)
  145. Init_mark = False
  146. if validate:
  147. self.cfg['EvalDataset'] = self.cfg.EvalDataset = create(
  148. "EvalDataset")()
  149. sync_bn = (getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
  150. self.cfg.use_gpu and self._nranks > 1)
  151. if sync_bn:
  152. self.model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(
  153. self.model)
  154. if self.cfg.get('fleet', False):
  155. self.model = fleet.distributed_model(self.model)
  156. self.optimizer = fleet.distributed_optimizer(self.optimizer)
  157. elif self._nranks > 1:
  158. find_unused_parameters = self.cfg[
  159. 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
  160. self.model = paddle.DataParallel(
  161. self.model, find_unused_parameters=find_unused_parameters)
  162. self.ema.model = paddle.DataParallel(
  163. self.ema.model, find_unused_parameters=find_unused_parameters)
  164. self.status.update({
  165. 'epoch_id': self.start_epoch,
  166. 'step_id': 0,
  167. 'steps_per_epoch': len(self.loader),
  168. 'exchange_save_model': True,
  169. })
  170. # Note: exchange_save_model
  171. # in DenseTeacher SSOD, the teacher model will be higher, so exchange when saving pdparams
  172. self.status['batch_time'] = stats.SmoothedValue(
  173. self.cfg.log_iter, fmt='{avg:.4f}')
  174. self.status['data_time'] = stats.SmoothedValue(
  175. self.cfg.log_iter, fmt='{avg:.4f}')
  176. self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
  177. if self.cfg.get('print_flops', False):
  178. flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
  179. self.dataset, self.cfg.worker_num)
  180. self._flops(flops_loader)
  181. profiler_options = self.cfg.get('profiler_options', None)
  182. self._compose_callback.on_train_begin(self.status)
  183. train_cfg = self.cfg.DenseTeacher['train_cfg']
  184. concat_sup_data = train_cfg.get('concat_sup_data', True)
  185. for param in self.ema.model.parameters():
  186. param.stop_gradient = True
  187. for epoch_id in range(self.start_epoch, self.cfg.epoch):
  188. self.status['mode'] = 'train'
  189. self.status['epoch_id'] = epoch_id
  190. self._compose_callback.on_epoch_begin(self.status)
  191. self.loader.dataset_label.set_epoch(epoch_id)
  192. self.loader.dataset_unlabel.set_epoch(epoch_id)
  193. iter_tic = time.time()
  194. loss_dict = {
  195. 'loss': paddle.to_tensor([0]),
  196. 'loss_sup_sum': paddle.to_tensor([0]),
  197. 'loss_unsup_sum': paddle.to_tensor([0]),
  198. 'fg_sum': paddle.to_tensor([0]),
  199. }
  200. if self._nranks > 1:
  201. for k in self.model._layers.get_loss_keys():
  202. loss_dict.update({k: paddle.to_tensor([0.])})
  203. for k in self.model._layers.get_loss_keys():
  204. loss_dict.update({'distill_' + k: paddle.to_tensor([0.])})
  205. else:
  206. for k in self.model.get_loss_keys():
  207. loss_dict.update({k: paddle.to_tensor([0.])})
  208. for k in self.model.get_loss_keys():
  209. loss_dict.update({'distill_' + k: paddle.to_tensor([0.])})
  210. # Note: for step_id, data in enumerate(self.loader): # enumerate bug
  211. for step_id in range(len(self.loader)):
  212. data = next(self.loader)
  213. self.model.train()
  214. self.ema.model.eval()
  215. data_sup_w, data_sup_s, data_unsup_w, data_unsup_s = data
  216. self.status['data_time'].update(time.time() - iter_tic)
  217. self.status['step_id'] = step_id
  218. profiler.add_profiler_step(profiler_options)
  219. self._compose_callback.on_step_begin(self.status)
  220. if data_sup_w['image'].shape != data_sup_s['image'].shape:
  221. data_sup_w, data_sup_s = align_weak_strong_shape(data_sup_w,
  222. data_sup_s)
  223. data_sup_w['epoch_id'] = epoch_id
  224. data_sup_s['epoch_id'] = epoch_id
  225. if concat_sup_data:
  226. for k, v in data_sup_s.items():
  227. if k in ['epoch_id']:
  228. continue
  229. data_sup_s[k] = paddle.concat([v, data_sup_w[k]])
  230. loss_dict_sup = self.model(data_sup_s)
  231. else:
  232. loss_dict_sup_w = self.model(data_sup_w)
  233. loss_dict_sup = self.model(data_sup_s)
  234. for k, v in loss_dict_sup_w.items():
  235. loss_dict_sup[k] = (loss_dict_sup[k] + v) * 0.5
  236. losses_sup = loss_dict_sup['loss'] * train_cfg['sup_weight']
  237. losses_sup.backward()
  238. losses = losses_sup.detach()
  239. loss_dict.update(loss_dict_sup)
  240. loss_dict.update({'loss_sup_sum': loss_dict['loss']})
  241. curr_iter = len(self.loader) * epoch_id + step_id
  242. st_iter = self.semi_start_iters
  243. if curr_iter == st_iter:
  244. logger.info("***" * 30)
  245. logger.info('Semi starting ...')
  246. logger.info("***" * 30)
  247. if curr_iter > st_iter:
  248. unsup_weight = train_cfg['unsup_weight']
  249. if train_cfg['suppress'] == 'linear':
  250. tar_iter = st_iter * 2
  251. if curr_iter <= tar_iter:
  252. unsup_weight *= (curr_iter - st_iter) / st_iter
  253. elif train_cfg['suppress'] == 'exp':
  254. tar_iter = st_iter + 2000
  255. if curr_iter <= tar_iter:
  256. scale = np.exp((curr_iter - tar_iter) / 1000)
  257. unsup_weight *= scale
  258. elif train_cfg['suppress'] == 'step':
  259. tar_iter = st_iter * 2
  260. if curr_iter <= tar_iter:
  261. unsup_weight *= 0.25
  262. else:
  263. raise ValueError
  264. if data_unsup_w['image'].shape != data_unsup_s[
  265. 'image'].shape:
  266. data_unsup_w, data_unsup_s = align_weak_strong_shape(
  267. data_unsup_w, data_unsup_s)
  268. data_unsup_w['epoch_id'] = epoch_id
  269. data_unsup_s['epoch_id'] = epoch_id
  270. data_unsup_s['get_data'] = True
  271. student_preds = self.model(data_unsup_s)
  272. with paddle.no_grad():
  273. data_unsup_w['is_teacher'] = True
  274. teacher_preds = self.ema.model(data_unsup_w)
  275. if self._nranks > 1:
  276. loss_dict_unsup = self.model._layers.get_distill_loss(
  277. student_preds,
  278. teacher_preds,
  279. ratio=train_cfg['ratio'])
  280. else:
  281. loss_dict_unsup = self.model.get_distill_loss(
  282. student_preds,
  283. teacher_preds,
  284. ratio=train_cfg['ratio'])
  285. fg_num = loss_dict_unsup["fg_sum"]
  286. del loss_dict_unsup["fg_sum"]
  287. distill_weights = train_cfg['loss_weight']
  288. loss_dict_unsup = {
  289. k: v * distill_weights[k]
  290. for k, v in loss_dict_unsup.items()
  291. }
  292. losses_unsup = sum([
  293. metrics_value
  294. for metrics_value in loss_dict_unsup.values()
  295. ]) * unsup_weight
  296. losses_unsup.backward()
  297. loss_dict.update(loss_dict_unsup)
  298. loss_dict.update({'loss_unsup_sum': losses_unsup})
  299. losses += losses_unsup.detach()
  300. loss_dict.update({"fg_sum": fg_num})
  301. loss_dict['loss'] = losses
  302. self.optimizer.step()
  303. curr_lr = self.optimizer.get_lr()
  304. self.lr.step()
  305. self.optimizer.clear_grad()
  306. self.status['learning_rate'] = curr_lr
  307. if self._nranks < 2 or self._local_rank == 0:
  308. self.status['training_staus'].update(loss_dict)
  309. self.status['batch_time'].update(time.time() - iter_tic)
  310. self._compose_callback.on_step_end(self.status)
  311. # Note: ema_start_iters
  312. if self.use_ema and curr_iter == self.ema_start_iters:
  313. logger.info("***" * 30)
  314. logger.info('EMA starting ...')
  315. logger.info("***" * 30)
  316. self.ema.update(self.model, decay=0)
  317. elif self.use_ema and curr_iter > self.ema_start_iters:
  318. self.ema.update(self.model)
  319. iter_tic = time.time()
  320. is_snapshot = (self._nranks < 2 or self._local_rank == 0) \
  321. and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 or epoch_id == self.end_epoch - 1)
  322. if is_snapshot and self.use_ema:
  323. # apply ema weight on model
  324. weight = copy.deepcopy(self.ema.model.state_dict())
  325. for k, v in weight.items():
  326. if paddle.is_floating_point(v):
  327. weight[k].stop_gradient = True
  328. self.status['weight'] = weight
  329. self._compose_callback.on_epoch_end(self.status)
  330. if validate and is_snapshot:
  331. if not hasattr(self, '_eval_loader'):
  332. # build evaluation dataset and loader
  333. self._eval_dataset = self.cfg.EvalDataset
  334. self._eval_batch_sampler = \
  335. paddle.io.BatchSampler(
  336. self._eval_dataset,
  337. batch_size=self.cfg.EvalReader['batch_size'])
  338. # If metric is VOC, need to be set collate_batch=False.
  339. if self.cfg.metric == 'VOC':
  340. self.cfg['EvalReader']['collate_batch'] = False
  341. self._eval_loader = create('EvalReader')(
  342. self._eval_dataset,
  343. self.cfg.worker_num,
  344. batch_sampler=self._eval_batch_sampler)
  345. # if validation in training is enabled, metrics should be re-init
  346. # Init_mark makes sure this code will only execute once
  347. if validate and Init_mark == False:
  348. Init_mark = True
  349. self._init_metrics(validate=validate)
  350. self._reset_metrics()
  351. with paddle.no_grad():
  352. self.status['save_best_model'] = True
  353. self._eval_with_loader(self._eval_loader)
  354. if is_snapshot and self.use_ema:
  355. self.status.pop('weight')
  356. self._compose_callback.on_train_end(self.status)
  357. def evaluate(self):
  358. # get distributed model
  359. if self.cfg.get('fleet', False):
  360. self.model = fleet.distributed_model(self.model)
  361. self.optimizer = fleet.distributed_optimizer(self.optimizer)
  362. elif self._nranks > 1:
  363. find_unused_parameters = self.cfg[
  364. 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
  365. self.model = paddle.DataParallel(
  366. self.model, find_unused_parameters=find_unused_parameters)
  367. with paddle.no_grad():
  368. self._eval_with_loader(self.loader)
  369. def _eval_with_loader(self, loader):
  370. sample_num = 0
  371. tic = time.time()
  372. self._compose_callback.on_epoch_begin(self.status)
  373. self.status['mode'] = 'eval'
  374. test_cfg = self.cfg.DenseTeacher['test_cfg']
  375. if test_cfg['inference_on'] == 'teacher':
  376. logger.info("***** teacher model evaluating *****")
  377. eval_model = self.ema.model
  378. else:
  379. logger.info("***** student model evaluating *****")
  380. eval_model = self.model
  381. eval_model.eval()
  382. if self.cfg.get('print_flops', False):
  383. flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
  384. self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
  385. self._flops(flops_loader)
  386. for step_id, data in enumerate(loader):
  387. self.status['step_id'] = step_id
  388. self._compose_callback.on_step_begin(self.status)
  389. # forward
  390. if self.use_amp:
  391. with paddle.amp.auto_cast(
  392. enable=self.cfg.use_gpu or self.cfg.use_mlu,
  393. custom_white_list=self.custom_white_list,
  394. custom_black_list=self.custom_black_list,
  395. level=self.amp_level):
  396. outs = eval_model(data)
  397. else:
  398. outs = eval_model(data)
  399. # update metrics
  400. for metric in self._metrics:
  401. metric.update(data, outs)
  402. # multi-scale inputs: all inputs have same im_id
  403. if isinstance(data, typing.Sequence):
  404. sample_num += data[0]['im_id'].numpy().shape[0]
  405. else:
  406. sample_num += data['im_id'].numpy().shape[0]
  407. self._compose_callback.on_step_end(self.status)
  408. self.status['sample_num'] = sample_num
  409. self.status['cost_time'] = time.time() - tic
  410. # accumulate metric to log out
  411. for metric in self._metrics:
  412. metric.accumulate()
  413. metric.log()
  414. self._compose_callback.on_epoch_end(self.status)
  415. # reset metric states for metric may performed multiple times
  416. self._reset_metrics()