mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[optim] abstract out _default_to_foreach_util (#92305)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92305 Approved by: https://github.com/albanD
This commit is contained in:
parent
5c9c39a83f
commit
4fc796daf9
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable
|
||||
from .optimizer import Optimizer, _use_grad_for_differentiable, _default_to_foreach
|
||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||
from typing import List, Optional
|
||||
|
||||
|
|
@ -192,19 +192,9 @@ def adadelta(
|
|||
See :class:`~torch.optim.Adadelta` for details.
|
||||
"""
|
||||
|
||||
# We try to use the foreach implementation on CUDA whenever possible since
|
||||
# it is faster than the for-loop implementation. However, the foreach
|
||||
# implementation is not differentiable, so we must check differentiable=False.
|
||||
# We still respect when the user inputs False for foreach.
|
||||
if foreach is None:
|
||||
all_tensors = []
|
||||
all_tensors.extend(params)
|
||||
all_tensors.extend(grads)
|
||||
all_tensors.extend(square_avgs)
|
||||
all_tensors.extend(acc_deltas)
|
||||
foreach = not torch.jit.is_scripting() and not differentiable and all(
|
||||
p.is_cuda for p in all_tensors
|
||||
)
|
||||
foreach = _default_to_foreach([params, grads, square_avgs, acc_deltas], differentiable=differentiable)
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import warnings
|
|||
import functools
|
||||
import math
|
||||
|
||||
from typing import Callable, Dict
|
||||
from typing import Callable, Dict, List
|
||||
|
||||
import torch.utils.hooks as hooks
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
|
@ -54,6 +54,19 @@ def _dispatch_sqrt(x: float): # float annotation is needed because of torchscri
|
|||
else:
|
||||
return math.sqrt(x)
|
||||
|
||||
|
||||
# We try to use the foreach implementation on CUDA whenever possible since
|
||||
# it is faster than the for-loop implementation. However, the foreach
|
||||
# implementation is not differentiable, so we must check differentiable=False.
|
||||
def _default_to_foreach(tensorlists: List[List[torch.Tensor]], differentiable: bool = False) -> bool:
|
||||
all_tensors = []
|
||||
for tensorlist in tensorlists:
|
||||
all_tensors.extend(tensorlist)
|
||||
return not torch.jit.is_scripting() and not differentiable and all(
|
||||
p.is_cuda for p in all_tensors
|
||||
)
|
||||
|
||||
|
||||
def register_optimizer_step_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
|
||||
r"""Register a pre hook common to all optimizers. The hook should have the following
|
||||
signature::
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user