[hierarchical-compilation][invoke_subgraph] Use tracing context to cache artifacts of dispatch keys (#137965)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137965
Approved by: https://github.com/zou3519
ghstack dependencies: #137538, #138036
This commit is contained in:
Animesh Jain 2024-10-21 16:33:44 -07:00 committed by PyTorch MergeBot
parent e045e8f0df
commit 2e48788a35
5 changed files with 150 additions and 22 deletions

View File

@ -125,6 +125,13 @@ class TestInvokeSubgraph(TestCase):
@skipIfTorchDynamo("Not a torch._dynamo test")
class TestInvokeSubgraphCompile(TestCase):
def count_unique_get_attr_nodes(self, gm, args, expected):
subgraph_attr_names = set()
for node in gm.graph.nodes:
if node.op == "get_attr":
subgraph_attr_names.add(node.target)
self.assertEqual(len(subgraph_attr_names), expected)
def test_simple(self):
def gn(x, y):
return (torch.mul(x, y),)
@ -148,7 +155,7 @@ class TestInvokeSubgraphCompile(TestCase):
self.assertEqual(x.grad, x_clone.grad)
self.assertEqual(y.grad, y_clone.grad)
def test_multiple(self):
def test_dedupe(self):
def gn(x, y):
return (torch.mul(x, y),)
@ -173,13 +180,13 @@ class TestInvokeSubgraphCompile(TestCase):
self.assertEqual(x.grad, x_clone.grad)
self.assertEqual(y.grad, y_clone.grad)
# Check that the Dynamo graph has just one subgraph module
# Check that the Dynamo and AOT graphs have just one subgraph module
self.assertEqual(len(backend.graphs), 1)
subgraph_attr_names = set()
for node in backend.graphs[0].graph.nodes:
if node.op == "get_attr":
subgraph_attr_names.add(node.target)
self.assertEqual(len(subgraph_attr_names), 1)
self.assertEqual(len(backend.fw_graphs), 1)
self.assertEqual(len(backend.bw_graphs), 1)
self.count_unique_get_attr_nodes(backend.graphs[0], [], 1)
self.count_unique_get_attr_nodes(backend.fw_graphs[0], [], 1)
self.count_unique_get_attr_nodes(backend.bw_graphs[0], [], 1)
if not TEST_WITH_CROSSREF:
self.assertExpectedInline(
@ -206,6 +213,27 @@ class GraphModule(torch.nn.Module):
""",
)
self.assertExpectedInline(
normalize_gm(backend.fw_graphs[0].print_readable(print_output=False)),
"""\
class GraphModule(torch.nn.Module):
def forward(self, primals_1: "f32[8]", primals_2: "f32[8]"):
repeated_subgraph0 = self.repeated_subgraph0
invoke_subgraph = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0, '___forward_invoke_subgraph_0', (primals_1, primals_2)); repeated_subgraph0 = None
getitem: "f32[8]" = invoke_subgraph[0]; invoke_subgraph = None
repeated_subgraph0_1 = self.repeated_subgraph0
invoke_subgraph_1 = torch.ops.higher_order.invoke_subgraph(repeated_subgraph0_1, '___forward_invoke_subgraph_0', (getitem, primals_2)); repeated_subgraph0_1 = None
getitem_1: "f32[8]" = invoke_subgraph_1[0]; invoke_subgraph_1 = None
return (getitem_1, primals_1, primals_2, getitem)
class repeated_subgraph0(torch.nn.Module):
def forward(self, arg0_1: "f32[8]", arg1_1: "f32[8]"):
mul: "f32[8]" = torch.ops.aten.mul.Tensor(arg0_1, arg1_1); arg0_1 = arg1_1 = None
return (mul,)
""",
)
def test_nonlocal_update(self):
counter = 2

View File

@ -417,7 +417,6 @@ class OutputGraph:
)
self.guard_on_key_order: Set[str] = set()
self.seen_invoke_subgraphs: Dict[str, str] = {}
def install_builtins_dict_in_fglobals(self):
# f_globals["__builtins__"] can be a dict or a module. This is an

View File

@ -2629,13 +2629,22 @@ class InvokeSubgraphHigherOrderVariable(WrapHigherOrderVariable):
key = hash_graph_and_inputs(tx, body_gmod, fake_inputs)
if key in tx.output.seen_invoke_subgraphs:
return tx.output.seen_invoke_subgraphs[key]
invoke_subgraph_cache = (
tx.output.tracing_context.hop_dispatch_set_cache.get_cache(
torch._higher_order_ops.invoke_subgraph
)
)
if invoke_subgraph_cache:
if identifier := invoke_subgraph_cache.get_dynamo_identifier(key):
return identifier
body_name = super().install_subgraph_in_output_graph(
tx, fn_vt, fn_args_vt, kwargs, body_gmod, attr_name
)
tx.output.seen_invoke_subgraphs[key] = body_name
if invoke_subgraph_cache:
invoke_subgraph_cache.add_dynamo_identifier(key, body_name)
return body_name
def call_function(

View File

@ -570,6 +570,66 @@ class GuardsContext(Checkpointable[GuardsCheckpointState]):
self.dynamo_guards = GuardsSet(state.dynamo_guards)
class HopSubgraphCache:
@abstractmethod
def add_dynamo_identifier(self, cache_key: str, identifier: str): ...
@abstractmethod
def get_dynamo_identifier(self, cache_key: str) -> Optional[str]: ...
@abstractmethod
def add_autograd_key_entry(self, identifier: str, key: Callable): ...
@abstractmethod
def get_autograd_key_entry(self, identifier: str): ...
@abstractmethod
def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ...
@abstractmethod
def get_proxy_dispatch_entry(self, identifier: str): ...
class InvokeSubgraphCache(HopSubgraphCache):
def __init__(self) -> None:
self.autograd_cache: Dict[str, Callable] = {}
self.proxy_dispatch_cache: Dict[str, Callable] = {}
self.dynamo_identifiers: Dict[str, str] = {}
def add_dynamo_identifier(self, cache_key: str, identifier: str):
self.dynamo_identifiers[cache_key] = identifier
def get_dynamo_identifier(self, cache_key: str) -> Optional[str]:
return self.dynamo_identifiers.get(cache_key, None)
def add_autograd_key_entry(self, identifier: str, key: Callable):
self.autograd_cache[identifier] = key
def get_autograd_key_entry(self, identifier: str):
return self.autograd_cache.get(identifier, None)
def add_proxy_dispatch_entry(self, identifier: str, key: Callable):
self.proxy_dispatch_cache[identifier] = key
def get_proxy_dispatch_entry(self, identifier: str):
return self.proxy_dispatch_cache.get(identifier, None)
class HopDispatchSetCache:
def __init__(self) -> None:
# Delayed import to avoid circular dependency
from torch._higher_order_ops.invoke_subgraph import invoke_subgraph
self.hop_cache_map = {invoke_subgraph: InvokeSubgraphCache()}
def get_cache(
self, op: torch._ops.HigherOrderOperator
) -> Optional[HopSubgraphCache]:
if op not in self.hop_cache_map:
return None
return self.hop_cache_map[op] # type: ignore[index]
_TLS = threading.local()
"""
@ -686,6 +746,7 @@ class TracingContext:
# meta on the first invocation
# see note: [Returning Fake Tensors on First AOT Autograd Call]
self.fakify_first_call = False
self.hop_dispatch_set_cache = HopDispatchSetCache()
def clear(self):
# Look at the note in output_graph.py in function `save_global_state`

View File

@ -62,6 +62,13 @@ class InvokeSubgraphHOP(HigherOrderOperator):
invoke_subgraph = InvokeSubgraphHOP()
def get_invoke_subgraph_cache():
cache = None
if tracing_ctx := torch._guards.TracingContext.try_get():
cache = tracing_ctx.hop_dispatch_set_cache.get_cache(invoke_subgraph)
return cache
def trace_joint_graph(fn, fw_inputs, fw_outputs):
"""
Naively trace out a joint graph. This simplifies the reconstruction of joint
@ -191,11 +198,26 @@ def _(subgraph, identifier, operands):
with torch._C._AutoDispatchBelowAutograd():
return invoke_subgraph(subgraph, identifier, operands)
# Check if we have already traced the subgraph.
invoke_subgraph_cache = get_invoke_subgraph_cache()
if invoke_subgraph_cache:
if saved_autograd_fn := invoke_subgraph_cache.get_autograd_key_entry(
identifier
):
return saved_autograd_fn(*operands)
fw_graph, bw_graph, num_fw_outs = create_fw_bw_graph(subgraph, operands)
# TODO(anijain2305) - Implement caching of autograd function op.
return InvokeSubgraphAutogradOp.apply(
fw_graph, bw_graph, identifier, num_fw_outs, *operands
)
def autograd_fn_callable(*args):
return InvokeSubgraphAutogradOp.apply(
fw_graph, bw_graph, identifier, num_fw_outs, *args
)
# Save the autograd_fn_callable in the dispatch set cache.
if invoke_subgraph_cache:
invoke_subgraph_cache.add_autograd_key_entry(identifier, autograd_fn_callable)
return autograd_fn_callable(*operands)
@invoke_subgraph.py_functionalize_impl
@ -218,18 +240,27 @@ def _(mode, subgraph, identifier, operands):
@invoke_subgraph.py_impl(ProxyTorchDispatchMode)
def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, operands):
# TODO(anijain2305) - Implement proxy tensor caching.
example_out = invoke_subgraph(subgraph, identifier, operands)
graph = reenter_make_fx(subgraph)(*operands)
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph")
proxy_mode.tracer.root.register_module(qualname, graph)
# Check if we have already traced the subgraph.
graph = None
invoke_subgraph_cache = get_invoke_subgraph_cache()
if invoke_subgraph_cache:
graph = invoke_subgraph_cache.get_proxy_dispatch_entry(identifier)
if graph is None:
graph = reenter_make_fx(subgraph)(*operands)
assert isinstance(proxy_mode.tracer, torch.fx.Tracer)
qualname = proxy_mode.tracer.get_fresh_qualname("repeated_subgraph")
proxy_mode.tracer.root.register_module(qualname, graph)
if invoke_subgraph_cache:
invoke_subgraph_cache.add_proxy_dispatch_entry(identifier, graph)
node_args = (graph, identifier, operands)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args) # type: ignore[union-attr]
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", invoke_subgraph, proxy_args, {}
)
example_out = invoke_subgraph(graph, identifier, operands)
return track_tensor_tree(
example_out, out_proxy, constant=None, tracer=proxy_mode.tracer
)