mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "[cond] don't trace fw and bw graph in autograd key (#148930)"
This reverts commit 6e843a51dd.
Reverted https://github.com/pytorch/pytorch/pull/148930 on behalf of https://github.com/ydwu4 due to Test failure is legit ([comment](https://github.com/pytorch/pytorch/pull/148930#issuecomment-2741585315))
This commit is contained in:
parent
4a4a71a73c
commit
24176f6e32
|
|
@ -4945,7 +4945,7 @@ def forward(self, arg0_1):
|
|||
gt = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (arg0_1,)); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [arg0_1]); gt = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1 = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
|
|
|
|||
|
|
@ -503,7 +503,7 @@ def forward(self, pred_1, x_1):
|
|||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
|
@ -544,7 +544,7 @@ def forward(self, pred_1, x_1):
|
|||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
|
@ -653,7 +653,7 @@ def forward(self, pred_1, x_1, y_1, z_1):
|
|||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (z_1, y_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = z_1 = y_1 = ones_like = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, z_1, y_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = z_1 = y_1 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]; cond_1 = getitem_2 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
|
|
@ -714,7 +714,7 @@ def forward(self, pred_1, x_1):
|
|||
_param_constant0_1 = self._param_constant0
|
||||
_param_constant1_1 = self._param_constant1
|
||||
_tensor_constant0_1 = self._tensor_constant0
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (_param_constant0_1, _param_constant1_1, x_1, sym_size_int, _tensor_constant0_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = _param_constant0_1 = _param_constant1_1 = x_1 = sym_size_int = _tensor_constant0_1 = ones_like = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _param_constant0_1, _param_constant1_1, x_1, sym_size_int, _tensor_constant0_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _param_constant0_1 = _param_constant1_1 = x_1 = sym_size_int = _tensor_constant0_1 = None
|
||||
getitem_1 = cond_1[0]; getitem_1 = None
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]; getitem_3 = None
|
||||
|
|
@ -832,7 +832,7 @@ def forward(self, pred_1, a_1, b_1, c_1):
|
|||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2, ones_like)); pred_1 = true_graph_1 = false_graph_1 = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = ones_like = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2)); pred_1 = true_graph_1 = false_graph_1 = ones_like = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]; getitem_3 = None
|
||||
|
|
@ -854,9 +854,11 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1):
|
|||
gm.true_graph_1.code.strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = add = None
|
||||
zeros_like = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None
|
||||
return [arg6_1, arg6_1, None, None, zeros_like, None]""",
|
||||
add = torch.ops.aten.add.Tensor(arg1_1, arg2_1); arg1_1 = arg2_1 = add = None
|
||||
zeros_like = torch.ops.aten.zeros_like.default(arg5_1, pin_memory = False); arg5_1 = None
|
||||
clone = torch.ops.aten.clone.default(arg0_1)
|
||||
clone_1 = torch.ops.aten.clone.default(arg0_1); arg0_1 = None
|
||||
return [clone, clone_1, None, None, zeros_like, None]""",
|
||||
)
|
||||
|
||||
def test_cond_autograd_pytree_input(self):
|
||||
|
|
@ -908,7 +910,7 @@ def forward(self, pred_1):
|
|||
_tensor_constant0_1 = self._tensor_constant0
|
||||
_tensor_constant1_1 = self._tensor_constant1
|
||||
_tensor_constant2_1 = self._tensor_constant2
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (_tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = ones_like = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, _tensor_constant0_1, _tensor_constant1_1, _tensor_constant2_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = _tensor_constant0_1 = _tensor_constant1_1 = _tensor_constant2_1 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
|
||||
|
|
@ -1037,7 +1039,7 @@ def forward(self, pred_1, x_1):
|
|||
_param_constant3_1 = self._param_constant3
|
||||
_param_constant4_1 = self._param_constant4
|
||||
_param_constant5_1 = self._param_constant5
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = ones_like = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1, _param_constant0_1, _param_constant1_1, _param_constant2_1, _param_constant3_1, _param_constant4_1, _param_constant5_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = _param_constant0_1 = _param_constant1_1 = _param_constant2_1 = _param_constant3_1 = _param_constant4_1 = _param_constant5_1 = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]; getitem_2 = None
|
||||
getitem_3 = cond_1[2]; getitem_3 = None
|
||||
|
|
@ -1095,7 +1097,7 @@ def forward(self, pred_1, x_1):
|
|||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
|
@ -1150,7 +1152,7 @@ def forward(self, pred_1, x_1):
|
|||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
|
@ -1192,7 +1194,7 @@ def forward(self, pred_1, x_1):
|
|||
ones_like = torch.ops.aten.ones_like.default(getitem, pin_memory = False); getitem = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (x_1, ones_like)); pred_1 = true_graph_1 = false_graph_1 = x_1 = ones_like = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (ones_like, x_1)); pred_1 = true_graph_1 = false_graph_1 = ones_like = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
return (getitem_1,)""", # noqa: B950
|
||||
)
|
||||
|
|
@ -1322,18 +1324,22 @@ def forward(self, pred_1, x_1):
|
|||
"""\
|
||||
def forward(self):
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
ones_like = torch.ops.aten.ones_like.default(_tensor_constant0, pin_memory = False, memory_format = torch.preserve_format); _tensor_constant0 = None
|
||||
ones_like = torch.ops.aten.ones_like.default(_tensor_constant0, pin_memory = False,\
|
||||
memory_format = torch.preserve_format); _tensor_constant0 = None
|
||||
_tensor_constant1 = self._tensor_constant1
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_tensor_constant2 = self._tensor_constant2
|
||||
_tensor_constant3 = self._tensor_constant3
|
||||
_tensor_constant4 = self._tensor_constant4
|
||||
cond = torch.ops.higher_order.cond(_tensor_constant1, true_graph_0, false_graph_0, (_tensor_constant2, _tensor_constant3, _tensor_constant4, ones_like)); _tensor_constant1 = true_graph_0 = false_graph_0 = _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = ones_like = None
|
||||
cond = torch.ops.higher_order.cond(_tensor_constant1, true_graph_0, false_graph_0,\
|
||||
(ones_like, _tensor_constant2, _tensor_constant3, _tensor_constant4));\
|
||||
_tensor_constant1 = true_graph_0 = false_graph_0 = ones_like = _tensor_constant2 =\
|
||||
_tensor_constant3 = _tensor_constant4 = None
|
||||
getitem = cond[0]; getitem = None
|
||||
getitem_1 = cond[1]
|
||||
getitem_2 = cond[2]; cond = None
|
||||
return (getitem_1, getitem_2)""", # noqa: B950
|
||||
return (getitem_1, getitem_2)""",
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
|
|
@ -1419,56 +1425,31 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1):
|
|||
|
||||
if compile_mode == "eager" or compile_mode == "none":
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(gm.print_readable(print_output=False)),
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
class f(torch.nn.Module):
|
||||
def forward(self):
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
ones_like: "f32[]" = torch.ops.aten.ones_like.default(_tensor_constant0, pin_memory = False, memory_format = torch.preserve_format); _tensor_constant0 = None
|
||||
_tensor_constant1 = self._tensor_constant1
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_tensor_constant2 = self._tensor_constant2
|
||||
_tensor_constant3 = self._tensor_constant3
|
||||
_tensor_constant4 = self._tensor_constant4
|
||||
_tensor_constant5 = self._tensor_constant5
|
||||
_tensor_constant6 = self._tensor_constant6
|
||||
cond = torch.ops.higher_order.cond(_tensor_constant1, true_graph_0, false_graph_0, (_tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5, _tensor_constant6, ones_like)); _tensor_constant1 = true_graph_0 = false_graph_0 = _tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 = _tensor_constant6 = ones_like = None
|
||||
getitem: "f32[4, 5]" = cond[0]; getitem = None
|
||||
getitem_1: "f32[2, 4]" = cond[1]
|
||||
getitem_2: "f32[2, 1]" = cond[2]
|
||||
getitem_3: "f32[2, 4]" = cond[3]
|
||||
getitem_4: "f32[1, 5]" = cond[4]; cond = None
|
||||
return (getitem_1, getitem_2, getitem_3, getitem_4)
|
||||
|
||||
class true_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[4, 5]", arg1_1: "f32[2, 4]", arg2_1: "f32[2, 1]", arg3_1: "f32[2, 4]", arg4_1: "f32[1, 5]", arg5_1: "f32[]"):
|
||||
mm: "f32[2, 5]" = torch.ops.aten.mm.default(arg1_1, arg0_1); arg1_1 = None
|
||||
add: "f32[2, 5]" = torch.ops.aten.add.Tensor(mm, arg2_1); mm = arg2_1 = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = sum_1 = None
|
||||
expand: "f32[2, 5]" = torch.ops.aten.expand.default(arg5_1, [2, 5]); arg5_1 = None
|
||||
sum_2: "f32[2, 1]" = torch.ops.aten.sum.dim_IntList(expand, [1], True)
|
||||
t: "f32[5, 4]" = torch.ops.aten.t.default(arg0_1)
|
||||
mm_1: "f32[2, 4]" = torch.ops.aten.mm.default(expand, t); expand = t = None
|
||||
zeros_like: "f32[4, 5]" = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False); arg0_1 = None
|
||||
zeros_like_1: "f32[2, 4]" = torch.ops.aten.zeros_like.default(arg3_1, pin_memory = False); arg3_1 = None
|
||||
zeros_like_2: "f32[1, 5]" = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None
|
||||
return [zeros_like, mm_1, sum_2, zeros_like_1, zeros_like_2]
|
||||
|
||||
class false_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[4, 5]", arg1_1: "f32[2, 4]", arg2_1: "f32[2, 1]", arg3_1: "f32[2, 4]", arg4_1: "f32[1, 5]", arg5_1: "f32[]"):
|
||||
mm: "f32[2, 5]" = torch.ops.aten.mm.default(arg3_1, arg0_1); arg3_1 = None
|
||||
add: "f32[2, 5]" = torch.ops.aten.add.Tensor(mm, arg4_1); mm = arg4_1 = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = sum_1 = None
|
||||
expand: "f32[2, 5]" = torch.ops.aten.expand.default(arg5_1, [2, 5]); arg5_1 = None
|
||||
sum_2: "f32[1, 5]" = torch.ops.aten.sum.dim_IntList(expand, [0], True)
|
||||
t: "f32[5, 4]" = torch.ops.aten.t.default(arg0_1)
|
||||
mm_1: "f32[2, 4]" = torch.ops.aten.mm.default(expand, t); expand = t = None
|
||||
zeros_like: "f32[4, 5]" = torch.ops.aten.zeros_like.default(arg0_1, pin_memory = False); arg0_1 = None
|
||||
zeros_like_1: "f32[2, 4]" = torch.ops.aten.zeros_like.default(arg1_1, pin_memory = False); arg1_1 = None
|
||||
zeros_like_2: "f32[2, 1]" = torch.ops.aten.zeros_like.default(arg2_1, pin_memory = False); arg2_1 = None
|
||||
return [zeros_like, zeros_like_1, zeros_like_2, mm_1, sum_2]
|
||||
""", # noqa: B950
|
||||
def forward(self):
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
ones_like = torch.ops.aten.ones_like.default(_tensor_constant0, pin_memory = False,\
|
||||
memory_format = torch.preserve_format); _tensor_constant0 = None
|
||||
_tensor_constant1 = self._tensor_constant1
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
_tensor_constant2 = self._tensor_constant2
|
||||
_tensor_constant3 = self._tensor_constant3
|
||||
_tensor_constant4 = self._tensor_constant4
|
||||
_tensor_constant5 = self._tensor_constant5
|
||||
_tensor_constant6 = self._tensor_constant6
|
||||
cond = torch.ops.higher_order.cond(_tensor_constant1, true_graph_0, false_graph_0,\
|
||||
(ones_like, _tensor_constant2, _tensor_constant3, _tensor_constant4, _tensor_constant5,\
|
||||
_tensor_constant6)); _tensor_constant1 = true_graph_0 = false_graph_0 = ones_like =\
|
||||
_tensor_constant2 = _tensor_constant3 = _tensor_constant4 = _tensor_constant5 =\
|
||||
_tensor_constant6 = None
|
||||
getitem = cond[0]; getitem = None
|
||||
getitem_1 = cond[1]
|
||||
getitem_2 = cond[2]
|
||||
getitem_3 = cond[3]
|
||||
getitem_4 = cond[4]; cond = None
|
||||
return (getitem_1, getitem_2, getitem_3, getitem_4)""",
|
||||
)
|
||||
else:
|
||||
self.assertExpectedInline(
|
||||
|
|
@ -4417,47 +4398,37 @@ def forward(self, L_pred_ : torch.Tensor, L_x_ : torch.Tensor):
|
|||
grad_out = torch.ones_like(result)
|
||||
return (result, grad_out)""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.graphs[1].print_readable(print_output=False)),
|
||||
gm.cond_true_0.code.strip(),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, L_ctx_saved_tensors_0_: "f32[4]", L_ctx_pred: "b8[]", L_args_1_: "f32[4]"):
|
||||
l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_
|
||||
l_ctx_pred = L_ctx_pred
|
||||
l_args_1_ = L_args_1_
|
||||
def forward(self, l_x_):
|
||||
l_x__1 = l_x_
|
||||
sin = l_x__1.sin(); l_x__1 = None
|
||||
return (sin,)""", # noqa: B950
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
gm.cond_false_0.code.strip(),
|
||||
"""\
|
||||
def forward(self, l_x_):
|
||||
l_x__1 = l_x_
|
||||
cos = l_x__1.cos(); l_x__1 = None
|
||||
return (cos,)""", # noqa: B950
|
||||
)
|
||||
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, [l_args_1_, l_ctx_saved_tensors_0_]); l_ctx_pred = cond_true_0 = cond_false_0 = l_args_1_ = l_ctx_saved_tensors_0_ = None
|
||||
getitem: "f32[4]" = cond[0]; cond = None
|
||||
return (getitem,)
|
||||
|
||||
class cond_true_0(torch.nn.Module):
|
||||
def forward(self, l_args_1_, l_ctx_saved_tensors_0_):
|
||||
l_args_1__1 = l_args_1_
|
||||
l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_
|
||||
|
||||
sin: "f32[4]" = torch.ops.aten.sin.default(l_ctx_saved_tensors_0__1); sin = None
|
||||
|
||||
cos: "f32[4]" = torch.ops.aten.cos.default(l_ctx_saved_tensors_0__1); l_ctx_saved_tensors_0__1 = None
|
||||
|
||||
mul: "f32[4]" = torch.ops.aten.mul.Tensor(l_args_1__1, cos); l_args_1__1 = cos = None
|
||||
return (mul,)
|
||||
|
||||
class cond_false_0(torch.nn.Module):
|
||||
def forward(self, l_args_1_, l_ctx_saved_tensors_0_):
|
||||
l_args_1__1 = l_args_1_
|
||||
l_ctx_saved_tensors_0__1 = l_ctx_saved_tensors_0_
|
||||
|
||||
cos: "f32[4]" = torch.ops.aten.cos.default(l_ctx_saved_tensors_0__1); cos = None
|
||||
|
||||
sin: "f32[4]" = torch.ops.aten.sin.default(l_ctx_saved_tensors_0__1); l_ctx_saved_tensors_0__1 = None
|
||||
|
||||
neg: "f32[4]" = torch.ops.aten.neg.default(sin); sin = None
|
||||
|
||||
mul: "f32[4]" = torch.ops.aten.mul.Tensor(l_args_1__1, neg); l_args_1__1 = neg = None
|
||||
return (mul,)
|
||||
""", # noqa: B950
|
||||
backward_gm = backend.graphs[1]
|
||||
self.assertExpectedInline(
|
||||
backward_gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, L_ctx_saved_tensors_0_ : torch.Tensor, L_ctx_pred : torch.Tensor, L_flat_grads_0_ : torch.Tensor):
|
||||
l_ctx_saved_tensors_0_ = L_ctx_saved_tensors_0_
|
||||
l_ctx_pred = L_ctx_pred
|
||||
l_flat_grads_0_ = L_flat_grads_0_
|
||||
cond_true_0 = self.cond_true_0
|
||||
cond_false_0 = self.cond_false_0
|
||||
cond = torch.ops.higher_order.cond(l_ctx_pred, cond_true_0, cond_false_0, [l_ctx_saved_tensors_0_, l_flat_grads_0_]); l_ctx_pred = cond_true_0 = cond_false_0 = l_ctx_saved_tensors_0_ = l_flat_grads_0_ = None
|
||||
getitem = cond[0]; cond = None
|
||||
return (getitem,)""", # noqa: B950
|
||||
)
|
||||
|
||||
def test_while_loop_op_mismatch_in_meta(self):
|
||||
|
|
@ -4994,7 +4965,7 @@ def forward(self, a_1, b_1):
|
|||
sym_size_int_3 = torch.ops.aten.sym_size.int(b_1, 1)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3)); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [a_1, b_1, sym_size_int, sym_size_int_1, sym_size_int_2, sym_size_int_3]); gt = true_graph_0 = false_graph_0 = a_1 = b_1 = sym_size_int = sym_size_int_1 = sym_size_int_2 = sym_size_int_3 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
|
@ -5474,11 +5445,11 @@ def forward(self, x_1):
|
|||
def forward(self, x_1, pred_1, pred2_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, (x_1,)); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||
return add""", # noqa: B950
|
||||
|
|
@ -5647,11 +5618,11 @@ def forward(self, arg0_1):
|
|||
def forward(self, x_1, pred_1, pred2_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, (x_1,)); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
cond = torch.ops.higher_order.cond(pred_1, true_graph_0, false_graph_0, [x_1]); pred_1 = true_graph_0 = false_graph_0 = None
|
||||
getitem = cond[0]; cond = None
|
||||
true_graph_1 = self.true_graph_1
|
||||
false_graph_1 = self.false_graph_1
|
||||
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, (x_1,)); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
cond_1 = torch.ops.higher_order.cond(pred2_1, true_graph_1, false_graph_1, [x_1]); pred2_1 = true_graph_1 = false_graph_1 = x_1 = None
|
||||
getitem_1 = cond_1[0]; cond_1 = None
|
||||
add = torch.ops.aten.add.Tensor(getitem, getitem_1); getitem = getitem_1 = None
|
||||
return add""", # noqa: B950
|
||||
|
|
@ -6125,7 +6096,7 @@ def forward(self, x_1):
|
|||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, sym_size_int, sym_size_int_1]); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
|
@ -6157,7 +6128,7 @@ def forward(self, x_1):
|
|||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 0)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, (x_1, sym_size_int_1)); gt = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = None
|
||||
cond = torch.ops.higher_order.cond(gt, true_graph_0, false_graph_0, [x_1, sym_size_int_1]); gt = true_graph_0 = false_graph_0 = x_1 = sym_size_int_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
|
@ -6238,7 +6209,7 @@ def forward(self, x_1):
|
|||
false_graph_0 = self.false_graph_0
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
_tensor_constant1 = self._tensor_constant1
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1)); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, _tensor_constant0, sym_size_int, sym_size_int_1, _tensor_constant1]); eq = true_graph_0 = false_graph_0 = x_1 = _tensor_constant0 = sym_size_int = sym_size_int_1 = _tensor_constant1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
|
@ -6389,7 +6360,7 @@ def forward(self, pred_1, x_1):
|
|||
def forward(self, arg0_1, arg1_1):
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, (arg0_1,)); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
cond = torch.ops.higher_order.cond(arg1_1, true_graph_0, false_graph_0, [arg0_1]); arg1_1 = true_graph_0 = false_graph_0 = arg0_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return [getitem]""", # noqa: B950
|
||||
)
|
||||
|
|
@ -6478,7 +6449,7 @@ def forward(self, x_1):
|
|||
sym_size_int_1 = torch.ops.aten.sym_size.int(x_1, 1)
|
||||
true_graph_0 = self.true_graph_0
|
||||
false_graph_0 = self.false_graph_0
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, (x_1, sym_size_int, sym_size_int_1)); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
|
||||
cond = torch.ops.higher_order.cond(eq, true_graph_0, false_graph_0, [x_1, sym_size_int, sym_size_int_1]); eq = true_graph_0 = false_graph_0 = x_1 = sym_size_int = sym_size_int_1 = None
|
||||
getitem = cond[0]; cond = None
|
||||
return getitem""", # noqa: B950
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ from torch._functorch.utils import exposed_in
|
|||
from torch._higher_order_ops.utils import (
|
||||
_has_potential_branch_input_alias,
|
||||
_has_potential_branch_input_mutation,
|
||||
_maybe_reenter_make_fx,
|
||||
_maybe_run_with_interpreter,
|
||||
_set_compilation_env,
|
||||
reenter_make_fx,
|
||||
|
|
@ -242,79 +241,6 @@ def create_fw_bw_graph_branches(true_fn, false_fn, *operands):
|
|||
return fw_true_graph, fw_false_graph, joint_true_graph, joint_false_graph
|
||||
|
||||
|
||||
def materialize_as_graph(
|
||||
fn: Callable,
|
||||
args: tuple[Any],
|
||||
include_key_set: torch._C.DispatchKeySet,
|
||||
exclude_key_set: torch._C.DispatchKeySet,
|
||||
force_enable_grad=False,
|
||||
) -> torch.fx.GraphModule:
|
||||
@torch._dynamo.disable(recursive=True)
|
||||
def _materialize_as_graph_inner():
|
||||
with suspend_functionalization(), disable_functional_mode():
|
||||
with disable_proxy_modes_tracing():
|
||||
unfunc_t = [_from_fun(arg) for arg in args]
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
torch._C._ForceDispatchKeyGuard(include_key_set, exclude_key_set),
|
||||
)
|
||||
if force_enable_grad:
|
||||
stack.enter_context(torch.enable_grad())
|
||||
return _maybe_reenter_make_fx(fn)(*unfunc_t)
|
||||
|
||||
gm = _materialize_as_graph_inner()
|
||||
assert gm is not None
|
||||
return gm
|
||||
|
||||
|
||||
def create_bw_fn(fn: Callable, args: tuple[Any]) -> Callable:
|
||||
"""
|
||||
For a fn that accepts flat inputs and returns flat outputs:
|
||||
fw_out = fn(*args),
|
||||
this function returns:
|
||||
grad_args = bw_fn(*args_and_grad_output)
|
||||
with the following invariants:
|
||||
1. args + fw_out has an 1-1 correspondence to args_and_grad_output
|
||||
2. grad_args has an 1-1 corresponsence to args
|
||||
3. for tensor arg whose requires_grad is False, its corresponding grad in
|
||||
grad_args will be a zero tensor with the same shape.
|
||||
"""
|
||||
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
||||
from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad
|
||||
|
||||
dummy_aot_config = AOTConfig(
|
||||
fw_compiler=None, # type: ignore[arg-type]
|
||||
bw_compiler=None, # type: ignore[arg-type]
|
||||
partition_fn=None, # type: ignore[arg-type]
|
||||
decompositions={},
|
||||
num_params_buffers=0,
|
||||
aot_id=0,
|
||||
keep_inference_input_mutations=False,
|
||||
)
|
||||
n_primals = len(args)
|
||||
|
||||
bw_fn = create_joint(
|
||||
prepare_fw_with_masks_all_requires_grad(fn), aot_config=dummy_aot_config
|
||||
)
|
||||
|
||||
def flat_fn(*args_and_grad_outs):
|
||||
primals = args_and_grad_outs[:n_primals]
|
||||
tangents = args_and_grad_outs[n_primals:]
|
||||
grad_args = bw_fn(primals, tangents)[1]
|
||||
assert len(args) == len(grad_args)
|
||||
return [
|
||||
(
|
||||
torch.zeros_like(arg)
|
||||
if isinstance(arg, torch.Tensor) and grad is None
|
||||
else grad
|
||||
)
|
||||
for grad, arg in zip(grad_args, primals)
|
||||
]
|
||||
|
||||
return flat_fn
|
||||
|
||||
|
||||
def trace_cond(proxy_mode, func_overload, pred, true_fn, false_fn, operands):
|
||||
assert isinstance(
|
||||
operands, (list, tuple)
|
||||
|
|
@ -381,69 +307,60 @@ class CondAutogradOp(torch.autograd.Function):
|
|||
def forward(
|
||||
ctx,
|
||||
pred,
|
||||
true_fn,
|
||||
false_fn,
|
||||
fw_true_graph,
|
||||
fw_false_graph,
|
||||
joint_true_graph,
|
||||
joint_false_graph,
|
||||
*operands,
|
||||
):
|
||||
ctx._pred = pred
|
||||
ctx._true_bw_fn = create_bw_fn(
|
||||
true_fn,
|
||||
operands,
|
||||
)
|
||||
ctx._false_bw_fn = create_bw_fn(
|
||||
false_fn,
|
||||
operands,
|
||||
)
|
||||
# We snapshot the dispatch keys in forward for materializing the
|
||||
# the bw_graph in backward.
|
||||
ctx._fw_include_key_set = torch._C._dispatch_tls_local_include_set()
|
||||
ctx._fw_exclude_key_set = torch._C._dispatch_tls_local_exclude_set()
|
||||
ctx._joint_true_graph = joint_true_graph
|
||||
ctx._joint_false_graph = joint_false_graph
|
||||
save_tensors_and_symints_for_backward(ctx, operands)
|
||||
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return cond_op(pred, true_fn, false_fn, operands)
|
||||
return cond_op(pred, fw_true_graph, fw_false_graph, operands)
|
||||
|
||||
@staticmethod
|
||||
def backward(ctx, *flat_grads):
|
||||
operands = saved_tensors_and_symints(ctx)
|
||||
args = operands + flat_grads
|
||||
# TODO: we need to materialize the bw graphs because dynamo is unable to
|
||||
# trace through the joint funcion when torch.compile torch.autograd.grad.
|
||||
true_bw_gm = materialize_as_graph(
|
||||
ctx._true_bw_fn,
|
||||
args,
|
||||
ctx._fw_include_key_set,
|
||||
ctx._fw_exclude_key_set,
|
||||
force_enable_grad=True,
|
||||
)
|
||||
false_bw_gm = materialize_as_graph(
|
||||
ctx._false_bw_fn,
|
||||
args,
|
||||
ctx._fw_include_key_set,
|
||||
ctx._fw_exclude_key_set,
|
||||
force_enable_grad=True,
|
||||
)
|
||||
|
||||
grads = cond_op(
|
||||
ctx._pred,
|
||||
true_bw_gm,
|
||||
false_bw_gm,
|
||||
args,
|
||||
ctx._joint_true_graph,
|
||||
ctx._joint_false_graph,
|
||||
flat_grads + operands,
|
||||
)
|
||||
return None, None, None, *grads
|
||||
return None, None, None, None, None, *grads
|
||||
|
||||
|
||||
# Note:
|
||||
# As long as one of the tensors in pred or operands requires grad,
|
||||
# all the output would require grad with backward fn set to be the CondAutogradOp.
|
||||
# This is consistent with autograd.Function's semantic.
|
||||
@cond_op.py_impl(DispatchKey.Autograd)
|
||||
def cond_autograd(pred, true_fn, false_fn, operands):
|
||||
return CondAutogradOp.apply(
|
||||
# A shortcut for the case where all inputs don't require gradient,
|
||||
# we skip tracing the forward and backward graph.
|
||||
if pytree.tree_all_only(
|
||||
torch.Tensor,
|
||||
lambda t: not t.requires_grad, # type: ignore[union-attr]
|
||||
(pred, operands),
|
||||
):
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return cond_op(pred, true_fn, false_fn, operands)
|
||||
|
||||
(
|
||||
fw_true_graph,
|
||||
fw_false_graph,
|
||||
joint_true_graph,
|
||||
joint_false_graph,
|
||||
) = create_fw_bw_graph_branches(true_fn, false_fn, *operands)
|
||||
flat_out = CondAutogradOp.apply(
|
||||
pred,
|
||||
true_fn,
|
||||
false_fn,
|
||||
fw_true_graph,
|
||||
fw_false_graph,
|
||||
joint_true_graph,
|
||||
joint_false_graph,
|
||||
*operands,
|
||||
)
|
||||
return flat_out
|
||||
|
||||
|
||||
@cond_op.py_impl(ProxyTorchDispatchMode)
|
||||
|
|
|
|||
|
|
@ -424,26 +424,6 @@ def prepare_fw_with_masks(fn):
|
|||
return fw_with_masks
|
||||
|
||||
|
||||
def prepare_fw_with_masks_all_requires_grad(fn):
|
||||
def fw_with_masks(*args):
|
||||
fw_out = fn(*args)
|
||||
# Note [force all outputs to be require grad]
|
||||
# Instead of using the original fn, we set the output of original
|
||||
# fn to all require grad. This is consistent with the behavior
|
||||
# of autograd.Function, where if any one of the inputs requires grad
|
||||
# all output will be require grad. This also makes the downstream
|
||||
# require_gradness reasoning much easier.
|
||||
if pytree.tree_any_only(torch.Tensor, lambda t: t.requires_grad, args):
|
||||
fw_out = pytree.tree_map_only(
|
||||
torch.Tensor, lambda x: x.requires_grad_(True), fw_out
|
||||
)
|
||||
return fw_out, pytree.tree_map_only(
|
||||
torch.Tensor, lambda x: x.requires_grad, fw_out
|
||||
)
|
||||
|
||||
return fw_with_masks
|
||||
|
||||
|
||||
# This function replaces None gradients with all-zero gradients.
|
||||
# `None` gradients are problematic for CUDA graphs. Those gradients are
|
||||
# replaced with an all-zero tensor for better optimization
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user