train.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import os
  18. import sys
  19. # add python path of PadleDetection to sys.path
  20. parent_path = os.path.abspath(os.path.join(__file__, *(['..'] * 2)))
  21. sys.path.insert(0, parent_path)
  22. # ignore warning log
  23. import warnings
  24. warnings.filterwarnings('ignore')
  25. import paddle
  26. from ppdet.core.workspace import load_config, merge_config
  27. from ppdet.engine import Trainer, TrainerCot, init_parallel_env, set_random_seed, init_fleet_env
  28. from ppdet.engine.trainer_ssod import Trainer_DenseTeacher
  29. from ppdet.slim import build_slim_model
  30. from ppdet.utils.cli import ArgsParser, merge_args
  31. import ppdet.utils.check as check
  32. from ppdet.utils.logger import setup_logger
  33. logger = setup_logger('train')
  34. def parse_args():
  35. parser = ArgsParser()
  36. parser.add_argument(
  37. "--eval",
  38. action='store_true',
  39. default=False,
  40. help="Whether to perform evaluation in train")
  41. parser.add_argument(
  42. "-r", "--resume", default=None, help="weights path for resume")
  43. parser.add_argument(
  44. "--slim_config",
  45. default=None,
  46. type=str,
  47. help="Configuration file of slim method.")
  48. parser.add_argument(
  49. "--enable_ce",
  50. type=bool,
  51. default=False,
  52. help="If set True, enable continuous evaluation job."
  53. "This flag is only used for internal test.")
  54. parser.add_argument(
  55. "--amp",
  56. action='store_true',
  57. default=False,
  58. help="Enable auto mixed precision training.")
  59. parser.add_argument(
  60. "--fleet", action='store_true', default=False, help="Use fleet or not")
  61. parser.add_argument(
  62. "--use_vdl",
  63. type=bool,
  64. default=False,
  65. help="whether to record the data to VisualDL.")
  66. parser.add_argument(
  67. '--vdl_log_dir',
  68. type=str,
  69. default="vdl_log_dir/scalar",
  70. help='VisualDL logging directory for scalar.')
  71. parser.add_argument(
  72. "--use_wandb",
  73. type=bool,
  74. default=False,
  75. help="whether to record the data to wandb.")
  76. parser.add_argument(
  77. '--save_prediction_only',
  78. action='store_true',
  79. default=False,
  80. help='Whether to save the evaluation results only')
  81. parser.add_argument(
  82. '--profiler_options',
  83. type=str,
  84. default=None,
  85. help="The option of profiler, which should be in "
  86. "format \"key1=value1;key2=value2;key3=value3\"."
  87. "please see ppdet/utils/profiler.py for detail.")
  88. parser.add_argument(
  89. '--save_proposals',
  90. action='store_true',
  91. default=False,
  92. help='Whether to save the train proposals')
  93. parser.add_argument(
  94. '--proposals_path',
  95. type=str,
  96. default="sniper/proposals.json",
  97. help='Train proposals directory')
  98. parser.add_argument(
  99. "--to_static",
  100. action='store_true',
  101. default=False,
  102. help="Enable dy2st to train.")
  103. args = parser.parse_args()
  104. return args
  105. def run(FLAGS, cfg):
  106. # init fleet environment
  107. if cfg.fleet:
  108. init_fleet_env(cfg.get('find_unused_parameters', False))
  109. else:
  110. # init parallel environment if nranks > 1
  111. init_parallel_env()
  112. if FLAGS.enable_ce:
  113. set_random_seed(0)
  114. # build trainer
  115. ssod_method = cfg.get('ssod_method', None)
  116. if ssod_method is not None:
  117. if ssod_method == 'DenseTeacher':
  118. trainer = Trainer_DenseTeacher(cfg, mode='train')
  119. else:
  120. raise ValueError(
  121. "Semi-Supervised Object Detection only support DenseTeacher now."
  122. )
  123. elif cfg.get('use_cot', False):
  124. trainer = TrainerCot(cfg, mode='train')
  125. else:
  126. trainer = Trainer(cfg, mode='train')
  127. # load weights
  128. if FLAGS.resume is not None:
  129. trainer.resume_weights(FLAGS.resume)
  130. elif 'pretrain_weights' in cfg and cfg.pretrain_weights:
  131. trainer.load_weights(cfg.pretrain_weights)
  132. # training
  133. trainer.train(FLAGS.eval)
  134. def main():
  135. FLAGS = parse_args()
  136. cfg = load_config(FLAGS.config)
  137. merge_args(cfg, FLAGS)
  138. merge_config(FLAGS.opt)
  139. # disable npu in config by default
  140. if 'use_npu' not in cfg:
  141. cfg.use_npu = False
  142. # disable xpu in config by default
  143. if 'use_xpu' not in cfg:
  144. cfg.use_xpu = False
  145. if 'use_gpu' not in cfg:
  146. cfg.use_gpu = False
  147. # disable mlu in config by default
  148. if 'use_mlu' not in cfg:
  149. cfg.use_mlu = False
  150. if cfg.use_gpu:
  151. place = paddle.set_device('gpu')
  152. elif cfg.use_npu:
  153. place = paddle.set_device('npu')
  154. elif cfg.use_xpu:
  155. place = paddle.set_device('xpu')
  156. elif cfg.use_mlu:
  157. place = paddle.set_device('mlu')
  158. else:
  159. place = paddle.set_device('cpu')
  160. if FLAGS.slim_config:
  161. cfg = build_slim_model(cfg, FLAGS.slim_config)
  162. # FIXME: Temporarily solve the priority problem of FLAGS.opt
  163. merge_config(FLAGS.opt)
  164. check.check_config(cfg)
  165. check.check_gpu(cfg.use_gpu)
  166. check.check_npu(cfg.use_npu)
  167. check.check_xpu(cfg.use_xpu)
  168. check.check_mlu(cfg.use_mlu)
  169. check.check_version()
  170. run(FLAGS, cfg)
  171. if __name__ == "__main__":
  172. main()