mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add expanded_def option for FX printing, render descriptor, update tests (#158708)
---- - First, we add a new expanded_def to FX, which will expand the definitions of variables into multiple lines, one per variable definition. This makes extremely long args/return lists much more readable. - Next, we extend this mechanism to also print out descriptors on placeholders and return values, as comments, if available. This is how we will test descriptors. - We update tlparse for AOTAutograd to use this format. - We update expect tests to use this format and update their formats, so you can inspect what it can look at. There may be other tests I should update, open to suggestions. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/158708 Approved by: https://github.com/wconstab ghstack dependencies: #158624
This commit is contained in:
parent
bf311141d6
commit
204eb4da5e
|
|
@ -2246,24 +2246,57 @@ class GraphModule(torch.nn.Module):
|
|||
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
||||
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
|
||||
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
||||
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
||||
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
||||
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
):
|
||||
mul: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
|
||||
mul_3: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
|
||||
return (mul, mul_3, primals_5, primals_7, primals_7, primals_1, primals_5, primals_7)
|
||||
return (
|
||||
mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
||||
mul_3, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
||||
primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_1, # SavedForBackwardsAOTOutput(idx=0)
|
||||
primals_5, # SavedForBackwardsAOTOutput(idx=1)
|
||||
primals_7, # SavedForBackwardsAOTOutput(idx=2)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
||||
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
tangents_1: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
||||
tangents_2: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
||||
):
|
||||
mul_8: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = None
|
||||
mul_9: "f32[s47, s16]" = torch.ops.aten.mul.Tensor(tangents_2, primals_1); tangents_2 = primals_1 = None
|
||||
return (None, None, mul_8, mul_9, primals_5, primals_7, primals_7)
|
||||
return (
|
||||
None, # None
|
||||
None, # None
|
||||
mul_8, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
||||
mul_9, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
||||
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2279,27 +2312,58 @@ class GraphModule(torch.nn.Module):
|
|||
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
||||
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
|
||||
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
||||
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
||||
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
||||
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
):
|
||||
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
view: "f32[s16, s47]" = torch.ops.aten.view.default(clone, [primals_2, primals_1]); clone = None
|
||||
view_1: "f32[s16, s47]" = torch.ops.aten.view.default(clone_1, [primals_2, primals_1]); clone_1 = primals_1 = None
|
||||
return (view, view_1, primals_2, primals_5, primals_5, primals_5, primals_7)
|
||||
return (
|
||||
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
||||
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
||||
primals_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
||||
primals_5, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
||||
primals_7, # SavedForBackwardsAOTOutput(idx=1)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16, s47]", tangents_2: "f32[s16, s47]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
tangents_1: "f32[s16, s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
||||
tangents_2: "f32[s16, s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
||||
):
|
||||
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
return (
|
||||
None, # None
|
||||
None, # None
|
||||
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
||||
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
||||
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2317,10 +2381,19 @@ class GraphModule(torch.nn.Module):
|
|||
fw, bw = self._compile_check(f, [(tt, a, b)], dynamic=True, call_backward=True)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_3: "f32[s97, s98]", primals_4: "f32[s97, s98]", primals_5: "Sym(s97)", primals_6: "Sym(s98)", primals_7: "Sym(s98)"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s97)", # PlainAOTInput(idx=0)
|
||||
primals_2: "Sym(s98)", # PlainAOTInput(idx=1)
|
||||
primals_3: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
||||
primals_4: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
||||
primals_5: "Sym(s97)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_6: "Sym(s98)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
||||
primals_7: "Sym(s98)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
):
|
||||
mul: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_3, primals_1); primals_3 = None
|
||||
mul_3: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(primals_4, primals_1); primals_4 = None
|
||||
mul_8: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul, primals_2); mul = None
|
||||
|
|
@ -2329,15 +2402,33 @@ class GraphModule(torch.nn.Module):
|
|||
mul_19: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_11, primals_1); mul_11 = None
|
||||
mul_24: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_16, primals_2); mul_16 = None
|
||||
mul_27: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_19, primals_2); mul_19 = None
|
||||
return (mul_24, mul_27, primals_5, primals_7, primals_7, primals_1, primals_2, primals_5, primals_7)
|
||||
return (
|
||||
mul_24, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
||||
mul_27, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
||||
primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_1, # SavedForBackwardsAOTOutput(idx=0)
|
||||
primals_2, # SavedForBackwardsAOTOutput(idx=1)
|
||||
primals_5, # SavedForBackwardsAOTOutput(idx=2)
|
||||
primals_7, # SavedForBackwardsAOTOutput(idx=3)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s97)", primals_2: "Sym(s98)", primals_5: "Sym(s97)", primals_7: "Sym(s98)", tangents_1: "f32[s97, s98]", tangents_2: "f32[s97, s98]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s97)", # PlainAOTInput(idx=0)
|
||||
primals_2: "Sym(s98)", # PlainAOTInput(idx=1)
|
||||
primals_5: "Sym(s97)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_7: "Sym(s98)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
tangents_1: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
||||
tangents_2: "f32[s97, s98]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
||||
):
|
||||
mul_32: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_1, primals_2); tangents_1 = None
|
||||
mul_33: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(tangents_2, primals_2); tangents_2 = None
|
||||
mul_34: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_32, primals_1); mul_32 = None
|
||||
|
|
@ -2346,7 +2437,15 @@ class GraphModule(torch.nn.Module):
|
|||
mul_37: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_35, primals_2); mul_35 = primals_2 = None
|
||||
mul_38: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_36, primals_1); mul_36 = None
|
||||
mul_39: "f32[s97, s98]" = torch.ops.aten.mul.Tensor(mul_37, primals_1); mul_37 = primals_1 = None
|
||||
return (None, None, mul_38, mul_39, primals_5, primals_7, primals_7)
|
||||
return (
|
||||
None, # None
|
||||
None, # None
|
||||
mul_38, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
||||
mul_39, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
||||
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2362,27 +2461,58 @@ class GraphModule(torch.nn.Module):
|
|||
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
||||
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
|
||||
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
||||
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
||||
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
||||
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
):
|
||||
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
view: "f32[s47, s16]" = torch.ops.aten.view.default(clone, [primals_1, primals_2]); clone = None
|
||||
view_1: "f32[s47, s16]" = torch.ops.aten.view.default(clone_1, [primals_1, primals_2]); clone_1 = primals_1 = primals_2 = None
|
||||
return (view, view_1, primals_5, primals_7, primals_7, primals_5, primals_7)
|
||||
return (
|
||||
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
||||
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_7, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
||||
primals_7, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
||||
primals_7, # SavedForBackwardsAOTOutput(idx=1)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s47, s16]", tangents_2: "f32[s47, s16]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
tangents_1: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
||||
tangents_2: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
||||
):
|
||||
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
return (
|
||||
None, # None
|
||||
None, # None
|
||||
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
||||
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
||||
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2398,28 +2528,57 @@ class GraphModule(torch.nn.Module):
|
|||
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
||||
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
|
||||
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
||||
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
||||
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
||||
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
):
|
||||
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
|
||||
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6]); clone = None
|
||||
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
|
||||
return (view, view_1, mul_6, primals_5, primals_7)
|
||||
return (
|
||||
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
||||
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
||||
mul_6, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
||||
primals_7, # SavedForBackwardsAOTOutput(idx=1)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
tangents_1: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
||||
tangents_2: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
||||
):
|
||||
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
return (
|
||||
None, # None
|
||||
None, # None
|
||||
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
||||
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
||||
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2435,28 +2594,58 @@ class GraphModule(torch.nn.Module):
|
|||
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s47)", primals_2: "Sym(s16)", primals_3: "f32[s47, s16]", primals_4: "f32[s47, s16]", primals_5: "Sym(s47)", primals_6: "Sym(s16)", primals_7: "Sym(s16)"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s47)", # PlainAOTInput(idx=0)
|
||||
primals_2: "Sym(s16)", # PlainAOTInput(idx=1)
|
||||
primals_3: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='a')
|
||||
primals_4: "f32[s47, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=2), attr='b')
|
||||
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_6: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=1)
|
||||
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
):
|
||||
clone: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
clone_1: "f32[s47, s16]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
mul_6: "Sym(s16*s47)" = primals_1 * primals_2; primals_1 = primals_2 = None
|
||||
view: "f32[s16*s47]" = torch.ops.aten.view.default(clone, [mul_6])
|
||||
view_1: "f32[s16*s47]" = torch.ops.aten.view.default(clone_1, [mul_6]); clone_1 = None
|
||||
return (clone, view, view_1, mul_6, primals_5, primals_7)
|
||||
return (
|
||||
clone, # PlainAOTOutput(idx=0)
|
||||
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
|
||||
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
|
||||
mul_6, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0)
|
||||
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
||||
primals_7, # SavedForBackwardsAOTOutput(idx=1)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s47)", primals_7: "Sym(s16)", tangents_1: "f32[s16*s47]", tangents_2: "f32[s16*s47]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_5: "Sym(s47)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
primals_7: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=2), idx=0)
|
||||
tangents_1: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
|
||||
tangents_2: "f32[s16*s47]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
|
||||
):
|
||||
view_2: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_1, [primals_5, primals_7]); tangents_1 = None
|
||||
view_3: "f32[s47, s16]" = torch.ops.aten.view.default(tangents_2, [primals_5, primals_7]); tangents_2 = None
|
||||
return (None, None, view_2, view_3, primals_5, primals_7, primals_7)
|
||||
return (
|
||||
None, # None
|
||||
None, # None
|
||||
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='a')
|
||||
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
primals_7, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=1)
|
||||
primals_7, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=2)), idx=0)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2517,53 +2706,94 @@ class GraphModule(torch.nn.Module):
|
|||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[3, 4]", primals_2: "f32[3, 4]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "f32[3, 4]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='a')
|
||||
primals_2: "f32[3, 4]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='b')
|
||||
):
|
||||
clone: "f32[3, 4]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
clone_1: "f32[3, 4]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
|
||||
view: "f32[12]" = torch.ops.aten.view.default(clone, [-1])
|
||||
view_1: "f32[12]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
return (clone, view, view_1, clone_1)
|
||||
return (
|
||||
clone, # PlainAOTOutput(idx=0)
|
||||
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
|
||||
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
|
||||
clone_1, # PlainAOTOutput(idx=2)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[1].print_readable(print_output=False)),
|
||||
normalize_gm(fw[1].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s16)", # PlainAOTInput(idx=0)
|
||||
primals_2: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='a')
|
||||
primals_3: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='b')
|
||||
primals_4: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=1), idx=1)
|
||||
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
|
||||
):
|
||||
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
|
||||
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
|
||||
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
|
||||
return (
|
||||
clone, # PlainAOTOutput(idx=0)
|
||||
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
|
||||
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
|
||||
sym_size_int_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0)
|
||||
clone_1, # PlainAOTOutput(idx=2)
|
||||
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[12]", tangents_2: "f32[12]"):
|
||||
def forward(
|
||||
self,
|
||||
tangents_1: "f32[12]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
|
||||
tangents_2: "f32[12]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
|
||||
):
|
||||
view_2: "f32[3, 4]" = torch.ops.aten.view.default(tangents_1, [3, 4]); tangents_1 = None
|
||||
view_3: "f32[3, 4]" = torch.ops.aten.view.default(tangents_2, [3, 4]); tangents_2 = None
|
||||
return (view_2, view_3)
|
||||
return (
|
||||
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a')
|
||||
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b')
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[1].print_readable(print_output=False)),
|
||||
normalize_gm(bw[1].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
|
||||
tangents_1: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
|
||||
tangents_2: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
|
||||
):
|
||||
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
|
||||
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
|
||||
return (None, view_2, view_3, primals_5, primals_5)
|
||||
return (
|
||||
None, # None
|
||||
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='a')
|
||||
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1)
|
||||
primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2587,28 +2817,53 @@ class GraphModule(torch.nn.Module):
|
|||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s16)", primals_2: "f32[3, s16]", primals_3: "f32[3, s16]", primals_4: "Sym(s16)", primals_5: "Sym(s16)"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s16)", # PlainAOTInput(idx=0)
|
||||
primals_2: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='a')
|
||||
primals_3: "f32[3, s16]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=1), attr='b')
|
||||
primals_4: "Sym(s16)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=1), idx=1)
|
||||
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
|
||||
):
|
||||
clone: "f32[3, s16]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
clone_1: "f32[3, s16]" = torch.ops.aten.clone.default(primals_3); primals_3 = None
|
||||
|
||||
view: "f32[3*s16]" = torch.ops.aten.view.default(clone, [-1])
|
||||
sym_size_int_2: "Sym(3*s16)" = torch.ops.aten.sym_size.int(view, 0)
|
||||
view_1: "f32[3*s16]" = torch.ops.aten.view.default(clone_1, [-1])
|
||||
return (clone, view, view_1, sym_size_int_2, clone_1, primals_5)
|
||||
return (
|
||||
clone, # PlainAOTOutput(idx=0)
|
||||
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='a')
|
||||
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=1), attr='b')
|
||||
sym_size_int_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=1), idx=0)
|
||||
clone_1, # PlainAOTOutput(idx=2)
|
||||
primals_5, # SavedForBackwardsAOTOutput(idx=0)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_5: "Sym(s16)", tangents_1: "f32[3*s16]", tangents_2: "f32[3*s16]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_5: "Sym(s16)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=1), idx=0)
|
||||
tangents_1: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='a')
|
||||
tangents_2: "f32[3*s16]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=1)), attr='b')
|
||||
):
|
||||
view_2: "f32[3, s16]" = torch.ops.aten.view.default(tangents_1, [3, primals_5]); tangents_1 = None
|
||||
view_3: "f32[3, s16]" = torch.ops.aten.view.default(tangents_2, [3, primals_5]); tangents_2 = None
|
||||
return (None, view_2, view_3, primals_5, primals_5)
|
||||
return (
|
||||
None, # None
|
||||
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='a')
|
||||
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), attr='b')
|
||||
primals_5, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=1)
|
||||
primals_5, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=1)), idx=0)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2624,27 +2879,41 @@ class GraphModule(torch.nn.Module):
|
|||
fw, bw = self._compile_check(f, [(tt,)], dynamic=True, call_backward=True)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[24]", primals_2: "f32[24]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "f32[24]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='a')
|
||||
primals_2: "f32[24]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=0), attr='b')
|
||||
):
|
||||
clone: "f32[24]" = torch.ops.aten.clone.default(primals_1); primals_1 = None
|
||||
clone_1: "f32[24]" = torch.ops.aten.clone.default(primals_2); primals_2 = None
|
||||
|
||||
view: "f32[3, 2, 4]" = torch.ops.aten.view.default(clone, [3, 2, 4]); clone = None
|
||||
view_1: "f32[3, 2, 4]" = torch.ops.aten.view.default(clone_1, [3, 2, 4]); clone_1 = None
|
||||
return (view, view_1)
|
||||
return (
|
||||
view, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='a')
|
||||
view_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='b')
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, tangents_1: "f32[3, 2, 4]", tangents_2: "f32[3, 2, 4]"):
|
||||
def forward(
|
||||
self,
|
||||
tangents_1: "f32[3, 2, 4]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='a')
|
||||
tangents_2: "f32[3, 2, 4]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='b')
|
||||
):
|
||||
view_2: "f32[24]" = torch.ops.aten.view.default(tangents_1, [24]); tangents_1 = None
|
||||
view_3: "f32[24]" = torch.ops.aten.view.default(tangents_2, [24]); tangents_2 = None
|
||||
return (view_2, view_3)
|
||||
return (
|
||||
view_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='a')
|
||||
view_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=0)), attr='b')
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2771,24 +3040,67 @@ class GraphModule(torch.nn.Module):
|
|||
fw, bw = self._compile_check(f, [(nt,)], dynamic=True, call_backward=True)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
|
||||
primals_2: "Sym(s71)", # PlainAOTInput(idx=1)
|
||||
primals_3: "Sym(s55)", # PlainAOTInput(idx=2)
|
||||
primals_4: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values')
|
||||
primals_5: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets')
|
||||
primals_6: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor')
|
||||
primals_7: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor')
|
||||
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
||||
primals_9: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2)
|
||||
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
||||
):
|
||||
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
mul: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(clone, primals_1); clone = None
|
||||
return (mul, primals_5, primals_6, primals_7, primals_8, primals_10, primals_10, primals_1, primals_8, primals_10)
|
||||
return (
|
||||
mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
|
||||
primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')
|
||||
primals_6, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
|
||||
primals_7, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
|
||||
primals_8, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
primals_10, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
|
||||
primals_10, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
||||
primals_1, # SavedForBackwardsAOTOutput(idx=0)
|
||||
primals_8, # SavedForBackwardsAOTOutput(idx=1)
|
||||
primals_10, # SavedForBackwardsAOTOutput(idx=2)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s51)", primals_8: "Sym(s51)", primals_10: "Sym(s55)", tangents_1: "f64[s64, s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
|
||||
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
||||
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
||||
tangents_1: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values')
|
||||
tangents_2: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_offsets')
|
||||
tangents_3: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_min_seqlen_tensor')
|
||||
tangents_4: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_max_seqlen_tensor')
|
||||
):
|
||||
mul_1: "f64[s64, s55]" = torch.ops.aten.mul.Tensor(tangents_1, primals_1); tangents_1 = primals_1 = None
|
||||
return (None, None, None, mul_1, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
|
||||
return (
|
||||
None, # None
|
||||
None, # None
|
||||
None, # None
|
||||
mul_1, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_values')
|
||||
tangents_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_offsets')
|
||||
tangents_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor')
|
||||
tangents_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor')
|
||||
primals_8, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0)
|
||||
primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2)
|
||||
primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2804,28 +3116,71 @@ class GraphModule(torch.nn.Module):
|
|||
fw, bw = self._compile_check(f, [(nt,)], dynamic=True, call_backward=True)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "Sym(s51)", primals_2: "Sym(s71)", primals_3: "Sym(s55)", primals_4: "f64[s64, s55]", primals_5: "i64[s51 + 1]", primals_6: "f32[s0, 0]", primals_7: "f32[s83, 0]", primals_8: "Sym(s51)", primals_9: "Sym(s55)", primals_10: "Sym(s55)"):
|
||||
def forward(
|
||||
self,
|
||||
primals_1: "Sym(s51)", # PlainAOTInput(idx=0)
|
||||
primals_2: "Sym(s71)", # PlainAOTInput(idx=1)
|
||||
primals_3: "Sym(s55)", # PlainAOTInput(idx=2)
|
||||
primals_4: "f64[s64, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values')
|
||||
primals_5: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets')
|
||||
primals_6: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor')
|
||||
primals_7: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor')
|
||||
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
||||
primals_9: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2)
|
||||
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
||||
):
|
||||
clone: "f64[s64, s55]" = torch.ops.aten.clone.default(primals_4); primals_4 = None
|
||||
|
||||
cat: "f64[s64, 2*s55]" = torch.ops.aten.cat.default([clone, clone], 1); clone = None
|
||||
add_2: "Sym(2*s55)" = primals_10 + primals_10
|
||||
return (cat, primals_5, primals_6, primals_7, primals_8, add_2, add_2, primals_8, primals_10, add_2)
|
||||
return (
|
||||
cat, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
|
||||
primals_5, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')
|
||||
primals_6, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
|
||||
primals_7, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
|
||||
primals_8, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=0)
|
||||
add_2, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
|
||||
add_2, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
||||
primals_8, # SavedForBackwardsAOTOutput(idx=0)
|
||||
primals_10, # SavedForBackwardsAOTOutput(idx=1)
|
||||
add_2, # SavedForBackwardsAOTOutput(idx=2)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(bw[0].print_readable(print_output=False)),
|
||||
normalize_gm(bw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_8: "Sym(s51)", primals_10: "Sym(s55)", add_2: "Sym(2*s55)", tangents_1: "f64[s64, 2*s55]", tangents_2: "i64[s51 + 1]", tangents_3: "f32[s0, 0]", tangents_4: "f32[s83, 0]"):
|
||||
def forward(
|
||||
self,
|
||||
primals_8: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
||||
primals_10: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
||||
add_2: "Sym(2*s55)",
|
||||
tangents_1: "f64[s64, 2*s55]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_values')
|
||||
tangents_2: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_offsets')
|
||||
tangents_3: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_min_seqlen_tensor')
|
||||
tangents_4: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=TangentAOTInput(output=PlainAOTOutput(idx=0)), attr='_max_seqlen_tensor')
|
||||
):
|
||||
slice_1: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, 0, primals_10)
|
||||
slice_2: "f64[s64, s55]" = torch.ops.aten.slice.Tensor(tangents_1, 1, primals_10, add_2); tangents_1 = add_2 = None
|
||||
|
||||
add_4: "f64[s64, s55]" = torch.ops.aten.add.Tensor(slice_1, slice_2); slice_1 = slice_2 = None
|
||||
return (None, None, None, add_4, tangents_2, tangents_3, tangents_4, primals_8, primals_10, primals_10)
|
||||
return (
|
||||
None, # None
|
||||
None, # None
|
||||
None, # None
|
||||
add_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_values')
|
||||
tangents_2, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_offsets')
|
||||
tangents_3, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_min_seqlen_tensor')
|
||||
tangents_4, # SubclassGetAttrAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), attr='_max_seqlen_tensor')
|
||||
primals_8, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=0)
|
||||
primals_10, # SubclassSizeAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=2)
|
||||
primals_10, # SubclassStrideAOTOutput(base=GradAOTOutput(grad_of=PlainAOTInput(idx=3)), idx=1)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
@ -2850,10 +3205,22 @@ class GraphModule(torch.nn.Module):
|
|||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(fw[0].print_readable(print_output=False)),
|
||||
normalize_gm(fw[0].print_readable(print_output=False, expanded_def=True)),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "Sym(s51)", arg1_1: "Sym(s71)", arg2_1: "Sym(s55)", arg3_1: "f64[9, s55]", arg4_1: "i64[s51 + 1]", arg5_1: "f32[s0, 0]", arg6_1: "f32[s83, 0]", arg7_1: "Sym(s51)", arg8_1: "Sym(s55)", arg9_1: "Sym(s55)"):
|
||||
def forward(
|
||||
self,
|
||||
arg0_1: "Sym(s51)", # PlainAOTInput(idx=0)
|
||||
arg1_1: "Sym(s71)", # PlainAOTInput(idx=1)
|
||||
arg2_1: "Sym(s55)", # PlainAOTInput(idx=2)
|
||||
arg3_1: "f64[9, s55]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_values')
|
||||
arg4_1: "i64[s51 + 1]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_offsets')
|
||||
arg5_1: "f32[s0, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_min_seqlen_tensor')
|
||||
arg6_1: "f32[s83, 0]", # SubclassGetAttrAOTInput(base=PlainAOTInput(idx=3), attr='_max_seqlen_tensor')
|
||||
arg7_1: "Sym(s51)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=0)
|
||||
arg8_1: "Sym(s55)", # SubclassSizeAOTInput(base=PlainAOTInput(idx=3), idx=2)
|
||||
arg9_1: "Sym(s55)", # SubclassStrideAOTInput(base=PlainAOTInput(idx=3), idx=1)
|
||||
):
|
||||
randn: "f64[2, 5]" = torch.ops.aten.randn.default([2, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
||||
randn_1: "f64[3, 5]" = torch.ops.aten.randn.default([3, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
||||
randn_2: "f64[4, 5]" = torch.ops.aten.randn.default([4, 5], dtype = torch.float64, device = device(type='cpu'), pin_memory = False)
|
||||
|
|
@ -2874,7 +3241,14 @@ class <lambda>(torch.nn.Module):
|
|||
|
||||
sym_size_int: "Sym(s55 + 5)" = torch.ops.aten.sym_size.int(cat_2, 1); cat_2 = None
|
||||
sym_stride_int: "Sym(s55 + 5)" = torch.ops.aten.sym_stride.int(mul, 0)
|
||||
return (mul, cat_1, zeros_1, zeros_2, sym_size_int, sym_stride_int)
|
||||
return (
|
||||
mul, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_values')
|
||||
cat_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_offsets')
|
||||
zeros_1, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_min_seqlen_tensor')
|
||||
zeros_2, # SubclassGetAttrAOTOutput(base=PlainAOTOutput(idx=0), attr='_max_seqlen_tensor')
|
||||
sym_size_int, # SubclassSizeAOTOutput(base=PlainAOTOutput(idx=0), idx=2)
|
||||
sym_stride_int, # SubclassStrideAOTOutput(base=PlainAOTOutput(idx=0), idx=1)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ torch.fx.graph.Graph.node_copy(self, node: torch.fx.node.Node, arg_transform: Ca
|
|||
torch.fx.graph.Graph.output(self, result: 'Argument', type_expr: Optional[Any] = None)
|
||||
torch.fx.graph.Graph.placeholder(self, name: str, type_expr: Optional[Any] = None, default_value: Any) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.print_tabular(self)
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph.Graph.python_code(self, root_module: str, verbose: bool = False, include_stride: bool = False, include_device: bool = False, colored: bool = False, expanded_def: bool = False) -> torch.fx.graph.PythonCode
|
||||
torch.fx.graph_module.GraphModule.__init__(self, root: Union[torch.nn.modules.module.Module, Dict[str, Any]], graph: torch.fx.graph.Graph, class_name: str = 'GraphModule')
|
||||
torch.fx.graph_module.GraphModule.add_submodule(self, target: str, m: torch.nn.modules.module.Module) -> bool
|
||||
torch.fx.graph_module.GraphModule.delete_all_unused_submodules(self) -> None
|
||||
|
|
|
|||
|
|
@ -4837,10 +4837,14 @@ def forward(self, arg0_1):
|
|||
inps = [torch.randn(2, 2), torch.ones(2)]
|
||||
gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True)
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(gm.print_readable(False)),
|
||||
normalize_gm(gm.print_readable(False, expanded_def=True)),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"):
|
||||
def forward(
|
||||
self,
|
||||
arg0_1: "f32[2, 2]", # PlainAOTInput(idx=0)
|
||||
arg1_1: "f32[2]", # PlainAOTInput(idx=1)
|
||||
):
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(arg0_1)
|
||||
gt: "b8[]" = torch.ops.aten.gt.Scalar(sum_1, 4); sum_1 = None
|
||||
|
||||
|
|
@ -4851,7 +4855,10 @@ class <lambda>(torch.nn.Module):
|
|||
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(getitem, 3)
|
||||
add_1: "f32[2, 2]" = torch.ops.aten.add.Tensor(getitem, 4); getitem = None
|
||||
return (add, add_1)
|
||||
return (
|
||||
add, # PlainAOTOutput(idx=0)
|
||||
add_1, # PlainAOTOutput(idx=1)
|
||||
)
|
||||
|
||||
class true_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"):
|
||||
|
|
@ -4925,10 +4932,14 @@ class <lambda>(torch.nn.Module):
|
|||
inps = [torch.randn(2, 2), torch.ones(2)]
|
||||
gm, _ = aot_export_module(M(), inps, trace_joint=False, pre_dispatch=True)
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(gm.print_readable(False)),
|
||||
normalize_gm(gm.print_readable(False, expanded_def=True)),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2, 2]", arg1_1: "f32[2]"):
|
||||
def forward(
|
||||
self,
|
||||
arg0_1: "f32[2, 2]", # PlainAOTInput(idx=0)
|
||||
arg1_1: "f32[2]", # PlainAOTInput(idx=1)
|
||||
):
|
||||
cos: "f32[2, 2]" = torch.ops.aten.cos.default(arg0_1); arg0_1 = None
|
||||
|
||||
_set_grad_enabled = torch._C._set_grad_enabled(True); _set_grad_enabled = None
|
||||
|
|
@ -4939,7 +4950,9 @@ class <lambda>(torch.nn.Module):
|
|||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(getitem_2); getitem_2 = None
|
||||
add: "f32[2, 2]" = torch.ops.aten.add.Tensor(cos, sum_1); cos = sum_1 = None
|
||||
return (add,)
|
||||
return (
|
||||
add, # PlainAOTOutput(idx=0)
|
||||
)
|
||||
|
||||
class body_graph_0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[2]", arg1_1: "f32[2]"):
|
||||
|
|
@ -5098,10 +5111,20 @@ def forward(self, arg0_1):
|
|||
for node in fx_g.graph.nodes:
|
||||
node.meta.pop("stack_trace", None)
|
||||
self.assertExpectedInline(
|
||||
fx_g.print_readable(print_output=False),
|
||||
fx_g.print_readable(print_output=False, expanded_def=True),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"):
|
||||
def forward(
|
||||
self,
|
||||
arg0_1: "f32[3, 1, 1, 1]",
|
||||
arg1_1: "f32[3]",
|
||||
arg2_1: "f32[3]",
|
||||
arg3_1: "f32[3]",
|
||||
arg4_1: "f32[3]",
|
||||
arg5_1: "f32[3]",
|
||||
arg6_1: "i64[]",
|
||||
arg7_1: "f32[1, 1, 3, 3]",
|
||||
):
|
||||
# No stacktrace found for following nodes
|
||||
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg1_1 = None
|
||||
add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None
|
||||
|
|
@ -5183,10 +5206,20 @@ class <lambda>(torch.nn.Module):
|
|||
for node in fx_g_inference.graph.nodes:
|
||||
node.meta.pop("stack_trace", None)
|
||||
self.assertExpectedInline(
|
||||
fx_g_inference.print_readable(print_output=False),
|
||||
fx_g_inference.print_readable(print_output=False, expanded_def=True),
|
||||
"""\
|
||||
class <lambda>(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[3, 1, 1, 1]", arg1_1: "f32[3]", arg2_1: "f32[3]", arg3_1: "f32[3]", arg4_1: "f32[3]", arg5_1: "f32[3]", arg6_1: "i64[]", arg7_1: "f32[1, 1, 3, 3]"):
|
||||
def forward(
|
||||
self,
|
||||
arg0_1: "f32[3, 1, 1, 1]", # PlainAOTInput(idx=0)
|
||||
arg1_1: "f32[3]", # PlainAOTInput(idx=1)
|
||||
arg2_1: "f32[3]", # PlainAOTInput(idx=2)
|
||||
arg3_1: "f32[3]", # PlainAOTInput(idx=3)
|
||||
arg4_1: "f32[3]", # PlainAOTInput(idx=4)
|
||||
arg5_1: "f32[3]", # PlainAOTInput(idx=5)
|
||||
arg6_1: "i64[]", # PlainAOTInput(idx=6)
|
||||
arg7_1: "f32[1, 1, 3, 3]", # PlainAOTInput(idx=7)
|
||||
):
|
||||
# No stacktrace found for following nodes
|
||||
convolution: "f32[1, 3, 3, 3]" = torch.ops.aten.convolution.default(arg7_1, arg0_1, arg1_1, [1, 1], [0, 0], [1, 1], False, [0, 0], 1); arg7_1 = arg0_1 = arg1_1 = None
|
||||
add: "i64[]" = torch.ops.aten.add.Tensor(arg6_1, 1); arg6_1 = None
|
||||
|
|
@ -5199,7 +5232,13 @@ class <lambda>(torch.nn.Module):
|
|||
detach: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(relu); relu = None
|
||||
detach_1: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach); detach = None
|
||||
detach_2: "f32[1, 3, 3, 3]" = torch.ops.aten.detach.default(detach_1); detach_1 = None
|
||||
return (getitem_3, getitem_4, add, sum_1, detach_2)
|
||||
return (
|
||||
getitem_3, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=4))
|
||||
getitem_4, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=5))
|
||||
add, # InputMutationAOTOutput(mutated_input=PlainAOTInput(idx=6))
|
||||
sum_1, # PlainAOTOutput(idx=0)
|
||||
detach_2, # PlainAOTOutput(idx=1)
|
||||
)
|
||||
""", # noqa: B950
|
||||
)
|
||||
# Some important characteristics of the exported graph below:
|
||||
|
|
|
|||
|
|
@ -739,3 +739,11 @@ class DummyAOTOutput(AOTOutput):
|
|||
|
||||
def expr(self) -> str:
|
||||
return f"__dummy{self.idx}"
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class SavedForBackwardsAOTOutput(AOTOutput):
|
||||
idx: int
|
||||
|
||||
def expr(self) -> str:
|
||||
return f"__saved_for_backwards_{self.idx}"
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from torch.fx.experimental.proxy_tensor import make_fx
|
|||
from torchgen.utils import dataclass_repr
|
||||
|
||||
from .. import config
|
||||
from .descriptors import AOTInput
|
||||
from .descriptors import AOTInput, BackwardTokenAOTInput
|
||||
from .functional_utils import (
|
||||
assert_functional_graph,
|
||||
propagate_input_mutation_stacktraces,
|
||||
|
|
@ -36,6 +36,7 @@ from .schemas import AOTConfig, FxValue, SubclassMeta, TraceFn, ViewAndMutationM
|
|||
from .utils import (
|
||||
call_and_expect_output_descs,
|
||||
copy_fwd_metadata_to_bw_nodes,
|
||||
fn_wrappers,
|
||||
register_buffer_assignment_hook,
|
||||
root_module_when_exporting_non_strict,
|
||||
unlift_tokens,
|
||||
|
|
@ -65,11 +66,10 @@ def _create_graph(
|
|||
@functools.wraps(f)
|
||||
def inner_f(*args):
|
||||
nonlocal out_descs
|
||||
assert out_descs is None
|
||||
out, out_descs = call_and_expect_output_descs(f, args)
|
||||
return out
|
||||
|
||||
# TODO: save args_descs/out_descs to the produced FX graph
|
||||
|
||||
with (
|
||||
enable_python_dispatcher(),
|
||||
FunctionalTensorMode(
|
||||
|
|
@ -86,6 +86,49 @@ def _create_graph(
|
|||
pre_dispatch=aot_config.pre_dispatch,
|
||||
)(*args)
|
||||
|
||||
if args_descs is not None:
|
||||
flat_args_descs, _ = pytree.tree_flatten(args_descs)
|
||||
flat_out_descs, _ = pytree.tree_flatten(out_descs)
|
||||
|
||||
# Unfortunately, flat_args_descs is not guaranteed to match the
|
||||
# number of actual arguments that show up on the FX graph.
|
||||
# Specifically, allow_token_discovery=True means that we will
|
||||
# silently add extra token arguments to the backwards graph.
|
||||
#
|
||||
# Although there are a few ways to detect what these tokens are,
|
||||
# we are going to settle for something dodgy but simple to
|
||||
# implement: match tangents_token placeholders specifically,
|
||||
# as these are the only placeholders that are created by token
|
||||
# discovery (NB: there is NO other code that treats this name
|
||||
# as load bearing, so this is a bit naughty!)
|
||||
#
|
||||
# I originally wanted to detect tokens in exactly the same way
|
||||
# that they are detected at normal runtime, but to be honest
|
||||
# the normal runtime detection is pretty strange: it seems the
|
||||
# backward tokens are not reliably at the end of the argument list
|
||||
# but *precede* the RNG arguments (I don't understand why this is
|
||||
# the case). And in unlift_tokens, token arguments are detected
|
||||
# by seeing if they feed into an effects call! Dastardly. Why
|
||||
# didn't we just introduce a new type.
|
||||
|
||||
i = 0
|
||||
j = 0
|
||||
for n in fx_g.graph.nodes:
|
||||
if n.op == "placeholder":
|
||||
if n.name.startswith("tangents_token"):
|
||||
n.meta["desc"] = BackwardTokenAOTInput(j)
|
||||
j += 1
|
||||
else:
|
||||
assert i < len(flat_args_descs), (
|
||||
(fn_wrappers(inner_f)),
|
||||
[n for n in fx_g.graph.nodes if n.op == "placeholder"],
|
||||
flat_args_descs,
|
||||
)
|
||||
n.meta["desc"] = flat_args_descs[i]
|
||||
i += 1
|
||||
elif n.op == "output":
|
||||
n.meta["desc"] = flat_out_descs
|
||||
|
||||
return fx_g
|
||||
|
||||
|
||||
|
|
@ -276,7 +319,10 @@ def aot_dispatch_base_graph(
|
|||
trace_structured(
|
||||
"aot_inference_graph",
|
||||
payload_fn=lambda: fw_module.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
expanded_def=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -266,6 +266,7 @@ def aot_stage2_inference(
|
|||
include_stride=True,
|
||||
include_device=True,
|
||||
fast_sympy_print=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
|
||||
fakified_out_wrapper = FakifiedOutWrapper()
|
||||
|
|
@ -877,7 +878,10 @@ def maybe_log_graph(
|
|||
|
||||
def gm_str_fn() -> str:
|
||||
return gm.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
|
||||
if out_structured_logs is not None:
|
||||
|
|
@ -1347,7 +1351,10 @@ def aot_stage2_autograd(
|
|||
),
|
||||
)
|
||||
joint_graph_str = fx_g.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
trace_structured(
|
||||
"aot_joint_graph",
|
||||
|
|
@ -1596,10 +1603,16 @@ def aot_stage2_autograd(
|
|||
),
|
||||
)
|
||||
fw_module_str = fw_module.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
bw_module_str = bw_module.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
print_output=False,
|
||||
include_stride=True,
|
||||
include_device=True,
|
||||
expanded_def=True,
|
||||
)
|
||||
|
||||
trace_structured(
|
||||
|
|
|
|||
|
|
@ -544,3 +544,12 @@ def call_and_expect_output_descs(fn, args):
|
|||
outs_descs,
|
||||
)
|
||||
return outs_pair
|
||||
|
||||
|
||||
def fn_wrappers(fn):
|
||||
fns = [fn]
|
||||
f = fn
|
||||
while hasattr(f, "__wrapped__"):
|
||||
f = f.__wrapped__
|
||||
fns.append(f)
|
||||
return fns
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ from ._activation_checkpointing.knapsack import (
|
|||
ilp_knapsack,
|
||||
)
|
||||
from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
|
||||
from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
|
||||
from ._aot_autograd.logging_utils import get_aot_graph_name
|
||||
from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
|
||||
from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
|
||||
|
|
@ -175,6 +176,7 @@ def _extract_graph_with_inputs_outputs(
|
|||
joint_graph: fx.Graph,
|
||||
inputs: list[fx.Node],
|
||||
outputs: list[fx.Node],
|
||||
outputs_descs: list[AOTOutput],
|
||||
subgraph: Optional[str] = None,
|
||||
) -> fx.Graph:
|
||||
"""
|
||||
|
|
@ -239,7 +241,8 @@ def _extract_graph_with_inputs_outputs(
|
|||
output_values.append(env[x])
|
||||
else:
|
||||
output_values.append(x)
|
||||
new_graph.output(tuple(output_values))
|
||||
out = new_graph.output(tuple(output_values))
|
||||
out.meta["desc"] = outputs_descs
|
||||
|
||||
new_graph.eliminate_dead_code()
|
||||
new_graph.lint()
|
||||
|
|
@ -299,13 +302,20 @@ def _must_be_in_backward(node: fx.Node) -> bool:
|
|||
|
||||
def _extract_fwd_bwd_outputs(
|
||||
joint_module: fx.GraphModule, *, num_fwd_outputs
|
||||
) -> tuple[list[fx.Node], list[fx.Node]]:
|
||||
) -> tuple[list[fx.Node], list[fx.Node], list[AOTOutput], list[AOTOutput]]:
|
||||
outputs = pytree.arg_tree_leaves(
|
||||
*(node.args for node in joint_module.graph.find_nodes(op="output"))
|
||||
)
|
||||
outputs_descs = pytree.arg_tree_leaves(
|
||||
next(iter(joint_module.graph.find_nodes(op="output"))).meta.get(
|
||||
"desc", [None] * len(outputs)
|
||||
)
|
||||
)
|
||||
fwd_outputs = outputs[:num_fwd_outputs]
|
||||
bwd_outputs = outputs[num_fwd_outputs:]
|
||||
return fwd_outputs, bwd_outputs
|
||||
fwd_outputs_descs = outputs_descs[:num_fwd_outputs]
|
||||
bwd_outputs_descs = outputs_descs[num_fwd_outputs:]
|
||||
return fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs
|
||||
|
||||
|
||||
def _remove_by_name(saved_values: list[fx.Node], name: str):
|
||||
|
|
@ -827,8 +837,8 @@ def _extract_fwd_bwd_modules(
|
|||
num_fwd_outputs: int,
|
||||
static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
|
||||
) -> tuple[fx.GraphModule, fx.GraphModule]:
|
||||
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
|
||||
joint_module, num_fwd_outputs=num_fwd_outputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
placeholders = joint_module.graph.find_nodes(op="placeholder")
|
||||
primal_inputs = [*filter(_is_primal, placeholders)]
|
||||
|
|
@ -841,6 +851,7 @@ def _extract_fwd_bwd_modules(
|
|||
joint_module.graph,
|
||||
saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs,
|
||||
bwd_outputs,
|
||||
bwd_outputs_descs,
|
||||
"backward",
|
||||
)
|
||||
|
||||
|
|
@ -914,6 +925,11 @@ def _extract_fwd_bwd_modules(
|
|||
joint_module.graph,
|
||||
primal_inputs + fwd_seed_offset_inputs,
|
||||
fwd_outputs + saved_values + saved_sym_nodes,
|
||||
fwd_outputs_descs
|
||||
+ [
|
||||
SavedForBackwardsAOTOutput(i)
|
||||
for i in range(len(saved_values) + len(saved_sym_nodes))
|
||||
],
|
||||
"forward",
|
||||
)
|
||||
bwd_graph = _extract_graph_with_inputs_outputs(
|
||||
|
|
@ -924,6 +940,7 @@ def _extract_fwd_bwd_modules(
|
|||
+ bwd_seed_offset_inputs
|
||||
+ backward_state_inputs,
|
||||
bwd_outputs,
|
||||
bwd_outputs_descs,
|
||||
"backward",
|
||||
)
|
||||
|
||||
|
|
@ -976,11 +993,11 @@ def default_partition(
|
|||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
|
||||
joint_module, num_fwd_outputs=num_fwd_outputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, "forward"
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
forward_node_names = OrderedSet(
|
||||
node.name for node in forward_only_graph.nodes if node.op != "output"
|
||||
|
|
@ -2651,14 +2668,14 @@ def min_cut_rematerialization_partition(
|
|||
filter(_is_fwd_seed_offset, joint_module.graph.nodes)
|
||||
)
|
||||
inputs = primal_inputs + fwd_seed_offset_inputs
|
||||
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
|
||||
joint_module, num_fwd_outputs=num_fwd_outputs
|
||||
fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
|
||||
_extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
|
||||
)
|
||||
required_bw_nodes.update(
|
||||
o for o in bwd_outputs if o is not None and o.op != "output"
|
||||
)
|
||||
forward_only_graph = _extract_graph_with_inputs_outputs(
|
||||
joint_module.graph, inputs, fwd_outputs, "forward"
|
||||
joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
|
||||
)
|
||||
required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
|
||||
name_to_node[node.name]
|
||||
|
|
|
|||
|
|
@ -880,6 +880,8 @@ def _export_to_aten_ir(
|
|||
new_output_node = list(new_gm.graph.nodes)[-1]
|
||||
assert old_output_node.op == "output" and new_output_node.op == "output"
|
||||
# make sure we don't override any meta
|
||||
if "desc" in new_output_node.meta:
|
||||
del new_output_node.meta["desc"]
|
||||
assert len(new_output_node.meta) == 0
|
||||
new_output_node.meta.update(old_output_node.meta)
|
||||
|
||||
|
|
|
|||
|
|
@ -324,7 +324,55 @@ class CodeGen:
|
|||
self._body_transformer: Optional[TransformCodeFunc] = None
|
||||
self._func_name: str = "forward"
|
||||
|
||||
def gen_fn_def(self, free_vars: list[str], maybe_return_annotation: str) -> str:
|
||||
def _format_multiline_args(self, args: list[str]) -> str:
|
||||
"""Helper to format function arguments in expanded multiline format."""
|
||||
return "".join(self._format_single_arg(arg) for arg in args)
|
||||
|
||||
def _format_single_arg(self, arg: str) -> str:
|
||||
"""Helper to format a single argument with optional comment."""
|
||||
if "#" in arg:
|
||||
arg_part, comment_part = arg.split("#", 1)
|
||||
return f" {arg_part.rstrip()}, # {comment_part.lstrip()}\n"
|
||||
else:
|
||||
return f" {arg},\n"
|
||||
|
||||
def _get_delimiters(self, container) -> tuple[str, str]:
|
||||
"""Helper to get opening and closing delimiters for containers."""
|
||||
return ("(", ")") if isinstance(container, tuple) else ("[", "]")
|
||||
|
||||
def _format_multiline_container(self, items, descs=None, prefix="") -> str:
|
||||
"""Helper to format containers (lists/tuples) in multiline format."""
|
||||
ldelim, rdelim = self._get_delimiters(items)
|
||||
desc_trailers = self._get_desc_trailers(items, descs)
|
||||
|
||||
return (
|
||||
f"{prefix}{ldelim}\n"
|
||||
+ "".join(
|
||||
f" {item},{trailer}\n" for item, trailer in zip(items, desc_trailers)
|
||||
)
|
||||
+ f"{rdelim}"
|
||||
)
|
||||
|
||||
def _get_desc_trailers(self, items, descs):
|
||||
"""Helper to generate description trailers for items."""
|
||||
if descs is None:
|
||||
return [""] * len(items)
|
||||
return [f" # {desc}" for desc in descs]
|
||||
|
||||
def _call_method_with_signature_check(self, method, *args, **kwargs):
|
||||
"""Helper to call a method with optional parameters based on signature."""
|
||||
sig = inspect.signature(method)
|
||||
# Filter kwargs to only include parameters that exist in the method signature
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||
return method(*args, **filtered_kwargs)
|
||||
|
||||
def gen_fn_def(
|
||||
self,
|
||||
free_vars: list[str],
|
||||
maybe_return_annotation: str,
|
||||
*,
|
||||
expanded_def: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
Given the free variables and a return annotation, generates the beginning of the FX function.
|
||||
By default, `gen_fn_def(['a', 'b'], '') == 'def {self._func_name}(a, b):'`
|
||||
|
|
@ -333,16 +381,26 @@ class CodeGen:
|
|||
# would have added it.
|
||||
if len(free_vars) == 0 or free_vars[0] != "self":
|
||||
free_vars.insert(0, "self")
|
||||
return (
|
||||
f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:"
|
||||
)
|
||||
|
||||
def generate_output(self, output_args: Argument) -> str:
|
||||
if expanded_def:
|
||||
args_formatted = self._format_multiline_args(free_vars)
|
||||
return (
|
||||
f"def {self._func_name}(\n{args_formatted}){maybe_return_annotation}:"
|
||||
)
|
||||
else:
|
||||
return f"def {self._func_name}({', '.join(free_vars)}){maybe_return_annotation}:"
|
||||
|
||||
def generate_output(
|
||||
self, output_args: Argument, *, descs: Optional[Any] = None
|
||||
) -> str:
|
||||
"""
|
||||
Given the output arguments, generates the return statement of the FX function.
|
||||
Note: The returned statement should not be indented.
|
||||
"""
|
||||
return f"return {repr(output_args)}"
|
||||
if descs is not None and isinstance(output_args, (list, tuple)):
|
||||
return self._format_multiline_container(output_args, descs, "return ")
|
||||
else:
|
||||
return f"return {repr(output_args)}"
|
||||
|
||||
def process_inputs(self, *args: Any) -> Any:
|
||||
"""
|
||||
|
|
@ -380,6 +438,8 @@ class CodeGen:
|
|||
include_stride: bool = False,
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
# Render each argument on its own line
|
||||
expanded_def: bool = False,
|
||||
) -> PythonCode:
|
||||
free_vars: list[str] = []
|
||||
body: list[str] = []
|
||||
|
|
@ -586,6 +646,7 @@ class CodeGen:
|
|||
maybe_type_annotation = (
|
||||
"" if node.type is None else f" : {type_repr(node.type)}"
|
||||
)
|
||||
maybe_comment = ""
|
||||
|
||||
if verbose:
|
||||
# override annotation with more detailed information
|
||||
|
|
@ -617,13 +678,20 @@ class CodeGen:
|
|||
elif isinstance(meta_val, TensorMetadata):
|
||||
maybe_type_annotation = f': "{dtype_abbrs[meta_val.dtype]}{stringify_shape(meta_val.shape)}"'
|
||||
|
||||
desc = None
|
||||
if expanded_def:
|
||||
desc = node.meta.get("desc", None)
|
||||
if desc is not None and node.op == "placeholder":
|
||||
maybe_comment += f" # {desc}"
|
||||
# output is handled specially
|
||||
|
||||
if node.op == "placeholder":
|
||||
assert isinstance(node.target, str)
|
||||
maybe_default_arg = (
|
||||
"" if not node.args else f" = {_get_repr(node.args[0])}"
|
||||
)
|
||||
free_vars.append(
|
||||
f"{node.target}{maybe_type_annotation}{maybe_default_arg}"
|
||||
f"{node.target}{maybe_type_annotation}{maybe_default_arg}{maybe_comment}"
|
||||
)
|
||||
raw_name = node.target.replace("*", "")
|
||||
if raw_name != repr(node):
|
||||
|
|
@ -699,7 +767,13 @@ class CodeGen:
|
|||
elif node.op == "output":
|
||||
if node.type is not None:
|
||||
maybe_return_annotation[0] = f" -> {type_repr(node.type)}"
|
||||
body.append(self.generate_output(node.args[0]))
|
||||
body.append(
|
||||
self._call_method_with_signature_check(
|
||||
self.generate_output,
|
||||
node.args[0],
|
||||
descs=desc if expanded_def else None,
|
||||
)
|
||||
)
|
||||
return
|
||||
raise NotImplementedError(f"node: {node.op} {node.target}")
|
||||
|
||||
|
|
@ -733,7 +807,12 @@ class CodeGen:
|
|||
for name, value in self.additional_globals():
|
||||
add_global(name, value)
|
||||
|
||||
prologue = self.gen_fn_def(free_vars, maybe_return_annotation[0])
|
||||
prologue = self._call_method_with_signature_check(
|
||||
self.gen_fn_def,
|
||||
free_vars,
|
||||
maybe_return_annotation[0],
|
||||
expanded_def=expanded_def,
|
||||
)
|
||||
|
||||
# remove counter and generate lineno to node index mapping
|
||||
lineno_map: dict[int, Optional[int]] = {}
|
||||
|
|
@ -782,7 +861,23 @@ class _PyTreeCodeGen(CodeGen):
|
|||
assert self.pytree_info.out_spec is not None
|
||||
return pytree.tree_unflatten(out, self.pytree_info.out_spec)
|
||||
|
||||
def gen_fn_def(self, free_vars, maybe_return_annotation):
|
||||
def _format_annotations(self, free_vars: list[str], expanded_def: bool) -> str:
|
||||
"""Helper to format annotations for variables in pytree codegen."""
|
||||
if not free_vars:
|
||||
return ""
|
||||
|
||||
has_annotation = [x for x in free_vars if ":" in x]
|
||||
if not has_annotation:
|
||||
return ""
|
||||
|
||||
if expanded_def:
|
||||
return "\n " + "\n ".join(has_annotation)
|
||||
else:
|
||||
return "\n " + "".join(x + "; " for x in has_annotation) + "\n"
|
||||
|
||||
def gen_fn_def(
|
||||
self, free_vars, maybe_return_annotation, *, expanded_def: bool = False
|
||||
):
|
||||
# Given a user function/model:
|
||||
# myargs = [myargs0, myargs1]
|
||||
# mykwargs = {'mykwargs0': ..., 'mykwargs1': ...}
|
||||
|
|
@ -799,13 +894,17 @@ class _PyTreeCodeGen(CodeGen):
|
|||
# If the user function/model does not have keywords, the dict is suppressed from tree_flatten_spec
|
||||
# e.g. tree_flatten_spec([mypos, myargs0, myargs1]), self._in_spec)
|
||||
if self.pytree_info is None:
|
||||
return super().gen_fn_def(free_vars, maybe_return_annotation)
|
||||
return super().gen_fn_def(
|
||||
free_vars, maybe_return_annotation, expanded_def=expanded_def
|
||||
)
|
||||
|
||||
fn_args = self.pytree_info.orig_args
|
||||
has_orig_self = (fn_args[0] == "self") if len(fn_args) > 0 else False
|
||||
if has_orig_self:
|
||||
free_vars.insert(0, "self")
|
||||
fn_definition = super().gen_fn_def(fn_args[:], maybe_return_annotation)
|
||||
fn_definition = super().gen_fn_def(
|
||||
fn_args[:], maybe_return_annotation, expanded_def=expanded_def
|
||||
)
|
||||
|
||||
if len(free_vars) > 0: # pytree has placeholders in it
|
||||
# when kwargs is present, in_spec is tuple(args, kwargs)
|
||||
|
|
@ -837,19 +936,27 @@ class _PyTreeCodeGen(CodeGen):
|
|||
# we need to split it to two lines:
|
||||
# one for annotation: `var1: annotation1; var2: annotation2;` (note the semicolon)
|
||||
# one for code: `var1, var2, = function_call()`
|
||||
without_annotation = [x.split(":")[0] for x in free_vars]
|
||||
has_annotation = [x + "; " for x in free_vars if ":" in x]
|
||||
if len(has_annotation) > 0:
|
||||
fn_definition += "\n " + "".join(has_annotation) + "\n"
|
||||
without_annotation = [x.split(":")[0].split("#")[0] for x in free_vars]
|
||||
fn_definition += self._format_annotations(free_vars, expanded_def)
|
||||
fn_definition += f"""
|
||||
{", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})"""
|
||||
return fn_definition
|
||||
|
||||
def generate_output(self, output_args):
|
||||
def generate_output(self, output_args, *, descs: Optional[Any] = None):
|
||||
if self.pytree_info and self.pytree_info.out_spec:
|
||||
return f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)"
|
||||
if descs is not None and isinstance(output_args, (list, tuple)):
|
||||
return (
|
||||
self._format_multiline_container(
|
||||
output_args, descs, "return pytree.tree_unflatten("
|
||||
)
|
||||
+ ", self._out_spec)"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"return pytree.tree_unflatten({repr(output_args)}, self._out_spec)"
|
||||
)
|
||||
else:
|
||||
return super().generate_output(output_args)
|
||||
return super().generate_output(output_args, descs=descs)
|
||||
|
||||
|
||||
class _FindNodesLookupTable:
|
||||
|
|
@ -1534,6 +1641,7 @@ class Graph:
|
|||
include_stride: bool = False,
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
) -> PythonCode:
|
||||
"""
|
||||
Turn this ``Graph`` into valid Python code.
|
||||
|
|
@ -1600,6 +1708,7 @@ class Graph:
|
|||
include_stride=include_stride,
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
)
|
||||
|
||||
def _python_code(
|
||||
|
|
@ -1611,6 +1720,7 @@ class Graph:
|
|||
include_stride: bool = False,
|
||||
include_device: bool = False,
|
||||
colored: bool = False,
|
||||
expanded_def: bool = False,
|
||||
) -> PythonCode:
|
||||
return self._codegen._gen_python_code(
|
||||
self.nodes,
|
||||
|
|
@ -1620,6 +1730,7 @@ class Graph:
|
|||
include_stride=include_stride,
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
)
|
||||
|
||||
def __str__(self) -> str:
|
||||
|
|
|
|||
|
|
@ -309,6 +309,7 @@ def _print_readable(
|
|||
include_stride=False,
|
||||
include_device=False,
|
||||
colored=False,
|
||||
expanded_def=False,
|
||||
):
|
||||
graph = module.graph
|
||||
assert graph is not None and isinstance(graph, torch.fx.Graph), (
|
||||
|
|
@ -321,6 +322,7 @@ def _print_readable(
|
|||
include_stride=include_stride,
|
||||
include_device=include_device,
|
||||
colored=colored,
|
||||
expanded_def=expanded_def,
|
||||
)
|
||||
module_code = verbose_python_code.src
|
||||
module_code = module_code.lstrip("\n")
|
||||
|
|
@ -935,6 +937,7 @@ class {module_name}(torch.nn.Module):
|
|||
# If `fast_sympy_print` is True then we use a sympy printer which is faster
|
||||
# but may result in less-readable output.
|
||||
fast_sympy_print: bool = False,
|
||||
expanded_def: bool = False,
|
||||
):
|
||||
"""
|
||||
Return the Python code generated for current GraphModule and its children GraphModules
|
||||
|
|
@ -956,6 +959,7 @@ class {module_name}(torch.nn.Module):
|
|||
include_stride,
|
||||
include_device,
|
||||
colored,
|
||||
expanded_def,
|
||||
)
|
||||
return r
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user