CUDA support in the CSR layout: constructors (#59010)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/59010

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D28719287

Pulled By: bhosmer

fbshipit-source-id: fbb5784ccb5ce19dcca1f2f95c4ee16f9b7680c4
This commit is contained in:
Alexander 2021-05-26 16:38:44 -07:00 committed by Facebook GitHub Bot
parent 7c17e1dd90
commit b435a27fb7
13 changed files with 692 additions and 308 deletions

View File

@ -4,6 +4,7 @@
#include <ATen/SparseTensorImpl.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/native/Resize.h>
namespace at {
namespace {
@ -56,21 +57,6 @@ SparseCsrTensorImpl::SparseCsrTensorImpl(
col_indices_(std::move(col_indices)),
values_(std::move(values)) {}
void SparseCsrTensorImpl::resize_and_clear_(
const int64_t nnz_size,
IntArrayRef size) {
// call crow_indices().options() here since the struct contructor calls the
// tensor constructor with args for device specific init.
auto empty_crow_indices = at::empty(size[0] + 1, crow_indices().options());
auto empty_col_indices = at::empty(nnz_size, col_indices().options());
auto empty_values = at::empty(nnz_size, values().options());
crow_indices_ = empty_crow_indices;
col_indices_ = empty_col_indices;
values_ = empty_values;
sizes_and_strides_.set_sizes(size);
}
void SparseCsrTensorImpl::resize_as_sparse_csr_tensor_(const Tensor& src) {
crow_indices_ = at::empty_like(
src.crow_indices(),
@ -85,22 +71,16 @@ void SparseCsrTensorImpl::resize_as_sparse_csr_tensor_(const Tensor& src) {
src.values().options(),
src.values().suggest_memory_format());
sizes_and_strides_.set_sizes(src.sizes());
refresh_numel();
}
void SparseCsrTensorImpl::set_member_tensors(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values) {
auto crow_indices_type = crow_indices.scalar_type();
auto col_indices_type = col_indices.scalar_type();
const Tensor& values,
IntArrayRef size) {
TORCH_CHECK(
crow_indices_type == col_indices_type,
"both crow_indices and col_indices should have the same type.");
TORCH_CHECK(
crow_indices_type == kInt || crow_indices_type == kLong,
"crow_indices and col_indices must be an int32 or int64 type, but got: ",
crow_indices_type);
// CSR Type Invariants
TORCH_CHECK(
values.scalar_type() == typeMetaToScalarType(dtype()),
"dtype of values (",
@ -109,45 +89,11 @@ void SparseCsrTensorImpl::set_member_tensors(
typeMetaToScalarType(dtype()),
")");
TORCH_CHECK(
col_indices.layout() == kStrided,
"expected col_indices to be a strided tensor, but got indices of layout ",
col_indices.layout());
TORCH_CHECK(
crow_indices.layout() == kStrided,
"expected crow_indices to be a strided tensor, but got crow_indices of layout ",
crow_indices.layout());
TORCH_CHECK(
values.layout() == kStrided && values.is_contiguous(),
"expected values to be a strided and contiguous tensor, but got values of layout ",
values.layout());
TORCH_CHECK(
values.device().type() == device().type(),
"device type of values (",
values.device().type(),
") must match device type of device().type()",
device().type(),
")");
TORCH_CHECK(
values.is_cuda() || col_indices.get_device() == crow_indices.get_device(),
"crow_indices and col_indices devices (",
crow_indices.get_device(),
", ",
col_indices.get_device(),
") must match with the (non-cuda) device of values (",
values.get_device(),
")");
TORCH_CHECK(
col_indices.size(0) == values.size(0),
"col_indices and values must have equal sizes, but got col_indices.size(0): ",
col_indices.size(0),
", values.size(0): ",
values.size(0));
crow_indices_ = crow_indices;
col_indices_ = col_indices;
values_ = values;
sizes_and_strides_.set_sizes(size);
refresh_numel();
}
} // namespace at

View File

@ -32,12 +32,12 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
public:
explicit SparseCsrTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta);
void resize_and_clear_(const int64_t nnz_size, IntArrayRef size);
void resize_as_sparse_csr_tensor_(const Tensor& src);
void set_member_tensors(
const Tensor& crow_indices,
const Tensor& col_indices,
const Tensor& values);
const Tensor& values,
IntArrayRef size);
const Tensor& crow_indices() const { return crow_indices_; }
const Tensor& col_indices() const { return col_indices_; }

View File

@ -119,6 +119,7 @@ _(aten, _sparse_addmm) \
_(aten, _sparse_coo_tensor_with_dims) \
_(aten, _sparse_coo_tensor_with_dims_and_tensors) \
_(aten, _sparse_coo_tensor_unsafe) \
_(aten, _sparse_csr_tensor_unsafe) \
_(aten, _sparse_dense_add) \
_(aten, _sparse_div_scalar) \
_(aten, _sparse_div_zerodim) \
@ -655,6 +656,7 @@ _(aten, softshrink_forward) \
_(aten, solve) \
_(aten, sort) \
_(aten, sparse_coo_tensor) \
_(aten, sparse_csr_tensor) \
_(aten, sparse_mask) \
_(aten, sparse_resize) \
_(aten, sparse_resize_and_clear) \

View File

@ -4820,6 +4820,8 @@
- func: sparse_csr_tensor.crow_col_value(Tensor crow_indices, Tensor col_indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
- func: _sparse_csr_tensor_unsafe(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
- func: sparse_coo_tensor.size(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
- func: sparse_coo_tensor.indices(Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
@ -4830,6 +4832,8 @@
- func: _validate_sparse_coo_tensor_args(Tensor indices, Tensor values, int[] size) -> ()
- func: _validate_sparse_csr_tensor_args(Tensor crow_indices, Tensor col_indices, Tensor values, int[] size) -> ()
- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
dispatch:
SparseCPU, SparseCUDA: new_with_dims_sparse
@ -4901,7 +4905,7 @@
variants: method
dispatch:
SparseCPU, SparseCUDA: _nnz_sparse
SparseCsrCPU: _nnz_sparse_csr
SparseCsrCPU, SparseCsrCUDA: _nnz_sparse_csr
device_check: NoCheck
device_guard: False
@ -4960,21 +4964,21 @@
variants: method
dispatch:
SparseCPU, SparseCUDA: values_sparse
SparseCsrCPU: values_sparse_csr
SparseCsrCPU, SparseCsrCUDA: values_sparse_csr
device_check: NoCheck
device_guard: False
- func: crow_indices(Tensor(a) self) -> Tensor(a)
variants: method
dispatch:
SparseCsrCPU: crow_indices_sparse_csr
SparseCsrCPU, SparseCsrCUDA: crow_indices_sparse_csr
device_check: NoCheck
device_guard: False
- func: col_indices(Tensor(a) self) -> Tensor(a)
variants: method
dispatch:
SparseCsrCPU: col_indices_sparse_csr
SparseCsrCPU, SparseCsrCUDA: col_indices_sparse_csr
device_check: NoCheck
device_guard: False

View File

@ -8,12 +8,116 @@
#include <ATen/SparseCsrTensorImpl.h>
#include <ATen/SparseCsrTensorUtils.h>
#include <ATen/SparseTensorImpl.h>
#include <ATen/InitialTensorOptions.h>
namespace at {
namespace native {
using namespace at::sparse_csr;
namespace {
} // end anonymous namespace
void _validate_sparse_csr_tensor_args(const Tensor& crow_indices, const Tensor& col_indices, const Tensor& values, IntArrayRef size) {
// Layout Invariants
TORCH_CHECK(
col_indices.layout() == kStrided && col_indices.is_contiguous(),
"expected col_indices to be a strided and contiguous tensor");
TORCH_CHECK(
crow_indices.layout() == kStrided && crow_indices.is_contiguous(),
"expected crow_indices to be a strided and contiguous tensor");
TORCH_CHECK(
values.layout() == kStrided && values.is_contiguous(),
"expected values to be a strided and contiguous tensor");
// Shape and Strides invariants
TORCH_CHECK(
size.size() == 2,
"size of a CSR tensor must be of length 2, but got: ",
size.size());
TORCH_CHECK(
crow_indices.dim() == 1,
"crow_indices must have dim=1 but got crow_indices.dim()=",
crow_indices.dim());
TORCH_CHECK(
col_indices.dim() == 1,
"col_indices must have dim=1 but got col_indices.dim()=",
col_indices.dim());
TORCH_CHECK(
values.dim() == 1,
"values must have dim=1 but got values.dim()=",
values.dim());
// Note, this check also enforces `crow_indices.numel() >= 1`
TORCH_CHECK(
crow_indices.numel() == (size[0] + 1),
"crow_indices.numel() must be size(0) + 1, but got: ",
crow_indices.numel());
TORCH_CHECK(
col_indices.numel() == values.numel(),
"col_indices and values must have equal sizes, but got col_indices.numel(): ",
col_indices.numel(),
", values.numel(): ",
values.numel());
// Indices invariants
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "csr_construct_check", [&] {
Tensor crow_indices_cpu = crow_indices.to(kCPU);
auto crow_indices_accessor = crow_indices_cpu.accessor<index_t, 1>();
TORCH_CHECK(
crow_indices_accessor[0] == 0, "0th value of crow_indices must be 0.");
TORCH_CHECK(
crow_indices_accessor[crow_indices.numel() - 1] == col_indices.numel(),
"last value of crow_indices should be equal to the length of col_indices.");
for (int i = 1; i <= size[0]; i++) {
TORCH_CHECK(
crow_indices_accessor[i - 1] <= crow_indices_accessor[i],
"at position i = ", i, ", this condition crow_indices[i - 1] <= crow_indices[i] fails");
}
if (col_indices.numel() > 0) {
TORCH_CHECK(0 <= col_indices.min().item<index_t>(), "col_indices.min() should be greater or equal to zero");
TORCH_CHECK(size[1] > col_indices.max().item<index_t>(), "size(1) should be greater than col_indices.max()");
}
});
// CSR Type Invariants
auto crow_indices_type = crow_indices.scalar_type();
auto col_indices_type = col_indices.scalar_type();
TORCH_CHECK(
crow_indices_type == col_indices_type,
"both crow_indices and col_indices should have the same type.");
TORCH_CHECK(
crow_indices_type == kInt || crow_indices_type == kLong,
"crow_indices and col_indices must be an int32 or int64 type, but got: ",
crow_indices_type);
// CSR Device Invariants
TORCH_CHECK(
col_indices.get_device() == crow_indices.get_device(),
"crow_indices and col_indices devices (",
crow_indices.get_device(),
", ",
col_indices.get_device(),
") must match");
TORCH_CHECK(
crow_indices.get_device() == values.get_device(),
"device of crow_indices (",
crow_indices.get_device(),
") must match device of values (",
values.get_device(),
")");
TORCH_CHECK(
values.device().type() == kCPU || values.device().type() == kCUDA,
"device type of values (",
values.device().type(),
") must be CPU or CUDA");
}
// Construction of CSR tensors.
SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
// TODO: remove this comment after enabling autograd support for CSR tensor
@ -22,10 +126,13 @@ SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
TORCH_INTERNAL_ASSERT(options.layout() == kSparseCsr);
DispatchKey dispatch_key;
TORCH_CHECK_NOT_IMPLEMENTED(
options.device().type() == kCPU || options.device().type() == kCUDA,
"Could not run '", "sparse_csr_tensor", "' from the '", options.device(), "' device.)");
if (options.device().is_cuda()) {
dispatch_key = DispatchKey::SparseCsrCUDA;
} else {
TORCH_INTERNAL_ASSERT(options.device().is_cpu());
dispatch_key = DispatchKey::SparseCsrCPU;
}
@ -33,6 +140,21 @@ SparseCsrTensor new_csr_tensor(const TensorOptions& options) {
DispatchKeySet(dispatch_key), options.dtype());
}
Tensor _sparse_csr_tensor_unsafe(const Tensor& crow_indices, const Tensor& col_indices,
const Tensor& values,
IntArrayRef size,
c10::optional<ScalarType> dtype,
c10::optional<Layout> layout,
c10::optional<Device> device,
c10::optional<bool> pin_memory) {
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
SparseCsrTensor self = new_csr_tensor(options);
get_sparse_csr_impl(self)->set_member_tensors(crow_indices, col_indices, values, size);
return self;
}
// TODO: This constructor should probably use an ATen abstract method in order
// to make autograd dispatch available for the CSR constructor. See the relevant
// note in native_functions.yaml.
@ -47,43 +169,18 @@ Tensor sparse_csr_tensor(
c10::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
TORCH_CHECK(
options.layout() == kSparseCsr,
"expected sparse CSR layout, but got layout ",
options.layout());
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "csr_construct_check", [&] {
auto crow_indices_accessor = crow_indices.accessor<index_t, 1>();
TORCH_CHECK(
crow_indices_accessor[crow_indices.numel() - 1] <= col_indices.numel(),
"last value of crow_indices should be less than length of col_indices.");
TORCH_CHECK(
crow_indices_accessor[0] == 0, "0th value of crow_indices must be 0.");
});
at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size);
TORCH_CHECK(
crow_indices.dim() == 1,
"crow_indices must have dim=1 but got crow_indices.dim()=",
crow_indices.dim());
TORCH_CHECK(
col_indices.dim() == 1,
"col_indices must have dim=1 but got col_indices.dim()=",
col_indices.dim());
TORCH_CHECK(
values.dim() == 1,
"values must have dim=1 but got values.dim()=",
values.dim());
TORCH_CHECK(
(crow_indices.numel() - 1) == size[0],
"crow_indices.numel() must be size(0) + 1, but got: ",
crow_indices.numel());
SparseCsrTensor self = new_csr_tensor(options);
get_sparse_csr_impl(self)->resize_and_clear_(values.numel(), size);
get_sparse_csr_impl(self)->set_member_tensors(
crow_indices, col_indices, values);
return self;
return at::native::_sparse_csr_tensor_unsafe(
crow_indices,
col_indices,
values,
size,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt());
}
Tensor sparse_csr_tensor(
@ -96,37 +193,28 @@ Tensor sparse_csr_tensor(
c10::optional<bool> pin_memory) {
// See [Note: hacky wrapper removal for TensorOptions]
TensorOptions options = TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory);
TORCH_CHECK(
options.layout() == kSparseCsr,
"expected sparse CSR layout, but got layout ",
options.layout());
TORCH_CHECK(crow_indices.numel() >= 1, "expected crow_indices.numel() >= 1, but got ",
crow_indices.numel());
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
std::array<int64_t, 2> size;
if (col_indices.numel() > 0) {
size[0] = crow_indices.numel() - 1;
Tensor max_col_indices = std::get<0>(col_indices.max(0, false));
AT_DISPATCH_INDEX_TYPES(crow_indices.scalar_type(), "csr_construct_check", [&] {
auto crow_indices_accessor = crow_indices.accessor<index_t, 1>();
TORCH_CHECK(
crow_indices_accessor[crow_indices.numel() - 1] <= col_indices.numel(),
"last value of crow_indices should be less than length of col_indices.");
TORCH_CHECK(
crow_indices_accessor[0] == 0, "0th value of crow_indices must be 0.");
size[1] = *max_col_indices.data_ptr<index_t>() + 1;
AT_DISPATCH_INDEX_TYPES(col_indices.scalar_type(), "csr_construct_check", [&] {
size[0] = crow_indices.numel() - 1;
size[1] = col_indices.max().item<index_t>() + 1;
});
} else {
size[0] = 0;
size[1] = 0;
}
return at::sparse_csr_tensor(
crow_indices, col_indices, values, size, options);
at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, size);
return at::native::_sparse_csr_tensor_unsafe(
crow_indices,
col_indices,
values,
size,
optTypeMetaToScalarType(options.dtype_opt()),
options.layout_opt(),
options.device_opt(),
options.pinned_memory_opt());
}
// Access members of CSR tensors.

View File

@ -5,64 +5,56 @@
# values_shape: torch.Size([10])
########## torch.float32/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
col_indices=tensor([5, 1, 6, 5, 6, 4, 2, 5, 5, 9]),
values=tensor([ 0.5674, 0.1261, 0.5497, 0.6416, -0.4414, 0.3634,
-0.4327, 0.3135, -0.5225, 0.4626]), size=(10, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
tensor([0, 2, 4], dtype=torch.int32)
# _col_indices
tensor([5, 1, 6, 5, 6, 4, 2, 5, 5, 9])
tensor([0, 1, 0, 1], dtype=torch.int32)
# _values
tensor([ 0.5674, 0.1261, 0.5497, 0.6416, -0.4414, 0.3634, -0.4327, 0.3135,
-0.5225, 0.4626])
tensor([1., 2., 3., 4.])
########## torch.float64/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
col_indices=tensor([8, 2, 0, 4, 9, 2, 1, 9, 2, 2]),
values=tensor([ 0.3324, -0.3314, 0.5786, -0.3567, 0.0494, 0.3377,
0.6872, -0.1470, 0.9123, -0.8460]), size=(10, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
tensor([0, 2, 4], dtype=torch.int32)
# _col_indices
tensor([8, 2, 0, 4, 9, 2, 1, 9, 2, 2])
tensor([0, 1, 0, 1], dtype=torch.int32)
# _values
tensor([ 0.3324, -0.3314, 0.5786, -0.3567, 0.0494, 0.3377, 0.6872, -0.1470,
0.9123, -0.8460])
tensor([1., 2., 3., 4.], dtype=torch.float64)
########## torch.float32/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
col_indices=tensor([1, 5, 2, 1, 7, 4, 3, 0, 7, 6]),
values=tensor([ 0.5056, 0.7977, 0.3677, 0.5317, 0.8298, -0.2015,
-0.7799, -0.4918, -0.1335, -0.1099]), size=(10, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
tensor([0, 2, 4])
# _col_indices
tensor([1, 5, 2, 1, 7, 4, 3, 0, 7, 6])
tensor([0, 1, 0, 1])
# _values
tensor([ 0.5056, 0.7977, 0.3677, 0.5317, 0.8298, -0.2015, -0.7799, -0.4918,
-0.1335, -0.1099])
tensor([1., 2., 3., 4.])
########## torch.float64/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]),
col_indices=tensor([6, 3, 4, 8, 5, 1, 5, 6, 4, 2]),
values=tensor([-0.2544, -0.2462, -0.9784, 0.8910, 0.5322, -0.4732,
-0.6239, 0.0348, 0.5698, -0.7176]), size=(10, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
tensor([0, 2, 4])
# _col_indices
tensor([6, 3, 4, 8, 5, 1, 5, 6, 4, 2])
tensor([0, 1, 0, 1])
# _values
tensor([-0.2544, -0.2462, -0.9784, 0.8910, 0.5322, -0.4732, -0.6239, 0.0348,
0.5698, -0.7176])
tensor([1., 2., 3., 4.], dtype=torch.float64)
# shape: torch.Size([100, 10])
@ -72,100 +64,56 @@ tensor([-0.2544, -0.2462, -0.9784, 0.8910, 0.5322, -0.4732, -0.6239, 0.0348,
# values_shape: torch.Size([10])
########## torch.float32/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
col_indices=tensor([7, 4, 5, 2, 8, 8, 8, 8, 4, 4]),
values=tensor([ 0.0548, 0.2650, -0.8181, -0.5354, 0.4537, -0.7625,
-0.2098, 0.4398, 0.5190, 0.0622]), size=(100, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0])
tensor([0, 2, 4], dtype=torch.int32)
# _col_indices
tensor([7, 4, 5, 2, 8, 8, 8, 8, 4, 4])
tensor([0, 1, 0, 1], dtype=torch.int32)
# _values
tensor([ 0.0548, 0.2650, -0.8181, -0.5354, 0.4537, -0.7625, -0.2098, 0.4398,
0.5190, 0.0622])
tensor([1., 2., 3., 4.])
########## torch.float64/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
col_indices=tensor([8, 9, 2, 9, 3, 4, 9, 2, 6, 2]),
values=tensor([ 0.0069, -0.3837, -0.2516, -0.1406, 0.9457, 0.9479,
-0.0935, -0.3003, 0.4856, -0.0798]), size=(100, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0])
tensor([0, 2, 4], dtype=torch.int32)
# _col_indices
tensor([8, 9, 2, 9, 3, 4, 9, 2, 6, 2])
tensor([0, 1, 0, 1], dtype=torch.int32)
# _values
tensor([ 0.0069, -0.3837, -0.2516, -0.1406, 0.9457, 0.9479, -0.0935, -0.3003,
0.4856, -0.0798])
tensor([1., 2., 3., 4.], dtype=torch.float64)
########## torch.float32/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
col_indices=tensor([1, 2, 3, 2, 1, 2, 7, 4, 7, 6]),
values=tensor([ 0.5833, 0.0894, 0.2440, -0.6665, -0.2136, 0.6597,
0.4587, -0.2891, 0.1230, 0.7656]), size=(100, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0])
tensor([0, 2, 4])
# _col_indices
tensor([1, 2, 3, 2, 1, 2, 7, 4, 7, 6])
tensor([0, 1, 0, 1])
# _values
tensor([ 0.5833, 0.0894, 0.2440, -0.6665, -0.2136, 0.6597, 0.4587, -0.2891,
0.1230, 0.7656])
tensor([1., 2., 3., 4.])
########## torch.float64/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]),
col_indices=tensor([6, 1, 5, 9, 0, 8, 6, 1, 0, 9]),
values=tensor([-0.2178, 0.7886, 0.3778, 0.6779, -0.6440, 0.2883,
0.1788, 0.1743, 0.9286, 0.5536]), size=(100, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0])
tensor([0, 2, 4])
# _col_indices
tensor([6, 1, 5, 9, 0, 8, 6, 1, 0, 9])
tensor([0, 1, 0, 1])
# _values
tensor([-0.2178, 0.7886, 0.3778, 0.6779, -0.6440, 0.2883, 0.1788, 0.1743,
0.9286, 0.5536])
tensor([1., 2., 3., 4.], dtype=torch.float64)
# shape: torch.Size([1000, 10])
@ -175,62 +123,54 @@ tensor([-0.2178, 0.7886, 0.3778, 0.6779, -0.6440, 0.2883, 0.1788, 0.1743,
# values_shape: torch.Size([10])
########## torch.float32/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 0, 0, ..., 0, 0, 0]),
col_indices=tensor([5, 4, 4, 0, 5, 6, 8, 0, 2, 8]),
values=tensor([-0.2851, -0.7618, 0.9845, 0.7515, 0.4756, 0.9898,
-0.5324, -0.5695, -0.5853, -0.0484]), size=(1000, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([0, 0, 0, ..., 0, 0, 0])
tensor([0, 2, 4], dtype=torch.int32)
# _col_indices
tensor([5, 4, 4, 0, 5, 6, 8, 0, 2, 8])
tensor([0, 1, 0, 1], dtype=torch.int32)
# _values
tensor([-0.2851, -0.7618, 0.9845, 0.7515, 0.4756, 0.9898, -0.5324, -0.5695,
-0.5853, -0.0484])
tensor([1., 2., 3., 4.])
########## torch.float64/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 0, 0, ..., 0, 0, 0]),
col_indices=tensor([3, 6, 2, 3, 7, 8, 6, 7, 7, 2]),
values=tensor([ 0.3105, -0.6785, -0.1184, -0.2653, 0.4315, 0.6985,
0.2432, -0.0908, -0.2561, 0.7840]), size=(1000, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([0, 0, 0, ..., 0, 0, 0])
tensor([0, 2, 4], dtype=torch.int32)
# _col_indices
tensor([3, 6, 2, 3, 7, 8, 6, 7, 7, 2])
tensor([0, 1, 0, 1], dtype=torch.int32)
# _values
tensor([ 0.3105, -0.6785, -0.1184, -0.2653, 0.4315, 0.6985, 0.2432, -0.0908,
-0.2561, 0.7840])
tensor([1., 2., 3., 4.], dtype=torch.float64)
########## torch.float32/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 0, 0, ..., 0, 0, 0]),
col_indices=tensor([2, 3, 1, 1, 0, 2, 5, 9, 3, 0]),
values=tensor([ 0.3443, -0.2613, 0.1793, 0.5857, -0.9265, -0.9102,
-0.5984, 0.1220, -0.1854, 0.2155]), size=(1000, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([0, 0, 0, ..., 0, 0, 0])
tensor([0, 2, 4])
# _col_indices
tensor([2, 3, 1, 1, 0, 2, 5, 9, 3, 0])
tensor([0, 1, 0, 1])
# _values
tensor([ 0.3443, -0.2613, 0.1793, 0.5857, -0.9265, -0.9102, -0.5984, 0.1220,
-0.1854, 0.2155])
tensor([1., 2., 3., 4.])
########## torch.float64/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 0, 0, ..., 0, 0, 0]),
col_indices=tensor([3, 7, 7, 9, 7, 7, 6, 6, 9, 2]),
values=tensor([ 0.3393, -0.9329, -0.8195, 0.5085, 0.4854, -0.9112,
0.7196, -0.1944, 0.7424, -0.5868]), size=(1000, 10),
nnz=10)
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([0, 0, 0, ..., 0, 0, 0])
tensor([0, 2, 4])
# _col_indices
tensor([3, 7, 7, 9, 7, 7, 6, 6, 9, 2])
tensor([0, 1, 0, 1])
# _values
tensor([ 0.3393, -0.9329, -0.8195, 0.5085, 0.4854, -0.9112, 0.7196, -0.1944,
0.7424, -0.5868])
tensor([1., 2., 3., 4.], dtype=torch.float64)

View File

@ -0,0 +1,176 @@
# shape: torch.Size([10, 10])
# nnz: 10
# crow_indices shape: torch.Size([11])
# col_indices shape: torch.Size([10])
# values_shape: torch.Size([10])
########## torch.float32/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0', dtype=torch.int32)
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32)
# _values
tensor([1., 2., 3., 4.], device='cuda:0')
########## torch.float64/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0', dtype=torch.int32)
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32)
# _values
tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64)
########## torch.float32/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0')
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0')
# _values
tensor([1., 2., 3., 4.], device='cuda:0')
########## torch.float64/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0')
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0')
# _values
tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64)
# shape: torch.Size([100, 10])
# nnz: 10
# crow_indices shape: torch.Size([101])
# col_indices shape: torch.Size([10])
# values_shape: torch.Size([10])
########## torch.float32/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0', dtype=torch.int32)
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32)
# _values
tensor([1., 2., 3., 4.], device='cuda:0')
########## torch.float64/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0', dtype=torch.int32)
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32)
# _values
tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64)
########## torch.float32/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0')
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0')
# _values
tensor([1., 2., 3., 4.], device='cuda:0')
########## torch.float64/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0')
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0')
# _values
tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64)
# shape: torch.Size([1000, 10])
# nnz: 10
# crow_indices shape: torch.Size([1001])
# col_indices shape: torch.Size([10])
# values_shape: torch.Size([10])
########## torch.float32/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0', dtype=torch.int32)
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32)
# _values
tensor([1., 2., 3., 4.], device='cuda:0')
########## torch.float64/torch.int32 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0', dtype=torch.int32)
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0', dtype=torch.int32)
# _values
tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64)
########## torch.float32/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0')
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0')
# _values
tensor([1., 2., 3., 4.], device='cuda:0')
########## torch.float64/torch.int64 ##########
# sparse tensor
tensor(crow_indices=tensor([0, 2, 4]),
col_indices=tensor([0, 1, 0, 1]),
values=tensor([1., 2., 3., 4.]), device='cuda:0', size=(2, 2), nnz=4,
dtype=torch.float64, layout=torch.sparse_csr)
# _crow_indices
tensor([0, 2, 4], device='cuda:0')
# _col_indices
tensor([0, 1, 0, 1], device='cuda:0')
# _values
tensor([1., 2., 3., 4.], device='cuda:0', dtype=torch.float64)

View File

@ -2,10 +2,11 @@ import torch
import warnings
import unittest
import random
import itertools
from torch.testing._internal.common_utils import \
(IS_MACOS, IS_WINDOWS, TestCase, run_tests, load_tests, coalescedonoff)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, onlyCPU)
(instantiate_device_type_tests, dtypes, onlyCPU, onlyCUDA)
# load_tests from torch.testing._internal.common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
@ -18,7 +19,6 @@ class TestSparseCSR(TestCase):
self.assertEqual(str(torch.sparse_csr), 'torch.sparse_csr')
self.assertEqual(type(torch.sparse_csr), torch.layout)
@onlyCPU
@dtypes(torch.double)
def test_sparse_csr_constructor_shape_inference(self, device, dtype):
crow_indices = [0, 2, 4]
@ -32,7 +32,6 @@ class TestSparseCSR(TestCase):
self.assertEqual(dtype, sparse.dtype)
self.assertEqual(torch.device(device), sparse.device)
@onlyCPU
@dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False,
include_bfloat16=False, include_complex=False))
def test_sparse_csr_constructor(self, device, dtype):
@ -51,40 +50,178 @@ class TestSparseCSR(TestCase):
self.assertEqual(torch.tensor(col_indices, dtype=index_dtype), sparse.col_indices())
self.assertEqual(torch.tensor(values, dtype=dtype), sparse.values())
with self.assertRaises(RuntimeError):
torch.sparse_csr_tensor(crow_indices, torch.tensor(col_indices), values, size=(2, 10))
@dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False,
include_bfloat16=False, include_complex=False))
def test_sparse_csr_constructor_from_lists(self, device, dtype):
# without size
sparse = torch.sparse_csr_tensor([0, 2, 4],
[0, 1, 0, 1],
[1, 2, 3, 4],
dtype=dtype,
device=device)
@onlyCPU
@dtypes(torch.double)
def test_factory_size_check(self, device, dtype):
self.assertEqual((2, 2), sparse.shape)
self.assertEqual(4, sparse.numel())
self.assertEqual(torch.tensor([0, 2, 4], dtype=torch.int64, device=device), sparse.crow_indices())
self.assertEqual(torch.tensor([0, 1, 0, 1], dtype=torch.int64, device=device), sparse.col_indices())
self.assertEqual(torch.tensor([1, 2, 3, 4], dtype=dtype, device=device), sparse.values())
# with size
for sparse_csr_tensor in [torch.sparse_csr_tensor, torch._sparse_csr_tensor_unsafe]:
sparse = sparse_csr_tensor([0, 2, 4],
[0, 1, 0, 1],
[1, 2, 3, 4],
size=(2, 10),
dtype=dtype,
device=device)
self.assertEqual((2, 10), sparse.shape)
self.assertEqual(torch.tensor([0, 2, 4], dtype=torch.int64, device=device), sparse.crow_indices())
self.assertEqual(torch.tensor([0, 1, 0, 1], dtype=torch.int64, device=device), sparse.col_indices())
self.assertEqual(torch.tensor([1, 2, 3, 4], dtype=dtype, device=device), sparse.values())
def test_factory_type_invariants_check(self, device):
with self.assertRaisesRegex(RuntimeError, "both crow_indices and col_indices should have the same type."):
torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int64),
torch.tensor([0, 1, 0, 1], dtype=torch.int32),
torch.tensor([1, 2, 3, 4]),
device=device)
with self.assertRaisesRegex(RuntimeError, r"\"csr_construct_check\" not implemented for 'Short'"):
torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=torch.int16),
torch.tensor([0, 1, 0, 1], dtype=torch.int16),
torch.tensor([1, 2, 3, 4]),
device=device)
def test_factory_layout_invariants_check(self, device):
with self.assertRaisesRegex(RuntimeError, "expected values to be a strided and contiguous tensor"):
values = torch.tensor([1.], device=device).expand(4,)
torch.sparse_csr_tensor(torch.tensor([0, 2, 4], device=device),
torch.tensor([0, 1, 0, 1], device=device),
values)
with self.assertRaisesRegex(RuntimeError, "expected col_indices to be a strided and contiguous tensor"):
col_indices = torch.tensor([0], device=device).expand(4,)
torch.sparse_csr_tensor(torch.tensor([0, 2, 4]),
col_indices,
torch.tensor([1, 2, 3, 4]))
with self.assertRaisesRegex(RuntimeError, "expected crow_indices to be a strided and contiguous tensor"):
crow_indices = torch.arange(6, device=device)
torch.sparse_csr_tensor(crow_indices[::2],
torch.tensor([0, 1, 0, 1], device=device),
torch.tensor([1, 2, 3, 4]))
def test_factory_shape_invariants_check(self, device):
crow_indices = [0, 2, 4]
col_indices = [0, 1, 0, 1]
values = [1, 2, 3, 4]
size = (2, 10)
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), size,
dtype=dtype, device=device)
device=device)
with self.assertRaisesRegex(RuntimeError, r"size of a CSR tensor must be of length 2, but got: 3"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values),
size=(2, 10, 2),
device=device)
with self.assertRaisesRegex(RuntimeError, r"crow_indices must have dim\=1 but got crow_indices\.dim\(\)\=2"):
torch.sparse_csr_tensor(torch.tensor(crow_indices).repeat(2, 1),
torch.tensor(col_indices),
torch.tensor(values),
size,
device=device)
with self.assertRaisesRegex(RuntimeError, r"col_indices must have dim\=1 but got col_indices\.dim\(\)\=2"):
torch.sparse_csr_tensor(torch.tensor(crow_indices),
torch.tensor(col_indices).repeat(2, 1),
torch.tensor(values),
size,
device=device)
with self.assertRaisesRegex(RuntimeError, r"values must have dim\=1 but got values\.dim\(\)\=2"):
torch.sparse_csr_tensor(torch.tensor(crow_indices),
torch.tensor(col_indices),
torch.tensor(values).repeat(2, 1),
size,
device=device)
with self.assertRaisesRegex(RuntimeError,
r"crow_indices\.numel\(\) must be size\(0\) \+ 1, but got: 3"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor(values), (1, 1),
dtype=dtype, device=device)
device=device)
with self.assertRaisesRegex(RuntimeError, "0th value of crow_indices must be 0"):
torch.sparse_csr_tensor(torch.tensor([-1, -1, -1]), torch.tensor(col_indices), torch.tensor(values), size,
dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError, "last value of crow_indices should be less than length of col_indices."):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 0, 0]), torch.tensor(values), size,
dtype=dtype, device=device)
with self.assertRaisesRegex(RuntimeError,
r"col_indices and values must have equal sizes, " +
r"but got col_indices\.size\(0\): 4, values\.size\(0\): 5"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor(col_indices), torch.tensor([0, 0, 0, 0, 0]),
size, dtype=dtype, device=device)
r"but got col_indices\.numel\(\): 3, values\.numel\(\): 4"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 1, 0]), torch.tensor(values), size,
device=device)
def test_factory_indices_invariants_check(self, device):
crow_indices = [0, 2, 4]
col_indices = [0, 1, 0, 1]
values = [1, 2, 3, 4]
size = (2, 10)
with self.assertRaisesRegex(RuntimeError, "0th value of crow_indices must be 0."):
torch.sparse_csr_tensor(torch.tensor([-1, 0, 4]), torch.tensor(col_indices), torch.tensor(values), size,
device=device)
with self.assertRaisesRegex(RuntimeError,
"last value of crow_indices should be equal to the length of col_indices."):
torch.sparse_csr_tensor(torch.tensor([0, 2, 5]), torch.tensor(col_indices), torch.tensor(values), size,
device=device)
with self.assertRaisesRegex(RuntimeError,
r"at position i \= 2," +
r" this condition crow_indices\[i - 1\] <\= crow_indices\[i\] fails"):
torch.sparse_csr_tensor(torch.tensor([0, 5, 4]), torch.tensor(col_indices), torch.tensor(values), size,
device=device)
with self.assertRaisesRegex(RuntimeError, r"col_indices\.min\(\) should be greater or equal to zero"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, -1, 0, 1]), torch.tensor(values), size,
device=device)
with self.assertRaisesRegex(RuntimeError, r"size\(1\) should be greater than col_indices\.max\(\)"):
torch.sparse_csr_tensor(torch.tensor(crow_indices), torch.tensor([0, 11, 0, 1]), torch.tensor(values), size,
device=device)
@onlyCUDA
@dtypes(*torch.testing.get_all_dtypes(include_bool=False, include_half=False,
include_bfloat16=False, include_complex=False))
def test_factory_device_type_inference(self, device, dtype):
cpu_cuda = ('cpu', 'cuda')
cpu_cuda_none = cpu_cuda + (None,)
for crow_indices_device, col_indices_device, values_device, device in itertools.product(cpu_cuda,
cpu_cuda,
cpu_cuda,
cpu_cuda_none):
for index_dtype in [torch.int32, torch.int64]:
crow_indices = torch.tensor([0, 2, 4], dtype=index_dtype, device=crow_indices_device)
col_indices = torch.tensor([0, 1, 0, 1], dtype=index_dtype, device=col_indices_device)
values = torch.tensor([1, 2, 3, 4], dtype=dtype, device=values_device)
if device is None and (crow_indices_device != col_indices_device or
crow_indices_device != values_device):
with self.assertRaises(RuntimeError):
torch.sparse_csr_tensor(crow_indices,
col_indices,
values,
size=(2, 10),
device=device)
else:
t = torch.sparse_csr_tensor(crow_indices,
col_indices,
values,
size=(2, 10),
device=device)
should_be_cuda = (device == 'cuda' or (device is None and values_device == 'cuda'))
self.assertEqual(should_be_cuda, t.is_cuda)
t.crow_indices().dtype == index_dtype
t.col_indices().dtype == index_dtype
t.values().dtype == dtype
t.crow_indices().device == t.values().device
t.col_indices().device == t.values().device
@onlyCPU
@unittest.skip("see: https://github.com/pytorch/pytorch/issues/58762")
def test_sparse_csr_print(self, device):
orig_maxDiff = self.maxDiff
self.maxDiff = None
@ -106,7 +243,9 @@ class TestSparseCSR(TestCase):
for index_dtype in [torch.int32, torch.int64]:
for dtype in torch.testing.floating_types():
printed.append("########## {}/{} ##########".format(dtype, index_dtype))
x = self.genSparseCSRTensor(shape, nnz, device=device, dtype=torch.float32, index_dtype=torch.int64)
x = torch.sparse_csr_tensor(torch.tensor([0, 2, 4], dtype=index_dtype),
torch.tensor([0, 1, 0, 1], dtype=index_dtype),
torch.tensor([1, 2, 3, 4]), dtype=dtype, device=device)
printed.append("# sparse tensor")
printed.append(str(x))
printed.append("# _crow_indices")
@ -120,7 +259,6 @@ class TestSparseCSR(TestCase):
self.assertExpected('\n'.join(printed))
self.maxDiff = orig_maxDiff
@onlyCPU
def test_sparse_csr_from_dense(self, device):
dense = torch.tensor([[4, 5, 0], [0, 0, 0], [1, 0, 0]], device=device)
sparse = dense.to_sparse_csr()
@ -161,8 +299,8 @@ class TestSparseCSR(TestCase):
dense = torch.tensor([[1, 2, 1], [3, 4, 0]], dtype=dtype, device=device)
self.assertEqual(csr.to_dense(), dense)
@coalescedonoff
@onlyCPU
@coalescedonoff
@dtypes(torch.double)
def test_coo_to_csr_convert(self, device, dtype, coalesced):
size = (5, 5)

