callbacks.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import datetime
  20. import six
  21. import copy
  22. import json
  23. import paddle
  24. import paddle.distributed as dist
  25. from ppdet.utils.checkpoint import save_model
  26. from ppdet.metrics import get_infer_results
  27. from ppdet.utils.logger import setup_logger
  28. logger = setup_logger('ppdet.engine')
  29. __all__ = [
  30. 'Callback', 'ComposeCallback', 'LogPrinter', 'Checkpointer',
  31. 'VisualDLWriter', 'SniperProposalsGenerator'
  32. ]
  33. class Callback(object):
  34. def __init__(self, model):
  35. self.model = model
  36. def on_step_begin(self, status):
  37. pass
  38. def on_step_end(self, status):
  39. pass
  40. def on_epoch_begin(self, status):
  41. pass
  42. def on_epoch_end(self, status):
  43. pass
  44. def on_train_begin(self, status):
  45. pass
  46. def on_train_end(self, status):
  47. pass
  48. class ComposeCallback(object):
  49. def __init__(self, callbacks):
  50. callbacks = [c for c in list(callbacks) if c is not None]
  51. for c in callbacks:
  52. assert isinstance(
  53. c, Callback), "callback should be subclass of Callback"
  54. self._callbacks = callbacks
  55. def on_step_begin(self, status):
  56. for c in self._callbacks:
  57. c.on_step_begin(status)
  58. def on_step_end(self, status):
  59. for c in self._callbacks:
  60. c.on_step_end(status)
  61. def on_epoch_begin(self, status):
  62. for c in self._callbacks:
  63. c.on_epoch_begin(status)
  64. def on_epoch_end(self, status):
  65. for c in self._callbacks:
  66. c.on_epoch_end(status)
  67. def on_train_begin(self, status):
  68. for c in self._callbacks:
  69. c.on_train_begin(status)
  70. def on_train_end(self, status):
  71. for c in self._callbacks:
  72. c.on_train_end(status)
  73. class LogPrinter(Callback):
  74. def __init__(self, model):
  75. super(LogPrinter, self).__init__(model)
  76. def on_step_end(self, status):
  77. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  78. mode = status['mode']
  79. if mode == 'train':
  80. epoch_id = status['epoch_id']
  81. step_id = status['step_id']
  82. steps_per_epoch = status['steps_per_epoch']
  83. training_staus = status['training_staus']
  84. batch_time = status['batch_time']
  85. data_time = status['data_time']
  86. epoches = self.model.cfg.epoch
  87. batch_size = self.model.cfg['{}Reader'.format(mode.capitalize(
  88. ))]['batch_size']
  89. logs = training_staus.log()
  90. space_fmt = ':' + str(len(str(steps_per_epoch))) + 'd'
  91. if step_id % self.model.cfg.log_iter == 0:
  92. eta_steps = (epoches - epoch_id) * steps_per_epoch - step_id
  93. eta_sec = eta_steps * batch_time.global_avg
  94. eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
  95. ips = float(batch_size) / batch_time.avg
  96. fmt = ' '.join([
  97. 'Epoch: [{}]',
  98. '[{' + space_fmt + '}/{}]',
  99. 'learning_rate: {lr:.6f}',
  100. '{meters}',
  101. 'eta: {eta}',
  102. 'batch_cost: {btime}',
  103. 'data_cost: {dtime}',
  104. 'ips: {ips:.4f} images/s',
  105. ])
  106. fmt = fmt.format(
  107. epoch_id,
  108. step_id,
  109. steps_per_epoch,
  110. lr=status['learning_rate'],
  111. meters=logs,
  112. eta=eta_str,
  113. btime=str(batch_time),
  114. dtime=str(data_time),
  115. ips=ips)
  116. logger.info(fmt)
  117. if mode == 'eval':
  118. step_id = status['step_id']
  119. if step_id % 100 == 0:
  120. logger.info("Eval iter: {}".format(step_id))
  121. def on_epoch_end(self, status):
  122. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  123. mode = status['mode']
  124. if mode == 'eval':
  125. sample_num = status['sample_num']
  126. cost_time = status['cost_time']
  127. logger.info('Total sample number: {}, average FPS: {}'.format(
  128. sample_num, sample_num / cost_time))
  129. class Checkpointer(Callback):
  130. def __init__(self, model):
  131. super(Checkpointer, self).__init__(model)
  132. self.best_ap = -1000.
  133. self.save_dir = os.path.join(self.model.cfg.save_dir,
  134. self.model.cfg.filename)
  135. if hasattr(self.model.model, 'student_model'):
  136. self.weight = self.model.model.student_model
  137. else:
  138. self.weight = self.model.model
  139. def on_epoch_end(self, status):
  140. # Checkpointer only performed during training
  141. mode = status['mode']
  142. epoch_id = status['epoch_id']
  143. weight = None
  144. save_name = None
  145. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  146. if mode == 'train':
  147. end_epoch = self.model.cfg.epoch
  148. if (
  149. epoch_id + 1
  150. ) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
  151. save_name = str(
  152. epoch_id) if epoch_id != end_epoch - 1 else "model_final"
  153. weight = self.weight.state_dict()
  154. elif mode == 'eval':
  155. if 'save_best_model' in status and status['save_best_model']:
  156. for metric in self.model._metrics:
  157. map_res = metric.get_results()
  158. eval_func = "ap"
  159. if 'pose3d' in map_res:
  160. key = 'pose3d'
  161. eval_func = "mpjpe"
  162. elif 'bbox' in map_res:
  163. key = 'bbox'
  164. elif 'keypoint' in map_res:
  165. key = 'keypoint'
  166. else:
  167. key = 'mask'
  168. if key not in map_res:
  169. logger.warning("Evaluation results empty, this may be due to " \
  170. "training iterations being too few or not " \
  171. "loading the correct weights.")
  172. return
  173. if map_res[key][0] >= self.best_ap:
  174. self.best_ap = map_res[key][0]
  175. save_name = 'best_model'
  176. weight = self.weight.state_dict()
  177. logger.info("Best test {} {} is {:0.3f}.".format(
  178. key, eval_func, abs(self.best_ap)))
  179. if weight:
  180. if self.model.use_ema:
  181. exchange_save_model = status.get('exchange_save_model',
  182. False)
  183. if not exchange_save_model:
  184. # save model and ema_model
  185. save_model(
  186. status['weight'],
  187. self.model.optimizer,
  188. self.save_dir,
  189. save_name,
  190. epoch_id + 1,
  191. ema_model=weight)
  192. else:
  193. # save model(student model) and ema_model(teacher model)
  194. # in DenseTeacher SSOD, the teacher model will be higher,
  195. # so exchange when saving pdparams
  196. student_model = status['weight'] # model
  197. teacher_model = weight # ema_model
  198. save_model(
  199. teacher_model,
  200. self.model.optimizer,
  201. self.save_dir,
  202. save_name,
  203. epoch_id + 1,
  204. ema_model=student_model)
  205. del teacher_model
  206. del student_model
  207. else:
  208. save_model(weight, self.model.optimizer, self.save_dir,
  209. save_name, epoch_id + 1)
  210. class WiferFaceEval(Callback):
  211. def __init__(self, model):
  212. super(WiferFaceEval, self).__init__(model)
  213. def on_epoch_begin(self, status):
  214. assert self.model.mode == 'eval', \
  215. "WiferFaceEval can only be set during evaluation"
  216. for metric in self.model._metrics:
  217. metric.update(self.model.model)
  218. sys.exit()
  219. class VisualDLWriter(Callback):
  220. """
  221. Use VisualDL to log data or image
  222. """
  223. def __init__(self, model):
  224. super(VisualDLWriter, self).__init__(model)
  225. assert six.PY3, "VisualDL requires Python >= 3.5"
  226. try:
  227. from visualdl import LogWriter
  228. except Exception as e:
  229. logger.error('visualdl not found, plaese install visualdl. '
  230. 'for example: `pip install visualdl`.')
  231. raise e
  232. self.vdl_writer = LogWriter(
  233. model.cfg.get('vdl_log_dir', 'vdl_log_dir/scalar'))
  234. self.vdl_loss_step = 0
  235. self.vdl_mAP_step = 0
  236. self.vdl_image_step = 0
  237. self.vdl_image_frame = 0
  238. def on_step_end(self, status):
  239. mode = status['mode']
  240. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  241. if mode == 'train':
  242. training_staus = status['training_staus']
  243. for loss_name, loss_value in training_staus.get().items():
  244. self.vdl_writer.add_scalar(loss_name, loss_value,
  245. self.vdl_loss_step)
  246. self.vdl_loss_step += 1
  247. elif mode == 'test':
  248. ori_image = status['original_image']
  249. result_image = status['result_image']
  250. self.vdl_writer.add_image(
  251. "original/frame_{}".format(self.vdl_image_frame), ori_image,
  252. self.vdl_image_step)
  253. self.vdl_writer.add_image(
  254. "result/frame_{}".format(self.vdl_image_frame),
  255. result_image, self.vdl_image_step)
  256. self.vdl_image_step += 1
  257. # each frame can display ten pictures at most.
  258. if self.vdl_image_step % 10 == 0:
  259. self.vdl_image_step = 0
  260. self.vdl_image_frame += 1
  261. def on_epoch_end(self, status):
  262. mode = status['mode']
  263. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  264. if mode == 'eval':
  265. for metric in self.model._metrics:
  266. for key, map_value in metric.get_results().items():
  267. self.vdl_writer.add_scalar("{}-mAP".format(key),
  268. map_value[0],
  269. self.vdl_mAP_step)
  270. self.vdl_mAP_step += 1
  271. class WandbCallback(Callback):
  272. def __init__(self, model):
  273. super(WandbCallback, self).__init__(model)
  274. try:
  275. import wandb
  276. self.wandb = wandb
  277. except Exception as e:
  278. logger.error('wandb not found, please install wandb. '
  279. 'Use: `pip install wandb`.')
  280. raise e
  281. self.wandb_params = model.cfg.get('wandb', None)
  282. self.save_dir = os.path.join(self.model.cfg.save_dir,
  283. self.model.cfg.filename)
  284. if self.wandb_params is None:
  285. self.wandb_params = {}
  286. for k, v in model.cfg.items():
  287. if k.startswith("wandb_"):
  288. self.wandb_params.update({k.lstrip("wandb_"): v})
  289. self._run = None
  290. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  291. _ = self.run
  292. self.run.config.update(self.model.cfg)
  293. self.run.define_metric("epoch")
  294. self.run.define_metric("eval/*", step_metric="epoch")
  295. self.best_ap = -1000.
  296. self.fps = []
  297. @property
  298. def run(self):
  299. if self._run is None:
  300. if self.wandb.run is not None:
  301. logger.info(
  302. "There is an ongoing wandb run which will be used"
  303. "for logging. Please use `wandb.finish()` to end that"
  304. "if the behaviour is not intended")
  305. self._run = self.wandb.run
  306. else:
  307. self._run = self.wandb.init(**self.wandb_params)
  308. return self._run
  309. def save_model(self,
  310. optimizer,
  311. save_dir,
  312. save_name,
  313. last_epoch,
  314. ema_model=None,
  315. ap=None,
  316. fps=None,
  317. tags=None):
  318. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  319. model_path = os.path.join(save_dir, save_name)
  320. metadata = {}
  321. metadata["last_epoch"] = last_epoch
  322. if ap:
  323. metadata["ap"] = ap
  324. if fps:
  325. metadata["fps"] = fps
  326. if ema_model is None:
  327. ema_artifact = self.wandb.Artifact(
  328. name="ema_model-{}".format(self.run.id),
  329. type="model",
  330. metadata=metadata)
  331. model_artifact = self.wandb.Artifact(
  332. name="model-{}".format(self.run.id),
  333. type="model",
  334. metadata=metadata)
  335. ema_artifact.add_file(model_path + ".pdema", name="model_ema")
  336. model_artifact.add_file(model_path + ".pdparams", name="model")
  337. self.run.log_artifact(ema_artifact, aliases=tags)
  338. self.run.log_artfact(model_artifact, aliases=tags)
  339. else:
  340. model_artifact = self.wandb.Artifact(
  341. name="model-{}".format(self.run.id),
  342. type="model",
  343. metadata=metadata)
  344. model_artifact.add_file(model_path + ".pdparams", name="model")
  345. self.run.log_artifact(model_artifact, aliases=tags)
  346. def on_step_end(self, status):
  347. mode = status['mode']
  348. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  349. if mode == 'train':
  350. training_status = status['training_staus'].get()
  351. for k, v in training_status.items():
  352. training_status[k] = float(v)
  353. # calculate ips, data_cost, batch_cost
  354. batch_time = status['batch_time']
  355. data_time = status['data_time']
  356. batch_size = self.model.cfg['{}Reader'.format(mode.capitalize(
  357. ))]['batch_size']
  358. ips = float(batch_size) / float(batch_time.avg)
  359. data_cost = float(data_time.avg)
  360. batch_cost = float(batch_time.avg)
  361. metrics = {"train/" + k: v for k, v in training_status.items()}
  362. metrics["train/ips"] = ips
  363. metrics["train/data_cost"] = data_cost
  364. metrics["train/batch_cost"] = batch_cost
  365. self.fps.append(ips)
  366. self.run.log(metrics)
  367. def on_epoch_end(self, status):
  368. mode = status['mode']
  369. epoch_id = status['epoch_id']
  370. save_name = None
  371. if dist.get_world_size() < 2 or dist.get_rank() == 0:
  372. if mode == 'train':
  373. fps = sum(self.fps) / len(self.fps)
  374. self.fps = []
  375. end_epoch = self.model.cfg.epoch
  376. if (
  377. epoch_id + 1
  378. ) % self.model.cfg.snapshot_epoch == 0 or epoch_id == end_epoch - 1:
  379. save_name = str(
  380. epoch_id) if epoch_id != end_epoch - 1 else "model_final"
  381. tags = ["latest", "epoch_{}".format(epoch_id)]
  382. self.save_model(
  383. self.model.optimizer,
  384. self.save_dir,
  385. save_name,
  386. epoch_id + 1,
  387. self.model.use_ema,
  388. fps=fps,
  389. tags=tags)
  390. if mode == 'eval':
  391. sample_num = status['sample_num']
  392. cost_time = status['cost_time']
  393. fps = sample_num / cost_time
  394. merged_dict = {}
  395. for metric in self.model._metrics:
  396. for key, map_value in metric.get_results().items():
  397. merged_dict["eval/{}-mAP".format(key)] = map_value[0]
  398. merged_dict["epoch"] = status["epoch_id"]
  399. merged_dict["eval/fps"] = sample_num / cost_time
  400. self.run.log(merged_dict)
  401. if 'save_best_model' in status and status['save_best_model']:
  402. for metric in self.model._metrics:
  403. map_res = metric.get_results()
  404. if 'pose3d' in map_res:
  405. key = 'pose3d'
  406. elif 'bbox' in map_res:
  407. key = 'bbox'
  408. elif 'keypoint' in map_res:
  409. key = 'keypoint'
  410. else:
  411. key = 'mask'
  412. if key not in map_res:
  413. logger.warning("Evaluation results empty, this may be due to " \
  414. "training iterations being too few or not " \
  415. "loading the correct weights.")
  416. return
  417. if map_res[key][0] >= self.best_ap:
  418. self.best_ap = map_res[key][0]
  419. save_name = 'best_model'
  420. tags = ["best", "epoch_{}".format(epoch_id)]
  421. self.save_model(
  422. self.model.optimizer,
  423. self.save_dir,
  424. save_name,
  425. last_epoch=epoch_id + 1,
  426. ema_model=self.model.use_ema,
  427. ap=abs(self.best_ap),
  428. fps=fps,
  429. tags=tags)
  430. def on_train_end(self, status):
  431. self.run.finish()
  432. class SniperProposalsGenerator(Callback):
  433. def __init__(self, model):
  434. super(SniperProposalsGenerator, self).__init__(model)
  435. ori_dataset = self.model.dataset
  436. self.dataset = self._create_new_dataset(ori_dataset)
  437. self.loader = self.model.loader
  438. self.cfg = self.model.cfg
  439. self.infer_model = self.model.model
  440. def _create_new_dataset(self, ori_dataset):
  441. dataset = copy.deepcopy(ori_dataset)
  442. # init anno_cropper
  443. dataset.init_anno_cropper()
  444. # generate infer roidbs
  445. ori_roidbs = dataset.get_ori_roidbs()
  446. roidbs = dataset.anno_cropper.crop_infer_anno_records(ori_roidbs)
  447. # set new roidbs
  448. dataset.set_roidbs(roidbs)
  449. return dataset
  450. def _eval_with_loader(self, loader):
  451. results = []
  452. with paddle.no_grad():
  453. self.infer_model.eval()
  454. for step_id, data in enumerate(loader):
  455. outs = self.infer_model(data)
  456. for key in ['im_shape', 'scale_factor', 'im_id']:
  457. outs[key] = data[key]
  458. for key, value in outs.items():
  459. if hasattr(value, 'numpy'):
  460. outs[key] = value.numpy()
  461. results.append(outs)
  462. return results
  463. def on_train_end(self, status):
  464. self.loader.dataset = self.dataset
  465. results = self._eval_with_loader(self.loader)
  466. results = self.dataset.anno_cropper.aggregate_chips_detections(results)
  467. # sniper
  468. proposals = []
  469. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()}
  470. for outs in results:
  471. batch_res = get_infer_results(outs, clsid2catid)
  472. start = 0
  473. for i, im_id in enumerate(outs['im_id']):
  474. bbox_num = outs['bbox_num']
  475. end = start + bbox_num[i]
  476. bbox_res = batch_res['bbox'][start:end] \
  477. if 'bbox' in batch_res else None
  478. if bbox_res:
  479. proposals += bbox_res
  480. logger.info("save proposals in {}".format(self.cfg.proposals_path))
  481. with open(self.cfg.proposals_path, 'w') as f:
  482. json.dump(proposals, f)