fpn.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import paddle.nn as nn
  15. import paddle.nn.functional as F
  16. from paddle import ParamAttr
  17. from paddle.nn.initializer import XavierUniform
  18. from ppdet.core.workspace import register, serializable
  19. from ppdet.modeling.layers import ConvNormLayer
  20. from ..shape_spec import ShapeSpec
  21. __all__ = ['FPN']
  22. @register
  23. @serializable
  24. class FPN(nn.Layer):
  25. """
  26. Feature Pyramid Network, see https://arxiv.org/abs/1612.03144
  27. Args:
  28. in_channels (list[int]): input channels of each level which can be
  29. derived from the output shape of backbone by from_config
  30. out_channel (int): output channel of each level
  31. spatial_scales (list[float]): the spatial scales between input feature
  32. maps and original input image which can be derived from the output
  33. shape of backbone by from_config
  34. has_extra_convs (bool): whether to add extra conv to the last level.
  35. default False
  36. extra_stage (int): the number of extra stages added to the last level.
  37. default 1
  38. use_c5 (bool): Whether to use c5 as the input of extra stage,
  39. otherwise p5 is used. default True
  40. norm_type (string|None): The normalization type in FPN module. If
  41. norm_type is None, norm will not be used after conv and if
  42. norm_type is string, bn, gn, sync_bn are available. default None
  43. norm_decay (float): weight decay for normalization layer weights.
  44. default 0.
  45. freeze_norm (bool): whether to freeze normalization layer.
  46. default False
  47. relu_before_extra_convs (bool): whether to add relu before extra convs.
  48. default False
  49. """
  50. def __init__(self,
  51. in_channels,
  52. out_channel,
  53. spatial_scales=[0.25, 0.125, 0.0625, 0.03125],
  54. has_extra_convs=False,
  55. extra_stage=1,
  56. use_c5=True,
  57. norm_type=None,
  58. norm_decay=0.,
  59. freeze_norm=False,
  60. relu_before_extra_convs=True):
  61. super(FPN, self).__init__()
  62. self.out_channel = out_channel
  63. for s in range(extra_stage):
  64. spatial_scales = spatial_scales + [spatial_scales[-1] / 2.]
  65. self.spatial_scales = spatial_scales
  66. self.has_extra_convs = has_extra_convs
  67. self.extra_stage = extra_stage
  68. self.use_c5 = use_c5
  69. self.relu_before_extra_convs = relu_before_extra_convs
  70. self.norm_type = norm_type
  71. self.norm_decay = norm_decay
  72. self.freeze_norm = freeze_norm
  73. self.lateral_convs = []
  74. self.fpn_convs = []
  75. fan = out_channel * 3 * 3
  76. # stage index 0,1,2,3 stands for res2,res3,res4,res5 on ResNet Backbone
  77. # 0 <= st_stage < ed_stage <= 3
  78. st_stage = 4 - len(in_channels)
  79. ed_stage = st_stage + len(in_channels) - 1
  80. for i in range(st_stage, ed_stage + 1):
  81. if i == 3:
  82. lateral_name = 'fpn_inner_res5_sum'
  83. else:
  84. lateral_name = 'fpn_inner_res{}_sum_lateral'.format(i + 2)
  85. in_c = in_channels[i - st_stage]
  86. if self.norm_type is not None:
  87. lateral = self.add_sublayer(
  88. lateral_name,
  89. ConvNormLayer(
  90. ch_in=in_c,
  91. ch_out=out_channel,
  92. filter_size=1,
  93. stride=1,
  94. norm_type=self.norm_type,
  95. norm_decay=self.norm_decay,
  96. freeze_norm=self.freeze_norm,
  97. initializer=XavierUniform(fan_out=in_c)))
  98. else:
  99. lateral = self.add_sublayer(
  100. lateral_name,
  101. nn.Conv2D(
  102. in_channels=in_c,
  103. out_channels=out_channel,
  104. kernel_size=1,
  105. weight_attr=ParamAttr(
  106. initializer=XavierUniform(fan_out=in_c))))
  107. self.lateral_convs.append(lateral)
  108. fpn_name = 'fpn_res{}_sum'.format(i + 2)
  109. if self.norm_type is not None:
  110. fpn_conv = self.add_sublayer(
  111. fpn_name,
  112. ConvNormLayer(
  113. ch_in=out_channel,
  114. ch_out=out_channel,
  115. filter_size=3,
  116. stride=1,
  117. norm_type=self.norm_type,
  118. norm_decay=self.norm_decay,
  119. freeze_norm=self.freeze_norm,
  120. initializer=XavierUniform(fan_out=fan)))
  121. else:
  122. fpn_conv = self.add_sublayer(
  123. fpn_name,
  124. nn.Conv2D(
  125. in_channels=out_channel,
  126. out_channels=out_channel,
  127. kernel_size=3,
  128. padding=1,
  129. weight_attr=ParamAttr(
  130. initializer=XavierUniform(fan_out=fan))))
  131. self.fpn_convs.append(fpn_conv)
  132. # add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
  133. if self.has_extra_convs:
  134. for i in range(self.extra_stage):
  135. lvl = ed_stage + 1 + i
  136. if i == 0 and self.use_c5:
  137. in_c = in_channels[-1]
  138. else:
  139. in_c = out_channel
  140. extra_fpn_name = 'fpn_{}'.format(lvl + 2)
  141. if self.norm_type is not None:
  142. extra_fpn_conv = self.add_sublayer(
  143. extra_fpn_name,
  144. ConvNormLayer(
  145. ch_in=in_c,
  146. ch_out=out_channel,
  147. filter_size=3,
  148. stride=2,
  149. norm_type=self.norm_type,
  150. norm_decay=self.norm_decay,
  151. freeze_norm=self.freeze_norm,
  152. initializer=XavierUniform(fan_out=fan)))
  153. else:
  154. extra_fpn_conv = self.add_sublayer(
  155. extra_fpn_name,
  156. nn.Conv2D(
  157. in_channels=in_c,
  158. out_channels=out_channel,
  159. kernel_size=3,
  160. stride=2,
  161. padding=1,
  162. weight_attr=ParamAttr(
  163. initializer=XavierUniform(fan_out=fan))))
  164. self.fpn_convs.append(extra_fpn_conv)
  165. @classmethod
  166. def from_config(cls, cfg, input_shape):
  167. return {
  168. 'in_channels': [i.channels for i in input_shape],
  169. 'spatial_scales': [1.0 / i.stride for i in input_shape],
  170. }
  171. def forward(self, body_feats):
  172. laterals = []
  173. num_levels = len(body_feats)
  174. for i in range(num_levels):
  175. laterals.append(self.lateral_convs[i](body_feats[i]))
  176. for i in range(1, num_levels):
  177. lvl = num_levels - i
  178. upsample = F.interpolate(
  179. laterals[lvl],
  180. scale_factor=2.,
  181. mode='nearest', )
  182. laterals[lvl - 1] += upsample
  183. fpn_output = []
  184. for lvl in range(num_levels):
  185. fpn_output.append(self.fpn_convs[lvl](laterals[lvl]))
  186. if self.extra_stage > 0:
  187. # use max pool to get more levels on top of outputs (Faster R-CNN, Mask R-CNN)
  188. if not self.has_extra_convs:
  189. assert self.extra_stage == 1, 'extra_stage should be 1 if FPN has not extra convs'
  190. fpn_output.append(F.max_pool2d(fpn_output[-1], 1, stride=2))
  191. # add extra conv levels for RetinaNet(use_c5)/FCOS(use_p5)
  192. else:
  193. if self.use_c5:
  194. extra_source = body_feats[-1]
  195. else:
  196. extra_source = fpn_output[-1]
  197. fpn_output.append(self.fpn_convs[num_levels](extra_source))
  198. for i in range(1, self.extra_stage):
  199. if self.relu_before_extra_convs:
  200. fpn_output.append(self.fpn_convs[num_levels + i](F.relu(
  201. fpn_output[-1])))
  202. else:
  203. fpn_output.append(self.fpn_convs[num_levels + i](
  204. fpn_output[-1]))
  205. return fpn_output
  206. @property
  207. def out_shape(self):
  208. return [
  209. ShapeSpec(
  210. channels=self.out_channel, stride=1. / s)
  211. for s in self.spatial_scales
  212. ]