Revert "[Ez][BE]: Remove accidental classvar (#153540)"

This reverts commit e0dece510b.

Reverted https://github.com/pytorch/pytorch/pull/153540 on behalf of https://github.com/jeanschmidt due to Broken internal tests, @albanD may you help the author get his PR merged? D74804063 ([comment](https://github.com/pytorch/pytorch/pull/153540#issuecomment-2886011101))
This commit is contained in:
PyTorch MergeBot 2025-05-16 08:26:37 +00:00
parent 4d073af58c
commit 86c6f71ddb

View File

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from typing import cast, Optional, TYPE_CHECKING, Union
from typing_extensions import TypeAlias
import torch
import torch.distributed as dist
@ -23,8 +22,6 @@ if TYPE_CHECKING:
# from run-time to resolve circular dependency.
from torch.distributed._shard.sharded_tensor import ShardedTensor
_ShardingDim: TypeAlias = Union[int, str]
@dataclass
class ChunkShardingSpec(ShardingSpec):
@ -53,7 +50,9 @@ class ChunkShardingSpec(ShardingSpec):
:class:`torch.distributed._remote_device`
"""
dim: _ShardingDim
ShardingDim = Union[int, str]
dim: ShardingDim
placements: list[Union[torch.distributed._remote_device, str]]
def __post_init__(self):