mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[aota] Needs autograd if an input requires_grad, agnostic to enable_grad (#128890)
Original issue: https://github.com/pytorch/pytorch/issues/114338 Reland of: https://github.com/pytorch/pytorch/pull/128016 Summary from previous PR: We assume only two possible mutually exclusive scenarios: Running compiled region for training (Any of inputs has requires_grad) Produced differentiable outputs should have requires_grad. Running compiled region for inference (None of inputs has requires_grad) All outputs do not have requires_grad. Even if user runs the region under no_grad(), but has an input Tensor with requires_grad - we go Training scenario (1). With current state that means: 1/ needs_autograd should not check torch.is_grad_enabled(), only that any of inputs requires_grad 2/ if needs_autograd => trace_joint (We are in training scenario 1.) => always run compiled region under with.enable_grad() Changes in partitioner? Inference and Training graphs had difference in return container, list/tuple. The changes in partitioner are done to unify and return always tuple. As a result - some changes in test_aotdispatch.py for graph contents list -> tuple. Why was revert? There was a regression of hf_Reformer model on inference. ``` TORCHINDUCTOR_FX_GRAPH_CACHE=0 python benchmarks/dynamo/torchbench.py --performance --inference --bfloat16 --backend inductor --device cuda --only hf_Reformer --cold-start-latency --use-eval-mode ``` Because one of the compiled graphs contained outputs, which are aliases to the inputs that are nn.Parameter(requires_grad=True). Even if inference bencharmsk torchbench runs inside with` torch.no_grad()` - alias (specifically for hf_Reformer - expand) ops preserve requires_grad. As a result we started compiling training graph instead of inference. Fix for view ops: If we have outputs, that are aliases to inputs that requires_grad, those outputs requires grad is not a reason to generate training graph. This is handled in aot_autograd.py, where output_and_mutation_safe are calculated. Pull Request resolved: https://github.com/pytorch/pytorch/pull/128890 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
e9d1c26275
commit
a94e507c39
|
|
@ -1547,8 +1547,8 @@ def forward(self, arg0_1):
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
fw_graph.code.strip(),
|
fw_graph.code.strip(),
|
||||||
"""\
|
"""\
|
||||||
def forward(self, primals_1):
|
def forward(self, arg0_1):
|
||||||
view = torch.ops.aten.view.default(primals_1, [-1]); primals_1 = None
|
view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None
|
||||||
return (view,)""",
|
return (view,)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2199,12 +2199,11 @@ def forward(self, primals_1):
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
fw_graph.code.strip(),
|
fw_graph.code.strip(),
|
||||||
"""\
|
"""\
|
||||||
def forward(self, primals_1, primals_2):
|
def forward(self, arg0_1, arg1_1):
|
||||||
view = torch.ops.aten.view.default(primals_1, [3, 3]); primals_1 = None
|
t = torch.ops.aten.t.default(arg0_1); arg0_1 = None
|
||||||
t = torch.ops.aten.t.default(view); view = None
|
view = torch.ops.aten.view.default(arg1_1, [3, 3]); arg1_1 = None
|
||||||
view_1 = torch.ops.aten.view.default(primals_2, [3, 3]); primals_2 = None
|
view_1 = torch.ops.aten.view.default(t, [3, 3])
|
||||||
view_2 = torch.ops.aten.view.default(t, [3, 3])
|
return (t, view, view_1)""",
|
||||||
return (t, view_1, view_2)""",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_view_detach(self):
|
def test_view_detach(self):
|
||||||
|
|
@ -5744,6 +5743,59 @@ def forward(self, tangents_1, tangents_2):
|
||||||
self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
|
self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
|
||||||
self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
|
self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
|
||||||
|
|
||||||
|
def test_aot_dispatch_output_requires_grad_in_no_grad(self):
|
||||||
|
def fn(x):
|
||||||
|
out1 = x.sin()
|
||||||
|
with torch.enable_grad():
|
||||||
|
out2 = x.cos()
|
||||||
|
return out1, out2
|
||||||
|
|
||||||
|
inp_fns = [
|
||||||
|
lambda: torch.ones(10, requires_grad=True),
|
||||||
|
lambda: torch.ones(10, requires_grad=False),
|
||||||
|
]
|
||||||
|
|
||||||
|
compiled_f = aot_function(fn, nop)
|
||||||
|
for inp_fn in inp_fns:
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_x = inp_fn()
|
||||||
|
ref_out = fn(ref_x)
|
||||||
|
x = inp_fn()
|
||||||
|
out = compiled_f(x)
|
||||||
|
for r, o in zip(ref_out, out):
|
||||||
|
self.assertEqual(r.requires_grad, o.requires_grad)
|
||||||
|
if ref_x.requires_grad:
|
||||||
|
with torch.enable_grad():
|
||||||
|
(ref_out[0] + ref_out[1]).sum().backward()
|
||||||
|
(out[0] + out[1]).sum().backward()
|
||||||
|
self.assertEqual(ref_x.grad, x.grad)
|
||||||
|
assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3)
|
||||||
|
|
||||||
|
def test_aot_dispatch_output_requires_grad_in_no_grad_views(self):
|
||||||
|
# view-type ops preserve requires_grad even in no_grad.
|
||||||
|
def fn(x):
|
||||||
|
return x.view(-1), x.sin()
|
||||||
|
|
||||||
|
inference_graph_cell = [None]
|
||||||
|
inference_compiler = make_boxed_compiler(
|
||||||
|
partial(extract_graph, graph_cell=inference_graph_cell)
|
||||||
|
)
|
||||||
|
compiled_fn = aot_function(fn, nop, inference_compiler=inference_compiler)
|
||||||
|
|
||||||
|
inp_x0 = torch.ones(2, 3, requires_grad=True)
|
||||||
|
# Clone in no_grad will make requires_grad=False tensors, keep clone outside of no_grad
|
||||||
|
ref_x0 = inp_x0.clone()
|
||||||
|
x0 = inp_x0.clone()
|
||||||
|
with torch.no_grad():
|
||||||
|
ref_out1, ref_out2 = fn(ref_x0)
|
||||||
|
|
||||||
|
out1, out2 = compiled_fn(x0)
|
||||||
|
# Assert that we executed inference graph
|
||||||
|
self.assertTrue(inference_graph_cell[0] is not None)
|
||||||
|
|
||||||
|
self.assertEqual(ref_out1.requires_grad, out1.requires_grad)
|
||||||
|
self.assertEqual(ref_out2.requires_grad, out2.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
class TestAOTModuleSimplified(AOTTestCase):
|
class TestAOTModuleSimplified(AOTTestCase):
|
||||||
def test_aot_module_simplified(self):
|
def test_aot_module_simplified(self):
|
||||||
|
|
|
||||||
|
|
@ -310,7 +310,13 @@ def _create_runtime_wrapper(
|
||||||
for idx in indices_of_inps_to_detach:
|
for idx in indices_of_inps_to_detach:
|
||||||
if isinstance(args_[idx], torch.Tensor):
|
if isinstance(args_[idx], torch.Tensor):
|
||||||
args_[idx] = args_[idx].detach()
|
args_[idx] = args_[idx].detach()
|
||||||
with torch.autograd._force_original_view_tracking(True):
|
|
||||||
|
# It's possible to have trace_joint inside user specified with no_grad() region,
|
||||||
|
# if there is a nested with enable_grad(), that forces some outputs to require gradients.
|
||||||
|
# Therefore, we unconditionally turn on enable_grad() for compiled_fn execution.
|
||||||
|
with torch.autograd._force_original_view_tracking(
|
||||||
|
True
|
||||||
|
), torch.enable_grad():
|
||||||
all_outs = call_func_at_runtime_with_args(
|
all_outs = call_func_at_runtime_with_args(
|
||||||
compiled_fn, args_, disable_amp=disable_amp, steal_args=True
|
compiled_fn, args_, disable_amp=disable_amp, steal_args=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -572,9 +572,8 @@ def create_aot_dispatcher_function(
|
||||||
|
|
||||||
fake_flat_args = process_inputs(flat_args)
|
fake_flat_args = process_inputs(flat_args)
|
||||||
|
|
||||||
needs_autograd = (
|
needs_autograd = any(
|
||||||
any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor))
|
x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)
|
||||||
and torch.is_grad_enabled()
|
|
||||||
)
|
)
|
||||||
|
|
||||||
with enable_python_dispatcher():
|
with enable_python_dispatcher():
|
||||||
|
|
@ -600,7 +599,17 @@ def create_aot_dispatcher_function(
|
||||||
)
|
)
|
||||||
|
|
||||||
output_and_mutation_safe = not any(
|
output_and_mutation_safe = not any(
|
||||||
x.requires_grad for x in fw_metadata.output_info
|
x.requires_grad
|
||||||
|
# view-type operations preserve requires_grad even in no_grad.
|
||||||
|
# Do not count aliases of inputs with requires_grad as reason to make a training graph,
|
||||||
|
# as AOTAutograd will perform view-replay to regenerate the view outputs at runtime,
|
||||||
|
# setting their grad_fn properly.
|
||||||
|
and not (
|
||||||
|
x.output_type
|
||||||
|
in (OutputType.alias_of_input, OutputType.is_input)
|
||||||
|
and fw_metadata.input_info[x.base_idx].requires_grad
|
||||||
|
)
|
||||||
|
for x in fw_metadata.output_info
|
||||||
) and not any(
|
) and not any(
|
||||||
x.requires_grad
|
x.requires_grad
|
||||||
and x.mutates_data
|
and x.mutates_data
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user