Add torch.Tensor._make_wrapper_subclass to torch/_C/__init__.pyi (#154022)

Fixes #153790

Pull Request resolved: https://github.com/pytorch/pytorch/pull/154022
Approved by: https://github.com/Skylion007
This commit is contained in:
Yuanhao Ji 2025-05-27 14:09:55 +00:00 committed by PyTorch MergeBot
parent d88699308f
commit 0a7eef140b
13 changed files with 44 additions and 19 deletions

View File

@ -1238,6 +1238,30 @@ def gen_pyi(
"S",
)
],
"_make_wrapper_subclass": [
"@staticmethod\n"
+ defs(
"_make_wrapper_subclass",
[
"cls: type[S]",
"size: Sequence[_int | SymInt]",
"strides: Sequence[_int | SymInt] | None = None",
"storage_offset: _int | SymInt | None = None",
"memory_format: torch.memory_format | None = None",
"dtype: _dtype | None = None",
"layout: _layout = strided",
"device: _device | None = None",
"pin_memory: _bool = False",
"requires_grad: _bool = False",
"dispatch_sizes_strides_policy: str | None = None",
"dispatch_device: _bool = False",
"dispatch_layout: _bool = False",
"_extra_dispatch_keys: torch.DispatchKeySet | None = None",
"storage_size: _int | SymInt | None = None",
],
"S",
)
],
"__contains__": [defs("__contains__", ["self", "item: Any", "/"], "_bool")],
"__getitem__": [defs("__getitem__", ["self", INDICES, "/"], "Tensor")],
"__setitem__": [

View File

@ -118,7 +118,7 @@ class FunctionalTensor(torch.Tensor):
FunctionalTensor._extra_dispatch_keys & torch._C._dispatch_keys(elem)
)
out = torch.Tensor._make_wrapper_subclass( # type: ignore[arg-type, attr-defined]
out = torch.Tensor._make_wrapper_subclass(
# TODO: right now, _make_wrapper_subclass's dynamic shape interaction is not great.
# Calling the overload that has kwargs causes us to go down the first overload path,
# which will **always** specialize sizes.

View File

@ -391,7 +391,7 @@ def _rebuild_wrapper_subclass(
requires_grad,
):
device = _get_restore_location(device)
return torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
return torch.Tensor._make_wrapper_subclass(
cls,
size,
strides=stride,

View File

@ -582,7 +582,7 @@ class AsyncCollectiveTensor(torch.Tensor):
@staticmethod
def __new__(cls, elem: torch.Tensor):
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
r = torch.Tensor._make_wrapper_subclass(
cls,
elem.size(),
strides=elem.stride(),

View File

@ -104,7 +104,7 @@ class ShardedTensorBase(torch.Tensor):
sizes, tensor_properties=tensor_properties
)
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
r = torch.Tensor._make_wrapper_subclass(
cls,
sizes,
dtype=dtype,

View File

@ -269,7 +269,7 @@ class DTensor(torch.Tensor):
# new method instruct wrapper tensor from local_tensor and add
# placement spec, it does not do actual distribution
assert spec.tensor_meta is not None, "TensorMeta should not be None!"
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
r = torch.Tensor._make_wrapper_subclass(
cls,
spec.tensor_meta.shape,
strides=spec.tensor_meta.stride,

View File

@ -45,7 +45,7 @@ class LocalShardsWrapper(torch.Tensor):
# if empty shard, we create a empty tensor
if len(local_shards) == 0:
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
r = torch.Tensor._make_wrapper_subclass(
cls,
torch.Size([0, 0]),
)
@ -82,7 +82,7 @@ class LocalShardsWrapper(torch.Tensor):
for shard, offset in zip(local_shards, local_offsets)
]
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
r = torch.Tensor._make_wrapper_subclass(
cls,
torch.Size(cat_tensor_shape),
)

View File

@ -69,7 +69,7 @@ class LocalShardsWrapper(torch.Tensor):
ChunkStorageMetadata(o, s.shape) for s, o in zip(local_shards, offsets)
]
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
r = torch.Tensor._make_wrapper_subclass(
cls,
wrapper_shape,
)

View File

@ -174,7 +174,7 @@ class MaskedTensor(torch.Tensor):
UserWarning,
stacklevel=2,
)
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) # type: ignore[attr-defined]
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
def _preprocess_data(self, data, mask):
from .._ops import _sparse_coo_where, _sparse_csr_where

View File

@ -107,7 +107,7 @@ class NestedTensor(torch.Tensor):
stride = values.stride()
_strides = (ragged_size * stride[r], *stride)
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
r = torch.Tensor._make_wrapper_subclass(
cls,
_size,
_strides,

View File

@ -138,13 +138,14 @@ class SparseSemiStructuredTensor(torch.Tensor):
else:
raise ValueError("At least one of packed or packed_t must be provided")
kwargs = {
"device": previous_tensor.device,
"dtype": previous_tensor.dtype,
"layout": previous_tensor.layout,
"requires_grad": requires_grad,
}
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
tensor = torch.Tensor._make_wrapper_subclass(
cls,
shape,
device=previous_tensor.device,
dtype=previous_tensor.dtype,
layout=previous_tensor.layout,
requires_grad=requires_grad,
)
tensor.packed = packed
tensor.meta = meta

View File

@ -127,7 +127,7 @@ def generate_cct_and_mode(autograd_view_consistency=True):
# by a Composite operation; if the Composite
# operator attempts to read from the storage without dispatching then it'll
# raise a RuntimeError due to it being a meta storage.
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
r = torch.Tensor._make_wrapper_subclass(
cls, elem.size(),
dtype=elem.dtype, layout=elem.layout,
device=elem.device, requires_grad=elem.requires_grad,

View File

@ -42,7 +42,7 @@ class LoggingTensor(torch.Tensor):
# The wrapping tensor (LoggingTensor) shouldn't hold any
# memory for the class in question, but it should still
# advertise the same device as before
r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
r = torch.Tensor._make_wrapper_subclass(
cls, elem.size(),
strides=elem.stride(), storage_offset=elem.storage_offset(),
# TODO: clone storage aliasing