mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[optim] abstract out _default_to_foreach_util (#92305)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92305 Approved by: https://github.com/albanD
This commit is contained in:
parent
5c9c39a83f
commit
4fc796daf9
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from .optimizer import Optimizer, _use_grad_for_differentiable
|
from .optimizer import Optimizer, _use_grad_for_differentiable, _default_to_foreach
|
||||||
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
from torch.utils._foreach_utils import _group_tensors_by_device_and_dtype
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
|
@ -192,19 +192,9 @@ def adadelta(
|
||||||
See :class:`~torch.optim.Adadelta` for details.
|
See :class:`~torch.optim.Adadelta` for details.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# We try to use the foreach implementation on CUDA whenever possible since
|
|
||||||
# it is faster than the for-loop implementation. However, the foreach
|
|
||||||
# implementation is not differentiable, so we must check differentiable=False.
|
|
||||||
# We still respect when the user inputs False for foreach.
|
# We still respect when the user inputs False for foreach.
|
||||||
if foreach is None:
|
if foreach is None:
|
||||||
all_tensors = []
|
foreach = _default_to_foreach([params, grads, square_avgs, acc_deltas], differentiable=differentiable)
|
||||||
all_tensors.extend(params)
|
|
||||||
all_tensors.extend(grads)
|
|
||||||
all_tensors.extend(square_avgs)
|
|
||||||
all_tensors.extend(acc_deltas)
|
|
||||||
foreach = not torch.jit.is_scripting() and not differentiable and all(
|
|
||||||
p.is_cuda for p in all_tensors
|
|
||||||
)
|
|
||||||
|
|
||||||
if foreach and torch.jit.is_scripting():
|
if foreach and torch.jit.is_scripting():
|
||||||
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
raise RuntimeError("torch.jit.script not supported with foreach optimizers")
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ import warnings
|
||||||
import functools
|
import functools
|
||||||
import math
|
import math
|
||||||
|
|
||||||
from typing import Callable, Dict
|
from typing import Callable, Dict, List
|
||||||
|
|
||||||
import torch.utils.hooks as hooks
|
import torch.utils.hooks as hooks
|
||||||
from torch.utils.hooks import RemovableHandle
|
from torch.utils.hooks import RemovableHandle
|
||||||
|
|
@ -54,6 +54,19 @@ def _dispatch_sqrt(x: float): # float annotation is needed because of torchscri
|
||||||
else:
|
else:
|
||||||
return math.sqrt(x)
|
return math.sqrt(x)
|
||||||
|
|
||||||
|
|
||||||
|
# We try to use the foreach implementation on CUDA whenever possible since
|
||||||
|
# it is faster than the for-loop implementation. However, the foreach
|
||||||
|
# implementation is not differentiable, so we must check differentiable=False.
|
||||||
|
def _default_to_foreach(tensorlists: List[List[torch.Tensor]], differentiable: bool = False) -> bool:
|
||||||
|
all_tensors = []
|
||||||
|
for tensorlist in tensorlists:
|
||||||
|
all_tensors.extend(tensorlist)
|
||||||
|
return not torch.jit.is_scripting() and not differentiable and all(
|
||||||
|
p.is_cuda for p in all_tensors
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def register_optimizer_step_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
|
def register_optimizer_step_pre_hook(hook: Callable[..., None]) -> RemovableHandle:
|
||||||
r"""Register a pre hook common to all optimizers. The hook should have the following
|
r"""Register a pre hook common to all optimizers. The hook should have the following
|
||||||
signature::
|
signature::
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user