layers.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import math
  15. import six
  16. import numpy as np
  17. from numbers import Integral
  18. import paddle
  19. import paddle.nn as nn
  20. from paddle import ParamAttr
  21. from paddle import to_tensor
  22. import paddle.nn.functional as F
  23. from paddle.nn.initializer import Normal, Constant, XavierUniform
  24. from paddle.regularizer import L2Decay
  25. from ppdet.core.workspace import register, serializable
  26. from ppdet.modeling.bbox_utils import delta2bbox
  27. from . import ops
  28. from .initializer import xavier_uniform_, constant_
  29. from paddle.vision.ops import DeformConv2D
  30. def _to_list(l):
  31. if isinstance(l, (list, tuple)):
  32. return list(l)
  33. return [l]
  34. class AlignConv(nn.Layer):
  35. def __init__(self, in_channels, out_channels, kernel_size=3, groups=1):
  36. super(AlignConv, self).__init__()
  37. self.kernel_size = kernel_size
  38. self.align_conv = paddle.vision.ops.DeformConv2D(
  39. in_channels,
  40. out_channels,
  41. kernel_size=self.kernel_size,
  42. padding=(self.kernel_size - 1) // 2,
  43. groups=groups,
  44. weight_attr=ParamAttr(initializer=Normal(0, 0.01)),
  45. bias_attr=None)
  46. @paddle.no_grad()
  47. def get_offset(self, anchors, featmap_size, stride):
  48. """
  49. Args:
  50. anchors: [B, L, 5] xc,yc,w,h,angle
  51. featmap_size: (feat_h, feat_w)
  52. stride: 8
  53. Returns:
  54. """
  55. batch = anchors.shape[0]
  56. dtype = anchors.dtype
  57. feat_h, feat_w = featmap_size
  58. pad = (self.kernel_size - 1) // 2
  59. idx = paddle.arange(-pad, pad + 1, dtype=dtype)
  60. yy, xx = paddle.meshgrid(idx, idx)
  61. xx = paddle.reshape(xx, [-1])
  62. yy = paddle.reshape(yy, [-1])
  63. # get sampling locations of default conv
  64. xc = paddle.arange(0, feat_w, dtype=dtype)
  65. yc = paddle.arange(0, feat_h, dtype=dtype)
  66. yc, xc = paddle.meshgrid(yc, xc)
  67. xc = paddle.reshape(xc, [-1, 1])
  68. yc = paddle.reshape(yc, [-1, 1])
  69. x_conv = xc + xx
  70. y_conv = yc + yy
  71. # get sampling locations of anchors
  72. x_ctr, y_ctr, w, h, a = paddle.split(anchors, 5, axis=-1)
  73. x_ctr = x_ctr / stride
  74. y_ctr = y_ctr / stride
  75. w_s = w / stride
  76. h_s = h / stride
  77. cos, sin = paddle.cos(a), paddle.sin(a)
  78. dw, dh = w_s / self.kernel_size, h_s / self.kernel_size
  79. x, y = dw * xx, dh * yy
  80. xr = cos * x - sin * y
  81. yr = sin * x + cos * y
  82. x_anchor, y_anchor = xr + x_ctr, yr + y_ctr
  83. # get offset filed
  84. offset_x = x_anchor - x_conv
  85. offset_y = y_anchor - y_conv
  86. offset = paddle.stack([offset_y, offset_x], axis=-1)
  87. offset = offset.reshape(
  88. [batch, feat_h, feat_w, self.kernel_size * self.kernel_size * 2])
  89. offset = offset.transpose([0, 3, 1, 2])
  90. return offset
  91. def forward(self, x, refine_anchors, featmap_size, stride):
  92. batch = paddle.shape(x)[0].numpy()
  93. offset = self.get_offset(refine_anchors, featmap_size, stride)
  94. if self.training:
  95. x = F.relu(self.align_conv(x, offset.detach()))
  96. else:
  97. x = F.relu(self.align_conv(x, offset))
  98. return x
  99. class DeformableConvV2(nn.Layer):
  100. def __init__(self,
  101. in_channels,
  102. out_channels,
  103. kernel_size,
  104. stride=1,
  105. padding=0,
  106. dilation=1,
  107. groups=1,
  108. weight_attr=None,
  109. bias_attr=None,
  110. lr_scale=1,
  111. regularizer=None,
  112. skip_quant=False,
  113. dcn_bias_regularizer=L2Decay(0.),
  114. dcn_bias_lr_scale=2.):
  115. super(DeformableConvV2, self).__init__()
  116. self.offset_channel = 2 * kernel_size**2
  117. self.mask_channel = kernel_size**2
  118. if lr_scale == 1 and regularizer is None:
  119. offset_bias_attr = ParamAttr(initializer=Constant(0.))
  120. else:
  121. offset_bias_attr = ParamAttr(
  122. initializer=Constant(0.),
  123. learning_rate=lr_scale,
  124. regularizer=regularizer)
  125. self.conv_offset = nn.Conv2D(
  126. in_channels,
  127. 3 * kernel_size**2,
  128. kernel_size,
  129. stride=stride,
  130. padding=(kernel_size - 1) // 2,
  131. weight_attr=ParamAttr(initializer=Constant(0.0)),
  132. bias_attr=offset_bias_attr)
  133. if skip_quant:
  134. self.conv_offset.skip_quant = True
  135. if bias_attr:
  136. # in FCOS-DCN head, specifically need learning_rate and regularizer
  137. dcn_bias_attr = ParamAttr(
  138. initializer=Constant(value=0),
  139. regularizer=dcn_bias_regularizer,
  140. learning_rate=dcn_bias_lr_scale)
  141. else:
  142. # in ResNet backbone, do not need bias
  143. dcn_bias_attr = False
  144. self.conv_dcn = DeformConv2D(
  145. in_channels,
  146. out_channels,
  147. kernel_size,
  148. stride=stride,
  149. padding=(kernel_size - 1) // 2 * dilation,
  150. dilation=dilation,
  151. groups=groups,
  152. weight_attr=weight_attr,
  153. bias_attr=dcn_bias_attr)
  154. def forward(self, x):
  155. offset_mask = self.conv_offset(x)
  156. offset, mask = paddle.split(
  157. offset_mask,
  158. num_or_sections=[self.offset_channel, self.mask_channel],
  159. axis=1)
  160. mask = F.sigmoid(mask)
  161. y = self.conv_dcn(x, offset, mask=mask)
  162. return y
  163. class ConvNormLayer(nn.Layer):
  164. def __init__(self,
  165. ch_in,
  166. ch_out,
  167. filter_size,
  168. stride,
  169. groups=1,
  170. norm_type='bn',
  171. norm_decay=0.,
  172. norm_groups=32,
  173. use_dcn=False,
  174. bias_on=False,
  175. lr_scale=1.,
  176. freeze_norm=False,
  177. initializer=Normal(
  178. mean=0., std=0.01),
  179. skip_quant=False,
  180. dcn_lr_scale=2.,
  181. dcn_regularizer=L2Decay(0.)):
  182. super(ConvNormLayer, self).__init__()
  183. assert norm_type in ['bn', 'sync_bn', 'gn', None]
  184. if bias_on:
  185. bias_attr = ParamAttr(
  186. initializer=Constant(value=0.), learning_rate=lr_scale)
  187. else:
  188. bias_attr = False
  189. if not use_dcn:
  190. self.conv = nn.Conv2D(
  191. in_channels=ch_in,
  192. out_channels=ch_out,
  193. kernel_size=filter_size,
  194. stride=stride,
  195. padding=(filter_size - 1) // 2,
  196. groups=groups,
  197. weight_attr=ParamAttr(
  198. initializer=initializer, learning_rate=1.),
  199. bias_attr=bias_attr)
  200. if skip_quant:
  201. self.conv.skip_quant = True
  202. else:
  203. # in FCOS-DCN head, specifically need learning_rate and regularizer
  204. self.conv = DeformableConvV2(
  205. in_channels=ch_in,
  206. out_channels=ch_out,
  207. kernel_size=filter_size,
  208. stride=stride,
  209. padding=(filter_size - 1) // 2,
  210. groups=groups,
  211. weight_attr=ParamAttr(
  212. initializer=initializer, learning_rate=1.),
  213. bias_attr=True,
  214. lr_scale=dcn_lr_scale,
  215. regularizer=dcn_regularizer,
  216. dcn_bias_regularizer=dcn_regularizer,
  217. dcn_bias_lr_scale=dcn_lr_scale,
  218. skip_quant=skip_quant)
  219. norm_lr = 0. if freeze_norm else 1.
  220. param_attr = ParamAttr(
  221. learning_rate=norm_lr,
  222. regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
  223. bias_attr = ParamAttr(
  224. learning_rate=norm_lr,
  225. regularizer=L2Decay(norm_decay) if norm_decay is not None else None)
  226. if norm_type in ['bn', 'sync_bn']:
  227. self.norm = nn.BatchNorm2D(
  228. ch_out, weight_attr=param_attr, bias_attr=bias_attr)
  229. elif norm_type == 'gn':
  230. self.norm = nn.GroupNorm(
  231. num_groups=norm_groups,
  232. num_channels=ch_out,
  233. weight_attr=param_attr,
  234. bias_attr=bias_attr)
  235. else:
  236. self.norm = None
  237. def forward(self, inputs):
  238. out = self.conv(inputs)
  239. if self.norm is not None:
  240. out = self.norm(out)
  241. return out
  242. class LiteConv(nn.Layer):
  243. def __init__(self,
  244. in_channels,
  245. out_channels,
  246. stride=1,
  247. with_act=True,
  248. norm_type='sync_bn',
  249. name=None):
  250. super(LiteConv, self).__init__()
  251. self.lite_conv = nn.Sequential()
  252. conv1 = ConvNormLayer(
  253. in_channels,
  254. in_channels,
  255. filter_size=5,
  256. stride=stride,
  257. groups=in_channels,
  258. norm_type=norm_type,
  259. initializer=XavierUniform())
  260. conv2 = ConvNormLayer(
  261. in_channels,
  262. out_channels,
  263. filter_size=1,
  264. stride=stride,
  265. norm_type=norm_type,
  266. initializer=XavierUniform())
  267. conv3 = ConvNormLayer(
  268. out_channels,
  269. out_channels,
  270. filter_size=1,
  271. stride=stride,
  272. norm_type=norm_type,
  273. initializer=XavierUniform())
  274. conv4 = ConvNormLayer(
  275. out_channels,
  276. out_channels,
  277. filter_size=5,
  278. stride=stride,
  279. groups=out_channels,
  280. norm_type=norm_type,
  281. initializer=XavierUniform())
  282. conv_list = [conv1, conv2, conv3, conv4]
  283. self.lite_conv.add_sublayer('conv1', conv1)
  284. self.lite_conv.add_sublayer('relu6_1', nn.ReLU6())
  285. self.lite_conv.add_sublayer('conv2', conv2)
  286. if with_act:
  287. self.lite_conv.add_sublayer('relu6_2', nn.ReLU6())
  288. self.lite_conv.add_sublayer('conv3', conv3)
  289. self.lite_conv.add_sublayer('relu6_3', nn.ReLU6())
  290. self.lite_conv.add_sublayer('conv4', conv4)
  291. if with_act:
  292. self.lite_conv.add_sublayer('relu6_4', nn.ReLU6())
  293. def forward(self, inputs):
  294. out = self.lite_conv(inputs)
  295. return out
  296. class DropBlock(nn.Layer):
  297. def __init__(self, block_size, keep_prob, name=None, data_format='NCHW'):
  298. """
  299. DropBlock layer, see https://arxiv.org/abs/1810.12890
  300. Args:
  301. block_size (int): block size
  302. keep_prob (int): keep probability
  303. name (str): layer name
  304. data_format (str): data format, NCHW or NHWC
  305. """
  306. super(DropBlock, self).__init__()
  307. self.block_size = block_size
  308. self.keep_prob = keep_prob
  309. self.name = name
  310. self.data_format = data_format
  311. def forward(self, x):
  312. if not self.training or self.keep_prob == 1:
  313. return x
  314. else:
  315. gamma = (1. - self.keep_prob) / (self.block_size**2)
  316. if self.data_format == 'NCHW':
  317. shape = x.shape[2:]
  318. else:
  319. shape = x.shape[1:3]
  320. for s in shape:
  321. gamma *= s / (s - self.block_size + 1)
  322. matrix = paddle.cast(paddle.rand(x.shape) < gamma, x.dtype)
  323. mask_inv = F.max_pool2d(
  324. matrix,
  325. self.block_size,
  326. stride=1,
  327. padding=self.block_size // 2,
  328. data_format=self.data_format)
  329. mask = 1. - mask_inv
  330. y = x * mask * (mask.numel() / mask.sum())
  331. return y
  332. @register
  333. @serializable
  334. class AnchorGeneratorSSD(object):
  335. def __init__(self,
  336. steps=[8, 16, 32, 64, 100, 300],
  337. aspect_ratios=[[2.], [2., 3.], [2., 3.], [2., 3.], [2.], [2.]],
  338. min_ratio=15,
  339. max_ratio=90,
  340. base_size=300,
  341. min_sizes=[30.0, 60.0, 111.0, 162.0, 213.0, 264.0],
  342. max_sizes=[60.0, 111.0, 162.0, 213.0, 264.0, 315.0],
  343. offset=0.5,
  344. flip=True,
  345. clip=False,
  346. min_max_aspect_ratios_order=False):
  347. self.steps = steps
  348. self.aspect_ratios = aspect_ratios
  349. self.min_ratio = min_ratio
  350. self.max_ratio = max_ratio
  351. self.base_size = base_size
  352. self.min_sizes = min_sizes
  353. self.max_sizes = max_sizes
  354. self.offset = offset
  355. self.flip = flip
  356. self.clip = clip
  357. self.min_max_aspect_ratios_order = min_max_aspect_ratios_order
  358. if self.min_sizes == [] and self.max_sizes == []:
  359. num_layer = len(aspect_ratios)
  360. step = int(
  361. math.floor(((self.max_ratio - self.min_ratio)) / (num_layer - 2
  362. )))
  363. for ratio in six.moves.range(self.min_ratio, self.max_ratio + 1,
  364. step):
  365. self.min_sizes.append(self.base_size * ratio / 100.)
  366. self.max_sizes.append(self.base_size * (ratio + step) / 100.)
  367. self.min_sizes = [self.base_size * .10] + self.min_sizes
  368. self.max_sizes = [self.base_size * .20] + self.max_sizes
  369. self.num_priors = []
  370. for aspect_ratio, min_size, max_size in zip(
  371. aspect_ratios, self.min_sizes, self.max_sizes):
  372. if isinstance(min_size, (list, tuple)):
  373. self.num_priors.append(
  374. len(_to_list(min_size)) + len(_to_list(max_size)))
  375. else:
  376. self.num_priors.append((len(aspect_ratio) * 2 + 1) * len(
  377. _to_list(min_size)) + len(_to_list(max_size)))
  378. def __call__(self, inputs, image):
  379. boxes = []
  380. for input, min_size, max_size, aspect_ratio, step in zip(
  381. inputs, self.min_sizes, self.max_sizes, self.aspect_ratios,
  382. self.steps):
  383. box, _ = ops.prior_box(
  384. input=input,
  385. image=image,
  386. min_sizes=_to_list(min_size),
  387. max_sizes=_to_list(max_size),
  388. aspect_ratios=aspect_ratio,
  389. flip=self.flip,
  390. clip=self.clip,
  391. steps=[step, step],
  392. offset=self.offset,
  393. min_max_aspect_ratios_order=self.min_max_aspect_ratios_order)
  394. boxes.append(paddle.reshape(box, [-1, 4]))
  395. return boxes
  396. @register
  397. @serializable
  398. class RCNNBox(object):
  399. __shared__ = ['num_classes', 'export_onnx']
  400. def __init__(self,
  401. prior_box_var=[10., 10., 5., 5.],
  402. code_type="decode_center_size",
  403. box_normalized=False,
  404. num_classes=80,
  405. export_onnx=False):
  406. super(RCNNBox, self).__init__()
  407. self.prior_box_var = prior_box_var
  408. self.code_type = code_type
  409. self.box_normalized = box_normalized
  410. self.num_classes = num_classes
  411. self.export_onnx = export_onnx
  412. def __call__(self, bbox_head_out, rois, im_shape, scale_factor):
  413. bbox_pred = bbox_head_out[0]
  414. cls_prob = bbox_head_out[1]
  415. roi = rois[0]
  416. rois_num = rois[1]
  417. if self.export_onnx:
  418. onnx_rois_num_per_im = rois_num[0]
  419. origin_shape = paddle.expand(im_shape[0, :],
  420. [onnx_rois_num_per_im, 2])
  421. else:
  422. origin_shape_list = []
  423. if isinstance(roi, list):
  424. batch_size = len(roi)
  425. else:
  426. batch_size = paddle.slice(paddle.shape(im_shape), [0], [0], [1])
  427. # bbox_pred.shape: [N, C*4]
  428. for idx in range(batch_size):
  429. rois_num_per_im = rois_num[idx]
  430. expand_im_shape = paddle.expand(im_shape[idx, :],
  431. [rois_num_per_im, 2])
  432. origin_shape_list.append(expand_im_shape)
  433. origin_shape = paddle.concat(origin_shape_list)
  434. # bbox_pred.shape: [N, C*4]
  435. # C=num_classes in faster/mask rcnn(bbox_head), C=1 in cascade rcnn(cascade_head)
  436. bbox = paddle.concat(roi)
  437. bbox = delta2bbox(bbox_pred, bbox, self.prior_box_var)
  438. scores = cls_prob[:, :-1]
  439. # bbox.shape: [N, C, 4]
  440. # bbox.shape[1] must be equal to scores.shape[1]
  441. total_num = bbox.shape[0]
  442. bbox_dim = bbox.shape[-1]
  443. bbox = paddle.expand(bbox, [total_num, self.num_classes, bbox_dim])
  444. origin_h = paddle.unsqueeze(origin_shape[:, 0], axis=1)
  445. origin_w = paddle.unsqueeze(origin_shape[:, 1], axis=1)
  446. zeros = paddle.zeros_like(origin_h)
  447. x1 = paddle.maximum(paddle.minimum(bbox[:, :, 0], origin_w), zeros)
  448. y1 = paddle.maximum(paddle.minimum(bbox[:, :, 1], origin_h), zeros)
  449. x2 = paddle.maximum(paddle.minimum(bbox[:, :, 2], origin_w), zeros)
  450. y2 = paddle.maximum(paddle.minimum(bbox[:, :, 3], origin_h), zeros)
  451. bbox = paddle.stack([x1, y1, x2, y2], axis=-1)
  452. bboxes = (bbox, rois_num)
  453. return bboxes, scores
  454. @register
  455. @serializable
  456. class MultiClassNMS(object):
  457. def __init__(self,
  458. score_threshold=.05,
  459. nms_top_k=-1,
  460. keep_top_k=100,
  461. nms_threshold=.5,
  462. normalized=True,
  463. nms_eta=1.0,
  464. return_index=False,
  465. return_rois_num=True,
  466. trt=False):
  467. super(MultiClassNMS, self).__init__()
  468. self.score_threshold = score_threshold
  469. self.nms_top_k = nms_top_k
  470. self.keep_top_k = keep_top_k
  471. self.nms_threshold = nms_threshold
  472. self.normalized = normalized
  473. self.nms_eta = nms_eta
  474. self.return_index = return_index
  475. self.return_rois_num = return_rois_num
  476. self.trt = trt
  477. def __call__(self, bboxes, score, background_label=-1):
  478. """
  479. bboxes (Tensor|List[Tensor]): 1. (Tensor) Predicted bboxes with shape
  480. [N, M, 4], N is the batch size and M
  481. is the number of bboxes
  482. 2. (List[Tensor]) bboxes and bbox_num,
  483. bboxes have shape of [M, C, 4], C
  484. is the class number and bbox_num means
  485. the number of bboxes of each batch with
  486. shape [N,]
  487. score (Tensor): Predicted scores with shape [N, C, M] or [M, C]
  488. background_label (int): Ignore the background label; For example, RCNN
  489. is num_classes and YOLO is -1.
  490. """
  491. kwargs = self.__dict__.copy()
  492. if isinstance(bboxes, tuple):
  493. bboxes, bbox_num = bboxes
  494. kwargs.update({'rois_num': bbox_num})
  495. if background_label > -1:
  496. kwargs.update({'background_label': background_label})
  497. kwargs.pop('trt')
  498. # TODO(wangxinxin08): paddle version should be develop or 2.3 and above to run nms on tensorrt
  499. if self.trt and (int(paddle.version.major) == 0 or
  500. (int(paddle.version.major) >= 2 and
  501. int(paddle.version.minor) >= 3)):
  502. # TODO(wangxinxin08): tricky switch to run nms on tensorrt
  503. kwargs.update({'nms_eta': 1.1})
  504. bbox, bbox_num, _ = ops.multiclass_nms(bboxes, score, **kwargs)
  505. bbox = bbox.reshape([1, -1, 6])
  506. idx = paddle.nonzero(bbox[..., 0] != -1)
  507. bbox = paddle.gather_nd(bbox, idx)
  508. return bbox, bbox_num, None
  509. else:
  510. return ops.multiclass_nms(bboxes, score, **kwargs)
  511. @register
  512. @serializable
  513. class MatrixNMS(object):
  514. __append_doc__ = True
  515. def __init__(self,
  516. score_threshold=.05,
  517. post_threshold=.05,
  518. nms_top_k=-1,
  519. keep_top_k=100,
  520. use_gaussian=False,
  521. gaussian_sigma=2.,
  522. normalized=False,
  523. background_label=0):
  524. super(MatrixNMS, self).__init__()
  525. self.score_threshold = score_threshold
  526. self.post_threshold = post_threshold
  527. self.nms_top_k = nms_top_k
  528. self.keep_top_k = keep_top_k
  529. self.normalized = normalized
  530. self.use_gaussian = use_gaussian
  531. self.gaussian_sigma = gaussian_sigma
  532. self.background_label = background_label
  533. def __call__(self, bbox, score, *args):
  534. return ops.matrix_nms(
  535. bboxes=bbox,
  536. scores=score,
  537. score_threshold=self.score_threshold,
  538. post_threshold=self.post_threshold,
  539. nms_top_k=self.nms_top_k,
  540. keep_top_k=self.keep_top_k,
  541. use_gaussian=self.use_gaussian,
  542. gaussian_sigma=self.gaussian_sigma,
  543. background_label=self.background_label,
  544. normalized=self.normalized)
  545. @register
  546. @serializable
  547. class YOLOBox(object):
  548. __shared__ = ['num_classes']
  549. def __init__(self,
  550. num_classes=80,
  551. conf_thresh=0.005,
  552. downsample_ratio=32,
  553. clip_bbox=True,
  554. scale_x_y=1.):
  555. self.num_classes = num_classes
  556. self.conf_thresh = conf_thresh
  557. self.downsample_ratio = downsample_ratio
  558. self.clip_bbox = clip_bbox
  559. self.scale_x_y = scale_x_y
  560. def __call__(self,
  561. yolo_head_out,
  562. anchors,
  563. im_shape,
  564. scale_factor,
  565. var_weight=None):
  566. boxes_list = []
  567. scores_list = []
  568. origin_shape = im_shape / scale_factor
  569. origin_shape = paddle.cast(origin_shape, 'int32')
  570. for i, head_out in enumerate(yolo_head_out):
  571. boxes, scores = paddle.vision.ops.yolo_box(
  572. head_out,
  573. origin_shape,
  574. anchors[i],
  575. self.num_classes,
  576. self.conf_thresh,
  577. self.downsample_ratio // 2**i,
  578. self.clip_bbox,
  579. scale_x_y=self.scale_x_y)
  580. boxes_list.append(boxes)
  581. scores_list.append(paddle.transpose(scores, perm=[0, 2, 1]))
  582. yolo_boxes = paddle.concat(boxes_list, axis=1)
  583. yolo_scores = paddle.concat(scores_list, axis=2)
  584. return yolo_boxes, yolo_scores
  585. @register
  586. @serializable
  587. class SSDBox(object):
  588. def __init__(self,
  589. is_normalized=True,
  590. prior_box_var=[0.1, 0.1, 0.2, 0.2],
  591. use_fuse_decode=False):
  592. self.is_normalized = is_normalized
  593. self.norm_delta = float(not self.is_normalized)
  594. self.prior_box_var = prior_box_var
  595. self.use_fuse_decode = use_fuse_decode
  596. def __call__(self,
  597. preds,
  598. prior_boxes,
  599. im_shape,
  600. scale_factor,
  601. var_weight=None):
  602. boxes, scores = preds
  603. boxes = paddle.concat(boxes, axis=1)
  604. prior_boxes = paddle.concat(prior_boxes)
  605. if self.use_fuse_decode:
  606. output_boxes = ops.box_coder(
  607. prior_boxes,
  608. self.prior_box_var,
  609. boxes,
  610. code_type="decode_center_size",
  611. box_normalized=self.is_normalized)
  612. else:
  613. pb_w = prior_boxes[:, 2] - prior_boxes[:, 0] + self.norm_delta
  614. pb_h = prior_boxes[:, 3] - prior_boxes[:, 1] + self.norm_delta
  615. pb_x = prior_boxes[:, 0] + pb_w * 0.5
  616. pb_y = prior_boxes[:, 1] + pb_h * 0.5
  617. out_x = pb_x + boxes[:, :, 0] * pb_w * self.prior_box_var[0]
  618. out_y = pb_y + boxes[:, :, 1] * pb_h * self.prior_box_var[1]
  619. out_w = paddle.exp(boxes[:, :, 2] * self.prior_box_var[2]) * pb_w
  620. out_h = paddle.exp(boxes[:, :, 3] * self.prior_box_var[3]) * pb_h
  621. output_boxes = paddle.stack(
  622. [
  623. out_x - out_w / 2., out_y - out_h / 2., out_x + out_w / 2.,
  624. out_y + out_h / 2.
  625. ],
  626. axis=-1)
  627. if self.is_normalized:
  628. h = (im_shape[:, 0] / scale_factor[:, 0]).unsqueeze(-1)
  629. w = (im_shape[:, 1] / scale_factor[:, 1]).unsqueeze(-1)
  630. im_shape = paddle.stack([w, h, w, h], axis=-1)
  631. output_boxes *= im_shape
  632. else:
  633. output_boxes[..., -2:] -= 1.0
  634. output_scores = F.softmax(paddle.concat(
  635. scores, axis=1)).transpose([0, 2, 1])
  636. return output_boxes, output_scores
  637. @register
  638. class TTFBox(object):
  639. __shared__ = ['down_ratio']
  640. def __init__(self, max_per_img=100, score_thresh=0.01, down_ratio=4):
  641. super(TTFBox, self).__init__()
  642. self.max_per_img = max_per_img
  643. self.score_thresh = score_thresh
  644. self.down_ratio = down_ratio
  645. def _simple_nms(self, heat, kernel=3):
  646. """
  647. Use maxpool to filter the max score, get local peaks.
  648. """
  649. pad = (kernel - 1) // 2
  650. hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
  651. keep = paddle.cast(hmax == heat, 'float32')
  652. return heat * keep
  653. def _topk(self, scores):
  654. """
  655. Select top k scores and decode to get xy coordinates.
  656. """
  657. k = self.max_per_img
  658. shape_fm = paddle.shape(scores)
  659. shape_fm.stop_gradient = True
  660. cat, height, width = shape_fm[1], shape_fm[2], shape_fm[3]
  661. # batch size is 1
  662. scores_r = paddle.reshape(scores, [cat, -1])
  663. topk_scores, topk_inds = paddle.topk(scores_r, k)
  664. topk_ys = topk_inds // width
  665. topk_xs = topk_inds % width
  666. topk_score_r = paddle.reshape(topk_scores, [-1])
  667. topk_score, topk_ind = paddle.topk(topk_score_r, k)
  668. k_t = paddle.full(paddle.shape(topk_ind), k, dtype='int64')
  669. topk_clses = paddle.cast(paddle.floor_divide(topk_ind, k_t), 'float32')
  670. topk_inds = paddle.reshape(topk_inds, [-1])
  671. topk_ys = paddle.reshape(topk_ys, [-1, 1])
  672. topk_xs = paddle.reshape(topk_xs, [-1, 1])
  673. topk_inds = paddle.gather(topk_inds, topk_ind)
  674. topk_ys = paddle.gather(topk_ys, topk_ind)
  675. topk_xs = paddle.gather(topk_xs, topk_ind)
  676. return topk_score, topk_inds, topk_clses, topk_ys, topk_xs
  677. def _decode(self, hm, wh, im_shape, scale_factor):
  678. heatmap = F.sigmoid(hm)
  679. heat = self._simple_nms(heatmap)
  680. scores, inds, clses, ys, xs = self._topk(heat)
  681. ys = paddle.cast(ys, 'float32') * self.down_ratio
  682. xs = paddle.cast(xs, 'float32') * self.down_ratio
  683. scores = paddle.tensor.unsqueeze(scores, [1])
  684. clses = paddle.tensor.unsqueeze(clses, [1])
  685. wh_t = paddle.transpose(wh, [0, 2, 3, 1])
  686. wh = paddle.reshape(wh_t, [-1, paddle.shape(wh_t)[-1]])
  687. wh = paddle.gather(wh, inds)
  688. x1 = xs - wh[:, 0:1]
  689. y1 = ys - wh[:, 1:2]
  690. x2 = xs + wh[:, 2:3]
  691. y2 = ys + wh[:, 3:4]
  692. bboxes = paddle.concat([x1, y1, x2, y2], axis=1)
  693. scale_y = scale_factor[:, 0:1]
  694. scale_x = scale_factor[:, 1:2]
  695. scale_expand = paddle.concat(
  696. [scale_x, scale_y, scale_x, scale_y], axis=1)
  697. boxes_shape = paddle.shape(bboxes)
  698. boxes_shape.stop_gradient = True
  699. scale_expand = paddle.expand(scale_expand, shape=boxes_shape)
  700. bboxes = paddle.divide(bboxes, scale_expand)
  701. results = paddle.concat([clses, scores, bboxes], axis=1)
  702. # hack: append result with cls=-1 and score=1. to avoid all scores
  703. # are less than score_thresh which may cause error in gather.
  704. fill_r = paddle.to_tensor(np.array([[-1, 1, 0, 0, 0, 0]]))
  705. fill_r = paddle.cast(fill_r, results.dtype)
  706. results = paddle.concat([results, fill_r])
  707. scores = results[:, 1]
  708. valid_ind = paddle.nonzero(scores > self.score_thresh)
  709. results = paddle.gather(results, valid_ind)
  710. return results, paddle.shape(results)[0:1]
  711. def __call__(self, hm, wh, im_shape, scale_factor):
  712. results = []
  713. results_num = []
  714. for i in range(scale_factor.shape[0]):
  715. result, num = self._decode(hm[i:i + 1, ], wh[i:i + 1, ],
  716. im_shape[i:i + 1, ],
  717. scale_factor[i:i + 1, ])
  718. results.append(result)
  719. results_num.append(num)
  720. results = paddle.concat(results, axis=0)
  721. results_num = paddle.concat(results_num, axis=0)
  722. return results, results_num
  723. @register
  724. @serializable
  725. class JDEBox(object):
  726. __shared__ = ['num_classes']
  727. def __init__(self, num_classes=1, conf_thresh=0.3, downsample_ratio=32):
  728. self.num_classes = num_classes
  729. self.conf_thresh = conf_thresh
  730. self.downsample_ratio = downsample_ratio
  731. def generate_anchor(self, nGh, nGw, anchor_wh):
  732. nA = len(anchor_wh)
  733. yv, xv = paddle.meshgrid([paddle.arange(nGh), paddle.arange(nGw)])
  734. mesh = paddle.stack(
  735. (xv, yv), axis=0).cast(dtype='float32') # 2 x nGh x nGw
  736. meshs = paddle.tile(mesh, [nA, 1, 1, 1])
  737. anchor_offset_mesh = anchor_wh[:, :, None][:, :, :, None].repeat(
  738. int(nGh), axis=-2).repeat(
  739. int(nGw), axis=-1)
  740. anchor_offset_mesh = paddle.to_tensor(
  741. anchor_offset_mesh.astype(np.float32))
  742. # nA x 2 x nGh x nGw
  743. anchor_mesh = paddle.concat([meshs, anchor_offset_mesh], axis=1)
  744. anchor_mesh = paddle.transpose(anchor_mesh,
  745. [0, 2, 3, 1]) # (nA x nGh x nGw) x 4
  746. return anchor_mesh
  747. def decode_delta(self, delta, fg_anchor_list):
  748. px, py, pw, ph = fg_anchor_list[:, 0], fg_anchor_list[:,1], \
  749. fg_anchor_list[:, 2], fg_anchor_list[:,3]
  750. dx, dy, dw, dh = delta[:, 0], delta[:, 1], delta[:, 2], delta[:, 3]
  751. gx = pw * dx + px
  752. gy = ph * dy + py
  753. gw = pw * paddle.exp(dw)
  754. gh = ph * paddle.exp(dh)
  755. gx1 = gx - gw * 0.5
  756. gy1 = gy - gh * 0.5
  757. gx2 = gx + gw * 0.5
  758. gy2 = gy + gh * 0.5
  759. return paddle.stack([gx1, gy1, gx2, gy2], axis=1)
  760. def decode_delta_map(self, nA, nGh, nGw, delta_map, anchor_vec):
  761. anchor_mesh = self.generate_anchor(nGh, nGw, anchor_vec)
  762. anchor_mesh = paddle.unsqueeze(anchor_mesh, 0)
  763. pred_list = self.decode_delta(
  764. paddle.reshape(
  765. delta_map, shape=[-1, 4]),
  766. paddle.reshape(
  767. anchor_mesh, shape=[-1, 4]))
  768. pred_map = paddle.reshape(pred_list, shape=[nA * nGh * nGw, 4])
  769. return pred_map
  770. def _postprocessing_by_level(self, nA, stride, head_out, anchor_vec):
  771. boxes_shape = head_out.shape # [nB, nA*6, nGh, nGw]
  772. nGh, nGw = boxes_shape[-2], boxes_shape[-1]
  773. nB = 1 # TODO: only support bs=1 now
  774. boxes_list, scores_list = [], []
  775. for idx in range(nB):
  776. p = paddle.reshape(
  777. head_out[idx], shape=[nA, self.num_classes + 5, nGh, nGw])
  778. p = paddle.transpose(p, perm=[0, 2, 3, 1]) # [nA, nGh, nGw, 6]
  779. delta_map = p[:, :, :, :4]
  780. boxes = self.decode_delta_map(nA, nGh, nGw, delta_map, anchor_vec)
  781. # [nA * nGh * nGw, 4]
  782. boxes_list.append(boxes * stride)
  783. p_conf = paddle.transpose(
  784. p[:, :, :, 4:6], perm=[3, 0, 1, 2]) # [2, nA, nGh, nGw]
  785. p_conf = F.softmax(
  786. p_conf, axis=0)[1, :, :, :].unsqueeze(-1) # [nA, nGh, nGw, 1]
  787. scores = paddle.reshape(p_conf, shape=[nA * nGh * nGw, 1])
  788. scores_list.append(scores)
  789. boxes_results = paddle.stack(boxes_list)
  790. scores_results = paddle.stack(scores_list)
  791. return boxes_results, scores_results
  792. def __call__(self, yolo_head_out, anchors):
  793. bbox_pred_list = []
  794. for i, head_out in enumerate(yolo_head_out):
  795. stride = self.downsample_ratio // 2**i
  796. anc_w, anc_h = anchors[i][0::2], anchors[i][1::2]
  797. anchor_vec = np.stack((anc_w, anc_h), axis=1) / stride
  798. nA = len(anc_w)
  799. boxes, scores = self._postprocessing_by_level(nA, stride, head_out,
  800. anchor_vec)
  801. bbox_pred_list.append(paddle.concat([boxes, scores], axis=-1))
  802. yolo_boxes_scores = paddle.concat(bbox_pred_list, axis=1)
  803. boxes_idx_over_conf_thr = paddle.nonzero(
  804. yolo_boxes_scores[:, :, -1] > self.conf_thresh)
  805. boxes_idx_over_conf_thr.stop_gradient = True
  806. return boxes_idx_over_conf_thr, yolo_boxes_scores
  807. @register
  808. @serializable
  809. class MaskMatrixNMS(object):
  810. """
  811. Matrix NMS for multi-class masks.
  812. Args:
  813. update_threshold (float): Updated threshold of categroy score in second time.
  814. pre_nms_top_n (int): Number of total instance to be kept per image before NMS
  815. post_nms_top_n (int): Number of total instance to be kept per image after NMS.
  816. kernel (str): 'linear' or 'gaussian'.
  817. sigma (float): std in gaussian method.
  818. Input:
  819. seg_preds (Variable): shape (n, h, w), segmentation feature maps
  820. seg_masks (Variable): shape (n, h, w), segmentation feature maps
  821. cate_labels (Variable): shape (n), mask labels in descending order
  822. cate_scores (Variable): shape (n), mask scores in descending order
  823. sum_masks (Variable): a float tensor of the sum of seg_masks
  824. Returns:
  825. Variable: cate_scores, tensors of shape (n)
  826. """
  827. def __init__(self,
  828. update_threshold=0.05,
  829. pre_nms_top_n=500,
  830. post_nms_top_n=100,
  831. kernel='gaussian',
  832. sigma=2.0):
  833. super(MaskMatrixNMS, self).__init__()
  834. self.update_threshold = update_threshold
  835. self.pre_nms_top_n = pre_nms_top_n
  836. self.post_nms_top_n = post_nms_top_n
  837. self.kernel = kernel
  838. self.sigma = sigma
  839. def _sort_score(self, scores, top_num):
  840. if paddle.shape(scores)[0] > top_num:
  841. return paddle.topk(scores, top_num)[1]
  842. else:
  843. return paddle.argsort(scores, descending=True)
  844. def __call__(self,
  845. seg_preds,
  846. seg_masks,
  847. cate_labels,
  848. cate_scores,
  849. sum_masks=None):
  850. # sort and keep top nms_pre
  851. sort_inds = self._sort_score(cate_scores, self.pre_nms_top_n)
  852. seg_masks = paddle.gather(seg_masks, index=sort_inds)
  853. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  854. sum_masks = paddle.gather(sum_masks, index=sort_inds)
  855. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  856. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  857. seg_masks = paddle.flatten(seg_masks, start_axis=1, stop_axis=-1)
  858. # inter.
  859. inter_matrix = paddle.mm(seg_masks, paddle.transpose(seg_masks, [1, 0]))
  860. n_samples = paddle.shape(cate_labels)
  861. # union.
  862. sum_masks_x = paddle.expand(sum_masks, shape=[n_samples, n_samples])
  863. # iou.
  864. iou_matrix = (inter_matrix / (
  865. sum_masks_x + paddle.transpose(sum_masks_x, [1, 0]) - inter_matrix))
  866. iou_matrix = paddle.triu(iou_matrix, diagonal=1)
  867. # label_specific matrix.
  868. cate_labels_x = paddle.expand(cate_labels, shape=[n_samples, n_samples])
  869. label_matrix = paddle.cast(
  870. (cate_labels_x == paddle.transpose(cate_labels_x, [1, 0])),
  871. 'float32')
  872. label_matrix = paddle.triu(label_matrix, diagonal=1)
  873. # IoU compensation
  874. compensate_iou = paddle.max((iou_matrix * label_matrix), axis=0)
  875. compensate_iou = paddle.expand(
  876. compensate_iou, shape=[n_samples, n_samples])
  877. compensate_iou = paddle.transpose(compensate_iou, [1, 0])
  878. # IoU decay
  879. decay_iou = iou_matrix * label_matrix
  880. # matrix nms
  881. if self.kernel == 'gaussian':
  882. decay_matrix = paddle.exp(-1 * self.sigma * (decay_iou**2))
  883. compensate_matrix = paddle.exp(-1 * self.sigma *
  884. (compensate_iou**2))
  885. decay_coefficient = paddle.min(decay_matrix / compensate_matrix,
  886. axis=0)
  887. elif self.kernel == 'linear':
  888. decay_matrix = (1 - decay_iou) / (1 - compensate_iou)
  889. decay_coefficient = paddle.min(decay_matrix, axis=0)
  890. else:
  891. raise NotImplementedError
  892. # update the score.
  893. cate_scores = cate_scores * decay_coefficient
  894. y = paddle.zeros(shape=paddle.shape(cate_scores), dtype='float32')
  895. keep = paddle.where(cate_scores >= self.update_threshold, cate_scores,
  896. y)
  897. keep = paddle.nonzero(keep)
  898. keep = paddle.squeeze(keep, axis=[1])
  899. # Prevent empty and increase fake data
  900. keep = paddle.concat(
  901. [keep, paddle.cast(paddle.shape(cate_scores)[0] - 1, 'int64')])
  902. seg_preds = paddle.gather(seg_preds, index=keep)
  903. cate_scores = paddle.gather(cate_scores, index=keep)
  904. cate_labels = paddle.gather(cate_labels, index=keep)
  905. # sort and keep top_k
  906. sort_inds = self._sort_score(cate_scores, self.post_nms_top_n)
  907. seg_preds = paddle.gather(seg_preds, index=sort_inds)
  908. cate_scores = paddle.gather(cate_scores, index=sort_inds)
  909. cate_labels = paddle.gather(cate_labels, index=sort_inds)
  910. return seg_preds, cate_scores, cate_labels
  911. def Conv2d(in_channels,
  912. out_channels,
  913. kernel_size,
  914. stride=1,
  915. padding=0,
  916. dilation=1,
  917. groups=1,
  918. bias=True,
  919. weight_init=Normal(std=0.001),
  920. bias_init=Constant(0.)):
  921. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  922. if bias:
  923. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  924. else:
  925. bias_attr = False
  926. conv = nn.Conv2D(
  927. in_channels,
  928. out_channels,
  929. kernel_size,
  930. stride,
  931. padding,
  932. dilation,
  933. groups,
  934. weight_attr=weight_attr,
  935. bias_attr=bias_attr)
  936. return conv
  937. def ConvTranspose2d(in_channels,
  938. out_channels,
  939. kernel_size,
  940. stride=1,
  941. padding=0,
  942. output_padding=0,
  943. groups=1,
  944. bias=True,
  945. dilation=1,
  946. weight_init=Normal(std=0.001),
  947. bias_init=Constant(0.)):
  948. weight_attr = paddle.framework.ParamAttr(initializer=weight_init)
  949. if bias:
  950. bias_attr = paddle.framework.ParamAttr(initializer=bias_init)
  951. else:
  952. bias_attr = False
  953. conv = nn.Conv2DTranspose(
  954. in_channels,
  955. out_channels,
  956. kernel_size,
  957. stride,
  958. padding,
  959. output_padding,
  960. dilation,
  961. groups,
  962. weight_attr=weight_attr,
  963. bias_attr=bias_attr)
  964. return conv
  965. def BatchNorm2d(num_features, eps=1e-05, momentum=0.9, affine=True):
  966. if not affine:
  967. weight_attr = False
  968. bias_attr = False
  969. else:
  970. weight_attr = None
  971. bias_attr = None
  972. batchnorm = nn.BatchNorm2D(
  973. num_features,
  974. momentum,
  975. eps,
  976. weight_attr=weight_attr,
  977. bias_attr=bias_attr)
  978. return batchnorm
  979. def ReLU():
  980. return nn.ReLU()
  981. def Upsample(scale_factor=None, mode='nearest', align_corners=False):
  982. return nn.Upsample(None, scale_factor, mode, align_corners)
  983. def MaxPool(kernel_size, stride, padding, ceil_mode=False):
  984. return nn.MaxPool2D(kernel_size, stride, padding, ceil_mode=ceil_mode)
  985. class Concat(nn.Layer):
  986. def __init__(self, dim=0):
  987. super(Concat, self).__init__()
  988. self.dim = dim
  989. def forward(self, inputs):
  990. return paddle.concat(inputs, axis=self.dim)
  991. def extra_repr(self):
  992. return 'dim={}'.format(self.dim)
  993. def _convert_attention_mask(attn_mask, dtype):
  994. """
  995. Convert the attention mask to the target dtype we expect.
  996. Parameters:
  997. attn_mask (Tensor, optional): A tensor used in multi-head attention
  998. to prevents attention to some unwanted positions, usually the
  999. paddings or the subsequent positions. It is a tensor with shape
  1000. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
  1001. When the data type is bool, the unwanted positions have `False`
  1002. values and the others have `True` values. When the data type is
  1003. int, the unwanted positions have 0 values and the others have 1
  1004. values. When the data type is float, the unwanted positions have
  1005. `-INF` values and the others have 0 values. It can be None when
  1006. nothing wanted or needed to be prevented attention to. Default None.
  1007. dtype (VarType): The target type of `attn_mask` we expect.
  1008. Returns:
  1009. Tensor: A Tensor with shape same as input `attn_mask`, with data type `dtype`.
  1010. """
  1011. return nn.layer.transformer._convert_attention_mask(attn_mask, dtype)
  1012. class MultiHeadAttention(nn.Layer):
  1013. """
  1014. Attention mapps queries and a set of key-value pairs to outputs, and
  1015. Multi-Head Attention performs multiple parallel attention to jointly attending
  1016. to information from different representation subspaces.
  1017. Please refer to `Attention Is All You Need <https://arxiv.org/pdf/1706.03762.pdf>`_
  1018. for more details.
  1019. Parameters:
  1020. embed_dim (int): The expected feature size in the input and output.
  1021. num_heads (int): The number of heads in multi-head attention.
  1022. dropout (float, optional): The dropout probability used on attention
  1023. weights to drop some attention targets. 0 for no dropout. Default 0
  1024. kdim (int, optional): The feature size in key. If None, assumed equal to
  1025. `embed_dim`. Default None.
  1026. vdim (int, optional): The feature size in value. If None, assumed equal to
  1027. `embed_dim`. Default None.
  1028. need_weights (bool, optional): Indicate whether to return the attention
  1029. weights. Default False.
  1030. Examples:
  1031. .. code-block:: python
  1032. import paddle
  1033. # encoder input: [batch_size, sequence_length, d_model]
  1034. query = paddle.rand((2, 4, 128))
  1035. # self attention mask: [batch_size, num_heads, query_len, query_len]
  1036. attn_mask = paddle.rand((2, 2, 4, 4))
  1037. multi_head_attn = paddle.nn.MultiHeadAttention(128, 2)
  1038. output = multi_head_attn(query, None, None, attn_mask=attn_mask) # [2, 4, 128]
  1039. """
  1040. def __init__(self,
  1041. embed_dim,
  1042. num_heads,
  1043. dropout=0.,
  1044. kdim=None,
  1045. vdim=None,
  1046. need_weights=False):
  1047. super(MultiHeadAttention, self).__init__()
  1048. self.embed_dim = embed_dim
  1049. self.kdim = kdim if kdim is not None else embed_dim
  1050. self.vdim = vdim if vdim is not None else embed_dim
  1051. self._qkv_same_embed_dim = self.kdim == embed_dim and self.vdim == embed_dim
  1052. self.num_heads = num_heads
  1053. self.dropout = dropout
  1054. self.need_weights = need_weights
  1055. self.head_dim = embed_dim // num_heads
  1056. assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
  1057. if self._qkv_same_embed_dim:
  1058. self.in_proj_weight = self.create_parameter(
  1059. shape=[embed_dim, 3 * embed_dim],
  1060. attr=None,
  1061. dtype=self._dtype,
  1062. is_bias=False)
  1063. self.in_proj_bias = self.create_parameter(
  1064. shape=[3 * embed_dim],
  1065. attr=None,
  1066. dtype=self._dtype,
  1067. is_bias=True)
  1068. else:
  1069. self.q_proj = nn.Linear(embed_dim, embed_dim)
  1070. self.k_proj = nn.Linear(self.kdim, embed_dim)
  1071. self.v_proj = nn.Linear(self.vdim, embed_dim)
  1072. self.out_proj = nn.Linear(embed_dim, embed_dim)
  1073. self._type_list = ('q_proj', 'k_proj', 'v_proj')
  1074. self._reset_parameters()
  1075. def _reset_parameters(self):
  1076. for p in self.parameters():
  1077. if p.dim() > 1:
  1078. xavier_uniform_(p)
  1079. else:
  1080. constant_(p)
  1081. def compute_qkv(self, tensor, index):
  1082. if self._qkv_same_embed_dim:
  1083. tensor = F.linear(
  1084. x=tensor,
  1085. weight=self.in_proj_weight[:, index * self.embed_dim:(index + 1)
  1086. * self.embed_dim],
  1087. bias=self.in_proj_bias[index * self.embed_dim:(index + 1) *
  1088. self.embed_dim]
  1089. if self.in_proj_bias is not None else None)
  1090. else:
  1091. tensor = getattr(self, self._type_list[index])(tensor)
  1092. tensor = tensor.reshape(
  1093. [0, 0, self.num_heads, self.head_dim]).transpose([0, 2, 1, 3])
  1094. return tensor
  1095. def forward(self, query, key=None, value=None, attn_mask=None):
  1096. r"""
  1097. Applies multi-head attention to map queries and a set of key-value pairs
  1098. to outputs.
  1099. Parameters:
  1100. query (Tensor): The queries for multi-head attention. It is a
  1101. tensor with shape `[batch_size, query_length, embed_dim]`. The
  1102. data type should be float32 or float64.
  1103. key (Tensor, optional): The keys for multi-head attention. It is
  1104. a tensor with shape `[batch_size, key_length, kdim]`. The
  1105. data type should be float32 or float64. If None, use `query` as
  1106. `key`. Default None.
  1107. value (Tensor, optional): The values for multi-head attention. It
  1108. is a tensor with shape `[batch_size, value_length, vdim]`.
  1109. The data type should be float32 or float64. If None, use `query` as
  1110. `value`. Default None.
  1111. attn_mask (Tensor, optional): A tensor used in multi-head attention
  1112. to prevents attention to some unwanted positions, usually the
  1113. paddings or the subsequent positions. It is a tensor with shape
  1114. broadcasted to `[batch_size, n_head, sequence_length, sequence_length]`.
  1115. When the data type is bool, the unwanted positions have `False`
  1116. values and the others have `True` values. When the data type is
  1117. int, the unwanted positions have 0 values and the others have 1
  1118. values. When the data type is float, the unwanted positions have
  1119. `-INF` values and the others have 0 values. It can be None when
  1120. nothing wanted or needed to be prevented attention to. Default None.
  1121. Returns:
  1122. Tensor|tuple: It is a tensor that has the same shape and data type \
  1123. as `query`, representing attention output. Or a tuple if \
  1124. `need_weights` is True or `cache` is not None. If `need_weights` \
  1125. is True, except for attention output, the tuple also includes \
  1126. the attention weights tensor shaped `[batch_size, num_heads, query_length, key_length]`. \
  1127. If `cache` is not None, the tuple then includes the new cache \
  1128. having the same type as `cache`, and if it is `StaticCache`, it \
  1129. is same as the input `cache`, if it is `Cache`, the new cache \
  1130. reserves tensors concatanating raw tensors with intermediate \
  1131. results of current query.
  1132. """
  1133. key = query if key is None else key
  1134. value = query if value is None else value
  1135. # compute q ,k ,v
  1136. q, k, v = (self.compute_qkv(t, i)
  1137. for i, t in enumerate([query, key, value]))
  1138. # scale dot product attention
  1139. product = paddle.matmul(x=q, y=k, transpose_y=True)
  1140. scaling = float(self.head_dim)**-0.5
  1141. product = product * scaling
  1142. if attn_mask is not None:
  1143. # Support bool or int mask
  1144. attn_mask = _convert_attention_mask(attn_mask, product.dtype)
  1145. product = product + attn_mask
  1146. weights = F.softmax(product)
  1147. if self.dropout:
  1148. weights = F.dropout(
  1149. weights,
  1150. self.dropout,
  1151. training=self.training,
  1152. mode="upscale_in_train")
  1153. out = paddle.matmul(weights, v)
  1154. # combine heads
  1155. out = paddle.transpose(out, perm=[0, 2, 1, 3])
  1156. out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
  1157. # project to output
  1158. out = self.out_proj(out)
  1159. outs = [out]
  1160. if self.need_weights:
  1161. outs.append(weights)
  1162. return out if len(outs) == 1 else tuple(outs)
  1163. @register
  1164. class ConvMixer(nn.Layer):
  1165. def __init__(
  1166. self,
  1167. dim,
  1168. depth,
  1169. kernel_size=3, ):
  1170. super().__init__()
  1171. self.dim = dim
  1172. self.depth = depth
  1173. self.kernel_size = kernel_size
  1174. self.mixer = self.conv_mixer(dim, depth, kernel_size)
  1175. def forward(self, x):
  1176. return self.mixer(x)
  1177. @staticmethod
  1178. def conv_mixer(
  1179. dim,
  1180. depth,
  1181. kernel_size, ):
  1182. Seq, ActBn = nn.Sequential, lambda x: Seq(x, nn.GELU(), nn.BatchNorm2D(dim))
  1183. Residual = type('Residual', (Seq, ),
  1184. {'forward': lambda self, x: self[0](x) + x})
  1185. return Seq(*[
  1186. Seq(Residual(
  1187. ActBn(
  1188. nn.Conv2D(
  1189. dim, dim, kernel_size, groups=dim, padding="same"))),
  1190. ActBn(nn.Conv2D(dim, dim, 1))) for i in range(depth)
  1191. ])