mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
add type annotations to torch._utils (#49705)
Summary: closes gh-49704 Pull Request resolved: https://github.com/pytorch/pytorch/pull/49705 Reviewed By: mruberry Differential Revision: D25725352 Pulled By: malfet fbshipit-source-id: 05a7041c9caffde4a5c1eb8af0d13697075103af
This commit is contained in:
parent
ce370398cc
commit
870ab04b64
17
mypy.ini
17
mypy.ini
|
|
@ -106,7 +106,22 @@ ignore_errors = True
|
|||
[mypy-torch._appdirs]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch._utils]
|
||||
[mypy-torch._overrides]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.tensorboard._caffe2_graph]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.contrib._tensorboard_vis]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.nn.utils.prune]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.show_pickle]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.hipify.hipify_python]
|
||||
ignore_errors = True
|
||||
|
||||
[mypy-torch.utils.benchmark.examples.*]
|
||||
|
|
|
|||
|
|
@ -302,6 +302,9 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, out: str) -> None:
|
|||
'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
|
||||
' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
|
||||
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
|
||||
'_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
|
||||
' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
|
||||
' requires_grad: bool = False) -> Tensor: ...'],
|
||||
'range': ['def range(start: Number, end: Number,'
|
||||
' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
|
||||
.format(FACTORY_PARAMS)],
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
import torch._six
|
||||
from typing import Optional
|
||||
from typing import Optional, List, DefaultDict
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
import sys
|
||||
|
|
@ -37,9 +37,9 @@ def _type(self, dtype=None, non_blocking=False, **kwargs):
|
|||
raise RuntimeError("Cannot cast sparse tensor to dense tensor")
|
||||
new_module_name = dtype.__module__.replace('.sparse', '')
|
||||
new_values_type_name = new_module_name + '.' + dtype.__name__
|
||||
new_values = torch._values(self).type(new_values_type_name, non_blocking)
|
||||
new_values = torch.Tensor._values(self).type(new_values_type_name, non_blocking)
|
||||
new_indices_type_name = new_module_name + '.LongTensor'
|
||||
new_indices = torch._indices(self).type(new_indices_type_name, non_blocking)
|
||||
new_indices = torch.Tensor._indices(self).type(new_indices_type_name, non_blocking)
|
||||
return dtype(new_indices, new_values, self.size())
|
||||
if dtype.is_sparse:
|
||||
raise RuntimeError("Cannot cast dense tensor to sparse tensor")
|
||||
|
|
@ -72,8 +72,8 @@ def _cuda(self, device=None, non_blocking=False, **kwargs):
|
|||
with torch.cuda.device(device):
|
||||
if self.is_sparse:
|
||||
new_type = getattr(torch.cuda.sparse, self.__class__.__name__)
|
||||
indices = torch._indices(self).cuda(device, non_blocking)
|
||||
values = torch._values(self).cuda(device, non_blocking)
|
||||
indices = torch.Tensor._indices(self).cuda(device, non_blocking)
|
||||
values = torch.Tensor._values(self).cuda(device, non_blocking)
|
||||
return new_type(indices, values, self.size())
|
||||
else:
|
||||
new_type = getattr(torch.cuda, self.__class__.__name__)
|
||||
|
|
@ -144,7 +144,7 @@ def _rebuild_tensor_v2(storage, storage_offset, size, stride, requires_grad, bac
|
|||
return tensor
|
||||
|
||||
|
||||
_sparse_tensors_to_validate = []
|
||||
_sparse_tensors_to_validate: List["torch.Tensor"] = []
|
||||
|
||||
# In _legacy_load() in serialization.py we unpickle storages after the sparse
|
||||
# tensors have been already unpickled. Those storages contain data necessary for
|
||||
|
|
@ -271,8 +271,8 @@ def _flatten_sparse_tensors(tensors):
|
|||
A tuple of two contiguous 1D buffers, one containing input tensors'
|
||||
indices and the other containing the values.
|
||||
"""
|
||||
flat_indices = _flatten_dense_tensors([torch._indices(t) for t in tensors])
|
||||
flat_values = _flatten_dense_tensors([torch._values(t) for t in tensors])
|
||||
flat_indices = _flatten_dense_tensors([torch.Tensor._indices(t) for t in tensors])
|
||||
flat_values = _flatten_dense_tensors([torch.Tensor._values(t) for t in tensors])
|
||||
return flat_indices, flat_values
|
||||
|
||||
|
||||
|
|
@ -314,8 +314,8 @@ def _unflatten_sparse_tensors(flat, tensors):
|
|||
flat.
|
||||
"""
|
||||
flat_indices, flat_values = flat
|
||||
indices = _unflatten_dense_tensors(flat_indices, [torch._indices(t) for t in tensors])
|
||||
values = _unflatten_dense_tensors(flat_values, [torch._values(t) for t in tensors])
|
||||
indices = _unflatten_dense_tensors(flat_indices, [torch.Tensor._indices(t) for t in tensors])
|
||||
values = _unflatten_dense_tensors(flat_values, [torch.Tensor._values(t) for t in tensors])
|
||||
outputs = []
|
||||
for t, i, v in zip(tensors, indices, values):
|
||||
outputs.append(t.new(i, v, t.size()))
|
||||
|
|
@ -340,8 +340,8 @@ def _reorder_tensors_as(tensors, ordered_tensors):
|
|||
type_dict = defaultdict(list)
|
||||
for tensor in tensors:
|
||||
type_dict[tensor.type()].append(tensor)
|
||||
type_dict = {t: iter(coll) for t, coll in type_dict.items()}
|
||||
return tuple(next(type_dict[tensor.type()]) for tensor in ordered_tensors)
|
||||
type_dict_ = {t: iter(coll) for t, coll in type_dict.items()}
|
||||
return tuple(next(type_dict_[tensor.type()]) for tensor in ordered_tensors)
|
||||
|
||||
|
||||
def _take_tensors(tensors, size_limit):
|
||||
|
|
@ -356,12 +356,12 @@ def _take_tensors(tensors, size_limit):
|
|||
Blocks of tensors of same type and within size_limit. The yielded
|
||||
tensors are only ordered as the original sequence within its types.
|
||||
"""
|
||||
buf_dict = defaultdict(lambda: [[], 0])
|
||||
buf_dict: DefaultDict[str, List] = defaultdict(lambda: [[], 0])
|
||||
for tensor in tensors:
|
||||
t = tensor.type()
|
||||
if tensor.is_sparse:
|
||||
indices = torch._indices(tensor)
|
||||
values = torch._values(tensor)
|
||||
indices = torch.Tensor._indices(tensor)
|
||||
values = torch.Tensor._values(tensor)
|
||||
size = indices.numel() * indices.element_size() + values.numel() * values.element_size()
|
||||
else:
|
||||
size = tensor.numel() * tensor.element_size()
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user