mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
In this pr, we allow users to register a customized flatten/unflatten/serialization/deserialization for a dataclass. We provide some default implementation for flatten/unflatten. We could implement a decorator based on it when needed. ## Motivation: HuggingFace and many internal models return dataclass output and torch.export wants to maintain the invariant that export result (i.e. exported_program) has the same calling convention and result as the original callable. This is not supported in export yet: we cannot recover the original dataclass from flattened output produced by the underlying graph module (produced by dynamo and processed further by aot_export). We need to have a place to store the metadata of the dataclass so that we can re-construct it. To avoid adding hacky code in export and allow princinpled extensibility, we think extending pytree may be a good option. ## Implementation: @zou3519 mentioned https://github.com/pytorch/pytorch/pull/93214/files and [jax-2371](https://github.com/google/jax/issues/2371#issuecomment-805361566), which suggests that it's not a good idea to make dataclass a default pytree node but it could be good to provide a default implementation for dataclass. Since currently, this seems to be an export-only feature, we added this extension point in export. We also add "return_none_fields" flag to control whether none fields are returned after flattening, which is expected to be False in produce_matching of dynamo.export. Also added some tests. Pull Request resolved: https://github.com/pytorch/pytorch/pull/106160 Approved by: https://github.com/zhxchen17
55 lines
1.6 KiB
Python
55 lines
1.6 KiB
Python
import dataclasses
|
|
|
|
from typing import Any, List, Optional, Tuple
|
|
|
|
from torch.utils._pytree import (
|
|
_register_pytree_node,
|
|
Context,
|
|
FlattenFunc,
|
|
MaybeFromStrFunc,
|
|
ToStrFunc,
|
|
UnflattenFunc,
|
|
)
|
|
|
|
|
|
def register_dataclass_as_pytree_node(
|
|
typ: Any,
|
|
flatten_fn: Optional[FlattenFunc] = None,
|
|
unflatten_fn: Optional[UnflattenFunc] = None,
|
|
to_str_fn: Optional[ToStrFunc] = None,
|
|
maybe_from_str_fn: Optional[MaybeFromStrFunc] = None,
|
|
*,
|
|
return_none_fields: bool = False,
|
|
) -> None:
|
|
assert dataclasses.is_dataclass(
|
|
typ
|
|
), f"Only dataclasses can be registered with this function: {typ}"
|
|
|
|
def default_flatten_fn(obj: Any) -> Tuple[List[Any], Context]:
|
|
flattened = []
|
|
flat_names = []
|
|
none_names = []
|
|
for f in dataclasses.fields(obj):
|
|
name, val = f.name, getattr(obj, f.name)
|
|
if val is not None or return_none_fields:
|
|
flattened.append(val)
|
|
flat_names.append(name)
|
|
else:
|
|
none_names.append(name)
|
|
return flattened, (typ, flat_names, none_names)
|
|
|
|
def default_unflatten_fn(values: List[Any], context: Context) -> Any:
|
|
typ, flat_names, none_names = context
|
|
return typ(**dict(zip(flat_names, values)), **{k: None for k in none_names})
|
|
|
|
flatten_fn = flatten_fn if flatten_fn is not None else default_flatten_fn
|
|
unflatten_fn = unflatten_fn if unflatten_fn is not None else default_unflatten_fn
|
|
|
|
_register_pytree_node(
|
|
typ,
|
|
flatten_fn,
|
|
unflatten_fn,
|
|
None,
|
|
None,
|
|
)
|