mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[sharded_tensor] fix typing issue for placement (#63426)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63426 placement should either be a string or a _remote_device, this fixes the type to match the behaviors ghstack-source-id: 136041125 Reviewed By: pritamdamania87 Differential Revision: D30379702 fbshipit-source-id: 34e226494240923b433e3a39cc08c84d42cdad6b
This commit is contained in:
parent
2fd14735d6
commit
d431c77d76
|
|
@ -1,5 +1,6 @@
|
|||
from typing import List
|
||||
from typing import List, Union
|
||||
from dataclasses import dataclass
|
||||
from torch.distributed.remote_device import _remote_device
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -24,7 +25,7 @@ class ShardMetadata(object):
|
|||
|
||||
shard_offsets: List[int]
|
||||
shard_lengths: List[int]
|
||||
placement: torch.distributed._remote_device
|
||||
placement: Union[str, _remote_device]
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.placement, str):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user