mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support this fallback by converting the jagged layout NT to strided layout NT, and the convert the result back to jagged layout NT. This fallback might not be efficient since it uses unbind, contiguous and split. Pull Request resolved: https://github.com/pytorch/pytorch/pull/116445 Approved by: https://github.com/soulitzer
3762 lines
167 KiB
Python
3762 lines
167 KiB
Python
# Owner(s): ["module: nestedtensor"]
|
|
|
|
import io
|
|
import itertools
|
|
import sys
|
|
from typing import Optional, Tuple
|
|
import unittest
|
|
from functools import partial
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn
|
|
import torch.nn.functional as F
|
|
from torch.testing._internal.common_cuda import SM80OrLater
|
|
from torch.testing._internal.common_device_type import (
|
|
dtypes,
|
|
dtypesIfCUDA,
|
|
instantiate_device_type_tests,
|
|
onlyCPU,
|
|
onlyCUDA,
|
|
skipMeta,
|
|
PYTORCH_CUDA_MEMCHECK,
|
|
)
|
|
from torch.testing._internal.common_dtype import floating_types_and_half
|
|
from torch.testing._internal.common_utils import (
|
|
decorateIf,
|
|
freeze_rng_state,
|
|
gradcheck,
|
|
instantiate_parametrized_tests,
|
|
IS_FBCODE,
|
|
parametrize,
|
|
run_tests,
|
|
skipIfSlowGradcheckEnv,
|
|
markDynamoStrictTest,
|
|
xfailIfTorchDynamo,
|
|
subtest,
|
|
TEST_WITH_ROCM,
|
|
TestCase,
|
|
)
|
|
|
|
from torch.nested._internal.nested_tensor import (
|
|
buffer_from_jagged,
|
|
jagged_from_list,
|
|
NestedTensor,
|
|
)
|
|
|
|
# Tests are ported from pytorch/nestedtensor.
|
|
# This makes porting as_nested_tensor easier in the future.
|
|
|
|
|
|
def _iter_constructors():
|
|
# yield as_nested_tensor
|
|
yield torch.nested.nested_tensor
|
|
|
|
# Helper function to generate a pair of random nested tensors
|
|
# one is contiguous, the other is not, but they appear to have same entries
|
|
# an output nested tensor consists of
|
|
# * `len(ragged_sizes)` matrices
|
|
# * matrices[i].shape == (20, ragged_sizes[i])
|
|
|
|
|
|
def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16):
|
|
xs = []
|
|
for size in ragged_sizes:
|
|
xs.append(torch.randn((size, 20), device=device, dtype=dtype))
|
|
# contiguous nested tensor
|
|
ys = []
|
|
for x in xs:
|
|
ys.append(x.transpose(-1, -2))
|
|
nt_contiguous = torch.nested.nested_tensor(ys)
|
|
# noncontiguous nested tensor
|
|
n = len(ragged_sizes)
|
|
nt_noncontiguous = torch.nested.nested_tensor(xs).transpose(-1, -2)
|
|
return nt_contiguous, nt_noncontiguous
|
|
|
|
# Helper functions to pad a noncontiguous nested tensor
|
|
# can be replaced once to_padded_tensor supports noncontiguous memory
|
|
|
|
|
|
def noncontiguous_to_padded_tensor(input, shape=None):
|
|
tensors = input.unbind()
|
|
ntensors = len(tensors)
|
|
assert ntensors > 0
|
|
if shape is None:
|
|
shape = []
|
|
for size in tensors[0].shape:
|
|
shape.append(size)
|
|
for i in range(1, ntensors):
|
|
new_shape = tensors[i].shape
|
|
for j in range(len(shape)):
|
|
shape[j] = max(shape[j], new_shape[j])
|
|
shape = [ntensors] + shape
|
|
result = tensors[0].new_zeros(shape)
|
|
for itensor in range(ntensors):
|
|
tensor = tensors[itensor]
|
|
view = result[itensor]
|
|
for idim in range(tensor.dim()):
|
|
view = view.narrow(idim, 0, tensor.size(idim))
|
|
view.copy_(tensor)
|
|
return result
|
|
|
|
# Helper function to generate a random nested tensor
|
|
|
|
|
|
def random_nt(device, dtype, num_tensors, max_dims, min_dims=None, layout=torch.strided, require_non_empty=True):
|
|
if min_dims is None:
|
|
min_dims = tuple([0] * len(max_dims))
|
|
|
|
assert len(max_dims) == len(min_dims)
|
|
for min_dim, max_dim in zip(min_dims, max_dims):
|
|
assert max_dim > min_dim, "random_nt: max_dim must be greater than min_dim"
|
|
assert min_dim >= 0, "random_nt: min_dim must be non-negative"
|
|
if require_non_empty:
|
|
assert not (min_dim == 0 and max_dim == 1), (
|
|
"random_nt: zero cannot be the only possible value if require_non_empty is True"
|
|
)
|
|
|
|
if require_non_empty:
|
|
# Select a random idx that will be required to be non-empty
|
|
non_zero_idx = torch.randint(low=0, high=num_tensors, size=(1,)).item()
|
|
|
|
ts1 = []
|
|
for i, _ in enumerate(range(num_tensors)):
|
|
tensor_dims = []
|
|
for min_dim, max_dim in zip(min_dims, max_dims):
|
|
new_min_dim = min_dim
|
|
if require_non_empty and i == non_zero_idx and min_dim == 0:
|
|
new_min_dim = 1
|
|
tensor_dims.append(torch.randint(low=new_min_dim, high=max_dim, size=(1,)).item())
|
|
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
|
|
ts1.append(t1)
|
|
|
|
return torch.nested.nested_tensor(ts1, device=device, dtype=dtype, layout=layout)
|
|
|
|
|
|
# Alternate approach to generating a random NT.
|
|
# dims should be something like [5, None, 10], with None indicating that a
|
|
# random ragged structure should be used
|
|
def random_nt_from_dims(dims, device=None, dtype=None, layout=torch.strided, requires_grad=False):
|
|
sizes = [
|
|
[d if d is not None else torch.randint(2, 10, size=(1,)).item() for d in dims[1:]]
|
|
for d in range(dims[0])
|
|
]
|
|
return torch.nested.nested_tensor([
|
|
torch.randn(*size) for size in sizes
|
|
], device=device, dtype=dtype, layout=layout, requires_grad=requires_grad)
|
|
|
|
|
|
# Creates an NT matching another NT's number of components and
|
|
# shape / ragged structure for all dims specified to be -1.
|
|
def random_nt_from_similar(other, dims=None):
|
|
if dims is None:
|
|
return torch.randn_like(other)
|
|
assert len(dims) == other.dim()
|
|
assert dims[0] == -1 or dims[0] == other.size(0)
|
|
|
|
ret_sizes = []
|
|
for t in other.unbind():
|
|
other_size = t.shape
|
|
ret_size = []
|
|
for i, d in enumerate(dims[1:]):
|
|
if d == -1:
|
|
ret_size.append(other_size[i])
|
|
else:
|
|
ret_size.append(d)
|
|
ret_sizes.append(ret_size)
|
|
|
|
return torch.nested.nested_tensor([
|
|
torch.randn(*size) for size in ret_sizes
|
|
], device=other.device)
|
|
|
|
|
|
# makes naming nice for tests that parametrize over layout.
|
|
def layout_name(layout):
|
|
# e.g. "torch.jagged" -> "jagged"
|
|
return layout.__repr__().split(".")[-1]
|
|
|
|
|
|
@markDynamoStrictTest
|
|
class TestNestedTensor(TestCase):
|
|
@parametrize("batch_size", [2, 4])
|
|
@parametrize("max_seq_len", [3, 5])
|
|
@parametrize("vocab_size", [10, 20])
|
|
def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
|
|
data = []
|
|
nested_tensor_ref_list = []
|
|
for _ in range(batch_size):
|
|
if max_seq_len == 0:
|
|
length = 0
|
|
else:
|
|
length = np.random.randint(low=1, high=max_seq_len)
|
|
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
|
|
data.append(row)
|
|
nested_tensor_ref_list.append(torch.Tensor(row))
|
|
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
|
|
nested_tensor_list = nested_tensor.unbind()
|
|
for id in range(batch_size):
|
|
self.assertEqual(
|
|
nested_tensor_list[id],
|
|
nested_tensor_ref_list[id].type(torch.int64)
|
|
)
|
|
|
|
@parametrize("batch_size", [2, 4])
|
|
@parametrize("max_seq_len", [3, 5])
|
|
@parametrize("vocab_size", [10, 20])
|
|
def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
|
|
data = []
|
|
nested_tensor_ref_list = []
|
|
for _ in range(batch_size):
|
|
if max_seq_len == 0:
|
|
length = 0
|
|
else:
|
|
length = np.random.randint(low=1, high=max_seq_len)
|
|
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
|
|
row = [list(item * np.arange(max_seq_len)) for item in row]
|
|
data.append(row)
|
|
nested_tensor_ref_list.append(torch.Tensor(row))
|
|
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
|
|
nested_tensor_list = nested_tensor.unbind()
|
|
for id in range(batch_size):
|
|
self.assertEqual(
|
|
nested_tensor_list[id],
|
|
nested_tensor_ref_list[id].type(torch.int64)
|
|
)
|
|
|
|
@parametrize("batch_size", [2, 4])
|
|
@parametrize("max_seq_len", [3, 5])
|
|
@parametrize("vocab_size", [10, 20])
|
|
def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size):
|
|
data = []
|
|
nested_tensor_ref_list = []
|
|
for _ in range(batch_size):
|
|
if max_seq_len == 0:
|
|
length = 0
|
|
else:
|
|
length = np.random.randint(low=1, high=max_seq_len)
|
|
row = list(
|
|
np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float)
|
|
)
|
|
row = [list(item * np.arange(max_seq_len)) for item in row]
|
|
data.append(row)
|
|
nested_tensor_ref_list.append(torch.Tensor(row))
|
|
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float)
|
|
nested_tensor_list = nested_tensor.unbind()
|
|
for id in range(batch_size):
|
|
self.assertEqual(
|
|
nested_tensor_list[id],
|
|
nested_tensor_ref_list[id].type(torch.float)
|
|
)
|
|
|
|
|
|
@torch.inference_mode()
|
|
def _test_unbind_case(self, a, b):
|
|
nt = torch.nested.nested_tensor([a, b])
|
|
a1, b1 = nt.unbind()
|
|
self.assertTrue(a is not a1)
|
|
self.assertTrue(b is not b1)
|
|
|
|
nt = torch.nested.nested_tensor([a, b], dtype=a.dtype)
|
|
a1, b1 = nt.unbind(0)
|
|
self.assertEqual(a, a1)
|
|
self.assertEqual(b, b1)
|
|
|
|
a = torch.randn((2, 3)).add_(1)
|
|
nt = torch.nested.nested_tensor([a])
|
|
self.assertEqual(a, nt.unbind(0)[0])
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_0(self):
|
|
self._test_unbind_case(
|
|
torch.tensor([1, 2]), torch.tensor([7, 8]),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_1(self):
|
|
self._test_unbind_case(
|
|
torch.tensor([1]), torch.tensor([7]),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_3(self):
|
|
self._test_unbind_case(
|
|
torch.tensor([1.0]), torch.tensor([]),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_4(self):
|
|
self._test_unbind_case(
|
|
torch.tensor([]), torch.tensor([]),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_unbind_dim(self):
|
|
def _test_fn(unbind_fn):
|
|
a = torch.rand(3, 2)
|
|
b = torch.rand(2, 3)
|
|
nt = torch.nested.nested_tensor([a, b])
|
|
self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
|
|
|
|
# Both of these tests are necessary, because we're using
|
|
# torch_function.
|
|
_test_fn(lambda x, dim: x.unbind(dim))
|
|
# TODO: Re-enable this once using torch_dispatch
|
|
# _test_fn(lambda x, dim: torch.unbind(x, dim))
|
|
|
|
@torch.inference_mode()
|
|
def test_nested_tensor(self):
|
|
self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0])))
|
|
self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0))
|
|
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_matching_dim(self):
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
|
|
lambda: torch.nested.nested_tensor([torch.tensor(1.0), torch.tensor([])]),
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
|
|
lambda: torch.nested.nested_tensor(
|
|
[torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
|
|
),
|
|
)
|
|
|
|
@torch.inference_mode()
|
|
def test_default_nested_tensor(self):
|
|
self.assertRaises(TypeError, lambda: torch.nested.nested_tensor())
|
|
default_nested_tensor = torch.nested.nested_tensor([])
|
|
default_tensor = torch.tensor([])
|
|
# self.assertEqual(default_nested_tensor.nested_dim(), 1)
|
|
# self.assertEqual(default_nested_tensor.nested_size(), ())
|
|
self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
|
|
self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
|
|
self.assertEqual(default_nested_tensor.device, default_tensor.device)
|
|
self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
|
|
self.assertEqual(
|
|
default_nested_tensor.requires_grad, default_tensor.requires_grad
|
|
)
|
|
self.assertIsNone(default_tensor.grad)
|
|
# TODO: Re-enable once we have a performance driven
|
|
# use case and implementation.
|
|
# self.assertEqual(default_nested_tensor.is_pinned(),
|
|
# default_tensor.is_pinned())
|
|
|
|
@torch.inference_mode()
|
|
def test_dim(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertEqual(a1.dim(), 1)
|
|
a1 = constructor([torch.tensor(3.0)])
|
|
self.assertEqual(a1.dim(), 1)
|
|
a1 = constructor([torch.tensor([1, 2, 3, 4])])
|
|
self.assertEqual(a1.dim(), 2)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
|
|
@torch.inference_mode()
|
|
def test_numel(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertEqual(a1.numel(), 0)
|
|
a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)])
|
|
self.assertEqual(a1.numel(), 2)
|
|
a1 = constructor([torch.randn(2, 2, 2)])
|
|
self.assertEqual(a1.numel(), 8)
|
|
a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)])
|
|
self.assertEqual(a1.numel(), 12)
|
|
a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)])
|
|
self.assertEqual(a1.numel(), 27)
|
|
a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)])
|
|
self.assertEqual(a1.numel(), 341)
|
|
|
|
# Interesting edge case
|
|
a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)])
|
|
self.assertEqual(a1.numel(), 6)
|
|
|
|
@torch.inference_mode()
|
|
def test_size(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"NestedTensorImpl doesn't support sizes",
|
|
lambda: a1.size(),
|
|
)
|
|
|
|
def test_size_dim(self):
|
|
a = torch.nested.nested_tensor([])
|
|
self.assertEqual(a.size(0), 0)
|
|
|
|
a = torch.nested.nested_tensor([torch.tensor(1)])
|
|
self.assertEqual(a.size(0), 1)
|
|
|
|
a = torch.nested.nested_tensor([torch.tensor(1), torch.tensor(2)])
|
|
self.assertEqual(a.size(0), 2)
|
|
|
|
a = torch.nested.nested_tensor([torch.rand(1, 2),
|
|
torch.rand(1, 8)])
|
|
self.assertEqual(a.size(0), 2)
|
|
self.assertEqual(a.size(1), 1)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "Given dimension 2 is irregular and does not have a size", lambda: a.size(2))
|
|
|
|
a = torch.nested.nested_tensor([torch.rand(3, 4),
|
|
torch.rand(5, 4)])
|
|
self.assertEqual(a.size(0), 2)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "Given dimension 1 is irregular and does not have a size", lambda: a.size(1))
|
|
self.assertEqual(a.size(2), 4)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
|
|
@torch.inference_mode()
|
|
def test_stride(self):
|
|
for constructor in _iter_constructors():
|
|
a1 = constructor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"NestedTensorImpl doesn't support strides",
|
|
lambda: a1.stride(),
|
|
)
|
|
|
|
@unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
|
|
@torch.inference_mode()
|
|
def test_is_contiguous(self):
|
|
# Test empty case
|
|
nt_empty = torch.nested.nested_tensor([])
|
|
assert nt_empty.is_contiguous()
|
|
self.assertEqual(nt_empty, nt_empty.contiguous())
|
|
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
|
|
|
# Test contiguous case
|
|
assert nt_contiguous.is_contiguous()
|
|
self.assertEqual(nt_contiguous, nt_contiguous.contiguous())
|
|
|
|
# Test non_contiguous case
|
|
assert not nt_noncontiguous.is_contiguous()
|
|
self.assertEqual(nt_contiguous, nt_noncontiguous.contiguous())
|
|
|
|
# Test querying by memory_format
|
|
self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format))
|
|
self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format))
|
|
|
|
@torch.inference_mode()
|
|
def test_repr_string(self):
|
|
a = torch.nested.nested_tensor([])
|
|
expected = "nested_tensor([\n\n])"
|
|
self.assertEqual(str(a), expected)
|
|
self.assertEqual(repr(a), expected)
|
|
|
|
a = torch.nested.nested_tensor([torch.tensor(1.0)])
|
|
expected = "nested_tensor([\n tensor(1.)\n])"
|
|
self.assertEqual(str(a), expected)
|
|
self.assertEqual(repr(a), expected)
|
|
|
|
a = torch.nested.nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
|
|
expected = "nested_tensor([\n tensor([[1, 2]]),\n tensor([[4, 5]])\n])"
|
|
self.assertEqual(str(a), expected)
|
|
self.assertEqual(repr(a), expected)
|
|
|
|
def test_to_padded_tensor_on_empty_tensor(self):
|
|
|
|
nt = torch.nested.nested_tensor([])
|
|
empty = torch.nested.to_padded_tensor(nt, 4)
|
|
self.assertEqual(empty, torch.tensor([]))
|
|
|
|
def test_nested_namespace(self):
|
|
nt = torch.nested.nested_tensor([torch.randn(2, 3), torch.randn(4, 5)])
|
|
result = nt.to_padded_tensor(4)
|
|
nested_namespace_result = torch.nested.to_padded_tensor(nt, 4)
|
|
self.assertEqual(result, nested_namespace_result)
|
|
|
|
def test_to(self):
|
|
ntensors = 4
|
|
nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4))
|
|
|
|
def test_copy_behavior(t, non_blocking=False):
|
|
self.assertIs(t, t.to(t, non_blocking=non_blocking))
|
|
self.assertIs(t, t.to(t.dtype, non_blocking=non_blocking))
|
|
self.assertIs(t, t.to(torch.empty_like(t), non_blocking=non_blocking))
|
|
self.assertIsNot(t, t.to(t, non_blocking=non_blocking, copy=True))
|
|
self.assertIsNot(t, t.to(t.dtype, non_blocking=non_blocking, copy=True))
|
|
self.assertIsNot(t, t.to(torch.empty_like(t), non_blocking=non_blocking, copy=True))
|
|
|
|
devices = [t.device]
|
|
if t.device.type == 'cuda':
|
|
if t.device.index == -1:
|
|
devices.append(f'cuda:{torch.cuda.current_device()}')
|
|
elif t.device.index == torch.cuda.current_device():
|
|
devices.append('cuda')
|
|
for device in devices:
|
|
self.assertIs(t, t.to(device, non_blocking=non_blocking))
|
|
self.assertIs(t, t.to(device, t.dtype, non_blocking=non_blocking))
|
|
self.assertIsNot(t, t.to(device, non_blocking=non_blocking, copy=True))
|
|
self.assertIsNot(t, t.to(device, t.dtype, non_blocking=non_blocking, copy=True))
|
|
|
|
test_copy_behavior(nt)
|
|
self.assertEqual(nt.device, nt.to('cpu').device)
|
|
self.assertEqual(nt.device, nt.to('cpu', dtype=torch.float32).device)
|
|
self.assertIs(torch.float32, nt.to('cpu', dtype=torch.float32).dtype)
|
|
self.assertEqual(nt.device, nt.to(torch.float32).device)
|
|
self.assertIs(torch.float32, nt.to(dtype=torch.float32).dtype)
|
|
|
|
def test_data_ptr(getter):
|
|
self.assertEqual(getter(nt), getter(nt.to('cpu')))
|
|
self.assertEqual(getter(nt), getter(nt.to(dtype=nt.dtype, device=nt.device, copy=False)))
|
|
self.assertEqual(getter(nt), getter(nt.to('cpu', copy=False)))
|
|
self.assertNotEqual(getter(nt), getter(nt.to('cpu', copy=True)))
|
|
|
|
test_data_ptr(lambda nt: nt.data_ptr())
|
|
|
|
if torch.cuda.is_available():
|
|
for non_blocking in [True, False]:
|
|
for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
|
|
nt2 = random_nt(cuda, torch.float32, ntensors, (4, 4))
|
|
test_copy_behavior(nt2, non_blocking)
|
|
self.assertEqual(nt2.device, nt2.to(cuda, non_blocking=non_blocking).device)
|
|
self.assertEqual(nt.device, nt2.to('cpu', non_blocking=non_blocking).device)
|
|
self.assertEqual(nt2.device, nt.to(cuda, non_blocking=non_blocking).device)
|
|
self.assertIs(torch.int32, nt2.to('cpu', dtype=torch.int32, non_blocking=non_blocking).dtype)
|
|
self.assertEqual(nt.device, nt2.to('cpu', dtype=torch.int32, non_blocking=non_blocking).device)
|
|
self.assertIs(torch.int32, nt2.to(dtype=torch.int32).dtype)
|
|
self.assertEqual(nt2.device, nt2.to(dtype=torch.int32).device)
|
|
|
|
def test_copy_(self):
|
|
ntensors = 4
|
|
nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4))
|
|
nt_copy = torch.empty_like(nt)
|
|
nt_copy.copy_(nt)
|
|
|
|
for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy):
|
|
self.assertEqual(nt_ub, nt_copy_ub)
|
|
|
|
nt_error = torch.nested.nested_tensor([torch.tensor([0, 0])])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"copy_ only supports tensors that are the same size for Nested implementations",
|
|
lambda: nt_error.copy_(nt)
|
|
)
|
|
|
|
if torch.cuda.is_available():
|
|
nt = random_nt(torch.device('cuda'), torch.float32, ntensors, (4, 4))
|
|
nt_copy = torch.empty_like(nt, device=torch.device('cpu'))
|
|
nt_copy.copy_(nt, non_blocking=True)
|
|
torch.cuda.current_stream(torch.cuda.current_device()).synchronize()
|
|
for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy):
|
|
self.assertEqual(nt_ub, nt_copy_ub)
|
|
|
|
nt_copy = torch.empty_like(nt, device=torch.device('cpu'))
|
|
nt_copy.copy_(nt, non_blocking=False)
|
|
for (nt_ub, nt_copy_ub) in zip(nt.unbind(), nt_copy):
|
|
self.assertEqual(nt_ub, nt_copy_ub)
|
|
|
|
def test_fill_(self):
|
|
ntensors = 4
|
|
nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4))
|
|
nt.fill_(10.)
|
|
for nt_ub in nt.unbind():
|
|
t = torch.empty_like(nt_ub)
|
|
t.fill_(10.)
|
|
self.assertEqual(nt_ub, t)
|
|
|
|
fill_tensor = torch.tensor([11.])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"fill_ only supports 0-dimension value tensor",
|
|
lambda: nt.fill_(fill_tensor)
|
|
)
|
|
|
|
nt.fill_(fill_tensor[0])
|
|
for nt_ub in nt.unbind():
|
|
t = torch.empty_like(nt_ub)
|
|
t.fill_(11.)
|
|
self.assertEqual(nt_ub, t)
|
|
|
|
def test_zero_(self):
|
|
ntensors = 4
|
|
nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4))
|
|
nt.zero_()
|
|
for nt_ub in nt.unbind():
|
|
t = torch.empty_like(nt_ub)
|
|
t.fill_(0.)
|
|
self.assertEqual(nt_ub, t)
|
|
|
|
@parametrize("func", [torch.ones_like, torch.zeros_like, torch.randn_like],
|
|
name_fn=lambda f: f.__name__)
|
|
def test_like_functions(self, func):
|
|
ntensors = 4
|
|
nt = random_nt(torch.device('cpu'), torch.float32, ntensors, (4, 4))
|
|
torch.manual_seed(1)
|
|
nt_like = func(nt)
|
|
|
|
torch.manual_seed(1)
|
|
for nt_ub in nt_like.unbind():
|
|
t_like = func(nt_ub)
|
|
self.assertEqual(nt_ub, t_like)
|
|
|
|
def test_cat(self):
|
|
# dim=0 success case
|
|
# No constraints on ragged structures matching.
|
|
x = random_nt_from_dims([5, None, 10])
|
|
y = random_nt_from_dims([3, 4, None])
|
|
output = torch.cat([x, y], dim=0)
|
|
for out_component, xy_component in zip(
|
|
output.unbind(), itertools.chain(x.unbind(), y.unbind())):
|
|
self.assertEqual(out_component, xy_component)
|
|
|
|
# dim=-1 success case
|
|
# shape (B, *, D)
|
|
x = random_nt_from_dims([5, None, 10])
|
|
# shape (B, *, D'); same structure as x but dim=-1 differs
|
|
y = random_nt_from_similar(x, dims=[-1, -1, 8])
|
|
# should be shape (B, *, D + D') when supported
|
|
output = torch.cat([x, y], dim=-1)
|
|
for out_component, x_component, y_component in zip(output.unbind(), x.unbind(), y.unbind()):
|
|
self.assertEqual(out_component, torch.cat([x_component, y_component], dim=-1))
|
|
|
|
# dim between 0 and -1 success case
|
|
x = random_nt_from_dims([5, None, 2, 3])
|
|
# same structure as x but dim=2 differs
|
|
y = random_nt_from_similar(x, dims=[-1, -1, 4, -1])
|
|
output = torch.cat([x, y], dim=2)
|
|
for out_component, x_component, y_component in zip(output.unbind(), x.unbind(), y.unbind()):
|
|
self.assertEqual(out_component, torch.cat([x_component, y_component], dim=1))
|
|
|
|
# error case: mixed NT / dense inputs
|
|
x = random_nt_from_dims([5, None, 2])
|
|
y = torch.randn(5, 3, 2)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "expected each tensor in given list to be nested"):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
# error case: NTs with different dims
|
|
x = random_nt_from_dims([5, None, 2])
|
|
y = random_nt_from_dims([5, None, 2, 3])
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
# error case: non-contiguous NT
|
|
x, y = random_nt_noncontiguous_pair((2, 3, 4), dtype=torch.float32)
|
|
# transpose to put ragged dim next to batch dim
|
|
x, y = x.transpose(-2, -1), y.transpose(-2, -1)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "only contiguous nested tensors are supported"):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
# error case: multiple ragged dims in inputs
|
|
x = random_nt_from_dims([5, None, None, 2])
|
|
y = random_nt_from_similar(x)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "only nested tensors with a single ragged dim next to the batch dim are supported"):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
# error case: ragged dim not next to batch dim
|
|
x = random_nt_from_dims([5, 2, None])
|
|
y = random_nt_from_similar(x)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "only nested tensors with a single ragged dim next to the batch dim are supported"):
|
|
torch.cat([x, y], dim=1)
|
|
|
|
# error case: NTs with different batch sizes
|
|
x = random_nt_from_dims([5, None, 2])
|
|
y = random_nt_from_dims([3, None, 2])
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
# error case: NTs with different ragged structures
|
|
x = torch.nested.nested_tensor([
|
|
torch.randn(2, 6),
|
|
torch.randn(4, 6),
|
|
torch.randn(5, 6),
|
|
])
|
|
y = torch.nested.nested_tensor([
|
|
torch.randn(5, 6),
|
|
torch.randn(4, 6),
|
|
torch.randn(2, 6),
|
|
])
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "expected all nested tensors to have matching ragged structures outside of the concatenated dim"):
|
|
torch.cat([x, y], dim=-1)
|
|
|
|
|
|
@markDynamoStrictTest
|
|
class TestNestedTensorDeviceType(TestCase):
|
|
# Helper function to generate a pair of random nested tensors
|
|
# the 2 nested tensors have same shapes
|
|
def random_nt_pair(self, device, dtype, num_tensors, max_dims):
|
|
ts1 = []
|
|
ts2 = []
|
|
for _ in range(num_tensors):
|
|
tensor_dims = tuple([torch.randint(low=0, high=max_dim, size=(1,)).item() for max_dim in max_dims])
|
|
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
|
|
t2 = torch.randn(tensor_dims, device=device, dtype=dtype)
|
|
ts1.append(t1)
|
|
ts2.append(t2)
|
|
return (torch.nested.nested_tensor(ts1, device=device, dtype=dtype),
|
|
torch.nested.nested_tensor(ts2, device=device, dtype=dtype))
|
|
|
|
@dtypes(*floating_types_and_half())
|
|
def test_detach(self, device, dtype):
|
|
a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=False)
|
|
b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=False)
|
|
x = torch.nested.nested_tensor([a, b], requires_grad=True)
|
|
|
|
x_detach = x.detach()
|
|
|
|
z = x_detach * 4
|
|
self.assertFalse(x_detach.requires_grad)
|
|
self.assertFalse(z.requires_grad)
|
|
|
|
a = torch.randn(2, 4, device=device, dtype=dtype, requires_grad=True)
|
|
b = torch.randn(5, 4, device=device, dtype=dtype, requires_grad=True)
|
|
x = torch.nested.as_nested_tensor([a, b])
|
|
|
|
y = x * 2
|
|
y = y.detach()
|
|
self.assertFalse(y.requires_grad)
|
|
self.assertIsNone(y.grad_fn)
|
|
|
|
z = x + y
|
|
torch.nested.to_padded_tensor(z, 0).sum().backward()
|
|
# This is an incorrect gradient, but we assume that's what the user
|
|
# wanted. detach() is an advanced option.
|
|
self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
|
|
self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_unbind_noncontiguous(self, device, dtype):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)
|
|
ub_contiguous = nt_contiguous.unbind()
|
|
ub_noncontiguous = nt_noncontiguous.unbind()
|
|
self.assertEqual(len(ub_contiguous), len(ub_noncontiguous))
|
|
n = len(ub_contiguous)
|
|
for i in range(n):
|
|
self.assertEqual(ub_contiguous[i], ub_noncontiguous[i])
|
|
|
|
@dtypes(torch.float)
|
|
@skipMeta
|
|
def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype):
|
|
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
|
|
ts = list(torch.unbind(t))
|
|
ts[0] = ts[0][:-1]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
padded = torch.nested.to_padded_tensor(nt, 0)
|
|
|
|
nt_to = torch._nested_from_padded_and_nested_example(padded, nt)
|
|
|
|
for (t1, t2) in zip(nt.unbind(), nt_to.unbind()):
|
|
self.assertEqual(t1, t2)
|
|
self.assertEqual(nt.device, nt_to.device)
|
|
|
|
@dtypes(torch.float)
|
|
@dtypesIfCUDA(torch.float, torch.half)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_layer_norm(self, device, dtype):
|
|
def _test(size):
|
|
# Simple shapes test
|
|
t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
|
|
t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
|
|
ts = [t0, t1, t0, t1]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
|
|
nt_result = layer_norm(nt)
|
|
for (nt_subresult, t) in zip(nt_result.unbind(), ts):
|
|
t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
|
|
self.assertEqual(nt_subresult, t_result)
|
|
|
|
# More complex nt test with different lengths for each tensor
|
|
t0 = torch.randn(4, size, device=device, dtype=dtype, requires_grad=False)
|
|
t1 = torch.randn(10, size, device=device, dtype=dtype, requires_grad=False)
|
|
t2 = torch.randn(7, size, device=device, dtype=dtype, requires_grad=False)
|
|
ts = [t0, t1, t2, t0, t2]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
|
|
nt_result = layer_norm(nt)
|
|
for (nt_subresult, t) in zip(nt_result.unbind(), ts):
|
|
t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
|
|
self.assertEqual(nt_subresult, t_result)
|
|
|
|
if size <= 128:
|
|
# Test with multidimensional tensors after irregular dim
|
|
# (run only with smaller dimensions to ensure fast execution)
|
|
t0 = torch.randn(4, size, size, 4, device=device, dtype=dtype, requires_grad=False)
|
|
t1 = torch.randn(10, size, size, 4, device=device, dtype=dtype, requires_grad=False)
|
|
t2 = torch.randn(7, size, size, 4, device=device, dtype=dtype, requires_grad=False)
|
|
ts = [t0, t1, t2, t0, t2]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
layer_norm = torch.nn.LayerNorm((size, size, 4), device=device, dtype=dtype)
|
|
nt_result = layer_norm(nt)
|
|
for (nt_subresult, t) in zip(nt_result.unbind(), ts):
|
|
t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
|
|
self.assertEqual(nt_subresult, t_result)
|
|
|
|
# Test where the normalizing dimensions are not all
|
|
layer_norm = torch.nn.LayerNorm((size, 4), device=device, dtype=dtype)
|
|
nt_result = layer_norm(nt)
|
|
for (nt_subresult, t) in zip(nt_result.unbind(), ts):
|
|
t_result = layer_norm(t.reshape(1, -1, size, size, 4).squeeze(0))
|
|
self.assertEqual(nt_subresult, t_result)
|
|
|
|
for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32):
|
|
_test(size)
|
|
|
|
@dtypes(torch.float)
|
|
@dtypesIfCUDA(torch.float, torch.half)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_layer_norm_breaking(self, device, dtype):
|
|
size = 128
|
|
t0 = torch.randn(4, size, size, 4, device=device, dtype=dtype, requires_grad=False)
|
|
t1 = torch.randn(10, size, size, 4, device=device, dtype=dtype, requires_grad=False)
|
|
t2 = torch.randn(7, size, size, 4, device=device, dtype=dtype, requires_grad=False)
|
|
ts = [t0, t1, t2, t0, t2]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
layer_norm = torch.nn.LayerNorm((4, size, size, 4), device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"normalized_shape extends into irregular dimensions for the nested tensor",
|
|
lambda: layer_norm(nt),
|
|
)
|
|
layer_norm = torch.nn.LayerNorm((size + 1, size, 4), device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"The shape at dimension 0",
|
|
lambda: layer_norm(nt),
|
|
)
|
|
|
|
@decorateIf(
|
|
xfailIfTorchDynamo,
|
|
# only fails in python 3.11. TODO: Ensure this is fixed once views work!
|
|
lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11)
|
|
)
|
|
@parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
|
|
def test_embedding(self, device, layout):
|
|
inputs = [
|
|
torch.randint(100, (L,), device=device, dtype=torch.int64)
|
|
for L in torch.randint(5, 50, (8,))
|
|
]
|
|
x = torch.nested.nested_tensor(inputs, device=device, dtype=torch.int64, layout=layout)
|
|
emb = torch.nn.Embedding(100, 8, device=device)
|
|
y = emb(x)
|
|
ys = y.unbind()
|
|
for i, inp in enumerate(inputs):
|
|
self.assertEqual(emb(inp), ys[i])
|
|
|
|
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
@dtypes(*floating_types_and_half())
|
|
def test_masked_fill(self, device, dtype):
|
|
# nested tensor * nested tensor
|
|
(nt, mask) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
mask = torch.nested.nested_tensor([m < 0 for m in mask.unbind()])
|
|
ref = torch.nested.nested_tensor([t.masked_fill(m, 0) for (t, m) in zip(nt.unbind(), mask.unbind())])
|
|
out = nt.masked_fill(mask, 0)
|
|
self.assertEqual(ref, out)
|
|
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
def test_to_padded_tensor_simple(self, device, dtype):
|
|
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
|
|
ts = list(torch.unbind(t))
|
|
ts[0] = ts[0][:-1]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
for padding_value in (0, 1):
|
|
padded = torch.nested.to_padded_tensor(nt, padding_value)
|
|
|
|
correct_output = t.clone()
|
|
if padding_value == 0:
|
|
correct_output[0][-1] = torch.zeros_like(correct_output[0][-1])
|
|
else:
|
|
correct_output[0][-1] = torch.ones_like(correct_output[0][-1])
|
|
|
|
self.assertEqual(padded, correct_output)
|
|
self.assertEqual(padded.device, torch.device(device))
|
|
self.assertEqual(padded.dtype, dtype)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
def test_to_padded_tensor_output_size(self, device, dtype):
|
|
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
|
|
output_size = (4, 6, 5)
|
|
ts = list(torch.unbind(t))
|
|
ts[0] = ts[0][:-1]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
for padding_value in (0, 1):
|
|
padded = torch.nested.to_padded_tensor(nt, padding_value, output_size=output_size)
|
|
correct_output = torch.ones(output_size, device=device, dtype=dtype) * padding_value
|
|
correct_output[:4:, :4, :4] = t.clone()
|
|
if padding_value == 0:
|
|
correct_output[0][3] = torch.zeros_like(correct_output[0][3])
|
|
else:
|
|
correct_output[0][3] = torch.ones_like(correct_output[0][3])
|
|
|
|
self.assertEqual(padded, correct_output)
|
|
self.assertEqual(padded.device, torch.device(device))
|
|
self.assertEqual(padded.dtype, dtype)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_to_padded_tensor_dim2(self, device, dtype):
|
|
ts = [
|
|
torch.randn(160, device=device, dtype=dtype),
|
|
torch.randn(1240, device=device, dtype=dtype),
|
|
torch.randn(2400, device=device, dtype=dtype),
|
|
]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
pad = 42
|
|
correct_output = []
|
|
for t in ts:
|
|
next_output = torch.ones_like(ts[2]) * pad
|
|
correct_output.append(next_output)
|
|
next_output[:t.size(0)].copy_(t)
|
|
correct_output = torch.stack(correct_output)
|
|
padded = torch.nested.to_padded_tensor(nt, pad)
|
|
self.assertEqual(padded, correct_output)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_to_padded_tensor_dim3(self, device, dtype):
|
|
ts = [
|
|
torch.randn(16, 21, device=device, dtype=dtype),
|
|
torch.randn(24, 32, device=device, dtype=dtype),
|
|
torch.randn(40, 53, device=device, dtype=dtype),
|
|
]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
pad = 42
|
|
correct_output = []
|
|
for t in ts:
|
|
next_output = torch.ones_like(ts[2]) * pad
|
|
correct_output.append(next_output)
|
|
next_output[:t.size(0), :t.size(1)].copy_(t)
|
|
correct_output = torch.stack(correct_output)
|
|
padded = torch.nested.to_padded_tensor(nt, pad)
|
|
self.assertEqual(padded, correct_output)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_to_padded_tensor_dim4(self, device, dtype):
|
|
ts = [
|
|
torch.randn(16, 21, 13, device=device, dtype=dtype),
|
|
torch.randn(24, 32, 14, device=device, dtype=dtype),
|
|
torch.randn(40, 53, 16, device=device, dtype=dtype),
|
|
]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
pad = 42
|
|
correct_output = []
|
|
for t in ts:
|
|
next_output = torch.ones_like(ts[2]) * pad
|
|
correct_output.append(next_output)
|
|
next_output[:t.size(0), :t.size(1), :t.size(2)].copy_(t)
|
|
correct_output = torch.stack(correct_output)
|
|
padded = torch.nested.to_padded_tensor(nt, pad)
|
|
self.assertEqual(padded, correct_output)
|
|
|
|
# TODO: test noncontiguous to_padded_tensor
|
|
# For now this tests the functionality of noncontiguous_to_padded_tensor
|
|
# and the error message of to_padded_tensor
|
|
# since to_padded_tensor does not support noncontiguous buffer yet
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
@torch.inference_mode()
|
|
def test_to_padded_tensor_noncontiguous(self, device, dtype):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)
|
|
# test noncontiguous_to_padded_tensor functionality
|
|
self.assertEqual(
|
|
torch.nested.to_padded_tensor(nt_contiguous, 0.0),
|
|
noncontiguous_to_padded_tensor(nt_noncontiguous))
|
|
# test to_padded_tensor error message
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"for now to_padded_tensor only supports contiguous nested tensor",
|
|
lambda: torch.nested.to_padded_tensor(nt_noncontiguous, 0.0)
|
|
)
|
|
|
|
@skipMeta
|
|
def test_device_checks(self, device):
|
|
nt = torch.nested.nested_tensor([], device=device)
|
|
is_cuda = 'cuda' in str(device)
|
|
self.assertEqual(nt.is_cuda, is_cuda)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_nested_tensor_indexing(self, device, dtype):
|
|
# edge case: empty nested tensor
|
|
nt0 = torch.nested.nested_tensor([])
|
|
self.assertRaises(IndexError, lambda: nt0[0])
|
|
# normal case
|
|
x0 = torch.randn((2, 5), device=device, dtype=dtype)
|
|
x1 = torch.randn((3, 4), device=device, dtype=dtype)
|
|
nt = torch.nested.nested_tensor([x0, x1])
|
|
# single index: only support integer in the batch dimension
|
|
self.assertEqual(nt[0], x0)
|
|
self.assertEqual(nt[-1], x1)
|
|
self.assertRaises(IndexError, lambda: nt[2])
|
|
self.assertRaises(IndexError, lambda: nt[-3])
|
|
self.assertRaises(NotImplementedError, lambda: nt[:])
|
|
self.assertRaises(NotImplementedError, lambda: nt[...])
|
|
# tuple of indices: only support integer in the batch dimension
|
|
# + all possible indexing in the original tensor dimensions
|
|
self.assertEqual(nt[0, 0, 0], x0[0, 0])
|
|
self.assertEqual(nt[0, 1, :], x0[1, :])
|
|
self.assertEqual(nt[1, ...], x1)
|
|
self.assertRaises(IndexError, lambda: nt[1, 4, 2])
|
|
self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1])
|
|
# test select on non-batch dimensions
|
|
self.assertEqual(nt.select(1, 0)[0], x0.select(0, 0))
|
|
self.assertEqual(nt.select(1, 0)[1], x1.select(0, 0))
|
|
self.assertRaises(IndexError, lambda: nt.select(1, 3))
|
|
self.assertEqual(nt.select(2, 0)[0], x0.select(1, 0))
|
|
self.assertEqual(nt.select(2, 0)[1], x1.select(1, 0))
|
|
self.assertRaises(IndexError, lambda: nt.select(2, 5))
|
|
# make sure indexing returns a view
|
|
nt[0].fill_(100.0)
|
|
answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5))
|
|
self.assertEqual(nt[0], answer)
|
|
nt[1, 1, :].fill_(200.0)
|
|
answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4)
|
|
self.assertEqual(nt[1, 1, :], answer)
|
|
|
|
# Test that indexing works when requires_grad_(True)
|
|
# previously this was failing because the backward kernel for select.int uses .sizes()
|
|
nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True)
|
|
self.assertEqual(nt[0], x0)
|
|
self.assertEqual(nt[-1], x1)
|
|
grad_x0 = torch.randn((2, 5), device=device, dtype=dtype)
|
|
nt[0].backward(grad_x0)
|
|
expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)])
|
|
self.assertEqual(nt.grad, expected_grad)
|
|
|
|
@parametrize("func", [subtest(torch.nn.functional.relu, name='relu'),
|
|
subtest(torch.nn.functional.relu_, name='relu_'),
|
|
subtest(torch.nn.functional.gelu, name='gelu'),
|
|
subtest(torch._C._nn.gelu_, name='gelu_'),
|
|
subtest(torch.tanh, name='tanh'),
|
|
subtest(torch.tanh_, name='tanh_'),
|
|
subtest(torch.neg, name='neg'),
|
|
subtest(torch.nn.functional.silu, name='silu'),
|
|
subtest(partial(torch.nn.functional.silu, inplace=True), name='silu_'),
|
|
subtest(torch.abs, name="abs"),
|
|
subtest(torch.abs_, name="abs_"),
|
|
subtest(torch.sgn, name="sgn"),
|
|
subtest(torch.logical_not, name='logical_not'),
|
|
subtest(torch.sin, name='sin'),
|
|
subtest(torch.cos, name='cos')])
|
|
def test_activations(self, device, func):
|
|
nt, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device=device, dtype=torch.float32)
|
|
nested_result = func(nt)
|
|
self.assertTrue(nested_result.is_nested)
|
|
for t, t_res in zip(nt.unbind(), nested_result.unbind()):
|
|
self.assertEqual(func(t), t_res)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"NestedTensor must be contiguous to get buffer.",
|
|
lambda: func(nt_noncontiguous))
|
|
|
|
@parametrize("func", [subtest(torch.ge, name='ge'),
|
|
subtest(torch.eq, name='eq')])
|
|
def test_binary_ops_with_scalar(self, device, func):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair(
|
|
(2, 3, 6, 7), device=device, dtype=torch.float32)
|
|
scalar = 0.0
|
|
|
|
# should work regardless of contiguity
|
|
for nt in (nt_contiguous, nt_noncontiguous):
|
|
nested_result = func(nt, scalar)
|
|
self.assertTrue(nested_result.is_nested)
|
|
for t, t_res in zip(nt.unbind(), nested_result.unbind()):
|
|
self.assertEqual(func(t, scalar), t_res)
|
|
|
|
@dtypes(*floating_types_and_half())
|
|
def test_nested_tensor_chunk(self, device, dtype):
|
|
# Transformer use case
|
|
a = torch.randn(3, 3 * 4, device=device, dtype=dtype)
|
|
b = torch.randn(2, 3 * 4, device=device, dtype=dtype)
|
|
c = torch.randn(1, 3 * 4, device=device, dtype=dtype)
|
|
a_chunks = a.chunk(3, dim=-1)
|
|
b_chunks = b.chunk(3, dim=-1)
|
|
c_chunks = c.chunk(3, dim=-1)
|
|
|
|
a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]]
|
|
b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]]
|
|
c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]]
|
|
|
|
nt = torch.nested.nested_tensor([a, b, c])
|
|
chunked = nt.chunk(3, dim=-1)
|
|
|
|
self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt))
|
|
self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt))
|
|
self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt))
|
|
|
|
for chunk in chunked:
|
|
self.assertFalse(chunk.is_contiguous())
|
|
|
|
# Failure chunking on ragged dimensions
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "Chunk for nested tensors is currently only supported for the last dimension.",
|
|
lambda: torch.chunk(nt, 5, dim=1))
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "Chunk for nested tensors is currently only supported for the last dimension.",
|
|
lambda: torch.chunk(nt, 5, dim=0))
|
|
|
|
# Failure on non-contiguous nt
|
|
_, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "chunk expects `self` to be contiguous.", lambda: torch.chunk(nt_noncontiguous, 5, dim=-1))
|
|
|
|
# Failure when calling non divisible n_chunks
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "Chunk for nested tensors is only supported for "
|
|
"nested tensors with trailing dimension divisible by chunks.",
|
|
lambda: torch.chunk(nt, 5, dim=-1))
|
|
|
|
# Failure when calling backward on a chunk
|
|
a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True)
|
|
b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True)
|
|
nt_grad = torch.nested.as_nested_tensor([a, b])
|
|
chunked = torch.chunk(nt_grad, 2, dim=-1)
|
|
self.assertRaisesRegex(RuntimeError, "derivative for aten::chunk is not implemented",
|
|
lambda: chunked[0].backward(chunked[0].clone()))
|
|
|
|
@dtypes(*floating_types_and_half())
|
|
def test_nested_tensor_split_with_sizes(self, device, dtype):
|
|
a = torch.randn(3, 20, device=device, dtype=dtype)
|
|
b = torch.randn(2, 20, device=device, dtype=dtype)
|
|
c = torch.randn(1, 20, device=device, dtype=dtype)
|
|
|
|
split_sizes = [4, 6, 10]
|
|
a_splits = a.split_with_sizes(split_sizes, dim=-1)
|
|
b_splits = b.split_with_sizes(split_sizes, dim=-1)
|
|
c_splits = c.split_with_sizes(split_sizes, dim=-1)
|
|
|
|
nt = torch.nested.nested_tensor([a, b, c])
|
|
nt_splits = nt.split_with_sizes(split_sizes, dim=-1)
|
|
|
|
for i, nt_split in enumerate(nt_splits):
|
|
self.assertEqual(nt_split, torch.nested.nested_tensor(
|
|
[a_splits[i], b_splits[i], c_splits[i]]))
|
|
dense_strides = torch.stack([
|
|
torch.tensor(a_splits[i].stride()),
|
|
torch.tensor(b_splits[i].stride()),
|
|
torch.tensor(c_splits[i].stride())
|
|
])
|
|
self.assertEqual(nt_split._nested_tensor_strides(), dense_strides)
|
|
self.assertFalse(nt_split.is_contiguous())
|
|
|
|
# Failure calling on ragged dimensions
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "split_with_sizes for nested tensors is currently only supported for the last dimension.",
|
|
lambda: torch.split_with_sizes(nt, split_sizes, dim=1))
|
|
|
|
# Failure calling on non-last dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "split_with_sizes for nested tensors is currently only supported for the last dimension.",
|
|
lambda: torch.split_with_sizes(nt, split_sizes, dim=0))
|
|
|
|
# Failure on non-contiguous nt
|
|
_, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "split_with_sizes expects `self` to be contiguous.",
|
|
lambda: torch.split_with_sizes(nt_noncontiguous, split_sizes, dim=-1))
|
|
|
|
# Failure when calling with split_sizes that don't cover the full dim size
|
|
bad_split_sizes = [4, 6, 9] # don't add up to 20
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "split_with_sizes expects split_sizes to sum exactly to 20",
|
|
lambda: torch.split_with_sizes(nt, bad_split_sizes, dim=-1))
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_indexing_noncontiguous(self, device, dtype):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)
|
|
self.assertEqual(nt_contiguous.size(0), nt_noncontiguous.size(0))
|
|
n = nt_contiguous.size(0)
|
|
for i in range(n):
|
|
self.assertEqual(nt_contiguous[i], nt_noncontiguous[i])
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
@parametrize("transpose", [True, False])
|
|
def test_nested_tensor_add(self, device, dtype, transpose):
|
|
if transpose:
|
|
a = torch.randn(2, 2, 2, device=device, dtype=dtype)
|
|
b = torch.rand(2, 2, 2, device=device, dtype=dtype)
|
|
c = a.transpose(-1, -2).contiguous()
|
|
d = b.transpose(-1, -2).contiguous()
|
|
nt1 = torch.nested.nested_tensor([a, b, a, b])
|
|
nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
|
|
else:
|
|
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
ref = torch.nested.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
|
|
out = nt1 + nt2
|
|
self.assertEqual(ref, out)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
@parametrize("transpose", [True, False])
|
|
def test_nested_tensor_sub(self, device, dtype, transpose):
|
|
if transpose:
|
|
a = torch.randn(2, 2, 2, device=device, dtype=dtype)
|
|
b = torch.rand(2, 2, 2, device=device, dtype=dtype)
|
|
c = a.transpose(-1, -2).contiguous()
|
|
d = b.transpose(-1, -2).contiguous()
|
|
nt1 = torch.nested.nested_tensor([a, b, a, b])
|
|
nt2 = torch.nested.nested_tensor([c, d, c, d]).transpose(-1, -2)
|
|
else:
|
|
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
ref = torch.nested.nested_tensor([t1 - t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
|
|
out = nt1 - nt2
|
|
self.assertEqual(ref, out)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float, torch.float16)
|
|
@torch.inference_mode()
|
|
@parametrize("embedding_dim", [8, 128, 256, 384])
|
|
def test_nested_tensor_dense_elementwise(self, device, dtype, embedding_dim):
|
|
def _test_add_mul(nt, t):
|
|
ref_add = torch.nested.nested_tensor(
|
|
[t1 + t2 for (t1, t2) in zip(nt.unbind(), t.unbind())])
|
|
ref_mul = torch.nested.nested_tensor(
|
|
[t1 * t2 for (t1, t2) in zip(nt.unbind(), t.unbind())])
|
|
self.assertEqual(nt.add(t), ref_add)
|
|
self.assertEqual(nt.mul(t), ref_mul)
|
|
|
|
batch_size = 32
|
|
seq_lens = torch.randint(low=0, high=10, size=(batch_size,))
|
|
|
|
# [B, *, D], [B, 1, D] case
|
|
ts = [torch.randn((seq_len, embedding_dim)) for seq_len in seq_lens]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
t = torch.randn((batch_size, 1, embedding_dim), device=device, dtype=dtype)
|
|
_test_add_mul(nt, t)
|
|
|
|
# [B, *], [B, 1] case
|
|
ts = [torch.randn(seq_len) for seq_len in seq_lens]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype)
|
|
t = torch.randn((batch_size, 1), device=device, dtype=dtype)
|
|
_test_add_mul(nt, t)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_mul(self, device, dtype):
|
|
# nested tensor * nested tensor
|
|
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
ref = torch.nested.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
|
|
out = nt1 * nt2
|
|
self.assertEqual(ref, out)
|
|
# nested tensor * scalar
|
|
number = 10.0
|
|
scalar = torch.tensor(number).to(dtype).to(device)
|
|
ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
|
|
out_number0 = nt1 * number
|
|
out_number1 = number * nt1
|
|
out_scalar0 = nt1 * scalar
|
|
out_scalar1 = scalar * nt1
|
|
self.assertEqual(out_number0, ref)
|
|
self.assertEqual(out_number1, ref)
|
|
self.assertEqual(out_scalar0, ref)
|
|
self.assertEqual(out_scalar1, ref)
|
|
# error case: numel == 1 but dim > 0
|
|
vector = torch.tensor([number]).to(dtype).to(device)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both self and other to be nested, but got a nested self and non-nested other",
|
|
lambda: nt1.mul(vector)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both self and other to be nested, but got a non-nested self and nested other",
|
|
lambda: vector.mul(nt1)
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_div(self, device, dtype):
|
|
nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
scale = 4.0
|
|
ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()])
|
|
out = nt / 4.0
|
|
self.assertEqual(ref, out)
|
|
ref_transposed = ref.transpose(1, 2)
|
|
out = nt.transpose(1, 2) / 4.0
|
|
self.assertEqual(ref_transposed, out)
|
|
|
|
ref = torch.nested.nested_tensor([t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())])
|
|
out = nt / nt2
|
|
self.assertEqual(ref, out)
|
|
|
|
out = nt.transpose(1, 2) / nt2.transpose(1, 2)
|
|
self.assertEqual(ref.transpose(1, 2), out)
|
|
|
|
nt_transpose_copy = torch.nested.nested_tensor([t.transpose(0, 1) for t in nt.unbind()])
|
|
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "div requires strides to match when given NestedTensors",
|
|
lambda: nt_transpose_copy.transpose(1, 2) / nt2)
|
|
|
|
nt = torch.nested.nested_tensor([torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype)
|
|
nt_chunks = nt.chunk(2, -1)
|
|
self.assertRaisesRegex(
|
|
RuntimeError, "div requires offsets to match when given NestedTensors",
|
|
lambda: nt_chunks[0] / nt_chunks[1])
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_add_in_place(self, device, dtype):
|
|
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
ref = torch.nested.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
|
|
nt1 += nt2
|
|
self.assertEqual(ref, nt1)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
@torch.inference_mode()
|
|
def test_nested_tensor_mul_in_place(self, device, dtype):
|
|
# nested tensor * nested tensor
|
|
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
|
|
ref = torch.nested.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
|
|
nt1 *= nt2
|
|
self.assertEqual(ref, nt1)
|
|
# nested tensor * scalar
|
|
number = 10.0
|
|
scalar = torch.tensor(number).to(dtype).to(device)
|
|
ref = torch.nested.nested_tensor([t * number for t in nt1.unbind()])
|
|
out_number = nt1.clone()
|
|
out_number *= number
|
|
out_scalar = nt1.clone()
|
|
out_scalar *= scalar
|
|
self.assertEqual(out_number, ref)
|
|
self.assertEqual(out_scalar, ref)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]",
|
|
lambda: scalar.mul_(nt1)
|
|
)
|
|
# error case: numel == 1 but dim > 0
|
|
vector = torch.tensor([number]).to(dtype).to(device)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both self and other to be nested, but got a nested self and non-nested other",
|
|
lambda: nt1.mul_(vector)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both self and other to be nested, but got a non-nested self and nested other",
|
|
lambda: vector.mul_(nt1)
|
|
)
|
|
|
|
@onlyCPU
|
|
@skipMeta
|
|
@dtypes(torch.float)
|
|
def test_nested_tensor_sum_dim(self, device, dtype):
|
|
params = ((2, (1, 1)), ((4), (4, 4)), (10, (3, 5, 7)))
|
|
|
|
def test_sum(device, dtype, ntensors, max_sizes, dim, keepdim=True):
|
|
nt = random_nt(device, dtype, ntensors, max_sizes, require_non_empty=False)
|
|
nt2 = nt.clone()
|
|
ub2 = nt2.unbind()
|
|
nt.requires_grad_(True)
|
|
[t.requires_grad_(True) for t in ub2]
|
|
nt_sum = nt.sum(dim=dim, keepdim=keepdim)
|
|
ub2_sum = [t.sum(-1, keepdim=keepdim) for t in ub2]
|
|
self.assertEqual(nt_sum, torch.nested.nested_tensor(ub2_sum))
|
|
|
|
# test backward
|
|
# generate gradient tensor that has the same size as the output
|
|
size = nt_sum._nested_tensor_size()
|
|
gt2 = []
|
|
for i in range(ntensors):
|
|
gt2.append(torch.randn(size[i].tolist(), device=device, dtype=dtype))
|
|
gt = torch.nested.nested_tensor(gt2).clone()
|
|
nt_sum.backward(gt)
|
|
for t2, g2 in zip(ub2_sum, gt2):
|
|
t2.backward(g2)
|
|
self.assertEqual(nt.grad, torch.nested.nested_tensor([t.grad for t in ub2]))
|
|
return
|
|
|
|
for ntensors, max_sizes in params:
|
|
test_sum(device, dtype, ntensors, max_sizes, len(max_sizes))
|
|
|
|
# Test error inputs
|
|
with self.assertRaisesRegex(RuntimeError, "NestedTensor can only be reduced across the last"):
|
|
torch.nested.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(0, keepdim=True)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "NestedTensor only allows reduction of a single"):
|
|
torch.nested.nested_tensor([torch.tensor([[3, 4, 5]]), torch.tensor([[1, 2]])]).sum([0, 1], keepdim=True)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "NestedTensor always requires keepdim=True for now."):
|
|
torch.nested.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(-1)
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
def test_contiguous(self, device, dtype):
|
|
# Since we don't have access to the buffer in python this is harder to show what
|
|
# we are testing for. When we call chunk on a consistent dim of a NT
|
|
# for chunk_size > 1 the resulting tensors are views of the original NT
|
|
# whose numels is now less than the size of the buffer. Clone was
|
|
# previously creating a new NT with a buffer that was the same size as the
|
|
# original.
|
|
nt_contiguous = torch.nested.nested_tensor([torch.randn(2, 20, device=device, dtype=dtype),
|
|
torch.randn(4, 20, device=device, dtype=dtype)])
|
|
# Split up the last dimension which has a consistent size of 20 into 5 chunks
|
|
chunks = nt_contiguous.chunk(5, dim=-1)
|
|
|
|
# # Check chunks are contiguous after calling contiguous
|
|
for chunk in chunks:
|
|
self.assertFalse(chunk.is_contiguous())
|
|
self.assertTrue(chunk.contiguous().is_contiguous())
|
|
|
|
@dtypes(torch.float, torch.float16)
|
|
@skipMeta
|
|
def test_clone(self, device, dtype):
|
|
nt1 = random_nt(device, dtype, 4, (4, 4), (1, 1))
|
|
nt2 = nt1.clone()
|
|
# Verify the values match
|
|
self.assertEqual(nt1, nt2)
|
|
# Verify modifying nt2 doesn't affect nt1
|
|
nt2.mul_(nt1)
|
|
ub1 = nt1.unbind()
|
|
ub2 = nt2.unbind()
|
|
for i in range(len(ub1)):
|
|
self.assertNotEqual(ub1[i], ub2[i])
|
|
|
|
nt1.clone(memory_format=torch.preserve_format)
|
|
msg = "Nested tensor clone supports Preserve and Contiguous memory formats, called clone with memory format: ChannelsLast"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
nt1.clone(memory_format=torch.channels_last)
|
|
|
|
# cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
|
|
@decorateIf(xfailIfTorchDynamo, lambda params: params["layout"] == torch.jagged)
|
|
@dtypes(torch.float, torch.double)
|
|
@parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
|
|
def test_dropout(self, device, dtype, layout):
|
|
# edge case: empty nested tensor
|
|
# TODO: support empty NT in jagged layout
|
|
if layout == torch.strided:
|
|
nt0 = torch.nested.nested_tensor([], layout=layout)
|
|
y = torch.nn.functional.dropout(nt0, 0.5)
|
|
self.assertEqual(nt0, y)
|
|
# normal nested tensor
|
|
ntensors = 4
|
|
if layout == torch.jagged:
|
|
nt = random_nt(device, dtype, ntensors, (4, 4), (0, 3), layout=layout)
|
|
else:
|
|
nt = random_nt(device, dtype, ntensors, (4, 4), layout=layout)
|
|
# edge case: invalid dropout
|
|
self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
|
|
self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
|
|
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
|
|
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
|
|
# edge case: no dropout
|
|
dropouter = torch.nn.Dropout(0.0)
|
|
y0 = dropouter(nt)
|
|
y1 = torch.nn.functional.dropout(nt, 0.0)
|
|
self.assertEqual(nt, y0)
|
|
self.assertEqual(nt, y1)
|
|
# edge case: all dropout
|
|
dropouter = torch.nn.Dropout(1.0)
|
|
y0 = dropouter(nt)
|
|
y1 = torch.nn.functional.dropout(nt, 1.0)
|
|
nt0 = torch.zeros_like(nt)
|
|
self.assertEqual(nt0, y0)
|
|
self.assertEqual(nt0, y1)
|
|
# normal case: normal dropout
|
|
p = 0.2
|
|
y = torch.nn.functional.dropout(nt, p)
|
|
expect = nt.clone()
|
|
if layout == torch.jagged:
|
|
expect = torch.where(y == 0.0, y, nt)
|
|
expect /= 1.0 - p
|
|
self.assertEqual(y, expect)
|
|
else:
|
|
expect = nt.clone()
|
|
for i in range(ntensors):
|
|
actual_tensor = y[i].view(-1)
|
|
expect_tensor = expect[i].view(-1)
|
|
for j in range(actual_tensor.shape[0]):
|
|
if actual_tensor[j].item() == 0.0:
|
|
expect_tensor[j] = 0.0
|
|
else:
|
|
expect_tensor[j] /= 1.0 - p
|
|
self.assertEqual(y, expect)
|
|
with freeze_rng_state():
|
|
dropouter = torch.nn.Dropout(p)
|
|
y0 = dropouter(nt)
|
|
with freeze_rng_state():
|
|
y1 = torch.nn.functional.dropout(nt, p)
|
|
self.assertEqual(y0, y1)
|
|
|
|
@dtypes(torch.float, torch.double)
|
|
def test_dropout_noncontiguous(self, device, dtype):
|
|
ntensors = 4
|
|
nt0 = random_nt(device, dtype, ntensors, (4, 4))
|
|
nt1 = nt0.transpose(-1, -2)
|
|
p = 0.3
|
|
with freeze_rng_state():
|
|
dropouter = torch.nn.Dropout(p)
|
|
y0 = dropouter(nt0)
|
|
with freeze_rng_state():
|
|
y1 = torch.nn.functional.dropout(nt1, p).transpose(-1, -2)
|
|
self.assertEqual(y0, y1)
|
|
|
|
# cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half'
|
|
@dtypes(torch.float, torch.double)
|
|
def test_softmax(self, device, dtype):
|
|
# normal nested tensor
|
|
ntensors = 4
|
|
nt = random_nt(device, dtype, ntensors, (4, 4))
|
|
# error case: softmax across nested dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Cannot apply softmax across nested dimension 0",
|
|
lambda: torch.nn.functional.softmax(nt, 0)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Cannot apply softmax across nested dimension 0",
|
|
lambda: torch.nn.functional.softmax(nt, -3)
|
|
)
|
|
# error case: dimension out of range
|
|
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3))
|
|
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4))
|
|
# normal case: should equal to padding -inf
|
|
softmaxer = torch.nn.Softmax(1)
|
|
y0 = softmaxer(nt)
|
|
y1 = torch.nn.functional.softmax(nt, 1)
|
|
self.assertEqual(y0, y1)
|
|
pt = torch.nested.to_padded_tensor(nt, float("-inf"))
|
|
# if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan
|
|
# however, physically speaking that should be 0.0
|
|
expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0)
|
|
self.assertEqual(torch.nested.to_padded_tensor(y0, 0.0), expect)
|
|
# edge case: empty nested tensor
|
|
nt0 = torch.nested.nested_tensor([])
|
|
y = torch.nn.functional.softmax(nt0, 1)
|
|
self.assertEqual(nt0, y)
|
|
# edge case: nesting scalars
|
|
nt1 = torch.nested.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)])
|
|
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0))
|
|
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1))
|
|
|
|
@dtypes(torch.float, torch.double)
|
|
@torch.inference_mode()
|
|
def test_softmax_noncontiguous(self, device, dtype):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)
|
|
self.assertEqual(
|
|
torch.nn.functional.softmax(nt_contiguous, -1),
|
|
torch.nn.functional.softmax(nt_noncontiguous, -1))
|
|
|
|
def _test_bmm(self, device, dtype):
|
|
# error case: one is nested but the other is not
|
|
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype)
|
|
t = torch.randn(4, device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both to be nested, but got a nested self and non-nested other",
|
|
lambda: nt.bmm(t)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both to be nested, but got a non-nested self and nested other",
|
|
lambda: t.bmm(nt)
|
|
)
|
|
# error case: not 3D tensors
|
|
nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype)
|
|
nt2 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"batch1 must be a 3D tensor",
|
|
lambda: nt0.bmm(nt0)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"batch1 must be a 3D tensor",
|
|
lambda: nt0.bmm(nt1)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"batch1 must be a 3D tensor",
|
|
lambda: nt0.bmm(nt2)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"batch1 must be a 3D tensor",
|
|
lambda: nt1.bmm(nt0)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"batch1 must be a 3D tensor",
|
|
lambda: nt1.bmm(nt1)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"batch1 must be a 3D tensor",
|
|
lambda: nt1.bmm(nt2)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"batch2 must be a 3D tensor",
|
|
lambda: nt2.bmm(nt0)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"batch2 must be a 3D tensor",
|
|
lambda: nt2.bmm(nt1)
|
|
)
|
|
# error case: incompatible batch size
|
|
nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((4, 6)),
|
|
torch.randn((4, 5)),
|
|
torch.randn((4, 7))],
|
|
device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.",
|
|
lambda: nt0.bmm(nt1)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.",
|
|
lambda: nt1.bmm(nt0)
|
|
)
|
|
# error case: underlying matrices cannot be multiplied
|
|
nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
|
|
lambda: nt0.bmm(nt0)
|
|
)
|
|
# normal nested tensor
|
|
nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype)
|
|
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
|
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0))
|
|
if dtype == torch.float16:
|
|
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
self.assertEqual(actual, expect)
|
|
|
|
# test tensorcore path
|
|
nt0 = torch.nested.nested_tensor([torch.randn((2, 8)), torch.randn((3, 16))], device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((8, 8)), torch.randn((16, 8))], device=device, dtype=dtype)
|
|
actual = torch.nested.to_padded_tensor(nt0.bmm(nt1), 0.0)
|
|
expect = torch.nested.to_padded_tensor(nt0, 0.0).bmm(torch.nested.to_padded_tensor(nt1, 0.0))
|
|
if dtype == torch.float16:
|
|
self.assertEqual(actual, expect, rtol=1e-3, atol=1e-3)
|
|
else:
|
|
self.assertEqual(actual, expect)
|
|
|
|
@onlyCUDA
|
|
@dtypes(torch.float, torch.double, torch.float16)
|
|
def test_bmm_cuda(self, device, dtype):
|
|
self._test_bmm(device, dtype)
|
|
|
|
@onlyCPU
|
|
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
|
@dtypes(torch.float, torch.double)
|
|
def test_bmm_cpu(self, device, dtype):
|
|
self._test_bmm(device, dtype)
|
|
|
|
# cannot test torch.float16 because: RuntimeError: "addmm_impl_cpu_" not implemented for 'Half'
|
|
@dtypes(torch.float, torch.double)
|
|
def test_bmm_noncontiguous(self, device, dtype):
|
|
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
|
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
|
|
self.assertEqual(
|
|
nt0_contiguous.transpose(-1, -2).bmm(nt1_contiguous),
|
|
nt0_noncontiguous.transpose(-1, -2).bmm(nt1_noncontiguous))
|
|
|
|
@dtypes(torch.float, torch.double)
|
|
def test_matmul_with_bmm_path(self, device, dtype):
|
|
def unbind_rebind_matmul(nt1, nt2):
|
|
t1s = nt1.unbind()
|
|
t2s = nt2.unbind()
|
|
out_ts = [t1.matmul(t2) for t1, t2 in zip(t1s, t2s)]
|
|
return torch.nested.nested_tensor(out_ts)
|
|
|
|
# [N, n_head, *, head_dim], [N, n_head, head_dim, *]
|
|
Ns = [1, 2, 5]
|
|
n_heads = np.random.randint(2, 5)
|
|
head_dim = 3
|
|
t1s = []
|
|
t2s = []
|
|
for N in Ns:
|
|
for _ in range(N):
|
|
seq_len1 = np.random.randint(2, 5)
|
|
seq_len2 = np.random.randint(2, 5)
|
|
t1s.append(torch.randn(n_heads, seq_len1, head_dim))
|
|
t2s.append(torch.randn(n_heads, head_dim, seq_len2))
|
|
nt1 = torch.nested.nested_tensor(t1s, device=device, dtype=dtype)
|
|
nt2 = torch.nested.nested_tensor(t2s, device=device, dtype=dtype)
|
|
self.assertEqual(torch.matmul(nt1, nt2), unbind_rebind_matmul(nt1, nt2))
|
|
|
|
# test with noncontiguous
|
|
t3s = []
|
|
t4s = []
|
|
for _ in range(N):
|
|
seq_len = np.random.randint(2, 5)
|
|
t3s.append(torch.randn(seq_len, n_heads, head_dim))
|
|
t4s.append(torch.randn(seq_len, n_heads, head_dim))
|
|
nt3 = torch.nested.nested_tensor(t3s, device=device, dtype=dtype).transpose(1, 2)
|
|
nt4 = torch.nested.nested_tensor(t4s, device=device, dtype=dtype).transpose(1, 2).transpose(2, 3)
|
|
self.assertEqual(torch.matmul(nt3, nt4), unbind_rebind_matmul(nt3, nt4))
|
|
|
|
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
|
|
@dtypes(torch.float, torch.double)
|
|
def test_matmul(self, device, dtype):
|
|
# error case: one is nested but the other is not
|
|
nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype)
|
|
t = torch.randn(4, device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both to be nested, but got a nested self and non-nested other",
|
|
lambda: torch.matmul(nt, t)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Expected both to be nested, but got a non-nested self and nested other",
|
|
lambda: torch.matmul(t, nt)
|
|
)
|
|
# error case: not 3+D tensors
|
|
nt0 = torch.nested.nested_tensor([], device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)], device=device, dtype=dtype)
|
|
nt2 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt0, nt0)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt0, nt1)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt0, nt2)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt1, nt0)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt1, nt1)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 1st input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt1, nt2)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt2, nt0)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: For nested tensors, only inputs with >= 3 dims are currently supported. 2nd input has rank: [0-9]+",
|
|
lambda: torch.matmul(nt2, nt1)
|
|
)
|
|
# error case: incompatible batch size
|
|
nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((4, 6)),
|
|
torch.randn((4, 5)),
|
|
torch.randn((4, 7))],
|
|
device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
|
|
lambda: torch.matmul(nt0, nt1)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"matmul: Expected size for the 1st dimension of 2nd input tensor to be: [0-9]+ but got: [0-9]+.",
|
|
lambda: torch.matmul(nt1, nt0)
|
|
)
|
|
# error case: incompatible (wrong) batch sizes that shouldn't even broadcast?
|
|
nt0 = torch.nested.nested_tensor([torch.randn((2, 2, 4)),
|
|
torch.randn((2, 3, 4))],
|
|
device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((3, 4, 6)),
|
|
torch.randn((3, 4, 5))],
|
|
device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"matmul(): For nested tensors, batch dimensions must have the same sizes,",
|
|
lambda: torch.matmul(nt0, nt1)
|
|
)
|
|
# error case: incompatible batch sizes that should technically broadcast
|
|
nt0 = torch.nested.nested_tensor([torch.randn((2, 2, 4)),
|
|
torch.randn((1, 3, 4))],
|
|
device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((1, 4, 6)),
|
|
torch.randn((3, 4, 5))],
|
|
device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"matmul(): For nested tensors, batch dimensions must have the same sizes,",
|
|
lambda: torch.matmul(nt0, nt1)
|
|
)
|
|
# error case: underlying matrices cannot be multiplied
|
|
nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))], device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"matmul(): Nested tensors cannot be matrix multiplied",
|
|
lambda: torch.matmul(nt0, nt0)
|
|
)
|
|
# normal nested tensor: 3D
|
|
nt0 = torch.nested.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))], device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))], device=device, dtype=dtype)
|
|
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
|
|
expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0))
|
|
self.assertEqual(actual, expect)
|
|
# normal nested tensor: 4D (with testing for batch_size=1)
|
|
nt0 = torch.nested.nested_tensor([torch.randn((1, 2, 4)),
|
|
torch.randn((8, 3, 7))],
|
|
device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((1, 4, 6)),
|
|
torch.randn((8, 7, 5))],
|
|
device=device, dtype=dtype)
|
|
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
|
|
expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0))
|
|
self.assertEqual(actual, expect)
|
|
# normal nested tensor: 5D
|
|
nt0 = torch.nested.nested_tensor([torch.randn((8, 9, 2, 4)),
|
|
torch.randn((8, 9, 3, 7))],
|
|
device=device, dtype=dtype)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((8, 9, 4, 6)),
|
|
torch.randn((8, 9, 7, 5))],
|
|
device=device, dtype=dtype)
|
|
actual = torch.nested.to_padded_tensor(torch.matmul(nt0, nt1), 0.0)
|
|
expect = torch.matmul(torch.nested.to_padded_tensor(nt0, 0.0), torch.nested.to_padded_tensor(nt1, 0.0))
|
|
self.assertEqual(actual, expect)
|
|
|
|
# only supported on CUDA for now
|
|
@dtypes(torch.float, torch.double)
|
|
def test_matmul_nt_with_broadcasted_t(self, device, dtype):
|
|
# NT (B, *, C, D) with T (D, E) broadcasting case
|
|
nt = random_nt_from_dims([3, None, 4, 5], device=device, dtype=dtype)
|
|
t = torch.randn(5, 6, device=device, dtype=dtype)
|
|
output = torch.matmul(nt, t)
|
|
|
|
# should be equivalent to matmul-ing each component with the dense tensor
|
|
self.assertEqual(nt.size(0), output.size(0))
|
|
for component, out_component in zip(nt, output):
|
|
self.assertEqual(out_component, torch.matmul(component, t))
|
|
|
|
# cannot test torch.float16 because: RuntimeError: "bmm" not implemented for 'Half'
|
|
@dtypes(torch.float, torch.double)
|
|
def test_matmul_noncontiguous(self, device, dtype):
|
|
nt0_contiguous, nt0_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
|
nt1_contiguous, nt1_noncontiguous = random_nt_noncontiguous_pair((6, 7), device, dtype)
|
|
self.assertEqual(
|
|
torch.matmul(nt0_contiguous.transpose(-1, -2), nt1_contiguous),
|
|
torch.matmul(nt0_noncontiguous.transpose(-1, -2), nt1_noncontiguous))
|
|
|
|
@dtypes(torch.float, torch.double)
|
|
def test_linear(self, device, dtype):
|
|
a = torch.randn(1, 2, device=device, dtype=dtype)
|
|
b = torch.randn(2, 2, device=device, dtype=dtype)
|
|
c = torch.randn(3, 2, device=device, dtype=dtype)
|
|
nt = torch.nested.nested_tensor([a, b, c])
|
|
|
|
weight = torch.randn(2, 2, device=device, dtype=dtype)
|
|
bias = torch.randn(2, device=device, dtype=dtype)
|
|
# success case
|
|
torch.functional.F.linear(nt, weight, bias)
|
|
|
|
# invalid nested tensor dimension
|
|
msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2'
|
|
nt1 = torch.nested.nested_tensor([torch.randn(1, device=device, dtype=dtype),
|
|
torch.randn(2, device=device, dtype=dtype)])
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.functional.F.linear(nt1, weight, bias)
|
|
|
|
# invalid weight shape
|
|
msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3'
|
|
weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.functional.F.linear(nt, weight1, bias)
|
|
|
|
# inconsistent last dim of nested tensor
|
|
msg = r"Expected all tensors in nested tensor to have the same trailing dimension, instead last dimension equals:"
|
|
nt2 = torch.nested.nested_tensor([torch.randn(1, 2, device=device, dtype=dtype),
|
|
torch.randn(2, 3, device=device, dtype=dtype)])
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.functional.F.linear(nt2, weight, bias)
|
|
|
|
# Mismatch of nested tensor last dim and weight dimension
|
|
weight2 = torch.randn(2, 4, device=device, dtype=dtype)
|
|
msg = r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" \
|
|
r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.functional.F.linear(nt, weight2, bias)
|
|
|
|
# Nested tensor input and nested weight
|
|
nt_weight = nt.clone()
|
|
msg = r"Linear does not support nested weight when input is a nested tensor."
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
torch.functional.F.linear(nt, nt_weight, bias)
|
|
|
|
# TODO: test noncontiguous linear
|
|
# For now this tests the error message of linear
|
|
# since linear does not support noncontiguous buffer yet
|
|
@dtypes(torch.float, torch.double)
|
|
def test_linear_noncontiguous(self, device, dtype):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)
|
|
weight = torch.randn((8, 5), device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"for now linear only supports contiguous nested tensor",
|
|
lambda: torch.nn.functional.linear(nt_noncontiguous, weight)
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_to_padded_tensor_zero_numel_errors(self, device, dtype):
|
|
ts = [torch.ones(1, 0), torch.ones(0, 0)]
|
|
nt = torch.nested.nested_tensor(ts, device=device, dtype=dtype, layout=torch.strided)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"at least one constituent tensor should have non-zero numel",
|
|
lambda: torch.nested.to_padded_tensor(nt, 0.0)
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_transpose(self, device, dtype):
|
|
nt = random_nt(device, dtype, 4, (4, 4))
|
|
# error case: transpose nested dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Nested tensor dimension 0 cannot be transposed",
|
|
lambda: nt.transpose(0, 1)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Nested tensor dimension 0 cannot be transposed",
|
|
lambda: nt.transpose(1, -3)
|
|
)
|
|
# error case: dimension out of range
|
|
self.assertRaises(IndexError, lambda: nt.transpose(1, 3))
|
|
self.assertRaises(IndexError, lambda: nt.transpose(-4, -1))
|
|
# normal case
|
|
ntT = nt.transpose(-1, -2)
|
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
ptT = pt.transpose(-1, -2)
|
|
self.assertEqual(ptT, ptT_from_ntT)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_squeeze_unsqueeze(self, device, dtype):
|
|
a = torch.arange(6).reshape(2, 3)
|
|
b = torch.arange(15).reshape(5, 3)
|
|
nt = torch.nested.nested_tensor([a, b], device=device, dtype=dtype)
|
|
# error case: squeeze no dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"For nested tensors, squeeze without the dim argument",
|
|
lambda: nt.squeeze()
|
|
)
|
|
# error case: squeeze nested dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"For nested tensors, squeezing dimension 0",
|
|
lambda: nt.squeeze(0)
|
|
)
|
|
# error case: dimension out of range
|
|
self.assertRaises(IndexError, lambda: nt.squeeze(3))
|
|
# error case: squeeze nested tensor of singleton tensors
|
|
c = torch.ones(1)
|
|
nt_singleton = torch.nested.nested_tensor([c, c], device=device, dtype=dtype)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"For nested tensors, squeezing a nested tensor of singleton",
|
|
lambda: nt_singleton.squeeze(1)
|
|
)
|
|
|
|
# squeezing a dim which does not have size 1 should be a no-op
|
|
nt2 = nt.squeeze(-1)
|
|
self.assertEqual(nt, nt2)
|
|
|
|
# test cases that should work
|
|
nt_sizes = nt._nested_tensor_size()
|
|
nt_strides = nt._nested_tensor_strides()
|
|
for i in range(-2, 4):
|
|
if (i == 0):
|
|
# cannot unsqueeze batch dim
|
|
continue
|
|
nt_unsqueezed = nt.unsqueeze(i)
|
|
# negative dim will correspond to unsqueeze() applied at dim = dim + nt.dim() + 1
|
|
wrapped_i = i + nt.dim() + 1 if i < 0 else i
|
|
# col_index into nt size tensor is requires subtraction of 1 to ignore batch dim
|
|
size_idx = wrapped_i - 1
|
|
self.assertEqual(nt_unsqueezed._nested_tensor_size()[:, size_idx], torch.ones(2, dtype=torch.long))
|
|
unsqueezed_stride = nt_unsqueezed._nested_tensor_strides()[:, size_idx]
|
|
if (i == nt.ndim or i == -1):
|
|
self.assertEqual(unsqueezed_stride, torch.ones(2, dtype=torch.long))
|
|
else:
|
|
stride_col_after = nt_strides[:, size_idx]
|
|
size_col_after = nt_sizes[:, size_idx]
|
|
self.assertEqual(unsqueezed_stride, stride_col_after * size_col_after)
|
|
nt_squeezed = nt_unsqueezed.squeeze(i)
|
|
self.assertEqual(nt_squeezed, nt)
|
|
self.assertEqual(nt_squeezed._nested_tensor_size(), nt_sizes)
|
|
self.assertEqual(nt_squeezed._nested_tensor_strides(), nt_strides)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_transpose_inference_mode_interaction(self, device, dtype):
|
|
nt = random_nt(device, dtype, 4, (4, 4))
|
|
# Construct in default mode and transpose while in inference mode
|
|
with torch.inference_mode():
|
|
ntT = nt.transpose(-1, -2)
|
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
ptT = pt.transpose(-1, -2)
|
|
self.assertEqual(ptT, ptT_from_ntT)
|
|
|
|
# Construct and transpose while in inference mode
|
|
with torch.inference_mode():
|
|
nt = random_nt(device, dtype, 4, (4, 4))
|
|
ntT = nt.transpose(-1, -2)
|
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
ptT = pt.transpose(-1, -2)
|
|
self.assertEqual(ptT, ptT_from_ntT)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_view(self, device, dtype):
|
|
nt = random_nt(device, dtype, 4, (4, 4))
|
|
# error case: empty shape
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"shape '\[\]' is invalid for a nested tensor",
|
|
lambda: nt.view(())
|
|
)
|
|
# error case: empty nested tensor
|
|
nt_empty = torch.nested.nested_tensor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"empty nested tensor cannot be reshaped",
|
|
lambda: nt_empty.view(-1)
|
|
)
|
|
# error case: -1 for batch size
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"view: For now nested view cannot change or infer the implicit batch dimension",
|
|
lambda: nt.view(-1, 2, 3)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"shape '\[.*\]' is invalid for input of size [0-9]+",
|
|
lambda: nt.view(4, 2, 3)
|
|
)
|
|
# normal case
|
|
x0 = torch.randn((2, 20), device=device, dtype=dtype)
|
|
x1 = torch.randn((3, 20), device=device, dtype=dtype)
|
|
nt = torch.nested.nested_tensor([x0, x1])
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
# error case, trying to reshape batch dim to a legit shape
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"For now nested view cannot change or infer the implicit batch dimension",
|
|
lambda: nt.transpose(-1, -2).view(40, -1)
|
|
)
|
|
# inherit only the ragged dimension
|
|
# (2, 20) -> (2, 5, 4)
|
|
# (3, 20) -> (3, 5, 4)
|
|
nt1 = nt.view(2, -1, 5, 4)
|
|
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
|
|
pt1 = pt.view(2, -1, 5, 4)
|
|
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
|
|
|
|
# more than one -1 (even for "old" dims), should fail
|
|
# this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
|
|
# but we ban "inherit old behavior" for >1 dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"only one dimension can be inferred",
|
|
lambda: nt1.view(2, -1, -1, 2, 2)
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_view_inference_mode_interaction(self, device, dtype):
|
|
# Construct in default mode and view while in inference mode
|
|
nt = torch.nested.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype)
|
|
with torch.inference_mode():
|
|
ntT = nt.view(2, -1, 4, 5)
|
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
ptT = pt.view(2, -1, 4, 5)
|
|
self.assertEqual(ptT, ptT_from_ntT)
|
|
# Construct and view while in inference mode
|
|
with torch.inference_mode():
|
|
nt = torch.nested.nested_tensor([torch.randn((2, 20)), torch.randn((3, 20))], device=device, dtype=dtype)
|
|
ntT = nt.view(2, -1, 4, 5)
|
|
ptT_from_ntT = noncontiguous_to_padded_tensor(ntT)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
ptT = pt.view(2, -1, 4, 5)
|
|
self.assertEqual(ptT, ptT_from_ntT)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_reshape(self, device, dtype):
|
|
nt = random_nt(device, dtype, 4, (4, 4))
|
|
# error case: empty shape
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"shape '\[\]' is invalid for a nested tensor",
|
|
lambda: nt.reshape(())
|
|
)
|
|
# error case: empty nested tensor
|
|
nt_empty = torch.nested.nested_tensor([])
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"empty nested tensor cannot be reshaped",
|
|
lambda: nt_empty.reshape(-1)
|
|
)
|
|
# error case: -1 for batch size
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
|
|
lambda: nt.reshape(-1, 2, 3)
|
|
)
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"shape '\[.*\]' is invalid for input of size [0-9]+",
|
|
lambda: nt.reshape(4, 2, 3)
|
|
)
|
|
# normal case
|
|
x0 = torch.randn((2, 20), device=device, dtype=dtype)
|
|
x1 = torch.randn((3, 20), device=device, dtype=dtype)
|
|
nt = torch.nested.nested_tensor([x0, x1]) # (2, (2, 3), 20)
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0)
|
|
# error case, trying to reshape batch dim to a legit shape
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"reshape: For now nested reshape cannot change or infer the implicit batch dimension",
|
|
lambda: nt.transpose(-1, -2).reshape(40, -1)
|
|
)
|
|
# inherit only the ragged dimension
|
|
# (2, 20) -> (2, 5, 4)
|
|
# (3, 20) -> (3, 5, 4)
|
|
nt1 = nt.reshape(2, -1, 5, 4)
|
|
# (2, 3, 20) -> (2, 3, 5, 4) -> (2, 4, 5, 4)
|
|
pt1 = pt.reshape(2, -1, 5, 4)
|
|
self.assertEqual(noncontiguous_to_padded_tensor(nt1), pt1)
|
|
|
|
# more than one -1 (even for "old" dims), should fail
|
|
# this attempts to do # (2, (2, 3), 5, 4) -> (2, (2, 3), 5, 2, 2)
|
|
# but we ban "inherit old behavior" for >1 dimension
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"only one dimension can be inferred",
|
|
lambda: nt1.reshape(2, -1, -1, 2, 2)
|
|
)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_narrow(self, device, dtype):
|
|
nt = random_nt_from_dims([5, None, None, None], device=device, dtype=dtype)
|
|
|
|
# narrow on dim=0 from start to end
|
|
bounds = [(0, 5), (0, 3), (1, 2), (1, 5), (2, 4)]
|
|
for start, end in bounds:
|
|
length = end - start
|
|
narrowed = nt.narrow(dim=0, start=start, length=length)
|
|
# ensure output is a view
|
|
self.assertTrue(narrowed._base is nt)
|
|
for nc, c in zip(narrowed.unbind(), nt.unbind()[start:end]):
|
|
self.assertEqual(nc, c)
|
|
|
|
# dim != 0 is not supported
|
|
for dim in range(1, nt.dim()):
|
|
with self.assertRaisesRegex(RuntimeError, "only dim=0 supported for nested tensors"):
|
|
nt.narrow(dim=dim, start=0, length=1)
|
|
|
|
# error case: non-contiguous NT
|
|
_, nt_noncont = random_nt_noncontiguous_pair((2, 3, 4))
|
|
with self.assertRaisesRegex(RuntimeError, "only contiguous nested tensors supported"):
|
|
nt_noncont.narrow(dim=0, start=0, length=1)
|
|
|
|
@parametrize("input_dim", [3, 4])
|
|
def test_scaled_dot_product_attention(self, device, input_dim):
|
|
|
|
def rand_tensor(*shape):
|
|
return torch.randn(shape, device=device)
|
|
|
|
E = 8
|
|
if input_dim == 3:
|
|
# Shape: (N, L, E); ragged L
|
|
query = torch.nested.nested_tensor([rand_tensor(2, E), rand_tensor(3, E), rand_tensor(4, E)])
|
|
|
|
# Shape: (N, S, E); ragged S
|
|
key = torch.nested.nested_tensor([rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)])
|
|
value = torch.nested.nested_tensor([rand_tensor(3, E), rand_tensor(4, E), rand_tensor(5, E)])
|
|
elif input_dim == 4:
|
|
# In the 4D case the L and S is ragged
|
|
# Shape: (N, N', L, E); ragged N' and L
|
|
query = torch.nested.nested_tensor([rand_tensor(2, 2, E), rand_tensor(3, 3, E), rand_tensor(4, 4, E)])
|
|
# Shape: (N, N', S, E); ragged N' and S
|
|
key = torch.nested.nested_tensor([rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)])
|
|
value = torch.nested.nested_tensor([rand_tensor(2, 3, E), rand_tensor(3, 4, E), rand_tensor(4, 5, E)])
|
|
else:
|
|
self.fail(f"Invalid input_dim {input_dim} encountered in SDP test")
|
|
|
|
def rand_mask(size):
|
|
return torch.randint(0, 2, size=size, dtype=torch.bool, device=device)
|
|
|
|
# Shape: (N, L, S); ragged L and S matching above
|
|
attn_mask = torch.nested.nested_tensor([rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))])
|
|
|
|
dropout_p = 0.0 # no dropout for reproducibility
|
|
|
|
# Success case: no attn_mask set and is_causal=False.
|
|
actual = torch.nn.functional.scaled_dot_product_attention(
|
|
query, key, value, attn_mask=None, is_causal=False, dropout_p=dropout_p)
|
|
|
|
expected_outputs = []
|
|
for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()):
|
|
output = torch.nn.functional.scaled_dot_product_attention(
|
|
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_mask=None, dropout_p=dropout_p)
|
|
expected_outputs.append(output.squeeze(0))
|
|
expected_output_nested = torch.nested.nested_tensor(expected_outputs)
|
|
self.assertEqual(actual, expected_output_nested)
|
|
|
|
# Error case: explicit attn_mask set.
|
|
with self.assertRaisesRegex(RuntimeError, "not supported when an explicit attn_mask is set"):
|
|
torch.nn.functional.scaled_dot_product_attention(
|
|
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p)
|
|
|
|
# Error case: is_causal=True.
|
|
with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"):
|
|
torch.nn.functional.scaled_dot_product_attention(
|
|
query, key, value, dropout_p=dropout_p, is_causal=True)
|
|
|
|
@dtypes(torch.float, torch.float16, torch.double)
|
|
def test_empty_like(self, device, dtype):
|
|
ntensors = 4
|
|
nt = random_nt(device, dtype, ntensors, (4, 4))
|
|
|
|
# Create empty on same device as original nested tensor
|
|
nt_empty = torch.empty_like(nt)
|
|
assert nt.is_same_size(nt_empty)
|
|
self.assertEqual(nt.dtype, nt_empty.dtype)
|
|
self.assertEqual(nt.device, nt_empty.device)
|
|
self.assertEqual(nt.layout, nt_empty.layout)
|
|
|
|
if torch.cuda.is_available():
|
|
if device == "cpu":
|
|
nt_cuda = torch.empty_like(nt, device='cuda')
|
|
self.assertEqual(torch.device("cuda").type, nt_cuda.device.type)
|
|
else:
|
|
nt_cpu = torch.empty_like(nt, device='cpu')
|
|
self.assertEqual(torch.device("cpu").type, nt_cpu.device.type)
|
|
|
|
# Check changing dtype of empty_like nested tensor output
|
|
dtype_set = {torch.float, torch.float16, torch.double}
|
|
for other_dtype in dtype_set - {dtype}:
|
|
nt_empty_other_dtype = torch.empty_like(nt, dtype=other_dtype)
|
|
self.assertEqual(nt.dtype, dtype)
|
|
self.assertEqual(nt_empty_other_dtype.dtype, other_dtype)
|
|
self.assertEqual(nt.device, nt_empty.device)
|
|
self.assertEqual(nt.layout, nt_empty.layout)
|
|
|
|
# Create tensor for autograd
|
|
nt_empty_req_grad = torch.empty_like(nt, requires_grad=True)
|
|
self.assertEqual(nt_empty_req_grad.requires_grad, True)
|
|
|
|
# Test noncontiguous tensor does not fail to copy
|
|
nt_cont, nt_noncont = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
|
nt_empty = torch.empty_like(nt_cont)
|
|
assert nt_cont.is_same_size(nt_empty)
|
|
nt_empty_non_contig = torch.empty_like(nt_noncont)
|
|
assert nt_noncont.is_same_size(nt_empty_non_contig)
|
|
|
|
# Test the contiguous memory format option
|
|
nt_empty_contig = torch.empty_like(nt_cont, memory_format=torch.contiguous_format)
|
|
assert nt_cont.is_same_size(nt_empty_contig)
|
|
assert nt_empty_contig.is_contiguous()
|
|
|
|
nt_empty_non_contig = torch.empty_like(nt_noncont, memory_format=torch.contiguous_format)
|
|
assert nt_noncont.is_same_size(nt_empty_non_contig)
|
|
assert nt_empty_non_contig.is_contiguous()
|
|
|
|
# Test other memory formats fail
|
|
self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last))
|
|
self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last))
|
|
self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_cont, memory_format=torch.channels_last_3d))
|
|
self.assertRaises(RuntimeError, lambda: torch.empty_like(nt_noncont, memory_format=torch.channels_last_3d))
|
|
|
|
@markDynamoStrictTest
|
|
class TestNestedTensorAutograd(TestCase):
|
|
# Note [Gradcheck args check_batched_grad=False] the common_utils testing version of gradcheck
|
|
# includes the default parameters used for testing ops with gradcheck. However nested tensor
|
|
# does not support the stack op therefore we turn it off for these tests
|
|
def _create_leaf_nested_tensor_from_list(self, tensor_device, requires_grad=False):
|
|
return torch.nested.nested_tensor([torch.randn(1, 2,),
|
|
torch.randn(7, 8)], requires_grad=requires_grad, device=tensor_device)
|
|
|
|
def _create_nested_tensor_from_list(self, tensor_device, requires_grad=False):
|
|
return torch.nested.as_nested_tensor([torch.randn(1, 2, requires_grad=requires_grad),
|
|
torch.randn(7, 8, requires_grad=requires_grad)], device=tensor_device)
|
|
|
|
def _create_nested_tensor_from_mask(self, tensor_device, requires_grad=False):
|
|
data = torch.randn(2, 3, 4, requires_grad=requires_grad, device=tensor_device)
|
|
mask = torch.ones_like(data[:, :, 0]).bool()
|
|
return torch._nested_tensor_from_mask(data, mask)
|
|
|
|
def test_as_nested_tensor_propagates_gradients(self, device):
|
|
a = torch.arange(3, dtype=torch.float, device=device)
|
|
b = torch.arange(5, dtype=torch.float, device=device)
|
|
nt = torch.nested.as_nested_tensor([a, b])
|
|
# tensors with requires_grad=False are leaves
|
|
self.assertTrue(nt.is_leaf)
|
|
self.assertTrue(not nt.requires_grad)
|
|
|
|
a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
|
|
b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
|
|
nt2 = torch.nested.as_nested_tensor([a, b])
|
|
fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)], device=device)
|
|
nt2.backward(fake_grad)
|
|
self.assertEqual(a.grad, fake_grad[0])
|
|
self.assertEqual(b.grad, fake_grad[1])
|
|
|
|
def test_nested_tensor_generates_leaf(self, device):
|
|
a = torch.arange(3, dtype=torch.float, requires_grad=True, device=device)
|
|
b = torch.arange(5, dtype=torch.float, requires_grad=True, device=device)
|
|
|
|
nt = torch.nested.nested_tensor([a, b], requires_grad=False)
|
|
self.assertTrue(nt.is_leaf)
|
|
self.assertTrue(not nt.requires_grad)
|
|
|
|
nt2 = torch.nested.nested_tensor([a, b], requires_grad=True)
|
|
self.assertTrue(nt2.is_leaf)
|
|
self.assertTrue(nt2.requires_grad)
|
|
|
|
fake_grad = torch.nested.nested_tensor([torch.ones_like(a), torch.zeros_like(b)], device=device)
|
|
nt2.backward(fake_grad)
|
|
self.assertEqual(nt2.grad, fake_grad)
|
|
self.assertEqual(a.grad, None)
|
|
self.assertEqual(b.grad, None)
|
|
|
|
def test_set_requires_grad_from_list(self, device):
|
|
nt = self._create_nested_tensor_from_list(device)
|
|
nt.requires_grad_()
|
|
assert nt.requires_grad
|
|
|
|
def test_set_requires_grad_from_mask(self, device):
|
|
nt = self._create_nested_tensor_from_mask(device)
|
|
nt.requires_grad_()
|
|
assert nt.requires_grad
|
|
|
|
def test_backward_for_add_op(self, device):
|
|
nt_1 = self._create_nested_tensor_from_mask(device)
|
|
nt_2 = self._create_nested_tensor_from_mask(device)
|
|
|
|
nt_1.requires_grad_()
|
|
c = nt_1 + nt_2
|
|
|
|
assert nt_1.requires_grad
|
|
assert c.requires_grad
|
|
grad_output = self._create_nested_tensor_from_mask(device)
|
|
c.backward(grad_output)
|
|
|
|
# Grad check doesn't work with nested yet.
|
|
# d/dnt_1 (nt + nt_1) = 1*grad_output
|
|
self.assertEqual(nt_1.grad, grad_output)
|
|
|
|
def test_backward_for_sub_op(self, device):
|
|
nt_1 = self._create_nested_tensor_from_mask(device)
|
|
nt_2 = self._create_nested_tensor_from_mask(device)
|
|
|
|
nt_1.requires_grad_()
|
|
nt_2.requires_grad_()
|
|
c = nt_1 - nt_2
|
|
|
|
assert nt_1.requires_grad
|
|
assert nt_2.requires_grad
|
|
assert c.requires_grad
|
|
grad_output = self._create_nested_tensor_from_mask(device)
|
|
c.backward(grad_output)
|
|
|
|
self.assertEqual(nt_1.grad, grad_output)
|
|
self.assertEqual(nt_2.grad, -1 * grad_output)
|
|
|
|
def test_backward_sub_strided(self, device):
|
|
a = torch.nested.nested_tensor([torch.randn(9, 2, 4), torch.randn(12, 2, 4)], requires_grad=True, device=device)
|
|
b = torch.nested.nested_tensor([torch.randn(9, 4, 2), torch.randn(12, 4, 2)], requires_grad=True, device=device)
|
|
c = a - b.transpose(-1, -2)
|
|
grad_output = c.clone()
|
|
c.backward(grad_output)
|
|
self.assertEqual(a.grad, grad_output)
|
|
self.assertEqual(b.grad, -1 * grad_output.transpose(-1, -2))
|
|
|
|
def test_backward_add_strided(self, device):
|
|
a = torch.nested.nested_tensor([torch.randn(9, 2, 4), torch.randn(12, 2, 4)], requires_grad=True, device=device)
|
|
b = torch.nested.nested_tensor([torch.randn(9, 4, 2), torch.randn(12, 4, 2)], requires_grad=True, device=device)
|
|
c = a + b.transpose(-1, -2)
|
|
grad_output = c.clone()
|
|
c.backward(grad_output)
|
|
self.assertEqual(a.grad, grad_output)
|
|
self.assertEqual(b.grad, grad_output.transpose(-1, -2))
|
|
|
|
# Test Factory Functions
|
|
def test_nested_tensor_to_padded_tensor(self, device):
|
|
for padding_val in [0, 1]:
|
|
nt = self._create_leaf_nested_tensor_from_list(tensor_device=device, requires_grad=True)
|
|
|
|
out = torch.nested.to_padded_tensor(nt, padding_val)
|
|
grad_output = torch.ones(out.shape, device=device)
|
|
out.backward(grad_output)
|
|
|
|
self.assertEqual(nt.grad, torch.nested.nested_tensor([torch.ones(1, 2), torch.ones(7, 8)], device=device))
|
|
|
|
def test_nested_tensor_from_mask_and_to_padded(self, device):
|
|
N, L, D = 2, 4, 4
|
|
mask = torch.ones(N, L, device=device)
|
|
for i in range(1, N):
|
|
end = torch.randint(1, L - 1, (1,), device=device)
|
|
mask[i, end:] = 0
|
|
|
|
mask[0, :] = 1
|
|
mask = mask.bool()
|
|
|
|
data = torch.randn(N, L, D, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(inpt):
|
|
nt = torch._nested_tensor_from_mask(inpt, mask)
|
|
# This implicitly tests to_padded_tensor grads
|
|
return torch.nested.to_padded_tensor(nt, 0)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_from_padded(self, device):
|
|
nested_size = torch.tensor([[1, 2], [2, 2]])
|
|
padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64, device=device)
|
|
padded_tensor[0, 1, :] = 0
|
|
padded_tensor.requires_grad_()
|
|
|
|
def grad_test_func(tensor, nested_size):
|
|
nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=False)
|
|
# This implicitly tests to_padded_tensor grads
|
|
return torch.nested.to_padded_tensor(nt, 0)
|
|
|
|
data = (padded_tensor, nested_size)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_from_padded_fused(self, device):
|
|
nested_size = torch.tensor([[1, 8], [2, 8]])
|
|
padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64, device=device)
|
|
padded_tensor[0, 1, :] = 0
|
|
padded_tensor.requires_grad_()
|
|
|
|
def grad_test_func(tensor, nested_size):
|
|
nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=True)
|
|
# This implicitly tests to_padded_tensor grads
|
|
return torch.nested.to_padded_tensor(nt, 0)
|
|
data = (padded_tensor, nested_size)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_from_list(self, device):
|
|
|
|
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
c = torch.nested.as_nested_tensor([a, b, c])
|
|
# This implictily tests to_padded_tensor grads
|
|
return torch.nested.to_padded_tensor(c, 0)
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
@decorateIf(
|
|
xfailIfTorchDynamo,
|
|
# only fails in python 3.11. TODO: Debug this!
|
|
lambda params: params["layout"] == torch.jagged and sys.version_info >= (3, 11)
|
|
)
|
|
@parametrize("layout", [torch.strided, torch.jagged], name_fn=layout_name)
|
|
def test_dropout_backward(self, layout):
|
|
if layout == torch.jagged:
|
|
nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 5))], requires_grad=True, layout=layout)
|
|
else:
|
|
nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))], requires_grad=True, layout=layout)
|
|
p = 0.2
|
|
y = torch.nn.functional.dropout(nt, p)
|
|
y.backward(nt.clone().detach())
|
|
self.assertEqual(nt.grad, y)
|
|
|
|
def test_nested_tensor_bmm_gradcheck(self, device):
|
|
a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, d):
|
|
nt0 = torch.nested.as_nested_tensor([a, b])
|
|
nt1 = torch.nested.as_nested_tensor([c, d])
|
|
result = nt0.bmm(nt1)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
data = (a, b, c, d)
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
|
|
|
|
def test_nested_tensor_bmm_backward(self, device):
|
|
nt0 = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True, device=device)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((6, 4)), torch.randn((6, 5))], requires_grad=True, device=device)
|
|
with torch.no_grad():
|
|
pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
|
|
pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)
|
|
|
|
ynt = nt0.bmm(nt1)
|
|
ypt = pt0.bmm(pt1)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)
|
|
|
|
def test_nested_tensor_matmul_gradcheck(self, device):
|
|
a = torch.randn(2, 6, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 6, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(6, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
d = torch.randn(6, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, d):
|
|
nt0 = torch.nested.as_nested_tensor([a, b])
|
|
nt1 = torch.nested.as_nested_tensor([c, d])
|
|
result = torch.matmul(nt0, nt1)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
data = (a, b, c, d)
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
|
|
|
|
def test_nested_tensor_matmul_backward(self, device):
|
|
nt0 = torch.nested.nested_tensor([torch.randn((7, 2, 6)), torch.randn((7, 3, 6))], requires_grad=True, device=device)
|
|
nt1 = torch.nested.nested_tensor([torch.randn((7, 6, 4)), torch.randn((7, 6, 5))], requires_grad=True, device=device)
|
|
with torch.no_grad():
|
|
pt0 = torch.nested.to_padded_tensor(nt0, 0.0).requires_grad_(True)
|
|
pt1 = torch.nested.to_padded_tensor(nt1, 0.0).requires_grad_(True)
|
|
|
|
ynt = torch.matmul(nt0, nt1)
|
|
ypt = torch.matmul(pt0, pt1)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt0.grad, 0.0), pt0.grad)
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt1.grad, 0.0), pt1.grad)
|
|
|
|
def test_nested_tensor_transpose_gradcheck(self, device):
|
|
a = torch.randn(2, 5, requires_grad=True, device=device)
|
|
b = torch.randn(3, 4, requires_grad=True, device=device)
|
|
|
|
def grad_test_func(a, b):
|
|
nt = torch.nested.as_nested_tensor([a, b])
|
|
result = nt.transpose(-2, -1).transpose(-2, -1)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
data = (a, b)
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)
|
|
|
|
def test_nested_tensor_transpose_backward(self, device):
|
|
nt = torch.nested.nested_tensor([torch.randn((2, 5)), torch.randn((3, 4))], requires_grad=True, device=device)
|
|
with torch.no_grad():
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
|
|
|
ynt = nt.transpose(-2, -1)
|
|
ypt = pt.transpose(-2, -1)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
|
|
|
def test_nested_tensor_reshape_gradcheck(self, device):
|
|
a = torch.randn(2, 6, requires_grad=True, device=device)
|
|
b = torch.randn(3, 6, requires_grad=True, device=device)
|
|
|
|
def grad_test_func(a, b):
|
|
nt = torch.nested.as_nested_tensor([a, b])
|
|
result = nt.reshape(2, -1, 2, 3)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
data = (a, b)
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=data, eps=1e-3)
|
|
|
|
def test_nested_tensor_reshape_backward(self):
|
|
nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True)
|
|
with torch.no_grad():
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
|
|
|
ynt = nt.reshape(2, -1, 2, 3)
|
|
ypt = pt.reshape(2, -1, 2, 3)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
|
|
|
def test_nested_tensor_squeeze_backward(self, device):
|
|
nt = torch.nested.nested_tensor([torch.randn((2, 6, 1)), torch.randn((3, 6, 1))], requires_grad=True, device=device)
|
|
with torch.no_grad():
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
|
|
|
ynt = nt.squeeze(-1)
|
|
ypt = pt.squeeze(-1)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
|
|
|
def test_nested_tensor_squeeze_gradcheck(self, device):
|
|
a = torch.randn((2, 6, 1), dtype=torch.float64, requires_grad=True, device=device)
|
|
b = torch.randn((3, 6, 1), dtype=torch.float64, requires_grad=True, device=device)
|
|
|
|
def grad_test_func(a, b):
|
|
nt = torch.nested.as_nested_tensor([a, b])
|
|
result = nt.squeeze(-1)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
|
|
|
|
def test_nested_tensor_unsqueeze_backward(self, device):
|
|
nt = torch.nested.nested_tensor([torch.randn((2, 6)), torch.randn((3, 6))], requires_grad=True, device=device)
|
|
with torch.no_grad():
|
|
pt = torch.nested.to_padded_tensor(nt, 0.0).requires_grad_(True)
|
|
|
|
ynt = nt.unsqueeze(2)
|
|
ypt = pt.unsqueeze(2)
|
|
ynt.backward(ynt.clone())
|
|
ypt.backward(ypt.clone())
|
|
|
|
self.assertEqual(torch.nested.to_padded_tensor(nt.grad, 0.0), pt.grad)
|
|
|
|
def test_nested_tensor_unsqueeze_gradcheck(self, device):
|
|
a = torch.randn((2, 6), dtype=torch.float64, requires_grad=True, device=device)
|
|
b = torch.randn((3, 6), dtype=torch.float64, requires_grad=True, device=device)
|
|
|
|
def grad_test_func(a, b):
|
|
nt = torch.nested.as_nested_tensor([a, b])
|
|
result = nt.unsqueeze(-1)
|
|
return torch.nested.to_padded_tensor(result, 0.0)
|
|
|
|
assert torch.autograd.gradcheck(grad_test_func, inputs=(a, b), eps=1e-3)
|
|
|
|
def test_nested_tensor_linear(self, device):
|
|
|
|
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, weight, bias=None):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
# This implicitly tests to_padded_tensor grads
|
|
d = torch.functional.F.linear(nt, weight, bias)
|
|
return torch.nested.to_padded_tensor(d, 0)
|
|
data = (a, b, c, weight, bias)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# Test linear with no bias added
|
|
data = (a, b, c, weight)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_linear_plus_transpose(self, device):
|
|
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
bias = torch.randn(2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, weight, bias=None):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
# This implicitly tests to_padded_tensor grads
|
|
d = torch.functional.F.linear(nt, weight, bias)
|
|
d = d.transpose(-1, -2).contiguous()
|
|
return torch.nested.to_padded_tensor(d, 0)
|
|
data = (a, b, c, weight, bias)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# Test linear with no bias added
|
|
data = (a, b, c, weight)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_softmax(self, device):
|
|
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, dim):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
# This implicitly tests to_padded_tensor grads
|
|
d = torch.functional.F.softmax(nt, dim=dim)
|
|
return torch.nested.to_padded_tensor(d, 0)
|
|
|
|
# softmax over last dim
|
|
data = (a, b, c, -1)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_nested_tensor_linear_backward(self, device):
|
|
a = torch.randn(1, 2, requires_grad=False, device=device)
|
|
b = torch.randn(2, 2, requires_grad=False, device=device)
|
|
c = torch.randn(3, 2, requires_grad=False, device=device)
|
|
|
|
weight = torch.randn(2, 2, requires_grad=True, device=device)
|
|
bias = torch.randn(2, requires_grad=True, device=device)
|
|
nt = torch.nested.as_nested_tensor([a, b, c], device=device)
|
|
|
|
out = torch.functional.F.linear(nt, weight, bias)
|
|
|
|
out.backward(out.clone())
|
|
|
|
assert weight.grad is not None
|
|
assert bias.grad is not None
|
|
|
|
assert a.grad is None
|
|
assert b.grad is None
|
|
assert c.grad is None
|
|
|
|
def test_values_grad_with_broadcast(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
buffer = nt.values()
|
|
return buffer.sum()
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_to_buffer_series_ops_grad_with_broadcast(self, device):
|
|
a = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(1, 1, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
buffer = nt.values()
|
|
buffer = buffer * 2
|
|
return buffer.exp()
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_unbind_flow_through(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
ntT = nt.transpose(-1, -2)
|
|
unbound = ntT.unbind()
|
|
d = unbound[0]
|
|
d = torch.pow(d, 2)
|
|
return d
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_split_with_sizes_flow_through(self, device):
|
|
a = torch.randn(2, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 5, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
splits = nt.split_with_sizes([2, 3], dim=-1)
|
|
unbound = splits[1].unbind()
|
|
d = unbound[0]
|
|
d = torch.pow(d, 2)
|
|
return d
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_indexing_backward(self, device):
|
|
x0 = torch.randn((2, 5))
|
|
x1 = torch.randn((3, 4))
|
|
nt = torch.nested.nested_tensor([x0, x1], device=device, requires_grad=True)
|
|
self.assertEqual(nt[0], x0)
|
|
self.assertEqual(nt[-1], x1)
|
|
grad_x0 = torch.randn((2, 5), device=device)
|
|
nt[0].backward(grad_x0)
|
|
expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device)])
|
|
self.assertEqual(nt.grad, expected_grad)
|
|
|
|
def test_masked_fill_backward(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
mask = nt.detach().clone().to(bool)
|
|
out = nt.masked_fill(mask, 0)
|
|
out = torch.nested.to_padded_tensor(out, 0)
|
|
return out
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_gelu_backward(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
nt_gelu = torch.nn.functional.gelu(nt)
|
|
return torch.nested.to_padded_tensor(nt_gelu, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_relu_backward(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
nt_relu = torch.nn.functional.relu(nt)
|
|
return torch.nested.to_padded_tensor(nt_relu, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_selu_backward(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
nt_relu = torch.nn.functional.silu(nt)
|
|
return torch.nested.to_padded_tensor(nt_relu, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
def test_abs_backward(self, device):
|
|
a = torch.randn(1, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
nt_abs = torch.abs(nt)
|
|
return torch.nested.to_padded_tensor(nt_abs, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# Previously would error when input NT doesn't require grad
|
|
# NotImplementedError: Cannot access storage of UndefinedTensorImpl
|
|
def test_layer_norm_backward_edge_case(self, device):
|
|
size = 4
|
|
a = torch.randn(1, 2, size, requires_grad=False, dtype=torch.float64, device=device)
|
|
nt = torch.nested.nested_tensor([a])
|
|
nt_layer_norm = torch.nn.LayerNorm(nt.size(-1), device=device, dtype=torch.float64)
|
|
out = nt_layer_norm(nt)
|
|
out.backward(out.clone())
|
|
|
|
def test_accumulate_grad_different_strides(self, device):
|
|
a = torch.rand(1, 4, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.rand(1, 8, 2, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b):
|
|
nt_1 = torch.nested.as_nested_tensor([a, b])
|
|
nt_2 = nt_1.clone()
|
|
out = torch.nn.functional.scaled_dot_product_attention(nt_1, nt_2, nt_2)
|
|
return torch.nested.to_padded_tensor(out, 0)
|
|
|
|
data = (a, b)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# TODO: OOM https://github.com/pytorch/pytorch/issues/95562
|
|
@skipIfSlowGradcheckEnv
|
|
@parametrize("size", [1024, 1023, 513, 512, 256, 128, 32, 4, 2])
|
|
def test_layer_norm_backward(self, device, size):
|
|
a = torch.randn(1, 2, size, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(2, 2, size, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(3, 2, size, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
layer_norm = torch.nn.LayerNorm(nt.size(-1), device=device, dtype=torch.float64)
|
|
nt_layer_norm = layer_norm(nt)
|
|
return torch.nested.to_padded_tensor(nt_layer_norm, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# TODO: OOM https://github.com/pytorch/pytorch/issues/95562
|
|
@skipIfSlowGradcheckEnv
|
|
# Could either mark slow or reduce size
|
|
@parametrize("size", [128, 32, 4, 2])
|
|
def test_layer_norm_backward_5d(self, device, size):
|
|
a = torch.randn(4, size, size, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(7, size, size, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(10, size, size, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt = torch.nested.as_nested_tensor([a, b, c])
|
|
layer_norm = torch.nn.LayerNorm((size, size, nt.size(-1)), device=device, dtype=torch.float64)
|
|
nt_layer_norm = layer_norm(nt)
|
|
return torch.nested.to_padded_tensor(nt_layer_norm, 0)
|
|
|
|
data = (a, b, c)
|
|
assert gradcheck(grad_test_func, inputs=data, check_batched_grad=False)
|
|
|
|
# Found in torch/testing/_comparison.py
|
|
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
|
|
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}
|
|
|
|
def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
|
|
deviation = true_value - computed_value
|
|
deviation = torch.abs(deviation / true_value)
|
|
# Fill in the nans with the default rtol
|
|
torch.nan_to_num_(deviation, nan=default_rtol[computed_value.dtype])
|
|
return deviation.max().item()
|
|
|
|
|
|
def get_atol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
|
|
deviation = true_value - computed_value
|
|
atol = torch.abs(deviation).max().item()
|
|
return atol
|
|
|
|
|
|
def get_tolerances(
|
|
true_value: torch.Tensor,
|
|
computed_value: torch.Tensor,
|
|
fudge_factor: Optional[float] = None,
|
|
) -> Tuple[float, float]:
|
|
"""Returns the absolute and relative tolerances for comparing two tensors."""
|
|
fudge_factor = fudge_factor if fudge_factor is not None else 1.0
|
|
atol = get_atol(true_value, computed_value)
|
|
rtol = get_rtol(true_value, computed_value)
|
|
|
|
atol = fudge_factor * max(atol, default_atol[computed_value.dtype])
|
|
rtol = fudge_factor * max(rtol, default_rtol[computed_value.dtype])
|
|
# torch.isclose() has weird behavior around see:
|
|
# https://github.com/pytorch/pytorch/issues/102400
|
|
if rtol > 1e30:
|
|
rtol = default_rtol[computed_value.dtype]
|
|
return atol, rtol
|
|
|
|
# We can probably parametrizing existing tests instead of having a separate
|
|
# test class as we begin to support more ops. Also maybe rewrite with OpInfos.
|
|
@markDynamoStrictTest
|
|
class TestNestedTensorSubclass(TestCase):
|
|
# TODO: consolidate with the below
|
|
def _get_list_for_jagged_tensor(self, nested_size, device, requires_grad=True):
|
|
Ds = nested_size[1:]
|
|
out = []
|
|
for s in nested_size[0]:
|
|
out.append(
|
|
torch.randn(s, *Ds, requires_grad=requires_grad, device=device, dtype=torch.float64)
|
|
)
|
|
return out
|
|
|
|
def _get_example_tensor_lists(self, include_list_of_lists=True, include_requires_grad=True):
|
|
|
|
def _make_tensor(*shape, include_requires_grad=include_requires_grad, requires_grad=True):
|
|
return torch.randn(
|
|
*shape,
|
|
requires_grad=(requires_grad if include_requires_grad else False)
|
|
)
|
|
|
|
# Purposefully introduce mixed requires_grad settings for the components
|
|
# when include_requires_grad=True.
|
|
example_lists = [
|
|
# (B, *, D) with B=4
|
|
[
|
|
_make_tensor(2, 5),
|
|
_make_tensor(3, 5, requires_grad=False),
|
|
_make_tensor(4, 5, requires_grad=False),
|
|
_make_tensor(6, 5)
|
|
],
|
|
# (B, *, D_0, D_1) with B=5
|
|
[
|
|
_make_tensor(2, 5, 6),
|
|
_make_tensor(3, 5, 6),
|
|
_make_tensor(4, 5, 6, requires_grad=False),
|
|
_make_tensor(5, 5, 6),
|
|
_make_tensor(6, 5, 6),
|
|
],
|
|
]
|
|
|
|
if include_list_of_lists:
|
|
example_lists.append(
|
|
# (B, *, D) with B=3 in list form
|
|
[
|
|
_make_tensor(2, 5, requires_grad=False).tolist(),
|
|
_make_tensor(3, 5).tolist(),
|
|
_make_tensor(4, 5).tolist(),
|
|
])
|
|
|
|
return example_lists
|
|
|
|
def test_tensor_attributes(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
nt, _offsets = jagged_from_list([a, b, c], None)
|
|
|
|
for op in (
|
|
torch.ops.aten.is_non_overlapping_and_dense.default,
|
|
torch.ops.aten.sym_size.default,
|
|
torch.ops.aten.dim.default,
|
|
torch.ops.aten.sym_numel.default,
|
|
torch.ops.aten.sym_stride.default,
|
|
torch.ops.aten.sym_storage_offset.default,
|
|
):
|
|
op(nt)
|
|
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
"directly calling torch.ops.aten.size"):
|
|
torch.ops.aten.size.default(nt)
|
|
|
|
singleton_int = torch.nested._internal.nested_tensor.get_tensor_symint(_offsets, coeff=1)
|
|
self.assertEqual(nt.size(), (3, singleton_int, 3))
|
|
self.assertEqual(nt.shape, (3, singleton_int, 3))
|
|
self.assertEqual(nt.dim(), 3)
|
|
self.assertEqual(nt.numel(), 27)
|
|
|
|
def test_linear(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
weight = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, weight):
|
|
nt, _ = jagged_from_list([a, b, c], None)
|
|
out = torch.nn.functional.linear(nt, weight)
|
|
return buffer_from_jagged(out)
|
|
|
|
gradcheck(grad_test_func, inputs=(a, b, c, weight), check_batched_grad=False)
|
|
|
|
def test_unary_pointwise(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c):
|
|
nt, _ = jagged_from_list([a, b, c], None)
|
|
out = torch.nn.functional.silu(nt.sin().cos())
|
|
return buffer_from_jagged(out)
|
|
|
|
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
|
|
|
|
def test_binary_pointwise(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
# Incorrect usage: shape check will fail if the offsets tensor are not
|
|
# the same exact tensor object
|
|
nt1, _ = jagged_from_list([a, b, c], None)
|
|
nt2, _ = jagged_from_list([a, b, c], None)
|
|
|
|
self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"cannot call binary pointwise function .* with inputs of shapes",
|
|
lambda: nt1 * nt2)
|
|
|
|
# Correct usage: chain the calls using the same offsets tensor object
|
|
def grad_test_func(a, b, c):
|
|
nt1, offsets = jagged_from_list([a, b, c], None)
|
|
nt2, offsets = jagged_from_list([a, b, c], offsets)
|
|
out = nt1 * nt2
|
|
return buffer_from_jagged(out)
|
|
|
|
gradcheck(grad_test_func, inputs=(a, b, c), check_batched_grad=False)
|
|
|
|
def test_split(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
nt, _ = jagged_from_list([a, b, c], None)
|
|
out = torch.split(nt, 2, -1)
|
|
self.assertEqual(len(out), 2)
|
|
self.assertEqual(
|
|
out[0], jagged_from_list([a[:, 0:2], b[:, 0:2], c[:, 0:2]], None)[0]
|
|
)
|
|
self.assertEqual(
|
|
out[1], jagged_from_list([a[:, 2:], b[:, 2:], c[:, 2:]], None)[0]
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"split\(\): not supported for NestedTensor on dim=0 or dim=1",
|
|
):
|
|
torch.split(nt, 2, 1)
|
|
|
|
def test_split_with_sizes(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
|
|
nt, _ = jagged_from_list([a, b, c], None)
|
|
out = torch.split(nt, [1, 2], -1)
|
|
self.assertEqual(len(out), 2)
|
|
self.assertEqual(
|
|
out[0], jagged_from_list([a[:, 0:1], b[:, 0:1], c[:, 0:1]], None)[0]
|
|
)
|
|
self.assertEqual(
|
|
out[1], jagged_from_list([a[:, 1:], b[:, 1:], c[:, 1:]], None)[0]
|
|
)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"split_with_sizes\(\): not supported for NestedTensor on dim=0 or dim=1",
|
|
):
|
|
torch.split(nt, [1, 2], 1)
|
|
|
|
def test_views_inherit_ragged_dim(self, device):
|
|
# view
|
|
nt = random_nt_from_dims(
|
|
[4, None, 8, 10], device=device, dtype=torch.float32, layout=torch.jagged)
|
|
# inherit ragged dim via -1
|
|
view = nt.view(4, -1, 80)
|
|
self.assertEqual(nt.shape[1], view.shape[1])
|
|
# inherit batch and ragged dims via -1
|
|
view2 = nt.view(-1, -1, 80)
|
|
self.assertEqual(nt.shape[:2], view2.shape[:2])
|
|
|
|
# expand
|
|
nt = random_nt_from_dims(
|
|
[3, None, 1], device=device, dtype=torch.float32, layout=torch.jagged)
|
|
# inherit batch and ragged dims via -1
|
|
view = nt.expand(-1, -1, 5)
|
|
self.assertEqual(nt.shape[:2], view.shape[:2])
|
|
|
|
@xfailIfTorchDynamo
|
|
@parametrize("requires_grad", [False, True])
|
|
def test_reshape_decomp(self, device, requires_grad):
|
|
# contiguous NT should result in view
|
|
nt = random_nt_from_dims(
|
|
[3, None, 10],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad
|
|
)
|
|
view = nt.reshape(-1, -1, 5, 2)
|
|
self.assertEqual(view.shape[:2], nt.shape[:2])
|
|
self.assertTrue(view._is_view() and view._base is nt)
|
|
# make sure gradients flow back
|
|
if requires_grad:
|
|
view.backward(torch.ones_like(view))
|
|
self.assertEqual(nt.grad, torch.ones_like(nt))
|
|
|
|
# non-contiguous NT should result in contiguous copy
|
|
nt = random_nt_from_dims(
|
|
[3, None, 5, 2],
|
|
device=device,
|
|
dtype=torch.float32,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad
|
|
)
|
|
nt_noncontig = nt.transpose(-1, -2)
|
|
self.assertFalse(nt_noncontig.is_contiguous())
|
|
copy = nt_noncontig.reshape(-1, -1, 10)
|
|
self.assertTrue(copy.is_contiguous())
|
|
self.assertEqual(copy.shape[:2], nt.shape[:2])
|
|
# make sure gradients flow back
|
|
if requires_grad:
|
|
copy.backward(torch.ones_like(copy))
|
|
self.assertEqual(nt.grad, torch.ones_like(nt))
|
|
|
|
def test_flatten_decomp(self, device):
|
|
nt = random_nt_from_dims(
|
|
[3, None, 5, 2], device=device, dtype=torch.float32, layout=torch.jagged)
|
|
flattened = nt.flatten(-2, -1)
|
|
self.assertEqual(flattened.shape, nt.view(3, -1, 10).shape)
|
|
|
|
nt = random_nt_from_dims(
|
|
[3, None, 5, 2, 6], device=device, dtype=torch.float32, layout=torch.jagged)
|
|
flattened = nt.flatten(-3, -2)
|
|
self.assertEqual(flattened.shape, nt.view(3, -1, 10, 6).shape)
|
|
|
|
def test_chunk(self, device):
|
|
# normal case
|
|
D = 30
|
|
nt = random_nt_from_dims(
|
|
[4, None, D], device=device, dtype=torch.float32, layout=torch.jagged)
|
|
NUM_CHUNKS = 3
|
|
chunks = nt.chunk(NUM_CHUNKS, dim=-1)
|
|
self.assertEqual(len(chunks), NUM_CHUNKS)
|
|
for i in range(NUM_CHUNKS):
|
|
self.assertEqual(chunks[i].shape[-1], D // NUM_CHUNKS)
|
|
|
|
# chunk on batch dim not supported
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "chunk.* not supported for NestedTensor on dim=0 or dim=1"):
|
|
nt.chunk(2, dim=0)
|
|
|
|
# chunk on ragged dim not supported
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "chunk.* not supported for NestedTensor on dim=0 or dim=1"):
|
|
nt.chunk(2, dim=1)
|
|
|
|
def test_squeeze(self, device):
|
|
B = 4
|
|
D = 6
|
|
# squeeze middle dim
|
|
nt = random_nt_from_dims(
|
|
[B, None, 1, D], device=device, dtype=torch.float32, layout=torch.jagged)
|
|
j0 = nt.shape[1]
|
|
|
|
for dim_arg in [-2, 2]:
|
|
out = nt.squeeze(dim_arg)
|
|
self.assertEqual(out.shape, (B, j0, D))
|
|
self.assertEqual(out.unsqueeze(-2), nt)
|
|
|
|
# squeeze last dim
|
|
nt = random_nt_from_dims(
|
|
[B, None, 1], device=device, dtype=torch.float32, layout=torch.jagged)
|
|
j1 = nt.shape[1]
|
|
|
|
for dim_arg in [-1, 2]:
|
|
out = nt.squeeze(dim_arg)
|
|
self.assertEqual(out.shape, (B, j1))
|
|
self.assertEqual(out.unsqueeze(-1), nt)
|
|
|
|
# squeeze on batch dim not supported
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "squeeze.* not supported for NestedTensor on dim=0 or dim=1"):
|
|
nt.squeeze(0)
|
|
|
|
# squeeze on ragged dim not supported
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "squeeze.* not supported for NestedTensor on dim=0 or dim=1"):
|
|
nt.squeeze(1)
|
|
|
|
def test_binary_pointwise_broadcasting(self, device):
|
|
# (B, j0, 3, 4)
|
|
ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device, requires_grad=True)
|
|
# (B, j0, ?, ?) + (?) -> (B, j0, ?, ?)
|
|
# (B, j0, ?, ?) + (?, ?) -> (B, j0, ?, ?)
|
|
# (B, j0, ?, ?) + (1, ?, ?) -> (B, j0, ?, ?)
|
|
# Unsupported: (B, j0, ?, ?) + (1, 1, 1, ?, ?) -> (1, B, j0, ?, ?)
|
|
t_sizes = (
|
|
(4,),
|
|
(1, 4),
|
|
(3, 1),
|
|
(1, 3, 1),
|
|
(1, 1, 1, 4),
|
|
# (1, 1, 1, 1, 4), (unsupported today)
|
|
)
|
|
|
|
def grad_test_func(t, *ts):
|
|
nt, _ = jagged_from_list(ts, None)
|
|
out = nt + t
|
|
return buffer_from_jagged(out)
|
|
|
|
for t_size in t_sizes:
|
|
t = torch.rand(t_size, requires_grad=True, device=device, dtype=torch.float64)
|
|
gradcheck(grad_test_func, inputs=(t, *ts), check_batched_grad=False)
|
|
|
|
@parametrize("keepdim", [False, True])
|
|
def test_sum_int_DimList(self, device, keepdim):
|
|
# (B, j0, 3, 4)
|
|
ts = self._get_list_for_jagged_tensor(((2, 3, 4), 3, 4), device=device, requires_grad=True)
|
|
|
|
# Check shape correctness
|
|
reduce_dims = (
|
|
# dims, expected shape, expected keepdim shape
|
|
# j0 is represented as None
|
|
((0, 1), (3, 4), (1, 1, 3, 4)),
|
|
((1, 2), None, None),
|
|
((2, 3), (3, None), (3, None, 1, 1)),
|
|
((0, 1, 3), (3,), (1, 1, 3, 1)),
|
|
((0, 1, 2), (4,), (1, 1, 1, 4)),
|
|
((0, 1, 2, 3), tuple(), (1, 1, 1, 1)),
|
|
)
|
|
for rd, ref_shape_no_keepdim, ref_shape_keepdim in reduce_dims:
|
|
if (0 in rd) ^ (1 in rd):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"applying over the ragged dimension, but not the batch dimension"):
|
|
nt, _ = jagged_from_list(ts, None)
|
|
out = torch.sum(nt, dim=rd, keepdim=keepdim)
|
|
continue
|
|
|
|
nt, _ = jagged_from_list(ts, None)
|
|
out = torch.sum(nt, dim=rd, keepdim=keepdim)
|
|
ref_shape = ref_shape_keepdim if keepdim else ref_shape_no_keepdim
|
|
self.assertEqual(len(out.shape), len(ref_shape))
|
|
for o, r in zip(out.shape, ref_shape):
|
|
if r is not None:
|
|
self.assertEqual(o, r)
|
|
else:
|
|
self.assertTrue(isinstance(o, torch.SymInt))
|
|
|
|
# Check values correctness
|
|
# raggedness not reduced
|
|
nt, _ = jagged_from_list(ts, None)
|
|
out = torch.sum(nt, dim=(2, 3), keepdim=keepdim)
|
|
out_ref = torch.sum(nt.values(), dim=(1, 2))
|
|
self.assertIsInstance(out, NestedTensor)
|
|
# flatten to avoid having to replicate unsqueeze logic depending on keepdim
|
|
self.assertTrue(torch.allclose(out.values().view(-1), out_ref.view(-1)))
|
|
|
|
# raggedness reduced away
|
|
nt, _ = jagged_from_list(ts, None)
|
|
out = torch.sum(nt, dim=(0, 1), keepdim=keepdim)
|
|
out_ref = torch.sum(nt.values(), dim=(0,))
|
|
self.assertNotIsInstance(out, NestedTensor)
|
|
self.assertTrue(torch.allclose(out, out_ref))
|
|
|
|
|
|
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("weights_only", [False, True])
|
|
def test_serialization(self, device, dtype, requires_grad, weights_only):
|
|
|
|
def compare_metadata(nt1, nt2):
|
|
self.assertEqual(nt1._nested_tensor_size(), nt2._nested_tensor_size())
|
|
self.assertEqual(nt1._nested_tensor_strides(), nt2._nested_tensor_strides())
|
|
self.assertEqual(nt1._nested_tensor_storage_offsets(),
|
|
nt2._nested_tensor_storage_offsets())
|
|
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
|
for a in [nt_contiguous, nt_noncontiguous]:
|
|
buffer = io.BytesIO()
|
|
serialized = torch.save(a, buffer)
|
|
buffer.seek(0)
|
|
b = torch.load(buffer, weights_only=weights_only)
|
|
# should be both conceptually equal and metadata equivalent
|
|
self.assertEqual(a, b)
|
|
compare_metadata(a, b)
|
|
# should be conceptually equal but not necessarily metadata equivalent
|
|
self.assertEqual(b, nt_contiguous)
|
|
self.assertEqual(b, nt_noncontiguous)
|
|
|
|
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
|
|
@onlyCUDA
|
|
def test_pin_memory(self, device):
|
|
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7))
|
|
for nt in [nt_contiguous, nt_noncontiguous]:
|
|
self.assertFalse(nt.is_pinned())
|
|
pinned = nt.pin_memory(device)
|
|
self.assertTrue(pinned.is_pinned())
|
|
self.assertEqual(nt, pinned)
|
|
self.assertNotEqual(nt.data_ptr(), pinned.data_ptr())
|
|
# test that pin_memory on already pinned tensor has no effect
|
|
self.assertIs(pinned, pinned.pin_memory())
|
|
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
|
|
|
|
def _validate_nt(self, nt, tensor_list, device, dtype, requires_grad):
|
|
# Validate a bunch of properties after NT construction.
|
|
device = torch.device(device)
|
|
first_t = torch.as_tensor(tensor_list[0])
|
|
expected_dim = first_t.dim() + 1
|
|
batch_size = len(tensor_list)
|
|
self.assertEqual(nt.dim(), expected_dim)
|
|
self.assertEqual(nt.device, device)
|
|
self.assertEqual(nt.dtype, dtype)
|
|
self.assertEqual(nt.layout, torch.jagged)
|
|
self.assertEqual(nt.requires_grad, requires_grad)
|
|
self.assertEqual(nt.values().device, device)
|
|
self.assertEqual(nt.offsets().device, device)
|
|
self.assertEqual(nt.shape[0], batch_size)
|
|
self.assertTrue(isinstance(nt.shape[1], torch.SymInt))
|
|
self.assertEqual(nt.shape[2:], first_t.shape[1:])
|
|
|
|
@xfailIfTorchDynamo
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
@parametrize("requires_grad", [False, True])
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_jagged_layout_construction_nested_tensor(
|
|
self, device, dtype, requires_grad, components_require_grad):
|
|
for tensor_list in self._get_example_tensor_lists(
|
|
include_list_of_lists=True, include_requires_grad=components_require_grad):
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged,
|
|
requires_grad=requires_grad)
|
|
self._validate_nt(nt, tensor_list, device, dtype, requires_grad)
|
|
|
|
# Make sure grads -don't- flow back into original tensors for nested_tensor()
|
|
if requires_grad:
|
|
(nt * 2).backward(torch.ones_like(nt))
|
|
for t in tensor_list:
|
|
t = t if isinstance(t, torch.Tensor) else torch.as_tensor(t)
|
|
self.assertTrue(t.grad is None)
|
|
|
|
@xfailIfTorchDynamo
|
|
@dtypes(torch.float, torch.double, torch.half)
|
|
@parametrize("components_require_grad", [False, True])
|
|
def test_jagged_layout_construction_as_nested_tensor(
|
|
self, device, dtype, components_require_grad):
|
|
# NB: as_nested_tensor(tensor_list) doesn't support lists of lists for tensor_list
|
|
for tensor_list in self._get_example_tensor_lists(
|
|
include_list_of_lists=False, include_requires_grad=components_require_grad):
|
|
nt = torch.nested.as_nested_tensor(
|
|
tensor_list,
|
|
device=device,
|
|
dtype=dtype,
|
|
layout=torch.jagged)
|
|
|
|
# nt.requires_grad=True should be set if at least one component requires grad
|
|
self._validate_nt(nt, tensor_list, device, dtype, components_require_grad)
|
|
|
|
# Make sure grads flow back into original tensors for as_nested_tensor()
|
|
if components_require_grad:
|
|
(nt * 2).backward(torch.ones_like(nt))
|
|
for t in tensor_list:
|
|
if t.requires_grad:
|
|
self.assertEqual(t.grad, torch.ones_like(t) * 2)
|
|
else:
|
|
self.assertTrue(t.grad is None)
|
|
|
|
@xfailIfTorchDynamo
|
|
@unittest.skipIf(PYTORCH_CUDA_MEMCHECK, "is_pinned uses failure to detect pointer property")
|
|
@onlyCUDA
|
|
def test_jagged_layout_construction_with_pinned_memory(self, device):
|
|
for tensor_list in self._get_example_tensor_lists():
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
layout=torch.jagged,
|
|
device="cpu",
|
|
pin_memory=True)
|
|
|
|
self._validate_nt(nt, tensor_list, "cpu", torch.float32, requires_grad=False)
|
|
self.assertTrue(nt.is_pinned())
|
|
|
|
@dtypes(torch.double, torch.half)
|
|
@onlyCUDA
|
|
def test_device_dtype_transfer_maintains_offsets(self, device, dtype):
|
|
for tensor_list in self._get_example_tensor_lists():
|
|
orig_device = torch.device("cpu")
|
|
orig_dtype = torch.float32
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
layout=torch.jagged,
|
|
device=orig_device,
|
|
dtype=orig_dtype)
|
|
|
|
self.assertEqual(torch.int64, nt.offsets().dtype)
|
|
nt = nt.to(device=device).to(dtype=dtype)
|
|
|
|
# offsets should still be int64 on the original device
|
|
self.assertEqual(orig_device, nt.offsets().device)
|
|
self.assertEqual(torch.int64, nt.offsets().dtype)
|
|
|
|
def test_unbind(self, device):
|
|
for tensor_list in self._get_example_tensor_lists():
|
|
nt = torch.nested.nested_tensor(
|
|
tensor_list,
|
|
layout=torch.jagged,
|
|
device=device)
|
|
out = nt.unbind()
|
|
self.assertEqual(len(out), len(tensor_list))
|
|
for i, t in enumerate(out):
|
|
self.assertEqual(t, tensor_list[i])
|
|
|
|
@xfailIfTorchDynamo
|
|
def test_layer_norm_2(self, device):
|
|
test_tensor_list = self._get_list_for_jagged_tensor(
|
|
((2, 3, 4), 3), device=device, requires_grad=True
|
|
)
|
|
bias = torch.randn(3, requires_grad=False, dtype=torch.float64, device=device)
|
|
|
|
def grad_test_func(a, b, c, bias):
|
|
nt, _ = jagged_from_list([a, b, c], None)
|
|
out = torch.nn.functional.layer_norm(nt, (nt.shape[-1],), bias=bias)
|
|
return buffer_from_jagged(out)
|
|
|
|
gradcheck(
|
|
grad_test_func, inputs=(*test_tensor_list, bias), check_batched_grad=False
|
|
)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"layer_norm\(\): normalizing over ragged dim not supported for nested tensors",
|
|
):
|
|
nt, _ = jagged_from_list(test_tensor_list, None)
|
|
_ = torch.nn.functional.layer_norm(nt, (nt.shape[-2], nt.shape[-1]))
|
|
|
|
def test_narrow(self, device):
|
|
starts = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
|
|
lengths = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
|
|
nt = torch.nested.narrow(
|
|
torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone().detach(),
|
|
1,
|
|
starts,
|
|
lengths,
|
|
layout=torch.jagged
|
|
)
|
|
|
|
# TODO: Use this approach when unbind is functional
|
|
# unbinded_nt = nt.unbind()
|
|
# for i in range(starts.shape[0]):
|
|
# self.assertEqual(torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64), unbinded_nt[i])
|
|
for i in range(starts.shape[0]):
|
|
self.assertEqual(
|
|
torch.arange(starts[i], starts[i] + lengths[i], device=device, dtype=torch.int64),
|
|
nt.values()[nt.offsets()[i]:(nt.offsets()[i] + nt.lengths()[i])]
|
|
)
|
|
|
|
def test_is_contiguous(self, device):
|
|
a = torch.randn(2, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, requires_grad=True, dtype=torch.float64, device=device)
|
|
nt_contiguous, _ = jagged_from_list([a, b, c], None)
|
|
|
|
starts_nc = torch.tensor([0, 1, 2, 3, 4], device=device, dtype=torch.int64)
|
|
lengths_nc = torch.tensor([3, 2, 2, 1, 5], device=device, dtype=torch.int64)
|
|
narrow_base = torch.arange(0, 10, device=device, dtype=torch.int64).unsqueeze(0).expand(5, -1).clone()
|
|
nt_noncontiguous = torch.nested.narrow(
|
|
narrow_base,
|
|
1,
|
|
starts_nc,
|
|
lengths_nc,
|
|
layout=torch.jagged
|
|
)
|
|
|
|
starts_c = torch.tensor([1, 0, 0, 0, 0], device=device, dtype=torch.int64)
|
|
lengths_c = torch.tensor([9, 10, 10, 10, 8], device=device, dtype=torch.int64)
|
|
nt_contiguous_narrow = torch.nested.narrow(
|
|
narrow_base,
|
|
1,
|
|
starts_c,
|
|
lengths_c,
|
|
layout=torch.jagged
|
|
)
|
|
|
|
# Test contiguous case
|
|
assert nt_contiguous.is_contiguous()
|
|
|
|
# Test narrow case
|
|
assert not nt_noncontiguous.is_contiguous()
|
|
assert nt_contiguous_narrow.is_contiguous()
|
|
|
|
# Test querying by memory_format
|
|
self.assertTrue(nt_contiguous.is_contiguous(memory_format=torch.contiguous_format))
|
|
self.assertTrue(not nt_noncontiguous.is_contiguous(memory_format=torch.contiguous_format))
|
|
self.assertTrue(nt_contiguous_narrow.is_contiguous(memory_format=torch.contiguous_format))
|
|
|
|
def test_noncontiguous_pointwise(self, device):
|
|
a = torch.randn(2, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
b = torch.randn(3, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
c = torch.randn(4, 3, 4, requires_grad=True, dtype=torch.float64, device=device)
|
|
nt, _ = jagged_from_list([a, b, c], None)
|
|
# transpose ragged dim
|
|
transposed = nt.transpose(1, 2)
|
|
# pointwise ops are not supported on ragged dim transposed jagged layout NTs
|
|
with self.assertRaisesRegex(ValueError, "expected .* to be a contiguous jagged layout"):
|
|
clone = transposed.clone()
|
|
|
|
# Note 1: Math fallback doesn't work with bfloat16 on CUDA
|
|
# Note 2: ROCm doesn't support flash attention or mem_efficient attention for NT
|
|
@xfailIfTorchDynamo
|
|
@unittest.skipIf(
|
|
TEST_WITH_ROCM,
|
|
"ROCm doesn't support flash attention or mem_efficient attention for NT",
|
|
)
|
|
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if
|
|
SM80OrLater else [torch.float16, torch.float32])
|
|
def test_sdpa(self, device, dtype):
|
|
batch_size = 1
|
|
emb_dims = 128
|
|
n_heads = 8
|
|
head_dims = emb_dims // n_heads
|
|
|
|
sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
|
|
sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)
|
|
|
|
query = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype)
|
|
key = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype)
|
|
value = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype)
|
|
|
|
# Simplest case: 1 sentence, no batching
|
|
x_d1 = sen1.unsqueeze(0)
|
|
x_nt = torch.nested.as_nested_tensor([sen1], layout=torch.jagged)
|
|
|
|
# See note below for why we detach here.
|
|
q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True)
|
|
q_d1_t = q_d1.transpose(1, 2)
|
|
k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True)
|
|
k_d1_t = k_d1.transpose(1, 2)
|
|
v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True)
|
|
v_d1_t = v_d1.transpose(1, 2)
|
|
|
|
q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True)
|
|
q_nt_t = q_nt.transpose(1, 2)
|
|
k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True)
|
|
k_nt_t = k_nt.transpose(1, 2)
|
|
v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True)
|
|
v_nt_t = v_nt.transpose(1, 2)
|
|
|
|
# High Precision Math Reference
|
|
q_d1_f32 = q_d1.to(torch.float32)
|
|
k_d1_f32 = k_d1.to(torch.float32)
|
|
v_d1_f32 = v_d1.to(torch.float32)
|
|
q_d1_f32_t = q_d1_f32.transpose(1, 2)
|
|
k_d1_f32_t = k_d1_f32.transpose(1, 2)
|
|
v_d1_f32_t = v_d1_f32.transpose(1, 2)
|
|
out_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_f32_t, k_d1_f32_t, v_d1_f32_t)[0]
|
|
grads_ref = torch.autograd.grad(out_ref.sum(), (q_d1_f32, k_d1_f32, v_d1_f32))
|
|
|
|
# Low Precision Math Reference
|
|
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_t, k_d1_t, v_d1_t)[0]
|
|
grads_lp_ref = torch.autograd.grad(out_lp_ref.sum(), (q_d1, k_d1, v_d1))
|
|
|
|
# Compute tolerances
|
|
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
|
|
grad_q_ref_atol, grad_q_ref_rtol = get_tolerances(grads_ref[0], grads_lp_ref[0])
|
|
grad_k_ref_atol, grad_k_ref_rtol = get_tolerances(grads_ref[1], grads_lp_ref[1])
|
|
grad_v_ref_atol, grad_v_ref_rtol = get_tolerances(grads_ref[2], grads_lp_ref[2])
|
|
grad_atols = [grad_q_ref_atol, grad_k_ref_atol, grad_v_ref_atol]
|
|
grad_rtols = [grad_q_ref_rtol, grad_k_ref_rtol, grad_v_ref_rtol]
|
|
|
|
attn_d1 = torch.nn.functional.scaled_dot_product_attention(q_d1_t, k_d1_t, v_d1_t).transpose(1, 2)
|
|
attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt_t, k_nt_t, v_nt_t).transpose(1, 2)
|
|
|
|
self.assertEqual(attn_d1, attn_nt.unbind()[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol)
|
|
|
|
# Simple case: 2 sentences, no extra params
|
|
x_d2 = sen2.unsqueeze(0)
|
|
x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)
|
|
|
|
# NB: we make sure the leaf tensor we compute gradients for is the view-ed tensor before
|
|
# it is transposed. This is because today we cannot backward through view or unbind a
|
|
# transposed tensor.
|
|
q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True)
|
|
q_d2_t = q_d2.transpose(1, 2)
|
|
k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True)
|
|
k_d2_t = k_d2.transpose(1, 2)
|
|
v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).detach().requires_grad_(True)
|
|
v_d2_t = v_d2.transpose(1, 2)
|
|
|
|
q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True)
|
|
q_nt_t = q_nt.transpose(1, 2)
|
|
k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True)
|
|
k_nt_t = k_nt.transpose(1, 2)
|
|
v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).detach().requires_grad_(True)
|
|
v_nt_t = v_nt.transpose(1, 2)
|
|
|
|
attn_d2 = torch.nn.functional.scaled_dot_product_attention(q_d2_t, k_d2_t, v_d2_t).transpose(1, 2)
|
|
d1_grads = torch.autograd.grad(attn_d1.sum(), (q_d1, k_d1, v_d1))
|
|
d2_grads = torch.autograd.grad(attn_d2.sum(), (q_d2, k_d2, v_d2))
|
|
|
|
def check_forward_backward():
|
|
attn_nt = torch.nn.functional.scaled_dot_product_attention(q_nt_t, k_nt_t, v_nt_t).transpose(1, 2)
|
|
|
|
attn_nts = attn_nt.unbind()
|
|
self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol)
|
|
self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol)
|
|
|
|
nt_grads = torch.autograd.grad(buffer_from_jagged(attn_nt).sum(), (q_nt, k_nt, v_nt))
|
|
for nt_grad, d1_grad, d2_grad, grad_atol, grad_rtol in zip(nt_grads, d1_grads, d2_grads, grad_atols, grad_rtols):
|
|
unbound_nt_grads = nt_grad.unbind()
|
|
self.assertEqual(d1_grad, unbound_nt_grads[0].unsqueeze(0), atol=grad_atol, rtol=grad_rtol)
|
|
self.assertEqual(d2_grad, unbound_nt_grads[1].unsqueeze(0), atol=grad_atol, rtol=grad_rtol)
|
|
|
|
# Default
|
|
check_forward_backward()
|
|
|
|
# Test dispatcher works by calling only mem-effn and math (as they are safe for all devices)
|
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=True):
|
|
check_forward_backward()
|
|
|
|
# Test math fallback
|
|
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True):
|
|
# Math fallback doesn't work with bfloat16 on CUDA because
|
|
# "group_gemm_dispatch" not implemented for 'BFloat16'
|
|
if not (str(device).startswith("cuda") and dtype == torch.bfloat16):
|
|
check_forward_backward()
|
|
|
|
|
|
# This requires NT -> NT views to work in inductor, which is a TODO
|
|
@unittest.expectedFailure # noqa: E301
|
|
@onlyCUDA
|
|
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32] if
|
|
SM80OrLater else [torch.float16, torch.float32])
|
|
def test_sdpa_compile(self, device, dtype):
|
|
batch_size = 1
|
|
emb_dims = 1024
|
|
n_heads = 8
|
|
head_dims = emb_dims // n_heads
|
|
|
|
sen1 = torch.randn(11, emb_dims, dtype=dtype, device=device)
|
|
sen2 = torch.randn(13, emb_dims, dtype=dtype, device=device)
|
|
|
|
query = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype)
|
|
key = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype)
|
|
value = torch.nn.Linear(emb_dims, emb_dims, bias=False, device=device, dtype=dtype)
|
|
|
|
# Simplest case: 1 sentence, no batching
|
|
x_d1 = sen1.unsqueeze(0)
|
|
x_d2 = sen2.unsqueeze(0)
|
|
x_nt = torch.nested.as_nested_tensor([sen1, sen2], layout=torch.jagged)
|
|
|
|
q_d1 = query(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
k_d1 = key(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
v_d1 = value(x_d1).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
q_d2 = query(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
k_d2 = key(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
v_d2 = value(x_d2).view(batch_size, -1, n_heads, head_dims).transpose(1, 2)
|
|
|
|
q_nt = query(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2)
|
|
k_nt = key(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2)
|
|
v_nt = value(x_nt).view(*x_nt.size()[0:2], n_heads, head_dims).transpose(1, 2)
|
|
|
|
# High Precision Math Reference
|
|
q_d1_f32 = q_d1.to(torch.float32)
|
|
k_d1_f32 = k_d1.to(torch.float32)
|
|
v_d1_f32 = v_d1.to(torch.float32)
|
|
out_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1_f32, k_d1_f32, v_d1_f32)[0]
|
|
# Low Precision Math Reference
|
|
out_lp_ref = torch.ops.aten._scaled_dot_product_attention_math(q_d1, k_d1, v_d1)[0]
|
|
output_ref_atol, output_ref_rtol = get_tolerances(out_ref, out_lp_ref)
|
|
|
|
attn_d1 = torch.nn.functional.scaled_dot_product_attention(q_d1, k_d1, v_d1).transpose(1, 2)
|
|
attn_d2 = torch.nn.functional.scaled_dot_product_attention(q_d2, k_d2, v_d2).transpose(1, 2)
|
|
|
|
compiled_sdpa = torch.compile(torch.nn.functional.scaled_dot_product_attention)
|
|
attn_nt = compiled_sdpa(q_nt, k_nt, v_nt).transpose(1, 2)
|
|
|
|
attn_nts = attn_nt.unbind()
|
|
self.assertEqual(attn_d1, attn_nts[0].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol)
|
|
self.assertEqual(attn_d2, attn_nts[1].unsqueeze(0), atol=output_ref_atol, rtol=output_ref_rtol)
|
|
|
|
@dtypes(torch.float32, torch.double, torch.half)
|
|
def test_sdpa_with_constant_sequence_length(self, device, dtype):
|
|
# shape (B, P*, S, D)
|
|
# B: batch size
|
|
# P*: ragged number of prompts
|
|
# S: (constant) sequence length
|
|
# D: embedding size
|
|
query = random_nt_from_dims(
|
|
[4, None, 8, 10], device=device, dtype=dtype, layout=torch.jagged)
|
|
key = random_nt_from_similar(query)
|
|
value = random_nt_from_similar(query)
|
|
output = F.scaled_dot_product_attention(query, key, value)
|
|
self.assertTrue(isinstance(output, NestedTensor))
|
|
|
|
# should be equivalent to just running the buffers through
|
|
output_dense = F.scaled_dot_product_attention(query._values, key._values, value._values)
|
|
self.assertEqual(output._values, output_dense)
|
|
|
|
|
|
instantiate_parametrized_tests(TestNestedTensor)
|
|
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
|
instantiate_device_type_tests(TestNestedTensorAutograd, globals())
|
|
instantiate_device_type_tests(TestNestedTensorSubclass, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|