mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
New export implementation with flat inp/out (#162167)
This is my first attempt of building new export API. The main thing it addresses is correctly getting input and output relations. Subsequent diffs willl add functionality for dynamic shapes, nn_module_stack etc. Differential Revision: [D81793205](https://our.internmc.facebook.com/intern/diff/D81793205) Pull Request resolved: https://github.com/pytorch/pytorch/pull/162167 Approved by: https://github.com/zhxchen17, https://github.com/avikchaudhuri
This commit is contained in:
parent
ae0edc133e
commit
047603d35b
|
|
@ -1,5 +1,6 @@
|
|||
# Owner(s): ["oncall: export"]
|
||||
# flake8: noqa
|
||||
import copy
|
||||
import types
|
||||
import unittest
|
||||
from typing import Dict, List, Tuple
|
||||
|
|
@ -351,6 +352,62 @@ def forward(self, x):
|
|||
res2 = p.generate(input_tensor=inp, input_tensor2=inp2)
|
||||
self.assertTrue(torch.allclose(res, res2))
|
||||
|
||||
def test_export_add_in_out_info(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, dct, lst, bleh):
|
||||
x = dct["a"] * lst[1][0]
|
||||
y = dct["b"] * lst[0]
|
||||
out_dict = {}
|
||||
# Mutate and get a new entry in there
|
||||
lst_copy = lst.copy()
|
||||
lst_copy.append(lst[0])
|
||||
out_dict["a"] = x
|
||||
out_dict["b"] = y
|
||||
return (
|
||||
dct["a"],
|
||||
out_dict["b"],
|
||||
bleh,
|
||||
lst_copy[-1],
|
||||
out_dict["a"],
|
||||
[5, 6],
|
||||
)
|
||||
|
||||
dct = {"a": torch.randn(2, 3), "b": torch.randn(2, 3)}
|
||||
lst = [torch.randn(2, 3), [torch.randn(2, 3), torch.randn(2, 3)]]
|
||||
|
||||
export_inputs = ((dct, lst, 56), {})
|
||||
eager_inputs = copy.deepcopy(export_inputs)
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
|
||||
graph_module = _dynamo_graph_capture_for_export(Foo())(
|
||||
*export_inputs[0], **export_inputs[1]
|
||||
)
|
||||
|
||||
res_export = graph_module(*export_inputs[0], **export_inputs[1])
|
||||
res_eager = Foo()(*eager_inputs[0], **eager_inputs[1])
|
||||
|
||||
self.assertEqual(res_export, res_eager)
|
||||
|
||||
def test_export_leaf(self):
|
||||
class Foo(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
return x.sin()
|
||||
|
||||
export_inputs = ((torch.randn(4, 4),), {})
|
||||
eager_inputs = copy.deepcopy(export_inputs)
|
||||
|
||||
from torch._dynamo.functional_export import _dynamo_graph_capture_for_export
|
||||
|
||||
graph_module = _dynamo_graph_capture_for_export(Foo())(
|
||||
*export_inputs[0], **export_inputs[1]
|
||||
)
|
||||
|
||||
res_export = graph_module(*export_inputs[0], **export_inputs[1])
|
||||
res_eager = Foo()(*eager_inputs[0], **eager_inputs[1])
|
||||
|
||||
self.assertEqual(res_export, res_eager)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -10,7 +10,14 @@ seamlessly optimize PyTorch programs, including those using modern Python featur
|
|||
|
||||
import torch
|
||||
|
||||
from . import aot_compile, config, convert_frame, eval_frame, resume_execution
|
||||
from . import (
|
||||
aot_compile,
|
||||
config,
|
||||
convert_frame,
|
||||
eval_frame,
|
||||
functional_export,
|
||||
resume_execution,
|
||||
)
|
||||
from .backends.registry import list_backends, lookup_backend, register_backend
|
||||
from .callback import callback_handler, on_compile_end, on_compile_start
|
||||
from .code_context import code_context
|
||||
|
|
|
|||
|
|
@ -110,6 +110,12 @@ automatic_dynamic_shapes = True
|
|||
# Valid options: "dynamic", "unbacked"
|
||||
automatic_dynamic_shapes_mark_as: Literal["dynamic", "unbacked"] = "dynamic"
|
||||
|
||||
# log graph in/out metadata
|
||||
# This is only turned on for export today since we
|
||||
# know we are tracing a flat callable. later, this
|
||||
# can extended to other use cases as well.
|
||||
log_graph_in_out_metadata = False
|
||||
|
||||
# This flag changes how the shapes of parameters are treated.
|
||||
# If this flag is set to True, then the shapes of torch.nn.Parameter as well as of torch.Tensor are attempted to be dynamic
|
||||
# If this flag is set to False, then the shapes of torch.nn.Parameter are assumed to be static,
|
||||
|
|
|
|||
|
|
@ -906,7 +906,9 @@ class FrameInfo:
|
|||
closure: tuple[CellType]
|
||||
|
||||
|
||||
def fullgraph_capture(frame: FrameInfo) -> CaptureOutput:
|
||||
def fullgraph_capture(
|
||||
frame: FrameInfo, *, _is_export_deprecated_do_not_use: bool = False
|
||||
) -> CaptureOutput:
|
||||
"""
|
||||
A standalone function which takes a frame and returns dynamo captured graph
|
||||
plus other important compile information. This should serve as the common
|
||||
|
|
@ -948,6 +950,7 @@ def fullgraph_capture(frame: FrameInfo) -> CaptureOutput:
|
|||
frame.builtins,
|
||||
frame.closure,
|
||||
compiler_fn=fullgraph_compiler,
|
||||
export=_is_export_deprecated_do_not_use,
|
||||
one_graph=True,
|
||||
restart_reasons=set(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1124,6 +1124,89 @@ class _NullDecorator(contextlib.nullcontext): # type: ignore[type-arg]
|
|||
return fn
|
||||
|
||||
|
||||
# Make dynamo graph to have same input/output spec as user code
|
||||
def argument_names(
|
||||
f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any]
|
||||
) -> list[str]:
|
||||
def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
|
||||
# Get a list of Parameter objects from the Signature object
|
||||
params = list(sig.parameters.values())
|
||||
# Separate positional arguments, keyword-only arguments and varargs/varkw
|
||||
args = [
|
||||
p.name for p in params if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
]
|
||||
kwonlyargs = [
|
||||
p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
]
|
||||
varargs = next(
|
||||
(p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
|
||||
None,
|
||||
)
|
||||
varkw = next(
|
||||
(p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
|
||||
None,
|
||||
)
|
||||
# Get default values for positional arguments and keyword-only arguments
|
||||
defaults = tuple(
|
||||
p.default
|
||||
for p in params
|
||||
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
and p.default is not inspect.Parameter.empty
|
||||
)
|
||||
kwonlydefaults = {
|
||||
p.name: p.default
|
||||
for p in params
|
||||
if p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
and p.default is not inspect.Parameter.empty
|
||||
}
|
||||
# Get annotations for parameters and return value
|
||||
annotations = {}
|
||||
if sig.return_annotation:
|
||||
annotations = {"return": sig.return_annotation}
|
||||
for parameter in params:
|
||||
annotations[parameter.name] = parameter.annotation
|
||||
# Return a FullArgSpec object with the extracted attributes
|
||||
return inspect.FullArgSpec(
|
||||
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
|
||||
)
|
||||
|
||||
fullargspec = signature_to_fullargspec(f_sig)
|
||||
|
||||
# 1. Map `args` 1-to-1 to positional arguments in original signature.
|
||||
input_strs = fullargspec.args[: len(args)]
|
||||
|
||||
if len(args) > len(fullargspec.args):
|
||||
# 2. If there are more arguments left in `args`, they map to varargs in original
|
||||
# signature. Assign names as {varargs}_0, {varargs}_1, ...
|
||||
assert fullargspec.varargs is not None, "More arguments than expected"
|
||||
input_strs += [
|
||||
f"{fullargspec.varargs}_{i}" for i in range(0, len(args) - len(input_strs))
|
||||
]
|
||||
elif len(args) < len(fullargspec.args):
|
||||
# 3. If there are fewer arguments in `args` than `fullargspec.args`,
|
||||
# it implies these are arguments either with default values, or provided in
|
||||
# `kwargs`. The former can be safely ignored. Because Dynamo.export does not
|
||||
# export them as part of the function signature. The latter will be handled
|
||||
# in the next step.
|
||||
for unprovided_arg in fullargspec.args[
|
||||
len(args) : -len(fullargspec.defaults or [])
|
||||
]:
|
||||
assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
|
||||
|
||||
# 4. Keyword arguments provided in `kwargs`.
|
||||
input_strs += list(kwargs.keys())
|
||||
|
||||
# 5. Keyword-only arguments with default values if not provided are not exported
|
||||
# as part of the function signature.
|
||||
for kwonly_arg in fullargspec.kwonlyargs:
|
||||
kwonlydefaults = fullargspec.kwonlydefaults or {}
|
||||
assert kwonly_arg in kwargs or kwonly_arg in kwonlydefaults, (
|
||||
f"Missing keyword only argument {kwonly_arg}"
|
||||
)
|
||||
|
||||
return input_strs
|
||||
|
||||
|
||||
def check_if_dynamo_supported() -> None:
|
||||
if sys.version_info >= (3, 14):
|
||||
raise RuntimeError("Python 3.14+ not yet supported for torch.compile")
|
||||
|
|
@ -1650,91 +1733,6 @@ def rewrite_signature(
|
|||
fake_mode,
|
||||
).transform()
|
||||
|
||||
# Make dynamo graph to have same input/output spec as user code
|
||||
def argument_names(
|
||||
f_sig: inspect.Signature, args: list[Any], kwargs: dict[str, Any]
|
||||
) -> list[str]:
|
||||
def signature_to_fullargspec(sig: inspect.Signature) -> inspect.FullArgSpec:
|
||||
# Get a list of Parameter objects from the Signature object
|
||||
params = list(sig.parameters.values())
|
||||
# Separate positional arguments, keyword-only arguments and varargs/varkw
|
||||
args = [
|
||||
p.name
|
||||
for p in params
|
||||
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
]
|
||||
kwonlyargs = [
|
||||
p.name for p in params if p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
]
|
||||
varargs = next(
|
||||
(p.name for p in params if p.kind == inspect.Parameter.VAR_POSITIONAL),
|
||||
None,
|
||||
)
|
||||
varkw = next(
|
||||
(p.name for p in params if p.kind == inspect.Parameter.VAR_KEYWORD),
|
||||
None,
|
||||
)
|
||||
# Get default values for positional arguments and keyword-only arguments
|
||||
defaults = tuple(
|
||||
p.default
|
||||
for p in params
|
||||
if p.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
||||
and p.default is not inspect.Parameter.empty
|
||||
)
|
||||
kwonlydefaults = {
|
||||
p.name: p.default
|
||||
for p in params
|
||||
if p.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
and p.default is not inspect.Parameter.empty
|
||||
}
|
||||
# Get annotations for parameters and return value
|
||||
annotations = {}
|
||||
if sig.return_annotation:
|
||||
annotations = {"return": sig.return_annotation}
|
||||
for parameter in params:
|
||||
annotations[parameter.name] = parameter.annotation
|
||||
# Return a FullArgSpec object with the extracted attributes
|
||||
return inspect.FullArgSpec(
|
||||
args, varargs, varkw, defaults, kwonlyargs, kwonlydefaults, annotations
|
||||
)
|
||||
|
||||
fullargspec = signature_to_fullargspec(f_sig)
|
||||
|
||||
# 1. Map `args` 1-to-1 to positional arguments in original signature.
|
||||
input_strs = fullargspec.args[: len(args)]
|
||||
|
||||
if len(args) > len(fullargspec.args):
|
||||
# 2. If there are more arguments left in `args`, they map to varargs in original
|
||||
# signature. Assign names as {varargs}_0, {varargs}_1, ...
|
||||
assert fullargspec.varargs is not None, "More arguments than expected"
|
||||
input_strs += [
|
||||
f"{fullargspec.varargs}_{i}"
|
||||
for i in range(0, len(args) - len(input_strs))
|
||||
]
|
||||
elif len(args) < len(fullargspec.args):
|
||||
# 3. If there are fewer arguments in `args` than `fullargspec.args`,
|
||||
# it implies these are arguments either with default values, or provided in
|
||||
# `kwargs`. The former can be safely ignored. Because Dynamo.export does not
|
||||
# export them as part of the function signature. The latter will be handled
|
||||
# in the next step.
|
||||
for unprovided_arg in fullargspec.args[
|
||||
len(args) : -len(fullargspec.defaults or [])
|
||||
]:
|
||||
assert unprovided_arg in kwargs, f"Missing argument {unprovided_arg}"
|
||||
|
||||
# 4. Keyword arguments provided in `kwargs`.
|
||||
input_strs += list(kwargs.keys())
|
||||
|
||||
# 5. Keyword-only arguments with default values if not provided are not exported
|
||||
# as part of the function signature.
|
||||
for kwonly_arg in fullargspec.kwonlyargs:
|
||||
kwonlydefaults = fullargspec.kwonlydefaults or {}
|
||||
assert kwonly_arg in kwargs or kwonly_arg in kwonlydefaults, (
|
||||
f"Missing keyword only argument {kwonly_arg}"
|
||||
)
|
||||
|
||||
return input_strs
|
||||
|
||||
new_graph.graph._codegen = _PyTreeCodeGen(
|
||||
_PyTreeInfo(
|
||||
argument_names(f_sig, orig_args, orig_kwargs),
|
||||
|
|
|
|||
142
torch/_dynamo/functional_export.py
Normal file
142
torch/_dynamo/functional_export.py
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
import builtins
|
||||
import inspect
|
||||
from collections import namedtuple
|
||||
from typing import Any, Callable
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.convert_frame import FrameInfo, fullgraph_capture, get_compile_id
|
||||
from torch._dynamo.eval_frame import argument_names
|
||||
from torch._dynamo.utils import dynamo_timed, get_metrics_context
|
||||
from torch._guards import compile_context, CompileContext
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
|
||||
|
||||
class ModuleToTrace(torch.nn.Module):
|
||||
def __init__(self, foo: Any, in_spec: Any) -> None:
|
||||
super().__init__()
|
||||
self._export_root = foo
|
||||
self.in_spec = in_spec
|
||||
|
||||
def forward(self, *flat_args: Any) -> "ExportTracerOutput":
|
||||
args, kwargs = pytree.tree_unflatten(flat_args, self.in_spec)
|
||||
res = self._export_root(*args, **kwargs)
|
||||
out_flat, out_spec = pytree.tree_flatten(res)
|
||||
return ExportTracerOutput(out_flat, out_spec)
|
||||
|
||||
|
||||
ExportTracerOutput = namedtuple("ExportTracerOutput", ["flat_args", "out_spec"])
|
||||
|
||||
|
||||
def _dynamo_graph_capture_for_export(
|
||||
mod: torch.nn.Module,
|
||||
) -> Callable[..., torch.fx.GraphModule]:
|
||||
"""
|
||||
This is lower level API that is used for export to capture dynamo level
|
||||
torch IR.
|
||||
|
||||
Notable TODOs:
|
||||
1. Are we actually gonna run the bytecode?
|
||||
2. Need to attach guards
|
||||
"""
|
||||
|
||||
def inner(*args: Any, **kwargs: Any) -> torch.fx.GraphModule:
|
||||
flat_inputs, in_spec = pytree.tree_flatten((args, kwargs))
|
||||
module_to_trace = ModuleToTrace(mod, in_spec)
|
||||
|
||||
signature = inspect.signature(module_to_trace.forward)
|
||||
|
||||
bound_arguments = signature.bind(*flat_inputs)
|
||||
bound_arguments.apply_defaults()
|
||||
|
||||
f_locals = {"self": module_to_trace, **bound_arguments.arguments}
|
||||
|
||||
frame = FrameInfo(
|
||||
module_to_trace.forward.__func__.__code__, # type: ignore[attr-defined]
|
||||
module_to_trace.forward.__func__.__globals__, # type: ignore[attr-defined]
|
||||
f_locals,
|
||||
builtins, # type: ignore[arg-type]
|
||||
closure=(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
dynamo_config_ctx = torch._dynamo.config.patch(
|
||||
"log_graph_in_out_metadata", True
|
||||
)
|
||||
|
||||
with (
|
||||
compile_context(CompileContext(get_compile_id({}))),
|
||||
get_metrics_context(),
|
||||
dynamo_timed("fullgraph_capture"),
|
||||
dynamo_config_ctx,
|
||||
):
|
||||
out = fullgraph_capture(frame, _is_export_deprecated_do_not_use=True)
|
||||
|
||||
assert out.dynamo_output.tracer_output.output_graph is not None
|
||||
|
||||
export_metadata = (
|
||||
out.dynamo_output.tracer_output.output_graph.export_metadata
|
||||
)
|
||||
graph_inputs = export_metadata.graph_input_idx_to_local_source
|
||||
output_return_type = export_metadata.output_return_type
|
||||
# We need to extract out_spec here because we are not actually running the bytecode
|
||||
out_spec = export_metadata.out_spec
|
||||
|
||||
graph = out.backend_input.graph_module
|
||||
|
||||
# It is not guaranteed that dynamo puts inputs in right order, so we need to
|
||||
# map the actual user order to the dynamo order.
|
||||
graph_input_order: dict[int, int] = {}
|
||||
for inp in graph_inputs:
|
||||
source = graph_inputs[inp]
|
||||
assert isinstance(source, torch._dynamo.source.GetItemSource)
|
||||
graph_input_order[source.index] = len(graph_input_order)
|
||||
|
||||
placeholders = [n for n in list(graph.graph.nodes) if n.op == "placeholder"]
|
||||
output = next(n for n in list(graph.graph.nodes) if n.op == "output")
|
||||
# Sometimes there can be empty inputs
|
||||
anchor = placeholders[0] if len(placeholders) > 0 else output
|
||||
inp_to_node = {}
|
||||
|
||||
with graph.graph.inserting_before(anchor):
|
||||
for i in range(len(flat_inputs)):
|
||||
node_new = graph.graph.placeholder(f"arg_{i}")
|
||||
if i in graph_input_order:
|
||||
placeholders[graph_input_order[i]]
|
||||
node_new.meta = placeholders[graph_input_order[i]].meta.copy()
|
||||
inp_to_node[i] = node_new
|
||||
|
||||
new_args = []
|
||||
for i in output_return_type:
|
||||
type, val = output_return_type[i]
|
||||
if type == "graph_out":
|
||||
new_args.append(output.args[0][val])
|
||||
if type == "input":
|
||||
input_idx = val.index
|
||||
new_args.append(inp_to_node[input_idx])
|
||||
if type == "constant":
|
||||
new_args.append(val)
|
||||
output.args = (tuple(new_args),)
|
||||
|
||||
for src_idx, i in graph_input_order.items():
|
||||
old = placeholders[src_idx]
|
||||
new = inp_to_node[i]
|
||||
old.replace_all_uses_with(new)
|
||||
graph.graph.erase_node(old)
|
||||
|
||||
# Dynamo uses _lazyGraphModule, so we need to force recompile
|
||||
from torch.fx._lazy_graph_module import _LazyGraphModule
|
||||
|
||||
_LazyGraphModule.force_recompile(graph)
|
||||
|
||||
graph.graph._codegen = _PyTreeCodeGen(
|
||||
_PyTreeInfo(
|
||||
argument_names(signature, args, kwargs), # type: ignore[arg-type]
|
||||
in_spec,
|
||||
out_spec,
|
||||
)
|
||||
)
|
||||
|
||||
graph.recompile()
|
||||
return graph
|
||||
|
||||
return inner
|
||||
|
|
@ -363,6 +363,24 @@ class StackLocalsMetadata:
|
|||
locals_ctx_args: list[tuple[str, tuple[Any, ...]]] = dc_field(default_factory=list)
|
||||
|
||||
|
||||
# TODO we should expand this to make it work for atribtrary in/out
|
||||
@dataclass
|
||||
class ExportMetaData:
|
||||
# maps graph input index to its' source which is later
|
||||
# used in export to map to correct user input. In its' flat form,
|
||||
# just looks like GetItem(base=LocalSource("foo", idx=0))
|
||||
graph_input_idx_to_local_source: dict[int, Source] = dc_field(default_factory=dict)
|
||||
# maps user output idx to what type of output it is. There are 3 options:
|
||||
# 1) graph out
|
||||
# 2) user input
|
||||
# 3) constants
|
||||
output_return_type: dict[int, tuple[str, Any]] = dc_field(default_factory=dict)
|
||||
# output spec of the traced function
|
||||
out_spec: Union[torch.utils._pytree.TreeSpec, torch.utils._pytree.LeafSpec] = (
|
||||
torch.utils._pytree._LEAF_SPEC
|
||||
)
|
||||
|
||||
|
||||
def get_builtins_dict(global_scope: Scope) -> dict[str, Any]:
|
||||
# f_globals["__builtins__"] can be a dict or a module. This is an
|
||||
# implementation detail -
|
||||
|
|
@ -598,6 +616,8 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
# mangled alias -> module fqn name
|
||||
self.import_sources: dict[str, str] = {}
|
||||
|
||||
self.export_metadata = ExportMetaData()
|
||||
|
||||
def mark_bytecode_tracing_start(self) -> None:
|
||||
self.compiler_trace_stack.enter_context(
|
||||
dynamo_timed(
|
||||
|
|
@ -1494,6 +1514,54 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
)
|
||||
self.codegen_suffix(tx, stack_values_flat, pass2)
|
||||
|
||||
if (
|
||||
torch._dynamo.config.log_graph_in_out_metadata
|
||||
and stack_values_flat
|
||||
and len(stack_values_flat) == 1
|
||||
):
|
||||
vt = stack_values_flat[0]
|
||||
if (
|
||||
isinstance(vt, torch._dynamo.variables.NamedTupleVariable)
|
||||
and vt.tuple_cls
|
||||
is torch._dynamo.functional_export.ExportTracerOutput
|
||||
):
|
||||
flat_returns = vt.items[0]
|
||||
out_spec = vt.items[1]
|
||||
assert isinstance(
|
||||
flat_returns, torch._dynamo.variables.ListVariable
|
||||
)
|
||||
|
||||
vt_to_graph_out_idx: dict[VariableTracker, int] = {}
|
||||
for value in pass2.graph_outputs.values():
|
||||
assert isinstance(value, torch._dynamo.codegen.GraphOutputEntry)
|
||||
variable: VariableTracker = value.variable
|
||||
vt_to_graph_out_idx[variable] = value.index
|
||||
|
||||
for idx, vt in enumerate(flat_returns.items):
|
||||
if vt in vt_to_graph_out_idx:
|
||||
self.export_metadata.output_return_type[idx] = (
|
||||
"graph_out",
|
||||
vt_to_graph_out_idx[vt],
|
||||
)
|
||||
elif (
|
||||
vt.source is not None
|
||||
and (source := getattr(vt.source, "base", None))
|
||||
and source.is_input
|
||||
):
|
||||
self.export_metadata.output_return_type[idx] = (
|
||||
"input",
|
||||
vt.source,
|
||||
)
|
||||
elif isinstance(vt, torch._dynamo.variables.ConstantVariable):
|
||||
self.export_metadata.output_return_type[idx] = (
|
||||
"constant",
|
||||
vt.as_python_constant(),
|
||||
)
|
||||
else:
|
||||
assert f"Encountered unrecognized type {vt} at output {idx}" # noqa: PLW0129
|
||||
|
||||
self.export_metadata.out_spec = out_spec.as_python_constant()
|
||||
|
||||
output = []
|
||||
if count_calls(self.graph) != 0 or len(pass2.graph_outputs) != 0:
|
||||
output.extend(
|
||||
|
|
@ -2039,6 +2107,10 @@ class OutputGraph(OutputGraphGuardsState):
|
|||
|
||||
assert self.root_tx is not None
|
||||
cg = PyCodegen(self.root_tx)
|
||||
|
||||
for idx, arg in enumerate(self.graphargs):
|
||||
self.export_metadata.graph_input_idx_to_local_source[idx] = arg.source
|
||||
|
||||
cg.make_call_generated_code(name)
|
||||
return cg.get_instructions()
|
||||
|
||||
|
|
|
|||
|
|
@ -1748,6 +1748,21 @@ class FrozenDataClassVariable(UserDefinedObjectVariable):
|
|||
ctor = self.python_type()
|
||||
return ctor(*args, **kwargs)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# Handle specific pytree classes
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
if self.value_type is pytree.LeafSpec:
|
||||
# Create a new LeafSpec instance by calling the constructor
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from("torch.utils._pytree", "LeafSpec")
|
||||
)
|
||||
codegen.extend_output(create_call_function(0, False))
|
||||
return
|
||||
|
||||
# For other frozen dataclasses, fall back to the base class behavior
|
||||
super().reconstruct(codegen)
|
||||
|
||||
# NB: This is called during __init__ for a frozen dataclass
|
||||
# use this to accumulate the most up-to-date field values
|
||||
def method_setattr_standard(self, tx: "InstructionTranslator", name, value):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user