[standalone_compile] Dynamic shape handling (#151788)

standalone_compile needs to get dynamic shape information from
somewhere. We add a new `dynamic_shapes` argument with three options:

1. from the passed-in graph (dynamic="from_graph"). This is the default.
2. from the example inputs, thereby specializing on them. (dynamic="from_example_inputs")
3. from the current tracing context (dynamic="from_tracing_context")

1 and 3 are not exactly the same. 2 can also be used for more advanced
things... (specialize on one input but not the other).

Most of this PR is tests.

Test Plan:
- a lot of new tests.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151788
Approved by: https://github.com/oulgen
This commit is contained in:
rzou 2025-04-21 20:13:19 -07:00 committed by PyTorch MergeBot
parent 7e4b89ac6c
commit 596296fb0b
6 changed files with 225 additions and 15 deletions

View File

@ -1448,7 +1448,7 @@ class TestStandaloneCompile(TestCase):
torch._dynamo.reset()
clear_inductor_caches()
def capture(self, fn):
def capture(self, fn, dynamic=None):
def inner(*args):
gm = None
actual_args = None
@ -1463,7 +1463,9 @@ class TestStandaloneCompile(TestCase):
kwargs = kwargs_
return gm
_ = torch.compile(fn, fullgraph=True, backend=backend)(*args)
_ = torch.compile(fn, fullgraph=True, backend=backend, dynamic=dynamic)(
*args
)
return gm, actual_args, kwargs
return inner
@ -1506,7 +1508,13 @@ class TestStandaloneCompile(TestCase):
with fresh_inductor_cache():
loaded = torch._inductor.CompiledArtifact.load(path=path, format=format)
compiled_out = loaded(*args)
if dynamic:
concrete_args = [
4 if isinstance(a, torch.SymInt) else a for a in args
]
else:
concrete_args = args
compiled_out = loaded(*concrete_args)
self.assertEqual(eager_out, compiled_out)
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
@ -1540,7 +1548,6 @@ class TestStandaloneCompile(TestCase):
def test_save_in_new_path(self) -> None:
mod = torch.nn.Linear(1, 3)
x = torch.randn(4, 1)
torch._dynamo.mark_dynamic(x, 0)
def f(x):
with torch.no_grad():
@ -1675,6 +1682,152 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
)
)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_dynamic_shapes_from_graph(self):
def f(x):
return x.shape[0] * x
x = torch.ones(3)
torch._dynamo.mark_dynamic(x, 0)
with fresh_inductor_cache():
# captured graph is lambda s0, x: x * s0
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes="from_graph"
)
x = torch.ones(4)
(result,) = compiled_artifact(4, x)
self.assertEqual(result, x * 4)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_dynamic_shapes_from_example_inputs(self):
def f(x):
return x.shape[0] * x
x = torch.ones(3)
torch._dynamo.mark_dynamic(x, 0)
with fresh_inductor_cache():
# captured graph is lambda s0, x: x * s0
gm, args, kwargs = self.capture(f)(x)
assert not kwargs
# specialized on example inputs
compiled_artifact = torch._inductor.standalone_compile(
gm, (5, torch.ones(4)), dynamic_shapes="from_example_inputs"
)
x = torch.ones(4)
(result,) = compiled_artifact(3, x)
# int 5 was baked in!
self.assertEqual(result, x * 5)
# size 4 was baked in
with self.assertRaisesRegex(AssertionError, "expected size 5==4"):
x = torch.randn(5)
(result,) = compiled_artifact(4, x)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"])
def test_static_shapes(self, dynamic_shapes):
def f(x):
return x.shape[0] * x
static_x = torch.randn(3)
with fresh_inductor_cache():
# static_gm is lambda x: x * 3
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
assert not kwargs
compiled_artifact = torch._inductor.standalone_compile(
static_gm, [static_x], dynamic_shapes=dynamic_shapes
)
x = torch.randn(3)
(result,) = compiled_artifact(x)
self.assertEqual(result, x * 3)
with self.assertRaisesRegex(AssertionError, "expected size 4==3"):
x = torch.randn(4)
(result,) = compiled_artifact(x)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"])
def test_backend(self, dynamic_shapes):
def f(x):
return x.shape[0] * x
x = torch.randn(3)
torch._dynamo.mark_dynamic(x, 0)
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, args, dynamic_shapes=dynamic_shapes
)
y = torch.randn(4)
(result,) = compiled_artifact(4, y)
self.assertEqual(result, y * 4)
return compiled_artifact
torch._dynamo.reset()
_ = torch.compile(f, backend=backend)(x)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
def test_backend_dynamic_shapes_from_example_inputs(self):
def f(x):
return x.shape[0] * x
x = torch.ones(4)
torch._dynamo.mark_dynamic(x, 0)
def backend(gm, args, **kwargs):
compiled_artifact = torch._inductor.standalone_compile(
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs"
)
y = torch.ones(4)
(result,) = compiled_artifact(4, y)
# 5 was baked in
self.assertEqual(result, y * 5)
# shape of y was baked in
with self.assertRaisesRegex(AssertionError, "expected size 5==4"):
y = torch.ones(5)
(result,) = compiled_artifact(4, y)
return compiled_artifact
torch._dynamo.reset()
_ = torch.compile(f, backend=backend)(x)
@config.patch({"fx_graph_cache": True})
@config.patch({"fx_graph_remote_cache": False})
@functorch_config.patch({"enable_autograd_cache": True})
@parametrize(
"dynamic_shapes", ["from_tracing_context", "from_graph", "from_example_inputs"]
)
def test_backend_static_shapes(self, dynamic_shapes):
# on static_x, all of these options should produce a static graph,
# but it's a bit hard to tell, so these are just smoke tests.
static_x = torch.randn(3)
def f(x):
return x.shape[0] * x
def backend(gm, args, **kwargs):
return torch._inductor.standalone_compile(
gm, args, dynamic_shapes=dynamic_shapes
)
result = torch.compile(f, backend=backend)(static_x)
self.assertEqual(result, static_x * 3)
class TestFxGraphCacheHashing(TestCase):
def test_parameter_constants(self):

