Browse Source

db分类网络

yangjun 1 năm trước cách đây
mục cha
commit
7a30ac2d7f

+ 167 - 0
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student_tw_card.yml

@@ -0,0 +1,167 @@
+Global:
+  debug: false
+  use_gpu: false
+  epoch_num: 500
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: ./output/ch_PP-OCR_V3_det_front/
+  save_epoch_step: 100
+  eval_batch_step:
+  - 0
+  - 400
+  cal_metric_during_train: false
+  pretrained_model: ./ch_PP-OCRv3_det_distill_train/student.pdparams
+  checkpoints: null
+  save_inference_dir: null
+  use_visualdl: false
+  infer_img: /home/aistudio/work/PaddleOCR/images/tw_01.png
+  save_res_path: ./jerome/det_db/predicts_db.txt
+  distributed: true
+  label_list: /home/aistudio/work/PaddleOCR/configs/det/ch_PP-OCRv3/label_list.txt
+  num_classes: 14
+
+Architecture:
+  model_type: det
+  algorithm: DB
+  Transform:
+  Backbone:
+    name: MobileNetV3
+    scale: 0.5
+    model_name: large
+    disable_se: True
+  Neck:
+    name: RSEFPN
+    out_channels: 96
+    shortcut: True
+  Head:
+    name: DBHead
+    k: 50
+
+Loss:
+  name: DBLoss
+  balance_loss: true
+  main_loss_type: DiceLoss
+  alpha: 5
+  beta: 10
+  ohem_ratio: 3
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.0001
+    warmup_epoch: 2
+  regularizer:
+    name: L2
+    factor: 5.0e-05
+PostProcess:
+  name: DBPostProcess
+  thresh: 0.3
+  box_thresh: 0.6
+  max_candidates: 1000
+  unclip_ratio: 1.5
+Metric:
+  name: DetMetric
+  main_indicator: hmean
+Train:
+  dataset:
+    name: SimpleDataSet
+    data_dir: /home/aistudio/data/data181819/tw_idcard_det_front_1000/
+    label_file_list:
+      - /home/aistudio/data/data181819/tw_idcard_det_front_1000/Label.txt
+    ratio_list: [1.0]
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - IaaAugment:
+        augmenter_args:
+        - type: Fliplr
+          args:
+            p: 0.5
+        - type: Affine
+          args:
+            rotate:
+            - -10
+            - 10
+        - type: Resize
+          args:
+            size:
+            - 0.5
+            - 3
+    - EastRandomCropData:
+        size:
+        - 960
+        - 960
+        max_tries: 50
+        keep_ratio: true
+    - MakeBorderMap:
+        shrink_ratio: 0.4
+        thresh_min: 0.3
+        thresh_max: 0.7
+    - MakeShrinkMap:
+        shrink_ratio: 0.4
+        min_text_size: 8
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - threshold_map
+        - threshold_mask
+        - shrink_map
+        - shrink_mask
+        - class_mask
+  loader:
+    shuffle: true
+    drop_last: false
+    batch_size_per_card: 8
+    num_workers: 4
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: /home/aistudio/data/data181819/tw_idcard_det_front_1000/
+    label_file_list:
+      - /home/aistudio/data/data181819/tw_idcard_det_front_1000/Label.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - DetResizeForTest:
+        resize_long: 960
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - shape
+        - polys
+        - ignore_tags
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 1
+    num_workers: 2

+ 167 - 0
configs/det/ch_PP-OCRv3/ch_PP-OCRv3_det_student_tw_card_back.yml

