initializer.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. This code is based on https://github.com/pytorch/pytorch/blob/master/torch/nn/init.py
  16. Ths copyright of pytorch/pytorch is a BSD-style license, as found in the LICENSE file.
  17. """
  18. import math
  19. import numpy as np
  20. import paddle
  21. import paddle.nn as nn
  22. __all__ = [
  23. 'uniform_',
  24. 'normal_',
  25. 'constant_',
  26. 'ones_',
  27. 'zeros_',
  28. 'xavier_uniform_',
  29. 'xavier_normal_',
  30. 'kaiming_uniform_',
  31. 'kaiming_normal_',
  32. 'linear_init_',
  33. 'conv_init_',
  34. 'reset_initialized_parameter',
  35. ]
  36. def _no_grad_uniform_(tensor, a, b):
  37. with paddle.no_grad():
  38. tensor.set_value(
  39. paddle.uniform(
  40. shape=tensor.shape, dtype=tensor.dtype, min=a, max=b))
  41. return tensor
  42. def _no_grad_normal_(tensor, mean=0., std=1.):
  43. with paddle.no_grad():
  44. tensor.set_value(paddle.normal(mean=mean, std=std, shape=tensor.shape))
  45. return tensor
  46. def _no_grad_fill_(tensor, value=0.):
  47. with paddle.no_grad():
  48. tensor.set_value(paddle.full_like(tensor, value, dtype=tensor.dtype))
  49. return tensor
  50. def uniform_(tensor, a, b):
  51. """
  52. Modified tensor inspace using uniform_
  53. Args:
  54. tensor (paddle.Tensor): paddle Tensor
  55. a (float|int): min value.
  56. b (float|int): max value.
  57. Return:
  58. tensor
  59. """
  60. return _no_grad_uniform_(tensor, a, b)
  61. def normal_(tensor, mean=0., std=1.):
  62. """
  63. Modified tensor inspace using normal_
  64. Args:
  65. tensor (paddle.Tensor): paddle Tensor
  66. mean (float|int): mean value.
  67. std (float|int): std value.
  68. Return:
  69. tensor
  70. """
  71. return _no_grad_normal_(tensor, mean, std)
  72. def constant_(tensor, value=0.):
  73. """
  74. Modified tensor inspace using constant_
  75. Args:
  76. tensor (paddle.Tensor): paddle Tensor
  77. value (float|int): value to fill tensor.
  78. Return:
  79. tensor
  80. """
  81. return _no_grad_fill_(tensor, value)
  82. def ones_(tensor):
  83. """
  84. Modified tensor inspace using ones_
  85. Args:
  86. tensor (paddle.Tensor): paddle Tensor
  87. Return:
  88. tensor
  89. """
  90. return _no_grad_fill_(tensor, 1)
  91. def zeros_(tensor):
  92. """
  93. Modified tensor inspace using zeros_
  94. Args:
  95. tensor (paddle.Tensor): paddle Tensor
  96. Return:
  97. tensor
  98. """
  99. return _no_grad_fill_(tensor, 0)
  100. def vector_(tensor, vector):
  101. with paddle.no_grad():
  102. tensor.set_value(paddle.to_tensor(vector, dtype=tensor.dtype))
  103. return tensor
  104. def _calculate_fan_in_and_fan_out(tensor, reverse=False):
  105. """
  106. Calculate (fan_in, _fan_out) for tensor
  107. Args:
  108. tensor (Tensor): paddle.Tensor
  109. reverse (bool: False): tensor data format order, False by default as [fout, fin, ...]. e.g. : conv.weight [cout, cin, kh, kw] is False; linear.weight [cin, cout] is True
  110. Return:
  111. Tuple[fan_in, fan_out]
  112. """
  113. if tensor.ndim < 2:
  114. raise ValueError(
  115. "Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
  116. )
  117. if reverse:
  118. num_input_fmaps, num_output_fmaps = tensor.shape[0], tensor.shape[1]
  119. else:
  120. num_input_fmaps, num_output_fmaps = tensor.shape[1], tensor.shape[0]
  121. receptive_field_size = 1
  122. if tensor.ndim > 2:
  123. receptive_field_size = np.prod(tensor.shape[2:])
  124. fan_in = num_input_fmaps * receptive_field_size
  125. fan_out = num_output_fmaps * receptive_field_size
  126. return fan_in, fan_out
  127. def xavier_uniform_(tensor, gain=1., reverse=False):
  128. """
  129. Modified tensor inspace using xavier_uniform_
  130. Args:
  131. tensor (paddle.Tensor): paddle Tensor
  132. gain (float): super parameter, 1. default.
  133. reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
  134. Return:
  135. tensor
  136. """
  137. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse)
  138. std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
  139. k = math.sqrt(3.0) * std
  140. return _no_grad_uniform_(tensor, -k, k)
  141. def xavier_normal_(tensor, gain=1., reverse=False):
  142. """
  143. Modified tensor inspace using xavier_normal_
  144. Args:
  145. tensor (paddle.Tensor): paddle Tensor
  146. gain (float): super parameter, 1. default.
  147. reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
  148. Return:
  149. tensor
  150. """
  151. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse=reverse)
  152. std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
  153. return _no_grad_normal_(tensor, 0, std)
  154. # reference: https://pytorch.org/docs/stable/_modules/torch/nn/init.html
  155. def _calculate_correct_fan(tensor, mode, reverse=False):
  156. mode = mode.lower()
  157. valid_modes = ['fan_in', 'fan_out']
  158. if mode not in valid_modes:
  159. raise ValueError("Mode {} not supported, please use one of {}".format(
  160. mode, valid_modes))
  161. fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor, reverse)
  162. return fan_in if mode == 'fan_in' else fan_out
  163. def _calculate_gain(nonlinearity, param=None):
  164. linear_fns = [
  165. 'linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d',
  166. 'conv_transpose2d', 'conv_transpose3d'
  167. ]
  168. if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
  169. return 1
  170. elif nonlinearity == 'tanh':
  171. return 5.0 / 3
  172. elif nonlinearity == 'relu':
  173. return math.sqrt(2.0)
  174. elif nonlinearity == 'leaky_relu':
  175. if param is None:
  176. negative_slope = 0.01
  177. elif not isinstance(param, bool) and isinstance(
  178. param, int) or isinstance(param, float):
  179. # True/False are instances of int, hence check above
  180. negative_slope = param
  181. else:
  182. raise ValueError("negative_slope {} not a valid number".format(
  183. param))
  184. return math.sqrt(2.0 / (1 + negative_slope**2))
  185. elif nonlinearity == 'selu':
  186. return 3.0 / 4
  187. else:
  188. raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
  189. def kaiming_uniform_(tensor,
  190. a=0,
  191. mode='fan_in',
  192. nonlinearity='leaky_relu',
  193. reverse=False):
  194. """
  195. Modified tensor inspace using kaiming_uniform method
  196. Args:
  197. tensor (paddle.Tensor): paddle Tensor
  198. mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut
  199. nonlinearity (str): nonlinearity method name
  200. reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
  201. Return:
  202. tensor
  203. """
  204. fan = _calculate_correct_fan(tensor, mode, reverse)
  205. gain = _calculate_gain(nonlinearity, a)
  206. std = gain / math.sqrt(fan)
  207. k = math.sqrt(3.0) * std
  208. return _no_grad_uniform_(tensor, -k, k)
  209. def kaiming_normal_(tensor,
  210. a=0,
  211. mode='fan_in',
  212. nonlinearity='leaky_relu',
  213. reverse=False):
  214. """
  215. Modified tensor inspace using kaiming_normal_
  216. Args:
  217. tensor (paddle.Tensor): paddle Tensor
  218. mode (str): ['fan_in', 'fan_out'], 'fin_in' defalut
  219. nonlinearity (str): nonlinearity method name
  220. reverse (bool): reverse (bool: False): tensor data format order, False by default as [fout, fin, ...].
  221. Return:
  222. tensor
  223. """
  224. fan = _calculate_correct_fan(tensor, mode, reverse)
  225. gain = _calculate_gain(nonlinearity, a)
  226. std = gain / math.sqrt(fan)
  227. return _no_grad_normal_(tensor, 0, std)
  228. def linear_init_(module):
  229. bound = 1 / math.sqrt(module.weight.shape[0])
  230. uniform_(module.weight, -bound, bound)
  231. uniform_(module.bias, -bound, bound)
  232. def conv_init_(module):
  233. bound = 1 / np.sqrt(np.prod(module.weight.shape[1:]))
  234. uniform_(module.weight, -bound, bound)
  235. if module.bias is not None:
  236. uniform_(module.bias, -bound, bound)
  237. def bias_init_with_prob(prior_prob=0.01):
  238. """initialize conv/fc bias value according to a given probability value."""
  239. bias_init = float(-np.log((1 - prior_prob) / prior_prob))
  240. return bias_init
  241. @paddle.no_grad()
  242. def reset_initialized_parameter(model, include_self=True):
  243. """
  244. Reset initialized parameter using following method for [conv, linear, embedding, bn]
  245. Args:
  246. model (paddle.Layer): paddle Layer
  247. include_self (bool: False): include_self for Layer.named_sublayers method. Indicate whether including itself
  248. Return:
  249. None
  250. """
  251. for _, m in model.named_sublayers(include_self=include_self):
  252. if isinstance(m, nn.Conv2D):
  253. k = float(m._groups) / (m._in_channels * m._kernel_size[0] *
  254. m._kernel_size[1])
  255. k = math.sqrt(k)
  256. _no_grad_uniform_(m.weight, -k, k)
  257. if hasattr(m, 'bias') and getattr(m, 'bias') is not None:
  258. _no_grad_uniform_(m.bias, -k, k)
  259. elif isinstance(m, nn.Linear):
  260. k = math.sqrt(1. / m.weight.shape[0])
  261. _no_grad_uniform_(m.weight, -k, k)
  262. if hasattr(m, 'bias') and getattr(m, 'bias') is not None:
  263. _no_grad_uniform_(m.bias, -k, k)
  264. elif isinstance(m, nn.Embedding):
  265. _no_grad_normal_(m.weight, mean=0., std=1.)
  266. elif isinstance(m, (nn.BatchNorm2D, nn.LayerNorm)):
  267. _no_grad_fill_(m.weight, 1.)
  268. if hasattr(m, 'bias') and getattr(m, 'bias') is not None:
  269. _no_grad_fill_(m.bias, 0)