123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191 |
- # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- This code is refer from:
- https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textdet/dense_heads/drrg_head.py
- """
- from __future__ import absolute_import
- from __future__ import division
- from __future__ import print_function
- import warnings
- import cv2
- import numpy as np
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- from .gcn import GCN
- from .local_graph import LocalGraphs
- from .proposal_local_graph import ProposalLocalGraphs
- class DRRGHead(nn.Layer):
- def __init__(self,
- in_channels,
- k_at_hops=(8, 4),
- num_adjacent_linkages=3,
- node_geo_feat_len=120,
- pooling_scale=1.0,
- pooling_output_size=(4, 3),
- nms_thr=0.3,
- min_width=8.0,
- max_width=24.0,
- comp_shrink_ratio=1.03,
- comp_ratio=0.4,
- comp_score_thr=0.3,
- text_region_thr=0.2,
- center_region_thr=0.2,
- center_region_area_thr=50,
- local_graph_thr=0.7,
- **kwargs):
- super().__init__()
- assert isinstance(in_channels, int)
- assert isinstance(k_at_hops, tuple)
- assert isinstance(num_adjacent_linkages, int)
- assert isinstance(node_geo_feat_len, int)
- assert isinstance(pooling_scale, float)
- assert isinstance(pooling_output_size, tuple)
- assert isinstance(comp_shrink_ratio, float)
- assert isinstance(nms_thr, float)
- assert isinstance(min_width, float)
- assert isinstance(max_width, float)
- assert isinstance(comp_ratio, float)
- assert isinstance(comp_score_thr, float)
- assert isinstance(text_region_thr, float)
- assert isinstance(center_region_thr, float)
- assert isinstance(center_region_area_thr, int)
- assert isinstance(local_graph_thr, float)
- self.in_channels = in_channels
- self.out_channels = 6
- self.downsample_ratio = 1.0
- self.k_at_hops = k_at_hops
- self.num_adjacent_linkages = num_adjacent_linkages
- self.node_geo_feat_len = node_geo_feat_len
- self.pooling_scale = pooling_scale
- self.pooling_output_size = pooling_output_size
- self.comp_shrink_ratio = comp_shrink_ratio
- self.nms_thr = nms_thr
- self.min_width = min_width
- self.max_width = max_width
- self.comp_ratio = comp_ratio
- self.comp_score_thr = comp_score_thr
- self.text_region_thr = text_region_thr
- self.center_region_thr = center_region_thr
- self.center_region_area_thr = center_region_area_thr
- self.local_graph_thr = local_graph_thr
- self.out_conv = nn.Conv2D(
- in_channels=self.in_channels,
- out_channels=self.out_channels,
- kernel_size=1,
- stride=1,
- padding=0)
- self.graph_train = LocalGraphs(
- self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
- self.pooling_scale, self.pooling_output_size, self.local_graph_thr)
- self.graph_test = ProposalLocalGraphs(
- self.k_at_hops, self.num_adjacent_linkages, self.node_geo_feat_len,
- self.pooling_scale, self.pooling_output_size, self.nms_thr,
- self.min_width, self.max_width, self.comp_shrink_ratio,
- self.comp_ratio, self.comp_score_thr, self.text_region_thr,
- self.center_region_thr, self.center_region_area_thr)
- pool_w, pool_h = self.pooling_output_size
- node_feat_len = (pool_w * pool_h) * (
- self.in_channels + self.out_channels) + self.node_geo_feat_len
- self.gcn = GCN(node_feat_len)
- def forward(self, inputs, targets=None):
- """
- Args:
- inputs (Tensor): Shape of :math:`(N, C, H, W)`.
- gt_comp_attribs (list[ndarray]): The padded text component
- attributes. Shape: (num_component, 8).
- Returns:
- tuple: Returns (pred_maps, (gcn_pred, gt_labels)).
- - | pred_maps (Tensor): Prediction map with shape
- :math:`(N, C_{out}, H, W)`.
- - | gcn_pred (Tensor): Prediction from GCN module, with
- shape :math:`(N, 2)`.
- - | gt_labels (Tensor): Ground-truth label with shape
- :math:`(N, 8)`.
- """
- if self.training:
- assert targets is not None
- gt_comp_attribs = targets[7]
- pred_maps = self.out_conv(inputs)
- feat_maps = paddle.concat([inputs, pred_maps], axis=1)
- node_feats, adjacent_matrices, knn_inds, gt_labels = self.graph_train(
- feat_maps, np.stack(gt_comp_attribs))
- gcn_pred = self.gcn(node_feats, adjacent_matrices, knn_inds)
- return pred_maps, (gcn_pred, gt_labels)
- else:
- return self.single_test(inputs)
- def single_test(self, feat_maps):
- r"""
- Args:
- feat_maps (Tensor): Shape of :math:`(N, C, H, W)`.
- Returns:
- tuple: Returns (edge, score, text_comps).
- - | edge (ndarray): The edge array of shape :math:`(N, 2)`
- where each row is a pair of text component indices
- that makes up an edge in graph.
- - | score (ndarray): The score array of shape :math:`(N,)`,
- corresponding to the edge above.
- - | text_comps (ndarray): The text components of shape
- :math:`(N, 9)` where each row corresponds to one box and
- its score: (x1, y1, x2, y2, x3, y3, x4, y4, score).
- """
- pred_maps = self.out_conv(feat_maps)
- feat_maps = paddle.concat([feat_maps, pred_maps], axis=1)
- none_flag, graph_data = self.graph_test(pred_maps, feat_maps)
- (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds,
- pivot_local_graphs, text_comps) = graph_data
- if none_flag:
- return None, None, None
- gcn_pred = self.gcn(local_graphs_node_feat, adjacent_matrices,
- pivots_knn_inds)
- pred_labels = F.softmax(gcn_pred, axis=1)
- edges = []
- scores = []
- pivot_local_graphs = pivot_local_graphs.squeeze().numpy()
- for pivot_ind, pivot_local_graph in enumerate(pivot_local_graphs):
- pivot = pivot_local_graph[0]
- for k_ind, neighbor_ind in enumerate(pivots_knn_inds[pivot_ind]):
- neighbor = pivot_local_graph[neighbor_ind.item()]
- edges.append([pivot, neighbor])
- scores.append(pred_labels[pivot_ind * pivots_knn_inds.shape[1] +
- k_ind, 1].item())
- edges = np.asarray(edges)
- scores = np.asarray(scores)
- return edges, scores, text_comps
|