123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693 |
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import os
- import sys
- import platform
- import yaml
- import time
- import datetime
- import paddle
- import paddle.distributed as dist
- from tqdm import tqdm
- import cv2
- import numpy as np
- from argparse import ArgumentParser, RawDescriptionHelpFormatter
- from ppocr.utils.stats import TrainingStats
- from ppocr.utils.save_load import save_model
- from ppocr.utils.utility import print_dict, AverageMeter
- from ppocr.utils.logging import get_logger
- from ppocr.utils.loggers import VDLLogger, WandbLogger, Loggers
- from ppocr.utils import profiler
- from ppocr.data import build_dataloader
- class ArgsParser(ArgumentParser):
- def __init__(self):
- super(ArgsParser, self).__init__(
- formatter_class=RawDescriptionHelpFormatter)
- self.add_argument("-c", "--config", help="configuration file to use")
- self.add_argument(
- "-o", "--opt", nargs='+', help="set configuration options")
- self.add_argument(
- '-p',
- '--profiler_options',
- type=str,
- default=None,
- help='The option of profiler, which should be in format ' \
- '\"key1=value1;key2=value2;key3=value3\".'
- )
- def parse_args(self, argv=None):
- args = super(ArgsParser, self).parse_args(argv)
- assert args.config is not None, \
- "Please specify --config=configure_file_path."
- args.opt = self._parse_opt(args.opt)
- return args
- def _parse_opt(self, opts):
- config = {}
- if not opts:
- return config
- for s in opts:
- s = s.strip()
- k, v = s.split('=')
- config[k] = yaml.load(v, Loader=yaml.Loader)
- return config
- def load_config(file_path):
- """
- Load config from yml/yaml file.
- Args:
- file_path (str): Path of the config file to be loaded.
- Returns: global config
- """
- _, ext = os.path.splitext(file_path)
- assert ext in ['.yml', '.yaml'], "only support yaml files for now"
- config = yaml.load(open(file_path, 'rb'), Loader=yaml.Loader)
- return config
- def merge_config(config, opts):
- """
- Merge config into global config.
- Args:
- config (dict): Config to be merged.
- Returns: global config
- """
- for key, value in opts.items():
- if "." not in key:
- if isinstance(value, dict) and key in config:
- config[key].update(value)
- else:
- config[key] = value
- else:
- sub_keys = key.split('.')
- assert (
- sub_keys[0] in config
- ), "the sub_keys can only be one of global_config: {}, but get: " \
- "{}, please check your running command".format(
- config.keys(), sub_keys[0])
- cur = config[sub_keys[0]]
- for idx, sub_key in enumerate(sub_keys[1:]):
- if idx == len(sub_keys) - 2:
- cur[sub_key] = value
- else:
- cur = cur[sub_key]
- return config
- def check_device(use_gpu, use_xpu=False, use_npu=False, use_mlu=False):
- """
- Log error and exit when set use_gpu=true in paddlepaddle
- cpu version.
- """
- err = "Config {} cannot be set as true while your paddle " \
- "is not compiled with {} ! \nPlease try: \n" \
- "\t1. Install paddlepaddle to run model on {} \n" \
- "\t2. Set {} as false in config file to run " \
- "model on CPU"
- try:
- if use_gpu and use_xpu:
- print("use_xpu and use_gpu can not both be ture.")
- if use_gpu and not paddle.is_compiled_with_cuda():
- print(err.format("use_gpu", "cuda", "gpu", "use_gpu"))
- sys.exit(1)
- if use_xpu and not paddle.device.is_compiled_with_xpu():
- print(err.format("use_xpu", "xpu", "xpu", "use_xpu"))
- sys.exit(1)
- if use_npu and not paddle.device.is_compiled_with_npu():
- print(err.format("use_npu", "npu", "npu", "use_npu"))
- sys.exit(1)
- if use_mlu and not paddle.device.is_compiled_with_mlu():
- print(err.format("use_mlu", "mlu", "mlu", "use_mlu"))
- sys.exit(1)
- except Exception as e:
- pass
- def to_float32(preds):
- if isinstance(preds, dict):
- for k in preds:
- if isinstance(preds[k], dict) or isinstance(preds[k], list):
- preds[k] = to_float32(preds[k])
- elif isinstance(preds[k], paddle.Tensor):
- preds[k] = preds[k].astype(paddle.float32)
- elif isinstance(preds, list):
- for k in range(len(preds)):
- if isinstance(preds[k], dict):
- preds[k] = to_float32(preds[k])
- elif isinstance(preds[k], list):
- preds[k] = to_float32(preds[k])
- elif isinstance(preds[k], paddle.Tensor):
- preds[k] = preds[k].astype(paddle.float32)
- elif isinstance(preds, paddle.Tensor):
- preds = preds.astype(paddle.float32)
- return preds
- def train(config,
- train_dataloader,
- valid_dataloader,
- device,
- model,
- loss_class,
- optimizer,
- lr_scheduler,
- post_process_class,
- eval_class,
- pre_best_model_dict,
- logger,
- log_writer=None,
- scaler=None,
- amp_level='O2',
- amp_custom_black_list=[]):
- cal_metric_during_train = config['Global'].get('cal_metric_during_train',
- False)
- calc_epoch_interval = config['Global'].get('calc_epoch_interval', 1)
- log_smooth_window = config['Global']['log_smooth_window']
- epoch_num = config['Global']['epoch_num']
- print_batch_step = config['Global']['print_batch_step']
- eval_batch_step = config['Global']['eval_batch_step']
- profiler_options = config['profiler_options']
- global_step = 0
- if 'global_step' in pre_best_model_dict:
- global_step = pre_best_model_dict['global_step']
- start_eval_step = 0
- if type(eval_batch_step) == list and len(eval_batch_step) >= 2:
- start_eval_step = eval_batch_step[0]
- eval_batch_step = eval_batch_step[1]
- if len(valid_dataloader) == 0:
- logger.info(
- 'No Images in eval dataset, evaluation during training ' \
- 'will be disabled'
- )
- start_eval_step = 1e111
- logger.info(
- "During the training process, after the {}th iteration, " \
- "an evaluation is run every {} iterations".
- format(start_eval_step, eval_batch_step))
- save_epoch_step = config['Global']['save_epoch_step']
- save_model_dir = config['Global']['save_model_dir']
- if not os.path.exists(save_model_dir):
- os.makedirs(save_model_dir)
- main_indicator = eval_class.main_indicator
- best_model_dict = {main_indicator: 0}
- best_model_dict.update(pre_best_model_dict)
- train_stats = TrainingStats(log_smooth_window, ['lr'])
- model_average = False
- model.train()
- use_srn = config['Architecture']['algorithm'] == "SRN"
- extra_input_models = [
- "SRN", "NRTR", "SAR", "SEED", "SVTR", "SPIN", "VisionLAN",
- "RobustScanner", "RFL", 'DRRG'
- ]
- extra_input = False
- if config['Architecture']['algorithm'] == 'Distillation':
- for key in config['Architecture']["Models"]:
- extra_input = extra_input or config['Architecture']['Models'][key][
- 'algorithm'] in extra_input_models
- else:
- extra_input = config['Architecture']['algorithm'] in extra_input_models
- try:
- model_type = config['Architecture']['model_type']
- except:
- model_type = None
- algorithm = config['Architecture']['algorithm']
- start_epoch = best_model_dict[
- 'start_epoch'] if 'start_epoch' in best_model_dict else 1
- total_samples = 0
- train_reader_cost = 0.0
- train_batch_cost = 0.0
- reader_start = time.time()
- eta_meter = AverageMeter()
- max_iter = len(train_dataloader) - 1 if platform.system(
- ) == "Windows" else len(train_dataloader)
- for epoch in range(start_epoch, epoch_num + 1):
- if train_dataloader.dataset.need_reset:
- train_dataloader = build_dataloader(
- config, 'Train', device, logger, seed=epoch)
- max_iter = len(train_dataloader) - 1 if platform.system(
- ) == "Windows" else len(train_dataloader)
- for idx, batch in enumerate(train_dataloader):
- profiler.add_profiler_step(profiler_options)
- train_reader_cost += time.time() - reader_start
- if idx >= max_iter:
- break
- lr = optimizer.get_lr()
- images = batch[0]
- if use_srn:
- model_average = True
- # use amp
- if scaler:
- with paddle.amp.auto_cast(
- level=amp_level,
- custom_black_list=amp_custom_black_list):
- if model_type == 'table' or extra_input:
- preds = model(images, data=batch[1:])
- elif model_type in ["kie"]:
- preds = model(batch)
- elif algorithm in ['CAN']:
- preds = model(batch[:3])
- else:
- preds = model(images)
- preds = to_float32(preds)
- loss = loss_class(preds, batch)
- avg_loss = loss['loss']
- scaled_avg_loss = scaler.scale(avg_loss)
- scaled_avg_loss.backward()
- scaler.minimize(optimizer, scaled_avg_loss)
- else:
- if model_type == 'table' or extra_input:
- preds = model(images, data=batch[1:])
- elif model_type in ["kie", 'sr']:
- preds = model(batch)
- elif algorithm in ['CAN']:
- preds = model(batch[:3])
- else:
- preds = model(images)
- loss = loss_class(preds, batch)
- avg_loss = loss['loss']
- avg_loss.backward()
- optimizer.step()
- optimizer.clear_grad()
- if cal_metric_during_train and epoch % calc_epoch_interval == 0: # only rec and cls need
- batch = [item.numpy() for item in batch]
- if model_type in ['kie', 'sr']:
- eval_class(preds, batch)
- elif model_type in ['table']:
- post_result = post_process_class(preds, batch)
- eval_class(post_result, batch)
- elif algorithm in ['CAN']:
- model_type = 'can'
- eval_class(preds[0], batch[2:], epoch_reset=(idx == 0))
- else:
- if config['Loss']['name'] in ['MultiLoss', 'MultiLoss_v2'
- ]: # for multi head loss
- post_result = post_process_class(
- preds['ctc'], batch[1]) # for CTC head out
- elif config['Loss']['name'] in ['VLLoss']:
- post_result = post_process_class(preds, batch[1],
- batch[-1])
- else:
- post_result = post_process_class(preds, batch[1])
- eval_class(post_result, batch)
- metric = eval_class.get_metric()
- train_stats.update(metric)
- train_batch_time = time.time() - reader_start
- train_batch_cost += train_batch_time
- eta_meter.update(train_batch_time)
- global_step += 1
- total_samples += len(images)
- if not isinstance(lr_scheduler, float):
- lr_scheduler.step()
- # logger and visualdl
- stats = {k: v.numpy().mean() for k, v in loss.items()}
- stats['lr'] = lr
- train_stats.update(stats)
- if log_writer is not None and dist.get_rank() == 0:
- log_writer.log_metrics(
- metrics=train_stats.get(), prefix="TRAIN", step=global_step)
- if dist.get_rank() == 0 and (
- (global_step > 0 and global_step % print_batch_step == 0) or
- (idx >= len(train_dataloader) - 1)):
- logs = train_stats.log()
- eta_sec = ((epoch_num + 1 - epoch) * \
- len(train_dataloader) - idx - 1) * eta_meter.avg
- eta_sec_format = str(datetime.timedelta(seconds=int(eta_sec)))
- strs = 'epoch: [{}/{}], global_step: {}, {}, avg_reader_cost: ' \
- '{:.5f} s, avg_batch_cost: {:.5f} s, avg_samples: {}, ' \
- 'ips: {:.5f} samples/s, eta: {}'.format(
- epoch, epoch_num, global_step, logs,
- train_reader_cost / print_batch_step,
- train_batch_cost / print_batch_step,
- total_samples / print_batch_step,
- total_samples / train_batch_cost, eta_sec_format)
- logger.info(strs)
- total_samples = 0
- train_reader_cost = 0.0
- train_batch_cost = 0.0
- # eval
- if global_step > start_eval_step and \
- (global_step - start_eval_step) % eval_batch_step == 0 \
- and dist.get_rank() == 0:
- if model_average:
- Model_Average = paddle.incubate.optimizer.ModelAverage(
- 0.15,
- parameters=model.parameters(),
- min_average_window=10000,
- max_average_window=15625)
- Model_Average.apply()
- cur_metric = eval(
- model,
- valid_dataloader,
- post_process_class,
- eval_class,
- model_type,
- extra_input=extra_input,
- scaler=scaler,
- amp_level=amp_level,
- amp_custom_black_list=amp_custom_black_list)
- cur_metric_str = 'cur metric, {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
- logger.info(cur_metric_str)
- # logger metric
- if log_writer is not None:
- log_writer.log_metrics(
- metrics=cur_metric, prefix="EVAL", step=global_step)
- if cur_metric[main_indicator] >= best_model_dict[
- main_indicator]:
- best_model_dict.update(cur_metric)
- best_model_dict['best_epoch'] = epoch
- save_model(
- model,
- optimizer,
- save_model_dir,
- logger,
- config,
- is_best=True,
- prefix='best_accuracy',
- best_model_dict=best_model_dict,
- epoch=epoch,
- global_step=global_step)
- best_str = 'best metric, {}'.format(', '.join([
- '{}: {}'.format(k, v) for k, v in best_model_dict.items()
- ]))
- logger.info(best_str)
- # logger best metric
- if log_writer is not None:
- log_writer.log_metrics(
- metrics={
- "best_{}".format(main_indicator):
- best_model_dict[main_indicator]
- },
- prefix="EVAL",
- step=global_step)
- log_writer.log_model(
- is_best=True,
- prefix="best_accuracy",
- metadata=best_model_dict)
- reader_start = time.time()
- if dist.get_rank() == 0:
- save_model(
- model,
- optimizer,
- save_model_dir,
- logger,
- config,
- is_best=False,
- prefix='latest',
- best_model_dict=best_model_dict,
- epoch=epoch,
- global_step=global_step)
- if log_writer is not None:
- log_writer.log_model(is_best=False, prefix="latest")
- if dist.get_rank() == 0 and epoch > 0 and epoch % save_epoch_step == 0:
- save_model(
- model,
- optimizer,
- save_model_dir,
- logger,
- config,
- is_best=False,
- prefix='iter_epoch_{}'.format(epoch),
- best_model_dict=best_model_dict,
- epoch=epoch,
- global_step=global_step)
- if log_writer is not None:
- log_writer.log_model(
- is_best=False, prefix='iter_epoch_{}'.format(epoch))
- best_str = 'best metric, {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
- logger.info(best_str)
- if dist.get_rank() == 0 and log_writer is not None:
- log_writer.close()
- return
- def eval(model,
- valid_dataloader,
- post_process_class,
- eval_class,
- model_type=None,
- extra_input=False,
- scaler=None,
- amp_level='O2',
- amp_custom_black_list=[]):
- model.eval()
- with paddle.no_grad():
- total_frame = 0.0
- total_time = 0.0
- pbar = tqdm(
- total=len(valid_dataloader),
- desc='eval model:',
- position=0,
- leave=True)
- max_iter = len(valid_dataloader) - 1 if platform.system(
- ) == "Windows" else len(valid_dataloader)
- sum_images = 0
- for idx, batch in enumerate(valid_dataloader):
- if idx >= max_iter:
- break
- images = batch[0]
- start = time.time()
- # use amp
- if scaler:
- with paddle.amp.auto_cast(
- level=amp_level,
- custom_black_list=amp_custom_black_list):
- if model_type == 'table' or extra_input:
- preds = model(images, data=batch[1:])
- elif model_type in ["kie"]:
- preds = model(batch)
- elif model_type in ['can']:
- preds = model(batch[:3])
- elif model_type in ['sr']:
- preds = model(batch)
- sr_img = preds["sr_img"]
- lr_img = preds["lr_img"]
- else:
- preds = model(images)
- preds = to_float32(preds)
- else:
- if model_type == 'table' or extra_input:
- preds = model(images, data=batch[1:])
- elif model_type in ["kie"]:
- preds = model(batch)
- elif model_type in ['can']:
- preds = model(batch[:3])
- elif model_type in ['sr']:
- preds = model(batch)
- sr_img = preds["sr_img"]
- lr_img = preds["lr_img"]
- else:
- preds = model(images)
- batch_numpy = []
- for item in batch:
- if isinstance(item, paddle.Tensor):
- batch_numpy.append(item.numpy())
- else:
- batch_numpy.append(item)
- # Obtain usable results from post-processing methods
- total_time += time.time() - start
- # Evaluate the results of the current batch
- if model_type in ['table', 'kie']:
- if post_process_class is None:
- eval_class(preds, batch_numpy)
- else:
- post_result = post_process_class(preds, batch_numpy)
- eval_class(post_result, batch_numpy)
- elif model_type in ['sr']:
- eval_class(preds, batch_numpy)
- elif model_type in ['can']:
- eval_class(preds[0], batch_numpy[2:], epoch_reset=(idx == 0))
- else:
- post_result = post_process_class(preds, batch_numpy[1])
- eval_class(post_result, batch_numpy)
- pbar.update(1)
- total_frame += len(images)
- sum_images += 1
- # Get final metric,eg. acc or hmean
- metric = eval_class.get_metric()
- pbar.close()
- model.train()
- metric['fps'] = total_frame / total_time
- return metric
- def update_center(char_center, post_result, preds):
- result, label = post_result
- feats, logits = preds
- logits = paddle.argmax(logits, axis=-1)
- feats = feats.numpy()
- logits = logits.numpy()
- for idx_sample in range(len(label)):
- if result[idx_sample][0] == label[idx_sample][0]:
- feat = feats[idx_sample]
- logit = logits[idx_sample]
- for idx_time in range(len(logit)):
- index = logit[idx_time]
- if index in char_center.keys():
- char_center[index][0] = (
- char_center[index][0] * char_center[index][1] +
- feat[idx_time]) / (char_center[index][1] + 1)
- char_center[index][1] += 1
- else:
- char_center[index] = [feat[idx_time], 1]
- return char_center
- def get_center(model, eval_dataloader, post_process_class):
- pbar = tqdm(total=len(eval_dataloader), desc='get center:')
- max_iter = len(eval_dataloader) - 1 if platform.system(
- ) == "Windows" else len(eval_dataloader)
- char_center = dict()
- for idx, batch in enumerate(eval_dataloader):
- if idx >= max_iter:
- break
- images = batch[0]
- start = time.time()
- preds = model(images)
- batch = [item.numpy() for item in batch]
- # Obtain usable results from post-processing methods
- post_result = post_process_class(preds, batch[1])
- #update char_center
- char_center = update_center(char_center, post_result, preds)
- pbar.update(1)
- pbar.close()
- for key in char_center.keys():
- char_center[key] = char_center[key][0]
- return char_center
- def preprocess(is_train=False):
- FLAGS = ArgsParser().parse_args()
- profiler_options = FLAGS.profiler_options
- config = load_config(FLAGS.config)
- config = merge_config(config, FLAGS.opt)
- profile_dic = {"profiler_options": FLAGS.profiler_options}
- config = merge_config(config, profile_dic)
- if is_train:
- # save_config
- save_model_dir = config['Global']['save_model_dir']
- os.makedirs(save_model_dir, exist_ok=True)
- with open(os.path.join(save_model_dir, 'config.yml'), 'w') as f:
- yaml.dump(
- dict(config), f, default_flow_style=False, sort_keys=False)
- log_file = '{}/train.log'.format(save_model_dir)
- else:
- log_file = None
- logger = get_logger(log_file=log_file)
- # check if set use_gpu=True in paddlepaddle cpu version
- use_gpu = config['Global'].get('use_gpu', False)
- use_xpu = config['Global'].get('use_xpu', False)
- use_npu = config['Global'].get('use_npu', False)
- use_mlu = config['Global'].get('use_mlu', False)
- alg = config['Architecture']['algorithm']
- assert alg in [
- 'EAST', 'DB', 'SAST', 'Rosetta', 'CRNN', 'STARNet', 'RARE', 'SRN',
- 'CLS', 'PGNet', 'Distillation', 'NRTR', 'TableAttn', 'SAR', 'PSE',
- 'SEED', 'SDMGR', 'LayoutXLM', 'LayoutLM', 'LayoutLMv2', 'PREN', 'FCE',
- 'SVTR', 'ViTSTR', 'ABINet', 'DB++', 'TableMaster', 'SPIN', 'VisionLAN',
- 'Gestalt', 'SLANet', 'RobustScanner', 'CT', 'RFL', 'DRRG', 'CAN',
- 'Telescope'
- ]
- if use_xpu:
- device = 'xpu:{0}'.format(os.getenv('FLAGS_selected_xpus', 0))
- elif use_npu:
- device = 'npu:{0}'.format(os.getenv('FLAGS_selected_npus', 0))
- elif use_mlu:
- device = 'mlu:{0}'.format(os.getenv('FLAGS_selected_mlus', 0))
- else:
- device = 'gpu:{}'.format(dist.ParallelEnv()
- .dev_id) if use_gpu else 'cpu'
- check_device(use_gpu, use_xpu, use_npu, use_mlu)
- device = paddle.set_device(device)
- config['Global']['distributed'] = dist.get_world_size() != 1
- loggers = []
- if 'use_visualdl' in config['Global'] and config['Global']['use_visualdl']:
- save_model_dir = config['Global']['save_model_dir']
- vdl_writer_path = '{}/vdl/'.format(save_model_dir)
- log_writer = VDLLogger(vdl_writer_path)
- loggers.append(log_writer)
- if ('use_wandb' in config['Global'] and
- config['Global']['use_wandb']) or 'wandb' in config:
- save_dir = config['Global']['save_model_dir']
- wandb_writer_path = "{}/wandb".format(save_dir)
- if "wandb" in config:
- wandb_params = config['wandb']
- else:
- wandb_params = dict()
- wandb_params.update({'save_dir': save_model_dir})
- log_writer = WandbLogger(**wandb_params, config=config)
- loggers.append(log_writer)
- else:
- log_writer = None
- print_dict(config, logger)
- if loggers:
- log_writer = Loggers(loggers)
- else:
- log_writer = None
- logger.info('train with paddle {} and device {}'.format(paddle.__version__,
- device))
- return config, device, logger, log_writer
|