[dynamo] Refactor convert_frame.compile_frame to be self contained function. [5/n] (#160900)

convert_frame.compile_frame used to take a callback transform function which will capture the frame object it has, but the frame information is not passed directly into compile_frame function.

This PR changes the signature of compile_frame so that frame information is directly passed in the function without taking a callback. This makes it easier to build fullgraph capture API on top of compile_frame.
@exported-using-ghexport

Differential Revision: [D80469801](https://our.internmc.facebook.com/intern/diff/D80469801/)

Differential Revision: [D80469801](https://our.internmc.facebook.com/intern/diff/D80469801)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/160900
Approved by: https://github.com/tugsbayasgalan, https://github.com/anijain2305
This commit is contained in:
zhxchen17 2025-08-24 19:54:48 -07:00 committed by PyTorch MergeBot
parent 40c0e700a4
commit 1113e7de30
5 changed files with 236 additions and 165 deletions

View File

@ -1,6 +1,7 @@
# Owner(s): ["module: dynamo"]
# ruff: noqa: F841
import abc
import builtins
import collections
import collections.abc
import copy
@ -8549,47 +8550,52 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)")
def test_fullgraph_capture(self):
from torch._dynamo.convert_frame import (
FrameInfo,
fullgraph_capture,
get_compile_id,
)
from torch._dynamo.utils import dynamo_timed, get_metrics_context
from torch._guards import compile_context, CompileContext
def foo(x):
return x + x.shape[0]
compiled_foo = torch._dynamo.eval_frame.fullgraph_capture(foo)
compiled_foo(torch.randn(3, 2))
compiled_foo(torch.randn(4))
artifacts = compiled_foo.get_artifacts()
guarded_codes = artifacts.dynamo_artifacts.guarded_codes
backend_ids = list(artifacts.backend_inputs.keys())
gms = [b.graph_module for b in artifacts.backend_inputs.values()]
def _convert_to_ep_demo(code, backend_id, gm, args):
# Inject compiled function as the original gm
new_globals = copy.copy(globals())
new_globals[backend_id] = gm
# Minimal boilerplate to setup a callable.
SerializedCode = type(code.dynamo_code)
dynamo_bytecode = SerializedCode.to_code_object(code.dynamo_code)
guards_state = pickle.loads(code.guards_state)
guard_manager = torch._dynamo.guards.CheckFunctionManager(
foo.__code__,
guards_state.output_graph,
shape_code_parts=guards_state.shape_code_parts,
runtime_global_scope=new_globals,
).guard_manager
class ModuleForExport(torch.nn.Module):
def forward(self, x):
return types.FunctionType(dynamo_bytecode, new_globals)(x)
m = ModuleForExport()
return guard_manager, torch.export.export(m, args)
guards0, ep0 = _convert_to_ep_demo(
guarded_codes[0], backend_ids[0], gms[0], (torch.randn(3, 2),)
x = torch.randn(4, 3)
f_locals = {"x": x}
with (
compile_context(CompileContext(get_compile_id({}))),
dynamo_timed(""),
get_metrics_context(),
):
capture_output = fullgraph_capture(
FrameInfo(
foo.__code__,
foo.__globals__,
f_locals,
builtins,
(),
)
)
dynamo_output = capture_output.dynamo_output
backend_input = capture_output.backend_input
self.assertTrue(
dynamo_output.build_guards(foo.__code__).guard_manager.check(f_locals)
)
import_sources = {
alias: importlib.import_module(module_name)
for alias, module_name in dynamo_output.tracer_output.output_graph.import_sources.items()
}
self.assertEqual(
foo(x),
types.FunctionType(
dynamo_output.bytecode,
{
**import_sources,
backend_input.backend_id: backend_input.graph_module,
},
)(x),
)
self.assertTrue(guards0.check({"x": torch.randn(3, 2)}))
self.assertFalse(guards0.check({"x": torch.randn(4)}))
input0 = torch.randn(3, 2)
self.assertEqual(ep0.module()(input0), foo(input0))
def test_torch_guards_stack_frame_register_inlining_deep(self):
x = torch.tensor([0.5, 0.5])

View File

@ -836,16 +836,180 @@ def trace_frame(
@dataclass
class DynamoOutput:
"""
Represents the core data returned from a single dynamo run, including:
- Guards, wrapped inside tracer_output.output_graph.guards
- Generated bytecode
- Other information needed for compilation.
This data structure should capture all the "interesting" information dynamo
produces on the frontend side before it enters user backend.
"""
tracer_output: DynamoTracerOutput
bytecode: types.CodeType
last_attempt_start_time: Optional[float]
def build_guards(
self,
code: types.CodeType,
hooks: Optional[Hooks] = None,
save: bool = False,
cache_entry: Optional[CacheEntry] = None,
) -> CheckFunctionManager:
assert self.tracer_output.output_graph is not None
return CheckFunctionManager(
code,
self.tracer_output.output_graph,
cache_entry,
hooks.guard_fail_fn if hooks else None,
hooks.guard_filter_fn if hooks else None,
save_guards=save,
)
@dataclass
class BackendInput:
"""
Represents core data structure that dynamo will pass to a backend, including:
- Graph module
- Example inputs
- The FakeTensorMode used for compiling graph.
This data structure should capture all the information dynamo produces
on for the user backend.
"""
backend_id: str
graph_module: torch.fx.GraphModule
example_inputs: Any
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
@dataclass
class CaptureOutput:
"""
CaptureOutput should represent all the information produced from torch
compiler for a single graph capture. This intends to be consumed by
various compiler frontends so that we can share as much compiler internals
as possible and avoid great divergence between different stacks.
This data structure should eventually contain all the information compiler
produces as more refactors happens to converge different compiler
frontends.
"""
dynamo_output: DynamoOutput
backend_input: BackendInput
@dataclass
class FrameInfo:
code: types.CodeType
globals: dict[str, object]
locals: dict[str, object]
builtins: dict[str, object]
closure: tuple[CellType]
def fullgraph_capture(frame: FrameInfo) -> CaptureOutput:
"""
A standalone function which takes a frame and returns dynamo captured graph
plus other important compile information. This should serve as the common
interface for different torch compiler AOT frontengs (e.g. precompile, export).
Note that this function doesn't apply context managers like metrics context
or compile id, and the expectation is that the caller will apply them depending
on the use case.
The CaptureOutput is separated into two parts:
1. Dynamo specific information from DynamoOutput, which includes:
- guards
- generated bytecode
- other information tracked by OutputGraph.
2. Backend specific information (indexed by unique backend id) such as:
- fx graph
- example inputs
"""
from torch._guards import TracingContext
backend_input: Optional[BackendInput] = None
def fullgraph_compiler(
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
) -> torch.fx.GraphModule:
nonlocal backend_input
fake_mode = TracingContext.get().fake_mode
assert fake_mode is not None
assert isinstance(gm.meta["backend_id"], str)
backend_input = BackendInput(
gm.meta["backend_id"], gm, example_inputs, fake_mode
)
return gm
dynamo_output = compile_frame(
frame.code,
frame.globals,
frame.locals,
frame.builtins,
frame.closure,
compiler_fn=fullgraph_compiler,
one_graph=True,
restart_reasons=set(),
)
assert backend_input is not None
return CaptureOutput(dynamo_output, backend_input)
def compile_frame( # type: ignore[return]
code: types.CodeType,
transform: Callable[[list[Instruction], dict[str, Any]], DynamoTracerOutput],
globals: dict[str, object],
locals: dict[str, object],
builtins: dict[str, object],
closure: tuple[CellType],
compiler_fn: CompilerFn,
one_graph: bool,
restart_reasons: set[str],
*,
export: bool = False,
export_constraints: Optional[typing.Never] = None,
frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
distributed_state: Optional[DistributedState] = None,
package: Optional[CompilePackage] = None,
) -> DynamoOutput:
"""
A helper function taking a frame and backend, then return the generated bytecode
and guards as a common data structure.
This is a shared interface for multiple compiler frontends (e.g. torch.compile,
torch.export) that needs to capture a graph out of python code.
"""
# This is shared across restarts
speculation_log = SpeculationLog()
def transform(
instructions: list[Instruction], code_options: dict[str, object]
) -> DynamoTracerOutput:
tf_mode_stack: list[torch.overrides.TorchFunctionMode] = (
torch.overrides._get_current_function_mode_stack()
)
tracer_output = trace_frame(
code,
globals,
locals,
builtins,
closure,
compiler_fn,
tf_mode_stack,
one_graph,
speculation_log,
instructions,
code_options,
export=export,
export_constraints=export_constraints,
frame_state=frame_state,
distributed_state=distributed_state,
package=package,
)
assert tracer_output is not None
return tracer_output
last_attempt_start_time = None
for attempt in itertools.count():
CompileContext.get().attempt = attempt
@ -926,40 +1090,9 @@ def _compile(
# Time spent compiling this frame before restarting or failing analysis
dynamo_time_before_restart: float = 0.0
def transform(
instructions: list[Instruction], code_options: dict[str, object]
) -> DynamoTracerOutput:
tf_mode_stack: list[torch.overrides.TorchFunctionMode] = (
torch.overrides._get_current_function_mode_stack()
)
tracer_output = trace_frame(
code,
globals,
locals,
builtins,
closure,
compiler_fn,
tf_mode_stack,
one_graph,
speculation_log,
instructions,
code_options,
export=export,
export_constraints=export_constraints,
frame_state=frame_state,
distributed_state=distributed_state,
package=package,
)
assert tracer_output is not None
return tracer_output
@compile_time_strobelight_meta(phase_name="compile_inner")
def compile_inner(
code: CodeType,
one_graph: bool,
hooks: Hooks,
transform: Callable[[list[Instruction], dict[str, Any]], Any],
code: CodeType, one_graph: bool, hooks: Hooks
) -> tuple[ConvertFrameReturn, Optional[DynamoTracerOutput]]:
with contextlib.ExitStack() as stack:
stack.enter_context(
@ -968,7 +1101,7 @@ def _compile(
)
)
stack.enter_context(CompileTimeInstructionCounter.record())
return _compile_inner(code, one_graph, hooks, transform)
return _compile_inner(code, one_graph, hooks)
return (
ConvertFrameReturn(),
@ -980,7 +1113,6 @@ def _compile(
code: CodeType,
one_graph: bool,
hooks: Hooks,
transform: Callable[[list[Instruction], dict[str, Any]], Any],
) -> tuple[ConvertFrameReturn, DynamoTracerOutput]:
nonlocal dynamo_time_before_restart
last_attempt_start_time = start_time = time.time()
@ -1003,7 +1135,21 @@ def _compile(
out_code = None
try:
dynamo_output = compile_frame(code, transform, restart_reasons)
dynamo_output = compile_frame(
code,
globals,
locals,
builtins,
closure,
compiler_fn,
one_graph,
restart_reasons,
export=export,
export_constraints=export_constraints,
frame_state=frame_state,
distributed_state=distributed_state,
package=package,
)
except exc.SkipFrame as e:
if one_graph or _is_error_on_graph_break(e._torch_dynamo_tracer_output):
log.debug(
@ -1091,13 +1237,11 @@ def _compile(
CleanupManager.instance[out_code] = output.cleanups
nonlocal cache_entry
with dynamo_timed("build_guards", log_pt2_compile_event=True):
check_fn = CheckFunctionManager(
check_fn = dynamo_output.build_guards(
code,
output,
cache_entry,
hooks.guard_fail_fn if hooks else None,
hooks.guard_filter_fn if hooks else None,
save_guards=True if package else False,
hooks=hooks,
save=package is not None,
cache_entry=cache_entry,
)
if package is not None:
@ -1145,8 +1289,6 @@ def _compile(
code_context,
):
restart_reasons: set[str] = set()
# This is shared across restarts
speculation_log = SpeculationLog()
if compile_pg := get_compile_pg():
distributed_state = DistributedState(compile_pg, LocalState())
else:
@ -1278,9 +1420,7 @@ def _compile(
torch._dynamo.utils.ReinplaceCounters.clear()
guarded_code = None
try:
guarded_code, tracer_output = compile_inner(
code, one_graph, hooks, transform
)
guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
# NB: We only put_code_state in success case. Success case here
# does include graph breaks; specifically, if a graph break still

View File

@ -113,7 +113,7 @@ from .utils import (
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from torch._dynamo.package import CompilePackage, DynamoCaptureOutput
from torch._dynamo.package import CompilePackage
from torch._dynamo.repro.after_dynamo import WrapBackendDebug
from torch._subclasses import fake_tensor
from torch.fx.node import Argument, Node, Target
@ -2288,83 +2288,3 @@ def skip_code(code: types.CodeType) -> None:
set_code_exec_strategy(
code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
)
@dataclass
class BackendInput:
graph_module: torch.fx.GraphModule
example_inputs: tuple[Any, ...]
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
@dataclass
class CaptureOutput:
"""
Core data structure that contains the all the information dynamo generates
from fullgraph=True. Ideally, this is should be the "return" type if dynamo
has a standard API to return compilation artifacts.
"""
dynamo_artifacts: DynamoCaptureOutput
backend_inputs: dict[str, BackendInput]
def fullgraph_capture(model: Callable[..., Any]) -> Callable[..., Any]:
"""
A helper function which wraps a model and returns a callable like optimize().
The callable can be called with normal inputs like torch.compile()-ed functions
and user can dump dynamo compilation artifacts through `get_artifacts()` call.
The CaptureOutput is separated into two parts:
1. Dynamo specific information from DynamoCaptureOutput, which includes:
- guards
- generated bytecode
- python source information
2. Backend specific information (indexed by unique backend id) such as:
- fx graph
- example inputs
Example:
def fn(*args):
...
compiled_fn = fullgraph_capture(fn)
compiled_fn(args)
compiled_fn(another_args)
artifacts = compiled_fn.get_artifacts()
"""
from torch._dynamo.package import CompilePackage
package = CompilePackage(model)
backend_inputs: dict[str, BackendInput] = {}
def _backend(
gm: torch.fx.GraphModule, example_inputs: tuple[Any, ...]
) -> torch.fx.GraphModule:
from torch._guards import TracingContext
fake_mode = TracingContext.get().fake_mode
assert fake_mode is not None
backend_id = gm._backend_id
assert isinstance(backend_id, str)
backend_inputs[backend_id] = BackendInput(gm, example_inputs, fake_mode)
return gm
# TODO For now we use eval_frame to give us the frame. This is can be simplified to
# a manual frame creation helper.
optimized_model = optimize(nopython=True, backend=_backend, package=package)(model)
@functools.wraps(model)
def capture_context(*args: Any, **kwargs: Any) -> Any:
return optimized_model(*args, **kwargs)
def get_artifacts() -> CaptureOutput:
cache_entry = package.cache_entry()
assert len(cache_entry.codes) == 1
return CaptureOutput(
dynamo_artifacts=cache_entry.codes[0], backend_inputs=backend_inputs
)
capture_context.get_artifacts = get_artifacts # type: ignore[attr-defined]
return capture_context

View File

@ -588,6 +588,9 @@ class OutputGraph(OutputGraphGuardsState):
self.maybe_install_saved_tensors_hooks_subgraphs()
)
# mangled alias -> module fqn name
self.import_sources: dict[str, str] = {}
def mark_bytecode_tracing_start(self) -> None:
self.compiler_trace_stack.enter_context(
dynamo_timed(
@ -1785,6 +1788,7 @@ class OutputGraph(OutputGraphGuardsState):
self.dynamo_flat_name_to_original_fqn.copy()
)
gm.meta["dynamo_compile_id"] = self.dynamo_compile_id
gm.meta["backend_id"] = name
graph_code_log.debug(
"%s",

View File

@ -1633,6 +1633,7 @@ class InstructionTranslatorBase(
if self.package is not None:
self.package.add_import_source(alias, module_name)
self.output.import_sources[alias] = module_name
f_globals = self.output.global_scope
assert alias not in f_globals or f_globals[alias] is value
f_globals[alias] = value