dla.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from ppdet.core.workspace import register, serializable
  18. from ppdet.modeling.layers import ConvNormLayer
  19. from ..shape_spec import ShapeSpec
  20. DLA_cfg = {34: ([1, 1, 1, 2, 2, 1], [16, 32, 64, 128, 256, 512]), }
  21. class BasicBlock(nn.Layer):
  22. def __init__(self, ch_in, ch_out, stride=1):
  23. super(BasicBlock, self).__init__()
  24. self.conv1 = ConvNormLayer(
  25. ch_in,
  26. ch_out,
  27. filter_size=3,
  28. stride=stride,
  29. bias_on=False,
  30. norm_decay=None)
  31. self.conv2 = ConvNormLayer(
  32. ch_out,
  33. ch_out,
  34. filter_size=3,
  35. stride=1,
  36. bias_on=False,
  37. norm_decay=None)
  38. def forward(self, inputs, residual=None):
  39. if residual is None:
  40. residual = inputs
  41. out = self.conv1(inputs)
  42. out = F.relu(out)
  43. out = self.conv2(out)
  44. out = paddle.add(x=out, y=residual)
  45. out = F.relu(out)
  46. return out
  47. class Root(nn.Layer):
  48. def __init__(self, ch_in, ch_out, kernel_size, residual):
  49. super(Root, self).__init__()
  50. self.conv = ConvNormLayer(
  51. ch_in,
  52. ch_out,
  53. filter_size=1,
  54. stride=1,
  55. bias_on=False,
  56. norm_decay=None)
  57. self.residual = residual
  58. def forward(self, inputs):
  59. children = inputs
  60. out = self.conv(paddle.concat(inputs, axis=1))
  61. if self.residual:
  62. out = paddle.add(x=out, y=children[0])
  63. out = F.relu(out)
  64. return out
  65. class Tree(nn.Layer):
  66. def __init__(self,
  67. level,
  68. block,
  69. ch_in,
  70. ch_out,
  71. stride=1,
  72. level_root=False,
  73. root_dim=0,
  74. root_kernel_size=1,
  75. root_residual=False):
  76. super(Tree, self).__init__()
  77. if root_dim == 0:
  78. root_dim = 2 * ch_out
  79. if level_root:
  80. root_dim += ch_in
  81. if level == 1:
  82. self.tree1 = block(ch_in, ch_out, stride)
  83. self.tree2 = block(ch_out, ch_out, 1)
  84. else:
  85. self.tree1 = Tree(
  86. level - 1,
  87. block,
  88. ch_in,
  89. ch_out,
  90. stride,
  91. root_dim=0,
  92. root_kernel_size=root_kernel_size,
  93. root_residual=root_residual)
  94. self.tree2 = Tree(
  95. level - 1,
  96. block,
  97. ch_out,
  98. ch_out,
  99. 1,
  100. root_dim=root_dim + ch_out,
  101. root_kernel_size=root_kernel_size,
  102. root_residual=root_residual)
  103. if level == 1:
  104. self.root = Root(root_dim, ch_out, root_kernel_size, root_residual)
  105. self.level_root = level_root
  106. self.root_dim = root_dim
  107. self.downsample = None
  108. self.project = None
  109. self.level = level
  110. if stride > 1:
  111. self.downsample = nn.MaxPool2D(stride, stride=stride)
  112. if ch_in != ch_out:
  113. self.project = ConvNormLayer(
  114. ch_in,
  115. ch_out,
  116. filter_size=1,
  117. stride=1,
  118. bias_on=False,
  119. norm_decay=None)
  120. def forward(self, x, residual=None, children=None):
  121. children = [] if children is None else children
  122. bottom = self.downsample(x) if self.downsample else x
  123. residual = self.project(bottom) if self.project else bottom
  124. if self.level_root:
  125. children.append(bottom)
  126. x1 = self.tree1(x, residual)
  127. if self.level == 1:
  128. x2 = self.tree2(x1)
  129. x = self.root([x2, x1] + children)
  130. else:
  131. children.append(x1)
  132. x = self.tree2(x1, children=children)
  133. return x
  134. @register
  135. @serializable
  136. class DLA(nn.Layer):
  137. """
  138. DLA, see https://arxiv.org/pdf/1707.06484.pdf
  139. Args:
  140. depth (int): DLA depth, only support 34 now.
  141. residual_root (bool): whether use a reidual layer in the root block
  142. pre_img (bool): add pre_img, only used in CenterTrack
  143. pre_hm (bool): add pre_hm, only used in CenterTrack
  144. """
  145. def __init__(self,
  146. depth=34,
  147. residual_root=False,
  148. pre_img=False,
  149. pre_hm=False):
  150. super(DLA, self).__init__()
  151. assert depth == 34, 'Only support DLA with depth of 34 now.'
  152. if depth == 34:
  153. block = BasicBlock
  154. levels, channels = DLA_cfg[depth]
  155. self.channels = channels
  156. self.num_levels = len(levels)
  157. self.base_layer = nn.Sequential(
  158. ConvNormLayer(
  159. 3,
  160. channels[0],
  161. filter_size=7,
  162. stride=1,
  163. bias_on=False,
  164. norm_decay=None),
  165. nn.ReLU())
  166. self.level0 = self._make_conv_level(channels[0], channels[0], levels[0])
  167. self.level1 = self._make_conv_level(
  168. channels[0], channels[1], levels[1], stride=2)
  169. self.level2 = Tree(
  170. levels[2],
  171. block,
  172. channels[1],
  173. channels[2],
  174. 2,
  175. level_root=False,
  176. root_residual=residual_root)
  177. self.level3 = Tree(
  178. levels[3],
  179. block,
  180. channels[2],
  181. channels[3],
  182. 2,
  183. level_root=True,
  184. root_residual=residual_root)
  185. self.level4 = Tree(
  186. levels[4],
  187. block,
  188. channels[3],
  189. channels[4],
  190. 2,
  191. level_root=True,
  192. root_residual=residual_root)
  193. self.level5 = Tree(
  194. levels[5],
  195. block,
  196. channels[4],
  197. channels[5],
  198. 2,
  199. level_root=True,
  200. root_residual=residual_root)
  201. if pre_img:
  202. self.pre_img_layer = nn.Sequential(
  203. ConvNormLayer(
  204. 3,
  205. channels[0],
  206. filter_size=7,
  207. stride=1,
  208. bias_on=False,
  209. norm_decay=None),
  210. nn.ReLU())
  211. if pre_hm:
  212. self.pre_hm_layer = nn.Sequential(
  213. ConvNormLayer(
  214. 1,
  215. channels[0],
  216. filter_size=7,
  217. stride=1,
  218. bias_on=False,
  219. norm_decay=None),
  220. nn.ReLU())
  221. self.pre_img = pre_img
  222. self.pre_hm = pre_hm
  223. def _make_conv_level(self, ch_in, ch_out, conv_num, stride=1):
  224. modules = []
  225. for i in range(conv_num):
  226. modules.extend([
  227. ConvNormLayer(
  228. ch_in,
  229. ch_out,
  230. filter_size=3,
  231. stride=stride if i == 0 else 1,
  232. bias_on=False,
  233. norm_decay=None), nn.ReLU()
  234. ])
  235. ch_in = ch_out
  236. return nn.Sequential(*modules)
  237. @property
  238. def out_shape(self):
  239. return [
  240. ShapeSpec(channels=self.channels[i]) for i in range(self.num_levels)
  241. ]
  242. def forward(self, inputs):
  243. outs = []
  244. feats = self.base_layer(inputs['image'])
  245. if self.pre_img and 'pre_image' in inputs and inputs[
  246. 'pre_image'] is not None:
  247. feats = feats + self.pre_img_layer(inputs['pre_image'])
  248. if self.pre_hm and 'pre_hm' in inputs and inputs['pre_hm'] is not None:
  249. feats = feats + self.pre_hm_layer(inputs['pre_hm'])
  250. for i in range(self.num_levels):
  251. feats = getattr(self, 'level{}'.format(i))(feats)
  252. outs.append(feats)
  253. return outs