mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add torch.Tensor._make_wrapper_subclass to torch/_C/__init__.pyi (#154022)
Fixes #153790 Pull Request resolved: https://github.com/pytorch/pytorch/pull/154022 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
d88699308f
commit
0a7eef140b
|
|
@ -1238,6 +1238,30 @@ def gen_pyi(
|
|||
"S",
|
||||
)
|
||||
],
|
||||
"_make_wrapper_subclass": [
|
||||
"@staticmethod\n"
|
||||
+ defs(
|
||||
"_make_wrapper_subclass",
|
||||
[
|
||||
"cls: type[S]",
|
||||
"size: Sequence[_int | SymInt]",
|
||||
"strides: Sequence[_int | SymInt] | None = None",
|
||||
"storage_offset: _int | SymInt | None = None",
|
||||
"memory_format: torch.memory_format | None = None",
|
||||
"dtype: _dtype | None = None",
|
||||
"layout: _layout = strided",
|
||||
"device: _device | None = None",
|
||||
"pin_memory: _bool = False",
|
||||
"requires_grad: _bool = False",
|
||||
"dispatch_sizes_strides_policy: str | None = None",
|
||||
"dispatch_device: _bool = False",
|
||||
"dispatch_layout: _bool = False",
|
||||
"_extra_dispatch_keys: torch.DispatchKeySet | None = None",
|
||||
"storage_size: _int | SymInt | None = None",
|
||||
],
|
||||
"S",
|
||||
)
|
||||
],
|
||||
"__contains__": [defs("__contains__", ["self", "item: Any", "/"], "_bool")],
|
||||
"__getitem__": [defs("__getitem__", ["self", INDICES, "/"], "Tensor")],
|
||||
"__setitem__": [
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class FunctionalTensor(torch.Tensor):
|
|||
FunctionalTensor._extra_dispatch_keys & torch._C._dispatch_keys(elem)
|
||||
)
|
||||
|
||||
out = torch.Tensor._make_wrapper_subclass( # type: ignore[arg-type, attr-defined]
|
||||
out = torch.Tensor._make_wrapper_subclass(
|
||||
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
|
||||
# Calling the overload that has kwargs causes us to go down the first overload path,
|
||||
# which will **always** specialize sizes.
|
||||
|
|
|
|||
|
|
@ -391,7 +391,7 @@ def _rebuild_wrapper_subclass(
|
|||
requires_grad,
|
||||
):
|
||||
device = _get_restore_location(device)
|
||||
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
return torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
size,
|
||||
strides=stride,
|
||||
|
|
|
|||
|
|
@ -582,7 +582,7 @@ class AsyncCollectiveTensor(torch.Tensor):
|
|||
|
||||
@staticmethod
|
||||
def __new__(cls, elem: torch.Tensor):
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
elem.size(),
|
||||
strides=elem.stride(),
|
||||
|
|
|
|||
|
|
@ -104,7 +104,7 @@ class ShardedTensorBase(torch.Tensor):
|
|||
sizes, tensor_properties=tensor_properties
|
||||
)
|
||||
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
sizes,
|
||||
dtype=dtype,
|
||||
|
|
|
|||
|
|
@ -269,7 +269,7 @@ class DTensor(torch.Tensor):
|
|||
# new method instruct wrapper tensor from local_tensor and add
|
||||
# placement spec, it does not do actual distribution
|
||||
assert spec.tensor_meta is not None, "TensorMeta should not be None!"
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
spec.tensor_meta.shape,
|
||||
strides=spec.tensor_meta.stride,
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class LocalShardsWrapper(torch.Tensor):
|
|||
|
||||
# if empty shard, we create a empty tensor
|
||||
if len(local_shards) == 0:
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
torch.Size([0, 0]),
|
||||
)
|
||||
|
|
@ -82,7 +82,7 @@ class LocalShardsWrapper(torch.Tensor):
|
|||
for shard, offset in zip(local_shards, local_offsets)
|
||||
]
|
||||
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
torch.Size(cat_tensor_shape),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class LocalShardsWrapper(torch.Tensor):
|
|||
ChunkStorageMetadata(o, s.shape) for s, o in zip(local_shards, offsets)
|
||||
]
|
||||
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
wrapper_shape,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ class MaskedTensor(torch.Tensor):
|
|||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
|
||||
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
|
||||
|
||||
def _preprocess_data(self, data, mask):
|
||||
from .._ops import _sparse_coo_where, _sparse_csr_where
|
||||
|
|
|
|||
|
|
@ -107,7 +107,7 @@ class NestedTensor(torch.Tensor):
|
|||
stride = values.stride()
|
||||
_strides = (ragged_size * stride[r], *stride)
|
||||
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
_size,
|
||||
_strides,
|
||||
|
|
|
|||
|
|
@ -138,13 +138,14 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||
else:
|
||||
raise ValueError("At least one of packed or packed_t must be provided")
|
||||
|
||||
kwargs = {
|
||||
"device": previous_tensor.device,
|
||||
"dtype": previous_tensor.dtype,
|
||||
"layout": previous_tensor.layout,
|
||||
"requires_grad": requires_grad,
|
||||
}
|
||||
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
|
||||
tensor = torch.Tensor._make_wrapper_subclass(
|
||||
cls,
|
||||
shape,
|
||||
device=previous_tensor.device,
|
||||
dtype=previous_tensor.dtype,
|
||||
layout=previous_tensor.layout,
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
tensor.packed = packed
|
||||
tensor.meta = meta
|
||||
|
|
|
|||
|
|
@ -127,7 +127,7 @@ def generate_cct_and_mode(autograd_view_consistency=True):
|
|||
# by a Composite operation; if the Composite
|
||||
# operator attempts to read from the storage without dispatching then it'll
|
||||
# raise a RuntimeError due to it being a meta storage.
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls, elem.size(),
|
||||
dtype=elem.dtype, layout=elem.layout,
|
||||
device=elem.device, requires_grad=elem.requires_grad,
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class LoggingTensor(torch.Tensor):
|
|||
# The wrapping tensor (LoggingTensor) shouldn't hold any
|
||||
# memory for the class in question, but it should still
|
||||
# advertise the same device as before
|
||||
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
|
||||
r = torch.Tensor._make_wrapper_subclass(
|
||||
cls, elem.size(),
|
||||
strides=elem.stride(), storage_offset=elem.storage_offset(),
|
||||
# TODO: clone storage aliasing
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user