remove allow-untyped-defs from torch/distributed/_shard/sharded_tensor/shard.py (#144623)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144623
Approved by: https://github.com/Skylion007
This commit is contained in:
bobrenjc93 2025-01-11 10:21:40 -08:00 committed by PyTorch MergeBot
parent b8aae2773f
commit f6688ac81d

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
from dataclasses import dataclass
from typing import List
@ -23,7 +22,7 @@ class Shard:
tensor: torch.Tensor
metadata: ShardMetadata
def __post_init__(self):
def __post_init__(self) -> None:
# verification between local tensor and metadata
if list(self.tensor.size()) != self.metadata.shard_sizes:
raise ValueError(
@ -45,7 +44,7 @@ class Shard:
@classmethod
def from_tensor_and_offsets(
cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int
):
) -> "Shard":
"""
Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank.