mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Previous discussion: https://github.com/pytorch/pytorch/pull/109476 In this PR, I made following additions to the original PR: 1) Unlifted graph module now runs the runtime assertions in its' forward call. 2) When we retrace, we make sure we run the assertions to make sure user is tracing the module with correct inputs with respect to the assumptions we made during first tracing. The way I do is that I create new graph module type with modified call method. And the runtime assertions happen under torchdynamo.disable so that it is just run in eager directly. The reason is we don't this to be traced part of the graph. 3) Both ep.module and capture_pre_autograd now returns _UnliftedGraphModule. Differential Revision: [D51078056](https://our.internmc.facebook.com/intern/diff/D51078056) Pull Request resolved: https://github.com/pytorch/pytorch/pull/110222 Approved by: https://github.com/zhxchen17
181 lines
5.6 KiB
Python
181 lines
5.6 KiB
Python
import dataclasses
|
|
|
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
|
|
|
|
import torch
|
|
|
|
from torch._export import ExportedProgram
|
|
|
|
from torch.utils._pytree import (
|
|
_register_pytree_node,
|
|
Context,
|
|
DumpableContext,
|
|
FlattenFunc,
|
|
FromDumpableContextFn,
|
|
ToDumpableContextFn,
|
|
tree_flatten,
|
|
UnflattenFunc,
|
|
)
|
|
|
|
|
|
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
|
|
|
|
|
|
@torch._dynamo.disable
|
|
def _check_input_constraints_pre_hook(self, *args, **kwargs):
|
|
flat_args, _ = tree_flatten(args)
|
|
return _check_input_constraints_for_graph(
|
|
self.graph,
|
|
range_constraints=self.range_constraints,
|
|
equality_constraints=self.equality_constraints,
|
|
)(*flat_args)
|
|
|
|
|
|
def _check_input_constraints_for_graph(
|
|
graph: torch.fx.Graph, range_constraints, equality_constraints
|
|
):
|
|
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
|
|
_AddRuntimeAssertionsForConstraintsPass,
|
|
)
|
|
|
|
def inner(*args):
|
|
# TODO(zhxchen17) Don't generate a runtime graph on the fly.
|
|
_assertion_graph = torch.fx.GraphModule({}, torch.fx.Graph())
|
|
for p in graph.nodes:
|
|
if p.op != "placeholder":
|
|
continue
|
|
new_p = _assertion_graph.graph.placeholder(p.name)
|
|
new_p.meta = p.meta
|
|
_assertion_graph.graph.output(())
|
|
_assertion_graph_res = _AddRuntimeAssertionsForConstraintsPass(
|
|
range_constraints,
|
|
equality_constraints,
|
|
)(_assertion_graph)
|
|
assert _assertion_graph_res is not None
|
|
_assertion_graph = _assertion_graph_res.graph_module
|
|
_assertion_graph(*args)
|
|
|
|
return inner
|
|
|
|
|
|
def register_dataclass_as_pytree_node(
|
|
cls: Any,
|
|
flatten_fn: Optional[FlattenFunc] = None,
|
|
unflatten_fn: Optional[UnflattenFunc] = None,
|
|
*,
|
|
to_dumpable_context: Optional[ToDumpableContextFn] = None,
|
|
from_dumpable_context: Optional[FromDumpableContextFn] = None,
|
|
serialized_type_name: Optional[str] = None,
|
|
return_none_fields: bool = False,
|
|
) -> None:
|
|
assert dataclasses.is_dataclass(
|
|
cls
|
|
), f"Only dataclasses can be registered with this function: {cls}"
|
|
|
|
serialized_type = f"{cls.__module__}.{cls.__name__}"
|
|
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = cls
|
|
|
|
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
|
flattened = []
|
|
flat_names = []
|
|
none_names = []
|
|
for f in dataclasses.fields(obj):
|
|
name, val = f.name, getattr(obj, f.name)
|
|
if val is not None or return_none_fields:
|
|
flattened.append(val)
|
|
flat_names.append(name)
|
|
else:
|
|
none_names.append(name)
|
|
return flattened, (cls, flat_names, none_names)
|
|
|
|
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
|
|
typ, flat_names, none_names = context
|
|
return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names})
|
|
|
|
def default_to_dumpable_context(context: Context) -> DumpableContext:
|
|
return (serialized_type, context[1], context[2])
|
|
|
|
def default_from_dumpable_context(dumpable_context: DumpableContext) -> Context:
|
|
return (
|
|
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[dumpable_context[0]],
|
|
dumpable_context[1],
|
|
dumpable_context[2],
|
|
)
|
|
|
|
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
|
|
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
|
|
|
|
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
|
|
raise ValueError(
|
|
f"Both to_dumpable_context and from_dumpable_context for {cls} must "
|
|
"be None or registered."
|
|
)
|
|
|
|
to_dumpable_context = (
|
|
to_dumpable_context
|
|
if to_dumpable_context is not None
|
|
else default_to_dumpable_context
|
|
)
|
|
from_dumpable_context = (
|
|
from_dumpable_context
|
|
if from_dumpable_context is not None
|
|
else default_from_dumpable_context
|
|
)
|
|
|
|
_register_pytree_node(
|
|
cls,
|
|
flatten_fn,
|
|
unflatten_fn,
|
|
serialized_type_name=serialized_type_name,
|
|
to_dumpable_context=to_dumpable_context,
|
|
from_dumpable_context=from_dumpable_context,
|
|
)
|
|
|
|
|
|
def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool:
|
|
"""
|
|
Checks if the given node is a parameter within the exported program
|
|
"""
|
|
|
|
return node.name in program.graph_signature.inputs_to_parameters
|
|
|
|
|
|
def get_param(
|
|
program: ExportedProgram,
|
|
node: torch.fx.Node,
|
|
) -> Optional[torch.nn.Parameter]:
|
|
"""
|
|
Returns the parameter associated with the given node in the exported program.
|
|
Returns None if the node is not a parameter within the exported program
|
|
"""
|
|
|
|
if is_param(program, node):
|
|
parameter_name = program.graph_signature.inputs_to_parameters[node.name]
|
|
return program.state_dict[parameter_name]
|
|
|
|
return None
|
|
|
|
|
|
def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool:
|
|
"""
|
|
Checks if the given node is a buffer within the exported program
|
|
"""
|
|
|
|
return node.name in program.graph_signature.inputs_to_buffers
|
|
|
|
|
|
def get_buffer(
|
|
program: ExportedProgram,
|
|
node: torch.fx.Node,
|
|
) -> Optional[torch.Tensor]:
|
|
"""
|
|
Returns the buffer associated with the given node in the exported program.
|
|
Returns None if the node is not a buffer within the exported program
|
|
"""
|
|
|
|
if is_buffer(program, node):
|
|
buffer_name = program.graph_signature.inputs_to_buffers[node.name]
|
|
return program.state_dict[buffer_name]
|
|
|
|
return None
|