[REFACTOR] Implement AOTDispatchCompiler wrapper (#142205)

This implements a new wrapper class AOTDispatchCompiler wrapper, which is just a wrapper around a callable that returns an OutputCode. We can then use it in AOTDispatch to decide whether or not to use the cache: if fw_compiler, bw_compiler and inference_compiler are all AOTDispatchCompilers, then we enable caching.

This type is pretty close to _CompiledFxGraphCallable, except it's not allowed to take any kwargs. Not sure how to consolidate the two ideas together just yet: unfortunately, there's no way to properly annotate the types to make them related. But a lot of the time, the input to this function will be a partially applied _CompiledFxGraphCallable.

This allows the PR above this one to enable AOTAutogradCache everywhere, but not increase instruction count or enable cache on unit tests that use aot_eager or other non inductor compilers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142205
Approved by: https://github.com/oulgen, https://github.com/bdhirsh
This commit is contained in:
James Wu 2024-12-06 10:41:15 -08:00 committed by PyTorch MergeBot
parent 5663ad99e7
commit 6e203ae6de
5 changed files with 109 additions and 27 deletions

View File

@ -41,7 +41,11 @@ from functorch.compile import (
from functorch.experimental import control_flow
from torch._decomp import decomposition_table
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
from torch._functorch.aot_autograd import (
aot_export_joint_simple,
aot_export_module,
SerializableAOTDispatchCompiler,
)
from torch._higher_order_ops.out_dtype import out_dtype
from torch._inductor.codecache import compiled_fx_graph_hash
from torch._inductor.output_code import MockFXGraphCacheOutput
@ -6831,12 +6835,13 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
def make_compiler(self, fw_graph_cell):
mock_inductor_cache = self.inductor_cache
def compiler(gm, inputs):
def compiler(gm, example_inputs):
nonlocal mock_inductor_cache, fw_graph_cell
result = mock_inductor_cache.load(gm, inputs)
result = mock_inductor_cache.load(gm, example_inputs)
fw_graph_cell[0] = gm
return result
compiler = SerializableAOTDispatchCompiler(MockFXGraphCacheOutput, compiler)
return compiler
def run_autograd(

View File

@ -9,7 +9,10 @@ import torch
from torch._dynamo import disable
from torch._dynamo.exc import TensorifyScalarRestartAnalysis
from torch._dynamo.utils import counters, defake, flatten_graph_inputs
from torch._functorch.aot_autograd import aot_module_simplified
from torch._functorch.aot_autograd import (
aot_module_simplified,
SerializableAOTDispatchCompiler,
)
from torch.utils._python_dispatch import _disable_current_modes
@ -45,14 +48,21 @@ class AotAutograd:
counters["aot_autograd"]["not_ok"] += 1
return gm
# OK attempt to compile
def wrap_bw_compiler(bw_compiler_fn):
def _wrapped_bw_compiler(*args, **kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return disable(disable(bw_compiler_fn)(*args, **kwargs))
def _wrapped_bw_compiler(*args, **kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return disable(disable(bw_compiler)(*args, **kwargs))
return _wrapped_bw_compiler
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
self.kwargs["bw_compiler"] = _wrapped_bw_compiler
if isinstance(bw_compiler, SerializableAOTDispatchCompiler):
bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn)
else:
bw_compiler = wrap_bw_compiler(bw_compiler)
self.kwargs["bw_compiler"] = bw_compiler
self.kwargs["inference_compiler"] = (
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
)

View File

@ -665,6 +665,7 @@ class AOTAutogradCache:
"""
Load a result from the cache, and reconstruct a runtime wrapper around the object
"""
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
with sanitize_gm_for_cache(gm):
compiled_fn = None

View File

@ -3,7 +3,19 @@
import itertools
from contextlib import contextmanager, nullcontext
from functools import partial, wraps
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple
from typing import (
Any,
Callable,
Dict,
List,
NewType,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
)
from unittest.mock import patch
import torch
@ -21,7 +33,8 @@ from torch._dynamo.utils import (
preserve_rng_state,
)
from torch._guards import detect_fake_mode
from torch._inductor.utils import BoxedBool
from torch._inductor.output_code import OutputCode
from torch._inductor.utils import BoxedBool, InputType
from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv
@ -435,6 +448,47 @@ aot_autograd_decompositions = {}
FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any])
TOutputCode = TypeVar("TOutputCode", bound=OutputCode)
class AOTDispatchCompiler(Protocol):
"""
Represents a fw or bw_compiler passed to AOTAutograd.
"""
def __call__(
self,
gm: torch.fx.GraphModule,
example_inputs: Sequence[InputType],
) -> Any:
...
# TODO: bikeshed on this name
class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
"""
Represents an AOTDispatchCompiler that returns an OutputCode, and is
therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode.
A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of
the kwargs in _CompileFxKwargs.
"""
def __init__(
self,
output_code_ty: Type[TOutputCode],
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
):
self.output_code_ty = output_code_ty
self.compiler_fn = compiler_fn
def __call__(
self,
gm: torch.fx.GraphModule,
example_inputs: Sequence[InputType],
) -> OutputCode:
return self.compiler_fn(gm, example_inputs)
def process_inputs(
flat_args: List[Any],
aot_config: AOTConfig,
@ -953,12 +1007,12 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
def aot_module_simplified(
mod: nn.Module,
args,
fw_compiler: Callable,
bw_compiler: Optional[Callable] = None,
fw_compiler: AOTDispatchCompiler,
bw_compiler: Optional[AOTDispatchCompiler] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
keep_inference_input_mutations=False,
inference_compiler: Optional[Callable] = None,
inference_compiler: Optional[AOTDispatchCompiler] = None,
cudagraphs: Optional[BoxedBool] = None,
) -> nn.Module:
"""
@ -1086,8 +1140,8 @@ def aot_module_simplified(
# Autograd cache stuff
remote = should_use_remote_autograd_cache()
local = should_use_local_autograd_cache()
if local or remote:
# We only care if the forward will return an OutputCode.
if (local or remote) and isinstance(fw_compiler, SerializableAOTDispatchCompiler):
compiled_fn = AOTAutogradCache.load(
dispatch_and_compile,
mod,

View File

@ -52,7 +52,11 @@ from torch._dynamo.utils import (
set_feature_use,
)
from torch._functorch import config as functorch_config
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
from torch._functorch.aot_autograd import (
aot_export_module,
make_boxed_func,
SerializableAOTDispatchCompiler,
)
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
from torch._inductor.debug import save_args_for_compile_fx_inner
@ -1663,20 +1667,20 @@ def compile_fx(
)
def fw_compiler_base(
model: GraphModule,
example_inputs: List[InputType],
gm: GraphModule,
example_inputs: Sequence[InputType],
is_inference: bool,
) -> OutputCode:
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
if is_inference:
# partition_fn won't be called
_recursive_joint_graph_passes(model)
_recursive_joint_graph_passes(gm)
fixed = torch._inductor.utils.num_fw_fixed_arguments(
num_example_inputs, len(example_inputs)
)
model_outputs_node = output_node(model)
model_outputs_node = output_node(gm)
if config.keep_output_stride:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
num_model_outputs = len(model_outputs)
@ -1733,7 +1737,7 @@ def compile_fx(
model_outputs_node.meta["user_visible_output_idxs"] = []
return inner_compile(
model,
gm,
example_inputs,
static_input_idxs=get_static_input_idxs(fixed),
cudagraphs=cudagraphs,
@ -1742,7 +1746,10 @@ def compile_fx(
boxed_forward_device_index=forward_device,
)
fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
fw_compiler: Callable[
[GraphModule, Sequence[InputType]], OutputCode
] = functools.partial(fw_compiler_base, is_inference=False)
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
if config.freezing and not torch.is_grad_enabled():
inference_compiler: Callable[..., Any] = functools.partial(
@ -1756,6 +1763,9 @@ def compile_fx(
)
else:
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
inference_compiler = SerializableAOTDispatchCompiler(
OutputCode, inference_compiler
)
def partition_fn(
gm: GraphModule,
@ -1771,14 +1781,14 @@ def compile_fx(
@compile_time_strobelight_meta(phase_name="backward")
def bw_compiler(
model: GraphModule, example_inputs: List[InputType]
gm: GraphModule, example_inputs: Sequence[InputType]
) -> OutputCode:
from torch._dynamo.convert_frame import compile_lock
with dynamo_utils.dynamo_timed(
"compile_fx.<locals>.bw_compiler"
), compile_lock:
model_outputs_node = output_node(model)
model_outputs_node = output_node(gm)
if config.bw_outputs_user_visible:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
model_outputs_node.meta["user_visible_output_idxs"] = [
@ -1789,12 +1799,12 @@ def compile_fx(
else:
model_outputs_node.meta["user_visible_output_idxs"] = []
fixed = count_tangents(model)
fixed = count_tangents(gm)
with config.patch(
get_cpp_wrapper_config()
) if config.cpp_wrapper else contextlib.nullcontext():
return inner_compile(
model,
gm,
example_inputs,
static_input_idxs=list(range(fixed)),
cudagraphs=cudagraphs,
@ -1803,6 +1813,8 @@ def compile_fx(
boxed_forward_device_index=forward_device,
)
bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler)
fake_mode = detect_fake_mode(
example_inputs_
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)