@@ -0,0 +1,167 @@
+Global:
+  debug: false
+  use_gpu: false
+  epoch_num: 500
+  log_smooth_window: 20
+  print_batch_step: 10
+  save_model_dir: ./output/ch_PP-OCR_V3_det_back_11111/
+  save_epoch_step: 100
+  eval_batch_step:
+  - 0
+  - 400
+  cal_metric_during_train: false
+  pretrained_model: ./ch_PP-OCRv3_det_distill_train/student.pdparams
+  checkpoints: null
+  save_inference_dir: null
+  use_visualdl: false
+  infer_img: /home/aistudio/work/PaddleOCR/images/tw_01.png
+  save_res_path: ./jerome/det_db/predicts_db.txt
+  distributed: true
+  label_list: /home/aistudio/work/PaddleOCR/configs/det/ch_PP-OCRv3/label_list.txt
+  num_classes: 11
+
+Architecture:
+  model_type: det
+  algorithm: DB
+  Transform:
+  Backbone:
+    name: MobileNetV3
+    scale: 0.5
+    model_name: large
+    disable_se: True
+  Neck:
+    name: RSEFPN
+    out_channels: 96
+    shortcut: True
+  Head:
+    name: DBHead
+    k: 50
+
+Loss:
+  name: DBLoss
+  balance_loss: true
+  main_loss_type: DiceLoss
+  alpha: 5
+  beta: 10
+  ohem_ratio: 3
+Optimizer:
+  name: Adam
+  beta1: 0.9
+  beta2: 0.999
+  lr:
+    name: Cosine
+    learning_rate: 0.0001
+    warmup_epoch: 2
+  regularizer:
+    name: L2
+    factor: 5.0e-05
+PostProcess:
+  name: DBPostProcess
+  thresh: 0.3
+  box_thresh: 0.6
+  max_candidates: 1000
+  unclip_ratio: 1.5
+Metric:
+  name: DetMetric
+  main_indicator: hmean
+Train:
+  dataset:
+    name: SimpleDataSet
+    data_dir: /home/aistudio/data/data183949/
+    label_file_list:
+      - /home/aistudio/data/data183949/Label.txt
+    ratio_list: [1.0]
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - IaaAugment:
+        augmenter_args:
+        - type: Fliplr
+          args:
+            p: 0.5
+        - type: Affine
+          args:
+            rotate:
+            - -10
+            - 10
+        - type: Resize
+          args:
+            size:
+            - 0.5
+            - 3
+    - EastRandomCropData:
+        size:
+        - 960
+        - 960
+        max_tries: 50
+        keep_ratio: true
+    - MakeBorderMap:
+        shrink_ratio: 0.4
+        thresh_min: 0.3
+        thresh_max: 0.7
+    - MakeShrinkMap:
+        shrink_ratio: 0.4
+        min_text_size: 8
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - threshold_map
+        - threshold_mask
+        - shrink_map
+        - shrink_mask
+        - class_mask
+  loader:
+    shuffle: true
+    drop_last: false
+    batch_size_per_card: 8
+    num_workers: 4
+Eval:
+  dataset:
+    name: SimpleDataSet
+    data_dir: /home/aistudio/data/data183949/
+    label_file_list:
+      - /home/aistudio/data/data183949/Label.txt
+    transforms:
+    - DecodeImage:
+        img_mode: BGR
+        channel_first: false
+    - DetLabelEncode: null
+    - DetResizeForTest:
+        resize_long: 960
+    - NormalizeImage:
+        scale: 1./255.
+        mean:
+        - 0.485
+        - 0.456
+        - 0.406
+        std:
+        - 0.229
+        - 0.224
+        - 0.225
+        order: hwc
+    - ToCHWImage: null
+    - KeepKeys:
+        keep_keys:
+        - image
+        - shape
+        - polys
+        - ignore_tags
+  loader:
+    shuffle: false
+    drop_last: false
+    batch_size_per_card: 1
+    num_workers: 2

+ 11 - 0
configs/det/ch_PP-OCRv3/label_list.txt

@@ -0,0 +1,11 @@
+background
+name
+bore_date
+gender
+id
+certification_date
+parent_name
+spouse_name_level
+birth_place
+address
+number

+ 24 - 4
ppocr/data/imaug/label_ops.py

@@ -43,14 +43,28 @@ class ClsLabelEncode(object):
 
 
 class DetLabelEncode(object):
-    def __init__(self, **kwargs):
-        pass
+    #def __init__(self, **kwargs):
+    #    pass
+    def __init__(self, label_list, num_classes=1, **kwargs):
+        self.num_classes = num_classes
+        self.label_list = []
+        if label_list is not None:
+            if isinstance(label_list, str):
+                with open(label_list, "r+", encoding="utf-8") as f:
+                    for line in f.readlines():
+                        self.label_list.append(line.replace("\n", ""))
+            else:
+                self.label_list = label_list
+ 
+        if num_classes != len(self.label_list):
+            assert "label_list长度与num_classes长度不符合"
 
     def __call__(self, data):
         label = data['label']
         label = json.loads(label)
         nBox = len(label)
         boxes, txts, txt_tags = [], [], []
+        classes = []
         for bno in range(0, nBox):
             box = label[bno]['points']
             txt = label[bno]['transcription']
@@ -58,17 +72,23 @@ class DetLabelEncode(object):
             txts.append(txt)
             if txt in ['*', '###']:
                 txt_tags.append(True)
+                if self.num_classes > 1:
+                    classes.append(-2)
             else:
                 txt_tags.append(False)
