trainer.py 53 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. import copy
  20. import time
  21. from tqdm import tqdm
  22. import numpy as np
  23. import typing
  24. from PIL import Image, ImageOps, ImageFile
  25. ImageFile.LOAD_TRUNCATED_IMAGES = True
  26. import paddle
  27. import paddle.nn as nn
  28. import paddle.distributed as dist
  29. from paddle.distributed import fleet
  30. from paddle.static import InputSpec
  31. from ppdet.optimizer import ModelEMA
  32. from ppdet.core.workspace import create
  33. from ppdet.utils.checkpoint import load_weight, load_pretrain_weight
  34. from ppdet.utils.visualizer import visualize_results, save_result
  35. from ppdet.metrics import Metric, COCOMetric, VOCMetric, WiderFaceMetric, get_infer_results, KeyPointTopDownCOCOEval, KeyPointTopDownMPIIEval, Pose3DEval
  36. from ppdet.metrics import RBoxMetric, JDEDetMetric, SNIPERCOCOMetric
  37. from ppdet.data.source.sniper_coco import SniperCOCODataSet
  38. from ppdet.data.source.category import get_categories
  39. import ppdet.utils.stats as stats
  40. from ppdet.utils.fuse_utils import fuse_conv_bn
  41. from ppdet.utils import profiler
  42. from ppdet.modeling.post_process import multiclass_nms
  43. from .callbacks import Callback, ComposeCallback, LogPrinter, Checkpointer, WiferFaceEval, VisualDLWriter, SniperProposalsGenerator, WandbCallback
  44. from .export_utils import _dump_infer_config, _prune_input_spec, apply_to_static
  45. from paddle.distributed.fleet.utils.hybrid_parallel_util import fused_allreduce_gradients
  46. from ppdet.utils.logger import setup_logger
  47. logger = setup_logger('ppdet.engine')
  48. __all__ = ['Trainer']
  49. MOT_ARCH = ['JDE', 'FairMOT', 'DeepSORT', 'ByteTrack', 'CenterTrack']
  50. class Trainer(object):
  51. def __init__(self, cfg, mode='train'):
  52. self.cfg = cfg
  53. assert mode.lower() in ['train', 'eval', 'test'], \
  54. "mode should be 'train', 'eval' or 'test'"
  55. self.mode = mode.lower()
  56. self.optimizer = None
  57. self.is_loaded_weights = False
  58. self.use_amp = self.cfg.get('amp', False)
  59. self.amp_level = self.cfg.get('amp_level', 'O1')
  60. self.custom_white_list = self.cfg.get('custom_white_list', None)
  61. self.custom_black_list = self.cfg.get('custom_black_list', None)
  62. # build data loader
  63. capital_mode = self.mode.capitalize()
  64. if cfg.architecture in MOT_ARCH and self.mode in [
  65. 'eval', 'test'
  66. ] and cfg.metric not in ['COCO', 'VOC']:
  67. self.dataset = self.cfg['{}MOTDataset'.format(
  68. capital_mode)] = create('{}MOTDataset'.format(capital_mode))()
  69. else:
  70. self.dataset = self.cfg['{}Dataset'.format(capital_mode)] = create(
  71. '{}Dataset'.format(capital_mode))()
  72. if cfg.architecture == 'DeepSORT' and self.mode == 'train':
  73. logger.error('DeepSORT has no need of training on mot dataset.')
  74. sys.exit(1)
  75. if cfg.architecture == 'FairMOT' and self.mode == 'eval':
  76. images = self.parse_mot_images(cfg)
  77. self.dataset.set_images(images)
  78. if self.mode == 'train':
  79. self.loader = create('{}Reader'.format(capital_mode))(
  80. self.dataset, cfg.worker_num)
  81. if cfg.architecture == 'JDE' and self.mode == 'train':
  82. cfg['JDEEmbeddingHead'][
  83. 'num_identities'] = self.dataset.num_identities_dict[0]
  84. # JDE only support single class MOT now.
  85. if cfg.architecture == 'FairMOT' and self.mode == 'train':
  86. cfg['FairMOTEmbeddingHead'][
  87. 'num_identities_dict'] = self.dataset.num_identities_dict
  88. # FairMOT support single class and multi-class MOT now.
  89. # build model
  90. if 'model' not in self.cfg:
  91. self.model = create(cfg.architecture)
  92. else:
  93. self.model = self.cfg.model
  94. self.is_loaded_weights = True
  95. if cfg.architecture == 'YOLOX':
  96. for k, m in self.model.named_sublayers():
  97. if isinstance(m, nn.BatchNorm2D):
  98. m._epsilon = 1e-3 # for amp(fp16)
  99. m._momentum = 0.97 # 0.03 in pytorch
  100. #normalize params for deploy
  101. if 'slim' in cfg and cfg['slim_type'] == 'OFA':
  102. self.model.model.load_meanstd(cfg['TestReader'][
  103. 'sample_transforms'])
  104. elif 'slim' in cfg and cfg['slim_type'] == 'Distill':
  105. self.model.student_model.load_meanstd(cfg['TestReader'][
  106. 'sample_transforms'])
  107. elif 'slim' in cfg and cfg[
  108. 'slim_type'] == 'DistillPrune' and self.mode == 'train':
  109. self.model.student_model.load_meanstd(cfg['TestReader'][
  110. 'sample_transforms'])
  111. else:
  112. self.model.load_meanstd(cfg['TestReader']['sample_transforms'])
  113. # EvalDataset build with BatchSampler to evaluate in single device
  114. # TODO: multi-device evaluate
  115. if self.mode == 'eval':
  116. if cfg.architecture == 'FairMOT':
  117. self.loader = create('EvalMOTReader')(self.dataset, 0)
  118. elif cfg.architecture == "METRO_Body":
  119. reader_name = '{}Reader'.format(self.mode.capitalize())
  120. self.loader = create(reader_name)(self.dataset, cfg.worker_num)
  121. else:
  122. self._eval_batch_sampler = paddle.io.BatchSampler(
  123. self.dataset, batch_size=self.cfg.EvalReader['batch_size'])
  124. reader_name = '{}Reader'.format(self.mode.capitalize())
  125. # If metric is VOC, need to be set collate_batch=False.
  126. if cfg.metric == 'VOC':
  127. cfg[reader_name]['collate_batch'] = False
  128. self.loader = create(reader_name)(self.dataset, cfg.worker_num,
  129. self._eval_batch_sampler)
  130. # TestDataset build after user set images, skip loader creation here
  131. # get Params
  132. print_params = self.cfg.get('print_params', False)
  133. if print_params:
  134. params = sum([
  135. p.numel() for n, p in self.model.named_parameters()
  136. if all([x not in n for x in ['_mean', '_variance', 'aux_']])
  137. ]) # exclude BatchNorm running status
  138. logger.info('Model Params : {} M.'.format((params / 1e6).numpy()[
  139. 0]))
  140. # build optimizer in train mode
  141. if self.mode == 'train':
  142. steps_per_epoch = len(self.loader)
  143. if steps_per_epoch < 1:
  144. logger.warning(
  145. "Samples in dataset are less than batch_size, please set smaller batch_size in TrainReader."
  146. )
  147. self.lr = create('LearningRate')(steps_per_epoch)
  148. self.optimizer = create('OptimizerBuilder')(self.lr, self.model)
  149. # Unstructured pruner is only enabled in the train mode.
  150. if self.cfg.get('unstructured_prune'):
  151. self.pruner = create('UnstructuredPruner')(self.model,
  152. steps_per_epoch)
  153. if self.use_amp and self.amp_level == 'O2':
  154. self.model, self.optimizer = paddle.amp.decorate(
  155. models=self.model,
  156. optimizers=self.optimizer,
  157. level=self.amp_level)
  158. self.use_ema = ('use_ema' in cfg and cfg['use_ema'])
  159. if self.use_ema:
  160. ema_decay = self.cfg.get('ema_decay', 0.9998)
  161. ema_decay_type = self.cfg.get('ema_decay_type', 'threshold')
  162. cycle_epoch = self.cfg.get('cycle_epoch', -1)
  163. ema_black_list = self.cfg.get('ema_black_list', None)
  164. self.ema = ModelEMA(
  165. self.model,
  166. decay=ema_decay,
  167. ema_decay_type=ema_decay_type,
  168. cycle_epoch=cycle_epoch,
  169. ema_black_list=ema_black_list)
  170. self._nranks = dist.get_world_size()
  171. self._local_rank = dist.get_rank()
  172. self.status = {}
  173. self.start_epoch = 0
  174. self.end_epoch = 0 if 'epoch' not in cfg else cfg.epoch
  175. # initial default callbacks
  176. self._init_callbacks()
  177. # initial default metrics
  178. self._init_metrics()
  179. self._reset_metrics()
  180. def _init_callbacks(self):
  181. if self.mode == 'train':
  182. self._callbacks = [LogPrinter(self), Checkpointer(self)]
  183. if self.cfg.get('use_vdl', False):
  184. self._callbacks.append(VisualDLWriter(self))
  185. if self.cfg.get('save_proposals', False):
  186. self._callbacks.append(SniperProposalsGenerator(self))
  187. if self.cfg.get('use_wandb', False) or 'wandb' in self.cfg:
  188. self._callbacks.append(WandbCallback(self))
  189. self._compose_callback = ComposeCallback(self._callbacks)
  190. elif self.mode == 'eval':
  191. self._callbacks = [LogPrinter(self)]
  192. if self.cfg.metric == 'WiderFace':
  193. self._callbacks.append(WiferFaceEval(self))
  194. self._compose_callback = ComposeCallback(self._callbacks)
  195. elif self.mode == 'test' and self.cfg.get('use_vdl', False):
  196. self._callbacks = [VisualDLWriter(self)]
  197. self._compose_callback = ComposeCallback(self._callbacks)
  198. else:
  199. self._callbacks = []
  200. self._compose_callback = None
  201. def _init_metrics(self, validate=False):
  202. if self.mode == 'test' or (self.mode == 'train' and not validate):
  203. self._metrics = []
  204. return
  205. classwise = self.cfg['classwise'] if 'classwise' in self.cfg else False
  206. if self.cfg.metric == 'COCO' or self.cfg.metric == "SNIPERCOCO":
  207. # TODO: bias should be unified
  208. bias = 1 if self.cfg.get('bias', False) else 0
  209. output_eval = self.cfg['output_eval'] \
  210. if 'output_eval' in self.cfg else None
  211. save_prediction_only = self.cfg.get('save_prediction_only', False)
  212. # pass clsid2catid info to metric instance to avoid multiple loading
  213. # annotation file
  214. clsid2catid = {v: k for k, v in self.dataset.catid2clsid.items()} \
  215. if self.mode == 'eval' else None
  216. # when do validation in train, annotation file should be get from
  217. # EvalReader instead of self.dataset(which is TrainReader)
  218. if self.mode == 'train' and validate:
  219. eval_dataset = self.cfg['EvalDataset']
  220. eval_dataset.check_or_download_dataset()
  221. anno_file = eval_dataset.get_anno()
  222. dataset = eval_dataset
  223. else:
  224. dataset = self.dataset
  225. anno_file = dataset.get_anno()
  226. IouType = self.cfg['IouType'] if 'IouType' in self.cfg else 'bbox'
  227. if self.cfg.metric == "COCO":
  228. self._metrics = [
  229. COCOMetric(
  230. anno_file=anno_file,
  231. clsid2catid=clsid2catid,
  232. classwise=classwise,
  233. output_eval=output_eval,
  234. bias=bias,
  235. IouType=IouType,
  236. save_prediction_only=save_prediction_only)
  237. ]
  238. elif self.cfg.metric == "SNIPERCOCO": # sniper
  239. self._metrics = [
  240. SNIPERCOCOMetric(
  241. anno_file=anno_file,
  242. dataset=dataset,
  243. clsid2catid=clsid2catid,
  244. classwise=classwise,
  245. output_eval=output_eval,
  246. bias=bias,
  247. IouType=IouType,
  248. save_prediction_only=save_prediction_only)
  249. ]
  250. elif self.cfg.metric == 'RBOX':
  251. # TODO: bias should be unified
  252. bias = self.cfg['bias'] if 'bias' in self.cfg else 0
  253. output_eval = self.cfg['output_eval'] \
  254. if 'output_eval' in self.cfg else None
  255. save_prediction_only = self.cfg.get('save_prediction_only', False)
  256. imid2path = self.cfg.get('imid2path', None)
  257. # when do validation in train, annotation file should be get from
  258. # EvalReader instead of self.dataset(which is TrainReader)
  259. anno_file = self.dataset.get_anno()
  260. if self.mode == 'train' and validate:
  261. eval_dataset = self.cfg['EvalDataset']
  262. eval_dataset.check_or_download_dataset()
  263. anno_file = eval_dataset.get_anno()
  264. self._metrics = [
  265. RBoxMetric(
  266. anno_file=anno_file,
  267. classwise=classwise,
  268. output_eval=output_eval,
  269. bias=bias,
  270. save_prediction_only=save_prediction_only,
  271. imid2path=imid2path)
  272. ]
  273. elif self.cfg.metric == 'VOC':
  274. output_eval = self.cfg['output_eval'] \
  275. if 'output_eval' in self.cfg else None
  276. save_prediction_only = self.cfg.get('save_prediction_only', False)
  277. self._metrics = [
  278. VOCMetric(
  279. label_list=self.dataset.get_label_list(),
  280. class_num=self.cfg.num_classes,
  281. map_type=self.cfg.map_type,
  282. classwise=classwise,
  283. output_eval=output_eval,
  284. save_prediction_only=save_prediction_only)
  285. ]
  286. elif self.cfg.metric == 'WiderFace':
  287. multi_scale = self.cfg.multi_scale_eval if 'multi_scale_eval' in self.cfg else True
  288. self._metrics = [
  289. WiderFaceMetric(
  290. image_dir=os.path.join(self.dataset.dataset_dir,
  291. self.dataset.image_dir),
  292. anno_file=self.dataset.get_anno(),
  293. multi_scale=multi_scale)
  294. ]
  295. elif self.cfg.metric == 'KeyPointTopDownCOCOEval':
  296. eval_dataset = self.cfg['EvalDataset']
  297. eval_dataset.check_or_download_dataset()
  298. anno_file = eval_dataset.get_anno()
  299. save_prediction_only = self.cfg.get('save_prediction_only', False)
  300. self._metrics = [
  301. KeyPointTopDownCOCOEval(
  302. anno_file,
  303. len(eval_dataset),
  304. self.cfg.num_joints,
  305. self.cfg.save_dir,
  306. save_prediction_only=save_prediction_only)
  307. ]
  308. elif self.cfg.metric == 'KeyPointTopDownMPIIEval':
  309. eval_dataset = self.cfg['EvalDataset']
  310. eval_dataset.check_or_download_dataset()
  311. anno_file = eval_dataset.get_anno()
  312. save_prediction_only = self.cfg.get('save_prediction_only', False)
  313. self._metrics = [
  314. KeyPointTopDownMPIIEval(
  315. anno_file,
  316. len(eval_dataset),
  317. self.cfg.num_joints,
  318. self.cfg.save_dir,
  319. save_prediction_only=save_prediction_only)
  320. ]
  321. elif self.cfg.metric == 'Pose3DEval':
  322. save_prediction_only = self.cfg.get('save_prediction_only', False)
  323. self._metrics = [
  324. Pose3DEval(
  325. self.cfg.save_dir,
  326. save_prediction_only=save_prediction_only)
  327. ]
  328. elif self.cfg.metric == 'MOTDet':
  329. self._metrics = [JDEDetMetric(), ]
  330. else:
  331. logger.warning("Metric not support for metric type {}".format(
  332. self.cfg.metric))
  333. self._metrics = []
  334. def _reset_metrics(self):
  335. for metric in self._metrics:
  336. metric.reset()
  337. def register_callbacks(self, callbacks):
  338. callbacks = [c for c in list(callbacks) if c is not None]
  339. for c in callbacks:
  340. assert isinstance(c, Callback), \
  341. "metrics shoule be instances of subclass of Metric"
  342. self._callbacks.extend(callbacks)
  343. self._compose_callback = ComposeCallback(self._callbacks)
  344. def register_metrics(self, metrics):
  345. metrics = [m for m in list(metrics) if m is not None]
  346. for m in metrics:
  347. assert isinstance(m, Metric), \
  348. "metrics shoule be instances of subclass of Metric"
  349. self._metrics.extend(metrics)
  350. def load_weights(self, weights):
  351. if self.is_loaded_weights:
  352. return
  353. self.start_epoch = 0
  354. load_pretrain_weight(self.model, weights)
  355. logger.debug("Load weights {} to start training".format(weights))
  356. def load_weights_sde(self, det_weights, reid_weights):
  357. if self.model.detector:
  358. load_weight(self.model.detector, det_weights)
  359. if self.model.reid:
  360. load_weight(self.model.reid, reid_weights)
  361. else:
  362. load_weight(self.model.reid, reid_weights)
  363. def resume_weights(self, weights):
  364. # support Distill resume weights
  365. if hasattr(self.model, 'student_model'):
  366. self.start_epoch = load_weight(self.model.student_model, weights,
  367. self.optimizer)
  368. else:
  369. self.start_epoch = load_weight(self.model, weights, self.optimizer,
  370. self.ema if self.use_ema else None)
  371. logger.debug("Resume weights of epoch {}".format(self.start_epoch))
  372. def train(self, validate=False):
  373. assert self.mode == 'train', "Model not in 'train' mode"
  374. Init_mark = False
  375. if validate:
  376. self.cfg['EvalDataset'] = self.cfg.EvalDataset = create(
  377. "EvalDataset")()
  378. model = self.model
  379. if self.cfg.get('to_static', False):
  380. model = apply_to_static(self.cfg, model)
  381. sync_bn = (
  382. getattr(self.cfg, 'norm_type', None) == 'sync_bn' and
  383. (self.cfg.use_gpu or self.cfg.use_npu or self.cfg.use_mlu) and
  384. self._nranks > 1)
  385. if sync_bn:
  386. model = paddle.nn.SyncBatchNorm.convert_sync_batchnorm(model)
  387. # enabel auto mixed precision mode
  388. if self.use_amp:
  389. scaler = paddle.amp.GradScaler(
  390. enable=self.cfg.use_gpu or self.cfg.use_npu or self.cfg.use_mlu,
  391. init_loss_scaling=self.cfg.get('init_loss_scaling', 1024))
  392. # get distributed model
  393. if self.cfg.get('fleet', False):
  394. model = fleet.distributed_model(model)
  395. self.optimizer = fleet.distributed_optimizer(self.optimizer)
  396. elif self._nranks > 1:
  397. find_unused_parameters = self.cfg[
  398. 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
  399. model = paddle.DataParallel(
  400. model, find_unused_parameters=find_unused_parameters)
  401. self.status.update({
  402. 'epoch_id': self.start_epoch,
  403. 'step_id': 0,
  404. 'steps_per_epoch': len(self.loader)
  405. })
  406. self.status['batch_time'] = stats.SmoothedValue(
  407. self.cfg.log_iter, fmt='{avg:.4f}')
  408. self.status['data_time'] = stats.SmoothedValue(
  409. self.cfg.log_iter, fmt='{avg:.4f}')
  410. self.status['training_staus'] = stats.TrainingStats(self.cfg.log_iter)
  411. if self.cfg.get('print_flops', False):
  412. flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
  413. self.dataset, self.cfg.worker_num)
  414. self._flops(flops_loader)
  415. profiler_options = self.cfg.get('profiler_options', None)
  416. self._compose_callback.on_train_begin(self.status)
  417. use_fused_allreduce_gradients = self.cfg[
  418. 'use_fused_allreduce_gradients'] if 'use_fused_allreduce_gradients' in self.cfg else False
  419. for epoch_id in range(self.start_epoch, self.cfg.epoch):
  420. self.status['mode'] = 'train'
  421. self.status['epoch_id'] = epoch_id
  422. self._compose_callback.on_epoch_begin(self.status)
  423. self.loader.dataset.set_epoch(epoch_id)
  424. model.train()
  425. iter_tic = time.time()
  426. for step_id, data in enumerate(self.loader):
  427. self.status['data_time'].update(time.time() - iter_tic)
  428. self.status['step_id'] = step_id
  429. profiler.add_profiler_step(profiler_options)
  430. self._compose_callback.on_step_begin(self.status)
  431. data['epoch_id'] = epoch_id
  432. if self.use_amp:
  433. if isinstance(
  434. model, paddle.
  435. DataParallel) and use_fused_allreduce_gradients:
  436. with model.no_sync():
  437. with paddle.amp.auto_cast(
  438. enable=self.cfg.use_gpu or
  439. self.cfg.use_npu or self.cfg.use_mlu,
  440. custom_white_list=self.custom_white_list,
  441. custom_black_list=self.custom_black_list,
  442. level=self.amp_level):
  443. # model forward
  444. outputs = model(data)
  445. loss = outputs['loss']
  446. # model backward
  447. scaled_loss = scaler.scale(loss)
  448. scaled_loss.backward()
  449. fused_allreduce_gradients(
  450. list(model.parameters()), None)
  451. else:
  452. with paddle.amp.auto_cast(
  453. enable=self.cfg.use_gpu or self.cfg.use_npu or
  454. self.cfg.use_mlu,
  455. custom_white_list=self.custom_white_list,
  456. custom_black_list=self.custom_black_list,
  457. level=self.amp_level):
  458. # model forward
  459. outputs = model(data)
  460. loss = outputs['loss']
  461. # model backward
  462. scaled_loss = scaler.scale(loss)
  463. scaled_loss.backward()
  464. # in dygraph mode, optimizer.minimize is equal to optimizer.step
  465. scaler.minimize(self.optimizer, scaled_loss)
  466. else:
  467. if isinstance(
  468. model, paddle.
  469. DataParallel) and use_fused_allreduce_gradients:
  470. with model.no_sync():
  471. # model forward
  472. outputs = model(data)
  473. loss = outputs['loss']
  474. # model backward
  475. loss.backward()
  476. fused_allreduce_gradients(
  477. list(model.parameters()), None)
  478. else:
  479. # model forward
  480. outputs = model(data)
  481. loss = outputs['loss']
  482. # model backward
  483. loss.backward()
  484. self.optimizer.step()
  485. curr_lr = self.optimizer.get_lr()
  486. self.lr.step()
  487. if self.cfg.get('unstructured_prune'):
  488. self.pruner.step()
  489. self.optimizer.clear_grad()
  490. self.status['learning_rate'] = curr_lr
  491. if self._nranks < 2 or self._local_rank == 0:
  492. self.status['training_staus'].update(outputs)
  493. self.status['batch_time'].update(time.time() - iter_tic)
  494. self._compose_callback.on_step_end(self.status)
  495. if self.use_ema:
  496. self.ema.update()
  497. iter_tic = time.time()
  498. if self.cfg.get('unstructured_prune'):
  499. self.pruner.update_params()
  500. is_snapshot = (self._nranks < 2 or (self._local_rank == 0 or self.cfg.metric == "Pose3DEval")) \
  501. and ((epoch_id + 1) % self.cfg.snapshot_epoch == 0 or epoch_id == self.end_epoch - 1)
  502. if is_snapshot and self.use_ema:
  503. # apply ema weight on model
  504. weight = copy.deepcopy(self.model.state_dict())
  505. self.model.set_dict(self.ema.apply())
  506. self.status['weight'] = weight
  507. self._compose_callback.on_epoch_end(self.status)
  508. if validate and is_snapshot:
  509. if not hasattr(self, '_eval_loader'):
  510. # build evaluation dataset and loader
  511. self._eval_dataset = self.cfg.EvalDataset
  512. self._eval_batch_sampler = \
  513. paddle.io.BatchSampler(
  514. self._eval_dataset,
  515. batch_size=self.cfg.EvalReader['batch_size'])
  516. # If metric is VOC, need to be set collate_batch=False.
  517. if self.cfg.metric == 'VOC':
  518. self.cfg['EvalReader']['collate_batch'] = False
  519. if self.cfg.metric == "Pose3DEval":
  520. self._eval_loader = create('EvalReader')(
  521. self._eval_dataset, self.cfg.worker_num)
  522. else:
  523. self._eval_loader = create('EvalReader')(
  524. self._eval_dataset,
  525. self.cfg.worker_num,
  526. batch_sampler=self._eval_batch_sampler)
  527. # if validation in training is enabled, metrics should be re-init
  528. # Init_mark makes sure this code will only execute once
  529. if validate and Init_mark == False:
  530. Init_mark = True
  531. self._init_metrics(validate=validate)
  532. self._reset_metrics()
  533. with paddle.no_grad():
  534. self.status['save_best_model'] = True
  535. self._eval_with_loader(self._eval_loader)
  536. if is_snapshot and self.use_ema:
  537. # reset original weight
  538. self.model.set_dict(weight)
  539. self.status.pop('weight')
  540. self._compose_callback.on_train_end(self.status)
  541. def _eval_with_loader(self, loader):
  542. sample_num = 0
  543. tic = time.time()
  544. self._compose_callback.on_epoch_begin(self.status)
  545. self.status['mode'] = 'eval'
  546. self.model.eval()
  547. if self.cfg.get('print_flops', False):
  548. flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
  549. self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
  550. self._flops(flops_loader)
  551. for step_id, data in enumerate(loader):
  552. self.status['step_id'] = step_id
  553. self._compose_callback.on_step_begin(self.status)
  554. # forward
  555. if self.use_amp:
  556. with paddle.amp.auto_cast(
  557. enable=self.cfg.use_gpu or self.cfg.use_npu or
  558. self.cfg.use_mlu,
  559. custom_white_list=self.custom_white_list,
  560. custom_black_list=self.custom_black_list,
  561. level=self.amp_level):
  562. outs = self.model(data)
  563. else:
  564. outs = self.model(data)
  565. # update metrics
  566. for metric in self._metrics:
  567. metric.update(data, outs)
  568. # multi-scale inputs: all inputs have same im_id
  569. if isinstance(data, typing.Sequence):
  570. sample_num += data[0]['im_id'].numpy().shape[0]
  571. else:
  572. sample_num += data['im_id'].numpy().shape[0]
  573. self._compose_callback.on_step_end(self.status)
  574. self.status['sample_num'] = sample_num
  575. self.status['cost_time'] = time.time() - tic
  576. # accumulate metric to log out
  577. for metric in self._metrics:
  578. metric.accumulate()
  579. metric.log()
  580. self._compose_callback.on_epoch_end(self.status)
  581. # reset metric states for metric may performed multiple times
  582. self._reset_metrics()
  583. def evaluate(self):
  584. # get distributed model
  585. if self.cfg.get('fleet', False):
  586. self.model = fleet.distributed_model(self.model)
  587. self.optimizer = fleet.distributed_optimizer(self.optimizer)
  588. elif self._nranks > 1:
  589. find_unused_parameters = self.cfg[
  590. 'find_unused_parameters'] if 'find_unused_parameters' in self.cfg else False
  591. self.model = paddle.DataParallel(
  592. self.model, find_unused_parameters=find_unused_parameters)
  593. with paddle.no_grad():
  594. self._eval_with_loader(self.loader)
  595. def _eval_with_loader_slice(self,
  596. loader,
  597. slice_size=[640, 640],
  598. overlap_ratio=[0.25, 0.25],
  599. combine_method='nms',
  600. match_threshold=0.6,
  601. match_metric='iou'):
  602. sample_num = 0
  603. tic = time.time()
  604. self._compose_callback.on_epoch_begin(self.status)
  605. self.status['mode'] = 'eval'
  606. self.model.eval()
  607. if self.cfg.get('print_flops', False):
  608. flops_loader = create('{}Reader'.format(self.mode.capitalize()))(
  609. self.dataset, self.cfg.worker_num, self._eval_batch_sampler)
  610. self._flops(flops_loader)
  611. merged_bboxs = []
  612. for step_id, data in enumerate(loader):
  613. self.status['step_id'] = step_id
  614. self._compose_callback.on_step_begin(self.status)
  615. # forward
  616. if self.use_amp:
  617. with paddle.amp.auto_cast(
  618. enable=self.cfg.use_gpu or self.cfg.use_npu or
  619. self.cfg.use_mlu,
  620. custom_white_list=self.custom_white_list,
  621. custom_black_list=self.custom_black_list,
  622. level=self.amp_level):
  623. outs = self.model(data)
  624. else:
  625. outs = self.model(data)
  626. shift_amount = data['st_pix']
  627. outs['bbox'][:, 2:4] = outs['bbox'][:, 2:4] + shift_amount
  628. outs['bbox'][:, 4:6] = outs['bbox'][:, 4:6] + shift_amount
  629. merged_bboxs.append(outs['bbox'])
  630. if data['is_last'] > 0:
  631. # merge matching predictions
  632. merged_results = {'bbox': []}
  633. if combine_method == 'nms':
  634. final_boxes = multiclass_nms(
  635. np.concatenate(merged_bboxs), self.cfg.num_classes,
  636. match_threshold, match_metric)
  637. merged_results['bbox'] = np.concatenate(final_boxes)
  638. elif combine_method == 'concat':
  639. merged_results['bbox'] = np.concatenate(merged_bboxs)
  640. else:
  641. raise ValueError(
  642. "Now only support 'nms' or 'concat' to fuse detection results."
  643. )
  644. merged_results['im_id'] = np.array([[0]])
  645. merged_results['bbox_num'] = np.array(
  646. [len(merged_results['bbox'])])
  647. merged_bboxs = []
  648. data['im_id'] = data['ori_im_id']
  649. # update metrics
  650. for metric in self._metrics:
  651. metric.update(data, merged_results)
  652. # multi-scale inputs: all inputs have same im_id
  653. if isinstance(data, typing.Sequence):
  654. sample_num += data[0]['im_id'].numpy().shape[0]
  655. else:
  656. sample_num += data['im_id'].numpy().shape[0]
  657. self._compose_callback.on_step_end(self.status)
  658. self.status['sample_num'] = sample_num
  659. self.status['cost_time'] = time.time() - tic
  660. # accumulate metric to log out
  661. for metric in self._metrics:
  662. metric.accumulate()
  663. metric.log()
  664. self._compose_callback.on_epoch_end(self.status)
  665. # reset metric states for metric may performed multiple times
  666. self._reset_metrics()
  667. def evaluate_slice(self,
  668. slice_size=[640, 640],
  669. overlap_ratio=[0.25, 0.25],
  670. combine_method='nms',
  671. match_threshold=0.6,
  672. match_metric='iou'):
  673. with paddle.no_grad():
  674. self._eval_with_loader_slice(self.loader, slice_size, overlap_ratio,
  675. combine_method, match_threshold,
  676. match_metric)
  677. def slice_predict(self,
  678. images,
  679. slice_size=[640, 640],
  680. overlap_ratio=[0.25, 0.25],
  681. combine_method='nms',
  682. match_threshold=0.6,
  683. match_metric='iou',
  684. draw_threshold=0.5,
  685. output_dir='output',
  686. save_results=False,
  687. visualize=True):
  688. if not os.path.exists(output_dir):
  689. os.makedirs(output_dir)
  690. self.dataset.set_slice_images(images, slice_size, overlap_ratio)
  691. loader = create('TestReader')(self.dataset, 0)
  692. imid2path = self.dataset.get_imid2path()
  693. def setup_metrics_for_loader():
  694. # mem
  695. metrics = copy.deepcopy(self._metrics)
  696. mode = self.mode
  697. save_prediction_only = self.cfg[
  698. 'save_prediction_only'] if 'save_prediction_only' in self.cfg else None
  699. output_eval = self.cfg[
  700. 'output_eval'] if 'output_eval' in self.cfg else None
  701. # modify
  702. self.mode = '_test'
  703. self.cfg['save_prediction_only'] = True
  704. self.cfg['output_eval'] = output_dir
  705. self.cfg['imid2path'] = imid2path
  706. self._init_metrics()
  707. # restore
  708. self.mode = mode
  709. self.cfg.pop('save_prediction_only')
  710. if save_prediction_only is not None:
  711. self.cfg['save_prediction_only'] = save_prediction_only
  712. self.cfg.pop('output_eval')
  713. if output_eval is not None:
  714. self.cfg['output_eval'] = output_eval
  715. self.cfg.pop('imid2path')
  716. _metrics = copy.deepcopy(self._metrics)
  717. self._metrics = metrics
  718. return _metrics
  719. if save_results:
  720. metrics = setup_metrics_for_loader()
  721. else:
  722. metrics = []
  723. anno_file = self.dataset.get_anno()
  724. clsid2catid, catid2name = get_categories(
  725. self.cfg.metric, anno_file=anno_file)
  726. # Run Infer
  727. self.status['mode'] = 'test'
  728. self.model.eval()
  729. if self.cfg.get('print_flops', False):
  730. flops_loader = create('TestReader')(self.dataset, 0)
  731. self._flops(flops_loader)
  732. results = [] # all images
  733. merged_bboxs = [] # single image
  734. for step_id, data in enumerate(tqdm(loader)):
  735. self.status['step_id'] = step_id
  736. # forward
  737. outs = self.model(data)
  738. outs['bbox'] = outs['bbox'].numpy() # only in test mode
  739. shift_amount = data['st_pix']
  740. outs['bbox'][:, 2:4] = outs['bbox'][:, 2:4] + shift_amount.numpy()
  741. outs['bbox'][:, 4:6] = outs['bbox'][:, 4:6] + shift_amount.numpy()
  742. merged_bboxs.append(outs['bbox'])
  743. if data['is_last'] > 0:
  744. # merge matching predictions
  745. merged_results = {'bbox': []}
  746. if combine_method == 'nms':
  747. final_boxes = multiclass_nms(
  748. np.concatenate(merged_bboxs), self.cfg.num_classes,
  749. match_threshold, match_metric)
  750. merged_results['bbox'] = np.concatenate(final_boxes)
  751. elif combine_method == 'concat':
  752. merged_results['bbox'] = np.concatenate(merged_bboxs)
  753. else:
  754. raise ValueError(
  755. "Now only support 'nms' or 'concat' to fuse detection results."
  756. )
  757. merged_results['im_id'] = np.array([[0]])
  758. merged_results['bbox_num'] = np.array(
  759. [len(merged_results['bbox'])])
  760. merged_bboxs = []
  761. data['im_id'] = data['ori_im_id']
  762. for _m in metrics:
  763. _m.update(data, merged_results)
  764. for key in ['im_shape', 'scale_factor', 'im_id']:
  765. if isinstance(data, typing.Sequence):
  766. merged_results[key] = data[0][key]
  767. else:
  768. merged_results[key] = data[key]
  769. for key, value in merged_results.items():
  770. if hasattr(value, 'numpy'):
  771. merged_results[key] = value.numpy()
  772. results.append(merged_results)
  773. for _m in metrics:
  774. _m.accumulate()
  775. _m.reset()
  776. if visualize:
  777. for outs in results:
  778. batch_res = get_infer_results(outs, clsid2catid)
  779. bbox_num = outs['bbox_num']
  780. start = 0
  781. for i, im_id in enumerate(outs['im_id']):
  782. image_path = imid2path[int(im_id)]
  783. image = Image.open(image_path).convert('RGB')
  784. image = ImageOps.exif_transpose(image)
  785. self.status['original_image'] = np.array(image.copy())
  786. end = start + bbox_num[i]
  787. bbox_res = batch_res['bbox'][start:end] \
  788. if 'bbox' in batch_res else None
  789. mask_res = batch_res['mask'][start:end] \
  790. if 'mask' in batch_res else None
  791. segm_res = batch_res['segm'][start:end] \
  792. if 'segm' in batch_res else None
  793. keypoint_res = batch_res['keypoint'][start:end] \
  794. if 'keypoint' in batch_res else None
  795. pose3d_res = batch_res['pose3d'][start:end] \
  796. if 'pose3d' in batch_res else None
  797. image = visualize_results(
  798. image, bbox_res, mask_res, segm_res, keypoint_res,
  799. pose3d_res, int(im_id), catid2name, draw_threshold)
  800. self.status['result_image'] = np.array(image.copy())
  801. if self._compose_callback:
  802. self._compose_callback.on_step_end(self.status)
  803. # save image with detection
  804. save_name = self._get_save_image_name(output_dir,
  805. image_path)
  806. logger.info("Detection bbox results save in {}".format(
  807. save_name))
  808. image.save(save_name, quality=95)
  809. start = end
  810. def predict(self,
  811. images,
  812. draw_threshold=0.5,
  813. output_dir='output',
  814. save_results=False,
  815. visualize=True):
  816. if not os.path.exists(output_dir):
  817. os.makedirs(output_dir)
  818. self.dataset.set_images(images)
  819. loader = create('TestReader')(self.dataset, 0)
  820. imid2path = self.dataset.get_imid2path()
  821. def setup_metrics_for_loader():
  822. # mem
  823. metrics = copy.deepcopy(self._metrics)
  824. mode = self.mode
  825. save_prediction_only = self.cfg[
  826. 'save_prediction_only'] if 'save_prediction_only' in self.cfg else None
  827. output_eval = self.cfg[
  828. 'output_eval'] if 'output_eval' in self.cfg else None
  829. # modify
  830. self.mode = '_test'
  831. self.cfg['save_prediction_only'] = True
  832. self.cfg['output_eval'] = output_dir
  833. self.cfg['imid2path'] = imid2path
  834. self._init_metrics()
  835. # restore
  836. self.mode = mode
  837. self.cfg.pop('save_prediction_only')
  838. if save_prediction_only is not None:
  839. self.cfg['save_prediction_only'] = save_prediction_only
  840. self.cfg.pop('output_eval')
  841. if output_eval is not None:
  842. self.cfg['output_eval'] = output_eval
  843. self.cfg.pop('imid2path')
  844. _metrics = copy.deepcopy(self._metrics)
  845. self._metrics = metrics
  846. return _metrics
  847. if save_results:
  848. metrics = setup_metrics_for_loader()
  849. else:
  850. metrics = []
  851. anno_file = self.dataset.get_anno()
  852. clsid2catid, catid2name = get_categories(
  853. self.cfg.metric, anno_file=anno_file)
  854. # Run Infer
  855. self.status['mode'] = 'test'
  856. self.model.eval()
  857. if self.cfg.get('print_flops', False):
  858. flops_loader = create('TestReader')(self.dataset, 0)
  859. self._flops(flops_loader)
  860. results = []
  861. for step_id, data in enumerate(tqdm(loader)):
  862. self.status['step_id'] = step_id
  863. # forward
  864. outs = self.model(data)
  865. for _m in metrics:
  866. _m.update(data, outs)
  867. for key in ['im_shape', 'scale_factor', 'im_id']:
  868. if isinstance(data, typing.Sequence):
  869. outs[key] = data[0][key]
  870. else:
  871. outs[key] = data[key]
  872. for key, value in outs.items():
  873. if hasattr(value, 'numpy'):
  874. outs[key] = value.numpy()
  875. results.append(outs)
  876. # sniper
  877. if type(self.dataset) == SniperCOCODataSet:
  878. results = self.dataset.anno_cropper.aggregate_chips_detections(
  879. results)
  880. for _m in metrics:
  881. _m.accumulate()
  882. _m.reset()
  883. if visualize:
  884. for outs in results:
  885. batch_res = get_infer_results(outs, clsid2catid)
  886. bbox_num = outs['bbox_num']
  887. start = 0
  888. for i, im_id in enumerate(outs['im_id']):
  889. image_path = imid2path[int(im_id)]
  890. image = Image.open(image_path).convert('RGB')
  891. image = ImageOps.exif_transpose(image)
  892. self.status['original_image'] = np.array(image.copy())
  893. end = start + bbox_num[i]
  894. bbox_res = batch_res['bbox'][start:end] \
  895. if 'bbox' in batch_res else None
  896. mask_res = batch_res['mask'][start:end] \
  897. if 'mask' in batch_res else None
  898. segm_res = batch_res['segm'][start:end] \
  899. if 'segm' in batch_res else None
  900. keypoint_res = batch_res['keypoint'][start:end] \
  901. if 'keypoint' in batch_res else None
  902. pose3d_res = batch_res['pose3d'][start:end] \
  903. if 'pose3d' in batch_res else None
  904. image = visualize_results(
  905. image, bbox_res, mask_res, segm_res, keypoint_res,
  906. pose3d_res, int(im_id), catid2name, draw_threshold)
  907. self.status['result_image'] = np.array(image.copy())
  908. if self._compose_callback:
  909. self._compose_callback.on_step_end(self.status)
  910. # save image with detection
  911. save_name = self._get_save_image_name(output_dir,
  912. image_path)
  913. logger.info("Detection bbox results save in {}".format(
  914. save_name))
  915. image.save(save_name, quality=95)
  916. start = end
  917. def _get_save_image_name(self, output_dir, image_path):
  918. """
  919. Get save image name from source image path.
  920. """
  921. image_name = os.path.split(image_path)[-1]
  922. name, ext = os.path.splitext(image_name)
  923. return os.path.join(output_dir, "{}".format(name)) + ext
  924. def _get_infer_cfg_and_input_spec(self,
  925. save_dir,
  926. prune_input=True,
  927. kl_quant=False):
  928. image_shape = None
  929. im_shape = [None, 2]
  930. scale_factor = [None, 2]
  931. if self.cfg.architecture in MOT_ARCH:
  932. test_reader_name = 'TestMOTReader'
  933. else:
  934. test_reader_name = 'TestReader'
  935. if 'inputs_def' in self.cfg[test_reader_name]:
  936. inputs_def = self.cfg[test_reader_name]['inputs_def']
  937. image_shape = inputs_def.get('image_shape', None)
  938. # set image_shape=[None, 3, -1, -1] as default
  939. if image_shape is None:
  940. image_shape = [None, 3, -1, -1]
  941. if len(image_shape) == 3:
  942. image_shape = [None] + image_shape
  943. else:
  944. im_shape = [image_shape[0], 2]
  945. scale_factor = [image_shape[0], 2]
  946. if hasattr(self.model, 'deploy'):
  947. self.model.deploy = True
  948. if 'slim' not in self.cfg:
  949. for layer in self.model.sublayers():
  950. if hasattr(layer, 'convert_to_deploy'):
  951. layer.convert_to_deploy()
  952. if hasattr(self.cfg, 'export') and 'fuse_conv_bn' in self.cfg[
  953. 'export'] and self.cfg['export']['fuse_conv_bn']:
  954. self.model = fuse_conv_bn(self.model)
  955. export_post_process = self.cfg['export'].get(
  956. 'post_process', False) if hasattr(self.cfg, 'export') else True
  957. export_nms = self.cfg['export'].get('nms', False) if hasattr(
  958. self.cfg, 'export') else True
  959. export_benchmark = self.cfg['export'].get(
  960. 'benchmark', False) if hasattr(self.cfg, 'export') else False
  961. if hasattr(self.model, 'fuse_norm'):
  962. self.model.fuse_norm = self.cfg['TestReader'].get('fuse_normalize',
  963. False)
  964. if hasattr(self.model, 'export_post_process'):
  965. self.model.export_post_process = export_post_process if not export_benchmark else False
  966. if hasattr(self.model, 'export_nms'):
  967. self.model.export_nms = export_nms if not export_benchmark else False
  968. if export_post_process and not export_benchmark:
  969. image_shape = [None] + image_shape[1:]
  970. # Save infer cfg
  971. _dump_infer_config(self.cfg,
  972. os.path.join(save_dir, 'infer_cfg.yml'), image_shape,
  973. self.model)
  974. input_spec = [{
  975. "image": InputSpec(
  976. shape=image_shape, name='image'),
  977. "im_shape": InputSpec(
  978. shape=im_shape, name='im_shape'),
  979. "scale_factor": InputSpec(
  980. shape=scale_factor, name='scale_factor')
  981. }]
  982. if self.cfg.architecture == 'DeepSORT':
  983. input_spec[0].update({
  984. "crops": InputSpec(
  985. shape=[None, 3, 192, 64], name='crops')
  986. })
  987. if prune_input:
  988. static_model = paddle.jit.to_static(
  989. self.model, input_spec=input_spec)
  990. # NOTE: dy2st do not pruned program, but jit.save will prune program
  991. # input spec, prune input spec here and save with pruned input spec
  992. pruned_input_spec = _prune_input_spec(
  993. input_spec, static_model.forward.main_program,
  994. static_model.forward.outputs)
  995. else:
  996. static_model = None
  997. pruned_input_spec = input_spec
  998. # TODO: Hard code, delete it when support prune input_spec.
  999. if self.cfg.architecture == 'PicoDet' and not export_post_process:
  1000. pruned_input_spec = [{
  1001. "image": InputSpec(
  1002. shape=image_shape, name='image')
  1003. }]
  1004. if kl_quant:
  1005. if self.cfg.architecture == 'PicoDet' or 'ppyoloe' in self.cfg.weights:
  1006. pruned_input_spec = [{
  1007. "image": InputSpec(
  1008. shape=image_shape, name='image'),
  1009. "scale_factor": InputSpec(
  1010. shape=scale_factor, name='scale_factor')
  1011. }]
  1012. elif 'tinypose' in self.cfg.weights:
  1013. pruned_input_spec = [{
  1014. "image": InputSpec(
  1015. shape=image_shape, name='image')
  1016. }]
  1017. return static_model, pruned_input_spec
  1018. def export(self, output_dir='output_inference'):
  1019. if hasattr(self.model, 'aux_neck'):
  1020. self.model.__delattr__('aux_neck')
  1021. if hasattr(self.model, 'aux_head'):
  1022. self.model.__delattr__('aux_head')
  1023. self.model.eval()
  1024. model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
  1025. save_dir = os.path.join(output_dir, model_name)
  1026. if not os.path.exists(save_dir):
  1027. os.makedirs(save_dir)
  1028. static_model, pruned_input_spec = self._get_infer_cfg_and_input_spec(
  1029. save_dir)
  1030. # dy2st and save model
  1031. if 'slim' not in self.cfg or 'QAT' not in self.cfg['slim_type']:
  1032. paddle.jit.save(
  1033. static_model,
  1034. os.path.join(save_dir, 'model'),
  1035. input_spec=pruned_input_spec)
  1036. else:
  1037. self.cfg.slim.save_quantized_model(
  1038. self.model,
  1039. os.path.join(save_dir, 'model'),
  1040. input_spec=pruned_input_spec)
  1041. logger.info("Export model and saved in {}".format(save_dir))
  1042. def post_quant(self, output_dir='output_inference'):
  1043. model_name = os.path.splitext(os.path.split(self.cfg.filename)[-1])[0]
  1044. save_dir = os.path.join(output_dir, model_name)
  1045. if not os.path.exists(save_dir):
  1046. os.makedirs(save_dir)
  1047. for idx, data in enumerate(self.loader):
  1048. self.model(data)
  1049. if idx == int(self.cfg.get('quant_batch_num', 10)):
  1050. break
  1051. # TODO: support prune input_spec
  1052. kl_quant = True if hasattr(self.cfg.slim, 'ptq') else False
  1053. _, pruned_input_spec = self._get_infer_cfg_and_input_spec(
  1054. save_dir, prune_input=False, kl_quant=kl_quant)
  1055. self.cfg.slim.save_quantized_model(
  1056. self.model,
  1057. os.path.join(save_dir, 'model'),
  1058. input_spec=pruned_input_spec)
  1059. logger.info("Export Post-Quant model and saved in {}".format(save_dir))
  1060. def _flops(self, loader):
  1061. if hasattr(self.model, 'aux_neck'):
  1062. self.model.__delattr__('aux_neck')
  1063. if hasattr(self.model, 'aux_head'):
  1064. self.model.__delattr__('aux_head')
  1065. self.model.eval()
  1066. try:
  1067. import paddleslim
  1068. except Exception as e:
  1069. logger.warning(
  1070. 'Unable to calculate flops, please install paddleslim, for example: `pip install paddleslim`'
  1071. )
  1072. return
  1073. from paddleslim.analysis import dygraph_flops as flops
  1074. input_data = None
  1075. for data in loader:
  1076. input_data = data
  1077. break
  1078. input_spec = [{
  1079. "image": input_data['image'][0].unsqueeze(0),
  1080. "im_shape": input_data['im_shape'][0].unsqueeze(0),
  1081. "scale_factor": input_data['scale_factor'][0].unsqueeze(0)
  1082. }]
  1083. flops = flops(self.model, input_spec) / (1000**3)
  1084. logger.info(" Model FLOPs : {:.6f}G. (image shape is {})".format(
  1085. flops, input_data['image'][0].unsqueeze(0).shape))
  1086. def parse_mot_images(self, cfg):
  1087. import glob
  1088. # for quant
  1089. dataset_dir = cfg['EvalMOTDataset'].dataset_dir
  1090. data_root = cfg['EvalMOTDataset'].data_root
  1091. data_root = '{}/{}'.format(dataset_dir, data_root)
  1092. seqs = os.listdir(data_root)
  1093. seqs.sort()
  1094. all_images = []
  1095. for seq in seqs:
  1096. infer_dir = os.path.join(data_root, seq)
  1097. assert infer_dir is None or os.path.isdir(infer_dir), \
  1098. "{} is not a directory".format(infer_dir)
  1099. images = set()
  1100. exts = ['jpg', 'jpeg', 'png', 'bmp']
  1101. exts += [ext.upper() for ext in exts]
  1102. for ext in exts:
  1103. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  1104. images = list(images)
  1105. images.sort()
  1106. assert len(images) > 0, "no image found in {}".format(infer_dir)
  1107. all_images.extend(images)
  1108. logger.info("Found {} inference images in total.".format(
  1109. len(images)))
  1110. return all_images