[optim] abstract out _default_to_foreach_util (#92305)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92305
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu 2023-01-17 15:30:00 +00:00 committed by PyTorch MergeBot
parent 5c9c39a83f
commit 4fc796daf9
2 changed files with 16 additions and 13 deletions

View File

@ -1,7 +1,7 @@
import torch
from torch import Tensor
from .optimizer import Optimizer, _use_grad_for_differentiable
from .optimizer import Optimizer, _use_grad_for_differentiable, _default_to_foreach
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from typing import List, Optional
@ -192,19 +192,9 @@ def adadelta(
See :class:`~torch.optim.Adadelta` for details.
"""
# We try to use the foreach implementation on CUDA whenever possible since
# it is faster than the for-loop implementation. However, the foreach
# implementation is not differentiable, so we must check differentiable=False.
# We still respect when the user inputs False for foreach.
if foreach is None:
all_tensors = []
all_tensors.extend(params)
all_tensors.extend(grads)
all_tensors.extend(square_avgs)
all_tensors.extend(acc_deltas)
foreach = not torch.jit.is_scripting() and not differentiable and all(
p.is_cuda for p in all_tensors
)
foreach = _default_to_foreach([params, grads, square_avgs, acc_deltas], differentiable=differentiable)
if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers")

View File

@ -6,7 +6,7 @@ import warnings
import functools
import math
from typing import Callable, Dict
from typing import Callable, Dict, List
import torch.utils.hooks as hooks
from torch.utils.hooks import RemovableHandle
@ -54,6 +54,19 @@ def _dispatch_sqrt(x: float): # float annotation is needed because of torchscri
else:
return math.sqrt(x)
# We try to use the foreach implementation on CUDA whenever possible since
# it is faster than the for-loop implementation. However, the foreach
# implementation is not differentiable, so we must check differentiable=False.
def _default_to_foreach(tensorlists: List[List[torch.Tensor]], differentiable: bool = False) -> bool:
all_tensors = []
for tensorlist in tensorlists:
all_tensors.extend(tensorlist)
return not torch.jit.is_scripting() and not differentiable and all(
p.is_cuda for p in all_tensors
)
def register_optimizer_step_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
r"""Register a pre hook common to all optimizers. The hook should have the following
signature::