Revert "[aota] Needs autograd if an input requires_grad, agnostic to enable_grad (#128890)"

This reverts commit 08d5423d33.

Reverted https://github.com/pytorch/pytorch/pull/128890 on behalf of https://github.com/clee2000 due to broke inductor/test_flex_attention https://github.com/pytorch/pytorch/actions/runs/9879109008/job/27286339304 08d5423d33 test was not run on PR due to bad TD ([comment](https://github.com/pytorch/pytorch/pull/128890#issuecomment-2221368245))
This commit is contained in:
PyTorch MergeBot 2024-07-10 20:22:24 +00:00
parent 1b3b4c2fb9
commit b81767161e
5 changed files with 60 additions and 128 deletions

View File

@ -509,7 +509,7 @@ def forward(self, primals_1):
wait_tensor = torch.ops._c10d_functional.wait_tensor.default(primals_1) wait_tensor = torch.ops._c10d_functional.wait_tensor.default(primals_1)
sin = torch.ops.aten.sin.default(wait_tensor) sin = torch.ops.aten.sin.default(wait_tensor)
sin_1 = torch.ops.aten.sin.default(sin); sin = None sin_1 = torch.ops.aten.sin.default(sin); sin = None
return (sin_1, primals_1, wait_tensor)""", return [sin_1, primals_1, wait_tensor]""",
) )
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch") @unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")

View File

