mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[pytree] freeze attributes of TreeSpec (#124011)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124011 Approved by: https://github.com/zou3519
This commit is contained in:
parent
6edf989e2f
commit
3b0f6cce5c
|
|
@ -662,7 +662,7 @@ def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -
|
||||||
# context: some context that is useful in unflattening the pytree
|
# context: some context that is useful in unflattening the pytree
|
||||||
# children_specs: specs for each child of the root Node
|
# children_specs: specs for each child of the root Node
|
||||||
# num_leaves: the number of leaves
|
# num_leaves: the number of leaves
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass(init=True, frozen=True, eq=True, repr=False)
|
||||||
class TreeSpec:
|
class TreeSpec:
|
||||||
type: Any
|
type: Any
|
||||||
context: Context
|
context: Context
|
||||||
|
|
@ -673,9 +673,12 @@ class TreeSpec:
|
||||||
num_children: int = dataclasses.field(init=False)
|
num_children: int = dataclasses.field(init=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.num_nodes = 1 + sum(spec.num_nodes for spec in self.children_specs)
|
num_nodes = sum((spec.num_nodes for spec in self.children_specs), start=1)
|
||||||
self.num_leaves = sum(spec.num_leaves for spec in self.children_specs)
|
num_leaves = sum(spec.num_leaves for spec in self.children_specs)
|
||||||
self.num_children = len(self.children_specs)
|
num_children = len(self.children_specs)
|
||||||
|
object.__setattr__(self, "num_nodes", num_nodes)
|
||||||
|
object.__setattr__(self, "num_leaves", num_leaves)
|
||||||
|
object.__setattr__(self, "num_children", num_children)
|
||||||
|
|
||||||
def __repr__(self, indent: int = 0) -> str:
|
def __repr__(self, indent: int = 0) -> str:
|
||||||
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
|
repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, ["
|
||||||
|
|
@ -808,9 +811,9 @@ class LeafSpec(TreeSpec):
|
||||||
super().__init__(None, None, [])
|
super().__init__(None, None, [])
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self.num_nodes = 1
|
object.__setattr__(self, "num_nodes", 1)
|
||||||
self.num_leaves = 1
|
object.__setattr__(self, "num_leaves", 1)
|
||||||
self.num_children = 0
|
object.__setattr__(self, "num_children", 0)
|
||||||
|
|
||||||
def __repr__(self, indent: int = 0) -> str:
|
def __repr__(self, indent: int = 0) -> str:
|
||||||
return "*"
|
return "*"
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user