[sharded_tensor] fix typing issue for placement (#63426)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63426

placement should either be a string or a _remote_device, this fixes the type to match the behaviors
ghstack-source-id: 136041125

Reviewed By: pritamdamania87

Differential Revision: D30379702

fbshipit-source-id: 34e226494240923b433e3a39cc08c84d42cdad6b
This commit is contained in:
Wanchao Liang 2021-08-17 23:10:48 -07:00 committed by Facebook GitHub Bot
parent 2fd14735d6
commit d431c77d76

View File

@ -1,5 +1,6 @@
from typing import List
from typing import List, Union
from dataclasses import dataclass
from torch.distributed.remote_device import _remote_device
import torch
@ -24,7 +25,7 @@ class ShardMetadata(object):
shard_offsets: List[int]
shard_lengths: List[int]
placement: torch.distributed._remote_device
placement: Union[str, _remote_device]
def __post_init__(self):
if isinstance(self.placement, str):