mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Enable fused foreach Adam compilation (#104121)
Fixes #ISSUE_NUMBER Pull Request resolved: https://github.com/pytorch/pytorch/pull/104121 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
01e6d64dd2
commit
a290cbf32b
93
test/inductor/test_compiled_optimizers.py
Normal file
93
test/inductor/test_compiled_optimizers.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import sys
|
||||
import unittest
|
||||
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
|
||||
import torch._inductor
|
||||
|
||||
from torch.testing._internal.common_utils import TestCase
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
try:
|
||||
try:
|
||||
from .test_torchinductor import check_model, check_model_cuda, requires_cuda
|
||||
except ImportError:
|
||||
from test_torchinductor import check_model, check_model_cuda, requires_cuda
|
||||
except (unittest.SkipTest, ImportError) as e:
|
||||
sys.stderr.write(f"{type(e)}: {e}\n")
|
||||
if __name__ == "__main__":
|
||||
sys.exit(0)
|
||||
raise
|
||||
|
||||
|
||||
def make_test(optim_cls, closure=None, **kwargs):
|
||||
@requires_cuda()
|
||||
def test_fn(self):
|
||||
input = torch.ones([10, 10], device="cuda:0")
|
||||
model_eager = torch.nn.Sequential(
|
||||
*[torch.nn.Linear(10, 10, device="cuda:0") for _ in range(2)]
|
||||
)
|
||||
model_eager(input).sum().backward()
|
||||
|
||||
input = torch.ones([10, 10], device="cuda:0")
|
||||
model_compiled = deepcopy(model_eager)
|
||||
model_compiled(input).sum().backward()
|
||||
|
||||
opt_eager = optim_cls(model_eager.parameters(), **kwargs)
|
||||
opt_compiled = optim_cls(model_compiled.parameters(), **kwargs)
|
||||
# run the patcher so that step has the expected structure
|
||||
torch._dynamo.eval_frame.TorchPatcher.patch()
|
||||
|
||||
# unwrap step to avoid a deliberate graph break due to
|
||||
# a limitation of functionalization/no_grad detection
|
||||
# see the [Note on graph break] in optimizer.py
|
||||
# This ignores the outer _use_grad_if_differentiable wrapper
|
||||
# and instead manually disables grad before calling step, which is fine
|
||||
# for now as dynamo does not support differentiable optimizers anyway
|
||||
step_fn = opt_compiled.step.__wrapped__
|
||||
if closure is not None:
|
||||
|
||||
def fn():
|
||||
step_fn(opt_compiled, closure)
|
||||
|
||||
else:
|
||||
|
||||
def fn():
|
||||
step_fn(opt_compiled)
|
||||
|
||||
with torch.set_grad_enabled(False):
|
||||
torch.compile(fn, backend="inductor", fullgraph=True)()
|
||||
opt_eager.step()
|
||||
|
||||
self.assertEqual(
|
||||
list(model_eager.parameters()), list(model_compiled.parameters())
|
||||
)
|
||||
if self.check_kernel_count:
|
||||
# currently, we compile the step and the rest of the computation
|
||||
# separately because the step is a single element tensor
|
||||
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
|
||||
|
||||
return test_fn
|
||||
|
||||
|
||||
class CompiledOptimizerTests(TestCase):
|
||||
check_model_cuda = check_model_cuda
|
||||
check_model_cpu = check_model
|
||||
check_kernel_count = True
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
torch._inductor.metrics.reset()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
torch._inductor.metrics.reset()
|
||||
|
||||
test_adam = make_test(torch.optim.Adam, lr=0.01)
|
||||
test_adam_weight_decay = make_test(torch.optim.Adam, lr=0.01, weight_decay=0.01)
|
||||
|
|
@ -1243,7 +1243,6 @@ class TorchPatcher:
|
|||
from ..optim import (
|
||||
adadelta,
|
||||
adagrad,
|
||||
adam,
|
||||
adamax,
|
||||
adamw,
|
||||
asgd,
|
||||
|
|
@ -1256,7 +1255,6 @@ class TorchPatcher:
|
|||
for opt_mod in (
|
||||
adadelta,
|
||||
adagrad,
|
||||
adam,
|
||||
adamax,
|
||||
adamw,
|
||||
asgd,
|
||||
|
|
|
|||
|
|
@ -25,6 +25,11 @@ class GuardInstallException(Exception):
|
|||
class OptimizerVariable(UserDefinedObjectVariable):
|
||||
def __init__(self, value, grad_to_source=None, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
|
||||
for group in self.value.param_groups:
|
||||
if "capturable" in group:
|
||||
group["capturable"] = True
|
||||
|
||||
if grad_to_source is None:
|
||||
self.grad_to_source = {}
|
||||
|
||||
|
|
|
|||
|
|
@ -279,7 +279,9 @@ def adam(params: List[Tensor],
|
|||
if foreach is None:
|
||||
foreach = False
|
||||
|
||||
if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
||||
# this check is slow during compilation, so we skip it
|
||||
# if it's strictly needed we can add this check back in dynamo
|
||||
if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
|
||||
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
|
||||
|
||||
if foreach and torch.jit.is_scripting():
|
||||
|
|
@ -339,7 +341,8 @@ def _single_tensor_adam(params: List[Tensor],
|
|||
exp_avg_sq = exp_avg_sqs[i]
|
||||
step_t = state_steps[i]
|
||||
|
||||
if capturable:
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors."
|
||||
|
||||
# update step
|
||||
|
|
@ -428,7 +431,8 @@ def _multi_tensor_adam(params: List[Tensor],
|
|||
if len(params) == 0:
|
||||
return
|
||||
|
||||
if capturable:
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \
|
||||
"If capturable=True, params and state_steps must be CUDA tensors."
|
||||
|
||||
|
|
|
|||
|
|
@ -304,7 +304,7 @@ def adamw(
|
|||
See :class:`~torch.optim.AdamW` for details.
|
||||
"""
|
||||
|
||||
if not all(isinstance(t, torch.Tensor) for t in state_steps):
|
||||
if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
|
||||
raise RuntimeError(
|
||||
"API has changed, `state_steps` argument must contain a list of singleton tensors"
|
||||
)
|
||||
|
|
@ -382,7 +382,8 @@ def _single_tensor_adamw(
|
|||
exp_avg_sq = exp_avg_sqs[i]
|
||||
step_t = state_steps[i]
|
||||
|
||||
if capturable:
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
assert (
|
||||
param.is_cuda and step_t.is_cuda
|
||||
), "If capturable=True, params and state_steps must be CUDA tensors."
|
||||
|
|
@ -479,7 +480,8 @@ def _multi_tensor_adamw(
|
|||
if len(params) == 0:
|
||||
return
|
||||
|
||||
if capturable:
|
||||
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
|
||||
if not torch._utils.is_compiling() and capturable:
|
||||
assert all(
|
||||
p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)
|
||||
), "If capturable=True, params and state_steps must be CUDA tensors."
|
||||
|
|
|
|||
|
|
@ -252,9 +252,15 @@ class Optimizer:
|
|||
|
||||
# Currently needed by Adam and AdamW
|
||||
def _cuda_graph_capture_health_check(self):
|
||||
# If we are compiling, we take the capturable path automatically
|
||||
# One caveat here is that if we are compiling, we *permit* step/param tensors to be on CPU
|
||||
# so we do not explicitly enable the capturable flag. Inductor will decide whether cudagraphs
|
||||
# Note [torch.compile x capturable]
|
||||
# If we are compiling, we try to take the capturable path automatically by
|
||||
# setting the flag to True during tracing. Due to this, we skip all the checks
|
||||
# normally required for determining whether we can use CUDA graphs and
|
||||
# shunt the responsibility to torch.inductor. This saves time during tracing
|
||||
# since the checks are slow without sacrificing UX since inductor will warn
|
||||
# later if CUDA graphs cannot be enabled, e.g.,
|
||||
# https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390.
|
||||
# Thus, when compiling, inductor will determine if cudagraphs
|
||||
# can be enabled based on whether there is input mutation or CPU tensors.
|
||||
if not is_compiling() and torch.backends.cuda.is_built() and torch.cuda.is_available():
|
||||
capturing = torch.cuda.is_current_stream_capturing()
|
||||
|
|
@ -422,11 +428,16 @@ class Optimizer:
|
|||
capturable = pg["capturable"] if "capturable" in pg else False
|
||||
break
|
||||
|
||||
if key != "step" or capturable or fused:
|
||||
if key == 'step':
|
||||
if capturable or fused:
|
||||
return value.to(dtype=torch.float32, device=param.device)
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
if param.is_floating_point():
|
||||
return value.to(dtype=param.dtype, device=param.device)
|
||||
return value.to(device=param.device)
|
||||
return value
|
||||
else:
|
||||
return value.to(device=param.device)
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
r"""Loads the optimizer state.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user