mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
remove allow-untyped-defs from torch/distributed/_shard/sharded_tensor/shard.py (#144623)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144623 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
b8aae2773f
commit
f6688ac81d
|
|
@ -1,4 +1,3 @@
|
|||
# mypy: allow-untyped-defs
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
|
||||
|
|
@ -23,7 +22,7 @@ class Shard:
|
|||
tensor: torch.Tensor
|
||||
metadata: ShardMetadata
|
||||
|
||||
def __post_init__(self):
|
||||
def __post_init__(self) -> None:
|
||||
# verification between local tensor and metadata
|
||||
if list(self.tensor.size()) != self.metadata.shard_sizes:
|
||||
raise ValueError(
|
||||
|
|
@ -45,7 +44,7 @@ class Shard:
|
|||
@classmethod
|
||||
def from_tensor_and_offsets(
|
||||
cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int
|
||||
):
|
||||
) -> "Shard":
|
||||
"""
|
||||
Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user