View File

@ -414,6 +414,14 @@ static PyObject * THPVariable_sparse_csr_tensor(PyObject* self, PyObject* args,
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable__sparse_csr_tensor_unsafe(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
jit::tracer::warn("torch._sparse_csr_tensor_unsafe", jit::tracer::WARN_CONSTRUCTOR);
return THPVariable_Wrap(torch::utils::_sparse_csr_tensor_unsafe_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs));
END_HANDLE_TH_ERRORS
}
static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, PyObject* kwargs)
{
HANDLE_TH_ERRORS
@ -493,9 +501,11 @@ static PyMethodDef torch_functions[] = {
{"range", castPyCFunctionWithKeywords(THPVariable_range), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"saddmm", castPyCFunctionWithKeywords(THPVariable_sspaddmm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"sparse_coo_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_coo_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_sparse_coo_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_coo_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_validate_sparse_coo_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_coo_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"sparse_csr_tensor", castPyCFunctionWithKeywords(THPVariable_sparse_csr_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_sparse_csr_tensor_unsafe", castPyCFunctionWithKeywords(THPVariable__sparse_csr_tensor_unsafe), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"_validate_sparse_csr_tensor_args", castPyCFunctionWithKeywords(THPVariable__validate_sparse_csr_tensor_args), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"spmm", castPyCFunctionWithKeywords(THPVariable_mm), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"tensor", castPyCFunctionWithKeywords(THPVariable_tensor), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},
{"get_device", castPyCFunctionWithKeywords(THPVariable_get_device), METH_VARARGS | METH_KEYWORDS | METH_STATIC, NULL},

View File

@ -291,13 +291,19 @@ def gen_pyi(native_yaml_path: str, deprecated_yaml_path: str, fm: FileManager) -
'sparse_coo_tensor': ['def sparse_coo_tensor(indices: Tensor, values: Union[Tensor,List],'
' size: Optional[_size]=None, *, dtype: Optional[_dtype]=None,'
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
'sparse_csr_tensor' : ['def sparse_csr_tensor(crow_indices: Tensor, col_indices: Tensor,'
' values: Tensor, size: Optional[_size]=None,'
'sparse_csr_tensor' : ['def sparse_csr_tensor(crow_indices: Union[Tensor, List],'
'col_indices: Union[Tensor, List],'
' values: Union[Tensor, List], size: Optional[_size]=None,'
' *, dtype: Optional[_dtype]=None,'
' device: Union[_device, str, None]=None, requires_grad:_bool=False) -> Tensor: ...'],
'_sparse_coo_tensor_unsafe': ['def _sparse_coo_tensor_unsafe(indices: Tensor, values: Tensor, size: List[int],'
' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
' requires_grad: bool = False) -> Tensor: ...'],
'_sparse_csr_tensor_unsafe': ['def _sparse_csr_tensor_unsafe(crow_indices: Union[Tensor, List],'
'col_indices: Union[Tensor, List],'
' values: Union[Tensor, List], size: List[int],'
' dtype: Optional[_dtype] = None, device: Optional[_device] = None,'
' requires_grad: bool = False) -> Tensor: ...'],
'range': ['def range(start: Number, end: Number,'
' step: Number=1, *, out: Optional[Tensor]=None, {}) -> Tensor: ...'
.format(FACTORY_PARAMS)],

View File

@ -953,10 +953,14 @@ class Tensor(torch._C._TensorBase):
while i < row_indices.size()[0] and row_indices[i] == irow:
i += 1
ro.append(i)
return torch.sparse_csr_tensor(torch.tensor(ro, dtype=row_indices.dtype),
coalesced_self.indices()[1], coalesced_self.values(),
size=coalesced_self.shape, dtype=coalesced_self.dtype)
device = coalesced_self.values().device
crow_indices = torch.tensor(ro, dtype=row_indices.dtype, device=device)
return torch.sparse_csr_tensor(crow_indices,
coalesced_self.indices()[1].contiguous(),
coalesced_self.values(),
size=coalesced_self.shape,
dtype=coalesced_self.dtype,
device=device)
elif self.is_sparse_csr:
return self
else:

View File

@ -605,6 +605,7 @@ Tensor indexing_tensor_from_data(
Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
static PythonArgParser parser({
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
"sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, *, ScalarType dtype=None, Layout? layout=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
@ -618,6 +619,10 @@ Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scal
// See https://github.com/pytorch/pytorch/issues/58520 for more details
auto rc = PyObject_GetAttrString(o, attr_name);
if (!rc) {
if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
throw python_error();
}
// Warning: a wrong attribute error may be suppressed here
PyErr_Clear();
}
return rc;
@ -642,11 +647,11 @@ Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scal
Tensor crow_indices = internal_new_from_data(values.options(),
crow_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG), r.pyobject(CROW_INDICES_ARG),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/false);
/*type_inference=*/true);
Tensor col_indices = internal_new_from_data(values.options(),
col_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG), r.pyobject(COL_INDICES_ARG),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/false);
/*type_inference=*/true);
return at::sparse_csr_tensor(crow_indices, col_indices, values, r.intlist(SIZE_ARRAY_ARG),
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
@ -663,16 +668,54 @@ Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scal
Tensor crow_indices = internal_new_from_data(values.options(),
crow_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
r.pyobject(CROW_INDICES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/false);
/*type_inference=*/true);
Tensor col_indices = internal_new_from_data(values.options(), col_indices_scalar_type, r.deviceOptional(DEVICE_TYPE_ARG),
r.pyobject(COL_INDICES_ARG), /*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/false);
/*type_inference=*/true);
return at::sparse_csr_tensor(crow_indices, col_indices, values,
values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(REQ_GRAD_ARG));
}
throw std::runtime_error("sparse_csr_tensor(): invalid arguments");
}
Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
enum {
ARG_CROW_INDICES = 0,
ARG_COL_INDICES,
ARG_VALUES,
ARG_SIZE,
ARG_TYPE,
ARG_DEVICE,
ARG_REQUIRES_GRAD,
ARGS_COUNT
};
static PythonArgParser parser({
"_sparse_csr_tensor_unsafe(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
});
ParsedArgs<ARGS_COUNT> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
bool type_inference = r.isNone(ARG_TYPE);
const auto inferred_options = typeIdWithDefault(r, ARG_DEVICE, dispatch_key);
const auto inferred_scalar_type = r.scalartypeWithDefault(ARG_TYPE, scalar_type);
at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE));
Tensor values = internal_new_from_data(inferred_options, inferred_scalar_type, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_VALUES),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/type_inference);
Tensor crow_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_CROW_INDICES),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/true);
Tensor col_indices = internal_new_from_data(values.options(), kInt, r.deviceOptional(ARG_DEVICE), r.pyobject(ARG_COL_INDICES),
/*copy_variables=*/false, /*copy_numpy=*/true,
/*type_inference=*/true);
return at::_sparse_csr_tensor_unsafe(crow_indices, col_indices, values, r.intlist(ARG_SIZE), values.options().layout(at::kSparseCsr)).set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
}
// Note [Ensuring sparse values and indices match devices]
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// In all places where we construct indices, we read out options from values
@ -696,6 +739,7 @@ Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scal
Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
static PythonArgParser parser({
"sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
"sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
@ -742,6 +786,7 @@ Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scal
Tensor _sparse_coo_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
enum {
ARG_INDICES = 0,
ARG_VALUES,
@ -789,6 +834,29 @@ void _validate_sparse_coo_tensor_args(c10::DispatchKey dispatch_key, at::ScalarT
at::native::_validate_sparse_coo_tensor_args(indices, values, r.intlist(2));
}
void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
auto options = dispatchKeyToTensorOptions(dispatch_key);
static PythonArgParser parser({
"_validate_sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)",
});
ParsedArgs<4> parsed_args;
auto r = parser.parse(args, kwargs, parsed_args);
Tensor values = internal_new_from_data(
options, scalar_type, c10::nullopt, r.pyobject(2),
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true);
// See Note [Ensuring sparse values and indices match devices]
Tensor crow_indices = internal_new_from_data(
values.options(), kInt, c10::nullopt, r.pyobject(0),
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true);
Tensor col_indices = internal_new_from_data(
values.options(), kInt, c10::nullopt, r.pyobject(1),
/*copy_variables=*/false, /*copy_numpy=*/true, /*type_inference=*/true);
at::native::_validate_sparse_csr_tensor_args(crow_indices, col_indices, values, r.intlist(3));
}
Tensor tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) {
static PythonArgParser parser({
"tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)",

View File

@ -13,10 +13,12 @@ at::Tensor indexing_tensor_from_data(
at::ScalarType scalar_type,
c10::optional<at::Device> device,
PyObject* data);
at::Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor _sparse_coo_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
void _validate_sparse_coo_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor _sparse_csr_tensor_unsafe_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
void _validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor as_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);
at::Tensor new_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs);