Enable fused foreach Adam compilation (#104121)

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/104121
Approved by: https://github.com/janeyx99
This commit is contained in:
Michael Lazos 2023-07-05 23:40:03 +00:00 committed by PyTorch MergeBot
parent 01e6d64dd2
commit a290cbf32b
6 changed files with 127 additions and 14 deletions

View File

@ -0,0 +1,93 @@
# Owner(s): ["module: inductor"]
import sys
import unittest
from copy import deepcopy
import torch
import torch._inductor
from torch.testing._internal.common_utils import TestCase
aten = torch.ops.aten
try:
try:
from .test_torchinductor import check_model, check_model_cuda, requires_cuda
except ImportError:
from test_torchinductor import check_model, check_model_cuda, requires_cuda
except (unittest.SkipTest, ImportError) as e:
sys.stderr.write(f"{type(e)}: {e}\n")
if __name__ == "__main__":
sys.exit(0)
raise
def make_test(optim_cls, closure=None, **kwargs):
@requires_cuda()
def test_fn(self):
input = torch.ones([10, 10], device="cuda:0")
model_eager = torch.nn.Sequential(
*[torch.nn.Linear(10, 10, device="cuda:0") for _ in range(2)]
)
model_eager(input).sum().backward()
input = torch.ones([10, 10], device="cuda:0")
model_compiled = deepcopy(model_eager)
model_compiled(input).sum().backward()
opt_eager = optim_cls(model_eager.parameters(), **kwargs)
opt_compiled = optim_cls(model_compiled.parameters(), **kwargs)
# run the patcher so that step has the expected structure
torch._dynamo.eval_frame.TorchPatcher.patch()
# unwrap step to avoid a deliberate graph break due to
# a limitation of functionalization/no_grad detection
# see the [Note on graph break] in optimizer.py
# This ignores the outer _use_grad_if_differentiable wrapper
# and instead manually disables grad before calling step, which is fine
# for now as dynamo does not support differentiable optimizers anyway
step_fn = opt_compiled.step.__wrapped__
if closure is not None:
def fn():
step_fn(opt_compiled, closure)
else:
def fn():
step_fn(opt_compiled)
with torch.set_grad_enabled(False):
torch.compile(fn, backend="inductor", fullgraph=True)()
opt_eager.step()
self.assertEqual(
list(model_eager.parameters()), list(model_compiled.parameters())
)
if self.check_kernel_count:
# currently, we compile the step and the rest of the computation
# separately because the step is a single element tensor
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
return test_fn
class CompiledOptimizerTests(TestCase):
check_model_cuda = check_model_cuda
check_model_cpu = check_model
check_kernel_count = True
def setUp(self):
super().setUp()
torch._inductor.metrics.reset()
def tearDown(self):
super().tearDown()
torch._inductor.metrics.reset()
test_adam = make_test(torch.optim.Adam, lr=0.01)
test_adam_weight_decay = make_test(torch.optim.Adam, lr=0.01, weight_decay=0.01)

View File

@ -1243,7 +1243,6 @@ class TorchPatcher:
from ..optim import (
adadelta,
adagrad,
adam,
adamax,
adamw,
asgd,
@ -1256,7 +1255,6 @@ class TorchPatcher:
for opt_mod in (
adadelta,
adagrad,
adam,
adamax,
adamw,
asgd,

View File

@ -25,6 +25,11 @@ class GuardInstallException(Exception):
class OptimizerVariable(UserDefinedObjectVariable):
def __init__(self, value, grad_to_source=None, **kwargs):
super().__init__(value, **kwargs)
for group in self.value.param_groups:
if "capturable" in group:
group["capturable"] = True
if grad_to_source is None:
self.grad_to_source = {}

View File

@ -279,7 +279,9 @@ def adam(params: List[Tensor],
if foreach is None:
foreach = False
if not all(isinstance(t, torch.Tensor) for t in state_steps):
# this check is slow during compilation, so we skip it
# if it's strictly needed we can add this check back in dynamo
if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError("API has changed, `state_steps` argument must contain a list of singleton tensors")
if foreach and torch.jit.is_scripting():
@ -339,7 +341,8 @@ def _single_tensor_adam(params: List[Tensor],
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
if capturable:
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert param.is_cuda and step_t.is_cuda, "If capturable=True, params and state_steps must be CUDA tensors."
# update step
@ -428,7 +431,8 @@ def _multi_tensor_adam(params: List[Tensor],
if len(params) == 0:
return
if capturable:
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert all(p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)), \
"If capturable=True, params and state_steps must be CUDA tensors."

View File

@ -304,7 +304,7 @@ def adamw(
See :class:`~torch.optim.AdamW` for details.
"""
if not all(isinstance(t, torch.Tensor) for t in state_steps):
if not torch._utils.is_compiling() and not all(isinstance(t, torch.Tensor) for t in state_steps):
raise RuntimeError(
"API has changed, `state_steps` argument must contain a list of singleton tensors"
)
@ -382,7 +382,8 @@ def _single_tensor_adamw(
exp_avg_sq = exp_avg_sqs[i]
step_t = state_steps[i]
if capturable:
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert (
param.is_cuda and step_t.is_cuda
), "If capturable=True, params and state_steps must be CUDA tensors."
@ -479,7 +480,8 @@ def _multi_tensor_adamw(
if len(params) == 0:
return
if capturable:
# If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
if not torch._utils.is_compiling() and capturable:
assert all(
p.is_cuda and step.is_cuda for p, step in zip(params, state_steps)
), "If capturable=True, params and state_steps must be CUDA tensors."

View File

@ -252,9 +252,15 @@ class Optimizer:
# Currently needed by Adam and AdamW
def _cuda_graph_capture_health_check(self):
# If we are compiling, we take the capturable path automatically
# One caveat here is that if we are compiling, we *permit* step/param tensors to be on CPU
# so we do not explicitly enable the capturable flag. Inductor will decide whether cudagraphs
# Note [torch.compile x capturable]
# If we are compiling, we try to take the capturable path automatically by
# setting the flag to True during tracing. Due to this, we skip all the checks
# normally required for determining whether we can use CUDA graphs and
# shunt the responsibility to torch.inductor. This saves time during tracing
# since the checks are slow without sacrificing UX since inductor will warn
# later if CUDA graphs cannot be enabled, e.g.,
# https://github.com/pytorch/pytorch/blob/d3ba8901d8640eb16f88b2bfef9df7fa383d4b47/torch/_inductor/compile_fx.py#L390.
# Thus, when compiling, inductor will determine if cudagraphs
# can be enabled based on whether there is input mutation or CPU tensors.
if not is_compiling() and torch.backends.cuda.is_built() and torch.cuda.is_available():
capturing = torch.cuda.is_current_stream_capturing()
@ -422,11 +428,16 @@ class Optimizer:
capturable = pg["capturable"] if "capturable" in pg else False
break
if key != "step" or capturable or fused:
if key == 'step':
if capturable or fused:
return value.to(dtype=torch.float32, device=param.device)
else:
return value
else:
if param.is_floating_point():
return value.to(dtype=param.dtype, device=param.device)
return value.to(device=param.device)
return value
else:
return value.to(device=param.device)
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.