[optim] abstract out _default_to_foreach_util (#92305)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/92305
Approved by: https://github.com/albanD
This commit is contained in:
Jane Xu 2023-01-17 15:30:00 +00:00 committed by PyTorch MergeBot
parent 5c9c39a83f
commit 4fc796daf9
2 changed files with 16 additions and 13 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from torch import Tensor from torch import Tensor
from .optimizer import Optimizer, _use_grad_for_differentiable from .optimizer import Optimizer, _use_grad_for_differentiable, _default_to_foreach
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
from typing import List, Optional from typing import List, Optional
@ -192,19 +192,9 @@ def adadelta(
See :class:`~torch.optim.Adadelta` for details. See :class:`~torch.optim.Adadelta` for details.
""" """
# We try to use the foreach implementation on CUDA whenever possible since
# it is faster than the for-loop implementation. However, the foreach
# implementation is not differentiable, so we must check differentiable=False.
# We still respect when the user inputs False for foreach. # We still respect when the user inputs False for foreach.
if foreach is None: if foreach is None:
all_tensors = [] foreach = _default_to_foreach([params, grads, square_avgs, acc_deltas], differentiable=differentiable)
all_tensors.extend(params)
all_tensors.extend(grads)
all_tensors.extend(square_avgs)
all_tensors.extend(acc_deltas)
foreach = not torch.jit.is_scripting() and not differentiable and all(
p.is_cuda for p in all_tensors
)
if foreach and torch.jit.is_scripting(): if foreach and torch.jit.is_scripting():
raise RuntimeError("torch.jit.script not supported with foreach optimizers") raise RuntimeError("torch.jit.script not supported with foreach optimizers")

View File

@ -6,7 +6,7 @@ import warnings
import functools import functools
import math import math
from typing import Callable, Dict from typing import Callable, Dict, List
import torch.utils.hooks as hooks import torch.utils.hooks as hooks
from torch.utils.hooks import RemovableHandle from torch.utils.hooks import RemovableHandle
@ -54,6 +54,19 @@ def _dispatch_sqrt(x: float): # float annotation is needed because of torchscri
else: else:
return math.sqrt(x) return math.sqrt(x)
# We try to use the foreach implementation on CUDA whenever possible since
# it is faster than the for-loop implementation. However, the foreach
# implementation is not differentiable, so we must check differentiable=False.
def _default_to_foreach(tensorlists: List[List[torch.Tensor]], differentiable: bool = False) -> bool:
all_tensors = []
for tensorlist in tensorlists:
all_tensors.extend(tensorlist)
return not torch.jit.is_scripting() and not differentiable and all(
p.is_cuda for p in all_tensors
)
def register_optimizer_step_pre_hook(hook: Callable[..., None]) -> RemovableHandle: def register_optimizer_step_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
r"""Register a pre hook common to all optimizers. The hook should have the following r"""Register a pre hook common to all optimizers. The hook should have the following
signature:: signature::