Revert "[cond] don't trace fw and bw graph in autograd key (#148930)"

This reverts commit 6e843a51dd.

Reverted https://github.com/pytorch/pytorch/pull/148930 on behalf of https://github.com/ydwu4 due to Test failure is legit ([comment](https://github.com/pytorch/pytorch/pull/148930#issuecomment-2741585315))
This commit is contained in:
PyTorch MergeBot 2025-03-20 20:28:29 +00:00
parent 4a4a71a73c
commit 24176f6e32
4 changed files with 119 additions and 251 deletions

View File

@ -4945,7 +4945,7 @@ def forward(self, arg0_1):
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (arg0_1,)); gt = true_graph_0 = false_graph_0 = arg0_1 = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
add = torch.ops.aten.add.Tensor(getitem, 3)
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None

View File

@ -503,7 +503,7 @@ def forward(self, pred_1, x_1):
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
@ -544,7 +544,7 @@ def forward(self, pred_1, x_1):
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
@ -653,7 +653,7 @@ def forward(self, pred_1, x_1, y_1, z_1):
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (z_1, y_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = z_1 = y_1 = ones_like = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
return (getitem_1,)""", # noqa: B950
@ -714,7 +714,7 @@ def forward(self, pred_1, x_1):
_param_constant0_1 = self._param_constant0
_param_constant1_1 = self._param_constant1
_tensor_constant0_1 = self._tensor_constant0
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (_param_constant0_1, _param_constant1_1, x_1, sym_size_int, _tensor_constant0_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = _param_constant0_1 = _param_constant1_1 = x_1 = sym_size_int = _tensor_constant0_1 = ones_like = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, sym_size_int, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = sym_size_int = _tensor_constant0_1 = None
getitem_1 = cond_1[0]; getitem_1 = None
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; getitem_3 = None
@ -832,7 +832,7 @@ def forward(self, pred_1, a_1, b_1, c_1):
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2, ones_like)); pred_1 = true_graph_1 = false_graph_1 = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = ones_like = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; getitem_3 = None
@ -854,9 +854,11 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
gm.true_graph_1.code.strip(),
"""\
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = add = None
zeros_like = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None
return [arg6_1, arg6_1, None, None, zeros_like, None]""",
add = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = add = None
zeros_like = torch.ops.aten.zeros_like.default(arg5_1, pin_memory = False); arg5_1 = None
clone = torch.ops.aten.clone.default(arg0_1)
clone_1 = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
return [clone, clone_1, None, None, zeros_like, None]""",
)
def test_cond_autograd_pytree_input(self):
@ -908,7 +910,7 @@ def forward(self, pred_1):
_tensor_constant0_1 = self._tensor_constant0
_tensor_constant1_1 = self._tensor_constant1
_tensor_constant2_1 = self._tensor_constant2
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (_tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = ones_like = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
@ -1037,7 +1039,7 @@ def forward(self, pred_1, x_1):
_param_constant3_1 = self._param_constant3
_param_constant4_1 = self._param_constant4
_param_constant5_1 = self._param_constant5
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = ones_like = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = None
getitem_1 = cond_1[0]
getitem_2 = cond_1[1]; getitem_2 = None
getitem_3 = cond_1[2]; getitem_3 = None
@ -1095,7 +1097,7 @@ def forward(self, pred_1, x_1):
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
@ -1150,7 +1152,7 @@ def forward(self, pred_1, x_1):
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
@ -1192,7 +1194,7 @@ def forward(self, pred_1, x_1):
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
return (getitem_1,)""", # noqa: B950
)
@ -1322,18 +1324,22 @@ def forward(self, pred_1, x_1):
"""\
def forward(self):
_tensor_constant0 = self._tensor_constant0
ones_like = torch.ops.aten.ones_like.default(_tensor_constant0, pin_memory = False, memory_format = torch.preserve_format); _tensor_constant0 = None
ones_like = torch.ops.aten.ones_like.default(_tensor_constant0, pin_memory = False,\
memory_format = torch.preserve_format); _tensor_constant0 = None
_tensor_constant1 = self._tensor_constant1
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
cond = torch.ops.higher_order.cond(_tensor_constant1, true_graph_0, false_graph_0, (_tensor_constant2, _tensor_constant3, _tensor_constant4, ones_like)); _tensor_constant1 = true_graph_0 = false_graph_0 = _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = ones_like = None
cond = torch.ops.higher_order.cond(_tensor_constant1, true_graph_0, false_graph_0,\
(ones_like, _tensor_constant2, _tensor_constant3, _tensor_constant4));\
_tensor_constant1 = true_graph_0 = false_graph_0 = ones_like = _tensor_constant2 =\
_tensor_constant3 = _tensor_constant4 = None
getitem = cond[0]; getitem = None
getitem_1 = cond[1]
getitem_2 = cond[2]; cond = None
return (getitem_1, getitem_2)""", # noqa: B950
return (getitem_1, getitem_2)""",
)
else:
self.assertExpectedInline(
@ -1419,56 +1425,31 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
if compile_mode == "eager" or compile_mode == "none":
self.assertExpectedInline(
normalize_gm(gm.print_readable(print_output=False)),
gm.code.strip(),
"""\
class f(torch.nn.Module):
def forward(self):
_tensor_constant0 = self._tensor_constant0
ones_like: "f32[]" = torch.ops.aten.ones_like.default(_tensor_constant0, pin_memory = False, memory_format = torch.preserve_format); _tensor_constant0 = None
_tensor_constant1 = self._tensor_constant1
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
_tensor_constant5 = self._tensor_constant5
_tensor_constant6 = self._tensor_constant6
cond = torch.ops.higher_order.cond(_tensor_constant1, true_graph_0, false_graph_0, (_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, ones_like)); _tensor_constant1 = true_graph_0 = false_graph_0 = _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = ones_like = None
getitem: "f32[4, 5]" = cond[0]; getitem = None
getitem_1: "f32[2, 4]" = cond[1]
getitem_2: "f32[2, 1]" = cond[2]
getitem_3: "f32[2, 4]" = cond[3]
getitem_4: "f32[1, 5]" = cond[4]; cond = None
return (getitem_1, getitem_2, getitem_3, getitem_4)
class true_graph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[4, 5]", arg1_1: "f32[2, 4]", arg2_1: "f32[2, 1]", arg3_1: "f32[2, 4]", arg4_1: "f32[1, 5]", arg5_1: "f32[]"):
mm: "f32[2, 5]" = torch.ops.aten.mm.default(arg1_1, arg0_1); arg1_1 = None
add: "f32[2, 5]" = torch.ops.aten.add.Tensor(mm, arg2_1); mm = arg2_1 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = sum_1 = None
expand: "f32[2, 5]" = torch.ops.aten.expand.default(arg5_1, [2, 5]); arg5_1 = None
sum_2: "f32[2, 1]" = torch.ops.aten.sum.dim_IntList(expand, [1], True)
t: "f32[5, 4]" = torch.ops.aten.t.default(arg0_1)
mm_1: "f32[2, 4]" = torch.ops.aten.mm.default(expand, t); expand = t = None
zeros_like: "f32[4, 5]" = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False); arg0_1 = None
zeros_like_1: "f32[2, 4]" = torch.ops.aten.zeros_like.default(arg3_1, pin_memory = False); arg3_1 = None
zeros_like_2: "f32[1, 5]" = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None
return [zeros_like, mm_1, sum_2, zeros_like_1, zeros_like_2]
class false_graph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[4, 5]", arg1_1: "f32[2, 4]", arg2_1: "f32[2, 1]", arg3_1: "f32[2, 4]", arg4_1: "f32[1, 5]", arg5_1: "f32[]"):
mm: "f32[2, 5]" = torch.ops.aten.mm.default(arg3_1, arg0_1); arg3_1 = None
add: "f32[2, 5]" = torch.ops.aten.add.Tensor(mm, arg4_1); mm = arg4_1 = None
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = sum_1 = None
expand: "f32[2, 5]" = torch.ops.aten.expand.default(arg5_1, [2, 5]); arg5_1 = None
sum_2: "f32[1, 5]" = torch.ops.aten.sum.dim_IntList(expand, [0], True)
t: "f32[5, 4]" = torch.ops.aten.t.default(arg0_1)
mm_1: "f32[2, 4]" = torch.ops.aten.mm.default(expand, t); expand = t = None
zeros_like: "f32[4, 5]" = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False); arg0_1 = None
zeros_like_1: "f32[2, 4]" = torch.ops.aten.zeros_like.default(arg1_1, pin_memory = False); arg1_1 = None
zeros_like_2: "f32[2, 1]" = torch.ops.aten.zeros_like.default(arg2_1, pin_memory = False); arg2_1 = None
return [zeros_like, zeros_like_1, zeros_like_2, mm_1, sum_2]
""", # noqa: B950
def forward(self):
_tensor_constant0 = self._tensor_constant0
ones_like = torch.ops.aten.ones_like.default(_tensor_constant0, pin_memory = False,\
memory_format = torch.preserve_format); _tensor_constant0 = None
_tensor_constant1 = self._tensor_constant1
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
_tensor_constant2 = self._tensor_constant2
_tensor_constant3 = self._tensor_constant3
_tensor_constant4 = self._tensor_constant4
_tensor_constant5 = self._tensor_constant5
_tensor_constant6 = self._tensor_constant6
cond = torch.ops.higher_order.cond(_tensor_constant1, true_graph_0, false_graph_0,\
(ones_like, _tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5,\
_tensor_constant6)); _tensor_constant1 = true_graph_0 = false_graph_0 = ones_like =\
_tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 =\
_tensor_constant6 = None
getitem = cond[0]; getitem = None
getitem_1 = cond[1]
getitem_2 = cond[2]
getitem_3 = cond[3]
getitem_4 = cond[4]; cond = None
return (getitem_1, getitem_2, getitem_3, getitem_4)""",
)
else:
self.assertExpectedInline(
@ -4417,47 +4398,37 @@ def forward(self, L_pred_ : torch.Tensor, L_x_ : torch.Tensor):
grad_out = torch.ones_like(result)
return (result, grad_out)""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
gm.cond_true_0.code.strip(),
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_ctx_saved_tensors_0_: "f32[4]", L_ctx_pred: "b8[]", L_args_1_: "f32[4]"):
l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_
l_ctx_pred = L_ctx_pred
l_args_1_ = L_args_1_
def forward(self, l_x_):
l_x__1 = l_x_
sin = l_x__1.sin(); l_x__1 = None
return (sin,)""", # noqa: B950
)
self.assertExpectedInline(
gm.cond_false_0.code.strip(),
"""\
def forward(self, l_x_):
l_x__1 = l_x_
cos = l_x__1.cos(); l_x__1 = None
return (cos,)""", # noqa: B950
)
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, [l_args_1_, l_ctx_saved_tensors_0_]); l_ctx_pred = cond_true_0 = cond_false_0 = l_args_1_ = l_ctx_saved_tensors_0_ = None
getitem: "f32[4]" = cond[0]; cond = None
return (getitem,)
class cond_true_0(torch.nn.Module):
def forward(self, l_args_1_, l_ctx_saved_tensors_0_):
l_args_1__1 = l_args_1_
l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_
sin: "f32[4]" = torch.ops.aten.sin.default(l_ctx_saved_tensors_0__1); sin = None
cos: "f32[4]" = torch.ops.aten.cos.default(l_ctx_saved_tensors_0__1); l_ctx_saved_tensors_0__1 = None
mul: "f32[4]" = torch.ops.aten.mul.Tensor(l_args_1__1, cos); l_args_1__1 = cos = None
return (mul,)
class cond_false_0(torch.nn.Module):
def forward(self, l_args_1_, l_ctx_saved_tensors_0_):
l_args_1__1 = l_args_1_
l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_
cos: "f32[4]" = torch.ops.aten.cos.default(l_ctx_saved_tensors_0__1); cos = None
sin: "f32[4]" = torch.ops.aten.sin.default(l_ctx_saved_tensors_0__1); l_ctx_saved_tensors_0__1 = None
neg: "f32[4]" = torch.ops.aten.neg.default(sin); sin = None
mul: "f32[4]" = torch.ops.aten.mul.Tensor(l_args_1__1, neg); l_args_1__1 = neg = None
return (mul,)
""", # noqa: B950
backward_gm = backend.graphs[1]
self.assertExpectedInline(
backward_gm.code.strip(),
"""\
def forward(self, L_ctx_saved_tensors_0_ : torch.Tensor, L_ctx_pred : torch.Tensor, L_flat_grads_0_ : torch.Tensor):
l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_
l_ctx_pred = L_ctx_pred
l_flat_grads_0_ = L_flat_grads_0_
cond_true_0 = self.cond_true_0
cond_false_0 = self.cond_false_0
cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, [l_ctx_saved_tensors_0_, l_flat_grads_0_]); l_ctx_pred = cond_true_0 = cond_false_0 = l_ctx_saved_tensors_0_ = l_flat_grads_0_ = None
getitem = cond[0]; cond = None
return (getitem,)""", # noqa: B950
)
def test_while_loop_op_mismatch_in_meta(self):
@ -4994,7 +4965,7 @@ def forward(self, a_1, b_1):
sym_size_int_3 = torch.ops.aten.sym_size.int(b_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -5474,11 +5445,11 @@ def forward(self, x_1):
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); pred_1 = true_graph_0 = false_graph_0 = None
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, (x_1,)); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
return add""", # noqa: B950
@ -5647,11 +5618,11 @@ def forward(self, arg0_1):
def forward(self, x_1, pred_1, pred2_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); pred_1 = true_graph_0 = false_graph_0 = None
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
getitem = cond[0]; cond = None
true_graph_1 = self.true_graph_1
false_graph_1 = self.false_graph_1
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, (x_1,)); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
getitem_1 = cond_1[0]; cond_1 = None
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
return add""", # noqa: B950
@ -6125,7 +6096,7 @@ def forward(self, x_1):
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, sym_size_int, sym_size_int_1]); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -6157,7 +6128,7 @@ def forward(self, x_1):
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 0)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x_1, sym_size_int_1)); gt = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = None
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1, sym_size_int_1]); gt = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -6238,7 +6209,7 @@ def forward(self, x_1):
false_graph_0 = self.false_graph_0
_tensor_constant0 = self._tensor_constant0
_tensor_constant1 = self._tensor_constant1
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)
@ -6389,7 +6360,7 @@ def forward(self, pred_1, x_1):
def forward(self, arg0_1, arg1_1):
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, (arg0_1,)); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
getitem = cond[0]; cond = None
return [getitem]""", # noqa: B950
)
@ -6478,7 +6449,7 @@ def forward(self, x_1):
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
true_graph_0 = self.true_graph_0
false_graph_0 = self.false_graph_0
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, sym_size_int, sym_size_int_1]); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
getitem = cond[0]; cond = None
return getitem""", # noqa: B950
)

View File

@ -20,7 +20,6 @@ from torch._functorch.utils import exposed_in
from torch._higher_order_ops.utils import (
_has_potential_branch_input_alias,
_has_potential_branch_input_mutation,
_maybe_reenter_make_fx,
_maybe_run_with_interpreter,
_set_compilation_env,
reenter_make_fx,
@ -242,79 +241,6 @@ def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
return fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph
def materialize_as_graph(
fn: Callable,
args: tuple[Any],
include_key_set: torch._C.DispatchKeySet,
exclude_key_set: torch._C.DispatchKeySet,
force_enable_grad=False,
) -> torch.fx.GraphModule:
@torch._dynamo.disable(recursive=True)
def _materialize_as_graph_inner():
with suspend_functionalization(), disable_functional_mode():
with disable_proxy_modes_tracing():
unfunc_t = [_from_fun(arg) for arg in args]
with contextlib.ExitStack() as stack:
stack.enter_context(
torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
)
if force_enable_grad:
stack.enter_context(torch.enable_grad())
return _maybe_reenter_make_fx(fn)(*unfunc_t)
gm = _materialize_as_graph_inner()
assert gm is not None
return gm
def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
"""
For a fn that accepts flat inputs and returns flat outputs:
fw_out = fn(*args),
this function returns:
grad_args = bw_fn(*args_and_grad_output)
with the following invariants:
1. args + fw_out has an 1-1 correspondence to args_and_grad_output
2. grad_args has an 1-1 corresponsence to args
3. for tensor arg whose requires_grad is False, its corresponding grad in
grad_args will be a zero tensor with the same shape.
"""
from torch._functorch.aot_autograd import AOTConfig, create_joint
from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad
dummy_aot_config = AOTConfig(
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
keep_inference_input_mutations=False,
)
n_primals = len(args)
bw_fn = create_joint(
prepare_fw_with_masks_all_requires_grad(fn), aot_config=dummy_aot_config
)
def flat_fn(*args_and_grad_outs):
primals = args_and_grad_outs[:n_primals]
tangents = args_and_grad_outs[n_primals:]
grad_args = bw_fn(primals, tangents)[1]
assert len(args) == len(grad_args)
return [
(
torch.zeros_like(arg)
if isinstance(arg, torch.Tensor) and grad is None
else grad
)
for grad, arg in zip(grad_args, primals)
]
return flat_fn
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
assert isinstance(
operands, (list, tuple)
@ -381,69 +307,60 @@ class CondAutogradOp(torch.autograd.Function):
def forward(
ctx,
pred,
true_fn,
false_fn,
fw_true_graph,
fw_false_graph,
joint_true_graph,
joint_false_graph,
*operands,
):
ctx._pred = pred
ctx._true_bw_fn = create_bw_fn(
true_fn,
operands,
)
ctx._false_bw_fn = create_bw_fn(
false_fn,
operands,
)
# We snapshot the dispatch keys in forward for materializing the
# the bw_graph in backward.
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
ctx._joint_true_graph = joint_true_graph
ctx._joint_false_graph = joint_false_graph
save_tensors_and_symints_for_backward(ctx, operands)
with torch._C._AutoDispatchBelowAutograd():
return cond_op(pred, true_fn, false_fn, operands)
return cond_op(pred, fw_true_graph, fw_false_graph, operands)
@staticmethod
def backward(ctx, *flat_grads):
operands = saved_tensors_and_symints(ctx)
args = operands + flat_grads
# TODO: we need to materialize the bw graphs because dynamo is unable to
# trace through the joint funcion when torch.compile torch.autograd.grad.
true_bw_gm = materialize_as_graph(
ctx._true_bw_fn,
args,
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
force_enable_grad=True,
)
false_bw_gm = materialize_as_graph(
ctx._false_bw_fn,
args,
ctx._fw_include_key_set,
ctx._fw_exclude_key_set,
force_enable_grad=True,
)
grads = cond_op(
ctx._pred,
true_bw_gm,
false_bw_gm,
args,
ctx._joint_true_graph,
ctx._joint_false_graph,
flat_grads + operands,
)
return None, None, None, *grads
return None, None, None, None, None, *grads
# Note:
# As long as one of the tensors in pred or operands requires grad,
# all the output would require grad with backward fn set to be the CondAutogradOp.
# This is consistent with autograd.Function's semantic.
@cond_op.py_impl(DispatchKey.Autograd)
def cond_autograd(pred, true_fn, false_fn, operands):
return CondAutogradOp.apply(
# A shortcut for the case where all inputs don't require gradient,
# we skip tracing the forward and backward graph.
if pytree.tree_all_only(
torch.Tensor,
lambda t: not t.requires_grad, # type: ignore[union-attr]
(pred, operands),
):
with torch._C._AutoDispatchBelowAutograd():
return cond_op(pred, true_fn, false_fn, operands)
(
fw_true_graph,
fw_false_graph,
joint_true_graph,
joint_false_graph,
) = create_fw_bw_graph_branches(true_fn, false_fn, *operands)
flat_out = CondAutogradOp.apply(
pred,
true_fn,
false_fn,
fw_true_graph,
fw_false_graph,
joint_true_graph,
joint_false_graph,
*operands,
)
return flat_out
@cond_op.py_impl(ProxyTorchDispatchMode)

View File

@ -424,26 +424,6 @@ def prepare_fw_with_masks(fn):
return fw_with_masks
def prepare_fw_with_masks_all_requires_grad(fn):
def fw_with_masks(*args):
fw_out = fn(*args)
# Note [force all outputs to be require grad]
# Instead of using the original fn, we set the output of original
# fn to all require grad. This is consistent with the behavior
# of autograd.Function, where if any one of the inputs requires grad
# all output will be require grad. This also makes the downstream
# require_gradness reasoning much easier.
if pytree.tree_any_only(torch.Tensor, lambda t: t.requires_grad, args):
fw_out = pytree.tree_map_only(
torch.Tensor, lambda x: x.requires_grad_(True), fw_out
)
return fw_out, pytree.tree_map_only(
torch.Tensor, lambda x: x.requires_grad, fw_out
)
return fw_with_masks
# This function replaces None gradients with all-zero gradients.
# `None` gradients are problematic for CUDA graphs. Those gradients are
# replaced with an all-zero tensor for better optimization