mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[aota] Needs autograd if an input requires_grad, agnostic to enable_grad (#128890)"
This reverts commit08d5423d33. Reverted https://github.com/pytorch/pytorch/pull/128890 on behalf of https://github.com/clee2000 due to broke inductor/test_flex_attention https://github.com/pytorch/pytorch/actions/runs/9879109008/job/2728633930408d5423d33test was not run on PR due to bad TD ([comment](https://github.com/pytorch/pytorch/pull/128890#issuecomment-2221368245))
This commit is contained in:
parent
1b3b4c2fb9
commit
b81767161e
|
|
@ -509,7 +509,7 @@ def forward(self, primals_1):
|
|||
wait_tensor = torch.ops._c10d_functional.wait_tensor.default(primals_1)
|
||||
sin = torch.ops.aten.sin.default(wait_tensor)
|
||||
sin_1 = torch.ops.aten.sin.default(sin); sin = None
|
||||
return (sin_1, primals_1, wait_tensor)""",
|
||||
return [sin_1, primals_1, wait_tensor]""",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not has_triton(), "Inductor+gpu needs triton and recent GPU arch")
|
||||
|
|
|
|||
|
|
@ -572,7 +572,7 @@ def forward(self, primals_1):
|
|||
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(mul, 3)
|
||||
return (mul, mul_1)""",
|
||||
return [mul, mul_1]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_set__input_mutation(self):
|
||||
|
|
@ -642,7 +642,7 @@ def forward(self, primals_1, primals_2):
|
|||
add = torch.ops.aten.add.Tensor(mul, mul)
|
||||
set_ = torch.ops.aten.set_.source_Tensor(primals_1, mul); primals_1 = None
|
||||
copy_ = torch.ops.aten.copy_.default(primals_2, mul); primals_2 = mul = None
|
||||
return (add,)""",
|
||||
return [add]""",
|
||||
)
|
||||
|
||||
# This is a (hopefully) extremely rare case that is difficult to handle,
|
||||
|
|
@ -721,7 +721,7 @@ def forward(self, primals_1):
|
|||
alias = torch.ops.aten.alias.default(primals_1); primals_1 = None
|
||||
view = torch.ops.aten.view.default(arange, [3, 3]); arange = None
|
||||
add = torch.ops.aten.add.Tensor(alias, view); alias = view = None
|
||||
return (add,)""",
|
||||
return [add]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_simple_with_none_and_nontensor(self):
|
||||
|
|
@ -966,7 +966,7 @@ def forward(self, primals_1):
|
|||
def forward(self, primals_1, primals_2, primals_3):
|
||||
t = torch.ops.aten.t.default(primals_1); primals_1 = None
|
||||
addmm = torch.ops.aten.addmm.default(primals_2, primals_3, t); primals_2 = None
|
||||
return (addmm, primals_3, t)""",
|
||||
return [addmm, primals_3, t]""",
|
||||
)
|
||||
|
||||
with torch.inference_mode():
|
||||
|
|
@ -1017,7 +1017,7 @@ def forward(self, arg0_1, arg1_1, arg2_1):
|
|||
def forward(self, primals_1):
|
||||
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
mul = torch.ops.aten.mul.Tensor(clone, 2); clone = None
|
||||
return (mul, mul)""",
|
||||
return [mul, mul]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_multiple(self):
|
||||
|
|
@ -1046,7 +1046,7 @@ def forward(self, primals_1, primals_2, primals_3):
|
|||
mul_1 = torch.ops.aten.mul.Tensor(clone_1, 2); clone_1 = None
|
||||
add = torch.ops.aten.add.Tensor(mul, primals_2); primals_2 = None
|
||||
add_1 = torch.ops.aten.add.Tensor(add, mul_1); add = None
|
||||
return (mul, mul_1, add_1)""",
|
||||
return [mul, mul_1, add_1]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_return(self):
|
||||
|
|
@ -1117,7 +1117,7 @@ def forward(self, primals_1):
|
|||
copy = torch.ops.aten.copy.default(primals_1, ones); ones = None
|
||||
add = torch.ops.aten.add.Tensor(copy, 1)
|
||||
copy_ = torch.ops.aten.copy_.default(primals_1, copy); primals_1 = copy = None
|
||||
return (add,)""",
|
||||
return [add]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_storage_resize_down(self):
|
||||
|
|
@ -1150,7 +1150,7 @@ def forward(self, primals_1):
|
|||
def forward(self, primals_1):
|
||||
sin = torch.ops.aten.sin.default(primals_1)
|
||||
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(primals_1, 0)
|
||||
return (sin, primals_1)""",
|
||||
return [sin, primals_1]""",
|
||||
)
|
||||
|
||||
# def test_input_mutation_storage_resize_up_down(self):
|
||||
|
|
@ -1251,7 +1251,7 @@ def forward(self, primals_1, primals_2):
|
|||
sin = torch.ops.aten.sin.default(cat)
|
||||
resize_storage_bytes_ = torch.ops.inductor.resize_storage_bytes_.default(cat, 0)
|
||||
set_ = torch.ops.aten.set_.source_Tensor(primals_1, cat); primals_1 = None
|
||||
return (sin, cat)""",
|
||||
return [sin, cat]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_storage_resize_before_set_(self):
|
||||
|
|
@ -1347,7 +1347,7 @@ def forward(self, primals_1):
|
|||
view_1 = torch.ops.aten.view.default(mul, [4]); mul = None
|
||||
add = torch.ops.aten.add.Tensor(view_1, 1)
|
||||
copy_ = torch.ops.aten.copy_.default(primals_1, view_1); primals_1 = view_1 = None
|
||||
return (add,)""",
|
||||
return [add]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_requires_grad_no_grad(self):
|
||||
|
|
@ -1369,7 +1369,7 @@ def forward(self, primals_1):
|
|||
mul = torch.ops.aten.mul.Tensor(primals_1, 2)
|
||||
add = torch.ops.aten.add.Tensor(mul, 3)
|
||||
copy_ = torch.ops.aten.copy_.default(primals_1, mul); primals_1 = mul = None
|
||||
return (add,)""",
|
||||
return [add]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_requires_grad_no_grad_inference_graph(self):
|
||||
|
|
@ -1493,9 +1493,9 @@ def forward(self, arg0_1):
|
|||
self.assertExpectedInline(
|
||||
fw_graph.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None
|
||||
return (view,)""",
|
||||
def forward(self, primals_1):
|
||||
view = torch.ops.aten.view.default(primals_1, [-1]); primals_1 = None
|
||||
return [view]""",
|
||||
)
|
||||
|
||||
def test_input_output_view_mutate_multiple(self):
|
||||
|
|
@ -1527,7 +1527,7 @@ def forward(self, primals_1, primals_2, primals_3):
|
|||
mul_1 = torch.ops.aten.mul.Tensor(clone_1, 3); clone_1 = None
|
||||
view = torch.ops.aten.view.default(primals_2, [2, 2]); primals_2 = None
|
||||
view_2 = torch.ops.aten.view.default(mul_1, [2, 2])
|
||||
return (mul, mul_1, view, view_2)""",
|
||||
return [mul, mul_1, view, view_2]""",
|
||||
)
|
||||
|
||||
def test_input_output_view_metadata_mutate_multiple(self):
|
||||
|
|
@ -1561,7 +1561,7 @@ def forward(self, primals_1, primals_2, primals_3):
|
|||
view_1 = torch.ops.aten.view.default(primals_1, [2, 2]); primals_1 = None
|
||||
view_3 = torch.ops.aten.view.default(t, [2, 2])
|
||||
view_4 = torch.ops.aten.view.default(mul, [2, 2])
|
||||
return (mul, t, view_1, view_4, view_3)""",
|
||||
return [mul, t, view_1, view_4, view_3]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_and_output_view(self):
|
||||
|
|
@ -1583,7 +1583,7 @@ def forward(self, primals_1):
|
|||
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
add = torch.ops.aten.add.Tensor(clone, 1); clone = None
|
||||
view_1 = torch.ops.aten.view.default(add, [-1])
|
||||
return (add, view_1)""",
|
||||
return [add, view_1]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_output_view_multiple(self):
|
||||
|
|
@ -1617,7 +1617,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4):
|
|||
add_1 = torch.ops.aten.add.Tensor(primals_4, 1); primals_4 = None
|
||||
diagonal = torch.ops.aten.diagonal.default(transpose)
|
||||
add_2 = torch.ops.aten.add.Tensor(primals_1, add); primals_1 = None
|
||||
return (transpose, add, add_1, diagonal, add_2)""",
|
||||
return [transpose, add, add_1, diagonal, add_2]""",
|
||||
)
|
||||
|
||||
def test_output_aliases_intermediate_single(self):
|
||||
|
|
@ -1637,7 +1637,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4):
|
|||
def forward(self, primals_1):
|
||||
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
|
||||
view = torch.ops.aten.view.default(mul, [-1]); mul = None
|
||||
return (view,)""",
|
||||
return [view]""",
|
||||
)
|
||||
|
||||
def test_output_aliases_input_multi_output_view_should_raise_autograd_error(self):
|
||||
|
|
@ -1872,7 +1872,7 @@ def forward(self, primals_1, primals_2):
|
|||
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
|
||||
view = torch.ops.aten.view.default(mul, [-1]); mul = None
|
||||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
return (view, add)""",
|
||||
return [view, add]""",
|
||||
)
|
||||
|
||||
def test_output_aliases_intermediate_returned_multiple_times(self):
|
||||
|
|
@ -1903,7 +1903,7 @@ def forward(self, primals_1):
|
|||
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
|
||||
view = torch.ops.aten.view.default(mul, [-1])
|
||||
view_1 = torch.ops.aten.view.default(mul, [-1])
|
||||
return (view, view_1, mul)""",
|
||||
return [view, view_1, mul]""",
|
||||
)
|
||||
|
||||
def test_output_aliases_intermediate_and_returned(self):
|
||||
|
|
@ -1923,7 +1923,7 @@ def forward(self, primals_1):
|
|||
def forward(self, primals_1):
|
||||
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
|
||||
view = torch.ops.aten.view.default(mul, [-1])
|
||||
return (view, mul)""",
|
||||
return [view, mul]""",
|
||||
)
|
||||
|
||||
def test_output_aliases_intermediate_and_returned_flipped(self):
|
||||
|
|
@ -1943,7 +1943,7 @@ def forward(self, primals_1):
|
|||
def forward(self, primals_1):
|
||||
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
|
||||
view = torch.ops.aten.view.default(mul, [-1])
|
||||
return (mul, view)""",
|
||||
return [mul, view]""",
|
||||
)
|
||||
|
||||
def test_output_aliases_intermediate_and_returned_different_grad(self):
|
||||
|
|
@ -1967,7 +1967,7 @@ def forward(self, primals_1):
|
|||
detach = torch.ops.aten.detach.default(select); select = None
|
||||
detach_1 = torch.ops.aten.detach.default(detach); detach = None
|
||||
detach_2 = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
||||
return (view, mul, detach_2)""",
|
||||
return [view, mul, detach_2]""",
|
||||
)
|
||||
|
||||
def test_output_aliases_intermediate_inplace_view(self):
|
||||
|
|
@ -2003,7 +2003,7 @@ def forward(self, primals_1):
|
|||
mul = torch.ops.aten.mul.Tensor(primals_1, 3)
|
||||
t = torch.ops.aten.t.default(mul); mul = None
|
||||
add = torch.ops.aten.add.Tensor(primals_1, 1); primals_1 = None
|
||||
return (t, add)""",
|
||||
return [t, add]""",
|
||||
)
|
||||
|
||||
def test_output_aliases_intermediate_inplace_view_and_view(self):
|
||||
|
|
@ -2040,7 +2040,7 @@ def forward(self, primals_1):
|
|||
view = torch.ops.aten.view.default(mul, [-1])
|
||||
transpose = torch.ops.aten.transpose.int(mul_1, 1, 0); mul_1 = None
|
||||
transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0)
|
||||
return (view, transpose, transpose_1, mul)""",
|
||||
return [view, transpose, transpose_1, mul]""",
|
||||
)
|
||||
|
||||
def test_output_all_alias_types(self):
|
||||
|
|
@ -2076,7 +2076,7 @@ def forward(self, primals_1):
|
|||
squeeze = torch.ops.aten.squeeze.default(mul)
|
||||
transpose_1 = torch.ops.aten.transpose.int(mul, 1, 0)
|
||||
unsqueeze = torch.ops.aten.unsqueeze.default(transpose, 0)
|
||||
return (transpose, squeeze, transpose_1, unsqueeze, mul)""",
|
||||
return [transpose, squeeze, transpose_1, unsqueeze, mul]""",
|
||||
)
|
||||
|
||||
@parametrize("req_grad", [False, True])
|
||||
|
|
@ -2126,7 +2126,7 @@ def forward(self, primals_1):
|
|||
t_4 = torch.ops.aten.t.default(t_2)
|
||||
t_6 = torch.ops.aten.t.default(t_2); t_2 = None
|
||||
view_1 = torch.ops.aten.view.default(t_6, [3, 3]); t_6 = None
|
||||
return (t_4, view_1)""",
|
||||
return [t_4, view_1]""",
|
||||
)
|
||||
|
||||
def test_view_and_inplace_view(self):
|
||||
|
|
@ -2145,11 +2145,12 @@ def forward(self, primals_1):
|
|||
self.assertExpectedInline(
|
||||
fw_graph.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
t = torch.ops.aten.t.default(arg0_1); arg0_1 = None
|
||||
view = torch.ops.aten.view.default(arg1_1, [3, 3]); arg1_1 = None
|
||||
view_1 = torch.ops.aten.view.default(t, [3, 3])
|
||||
return (t, view, view_1)""",
|
||||
def forward(self, primals_1, primals_2):
|
||||
view = torch.ops.aten.view.default(primals_1, [3, 3]); primals_1 = None
|
||||
t = torch.ops.aten.t.default(view); view = None
|
||||
view_1 = torch.ops.aten.view.default(primals_2, [3, 3]); primals_2 = None
|
||||
view_2 = torch.ops.aten.view.default(t, [3, 3])
|
||||
return [t, view_1, view_2]""",
|
||||
)
|
||||
|
||||
def test_view_detach(self):
|
||||
|
|
@ -2181,7 +2182,7 @@ def forward(self, arg0_1, arg1_1):
|
|||
def forward(self, primals_1, primals_2):
|
||||
mul = torch.ops.aten.mul.Tensor(primals_1, 3); primals_1 = None
|
||||
mul_1 = torch.ops.aten.mul.Tensor(primals_2, 4); primals_2 = None
|
||||
return (mul, mul_1)""",
|
||||
return [mul, mul_1]""",
|
||||
)
|
||||
|
||||
# This is a torture test:
|
||||
|
|
@ -2701,7 +2702,7 @@ def forward(self, primals_1):
|
|||
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
|
||||
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
|
||||
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
|
||||
return (as_strided_scatter, add_1)""",
|
||||
return [as_strided_scatter, add_1]""",
|
||||
) # noqa: B950
|
||||
|
||||
def test_input_mutation_aliases_other_input2(self):
|
||||
|
|
@ -2734,7 +2735,7 @@ def forward(self, primals_1):
|
|||
as_strided_2 = torch.ops.aten.as_strided.default(as_strided_scatter, [2], [1], 0)
|
||||
as_strided_5 = torch.ops.aten.as_strided.default(as_strided_scatter, [2, 2], [2, 1], 0)
|
||||
add_1 = torch.ops.aten.add.Tensor(as_strided_2, as_strided_5); as_strided_2 = as_strided_5 = None
|
||||
return (as_strided_scatter, add_1)""",
|
||||
return [as_strided_scatter, add_1]""",
|
||||
) # noqa: B950
|
||||
|
||||
def test_input_mutation_aliases_and_output_alias(self):
|
||||
|
|
@ -2765,7 +2766,7 @@ def forward(self, primals_1):
|
|||
as_strided_scatter = torch.ops.aten.as_strided_scatter.default(clone, add, [4], [1], 0); clone = add = None
|
||||
as_strided_8 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_1 = torch.ops.aten.view.default(as_strided_8, [4]); as_strided_8 = None
|
||||
return (as_strided_scatter, view_1)""",
|
||||
return [as_strided_scatter, view_1]""",
|
||||
) # noqa: B950
|
||||
|
||||
def test_input_aliased_with_mutation_output_alias(self):
|
||||
|
|
@ -2802,7 +2803,7 @@ def forward(self, primals_1, primals_2):
|
|||
add = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
as_strided_7 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_1 = torch.ops.aten.view.default(as_strided_7, [-1]); as_strided_7 = None
|
||||
return (as_strided_scatter, add, view_1)""",
|
||||
return [as_strided_scatter, add, view_1]""",
|
||||
) # noqa: B950
|
||||
|
||||
def test_input_metadata_mutation_aliases(self):
|
||||
|
|
@ -2831,7 +2832,7 @@ def forward(self, primals_1, primals_2):
|
|||
def forward(self, primals_1, primals_2):
|
||||
t = torch.ops.aten.t.default(primals_1); primals_1 = None
|
||||
add = torch.ops.aten.add.Tensor(t, primals_2); t = primals_2 = None
|
||||
return (add,)""",
|
||||
return [add]""",
|
||||
)
|
||||
|
||||
def test_input_mutation_aliases_and_none_require_gradients(self):
|
||||
|
|
@ -2874,7 +2875,7 @@ def forward(self, primals_1, primals_2):
|
|||
as_strided_3 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
add = torch.ops.aten.add.Tensor(as_strided_3, 1); as_strided_3 = None
|
||||
add_1 = torch.ops.aten.add.Tensor(primals_2, 1); primals_2 = None
|
||||
return (as_strided_scatter, add, add_1)""",
|
||||
return [as_strided_scatter, add, add_1]""",
|
||||
) # noqa: B950
|
||||
|
||||
@skipIfDynamoInput("Fails with dynamo")
|
||||
|
|
@ -2933,7 +2934,7 @@ def forward(self, primals_1, primals_2, primals_3):
|
|||
add_2 = torch.ops.aten.add.Tensor(add_1, unsqueeze_1); add_1 = None
|
||||
as_strided_14 = torch.ops.aten.as_strided.default(as_strided_scatter, [4], [1], 0)
|
||||
view_2 = torch.ops.aten.view.default(as_strided_14, [-1]); as_strided_14 = None
|
||||
return (as_strided_scatter, add_2, view_2, unsqueeze_1)""",
|
||||
return [as_strided_scatter, add_2, view_2, unsqueeze_1]""",
|
||||
) # noqa: B950
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
||||
|
|
@ -3007,7 +3008,7 @@ def forward(self, primals_1, primals_2):
|
|||
view_1 = torch.ops.aten.view.default(add, [-1])
|
||||
t_1 = torch.ops.aten.t.default(t)
|
||||
unsqueeze = torch.ops.aten.unsqueeze.default(view_1, 0)
|
||||
return (as_strided_scatter, t, view_1, t_1, unsqueeze, add)""",
|
||||
return [as_strided_scatter, t, view_1, t_1, unsqueeze, add]""",
|
||||
) # noqa: B950
|
||||
|
||||
def test_dynamic_shape_output_not_in_bw_graph(self):
|
||||
|
|
@ -3035,7 +3036,7 @@ def forward(self, primals_1, primals_2):
|
|||
bw_graph_cell[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, tangents_1):
|
||||
return (tangents_1,)""",
|
||||
return [tangents_1]""",
|
||||
)
|
||||
|
||||
def test_no_grad_input_output(self):
|
||||
|
|
@ -3555,7 +3556,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4):
|
|||
sum_2 = torch.ops.aten.sum.default(add)
|
||||
add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
||||
copy_ = torch.ops.aten.copy_.default(primals_3, add); primals_3 = add = None
|
||||
return (add_1, primals_1, primals_2, primals_4, mul)""",
|
||||
return [add_1, primals_1, primals_2, primals_4, mul]""",
|
||||
)
|
||||
|
||||
self.assertEqual(out_ref, out_test)
|
||||
|
|
@ -3610,7 +3611,7 @@ def forward(self, primals_1, primals_2, primals_3):
|
|||
sum_2 = torch.ops.aten.sum.default(add)
|
||||
add_1 = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
||||
copy_ = torch.ops.aten.copy_.default(primals_2, add); primals_2 = add = None
|
||||
return (add_1, primals_1, primals_3)""",
|
||||
return [add_1, primals_1, primals_3]""",
|
||||
)
|
||||
self.assertEqual(out_ref, out_test)
|
||||
|
||||
|
|
@ -3668,7 +3669,7 @@ def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5, primals
|
|||
copy_ = torch.ops.aten.copy_.default(primals_3, getitem_3); primals_3 = None
|
||||
copy__1 = torch.ops.aten.copy_.default(primals_4, getitem_4); primals_4 = None
|
||||
copy__2 = torch.ops.aten.copy_.default(primals_5, add); primals_5 = add = None
|
||||
return (getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4)""", # noqa: B950
|
||||
return [getitem, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem_4]""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertEqual(out_ref, out_test)
|
||||
|
|
@ -3688,7 +3689,7 @@ def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem
|
|||
getitem_5 = native_batch_norm_backward[0]
|
||||
getitem_6 = native_batch_norm_backward[1]
|
||||
getitem_7 = native_batch_norm_backward[2]; native_batch_norm_backward = None
|
||||
return (getitem_6, getitem_7, None, None, None, getitem_5)""", # noqa: B950
|
||||
return [getitem_6, getitem_7, None, None, None, getitem_5]""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertEqual(inp_ref.grad, inp_test.grad)
|
||||
|
|
@ -3729,7 +3730,7 @@ def forward(self, primals_1, primals_6, getitem_1, getitem_2, getitem_3, getitem
|
|||
def forward(self, primals_1, primals_2):
|
||||
clone = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
add = torch.ops.aten.add.Tensor(clone, primals_2); clone = primals_2 = None
|
||||
return (add, add)""",
|
||||
return [add, add]""",
|
||||
) # noqa: B950
|
||||
|
||||
self.assertEqual(out_ref, out_test)
|
||||
|
|
@ -3741,7 +3742,7 @@ def forward(self, primals_1, primals_2):
|
|||
bw_graph_cell[0].code.strip(),
|
||||
"""\
|
||||
def forward(self, tangents_1):
|
||||
return (None, tangents_1)""",
|
||||
return [None, tangents_1]""",
|
||||
) # noqa: B950
|
||||
|
||||
def test_real_weights_in_symbolic_mode(self):
|
||||
|
|
@ -3829,7 +3830,6 @@ def forward(self, tangents_1):
|
|||
# since they are the only ones that will get reconstructed.
|
||||
def wrapper(g, *args, **kwargs):
|
||||
outs = g(*args, **kwargs)
|
||||
outs = list(outs)
|
||||
for i in output_view_indices:
|
||||
outs[i] = NoViewReplayTensor(outs[i])
|
||||
return outs
|
||||
|
|
@ -5333,7 +5333,7 @@ def forward(self, primals_1, primals_2, primals_3):
|
|||
div = torch.ops.aten.div.Tensor(primals_3, 2); primals_3 = None
|
||||
add = torch.ops.aten.add.Tensor(mul, div); mul = None
|
||||
add_1 = torch.ops.aten.add.Tensor(mul_1, div); mul_1 = div = None
|
||||
return (add, add_1)""",
|
||||
return [add, add_1]""",
|
||||
)
|
||||
|
||||
# Important pieces of the graph:
|
||||
|
|
@ -5351,7 +5351,7 @@ def forward(self, tangents_1, tangents_2):
|
|||
div_2 = torch.ops.aten.div.Tensor(tangents_2, 2)
|
||||
mul_2 = torch.ops.aten.mul.Tensor(tangents_1, 6); tangents_1 = None
|
||||
mul_3 = torch.ops.aten.mul.Tensor(tangents_2, 6); tangents_2 = None
|
||||
return (mul_2, mul_3, div_1, div_2)""",
|
||||
return [mul_2, mul_3, div_1, div_2]""",
|
||||
)
|
||||
|
||||
def test_aot_dispatch_inference(self):
|
||||
|
|
@ -5666,59 +5666,6 @@ def forward(self, tangents_1, tangents_2):
|
|||
self.assertEqual(a_ref_base.grad.a, a_test_base.grad.a)
|
||||
self.assertEqual(a_ref_base.grad.b, a_test_base.grad.b)
|
||||
|
||||
def test_aot_dispatch_output_requires_grad_in_no_grad(self):
|
||||
def fn(x):
|
||||
out1 = x.sin()
|
||||
with torch.enable_grad():
|
||||
out2 = x.cos()
|
||||
return out1, out2
|
||||
|
||||
inp_fns = [
|
||||
lambda: torch.ones(10, requires_grad=True),
|
||||
lambda: torch.ones(10, requires_grad=False),
|
||||
]
|
||||
|
||||
compiled_f = aot_function(fn, nop)
|
||||
for inp_fn in inp_fns:
|
||||
with torch.no_grad():
|
||||
ref_x = inp_fn()
|
||||
ref_out = fn(ref_x)
|
||||
x = inp_fn()
|
||||
out = compiled_f(x)
|
||||
for r, o in zip(ref_out, out):
|
||||
self.assertEqual(r.requires_grad, o.requires_grad)
|
||||
if ref_x.requires_grad:
|
||||
with torch.enable_grad():
|
||||
(ref_out[0] + ref_out[1]).sum().backward()
|
||||
(out[0] + out[1]).sum().backward()
|
||||
self.assertEqual(ref_x.grad, x.grad)
|
||||
assert torch.allclose(ref_x.grad, x.grad, atol=1e-3, rtol=1e-3)
|
||||
|
||||
def test_aot_dispatch_output_requires_grad_in_no_grad_views(self):
|
||||
# view-type ops preserve requires_grad even in no_grad.
|
||||
def fn(x):
|
||||
return x.view(-1), x.sin()
|
||||
|
||||
inference_graph_cell = [None]
|
||||
inference_compiler = make_boxed_compiler(
|
||||
partial(extract_graph, graph_cell=inference_graph_cell)
|
||||
)
|
||||
compiled_fn = aot_function(fn, nop, inference_compiler=inference_compiler)
|
||||
|
||||
inp_x0 = torch.ones(2, 3, requires_grad=True)
|
||||
# Clone in no_grad will make requires_grad=False tensors, keep clone outside of no_grad
|
||||
ref_x0 = inp_x0.clone()
|
||||
x0 = inp_x0.clone()
|
||||
with torch.no_grad():
|
||||
ref_out1, ref_out2 = fn(ref_x0)
|
||||
|
||||
out1, out2 = compiled_fn(x0)
|
||||
# Assert that we executed inference graph
|
||||
self.assertTrue(inference_graph_cell[0] is not None)
|
||||
|
||||
self.assertEqual(ref_out1.requires_grad, out1.requires_grad)
|
||||
self.assertEqual(ref_out2.requires_grad, out2.requires_grad)
|
||||
|
||||
|
||||
class TestAOTModuleSimplified(AOTTestCase):
|
||||
def test_aot_module_simplified(self):
|
||||
|
|
|
|||
|
|
@ -304,13 +304,7 @@ def _create_runtime_wrapper(
|
|||
for idx in indices_of_inps_to_detach:
|
||||
if isinstance(args_[idx], torch.Tensor):
|
||||
args_[idx] = args_[idx].detach()
|
||||
|
||||
# It's possible to have trace_joint inside user specified with no_grad() region,
|
||||
# if there is a nested with enable_grad(), that forces some outputs to require gradients.
|
||||
# Therefore, we unconditionally turn on enable_grad() for compiled_fn execution.
|
||||
with torch.autograd._force_original_view_tracking(
|
||||
True
|
||||
), torch.enable_grad():
|
||||
with torch.autograd._force_original_view_tracking(True):
|
||||
all_outs = call_func_at_runtime_with_args(
|
||||
compiled_fn, args_, disable_amp=disable_amp, steal_args=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -566,8 +566,9 @@ def create_aot_dispatcher_function(
|
|||
|
||||
fake_flat_args = process_inputs(flat_args)
|
||||
|
||||
needs_autograd = any(
|
||||
x.requires_grad for x in fake_flat_args if isinstance(x, Tensor)
|
||||
needs_autograd = (
|
||||
any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor))
|
||||
and torch.is_grad_enabled()
|
||||
)
|
||||
|
||||
with enable_python_dispatcher():
|
||||
|
|
@ -592,17 +593,7 @@ def create_aot_dispatcher_function(
|
|||
)
|
||||
|
||||
output_and_mutation_safe = not any(
|
||||
x.requires_grad
|
||||
# view-type operations preserve requires_grad even in no_grad.
|
||||
# Do not count aliases of inputs with requires_grad as reason to make a training graph,
|
||||
# as AOTAutograd will perform view-replay to regenerate the view outputs at runtime,
|
||||
# setting their grad_fn properly.
|
||||
and not (
|
||||
x.output_type
|
||||
in (OutputType.alias_of_input, OutputType.is_input)
|
||||
and fw_metadata.input_info[x.base_idx].requires_grad
|
||||
)
|
||||
for x in fw_metadata.output_info
|
||||
x.requires_grad for x in fw_metadata.output_info
|
||||
) and not any(
|
||||
x.requires_grad
|
||||
and x.mutates_data
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ def _extract_graph_with_inputs_outputs(
|
|||
output_values.append(env[x])
|
||||
else:
|
||||
output_values.append(x)
|
||||
new_graph.output(tuple(output_values))
|
||||
new_graph.output(output_values)
|
||||
|
||||
new_graph.eliminate_dead_code()
|
||||
new_graph.lint()
|
||||
|
|
@ -727,7 +727,7 @@ def functionalize_rng_ops(
|
|||
sym_node_start_idx = len(fw_outputs) - num_sym_nodes
|
||||
outputs = (
|
||||
fw_outputs[:sym_node_start_idx]
|
||||
+ tuple(fw_rng_state_outputs)
|
||||
+ fw_rng_state_outputs
|
||||
+ fw_outputs[sym_node_start_idx:]
|
||||
)
|
||||
fw_module.graph.output(outputs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user