mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "[dynamo] Refactor convert_frame.compile_frame to be self contained function. [5/n] (#160900)"
This reverts commit 447d34b5f8.
Reverted https://github.com/pytorch/pytorch/pull/160900 on behalf of https://github.com/atalman due to reverting since can't land existing diff internally, will need to reland it ([comment](https://github.com/pytorch/pytorch/pull/160900#issuecomment-3224029031))
This commit is contained in:
parent
8c506e6310
commit
e795450a35
|
|
@ -1,7 +1,6 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
# ruff: noqa: F841
|
||||
import abc
|
||||
import builtins
|
||||
import collections
|
||||
import collections.abc
|
||||
import copy
|
||||
|
|
@ -8563,52 +8562,47 @@ utils_device.CURRENT_DEVICE == None""".split("\n"):
|
|||
self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)")
|
||||
|
||||
def test_fullgraph_capture(self):
|
||||
from torch._dynamo.convert_frame import (
|
||||
FrameInfo,
|
||||
fullgraph_capture,
|
||||
get_compile_id,
|
||||
)
|
||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
||||
from torch._guards import compile_context, CompileContext
|
||||
|
||||
def foo(x):
|
||||
return x + x.shape[0]
|
||||
|
||||
x = torch.randn(4, 3)
|
||||
f_locals = {"x": x}
|
||||
with (
|
||||
compile_context(CompileContext(get_compile_id({}))),
|
||||
dynamo_timed(""),
|
||||
get_metrics_context(),
|
||||
):
|
||||
capture_output = fullgraph_capture(
|
||||
FrameInfo(
|
||||
foo.__code__,
|
||||
foo.__globals__,
|
||||
f_locals,
|
||||
builtins,
|
||||
(),
|
||||
)
|
||||
)
|
||||
dynamo_output = capture_output.dynamo_output
|
||||
backend_input = capture_output.backend_input
|
||||
self.assertTrue(
|
||||
dynamo_output.build_guards(foo.__code__).guard_manager.check(f_locals)
|
||||
)
|
||||
import_sources = {
|
||||
alias: importlib.import_module(module_name)
|
||||
for alias, module_name in dynamo_output.tracer_output.output_graph.import_sources.items()
|
||||
}
|
||||
self.assertEqual(
|
||||
foo(x),
|
||||
types.FunctionType(
|
||||
dynamo_output.bytecode,
|
||||
{
|
||||
**import_sources,
|
||||
backend_input.backend_id: backend_input.graph_module,
|
||||
},
|
||||
)(x),
|
||||
compiled_foo = torch._dynamo.eval_frame.fullgraph_capture(foo)
|
||||
compiled_foo(torch.randn(3, 2))
|
||||
compiled_foo(torch.randn(4))
|
||||
artifacts = compiled_foo.get_artifacts()
|
||||
|
||||
guarded_codes = artifacts.dynamo_artifacts.guarded_codes
|
||||
backend_ids = list(artifacts.backend_inputs.keys())
|
||||
gms = [b.graph_module for b in artifacts.backend_inputs.values()]
|
||||
|
||||
def _convert_to_ep_demo(code, backend_id, gm, args):
|
||||
# Inject compiled function as the original gm
|
||||
new_globals = copy.copy(globals())
|
||||
new_globals[backend_id] = gm
|
||||
# Minimal boilerplate to setup a callable.
|
||||
SerializedCode = type(code.dynamo_code)
|
||||
dynamo_bytecode = SerializedCode.to_code_object(code.dynamo_code)
|
||||
guards_state = pickle.loads(code.guards_state)
|
||||
guard_manager = torch._dynamo.guards.CheckFunctionManager(
|
||||
foo.__code__,
|
||||
guards_state.output_graph,
|
||||
shape_code_parts=guards_state.shape_code_parts,
|
||||
runtime_global_scope=new_globals,
|
||||
).guard_manager
|
||||
|
||||
class ModuleForExport(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return types.FunctionType(dynamo_bytecode, new_globals)(x)
|
||||
|
||||
m = ModuleForExport()
|
||||
return guard_manager, torch.export.export(m, args)
|
||||
|
||||
guards0, ep0 = _convert_to_ep_demo(
|
||||
guarded_codes[0], backend_ids[0], gms[0], (torch.randn(3, 2),)
|
||||
)
|
||||
self.assertTrue(guards0.check({"x": torch.randn(3, 2)}))
|
||||
self.assertFalse(guards0.check({"x": torch.randn(4)}))
|
||||
input0 = torch.randn(3, 2)
|
||||
self.assertEqual(ep0.module()(input0), foo(input0))
|
||||
|
||||
def test_torch_guards_stack_frame_register_inlining_deep(self):
|
||||
x = torch.tensor([0.5, 0.5])
|
||||
|
|
|
|||
|
|
@ -836,180 +836,16 @@ def trace_frame(
|
|||
|
||||
@dataclass
|
||||
class DynamoOutput:
|
||||
"""
|
||||
Represents the core data returned from a single dynamo run, including:
|
||||
- Guards, wrapped inside tracer_output.output_graph.guards
|
||||
- Generated bytecode
|
||||
- Other information needed for compilation.
|
||||
This data structure should capture all the "interesting" information dynamo
|
||||
produces on the frontend side before it enters user backend.
|
||||
"""
|
||||
|
||||
tracer_output: DynamoTracerOutput
|
||||
bytecode: types.CodeType
|
||||
last_attempt_start_time: Optional[float]
|
||||
|
||||
def build_guards(
|
||||
self,
|
||||
code: types.CodeType,
|
||||
hooks: Optional[Hooks] = None,
|
||||
save: bool = False,
|
||||
cache_entry: Optional[CacheEntry] = None,
|
||||
) -> CheckFunctionManager:
|
||||
assert self.tracer_output.output_graph is not None
|
||||
return CheckFunctionManager(
|
||||
code,
|
||||
self.tracer_output.output_graph,
|
||||
cache_entry,
|
||||
hooks.guard_fail_fn if hooks else None,
|
||||
hooks.guard_filter_fn if hooks else None,
|
||||
save_guards=save,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendInput:
|
||||
"""
|
||||
Represents core data structure that dynamo will pass to a backend, including:
|
||||
- Graph module
|
||||
- Example inputs
|
||||
- The FakeTensorMode used for compiling graph.
|
||||
This data structure should capture all the information dynamo produces
|
||||
on for the user backend.
|
||||
"""
|
||||
|
||||
backend_id: str
|
||||
graph_module: torch.fx.GraphModule
|
||||
example_inputs: Any
|
||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaptureOutput:
|
||||
"""
|
||||
CaptureOutput should represent all the information produced from torch
|
||||
compiler for a single graph capture. This intends to be consumed by
|
||||
various compiler frontends so that we can share as much compiler internals
|
||||
as possible and avoid great divergence between different stacks.
|
||||
This data structure should eventually contain all the information compiler
|
||||
produces as more refactors happens to converge different compiler
|
||||
frontends.
|
||||
"""
|
||||
|
||||
dynamo_output: DynamoOutput
|
||||
backend_input: BackendInput
|
||||
|
||||
|
||||
@dataclass
|
||||
class FrameInfo:
|
||||
code: types.CodeType
|
||||
globals: dict[str, object]
|
||||
locals: dict[str, object]
|
||||
builtins: dict[str, object]
|
||||
closure: tuple[CellType]
|
||||
|
||||
|
||||
def fullgraph_capture(frame: FrameInfo) -> CaptureOutput:
|
||||
"""
|
||||
A standalone function which takes a frame and returns dynamo captured graph
|
||||
plus other important compile information. This should serve as the common
|
||||
interface for different torch compiler AOT frontengs (e.g. precompile, export).
|
||||
Note that this function doesn't apply context managers like metrics context
|
||||
or compile id, and the expectation is that the caller will apply them depending
|
||||
on the use case.
|
||||
|
||||
The CaptureOutput is separated into two parts:
|
||||
1. Dynamo specific information from DynamoOutput, which includes:
|
||||
- guards
|
||||
- generated bytecode
|
||||
- other information tracked by OutputGraph.
|
||||
2. Backend specific information (indexed by unique backend id) such as:
|
||||
- fx graph
|
||||
- example inputs
|
||||
"""
|
||||
from torch._guards import TracingContext
|
||||
|
||||
backend_input: Optional[BackendInput] = None
|
||||
|
||||
def fullgraph_compiler(
|
||||
gm: torch.fx.GraphModule, example_inputs: list[torch.Tensor]
|
||||
) -> torch.fx.GraphModule:
|
||||
nonlocal backend_input
|
||||
fake_mode = TracingContext.get().fake_mode
|
||||
assert fake_mode is not None
|
||||
assert isinstance(gm.meta["backend_id"], str)
|
||||
backend_input = BackendInput(
|
||||
gm.meta["backend_id"], gm, example_inputs, fake_mode
|
||||
)
|
||||
return gm
|
||||
|
||||
dynamo_output = compile_frame(
|
||||
frame.code,
|
||||
frame.globals,
|
||||
frame.locals,
|
||||
frame.builtins,
|
||||
frame.closure,
|
||||
compiler_fn=fullgraph_compiler,
|
||||
one_graph=True,
|
||||
restart_reasons=set(),
|
||||
)
|
||||
assert backend_input is not None
|
||||
return CaptureOutput(dynamo_output, backend_input)
|
||||
|
||||
|
||||
def compile_frame( # type: ignore[return]
|
||||
code: types.CodeType,
|
||||
globals: dict[str, object],
|
||||
locals: dict[str, object],
|
||||
builtins: dict[str, object],
|
||||
closure: tuple[CellType],
|
||||
compiler_fn: CompilerFn,
|
||||
one_graph: bool,
|
||||
transform: Callable[[list[Instruction], dict[str, Any]], DynamoTracerOutput],
|
||||
restart_reasons: set[str],
|
||||
*,
|
||||
export: bool = False,
|
||||
export_constraints: Optional[typing.Never] = None,
|
||||
frame_state: Optional[dict[str, Union[int, FrameStateSizeEntry]]] = None,
|
||||
distributed_state: Optional[DistributedState] = None,
|
||||
package: Optional[CompilePackage] = None,
|
||||
) -> DynamoOutput:
|
||||
"""
|
||||
A helper function taking a frame and backend, then return the generated bytecode
|
||||
and guards as a common data structure.
|
||||
This is a shared interface for multiple compiler frontends (e.g. torch.compile,
|
||||
torch.export) that needs to capture a graph out of python code.
|
||||
"""
|
||||
# This is shared across restarts
|
||||
speculation_log = SpeculationLog()
|
||||
|
||||
def transform(
|
||||
instructions: list[Instruction], code_options: dict[str, object]
|
||||
) -> DynamoTracerOutput:
|
||||
tf_mode_stack: list[torch.overrides.TorchFunctionMode] = (
|
||||
torch.overrides._get_current_function_mode_stack()
|
||||
)
|
||||
tracer_output = trace_frame(
|
||||
code,
|
||||
globals,
|
||||
locals,
|
||||
builtins,
|
||||
closure,
|
||||
compiler_fn,
|
||||
tf_mode_stack,
|
||||
one_graph,
|
||||
speculation_log,
|
||||
instructions,
|
||||
code_options,
|
||||
export=export,
|
||||
export_constraints=export_constraints,
|
||||
frame_state=frame_state,
|
||||
distributed_state=distributed_state,
|
||||
package=package,
|
||||
)
|
||||
|
||||
assert tracer_output is not None
|
||||
return tracer_output
|
||||
|
||||
last_attempt_start_time = None
|
||||
for attempt in itertools.count():
|
||||
CompileContext.get().attempt = attempt
|
||||
|
|
@ -1090,9 +926,40 @@ def _compile(
|
|||
# Time spent compiling this frame before restarting or failing analysis
|
||||
dynamo_time_before_restart: float = 0.0
|
||||
|
||||
def transform(
|
||||
instructions: list[Instruction], code_options: dict[str, object]
|
||||
) -> DynamoTracerOutput:
|
||||
tf_mode_stack: list[torch.overrides.TorchFunctionMode] = (
|
||||
torch.overrides._get_current_function_mode_stack()
|
||||
)
|
||||
tracer_output = trace_frame(
|
||||
code,
|
||||
globals,
|
||||
locals,
|
||||
builtins,
|
||||
closure,
|
||||
compiler_fn,
|
||||
tf_mode_stack,
|
||||
one_graph,
|
||||
speculation_log,
|
||||
instructions,
|
||||
code_options,
|
||||
export=export,
|
||||
export_constraints=export_constraints,
|
||||
frame_state=frame_state,
|
||||
distributed_state=distributed_state,
|
||||
package=package,
|
||||
)
|
||||
|
||||
assert tracer_output is not None
|
||||
return tracer_output
|
||||
|
||||
@compile_time_strobelight_meta(phase_name="compile_inner")
|
||||
def compile_inner(
|
||||
code: CodeType, one_graph: bool, hooks: Hooks
|
||||
code: CodeType,
|
||||
one_graph: bool,
|
||||
hooks: Hooks,
|
||||
transform: Callable[[list[Instruction], dict[str, Any]], Any],
|
||||
) -> tuple[ConvertFrameReturn, Optional[DynamoTracerOutput]]:
|
||||
with contextlib.ExitStack() as stack:
|
||||
stack.enter_context(
|
||||
|
|
@ -1101,7 +968,7 @@ def _compile(
|
|||
)
|
||||
)
|
||||
stack.enter_context(CompileTimeInstructionCounter.record())
|
||||
return _compile_inner(code, one_graph, hooks)
|
||||
return _compile_inner(code, one_graph, hooks, transform)
|
||||
|
||||
return (
|
||||
ConvertFrameReturn(),
|
||||
|
|
@ -1113,6 +980,7 @@ def _compile(
|
|||
code: CodeType,
|
||||
one_graph: bool,
|
||||
hooks: Hooks,
|
||||
transform: Callable[[list[Instruction], dict[str, Any]], Any],
|
||||
) -> tuple[ConvertFrameReturn, DynamoTracerOutput]:
|
||||
nonlocal dynamo_time_before_restart
|
||||
last_attempt_start_time = start_time = time.time()
|
||||
|
|
@ -1135,21 +1003,7 @@ def _compile(
|
|||
|
||||
out_code = None
|
||||
try:
|
||||
dynamo_output = compile_frame(
|
||||
code,
|
||||
globals,
|
||||
locals,
|
||||
builtins,
|
||||
closure,
|
||||
compiler_fn,
|
||||
one_graph,
|
||||
restart_reasons,
|
||||
export=export,
|
||||
export_constraints=export_constraints,
|
||||
frame_state=frame_state,
|
||||
distributed_state=distributed_state,
|
||||
package=package,
|
||||
)
|
||||
dynamo_output = compile_frame(code, transform, restart_reasons)
|
||||
except exc.SkipFrame as e:
|
||||
if one_graph or _is_error_on_graph_break(e._torch_dynamo_tracer_output):
|
||||
log.debug(
|
||||
|
|
@ -1237,11 +1091,13 @@ def _compile(
|
|||
CleanupManager.instance[out_code] = output.cleanups
|
||||
nonlocal cache_entry
|
||||
with dynamo_timed("build_guards", log_pt2_compile_event=True):
|
||||
check_fn = dynamo_output.build_guards(
|
||||
check_fn = CheckFunctionManager(
|
||||
code,
|
||||
hooks=hooks,
|
||||
save=package is not None,
|
||||
cache_entry=cache_entry,
|
||||
output,
|
||||
cache_entry,
|
||||
hooks.guard_fail_fn if hooks else None,
|
||||
hooks.guard_filter_fn if hooks else None,
|
||||
save_guards=True if package else False,
|
||||
)
|
||||
|
||||
if package is not None:
|
||||
|
|
@ -1289,6 +1145,8 @@ def _compile(
|
|||
code_context,
|
||||
):
|
||||
restart_reasons: set[str] = set()
|
||||
# This is shared across restarts
|
||||
speculation_log = SpeculationLog()
|
||||
if compile_pg := get_compile_pg():
|
||||
distributed_state = DistributedState(compile_pg, LocalState())
|
||||
else:
|
||||
|
|
@ -1420,7 +1278,9 @@ def _compile(
|
|||
torch._dynamo.utils.ReinplaceCounters.clear()
|
||||
guarded_code = None
|
||||
try:
|
||||
guarded_code, tracer_output = compile_inner(code, one_graph, hooks)
|
||||
guarded_code, tracer_output = compile_inner(
|
||||
code, one_graph, hooks, transform
|
||||
)
|
||||
|
||||
# NB: We only put_code_state in success case. Success case here
|
||||
# does include graph breaks; specifically, if a graph break still
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ from .utils import (
|
|||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from torch._dynamo.package import CompilePackage
|
||||
from torch._dynamo.package import CompilePackage, DynamoCaptureOutput
|
||||
from torch._dynamo.repro.after_dynamo import WrapBackendDebug
|
||||
from torch._subclasses import fake_tensor
|
||||
from torch.fx.node import Argument, Node, Target
|
||||
|
|
@ -2295,3 +2295,83 @@ def skip_code(code: types.CodeType) -> None:
|
|||
set_code_exec_strategy(
|
||||
code, FrameExecStrategy(FrameAction.SKIP, FrameAction.DEFAULT)
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackendInput:
|
||||
graph_module: torch.fx.GraphModule
|
||||
example_inputs: tuple[Any, ...]
|
||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode
|
||||
|
||||
|
||||
@dataclass
|
||||
class CaptureOutput:
|
||||
"""
|
||||
Core data structure that contains the all the information dynamo generates
|
||||
from fullgraph=True. Ideally, this is should be the "return" type if dynamo
|
||||
has a standard API to return compilation artifacts.
|
||||
"""
|
||||
|
||||
dynamo_artifacts: DynamoCaptureOutput
|
||||
backend_inputs: dict[str, BackendInput]
|
||||
|
||||
|
||||
def fullgraph_capture(model: Callable[..., Any]) -> Callable[..., Any]:
|
||||
"""
|
||||
A helper function which wraps a model and returns a callable like optimize().
|
||||
The callable can be called with normal inputs like torch.compile()-ed functions
|
||||
and user can dump dynamo compilation artifacts through `get_artifacts()` call.
|
||||
|
||||
The CaptureOutput is separated into two parts:
|
||||
1. Dynamo specific information from DynamoCaptureOutput, which includes:
|
||||
- guards
|
||||
- generated bytecode
|
||||
- python source information
|
||||
2. Backend specific information (indexed by unique backend id) such as:
|
||||
- fx graph
|
||||
- example inputs
|
||||
|
||||
Example:
|
||||
def fn(*args):
|
||||
...
|
||||
|
||||
compiled_fn = fullgraph_capture(fn)
|
||||
compiled_fn(args)
|
||||
compiled_fn(another_args)
|
||||
artifacts = compiled_fn.get_artifacts()
|
||||
"""
|
||||
from torch._dynamo.package import CompilePackage
|
||||
|
||||
package = CompilePackage(model)
|
||||
|
||||
backend_inputs: dict[str, BackendInput] = {}
|
||||
|
||||
def _backend(
|
||||
gm: torch.fx.GraphModule, example_inputs: tuple[Any, ...]
|
||||
) -> torch.fx.GraphModule:
|
||||
from torch._guards import TracingContext
|
||||
|
||||
fake_mode = TracingContext.get().fake_mode
|
||||
assert fake_mode is not None
|
||||
backend_id = gm._backend_id
|
||||
assert isinstance(backend_id, str)
|
||||
backend_inputs[backend_id] = BackendInput(gm, example_inputs, fake_mode)
|
||||
return gm
|
||||
|
||||
# TODO For now we use eval_frame to give us the frame. This is can be simplified to
|
||||
# a manual frame creation helper.
|
||||
optimized_model = optimize(nopython=True, backend=_backend, package=package)(model)
|
||||
|
||||
@functools.wraps(model)
|
||||
def capture_context(*args: Any, **kwargs: Any) -> Any:
|
||||
return optimized_model(*args, **kwargs)
|
||||
|
||||
def get_artifacts() -> CaptureOutput:
|
||||
cache_entry = package.cache_entry()
|
||||
assert len(cache_entry.codes) == 1
|
||||
return CaptureOutput(
|
||||
dynamo_artifacts=cache_entry.codes[0], backend_inputs=backend_inputs
|
||||
)
|
||||
|
||||
capture_context.get_artifacts = get_artifacts # type: ignore[attr-defined]
|
||||
return capture_context
|
||||
|
|
|
|||
|
|
@ -596,9 +596,6 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
self.maybe_install_saved_tensors_hooks_subgraphs()
|
||||
)
|
||||
|
||||
# mangled alias -> module fqn name
|
||||
self.import_sources: dict[str, str] = {}
|
||||
|
||||
def mark_bytecode_tracing_start(self) -> None:
|
||||
self.compiler_trace_stack.enter_context(
|
||||
dynamo_timed(
|
||||
|
|
@ -1906,7 +1903,6 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
self.dynamo_flat_name_to_original_fqn.copy()
|
||||
)
|
||||
gm.meta["dynamo_compile_id"] = self.dynamo_compile_id
|
||||
gm.meta["backend_id"] = name
|
||||
|
||||
graph_code_log.debug(
|
||||
"%s",
|
||||
|
|
|
|||
|
|
@ -1685,7 +1685,6 @@ class InstructionTranslatorBase(
|
|||
|
||||
if self.package is not None:
|
||||
self.package.add_import_source(alias, module_name)
|
||||
self.output.import_sources[alias] = module_name
|
||||
f_globals = self.output.global_scope
|
||||
assert alias not in f_globals or f_globals[alias] is value
|
||||
f_globals[alias] = value
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user