mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support broadcast_to on sparse COO tensors (#71073)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/71073 cc nikitaved pearu cpuhrsch Test Plan: Imported from OSS Reviewed By: mikaylagawarecki Differential Revision: D33645744 Pulled By: cpuhrsch fbshipit-source-id: 4775c9636c4e868022a8c1bbfec93e351d1cf885
This commit is contained in:
parent
d1e72b144a
commit
640f21e09a
|
|
@ -203,6 +203,11 @@ Tensor FunctionalInverses::values_inverse(const Tensor& base, const Tensor& muta
|
|||
return Tensor();
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::_sparse_broadcast_to_inverse(const Tensor& base, const Tensor& mutated_view, at::IntArrayRef size) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Attempted to call _sparse_broadcast_to() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
|
||||
return Tensor();
|
||||
}
|
||||
|
||||
Tensor FunctionalInverses::crow_indices_inverse(const at::Tensor& base, const at::Tensor& mutated_view) {
|
||||
TORCH_INTERNAL_ASSERT(false, "Attempted to call crow_indices() during the functionalization pass. For now, sparse tensors aren't supported during functionalization");
|
||||
return Tensor();
|
||||
|
|
|
|||
|
|
@ -87,6 +87,82 @@ Tensor& set_cpu_(Tensor& result) {
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor sparse_broadcast_to(const Tensor& self, IntArrayRef size) {
|
||||
TORCH_CHECK(self.is_sparse(), "input must be sparse tensor");
|
||||
int64_t sparse_extra_ndim = size.size() - self.dim();
|
||||
int64_t sparse_ndim = size.size() - self.dense_dim();
|
||||
TORCH_CHECK(sparse_extra_ndim >= 0, "input not broadcastable to size with smaller dimensionality");
|
||||
Tensor indices = self._indices();
|
||||
Tensor values = self._values();
|
||||
auto nnz = values.size(0);
|
||||
|
||||
std::vector<int64_t> broadcast_sizes;
|
||||
std::vector<int64_t> broadcast_dense_sizes;
|
||||
std::vector<int64_t> broadcast_dims;
|
||||
std::vector<int64_t> unchanged_dims;
|
||||
broadcast_sizes.reserve(sparse_ndim);
|
||||
broadcast_dense_sizes.reserve(self.dense_dim() + 1);
|
||||
broadcast_dims.reserve(self.sparse_dim());
|
||||
unchanged_dims.reserve(self.sparse_dim());
|
||||
int64_t nnz_factor = 1;
|
||||
int64_t min_broadcast_dim = (sparse_extra_ndim > 0 ? 0: -1);
|
||||
int64_t max_unchanged_dim = -1;
|
||||
for (int64_t i=0; i<sparse_extra_ndim; i++) {
|
||||
auto d = size[i];
|
||||
nnz_factor *= d;
|
||||
broadcast_sizes.emplace_back(d);
|
||||
}
|
||||
for (int64_t i=0; i<self.sparse_dim(); i++) {
|
||||
auto d = size[sparse_extra_ndim + i];
|
||||
if (self.size(i) != d) {
|
||||
TORCH_CHECK(self.size(i) == 1,
|
||||
"The expanded size of the tensor (",size[sparse_extra_ndim + i],") ",
|
||||
"must match the existing size (",self.size(i),")");
|
||||
nnz_factor *= d;
|
||||
broadcast_sizes.emplace_back(d);
|
||||
if (min_broadcast_dim == -1) {
|
||||
min_broadcast_dim = sparse_extra_ndim + i;
|
||||
}
|
||||
broadcast_dims.emplace_back(i);
|
||||
} else {
|
||||
unchanged_dims.emplace_back(i);
|
||||
max_unchanged_dim = sparse_extra_ndim + i;
|
||||
}
|
||||
}
|
||||
// to_broadcast conserves is_coalesced property iff only the last
|
||||
// sparse dimensions are expaned. Possible expansion of dense
|
||||
// dimensions can be discarded as it does not affect the is_coalesce
|
||||
// property.
|
||||
bool is_coalesced = self.dim()==0 || (self.is_coalesced() && (max_unchanged_dim < min_broadcast_dim || min_broadcast_dim == -1));
|
||||
|
||||
broadcast_dense_sizes.emplace_back(nnz);
|
||||
for (int64_t i=0; i<self.dense_dim(); i++) {
|
||||
broadcast_dense_sizes.emplace_back(size[sparse_extra_ndim + self.sparse_dim() + i]);
|
||||
}
|
||||
|
||||
std::vector<int64_t> new_indices_size{sparse_ndim, nnz * nnz_factor};
|
||||
std::vector<int64_t> new_values_size(values.sizes().vec());
|
||||
new_values_size[0] = new_indices_size[1];
|
||||
|
||||
Tensor new_values = values.expand(broadcast_dense_sizes).repeat_interleave(nnz_factor, 0);
|
||||
Tensor new_indices = at::native::new_empty(indices, new_indices_size);
|
||||
if (broadcast_sizes.size()>0) {
|
||||
// ones(broadcast_sizes).nonzero() is equivalent to
|
||||
// product(map(arange, broadcast_sizes)) but avoids creating
|
||||
// auxilary arange tensors
|
||||
Tensor broadcast_indices = at::native::new_ones(indices, broadcast_sizes).nonzero().transpose(0, 1).tile(nnz);
|
||||
new_indices.narrow(0, 0, sparse_extra_ndim).copy_(broadcast_indices.narrow(0, 0, sparse_extra_ndim));
|
||||
for (size_t i=0; i<broadcast_dims.size(); i++) {
|
||||
int64_t j=broadcast_dims[i];
|
||||
new_indices.select(0, sparse_extra_ndim + j).copy_(broadcast_indices.select(0, sparse_extra_ndim + i));
|
||||
}
|
||||
}
|
||||
for (int64_t j:unchanged_dims) {
|
||||
new_indices.select(0, sparse_extra_ndim + j).copy_(indices.select(0, j).repeat_interleave(nnz_factor));
|
||||
}
|
||||
return at::sparse_coo_tensor(new_indices, new_values, size)._coalesced_(is_coalesced);
|
||||
}
|
||||
|
||||
Tensor broadcast_to(const Tensor& self, IntArrayRef size) {
|
||||
return self.expand(size);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1070,6 +1070,11 @@
|
|||
- func: broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
|
||||
variants: function, method
|
||||
|
||||
- func: _sparse_broadcast_to(Tensor(a) self, int[] size) -> Tensor(a)
|
||||
variants: function
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_broadcast_to
|
||||
|
||||
- func: cat(Tensor[] tensors, int dim=0) -> Tensor
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: cat
|
||||
|
|
|
|||
|
|
@ -3359,6 +3359,51 @@ class TestSparse(TestCase):
|
|||
with self.assertRaisesRegex(NotImplementedError, "CUDA"):
|
||||
t23 * s
|
||||
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_full_broadcast_to(self, device, dtype):
|
||||
def can_broadcast(s0, s1):
|
||||
s0 = tuple(reversed(s0))
|
||||
s1 = tuple(reversed(s1))
|
||||
for i in range(len(s0)):
|
||||
if s0[i] != 1 and s0[i] != s1[i]:
|
||||
return False
|
||||
return True
|
||||
sizes = (
|
||||
(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)
|
||||
)
|
||||
for s0, s1 in itertools.combinations(sizes, r=2):
|
||||
t = make_tensor(s0, device, dtype, low=-9, high=9)
|
||||
for sparse_dims in range(1, len(s0) + 1):
|
||||
s = t.to_sparse(sparse_dims)
|
||||
if can_broadcast(s0, s1):
|
||||
t_res = torch.broadcast_to(t, s1)
|
||||
s_res = torch._sparse_broadcast_to(s, s1)
|
||||
torch._validate_sparse_coo_tensor_args(s_res._indices(), s_res._values(), s_res.shape)
|
||||
if s_res.is_coalesced():
|
||||
# ensure that is_coalesced is estimated correctly
|
||||
self.assertEqual(s_res, torch.sparse_coo_tensor(s_res._indices(), s_res._values(), s_res.shape).coalesce())
|
||||
self.assertEqual(s_res.to_dense(), t_res)
|
||||
else:
|
||||
with self.assertRaisesRegex(RuntimeError,
|
||||
r"The expanded size of the tensor \(\d\) "
|
||||
r"must match the existing size \(\d\)"):
|
||||
torch._sparse_broadcast_to(s, s1)
|
||||
|
||||
@coalescedonoff
|
||||
@dtypes(torch.double, torch.cdouble)
|
||||
def test_sparse_broadcast_to(self, device, dtype, coalesced):
|
||||
def test(sparse_dims, nnz, with_size, new_size):
|
||||
x = self._gen_sparse(sparse_dims, nnz, with_size, dtype, device, coalesced)[0]
|
||||
y = self.safeToDense(x)
|
||||
x1 = torch._sparse_broadcast_to(x, new_size)
|
||||
y1 = y.broadcast_to(new_size)
|
||||
self.assertEqual(self.safeToDense(x1), y1)
|
||||
|
||||
test(4, 6, [7, 3, 1, 3, 0], [7, 3, 4, 3, 0])
|
||||
test(4, 6, [7, 3, 1, 3, 0], [2, 7, 3, 1, 3, 0])
|
||||
test(4, 6, [7, 3, 1, 3, 1, 3], [7, 3, 1, 3, 2, 3])
|
||||
test(4, 6, [7, 3, 1, 3, 2, 1], [7, 3, 1, 3, 2, 3])
|
||||
|
||||
|
||||
class TestSparseOneOff(TestCase):
|
||||
@unittest.skipIf(not TEST_CUDA, 'CUDA not available')
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user