mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[standalone_compile] Dynamic shape handling (#151788)
standalone_compile needs to get dynamic shape information from somewhere. We add a new `dynamic_shapes` argument with three options: 1. from the passed-in graph (dynamic="from_graph"). This is the default. 2. from the example inputs, thereby specializing on them. (dynamic="from_example_inputs") 3. from the current tracing context (dynamic="from_tracing_context") 1 and 3 are not exactly the same. 2 can also be used for more advanced things... (specialize on one input but not the other). Most of this PR is tests. Test Plan: - a lot of new tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/151788 Approved by: https://github.com/oulgen
This commit is contained in:
parent
7e4b89ac6c
commit
596296fb0b
|
|
@ -1448,7 +1448,7 @@ class TestStandaloneCompile(TestCase):
|
|||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
|
||||
def capture(self, fn):
|
||||
def capture(self, fn, dynamic=None):
|
||||
def inner(*args):
|
||||
gm = None
|
||||
actual_args = None
|
||||
|
|
@ -1463,7 +1463,9 @@ class TestStandaloneCompile(TestCase):
|
|||
kwargs = kwargs_
|
||||
return gm
|
||||
|
||||
_ = torch.compile(fn, fullgraph=True, backend=backend)(*args)
|
||||
_ = torch.compile(fn, fullgraph=True, backend=backend, dynamic=dynamic)(
|
||||
*args
|
||||
)
|
||||
return gm, actual_args, kwargs
|
||||
|
||||
return inner
|
||||
|
|
@ -1506,7 +1508,13 @@ class TestStandaloneCompile(TestCase):
|
|||
|
||||
with fresh_inductor_cache():
|
||||
loaded = torch._inductor.CompiledArtifact.load(path=path, format=format)
|
||||
compiled_out = loaded(*args)
|
||||
if dynamic:
|
||||
concrete_args = [
|
||||
4 if isinstance(a, torch.SymInt) else a for a in args
|
||||
]
|
||||
else:
|
||||
concrete_args = args
|
||||
compiled_out = loaded(*concrete_args)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
|
|
@ -1540,7 +1548,6 @@ class TestStandaloneCompile(TestCase):
|
|||
def test_save_in_new_path(self) -> None:
|
||||
mod = torch.nn.Linear(1, 3)
|
||||
x = torch.randn(4, 1)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
def f(x):
|
||||
with torch.no_grad():
|
||||
|
|
@ -1675,6 +1682,152 @@ if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
|||
)
|
||||
)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_dynamic_shapes_from_graph(self):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
x = torch.ones(3)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
with fresh_inductor_cache():
|
||||
# captured graph is lambda s0, x: x * s0
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, dynamic_shapes="from_graph"
|
||||
)
|
||||
x = torch.ones(4)
|
||||
(result,) = compiled_artifact(4, x)
|
||||
self.assertEqual(result, x * 4)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_dynamic_shapes_from_example_inputs(self):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
x = torch.ones(3)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
with fresh_inductor_cache():
|
||||
# captured graph is lambda s0, x: x * s0
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
||||
# specialized on example inputs
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, (5, torch.ones(4)), dynamic_shapes="from_example_inputs"
|
||||
)
|
||||
x = torch.ones(4)
|
||||
(result,) = compiled_artifact(3, x)
|
||||
# int 5 was baked in!
|
||||
self.assertEqual(result, x * 5)
|
||||
|
||||
# size 4 was baked in
|
||||
with self.assertRaisesRegex(AssertionError, "expected size 5==4"):
|
||||
x = torch.randn(5)
|
||||
(result,) = compiled_artifact(4, x)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("dynamic_shapes", ["from_graph", "from_example_inputs"])
|
||||
def test_static_shapes(self, dynamic_shapes):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
static_x = torch.randn(3)
|
||||
with fresh_inductor_cache():
|
||||
# static_gm is lambda x: x * 3
|
||||
static_gm, args, kwargs = self.capture(f, dynamic=False)(static_x)
|
||||
assert not kwargs
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
static_gm, [static_x], dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
x = torch.randn(3)
|
||||
(result,) = compiled_artifact(x)
|
||||
self.assertEqual(result, x * 3)
|
||||
with self.assertRaisesRegex(AssertionError, "expected size 4==3"):
|
||||
x = torch.randn(4)
|
||||
(result,) = compiled_artifact(x)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("dynamic_shapes", ["from_tracing_context", "from_graph"])
|
||||
def test_backend(self, dynamic_shapes):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
x = torch.randn(3)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, args, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
y = torch.randn(4)
|
||||
(result,) = compiled_artifact(4, y)
|
||||
self.assertEqual(result, y * 4)
|
||||
return compiled_artifact
|
||||
|
||||
torch._dynamo.reset()
|
||||
_ = torch.compile(f, backend=backend)(x)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_backend_dynamic_shapes_from_example_inputs(self):
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
x = torch.ones(4)
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
compiled_artifact = torch._inductor.standalone_compile(
|
||||
gm, [5, torch.ones(4)], dynamic_shapes="from_example_inputs"
|
||||
)
|
||||
y = torch.ones(4)
|
||||
(result,) = compiled_artifact(4, y)
|
||||
# 5 was baked in
|
||||
self.assertEqual(result, y * 5)
|
||||
|
||||
# shape of y was baked in
|
||||
with self.assertRaisesRegex(AssertionError, "expected size 5==4"):
|
||||
y = torch.ones(5)
|
||||
(result,) = compiled_artifact(4, y)
|
||||
|
||||
return compiled_artifact
|
||||
|
||||
torch._dynamo.reset()
|
||||
_ = torch.compile(f, backend=backend)(x)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize(
|
||||
"dynamic_shapes", ["from_tracing_context", "from_graph", "from_example_inputs"]
|
||||
)
|
||||
def test_backend_static_shapes(self, dynamic_shapes):
|
||||
# on static_x, all of these options should produce a static graph,
|
||||
# but it's a bit hard to tell, so these are just smoke tests.
|
||||
static_x = torch.randn(3)
|
||||
|
||||
def f(x):
|
||||
return x.shape[0] * x
|
||||
|
||||
def backend(gm, args, **kwargs):
|
||||
return torch._inductor.standalone_compile(
|
||||
gm, args, dynamic_shapes=dynamic_shapes
|
||||
)
|
||||
|
||||
result = torch.compile(f, backend=backend)(static_x)
|
||||
self.assertEqual(result, static_x * 3)
|
||||
|
||||
|
||||
class TestFxGraphCacheHashing(TestCase):
|
||||
def test_parameter_constants(self):
|
||||
|
|
|
|||
|
|
@ -940,6 +940,11 @@ class AOTConfig:
|
|||
pre_dispatch: bool = False
|
||||
# Key to use for AOTAutogradCache
|
||||
cache_info: Optional[AOTAutogradCacheInfo] = None
|
||||
# If we should ignore the shape_env in the ambient tracing_context.
|
||||
# The net effect is that if dynamic shapes are on, we end up
|
||||
# specializing on example_inputs.
|
||||
# Used only by standalone_compile.
|
||||
ignore_shape_env: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
if self.pre_dispatch:
|
||||
|
|
|
|||
|
|
@ -488,11 +488,12 @@ def process_inputs(
|
|||
aot_config: AOTConfig,
|
||||
fake_mode: FakeTensorMode,
|
||||
shape_env: Optional[ShapeEnv],
|
||||
ignore_shape_env: bool = False,
|
||||
) -> FakifiedFlatArgs:
|
||||
with fake_mode:
|
||||
|
||||
def convert(idx, x):
|
||||
if shape_env is not None:
|
||||
if shape_env is not None and not ignore_shape_env:
|
||||
from torch._dynamo.source import ConstantSource
|
||||
|
||||
if isinstance(x, int):
|
||||
|
|
@ -540,13 +541,14 @@ def process_inputs(
|
|||
# Dynamo
|
||||
return fake_mode.from_tensor(x, static_shapes=True)
|
||||
|
||||
return fake_mode.from_tensor(
|
||||
result = fake_mode.from_tensor(
|
||||
x,
|
||||
static_shapes=False,
|
||||
static_shapes=ignore_shape_env,
|
||||
symbolic_context=symbolic_context,
|
||||
source=source,
|
||||
trace=trace,
|
||||
)
|
||||
return result
|
||||
|
||||
return FakifiedFlatArgs([convert(idx, x) for idx, x in enumerate(flat_args)])
|
||||
|
||||
|
|
@ -1073,6 +1075,7 @@ def aot_module_simplified(
|
|||
inference_compiler: Optional[AOTDispatchCompiler] = None,
|
||||
cudagraphs: Optional[BoxedBool] = None,
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex] = None,
|
||||
ignore_shape_env: bool = False,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
This is the simplified or low overhead version of aot_module. For frontends
|
||||
|
|
@ -1140,9 +1143,12 @@ def aot_module_simplified(
|
|||
is_export=False,
|
||||
no_tangents=False,
|
||||
cache_info=None,
|
||||
ignore_shape_env=ignore_shape_env,
|
||||
)
|
||||
fake_mode, shape_env = construct_fake_mode(full_args, aot_config)
|
||||
fake_flat_args = process_inputs(full_args, aot_config, fake_mode, shape_env)
|
||||
fake_flat_args = process_inputs(
|
||||
full_args, aot_config, fake_mode, shape_env, ignore_shape_env
|
||||
)
|
||||
|
||||
def dispatch_and_compile():
|
||||
functional_call = create_functional_call(mod, params_spec, params_len)
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
import io
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, IO, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, IO, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch._inductor.config
|
||||
import torch.fx
|
||||
|
|
@ -366,6 +366,10 @@ def cudagraph_mark_step_begin():
|
|||
def standalone_compile(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: list[InputType],
|
||||
*,
|
||||
dynamic_shapes: Literal[
|
||||
"from_example_inputs", "from_tracing_context", "from_graph"
|
||||
] = "from_graph",
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> CompiledArtifact:
|
||||
"""
|
||||
|
|
@ -383,6 +387,12 @@ def standalone_compile(
|
|||
Args:
|
||||
gm: Graph Module
|
||||
example_inputs: Inputs for the graph module
|
||||
dynamic_shapes: If "from_graph" (default), we will use the dynamic
|
||||
shapes in the passed-in graph module.
|
||||
If "from_tracing_context", we use the dynamic shape info in the
|
||||
ambient tracing context.
|
||||
If "from_example_inputs", we will specialize the graph on the
|
||||
example_inputs.
|
||||
options: Inductor compilation options
|
||||
|
||||
Returns:
|
||||
|
|
@ -391,4 +401,6 @@ def standalone_compile(
|
|||
from .standalone_compile import standalone_compile
|
||||
|
||||
options = options if options else {}
|
||||
return standalone_compile(gm, example_inputs, **options)
|
||||
return standalone_compile(
|
||||
gm, example_inputs, dynamic_shapes=dynamic_shapes, options=options
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1868,6 +1868,7 @@ def compile_fx(
|
|||
inner_compile: Callable[..., OutputCode] = compile_fx_inner,
|
||||
config_patches: Optional[dict[str, Any]] = None,
|
||||
decompositions: Optional[dict[OpOverload, Callable[..., Any]]] = None,
|
||||
ignore_shape_env: bool = False,
|
||||
) -> Union[Callable[[list[object]], Sequence[torch.Tensor]], str, list[str]]:
|
||||
"""
|
||||
Main entry point for compiling given FX graph. Despite the fact that this
|
||||
|
|
@ -2279,6 +2280,7 @@ def compile_fx(
|
|||
keep_inference_input_mutations=True,
|
||||
cudagraphs=cudagraphs,
|
||||
boxed_forward_device_index=forward_device,
|
||||
ignore_shape_env=ignore_shape_env,
|
||||
)(model_, example_inputs_)
|
||||
except ShortenTraceback as e:
|
||||
# We will also shorten the traceback inside dynamo.
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from contextlib import AbstractContextManager, nullcontext
|
|||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
|
||||
|
||||
import torch.fx
|
||||
from torch._dynamo.utils import detect_fake_mode, dynamo_timed
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
||||
from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir
|
||||
from torch._inductor.utils import BoxedBool, InputType
|
||||
|
|
@ -178,22 +178,54 @@ class CompiledArtifact:
|
|||
|
||||
|
||||
def standalone_compile(
|
||||
gm: GraphModule, example_inputs: Sequence[InputType], **kwargs: Any
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
*,
|
||||
dynamic_shapes: Any,
|
||||
options: Any,
|
||||
) -> CompiledArtifact:
|
||||
from torch.compiler._cache import CacheArtifactManager
|
||||
|
||||
from .compile_fx import compile_fx
|
||||
|
||||
fake_mode = detect_fake_mode(example_inputs)
|
||||
if fake_mode is None:
|
||||
ignore_shape_env = False
|
||||
if dynamic_shapes == "from_example_inputs":
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
# tells compile_fx to ignore the shape_envs on the ambient context
|
||||
# and the graph_module.
|
||||
ignore_shape_env = True
|
||||
elif dynamic_shapes == "from_tracing_context":
|
||||
# Reuse fake_mode from the TracingContext.
|
||||
# NB: The TracingContext only exists if we're currently in a torch.compile backend.
|
||||
context = torch._guards.TracingContext.get()
|
||||
fake_mode = context.fake_mode
|
||||
elif dynamic_shapes == "from_graph":
|
||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||
# Strategy: find a FakeTensor in the graph output, grab its FakeTensorMode.
|
||||
# The graph passed to standalone_compile must be an Inductor-approved graph,
|
||||
# which means that there is at least one Tensor output and the output node
|
||||
# contains a flat list of Tensors.
|
||||
last_node = next(iter(reversed(gm.graph.nodes)))
|
||||
assert last_node.op == "output"
|
||||
assert len(last_node.args) == 1
|
||||
for node in last_node.args[0]:
|
||||
if "example_value" in node.meta:
|
||||
maybe_tensor = node.meta["example_value"]
|
||||
if isinstance(maybe_tensor, torch._subclasses.fake_tensor.FakeTensor):
|
||||
fake_mode = maybe_tensor.fake_mode
|
||||
else:
|
||||
raise ValueError(
|
||||
f"standalone_compile got unsupported `dynamic_shapes` value: dynamic_shapes={dynamic_shapes}."
|
||||
)
|
||||
|
||||
context = torch._guards.TracingContext(fake_mode)
|
||||
with torch._guards.tracing(context):
|
||||
with CacheArtifactManager.with_fresh_cache():
|
||||
# compile_fx can mutate gm
|
||||
gm = copy.deepcopy(gm)
|
||||
compiled_fn = compile_fx(gm, example_inputs, **kwargs)
|
||||
compiled_fn = compile_fx(
|
||||
gm, example_inputs, ignore_shape_env=ignore_shape_env, **options
|
||||
)
|
||||
assert callable(compiled_fn)
|
||||
|
||||
artifacts = torch.compiler.save_cache_artifacts()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user