[REFACTOR] Inline FxGraphCache.post_compile into sole call site (#141877)

I am going to break apart the arguments passed to the constituents
to only pass exactly what is needed, so easy access to the insides
is helpful here.

This also moves two helper functions to output_code.py as well.

Also set _boxed_call at constructor.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141877
Approved by: https://github.com/jamesjwu, https://github.com/jansel

Co-authored-by: James Wu <jjwu@meta.com>
This commit is contained in:
Edward Z. Yang 2024-12-03 18:06:09 -08:00 committed by PyTorch MergeBot
parent f85e238186
commit 7666c8263a
4 changed files with 191 additions and 188 deletions

View File

@ -44,6 +44,7 @@ from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
from torch._functorch.aot_autograd import aot_export_joint_simple, aot_export_module
from torch._higher_order_ops.out_dtype import out_dtype
from torch._inductor.codecache import compiled_fx_graph_hash
from torch._inductor.output_code import MockFXGraphCacheOutput
from torch._subclasses.fake_tensor import DynamicOutputShapeException, FakeTensorMode
from torch.fx.experimental.proxy_tensor import is_sym_node
from torch.fx.experimental.symbolic_shapes import GuardOnDataDependentSymNode, ShapeEnv
@ -6768,26 +6769,21 @@ class MockFXGraphCache:
self.cache[key] = gm
def load(self, gm, inputs):
key, _ = compiled_fx_graph_hash(gm, inputs, {}, {})
if key in self.cache:
gm = make_boxed_func(gm)
gm._fx_graph_cache_key = key
return gm
else:
self.save(key, gm)
gm = make_boxed_func(gm)
gm._fx_graph_cache_key = key
key, _ = compiled_fx_graph_hash(gm, inputs, {}, [])
if key not in self.cache:
self.cache[key] = gm
gm, _ = self.load_with_key(key, [], inputs, None, None, None)
return gm
def load_with_key(self, key, debug_lines, inputs, local, remote_cache, is_backward):
gm = self.cache.get(key)
if gm is not None:
gm = make_boxed_func(gm)
gm = MockFXGraphCacheOutput(gm)
gm._fx_graph_cache_key = key
gm._time_taken_ns = 0
return gm, {}
def post_compile(self, gm, inputs, cudagraphs):
return gm
# The following tests fail in strict caching mode (i.e. they bypass or
# cache miss instead of cache hitting). They will be fixed in the PRs above this.
@ -6859,9 +6855,6 @@ class TestAOTAutogradWithCache(TestAOTAutogradWithDynamo):
with patch(
"torch._inductor.codecache.FxGraphCache.load_with_key",
new=self.inductor_cache.load_with_key,
), patch(
"torch._inductor.codecache.FxGraphCache.post_compile",
new=self.inductor_cache.post_compile,
):
return super().verify_aot_autograd(
f,

View File

@ -354,8 +354,9 @@ class FXGraphCacheLoadable:
payload_fn=lambda: json.dumps(cache_info),
)
FxGraphCache.post_compile(result, example_inputs, fx_config["cudagraphs"]) # type: ignore[arg-type]
result._boxed_call = True
# TODO: How come cudagraphs could be None here?
# TODO: How come gm is None here?
result.post_compile(example_inputs, fx_config["cudagraphs"], None) # type: ignore[arg-type]
return result

View File

@ -100,7 +100,6 @@ from torch._inductor.cpp_builder import (
normalize_path_separator,
)
from torch._inductor.cpu_vec_isa import pick_vec_isa
from torch._inductor.cudagraph_utils import log_cudagraph_skip_and_bump_counter
from torch._inductor.runtime.compile_tasks import (
_module_to_triton_kernel,
_reload_python_module,
@ -109,12 +108,9 @@ from torch._inductor.runtime.compile_tasks import (
from torch._inductor.runtime.runtime_utils import cache_dir, default_cache_dir
from torch._inductor.utils import (
ALIGN_BYTES,
align_inputs_from_check_idxs,
BoxedBool,
clear_on_fresh_inductor_cache,
is_linux,
is_windows,
set_tracing_context_output_strides,
)
from torch._logging import trace_structured
from torch._subclasses.fake_tensor import (
@ -908,117 +904,6 @@ def compiled_fx_graph_hash(
return key, debug_lines
def cudagraph_post_compile(
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule],
) -> None:
"""
Checks for any reasons not to run cudagraphs and then
runs it on compiled_graph.
Mutates the `compiled_graph.current_callable` and `cudagraphs`
"""
assert compiled_graph.current_callable is not None
assert compiled_graph.cudagraph_info is not None
cached_info = compiled_graph.cudagraph_info
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
inputs_to_check = compiled_graph.inputs_to_check
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
is_inference = compiled_graph.fx_kwargs["is_inference"]
is_backward = compiled_graph.fx_kwargs["is_backward"]
if not cudagraph_fail_reasons:
fx_kwargs = compiled_graph.fx_kwargs
static_input_idxs = fx_kwargs["static_input_idxs"]
placeholders = cached_info.placeholders
stack_traces = cached_info.stack_traces
if not config.triton.cudagraph_trees:
# Force specialize all inputs so that CUDA graphs will work
for t in example_inputs:
if isinstance(t, torch.SymInt):
int(t) # guard
if (
boxed_forward_device_index is not None
and not is_inference
and not is_backward
):
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
from .compile_fx import cudagraphify
current_callable = compiled_graph.current_callable
assert current_callable is not None
compiled_graph.current_callable = cudagraphify(
current_callable,
static_input_idxs=static_input_idxs or (),
device_index=next(iter(compiled_graph.device_idxs)),
stack_traces=stack_traces,
is_backward=is_backward,
is_inference=is_inference,
constants=tuple(compiled_graph.get_constants(gm).values()),
placeholders=placeholders,
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
)
else:
BoxedBool.disable(cudagraphs)
# See [Backward Generation Handling]
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
# know we are we running the backward even if we will not run it in cudagraphs
if is_backward and config.triton.cudagraph_trees:
assert boxed_forward_device_index is not None
assert boxed_forward_device_index.value is not None
compiled_graph_callable = compiled_graph.current_callable
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_forward_device_index.value, create_if_none_exists=False
)
# should already exist from forward
assert manager is not None
def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]:
manager.set_to_running_backward() # type: ignore[union-attr]
return compiled_graph_callable(new_inputs)
compiled_graph.current_callable = compiled_artifact
if "cuda" in compiled_graph.device_types:
# prefer better disable_cudagraphs_reason bc stack trace
# TODO: migrate all disable reasons to stack trace, refactor
if compiled_graph.disabled_cudagraphs_reason:
log_cudagraph_skip_and_bump_counter(
compiled_graph.disabled_cudagraphs_reason
)
else:
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {cudagraph_fail_reasons}"
)
def maybe_realign_inputs(
ran_cudagraphs: BoxedBool,
compiled_graph: CompiledFxGraph,
inputs_to_check: Sequence[int],
) -> None:
"""
Realigns input strides from inputs_to_check if
we didn't end up running cudagraphs. Mutates
`compiled_graph.current_callable` if cudagraphs
was run. Otherwise, does nothing.
"""
if not ran_cudagraphs:
assert compiled_graph.current_callable is not None
new_callable = align_inputs_from_check_idxs(
compiled_graph.current_callable, inputs_to_check
)
if new_callable is not compiled_graph.current_callable:
compiled_graph.current_callable = new_callable
def add_ephemeral_timeout_increase_for_distributed(time_saved_ns: int) -> int:
"""
Ephemerally increases the NCCL timeout when compiling for a distributed job
@ -1236,54 +1121,6 @@ class FxGraphCache:
)
return graph, cache_info
@staticmethod
def post_compile(
compiled_graph: CompiledFxGraph,
example_inputs: Sequence[InputType],
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule] = None,
) -> CompiledFxGraph:
"""
Run a set of post processing steps after loading from the cache. These involve:
- Setting the tracing context output strides
- Running cudagraphs if enabled
- Realigning inputs
This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph.
The results of this function are *not* saved in the cache itself.
"""
set_tracing_context_output_strides(example_inputs, compiled_graph)
if cudagraphs:
# It's possible that cudagraphs is enabled, but was disabled
# during a previous compilation we're loading from the cache.
# If so, we need to disable it on this new process too.
if compiled_graph.disabled_cudagraphs_reason:
if "cuda" in compiled_graph.device_types:
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {compiled_graph.disabled_cudagraphs_reason}"
)
else:
counters["inductor"]["cudagraph_skips"] += 1
BoxedBool.disable(cudagraphs)
else:
cudagraph_post_compile(
example_inputs,
compiled_graph,
cudagraphs,
gm,
)
inputs_to_check = compiled_graph.inputs_to_check
# cudagraphs could have been disabled from the earlier conditions
# so we still need to realign inputs if that happens
maybe_realign_inputs(
cudagraphs,
compiled_graph,
inputs_to_check,
)
return compiled_graph
@staticmethod
def _save_graph(
key: str,

View File

@ -50,10 +50,16 @@ from torch._inductor.cudagraph_utils import (
get_placeholder_info,
log_cudagraph_skip_and_bump_counter,
)
from torch._inductor.utils import (
align_inputs_from_check_idxs,
BoxedBool,
InputType,
output_node,
set_tracing_context_output_strides,
)
from . import config
from .runtime.autotune_cache import AutotuneCacheBundler
from .utils import BoxedBool, InputType, output_node
if TYPE_CHECKING:
@ -138,6 +144,117 @@ def complex_memory_overlap(t: torch.Tensor) -> bool:
return False
def cudagraph_post_compile(
example_inputs: Sequence[InputType],
compiled_graph: CompiledFxGraph,
cudagraphs: BoxedBool,
gm: Optional[torch.fx.GraphModule],
) -> None:
"""
Checks for any reasons not to run cudagraphs and then
runs it on compiled_graph.
Mutates the `compiled_graph.current_callable` and `cudagraphs`
"""
assert compiled_graph.current_callable is not None
assert compiled_graph.cudagraph_info is not None
cached_info = compiled_graph.cudagraph_info
cudagraph_fail_reasons = cached_info.cudagraph_fail_reasons
inputs_to_check = compiled_graph.inputs_to_check
boxed_forward_device_index = compiled_graph.boxed_forward_device_index
is_inference = compiled_graph.fx_kwargs["is_inference"]
is_backward = compiled_graph.fx_kwargs["is_backward"]
if not cudagraph_fail_reasons:
fx_kwargs = compiled_graph.fx_kwargs
static_input_idxs = fx_kwargs["static_input_idxs"]
placeholders = cached_info.placeholders
stack_traces = cached_info.stack_traces
if not config.triton.cudagraph_trees:
# Force specialize all inputs so that CUDA graphs will work
for t in example_inputs:
if isinstance(t, torch.SymInt):
int(t) # guard
if (
boxed_forward_device_index is not None
and not is_inference
and not is_backward
):
boxed_forward_device_index.set(next(iter(compiled_graph.device_idxs)))
from .compile_fx import cudagraphify
current_callable = compiled_graph.current_callable
assert current_callable is not None
compiled_graph.current_callable = cudagraphify(
current_callable,
static_input_idxs=static_input_idxs or (),
device_index=next(iter(compiled_graph.device_idxs)),
stack_traces=stack_traces,
is_backward=is_backward,
is_inference=is_inference,
constants=tuple(compiled_graph.get_constants(gm).values()),
placeholders=placeholders,
mutated_input_idxs=tuple(compiled_graph.mutated_input_idxs),
)
else:
BoxedBool.disable(cudagraphs)
# See [Backward Generation Handling]
# if cudagraph'd the forward and set the device, we need to let the cudagraph manager
# know we are we running the backward even if we will not run it in cudagraphs
if is_backward and config.triton.cudagraph_trees:
assert boxed_forward_device_index is not None
assert boxed_forward_device_index.value is not None
compiled_graph_callable = compiled_graph.current_callable
manager = torch._inductor.cudagraph_trees.get_manager(
boxed_forward_device_index.value, create_if_none_exists=False
)
# should already exist from forward
assert manager is not None
def compiled_artifact(new_inputs: List[Any]) -> Callable[..., Any]:
manager.set_to_running_backward() # type: ignore[union-attr]
return compiled_graph_callable(new_inputs)
compiled_graph.current_callable = compiled_artifact
if "cuda" in compiled_graph.device_types:
# prefer better disable_cudagraphs_reason bc stack trace
# TODO: migrate all disable reasons to stack trace, refactor
if compiled_graph.disabled_cudagraphs_reason:
log_cudagraph_skip_and_bump_counter(
compiled_graph.disabled_cudagraphs_reason
)
else:
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {cudagraph_fail_reasons}"
)
def maybe_realign_inputs(
ran_cudagraphs: BoxedBool,
compiled_graph: CompiledFxGraph,
inputs_to_check: Sequence[int],
) -> None:
"""
Realigns input strides from inputs_to_check if
we didn't end up running cudagraphs. Mutates
`compiled_graph.current_callable` if cudagraphs
was run. Otherwise, does nothing.
"""
if not ran_cudagraphs:
assert compiled_graph.current_callable is not None
new_callable = align_inputs_from_check_idxs(
compiled_graph.current_callable, inputs_to_check
)
if new_callable is not compiled_graph.current_callable:
compiled_graph.current_callable = new_callable
@dataclasses.dataclass
class CompiledFxGraph(OutputCode):
"""
@ -295,6 +412,9 @@ class CompiledFxGraph(OutputCode):
# TODO: should this be part of fx_kwargs
self.boxed_forward_device_index = boxed_forward_device_index
# aot autograd needs to know to pass in inputs as a list
self._boxed_call = True
def __call__(self, inputs: Sequence[Any]) -> Any:
assert self.current_callable is not None
try:
@ -308,14 +428,44 @@ class CompiledFxGraph(OutputCode):
cudagraphs: BoxedBool,
gm: GraphModule,
) -> None:
# TODO: maybe move this here? Not sure.
from torch._inductor.codecache import FxGraphCache
"""
Run a set of post processing steps after loading from the cache. These involve:
- Setting the tracing context output strides
- Running cudagraphs if enabled
- Realigning inputs
FxGraphCache.post_compile(self, example_inputs, cudagraphs, gm)
This runs whether or not we have a cache hit, and always runs directly after we get a CompiledFxGraph.
The results of this function are *not* saved in the cache itself.
"""
set_tracing_context_output_strides(example_inputs, self)
# aot autograd needs to know to pass in inputs as a list
# TODO: Not sure why this isn't just set by default on CompiledFxGraph
self._boxed_call = True
if cudagraphs:
# It's possible that cudagraphs is enabled, but was disabled
# during a previous compilation we're loading from the cache.
# If so, we need to disable it on this new process too.
if self.disabled_cudagraphs_reason:
if "cuda" in self.device_types:
log_cudagraph_skip_and_bump_counter(
f"skipping cudagraphs due to {self.disabled_cudagraphs_reason}"
)
else:
counters["inductor"]["cudagraph_skips"] += 1
BoxedBool.disable(cudagraphs)
else:
cudagraph_post_compile(
example_inputs,
self,
cudagraphs,
gm,
)
inputs_to_check = self.inputs_to_check
# cudagraphs could have been disabled from the earlier conditions
# so we still need to realign inputs if that happens
maybe_realign_inputs(
cudagraphs,
self,
inputs_to_check,
)
def set_triton_bundle(self, triton_bundle: Any) -> None:
self._triton_bundle = triton_bundle
@ -428,3 +578,25 @@ class CompiledAOTI(OutputCode):
def _typecheck_CompiledAOTI(h: CompiledAOTI) -> OutputCode:
return h
@dataclasses.dataclass
class MockFXGraphCacheOutput(OutputCode):
gm: Any = None
def __post_init__(self) -> None:
self._boxed_call = True
def post_compile(
self,
example_inputs: Sequence[InputType],
cudagraphs: BoxedBool,
gm: GraphModule,
) -> None:
pass
def __call__(self, inputs: Sequence[Any]) -> Any:
return self.gm(inputs)
def set_triton_bundle(self, triton_bundle: Any) -> None:
pass