mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/75808 Just as it is often difficult to write a single kernel that can handle both CPU and CUDA, so can it be difficult to do the same for NestedTensor. ghstack-source-id: 154171542 (Note: this ignores all push blocking failures!) Test Plan: CI? Reviewed By: bdhirsh Differential Revision: D35603836 fbshipit-source-id: fb0ebb19d34531ed96ce176aca325f8e2b5f90e6 (cherry picked from commit 0bcd753f93c04256c1b745f84a74ecccf0dceef5)
296 lines
11 KiB
Python
296 lines
11 KiB
Python
# Owner(s): ["module: nestedtensor"]
|
|
|
|
import torch
|
|
import torch.nn
|
|
import unittest
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
dtypesIfCUDA,
|
|
instantiate_device_type_tests,
|
|
skipMeta,
|
|
)
|
|
from torch.testing._internal.common_utils import TestCase, IS_FBCODE
|
|
from torch import nested_tensor
|
|
|
|
# Tests are ported from pytorch/nestedtensor.
|
|
# This makes porting as_nested_tensor easier in the future.
|
|
def _iter_constructors():
|
|
# yield as_nested_tensor
|
|
yield nested_tensor
|
|
|
|
|
|
class TestNestedTensor(TestCase):
|
|
@torch.inference_mode()
|
|
def _test_unbind_case(self, a, b):
|
|
nt = nested_tensor([a, b])
|
|
a1, b1 = nt.unbind()
|
|
self.assertTrue(a is not a1)
|
|
self.assertTrue(b is not b1)
|
|
|
|
nt = nested_tensor([a, b], dtype=a.dtype)
|
|
a1, b1 = nt.unbind(0)
|
|
self.assertEqual(a, a1)
|
|
self.assertEqual(b, b1)
|
|
|
|
a = torch.randn((2, 3)).add_(1)
|
|
nt = nested_tensor([a])
|
|
self.assertEqual(a, nt.unbind(0)[0])
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_0(self):
|
|
self._test_unbind_case(
|
|
torch.tensor([1, 2]), torch.tensor([7, 8]),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_1(self):
|
|
self._test_unbind_case(
|
|
torch.tensor([1]), torch.tensor([7]),
|
|
)
|
|
|
|
# @torch.inference_mode()
|
|
# def test_unbind_2(self):
|
|
# self._test_unbind_case(
|
|
# torch.tensor(1), torch.tensor(7),
|
|
# )
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_3(self):
|
|
self._test_unbind_case(
|
|
torch.tensor([1.0]), torch.tensor([]),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_4(self):
|
|
self._test_unbind_case(
|
|
torch.tensor([]), torch.tensor([]),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_dim(self):
|
|
def _test_fn(unbind_fn):
|
|
a = torch.rand(3, 2)
|
|
b = torch.rand(2, 3)
|
|
nt = nested_tensor([a, b])
|
|
self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
|
|
|
|
# Both of these tests are necessary, because we're using
|
|
# torch_function.
|
|
_test_fn(lambda x, dim: x.unbind(dim))
|
|
# TODO: Re-enable this once using torch_dispatch
|
|
# _test_fn(lambda x, dim: torch.unbind(x, dim))
|
|
|
|
@torch.inference_mode()
|
|
def test_nested_tensor(self):
|
|
self.assertRaises(TypeError, lambda: nested_tensor([3.0]))
|
|
self.assertRaises(TypeError, lambda: nested_tensor(torch.tensor([3.0])))
|
|
self.assertRaises(TypeError, lambda: nested_tensor(4.0))
|
|
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_matching_dim(self):
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
|
|
lambda: nested_tensor([torch.tensor(1.0), torch.tensor([])]),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
|
|
lambda: nested_tensor(
|
|
[torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
|
|
),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_default_nested_tensor(self):
|
|
self.assertRaises(TypeError, lambda: nested_tensor())
|
|
default_nested_tensor = nested_tensor([])
|
|
default_tensor = torch.tensor([])
|
|
# self.assertEqual(default_nested_tensor.nested_dim(), 1)
|
|
# self.assertEqual(default_nested_tensor.nested_size(), ())
|
|
self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
|
|
self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
|
|
self.assertEqual(default_nested_tensor.device, default_tensor.device)
|
|
self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
|
|
self.assertEqual(
|
|
default_nested_tensor.requires_grad, default_tensor.requires_grad
|
|
)
|
|
self.assertIsNone(default_tensor.grad)
|
|
# TODO: Re-enable once we have a performance driven
|
|
# use case and implementation.
|
|
# self.assertEqual(default_nested_tensor.is_pinned(),
|
|
# default_tensor.is_pinned())
|
|
|
|
@torch.inference_mode()
|
|
def test_dim(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertEqual(a1.dim(), 1)
|
|
a1 = constructor([torch.tensor(3.0)])
|
|
self.assertEqual(a1.dim(), 1)
|
|
a1 = constructor([torch.tensor([1, 2, 3, 4])])
|
|
self.assertEqual(a1.dim(), 2)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
|
|
@torch.inference_mode()
|
|
def test_numel(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "numel is disabled", lambda: a1.numel(),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_size(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Tensors of type NestedTensorImpl do not have sizes"
|
|
if IS_FBCODE
|
|
else "NestedTensorImpl doesn't support sizes",
|
|
lambda: a1.size(),
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
|
|
@torch.inference_mode()
|
|
def test_stride(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"NestedTensorImpl doesn't support strides",
|
|
lambda: a1.stride(),
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
|
|
@torch.inference_mode()
|
|
def test_is_contiguous(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "is_contiguous is disabled", lambda: a1.is_contiguous()
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_repr_string(self):
|
|
a = nested_tensor([])
|
|
expected = "nested_tensor([" "\n\n])"
|
|
self.assertEqual(str(a), expected)
|
|
self.assertEqual(repr(a), expected)
|
|
|
|
a = nested_tensor([torch.tensor(1.0)])
|
|
expected = "nested_tensor([" "\n tensor(1.)" "\n])"
|
|
self.assertEqual(str(a), expected)
|
|
self.assertEqual(repr(a), expected)
|
|
|
|
a = nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
|
|
expected = (
|
|
"nested_tensor([" "\n tensor([[1, 2]])" "," "\n tensor([[4, 5]])" "\n])"
|
|
)
|
|
self.assertEqual(str(a), expected)
|
|
self.assertEqual(repr(a), expected)
|
|
|
|
@torch.inference_mode()
|
|
def test_activations(self):
|
|
for func in (torch.nn.functional.relu, torch.nn.functional.relu_, torch.nn.functional.gelu, torch._C._nn.gelu_):
|
|
t = torch.tensor([-1, 0, 1], dtype=torch.float)
|
|
nt = nested_tensor([t])
|
|
nested_result = func(nt)
|
|
self.assertTrue(nested_result.is_nested)
|
|
self.assertEqual(func(t), nested_result.unbind()[0])
|
|
|
|
def test_to_padded_tensor_on_empty_tensor(self):
|
|
nt = torch.nested_tensor([])
|
|
empty = nt.to_padded_tensor(4)
|
|
self.assertEqual(empty, torch.tensor([]))
|
|
|
|
class TestNestedTensorDeviceType(TestCase):
|
|
@dtypes(torch.float)
|
|
@skipMeta
|
|
def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype):
|
|
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
|
|
ts = list(torch.unbind(t))
|
|
ts[0] = ts[0][:-1]
|
|
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
|
|
padded = nt.to_padded_tensor(0)
|
|
|
|
nt_to = torch._nested_from_padded_and_nested_example(padded, nt)
|
|
|
|
for (t1, t2) in zip(nt.unbind(), nt_to.unbind()):
|
|
self.assertEqual(t1, t2)
|
|
self.assertEqual(nt.device, nt_to.device)
|
|
|
|
@dtypes(torch.float)
|
|
@dtypesIfCUDA(torch.float, torch.half)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_layer_norm(self, device, dtype):
|
|
def _test(size):
|
|
t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
|
|
t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
|
|
ts = [t0, t1, t0, t1]
|
|
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
|
|
layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
|
|
nt_result = nt._nested_tensor_layer_norm(
|
|
layer_norm.weight, layer_norm.bias, 1e-5
|
|
)
|
|
for (nt_subresult, t) in zip(nt_result.unbind(), ts):
|
|
t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
|
|
self.assertEqual(nt_subresult, t_result)
|
|
|
|
for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32):
|
|
_test(size)
|
|
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_embedding(self, device):
|
|
inputs = [
|
|
torch.randint(100, (L,), device=device, dtype=torch.int64)
|
|
for L in torch.randint(5, 50, (8,))
|
|
]
|
|
x = torch.nested_tensor(inputs, device=device, dtype=torch.int64)
|
|
emb = torch.nn.Embedding(100, 8, device=device)
|
|
y = emb(x)
|
|
ys = y.unbind()
|
|
for i, inp in enumerate(inputs):
|
|
self.assertEqual(emb(inp), ys[i])
|
|
|
|
def test_to_padded_tensor_simple(self, device):
|
|
t = torch.randn(4, 4, 4, device=device)
|
|
ts = list(torch.unbind(t))
|
|
ts[0] = ts[0][:-1]
|
|
nt = torch.nested_tensor(ts, device=device)
|
|
for padding_value in (0, 1):
|
|
padded = nt.to_padded_tensor(padding_value)
|
|
|
|
correct_output = t.clone()
|
|
if padding_value == 0:
|
|
correct_output[0][-1] = torch.zeros_like(correct_output[0][-1])
|
|
else:
|
|
correct_output[0][-1] = torch.ones_like(correct_output[0][-1])
|
|
|
|
self.assertEqual(padded, correct_output)
|
|
self.assertEqual(padded.device, torch.device(device))
|
|
|
|
def test_to_padded_tensor_unrelated_shapes(self, device):
|
|
ts = [
|
|
torch.randn(1, 2, 3, device=device),
|
|
torch.randn(2, 3, 4, device=device),
|
|
torch.randn(4, 5, 6, device=device),
|
|
]
|
|
nt = torch.nested_tensor(ts, device=device)
|
|
pad = 42
|
|
correct_output = torch.cat(
|
|
[torch.nn.ConstantPad3d((0, 6 - x.shape[2], 0, 5 - x.shape[1], 0, 4 - x.shape[0]), pad)(x.unsqueeze(0)) for x in ts])
|
|
padded = nt.to_padded_tensor(pad)
|
|
self.assertEqual(padded, correct_output)
|
|
|
|
@skipMeta
|
|
def test_device_checks(self, device):
|
|
nt = torch.nested_tensor([], device=device)
|
|
is_cuda = 'cuda' in str(device)
|
|
self.assertEqual(nt.is_cuda, is_cuda)
|
|
|
|
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|