[aota] Needs autograd if an input requires_grad, agnostic to enable_grad (#128890)

Original issue: https://github.com/pytorch/pytorch/issues/114338

Reland of:  https://github.com/pytorch/pytorch/pull/128016

Summary from previous PR:
We assume only two possible mutually exclusive scenarios:

Running compiled region for training (Any of inputs has requires_grad)

Produced differentiable outputs should have requires_grad.
Running compiled region for inference (None of inputs has requires_grad)

All outputs do not have requires_grad.
Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1).

With current state that means:
1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad
2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad()

Changes in partitioner?

Inference and Training graphs had difference in return container, list/tuple.
The changes in partitioner are done to unify and return always tuple.
As a result - some changes in test_aotdispatch.py for graph contents list -> tuple.

Why was revert?

There was a regression of hf_Reformer model on inference.
```
TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode
```

Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True).

Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad.

As a result we started compiling training graph instead of inference.

Fix for view ops:

If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph.

This is handled in aot_autograd.py, where output_and_mutation_safe are calculated.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128890
Approved by: https://github.com/bdhirsh
This commit is contained in:
IvanKobzarev 2024-07-30 12:21:32 -07:00 committed by PyTorch MergeBot
parent e9d1c26275
commit a94e507c39
3 changed files with 80 additions and 13 deletions

View File

@ -1547,8 +1547,8 @@ def forward(self, arg0_1):
self.assertExpectedInline( self.assertExpectedInline(
fw_graph.code.strip(), fw_graph.code.strip(),
"""\ """\
def forward(self, primals_1): def forward(self, arg0_1):
view = torch.ops.aten.view.default(primals_1, [-1]); primals_1 = None view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None
return (view,)""", return (view,)""",
) )
@ -2199,12 +2199,11 @@ def forward(self, primals_1):
self.assertExpectedInline( self.assertExpectedInline(
fw_graph.code.strip(), fw_graph.code.strip(),
"""\ """\
def forward(self, primals_1, primals_2): def forward(self, arg0_1, arg1_1):
view = torch.ops.aten.view.default(primals_1, [3, 3]); primals_1 = None t = torch.ops.aten.t.default(arg0_1); arg0_1 = None
t = torch.ops.aten.t.default(view); view = None view = torch.ops.aten.view.default(arg1_1, [3, 3]); arg1_1 = None
view_1 = torch.ops.aten.view.default(primals_2, [3, 3]); primals_2 = None view_1 = torch.ops.aten.view.default(t, [3, 3])
view_2 = torch.ops.aten.view.default(t, [3, 3]) return (t, view, view_1)""",
return (t, view_1, view_2)""",
) )
def test_view_detach(self): def test_view_detach(self):
@ -5744,6 +5743,59 @@ def forward(self, tangents_1, tangents_2):
self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
def test_aot_dispatch_output_requires_grad_in_no_grad(self):
def fn(x):
out1 = x.sin()
with torch.enable_grad():
out2 = x.cos()
return out1, out2
inp_fns = [
lambda: torch.ones(10, requires_grad=True),
lambda: torch.ones(10, requires_grad=False),
]
compiled_f = aot_function(fn, nop)
for inp_fn in inp_fns:
with torch.no_grad():
ref_x = inp_fn()
ref_out = fn(ref_x)
x = inp_fn()
out = compiled_f(x)
for r, o in zip(ref_out, out):
self.assertEqual(r.requires_grad, o.requires_grad)
if ref_x.requires_grad:
with torch.enable_grad():
(ref_out[0] + ref_out[1]).sum().backward()
(out[0] + out[1]).sum().backward()
self.assertEqual(ref_x.grad, x.grad)
assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3)
def test_aot_dispatch_output_requires_grad_in_no_grad_views(self):
# view-type ops preserve requires_grad even in no_grad.
def fn(x):
return x.view(-1), x.sin()
inference_graph_cell = [None]
inference_compiler = make_boxed_compiler(
partial(extract_graph, graph_cell=inference_graph_cell)
)
compiled_fn = aot_function(fn, nop, inference_compiler=inference_compiler)
inp_x0 = torch.ones(2, 3, requires_grad=True)
# Clone in no_grad will make requires_grad=False tensors, keep clone outside of no_grad
ref_x0 = inp_x0.clone()
x0 = inp_x0.clone()
with torch.no_grad():
ref_out1, ref_out2 = fn(ref_x0)
out1, out2 = compiled_fn(x0)
# Assert that we executed inference graph
self.assertTrue(inference_graph_cell[0] is not None)
self.assertEqual(ref_out1.requires_grad, out1.requires_grad)
self.assertEqual(ref_out2.requires_grad, out2.requires_grad)
class TestAOTModuleSimplified(AOTTestCase): class TestAOTModuleSimplified(AOTTestCase):
def test_aot_module_simplified(self): def test_aot_module_simplified(self):

View File

@ -310,7 +310,13 @@ def _create_runtime_wrapper(
for idx in indices_of_inps_to_detach: for idx in indices_of_inps_to_detach:
if isinstance(args_[idx], torch.Tensor): if isinstance(args_[idx], torch.Tensor):
args_[idx] = args_[idx].detach() args_[idx] = args_[idx].detach()
with torch.autograd._force_original_view_tracking(True):
# It's possible to have trace_joint inside user specified with no_grad() region,
# if there is a nested with enable_grad(), that forces some outputs to require gradients.
# Therefore, we unconditionally turn on enable_grad() for compiled_fn execution.
with torch.autograd._force_original_view_tracking(
True
), torch.enable_grad():
all_outs = call_func_at_runtime_with_args( all_outs = call_func_at_runtime_with_args(
compiled_fn, args_, disable_amp=disable_amp, steal_args=True compiled_fn, args_, disable_amp=disable_amp, steal_args=True
) )

View File

@ -572,9 +572,8 @@ def create_aot_dispatcher_function(
fake_flat_args = process_inputs(flat_args) fake_flat_args = process_inputs(flat_args)
needs_autograd = ( needs_autograd = any(
any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)) x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)
and torch.is_grad_enabled()
) )
with enable_python_dispatcher(): with enable_python_dispatcher():
@ -600,7 +599,17 @@ def create_aot_dispatcher_function(
) )
output_and_mutation_safe = not any( output_and_mutation_safe = not any(
x.requires_grad for x in fw_metadata.output_info x.requires_grad
# view-type operations preserve requires_grad even in no_grad.
# Do not count aliases of inputs with requires_grad as reason to make a training graph,
# as AOTAutograd will perform view-replay to regenerate the view outputs at runtime,
# setting their grad_fn properly.
and not (
x.output_type
in (OutputType.alias_of_input, OutputType.is_input)
and fw_metadata.input_info[x.base_idx].requires_grad
)
for x in fw_metadata.output_info
) and not any( ) and not any(
x.requires_grad x.requires_grad
and x.mutates_data and x.mutates_data