from typing import cast, List, Optional, Dict import torch from torch import Tensor from .optimizer import (Optimizer, _use_grad_for_differentiable, _get_value, _stack_if_compiling, _dispatch_sqrt, _default_to_fused_or_foreach, _capturable_doc, _differentiable_doc, _foreach_doc, _maximize_doc) from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype __all__ = ['Adam', 'adam'] # TODO(crcrpar): Move this to soemwhere (e.g. torch/optim/_utils?) else when adding another fused optimizer. # NOTE(crcrpar): Almost the same as `_MultiDeviceReplicator` defined in # torch/cuda/amp/grad_scaler.py except for the key being str only for torch script. class _MultiDeviceReplicator: main_tensor: Tensor _per_device_tensors: Dict[str, Tensor] def __init__(self, main_tensor: Tensor) -> None: self.main_tensor = main_tensor self._per_device_tensors = {str(main_tensor.device): main_tensor} def get(self, device: str): if device in self._per_device_tensors: return self._per_device_tensors[device] tensor = self.main_tensor.to(device=device, non_blocking=True, copy=True) self._per_device_tensors[device] = tensor return tensor # todo(crcrpar): Move this to another place when adding another fused optimizer. def _get_fp16AMP_params( *, optimizer: Optimizer, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, device: torch.device, ) -> Optional[_MultiDeviceReplicator]: if grad_scaler is None: return None found_inf_dict = grad_scaler._check_inf_per_device(optimizer) # Combines found_inf tensors from all devices. As in GradScaler.update(), # tensors are combined on the scale's device, which is an arbitrary but # reasonable choice that avoids new context creation. found_infs = [f.to(device, non_blocking=True) for f in found_inf_dict.values()] assert len(found_infs) > 0, "No inf checks were recorded in _check_inf_per_device." with torch.no_grad(): found_inf_combined = cast(torch.Tensor, sum(found_infs)) return _MultiDeviceReplicator(found_inf_combined) class Adam(Optimizer): def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, *, foreach: Optional[bool] = None, maximize: bool = False, capturable: bool = False, differentiable: bool = False, fused: Optional[bool] = None): if not 0.0 <= lr: raise ValueError("Invalid learning rate: {}".format(lr)) if not 0.0 <= eps: raise ValueError("Invalid epsilon value: {}".format(eps)) if not 0.0 <= betas[0] < 1.0: raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) if not 0.0 <= betas[1] < 1.0: raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) if not 0.0 <= weight_decay: raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad, maximize=maximize, foreach=foreach, capturable=capturable, differentiable=differentiable, fused=fused) super(Adam, self).__init__(params, defaults) if fused: if differentiable: raise RuntimeError("`fused` cannot be `differentiable`") self._step_supports_amp_scaling = True # TODO(crcrpar): [low prec params & their higher prec copy] # Suppor AMP with FP16/BF16 model params which would need # higher prec copy of params to do update math in higher prec to # alleviate the loss of information. if not all( p.is_cuda and torch.is_floating_point(p) for pg in self.param_groups for p in pg['params'] ): raise RuntimeError("FusedAdam requires all the params to be CUDA, floating point") def __setstate__(self, state): super().__setstate__(state) for group in self.param_groups: group.setdefault('amsgrad', False) group.setdefault('maximize', False) group.setdefault('foreach', None) group.setdefault('capturable', False) group.setdefault('differentiable', False) group.setdefault('fused', None) state_values = list(self.state.values()) step_is_tensor = (len(state_values) != 0) and torch.is_tensor(state_values[0]['step']) if not step_is_tensor: for s in state_values: s['step'] = torch.tensor(float(s['step'])) def _init_group( self, group, grad_scaler, params_with_grad, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps ): grad_scale = None found_inf = None if group['fused'] and grad_scaler is not None: grad_scale = grad_scaler._get_scale_async() device = grad_scale.device grad_scale = _MultiDeviceReplicator(grad_scale) found_inf = _get_fp16AMP_params(optimizer=self, grad_scaler=grad_scaler, device=device) for p in group['params']: if p.grad is not None: params_with_grad.append(p) if p.grad.is_sparse: raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead') grads.append(p.grad) state = self.state[p] # Lazy state initialization if len(state) == 0: state['step'] = ( torch.zeros((1,), dtype=torch.float, device=p.device) if self.defaults['capturable'] or self.defaults['fused'] else torch.tensor(0.) ) # Exponential moving average of gradient values state['exp_avg'] = torch.zeros_like(p, memory_format=torch.preserve_format) # Exponential moving average of squared gradient values state['exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) if group['amsgrad']: # Maintains max of all exp. moving avg. of sq. grad. values state['max_exp_avg_sq'] = torch.zeros_like(p, memory_format=torch.preserve_format) exp_avgs.append(state['exp_avg']) exp_avg_sqs.append(state['exp_avg_sq']) if group['amsgrad']: max_exp_avg_sqs.append(state['max_exp_avg_sq']) if group['differentiable'] and state['step'].requires_grad: raise RuntimeError('`requires_grad` is not supported for `step` in differentiable mode') state_steps.append(state['step']) return grad_scale, found_inf @_use_grad_for_differentiable def step(self, closure=None, *, grad_scaler=None): """Performs a single optimization step. Args: closure (Callable, optional): A closure that reevaluates the model and returns the loss. grad_scaler (:class:`torch.cuda.amp.GradScaler`, optional): A GradScaler which is supplied from ``grad_scaler.step(optimizer)``. """ self._cuda_graph_capture_health_check() loss = None if closure is not None: with torch.enable_grad(): loss = closure() for group in self.param_groups: params_with_grad = [] grads = [] exp_avgs = [] exp_avg_sqs = [] max_exp_avg_sqs = [] state_steps = [] beta1, beta2 = group['betas'] grad_scale, found_inf = self._init_group( group, grad_scaler, params_with_grad, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps) adam(params_with_grad, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad=group['amsgrad'], beta1=beta1, beta2=beta2, lr=group['lr'], weight_decay=group['weight_decay'], eps=group['eps'], maximize=group['maximize'], foreach=group['foreach'], capturable=group['capturable'], differentiable=group['differentiable'], fused=group['fused'], grad_scale=grad_scale, found_inf=found_inf) return loss Adam.__doc__ = r"""Implements Adam algorithm. .. math:: \begin{aligned} &\rule{110mm}{0.4pt} \\ &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2 \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\ &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad}, \:\textit{maximize} \\ &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)}, v_0\leftarrow 0 \text{ (second moment)},\: \widehat{v_0}^{max}\leftarrow 0\\[-1.ex] &\rule{110mm}{0.4pt} \\ &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\ &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\ &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\ &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\ &\hspace{5mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\ &\hspace{5mm}\textbf{if} \: amsgrad \\ &\hspace{10mm}\widehat{v_t}^{max} \leftarrow \mathrm{max}(\widehat{v_t}^{max}, \widehat{v_t}) \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}^{max}} + \epsilon \big) \\ &\hspace{5mm}\textbf{else} \\ &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/ \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\ &\rule{110mm}{0.4pt} \\[-1.ex] &\bf{return} \: \theta_t \\[-1.ex] &\rule{110mm}{0.4pt} \\[-1.ex] \end{aligned} For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_. """ + r""" Args: params (iterable): iterable of parameters to optimize or dicts defining parameter groups lr (float, optional): learning rate (default: 1e-3) betas (Tuple[float, float], optional): coefficients used for computing running averages of gradient and its square (default: (0.9, 0.999)) eps (float, optional): term added to the denominator to improve numerical stability (default: 1e-8) weight_decay (float, optional): weight decay (L2 penalty) (default: 0) amsgrad (bool, optional): whether to use the AMSGrad variant of this algorithm from the paper `On the Convergence of Adam and Beyond`_ (default: False) {foreach} {maximize} {capturable} {differentiable} fused (bool, optional): whether the fused implementation (CUDA only) is used. Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16` are supported. Since the fused implementation is usually significantly faster than the for-loop implementation, we try to use it whenever possible (all parameters are on CUDA and are of a supported type). Else, we attempt to use the foreach implementation and lastly fall back to the for-loop implementation. (default: None) .. note:: The foreach and fused implementations are typically faster than the for-loop, single-tensor implementation, so we will try to default to them IF the user has not specified either flag (i.e., when foreach = fused = None). For example, if the user specifies True for foreach but nothing for fused, we will run the foreach implementation. If the user specifies False for fused but nothing for foreach, we will run the for-loop implementation. If the user specifies True for both foreach and fused, we will prioritize fused over foreach. We attempt to use the fastest, so the hierarchy goes fused -> foreach -> for-loop. .. _Adam\: A Method for Stochastic Optimization: https://arxiv.org/abs/1412.6980 .. _On the Convergence of Adam and Beyond: https://openreview.net/forum?id=ryQu7f-RZ """.format(foreach=_foreach_doc, maximize=_maximize_doc, capturable=_capturable_doc, differentiable=_differentiable_doc) def adam(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[Tensor], # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627 # setting this as kwarg for now as functional API is compiled by torch/distributed/optim foreach: Optional[bool] = None, capturable: bool = False, differentiable: bool = False, fused: Optional[bool] = None, grad_scale: Optional[_MultiDeviceReplicator] = None, found_inf: Optional[_MultiDeviceReplicator] = None, *, amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, maximize: bool): r"""Functional API that performs Adam algorithm computation. See :class:`~torch.optim.Adam` for details. """ if fused is None and foreach is None: fused, foreach = _default_to_fused_or_foreach( [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps], differentiable, has_fused=True) if fused is None: fused = False if foreach is None: foreach = False if not all(isinstance(t, torch.Tensor) for t in state_steps): raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors") if foreach and torch.jit.is_scripting(): raise RuntimeError('torch.jit.script not supported with foreach optimizers') if fused and not torch.jit.is_scripting(): func = _fused_adam elif foreach and not torch.jit.is_scripting(): func = _multi_tensor_adam else: func = _single_tensor_adam func(params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps, amsgrad=amsgrad, beta1=beta1, beta2=beta2, lr=lr, weight_decay=weight_decay, eps=eps, maximize=maximize, capturable=capturable, differentiable=differentiable, grad_scale=grad_scale, found_inf=found_inf) def _single_tensor_adam(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[Tensor], grad_scale: Optional[_MultiDeviceReplicator], found_inf: Optional[_MultiDeviceReplicator], *, amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, maximize: bool, capturable: bool, differentiable: bool): assert grad_scale is None and found_inf is None for i, param in enumerate(params): grad = grads[i] if not maximize else -grads[i] exp_avg = exp_avgs[i] exp_avg_sq = exp_avg_sqs[i] step_t = state_steps[i] if capturable: assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors." # update step step_t += 1 if weight_decay != 0: grad = grad.add(param, alpha=weight_decay) if torch.is_complex(param): grad = torch.view_as_real(grad) exp_avg = torch.view_as_real(exp_avg) exp_avg_sq = torch.view_as_real(exp_avg_sq) param = torch.view_as_real(param) # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2) if capturable or differentiable: step = step_t # 1 - beta1 ** step can't be captured in a CUDA graph, even if step is a CUDA tensor # (incurs "RuntimeError: CUDA error: operation not permitted when stream is capturing") bias_correction1 = 1 - torch.pow(beta1, step) bias_correction2 = 1 - torch.pow(beta2, step) step_size = lr / bias_correction1 step_size_neg = step_size.neg() bias_correction2_sqrt = bias_correction2.sqrt() if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now if differentiable: max_exp_avg_sqs_i = max_exp_avg_sqs[i].clone() else: max_exp_avg_sqs_i = max_exp_avg_sqs[i] max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sqs_i, exp_avg_sq)) # Uses the max. for normalizing running avg. of gradient # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) denom = (max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) else: denom = (exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)).add_(eps / step_size_neg) param.addcdiv_(exp_avg, denom) else: step = _get_value(step_t) bias_correction1 = 1 - beta1 ** step bias_correction2 = 1 - beta2 ** step step_size = lr / bias_correction1 bias_correction2_sqrt = _dispatch_sqrt(bias_correction2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i]) # Use the max. for normalizing running avg. of gradient denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps) else: denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps) param.addcdiv_(exp_avg, denom, value=-step_size) def _multi_tensor_adam(params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[Tensor], grad_scale: Optional[_MultiDeviceReplicator], found_inf: Optional[_MultiDeviceReplicator], *, amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, maximize: bool, capturable: bool, differentiable: bool): if len(params) == 0: return if capturable: assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \ "If capturable=True, params and state_steps must be CUDA tensors." assert grad_scale is None and found_inf is None assert not differentiable, "_foreach ops don't support autograd" grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]) for (device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs, device_state_steps) in grouped_tensors.values(): if maximize: device_grads = torch._foreach_neg(tuple(device_grads)) # type: ignore[assignment] # Handle complex parameters device_grads = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_grads] device_exp_avgs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avgs] device_exp_avg_sqs = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_exp_avg_sqs] params_ = [torch.view_as_real(x) if torch.is_complex(x) else x for x in device_params] # update steps torch._foreach_add_(device_state_steps, 1) if weight_decay != 0: device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay) # Decay the first and second moment running average coefficient torch._foreach_mul_(device_exp_avgs, beta1) torch._foreach_add_(device_exp_avgs, device_grads, alpha=1 - beta1) torch._foreach_mul_(device_exp_avg_sqs, beta2) torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, 1 - beta2) if capturable: # TODO: use foreach_pow if/when foreach_pow is added bias_correction1 = [torch.pow(beta1, step) for step in device_state_steps] bias_correction2 = [torch.pow(beta2, step) for step in device_state_steps] # foreach_sub doesn't allow a scalar as the first arg torch._foreach_sub_(bias_correction1, 1) torch._foreach_sub_(bias_correction2, 1) torch._foreach_neg_(bias_correction1) torch._foreach_neg_(bias_correction2) # foreach_div doesn't allow a scalar as the first arg step_size = torch._foreach_div(bias_correction1, lr) torch._foreach_reciprocal_(step_size) torch._foreach_neg_(step_size) bias_correction2_sqrt = torch._foreach_sqrt(bias_correction2) if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment] # Use the max. for normalizing running avg. of gradient max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor) torch._foreach_div_(max_exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)) eps_over_step_size = torch._foreach_div(step_size, eps) torch._foreach_reciprocal_(eps_over_step_size) denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps_over_step_size) else: exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) torch._foreach_div_(exp_avg_sq_sqrt, torch._foreach_mul(bias_correction2_sqrt, step_size)) eps_over_step_size = torch._foreach_div(step_size, eps) torch._foreach_reciprocal_(eps_over_step_size) denom = torch._foreach_add(exp_avg_sq_sqrt, eps_over_step_size) torch._foreach_addcdiv_(params_, device_exp_avgs, denom) else: bias_correction1 = [1 - beta1 ** _get_value(step) for step in device_state_steps] bias_correction2 = [1 - beta2 ** _get_value(step) for step in device_state_steps] step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1]) bias_correction2_sqrt = [_dispatch_sqrt(bc) for bc in bias_correction2] if amsgrad: # Maintains the maximum of all 2nd moment running avg. till now torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # Use the max. for normalizing running avg. of gradient max_exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs) torch._foreach_div_(max_exp_avg_sq_sqrt, bias_correction2_sqrt) denom = torch._foreach_add(max_exp_avg_sq_sqrt, eps) else: exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs) torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt) denom = torch._foreach_add(exp_avg_sq_sqrt, eps) torch._foreach_addcdiv_(params_, device_exp_avgs, denom, step_size) def _fused_adam( params: List[Tensor], grads: List[Tensor], exp_avgs: List[Tensor], exp_avg_sqs: List[Tensor], max_exp_avg_sqs: List[Tensor], state_steps: List[Tensor], grad_scale: Optional[_MultiDeviceReplicator], found_inf: Optional[_MultiDeviceReplicator], *, amsgrad: bool, beta1: float, beta2: float, lr: float, weight_decay: float, eps: float, maximize: bool, capturable: bool, # Needed for consistency. differentiable: bool, ) -> None: grouped_tensors = _group_tensors_by_device_and_dtype([params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]) for (device, dtype) in grouped_tensors: ( device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs, device_state_steps, ) = grouped_tensors[(device, dtype)] if grad_scale is not None and found_inf is not None: device_grad_scale = grad_scale.get(str(device)) device_found_inf = found_inf.get(str(device)) else: device_grad_scale = None device_found_inf = None torch._foreach_add_(device_state_steps, 1) torch._fused_adam_( device_params, device_grads, device_exp_avgs, device_exp_avg_sqs, device_max_exp_avg_sqs, device_state_steps, amsgrad=amsgrad, lr=lr, beta1=beta1, beta2=beta2, weight_decay=weight_decay, eps=eps, maximize=maximize, grad_scale=device_grad_scale, found_inf=device_found_inf, ) if device_found_inf is not None: torch._foreach_sub_(device_state_steps, [device_found_inf] * len(device_state_steps))