+                if self.num_classes > 1:
+                    classes.append(int(self.label_list.index(txt)))
         if len(boxes) == 0:
             return None
         boxes = self.expand_points_num(boxes)
         boxes = np.array(boxes, dtype=np.float32)
-        txt_tags = np.array(txt_tags, dtype=np.bool_)
-
+        txt_tags = np.array(txt_tags, dtype=np.bool)
+        classes = classes
         data['polys'] = boxes
         data['texts'] = txts
         data['ignore_tags'] = txt_tags
+        if self.num_classes > 1:
+            data['classes'] = classes
         return data
 
     def order_points_clockwise(self, pts):

+ 26 - 0
ppocr/data/imaug/label_ops.py.rej

@@ -0,0 +1,26 @@
+diff a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py	(rejected hunks)
+@@ -58,17 +72,24 @@ class DetLabelEncode(object):
+             txts.append(txt)
+             if txt in ['*', '###']:
+                 txt_tags.append(True)
++                if self.num_classes > 1:
++                    classes.append(-2)
+             else:
+                 txt_tags.append(False)
++                if self.num_classes > 1:
++                    classes.append(int(self.label_list.index(txt)))
+         if len(boxes) == 0:
+             return None
+         boxes = self.expand_points_num(boxes)
+         boxes = np.array(boxes, dtype=np.float32)
+         txt_tags = np.array(txt_tags, dtype=np.bool)
++        classes = classes
+ 
+         data['polys'] = boxes
+         data['texts'] = txts
+         data['ignore_tags'] = txt_tags
++        if self.num_classes > 1:
++            data['classes'] = classes
+         return data
+ 
+     def order_points_clockwise(self, pts):

+ 9 - 1
ppocr/data/imaug/make_shrink_map.py

