senet.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle.nn as nn
  15. from ppdet.core.workspace import register, serializable
  16. from .resnet import ResNet, Blocks, BasicBlock, BottleNeck
  17. from ..shape_spec import ShapeSpec
  18. from .name_adapter import NameAdapter
  19. __all__ = ['SENet', 'SERes5Head']
  20. @register
  21. @serializable
  22. class SENet(ResNet):
  23. __shared__ = ['norm_type']
  24. def __init__(self,
  25. depth=50,
  26. variant='b',
  27. lr_mult_list=[1.0, 1.0, 1.0, 1.0],
  28. groups=1,
  29. base_width=64,
  30. norm_type='bn',
  31. norm_decay=0,
  32. freeze_norm=True,
  33. freeze_at=0,
  34. return_idx=[0, 1, 2, 3],
  35. dcn_v2_stages=[-1],
  36. std_senet=True,
  37. num_stages=4):
  38. """
  39. Squeeze-and-Excitation Networks, see https://arxiv.org/abs/1709.01507
  40. Args:
  41. depth (int): SENet depth, should be 50, 101, 152
  42. variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
  43. lr_mult_list (list): learning rate ratio of different resnet stages(2,3,4,5),
  44. lower learning rate ratio is need for pretrained model
  45. got using distillation(default as [1.0, 1.0, 1.0, 1.0]).
  46. groups (int): group convolution cardinality
  47. base_width (int): base width of each group convolution
  48. norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
  49. norm_decay (float): weight decay for normalization layer weights
  50. freeze_norm (bool): freeze normalization layers
  51. freeze_at (int): freeze the backbone at which stage
  52. return_idx (list): index of the stages whose feature maps are returned
  53. dcn_v2_stages (list): index of stages who select deformable conv v2
  54. std_senet (bool): whether use senet, default True
  55. num_stages (int): total num of stages
  56. """
  57. super(SENet, self).__init__(
  58. depth=depth,
  59. variant=variant,
  60. lr_mult_list=lr_mult_list,
  61. ch_in=128,
  62. groups=groups,
  63. base_width=base_width,
  64. norm_type=norm_type,
  65. norm_decay=norm_decay,
  66. freeze_norm=freeze_norm,
  67. freeze_at=freeze_at,
  68. return_idx=return_idx,
  69. dcn_v2_stages=dcn_v2_stages,
  70. std_senet=std_senet,
  71. num_stages=num_stages)
  72. @register
  73. class SERes5Head(nn.Layer):
  74. def __init__(self,
  75. depth=50,
  76. variant='b',
  77. lr_mult=1.0,
  78. groups=1,
  79. base_width=64,
  80. norm_type='bn',
  81. norm_decay=0,
  82. dcn_v2=False,
  83. freeze_norm=False,
  84. std_senet=True):
  85. """
  86. SERes5Head layer
  87. Args:
  88. depth (int): SENet depth, should be 50, 101, 152
  89. variant (str): ResNet variant, supports 'a', 'b', 'c', 'd' currently
  90. lr_mult (list): learning rate ratio of SERes5Head, default as 1.0.
  91. groups (int): group convolution cardinality
  92. base_width (int): base width of each group convolution
  93. norm_type (str): normalization type, 'bn', 'sync_bn' or 'affine_channel'
  94. norm_decay (float): weight decay for normalization layer weights
  95. dcn_v2_stages (list): index of stages who select deformable conv v2
  96. std_senet (bool): whether use senet, default True
  97. """
  98. super(SERes5Head, self).__init__()
  99. ch_out = 512
  100. ch_in = 256 if depth < 50 else 1024
  101. na = NameAdapter(self)
  102. block = BottleNeck if depth >= 50 else BasicBlock
  103. self.res5 = Blocks(
  104. block,
  105. ch_in,
  106. ch_out,
  107. count=3,
  108. name_adapter=na,
  109. stage_num=5,
  110. variant=variant,
  111. groups=groups,
  112. base_width=base_width,
  113. lr=lr_mult,
  114. norm_type=norm_type,
  115. norm_decay=norm_decay,
  116. freeze_norm=freeze_norm,
  117. dcn_v2=dcn_v2,
  118. std_senet=std_senet)
  119. self.ch_out = ch_out * block.expansion
  120. @property
  121. def out_shape(self):
  122. return [ShapeSpec(
  123. channels=self.ch_out,
  124. stride=16, )]
  125. def forward(self, roi_feat):
  126. y = self.res5(roi_feat)
  127. return y