pytorch/test/test_nestedtensor.py
Joel Benjamin Schlosser f15c5bf133 Initial private SDP interface and naive composite impl (#81956)
Adds an initial private API version of the SDP interface.

Signature:
```
_scaled_dot_product_attention(Tensor query, Tensor key, Tensor value, Tensor? attn_mask=None,
    float dropout_p=0.0, bool need_attn_weights=True, bool is_causal=False) -> (Tensor, Tensor)
```

Returns a tuple of `(output, attn_weights)`.

Note the following:
* `need_attn_weights`: flag indicating that attention weights should be computed. This is useful to toggle off for flash attention as it does not materialize the weights by default, making it more expensive to return them.
* Boolean attention mask support only; `True` values within `attn_mask` indicate that the element should take part in attention (notably, this is reverse of MHA, which uses `True` to mask *out* values). Mask is optional.
* `is_causal`: Temporary flag indicating whether to use a causal attention weighting. If this is set to `True`, it takes precedent over any value passed in for `attn_mask`. Longer term, the `is_causal` flagging can be subsumed into the `attn_mask` arg via tensor subclassing (see e.g. [CausalTensor](https://github.com/facebookresearch/xformers/blob/sparse_cleanup/xformers/sparse/causal_tensor.py) in xFormers).
* Testing is currently done via reference with the existing Python impl of `F._scaled_dot_product_attention`.
* This PR does not yet drop-in the new SDP anywhere. A future PR can hook it up in BT or MHA.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/81956
Approved by: https://github.com/drisspg, https://github.com/erichan1
2022-07-27 15:41:45 +00:00

963 lines
40 KiB
Python

# Owner(s): ["module: nestedtensor"]
import torch
import torch.nn
import unittest
from torch.testing._internal.common_device_type import (
dtypes,
dtypesIfCUDA,
instantiate_device_type_tests,
skipMeta,
)
from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state
from torch import nested_tensor
# Tests are ported from pytorch/nestedtensor.
# This makes porting as_nested_tensor easier in the future.
def _iter_constructors():
# yield as_nested_tensor
yield nested_tensor
class TestNestedTensor(TestCase):
@torch.inference_mode()
def _test_unbind_case(self, a, b):
nt = nested_tensor([a, b])
a1, b1 = nt.unbind()
self.assertTrue(a is not a1)
self.assertTrue(b is not b1)
nt = nested_tensor([a, b], dtype=a.dtype)
a1, b1 = nt.unbind(0)
self.assertEqual(a, a1)
self.assertEqual(b, b1)
a = torch.randn((2, 3)).add_(1)
nt = nested_tensor([a])
self.assertEqual(a, nt.unbind(0)[0])
@torch.inference_mode()
def test_unbind_0(self):
self._test_unbind_case(
torch.tensor([1, 2]), torch.tensor([7, 8]),
)
@torch.inference_mode()
def test_unbind_1(self):
self._test_unbind_case(
torch.tensor([1]), torch.tensor([7]),
)
# @torch.inference_mode()
# def test_unbind_2(self):
# self._test_unbind_case(
# torch.tensor(1), torch.tensor(7),
# )
@torch.inference_mode()
def test_unbind_3(self):
self._test_unbind_case(
torch.tensor([1.0]), torch.tensor([]),
)
@torch.inference_mode()
def test_unbind_4(self):
self._test_unbind_case(
torch.tensor([]), torch.tensor([]),
)
@torch.inference_mode()
def test_unbind_dim(self):
def _test_fn(unbind_fn):
a = torch.rand(3, 2)
b = torch.rand(2, 3)
nt = nested_tensor([a, b])
self.assertRaises(RuntimeError, lambda: unbind_fn(nt, 1))
# Both of these tests are necessary, because we're using
# torch_function.
_test_fn(lambda x, dim: x.unbind(dim))
# TODO: Re-enable this once using torch_dispatch
# _test_fn(lambda x, dim: torch.unbind(x, dim))
@torch.inference_mode()
def test_nested_tensor(self):
self.assertRaises(TypeError, lambda: nested_tensor([3.0]))
self.assertRaises(TypeError, lambda: nested_tensor(torch.tensor([3.0])))
self.assertRaises(TypeError, lambda: nested_tensor(4.0))
@torch.inference_mode()
def test_nested_tensor_matching_dim(self):
self.assertRaisesRegex(
RuntimeError,
"Found dimension 1 for Tensor at index 1 and dimension 0 for Tensor at index 0.",
lambda: nested_tensor([torch.tensor(1.0), torch.tensor([])]),
)
self.assertRaisesRegex(
RuntimeError,
"Found dimension 1 for Tensor at index 2 and dimension 0 for Tensor at index 1.",
lambda: nested_tensor(
[torch.tensor(1.0), torch.tensor(2.0), torch.tensor([])]
),
)
@torch.inference_mode()
def test_default_nested_tensor(self):
self.assertRaises(TypeError, lambda: nested_tensor())
default_nested_tensor = nested_tensor([])
default_tensor = torch.tensor([])
# self.assertEqual(default_nested_tensor.nested_dim(), 1)
# self.assertEqual(default_nested_tensor.nested_size(), ())
self.assertEqual(default_nested_tensor.dim(), default_tensor.dim())
self.assertEqual(default_nested_tensor.layout, default_tensor.layout)
self.assertEqual(default_nested_tensor.device, default_tensor.device)
self.assertEqual(default_nested_tensor.dtype, default_tensor.dtype)
self.assertEqual(
default_nested_tensor.requires_grad, default_tensor.requires_grad
)
self.assertIsNone(default_tensor.grad)
# TODO: Re-enable once we have a performance driven
# use case and implementation.
# self.assertEqual(default_nested_tensor.is_pinned(),
# default_tensor.is_pinned())
@torch.inference_mode()
def test_dim(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertEqual(a1.dim(), 1)
a1 = constructor([torch.tensor(3.0)])
self.assertEqual(a1.dim(), 1)
a1 = constructor([torch.tensor([1, 2, 3, 4])])
self.assertEqual(a1.dim(), 2)
@unittest.skipIf(IS_FBCODE, "numel is not virtual in fbcode.")
@torch.inference_mode()
def test_numel(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertEqual(a1.numel(), 0)
a1 = constructor([torch.tensor(3.0), torch.tensor(4.0)])
self.assertEqual(a1.numel(), 2)
a1 = constructor([torch.randn(2, 2, 2)])
self.assertEqual(a1.numel(), 8)
a1 = constructor([torch.randn([1, 2, 3]), torch.randn(3, 2, 1)])
self.assertEqual(a1.numel(), 12)
a1 = constructor([torch.randn([1, 1, 3]), torch.randn(3, 2, 4)])
self.assertEqual(a1.numel(), 27)
a1 = constructor([torch.randn([5, 5, 5]), torch.randn(6, 6, 6)])
self.assertEqual(a1.numel(), 341)
# Interesting edge case
a1 = constructor([torch.randn([1, 2, 3]), torch.randn(1, 2, 0)])
self.assertEqual(a1.numel(), 6)
@torch.inference_mode()
def test_size(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError,
"Tensors of type NestedTensorImpl do not have sym sizes"
if IS_FBCODE
else "NestedTensorImpl doesn't support sizes",
lambda: a1.size(),
)
@unittest.skipIf(IS_FBCODE, "stride is not virtual in fbcode.")
@torch.inference_mode()
def test_stride(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError,
"NestedTensorImpl doesn't support strides",
lambda: a1.stride(),
)
@unittest.skipIf(IS_FBCODE, "is_contiguous is not virtual in fbcode.")
@torch.inference_mode()
def test_is_contiguous(self):
for constructor in _iter_constructors():
a1 = constructor([])
self.assertRaisesRegex(
RuntimeError, "is_contiguous is disabled", lambda: a1.is_contiguous()
)
@torch.inference_mode()
def test_repr_string(self):
a = nested_tensor([])
expected = "nested_tensor([" "\n\n])"
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
a = nested_tensor([torch.tensor(1.0)])
expected = "nested_tensor([" "\n tensor(1.)" "\n])"
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
a = nested_tensor([torch.tensor([[1, 2]]), torch.tensor([[4, 5]])])
expected = (
"nested_tensor([" "\n tensor([[1, 2]])" "," "\n tensor([[4, 5]])" "\n])"
)
self.assertEqual(str(a), expected)
self.assertEqual(repr(a), expected)
@torch.inference_mode()
def test_activations(self):
for func in (torch.nn.functional.relu, torch.nn.functional.relu_, torch.nn.functional.gelu, torch._C._nn.gelu_):
t = torch.tensor([-1, 0, 1], dtype=torch.float)
nt = nested_tensor([t])
nested_result = func(nt)
self.assertTrue(nested_result.is_nested)
self.assertEqual(func(t), nested_result.unbind()[0])
def test_to_padded_tensor_on_empty_tensor(self):
nt = torch.nested_tensor([])
empty = nt.to_padded_tensor(4)
self.assertEqual(empty, torch.tensor([]))
class TestNestedTensorDeviceType(TestCase):
@dtypes(torch.float)
@skipMeta
def test_to_then_from_padded_tensor_no_transform0213(self, device, dtype):
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
ts = list(torch.unbind(t))
ts[0] = ts[0][:-1]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
padded = nt.to_padded_tensor(0)
nt_to = torch._nested_from_padded_and_nested_example(padded, nt)
for (t1, t2) in zip(nt.unbind(), nt_to.unbind()):
self.assertEqual(t1, t2)
self.assertEqual(nt.device, nt_to.device)
@dtypes(torch.float)
@dtypesIfCUDA(torch.float, torch.half)
@skipMeta
@torch.inference_mode()
def test_layer_norm(self, device, dtype):
def _test(size):
t0 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
t1 = torch.randn(2, size, device=device, dtype=dtype, requires_grad=False)
ts = [t0, t1, t0, t1]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
layer_norm = torch.nn.LayerNorm(size, device=device, dtype=dtype)
nt_result = nt._nested_tensor_layer_norm(
layer_norm.weight, layer_norm.bias, 1e-5
)
for (nt_subresult, t) in zip(nt_result.unbind(), ts):
t_result = layer_norm(t.reshape(1, -1, size).squeeze(0))
self.assertEqual(nt_subresult, t_result)
for size in (1024, 1023, 513, 512, 256, 128, 2, 4, 32):
_test(size)
@skipMeta
@torch.inference_mode()
def test_embedding(self, device):
inputs = [
torch.randint(100, (L,), device=device, dtype=torch.int64)
for L in torch.randint(5, 50, (8,))
]
x = torch.nested_tensor(inputs, device=device, dtype=torch.int64)
emb = torch.nn.Embedding(100, 8, device=device)
y = emb(x)
ys = y.unbind()
for i, inp in enumerate(inputs):
self.assertEqual(emb(inp), ys[i])
@dtypes(torch.float, torch.float16)
def test_to_padded_tensor_simple(self, device, dtype):
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
ts = list(torch.unbind(t))
ts[0] = ts[0][:-1]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
for padding_value in (0, 1):
padded = nt.to_padded_tensor(padding_value)
correct_output = t.clone()
if padding_value == 0:
correct_output[0][-1] = torch.zeros_like(correct_output[0][-1])
else:
correct_output[0][-1] = torch.ones_like(correct_output[0][-1])
self.assertEqual(padded, correct_output)
self.assertEqual(padded.device, torch.device(device))
self.assertEqual(padded.dtype, dtype)
@dtypes(torch.float, torch.float16)
def test_to_padded_tensor_output_size(self, device, dtype):
t = torch.randn(4, 4, 4, device=device, dtype=dtype)
output_size = (4, 6, 5)
ts = list(torch.unbind(t))
ts[0] = ts[0][:-1]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
for padding_value in (0, 1):
padded = nt.to_padded_tensor(padding_value, output_size=output_size)
correct_output = torch.ones(output_size, device=device, dtype=dtype) * padding_value
correct_output[:4:, :4, :4] = t.clone()
if padding_value == 0:
correct_output[0][3] = torch.zeros_like(correct_output[0][3])
else:
correct_output[0][3] = torch.ones_like(correct_output[0][3])
self.assertEqual(padded, correct_output)
self.assertEqual(padded.device, torch.device(device))
self.assertEqual(padded.dtype, dtype)
@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_dim2(self, device, dtype):
ts = [
torch.randn(160, device=device, dtype=dtype),
torch.randn(1240, device=device, dtype=dtype),
torch.randn(2400, device=device, dtype=dtype),
]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
pad = 42
correct_output = []
for t in ts:
next_output = torch.ones_like(ts[2]) * pad
correct_output.append(next_output)
next_output[:t.size(0)].copy_(t)
correct_output = torch.stack(correct_output)
padded = nt.to_padded_tensor(pad)
self.assertEqual(padded, correct_output)
@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_dim3(self, device, dtype):
ts = [
torch.randn(16, 21, device=device, dtype=dtype),
torch.randn(24, 32, device=device, dtype=dtype),
torch.randn(40, 53, device=device, dtype=dtype),
]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
pad = 42
correct_output = []
for t in ts:
next_output = torch.ones_like(ts[2]) * pad
correct_output.append(next_output)
next_output[:t.size(0), :t.size(1)].copy_(t)
correct_output = torch.stack(correct_output)
padded = nt.to_padded_tensor(pad)
self.assertEqual(padded, correct_output)
@dtypes(torch.float, torch.float16, torch.double)
def test_to_padded_tensor_dim4(self, device, dtype):
ts = [
torch.randn(16, 21, 13, device=device, dtype=dtype),
torch.randn(24, 32, 14, device=device, dtype=dtype),
torch.randn(40, 53, 16, device=device, dtype=dtype),
]
nt = torch.nested_tensor(ts, device=device, dtype=dtype)
pad = 42
correct_output = []
for t in ts:
next_output = torch.ones_like(ts[2]) * pad
correct_output.append(next_output)
next_output[:t.size(0), :t.size(1), :t.size(2)].copy_(t)
correct_output = torch.stack(correct_output)
padded = nt.to_padded_tensor(pad)
self.assertEqual(padded, correct_output)
@skipMeta
def test_device_checks(self, device):
nt = torch.nested_tensor([], device=device)
is_cuda = 'cuda' in str(device)
self.assertEqual(nt.is_cuda, is_cuda)
@dtypes(torch.float, torch.float16, torch.double)
def test_nested_tensor_indexing(self, device, dtype):
# edge case: empty nested tensor
nt0 = torch.nested_tensor([])
self.assertRaises(IndexError, lambda: nt0[0])
# normal case
x0 = torch.randn((2, 5), device=device, dtype=dtype)
x1 = torch.randn((3, 4), device=device, dtype=dtype)
nt = torch.nested_tensor([x0, x1])
# single index: only support integer in the batch dimension
self.assertEqual(nt[0], x0)
self.assertEqual(nt[-1], x1)
self.assertRaises(IndexError, lambda: nt[2])
self.assertRaises(IndexError, lambda: nt[-3])
self.assertRaises(NotImplementedError, lambda: nt[:])
self.assertRaises(NotImplementedError, lambda: nt[None])
self.assertRaises(NotImplementedError, lambda: nt[...])
# tuple of indices: only support integer in the batch dimension
# + all possible indexing in the original tensor dimensions
self.assertEqual(nt[0, 0, 0], x0[0, 0])
self.assertEqual(nt[0, 1, :], x0[1, :])
self.assertEqual(nt[1, ...], x1)
self.assertRaises(IndexError, lambda: nt[1, 4, 2])
self.assertRaises(NotImplementedError, lambda: nt[:, 1, 1])
# make sure indexing returns a view
nt[0].fill_(100.0)
answer = torch.tensor(100.0, device=device, dtype=dtype).expand((2, 5))
self.assertEqual(nt[0], answer)
nt[1, 1, :].fill_(200.0)
answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4)
self.assertEqual(nt[1, 1, :], answer)
# Helper functions for testing elementwise ops
def random_nt(self, device, dtype, num_tensors, max_dims, min_dims=None):
if min_dims is None:
min_dims = tuple([0] * len(max_dims))
ts1 = []
for _ in range(num_tensors):
tensor_dims = tuple([torch.randint(low=min_dim, high=max_dim, size=(1,)).item()
for (min_dim, max_dim) in zip(min_dims, max_dims)])
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
ts1.append(t1)
return torch.nested_tensor(ts1, device=device, dtype=dtype)
# Helper functions for testing elementwise ops
def random_nt_pair(self, device, dtype, num_tensors, max_dims):
ts1 = []
ts2 = []
for _ in range(num_tensors):
tensor_dims = tuple([torch.randint(low=0, high=max_dim, size=(1,)).item() for max_dim in max_dims])
t1 = torch.randn(tensor_dims, device=device, dtype=dtype)
t2 = torch.randn(tensor_dims, device=device, dtype=dtype)
ts1.append(t1)
ts2.append(t2)
return (torch.nested_tensor(ts1, device=device, dtype=dtype),
torch.nested_tensor(ts2, device=device, dtype=dtype))
def nt_equal(self, nt1, nt2):
self.assertEqual(nt1.dtype, nt2.dtype)
self.assertEqual(nt1.device, nt2.device)
ub1 = nt1.unbind()
ub2 = nt2.unbind()
self.assertEqual(len(ub1), len(ub2))
n = len(ub1)
for i in range(n):
self.assertEqual(ub1[i], ub2[i])
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_add(self, device, dtype):
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
out = nt1 + nt2
self.nt_equal(ref, out)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_mul(self, device, dtype):
# nested tensor * nested tensor
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
out = nt1 * nt2
self.nt_equal(ref, out)
# nested tensor * scalar
number = 10.0
scalar = torch.tensor(number).to(dtype).to(device)
ref = torch.nested_tensor([t * number for t in nt1.unbind()])
out_number0 = nt1 * number
out_number1 = number * nt1
out_scalar0 = nt1 * scalar
out_scalar1 = scalar * nt1
self.nt_equal(out_number0, ref)
self.nt_equal(out_number1, ref)
self.nt_equal(out_scalar0, ref)
self.nt_equal(out_scalar1, ref)
# error case: numel == 1 but dim > 0
vector = torch.tensor([number]).to(dtype).to(device)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a nested self and non-nested other",
lambda: nt1.mul(vector)
)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a non-nested self and nested other",
lambda: vector.mul(nt1)
)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_add_in_place(self, device, dtype):
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested_tensor([t1 + t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
nt1 += nt2
self.nt_equal(ref, nt1)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_mul_in_place(self, device, dtype):
# nested tensor * nested tensor
(nt1, nt2) = self.random_nt_pair(device, dtype, 4, (4, 4))
ref = torch.nested_tensor([t1 * t2 for (t1, t2) in zip(nt1.unbind(), nt2.unbind())])
nt1 *= nt2
self.nt_equal(ref, nt1)
# nested tensor * scalar
number = 10.0
scalar = torch.tensor(number).to(dtype).to(device)
ref = torch.nested_tensor([t * number for t in nt1.unbind()])
out_number = nt1.clone()
out_number *= number
out_scalar = nt1.clone()
out_scalar *= scalar
self.nt_equal(out_number, ref)
self.nt_equal(out_scalar, ref)
self.assertRaisesRegex(
RuntimeError,
r"output with shape \[.*\] doesn't match the broadcast shape \[.*\]",
lambda: scalar.mul_(nt1)
)
# error case: numel == 1 but dim > 0
vector = torch.tensor([number]).to(dtype).to(device)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a nested self and non-nested other",
lambda: nt1.mul_(vector)
)
self.assertRaisesRegex(
RuntimeError,
"Expected both self and other to be nested, but got a non-nested self and nested other",
lambda: vector.mul_(nt1)
)
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_clone(self, device, dtype):
nt1 = self.random_nt(device, dtype, 4, (4, 4), (1, 1))
nt2 = nt1.clone()
# Verify the values match
self.nt_equal(nt1, nt2)
# Verify modifying nt2 doesn't affect nt1
nt2.mul_(nt1)
ub1 = nt1.unbind()
ub2 = nt2.unbind()
for i in range(len(ub1)):
self.assertNotEqual(ub1[i], ub2[i])
nt1.clone(memory_format=torch.preserve_format)
msg = "clone_nested only supports memory format Preserve, but got ChannelsLast instead."
with self.assertRaisesRegex(RuntimeError, msg):
nt1.clone(memory_format=torch.channels_last)
# cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
@dtypes(torch.float, torch.double)
@torch.inference_mode()
def test_dropout(self, device, dtype):
# edge case: empty nested tensor
nt0 = torch.nested_tensor([])
y = torch.nn.functional.dropout(nt0, 0.5)
self.nt_equal(nt0, y)
# normal nested tensor
ntensors = 4
nt = self.random_nt(device, dtype, ntensors, (4, 4))
# edge case: invalid dropout
self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
# edge case: no dropout
dropouter = torch.nn.Dropout(0.0)
y0 = dropouter(nt)
y1 = torch.nn.functional.dropout(nt, 0.0)
self.nt_equal(nt, y0)
self.nt_equal(nt, y1)
# edge case: all dropout
dropouter = torch.nn.Dropout(1.0)
y0 = dropouter(nt)
y1 = torch.nn.functional.dropout(nt, 1.0)
nt0 = nt.clone()
for i in range(ntensors):
nt0[i].fill_(0.0)
self.nt_equal(nt0, y0)
self.nt_equal(nt0, y1)
# normal case: normal dropout
p = 0.2
y = torch.nn.functional.dropout(nt, p)
expect = nt.clone()
for i in range(ntensors):
actual_tensor = y[i].view(-1)
expect_tensor = expect[i].view(-1)
for j in range(actual_tensor.shape[0]):
if actual_tensor[j].item() == 0.0:
expect_tensor[j] = 0.0
else:
expect_tensor[j] /= 1.0 - p
self.nt_equal(y, expect)
with freeze_rng_state():
dropouter = torch.nn.Dropout(p)
y0 = dropouter(nt)
with freeze_rng_state():
y1 = torch.nn.functional.dropout(nt, p)
self.nt_equal(y0, y1)
# inplace
# in principle, since we have established the correctness of functional, we could simply compare inplace vs functional
# in practice, cuda functional has its own implementation to skip `bernoulli_`
# so cuda functional will differ from cuda inplace causing test failure
# in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)`
# on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)`
expect = nt.clone()
torch.nn.functional.dropout(nt, p, inplace=True)
for i in range(ntensors):
actual_tensor = nt[i].view(-1)
expect_tensor = expect[i].view(-1)
for j in range(actual_tensor.shape[0]):
if actual_tensor[j].item() == 0.0:
expect_tensor[j] = 0.0
else:
expect_tensor[j] /= 1.0 - p
self.nt_equal(nt, expect)
# cannot test torch.float16 because: RuntimeError: "softmax_kernel_impl" not implemented for 'Half'
@dtypes(torch.float, torch.double)
@torch.inference_mode()
def test_softmax(self, device, dtype):
# normal nested tensor
ntensors = 4
nt = self.random_nt(device, dtype, ntensors, (4, 4))
# error case: softmax across nested dimension
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt, 0))
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt, -3))
# error case: dimension out of range
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, 3))
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt, -4))
# normal case: should equal to padding -inf
softmaxer = torch.nn.Softmax(1)
y0 = softmaxer(nt)
y1 = torch.nn.functional.softmax(nt, 1)
self.nt_equal(y0, y1)
pt = nt.to_padded_tensor(float("-inf"))
# if an entire slice is padded, then softmax will return 0.0 / 0.0 = nan
# however, physically speaking that should be 0.0
expect = torch.nn.functional.softmax(pt, 1).nan_to_num_(0.0)
self.assertEqual(y0.to_padded_tensor(0.0), expect)
# edge case: empty nested tensor
nt0 = torch.nested_tensor([])
y = torch.nn.functional.softmax(nt0, 1)
self.nt_equal(nt0, y)
# edge case: nesting scalars
nt1 = torch.nested_tensor([torch.tensor(0.0), torch.tensor(1.0)])
self.assertRaises(RuntimeError, lambda: torch.nn.functional.softmax(nt1, 0))
self.assertRaises(IndexError, lambda: torch.nn.functional.softmax(nt1, 1))
@dtypes(torch.float, torch.float16, torch.double)
@torch.inference_mode()
def test_bmm(self, device, dtype):
# error case: not 3D tensors
nt0 = torch.nested_tensor([])
nt1 = torch.nested_tensor([torch.randn(2), torch.randn(3)])
nt2 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))])
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt0))
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt1))
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt0.bmm(nt2))
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt0))
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt1))
self.assertRaisesRegex(RuntimeError, "batch1 must be a 3D tensor", lambda: nt1.bmm(nt2))
self.assertRaisesRegex(RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt0))
self.assertRaisesRegex(RuntimeError, "batch2 must be a 3D tensor", lambda: nt2.bmm(nt1))
# error case: incompatible batch size
nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))])
nt1 = torch.nested_tensor([torch.randn((4, 6)), torch.randn((4, 5)), torch.randn((4, 7))])
self.assertRaisesRegex(
RuntimeError,
"Expected size for the 1st dimension of batch2 tensor to be: 2 but got: 3.",
lambda: nt0.bmm(nt1)
)
self.assertRaisesRegex(
RuntimeError,
"Expected size for the 1st dimension of batch2 tensor to be: 3 but got: 2.",
lambda: nt1.bmm(nt0)
)
# error case: underlying matrices cannot be multiplied
nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 4))])
self.assertRaisesRegex(
RuntimeError,
r"0-th nested matrices in batch cannot be multiplied \(2x4 and 2x4\)",
lambda: nt0.bmm(nt0)
)
# normal nested tensor
nt0 = torch.nested_tensor([torch.randn((2, 4)), torch.randn((3, 7))])
nt1 = torch.nested_tensor([torch.randn((4, 6)), torch.randn((7, 5))])
actual = nt0.bmm(nt1)
expect = nt0.to_padded_tensor(0.0).bmm(nt1.to_padded_tensor(0.0))
self.assertEqual(actual.to_padded_tensor(0.0), expect)
@dtypes(torch.float, torch.double)
def test_linear(self, device, dtype):
a = torch.randn(1, 2, device=device, dtype=dtype)
b = torch.randn(2, 2, device=device, dtype=dtype)
c = torch.randn(3, 2, device=device, dtype=dtype)
nt = torch.nested_tensor([a, b, c])
weight = torch.randn(2, 2, device=device, dtype=dtype)
bias = torch.randn(2, device=device, dtype=dtype)
# success case
torch.functional.F.linear(nt, weight, bias)
# invalid nested tensor dimension
msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 2. Dense tensor dim: 2'
nt1 = torch.nested_tensor([torch.randn(1, device=device, dtype=dtype),
torch.randn(2, device=device, dtype=dtype)])
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt1, weight, bias)
# invalid weight shape
msg = r'Linear requires nested_tensor.dim == 3 and dense_matrix.dim == 2. Nested tensor dim: 3. Dense tensor dim: 3'
weight1 = torch.randn(2, 2, 3, device=device, dtype=dtype)
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt, weight1, bias)
# inconsistent last dim of nested tensor
msg = r"all tensors in NestedTensor must have the same trailing dim"
nt2 = torch.nested_tensor([torch.randn(1, 2, device=device, dtype=dtype),
torch.randn(2, 3, device=device, dtype=dtype)])
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt2, weight, bias)
# Mismatch of nested tensor last dim and weight dimension
weight2 = torch.randn(2, 4, device=device, dtype=dtype)
msg = r"Shape mismatch for NestedTensor Linear: Expected input's \(a nested tensor\) 'last_dim'" \
r" to equal 'weight.size\(1\), but got: last_dim = 2, and weight.size\(1\) = 4"
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt, weight2, bias)
# Nested tensor input and nested weight
nt_weight = nt.clone()
msg = r"Linear does not support nested weight when input is a nested tensor."
with self.assertRaisesRegex(RuntimeError, msg):
torch.functional.F.linear(nt, nt_weight, bias)
class TestNestedTensorAutograd(TestCase):
def nt_equal(self, nt1, nt2):
self.assertEqual(nt1.dtype, nt2.dtype)
self.assertEqual(nt1.device, nt2.device)
ub1 = nt1.unbind()
ub2 = nt2.unbind()
self.assertEqual(len(ub1), len(ub2))
n = len(ub1)
for i in range(n):
self.assertEqual(ub1[i], ub2[i])
def _create_nested_tensor_from_list(self, requires_grad=False):
return torch.nested_tensor([torch.randn(1, 2, requires_grad=requires_grad),
torch.randn(7, 8, requires_grad=requires_grad)])
def _create_nested_tensor_from_mask(self, requires_grad=False):
data = torch.randn(2, 3, 4, requires_grad=requires_grad)
mask = torch.ones_like(data[:, :, 0]).bool()
return torch._nested_tensor_from_mask(data, mask)
def test_set_requires_grad_from_list(self):
nt = self._create_nested_tensor_from_list()
nt.requires_grad_()
assert nt.requires_grad
def test_set_requires_grad_from_mask(self):
nt = self._create_nested_tensor_from_mask()
nt.requires_grad_()
assert nt.requires_grad
def test_backward_for_add_op(self):
nt_1 = self._create_nested_tensor_from_mask()
nt_2 = self._create_nested_tensor_from_mask()
nt_1.requires_grad_()
c = nt_1 + nt_2
assert nt_1.requires_grad
assert c.requires_grad
grad_output = self._create_nested_tensor_from_mask()
c.backward(grad_output)
# Grad check doesn't work with nested yet.
# d/dnt_1 (nt + nt_1) = 1*grad_output
self.nt_equal(nt_1.grad, grad_output)
# Test Factory Functions
def test_nested_tensor_to_padded_tensor(self):
for padding_val in [0, 1]:
nt = torch.nested_tensor([torch.randn(1, 2), torch.randn(7, 8)])
nt.requires_grad_()
out = nt.to_padded_tensor(padding_val)
grad_output = torch.ones(out.shape)
out.backward(grad_output)
self.nt_equal(nt.grad, torch.nested_tensor([torch.ones(1, 2), torch.ones(7, 8)]))
def test_nested_tensor_from_mask_and_to_padded(self):
N, L, D = 2, 4, 4
mask = torch.ones(N, L)
for i in range(1, N):
end = torch.randint(1, L - 1, (1,))
mask[i, end:] = 0
mask[0, :] = 1
mask = mask.bool()
data = torch.randn(N, L, D, requires_grad=True, dtype=torch.float64)
def grad_test_func(inpt):
nt = torch._nested_tensor_from_mask(inpt, mask)
# This implicitly tests to_padded_tensor grads
return nt.to_padded_tensor(0)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_nested_tensor_from_padded(self):
nested_size = torch.tensor([[1, 2], [2, 2]])
padded_tensor = torch.randn(2, 2, 2, dtype=torch.float64)
padded_tensor[0, 1, :] = 0
padded_tensor.requires_grad_()
def grad_test_func(tensor, nested_size):
nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=False)
# This implicitly tests to_padded_tensor grads
return nt.to_padded_tensor(0)
data = (padded_tensor, nested_size)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_nested_tensor_from_padded_fused(self):
nested_size = torch.tensor([[1, 8], [2, 8]])
padded_tensor = torch.randn(2, 2, 2, 4, dtype=torch.float64)
padded_tensor[0, 1, :] = 0
padded_tensor.requires_grad_()
def grad_test_func(tensor, nested_size):
nt = torch._nested_from_padded(tensor, nested_size, fuse_transform_0213=True)
# This implicitly tests to_padded_tensor grads
return nt.to_padded_tensor(0)
data = (padded_tensor, nested_size)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_nested_tensor_from_list(self):
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64)
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64)
c = torch.randn(10, 2, requires_grad=True, dtype=torch.float64)
def grad_test_func(a, b, c):
c = torch.nested_tensor([a, b, c])
# This implictily tests to_padded_tensor grads
return c.to_padded_tensor(0)
data = (a, b, c)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_size_dim(self):
a = torch.nested_tensor([])
self.assertEqual(a.size(0), 0)
a = torch.nested_tensor([torch.tensor(1)])
self.assertEqual(a.size(0), 1)
a = torch.nested_tensor([torch.tensor(1), torch.tensor(2)])
self.assertEqual(a.size(0), 2)
a = torch.nested_tensor([torch.rand(1, 2),
torch.rand(1, 8)])
self.assertEqual(a.size(0), 2)
self.assertEqual(a.size(1), 1)
self.assertRaisesRegex(
RuntimeError, "Given dimension 2 is irregular and does not have a size", lambda: a.size(2))
a = torch.nested_tensor([torch.rand(3, 4),
torch.rand(5, 4)])
self.assertEqual(a.size(0), 2)
self.assertRaisesRegex(
RuntimeError, "Given dimension 1 is irregular and does not have a size", lambda: a.size(1))
self.assertEqual(a.size(2), 4)
def test_nested_tensor_linear(self):
a = torch.randn(1, 2, requires_grad=True, dtype=torch.float64)
b = torch.randn(2, 2, requires_grad=True, dtype=torch.float64)
c = torch.randn(3, 2, requires_grad=True, dtype=torch.float64)
weight = torch.randn(2, 2, requires_grad=True, dtype=torch.float64)
bias = torch.randn(2, requires_grad=True, dtype=torch.float64)
def grad_test_func(a, b, c, weight, bias=None):
nt = torch.nested_tensor([a, b, c])
# This implicitly tests to_padded_tensor grads
d = torch.functional.F.linear(nt, weight, bias)
return d.to_padded_tensor(0)
data = (a, b, c, weight, bias)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
# Test linear with no bias added
data = (a, b, c, weight)
assert torch.autograd.gradcheck(grad_test_func, inputs=data)
def test_nested_tensor_linear_backward(self):
a = torch.randn(1, 2, requires_grad=False)
b = torch.randn(2, 2, requires_grad=False)
c = torch.randn(3, 2, requires_grad=False)
weight = torch.randn(2, 2, requires_grad=True)
bias = torch.randn(2, requires_grad=True)
nt = torch.nested_tensor([a, b, c])
out = torch.functional.F.linear(nt, weight, bias)
out.backward(out.clone())
assert weight.grad is not None
assert bias.grad is not None
assert a.grad is None
assert b.grad is None
assert c.grad is None
def test_scaled_dot_product_attention(self):
# Shape: (N, L, E); ragged L
E = 10
query = torch.nested_tensor([torch.randn(2, E), torch.randn(3, E), torch.randn(4, E)])
# Shape: (N, S, E); ragged S
key = torch.nested_tensor([torch.randn(3, E), torch.randn(4, E), torch.randn(5, E)])
value = torch.nested_tensor([torch.randn(3, E), torch.randn(4, E), torch.randn(5, E)])
def rand_mask(size):
return torch.randint(0, 2, size=size, dtype=torch.bool)
# Shape: (N, L, S); ragged L and S matching above
attn_mask = torch.nested_tensor([rand_mask((2, 3)), rand_mask((3, 4)), rand_mask((4, 5))])
dropout_p = 0.0 # no dropout for reproducibility
need_attn_weights: bool = True
# Success case: no attn_mask set and is_causal=False.
actual = torch.ops.aten._scaled_dot_product_attention(
query, key, value, attn_mask=None, dropout_p=dropout_p, need_attn_weights=need_attn_weights)
expected_outputs = []
expected_attn_weights = []
for q, k, v in zip(query.unbind(), key.unbind(), value.unbind()):
(output, attn_weights) = torch.ops.aten._scaled_dot_product_attention(
q.unsqueeze(0), k.unsqueeze(0), v.unsqueeze(0), attn_mask=None, dropout_p=dropout_p,
need_attn_weights=need_attn_weights)
expected_outputs.append(output.squeeze(0))
expected_attn_weights.append(attn_weights.squeeze(0))
expected_output_nested = torch.nested_tensor(expected_outputs)
expected_attn_weight_nested = torch.nested_tensor(expected_attn_weights)
self.nt_equal(actual[0], expected_output_nested)
self.nt_equal(actual[1], expected_attn_weight_nested)
# Error case: explicit attn_mask set.
with self.assertRaisesRegex(RuntimeError, "not supported when an explicit attn_mask is set"):
torch.ops.aten._scaled_dot_product_attention(
query, key, value, attn_mask=attn_mask, dropout_p=dropout_p, need_attn_weights=need_attn_weights)
# Error case: is_causal=True.
with self.assertRaisesRegex(RuntimeError, "not supported when is_causal=True"):
torch.ops.aten._scaled_dot_product_attention(
query, key, value, dropout_p=dropout_p, need_attn_weights=need_attn_weights, is_causal=True)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
if __name__ == '__main__':
run_tests()