mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[Ez][BE]: Remove accidental classvar (#153540)"
This reverts commit e0dece510b.
Reverted https://github.com/pytorch/pytorch/pull/153540 on behalf of https://github.com/jeanschmidt due to Broken internal tests, @albanD may you help the author get his PR merged? D74804063 ([comment](https://github.com/pytorch/pytorch/pull/153540#issuecomment-2886011101))
This commit is contained in:
parent
4d073af58c
commit
86c6f71ddb
|
|
@ -1,7 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass
|
||||
from typing import cast, Optional, TYPE_CHECKING, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
|
@ -23,8 +22,6 @@ if TYPE_CHECKING:
|
|||
# from run-time to resolve circular dependency.
|
||||
from torch.distributed._shard.sharded_tensor import ShardedTensor
|
||||
|
||||
_ShardingDim: TypeAlias = Union[int, str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChunkShardingSpec(ShardingSpec):
|
||||
|
|
@ -53,7 +50,9 @@ class ChunkShardingSpec(ShardingSpec):
|
|||
:class:`torch.distributed._remote_device`
|
||||
"""
|
||||
|
||||
dim: _ShardingDim
|
||||
ShardingDim = Union[int, str]
|
||||
|
||||
dim: ShardingDim
|
||||
placements: list[Union[torch.distributed._remote_device, str]]
|
||||
|
||||
def __post_init__(self):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user