@@ -35,19 +35,23 @@ class MakeShrinkMap(object):
     Typically following the process of class `MakeICDARData`.
     '''
 
-    def __init__(self, min_text_size=8, shrink_ratio=0.4, **kwargs):
+    def __init__(self, min_text_size=8, shrink_ratio=0.4, num_classes=1, **kwargs):
         self.min_text_size = min_text_size
         self.shrink_ratio = shrink_ratio
+        self.num_classes = num_classes
 
     def __call__(self, data):
         image = data['image']
         text_polys = data['polys']
         ignore_tags = data['ignore_tags']
+        if self.num_classes > 1:
+            classes = data['classes']
 
         h, w = image.shape[:2]
         text_polys, ignore_tags = self.validate_polygons(text_polys,
                                                          ignore_tags, h, w)
         gt = np.zeros((h, w), dtype=np.float32)
+        gt_class = np.zeros((h, w), dtype=np.float32)
         mask = np.ones((h, w), dtype=np.float32)
         for i in range(len(text_polys)):
             polygon = text_polys[i]
@@ -87,8 +91,12 @@ class MakeShrinkMap(object):
                 for each_shirnk in shrinked:
                     shirnk = np.array(each_shirnk).reshape(-1, 2)
                     cv2.fillPoly(gt, [shirnk.astype(np.int32)], 1)
+                    if self.num_classes > 1:
+                        cv2.fillPoly(gt_class, polygon.astype(np.int32)[np.newaxis, :, :], classes[i])
 
         data['shrink_map'] = gt
+        if self.num_classes > 1:
+            data['class_mask'] = gt_class
         data['shrink_mask'] = mask
         return data
 

+ 11 - 1
ppocr/data/imaug/random_crop_data.py

@@ -130,17 +130,21 @@ class EastRandomCropData(object):
                  max_tries=10,
                  min_crop_side_ratio=0.1,
                  keep_ratio=True,
+                 num_classes=1,
                  **kwargs):
         self.size = size
         self.max_tries = max_tries
         self.min_crop_side_ratio = min_crop_side_ratio
         self.keep_ratio = keep_ratio
+        self.num_classes = num_classes
 
     def __call__(self, data):
         img = data['image']
         text_polys = data['polys']
         ignore_tags = data['ignore_tags']
         texts = data['texts']
+        if self.num_classes > 1:
+            classes = data['classes']
         all_care_polys = [
             text_polys[i] for i, tag in enumerate(ignore_tags) if not tag
         ]
@@ -167,16 +171,22 @@ class EastRandomCropData(object):
         text_polys_crop = []
         ignore_tags_crop = []
         texts_crop = []
-        for poly, text, tag in zip(text_polys, texts, ignore_tags):
+        classes_crop = []
+
+        for poly, text, tag, class_index in zip(text_polys, texts, ignore_tags, classes):
             poly = ((poly - (crop_x, crop_y)) * scale).tolist()
             if not is_poly_outside_rect(poly, 0, 0, w, h):
                 text_polys_crop.append(poly)
                 ignore_tags_crop.append(tag)
                 texts_crop.append(text)
+                if self.num_classes > 1:
+                    classes_crop.append(class_index)
         data['image'] = img
         data['polys'] = np.array(text_polys_crop)
         data['ignore_tags'] = ignore_tags_crop
         data['texts'] = texts_crop
+        if self.num_classes > 1:
+            data['classes'] = classes_crop
         return data
 
 

+ 150 - 8
ppocr/losses/det_db_loss.py

@@ -24,6 +24,129 @@ from paddle import nn
 
 from .det_basic_loss import BalanceLoss, MaskL1Loss, DiceLoss
 
+import paddle
+import paddle.nn.functional as F
+
+
+class CrossEntropyLoss(nn.Layer):
+    """
+    Implements the cross entropy loss function.
+
+    Args:
+        weight (tuple|list|ndarray|Tensor, optional): A manual rescaling weight
+            given to each class. Its length must be equal to the number of classes.
+            Default ``None``.
+        ignore_index (int64, optional): Specifies a target value that is ignored
+            and does not contribute to the input gradient. Default ``255``.
+        top_k_percent_pixels (float, optional): the value lies in [0.0, 1.0].
+            When its value < 1.0, only compute the loss for the top k percent pixels
+            (e.g., the top 20% pixels). This is useful for hard pixel mining. Default ``1.0``.
+        data_format (str, optional): The tensor format to use, 'NCHW' or 'NHWC'. Default ``'NCHW'``.
+    """
+
+    def __init__(self,
+                 weight=None,
+                 ignore_index=255,
+                 top_k_percent_pixels=1.0,
+                 data_format='NCHW'):
+        super(CrossEntropyLoss, self).__init__()
+        self.ignore_index = ignore_index
+        self.top_k_percent_pixels = top_k_percent_pixels
+        self.EPS = 1e-8
+        self.data_format = data_format
+        if weight is not None:
+            self.weight = paddle.to_tensor(weight, dtype='float32')
+        else:
+            self.weight = None
+
+    def forward(self, logit, label, semantic_weights=None):
+        """
+        Forward computation.
+
+        Args:
+            logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
+                (N, C), where C is number of classes, and if shape is more than 2D, this
+                is (N, C, D1, D2,..., Dk), k >= 1.
+            label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
+                value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
+                (N, D1, D2,..., Dk), k >= 1.
+            semantic_weights (Tensor, optional): Weights about loss for each pixels,
+                shape is the same as label. Default: None.
+        Returns:
+            (Tensor): The average loss.
+        """
+        channel_axis = 1 if self.data_format == 'NCHW' else -1
+        if self.weight is not None and logit.shape[channel_axis] != len(
+                self.weight):
+            raise ValueError(
+                'The number of weights = {} must be the same as the number of classes = {}.'
+                    .format(len(self.weight), logit.shape[channel_axis]))
+
+        if channel_axis == 1:
+            logit = paddle.transpose(logit, [0, 2, 3, 1])
+        label = label.astype('int64')
+        # In F.cross_entropy, the ignore_index is invalid, which needs to be fixed.
+        # When there is 255 in the label and paddle version <= 2.1.3, the cross_entropy OP will report an error, which is fixed in paddle develop version.
+        loss = F.cross_entropy(
+            logit,
+            label,
+            ignore_index=self.ignore_index,
+            reduction='none',
+            weight=self.weight)
+
+        return self._post_process_loss(logit, label, semantic_weights, loss)
+
+    def _post_process_loss(self, logit, label, semantic_weights, loss):
+        """
+        Consider mask and top_k to calculate the final loss.
+
+        Args:
+            logit (Tensor): Logit tensor, the data type is float32, float64. Shape is
+                (N, C), where C is number of classes, and if shape is more than 2D, this
+                is (N, C, D1, D2,..., Dk), k >= 1.
+            label (Tensor): Label tensor, the data type is int64. Shape is (N), where each
+                value is 0 <= label[i] <= C-1, and if shape is more than 2D, this is
+                (N, D1, D2,..., Dk), k >= 1.
+            semantic_weights (Tensor, optional): Weights about loss for each pixels,
+                shape is the same as label.
+            loss (Tensor): Loss tensor which is the output of cross_entropy. If soft_label
+                is False in cross_entropy, the shape of loss should be the same as the label.
+                If soft_label is True in cross_entropy, the shape of loss should be
+                (N, D1, D2,..., Dk, 1).
+        Returns:
+            (Tensor): The average loss.
+        """
+        mask = label != self.ignore_index
+        mask = paddle.cast(mask, 'float32')
+        label.stop_gradient = True
+        mask.stop_gradient = True
+
+        if loss.ndim > mask.ndim:
+            loss = paddle.squeeze(loss, axis=-1)
+        loss = loss * mask
+        if semantic_weights is not None:
+            loss = loss * semantic_weights
+
+        if self.weight is not None:
+            _one_hot = F.one_hot(label, logit.shape[-1])
+            coef = paddle.sum(_one_hot * self.weight, axis=-1)
+        else:
+            coef = paddle.ones_like(label)
+
+        if self.top_k_percent_pixels == 1.0:
+            avg_loss = paddle.mean(loss) / (paddle.mean(mask * coef) + self.EPS)
+        else:
+            loss = loss.reshape((-1,))
+            top_k_pixels = int(self.top_k_percent_pixels * loss.numel())
+            loss, indices = paddle.topk(loss, top_k_pixels)
+            coef = coef.reshape((-1,))
+            coef = paddle.gather(coef, indices)
+            coef.stop_gradient = True
+            coef = coef.astype('float32')
+            avg_loss = loss.mean() / (paddle.mean(coef) + self.EPS)
+
+        return avg_loss
+
 
 class DBLoss(nn.Layer):
     """
@@ -39,21 +162,27 @@ class DBLoss(nn.Layer):
                  beta=10,
                  ohem_ratio=3,
                  eps=1e-6,
+                 num_classes=1,
                  **kwargs):
         super(DBLoss, self).__init__()
         self.alpha = alpha
         self.beta = beta
+        self.num_classes = num_classes
         self.dice_loss = DiceLoss(eps=eps)
         self.l1_loss = MaskL1Loss(eps=eps)
         self.bce_loss = BalanceLoss(
             balance_loss=balance_loss,
             main_loss_type=main_loss_type,
             negative_ratio=ohem_ratio)
+        self.loss_func = CrossEntropyLoss()
 
     def forward(self, predicts, labels):
         predict_maps = predicts['maps']
-        label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[
-            1:]
+        if self.num_classes > 1:
+            predict_classes = predicts['classes']
+            label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask, class_mask = labels[1:]
+        else:
+            label_threshold_map, label_threshold_mask, label_shrink_map, label_shrink_mask = labels[1:]
         shrink_maps = predict_maps[:, 0, :, :]
         threshold_maps = predict_maps[:, 1, :, :]
         binary_maps = predict_maps[:, 2, :, :]
@@ -67,10 +196,23 @@ class DBLoss(nn.Layer):
         loss_shrink_maps = self.alpha * loss_shrink_maps
         loss_threshold_maps = self.beta * loss_threshold_maps
 
-        loss_all = loss_shrink_maps + loss_threshold_maps \
-                   + loss_binary_maps
-        losses = {'loss': loss_all, \
-                  "loss_shrink_maps": loss_shrink_maps, \
-                  "loss_threshold_maps": loss_threshold_maps, \
-                  "loss_binary_maps": loss_binary_maps}
+
+        # 处理
+        if self.num_classes > 1:
+            loss_classes = self.loss_func(predict_classes, class_mask)
+
+            loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps + loss_classes
+
+            losses = {'loss': loss_all,
+                      "loss_shrink_maps": loss_shrink_maps,
+                      "loss_threshold_maps": loss_threshold_maps,
+                      "loss_binary_maps": loss_binary_maps,
+                      "loss_classes": loss_classes}
+        else:
+            loss_all = loss_shrink_maps + loss_threshold_maps + loss_binary_maps
+
+            losses = {'loss': loss_all,
+                      "loss_shrink_maps": loss_shrink_maps,
+                      "loss_threshold_maps": loss_threshold_maps,
+                      "loss_binary_maps": loss_binary_maps}
         return losses

+ 20 - 6
ppocr/modeling/heads/det_db_head.py

@@ -31,8 +31,9 @@ def get_bias_attr(k):
 
 
 class Head(nn.Layer):
-    def __init__(self, in_channels, kernel_list=[3, 2, 2], **kwargs):
+    def __init__(self, in_channels, kernel_list=[3, 2, 2], num_classes=1 , **kwargs):
         super(Head, self).__init__()
+        self.num_classes = num_classes
 
         self.conv1 = nn.Conv2D(
             in_channels=in_channels,
@@ -65,7 +66,7 @@ class Head(nn.Layer):
             act="relu")
         self.conv3 = nn.Conv2DTranspose(
             in_channels=in_channels // 4,
-            out_channels=1,
+            out_channels=num_classes,
             kernel_size=kernel_list[2],
             stride=2,
             weight_attr=ParamAttr(
@@ -78,7 +79,8 @@ class Head(nn.Layer):
         x = self.conv2(x)
         x = self.conv_bn2(x)
         x = self.conv3(x)
-        x = F.sigmoid(x)
+        if self.num_classes == 1:
+            x = F.sigmoid(x)
         return x
 
 
@@ -90,11 +92,16 @@ class DBHead(nn.Layer):
         params(dict): super parameters for build DB network
     """
 
-    def __init__(self, in_channels, k=50, **kwargs):
+    def __init__(self, in_channels, num_classes=1, k=50, **kwargs):
         super(DBHead, self).__init__()
         self.k = k
+        self.num_classes = num_classes
         self.binarize = Head(in_channels, **kwargs)
         self.thresh = Head(in_channels, **kwargs)
+        if num_classes != 1:
+            self.classes = Head(in_channels, num_classes=num_classes)
+        else:
+            self.classes = None
 
     def step_function(self, x, y):
         return paddle.reciprocal(1 + paddle.exp(-self.k * (x - y)))
@@ -102,9 +109,16 @@ class DBHead(nn.Layer):
     def forward(self, x, targets=None):
         shrink_maps = self.binarize(x)
         if not self.training:
-            return {'maps': shrink_maps}
+            if self.num_classes == 1:
+                return {'maps': shrink_maps}
+            else:
+                classes = paddle.argmax(self.classes(x), axis=1, keepdim=True, dtype='int32')
+                return {'maps': shrink_maps, "classes": classes}
 
         threshold_maps = self.thresh(x)
         binary_maps = self.step_function(shrink_maps, threshold_maps)
         y = paddle.concat([shrink_maps, threshold_maps, binary_maps], axis=1)
-        return {'maps': y}
+        if self.num_classes == 1:
+            return {'maps': y}
+        else:
+            return {'maps': y, "classes": self.classes(x)}

+ 103 - 10
ppocr/postprocess/db_postprocess.py

@@ -77,6 +77,7 @@ class DBPostProcess(object):
                 continue
 
             score = self.box_score_fast(pred, points.reshape(-1, 2))
+            
             if self.box_thresh > score:
                 continue
 
@@ -101,7 +102,7 @@ class DBPostProcess(object):
             scores.append(score)
         return boxes, scores
 
-    def boxes_from_bitmap(self, pred, _bitmap, dest_width, dest_height):
+    def boxes_from_bitmap(self, pred, _bitmap,classes, dest_width, dest_height):
         '''
         _bitmap: single map with shape (1, H, W),
                 whose values are binarized as {0, 1}
