mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[invoke_subgraph] Preserve node meta (#150782)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150782 Approved by: https://github.com/bdhirsh ghstack dependencies: #150666
This commit is contained in:
parent
4447352e64
commit
173f126068
|
|
@ -112,7 +112,9 @@ class GraphModule(torch.nn.Module):
|
|||
class ___forward_subgraph_0_post_graph(torch.nn.Module):
|
||||
def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None
|
||||
|
||||
add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
|
||||
add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
||||
|
|
@ -205,7 +207,9 @@ class GraphModule(torch.nn.Module):
|
|||
class ___forward_subgraph_0_post_graph(torch.nn.Module):
|
||||
def forward(self, primals_0: "f32[10, 10]"):
|
||||
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(primals_0, 7); primals_0 = None
|
||||
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(mul, 1); mul = None
|
||||
|
||||
add_1: "f32[10, 10]" = torch.ops.aten.add.Tensor(add, 2); add = None
|
||||
return (add_1,)
|
||||
""",
|
||||
|
|
@ -349,7 +353,9 @@ class GraphModule(torch.nn.Module):
|
|||
class ___forward_subgraph_0_post_graph(torch.nn.Module):
|
||||
def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None
|
||||
|
||||
add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 2); primals_1 = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(add_1); add_1 = None
|
||||
add_2: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
||||
|
|
@ -358,7 +364,9 @@ class GraphModule(torch.nn.Module):
|
|||
class ___forward_subgraph_1_post_graph(torch.nn.Module):
|
||||
def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[10, 20]"):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 2)
|
||||
|
||||
add_1: "f32[10, 20]" = torch.ops.aten.add.Tensor(primals_1, 3)
|
||||
|
||||
cos: "f32[10, 20]" = torch.ops.aten.cos.default(add_1); add_1 = None
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(cos); cos = None
|
||||
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(add, sum_1); add = None
|
||||
|
|
@ -416,6 +424,7 @@ class GraphModule(torch.nn.Module):
|
|||
class ___forward_subgraph_0_post_graph(torch.nn.Module):
|
||||
def forward(self, primals_0: "f32[10, 10]", primals_1: "f32[]"):
|
||||
add: "f32[10, 10]" = torch.ops.aten.add.Tensor(primals_0, 1); primals_0 = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(add); add = None
|
||||
add_1: "f32[]" = torch.ops.aten.add.Tensor(sum_1, primals_1); sum_1 = primals_1 = None
|
||||
return (add_1,)
|
||||
|
|
@ -564,7 +573,9 @@ class <lambda>(torch.nn.Module):
|
|||
class repeated_subgraph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[10, 10]", arg1_1: "f32[10, 20]"):
|
||||
mul: "f32[10, 10]" = torch.ops.aten.mul.Tensor(arg0_1, 2); arg0_1 = None
|
||||
|
||||
mul_1: "f32[10, 20]" = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
|
||||
|
||||
sum_1: "f32[]" = torch.ops.aten.sum.default(mul); mul = None
|
||||
sum_2: "f32[]" = torch.ops.aten.sum.default(mul_1); mul_1 = None
|
||||
add: "f32[]" = torch.ops.aten.add.Tensor(sum_1, sum_2); sum_1 = sum_2 = None
|
||||
|
|
|
|||
|
|
@ -921,6 +921,7 @@ class GraphModule(torch.nn.Module):
|
|||
def forward(self, tangents_0: "f32[8, 8]", tangents_1: "f32[8, 8]"):
|
||||
mul_2: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 3)
|
||||
mul_3: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_1, 2); tangents_1 = None
|
||||
|
||||
add: "f32[8, 8]" = torch.ops.aten.add.Tensor(mul_2, mul_3); mul_2 = mul_3 = None
|
||||
return (add,)
|
||||
""",
|
||||
|
|
@ -1030,7 +1031,9 @@ class GraphModule(torch.nn.Module):
|
|||
class ___forward_invoke_subgraph_0_post_graph(torch.nn.Module):
|
||||
def forward(self, primals_0: "f32[8, 8]", primals_1: "f32[8, 8]"):
|
||||
mm: "f32[8, 8]" = torch.ops.aten.mm.default(primals_0, primals_1)
|
||||
|
||||
sin: "f32[8, 8]" = torch.ops.aten.sin.default(mm)
|
||||
|
||||
t: "f32[8, 8]" = torch.ops.aten.t.default(primals_0); primals_0 = None
|
||||
t_1: "f32[8, 8]" = torch.ops.aten.t.default(primals_1); primals_1 = None
|
||||
return (sin, mm, t, t_1)
|
||||
|
|
@ -1055,6 +1058,7 @@ class GraphModule(torch.nn.Module):
|
|||
def forward(self, mm: "f32[8, 8]", t: "f32[8, 8]", t_1: "f32[8, 8]", tangents_0: "f32[8, 8]"):
|
||||
cos: "f32[8, 8]" = torch.ops.aten.cos.default(mm); mm = None
|
||||
mul: "f32[8, 8]" = torch.ops.aten.mul.Tensor(tangents_0, cos); tangents_0 = cos = None
|
||||
|
||||
mm_1: "f32[8, 8]" = torch.ops.aten.mm.default(t, mul); t = None
|
||||
mm_2: "f32[8, 8]" = torch.ops.aten.mm.default(mul, t_1); mul = t_1 = None
|
||||
return (mm_2, mm_1)
|
||||
|
|
|
|||
|
|
@ -298,7 +298,7 @@ def get_output_metadata(subgraph, operands):
|
|||
|
||||
|
||||
def trace_joint_graph_as_bwd(
|
||||
fn, num_primals, joint_operands, include_key_set, exclude_key_set
|
||||
subgraph, num_primals, joint_operands, include_key_set, exclude_key_set
|
||||
):
|
||||
"""
|
||||
Naively trace out a joint graph. This simplifies the reconstruction of joint
|
||||
|
|
@ -308,6 +308,17 @@ def trace_joint_graph_as_bwd(
|
|||
|
||||
dummy_aot_config = get_dummy_aot_autograd_config()
|
||||
|
||||
if isinstance(subgraph, torch.fx.GraphModule):
|
||||
|
||||
def graph_with_interpreter(*args):
|
||||
# Running graph with interpreter is needed for propagating the stack_trace
|
||||
with torch.fx.traceback.preserve_node_meta():
|
||||
return torch.fx.Interpreter(subgraph).run(*args)
|
||||
|
||||
fn = graph_with_interpreter
|
||||
else:
|
||||
fn = subgraph
|
||||
|
||||
# This joint_fn is inserted as the backward graph as is. This simplifies the
|
||||
# min-cut partitioner work later on.
|
||||
# Input signature - (*primals, *tangents)
|
||||
|
|
|
|||
|
|
@ -821,4 +821,10 @@ class FunctionalizeCtxWrapper:
|
|||
return f"FunctionalizeCtxWrapper on subgraph {self.subgraph})"
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
if isinstance(self.subgraph, torch.fx.GraphModule):
|
||||
# Running graph with interpreter is needed for propagating the stack_trace
|
||||
with fx_traceback.preserve_node_meta():
|
||||
return self.ctx.functionalize(torch.fx.Interpreter(self.subgraph).run)(
|
||||
*args, **kwargs
|
||||
)
|
||||
return self.ctx.functionalize(self.subgraph)(*args, **kwargs)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user