optimizer.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355
  1. # Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import absolute_import
  15. from __future__ import division
  16. from __future__ import print_function
  17. import sys
  18. import math
  19. import paddle
  20. import paddle.nn as nn
  21. import paddle.optimizer as optimizer
  22. import paddle.regularizer as regularizer
  23. from ppdet.core.workspace import register, serializable
  24. import copy
  25. from .adamw import AdamWDL, build_adamwdl
  26. __all__ = ['LearningRate', 'OptimizerBuilder']
  27. from ppdet.utils.logger import setup_logger
  28. logger = setup_logger(__name__)
  29. @serializable
  30. class CosineDecay(object):
  31. """
  32. Cosine learning rate decay
  33. Args:
  34. max_epochs (int): max epochs for the training process.
  35. if you commbine cosine decay with warmup, it is recommended that
  36. the max_iters is much larger than the warmup iter
  37. use_warmup (bool): whether to use warmup. Default: True.
  38. min_lr_ratio (float): minimum learning rate ratio. Default: 0.
  39. last_plateau_epochs (int): use minimum learning rate in
  40. the last few epochs. Default: 0.
  41. """
  42. def __init__(self,
  43. max_epochs=1000,
  44. use_warmup=True,
  45. min_lr_ratio=0.,
  46. last_plateau_epochs=0):
  47. self.max_epochs = max_epochs
  48. self.use_warmup = use_warmup
  49. self.min_lr_ratio = min_lr_ratio
  50. self.last_plateau_epochs = last_plateau_epochs
  51. def __call__(self,
  52. base_lr=None,
  53. boundary=None,
  54. value=None,
  55. step_per_epoch=None):
  56. assert base_lr is not None, "either base LR or values should be provided"
  57. max_iters = self.max_epochs * int(step_per_epoch)
  58. last_plateau_iters = self.last_plateau_epochs * int(step_per_epoch)
  59. min_lr = base_lr * self.min_lr_ratio
  60. if boundary is not None and value is not None and self.use_warmup:
  61. # use warmup
  62. warmup_iters = len(boundary)
  63. for i in range(int(boundary[-1]), max_iters):
  64. boundary.append(i)
  65. if i < max_iters - last_plateau_iters:
  66. decayed_lr = min_lr + (base_lr - min_lr) * 0.5 * (math.cos(
  67. (i - warmup_iters) * math.pi /
  68. (max_iters - warmup_iters - last_plateau_iters)) + 1)
  69. value.append(decayed_lr)
  70. else:
  71. value.append(min_lr)
  72. return optimizer.lr.PiecewiseDecay(boundary, value)
  73. elif last_plateau_iters > 0:
  74. # not use warmup, but set `last_plateau_epochs` > 0
  75. boundary = []
  76. value = []
  77. for i in range(max_iters):
  78. if i < max_iters - last_plateau_iters:
  79. decayed_lr = min_lr + (base_lr - min_lr) * 0.5 * (math.cos(
  80. i * math.pi / (max_iters - last_plateau_iters)) + 1)
  81. value.append(decayed_lr)
  82. else:
  83. value.append(min_lr)
  84. if i > 0:
  85. boundary.append(i)
  86. return optimizer.lr.PiecewiseDecay(boundary, value)
  87. return optimizer.lr.CosineAnnealingDecay(
  88. base_lr, T_max=max_iters, eta_min=min_lr)
  89. @serializable
  90. class PiecewiseDecay(object):
  91. """
  92. Multi step learning rate decay
  93. Args:
  94. gamma (float | list): decay factor
  95. milestones (list): steps at which to decay learning rate
  96. """
  97. def __init__(self,
  98. gamma=[0.1, 0.01],
  99. milestones=[8, 11],
  100. values=None,
  101. use_warmup=True):
  102. super(PiecewiseDecay, self).__init__()
  103. if type(gamma) is not list:
  104. self.gamma = []
  105. for i in range(len(milestones)):
  106. self.gamma.append(gamma / 10**i)
  107. else:
  108. self.gamma = gamma
  109. self.milestones = milestones
  110. self.values = values
  111. self.use_warmup = use_warmup
  112. def __call__(self,
  113. base_lr=None,
  114. boundary=None,
  115. value=None,
  116. step_per_epoch=None):
  117. if boundary is not None and self.use_warmup:
  118. boundary.extend([int(step_per_epoch) * i for i in self.milestones])
  119. else:
  120. # do not use LinearWarmup
  121. boundary = [int(step_per_epoch) * i for i in self.milestones]
  122. value = [base_lr] # during step[0, boundary[0]] is base_lr
  123. # self.values is setted directly in config
  124. if self.values is not None:
  125. assert len(self.milestones) + 1 == len(self.values)
  126. return optimizer.lr.PiecewiseDecay(boundary, self.values)
  127. # value is computed by self.gamma
  128. value = value if value is not None else [base_lr]
  129. for i in self.gamma:
  130. value.append(base_lr * i)
  131. return optimizer.lr.PiecewiseDecay(boundary, value)
  132. @serializable
  133. class LinearWarmup(object):
  134. """
  135. Warm up learning rate linearly
  136. Args:
  137. steps (int): warm up steps
  138. start_factor (float): initial learning rate factor
  139. epochs (int|None): use epochs as warm up steps, the priority
  140. of `epochs` is higher than `steps`. Default: None.
  141. """
  142. def __init__(self, steps=500, start_factor=1. / 3, epochs=None):
  143. super(LinearWarmup, self).__init__()
  144. self.steps = steps
  145. self.start_factor = start_factor
  146. self.epochs = epochs
  147. def __call__(self, base_lr, step_per_epoch):
  148. boundary = []
  149. value = []
  150. warmup_steps = self.epochs * step_per_epoch \
  151. if self.epochs is not None else self.steps
  152. warmup_steps = max(warmup_steps, 1)
  153. for i in range(warmup_steps + 1):
  154. if warmup_steps > 0:
  155. alpha = i / warmup_steps
  156. factor = self.start_factor * (1 - alpha) + alpha
  157. lr = base_lr * factor
  158. value.append(lr)
  159. if i > 0:
  160. boundary.append(i)
  161. return boundary, value
  162. @serializable
  163. class ExpWarmup(object):
  164. """
  165. Warm up learning rate in exponential mode
  166. Args:
  167. steps (int): warm up steps.
  168. epochs (int|None): use epochs as warm up steps, the priority
  169. of `epochs` is higher than `steps`. Default: None.
  170. power (int): Exponential coefficient. Default: 2.
  171. """
  172. def __init__(self, steps=1000, epochs=None, power=2):
  173. super(ExpWarmup, self).__init__()
  174. self.steps = steps
  175. self.epochs = epochs
  176. self.power = power
  177. def __call__(self, base_lr, step_per_epoch):
  178. boundary = []
  179. value = []
  180. warmup_steps = self.epochs * step_per_epoch if self.epochs is not None else self.steps
  181. warmup_steps = max(warmup_steps, 1)
  182. for i in range(warmup_steps + 1):
  183. factor = (i / float(warmup_steps))**self.power
  184. value.append(base_lr * factor)
  185. if i > 0:
  186. boundary.append(i)
  187. return boundary, value
  188. @register
  189. class LearningRate(object):
  190. """
  191. Learning Rate configuration
  192. Args:
  193. base_lr (float): base learning rate
  194. schedulers (list): learning rate schedulers
  195. """
  196. __category__ = 'optim'
  197. def __init__(self,
  198. base_lr=0.01,
  199. schedulers=[PiecewiseDecay(), LinearWarmup()]):
  200. super(LearningRate, self).__init__()
  201. self.base_lr = base_lr
  202. self.schedulers = []
  203. schedulers = copy.deepcopy(schedulers)
  204. for sched in schedulers:
  205. if isinstance(sched, dict):
  206. # support dict sched instantiate
  207. module = sys.modules[__name__]
  208. type = sched.pop("name")
  209. scheduler = getattr(module, type)(**sched)
  210. self.schedulers.append(scheduler)
  211. else:
  212. self.schedulers.append(sched)
  213. def __call__(self, step_per_epoch):
  214. assert len(self.schedulers) >= 1
  215. if not self.schedulers[0].use_warmup:
  216. return self.schedulers[0](base_lr=self.base_lr,
  217. step_per_epoch=step_per_epoch)
  218. # TODO: split warmup & decay
  219. # warmup
  220. boundary, value = self.schedulers[1](self.base_lr, step_per_epoch)
  221. # decay
  222. decay_lr = self.schedulers[0](self.base_lr, boundary, value,
  223. step_per_epoch)
  224. return decay_lr
  225. @register
  226. class OptimizerBuilder():
  227. """
  228. Build optimizer handles
  229. Args:
  230. regularizer (object): an `Regularizer` instance
  231. optimizer (object): an `Optimizer` instance
  232. """
  233. __category__ = 'optim'
  234. def __init__(self,
  235. clip_grad_by_norm=None,
  236. clip_grad_by_value=None,
  237. regularizer={'type': 'L2',
  238. 'factor': .0001},
  239. optimizer={'type': 'Momentum',
  240. 'momentum': .9}):
  241. self.clip_grad_by_norm = clip_grad_by_norm
  242. self.clip_grad_by_value = clip_grad_by_value
  243. self.regularizer = regularizer
  244. self.optimizer = optimizer
  245. def __call__(self, learning_rate, model=None):
  246. if self.clip_grad_by_norm is not None:
  247. grad_clip = nn.ClipGradByGlobalNorm(
  248. clip_norm=self.clip_grad_by_norm)
  249. elif self.clip_grad_by_value is not None:
  250. var = abs(self.clip_grad_by_value)
  251. grad_clip = nn.ClipGradByValue(min=-var, max=var)
  252. else:
  253. grad_clip = None
  254. if self.regularizer and self.regularizer != 'None':
  255. reg_type = self.regularizer['type'] + 'Decay'
  256. reg_factor = self.regularizer['factor']
  257. regularization = getattr(regularizer, reg_type)(reg_factor)
  258. else:
  259. regularization = None
  260. optim_args = self.optimizer.copy()
  261. optim_type = optim_args['type']
  262. del optim_args['type']
  263. if optim_type == 'AdamWDL':
  264. return build_adamwdl(model, lr=learning_rate, **optim_args)
  265. if optim_type != 'AdamW':
  266. optim_args['weight_decay'] = regularization
  267. op = getattr(optimizer, optim_type)
  268. if 'param_groups' in optim_args:
  269. assert isinstance(optim_args['param_groups'], list), ''
  270. param_groups = optim_args.pop('param_groups')
  271. params, visited = [], []
  272. for group in param_groups:
  273. assert isinstance(group,
  274. dict) and 'params' in group and isinstance(
  275. group['params'], list), ''
  276. _params = {
  277. n: p
  278. for n, p in model.named_parameters()
  279. if any([k in n
  280. for k in group['params']]) and p.trainable is True
  281. }
  282. _group = group.copy()
  283. _group.update({'params': list(_params.values())})
  284. params.append(_group)
  285. visited.extend(list(_params.keys()))
  286. ext_params = [
  287. p for n, p in model.named_parameters()
  288. if n not in visited and p.trainable is True
  289. ]
  290. if len(ext_params) < len(model.parameters()):
  291. params.append({'params': ext_params})
  292. elif len(ext_params) > len(model.parameters()):
  293. raise RuntimeError
  294. else:
  295. _params = model.parameters()
  296. params = [param for param in _params if param.trainable is True]
  297. return op(learning_rate=learning_rate,
  298. parameters=params,
  299. grad_clip=grad_clip,
  300. **optim_args)