View File

@ -940,6 +940,11 @@ class AOTConfig:
pre_dispatch: bool = False
# Key to use for AOTAutogradCache
cache_info: Optional[AOTAutogradCacheInfo] = None
# If we should ignore the shape_env in the ambient tracing_context.
# The net effect is that if dynamic shapes are on, we end up
# specializing on example_inputs.
# Used only by standalone_compile.
ignore_shape_env: bool = False
def __post_init__(self):
if self.pre_dispatch:

View File

@ -488,11 +488,12 @@ def process_inputs(
aot_config: AOTConfig,
fake_mode: FakeTensorMode,
shape_env: Optional[ShapeEnv],
ignore_shape_env: bool = False,
) -> FakifiedFlatArgs:
with fake_mode:
def convert(idx, x):
if shape_env is not None:
if shape_env is not None and not ignore_shape_env:
from torch._dynamo.source import ConstantSource
if isinstance(x, int):
@ -540,13 +541,14 @@ def process_inputs(
# Dynamo
return fake_mode.from_tensor(x, static_shapes=True)
return fake_mode.from_tensor(
result = fake_mode.from_tensor(
x,
static_shapes=False,
static_shapes=ignore_shape_env,
symbolic_context=symbolic_context,
source=source,
trace=trace,
)
return result
return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)])
@ -1073,6 +1075,7 @@ def aot_module_simplified(
inference_compiler: Optional[AOTDispatchCompiler] = None,
cudagraphs: Optional[BoxedBool] = None,
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
ignore_shape_env: bool = False,
) -> nn.Module:
"""
This is the simplified or low overhead version of aot_module. For frontends
@ -1140,9 +1143,12 @@ def aot_module_simplified(
is_export=False,
no_tangents=False,
cache_info=None,
ignore_shape_env=ignore_shape_env,
)
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
fake_flat_args = process_inputs(full_args, aot_config, fake_mode, shape_env)
fake_flat_args = process_inputs(
full_args, aot_config, fake_mode, shape_env, ignore_shape_env
)
def dispatch_and_compile():
functional_call = create_functional_call(mod, params_spec, params_len)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
import io
import logging
import os
from typing import Any, IO, Optional, TYPE_CHECKING, Union
from typing import Any, IO, Literal, Optional, TYPE_CHECKING, Union
import torch._inductor.config
import torch.fx
@ -366,6 +366,10 @@ def cudagraph_mark_step_begin():
def standalone_compile(
gm: torch.fx.GraphModule,
example_inputs: list[InputType],
*,
dynamic_shapes: Literal[
"from_example_inputs", "from_tracing_context", "from_graph"
] = "from_graph",
options: Optional[dict[str, Any]] = None,
) -> CompiledArtifact:
"""
@ -383,6 +387,12 @@ def standalone_compile(
Args:
gm: Graph Module
example_inputs: Inputs for the graph module
dynamic_shapes: If "from_graph" (default), we will use the dynamic
shapes in the passed-in graph module.
If "from_tracing_context", we use the dynamic shape info in the
ambient tracing context.
If "from_example_inputs", we will specialize the graph on the
example_inputs.
options: Inductor compilation options
Returns:
@ -391,4 +401,6 @@ def standalone_compile(
from .standalone_compile import standalone_compile
options = options if options else {}
return standalone_compile(gm, example_inputs, **options)
return standalone_compile(
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options
)

View File

@ -1868,6 +1868,7 @@ def compile_fx(
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
config_patches: Optional[dict[str, Any]] = None,
decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
ignore_shape_env: bool = False,
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]:
"""
Main entry point for compiling given FX graph. Despite the fact that this
@ -2279,6 +2280,7 @@ def compile_fx(
keep_inference_input_mutations=True,
cudagraphs=cudagraphs,
boxed_forward_device_index=forward_device,
ignore_shape_env=ignore_shape_env,
)(model_, example_inputs_)
except ShortenTraceback as e:
# We will also shorten the traceback inside dynamo.

View File

@ -9,7 +9,7 @@ from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
import torch.fx
from torch._dynamo.utils import detect_fake_mode, dynamo_timed
from torch._dynamo.utils import dynamo_timed
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir
from torch._inductor.utils import BoxedBool, InputType
@ -178,22 +178,54 @@ class CompiledArtifact:
def standalone_compile(
gm: GraphModule, example_inputs: Sequence[InputType], **kwargs: Any
gm: GraphModule,
example_inputs: Sequence[InputType],
*,
dynamic_shapes: Any,
options: Any,
) -> CompiledArtifact:
from torch.compiler._cache import CacheArtifactManager
from .compile_fx import compile_fx
fake_mode = detect_fake_mode(example_inputs)
if fake_mode is None:
ignore_shape_env = False
if dynamic_shapes == "from_example_inputs":
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
# tells compile_fx to ignore the shape_envs on the ambient context
# and the graph_module.
ignore_shape_env = True
elif dynamic_shapes == "from_tracing_context":
# Reuse fake_mode from the TracingContext.
# NB: The TracingContext only exists if we're currently in a torch.compile backend.
context = torch._guards.TracingContext.get()
fake_mode = context.fake_mode
elif dynamic_shapes == "from_graph":
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
# Strategy: find a FakeTensor in the graph output, grab its FakeTensorMode.
# The graph passed to standalone_compile must be an Inductor-approved graph,
# which means that there is at least one Tensor output and the output node
# contains a flat list of Tensors.
last_node = next(iter(reversed(gm.graph.nodes)))
assert last_node.op == "output"
assert len(last_node.args) == 1
for node in last_node.args[0]:
if "example_value" in node.meta:
maybe_tensor = node.meta["example_value"]
if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor):
fake_mode = maybe_tensor.fake_mode
else:
raise ValueError(
f"standalone_compile got unsupported `dynamic_shapes` value: dynamic_shapes={dynamic_shapes}."
)
context = torch._guards.TracingContext(fake_mode)
with torch._guards.tracing(context):
with CacheArtifactManager.with_fresh_cache():
# compile_fx can mutate gm
gm = copy.deepcopy(gm)
compiled_fn = compile_fx(gm, example_inputs, **kwargs)
compiled_fn = compile_fx(
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
)
assert callable(compiled_fn)
artifacts = torch.compiler.save_cache_artifacts()