pytorch/torch/_export/utils.py
angelayi ff35e1e45b [pytree] Add custom treespec fqn field (#112428)
Custom classes that are serialized with pytree are serialized by default with `f”{class.__module__}.{class.__name__}”`. This is a dependency from our serialized program directly into the outer Python environment. If a user moves the class to a different directory, the serialized program will be unable to be loaded. So, we will require users to pass in an FQN if they want to serialize their custom treespec type.

Differential Revision: [D50886366](https://our.internmc.facebook.com/intern/diff/D50886366)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112428
Approved by: https://github.com/suo
2023-11-02 00:26:41 +00:00

143 lines
4.3 KiB
Python

import dataclasses
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type
import torch
from torch._export import ExportedProgram
from torch.utils._pytree import (
_register_pytree_node,
Context,
DumpableContext,
FlattenFunc,
FromDumpableContextFn,
ToDumpableContextFn,
UnflattenFunc,
)
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS: Dict[str, Type[Any]] = {}
def register_dataclass_as_pytree_node(
typ: Any,
flatten_fn: Optional[FlattenFunc] = None,
unflatten_fn: Optional[UnflattenFunc] = None,
*,
serialized_type_name: Optional[str] = None,
to_dumpable_context: Optional[ToDumpableContextFn] = None,
from_dumpable_context: Optional[FromDumpableContextFn] = None,
return_none_fields: bool = False,
) -> None:
assert dataclasses.is_dataclass(
typ
), f"Only dataclasses can be registered with this function: {typ}"
serialized_type = f"{typ.__module__}.{typ.__name__}"
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[serialized_type] = typ
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
flattened = []
flat_names = []
none_names = []
for f in dataclasses.fields(obj):
name, val = f.name, getattr(obj, f.name)
if val is not None or return_none_fields:
flattened.append(val)
flat_names.append(name)
else:
none_names.append(name)
return flattened, (typ, flat_names, none_names)
def default_unflatten_fn(values: Iterable[Any], context: Context) -> Any:
typ, flat_names, none_names = context
return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names})
def default_to_dumpable_context(context: Context) -> DumpableContext:
return (serialized_type, context[1], context[2])
def default_from_dumpable_context(dumpable_context: DumpableContext) -> Context:
return (
SERIALIZED_DATACLASS_TO_PYTHON_DATACLASS[dumpable_context[0]],
dumpable_context[1],
dumpable_context[2],
)
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
if (to_dumpable_context is None) ^ (from_dumpable_context is None):
raise ValueError(
f"Both to_dumpable_context and from_dumpable_context for {typ} must "
"be None or registered."
)
to_dumpable_context = (
to_dumpable_context
if to_dumpable_context is not None
else default_to_dumpable_context
)
from_dumpable_context = (
from_dumpable_context
if from_dumpable_context is not None
else default_from_dumpable_context
)
_register_pytree_node(
typ,
flatten_fn,
unflatten_fn,
serialized_type_name=serialized_type_name,
to_dumpable_context=to_dumpable_context,
from_dumpable_context=from_dumpable_context,
)
def is_param(program: ExportedProgram, node: torch.fx.Node) -> bool:
"""
Checks if the given node is a parameter within the exported program
"""
return node.name in program.graph_signature.inputs_to_parameters
def get_param(
program: ExportedProgram,
node: torch.fx.Node,
) -> Optional[torch.nn.Parameter]:
"""
Returns the parameter associated with the given node in the exported program.
Returns None if the node is not a parameter within the exported program
"""
if is_param(program, node):
parameter_name = program.graph_signature.inputs_to_parameters[node.name]
return program.state_dict[parameter_name]
return None
def is_buffer(program: ExportedProgram, node: torch.fx.Node) -> bool:
"""
Checks if the given node is a buffer within the exported program
"""
return node.name in program.graph_signature.inputs_to_buffers
def get_buffer(
program: ExportedProgram,
node: torch.fx.Node,
) -> Optional[torch.Tensor]:
"""
Returns the buffer associated with the given node in the exported program.
Returns None if the node is not a buffer within the exported program
"""
if is_buffer(program, node):
buffer_name = program.graph_signature.inputs_to_buffers[node.name]
return program.state_dict[buffer_name]
return None