mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[hierarchical-compilation][invoke_subgraph] Use tracing context to cache artifacts of dispatch keys (#137965)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137965 Approved by: https://github.com/zou3519 ghstack dependencies: #137538, #138036
This commit is contained in:
parent
e045e8f0df
commit
2e48788a35
|
|
@ -125,6 +125,13 @@ class TestInvokeSubgraph(TestCase):
|
|||
|
||||
@skipIfTorchDynamo("Not a torch._dynamo test")
|
||||
class TestInvokeSubgraphCompile(TestCase):
|
||||
def count_unique_get_attr_nodes(self, gm, args, expected):
|
||||
subgraph_attr_names = set()
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
subgraph_attr_names.add(node.target)
|
||||
self.assertEqual(len(subgraph_attr_names), expected)
|
||||
|
||||
def test_simple(self):
|
||||
def gn(x, y):
|
||||
return (torch.mul(x, y),)
|
||||
|
|
@ -148,7 +155,7 @@ class TestInvokeSubgraphCompile(TestCase):
|
|||
self.assertEqual(x.grad, x_clone.grad)
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
def test_multiple(self):
|
||||
def test_dedupe(self):
|
||||
def gn(x, y):
|
||||
return (torch.mul(x, y),)
|
||||
|
||||
|
|
@ -173,13 +180,13 @@ class TestInvokeSubgraphCompile(TestCase):
|
|||
self.assertEqual(x.grad, x_clone.grad)
|
||||
self.assertEqual(y.grad, y_clone.grad)
|
||||
|
||||
# Check that the Dynamo graph has just one subgraph module
|
||||
# Check that the Dynamo and AOT graphs have just one subgraph module
|
||||
self.assertEqual(len(backend.graphs), 1)
|
||||
subgraph_attr_names = set()
|
||||
for node in backend.graphs[0].graph.nodes:
|
||||
if node.op == "get_attr":
|
||||
subgraph_attr_names.add(node.target)
|
||||
self.assertEqual(len(subgraph_attr_names), 1)
|
||||
self.assertEqual(len(backend.fw_graphs), 1)
|
||||
self.assertEqual(len(backend.bw_graphs), 1)
|
||||
self.count_unique_get_attr_nodes(backend.graphs[0], [], 1)
|
||||
self.count_unique_get_attr_nodes(backend.fw_graphs[0], [], 1)
|
||||
self.count_unique_get_attr_nodes(backend.bw_graphs[0], [], 1)
|
||||
|
||||
if not TEST_WITH_CROSSREF:
|
||||
self.assertExpectedInline(
|
||||
|
|
@ -206,6 +213,27 @@ class GraphModule(torch.nn.Module):
|
|||
""",
|
||||
)
|
||||
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
|
||||
"""\
|
||||
class GraphModule(torch.nn.Module):
|
||||
def forward(self, primals_1: "f32[8]", primals_2: "f32[8]"):
|
||||
repeated_subgraph0 = self.repeated_subgraph0
|
||||
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, '___forward_invoke_subgraph_0', (primals_1, primals_2)); repeated_subgraph0 = None
|
||||
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
|
||||
|
||||
repeated_subgraph0_1 = self.repeated_subgraph0
|
||||
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, '___forward_invoke_subgraph_0', (getitem, primals_2)); repeated_subgraph0_1 = None
|
||||
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
|
||||
return (getitem_1, primals_1, primals_2, getitem)
|
||||
|
||||
class repeated_subgraph0(torch.nn.Module):
|
||||
def forward(self, arg0_1: "f32[8]", arg1_1: "f32[8]"):
|
||||
mul: "f32[8]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
|
||||
return (mul,)
|
||||
""",
|
||||
)
|
||||
|
||||
def test_nonlocal_update(self):
|
||||
counter = 2
|
||||
|
||||
|
|
|
|||
|
|
@ -417,7 +417,6 @@ class OutputGraph:
|
|||
)
|
||||
|
||||
self.guard_on_key_order: Set[str] = set()
|
||||
self.seen_invoke_subgraphs: Dict[str, str] = {}
|
||||
|
||||
def install_builtins_dict_in_fglobals(self):
|
||||
# f_globals["__builtins__"] can be a dict or a module. This is an
|
||||
|
|
|
|||
|
|
@ -2629,13 +2629,22 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
|
|||
|
||||
key = hash_graph_and_inputs(tx, body_gmod, fake_inputs)
|
||||
|
||||
if key in tx.output.seen_invoke_subgraphs:
|
||||
return tx.output.seen_invoke_subgraphs[key]
|
||||
invoke_subgraph_cache = (
|
||||
tx.output.tracing_context.hop_dispatch_set_cache.get_cache(
|
||||
torch._higher_order_ops.invoke_subgraph
|
||||
)
|
||||
)
|
||||
|
||||
if invoke_subgraph_cache:
|
||||
if identifier := invoke_subgraph_cache.get_dynamo_identifier(key):
|
||||
return identifier
|
||||
|
||||
body_name = super().install_subgraph_in_output_graph(
|
||||
tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name
|
||||
)
|
||||
tx.output.seen_invoke_subgraphs[key] = body_name
|
||||
if invoke_subgraph_cache:
|
||||
invoke_subgraph_cache.add_dynamo_identifier(key, body_name)
|
||||
|
||||
return body_name
|
||||
|
||||
def call_function(
|
||||
|
|
|
|||
|
|
@ -570,6 +570,66 @@ class GuardsContext(Checkpointable[GuardsCheckpointState]):
|
|||
self.dynamo_guards = GuardsSet(state.dynamo_guards)
|
||||
|
||||
|
||||
class HopSubgraphCache:
|
||||
@abstractmethod
|
||||
def add_dynamo_identifier(self, cache_key: str, identifier: str): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_dynamo_identifier(self, cache_key: str) -> Optional[str]: ...
|
||||
|
||||
@abstractmethod
|
||||
def add_autograd_key_entry(self, identifier: str, key: Callable): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_autograd_key_entry(self, identifier: str): ...
|
||||
|
||||
@abstractmethod
|
||||
def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ...
|
||||
|
||||
@abstractmethod
|
||||
def get_proxy_dispatch_entry(self, identifier: str): ...
|
||||
|
||||
|
||||
class InvokeSubgraphCache(HopSubgraphCache):
|
||||
def __init__(self) -> None:
|
||||
self.autograd_cache: Dict[str, Callable] = {}
|
||||
self.proxy_dispatch_cache: Dict[str, Callable] = {}
|
||||
self.dynamo_identifiers: Dict[str, str] = {}
|
||||
|
||||
def add_dynamo_identifier(self, cache_key: str, identifier: str):
|
||||
self.dynamo_identifiers[cache_key] = identifier
|
||||
|
||||
def get_dynamo_identifier(self, cache_key: str) -> Optional[str]:
|
||||
return self.dynamo_identifiers.get(cache_key, None)
|
||||
|
||||
def add_autograd_key_entry(self, identifier: str, key: Callable):
|
||||
self.autograd_cache[identifier] = key
|
||||
|
||||
def get_autograd_key_entry(self, identifier: str):
|
||||
return self.autograd_cache.get(identifier, None)
|
||||
|
||||
def add_proxy_dispatch_entry(self, identifier: str, key: Callable):
|
||||
self.proxy_dispatch_cache[identifier] = key
|
||||
|
||||
def get_proxy_dispatch_entry(self, identifier: str):
|
||||
return self.proxy_dispatch_cache.get(identifier, None)
|
||||
|
||||
|
||||
class HopDispatchSetCache:
|
||||
def __init__(self) -> None:
|
||||
# Delayed import to avoid circular dependency
|
||||
from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
|
||||
|
||||
self.hop_cache_map = {invoke_subgraph: InvokeSubgraphCache()}
|
||||
|
||||
def get_cache(
|
||||
self, op: torch._ops.HigherOrderOperator
|
||||
) -> Optional[HopSubgraphCache]:
|
||||
if op not in self.hop_cache_map:
|
||||
return None
|
||||
return self.hop_cache_map[op] # type: ignore[index]
|
||||
|
||||
|
||||
_TLS = threading.local()
|
||||
|
||||
"""
|
||||
|
|
@ -686,6 +746,7 @@ class TracingContext:
|
|||
# meta on the first invocation
|
||||
# see note: [Returning Fake Tensors on First AOT Autograd Call]
|
||||
self.fakify_first_call = False
|
||||
self.hop_dispatch_set_cache = HopDispatchSetCache()
|
||||
|
||||
def clear(self):
|
||||
# Look at the note in output_graph.py in function `save_global_state`
|
||||
|
|
|
|||
|
|
@ -62,6 +62,13 @@ class InvokeSubgraphHOP(HigherOrderOperator):
|
|||
invoke_subgraph = InvokeSubgraphHOP()
|
||||
|
||||
|
||||
def get_invoke_subgraph_cache():
|
||||
cache = None
|
||||
if tracing_ctx := torch._guards.TracingContext.try_get():
|
||||
cache = tracing_ctx.hop_dispatch_set_cache.get_cache(invoke_subgraph)
|
||||
return cache
|
||||
|
||||
|
||||
def trace_joint_graph(fn, fw_inputs, fw_outputs):
|
||||
"""
|
||||
Naively trace out a joint graph. This simplifies the reconstruction of joint
|
||||
|
|
@ -191,11 +198,26 @@ def _(subgraph, identifier, operands):
|
|||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return invoke_subgraph(subgraph, identifier, operands)
|
||||
|
||||
# Check if we have already traced the subgraph.
|
||||
invoke_subgraph_cache = get_invoke_subgraph_cache()
|
||||
if invoke_subgraph_cache:
|
||||
if saved_autograd_fn := invoke_subgraph_cache.get_autograd_key_entry(
|
||||
identifier
|
||||
):
|
||||
return saved_autograd_fn(*operands)
|
||||
|
||||
fw_graph, bw_graph, num_fw_outs = create_fw_bw_graph(subgraph, operands)
|
||||
# TODO(anijain2305) - Implement caching of autograd function op.
|
||||
return InvokeSubgraphAutogradOp.apply(
|
||||
fw_graph, bw_graph, identifier, num_fw_outs, *operands
|
||||
)
|
||||
|
||||
def autograd_fn_callable(*args):
|
||||
return InvokeSubgraphAutogradOp.apply(
|
||||
fw_graph, bw_graph, identifier, num_fw_outs, *args
|
||||
)
|
||||
|
||||
# Save the autograd_fn_callable in the dispatch set cache.
|
||||
if invoke_subgraph_cache:
|
||||
invoke_subgraph_cache.add_autograd_key_entry(identifier, autograd_fn_callable)
|
||||
|
||||
return autograd_fn_callable(*operands)
|
||||
|
||||
|
||||
@invoke_subgraph.py_functionalize_impl
|
||||
|
|
@ -218,18 +240,27 @@ def _(mode, subgraph, identifier, operands):
|
|||
|
||||
@invoke_subgraph.py_impl(ProxyTorchDispatchMode)
|
||||
def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands):
|
||||
# TODO(anijain2305) - Implement proxy tensor caching.
|
||||
example_out = invoke_subgraph(subgraph, identifier, operands)
|
||||
graph = reenter_make_fx(subgraph)(*operands)
|
||||
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
|
||||
qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph")
|
||||
proxy_mode.tracer.root.register_module(qualname, graph)
|
||||
# Check if we have already traced the subgraph.
|
||||
graph = None
|
||||
invoke_subgraph_cache = get_invoke_subgraph_cache()
|
||||
if invoke_subgraph_cache:
|
||||
graph = invoke_subgraph_cache.get_proxy_dispatch_entry(identifier)
|
||||
|
||||
if graph is None:
|
||||
graph = reenter_make_fx(subgraph)(*operands)
|
||||
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
|
||||
qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph")
|
||||
proxy_mode.tracer.root.register_module(qualname, graph)
|
||||
if invoke_subgraph_cache:
|
||||
invoke_subgraph_cache.add_proxy_dispatch_entry(identifier, graph)
|
||||
|
||||
node_args = (graph, identifier, operands)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) # type: ignore[union-attr]
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", invoke_subgraph, proxy_args, {}
|
||||
)
|
||||
|
||||
example_out = invoke_subgraph(graph, identifier, operands)
|
||||
return track_tensor_tree(
|
||||
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user