mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary:
Per title.
Also annotates test_torch.py with additional cleanup tasks and adds empty sample inputs to elementwise unary and binary OpInfos.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71691
Reviewed By: ngimel
Differential Revision: D33735126
Pulled By: mruberry
fbshipit-source-id: 8cc097a7581a8b620540c95b2a5889c1165ecf23
(cherry picked from commit 5c6a245a3f)
169 lines
7.2 KiB
Python
169 lines
7.2 KiB
Python
# -*- coding: utf-8 -*-
|
|
# Owner(s): ["module: scatter & gather ops"]
|
|
|
|
import random
|
|
|
|
import torch
|
|
|
|
from torch.testing import make_tensor
|
|
from torch.testing._internal.common_utils import \
|
|
(run_tests, TestCase,)
|
|
from torch.testing._internal.common_device_type import \
|
|
(instantiate_device_type_tests, dtypes, dtypesIfCUDA,
|
|
toleranceOverride, tol)
|
|
|
|
# Protects against includes accidentally setting the default dtype
|
|
assert torch.get_default_dtype() is torch.float32
|
|
|
|
|
|
# Note: test_scatter_gather_ops.py
|
|
# This test file tests scatter and gather operations,
|
|
# like torch.scatter and torch.gather.
|
|
|
|
class TestScatterGather(TestCase):
|
|
# Fills an index tensor with valid indices
|
|
def _fill_indices(self, idx, dim, dim_size, elems_per_row, m, n, o):
|
|
for i in range(1 if dim == 0 else m):
|
|
for j in range(1 if dim == 1 else n):
|
|
for k in range(1 if dim == 2 else o):
|
|
ii = [i, j, k]
|
|
ii[dim] = slice(0, idx.size(dim) + 1)
|
|
idx[tuple(ii)] = torch.randperm(dim_size)[0:elems_per_row]
|
|
|
|
@dtypes(torch.float32, torch.complex64)
|
|
def test_gather(self, device, dtype):
|
|
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
|
|
elems_per_row = random.randint(1, 10)
|
|
dim = random.randrange(3)
|
|
|
|
src = make_tensor((m, n, o), device=device, dtype=dtype)
|
|
idx_size = [m, n, o]
|
|
idx_size[dim] = elems_per_row
|
|
idx = make_tensor(idx_size, device=device, dtype=torch.long)
|
|
self._fill_indices(idx, dim, src.size(dim), elems_per_row, m, n, o)
|
|
|
|
actual = torch.gather(src, dim, idx)
|
|
expected = torch.zeros(idx_size, device=device, dtype=dtype)
|
|
for i in range(idx_size[0]):
|
|
for j in range(idx_size[1]):
|
|
for k in range(idx_size[2]):
|
|
ii = [i, j, k]
|
|
ii[dim] = idx[i, j, k]
|
|
expected[i, j, k] = src[tuple(ii)]
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
# Guarded because torch.max isn't defined for complex types
|
|
if not dtype.is_complex:
|
|
src = make_tensor((3, 4, 5), device=device, dtype=dtype)
|
|
expected, idx = src.max(2, True)
|
|
actual = torch.gather(src, 2, idx)
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
@dtypes(torch.bool)
|
|
def test_gather_bool(self, device, dtype):
|
|
src = torch.tensor(((False, True), (True, True)), device=device, dtype=dtype)
|
|
idx = torch.tensor(((0, 0), (1, 0)), device=device, dtype=torch.long)
|
|
actual = torch.gather(src, 1, idx)
|
|
expected = torch.tensor(((False, False), (True, True)), device=device, dtype=dtype)
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
def _test_scatter_base(self, fn, *, device, dtype, is_scalar, reduction):
|
|
m, n, o = random.randint(10, 20), random.randint(10, 20), random.randint(10, 20)
|
|
elems_per_row = random.randint(1, 10)
|
|
dim = random.randrange(3)
|
|
|
|
idx_size = [m, n, o]
|
|
idx_size[dim] = elems_per_row
|
|
idx = torch.empty(tuple(idx_size), device=device, dtype=torch.long)
|
|
self._fill_indices(idx, dim, ([m, n, o])[dim], elems_per_row, m, n, o)
|
|
|
|
if is_scalar:
|
|
src = random.random()
|
|
else:
|
|
src_size = [random.randint(1, 5) + s for s in idx_size]
|
|
src = make_tensor(tuple(src_size), device=device, dtype=dtype)
|
|
|
|
base = make_tensor((m, n, o), device=device, dtype=dtype)
|
|
if reduction is not None:
|
|
actual = fn(base.clone(), dim, idx, src, reduce=reduction)
|
|
else:
|
|
actual = fn(base.clone(), dim, idx, src)
|
|
|
|
expected = base.clone()
|
|
for i in range(idx_size[0]):
|
|
for j in range(idx_size[1]):
|
|
for k in range(idx_size[2]):
|
|
ii = [i, j, k]
|
|
ii[dim] = idx[i, j, k]
|
|
if fn is torch.Tensor.scatter_add_:
|
|
expected[tuple(ii)] += src[i, j, k]
|
|
else:
|
|
# method may be 'scatter_' or 'scatter'
|
|
# both might have a reduction argument
|
|
value = src if is_scalar else src[i, j, k]
|
|
|
|
if reduction == "add":
|
|
expected[tuple(ii)] += value
|
|
elif reduction == "multiply":
|
|
expected[tuple(ii)] *= value
|
|
else:
|
|
expected[tuple(ii)] = value
|
|
|
|
self.assertEqual(actual, expected, atol=0, rtol=0)
|
|
|
|
# Tests empty index
|
|
dst = make_tensor((2, 2), device=device, dtype=dtype)
|
|
idx = torch.tensor((), device=device, dtype=torch.long)
|
|
src = make_tensor((2, 2), device=device, dtype=dtype)
|
|
if reduction is not None:
|
|
actual = fn(dst, 0, idx, src, reduce=reduction)
|
|
else:
|
|
actual = fn(dst, 0, idx, src)
|
|
self.assertEqual(actual, dst, atol=0, rtol=0)
|
|
|
|
@dtypes(torch.float16, torch.float32, torch.complex64)
|
|
def test_scatter_(self, device, dtype):
|
|
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction=None)
|
|
|
|
@dtypes(torch.float16, torch.float32, torch.complex64)
|
|
def test_scatter__scalar(self, device, dtype):
|
|
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
|
is_scalar=True, reduction=None)
|
|
|
|
# FIXME: RuntimeError: "cuda_scatter_gather_base_kernel_reduce_multiply" not implemented for 'ComplexFloat'
|
|
@toleranceOverride({torch.float16: tol(atol=1e-2, rtol=0)})
|
|
@dtypesIfCUDA(torch.float16, torch.float32)
|
|
@dtypes(torch.float16, torch.float32, torch.complex64)
|
|
def test_scatter__reductions(self, device, dtype):
|
|
for reduction in ("add", "multiply"):
|
|
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction=reduction)
|
|
self._test_scatter_base(torch.Tensor.scatter_, device=device, dtype=dtype,
|
|
is_scalar=True, reduction=reduction)
|
|
|
|
@dtypes(torch.float16, torch.float32, torch.complex64)
|
|
def test_scatter_add_(self, device, dtype):
|
|
self._test_scatter_base(torch.Tensor.scatter_add_, device=device, dtype=dtype,
|
|
is_scalar=False, reduction=None)
|
|
|
|
@dtypes(torch.float32)
|
|
def test_scatter_add_mult_index_base(self, device, dtype):
|
|
m, n = 30, 40
|
|
idx = torch.zeros(m, n, device=device, dtype=torch.long)
|
|
src = torch.ones(m, n, device=device, dtype=dtype)
|
|
res0 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(0, idx, src)
|
|
res1 = torch.zeros(m, n, device=device, dtype=dtype).scatter_add_(1, idx, src)
|
|
|
|
self.assertEqual(res0[0, :], m * torch.ones(n, device=device, dtype=dtype), atol=0, rtol=0)
|
|
self.assertEqual(res1[:, 0], n * torch.ones(m, device=device, dtype=dtype), atol=0, rtol=0)
|
|
|
|
|
|
# Generic Device Test Framework instantation, see
|
|
# https://github.com/pytorch/pytorch/wiki/Running-and-writing-tests
|
|
# for details.
|
|
instantiate_device_type_tests(TestScatterGather, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|