Add expanded_def option for FX printing, render descriptor, update tests (#158708)

----

- First, we add a new expanded_def to FX, which will expand the
  definitions of variables into multiple lines, one per variable
  definition.  This makes extremely long args/return lists much
  more readable.

- Next, we extend this mechanism to also print out descriptors on
  placeholders and return values, as comments, if available.  This
  is how we will test descriptors.

- We update tlparse for AOTAutograd to use this format.

- We update expect tests to use this format and update their formats,
  so you can inspect what it can look at.  There may be other tests
  I should update, open to suggestions.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158708
Approved by: https://github.com/wconstab
ghstack dependencies: #158624
This commit is contained in:
Edward Z. Yang 2025-07-24 19:06:11 -07:00 committed by PyTorch MergeBot
parent bf311141d6
commit 204eb4da5e
11 changed files with 748 additions and 125 deletions

View File

@ -2246,24 +2246,57 @@ class GraphModule(torch.nn.Module):
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
def forward(
self,
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
):
mul: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7)
return (
mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
mul_3, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_1, # SavedForBackwardsAOTOutput(idx=0)
primals_5, # SavedForBackwardsAOTOutput(idx=1)
primals_7, # SavedForBackwardsAOTOutput(idx=2)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
def forward(
self,
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
tangents_1: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
tangents_2: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
):
mul_8: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
mul_9: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7)
return (
None, # None
None, # None
mul_8, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
mul_9, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
)
""", # noqa: B950
)
@ -2279,27 +2312,58 @@ class GraphModule(torch.nn.Module):
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
def forward(
self,
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
view: "f32[s16, s47]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
view_1: "f32[s16, s47]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7)
return (
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
primals_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
primals_5, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_5, # SavedForBackwardsAOTOutput(idx=0)
primals_7, # SavedForBackwardsAOTOutput(idx=1)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16, s47]", tangents_2: "f32[s16, s47]"):
def forward(
self,
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
tangents_1: "f32[s16, s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
tangents_2: "f32[s16, s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
return (
None, # None
None, # None
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
)
""", # noqa: B950
)
@ -2317,10 +2381,19 @@ class GraphModule(torch.nn.Module):
fw, bw = self._compile_check(f, [(tt, a, b)], dynamic=True, call_backward=True)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_3: "f32[s97, s98]", primals_4: "f32[s97, s98]", primals_5: "Sym(s97)", primals_6: "Sym(s98)", primals_7: "Sym(s98)"):
def forward(
self,
primals_1: "Sym(s97)", # PlainAOTInput(idx=0)
primals_2: "Sym(s98)", # PlainAOTInput(idx=1)
primals_3: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
primals_4: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
primals_5: "Sym(s97)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_6: "Sym(s98)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
primals_7: "Sym(s98)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
):
mul: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
mul_3: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
mul_8: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None
@ -2329,15 +2402,33 @@ class GraphModule(torch.nn.Module):
mul_19: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None
mul_24: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None
mul_27: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None
return (mul_24, mul_27, primals_5, primals_7, primals_7, primals_1, primals_2, primals_5, primals_7)
return (
mul_24, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
mul_27, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_1, # SavedForBackwardsAOTOutput(idx=0)
primals_2, # SavedForBackwardsAOTOutput(idx=1)
primals_5, # SavedForBackwardsAOTOutput(idx=2)
primals_7, # SavedForBackwardsAOTOutput(idx=3)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_5: "Sym(s97)", primals_7: "Sym(s98)", tangents_1: "f32[s97, s98]", tangents_2: "f32[s97, s98]"):
def forward(
self,
primals_1: "Sym(s97)", # PlainAOTInput(idx=0)
primals_2: "Sym(s98)", # PlainAOTInput(idx=1)
primals_5: "Sym(s97)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_7: "Sym(s98)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
tangents_1: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
tangents_2: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
):
mul_32: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None
mul_33: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None
mul_34: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None
@ -2346,7 +2437,15 @@ class GraphModule(torch.nn.Module):
mul_37: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None
mul_38: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None
mul_39: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None
return (None, None, mul_38, mul_39, primals_5, primals_7, primals_7)
return (
None, # None
None, # None
mul_38, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
mul_39, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
)
""", # noqa: B950
)
@ -2362,27 +2461,58 @@ class GraphModule(torch.nn.Module):
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
def forward(
self,
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
view: "f32[s47, s16]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
view_1: "f32[s47, s16]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7)
return (
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_5, # SavedForBackwardsAOTOutput(idx=0)
primals_7, # SavedForBackwardsAOTOutput(idx=1)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
def forward(
self,
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
tangents_1: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
tangents_2: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
return (
None, # None
None, # None
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
)
""", # noqa: B950
)
@ -2398,28 +2528,57 @@ class GraphModule(torch.nn.Module):
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
def forward(
self,
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
return (view, view_1, mul_6, primals_5, primals_7)
return (
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
mul_6, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_5, # SavedForBackwardsAOTOutput(idx=0)
primals_7, # SavedForBackwardsAOTOutput(idx=1)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
def forward(
self,
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
tangents_1: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
tangents_2: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
return (
None, # None
None, # None
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
)
""", # noqa: B950
)
@ -2435,28 +2594,58 @@ class GraphModule(torch.nn.Module):
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
def forward(
self,
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
):
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6])
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
return (clone, view, view_1, mul_6, primals_5, primals_7)
return (
clone, # PlainAOTOutput(idx=0)
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
mul_6, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0)
primals_5, # SavedForBackwardsAOTOutput(idx=0)
primals_7, # SavedForBackwardsAOTOutput(idx=1)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
def forward(
self,
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
tangents_1: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
tangents_2: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
):
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
return (
None, # None
None, # None
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
)
""", # noqa: B950
)
@ -2517,53 +2706,94 @@ class GraphModule(torch.nn.Module):
)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[3, 4]", primals_2: "f32[3, 4]"):
def forward(
self,
primals_1: "f32[3, 4]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='a')
primals_2: "f32[3, 4]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='b')
):
clone: "f32[3, 4]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
view: "f32[12]" = torch.ops.aten.view.default(clone, [-1])
view_1: "f32[12]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, clone_1)
return (
clone, # PlainAOTOutput(idx=0)
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
clone_1, # PlainAOTOutput(idx=2)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(fw[1].print_readable(print_output=False)),
normalize_gm(fw[1].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
def forward(
self,
primals_1: "Sym(s16)", # PlainAOTInput(idx=0)
primals_2: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='a')
primals_3: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='b')
primals_4: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=1), idx=1)
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
):
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
return (
clone, # PlainAOTOutput(idx=0)
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
sym_size_int_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0)
clone_1, # PlainAOTOutput(idx=2)
primals_5, # SavedForBackwardsAOTOutput(idx=0)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[12]", tangents_2: "f32[12]"):
def forward(
self,
tangents_1: "f32[12]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
tangents_2: "f32[12]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
):
view_2: "f32[3, 4]" = torch.ops.aten.view.default(tangents_1, [3, 4]); tangents_1 = None
view_3: "f32[3, 4]" = torch.ops.aten.view.default(tangents_2, [3, 4]); tangents_2 = None
return (view_2, view_3)
return (
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a')
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b')
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[1].print_readable(print_output=False)),
normalize_gm(bw[1].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
def forward(
self,
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
tangents_1: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
tangents_2: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
):
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
return (None, view_2, view_3, primals_5, primals_5)
return (
None, # None
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='a')
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='b')
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1)
primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0)
)
""", # noqa: B950
)
@ -2587,28 +2817,53 @@ class GraphModule(torch.nn.Module):
)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
def forward(
self,
primals_1: "Sym(s16)", # PlainAOTInput(idx=0)
primals_2: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='a')
primals_3: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='b')
primals_4: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=1), idx=1)
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
):
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
return (
clone, # PlainAOTOutput(idx=0)
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
sym_size_int_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0)
clone_1, # PlainAOTOutput(idx=2)
primals_5, # SavedForBackwardsAOTOutput(idx=0)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
def forward(
self,
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
tangents_1: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
tangents_2: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
):
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
return (None, view_2, view_3, primals_5, primals_5)
return (
None, # None
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='a')
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='b')
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1)
primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0)
)
""", # noqa: B950
)
@ -2624,27 +2879,41 @@ class GraphModule(torch.nn.Module):
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[24]", primals_2: "f32[24]"):
def forward(
self,
primals_1: "f32[24]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='a')
primals_2: "f32[24]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='b')
):
clone: "f32[24]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
clone_1: "f32[24]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
view: "f32[3, 2, 4]" = torch.ops.aten.view.default(clone, [3, 2, 4]); clone = None
view_1: "f32[3, 2, 4]" = torch.ops.aten.view.default(clone_1, [3, 2, 4]); clone_1 = None
return (view, view_1)
return (
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, tangents_1: "f32[3, 2, 4]", tangents_2: "f32[3, 2, 4]"):
def forward(
self,
tangents_1: "f32[3, 2, 4]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
tangents_2: "f32[3, 2, 4]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
):
view_2: "f32[24]" = torch.ops.aten.view.default(tangents_1, [24]); tangents_1 = None
view_3: "f32[24]" = torch.ops.aten.view.default(tangents_2, [24]); tangents_2 = None
return (view_2, view_3)
return (
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a')
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b')
)
""", # noqa: B950
)
@ -2771,24 +3040,67 @@ class GraphModule(torch.nn.Module):
fw, bw = self._compile_check(f, [(nt,)], dynamic=True, call_backward=True)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
def forward(
self,
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
primals_2: "Sym(s71)", # PlainAOTInput(idx=1)
primals_3: "Sym(s55)", # PlainAOTInput(idx=2)
primals_4: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values')
primals_5: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets')
primals_6: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor')
primals_7: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor')
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
primals_9: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2)
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
):
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
mul: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10)
return (
mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')
primals_6, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
primals_7, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
primals_8, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
primals_10, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
primals_10, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
primals_1, # SavedForBackwardsAOTOutput(idx=0)
primals_8, # SavedForBackwardsAOTOutput(idx=1)
primals_10, # SavedForBackwardsAOTOutput(idx=2)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s51)", primals_8: "Sym(s51)", primals_10: "Sym(s55)", tangents_1: "f64[s64, s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
def forward(
self,
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
tangents_1: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values')
tangents_2: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_offsets')
tangents_3: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_min_seqlen_tensor')
tangents_4: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_max_seqlen_tensor')
):
mul_1: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
return (
None, # None
None, # None
None, # None
mul_1, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_values')
tangents_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_offsets')
tangents_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor')
tangents_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor')
primals_8, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0)
primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2)
primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1)
)
""", # noqa: B950
)
@ -2804,28 +3116,71 @@ class GraphModule(torch.nn.Module):
fw, bw = self._compile_check(f, [(nt,)], dynamic=True, call_backward=True)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
def forward(
self,
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
primals_2: "Sym(s71)", # PlainAOTInput(idx=1)
primals_3: "Sym(s55)", # PlainAOTInput(idx=2)
primals_4: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values')
primals_5: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets')
primals_6: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor')
primals_7: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor')
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
primals_9: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2)
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
):
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
add_2: "Sym(2*s55)" = primals_10 + primals_10
return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2)
return (
cat, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')
primals_6, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
primals_7, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
primals_8, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
add_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
add_2, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
primals_8, # SavedForBackwardsAOTOutput(idx=0)
primals_10, # SavedForBackwardsAOTOutput(idx=1)
add_2, # SavedForBackwardsAOTOutput(idx=2)
)
""", # noqa: B950
)
self.assertExpectedInline(
normalize_gm(bw[0].print_readable(print_output=False)),
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_8: "Sym(s51)", primals_10: "Sym(s55)", add_2: "Sym(2*s55)", tangents_1: "f64[s64, 2*s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
def forward(
self,
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
add_2: "Sym(2*s55)",
tangents_1: "f64[s64, 2*s55]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values')
tangents_2: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_offsets')
tangents_3: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_min_seqlen_tensor')
tangents_4: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_max_seqlen_tensor')
):
slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
return (
None, # None
None, # None
None, # None
add_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_values')
tangents_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_offsets')
tangents_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor')
tangents_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor')
primals_8, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0)
primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2)
primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1)
)
""", # noqa: B950
)
@ -2850,10 +3205,22 @@ class GraphModule(torch.nn.Module):
)
self.assertExpectedInline(
normalize_gm(fw[0].print_readable(print_output=False)),
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "Sym(s51)", arg1_1: "Sym(s71)", arg2_1: "Sym(s55)", arg3_1: "f64[9, s55]", arg4_1: "i64[s51 + 1]", arg5_1: "f32[s0, 0]", arg6_1: "f32[s83, 0]", arg7_1: "Sym(s51)", arg8_1: "Sym(s55)", arg9_1: "Sym(s55)"):
def forward(
self,
arg0_1: "Sym(s51)", # PlainAOTInput(idx=0)
arg1_1: "Sym(s71)", # PlainAOTInput(idx=1)
arg2_1: "Sym(s55)", # PlainAOTInput(idx=2)
arg3_1: "f64[9, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values')
arg4_1: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets')
arg5_1: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor')
arg6_1: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor')
arg7_1: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
arg8_1: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2)
arg9_1: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
):
randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
@ -2874,7 +3241,14 @@ class <lambda>(torch.nn.Module):
sym_size_int: "Sym(s55 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None
sym_stride_int: "Sym(s55 + 5)" = torch.ops.aten.sym_stride.int(mul, 0)
return (mul, cat_1, zeros_1, zeros_2, sym_size_int, sym_stride_int)
return (
mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
cat_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')
zeros_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
zeros_2, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
sym_size_int, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
sym_stride_int, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
)
""", # noqa: B950
)

View File

@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
torch.fx.graph.Graph.print_tabular(self)
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None

View File

@ -4837,10 +4837,14 @@ def forward(self, arg0_1):
inps = [torch.randn(2, 2), torch.ones(2)]
gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True)
self.assertExpectedInline(
normalize_gm(gm.print_readable(False)),
normalize_gm(gm.print_readable(False, expanded_def=True)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"):
def forward(
self,
arg0_1: "f32[2, 2]", # PlainAOTInput(idx=0)
arg1_1: "f32[2]", # PlainAOTInput(idx=1)
):
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1)
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
@ -4851,7 +4855,10 @@ class <lambda>(torch.nn.Module):
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(getitem, 3)
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
return (add, add_1)
return (
add, # PlainAOTOutput(idx=0)
add_1, # PlainAOTOutput(idx=1)
)
class true_graph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"):
@ -4925,10 +4932,14 @@ class <lambda>(torch.nn.Module):
inps = [torch.randn(2, 2), torch.ones(2)]
gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True)
self.assertExpectedInline(
normalize_gm(gm.print_readable(False)),
normalize_gm(gm.print_readable(False, expanded_def=True)),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"):
def forward(
self,
arg0_1: "f32[2, 2]", # PlainAOTInput(idx=0)
arg1_1: "f32[2]", # PlainAOTInput(idx=1)
):
cos: "f32[2, 2]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
_set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None
@ -4939,7 +4950,9 @@ class <lambda>(torch.nn.Module):
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
return (add,)
return (
add, # PlainAOTOutput(idx=0)
)
class body_graph_0(torch.nn.Module):
def forward(self, arg0_1: "f32[2]", arg1_1: "f32[2]"):
@ -5098,10 +5111,20 @@ def forward(self, arg0_1):
for node in fx_g.graph.nodes:
node.meta.pop("stack_trace", None)
self.assertExpectedInline(
fx_g.print_readable(print_output=False),
fx_g.print_readable(print_output=False, expanded_def=True),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"):
def forward(
self,
arg0_1: "f32[3, 1, 1, 1]",
arg1_1: "f32[3]",
arg2_1: "f32[3]",
arg3_1: "f32[3]",
arg4_1: "f32[3]",
arg5_1: "f32[3]",
arg6_1: "i64[]",
arg7_1: "f32[1, 1, 3, 3]",
):
# No stacktrace found for following nodes
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg1_1 = None
add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None
@ -5183,10 +5206,20 @@ class <lambda>(torch.nn.Module):
for node in fx_g_inference.graph.nodes:
node.meta.pop("stack_trace", None)
self.assertExpectedInline(
fx_g_inference.print_readable(print_output=False),
fx_g_inference.print_readable(print_output=False, expanded_def=True),
"""\
class <lambda>(torch.nn.Module):
def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"):
def forward(
self,
arg0_1: "f32[3, 1, 1, 1]", # PlainAOTInput(idx=0)
arg1_1: "f32[3]", # PlainAOTInput(idx=1)
arg2_1: "f32[3]", # PlainAOTInput(idx=2)
arg3_1: "f32[3]", # PlainAOTInput(idx=3)
arg4_1: "f32[3]", # PlainAOTInput(idx=4)
arg5_1: "f32[3]", # PlainAOTInput(idx=5)
arg6_1: "i64[]", # PlainAOTInput(idx=6)
arg7_1: "f32[1, 1, 3, 3]", # PlainAOTInput(idx=7)
):
# No stacktrace found for following nodes
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg7_1 = arg0_1 = arg1_1 = None
add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None
@ -5199,7 +5232,13 @@ class <lambda>(torch.nn.Module):
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
return (getitem_3, getitem_4, add, sum_1, detach_2)
return (
getitem_3, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=4))
getitem_4, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=5))
add, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=6))
sum_1, # PlainAOTOutput(idx=0)
detach_2, # PlainAOTOutput(idx=1)
)
""", # noqa: B950
)
# Some important characteristics of the exported graph below:

View File

@ -739,3 +739,11 @@ class DummyAOTOutput(AOTOutput):
def expr(self) -> str:
return f"__dummy{self.idx}"
@dataclasses.dataclass(frozen=True)
class SavedForBackwardsAOTOutput(AOTOutput):
idx: int
def expr(self) -> str:
return f"__saved_for_backwards_{self.idx}"

View File

@ -19,7 +19,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
from torchgen.utils import dataclass_repr
from .. import config
from .descriptors import AOTInput
from .descriptors import AOTInput, BackwardTokenAOTInput
from .functional_utils import (
assert_functional_graph,
propagate_input_mutation_stacktraces,
@ -36,6 +36,7 @@ from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationM
from .utils import (
call_and_expect_output_descs,
copy_fwd_metadata_to_bw_nodes,
fn_wrappers,
register_buffer_assignment_hook,
root_module_when_exporting_non_strict,
unlift_tokens,
@ -65,11 +66,10 @@ def _create_graph(
@functools.wraps(f)
def inner_f(*args):
nonlocal out_descs
assert out_descs is None
out, out_descs = call_and_expect_output_descs(f, args)
return out
# TODO: save args_descs/out_descs to the produced FX graph
with (
enable_python_dispatcher(),
FunctionalTensorMode(
@ -86,6 +86,49 @@ def _create_graph(
pre_dispatch=aot_config.pre_dispatch,
)(*args)
if args_descs is not None:
flat_args_descs, _ = pytree.tree_flatten(args_descs)
flat_out_descs, _ = pytree.tree_flatten(out_descs)
# Unfortunately, flat_args_descs is not guaranteed to match the
# number of actual arguments that show up on the FX graph.
# Specifically, allow_token_discovery=True means that we will
# silently add extra token arguments to the backwards graph.
#
# Although there are a few ways to detect what these tokens are,
# we are going to settle for something dodgy but simple to
# implement: match tangents_token placeholders specifically,
# as these are the only placeholders that are created by token
# discovery (NB: there is NO other code that treats this name
# as load bearing, so this is a bit naughty!)
#
# I originally wanted to detect tokens in exactly the same way
# that they are detected at normal runtime, but to be honest
# the normal runtime detection is pretty strange: it seems the
# backward tokens are not reliably at the end of the argument list
# but *precede* the RNG arguments (I don't understand why this is
# the case). And in unlift_tokens, token arguments are detected
# by seeing if they feed into an effects call! Dastardly. Why
# didn't we just introduce a new type.
i = 0
j = 0
for n in fx_g.graph.nodes:
if n.op == "placeholder":
if n.name.startswith("tangents_token"):
n.meta["desc"] = BackwardTokenAOTInput(j)
j += 1
else:
assert i < len(flat_args_descs), (
(fn_wrappers(inner_f)),
[n for n in fx_g.graph.nodes if n.op == "placeholder"],
flat_args_descs,
)
n.meta["desc"] = flat_args_descs[i]
i += 1
elif n.op == "output":
n.meta["desc"] = flat_out_descs
return fx_g
@ -276,7 +319,10 @@ def aot_dispatch_base_graph(
trace_structured(
"aot_inference_graph",
payload_fn=lambda: fw_module.print_readable(
print_output=False, include_stride=True, include_device=True
print_output=False,
include_stride=True,
include_device=True,
expanded_def=True,
),
)

View File

@ -266,6 +266,7 @@ def aot_stage2_inference(
include_stride=True,
include_device=True,
fast_sympy_print=True,
expanded_def=True,
)
fakified_out_wrapper = FakifiedOutWrapper()
@ -877,7 +878,10 @@ def maybe_log_graph(
def gm_str_fn() -> str:
return gm.print_readable(
print_output=False, include_stride=True, include_device=True
print_output=False,
include_stride=True,
include_device=True,
expanded_def=True,
)
if out_structured_logs is not None:
@ -1347,7 +1351,10 @@ def aot_stage2_autograd(
),
)
joint_graph_str = fx_g.print_readable(
print_output=False, include_stride=True, include_device=True
print_output=False,
include_stride=True,
include_device=True,
expanded_def=True,
)
trace_structured(
"aot_joint_graph",
@ -1596,10 +1603,16 @@ def aot_stage2_autograd(
),
)
fw_module_str = fw_module.print_readable(
print_output=False, include_stride=True, include_device=True
print_output=False,
include_stride=True,
include_device=True,
expanded_def=True,
)
bw_module_str = bw_module.print_readable(
print_output=False, include_stride=True, include_device=True
print_output=False,
include_stride=True,
include_device=True,
expanded_def=True,
)
trace_structured(

View File

@ -544,3 +544,12 @@ def call_and_expect_output_descs(fn, args):
outs_descs,
)
return outs_pair
def fn_wrappers(fn):
fns = [fn]
f = fn
while hasattr(f, "__wrapped__"):
f = f.__wrapped__
fns.append(f)
return fns

View File

@ -48,6 +48,7 @@ from ._activation_checkpointing.knapsack import (
ilp_knapsack,
)
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
from ._aot_autograd.logging_utils import get_aot_graph_name
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
@ -175,6 +176,7 @@ def _extract_graph_with_inputs_outputs(
joint_graph: fx.Graph,
inputs: list[fx.Node],
outputs: list[fx.Node],
outputs_descs: list[AOTOutput],
subgraph: Optional[str] = None,
) -> fx.Graph:
"""
@ -239,7 +241,8 @@ def _extract_graph_with_inputs_outputs(
output_values.append(env[x])
else:
output_values.append(x)
new_graph.output(tuple(output_values))
out = new_graph.output(tuple(output_values))
out.meta["desc"] = outputs_descs
new_graph.eliminate_dead_code()
new_graph.lint()
@ -299,13 +302,20 @@ def _must_be_in_backward(node: fx.Node) -> bool:
def _extract_fwd_bwd_outputs(
joint_module: fx.GraphModule, *, num_fwd_outputs
) -> tuple[list[fx.Node], list[fx.Node]]:
) -> tuple[list[fx.Node], list[fx.Node], list[AOTOutput], list[AOTOutput]]:
outputs = pytree.arg_tree_leaves(
*(node.args for node in joint_module.graph.find_nodes(op="output"))
)
outputs_descs = pytree.arg_tree_leaves(
next(iter(joint_module.graph.find_nodes(op="output"))).meta.get(
"desc", [None] * len(outputs)
)
)
fwd_outputs = outputs[:num_fwd_outputs]
bwd_outputs = outputs[num_fwd_outputs:]
return fwd_outputs, bwd_outputs
fwd_outputs_descs = outputs_descs[:num_fwd_outputs]
bwd_outputs_descs = outputs_descs[num_fwd_outputs:]
return fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs
def _remove_by_name(saved_values: list[fx.Node], name: str):
@ -827,8 +837,8 @@ def _extract_fwd_bwd_modules(
num_fwd_outputs: int,
static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
) -> tuple[fx.GraphModule, fx.GraphModule]:
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
joint_module, num_fwd_outputs=num_fwd_outputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
placeholders = joint_module.graph.find_nodes(op="placeholder")
primal_inputs = [*filter(_is_primal, placeholders)]
@ -841,6 +851,7 @@ def _extract_fwd_bwd_modules(
joint_module.graph,
saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs,
bwd_outputs,
bwd_outputs_descs,
"backward",
)
@ -914,6 +925,11 @@ def _extract_fwd_bwd_modules(
joint_module.graph,
primal_inputs + fwd_seed_offset_inputs,
fwd_outputs + saved_values + saved_sym_nodes,
fwd_outputs_descs
+ [
SavedForBackwardsAOTOutput(i)
for i in range(len(saved_values) + len(saved_sym_nodes))
],
"forward",
)
bwd_graph = _extract_graph_with_inputs_outputs(
@ -924,6 +940,7 @@ def _extract_fwd_bwd_modules(
+ bwd_seed_offset_inputs
+ backward_state_inputs,
bwd_outputs,
bwd_outputs_descs,
"backward",
)
@ -976,11 +993,11 @@ def default_partition(
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
joint_module, num_fwd_outputs=num_fwd_outputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, "forward"
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
forward_node_names = OrderedSet(
node.name for node in forward_only_graph.nodes if node.op != "output"
@ -2651,14 +2668,14 @@ def min_cut_rematerialization_partition(
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
)
inputs = primal_inputs + fwd_seed_offset_inputs
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
joint_module, num_fwd_outputs=num_fwd_outputs
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
)
required_bw_nodes.update(
o for o in bwd_outputs if o is not None and o.op != "output"
)
forward_only_graph = _extract_graph_with_inputs_outputs(
joint_module.graph, inputs, fwd_outputs, "forward"
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
)
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
name_to_node[node.name]

View File

@ -880,6 +880,8 @@ def _export_to_aten_ir(
new_output_node = list(new_gm.graph.nodes)[-1]
assert old_output_node.op == "output" and new_output_node.op == "output"
# make sure we don't override any meta
if "desc" in new_output_node.meta:
del new_output_node.meta["desc"]
assert len(new_output_node.meta) == 0
new_output_node.meta.update(old_output_node.meta)

View File

@ -324,7 +324,55 @@ class CodeGen:
self._body_transformer: Optional[TransformCodeFunc] = None
self._func_name: str = "forward"
def gen_fn_def(self, free_vars: list[str], maybe_return_annotation: str) -> str:
def _format_multiline_args(self, args: list[str]) -> str:
"""Helper to format function arguments in expanded multiline format."""
return "".join(self._format_single_arg(arg) for arg in args)
def _format_single_arg(self, arg: str) -> str:
"""Helper to format a single argument with optional comment."""
if "#" in arg:
arg_part, comment_part = arg.split("#", 1)
return f" {arg_part.rstrip()}, # {comment_part.lstrip()}\n"
else:
return f" {arg},\n"
def _get_delimiters(self, container) -> tuple[str, str]:
"""Helper to get opening and closing delimiters for containers."""
return ("(", ")") if isinstance(container, tuple) else ("[", "]")
def _format_multiline_container(self, items, descs=None, prefix="") -> str:
"""Helper to format containers (lists/tuples) in multiline format."""
ldelim, rdelim = self._get_delimiters(items)
desc_trailers = self._get_desc_trailers(items, descs)
return (
f"{prefix}{ldelim}\n"
+ "".join(
f" {item},{trailer}\n" for item, trailer in zip(items, desc_trailers)
)
+ f"{rdelim}"
)
def _get_desc_trailers(self, items, descs):
"""Helper to generate description trailers for items."""
if descs is None:
return [""] * len(items)
return [f" # {desc}" for desc in descs]
def _call_method_with_signature_check(self, method, *args, **kwargs):
"""Helper to call a method with optional parameters based on signature."""
sig = inspect.signature(method)
# Filter kwargs to only include parameters that exist in the method signature
filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
return method(*args, **filtered_kwargs)
def gen_fn_def(
self,
free_vars: list[str],
maybe_return_annotation: str,
*,
expanded_def: bool = False,
) -> str:
"""
Given the free variables and a return annotation, generates the beginning of the FX function.
By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
@ -333,16 +381,26 @@ class CodeGen:
# would have added it.
if len(free_vars) == 0 or free_vars[0] != "self":
free_vars.insert(0, "self")
return (
f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:"
)
def generate_output(self, output_args: Argument) -> str:
if expanded_def:
args_formatted = self._format_multiline_args(free_vars)
return (
f"def {self._func_name}(\n{args_formatted}){maybe_return_annotation}:"
)
else:
return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:"
def generate_output(
self, output_args: Argument, *, descs: Optional[Any] = None
) -> str:
"""
Given the output arguments, generates the return statement of the FX function.
Note: The returned statement should not be indented.
"""
return f"return {repr(output_args)}"
if descs is not None and isinstance(output_args, (list, tuple)):
return self._format_multiline_container(output_args, descs, "return ")
else:
return f"return {repr(output_args)}"
def process_inputs(self, *args: Any) -> Any:
"""
@ -380,6 +438,8 @@ class CodeGen:
include_stride: bool = False,
include_device: bool = False,
colored: bool = False,
# Render each argument on its own line
expanded_def: bool = False,
) -> PythonCode:
free_vars: list[str] = []
body: list[str] = []
@ -586,6 +646,7 @@ class CodeGen:
maybe_type_annotation = (
"" if node.type is None else f" : {type_repr(node.type)}"
)
maybe_comment = ""
if verbose:
# override annotation with more detailed information
@ -617,13 +678,20 @@ class CodeGen:
elif isinstance(meta_val, TensorMetadata):
maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
desc = None
if expanded_def:
desc = node.meta.get("desc", None)
if desc is not None and node.op == "placeholder":
maybe_comment += f" # {desc}"
# output is handled specially
if node.op == "placeholder":
assert isinstance(node.target, str)
maybe_default_arg = (
"" if not node.args else f" = {_get_repr(node.args[0])}"
)
free_vars.append(
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
f"{node.target}{maybe_type_annotation}{maybe_default_arg}{maybe_comment}"
)
raw_name = node.target.replace("*", "")
if raw_name != repr(node):
@ -699,7 +767,13 @@ class CodeGen:
elif node.op == "output":
if node.type is not None:
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
body.append(self.generate_output(node.args[0]))
body.append(
self._call_method_with_signature_check(
self.generate_output,
node.args[0],
descs=desc if expanded_def else None,
)
)
return
raise NotImplementedError(f"node: {node.op} {node.target}")
@ -733,7 +807,12 @@ class CodeGen:
for name, value in self.additional_globals():
add_global(name, value)
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
prologue = self._call_method_with_signature_check(
self.gen_fn_def,
free_vars,
maybe_return_annotation[0],
expanded_def=expanded_def,
)
# remove counter and generate lineno to node index mapping
lineno_map: dict[int, Optional[int]] = {}
@ -782,7 +861,23 @@ class _PyTreeCodeGen(CodeGen):
assert self.pytree_info.out_spec is not None
return pytree.tree_unflatten(out, self.pytree_info.out_spec)
def gen_fn_def(self, free_vars, maybe_return_annotation):
def _format_annotations(self, free_vars: list[str], expanded_def: bool) -> str:
"""Helper to format annotations for variables in pytree codegen."""
if not free_vars:
return ""
has_annotation = [x for x in free_vars if ":" in x]
if not has_annotation:
return ""
if expanded_def:
return "\n " + "\n ".join(has_annotation)
else:
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
def gen_fn_def(
self, free_vars, maybe_return_annotation, *, expanded_def: bool = False
):
# Given a user function/model:
# myargs = [myargs0, myargs1]
# mykwargs = {'mykwargs0': ..., 'mykwargs1': ...}
@ -799,13 +894,17 @@ class _PyTreeCodeGen(CodeGen):
# If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec
# e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec)
if self.pytree_info is None:
return super().gen_fn_def(free_vars, maybe_return_annotation)
return super().gen_fn_def(
free_vars, maybe_return_annotation, expanded_def=expanded_def
)
fn_args = self.pytree_info.orig_args
has_orig_self = (fn_args[0] == "self") if len(fn_args) > 0 else False
if has_orig_self:
free_vars.insert(0, "self")
fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation)
fn_definition = super().gen_fn_def(
fn_args[:], maybe_return_annotation, expanded_def=expanded_def
)
if len(free_vars) > 0: # pytree has placeholders in it
# when kwargs is present, in_spec is tuple(args, kwargs)
@ -837,19 +936,27 @@ class _PyTreeCodeGen(CodeGen):
# we need to split it to two lines:
# one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
# one for code: `var1, var2, = function_call()`
without_annotation = [x.split(":")[0] for x in free_vars]
has_annotation = [x + "; " for x in free_vars if ":" in x]
if len(has_annotation) > 0:
fn_definition += "\n " + "".join(has_annotation) + "\n"
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
fn_definition += self._format_annotations(free_vars, expanded_def)
fn_definition += f"""
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
return fn_definition
def generate_output(self, output_args):
def generate_output(self, output_args, *, descs: Optional[Any] = None):
if self.pytree_info and self.pytree_info.out_spec:
return f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)"
if descs is not None and isinstance(output_args, (list, tuple)):
return (
self._format_multiline_container(
output_args, descs, "return pytree.tree_unflatten("
)
+ ", self._out_spec)"
)
else:
return (
f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)"
)
else:
return super().generate_output(output_args)
return super().generate_output(output_args, descs=descs)
class _FindNodesLookupTable:
@ -1534,6 +1641,7 @@ class Graph:
include_stride: bool = False,
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
) -> PythonCode:
"""
Turn this ``Graph`` into valid Python code.
@ -1600,6 +1708,7 @@ class Graph:
include_stride=include_stride,
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
)
def _python_code(
@ -1611,6 +1720,7 @@ class Graph:
include_stride: bool = False,
include_device: bool = False,
colored: bool = False,
expanded_def: bool = False,
) -> PythonCode:
return self._codegen._gen_python_code(
self.nodes,
@ -1620,6 +1730,7 @@ class Graph:
include_stride=include_stride,
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
)
def __str__(self) -> str:

View File

@ -309,6 +309,7 @@ def _print_readable(
include_stride=False,
include_device=False,
colored=False,
expanded_def=False,
):
graph = module.graph
assert graph is not None and isinstance(graph, torch.fx.Graph), (
@ -321,6 +322,7 @@ def _print_readable(
include_stride=include_stride,
include_device=include_device,
colored=colored,
expanded_def=expanded_def,
)
module_code = verbose_python_code.src
module_code = module_code.lstrip("\n")
@ -935,6 +937,7 @@ class {module_name}(torch.nn.Module):
# If `fast_sympy_print` is True then we use a sympy printer which is faster
# but may result in less-readable output.
fast_sympy_print: bool = False,
expanded_def: bool = False,
):
"""
Return the Python code generated for current GraphModule and its children GraphModules
@ -956,6 +959,7 @@ class {module_name}(torch.nn.Module):
include_stride,
include_device,
colored,
expanded_def,
)
return r