[REFACTOR] Implement AOTDispatchCompiler wrapper (#142205)

This implements a new wrapper class AOTDispatchCompiler wrapper, which is just a wrapper around a callable that returns an OutputCode. We can then use it in AOTDispatch to decide whether or not to use the cache: if fw_compiler, bw_compiler and inference_compiler are all AOTDispatchCompilers, then we enable caching.

This type is pretty close to _CompiledFxGraphCallable, except it's not allowed to take any kwargs. Not sure how to consolidate the two ideas together just yet: unfortunately, there's no way to properly annotate the types to make them related. But a lot of the time, the input to this function will be a partially applied _CompiledFxGraphCallable.

This allows the PR above this one to enable AOTAutogradCache everywhere, but not increase instruction count or enable cache on unit tests that use aot_eager or other non inductor compilers.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142205
Approved by: https://github.com/oulgen, https://github.com/bdhirsh
This commit is contained in:
James Wu 2024-12-06 10:41:15 -08:00 committed by PyTorch MergeBot
parent 5663ad99e7
commit 6e203ae6de
5 changed files with 109 additions and 27 deletions

View File

@ -41,7 +41,11 @@ from functorch.compile import (
from functorch.experimental import control_flow from functorch.experimental import control_flow
from torch._decomp import decomposition_table from torch._decomp import decomposition_table
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module from torch._functorch.aot_autograd import (
aot_export_joint_simple,
aot_export_module,
SerializableAOTDispatchCompiler,
)
from torch._higher_order_ops.out_dtype import out_dtype from torch._higher_order_ops.out_dtype import out_dtype
from torch._inductor.codecache import compiled_fx_graph_hash from torch._inductor.codecache import compiled_fx_graph_hash
from torch._inductor.output_code import MockFXGraphCacheOutput from torch._inductor.output_code import MockFXGraphCacheOutput
@ -6831,12 +6835,13 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
def make_compiler(self, fw_graph_cell): def make_compiler(self, fw_graph_cell):
mock_inductor_cache = self.inductor_cache mock_inductor_cache = self.inductor_cache
def compiler(gm, inputs): def compiler(gm, example_inputs):
nonlocal mock_inductor_cache, fw_graph_cell nonlocal mock_inductor_cache, fw_graph_cell
result = mock_inductor_cache.load(gm, inputs) result = mock_inductor_cache.load(gm, example_inputs)
fw_graph_cell[0] = gm fw_graph_cell[0] = gm
return result return result
compiler = SerializableAOTDispatchCompiler(MockFXGraphCacheOutput, compiler)
return compiler return compiler
def run_autograd( def run_autograd(

View File

@ -9,7 +9,10 @@ import torch
from torch._dynamo import disable from torch._dynamo import disable
from torch._dynamo.exc import TensorifyScalarRestartAnalysis from torch._dynamo.exc import TensorifyScalarRestartAnalysis
from torch._dynamo.utils import counters, defake, flatten_graph_inputs from torch._dynamo.utils import counters, defake, flatten_graph_inputs
from torch._functorch.aot_autograd import aot_module_simplified from torch._functorch.aot_autograd import (
aot_module_simplified,
SerializableAOTDispatchCompiler,
)
from torch.utils._python_dispatch import _disable_current_modes from torch.utils._python_dispatch import _disable_current_modes
@ -45,14 +48,21 @@ class AotAutograd:
counters["aot_autograd"]["not_ok"] += 1 counters["aot_autograd"]["not_ok"] += 1
return gm return gm
# OK attempt to compile def wrap_bw_compiler(bw_compiler_fn):
def _wrapped_bw_compiler(*args, **kwargs):
# stop TorchDynamo from trying to compile our generated backwards pass
return disable(disable(bw_compiler_fn)(*args, **kwargs))
def _wrapped_bw_compiler(*args, **kwargs): return _wrapped_bw_compiler
# stop TorchDynamo from trying to compile our generated backwards pass
return disable(disable(bw_compiler)(*args, **kwargs))
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"] bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
self.kwargs["bw_compiler"] = _wrapped_bw_compiler
if isinstance(bw_compiler, SerializableAOTDispatchCompiler):
bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn)
else:
bw_compiler = wrap_bw_compiler(bw_compiler)
self.kwargs["bw_compiler"] = bw_compiler
self.kwargs["inference_compiler"] = ( self.kwargs["inference_compiler"] = (
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"] self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
) )

View File

@ -665,6 +665,7 @@ class AOTAutogradCache:
""" """
Load a result from the cache, and reconstruct a runtime wrapper around the object Load a result from the cache, and reconstruct a runtime wrapper around the object
""" """
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
with sanitize_gm_for_cache(gm): with sanitize_gm_for_cache(gm):
compiled_fn = None compiled_fn = None

View File

@ -3,7 +3,19 @@
import itertools import itertools
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from functools import partial, wraps from functools import partial, wraps
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple from typing import (
Any,
Callable,
Dict,
List,
NewType,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
)
from unittest.mock import patch from unittest.mock import patch
import torch import torch
@ -21,7 +33,8 @@ from torch._dynamo.utils import (
preserve_rng_state, preserve_rng_state,
) )
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
from torch._inductor.utils import BoxedBool from torch._inductor.output_code import OutputCode
from torch._inductor.utils import BoxedBool, InputType
from torch._subclasses import FakeTensor, FakeTensorMode from torch._subclasses import FakeTensor, FakeTensorMode
from torch.fx.experimental.proxy_tensor import make_fx from torch.fx.experimental.proxy_tensor import make_fx
from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.fx.experimental.symbolic_shapes import ShapeEnv
@ -435,6 +448,47 @@ aot_autograd_decompositions = {}
FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any]) FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any])
TOutputCode = TypeVar("TOutputCode", bound=OutputCode)
class AOTDispatchCompiler(Protocol):
"""
Represents a fw or bw_compiler passed to AOTAutograd.
"""
def __call__(
self,
gm: torch.fx.GraphModule,
example_inputs: Sequence[InputType],
) -> Any:
...
# TODO: bikeshed on this name
class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
"""
Represents an AOTDispatchCompiler that returns an OutputCode, and is
therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode.
A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of
the kwargs in _CompileFxKwargs.
"""
def __init__(
self,
output_code_ty: Type[TOutputCode],
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
):
self.output_code_ty = output_code_ty
self.compiler_fn = compiler_fn
def __call__(
self,
gm: torch.fx.GraphModule,
example_inputs: Sequence[InputType],
) -> OutputCode:
return self.compiler_fn(gm, example_inputs)
def process_inputs( def process_inputs(
flat_args: List[Any], flat_args: List[Any],
aot_config: AOTConfig, aot_config: AOTConfig,
@ -953,12 +1007,12 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
def aot_module_simplified( def aot_module_simplified(
mod: nn.Module, mod: nn.Module,
args, args,
fw_compiler: Callable, fw_compiler: AOTDispatchCompiler,
bw_compiler: Optional[Callable] = None, bw_compiler: Optional[AOTDispatchCompiler] = None,
partition_fn: Callable = default_partition, partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None, decompositions: Optional[Dict] = None,
keep_inference_input_mutations=False, keep_inference_input_mutations=False,
inference_compiler: Optional[Callable] = None, inference_compiler: Optional[AOTDispatchCompiler] = None,
cudagraphs: Optional[BoxedBool] = None, cudagraphs: Optional[BoxedBool] = None,
) -> nn.Module: ) -> nn.Module:
""" """
@ -1086,8 +1140,8 @@ def aot_module_simplified(
# Autograd cache stuff # Autograd cache stuff
remote = should_use_remote_autograd_cache() remote = should_use_remote_autograd_cache()
local = should_use_local_autograd_cache() local = should_use_local_autograd_cache()
# We only care if the forward will return an OutputCode.
if local or remote: if (local or remote) and isinstance(fw_compiler, SerializableAOTDispatchCompiler):
compiled_fn = AOTAutogradCache.load( compiled_fn = AOTAutogradCache.load(
dispatch_and_compile, dispatch_and_compile,
mod, mod,

View File

@ -52,7 +52,11 @@ from torch._dynamo.utils import (
set_feature_use, set_feature_use,
) )
from torch._functorch import config as functorch_config from torch._functorch import config as functorch_config
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func from torch._functorch.aot_autograd import (
aot_export_module,
make_boxed_func,
SerializableAOTDispatchCompiler,
)
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
from torch._inductor.debug import save_args_for_compile_fx_inner from torch._inductor.debug import save_args_for_compile_fx_inner
@ -1663,20 +1667,20 @@ def compile_fx(
) )
def fw_compiler_base( def fw_compiler_base(
model: GraphModule, gm: GraphModule,
example_inputs: List[InputType], example_inputs: Sequence[InputType],
is_inference: bool, is_inference: bool,
) -> OutputCode: ) -> OutputCode:
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"): with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
if is_inference: if is_inference:
# partition_fn won't be called # partition_fn won't be called
_recursive_joint_graph_passes(model) _recursive_joint_graph_passes(gm)
fixed = torch._inductor.utils.num_fw_fixed_arguments( fixed = torch._inductor.utils.num_fw_fixed_arguments(
num_example_inputs, len(example_inputs) num_example_inputs, len(example_inputs)
) )
model_outputs_node = output_node(model) model_outputs_node = output_node(gm)
if config.keep_output_stride: if config.keep_output_stride:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
num_model_outputs = len(model_outputs) num_model_outputs = len(model_outputs)
@ -1733,7 +1737,7 @@ def compile_fx(
model_outputs_node.meta["user_visible_output_idxs"] = [] model_outputs_node.meta["user_visible_output_idxs"] = []
return inner_compile( return inner_compile(
model, gm,
example_inputs, example_inputs,
static_input_idxs=get_static_input_idxs(fixed), static_input_idxs=get_static_input_idxs(fixed),
cudagraphs=cudagraphs, cudagraphs=cudagraphs,
@ -1742,7 +1746,10 @@ def compile_fx(
boxed_forward_device_index=forward_device, boxed_forward_device_index=forward_device,
) )
fw_compiler = functools.partial(fw_compiler_base, is_inference=False) fw_compiler: Callable[
[GraphModule, Sequence[InputType]], OutputCode
] = functools.partial(fw_compiler_base, is_inference=False)
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
if config.freezing and not torch.is_grad_enabled(): if config.freezing and not torch.is_grad_enabled():
inference_compiler: Callable[..., Any] = functools.partial( inference_compiler: Callable[..., Any] = functools.partial(
@ -1756,6 +1763,9 @@ def compile_fx(
) )
else: else:
inference_compiler = functools.partial(fw_compiler_base, is_inference=True) inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
inference_compiler = SerializableAOTDispatchCompiler(
OutputCode, inference_compiler
)
def partition_fn( def partition_fn(
gm: GraphModule, gm: GraphModule,
@ -1771,14 +1781,14 @@ def compile_fx(
@compile_time_strobelight_meta(phase_name="backward") @compile_time_strobelight_meta(phase_name="backward")
def bw_compiler( def bw_compiler(
model: GraphModule, example_inputs: List[InputType] gm: GraphModule, example_inputs: Sequence[InputType]
) -> OutputCode: ) -> OutputCode:
from torch._dynamo.convert_frame import compile_lock from torch._dynamo.convert_frame import compile_lock
with dynamo_utils.dynamo_timed( with dynamo_utils.dynamo_timed(
"compile_fx.<locals>.bw_compiler" "compile_fx.<locals>.bw_compiler"
), compile_lock: ), compile_lock:
model_outputs_node = output_node(model) model_outputs_node = output_node(gm)
if config.bw_outputs_user_visible: if config.bw_outputs_user_visible:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
model_outputs_node.meta["user_visible_output_idxs"] = [ model_outputs_node.meta["user_visible_output_idxs"] = [
@ -1789,12 +1799,12 @@ def compile_fx(
else: else:
model_outputs_node.meta["user_visible_output_idxs"] = [] model_outputs_node.meta["user_visible_output_idxs"] = []
fixed = count_tangents(model) fixed = count_tangents(gm)
with config.patch( with config.patch(
get_cpp_wrapper_config() get_cpp_wrapper_config()
) if config.cpp_wrapper else contextlib.nullcontext(): ) if config.cpp_wrapper else contextlib.nullcontext():
return inner_compile( return inner_compile(
model, gm,
example_inputs, example_inputs,
static_input_idxs=list(range(fixed)), static_input_idxs=list(range(fixed)),
cudagraphs=cudagraphs, cudagraphs=cudagraphs,
@ -1803,6 +1813,8 @@ def compile_fx(
boxed_forward_device_index=forward_device, boxed_forward_device_index=forward_device,
) )
bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler)
fake_mode = detect_fake_mode( fake_mode = detect_fake_mode(
example_inputs_ example_inputs_
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True) ) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)