@@ -118,9 +119,11 @@ class DBPostProcess(object):
             contours, _ = outs[0], outs[1]
 
         num_contours = min(len(contours), self.max_candidates)
-
+        
         boxes = []
         scores = []
+        class_indexes = []
+        class_scores = []
         for index in range(num_contours):
             contour = contours[index]
             points, sside = self.get_mini_boxes(contour)
@@ -128,9 +131,12 @@ class DBPostProcess(object):
                 continue
             points = np.array(points)
             if self.score_mode == "fast":
-                score = self.box_score_fast(pred, points.reshape(-1, 2))
+                score, class_index, class_score = self.box_score_fast(pred, points.reshape(-1, 2), classes)
             else:
-                score = self.box_score_slow(pred, contour)
+                score, class_index, class_score = self.box_score_slow(pred, contour, classes)
+
+
+            print("origin score:" + str(score))
             if self.box_thresh > score:
                 continue
 
@@ -146,7 +152,15 @@ class DBPostProcess(object):
                 np.round(box[:, 1] / height * dest_height), 0, dest_height)
             boxes.append(box.astype("int32"))
             scores.append(score)
-        return np.array(boxes, dtype="int32"), scores
+
+            class_indexes.append(class_index)
+            class_scores.append(class_score)
+
+        if classes is None:
+            return np.array(boxes, dtype="int32"), scores
+        else:
+            return np.array(boxes, dtype="int32"), scores, class_indexes, class_scores
+
 
     def unclip(self, box, unclip_ratio):
         poly = Polygon(box)
