1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283 |
- import copy
- import paddle
- import paddle.nn as nn
- from .basic_loss import LossFromOutput
- from .det_db_loss import DBLoss
- from .det_east_loss import EASTLoss
- from .det_sast_loss import SASTLoss
- from .det_pse_loss import PSELoss
- from .det_fce_loss import FCELoss
- from .det_ct_loss import CTLoss
- from .det_drrg_loss import DRRGLoss
- from .rec_ctc_loss import CTCLoss
- from .rec_att_loss import AttentionLoss
- from .rec_srn_loss import SRNLoss
- from .rec_ce_loss import CELoss
- from .rec_sar_loss import SARLoss
- from .rec_aster_loss import AsterLoss
- from .rec_pren_loss import PRENLoss
- from .rec_multi_loss import MultiLoss
- from .rec_vl_loss import VLLoss
- from .rec_spin_att_loss import SPINAttentionLoss
- from .rec_rfl_loss import RFLLoss
- from .rec_can_loss import CANLoss
- from .cls_loss import ClsLoss
- from .e2e_pg_loss import PGLoss
- from .kie_sdmgr_loss import SDMGRLoss
- from .basic_loss import DistanceLoss
- from .combined_loss import CombinedLoss
- from .table_att_loss import TableAttentionLoss, SLALoss
- from .table_master_loss import TableMasterLoss
- from .vqa_token_layoutlm_loss import VQASerTokenLayoutLMLoss
- from .stroke_focus_loss import StrokeFocusLoss
- from .text_focus_loss import TelescopeLoss
- def build_loss(config):
- support_dict = [
- 'DBLoss', 'PSELoss', 'EASTLoss', 'SASTLoss', 'FCELoss', 'CTCLoss',
- 'ClsLoss', 'AttentionLoss', 'SRNLoss', 'PGLoss', 'CombinedLoss',
- 'CELoss', 'TableAttentionLoss', 'SARLoss', 'AsterLoss', 'SDMGRLoss',
- 'VQASerTokenLayoutLMLoss', 'LossFromOutput', 'PRENLoss', 'MultiLoss',
- 'TableMasterLoss', 'SPINAttentionLoss', 'VLLoss', 'StrokeFocusLoss',
- 'SLALoss', 'CTLoss', 'RFLLoss', 'DRRGLoss', 'CANLoss', 'TelescopeLoss'
- ]
- config = copy.deepcopy(config)
- module_name = config.pop('name')
- assert module_name in support_dict, Exception('loss only support {}'.format(
- support_dict))
- module_class = eval(module_name)(**config)
- return module_class
|