12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346 |
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import math
- import six
- import numpy as np
- from numbers import Integral
- import paddle
- import paddle.nn as nn
- from paddle import ParamAttr
- from paddle import to_tensor
- import paddle.nn.functional as F
- from paddle.nn.initializer import Normal, Constant, XavierUniform
- from paddle.regularizer import L2Decay
- from ppdet.core.workspace import register, serializable
- from ppdet.modeling.bbox_utils import delta2bbox
- from . import ops
- from .initializer import xavier_uniform_, constant_
- from paddle.vision.ops import DeformConv2D
- def _to_list(l):
- if isinstance(l, (list, tuple)):
- return list(l)
- return [l]
- class AlignConv(nn.Layer):
- def __init__(self, in_channels, out_channels, kernel_size=3, groups=1):
- super(AlignConv, self).__init__()
- self.kernel_size = kernel_size
- self.align_conv = paddle.vision.ops.DeformConv2D(
- in_channels,
- out_channels,
- kernel_size=self.kernel_size,
- padding=(self.kernel_size - 1) // 2,
- groups=groups,
- weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
- bias_attr=None)
- @paddle.no_grad()
- def get_offset(self, anchors, featmap_size, stride):
- """
- Args:
- anchors: [B, L, 5] xc,yc,w,h,angle
- featmap_size: (feat_h, feat_w)
- stride: 8
- Returns:
- """
- batch = anchors.shape[0]
- dtype = anchors.dtype
- feat_h, feat_w = featmap_size
- pad = (self.kernel_size - 1) // 2
- idx = paddle.arange(-pad, pad + 1, dtype=dtype)
- yy, xx = paddle.meshgrid(idx, idx)
- xx = paddle.reshape(xx, [-1])
- yy = paddle.reshape(yy, [-1])
- # get sampling locations of default conv
- xc = paddle.arange(0, feat_w, dtype=dtype)
- yc = paddle.arange(0, feat_h, dtype=dtype)
- yc, xc = paddle.meshgrid(yc, xc)
- xc = paddle.reshape(xc, [-1, 1])
- yc = paddle.reshape(yc, [-1, 1])
- x_conv = xc + xx
- y_conv = yc + yy
- # get sampling locations of anchors
- x_ctr, y_ctr, w, h, a = paddle.split(anchors, 5, axis=-1)
- x_ctr = x_ctr / stride
- y_ctr = y_ctr / stride
- w_s = w / stride
- h_s = h / stride
- cos, sin = paddle.cos(a), paddle.sin(a)
- dw, dh = w_s / self.kernel_size, h_s / self.kernel_size
- x, y = dw * xx, dh * yy
- xr = cos * x - sin * y
- yr = sin * x + cos * y
- x_anchor, y_anchor = xr + x_ctr, yr + y_ctr
- # get offset filed
- offset_x = x_anchor - x_conv
- offset_y = y_anchor - y_conv
- offset = paddle.stack([offset_y, offset_x], axis=-1)
- offset = offset.reshape(
- [batch, feat_h, feat_w, self.kernel_size * self.kernel_size * 2])
- offset = offset.transpose([0, 3, 1, 2])
- return offset
- def forward(self, x, refine_anchors, featmap_size, stride):
- batch = paddle.shape(x)[0].numpy()
- offset = self.get_offset(refine_anchors, featmap_size, stride)
- if self.training:
- x = F.relu(self.align_conv(x, offset.detach()))
- else:
- x = F.relu(self.align_conv(x, offset))
- return x
- class DeformableConvV2(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- weight_attr=None,
- bias_attr=None,
- lr_scale=1,
- regularizer=None,
- skip_quant=False,
- dcn_bias_regularizer=L2Decay(0.),
- dcn_bias_lr_scale=2.):
- super(DeformableConvV2, self).__init__()
- self.offset_channel = 2 * kernel_size**2
- self.mask_channel = kernel_size**2
- if lr_scale == 1 and regularizer is None:
- offset_bias_attr = ParamAttr(initializer=Constant(0.))
- else:
- offset_bias_attr = ParamAttr(
- initializer=Constant(0.),
- learning_rate=lr_scale,
- regularizer=regularizer)
- self.conv_offset = nn.Conv2D(
- in_channels,
- 3 * kernel_size**2,
- kernel_size,
- stride=stride,
- padding=(kernel_size - 1) // 2,
- weight_attr=ParamAttr(initializer=Constant(0.0)),
- bias_attr=offset_bias_attr)
- if skip_quant:
- self.conv_offset.skip_quant = True
- if bias_attr:
- # in FCOS-DCN head, specifically need learning_rate and regularizer
- dcn_bias_attr = ParamAttr(
- initializer=Constant(value=0),
- regularizer=dcn_bias_regularizer,
- learning_rate=dcn_bias_lr_scale)
- else:
- # in ResNet backbone, do not need bias
- dcn_bias_attr = False
- self.conv_dcn = DeformConv2D(
- in_channels,
- out_channels,
- kernel_size,
- stride=stride,
- padding=(kernel_size - 1) // 2 * dilation,
- dilation=dilation,
- groups=groups,
- weight_attr=weight_attr,
- bias_attr=dcn_bias_attr)
- def forward(self, x):
- offset_mask = self.conv_offset(x)
- offset, mask = paddle.split(
- offset_mask,
- num_or_sections=[self.offset_channel, self.mask_channel],
- axis=1)
- mask = F.sigmoid(mask)
- y = self.conv_dcn(x, offset, mask=mask)
- return y
- class ConvNormLayer(nn.Layer):
- def __init__(self,
- ch_in,
- ch_out,
- filter_size,
- stride,
- groups=1,
- norm_type='bn',
- norm_decay=0.,
- norm_groups=32,
- use_dcn=False,
- bias_on=False,
- lr_scale=1.,
- freeze_norm=False,
- initializer=Normal(
- mean=0., std=0.01),
- skip_quant=False,
- dcn_lr_scale=2.,
- dcn_regularizer=L2Decay(0.)):
- super(ConvNormLayer, self).__init__()
- assert norm_type in ['bn', 'sync_bn', 'gn', None]
- if bias_on:
- bias_attr = ParamAttr(
- initializer=Constant(value=0.), learning_rate=lr_scale)
- else:
- bias_attr = False
- if not use_dcn:
- self.conv = nn.Conv2D(
- in_channels=ch_in,
- out_channels=ch_out,
- kernel_size=filter_size,
- stride=stride,
- padding=(filter_size - 1) // 2,
- groups=groups,
- weight_attr=ParamAttr(
- initializer=initializer, learning_rate=1.),
- bias_attr=bias_attr)
- if skip_quant:
- self.conv.skip_quant = True
- else:
- # in FCOS-DCN head, specifically need learning_rate and regularizer
- self.conv = DeformableConvV2(
- in_channels=ch_in,
- out_channels=ch_out,
- kernel_size=filter_size,
- stride=stride,
- padding=(filter_size - 1) // 2,
- groups=groups,
- weight_attr=ParamAttr(
- initializer=initializer, learning_rate=1.),
- bias_attr=True,
- lr_scale=dcn_lr_scale,
- regularizer=dcn_regularizer,
- dcn_bias_regularizer=dcn_regularizer,
- dcn_bias_lr_scale=dcn_lr_scale,
- skip_quant=skip_quant)
- norm_lr = 0. if freeze_norm else 1.
- param_attr = ParamAttr(
- learning_rate=norm_lr,
- regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
- bias_attr = ParamAttr(
- learning_rate=norm_lr,
- regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
- if norm_type in ['bn', 'sync_bn']:
- self.norm = nn.BatchNorm2D(
- ch_out, weight_attr=param_attr, bias_attr=bias_attr)
- elif norm_type == 'gn':
- self.norm = nn.GroupNorm(
- num_groups=norm_groups,
- num_channels=ch_out,
- weight_attr=param_attr,
- bias_attr=bias_attr)
- else:
- self.norm = None
- def forward(self, inputs):
- out = self.conv(inputs)
- if self.norm is not None:
- out = self.norm(out)
- return out
- class LiteConv(nn.Layer):
- def __init__(self,
- in_channels,
- out_channels,
- stride=1,
- with_act=True,
- norm_type='sync_bn',
- name=None):
- super(LiteConv, self).__init__()
- self.lite_conv = nn.Sequential()
- conv1 = ConvNormLayer(
- in_channels,
- in_channels,
- filter_size=5,
- stride=stride,
- groups=in_channels,
- norm_type=norm_type,
- initializer=XavierUniform())
- conv2 = ConvNormLayer(
- in_channels,
- out_channels,
- filter_size=1,
- stride=stride,
- norm_type=norm_type,
- initializer=XavierUniform())
- conv3 = ConvNormLayer(
- out_channels,
- out_channels,
- filter_size=1,
- stride=stride,
- norm_type=norm_type,
- initializer=XavierUniform())
- conv4 = ConvNormLayer(
- out_channels,
- out_channels,
- filter_size=5,
- stride=stride,
- groups=out_channels,
- norm_type=norm_type,
- initializer=XavierUniform())
- conv_list = [conv1, conv2, conv3, conv4]
- self.lite_conv.add_sublayer('conv1', conv1)
- self.lite_conv.add_sublayer('relu6_1', nn.ReLU6())
- self.lite_conv.add_sublayer('conv2', conv2)
- if with_act:
- self.lite_conv.add_sublayer('relu6_2', nn.ReLU6())
- self.lite_conv.add_sublayer('conv3', conv3)
- self.lite_conv.add_sublayer('relu6_3', nn.ReLU6())
- self.lite_conv.add_sublayer('conv4', conv4)
- if with_act:
- self.lite_conv.add_sublayer('relu6_4', nn.ReLU6())
- def forward(self, inputs):
- out = self.lite_conv(inputs)
- return out
- class DropBlock(nn.Layer):
- def __init__(self, block_size, keep_prob, name=None, data_format='NCHW'):
- """
- DropBlock layer, see https://arxiv.org/abs/1810.12890
- Args:
- block_size (int): block size
- keep_prob (int): keep probability
- name (str): layer name
- data_format (str): data format, NCHW or NHWC
- """
- super(DropBlock, self).__init__()
- self.block_size = block_size
- self.keep_prob = keep_prob
- self.name = name
- self.data_format = data_format
- def forward(self, x):
- if not self.training or self.keep_prob == 1:
- return x
- else:
- gamma = (1. - self.keep_prob) / (self.block_size**2)
- if self.data_format == 'NCHW':
- shape = x.shape[2:]
- else:
- shape = x.shape[1:3]
- for s in shape:
- gamma *= s / (s - self.block_size + 1)
- matrix = paddle.cast(paddle.rand(x.shape) < gamma, x.dtype)
- mask_inv = F.max_pool2d(
- matrix,
- self.block_size,
- stride=1,
- padding=self.block_size // 2,
- data_format=self.data_format)
- mask = 1. - mask_inv
- y = x * mask * (mask.numel() / mask.sum())
- return y
- @register
- @serializable
- class AnchorGeneratorSSD(object):
- def __init__(self,
- steps=[8, 16, 32, 64, 100, 300],
- aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
- min_ratio=15,
- max_ratio=90,
- base_size=300,
- min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
- max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
- offset=0.5,
- flip=True,
- clip=False,
- min_max_aspect_ratios_order=False):
- self.steps = steps
- self.aspect_ratios = aspect_ratios
- self.min_ratio = min_ratio
- self.max_ratio = max_ratio
- self.base_size = base_size
- self.min_sizes = min_sizes
- self.max_sizes = max_sizes
- self.offset = offset
- self.flip = flip
- self.clip = clip
- self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
- if self.min_sizes == [] and self.max_sizes == []:
- num_layer = len(aspect_ratios)
- step = int(
- math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2
- )))
- for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1,
- step):
- self.min_sizes.append(self.base_size * ratio / 100.)
- self.max_sizes.append(self.base_size * (ratio + step) / 100.)
- self.min_sizes = [self.base_size * .10] + self.min_sizes
- self.max_sizes = [self.base_size * .20] + self.max_sizes
- self.num_priors = []
- for aspect_ratio, min_size, max_size in zip(
- aspect_ratios, self.min_sizes, self.max_sizes):
- if isinstance(min_size, (list, tuple)):
- self.num_priors.append(
- len(_to_list(min_size)) + len(_to_list(max_size)))
- else:
- self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
- _to_list(min_size)) + len(_to_list(max_size)))
- def __call__(self, inputs, image):
- boxes = []
- for input, min_size, max_size, aspect_ratio, step in zip(
- inputs, self.min_sizes, self.max_sizes, self.aspect_ratios,
- self.steps):
- box, _ = ops.prior_box(
- input=input,
- image=image,
- min_sizes=_to_list(min_size),
- max_sizes=_to_list(max_size),
- aspect_ratios=aspect_ratio,
- flip=self.flip,
- clip=self.clip,
- steps=[step, step],
- offset=self.offset,
- min_max_aspect_ratios_order=self.min_max_aspect_ratios_order)
- boxes.append(paddle.reshape(box, [-1, 4]))
- return boxes
- @register
- @serializable
- class RCNNBox(object):
- __shared__ = ['num_classes', 'export_onnx']
- def __init__(self,
- prior_box_var=[10., 10., 5., 5.],
- code_type="decode_center_size",
- box_normalized=False,
- num_classes=80,
- export_onnx=False):
- super(RCNNBox, self).__init__()
- self.prior_box_var = prior_box_var
- self.code_type = code_type
- self.box_normalized = box_normalized
- self.num_classes = num_classes
- self.export_onnx = export_onnx
- def __call__(self, bbox_head_out, rois, im_shape, scale_factor):
- bbox_pred = bbox_head_out[0]
- cls_prob = bbox_head_out[1]
- roi = rois[0]
- rois_num = rois[1]
- if self.export_onnx:
- onnx_rois_num_per_im = rois_num[0]
- origin_shape = paddle.expand(im_shape[0, :],
- [onnx_rois_num_per_im, 2])
- else:
- origin_shape_list = []
- if isinstance(roi, list):
- batch_size = len(roi)
- else:
- batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
- # bbox_pred.shape: [N, C*4]
- for idx in range(batch_size):
- rois_num_per_im = rois_num[idx]
- expand_im_shape = paddle.expand(im_shape[idx, :],
- [rois_num_per_im, 2])
- origin_shape_list.append(expand_im_shape)
- origin_shape = paddle.concat(origin_shape_list)
- # bbox_pred.shape: [N, C*4]
- # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head)
- bbox = paddle.concat(roi)
- bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
- scores = cls_prob[:, :-1]
- # bbox.shape: [N, C, 4]
- # bbox.shape[1] must be equal to scores.shape[1]
- total_num = bbox.shape[0]
- bbox_dim = bbox.shape[-1]
- bbox = paddle.expand(bbox, [total_num, self.num_classes, bbox_dim])
- origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1)
- origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1)
- zeros = paddle.zeros_like(origin_h)
- x1 = paddle.maximum(paddle.minimum(bbox[:, :, 0], origin_w), zeros)
- y1 = paddle.maximum(paddle.minimum(bbox[:, :, 1], origin_h), zeros)
- x2 = paddle.maximum(paddle.minimum(bbox[:, :, 2], origin_w), zeros)
- y2 = paddle.maximum(paddle.minimum(bbox[:, :, 3], origin_h), zeros)
- bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
- bboxes = (bbox, rois_num)
- return bboxes, scores
- @register
- @serializable
- class MultiClassNMS(object):
- def __init__(self,
- score_threshold=.05,
- nms_top_k=-1,
- keep_top_k=100,
- nms_threshold=.5,
- normalized=True,
- nms_eta=1.0,
- return_index=False,
- return_rois_num=True,
- trt=False):
- super(MultiClassNMS, self).__init__()
- self.score_threshold = score_threshold
- self.nms_top_k = nms_top_k
- self.keep_top_k = keep_top_k
- self.nms_threshold = nms_threshold
- self.normalized = normalized
- self.nms_eta = nms_eta
- self.return_index = return_index
- self.return_rois_num = return_rois_num
- self.trt = trt
- def __call__(self, bboxes, score, background_label=-1):
- """
- bboxes (Tensor|List[Tensor]): 1. (Tensor) Predicted bboxes with shape
- [N, M, 4], N is the batch size and M
- is the number of bboxes
- 2. (List[Tensor]) bboxes and bbox_num,
- bboxes have shape of [M, C, 4], C
- is the class number and bbox_num means
- the number of bboxes of each batch with
- shape [N,]
- score (Tensor): Predicted scores with shape [N, C, M] or [M, C]
- background_label (int): Ignore the background label; For example, RCNN
- is num_classes and YOLO is -1.
- """
- kwargs = self.__dict__.copy()
- if isinstance(bboxes, tuple):
- bboxes, bbox_num = bboxes
- kwargs.update({'rois_num': bbox_num})
- if background_label > -1:
- kwargs.update({'background_label': background_label})
- kwargs.pop('trt')
- # TODO(wangxinxin08): paddle version should be develop or 2.3 and above to run nms on tensorrt
- if self.trt and (int(paddle.version.major) == 0 or
- (int(paddle.version.major) >= 2 and
- int(paddle.version.minor) >= 3)):
- # TODO(wangxinxin08): tricky switch to run nms on tensorrt
- kwargs.update({'nms_eta': 1.1})
- bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs)
- bbox = bbox.reshape([1, -1, 6])
- idx = paddle.nonzero(bbox[..., 0] != -1)
- bbox = paddle.gather_nd(bbox, idx)
- return bbox, bbox_num, None
- else:
- return ops.multiclass_nms(bboxes, score, **kwargs)
- @register
- @serializable
- class MatrixNMS(object):
- __append_doc__ = True
- def __init__(self,
- score_threshold=.05,
- post_threshold=.05,
- nms_top_k=-1,
- keep_top_k=100,
- use_gaussian=False,
- gaussian_sigma=2.,
- normalized=False,
- background_label=0):
- super(MatrixNMS, self).__init__()
- self.score_threshold = score_threshold
- self.post_threshold = post_threshold
- self.nms_top_k = nms_top_k
- self.keep_top_k = keep_top_k
- self.normalized = normalized
- self.use_gaussian = use_gaussian
- self.gaussian_sigma = gaussian_sigma
- self.background_label = background_label
- def __call__(self, bbox, score, *args):
- return ops.matrix_nms(
- bboxes=bbox,
- scores=score,
- score_threshold=self.score_threshold,
- post_threshold=self.post_threshold,
- nms_top_k=self.nms_top_k,
- keep_top_k=self.keep_top_k,
- use_gaussian=self.use_gaussian,
- gaussian_sigma=self.gaussian_sigma,
- background_label=self.background_label,
- normalized=self.normalized)
- @register
- @serializable
- class YOLOBox(object):
- __shared__ = ['num_classes']
- def __init__(self,
- num_classes=80,
- conf_thresh=0.005,
- downsample_ratio=32,
- clip_bbox=True,
- scale_x_y=1.):
- self.num_classes = num_classes
- self.conf_thresh = conf_thresh
- self.downsample_ratio = downsample_ratio
- self.clip_bbox = clip_bbox
- self.scale_x_y = scale_x_y
- def __call__(self,
- yolo_head_out,
- anchors,
- im_shape,
- scale_factor,
- var_weight=None):
- boxes_list = []
- scores_list = []
- origin_shape = im_shape / scale_factor
- origin_shape = paddle.cast(origin_shape, 'int32')
- for i, head_out in enumerate(yolo_head_out):
- boxes, scores = paddle.vision.ops.yolo_box(
- head_out,
- origin_shape,
- anchors[i],
- self.num_classes,
- self.conf_thresh,
- self.downsample_ratio // 2**i,
- self.clip_bbox,
- scale_x_y=self.scale_x_y)
- boxes_list.append(boxes)
- scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
- yolo_boxes = paddle.concat(boxes_list, axis=1)
- yolo_scores = paddle.concat(scores_list, axis=2)
- return yolo_boxes, yolo_scores
- @register
- @serializable
- class SSDBox(object):
- def __init__(self,
- is_normalized=True,
- prior_box_var=[0.1, 0.1, 0.2, 0.2],
- use_fuse_decode=False):
- self.is_normalized = is_normalized
- self.norm_delta = float(not self.is_normalized)
- self.prior_box_var = prior_box_var
- self.use_fuse_decode = use_fuse_decode
- def __call__(self,
- preds,
- prior_boxes,
- im_shape,
- scale_factor,
- var_weight=None):
- boxes, scores = preds
- boxes = paddle.concat(boxes, axis=1)
- prior_boxes = paddle.concat(prior_boxes)
- if self.use_fuse_decode:
- output_boxes = ops.box_coder(
- prior_boxes,
- self.prior_box_var,
- boxes,
- code_type="decode_center_size",
- box_normalized=self.is_normalized)
- else:
- pb_w = prior_boxes[:, 2] - prior_boxes[:, 0] + self.norm_delta
- pb_h = prior_boxes[:, 3] - prior_boxes[:, 1] + self.norm_delta
- pb_x = prior_boxes[:, 0] + pb_w * 0.5
- pb_y = prior_boxes[:, 1] + pb_h * 0.5
- out_x = pb_x + boxes[:, :, 0] * pb_w * self.prior_box_var[0]
- out_y = pb_y + boxes[:, :, 1] * pb_h * self.prior_box_var[1]
- out_w = paddle.exp(boxes[:, :, 2] * self.prior_box_var[2]) * pb_w
- out_h = paddle.exp(boxes[:, :, 3] * self.prior_box_var[3]) * pb_h
- output_boxes = paddle.stack(
- [
- out_x - out_w / 2., out_y - out_h / 2., out_x + out_w / 2.,
- out_y + out_h / 2.
- ],
- axis=-1)
- if self.is_normalized:
- h = (im_shape[:, 0] / scale_factor[:, 0]).unsqueeze(-1)
- w = (im_shape[:, 1] / scale_factor[:, 1]).unsqueeze(-1)
- im_shape = paddle.stack([w, h, w, h], axis=-1)
- output_boxes *= im_shape
- else:
- output_boxes[..., -2:] -= 1.0
- output_scores = F.softmax(paddle.concat(
- scores, axis=1)).transpose([0, 2, 1])
- return output_boxes, output_scores
- @register
- class TTFBox(object):
- __shared__ = ['down_ratio']
- def __init__(self, max_per_img=100, score_thresh=0.01, down_ratio=4):
- super(TTFBox, self).__init__()
- self.max_per_img = max_per_img
- self.score_thresh = score_thresh
- self.down_ratio = down_ratio
- def _simple_nms(self, heat, kernel=3):
- """
- Use maxpool to filter the max score, get local peaks.
- """
- pad = (kernel - 1) // 2
- hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
- keep = paddle.cast(hmax == heat, 'float32')
- return heat * keep
- def _topk(self, scores):
- """
- Select top k scores and decode to get xy coordinates.
- """
- k = self.max_per_img
- shape_fm = paddle.shape(scores)
- shape_fm.stop_gradient = True
- cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
- # batch size is 1
- scores_r = paddle.reshape(scores, [cat, -1])
- topk_scores, topk_inds = paddle.topk(scores_r, k)
- topk_ys = topk_inds // width
- topk_xs = topk_inds % width
- topk_score_r = paddle.reshape(topk_scores, [-1])
- topk_score, topk_ind = paddle.topk(topk_score_r, k)
- k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
- topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
- topk_inds = paddle.reshape(topk_inds, [-1])
- topk_ys = paddle.reshape(topk_ys, [-1, 1])
- topk_xs = paddle.reshape(topk_xs, [-1, 1])
- topk_inds = paddle.gather(topk_inds, topk_ind)
- topk_ys = paddle.gather(topk_ys, topk_ind)
- topk_xs = paddle.gather(topk_xs, topk_ind)
- return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
- def _decode(self, hm, wh, im_shape, scale_factor):
- heatmap = F.sigmoid(hm)
- heat = self._simple_nms(heatmap)
- scores, inds, clses, ys, xs = self._topk(heat)
- ys = paddle.cast(ys, 'float32') * self.down_ratio
- xs = paddle.cast(xs, 'float32') * self.down_ratio
- scores = paddle.tensor.unsqueeze(scores, [1])
- clses = paddle.tensor.unsqueeze(clses, [1])
- wh_t = paddle.transpose(wh, [0, 2, 3, 1])
- wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
- wh = paddle.gather(wh, inds)
- x1 = xs - wh[:, 0:1]
- y1 = ys - wh[:, 1:2]
- x2 = xs + wh[:, 2:3]
- y2 = ys + wh[:, 3:4]
- bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
- scale_y = scale_factor[:, 0:1]
- scale_x = scale_factor[:, 1:2]
- scale_expand = paddle.concat(
- [scale_x, scale_y, scale_x, scale_y], axis=1)
- boxes_shape = paddle.shape(bboxes)
- boxes_shape.stop_gradient = True
- scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
- bboxes = paddle.divide(bboxes, scale_expand)
- results = paddle.concat([clses, scores, bboxes], axis=1)
- # hack: append result with cls=-1 and score=1. to avoid all scores
- # are less than score_thresh which may cause error in gather.
- fill_r = paddle.to_tensor(np.array([[-1, 1, 0, 0, 0, 0]]))
- fill_r = paddle.cast(fill_r, results.dtype)
- results = paddle.concat([results, fill_r])
- scores = results[:, 1]
- valid_ind = paddle.nonzero(scores > self.score_thresh)
- results = paddle.gather(results, valid_ind)
- return results, paddle.shape(results)[0:1]
- def __call__(self, hm, wh, im_shape, scale_factor):
- results = []
- results_num = []
- for i in range(scale_factor.shape[0]):
- result, num = self._decode(hm[i:i + 1, ], wh[i:i + 1, ],
- im_shape[i:i + 1, ],
- scale_factor[i:i + 1, ])
- results.append(result)
- results_num.append(num)
- results = paddle.concat(results, axis=0)
- results_num = paddle.concat(results_num, axis=0)
- return results, results_num
- @register
- @serializable
- class JDEBox(object):
- __shared__ = ['num_classes']
- def __init__(self, num_classes=1, conf_thresh=0.3, downsample_ratio=32):
- self.num_classes = num_classes
- self.conf_thresh = conf_thresh
- self.downsample_ratio = downsample_ratio
- def generate_anchor(self, nGh, nGw, anchor_wh):
- nA = len(anchor_wh)
- yv, xv = paddle.meshgrid([paddle.arange(nGh), paddle.arange(nGw)])
- mesh = paddle.stack(
- (xv, yv), axis=0).cast(dtype='float32') # 2 x nGh x nGw
- meshs = paddle.tile(mesh, [nA, 1, 1, 1])
- anchor_offset_mesh = anchor_wh[:, :, None][:, :, :, None].repeat(
- int(nGh), axis=-2).repeat(
- int(nGw), axis=-1)
- anchor_offset_mesh = paddle.to_tensor(
- anchor_offset_mesh.astype(np.float32))
- # nA x 2 x nGh x nGw
- anchor_mesh = paddle.concat([meshs, anchor_offset_mesh], axis=1)
- anchor_mesh = paddle.transpose(anchor_mesh,
- [0, 2, 3, 1]) # (nA x nGh x nGw) x 4
- return anchor_mesh
- def decode_delta(self, delta, fg_anchor_list):
- px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
- fg_anchor_list[:, 2], fg_anchor_list[:,3]
- dx, dy, dw, dh = delta[:, 0], delta[:, 1], delta[:, 2], delta[:, 3]
- gx = pw * dx + px
- gy = ph * dy + py
- gw = pw * paddle.exp(dw)
- gh = ph * paddle.exp(dh)
- gx1 = gx - gw * 0.5
- gy1 = gy - gh * 0.5
- gx2 = gx + gw * 0.5
- gy2 = gy + gh * 0.5
- return paddle.stack([gx1, gy1, gx2, gy2], axis=1)
- def decode_delta_map(self, nA, nGh, nGw, delta_map, anchor_vec):
- anchor_mesh = self.generate_anchor(nGh, nGw, anchor_vec)
- anchor_mesh = paddle.unsqueeze(anchor_mesh, 0)
- pred_list = self.decode_delta(
- paddle.reshape(
- delta_map, shape=[-1, 4]),
- paddle.reshape(
- anchor_mesh, shape=[-1, 4]))
- pred_map = paddle.reshape(pred_list, shape=[nA * nGh * nGw, 4])
- return pred_map
- def _postprocessing_by_level(self, nA, stride, head_out, anchor_vec):
- boxes_shape = head_out.shape # [nB, nA*6, nGh, nGw]
- nGh, nGw = boxes_shape[-2], boxes_shape[-1]
- nB = 1 # TODO: only support bs=1 now
- boxes_list, scores_list = [], []
- for idx in range(nB):
- p = paddle.reshape(
- head_out[idx], shape=[nA, self.num_classes + 5, nGh, nGw])
- p = paddle.transpose(p, perm=[0, 2, 3, 1]) # [nA, nGh, nGw, 6]
- delta_map = p[:, :, :, :4]
- boxes = self.decode_delta_map(nA, nGh, nGw, delta_map, anchor_vec)
- # [nA * nGh * nGw, 4]
- boxes_list.append(boxes * stride)
- p_conf = paddle.transpose(
- p[:, :, :, 4:6], perm=[3, 0, 1, 2]) # [2, nA, nGh, nGw]
- p_conf = F.softmax(
- p_conf, axis=0)[1, :, :, :].unsqueeze(-1) # [nA, nGh, nGw, 1]
- scores = paddle.reshape(p_conf, shape=[nA * nGh * nGw, 1])
- scores_list.append(scores)
- boxes_results = paddle.stack(boxes_list)
- scores_results = paddle.stack(scores_list)
- return boxes_results, scores_results
- def __call__(self, yolo_head_out, anchors):
- bbox_pred_list = []
- for i, head_out in enumerate(yolo_head_out):
- stride = self.downsample_ratio // 2**i
- anc_w, anc_h = anchors[i][0::2], anchors[i][1::2]
- anchor_vec = np.stack((anc_w, anc_h), axis=1) / stride
- nA = len(anc_w)
- boxes, scores = self._postprocessing_by_level(nA, stride, head_out,
- anchor_vec)
- bbox_pred_list.append(paddle.concat([boxes, scores], axis=-1))
- yolo_boxes_scores = paddle.concat(bbox_pred_list, axis=1)
- boxes_idx_over_conf_thr = paddle.nonzero(
- yolo_boxes_scores[:, :, -1] > self.conf_thresh)
- boxes_idx_over_conf_thr.stop_gradient = True
- return boxes_idx_over_conf_thr, yolo_boxes_scores
- @register
- @serializable
- class MaskMatrixNMS(object):
- """
- Matrix NMS for multi-class masks.
- Args:
- update_threshold (float): Updated threshold of categroy score in second time.
- pre_nms_top_n (int): Number of total instance to be kept per image before NMS
- post_nms_top_n (int): Number of total instance to be kept per image after NMS.
- kernel (str): 'linear' or 'gaussian'.
- sigma (float): std in gaussian method.
- Input:
- seg_preds (Variable): shape (n, h, w), segmentation feature maps
- seg_masks (Variable): shape (n, h, w), segmentation feature maps
- cate_labels (Variable): shape (n), mask labels in descending order
- cate_scores (Variable): shape (n), mask scores in descending order
- sum_masks (Variable): a float tensor of the sum of seg_masks
- Returns:
- Variable: cate_scores, tensors of shape (n)
- """
- def __init__(self,
- update_threshold=0.05,
- pre_nms_top_n=500,
- post_nms_top_n=100,
- kernel='gaussian',
- sigma=2.0):
- super(MaskMatrixNMS, self).__init__()
- self.update_threshold = update_threshold
- self.pre_nms_top_n = pre_nms_top_n
- self.post_nms_top_n = post_nms_top_n
- self.kernel = kernel
- self.sigma = sigma
- def _sort_score(self, scores, top_num):
- if paddle.shape(scores)[0] > top_num:
- return paddle.topk(scores, top_num)[1]
- else:
- return paddle.argsort(scores, descending=True)
- def __call__(self,
- seg_preds,
- seg_masks,
- cate_labels,
- cate_scores,
- sum_masks=None):
- # sort and keep top nms_pre
- sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n)
- seg_masks = paddle.gather(seg_masks, index=sort_inds)
- seg_preds = paddle.gather(seg_preds, index=sort_inds)
- sum_masks = paddle.gather(sum_masks, index=sort_inds)
- cate_scores = paddle.gather(cate_scores, index=sort_inds)
- cate_labels = paddle.gather(cate_labels, index=sort_inds)
- seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1)
- # inter.
- inter_matrix = paddle.mm(seg_masks, paddle.transpose(seg_masks, [1, 0]))
- n_samples = paddle.shape(cate_labels)
- # union.
- sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples])
- # iou.
- iou_matrix = (inter_matrix / (
- sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix))
- iou_matrix = paddle.triu(iou_matrix, diagonal=1)
- # label_specific matrix.
- cate_labels_x = paddle.expand(cate_labels, shape=[n_samples, n_samples])
- label_matrix = paddle.cast(
- (cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])),
- 'float32')
- label_matrix = paddle.triu(label_matrix, diagonal=1)
- # IoU compensation
- compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
- compensate_iou = paddle.expand(
- compensate_iou, shape=[n_samples, n_samples])
- compensate_iou = paddle.transpose(compensate_iou, [1, 0])
- # IoU decay
- decay_iou = iou_matrix * label_matrix
- # matrix nms
- if self.kernel == 'gaussian':
- decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2))
- compensate_matrix = paddle.exp(-1 * self.sigma *
- (compensate_iou**2))
- decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
- axis=0)
- elif self.kernel == 'linear':
- decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
- decay_coefficient = paddle.min(decay_matrix, axis=0)
- else:
- raise NotImplementedError
- # update the score.
- cate_scores = cate_scores * decay_coefficient
- y = paddle.zeros(shape=paddle.shape(cate_scores), dtype='float32')
- keep = paddle.where(cate_scores >= self.update_threshold, cate_scores,
- y)
- keep = paddle.nonzero(keep)
- keep = paddle.squeeze(keep, axis=[1])
- # Prevent empty and increase fake data
- keep = paddle.concat(
- [keep, paddle.cast(paddle.shape(cate_scores)[0] - 1, 'int64')])
- seg_preds = paddle.gather(seg_preds, index=keep)
- cate_scores = paddle.gather(cate_scores, index=keep)
- cate_labels = paddle.gather(cate_labels, index=keep)
- # sort and keep top_k
- sort_inds = self._sort_score(cate_scores, self.post_nms_top_n)
- seg_preds = paddle.gather(seg_preds, index=sort_inds)
- cate_scores = paddle.gather(cate_scores, index=sort_inds)
- cate_labels = paddle.gather(cate_labels, index=sort_inds)
- return seg_preds, cate_scores, cate_labels
- def Conv2d(in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- bias=True,
- weight_init=Normal(std=0.001),
- bias_init=Constant(0.)):
- weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
- if bias:
- bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
- else:
- bias_attr = False
- conv = nn.Conv2D(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- dilation,
- groups,
- weight_attr=weight_attr,
- bias_attr=bias_attr)
- return conv
- def ConvTranspose2d(in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- output_padding=0,
- groups=1,
- bias=True,
- dilation=1,
- weight_init=Normal(std=0.001),
- bias_init=Constant(0.)):
- weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
- if bias:
- bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
- else:
- bias_attr = False
- conv = nn.Conv2DTranspose(
- in_channels,
- out_channels,
- kernel_size,
- stride,
- padding,
- output_padding,
- dilation,
- groups,
- weight_attr=weight_attr,
- bias_attr=bias_attr)
- return conv
- def BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True):
- if not affine:
- weight_attr = False
- bias_attr = False
- else:
- weight_attr = None
- bias_attr = None
- batchnorm = nn.BatchNorm2D(
- num_features,
- momentum,
- eps,
- weight_attr=weight_attr,
- bias_attr=bias_attr)
- return batchnorm
- def ReLU():
- return nn.ReLU()
- def Upsample(scale_factor=None, mode='nearest', align_corners=False):
- return nn.Upsample(None, scale_factor, mode, align_corners)
- def MaxPool(kernel_size, stride, padding, ceil_mode=False):
- return nn.MaxPool2D(kernel_size, stride, padding, ceil_mode=ceil_mode)
- class Concat(nn.Layer):
- def __init__(self, dim=0):
- super(Concat, self).__init__()
- self.dim = dim
- def forward(self, inputs):
- return paddle.concat(inputs, axis=self.dim)
- def extra_repr(self):
- return 'dim={}'.format(self.dim)
- def _convert_attention_mask(attn_mask, dtype):
- """
- Convert the attention mask to the target dtype we expect.
- Parameters:
- attn_mask (Tensor, optional): A tensor used in multi-head attention
- to prevents attention to some unwanted positions, usually the
- paddings or the subsequent positions. It is a tensor with shape
- broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
- When the data type is bool, the unwanted positions have `False`
- values and the others have `True` values. When the data type is
- int, the unwanted positions have 0 values and the others have 1
- values. When the data type is float, the unwanted positions have
- `-INF` values and the others have 0 values. It can be None when
- nothing wanted or needed to be prevented attention to. Default None.
- dtype (VarType): The target type of `attn_mask` we expect.
- Returns:
- Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
- """
- return nn.layer.transformer._convert_attention_mask(attn_mask, dtype)
- class MultiHeadAttention(nn.Layer):
- """
- Attention mapps queries and a set of key-value pairs to outputs, and
- Multi-Head Attention performs multiple parallel attention to jointly attending
- to information from different representation subspaces.
- Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_
- for more details.
- Parameters:
- embed_dim (int): The expected feature size in the input and output.
- num_heads (int): The number of heads in multi-head attention.
- dropout (float, optional): The dropout probability used on attention
- weights to drop some attention targets. 0 for no dropout. Default 0
- kdim (int, optional): The feature size in key. If None, assumed equal to
- `embed_dim`. Default None.
- vdim (int, optional): The feature size in value. If None, assumed equal to
- `embed_dim`. Default None.
- need_weights (bool, optional): Indicate whether to return the attention
- weights. Default False.
- Examples:
- .. code-block:: python
- import paddle
- # encoder input: [batch_size, sequence_length, d_model]
- query = paddle.rand((2, 4, 128))
- # self attention mask: [batch_size, num_heads, query_len, query_len]
- attn_mask = paddle.rand((2, 2, 4, 4))
- multi_head_attn = paddle.nn.MultiHeadAttention(128, 2)
- output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128]
- """
- def __init__(self,
- embed_dim,
- num_heads,
- dropout=0.,
- kdim=None,
- vdim=None,
- need_weights=False):
- super(MultiHeadAttention, self).__init__()
- self.embed_dim = embed_dim
- self.kdim = kdim if kdim is not None else embed_dim
- self.vdim = vdim if vdim is not None else embed_dim
- self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
- self.num_heads = num_heads
- self.dropout = dropout
- self.need_weights = need_weights
- self.head_dim = embed_dim // num_heads
- assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
- if self._qkv_same_embed_dim:
- self.in_proj_weight = self.create_parameter(
- shape=[embed_dim, 3 * embed_dim],
- attr=None,
- dtype=self._dtype,
- is_bias=False)
- self.in_proj_bias = self.create_parameter(
- shape=[3 * embed_dim],
- attr=None,
- dtype=self._dtype,
- is_bias=True)
- else:
- self.q_proj = nn.Linear(embed_dim, embed_dim)
- self.k_proj = nn.Linear(self.kdim, embed_dim)
- self.v_proj = nn.Linear(self.vdim, embed_dim)
- self.out_proj = nn.Linear(embed_dim, embed_dim)
- self._type_list = ('q_proj', 'k_proj', 'v_proj')
- self._reset_parameters()
- def _reset_parameters(self):
- for p in self.parameters():
- if p.dim() > 1:
- xavier_uniform_(p)
- else:
- constant_(p)
- def compute_qkv(self, tensor, index):
- if self._qkv_same_embed_dim:
- tensor = F.linear(
- x=tensor,
- weight=self.in_proj_weight[:, index * self.embed_dim:(index + 1)
- * self.embed_dim],
- bias=self.in_proj_bias[index * self.embed_dim:(index + 1) *
- self.embed_dim]
- if self.in_proj_bias is not None else None)
- else:
- tensor = getattr(self, self._type_list[index])(tensor)
- tensor = tensor.reshape(
- [0, 0, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
- return tensor
- def forward(self, query, key=None, value=None, attn_mask=None):
- r"""
- Applies multi-head attention to map queries and a set of key-value pairs
- to outputs.
- Parameters:
- query (Tensor): The queries for multi-head attention. It is a
- tensor with shape `[batch_size, query_length, embed_dim]`. The
- data type should be float32 or float64.
- key (Tensor, optional): The keys for multi-head attention. It is
- a tensor with shape `[batch_size, key_length, kdim]`. The
- data type should be float32 or float64. If None, use `query` as
- `key`. Default None.
- value (Tensor, optional): The values for multi-head attention. It
- is a tensor with shape `[batch_size, value_length, vdim]`.
- The data type should be float32 or float64. If None, use `query` as
- `value`. Default None.
- attn_mask (Tensor, optional): A tensor used in multi-head attention
- to prevents attention to some unwanted positions, usually the
- paddings or the subsequent positions. It is a tensor with shape
- broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
- When the data type is bool, the unwanted positions have `False`
- values and the others have `True` values. When the data type is
- int, the unwanted positions have 0 values and the others have 1
- values. When the data type is float, the unwanted positions have
- `-INF` values and the others have 0 values. It can be None when
- nothing wanted or needed to be prevented attention to. Default None.
- Returns:
- Tensor|tuple: It is a tensor that has the same shape and data type \
- as `query`, representing attention output. Or a tuple if \
- `need_weights` is True or `cache` is not None. If `need_weights` \
- is True, except for attention output, the tuple also includes \
- the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \
- If `cache` is not None, the tuple then includes the new cache \
- having the same type as `cache`, and if it is `StaticCache`, it \
- is same as the input `cache`, if it is `Cache`, the new cache \
- reserves tensors concatanating raw tensors with intermediate \
- results of current query.
- """
- key = query if key is None else key
- value = query if value is None else value
- # compute q ,k ,v
- q, k, v = (self.compute_qkv(t, i)
- for i, t in enumerate([query, key, value]))
- # scale dot product attention
- product = paddle.matmul(x=q, y=k, transpose_y=True)
- scaling = float(self.head_dim)**-0.5
- product = product * scaling
- if attn_mask is not None:
- # Support bool or int mask
- attn_mask = _convert_attention_mask(attn_mask, product.dtype)
- product = product + attn_mask
- weights = F.softmax(product)
- if self.dropout:
- weights = F.dropout(
- weights,
- self.dropout,
- training=self.training,
- mode="upscale_in_train")
- out = paddle.matmul(weights, v)
- # combine heads
- out = paddle.transpose(out, perm=[0, 2, 1, 3])
- out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
- # project to output
- out = self.out_proj(out)
- outs = [out]
- if self.need_weights:
- outs.append(weights)
- return out if len(outs) == 1 else tuple(outs)
- @register
- class ConvMixer(nn.Layer):
- def __init__(
- self,
- dim,
- depth,
- kernel_size=3, ):
- super().__init__()
- self.dim = dim
- self.depth = depth
- self.kernel_size = kernel_size
- self.mixer = self.conv_mixer(dim, depth, kernel_size)
- def forward(self, x):
- return self.mixer(x)
- @staticmethod
- def conv_mixer(
- dim,
- depth,
- kernel_size, ):
- Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim))
- Residual = type('Residual', (Seq, ),
- {'forward': lambda self, x: self[0](x) + x})
- return Seq(*[
- Seq(Residual(
- ActBn(
- nn.Conv2D(
- dim, dim, kernel_size, groups=dim, padding="same"))),
- ActBn(nn.Conv2D(dim, dim, 1))) for i in range(depth)
- ])
|