@@ -179,24 +193,53 @@ class DBPostProcess(object):
         ]
         return box, min(bounding_box[1])
 
-    def box_score_fast(self, bitmap, _box):
+    def box_score_fast(self, bitmap, _box,classes):
         '''
         box_score_fast: use bbox mean score as the mean score
         '''
+        # print(classes)
         h, w = bitmap.shape[:2]
         box = _box.copy()
+
         xmin = np.clip(np.floor(box[:, 0].min()).astype("int32"), 0, w - 1)
         xmax = np.clip(np.ceil(box[:, 0].max()).astype("int32"), 0, w - 1)
         ymin = np.clip(np.floor(box[:, 1].min()).astype("int32"), 0, h - 1)
         ymax = np.clip(np.ceil(box[:, 1].max()).astype("int32"), 0, h - 1)
 
+        # box__ = box.reshape(1, -1, 2)
+        
+
+
+
         mask = np.zeros((ymax - ymin + 1, xmax - xmin + 1), dtype=np.uint8)
         box[:, 0] = box[:, 0] - xmin
         box[:, 1] = box[:, 1] - ymin
         cv2.fillPoly(mask, box.reshape(1, -1, 2).astype("int32"), 1)
-        return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+        if classes is None:
+            return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], None, None
+        else:
+            k = 255
+            class_mask = np.full((ymax - ymin + 1, xmax - xmin + 1), k, dtype=np.int32)
+            
+            cv2.fillPoly(class_mask, box.reshape(1, -1, 2).astype(np.int32), 0)
+            classes = classes[ymin:ymax + 1, xmin:xmax + 1]
+
+            new_classes = classes + class_mask
+
+            # 拉平
+            a = new_classes.reshape(-1)
+            b = np.where(a >= k)
+            # print(len(b[0].tolist()))
+            classes = np.delete(a, b[0].tolist())
+
+            class_index = np.argmax(np.bincount(classes))
+            print(class_index)
+            class_score = np.sum(classes == class_index) / len(classes)
+            print(class_score)
+            return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], class_index, class_score
+
+    def box_score_slow(self, bitmap, contour,classes):
 