@ -572,7 +572,7 @@ def forward(self, primals_1):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None
mul_1 = torch.ops.aten.mul.Tensor(mul, 3) mul_1 = torch.ops.aten.mul.Tensor(mul, 3)
return (mul, mul_1)""", return [mul, mul_1]""",
) )
def test_input_mutation_set__input_mutation(self): def test_input_mutation_set__input_mutation(self):
@ -642,7 +642,7 @@ def forward(self, primals_1, primals_2):
add = torch.ops.aten.add.Tensor(mul, mul) add = torch.ops.aten.add.Tensor(mul, mul)
set_ = torch.ops.aten.set_.source_Tensor(primals_1, mul); primals_1 = None set_ = torch.ops.aten.set_.source_Tensor(primals_1, mul); primals_1 = None
copy_ = torch.ops.aten.copy_.default(primals_2, mul); primals_2 = mul = None copy_ = torch.ops.aten.copy_.default(primals_2, mul); primals_2 = mul = None
return (add,)""", return [add]""",
) )
# This is a (hopefully) extremely rare case that is difficult to handle, # This is a (hopefully) extremely rare case that is difficult to handle,
@ -721,7 +721,7 @@ def forward(self, primals_1):
alias = torch.ops.aten.alias.default(primals_1); primals_1 = None alias = torch.ops.aten.alias.default(primals_1); primals_1 = None
view = torch.ops.aten.view.default(arange, [3, 3]); arange = None view = torch.ops.aten.view.default(arange, [3, 3]); arange = None
add = torch.ops.aten.add.Tensor(alias, view); alias = view = None add = torch.ops.aten.add.Tensor(alias, view); alias = view = None
return (add,)""", return [add]""",
) )
def test_input_mutation_simple_with_none_and_nontensor(self): def test_input_mutation_simple_with_none_and_nontensor(self):
@ -966,7 +966,7 @@ def forward(self, primals_1):
def forward(self, primals_1, primals_2, primals_3): def forward(self, primals_1, primals_2, primals_3):
t = torch.ops.aten.t.default(primals_1); primals_1 = None t = torch.ops.aten.t.default(primals_1); primals_1 = None
addmm = torch.ops.aten.addmm.default(primals_2, primals_3, t); primals_2 = None addmm = torch.ops.aten.addmm.default(primals_2, primals_3, t); primals_2 = None
return (addmm, primals_3, t)""", return [addmm, primals_3, t]""",
) )
with torch.inference_mode(): with torch.inference_mode():
@ -1017,7 +1017,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
def forward(self, primals_1): def forward(self, primals_1):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None
return (mul, mul)""", return [mul, mul]""",
) )
def test_input_mutation_multiple(self): def test_input_mutation_multiple(self):
@ -1046,7 +1046,7 @@ def forward(self, primals_1, primals_2, primals_3):
mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None
add = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None add = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None
add_1 = torch.ops.aten.add.Tensor(add, mul_1); add = None add_1 = torch.ops.aten.add.Tensor(add, mul_1); add = None
return (mul, mul_1, add_1)""", return [mul, mul_1, add_1]""",
) )
def test_input_mutation_return(self): def test_input_mutation_return(self):
@ -1117,7 +1117,7 @@ def forward(self, primals_1):
copy = torch.ops.aten.copy.default(primals_1, ones); ones = None copy = torch.ops.aten.copy.default(primals_1, ones); ones = None
add = torch.ops.aten.add.Tensor(copy, 1) add = torch.ops.aten.add.Tensor(copy, 1)
copy_ = torch.ops.aten.copy_.default(primals_1, copy); primals_1 = copy = None copy_ = torch.ops.aten.copy_.default(primals_1, copy); primals_1 = copy = None
return (add,)""", return [add]""",
) )
def test_input_mutation_storage_resize_down(self): def test_input_mutation_storage_resize_down(self):
@ -1150,7 +1150,7 @@ def forward(self, primals_1):
def forward(self, primals_1): def forward(self, primals_1):
sin = torch.ops.aten.sin.default(primals_1) sin = torch.ops.aten.sin.default(primals_1)
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 0) resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 0)
return (sin, primals_1)""", return [sin, primals_1]""",
) )
# def test_input_mutation_storage_resize_up_down(self): # def test_input_mutation_storage_resize_up_down(self):
@ -1251,7 +1251,7 @@ def forward(self, primals_1, primals_2):
sin = torch.ops.aten.sin.default(cat) sin = torch.ops.aten.sin.default(cat)
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(cat, 0) resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(cat, 0)
set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat); primals_1 = None set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat); primals_1 = None
return (sin, cat)""", return [sin, cat]""",
) )
def test_input_mutation_storage_resize_before_set_(self): def test_input_mutation_storage_resize_before_set_(self):
@ -1347,7 +1347,7 @@ def forward(self, primals_1):
view_1 = torch.ops.aten.view.default(mul, [4]); mul = None view_1 = torch.ops.aten.view.default(mul, [4]); mul = None
add = torch.ops.aten.add.Tensor(view_1, 1) add = torch.ops.aten.add.Tensor(view_1, 1)
copy_ = torch.ops.aten.copy_.default(primals_1, view_1); primals_1 = view_1 = None copy_ = torch.ops.aten.copy_.default(primals_1, view_1); primals_1 = view_1 = None
return (add,)""", return [add]""",
) )
def test_input_mutation_requires_grad_no_grad(self): def test_input_mutation_requires_grad_no_grad(self):
@ -1369,7 +1369,7 @@ def forward(self, primals_1):
mul = torch.ops.aten.mul.Tensor(primals_1, 2) mul = torch.ops.aten.mul.Tensor(primals_1, 2)
add = torch.ops.aten.add.Tensor(mul, 3) add = torch.ops.aten.add.Tensor(mul, 3)
copy_ = torch.ops.aten.copy_.default(primals_1, mul); primals_1 = mul = None copy_ = torch.ops.aten.copy_.default(primals_1, mul); primals_1 = mul = None
return (add,)""", return [add]""",
) )
def test_input_mutation_requires_grad_no_grad_inference_graph(self): def test_input_mutation_requires_grad_no_grad_inference_graph(self):
@ -1493,9 +1493,9 @@ def forward(self, arg0_1):
self.assertExpectedInline( self.assertExpectedInline(
fw_graph.code.strip(), fw_graph.code.strip(),
"""\ """\
def forward(self, arg0_1): def forward(self, primals_1):
view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None view = torch.ops.aten.view.default(primals_1, [-1]); primals_1 = None
return (view,)""", return [view]""",
) )
def test_input_output_view_mutate_multiple(self): def test_input_output_view_mutate_multiple(self):
@ -1527,7 +1527,7 @@ def forward(self, primals_1, primals_2, primals_3):
mul_1 = torch.ops.aten.mul.Tensor(clone_1, 3); clone_1 = None mul_1 = torch.ops.aten.mul.Tensor(clone_1, 3); clone_1 = None
view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None
view_2 = torch.ops.aten.view.default(mul_1, [2, 2]) view_2 = torch.ops.aten.view.default(mul_1, [2, 2])
return (mul, mul_1, view, view_2)""", return [mul, mul_1, view, view_2]""",
) )
def test_input_output_view_metadata_mutate_multiple(self): def test_input_output_view_metadata_mutate_multiple(self):
@ -1561,7 +1561,7 @@ def forward(self, primals_1, primals_2, primals_3):
view_1 = torch.ops.aten.view.default(primals_1, [2, 2]); primals_1 = None view_1 = torch.ops.aten.view.default(primals_1, [2, 2]); primals_1 = None
view_3 = torch.ops.aten.view.default(t, [2, 2]) view_3 = torch.ops.aten.view.default(t, [2, 2])
view_4 = torch.ops.aten.view.default(mul, [2, 2]) view_4 = torch.ops.aten.view.default(mul, [2, 2])
return (mul, t, view_1, view_4, view_3)""", return [mul, t, view_1, view_4, view_3]""",
) )
def test_input_mutation_and_output_view(self): def test_input_mutation_and_output_view(self):
@ -1583,7 +1583,7 @@ def forward(self, primals_1):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
add = torch.ops.aten.add.Tensor(clone, 1); clone = None add = torch.ops.aten.add.Tensor(clone, 1); clone = None
view_1 = torch.ops.aten.view.default(add, [-1]) view_1 = torch.ops.aten.view.default(add, [-1])
return (add, view_1)""", return [add, view_1]""",
) )
def test_input_mutation_output_view_multiple(self): def test_input_mutation_output_view_multiple(self):
@ -1617,7 +1617,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4):
add_1 = torch.ops.aten.add.Tensor(primals_4, 1); primals_4 = None add_1 = torch.ops.aten.add.Tensor(primals_4, 1); primals_4 = None
diagonal = torch.ops.aten.diagonal.default(transpose) diagonal = torch.ops.aten.diagonal.default(transpose)
add_2 = torch.ops.aten.add.Tensor(primals_1, add); primals_1 = None add_2 = torch.ops.aten.add.Tensor(primals_1, add); primals_1 = None
return (transpose, add, add_1, diagonal, add_2)""", return [transpose, add, add_1, diagonal, add_2]""",
) )
def test_output_aliases_intermediate_single(self): def test_output_aliases_intermediate_single(self):
@ -1637,7 +1637,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4):
def forward(self, primals_1): def forward(self, primals_1):
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
view = torch.ops.aten.view.default(mul, [-1]); mul = None view = torch.ops.aten.view.default(mul, [-1]); mul = None
return (view,)""", return [view]""",
) )
def test_output_aliases_input_multi_output_view_should_raise_autograd_error(self): def test_output_aliases_input_multi_output_view_should_raise_autograd_error(self):
@ -1872,7 +1872,7 @@ def forward(self, primals_1, primals_2):
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
view = torch.ops.aten.view.default(mul, [-1]); mul = None view = torch.ops.aten.view.default(mul, [-1]); mul = None
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
return (view, add)""", return [view, add]""",
) )
def test_output_aliases_intermediate_returned_multiple_times(self): def test_output_aliases_intermediate_returned_multiple_times(self):
@ -1903,7 +1903,7 @@ def forward(self, primals_1):
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
view = torch.ops.aten.view.default(mul, [-1]) view = torch.ops.aten.view.default(mul, [-1])
view_1 = torch.ops.aten.view.default(mul, [-1]) view_1 = torch.ops.aten.view.default(mul, [-1])
return (view, view_1, mul)""", return [view, view_1, mul]""",
) )
def test_output_aliases_intermediate_and_returned(self): def test_output_aliases_intermediate_and_returned(self):
@ -1923,7 +1923,7 @@ def forward(self, primals_1):
def forward(self, primals_1): def forward(self, primals_1):
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
view = torch.ops.aten.view.default(mul, [-1]) view = torch.ops.aten.view.default(mul, [-1])
return (view, mul)""", return [view, mul]""",
) )
def test_output_aliases_intermediate_and_returned_flipped(self): def test_output_aliases_intermediate_and_returned_flipped(self):
@ -1943,7 +1943,7 @@ def forward(self, primals_1):
def forward(self, primals_1): def forward(self, primals_1):
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
view = torch.ops.aten.view.default(mul, [-1]) view = torch.ops.aten.view.default(mul, [-1])
return (mul, view)""", return [mul, view]""",
) )
def test_output_aliases_intermediate_and_returned_different_grad(self): def test_output_aliases_intermediate_and_returned_different_grad(self):
@ -1967,7 +1967,7 @@ def forward(self, primals_1):
detach = torch.ops.aten.detach.default(select); select = None detach = torch.ops.aten.detach.default(select); select = None
detach_1 = torch.ops.aten.detach.default(detach); detach = None detach_1 = torch.ops.aten.detach.default(detach); detach = None
detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None
return (view, mul, detach_2)""", return [view, mul, detach_2]""",
) )
def test_output_aliases_intermediate_inplace_view(self): def test_output_aliases_intermediate_inplace_view(self):
@ -2003,7 +2003,7 @@ def forward(self, primals_1):
mul = torch.ops.aten.mul.Tensor(primals_1, 3) mul = torch.ops.aten.mul.Tensor(primals_1, 3)
t = torch.ops.aten.t.default(mul); mul = None t = torch.ops.aten.t.default(mul); mul = None
add = torch.ops.aten.add.Tensor(primals_1, 1); primals_1 = None add = torch.ops.aten.add.Tensor(primals_1, 1); primals_1 = None
return (t, add)""", return [t, add]""",
) )
def test_output_aliases_intermediate_inplace_view_and_view(self): def test_output_aliases_intermediate_inplace_view_and_view(self):
@ -2040,7 +2040,7 @@ def forward(self, primals_1):
view = torch.ops.aten.view.default(mul, [-1]) view = torch.ops.aten.view.default(mul, [-1])
transpose = torch.ops.aten.transpose.int(mul_1, 1, 0); mul_1 = None transpose = torch.ops.aten.transpose.int(mul_1, 1, 0); mul_1 = None
transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0) transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0)
return (view, transpose, transpose_1, mul)""", return [view, transpose, transpose_1, mul]""",
) )
def test_output_all_alias_types(self): def test_output_all_alias_types(self):
@ -2076,7 +2076,7 @@ def forward(self, primals_1):
squeeze = torch.ops.aten.squeeze.default(mul) squeeze = torch.ops.aten.squeeze.default(mul)
transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0) transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0)
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0) unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0)
return (transpose, squeeze, transpose_1, unsqueeze, mul)""", return [transpose, squeeze, transpose_1, unsqueeze, mul]""",
) )
@parametrize("req_grad", [False, True]) @parametrize("req_grad", [False, True])
@ -2126,7 +2126,7 @@ def forward(self, primals_1):
t_4 = torch.ops.aten.t.default(t_2) t_4 = torch.ops.aten.t.default(t_2)
t_6 = torch.ops.aten.t.default(t_2); t_2 = None t_6 = torch.ops.aten.t.default(t_2); t_2 = None
view_1 = torch.ops.aten.view.default(t_6, [3, 3]); t_6 = None view_1 = torch.ops.aten.view.default(t_6, [3, 3]); t_6 = None
return (t_4, view_1)""", return [t_4, view_1]""",
) )
def test_view_and_inplace_view(self): def test_view_and_inplace_view(self):
@ -2145,11 +2145,12 @@ def forward(self, primals_1):
self.assertExpectedInline( self.assertExpectedInline(
fw_graph.code.strip(), fw_graph.code.strip(),
"""\ """\
def forward(self, arg0_1, arg1_1): def forward(self, primals_1, primals_2):
t = torch.ops.aten.t.default(arg0_1); arg0_1 = None view = torch.ops.aten.view.default(primals_1, [3, 3]); primals_1 = None
view = torch.ops.aten.view.default(arg1_1, [3, 3]); arg1_1 = None t = torch.ops.aten.t.default(view); view = None
view_1 = torch.ops.aten.view.default(t, [3, 3]) view_1 = torch.ops.aten.view.default(primals_2, [3, 3]); primals_2 = None
return (t, view, view_1)""", view_2 = torch.ops.aten.view.default(t, [3, 3])
return [t, view_1, view_2]""",
) )
def test_view_detach(self): def test_view_detach(self):
@ -2181,7 +2182,7 @@ def forward(self, arg0_1, arg1_1):
def forward(self, primals_1, primals_2): def forward(self, primals_1, primals_2):
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
mul_1 = torch.ops.aten.mul.Tensor(primals_2, 4); primals_2 = None mul_1 = torch.ops.aten.mul.Tensor(primals_2, 4); primals_2 = None
return (mul, mul_1)""", return [mul, mul_1]""",
) )
# This is a torture test: # This is a torture test:
@ -2701,7 +2702,7 @@ def forward(self, primals_1):
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
return (as_strided_scatter, add_1)""", return [as_strided_scatter, add_1]""",
) # noqa: B950 ) # noqa: B950
def test_input_mutation_aliases_other_input2(self): def test_input_mutation_aliases_other_input2(self):
@ -2734,7 +2735,7 @@ def forward(self, primals_1):
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0) as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0) as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
return (as_strided_scatter, add_1)""", return [as_strided_scatter, add_1]""",
) # noqa: B950 ) # noqa: B950
def test_input_mutation_aliases_and_output_alias(self): def test_input_mutation_aliases_and_output_alias(self):
@ -2765,7 +2766,7 @@ def forward(self, primals_1):
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
return (as_strided_scatter, view_1)""", return [as_strided_scatter, view_1]""",
) # noqa: B950 ) # noqa: B950
def test_input_aliased_with_mutation_output_alias(self): def test_input_aliased_with_mutation_output_alias(self):
@ -2802,7 +2803,7 @@ def forward(self, primals_1, primals_2):
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None
return (as_strided_scatter, add, view_1)""", return [as_strided_scatter, add, view_1]""",
) # noqa: B950 ) # noqa: B950
def test_input_metadata_mutation_aliases(self): def test_input_metadata_mutation_aliases(self):
@ -2831,7 +2832,7 @@ def forward(self, primals_1, primals_2):
def forward(self, primals_1, primals_2): def forward(self, primals_1, primals_2):
t = torch.ops.aten.t.default(primals_1); primals_1 = None t = torch.ops.aten.t.default(primals_1); primals_1 = None
add = torch.ops.aten.add.Tensor(t, primals_2); t = primals_2 = None add = torch.ops.aten.add.Tensor(t, primals_2); t = primals_2 = None
return (add,)""", return [add]""",
) )
def test_input_mutation_aliases_and_none_require_gradients(self): def test_input_mutation_aliases_and_none_require_gradients(self):
@ -2874,7 +2875,7 @@ def forward(self, primals_1, primals_2):
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None
add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
return (as_strided_scatter, add, add_1)""", return [as_strided_scatter, add, add_1]""",
) # noqa: B950 ) # noqa: B950
@skipIfDynamoInput("Fails with dynamo") @skipIfDynamoInput("Fails with dynamo")
@ -2933,7 +2934,7 @@ def forward(self, primals_1, primals_2, primals_3):
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0) as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
return (as_strided_scatter, add_2, view_2, unsqueeze_1)""", return [as_strided_scatter, add_2, view_2, unsqueeze_1]""",
) # noqa: B950 ) # noqa: B950
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable") @unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
@ -3007,7 +3008,7 @@ def forward(self, primals_1, primals_2):
view_1 = torch.ops.aten.view.default(add, [-1]) view_1 = torch.ops.aten.view.default(add, [-1])
t_1 = torch.ops.aten.t.default(t) t_1 = torch.ops.aten.t.default(t)
unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0) unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0)
return (as_strided_scatter, t, view_1, t_1, unsqueeze, add)""", return [as_strided_scatter, t, view_1, t_1, unsqueeze, add]""",
) # noqa: B950 ) # noqa: B950
def test_dynamic_shape_output_not_in_bw_graph(self): def test_dynamic_shape_output_not_in_bw_graph(self):
@ -3035,7 +3036,7 @@ def forward(self, primals_1, primals_2):
bw_graph_cell[0].code.strip(), bw_graph_cell[0].code.strip(),
"""\ """\
def forward(self, tangents_1): def forward(self, tangents_1):
return (tangents_1,)""", return [tangents_1]""",
) )
def test_no_grad_input_output(self): def test_no_grad_input_output(self):
@ -3555,7 +3556,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4):
sum_2 = torch.ops.aten.sum.default(add) sum_2 = torch.ops.aten.sum.default(add)
add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
copy_ = torch.ops.aten.copy_.default(primals_3, add); primals_3 = add = None copy_ = torch.ops.aten.copy_.default(primals_3, add); primals_3 = add = None
return (add_1, primals_1, primals_2, primals_4, mul)""", return [add_1, primals_1, primals_2, primals_4, mul]""",
) )
self.assertEqual(out_ref, out_test) self.assertEqual(out_ref, out_test)
@ -3610,7 +3611,7 @@ def forward(self, primals_1, primals_2, primals_3):
sum_2 = torch.ops.aten.sum.default(add) sum_2 = torch.ops.aten.sum.default(add)
add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = None copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = None
return (add_1, primals_1, primals_3)""", return [add_1, primals_1, primals_3]""",
) )
self.assertEqual(out_ref, out_test) self.assertEqual(out_ref, out_test)
@ -3668,7 +3669,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals
copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = None copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = None
copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = None copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = None
copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = None copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = None
return (getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4)""", # noqa: B950 return [getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4]""", # noqa: B950
) )
self.assertEqual(out_ref, out_test) self.assertEqual(out_ref, out_test)
@ -3688,7 +3689,7 @@ def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem
getitem_5 = native_batch_norm_backward[0] getitem_5 = native_batch_norm_backward[0]
getitem_6 = native_batch_norm_backward[1] getitem_6 = native_batch_norm_backward[1]
getitem_7 = native_batch_norm_backward[2]; native_batch_norm_backward = None getitem_7 = native_batch_norm_backward[2]; native_batch_norm_backward = None
return (getitem_6, getitem_7, None, None, None, getitem_5)""", # noqa: B950 return [getitem_6, getitem_7, None, None, None, getitem_5]""", # noqa: B950
) )
self.assertEqual(inp_ref.grad, inp_test.grad) self.assertEqual(inp_ref.grad, inp_test.grad)
@ -3729,7 +3730,7 @@ def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem
def forward(self, primals_1, primals_2): def forward(self, primals_1, primals_2):
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
add = torch.ops.aten.add.Tensor(clone, primals_2); clone = primals_2 = None add = torch.ops.aten.add.Tensor(clone, primals_2); clone = primals_2 = None
return (add, add)""", return [add, add]""",
) # noqa: B950 ) # noqa: B950
self.assertEqual(out_ref, out_test) self.assertEqual(out_ref, out_test)
@ -3741,7 +3742,7 @@ def forward(self, primals_1, primals_2):
bw_graph_cell[0].code.strip(), bw_graph_cell[0].code.strip(),
"""\ """\
def forward(self, tangents_1): def forward(self, tangents_1):
return (None, tangents_1)""", return [None, tangents_1]""",
) # noqa: B950 ) # noqa: B950
def test_real_weights_in_symbolic_mode(self): def test_real_weights_in_symbolic_mode(self):
@ -3829,7 +3830,6 @@ def forward(self, tangents_1):
# since they are the only ones that will get reconstructed. # since they are the only ones that will get reconstructed.
def wrapper(g, *args, **kwargs): def wrapper(g, *args, **kwargs):
outs = g(*args, **kwargs) outs = g(*args, **kwargs)
outs = list(outs)
for i in output_view_indices: for i in output_view_indices:
outs[i] = NoViewReplayTensor(outs[i]) outs[i] = NoViewReplayTensor(outs[i])
return outs return outs
@ -5333,7 +5333,7 @@ def forward(self, primals_1, primals_2, primals_3):
div = torch.ops.aten.div.Tensor(primals_3, 2); primals_3 = None div = torch.ops.aten.div.Tensor(primals_3, 2); primals_3 = None
add = torch.ops.aten.add.Tensor(mul, div); mul = None add = torch.ops.aten.add.Tensor(mul, div); mul = None
add_1 = torch.ops.aten.add.Tensor(mul_1, div); mul_1 = div = None add_1 = torch.ops.aten.add.Tensor(mul_1, div); mul_1 = div = None
return (add, add_1)""", return [add, add_1]""",
) )
# Important pieces of the graph: # Important pieces of the graph:
@ -5351,7 +5351,7 @@ def forward(self, tangents_1, tangents_2):
div_2 = torch.ops.aten.div.Tensor(tangents_2, 2) div_2 = torch.ops.aten.div.Tensor(tangents_2, 2)
mul_2 = torch.ops.aten.mul.Tensor(tangents_1, 6); tangents_1 = None mul_2 = torch.ops.aten.mul.Tensor(tangents_1, 6); tangents_1 = None
mul_3 = torch.ops.aten.mul.Tensor(tangents_2, 6); tangents_2 = None mul_3 = torch.ops.aten.mul.Tensor(tangents_2, 6); tangents_2 = None
return (mul_2, mul_3, div_1, div_2)""", return [mul_2, mul_3, div_1, div_2]""",
) )
def test_aot_dispatch_inference(self): def test_aot_dispatch_inference(self):
@ -5666,59 +5666,6 @@ def forward(self, tangents_1, tangents_2):
self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a) self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b) self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
def test_aot_dispatch_output_requires_grad_in_no_grad(self):
def fn(x):
out1 = x.sin()
with torch.enable_grad():
out2 = x.cos()
return out1, out2
inp_fns = [
lambda: torch.ones(10, requires_grad=True),
lambda: torch.ones(10, requires_grad=False),
]
compiled_f = aot_function(fn, nop)
for inp_fn in inp_fns:
with torch.no_grad():
ref_x = inp_fn()
ref_out = fn(ref_x)
x = inp_fn()
out = compiled_f(x)
for r, o in zip(ref_out, out):
self.assertEqual(r.requires_grad, o.requires_grad)
if ref_x.requires_grad:
with torch.enable_grad():
(ref_out[0] + ref_out[1]).sum().backward()
(out[0] + out[1]).sum().backward()
self.assertEqual(ref_x.grad, x.grad)
assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3)
def test_aot_dispatch_output_requires_grad_in_no_grad_views(self):
# view-type ops preserve requires_grad even in no_grad.
def fn(x):
return x.view(-1), x.sin()
inference_graph_cell = [None]
inference_compiler = make_boxed_compiler(
partial(extract_graph, graph_cell=inference_graph_cell)
)
compiled_fn = aot_function(fn, nop, inference_compiler=inference_compiler)
inp_x0 = torch.ones(2, 3, requires_grad=True)
# Clone in no_grad will make requires_grad=False tensors, keep clone outside of no_grad
ref_x0 = inp_x0.clone()
x0 = inp_x0.clone()
with torch.no_grad():
ref_out1, ref_out2 = fn(ref_x0)
out1, out2 = compiled_fn(x0)
# Assert that we executed inference graph
self.assertTrue(inference_graph_cell[0] is not None)
self.assertEqual(ref_out1.requires_grad, out1.requires_grad)
self.assertEqual(ref_out2.requires_grad, out2.requires_grad)
class TestAOTModuleSimplified(AOTTestCase): class TestAOTModuleSimplified(AOTTestCase):
def test_aot_module_simplified(self): def test_aot_module_simplified(self):

