mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[REFACTOR] Implement AOTDispatchCompiler wrapper (#142205)
This implements a new wrapper class AOTDispatchCompiler wrapper, which is just a wrapper around a callable that returns an OutputCode. We can then use it in AOTDispatch to decide whether or not to use the cache: if fw_compiler, bw_compiler and inference_compiler are all AOTDispatchCompilers, then we enable caching. This type is pretty close to _CompiledFxGraphCallable, except it's not allowed to take any kwargs. Not sure how to consolidate the two ideas together just yet: unfortunately, there's no way to properly annotate the types to make them related. But a lot of the time, the input to this function will be a partially applied _CompiledFxGraphCallable. This allows the PR above this one to enable AOTAutogradCache everywhere, but not increase instruction count or enable cache on unit tests that use aot_eager or other non inductor compilers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142205 Approved by: https://github.com/oulgen, https://github.com/bdhirsh
This commit is contained in:
parent
5663ad99e7
commit
6e203ae6de
|
|
@ -41,7 +41,11 @@ from functorch.compile import (
|
|||
from functorch.experimental import control_flow
|
||||
from torch._decomp import decomposition_table
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
|
||||
from torch._functorch.aot_autograd import (
|
||||
aot_export_joint_simple,
|
||||
aot_export_module,
|
||||
SerializableAOTDispatchCompiler,
|
||||
)
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch._inductor.codecache import compiled_fx_graph_hash
|
||||
from torch._inductor.output_code import MockFXGraphCacheOutput
|
||||
|
|
@ -6831,12 +6835,13 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
|
|||
def make_compiler(self, fw_graph_cell):
|
||||
mock_inductor_cache = self.inductor_cache
|
||||
|
||||
def compiler(gm, inputs):
|
||||
def compiler(gm, example_inputs):
|
||||
nonlocal mock_inductor_cache, fw_graph_cell
|
||||
result = mock_inductor_cache.load(gm, inputs)
|
||||
result = mock_inductor_cache.load(gm, example_inputs)
|
||||
fw_graph_cell[0] = gm
|
||||
return result
|
||||
|
||||
compiler = SerializableAOTDispatchCompiler(MockFXGraphCacheOutput, compiler)
|
||||
return compiler
|
||||
|
||||
def run_autograd(
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@ import torch
|
|||
from torch._dynamo import disable
|
||||
from torch._dynamo.exc import TensorifyScalarRestartAnalysis
|
||||
from torch._dynamo.utils import counters, defake, flatten_graph_inputs
|
||||
from torch._functorch.aot_autograd import aot_module_simplified
|
||||
from torch._functorch.aot_autograd import (
|
||||
aot_module_simplified,
|
||||
SerializableAOTDispatchCompiler,
|
||||
)
|
||||
from torch.utils._python_dispatch import _disable_current_modes
|
||||
|
||||
|
||||
|
|
@ -45,14 +48,21 @@ class AotAutograd:
|
|||
counters["aot_autograd"]["not_ok"] += 1
|
||||
return gm
|
||||
|
||||
# OK attempt to compile
|
||||
def wrap_bw_compiler(bw_compiler_fn):
|
||||
def _wrapped_bw_compiler(*args, **kwargs):
|
||||
# stop TorchDynamo from trying to compile our generated backwards pass
|
||||
return disable(disable(bw_compiler_fn)(*args, **kwargs))
|
||||
|
||||
def _wrapped_bw_compiler(*args, **kwargs):
|
||||
# stop TorchDynamo from trying to compile our generated backwards pass
|
||||
return disable(disable(bw_compiler)(*args, **kwargs))
|
||||
return _wrapped_bw_compiler
|
||||
|
||||
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
|
||||
self.kwargs["bw_compiler"] = _wrapped_bw_compiler
|
||||
|
||||
if isinstance(bw_compiler, SerializableAOTDispatchCompiler):
|
||||
bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn)
|
||||
else:
|
||||
bw_compiler = wrap_bw_compiler(bw_compiler)
|
||||
|
||||
self.kwargs["bw_compiler"] = bw_compiler
|
||||
self.kwargs["inference_compiler"] = (
|
||||
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
|
||||
)
|
||||
|
|
|
|||
|
|
@ -665,6 +665,7 @@ class AOTAutogradCache:
|
|||
"""
|
||||
Load a result from the cache, and reconstruct a runtime wrapper around the object
|
||||
"""
|
||||
|
||||
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
|
||||
with sanitize_gm_for_cache(gm):
|
||||
compiled_fn = None
|
||||
|
|
|
|||
|
|
@ -3,7 +3,19 @@
|
|||
import itertools
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from functools import partial, wraps
|
||||
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NewType,
|
||||
Optional,
|
||||
Protocol,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -21,7 +33,8 @@ from torch._dynamo.utils import (
|
|||
preserve_rng_state,
|
||||
)
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._inductor.utils import BoxedBool
|
||||
from torch._inductor.output_code import OutputCode
|
||||
from torch._inductor.utils import BoxedBool, InputType
|
||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
|
@ -435,6 +448,47 @@ aot_autograd_decompositions = {}
|
|||
FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any])
|
||||
|
||||
|
||||
TOutputCode = TypeVar("TOutputCode", bound=OutputCode)
|
||||
|
||||
|
||||
class AOTDispatchCompiler(Protocol):
|
||||
"""
|
||||
Represents a fw or bw_compiler passed to AOTAutograd.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
) -> Any:
|
||||
...
|
||||
|
||||
|
||||
# TODO: bikeshed on this name
|
||||
class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
|
||||
"""
|
||||
Represents an AOTDispatchCompiler that returns an OutputCode, and is
|
||||
therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode.
|
||||
A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of
|
||||
the kwargs in _CompileFxKwargs.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_code_ty: Type[TOutputCode],
|
||||
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
|
||||
):
|
||||
self.output_code_ty = output_code_ty
|
||||
self.compiler_fn = compiler_fn
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
) -> OutputCode:
|
||||
return self.compiler_fn(gm, example_inputs)
|
||||
|
||||
|
||||
def process_inputs(
|
||||
flat_args: List[Any],
|
||||
aot_config: AOTConfig,
|
||||
|
|
@ -953,12 +1007,12 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
|
|||
def aot_module_simplified(
|
||||
mod: nn.Module,
|
||||
args,
|
||||
fw_compiler: Callable,
|
||||
bw_compiler: Optional[Callable] = None,
|
||||
fw_compiler: AOTDispatchCompiler,
|
||||
bw_compiler: Optional[AOTDispatchCompiler] = None,
|
||||
partition_fn: Callable = default_partition,
|
||||
decompositions: Optional[Dict] = None,
|
||||
keep_inference_input_mutations=False,
|
||||
inference_compiler: Optional[Callable] = None,
|
||||
inference_compiler: Optional[AOTDispatchCompiler] = None,
|
||||
cudagraphs: Optional[BoxedBool] = None,
|
||||
) -> nn.Module:
|
||||
"""
|
||||
|
|
@ -1086,8 +1140,8 @@ def aot_module_simplified(
|
|||
# Autograd cache stuff
|
||||
remote = should_use_remote_autograd_cache()
|
||||
local = should_use_local_autograd_cache()
|
||||
|
||||
if local or remote:
|
||||
# We only care if the forward will return an OutputCode.
|
||||
if (local or remote) and isinstance(fw_compiler, SerializableAOTDispatchCompiler):
|
||||
compiled_fn = AOTAutogradCache.load(
|
||||
dispatch_and_compile,
|
||||
mod,
|
||||
|
|
|
|||
|
|
@ -52,7 +52,11 @@ from torch._dynamo.utils import (
|
|||
set_feature_use,
|
||||
)
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
|
||||
from torch._functorch.aot_autograd import (
|
||||
aot_export_module,
|
||||
make_boxed_func,
|
||||
SerializableAOTDispatchCompiler,
|
||||
)
|
||||
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
|
||||
from torch._inductor.debug import save_args_for_compile_fx_inner
|
||||
|
|
@ -1663,20 +1667,20 @@ def compile_fx(
|
|||
)
|
||||
|
||||
def fw_compiler_base(
|
||||
model: GraphModule,
|
||||
example_inputs: List[InputType],
|
||||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
is_inference: bool,
|
||||
) -> OutputCode:
|
||||
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
|
||||
if is_inference:
|
||||
# partition_fn won't be called
|
||||
_recursive_joint_graph_passes(model)
|
||||
_recursive_joint_graph_passes(gm)
|
||||
|
||||
fixed = torch._inductor.utils.num_fw_fixed_arguments(
|
||||
num_example_inputs, len(example_inputs)
|
||||
)
|
||||
|
||||
model_outputs_node = output_node(model)
|
||||
model_outputs_node = output_node(gm)
|
||||
if config.keep_output_stride:
|
||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||
num_model_outputs = len(model_outputs)
|
||||
|
|
@ -1733,7 +1737,7 @@ def compile_fx(
|
|||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||
|
||||
return inner_compile(
|
||||
model,
|
||||
gm,
|
||||
example_inputs,
|
||||
static_input_idxs=get_static_input_idxs(fixed),
|
||||
cudagraphs=cudagraphs,
|
||||
|
|
@ -1742,7 +1746,10 @@ def compile_fx(
|
|||
boxed_forward_device_index=forward_device,
|
||||
)
|
||||
|
||||
fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
|
||||
fw_compiler: Callable[
|
||||
[GraphModule, Sequence[InputType]], OutputCode
|
||||
] = functools.partial(fw_compiler_base, is_inference=False)
|
||||
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
|
||||
|
||||
if config.freezing and not torch.is_grad_enabled():
|
||||
inference_compiler: Callable[..., Any] = functools.partial(
|
||||
|
|
@ -1756,6 +1763,9 @@ def compile_fx(
|
|||
)
|
||||
else:
|
||||
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
|
||||
inference_compiler = SerializableAOTDispatchCompiler(
|
||||
OutputCode, inference_compiler
|
||||
)
|
||||
|
||||
def partition_fn(
|
||||
gm: GraphModule,
|
||||
|
|
@ -1771,14 +1781,14 @@ def compile_fx(
|
|||
|
||||
@compile_time_strobelight_meta(phase_name="backward")
|
||||
def bw_compiler(
|
||||
model: GraphModule, example_inputs: List[InputType]
|
||||
gm: GraphModule, example_inputs: Sequence[InputType]
|
||||
) -> OutputCode:
|
||||
from torch._dynamo.convert_frame import compile_lock
|
||||
|
||||
with dynamo_utils.dynamo_timed(
|
||||
"compile_fx.<locals>.bw_compiler"
|
||||
), compile_lock:
|
||||
model_outputs_node = output_node(model)
|
||||
model_outputs_node = output_node(gm)
|
||||
if config.bw_outputs_user_visible:
|
||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = [
|
||||
|
|
@ -1789,12 +1799,12 @@ def compile_fx(
|
|||
else:
|
||||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||
|
||||
fixed = count_tangents(model)
|
||||
fixed = count_tangents(gm)
|
||||
with config.patch(
|
||||
get_cpp_wrapper_config()
|
||||
) if config.cpp_wrapper else contextlib.nullcontext():
|
||||
return inner_compile(
|
||||
model,
|
||||
gm,
|
||||
example_inputs,
|
||||
static_input_idxs=list(range(fixed)),
|
||||
cudagraphs=cudagraphs,
|
||||
|
|
@ -1803,6 +1813,8 @@ def compile_fx(
|
|||
boxed_forward_device_index=forward_device,
|
||||
)
|
||||
|
||||
bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler)
|
||||
|
||||
fake_mode = detect_fake_mode(
|
||||
example_inputs_
|
||||
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user