-    def box_score_slow(self, bitmap, contour):
         '''
         box_score_slow: use polyon mean score as the mean score
         '''
@@ -215,7 +258,27 @@ class DBPostProcess(object):
         contour[:, 1] = contour[:, 1] - ymin
 
         cv2.fillPoly(mask, contour.reshape(1, -1, 2).astype("int32"), 1)
-        return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0]
+        if classes is None:
+            return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], None, None
+        else:
+            k = 999
+            class_mask = np.full((ymax - ymin + 1, xmax - xmin + 1), k, dtype=np.int32)
+
+            cv2.fillPoly(class_mask, contour.reshape(1, -1, 2).astype("int32"), 0)
+            classes = classes[ymin:ymax + 1, xmin:xmax + 1]
+
+            new_classes = classes + class_mask
+
+            # 拉平
+            a = new_classes.reshape(-1)
+            b = np.where(a >= k)
+            classes = np.delete(a, b[0].tolist())
+
+            class_index = np.argmax(np.bincount(classes))
+            class_score = np.sum(classes == class_index) / len(classes)
+
+            return cv2.mean(bitmap[ymin:ymax + 1, xmin:xmax + 1], mask)[0], class_index, class_score
+
 
     def __call__(self, outs_dict, shape_list):
         pred = outs_dict['maps']
@@ -224,6 +287,29 @@ class DBPostProcess(object):
         pred = pred[:, 0, :, :]
         segmentation = pred > self.thresh
 
+        print(pred.shape)
+
+        if "classes" in outs_dict:
+            classes = outs_dict['classes']
+            # print(classes)
+            # print("jerome1")
+            # print(classes.shape)
+            # print(classes)
+
+            # np.set_printoptions(threshold=np.inf)
+
+            if isinstance(classes, paddle.Tensor):
+                # classes = paddle.argmax(classes, axis=1, dtype='int32')
+                classes = classes.numpy()
+            # else:
+                # classes = np.argmax(classes, axis=1).astype(np.int32)
+
+            classes = classes[:, 0, :, :]
+            print(classes.shape)
+
+        else:
+            classes = None
+
         boxes_batch = []
         for batch_index in range(pred.shape[0]):
             src_h, src_w, ratio_h, ratio_w = shape_list[batch_index]
@@ -237,8 +323,15 @@ class DBPostProcess(object):
                 boxes, scores = self.polygons_from_bitmap(pred[batch_index],
                                                           mask, src_w, src_h)
             elif self.box_type == 'quad':
-                boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask,
+                if classes is None:
+                    boxes, scores = self.boxes_from_bitmap(pred[batch_index], mask, None,
                                                        src_w, src_h)