View File

@ -304,13 +304,7 @@ def _create_runtime_wrapper(
for idx in indices_of_inps_to_detach: for idx in indices_of_inps_to_detach:
if isinstance(args_[idx], torch.Tensor): if isinstance(args_[idx], torch.Tensor):
args_[idx] = args_[idx].detach() args_[idx] = args_[idx].detach()
with torch.autograd._force_original_view_tracking(True):
# It's possible to have trace_joint inside user specified with no_grad() region,
# if there is a nested with enable_grad(), that forces some outputs to require gradients.
# Therefore, we unconditionally turn on enable_grad() for compiled_fn execution.
with torch.autograd._force_original_view_tracking(
True
), torch.enable_grad():
all_outs = call_func_at_runtime_with_args( all_outs = call_func_at_runtime_with_args(
compiled_fn, args_, disable_amp=disable_amp, steal_args=True compiled_fn, args_, disable_amp=disable_amp, steal_args=True
) )

View File

@ -566,8 +566,9 @@ def create_aot_dispatcher_function(
fake_flat_args = process_inputs(flat_args) fake_flat_args = process_inputs(flat_args)
needs_autograd = any( needs_autograd = (
x.requires_grad for x in fake_flat_args if isinstance(x, Tensor) any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor))
and torch.is_grad_enabled()
) )
with enable_python_dispatcher(): with enable_python_dispatcher():
@ -592,17 +593,7 @@ def create_aot_dispatcher_function(
) )
output_and_mutation_safe = not any( output_and_mutation_safe = not any(
x.requires_grad x.requires_grad for x in fw_metadata.output_info
# view-type operations preserve requires_grad even in no_grad.
# Do not count aliases of inputs with requires_grad as reason to make a training graph,
# as AOTAutograd will perform view-replay to regenerate the view outputs at runtime,
# setting their grad_fn properly.
and not (
x.output_type
in (OutputType.alias_of_input, OutputType.is_input)
and fw_metadata.input_info[x.base_idx].requires_grad
)
for x in fw_metadata.output_info
) and not any( ) and not any(
x.requires_grad x.requires_grad
and x.mutates_data and x.mutates_data

View File

@ -204,7 +204,7 @@ def _extract_graph_with_inputs_outputs(
output_values.append(env[x]) output_values.append(env[x])
else: else:
output_values.append(x) output_values.append(x)
new_graph.output(tuple(output_values)) new_graph.output(output_values)
new_graph.eliminate_dead_code() new_graph.eliminate_dead_code()
new_graph.lint() new_graph.lint()
@ -727,7 +727,7 @@ def functionalize_rng_ops(
sym_node_start_idx = len(fw_outputs) - num_sym_nodes sym_node_start_idx = len(fw_outputs) - num_sym_nodes
outputs = ( outputs = (
fw_outputs[:sym_node_start_idx] fw_outputs[:sym_node_start_idx]
+ tuple(fw_rng_state_outputs) + fw_rng_state_outputs
+ fw_outputs[sym_node_start_idx:] + fw_outputs[sym_node_start_idx:]
) )
fw_module.graph.output(outputs) fw_module.graph.output(outputs)