det_infer.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import yaml
  16. import glob
  17. from functools import reduce
  18. import cv2
  19. import numpy as np
  20. import math
  21. import paddle
  22. from paddle.inference import Config
  23. from paddle.inference import create_predictor
  24. import sys
  25. # add deploy path of PadleDetection to sys.path
  26. parent_path = os.path.abspath(os.path.join(__file__, *(['..'])))
  27. sys.path.insert(0, parent_path)
  28. from benchmark_utils import PaddleInferBenchmark
  29. from picodet_postprocess import PicoDetPostProcess
  30. from preprocess import preprocess, Resize, NormalizeImage, Permute, PadStride, LetterBoxResize, Pad, decode_image
  31. from mot.visualize import visualize_box_mask
  32. from mot_utils import argsparser, Timer, get_current_memory_mb
  33. # Global dictionary
  34. SUPPORT_MODELS = {
  35. 'YOLO',
  36. 'PPYOLOE',
  37. 'PicoDet',
  38. 'JDE',
  39. 'FairMOT',
  40. 'DeepSORT',
  41. 'StrongBaseline',
  42. }
  43. def bench_log(detector, img_list, model_info, batch_size=1, name=None):
  44. mems = {
  45. 'cpu_rss_mb': detector.cpu_mem / len(img_list),
  46. 'gpu_rss_mb': detector.gpu_mem / len(img_list),
  47. 'gpu_util': detector.gpu_util * 100 / len(img_list)
  48. }
  49. perf_info = detector.det_times.report(average=True)
  50. data_info = {
  51. 'batch_size': batch_size,
  52. 'shape': "dynamic_shape",
  53. 'data_num': perf_info['img_num']
  54. }
  55. log = PaddleInferBenchmark(detector.config, model_info, data_info,
  56. perf_info, mems)
  57. log(name)
  58. class Detector(object):
  59. """
  60. Args:
  61. pred_config (object): config of model, defined by `Config(model_dir)`
  62. model_dir (str): root path of model.pdiparams, model.pdmodel and infer_cfg.yml
  63. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  64. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16)
  65. batch_size (int): size of pre batch in inference
  66. trt_min_shape (int): min shape for dynamic shape in trt
  67. trt_max_shape (int): max shape for dynamic shape in trt
  68. trt_opt_shape (int): opt shape for dynamic shape in trt
  69. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  70. calibration, trt_calib_mode need to set True
  71. cpu_threads (int): cpu threads
  72. enable_mkldnn (bool): whether to open MKLDNN
  73. output_dir (str): The path of output
  74. threshold (float): The threshold of score for visualization
  75. """
  76. def __init__(
  77. self,
  78. model_dir,
  79. device='CPU',
  80. run_mode='paddle',
  81. batch_size=1,
  82. trt_min_shape=1,
  83. trt_max_shape=1280,
  84. trt_opt_shape=640,
  85. trt_calib_mode=False,
  86. cpu_threads=1,
  87. enable_mkldnn=False,
  88. output_dir='output',
  89. threshold=0.5, ):
  90. self.pred_config = self.set_config(model_dir)
  91. self.predictor, self.config = load_predictor(
  92. model_dir,
  93. run_mode=run_mode,
  94. batch_size=batch_size,
  95. min_subgraph_size=self.pred_config.min_subgraph_size,
  96. device=device,
  97. use_dynamic_shape=self.pred_config.use_dynamic_shape,
  98. trt_min_shape=trt_min_shape,
  99. trt_max_shape=trt_max_shape,
  100. trt_opt_shape=trt_opt_shape,
  101. trt_calib_mode=trt_calib_mode,
  102. cpu_threads=cpu_threads,
  103. enable_mkldnn=enable_mkldnn)
  104. self.det_times = Timer()
  105. self.cpu_mem, self.gpu_mem, self.gpu_util = 0, 0, 0
  106. self.batch_size = batch_size
  107. self.output_dir = output_dir
  108. self.threshold = threshold
  109. def set_config(self, model_dir):
  110. return PredictConfig(model_dir)
  111. def preprocess(self, image_list):
  112. preprocess_ops = []
  113. for op_info in self.pred_config.preprocess_infos:
  114. new_op_info = op_info.copy()
  115. op_type = new_op_info.pop('type')
  116. preprocess_ops.append(eval(op_type)(**new_op_info))
  117. input_im_lst = []
  118. input_im_info_lst = []
  119. for im_path in image_list:
  120. im, im_info = preprocess(im_path, preprocess_ops)
  121. input_im_lst.append(im)
  122. input_im_info_lst.append(im_info)
  123. inputs = create_inputs(input_im_lst, input_im_info_lst)
  124. input_names = self.predictor.get_input_names()
  125. for i in range(len(input_names)):
  126. input_tensor = self.predictor.get_input_handle(input_names[i])
  127. input_tensor.copy_from_cpu(inputs[input_names[i]])
  128. return inputs
  129. def postprocess(self, inputs, result):
  130. # postprocess output of predictor
  131. np_boxes_num = result['boxes_num']
  132. if np_boxes_num[0] <= 0:
  133. print('[WARNNING] No object detected.')
  134. result = {'boxes': np.zeros([0, 6]), 'boxes_num': [0]}
  135. result = {k: v for k, v in result.items() if v is not None}
  136. return result
  137. def predict(self, repeats=1):
  138. '''
  139. Args:
  140. repeats (int): repeats number for prediction
  141. Returns:
  142. result (dict): include 'boxes': np.ndarray: shape:[N,6], N: number of box,
  143. matix element:[class, score, x_min, y_min, x_max, y_max]
  144. '''
  145. # model prediction
  146. np_boxes, np_boxes_num = None, None
  147. for i in range(repeats):
  148. self.predictor.run()
  149. output_names = self.predictor.get_output_names()
  150. boxes_tensor = self.predictor.get_output_handle(output_names[0])
  151. np_boxes = boxes_tensor.copy_to_cpu()
  152. boxes_num = self.predictor.get_output_handle(output_names[1])
  153. np_boxes_num = boxes_num.copy_to_cpu()
  154. result = dict(boxes=np_boxes, boxes_num=np_boxes_num)
  155. return result
  156. def merge_batch_result(self, batch_result):
  157. if len(batch_result) == 1:
  158. return batch_result[0]
  159. res_key = batch_result[0].keys()
  160. results = {k: [] for k in res_key}
  161. for res in batch_result:
  162. for k, v in res.items():
  163. results[k].append(v)
  164. for k, v in results.items():
  165. results[k] = np.concatenate(v)
  166. return results
  167. def get_timer(self):
  168. return self.det_times
  169. def predict_image(self,
  170. image_list,
  171. run_benchmark=False,
  172. repeats=1,
  173. visual=True):
  174. batch_loop_cnt = math.ceil(float(len(image_list)) / self.batch_size)
  175. results = []
  176. for i in range(batch_loop_cnt):
  177. start_index = i * self.batch_size
  178. end_index = min((i + 1) * self.batch_size, len(image_list))
  179. batch_image_list = image_list[start_index:end_index]
  180. if run_benchmark:
  181. # preprocess
  182. inputs = self.preprocess(batch_image_list) # warmup
  183. self.det_times.preprocess_time_s.start()
  184. inputs = self.preprocess(batch_image_list)
  185. self.det_times.preprocess_time_s.end()
  186. # model prediction
  187. result = self.predict(repeats=repeats) # warmup
  188. self.det_times.inference_time_s.start()
  189. result = self.predict(repeats=repeats)
  190. self.det_times.inference_time_s.end(repeats=repeats)
  191. # postprocess
  192. result_warmup = self.postprocess(inputs, result) # warmup
  193. self.det_times.postprocess_time_s.start()
  194. result = self.postprocess(inputs, result)
  195. self.det_times.postprocess_time_s.end()
  196. self.det_times.img_num += len(batch_image_list)
  197. cm, gm, gu = get_current_memory_mb()
  198. self.cpu_mem += cm
  199. self.gpu_mem += gm
  200. self.gpu_util += gu
  201. else:
  202. # preprocess
  203. self.det_times.preprocess_time_s.start()
  204. inputs = self.preprocess(batch_image_list)
  205. self.det_times.preprocess_time_s.end()
  206. # model prediction
  207. self.det_times.inference_time_s.start()
  208. result = self.predict()
  209. self.det_times.inference_time_s.end()
  210. # postprocess
  211. self.det_times.postprocess_time_s.start()
  212. result = self.postprocess(inputs, result)
  213. self.det_times.postprocess_time_s.end()
  214. self.det_times.img_num += len(batch_image_list)
  215. if visual:
  216. visualize(
  217. batch_image_list,
  218. result,
  219. self.pred_config.labels,
  220. output_dir=self.output_dir,
  221. threshold=self.threshold)
  222. results.append(result)
  223. if visual:
  224. print('Test iter {}'.format(i))
  225. results = self.merge_batch_result(results)
  226. return results
  227. def predict_video(self, video_file, camera_id):
  228. video_out_name = 'output.mp4'
  229. if camera_id != -1:
  230. capture = cv2.VideoCapture(camera_id)
  231. else:
  232. capture = cv2.VideoCapture(video_file)
  233. video_out_name = os.path.split(video_file)[-1]
  234. # Get Video info : resolution, fps, frame count
  235. width = int(capture.get(cv2.CAP_PROP_FRAME_WIDTH))
  236. height = int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT))
  237. fps = int(capture.get(cv2.CAP_PROP_FPS))
  238. frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT))
  239. print("fps: %d, frame_count: %d" % (fps, frame_count))
  240. if not os.path.exists(self.output_dir):
  241. os.makedirs(self.output_dir)
  242. out_path = os.path.join(self.output_dir, video_out_name)
  243. fourcc = cv2.VideoWriter_fourcc(* 'mp4v')
  244. writer = cv2.VideoWriter(out_path, fourcc, fps, (width, height))
  245. index = 1
  246. while (1):
  247. ret, frame = capture.read()
  248. if not ret:
  249. break
  250. print('detect frame: %d' % (index))
  251. index += 1
  252. results = self.predict_image([frame], visual=False)
  253. im = visualize_box_mask(
  254. frame,
  255. results,
  256. self.pred_config.labels,
  257. threshold=self.threshold)
  258. im = np.array(im)
  259. writer.write(im)
  260. if camera_id != -1:
  261. cv2.imshow('Mask Detection', im)
  262. if cv2.waitKey(1) & 0xFF == ord('q'):
  263. break
  264. writer.release()
  265. def create_inputs(imgs, im_info):
  266. """generate input for different model type
  267. Args:
  268. imgs (list(numpy)): list of images (np.ndarray)
  269. im_info (list(dict)): list of image info
  270. Returns:
  271. inputs (dict): input of model
  272. """
  273. inputs = {}
  274. im_shape = []
  275. scale_factor = []
  276. if len(imgs) == 1:
  277. inputs['image'] = np.array((imgs[0], )).astype('float32')
  278. inputs['im_shape'] = np.array(
  279. (im_info[0]['im_shape'], )).astype('float32')
  280. inputs['scale_factor'] = np.array(
  281. (im_info[0]['scale_factor'], )).astype('float32')
  282. return inputs
  283. for e in im_info:
  284. im_shape.append(np.array((e['im_shape'], )).astype('float32'))
  285. scale_factor.append(np.array((e['scale_factor'], )).astype('float32'))
  286. inputs['im_shape'] = np.concatenate(im_shape, axis=0)
  287. inputs['scale_factor'] = np.concatenate(scale_factor, axis=0)
  288. imgs_shape = [[e.shape[1], e.shape[2]] for e in imgs]
  289. max_shape_h = max([e[0] for e in imgs_shape])
  290. max_shape_w = max([e[1] for e in imgs_shape])
  291. padding_imgs = []
  292. for img in imgs:
  293. im_c, im_h, im_w = img.shape[:]
  294. padding_im = np.zeros(
  295. (im_c, max_shape_h, max_shape_w), dtype=np.float32)
  296. padding_im[:, :im_h, :im_w] = img
  297. padding_imgs.append(padding_im)
  298. inputs['image'] = np.stack(padding_imgs, axis=0)
  299. return inputs
  300. class PredictConfig():
  301. """set config of preprocess, postprocess and visualize
  302. Args:
  303. model_dir (str): root path of model.yml
  304. """
  305. def __init__(self, model_dir):
  306. # parsing Yaml config for Preprocess
  307. deploy_file = os.path.join(model_dir, 'infer_cfg.yml')
  308. with open(deploy_file) as f:
  309. yml_conf = yaml.safe_load(f)
  310. self.check_model(yml_conf)
  311. self.arch = yml_conf['arch']
  312. self.preprocess_infos = yml_conf['Preprocess']
  313. self.min_subgraph_size = yml_conf['min_subgraph_size']
  314. self.labels = yml_conf['label_list']
  315. self.mask = False
  316. self.use_dynamic_shape = yml_conf['use_dynamic_shape']
  317. if 'mask' in yml_conf:
  318. self.mask = yml_conf['mask']
  319. self.tracker = None
  320. if 'tracker' in yml_conf:
  321. self.tracker = yml_conf['tracker']
  322. if 'NMS' in yml_conf:
  323. self.nms = yml_conf['NMS']
  324. if 'fpn_stride' in yml_conf:
  325. self.fpn_stride = yml_conf['fpn_stride']
  326. self.print_config()
  327. def check_model(self, yml_conf):
  328. """
  329. Raises:
  330. ValueError: loaded model not in supported model type
  331. """
  332. for support_model in SUPPORT_MODELS:
  333. if support_model in yml_conf['arch']:
  334. return True
  335. raise ValueError("Unsupported arch: {}, expect {}".format(yml_conf[
  336. 'arch'], SUPPORT_MODELS))
  337. def print_config(self):
  338. print('----------- Model Configuration -----------')
  339. print('%s: %s' % ('Model Arch', self.arch))
  340. print('%s: ' % ('Transform Order'))
  341. for op_info in self.preprocess_infos:
  342. print('--%s: %s' % ('transform op', op_info['type']))
  343. print('--------------------------------------------')
  344. def load_predictor(model_dir,
  345. run_mode='paddle',
  346. batch_size=1,
  347. device='CPU',
  348. min_subgraph_size=3,
  349. use_dynamic_shape=False,
  350. trt_min_shape=1,
  351. trt_max_shape=1280,
  352. trt_opt_shape=640,
  353. trt_calib_mode=False,
  354. cpu_threads=1,
  355. enable_mkldnn=False):
  356. """set AnalysisConfig, generate AnalysisPredictor
  357. Args:
  358. model_dir (str): root path of __model__ and __params__
  359. device (str): Choose the device you want to run, it can be: CPU/GPU/XPU, default is CPU
  360. run_mode (str): mode of running(paddle/trt_fp32/trt_fp16/trt_int8)
  361. use_dynamic_shape (bool): use dynamic shape or not
  362. trt_min_shape (int): min shape for dynamic shape in trt
  363. trt_max_shape (int): max shape for dynamic shape in trt
  364. trt_opt_shape (int): opt shape for dynamic shape in trt
  365. trt_calib_mode (bool): If the model is produced by TRT offline quantitative
  366. calibration, trt_calib_mode need to set True
  367. Returns:
  368. predictor (PaddlePredictor): AnalysisPredictor
  369. Raises:
  370. ValueError: predict by TensorRT need device == 'GPU'.
  371. """
  372. if device != 'GPU' and run_mode != 'paddle':
  373. raise ValueError(
  374. "Predict by TensorRT mode: {}, expect device=='GPU', but device == {}"
  375. .format(run_mode, device))
  376. infer_model = os.path.join(model_dir, 'model.pdmodel')
  377. infer_params = os.path.join(model_dir, 'model.pdiparams')
  378. if not os.path.exists(infer_model):
  379. infer_model = os.path.join(model_dir, 'inference.pdmodel')
  380. infer_params = os.path.join(model_dir, 'inference.pdiparams')
  381. if not os.path.exists(infer_model):
  382. raise ValueError(
  383. "Cannot find any inference model in dir: {},".format(model_dir))
  384. config = Config(infer_model, infer_params)
  385. if device == 'GPU':
  386. # initial GPU memory(M), device ID
  387. config.enable_use_gpu(200, 0)
  388. # optimize graph and fuse op
  389. config.switch_ir_optim(True)
  390. elif device == 'XPU':
  391. config.enable_lite_engine()
  392. config.enable_xpu(10 * 1024 * 1024)
  393. else:
  394. config.disable_gpu()
  395. config.set_cpu_math_library_num_threads(cpu_threads)
  396. if enable_mkldnn:
  397. try:
  398. # cache 10 different shapes for mkldnn to avoid memory leak
  399. config.set_mkldnn_cache_capacity(10)
  400. config.enable_mkldnn()
  401. except Exception as e:
  402. print(
  403. "The current environment does not support `mkldnn`, so disable mkldnn."
  404. )
  405. pass
  406. precision_map = {
  407. 'trt_int8': Config.Precision.Int8,
  408. 'trt_fp32': Config.Precision.Float32,
  409. 'trt_fp16': Config.Precision.Half
  410. }
  411. if run_mode in precision_map.keys():
  412. config.enable_tensorrt_engine(
  413. workspace_size=1 << 25,
  414. max_batch_size=batch_size,
  415. min_subgraph_size=min_subgraph_size,
  416. precision_mode=precision_map[run_mode],
  417. use_static=False,
  418. use_calib_mode=trt_calib_mode)
  419. if use_dynamic_shape:
  420. min_input_shape = {
  421. 'image': [batch_size, 3, trt_min_shape, trt_min_shape]
  422. }
  423. max_input_shape = {
  424. 'image': [batch_size, 3, trt_max_shape, trt_max_shape]
  425. }
  426. opt_input_shape = {
  427. 'image': [batch_size, 3, trt_opt_shape, trt_opt_shape]
  428. }
  429. config.set_trt_dynamic_shape_info(min_input_shape, max_input_shape,
  430. opt_input_shape)
  431. print('trt set dynamic shape done!')
  432. # disable print log when predict
  433. config.disable_glog_info()
  434. # enable shared memory
  435. config.enable_memory_optim()
  436. # disable feed, fetch OP, needed by zero_copy_run
  437. config.switch_use_feed_fetch_ops(False)
  438. predictor = create_predictor(config)
  439. return predictor, config
  440. def get_test_images(infer_dir, infer_img):
  441. """
  442. Get image path list in TEST mode
  443. """
  444. assert infer_img is not None or infer_dir is not None, \
  445. "--infer_img or --infer_dir should be set"
  446. assert infer_img is None or os.path.isfile(infer_img), \
  447. "{} is not a file".format(infer_img)
  448. assert infer_dir is None or os.path.isdir(infer_dir), \
  449. "{} is not a directory".format(infer_dir)
  450. # infer_img has a higher priority
  451. if infer_img and os.path.isfile(infer_img):
  452. return [infer_img]
  453. images = set()
  454. infer_dir = os.path.abspath(infer_dir)
  455. assert os.path.isdir(infer_dir), \
  456. "infer_dir {} is not a directory".format(infer_dir)
  457. exts = ['jpg', 'jpeg', 'png', 'bmp']
  458. exts += [ext.upper() for ext in exts]
  459. for ext in exts:
  460. images.update(glob.glob('{}/*.{}'.format(infer_dir, ext)))
  461. images = list(images)
  462. assert len(images) > 0, "no image found in {}".format(infer_dir)
  463. print("Found {} inference images in total.".format(len(images)))
  464. return images
  465. def visualize(image_list, result, labels, output_dir='output/', threshold=0.5):
  466. # visualize the predict result
  467. start_idx = 0
  468. for idx, image_file in enumerate(image_list):
  469. im_bboxes_num = result['boxes_num'][idx]
  470. im_results = {}
  471. if 'boxes' in result:
  472. im_results['boxes'] = result['boxes'][start_idx:start_idx +
  473. im_bboxes_num, :]
  474. start_idx += im_bboxes_num
  475. im = visualize_box_mask(
  476. image_file, im_results, labels, threshold=threshold)
  477. img_name = os.path.split(image_file)[-1]
  478. if not os.path.exists(output_dir):
  479. os.makedirs(output_dir)
  480. out_path = os.path.join(output_dir, img_name)
  481. im.save(out_path, quality=95)
  482. print("save result to: " + out_path)
  483. def print_arguments(args):
  484. print('----------- Running Arguments -----------')
  485. for arg, value in sorted(vars(args).items()):
  486. print('%s: %s' % (arg, value))
  487. print('------------------------------------------')
  488. def main():
  489. deploy_file = os.path.join(FLAGS.model_dir, 'infer_cfg.yml')
  490. with open(deploy_file) as f:
  491. yml_conf = yaml.safe_load(f)
  492. arch = yml_conf['arch']
  493. detector_func = 'Detector'
  494. detector = eval(detector_func)(FLAGS.model_dir,
  495. device=FLAGS.device,
  496. run_mode=FLAGS.run_mode,
  497. batch_size=FLAGS.batch_size,
  498. trt_min_shape=FLAGS.trt_min_shape,
  499. trt_max_shape=FLAGS.trt_max_shape,
  500. trt_opt_shape=FLAGS.trt_opt_shape,
  501. trt_calib_mode=FLAGS.trt_calib_mode,
  502. cpu_threads=FLAGS.cpu_threads,
  503. enable_mkldnn=FLAGS.enable_mkldnn,
  504. threshold=FLAGS.threshold,
  505. output_dir=FLAGS.output_dir)
  506. # predict from video file or camera video stream
  507. if FLAGS.video_file is not None or FLAGS.camera_id != -1:
  508. detector.predict_video(FLAGS.video_file, FLAGS.camera_id)
  509. else:
  510. # predict from image
  511. if FLAGS.image_dir is None and FLAGS.image_file is not None:
  512. assert FLAGS.batch_size == 1, "batch_size should be 1, when image_file is not None"
  513. img_list = get_test_images(FLAGS.image_dir, FLAGS.image_file)
  514. detector.predict_image(img_list, FLAGS.run_benchmark, repeats=10)
  515. if not FLAGS.run_benchmark:
  516. detector.det_times.info(average=True)
  517. else:
  518. mode = FLAGS.run_mode
  519. model_dir = FLAGS.model_dir
  520. model_info = {
  521. 'model_name': model_dir.strip('/').split('/')[-1],
  522. 'precision': mode.split('_')[-1]
  523. }
  524. bench_log(detector, img_list, model_info, name='DET')
  525. if __name__ == '__main__':
  526. paddle.enable_static()
  527. parser = argsparser()
  528. FLAGS = parser.parse_args()
  529. print_arguments(FLAGS)
  530. FLAGS.device = FLAGS.device.upper()
  531. assert FLAGS.device in ['CPU', 'GPU', 'XPU'
  532. ], "device should be CPU, GPU or XPU"
  533. main()