+              
+                else:
+                    boxes, scores, class_indexes, class_scores = self.boxes_from_bitmap(pred[batch_index], mask,
+                                                                                      classes[batch_index],
+                                                                                      src_w, src_h)
+                    boxes_batch.append({'points': boxes, "classes": class_indexes, "class_scores": class_scores})
             else:
                 raise ValueError("box_type can only be one of ['quad', 'poly']")
 

+ 2 - 0
tools/eval.py

@@ -73,6 +73,8 @@ def main():
         else:  # base rec model
             config['Architecture']["Head"]['out_channels'] = char_num
 
+    if "num_classes" in global_config:
+        config['Architecture']["Head"]['num_classes'] = global_config["num_classes"]
     model = build_model(config['Architecture'])
     extra_input_models = [
         "SRN", "NRTR", "SAR", "SEED", "SVTR", "VisionLAN", "RobustScanner"

+ 3 - 0
tools/export_model.py

@@ -238,6 +238,9 @@ def main():
     # for sr algorithm
     if config["Architecture"]["model_type"] == "sr":
         config['Architecture']["Transform"]['infer_mode'] = True
+        
+    if "num_classes" in config['Global']:
+        config['Architecture']["Head"]['num_classes'] = config['Global']["num_classes"]
     model = build_model(config["Architecture"])
     load_model(config, model, model_type=config['Architecture']["model_type"])
     model.eval()

+ 1 - 1
tools/infer/predict_det.py

@@ -217,7 +217,7 @@ class TextDetector(object):
         dt_boxes = np.array(dt_boxes_new)
         return dt_boxes
 
-    def __call__(self, img):
+    def __call__(self, img, cls=True):
         ori_im = img.copy()
         data = {'image': img}
 

+ 32 - 1
tools/infer_det.py

@@ -52,11 +52,39 @@ def draw_det_res(dt_boxes, config, img, img_name, save_path):
         cv2.imwrite(save_path, src_im)
         logger.info("The detected Image saved in {}".format(save_path))
 
+def draw_det_res_and_label(dt_boxes, classes, config, img, img_name, save_path):
+    label_list = config["Global"]["label_list"]
+    labels = []
+    if label_list is not None:
+        if isinstance(label_list, str):
+            with open(label_list, "r+", encoding="utf-8") as f:
+                for line in f.readlines():
+                    labels.append(line.replace("\n", ""))
+        else:
+            labels = label_list
+    if len(dt_boxes) > 0:
+        import cv2
+        index = 0
+        src_im = img
+        for box in dt_boxes:
+            box = box.astype(np.int32).reshape((-1, 1, 2))
+            cv2.polylines(src_im, [box], True, color=(255, 255, 0), thickness=2)
+
+            font = cv2.FONT_HERSHEY_SIMPLEX
+            src_im = cv2.putText(src_im, labels[classes[index]], (box[0][0][0], box[0][0][1]), font, 0.5, (255, 0, 0), 1)
+            index += 1
+        if not os.path.exists(save_path):
+            os.makedirs(save_path)
+        save_path = os.path.join(save_path, os.path.basename(img_name))
+        cv2.imwrite(save_path, src_im)
+        logger.info("The detected Image saved in {}".format(save_path))
 
 @paddle.no_grad()
 def main():
     global_config = config['Global']
 
+    if "num_classes" in global_config:
+        config['Architecture']["Head"]['num_classes'] = global_config["num_classes"]
     # build model
     model = build_model(config['Architecture'])
 
@@ -122,7 +150,10 @@ def main():
                     dt_boxes_json.append(tmp_json)
                 save_det_path = os.path.dirname(config['Global'][
                     'save_res_path']) + "/det_results/"
-                draw_det_res(boxes, config, src_img, file, save_det_path)
+                if "classes" in post_result[0]:
+                    draw_det_res_and_label(boxes, post_result[0]["classes"], config, src_img, file, save_det_path)
+                else:
+                    draw_det_res(boxes, config, src_img, file, save_det_path)
             otstr = file + "\t" + json.dumps(dt_boxes_json) + "\n"
             fout.write(otstr.encode())
 

+ 3 - 0
tools/train.py

@@ -118,6 +118,9 @@ def main(config, device, logger, vdl_writer):
         if config['PostProcess']['name'] == 'SARLabelDecode':  # for SAR model
             config['Loss']['ignore_index'] = char_num - 1
 
+    if "num_classes" in global_config:
+        config['Architecture']["Head"]['num_classes'] = global_config["num_classes"]
+        config['Loss']['num_classes'] = global_config["num_classes"]
     model = build_model(config['Architecture'])
 
     use_sync_bn = config["Global"].get("use_sync_bn", False)