hrnet.py 29 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle
  15. import paddle.nn as nn
  16. import paddle.nn.functional as F
  17. from paddle.nn import AdaptiveAvgPool2D, Linear
  18. from paddle.regularizer import L2Decay
  19. from paddle import ParamAttr
  20. from paddle.nn.initializer import Normal, Uniform
  21. from numbers import Integral
  22. import math
  23. from ppdet.core.workspace import register
  24. from ..shape_spec import ShapeSpec
  25. __all__ = ['HRNet']
  26. class ConvNormLayer(nn.Layer):
  27. def __init__(self,
  28. ch_in,
  29. ch_out,
  30. filter_size,
  31. stride=1,
  32. norm_type='bn',
  33. norm_groups=32,
  34. use_dcn=False,
  35. norm_momentum=0.9,
  36. norm_decay=0.,
  37. freeze_norm=False,
  38. act=None,
  39. name=None):
  40. super(ConvNormLayer, self).__init__()
  41. assert norm_type in ['bn', 'sync_bn', 'gn']
  42. self.act = act
  43. self.conv = nn.Conv2D(
  44. in_channels=ch_in,
  45. out_channels=ch_out,
  46. kernel_size=filter_size,
  47. stride=stride,
  48. padding=(filter_size - 1) // 2,
  49. groups=1,
  50. weight_attr=ParamAttr(initializer=Normal(
  51. mean=0., std=0.01)),
  52. bias_attr=False)
  53. norm_lr = 0. if freeze_norm else 1.
  54. param_attr = ParamAttr(
  55. learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
  56. bias_attr = ParamAttr(
  57. learning_rate=norm_lr, regularizer=L2Decay(norm_decay))
  58. global_stats = True if freeze_norm else None
  59. if norm_type in ['bn', 'sync_bn']:
  60. self.norm = nn.BatchNorm2D(
  61. ch_out,
  62. momentum=norm_momentum,
  63. weight_attr=param_attr,
  64. bias_attr=bias_attr,
  65. use_global_stats=global_stats)
  66. elif norm_type == 'gn':
  67. self.norm = nn.GroupNorm(
  68. num_groups=norm_groups,
  69. num_channels=ch_out,
  70. weight_attr=param_attr,
  71. bias_attr=bias_attr)
  72. norm_params = self.norm.parameters()
  73. if freeze_norm:
  74. for param in norm_params:
  75. param.stop_gradient = True
  76. def forward(self, inputs):
  77. out = self.conv(inputs)
  78. out = self.norm(out)
  79. if self.act == 'relu':
  80. out = F.relu(out)
  81. return out
  82. class Layer1(nn.Layer):
  83. def __init__(self,
  84. num_channels,
  85. has_se=False,
  86. norm_momentum=0.9,
  87. norm_decay=0.,
  88. freeze_norm=True,
  89. name=None):
  90. super(Layer1, self).__init__()
  91. self.bottleneck_block_list = []
  92. for i in range(4):
  93. bottleneck_block = self.add_sublayer(
  94. "block_{}_{}".format(name, i + 1),
  95. BottleneckBlock(
  96. num_channels=num_channels if i == 0 else 256,
  97. num_filters=64,
  98. has_se=has_se,
  99. stride=1,
  100. downsample=True if i == 0 else False,
  101. norm_momentum=norm_momentum,
  102. norm_decay=norm_decay,
  103. freeze_norm=freeze_norm,
  104. name=name + '_' + str(i + 1)))
  105. self.bottleneck_block_list.append(bottleneck_block)
  106. def forward(self, input):
  107. conv = input
  108. for block_func in self.bottleneck_block_list:
  109. conv = block_func(conv)
  110. return conv
  111. class TransitionLayer(nn.Layer):
  112. def __init__(self,
  113. in_channels,
  114. out_channels,
  115. norm_momentum=0.9,
  116. norm_decay=0.,
  117. freeze_norm=True,
  118. name=None):
  119. super(TransitionLayer, self).__init__()
  120. num_in = len(in_channels)
  121. num_out = len(out_channels)
  122. out = []
  123. self.conv_bn_func_list = []
  124. for i in range(num_out):
  125. residual = None
  126. if i < num_in:
  127. if in_channels[i] != out_channels[i]:
  128. residual = self.add_sublayer(
  129. "transition_{}_layer_{}".format(name, i + 1),
  130. ConvNormLayer(
  131. ch_in=in_channels[i],
  132. ch_out=out_channels[i],
  133. filter_size=3,
  134. norm_momentum=norm_momentum,
  135. norm_decay=norm_decay,
  136. freeze_norm=freeze_norm,
  137. act='relu',
  138. name=name + '_layer_' + str(i + 1)))
  139. else:
  140. residual = self.add_sublayer(
  141. "transition_{}_layer_{}".format(name, i + 1),
  142. ConvNormLayer(
  143. ch_in=in_channels[-1],
  144. ch_out=out_channels[i],
  145. filter_size=3,
  146. stride=2,
  147. norm_momentum=norm_momentum,
  148. norm_decay=norm_decay,
  149. freeze_norm=freeze_norm,
  150. act='relu',
  151. name=name + '_layer_' + str(i + 1)))
  152. self.conv_bn_func_list.append(residual)
  153. def forward(self, input):
  154. outs = []
  155. for idx, conv_bn_func in enumerate(self.conv_bn_func_list):
  156. if conv_bn_func is None:
  157. outs.append(input[idx])
  158. else:
  159. if idx < len(input):
  160. outs.append(conv_bn_func(input[idx]))
  161. else:
  162. outs.append(conv_bn_func(input[-1]))
  163. return outs
  164. class Branches(nn.Layer):
  165. def __init__(self,
  166. block_num,
  167. in_channels,
  168. out_channels,
  169. has_se=False,
  170. norm_momentum=0.9,
  171. norm_decay=0.,
  172. freeze_norm=True,
  173. name=None):
  174. super(Branches, self).__init__()
  175. self.basic_block_list = []
  176. for i in range(len(out_channels)):
  177. self.basic_block_list.append([])
  178. for j in range(block_num):
  179. in_ch = in_channels[i] if j == 0 else out_channels[i]
  180. basic_block_func = self.add_sublayer(
  181. "bb_{}_branch_layer_{}_{}".format(name, i + 1, j + 1),
  182. BasicBlock(
  183. num_channels=in_ch,
  184. num_filters=out_channels[i],
  185. has_se=has_se,
  186. norm_momentum=norm_momentum,
  187. norm_decay=norm_decay,
  188. freeze_norm=freeze_norm,
  189. name=name + '_branch_layer_' + str(i + 1) + '_' +
  190. str(j + 1)))
  191. self.basic_block_list[i].append(basic_block_func)
  192. def forward(self, inputs):
  193. outs = []
  194. for idx, input in enumerate(inputs):
  195. conv = input
  196. basic_block_list = self.basic_block_list[idx]
  197. for basic_block_func in basic_block_list:
  198. conv = basic_block_func(conv)
  199. outs.append(conv)
  200. return outs
  201. class BottleneckBlock(nn.Layer):
  202. def __init__(self,
  203. num_channels,
  204. num_filters,
  205. has_se,
  206. stride=1,
  207. downsample=False,
  208. norm_momentum=0.9,
  209. norm_decay=0.,
  210. freeze_norm=True,
  211. name=None):
  212. super(BottleneckBlock, self).__init__()
  213. self.has_se = has_se
  214. self.downsample = downsample
  215. self.conv1 = ConvNormLayer(
  216. ch_in=num_channels,
  217. ch_out=num_filters,
  218. filter_size=1,
  219. norm_momentum=norm_momentum,
  220. norm_decay=norm_decay,
  221. freeze_norm=freeze_norm,
  222. act="relu",
  223. name=name + "_conv1")
  224. self.conv2 = ConvNormLayer(
  225. ch_in=num_filters,
  226. ch_out=num_filters,
  227. filter_size=3,
  228. stride=stride,
  229. norm_momentum=norm_momentum,
  230. norm_decay=norm_decay,
  231. freeze_norm=freeze_norm,
  232. act="relu",
  233. name=name + "_conv2")
  234. self.conv3 = ConvNormLayer(
  235. ch_in=num_filters,
  236. ch_out=num_filters * 4,
  237. filter_size=1,
  238. norm_momentum=norm_momentum,
  239. norm_decay=norm_decay,
  240. freeze_norm=freeze_norm,
  241. act=None,
  242. name=name + "_conv3")
  243. if self.downsample:
  244. self.conv_down = ConvNormLayer(
  245. ch_in=num_channels,
  246. ch_out=num_filters * 4,
  247. filter_size=1,
  248. norm_momentum=norm_momentum,
  249. norm_decay=norm_decay,
  250. freeze_norm=freeze_norm,
  251. act=None,
  252. name=name + "_downsample")
  253. if self.has_se:
  254. self.se = SELayer(
  255. num_channels=num_filters * 4,
  256. num_filters=num_filters * 4,
  257. reduction_ratio=16,
  258. name='fc' + name)
  259. def forward(self, input):
  260. residual = input
  261. conv1 = self.conv1(input)
  262. conv2 = self.conv2(conv1)
  263. conv3 = self.conv3(conv2)
  264. if self.downsample:
  265. residual = self.conv_down(input)
  266. if self.has_se:
  267. conv3 = self.se(conv3)
  268. y = paddle.add(x=residual, y=conv3)
  269. y = F.relu(y)
  270. return y
  271. class BasicBlock(nn.Layer):
  272. def __init__(self,
  273. num_channels,
  274. num_filters,
  275. stride=1,
  276. has_se=False,
  277. downsample=False,
  278. norm_momentum=0.9,
  279. norm_decay=0.,
  280. freeze_norm=True,
  281. name=None):
  282. super(BasicBlock, self).__init__()
  283. self.has_se = has_se
  284. self.downsample = downsample
  285. self.conv1 = ConvNormLayer(
  286. ch_in=num_channels,
  287. ch_out=num_filters,
  288. filter_size=3,
  289. norm_momentum=norm_momentum,
  290. norm_decay=norm_decay,
  291. freeze_norm=freeze_norm,
  292. stride=stride,
  293. act="relu",
  294. name=name + "_conv1")
  295. self.conv2 = ConvNormLayer(
  296. ch_in=num_filters,
  297. ch_out=num_filters,
  298. filter_size=3,
  299. norm_momentum=norm_momentum,
  300. norm_decay=norm_decay,
  301. freeze_norm=freeze_norm,
  302. stride=1,
  303. act=None,
  304. name=name + "_conv2")
  305. if self.downsample:
  306. self.conv_down = ConvNormLayer(
  307. ch_in=num_channels,
  308. ch_out=num_filters * 4,
  309. filter_size=1,
  310. norm_momentum=norm_momentum,
  311. norm_decay=norm_decay,
  312. freeze_norm=freeze_norm,
  313. act=None,
  314. name=name + "_downsample")
  315. if self.has_se:
  316. self.se = SELayer(
  317. num_channels=num_filters,
  318. num_filters=num_filters,
  319. reduction_ratio=16,
  320. name='fc' + name)
  321. def forward(self, input):
  322. residual = input
  323. conv1 = self.conv1(input)
  324. conv2 = self.conv2(conv1)
  325. if self.downsample:
  326. residual = self.conv_down(input)
  327. if self.has_se:
  328. conv2 = self.se(conv2)
  329. y = paddle.add(x=residual, y=conv2)
  330. y = F.relu(y)
  331. return y
  332. class SELayer(nn.Layer):
  333. def __init__(self, num_channels, num_filters, reduction_ratio, name=None):
  334. super(SELayer, self).__init__()
  335. self.pool2d_gap = AdaptiveAvgPool2D(1)
  336. self._num_channels = num_channels
  337. med_ch = int(num_channels / reduction_ratio)
  338. stdv = 1.0 / math.sqrt(num_channels * 1.0)
  339. self.squeeze = Linear(
  340. num_channels,
  341. med_ch,
  342. weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
  343. stdv = 1.0 / math.sqrt(med_ch * 1.0)
  344. self.excitation = Linear(
  345. med_ch,
  346. num_filters,
  347. weight_attr=ParamAttr(initializer=Uniform(-stdv, stdv)))
  348. def forward(self, input):
  349. pool = self.pool2d_gap(input)
  350. pool = paddle.squeeze(pool, axis=[2, 3])
  351. squeeze = self.squeeze(pool)
  352. squeeze = F.relu(squeeze)
  353. excitation = self.excitation(squeeze)
  354. excitation = F.sigmoid(excitation)
  355. excitation = paddle.unsqueeze(excitation, axis=[2, 3])
  356. out = input * excitation
  357. return out
  358. class Stage(nn.Layer):
  359. def __init__(self,
  360. num_channels,
  361. num_modules,
  362. num_filters,
  363. has_se=False,
  364. norm_momentum=0.9,
  365. norm_decay=0.,
  366. freeze_norm=True,
  367. multi_scale_output=True,
  368. name=None):
  369. super(Stage, self).__init__()
  370. self._num_modules = num_modules
  371. self.stage_func_list = []
  372. for i in range(num_modules):
  373. if i == num_modules - 1 and not multi_scale_output:
  374. stage_func = self.add_sublayer(
  375. "stage_{}_{}".format(name, i + 1),
  376. HighResolutionModule(
  377. num_channels=num_channels,
  378. num_filters=num_filters,
  379. has_se=has_se,
  380. norm_momentum=norm_momentum,
  381. norm_decay=norm_decay,
  382. freeze_norm=freeze_norm,
  383. multi_scale_output=False,
  384. name=name + '_' + str(i + 1)))
  385. else:
  386. stage_func = self.add_sublayer(
  387. "stage_{}_{}".format(name, i + 1),
  388. HighResolutionModule(
  389. num_channels=num_channels,
  390. num_filters=num_filters,
  391. has_se=has_se,
  392. norm_momentum=norm_momentum,
  393. norm_decay=norm_decay,
  394. freeze_norm=freeze_norm,
  395. name=name + '_' + str(i + 1)))
  396. self.stage_func_list.append(stage_func)
  397. def forward(self, input):
  398. out = input
  399. for idx in range(self._num_modules):
  400. out = self.stage_func_list[idx](out)
  401. return out
  402. class HighResolutionModule(nn.Layer):
  403. def __init__(self,
  404. num_channels,
  405. num_filters,
  406. has_se=False,
  407. multi_scale_output=True,
  408. norm_momentum=0.9,
  409. norm_decay=0.,
  410. freeze_norm=True,
  411. name=None):
  412. super(HighResolutionModule, self).__init__()
  413. self.branches_func = Branches(
  414. block_num=4,
  415. in_channels=num_channels,
  416. out_channels=num_filters,
  417. has_se=has_se,
  418. norm_momentum=norm_momentum,
  419. norm_decay=norm_decay,
  420. freeze_norm=freeze_norm,
  421. name=name)
  422. self.fuse_func = FuseLayers(
  423. in_channels=num_filters,
  424. out_channels=num_filters,
  425. multi_scale_output=multi_scale_output,
  426. norm_momentum=norm_momentum,
  427. norm_decay=norm_decay,
  428. freeze_norm=freeze_norm,
  429. name=name)
  430. def forward(self, input):
  431. out = self.branches_func(input)
  432. out = self.fuse_func(out)
  433. return out
  434. class FuseLayers(nn.Layer):
  435. def __init__(self,
  436. in_channels,
  437. out_channels,
  438. multi_scale_output=True,
  439. norm_momentum=0.9,
  440. norm_decay=0.,
  441. freeze_norm=True,
  442. name=None):
  443. super(FuseLayers, self).__init__()
  444. self._actual_ch = len(in_channels) if multi_scale_output else 1
  445. self._in_channels = in_channels
  446. self.residual_func_list = []
  447. for i in range(self._actual_ch):
  448. for j in range(len(in_channels)):
  449. residual_func = None
  450. if j > i:
  451. residual_func = self.add_sublayer(
  452. "residual_{}_layer_{}_{}".format(name, i + 1, j + 1),
  453. ConvNormLayer(
  454. ch_in=in_channels[j],
  455. ch_out=out_channels[i],
  456. filter_size=1,
  457. stride=1,
  458. act=None,
  459. norm_momentum=norm_momentum,
  460. norm_decay=norm_decay,
  461. freeze_norm=freeze_norm,
  462. name=name + '_layer_' + str(i + 1) + '_' +
  463. str(j + 1)))
  464. self.residual_func_list.append(residual_func)
  465. elif j < i:
  466. pre_num_filters = in_channels[j]
  467. for k in range(i - j):
  468. if k == i - j - 1:
  469. residual_func = self.add_sublayer(
  470. "residual_{}_layer_{}_{}_{}".format(
  471. name, i + 1, j + 1, k + 1),
  472. ConvNormLayer(
  473. ch_in=pre_num_filters,
  474. ch_out=out_channels[i],
  475. filter_size=3,
  476. stride=2,
  477. norm_momentum=norm_momentum,
  478. norm_decay=norm_decay,
  479. freeze_norm=freeze_norm,
  480. act=None,
  481. name=name + '_layer_' + str(i + 1) + '_' +
  482. str(j + 1) + '_' + str(k + 1)))
  483. pre_num_filters = out_channels[i]
  484. else:
  485. residual_func = self.add_sublayer(
  486. "residual_{}_layer_{}_{}_{}".format(
  487. name, i + 1, j + 1, k + 1),
  488. ConvNormLayer(
  489. ch_in=pre_num_filters,
  490. ch_out=out_channels[j],
  491. filter_size=3,
  492. stride=2,
  493. norm_momentum=norm_momentum,
  494. norm_decay=norm_decay,
  495. freeze_norm=freeze_norm,
  496. act="relu",
  497. name=name + '_layer_' + str(i + 1) + '_' +
  498. str(j + 1) + '_' + str(k + 1)))
  499. pre_num_filters = out_channels[j]
  500. self.residual_func_list.append(residual_func)
  501. def forward(self, input):
  502. outs = []
  503. residual_func_idx = 0
  504. for i in range(self._actual_ch):
  505. residual = input[i]
  506. for j in range(len(self._in_channels)):
  507. if j > i:
  508. y = self.residual_func_list[residual_func_idx](input[j])
  509. residual_func_idx += 1
  510. y = F.interpolate(y, scale_factor=2**(j - i))
  511. residual = paddle.add(x=residual, y=y)
  512. elif j < i:
  513. y = input[j]
  514. for k in range(i - j):
  515. y = self.residual_func_list[residual_func_idx](y)
  516. residual_func_idx += 1
  517. residual = paddle.add(x=residual, y=y)
  518. residual = F.relu(residual)
  519. outs.append(residual)
  520. return outs
  521. @register
  522. class HRNet(nn.Layer):
  523. """
  524. HRNet, see https://arxiv.org/abs/1908.07919
  525. Args:
  526. width (int): the width of HRNet
  527. has_se (bool): whether to add SE block for each stage
  528. freeze_at (int): the stage to freeze
  529. freeze_norm (bool): whether to freeze norm in HRNet
  530. norm_momentum (float): momentum of BatchNorm
  531. norm_decay (float): weight decay for normalization layer weights
  532. return_idx (List): the stage to return
  533. upsample (bool): whether to upsample and concat the backbone feats
  534. """
  535. def __init__(self,
  536. width=18,
  537. has_se=False,
  538. freeze_at=0,
  539. freeze_norm=True,
  540. norm_momentum=0.9,
  541. norm_decay=0.,
  542. return_idx=[0, 1, 2, 3],
  543. upsample=False,
  544. downsample=False):
  545. super(HRNet, self).__init__()
  546. self.width = width
  547. self.has_se = has_se
  548. if isinstance(return_idx, Integral):
  549. return_idx = [return_idx]
  550. assert len(return_idx) > 0, "need one or more return index"
  551. self.freeze_at = freeze_at
  552. self.return_idx = return_idx
  553. self.upsample = upsample
  554. self.downsample = downsample
  555. self.channels = {
  556. 18: [[18, 36], [18, 36, 72], [18, 36, 72, 144]],
  557. 30: [[30, 60], [30, 60, 120], [30, 60, 120, 240]],
  558. 32: [[32, 64], [32, 64, 128], [32, 64, 128, 256]],
  559. 40: [[40, 80], [40, 80, 160], [40, 80, 160, 320]],
  560. 44: [[44, 88], [44, 88, 176], [44, 88, 176, 352]],
  561. 48: [[48, 96], [48, 96, 192], [48, 96, 192, 384]],
  562. 60: [[60, 120], [60, 120, 240], [60, 120, 240, 480]],
  563. 64: [[64, 128], [64, 128, 256], [64, 128, 256, 512]]
  564. }
  565. channels_2, channels_3, channels_4 = self.channels[width]
  566. num_modules_2, num_modules_3, num_modules_4 = 1, 4, 3
  567. self._out_channels = [sum(channels_4)] if self.upsample else channels_4
  568. self._out_strides = [4] if self.upsample else [4, 8, 16, 32]
  569. self.conv_layer1_1 = ConvNormLayer(
  570. ch_in=3,
  571. ch_out=64,
  572. filter_size=3,
  573. stride=2,
  574. norm_momentum=norm_momentum,
  575. norm_decay=norm_decay,
  576. freeze_norm=freeze_norm,
  577. act='relu',
  578. name="layer1_1")
  579. self.conv_layer1_2 = ConvNormLayer(
  580. ch_in=64,
  581. ch_out=64,
  582. filter_size=3,
  583. stride=2,
  584. norm_momentum=norm_momentum,
  585. norm_decay=norm_decay,
  586. freeze_norm=freeze_norm,
  587. act='relu',
  588. name="layer1_2")
  589. self.la1 = Layer1(
  590. num_channels=64,
  591. has_se=has_se,
  592. norm_momentum=norm_momentum,
  593. norm_decay=norm_decay,
  594. freeze_norm=freeze_norm,
  595. name="layer2")
  596. self.tr1 = TransitionLayer(
  597. in_channels=[256],
  598. out_channels=channels_2,
  599. norm_momentum=norm_momentum,
  600. norm_decay=norm_decay,
  601. freeze_norm=freeze_norm,
  602. name="tr1")
  603. self.st2 = Stage(
  604. num_channels=channels_2,
  605. num_modules=num_modules_2,
  606. num_filters=channels_2,
  607. has_se=self.has_se,
  608. norm_momentum=norm_momentum,
  609. norm_decay=norm_decay,
  610. freeze_norm=freeze_norm,
  611. name="st2")
  612. self.tr2 = TransitionLayer(
  613. in_channels=channels_2,
  614. out_channels=channels_3,
  615. norm_momentum=norm_momentum,
  616. norm_decay=norm_decay,
  617. freeze_norm=freeze_norm,
  618. name="tr2")
  619. self.st3 = Stage(
  620. num_channels=channels_3,
  621. num_modules=num_modules_3,
  622. num_filters=channels_3,
  623. has_se=self.has_se,
  624. norm_momentum=norm_momentum,
  625. norm_decay=norm_decay,
  626. freeze_norm=freeze_norm,
  627. name="st3")
  628. self.tr3 = TransitionLayer(
  629. in_channels=channels_3,
  630. out_channels=channels_4,
  631. norm_momentum=norm_momentum,
  632. norm_decay=norm_decay,
  633. freeze_norm=freeze_norm,
  634. name="tr3")
  635. self.st4 = Stage(
  636. num_channels=channels_4,
  637. num_modules=num_modules_4,
  638. num_filters=channels_4,
  639. has_se=self.has_se,
  640. norm_momentum=norm_momentum,
  641. norm_decay=norm_decay,
  642. freeze_norm=freeze_norm,
  643. multi_scale_output=len(return_idx) > 1,
  644. name="st4")
  645. if self.downsample:
  646. self.incre_modules, self.downsamp_modules, \
  647. self.final_layer = self._make_head(channels_4, norm_momentum=norm_momentum, has_se=self.has_se)
  648. def _make_layer(self,
  649. block,
  650. inplanes,
  651. planes,
  652. blocks,
  653. stride=1,
  654. norm_momentum=0.9,
  655. has_se=False,
  656. name=None):
  657. downsample = None
  658. if stride != 1 or inplanes != planes * 4:
  659. downsample = True
  660. layers = []
  661. layers.append(
  662. block(
  663. inplanes,
  664. planes,
  665. has_se,
  666. stride,
  667. downsample,
  668. norm_momentum=norm_momentum,
  669. freeze_norm=False,
  670. name=name + "_s0"))
  671. inplanes = planes * 4
  672. for i in range(1, blocks):
  673. layers.append(
  674. block(
  675. inplanes,
  676. planes,
  677. has_se,
  678. norm_momentum=norm_momentum,
  679. freeze_norm=False,
  680. name=name + "_s" + str(i)))
  681. return nn.Sequential(*layers)
  682. def _make_head(self, pre_stage_channels, norm_momentum=0.9, has_se=False):
  683. head_block = BottleneckBlock
  684. head_channels = [32, 64, 128, 256]
  685. # Increasing the #channels on each resolution
  686. # from C, 2C, 4C, 8C to 128, 256, 512, 1024
  687. incre_modules = []
  688. for i, channels in enumerate(pre_stage_channels):
  689. incre_module = self._make_layer(
  690. head_block,
  691. channels,
  692. head_channels[i],
  693. 1,
  694. stride=1,
  695. norm_momentum=norm_momentum,
  696. has_se=has_se,
  697. name='incre' + str(i))
  698. incre_modules.append(incre_module)
  699. incre_modules = nn.LayerList(incre_modules)
  700. # downsampling modules
  701. downsamp_modules = []
  702. for i in range(len(pre_stage_channels) - 1):
  703. in_channels = head_channels[i] * 4
  704. out_channels = head_channels[i + 1] * 4
  705. downsamp_module = nn.Sequential(
  706. nn.Conv2D(
  707. in_channels=in_channels,
  708. out_channels=out_channels,
  709. kernel_size=3,
  710. stride=2,
  711. padding=1),
  712. nn.BatchNorm2D(
  713. out_channels, momentum=norm_momentum),
  714. nn.ReLU())
  715. downsamp_modules.append(downsamp_module)
  716. downsamp_modules = nn.LayerList(downsamp_modules)
  717. final_layer = nn.Sequential(
  718. nn.Conv2D(
  719. in_channels=head_channels[3] * 4,
  720. out_channels=2048,
  721. kernel_size=1,
  722. stride=1,
  723. padding=0),
  724. nn.BatchNorm2D(
  725. 2048, momentum=norm_momentum),
  726. nn.ReLU())
  727. return incre_modules, downsamp_modules, final_layer
  728. def forward(self, inputs):
  729. x = inputs['image']
  730. conv1 = self.conv_layer1_1(x)
  731. conv2 = self.conv_layer1_2(conv1)
  732. la1 = self.la1(conv2)
  733. tr1 = self.tr1([la1])
  734. st2 = self.st2(tr1)
  735. tr2 = self.tr2(st2)
  736. st3 = self.st3(tr2)
  737. tr3 = self.tr3(st3)
  738. st4 = self.st4(tr3)
  739. if self.upsample:
  740. # Upsampling
  741. x0_h, x0_w = st4[0].shape[2:4]
  742. x1 = F.upsample(st4[1], size=(x0_h, x0_w), mode='bilinear')
  743. x2 = F.upsample(st4[2], size=(x0_h, x0_w), mode='bilinear')
  744. x3 = F.upsample(st4[3], size=(x0_h, x0_w), mode='bilinear')
  745. x = paddle.concat([st4[0], x1, x2, x3], 1)
  746. return x
  747. if self.downsample:
  748. y = self.incre_modules[0](st4[0])
  749. for i in range(len(self.downsamp_modules)):
  750. y = self.incre_modules[i+1](st4[i+1]) + \
  751. self.downsamp_modules[i](y)
  752. y = self.final_layer(y)
  753. return y
  754. res = []
  755. for i, layer in enumerate(st4):
  756. if i == self.freeze_at:
  757. layer.stop_gradient = True
  758. if i in self.return_idx:
  759. res.append(layer)
  760. return res
  761. @property
  762. def out_shape(self):
  763. if self.upsample:
  764. self.return_idx = [0]
  765. return [
  766. ShapeSpec(
  767. channels=self._out_channels[i], stride=self._out_strides[i])
  768. for i in self.return_idx
  769. ]