mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[REFACTOR] Implement AOTDispatchCompiler wrapper (#142205)
This implements a new wrapper class AOTDispatchCompiler wrapper, which is just a wrapper around a callable that returns an OutputCode. We can then use it in AOTDispatch to decide whether or not to use the cache: if fw_compiler, bw_compiler and inference_compiler are all AOTDispatchCompilers, then we enable caching. This type is pretty close to _CompiledFxGraphCallable, except it's not allowed to take any kwargs. Not sure how to consolidate the two ideas together just yet: unfortunately, there's no way to properly annotate the types to make them related. But a lot of the time, the input to this function will be a partially applied _CompiledFxGraphCallable. This allows the PR above this one to enable AOTAutogradCache everywhere, but not increase instruction count or enable cache on unit tests that use aot_eager or other non inductor compilers. Pull Request resolved: https://github.com/pytorch/pytorch/pull/142205 Approved by: https://github.com/oulgen, https://github.com/bdhirsh
This commit is contained in:
parent
5663ad99e7
commit
6e203ae6de
|
|
@ -41,7 +41,11 @@ from functorch.compile import (
|
||||||
from functorch.experimental import control_flow
|
from functorch.experimental import control_flow
|
||||||
from torch._decomp import decomposition_table
|
from torch._decomp import decomposition_table
|
||||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||||
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
|
from torch._functorch.aot_autograd import (
|
||||||
|
aot_export_joint_simple,
|
||||||
|
aot_export_module,
|
||||||
|
SerializableAOTDispatchCompiler,
|
||||||
|
)
|
||||||
from torch._higher_order_ops.out_dtype import out_dtype
|
from torch._higher_order_ops.out_dtype import out_dtype
|
||||||
from torch._inductor.codecache import compiled_fx_graph_hash
|
from torch._inductor.codecache import compiled_fx_graph_hash
|
||||||
from torch._inductor.output_code import MockFXGraphCacheOutput
|
from torch._inductor.output_code import MockFXGraphCacheOutput
|
||||||
|
|
@ -6831,12 +6835,13 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
|
||||||
def make_compiler(self, fw_graph_cell):
|
def make_compiler(self, fw_graph_cell):
|
||||||
mock_inductor_cache = self.inductor_cache
|
mock_inductor_cache = self.inductor_cache
|
||||||
|
|
||||||
def compiler(gm, inputs):
|
def compiler(gm, example_inputs):
|
||||||
nonlocal mock_inductor_cache, fw_graph_cell
|
nonlocal mock_inductor_cache, fw_graph_cell
|
||||||
result = mock_inductor_cache.load(gm, inputs)
|
result = mock_inductor_cache.load(gm, example_inputs)
|
||||||
fw_graph_cell[0] = gm
|
fw_graph_cell[0] = gm
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
compiler = SerializableAOTDispatchCompiler(MockFXGraphCacheOutput, compiler)
|
||||||
return compiler
|
return compiler
|
||||||
|
|
||||||
def run_autograd(
|
def run_autograd(
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,10 @@ import torch
|
||||||
from torch._dynamo import disable
|
from torch._dynamo import disable
|
||||||
from torch._dynamo.exc import TensorifyScalarRestartAnalysis
|
from torch._dynamo.exc import TensorifyScalarRestartAnalysis
|
||||||
from torch._dynamo.utils import counters, defake, flatten_graph_inputs
|
from torch._dynamo.utils import counters, defake, flatten_graph_inputs
|
||||||
from torch._functorch.aot_autograd import aot_module_simplified
|
from torch._functorch.aot_autograd import (
|
||||||
|
aot_module_simplified,
|
||||||
|
SerializableAOTDispatchCompiler,
|
||||||
|
)
|
||||||
from torch.utils._python_dispatch import _disable_current_modes
|
from torch.utils._python_dispatch import _disable_current_modes
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -45,14 +48,21 @@ class AotAutograd:
|
||||||
counters["aot_autograd"]["not_ok"] += 1
|
counters["aot_autograd"]["not_ok"] += 1
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
# OK attempt to compile
|
def wrap_bw_compiler(bw_compiler_fn):
|
||||||
|
|
||||||
def _wrapped_bw_compiler(*args, **kwargs):
|
def _wrapped_bw_compiler(*args, **kwargs):
|
||||||
# stop TorchDynamo from trying to compile our generated backwards pass
|
# stop TorchDynamo from trying to compile our generated backwards pass
|
||||||
return disable(disable(bw_compiler)(*args, **kwargs))
|
return disable(disable(bw_compiler_fn)(*args, **kwargs))
|
||||||
|
|
||||||
|
return _wrapped_bw_compiler
|
||||||
|
|
||||||
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
|
bw_compiler = self.kwargs.get("bw_compiler") or self.kwargs["fw_compiler"]
|
||||||
self.kwargs["bw_compiler"] = _wrapped_bw_compiler
|
|
||||||
|
if isinstance(bw_compiler, SerializableAOTDispatchCompiler):
|
||||||
|
bw_compiler.compiler_fn = wrap_bw_compiler(bw_compiler.compiler_fn)
|
||||||
|
else:
|
||||||
|
bw_compiler = wrap_bw_compiler(bw_compiler)
|
||||||
|
|
||||||
|
self.kwargs["bw_compiler"] = bw_compiler
|
||||||
self.kwargs["inference_compiler"] = (
|
self.kwargs["inference_compiler"] = (
|
||||||
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
|
self.kwargs.get("inference_compiler") or self.kwargs["fw_compiler"]
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -665,6 +665,7 @@ class AOTAutogradCache:
|
||||||
"""
|
"""
|
||||||
Load a result from the cache, and reconstruct a runtime wrapper around the object
|
Load a result from the cache, and reconstruct a runtime wrapper around the object
|
||||||
"""
|
"""
|
||||||
|
|
||||||
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
|
gm = mod.gm if isinstance(mod, torch._dynamo.utils.GmWrapper) else mod
|
||||||
with sanitize_gm_for_cache(gm):
|
with sanitize_gm_for_cache(gm):
|
||||||
compiled_fn = None
|
compiled_fn = None
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,19 @@
|
||||||
import itertools
|
import itertools
|
||||||
from contextlib import contextmanager, nullcontext
|
from contextlib import contextmanager, nullcontext
|
||||||
from functools import partial, wraps
|
from functools import partial, wraps
|
||||||
from typing import Any, Callable, Dict, List, NewType, Optional, Tuple
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
NewType,
|
||||||
|
Optional,
|
||||||
|
Protocol,
|
||||||
|
Sequence,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
)
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
@ -21,7 +33,8 @@ from torch._dynamo.utils import (
|
||||||
preserve_rng_state,
|
preserve_rng_state,
|
||||||
)
|
)
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
from torch._inductor.utils import BoxedBool
|
from torch._inductor.output_code import OutputCode
|
||||||
|
from torch._inductor.utils import BoxedBool, InputType
|
||||||
from torch._subclasses import FakeTensor, FakeTensorMode
|
from torch._subclasses import FakeTensor, FakeTensorMode
|
||||||
from torch.fx.experimental.proxy_tensor import make_fx
|
from torch.fx.experimental.proxy_tensor import make_fx
|
||||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||||
|
|
@ -435,6 +448,47 @@ aot_autograd_decompositions = {}
|
||||||
FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any])
|
FakifiedFlatArgs = NewType("FakifiedFlatArgs", List[Any])
|
||||||
|
|
||||||
|
|
||||||
|
TOutputCode = TypeVar("TOutputCode", bound=OutputCode)
|
||||||
|
|
||||||
|
|
||||||
|
class AOTDispatchCompiler(Protocol):
|
||||||
|
"""
|
||||||
|
Represents a fw or bw_compiler passed to AOTAutograd.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
gm: torch.fx.GraphModule,
|
||||||
|
example_inputs: Sequence[InputType],
|
||||||
|
) -> Any:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: bikeshed on this name
|
||||||
|
class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
|
||||||
|
"""
|
||||||
|
Represents an AOTDispatchCompiler that returns an OutputCode, and is
|
||||||
|
therefore cacheable. SerializableAOTDispatchCompiler always return an OutputCode.
|
||||||
|
A _CompileFxCallable usually gets converted into an AOTDispatchCompiler after binding all of
|
||||||
|
the kwargs in _CompileFxKwargs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
output_code_ty: Type[TOutputCode],
|
||||||
|
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
|
||||||
|
):
|
||||||
|
self.output_code_ty = output_code_ty
|
||||||
|
self.compiler_fn = compiler_fn
|
||||||
|
|
||||||
|
def __call__(
|
||||||
|
self,
|
||||||
|
gm: torch.fx.GraphModule,
|
||||||
|
example_inputs: Sequence[InputType],
|
||||||
|
) -> OutputCode:
|
||||||
|
return self.compiler_fn(gm, example_inputs)
|
||||||
|
|
||||||
|
|
||||||
def process_inputs(
|
def process_inputs(
|
||||||
flat_args: List[Any],
|
flat_args: List[Any],
|
||||||
aot_config: AOTConfig,
|
aot_config: AOTConfig,
|
||||||
|
|
@ -953,12 +1007,12 @@ def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
|
||||||
def aot_module_simplified(
|
def aot_module_simplified(
|
||||||
mod: nn.Module,
|
mod: nn.Module,
|
||||||
args,
|
args,
|
||||||
fw_compiler: Callable,
|
fw_compiler: AOTDispatchCompiler,
|
||||||
bw_compiler: Optional[Callable] = None,
|
bw_compiler: Optional[AOTDispatchCompiler] = None,
|
||||||
partition_fn: Callable = default_partition,
|
partition_fn: Callable = default_partition,
|
||||||
decompositions: Optional[Dict] = None,
|
decompositions: Optional[Dict] = None,
|
||||||
keep_inference_input_mutations=False,
|
keep_inference_input_mutations=False,
|
||||||
inference_compiler: Optional[Callable] = None,
|
inference_compiler: Optional[AOTDispatchCompiler] = None,
|
||||||
cudagraphs: Optional[BoxedBool] = None,
|
cudagraphs: Optional[BoxedBool] = None,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
"""
|
"""
|
||||||
|
|
@ -1086,8 +1140,8 @@ def aot_module_simplified(
|
||||||
# Autograd cache stuff
|
# Autograd cache stuff
|
||||||
remote = should_use_remote_autograd_cache()
|
remote = should_use_remote_autograd_cache()
|
||||||
local = should_use_local_autograd_cache()
|
local = should_use_local_autograd_cache()
|
||||||
|
# We only care if the forward will return an OutputCode.
|
||||||
if local or remote:
|
if (local or remote) and isinstance(fw_compiler, SerializableAOTDispatchCompiler):
|
||||||
compiled_fn = AOTAutogradCache.load(
|
compiled_fn = AOTAutogradCache.load(
|
||||||
dispatch_and_compile,
|
dispatch_and_compile,
|
||||||
mod,
|
mod,
|
||||||
|
|
|
||||||
|
|
@ -52,7 +52,11 @@ from torch._dynamo.utils import (
|
||||||
set_feature_use,
|
set_feature_use,
|
||||||
)
|
)
|
||||||
from torch._functorch import config as functorch_config
|
from torch._functorch import config as functorch_config
|
||||||
from torch._functorch.aot_autograd import aot_export_module, make_boxed_func
|
from torch._functorch.aot_autograd import (
|
||||||
|
aot_export_module,
|
||||||
|
make_boxed_func,
|
||||||
|
SerializableAOTDispatchCompiler,
|
||||||
|
)
|
||||||
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
|
from torch._inductor.codecache import code_hash, FxGraphCache, output_code_log
|
||||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
|
from torch._inductor.cudagraph_utils import BoxedDeviceIndex, PlaceholderInfo
|
||||||
from torch._inductor.debug import save_args_for_compile_fx_inner
|
from torch._inductor.debug import save_args_for_compile_fx_inner
|
||||||
|
|
@ -1663,20 +1667,20 @@ def compile_fx(
|
||||||
)
|
)
|
||||||
|
|
||||||
def fw_compiler_base(
|
def fw_compiler_base(
|
||||||
model: GraphModule,
|
gm: GraphModule,
|
||||||
example_inputs: List[InputType],
|
example_inputs: Sequence[InputType],
|
||||||
is_inference: bool,
|
is_inference: bool,
|
||||||
) -> OutputCode:
|
) -> OutputCode:
|
||||||
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
|
with dynamo_utils.dynamo_timed("compile_fx.<locals>.fw_compiler_base"):
|
||||||
if is_inference:
|
if is_inference:
|
||||||
# partition_fn won't be called
|
# partition_fn won't be called
|
||||||
_recursive_joint_graph_passes(model)
|
_recursive_joint_graph_passes(gm)
|
||||||
|
|
||||||
fixed = torch._inductor.utils.num_fw_fixed_arguments(
|
fixed = torch._inductor.utils.num_fw_fixed_arguments(
|
||||||
num_example_inputs, len(example_inputs)
|
num_example_inputs, len(example_inputs)
|
||||||
)
|
)
|
||||||
|
|
||||||
model_outputs_node = output_node(model)
|
model_outputs_node = output_node(gm)
|
||||||
if config.keep_output_stride:
|
if config.keep_output_stride:
|
||||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||||
num_model_outputs = len(model_outputs)
|
num_model_outputs = len(model_outputs)
|
||||||
|
|
@ -1733,7 +1737,7 @@ def compile_fx(
|
||||||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||||
|
|
||||||
return inner_compile(
|
return inner_compile(
|
||||||
model,
|
gm,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
static_input_idxs=get_static_input_idxs(fixed),
|
static_input_idxs=get_static_input_idxs(fixed),
|
||||||
cudagraphs=cudagraphs,
|
cudagraphs=cudagraphs,
|
||||||
|
|
@ -1742,7 +1746,10 @@ def compile_fx(
|
||||||
boxed_forward_device_index=forward_device,
|
boxed_forward_device_index=forward_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
fw_compiler = functools.partial(fw_compiler_base, is_inference=False)
|
fw_compiler: Callable[
|
||||||
|
[GraphModule, Sequence[InputType]], OutputCode
|
||||||
|
] = functools.partial(fw_compiler_base, is_inference=False)
|
||||||
|
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
|
||||||
|
|
||||||
if config.freezing and not torch.is_grad_enabled():
|
if config.freezing and not torch.is_grad_enabled():
|
||||||
inference_compiler: Callable[..., Any] = functools.partial(
|
inference_compiler: Callable[..., Any] = functools.partial(
|
||||||
|
|
@ -1756,6 +1763,9 @@ def compile_fx(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
|
inference_compiler = functools.partial(fw_compiler_base, is_inference=True)
|
||||||
|
inference_compiler = SerializableAOTDispatchCompiler(
|
||||||
|
OutputCode, inference_compiler
|
||||||
|
)
|
||||||
|
|
||||||
def partition_fn(
|
def partition_fn(
|
||||||
gm: GraphModule,
|
gm: GraphModule,
|
||||||
|
|
@ -1771,14 +1781,14 @@ def compile_fx(
|
||||||
|
|
||||||
@compile_time_strobelight_meta(phase_name="backward")
|
@compile_time_strobelight_meta(phase_name="backward")
|
||||||
def bw_compiler(
|
def bw_compiler(
|
||||||
model: GraphModule, example_inputs: List[InputType]
|
gm: GraphModule, example_inputs: Sequence[InputType]
|
||||||
) -> OutputCode:
|
) -> OutputCode:
|
||||||
from torch._dynamo.convert_frame import compile_lock
|
from torch._dynamo.convert_frame import compile_lock
|
||||||
|
|
||||||
with dynamo_utils.dynamo_timed(
|
with dynamo_utils.dynamo_timed(
|
||||||
"compile_fx.<locals>.bw_compiler"
|
"compile_fx.<locals>.bw_compiler"
|
||||||
), compile_lock:
|
), compile_lock:
|
||||||
model_outputs_node = output_node(model)
|
model_outputs_node = output_node(gm)
|
||||||
if config.bw_outputs_user_visible:
|
if config.bw_outputs_user_visible:
|
||||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||||
model_outputs_node.meta["user_visible_output_idxs"] = [
|
model_outputs_node.meta["user_visible_output_idxs"] = [
|
||||||
|
|
@ -1789,12 +1799,12 @@ def compile_fx(
|
||||||
else:
|
else:
|
||||||
model_outputs_node.meta["user_visible_output_idxs"] = []
|
model_outputs_node.meta["user_visible_output_idxs"] = []
|
||||||
|
|
||||||
fixed = count_tangents(model)
|
fixed = count_tangents(gm)
|
||||||
with config.patch(
|
with config.patch(
|
||||||
get_cpp_wrapper_config()
|
get_cpp_wrapper_config()
|
||||||
) if config.cpp_wrapper else contextlib.nullcontext():
|
) if config.cpp_wrapper else contextlib.nullcontext():
|
||||||
return inner_compile(
|
return inner_compile(
|
||||||
model,
|
gm,
|
||||||
example_inputs,
|
example_inputs,
|
||||||
static_input_idxs=list(range(fixed)),
|
static_input_idxs=list(range(fixed)),
|
||||||
cudagraphs=cudagraphs,
|
cudagraphs=cudagraphs,
|
||||||
|
|
@ -1803,6 +1813,8 @@ def compile_fx(
|
||||||
boxed_forward_device_index=forward_device,
|
boxed_forward_device_index=forward_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
bw_compiler = SerializableAOTDispatchCompiler(OutputCode, bw_compiler)
|
||||||
|
|
||||||
fake_mode = detect_fake_mode(
|
fake_mode = detect_fake_mode(
|
||||||
example_inputs_
|
example_inputs_
|
||||||
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
) or torch._subclasses.FakeTensorMode(allow_non_fake_inputs=True)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user