mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
As brought up in https://github.com/pytorch/pytorch/issues/86234#issuecomment-1268296036, our heuristic for which SVD backend to choose was not great in some cases. The case in which there could be some improvements is when we have a large batch of very small non-square matrices. This PR, adapts the calling code to gesvdj by creating two temporary square buffers to allow to call gesvdjBatched, and then copies back the result into the output buffers. We then modify the heuristic that chooses between gesvdj and gesvdjBatched. Fixes https://github.com/pytorch/pytorch/issues/86234 Pull Request resolved: https://github.com/pytorch/pytorch/pull/88502 Approved by: https://github.com/IvanYashchuk, https://github.com/nikitaved, https://github.com/mruberry, https://github.com/xwang233
4573 lines
182 KiB
Python
4573 lines
182 KiB
Python
# Owner(s): ["module: functorch"]
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
from typing import OrderedDict
|
|
from unittest.case import skipIf
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
import functools
|
|
import itertools
|
|
import warnings
|
|
import unittest
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
from torch.testing._internal.common_cuda import with_tf32_off
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, \
|
|
skipCUDAIfNoMagma
|
|
from torch.testing._internal.common_device_type import ops
|
|
from torch.testing._internal.common_utils import (
|
|
parametrize,
|
|
instantiate_parametrized_tests,
|
|
subtest
|
|
)
|
|
from torch.testing._internal.common_device_type import \
|
|
toleranceOverride, tol
|
|
from functorch_additional_op_db import additional_op_db
|
|
from common_utils import (
|
|
get_fallback_and_vmap_exhaustive,
|
|
xfail,
|
|
skip,
|
|
skipOps,
|
|
check_vmap_fallback,
|
|
tol1,
|
|
opsToleranceOverride,
|
|
is_batch_norm_training,
|
|
generate_vmap_inputs,
|
|
compute_quantities_for_vmap_test,
|
|
is_valid_inplace_sample_input,
|
|
)
|
|
import types
|
|
from collections import namedtuple
|
|
|
|
import functorch
|
|
from functorch import vmap, grad, grad_and_value, jvp, vjp, jacfwd
|
|
from functorch.experimental import chunk_vmap
|
|
from torch._C._functorch import reshape_dim_into, reshape_dim_outof
|
|
from functorch._src.make_functional import functional_init_with_buffers
|
|
|
|
FALLBACK_REGEX = 'There is a performance drop'
|
|
|
|
|
|
class EnableVmapFallbackWarnings:
|
|
def __enter__(self):
|
|
self.prev_state = torch._C._debug_only_are_vmap_fallback_warnings_enabled()
|
|
torch._C._debug_only_display_vmap_fallback_warnings(True)
|
|
|
|
def __exit__(self, *ignored):
|
|
torch._C._debug_only_display_vmap_fallback_warnings(self.prev_state)
|
|
|
|
|
|
class TestVmapAPI(TestCase):
|
|
def test_non_tensor_output_raises(self):
|
|
with self.assertRaisesRegex(ValueError, "got type <class 'float'> as a return"):
|
|
vmap(lambda x: 3.14)(torch.ones(3))
|
|
|
|
def multiple_outputs(x):
|
|
return x, 3
|
|
|
|
with self.assertRaisesRegex(ValueError, "got type <class 'int'> as a return"):
|
|
vmap(multiple_outputs)(torch.ones(3))
|
|
|
|
def test_different_map_dim_size_raises(self):
|
|
x = torch.randn(2)
|
|
y = torch.randn(3)
|
|
expected_msg = 'Expected all tensors to have the same size in the mapped dimension'
|
|
with self.assertRaisesRegex(ValueError, expected_msg):
|
|
vmap(torch.mul)(x, y)
|
|
with self.assertRaisesRegex(ValueError, expected_msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
|
|
with self.assertRaisesRegex(ValueError, expected_msg):
|
|
vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y})
|
|
|
|
def test_func_with_no_inputs(self):
|
|
expected_msg = 'got no inputs'
|
|
|
|
def foo():
|
|
return torch.randn(3)
|
|
|
|
def bar(x):
|
|
return torch.randn(3)
|
|
|
|
with self.assertRaisesRegex(ValueError, expected_msg):
|
|
vmap(foo)()
|
|
|
|
with self.assertRaisesRegex(ValueError, expected_msg):
|
|
vmap(bar)()
|
|
|
|
def test_func_with_no_tensors(self):
|
|
def foo(x):
|
|
return torch.randn(3)
|
|
|
|
with self.assertRaisesRegex(ValueError, 'at least one Tensor'):
|
|
vmap(foo, (None,))(1)
|
|
|
|
def test_constant_function(self):
|
|
output = vmap(lambda x: torch.tensor(3.14))(torch.ones(3))
|
|
self.assertEqual(output, torch.tensor([3.14, 3.14, 3.14]))
|
|
|
|
def test_single_input(self):
|
|
x = torch.randn(2, 3)
|
|
|
|
def square(x):
|
|
return x * x
|
|
|
|
output = vmap(square)(x)
|
|
self.assertEqual(output, x * x)
|
|
|
|
def test_multiple_inputs(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
output = vmap(torch.mul)(x, y)
|
|
self.assertEqual(output, x * y)
|
|
|
|
def test_multiple_outputs(self):
|
|
def foo(x):
|
|
return x * x, x * x * x
|
|
|
|
x = torch.randn(3)
|
|
outputs = vmap(foo)(x)
|
|
self.assertEqual(outputs[0], x * x)
|
|
self.assertEqual(outputs[1], x * x * x)
|
|
|
|
def test_multiple_outputs2(self):
|
|
# This is the same thing as
|
|
# def returns_tuple_of_tensors(x):
|
|
# return x, x
|
|
def returns_tuple_of_tensors(x):
|
|
return (x, x)
|
|
|
|
def returns_list_of_two_tensors(x):
|
|
return [x, x]
|
|
|
|
def returns_list_of_one_tensor(x):
|
|
return [x]
|
|
|
|
x = torch.randn(3)
|
|
|
|
# should not throw
|
|
vmap(returns_tuple_of_tensors)(x)
|
|
vmap(returns_list_of_two_tensors)(x)
|
|
vmap(returns_list_of_one_tensor)(x)
|
|
|
|
def test_nested_with_same_map_dim(self):
|
|
x = torch.randn(2, 3, 5)
|
|
y = torch.randn(2, 3, 5)
|
|
output = vmap(vmap(torch.mul))(x, y)
|
|
self.assertEqual(output, x * y)
|
|
|
|
output = vmap(vmap(vmap(torch.mul)))(x, y)
|
|
self.assertEqual(output, x * y)
|
|
|
|
def test_nested_with_diag_embed(self):
|
|
# diag_embed requires special testing because it is registered with conditional functionalization.
|
|
x = torch.randn(3, 3, 5)
|
|
output = vmap(vmap(torch.diag_embed))(x)
|
|
self.assertEqual(output, torch.diag_embed(x))
|
|
|
|
def test_nested_with_different_map_dim(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(5, 3)
|
|
output = vmap(lambda x: vmap(lambda y: x * y)(y))(x)
|
|
self.assertEqual(output.shape, (2, 5, 3))
|
|
self.assertEqual(output, x.view(2, 1, 3) * y)
|
|
|
|
z = torch.randn(7, 3)
|
|
output = vmap(lambda x: vmap(lambda y: vmap(lambda z: x * y * z)(z))(y))(x)
|
|
self.assertEqual(output.shape, (2, 5, 7, 3))
|
|
self.assertEqual(output, x.view(2, 1, 1, 3) * y.view(5, 1, 3) * z)
|
|
|
|
def test_noop_in_inner_vmap(self):
|
|
x = torch.randn(3)
|
|
y = torch.randn(5)
|
|
output = vmap(lambda x: vmap(lambda y: x)(y))(x)
|
|
self.assertEqual(output, x.view(3, 1).expand(3, 5))
|
|
|
|
def test_unsupported_op_err_msg(self):
|
|
# Unsupported view op
|
|
tensor = torch.randn(2, 3)
|
|
msg = (
|
|
r"Batching rule not implemented for aten::.+; the "
|
|
r"fallback path doesn't work on out= or view ops"
|
|
)
|
|
# TODO: find a view op
|
|
# with self.assertRaisesRegex(RuntimeError, msg):
|
|
# vmap(torch.ravel)(tensor)
|
|
|
|
def out_op(x, y):
|
|
return torch.abs(x, out=y)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(out_op)(tensor, tensor)
|
|
|
|
# Don't support non-tensor returns. This is a limitation of vmap;
|
|
# functions that don't return tensors must be special cased
|
|
with self.assertRaisesRegex(RuntimeError, 'Batching rule not implemented'):
|
|
vmap(torch.equal)(tensor, tensor)
|
|
|
|
def test_nonzero_out_dims(self):
|
|
# Basic test
|
|
tensor = torch.randn(2, 3)
|
|
result = vmap(lambda x: x, out_dims=1)(tensor)
|
|
self.assertEqual(result, tensor.permute(1, 0))
|
|
self.assertEqual(result.data_ptr(), tensor.data_ptr())
|
|
|
|
# Test that the batch dimension gets permuted to dim 2
|
|
tensor = torch.randn(2, 3, 5, 7)
|
|
result = vmap(lambda x: x, out_dims=2)(tensor)
|
|
self.assertEqual(result, tensor.permute(1, 2, 0, 3))
|
|
self.assertEqual(result.data_ptr(), tensor.data_ptr())
|
|
|
|
# negative out_dim
|
|
tensor = torch.randn(2, 3, 5, 7)
|
|
result = vmap(lambda x: x, out_dims=-1)(tensor)
|
|
self.assertEqual(result, tensor.permute(1, 2, 3, 0))
|
|
self.assertEqual(result.data_ptr(), tensor.data_ptr())
|
|
|
|
# check that out_dims works on ALL outputs
|
|
tensor = torch.randn(2, 3, 5, 7)
|
|
other = torch.randn(2, 3, 5, 7)
|
|
result = vmap(lambda x, y: (x, y), out_dims=2)(tensor, other)
|
|
self.assertEqual(result, (tensor.permute(1, 2, 0, 3), other.permute(1, 2, 0, 3)))
|
|
|
|
# use out_dims with the maximum vmap-able tensor dims (64 dims)
|
|
ndims = 64
|
|
shape = [2] + [1] * (ndims - 1)
|
|
expected_shape = [1, 1, 2] + [1] * (ndims - 3)
|
|
tensor = torch.randn(shape)
|
|
result = vmap(lambda x: x, out_dims=2)(tensor)
|
|
self.assertEqual(result.shape, expected_shape)
|
|
|
|
# test something that is not the identity function
|
|
def foo(x, y):
|
|
return x, x * y, x * y * y
|
|
x = torch.randn(2, 3, 5)
|
|
y = torch.randn(2, 3, 5)
|
|
result = vmap(foo, out_dims=1)(x, y)
|
|
self.assertEqual(
|
|
result,
|
|
(x.permute(1, 0, 2), (x * y).permute(1, 0, 2), (x * y * y).permute(1, 0, 2)))
|
|
|
|
def test_multiple_out_dims(self):
|
|
def foo(x):
|
|
return x, x
|
|
|
|
def bar(x, y):
|
|
return x, x, x, x * y
|
|
|
|
x = torch.randn(2, 3, 5)
|
|
y = torch.randn(2, 3, 5)
|
|
result = vmap(foo, out_dims=(0, 1))(x)
|
|
self.assertEqual(result, (x, x.permute(1, 0, 2)))
|
|
|
|
result = vmap(bar, out_dims=(-1, 0, 1, 2))(x, y)
|
|
expected = (
|
|
x.permute(1, 2, 0),
|
|
x,
|
|
x.permute(1, 0, 2),
|
|
(x * y).permute(1, 2, 0),
|
|
)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_nested_out_dims(self):
|
|
y = torch.randn(2, 3, 5, 7)
|
|
|
|
# Inner vmap has non-zero out_dim
|
|
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y))(y)
|
|
self.assertEqual(result.shape, (2, 5, 3, 7))
|
|
self.assertEqual(result, y.permute(0, 2, 1, 3))
|
|
|
|
# all vmaps have non-zero out_dim
|
|
result = vmap(lambda y: vmap(lambda x: x, out_dims=1)(y), out_dims=1)(y)
|
|
self.assertEqual(result.shape, (5, 2, 3, 7))
|
|
self.assertEqual(result, y.permute(2, 0, 1, 3))
|
|
|
|
# throwing in some negative out_dims
|
|
result = vmap(lambda y: vmap(lambda x: x, out_dims=-1)(y), out_dims=-1)(y)
|
|
self.assertEqual(result.shape, (5, 7, 3, 2))
|
|
self.assertEqual(result, y.permute(2, 3, 1, 0))
|
|
|
|
# testing fn that isn't the identity
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(5, 3)
|
|
result = vmap(lambda y: vmap(lambda x: x * y, out_dims=1)(x), out_dims=-1)(y)
|
|
self.assertEqual(result.shape, (3, 2, 5))
|
|
self.assertEqual(result, (y.view(5, 1, 3) * x).permute(2, 1, 0))
|
|
|
|
def test_out_dims_edge_case(self):
|
|
def foo(x):
|
|
return x
|
|
|
|
# Test that we accept out_dims=(1,) for a function with one output.
|
|
tensor = torch.randn(2, 3)
|
|
expected = vmap(foo, out_dims=1)(tensor)
|
|
result = vmap(foo, out_dims=(1,))(tensor)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_pytree_returns(self):
|
|
x = torch.randn(2, 3)
|
|
|
|
def f(x):
|
|
y = x.sin()
|
|
return y, (y, y), [y, (y, y)]
|
|
|
|
y0, (y1, y2), (y3, (y4, y5)) = vmap(f)(x)
|
|
self.assertEqual(y0, x.sin())
|
|
self.assertEqual(y0, y1)
|
|
self.assertEqual(y2, y1)
|
|
self.assertEqual(y2, y3)
|
|
self.assertEqual(y4, y3)
|
|
self.assertEqual(y5, y4)
|
|
|
|
def test_pytree_odict_returns(self):
|
|
x = torch.randn(2, 3)
|
|
|
|
def f(t):
|
|
y = t.sin()
|
|
return OrderedDict([("sin", y), ("cos", t.cos())])
|
|
|
|
out = vmap(f)(x)
|
|
assert isinstance(out, OrderedDict)
|
|
expected = f(x)
|
|
self.assertEqual(out["sin"], expected["sin"])
|
|
self.assertEqual(out["cos"], expected["cos"])
|
|
|
|
def test_pytree_returns_outdims(self):
|
|
x = torch.randn(2, 3)
|
|
|
|
def f(x):
|
|
y = x.sin()
|
|
return y, (y, y)
|
|
|
|
y0, (y1, y2) = vmap(f, out_dims=(0, (0, 1)))(x)
|
|
self.assertEqual(y0, x.sin())
|
|
self.assertEqual(y1, x.sin())
|
|
self.assertEqual(y2, x.sin().t())
|
|
|
|
def test_pytree_returns_broadcast_simple(self):
|
|
x = torch.randn(2, 3)
|
|
|
|
def f(x):
|
|
y = x.sin()
|
|
return y, (y, y)
|
|
|
|
y0, (y1, y2) = vmap(f, out_dims=1)(x)
|
|
self.assertEqual(y0, x.sin().t())
|
|
self.assertEqual(y1, y0)
|
|
self.assertEqual(y2, y0)
|
|
|
|
def test_pytree_returns_broadcast_nested(self):
|
|
x = torch.randn(2, 3)
|
|
|
|
def f(x):
|
|
y = x.sin()
|
|
return y, (y, y)
|
|
|
|
y0, (y1, y2) = vmap(f, out_dims=(0, 1))(x)
|
|
self.assertEqual(y0, x.sin())
|
|
self.assertEqual(y1, y0.t())
|
|
self.assertEqual(y2, y0.t())
|
|
|
|
def test_out_dims_must_be_int_or_collection_of_int_err_msg(self):
|
|
msg = 'must be an int or a python collection of ints'
|
|
tensor = torch.randn(2, 3)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: x, out_dims='lol')(tensor)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: x, out_dims=('lol',))(tensor)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: x, out_dims=None)(tensor)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: x, out_dims=(None,))(tensor)
|
|
|
|
def test_out_dims_and_num_outputs_mismatch_err_msg(self):
|
|
msg = 'not compatible'
|
|
x = torch.randn(2, 3, 5)
|
|
|
|
# Too many out_dims
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: x, out_dims=(0, 0))(x)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: (x, x, x), out_dims=(0, 0, 0, 0))(x)
|
|
|
|
# Too few out_dims
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: (x, x), out_dims=(0,))(x)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda x: (x, x, x), out_dims=(0, 0))(x)
|
|
|
|
def test_out_dim_out_of_bounds_err_msg(self):
|
|
# TODO(rzou): This error message isn't that great. It comes straight
|
|
# from maybe_wrap_dim. Consider doing a try-catch-(add some context) to
|
|
# the error message in the future in C++
|
|
msg = 'Dimension out of range'
|
|
x = torch.randn(2, 3, 5)
|
|
with self.assertRaisesRegex(IndexError, msg):
|
|
vmap(lambda x: x, out_dims=3)(x)
|
|
with self.assertRaisesRegex(IndexError, msg):
|
|
vmap(lambda x: x, out_dims=-4)(x)
|
|
|
|
def test_non_zero_in_dims(self):
|
|
tensor = torch.randn(2, 3, 5)
|
|
|
|
# Implicit out_dims = 0; vmap will move the batch dim to the front.
|
|
output = vmap(lambda x: x, (1,))(tensor)
|
|
self.assertEqual(output, tensor.permute(1, 0, 2))
|
|
self.assertEqual(output.data_ptr(), tensor.data_ptr())
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(3, 2)
|
|
output = vmap(torch.mul, (0, 1))(x, y)
|
|
self.assertEqual(output, x * y.t())
|
|
output = vmap(torch.mul, (1, 0))(x, y)
|
|
self.assertEqual(output, x.t() * y)
|
|
|
|
def test_none_in_dims(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
# None in_dim for a Tensor means we don't map over it
|
|
output = vmap(torch.mul, (0, None))(x, y)
|
|
self.assertEqual(output.shape, (2, 2, 3))
|
|
self.assertEqual(output, x.view(2, 1, 3) * y)
|
|
|
|
# None in_dim for non-tensor arguments
|
|
output = vmap(torch.mul, (0, None))(x, 2)
|
|
self.assertEqual(output, x * 2)
|
|
|
|
def test_nested_non_default_in_dims(self):
|
|
x = torch.rand(5, 2, 3)
|
|
y = torch.rand(3, 5, 2)
|
|
result = vmap(vmap(vmap(torch.mul), (1, 0)), (1, 2))(x, y)
|
|
self.assertEqual(result, x.permute(1, 2, 0) * y.permute(2, 0, 1))
|
|
|
|
def test_nested_negative_in_dims(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
output = vmap(torch.mul, (-1, -1))(x, y)
|
|
self.assertEqual(output.shape, (3, 2))
|
|
self.assertEqual(output, (x * y).permute(1, 0))
|
|
|
|
def test_non_default_in_dims_out_dims(self):
|
|
x = torch.randn(2, 3, 5)
|
|
|
|
# Same in_dim as out_dim, vmap over identity
|
|
result = vmap(lambda x: x, in_dims=1, out_dims=1)(x)
|
|
self.assertEqual(result, x)
|
|
self.assertEqual(result.data_ptr(), x.data_ptr())
|
|
|
|
# Different in_dim from out_dim, vmap over identity
|
|
result = vmap(lambda x: x, in_dims=2, out_dims=1)(x)
|
|
self.assertEqual(result.shape, (2, 5, 3))
|
|
self.assertEqual(result, x.transpose(1, 2))
|
|
self.assertEqual(result.data_ptr(), x.data_ptr())
|
|
|
|
def foo(x):
|
|
return x * 2
|
|
|
|
# Same in_dim as out_dim, vmap over operation
|
|
result = vmap(foo, in_dims=1, out_dims=1)(x)
|
|
self.assertEqual(result, x * 2)
|
|
|
|
# Different in_dim as out_dim, vmap over operation
|
|
result = vmap(foo, in_dims=2, out_dims=1)(x)
|
|
self.assertEqual(result.shape, (2, 5, 3))
|
|
self.assertEqual(result, (x * 2).transpose(1, 2))
|
|
|
|
# Basic nested test.
|
|
result = vmap(vmap(foo, 1, 1), 1, 1)(x)
|
|
self.assertEqual(result, x * 2)
|
|
|
|
def test_item_throws(self):
|
|
def f(x):
|
|
return x.item()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r'item\(\) on a Tensor'):
|
|
vmap(f)(torch.randn(3))
|
|
|
|
def test_data_dependent_control_flow_throws(self):
|
|
def f(x):
|
|
if x:
|
|
return x
|
|
return 0
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r'data-dependent control flow'):
|
|
vmap(f)(torch.randn(3))
|
|
|
|
def test_accepts_nested_inputs(self):
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
# Single layer of nesting
|
|
out = vmap(lambda z: z[0] + z[1])((x, y))
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z[0] + z[1], in_dims=(0,))((x, y))
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))((x, y))
|
|
self.assertEqual(out, x + y)
|
|
|
|
out = vmap(lambda z: z[0] + z[1])([x, y])
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z[0] + z[1], in_dims=(0,))([x, y])
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, y])
|
|
self.assertEqual(out, x + y)
|
|
|
|
out = vmap(lambda z: z['x'] + z['y'])({'x': x, 'y': y})
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z['x'] + z['y'], in_dims=(0,))({'x': x, 'y': y})
|
|
self.assertEqual(out, x + y)
|
|
out = vmap(lambda z: z['x'] + z['y'], in_dims=({'x': 0, 'y': 0},))({'x': x, 'y': y})
|
|
self.assertEqual(out, x + y)
|
|
|
|
# Multiple layers of nesting
|
|
out_fn = vmap(lambda z: z['x'][0] + z['x'][1][0] + z['y'][0] + z['y'][1])
|
|
out = out_fn({'x': [x, (x,)], 'y': [y, y]})
|
|
self.assertEqual(out, x + x + y + y)
|
|
|
|
def test_in_dims_wrong_type_err_msg(self):
|
|
x = torch.randn(3)
|
|
y = torch.randn(3)
|
|
msg = r'expected `in_dims` to be int or a \(potentially nested\) tuple'
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.mul, [0, 0])(x, y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.mul, set({0, 0}))(x, y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.mul, 'lol')(x, y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=[0, 0])([x, y])
|
|
# The following should not throw
|
|
vmap(torch.mul, (0, 0))(x, y)
|
|
|
|
def test_not_enough_in_dims_err_msg(self):
|
|
x = torch.randn(3)
|
|
y = torch.randn(3)
|
|
msg = r'in_dims is not compatible with the structure of `inputs`'
|
|
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.mul, (0,))(x, y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.mul, (0, 0, 0))(x, y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=([0],))([x, y])
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=((0, 0),))([x, y])
|
|
# The following should not throw
|
|
vmap(torch.mul, (0, 0))(x, y)
|
|
|
|
def test_integer_in_dim_but_not_tensor_input_err_msg(self):
|
|
def foo(xy):
|
|
return xy[0] * xy[1]
|
|
|
|
def bar(x, yz):
|
|
return x * yz[0] * yz[1]
|
|
|
|
x = torch.randn(2, 3)
|
|
|
|
# the following are errors in jax (and will always be errors)
|
|
msg = 'Got in_dim=0 for an input but the input is of type'
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.sum)(x, 0)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(torch.sum, (0, 0))(x, 0)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=([0, 0],))([x, 1])
|
|
# The following should not throw
|
|
vmap(torch.sum, (0, None))(x, 0)
|
|
|
|
def test_in_dim_not_in_tensor_err_msg(self):
|
|
def foo(x):
|
|
return x * x
|
|
|
|
x = torch.randn(2, 3)
|
|
y = torch.randn(2, 3)
|
|
|
|
msg = r'Got in_dim=-?\w for some input, but that input is a Tensor of dimensionality \w'
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(foo)(torch.randn([]))
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(foo, in_dims=(0,))(torch.randn([]))
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(foo, in_dims=(-3,))(x)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(foo, in_dims=(2,))(y)
|
|
with self.assertRaisesRegex(ValueError, msg):
|
|
vmap(lambda z: z[0] + z[1], in_dims=([3, 0],))([x, y])
|
|
# the following should not throw
|
|
vmap(foo, in_dims=(0,))(torch.randn(2, 3))
|
|
vmap(foo, in_dims=(1,))(torch.randn(2, 3))
|
|
|
|
def test_fallback_does_not_warn_by_default(self):
|
|
# NB: One day we will implement a batching rule for torch.atan2.
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = torch.copysign
|
|
x = torch.randn(11)
|
|
y = torch.randn(11)
|
|
with warnings.catch_warnings(record=True) as wa:
|
|
vmap(op)(x, y)
|
|
# The single warning here is the "vmap is experimental"
|
|
# warning, not a warning from the vmap fallback path.
|
|
self.assertEqual(len(wa), 1)
|
|
|
|
@unittest.expectedFailure
|
|
def test_fallback_warns_when_warnings_are_enabled(self):
|
|
# NB: One day we will implement a batching rule for torch.atan2.
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = torch.copysign
|
|
x = torch.randn(11)
|
|
y = torch.randn(11)
|
|
with warnings.catch_warnings(record=True) as wa:
|
|
with EnableVmapFallbackWarnings():
|
|
vmap(op)(x, y)
|
|
self.assertEqual(len(wa), 2)
|
|
self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
|
|
|
|
def _assert_uses_vmap_fallback(self, vmap_args, inputs):
|
|
return
|
|
# with warnings.catch_warnings(record=True) as wa:
|
|
# with EnableVmapFallbackWarnings():
|
|
# result = vmap(*vmap_args)(*inputs)
|
|
# self.assertEqual(len(wa), 2)
|
|
# self.assertRegex(str(wa[-1].message), FALLBACK_REGEX)
|
|
|
|
def test_fallback_zero_dim(self):
|
|
# NB: One day we will implement a batching rule for torch.atan2.
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = torch.copysign
|
|
x = torch.randn(11)
|
|
y = torch.randn(11)
|
|
self._assert_uses_vmap_fallback((op,), (x, y))
|
|
|
|
B0, B1 = 0, 3
|
|
x = torch.randn(B0, 11)
|
|
y = torch.randn(11)
|
|
|
|
msg = 'The fallback path does not support vmap over dims of size 0'
|
|
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, (0, None))(x, y)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, (None, 0))(y, x)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(x, x)
|
|
|
|
x = torch.randn(B0, B1, 11)
|
|
y = torch.randn(B1, 11)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, (0, None))(x, y)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, (None, 0))(y, x)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(x, x)
|
|
|
|
def test_fallback_atan2(self):
|
|
# NB: One day we will implement a batching rule for torch.atan2.
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = torch.copysign
|
|
|
|
x = torch.randn(5, 7, 11)
|
|
y = torch.randn(5, 7, 11)
|
|
|
|
self._assert_uses_vmap_fallback((op,), (x, y))
|
|
|
|
# fallback on torch.atan2
|
|
x = torch.randn(7, 11, 5)
|
|
y = torch.randn(5, 7, 11)
|
|
result = vmap(op, (2, 0))(x, y)
|
|
self.assertEqual(result, op(x.permute(2, 0, 1), y))
|
|
|
|
# fallback on torch.atan2, nested vmap
|
|
x = torch.randn(7, 11, 5)
|
|
y = torch.randn(5, 7, 11)
|
|
result = vmap(vmap(op), (2, 0))(x, y)
|
|
self.assertEqual(result, op(x.permute(2, 0, 1), y))
|
|
|
|
# big batch size (total 10000)
|
|
x = torch.randn(100, 10, 10, 5)
|
|
y = torch.randn(100, 10, 10)
|
|
result = vmap(vmap(vmap(op)))(x, y)
|
|
self.assertEqual(result, op(x, y.view(100, 10, 10, 1)))
|
|
|
|
# TODO: No clue what is wrong here.
|
|
@unittest.skip
|
|
def test_fallback_masked_fill(self):
|
|
# NB: One day we will implement a batching rule for masked_fill
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
def run_test(batch_size):
|
|
B0 = batch_size
|
|
x = torch.randn(B0, 7, 11, 13)
|
|
dim = 0
|
|
index = torch.tensor([0, 4, 2])
|
|
values = torch.randn(B0, 3, 13)
|
|
|
|
self._assert_uses_vmap_fallback((torch.index_add, (0, None, None, 0)), (x, dim, index, values))
|
|
|
|
result = vmap(torch.index_add, (0, None, None, 0))(x, dim, index, values)
|
|
expected = torch.index_add(
|
|
x, dim + 1, index, values.view(B0, 3, 1, 13))
|
|
self.assertEqual(result, expected)
|
|
|
|
run_test(batch_size=5)
|
|
run_test(batch_size=1237)
|
|
|
|
def test_fallback_multiple_returns(self):
|
|
# NB: One day we will implement a batching rule for torch.var_mean
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
B0, B1, B2 = 2, 3, 1237
|
|
tensor = torch.randn(B0, 10)
|
|
|
|
self._assert_uses_vmap_fallback((torch.var_mean,), (tensor,))
|
|
|
|
# fallback correctness on torch.var_mean
|
|
result = vmap(torch.var_mean)(tensor)
|
|
expected = torch.var_mean(tensor, dim=1)
|
|
self.assertEqual(result, expected)
|
|
|
|
# nested vmap
|
|
tensor = torch.randn(B0, B1, 10)
|
|
result = vmap(vmap(torch.var_mean))(tensor)
|
|
expected = torch.var_mean(tensor, dim=2)
|
|
self.assertEqual(result, expected)
|
|
|
|
# big batch size, nested vmap
|
|
tensor = torch.randn(B0, B1, B2, 10)
|
|
result = vmap(vmap(vmap(torch.var_mean)))(tensor)
|
|
expected = torch.var_mean(tensor, dim=3)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_inplace_fallback_unary(self):
|
|
# Test the in-place fallback on an in-place method that takes no
|
|
# additional Tensor arguments. This is the simplest case of the fallback.
|
|
# NB: One day we will implement a batching rule for acos_.
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = Tensor.acos_
|
|
B0, B1, B2 = 2, 3, 10000
|
|
|
|
x = torch.randn(B0, 5)
|
|
self._assert_uses_vmap_fallback((op,), (x,))
|
|
|
|
# Single vmap
|
|
x_orig = torch.rand(B0, 5)
|
|
x = x_orig.clone()
|
|
result = vmap(op)(x)
|
|
self.assertTrue(result is x)
|
|
self.assertEqual(result, x_orig.acos())
|
|
|
|
# Single vmap + different out_dim produces a view(!)
|
|
x_orig = torch.rand(B0, 5)
|
|
x = x_orig.clone()
|
|
result = vmap(op, out_dims=(1,))(x)
|
|
self.assertTrue(result._base is x)
|
|
self.assertEqual(result, x_orig.t().acos())
|
|
|
|
# Nested vmap
|
|
x_orig = torch.randn(B0, B1, 5)
|
|
x = x_orig.clone()
|
|
result = vmap(vmap(op))(x)
|
|
self.assertTrue(result is x)
|
|
self.assertEqual(result, x_orig.acos())
|
|
|
|
# Nested vmap, large batch size
|
|
x_orig = torch.randn(B0, B1, B2, 5)
|
|
x = x_orig.clone()
|
|
result = vmap(vmap(vmap(op)))(x)
|
|
self.assertTrue(result is x)
|
|
self.assertEqual(result, x_orig.acos())
|
|
|
|
def test_inplace_fallback_nary_same_levels(self):
|
|
# NB: One day we will implement a batching rule for atan2_
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = Tensor.atan2_
|
|
outplace_op = torch.atan2
|
|
|
|
x = torch.randn(5, 7, 11)
|
|
y = torch.randn(5, 7, 11)
|
|
self._assert_uses_vmap_fallback((op,), (x, y))
|
|
|
|
# Single vmap
|
|
B0 = 5
|
|
x_orig = torch.randn(7, 11, B0)
|
|
x = x_orig.clone()
|
|
y = torch.randn(B0, 7, 11)
|
|
vmap(op, (2, 0))(x, y)
|
|
self.assertEqual(x, outplace_op(x_orig, y.movedim(0, 2)))
|
|
|
|
# Nested vmap
|
|
B0, B1 = 5, 7
|
|
x_orig = torch.randn(B1, 11, B0)
|
|
x = x_orig.clone()
|
|
y = torch.randn(B0, B1, 11)
|
|
vmap(vmap(op), (2, 0))(x, y)
|
|
self.assertEqual(x, outplace_op(x_orig, y.movedim([0, 1], [2, 0])))
|
|
|
|
# big batch size (total 10000)
|
|
B0, B1, B2 = 100, 10, 10
|
|
x_orig = torch.randn(B0, B1, B2, 5)
|
|
x = x_orig.clone()
|
|
y = torch.randn(B0, B1, B2)
|
|
vmap(vmap(vmap(op)))(x, y)
|
|
self.assertEqual(x, outplace_op(x_orig, y.view(B0, B1, B2, 1)))
|
|
|
|
# ("Fallback isInplaceVmapCompatible check is broken")
|
|
@unittest.expectedFailure
|
|
def test_inplace_fallback_nary_different_levels(self):
|
|
# NB: One day we will implement a batching rule for atan2_
|
|
# If/when we do, this test should be replaced to test the fallback
|
|
# path on another operator to avoid bitrot.
|
|
op = Tensor.atan2_
|
|
outplace_op = torch.atan2
|
|
B0, B1 = 2, 3
|
|
|
|
x = torch.rand(B0, 7)
|
|
y = torch.rand(7)
|
|
self._assert_uses_vmap_fallback((op, (0, None)), (x, y))
|
|
|
|
# op(left, right): All of the levels in right are found in left
|
|
x_orig = torch.rand(B0, 7)
|
|
x = x_orig.clone()
|
|
y = torch.rand(7)
|
|
vmap(op, in_dims=(0, None))(x, y)
|
|
self.assertEqual(x, outplace_op(x_orig, y))
|
|
|
|
x_orig = torch.rand(B0, B1, 7)
|
|
x = x_orig.clone()
|
|
y = torch.rand(B0, 7)
|
|
vmap(vmap(op, in_dims=(0, None)))(x, y)
|
|
self.assertEqual(x, outplace_op(x_orig, y.view(B0, 1, 7)))
|
|
|
|
# op(left, right): Some of the levels in right are not found in left
|
|
msg = r'vmap: aten::atan2_\(self, \*extra_args\) is not possible'
|
|
x = torch.rand(7)
|
|
y = torch.rand(B0, 7)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(None, 0))(x, y)
|
|
|
|
x = torch.rand(B1, 7)
|
|
y = torch.rand(B0, 7)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 0))(x, y)
|
|
|
|
x = torch.rand(B1, 7)
|
|
y = torch.rand(7, B0)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(vmap(op, in_dims=(0, None)), in_dims=(None, 1))(x, y)
|
|
|
|
x = torch.rand(B0, 7)
|
|
y = torch.rand(B0, B1, 7)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(vmap(op, in_dims=(None, 0)))(x, y)
|
|
|
|
def test_backward_unsupported_interaction(self):
|
|
x = torch.randn(3, requires_grad=True)
|
|
y = torch.randn(5)
|
|
grad = torch.randn_like(x)
|
|
err_msg = r'backward\(\) called inside a functorch transform'
|
|
|
|
def backward_on_vmapped_tensor(x):
|
|
x.sum().backward()
|
|
|
|
# FIXME
|
|
return self.skipTest("error: element 0 of tensors does not require grad and does not have a grad_fn")
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
vmap(backward_on_vmapped_tensor)(x)
|
|
|
|
def backward_with_vmapped_grad(x, grad):
|
|
x.backward(grad)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
vmap(backward_with_vmapped_grad)(x, grad)
|
|
|
|
def completely_unrelated_backward(y):
|
|
x.sum().backward()
|
|
return y
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
vmap(completely_unrelated_backward)(y)
|
|
|
|
@unittest.expectedFailure
|
|
def test_grad_unsupported_interaction(self):
|
|
input_tensor = torch.randn(3, requires_grad=True)
|
|
err_msg = 'autograd.grad.* called inside torch.vmap'
|
|
|
|
captured = torch.randn(3, requires_grad=True)
|
|
|
|
def output_to_grad_is_vmapped(input_tensor):
|
|
output = (captured * input_tensor).sum()
|
|
return torch.autograd.grad([output], [captured])[0]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
vmap(output_to_grad_is_vmapped)(input_tensor)
|
|
|
|
output = (input_tensor ** 2).sum()
|
|
|
|
def input_to_grad_is_vmapped(input_tensor):
|
|
return torch.autograd.grad([output], [input_tensor])[0]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, err_msg):
|
|
vmap(input_to_grad_is_vmapped)(input_tensor)
|
|
|
|
def test_batched_gradient_basic(self):
|
|
N = 3
|
|
x = torch.randn(N, requires_grad=True)
|
|
y = torch.randn(N)
|
|
|
|
def vjp_mul(v):
|
|
return torch.autograd.grad([x * y], [x], grad_outputs=[v])[0]
|
|
|
|
batched_v = torch.eye(N)
|
|
jacobian = vmap(vjp_mul)(batched_v)
|
|
self.assertEqual(jacobian, torch.diagflat(y))
|
|
|
|
def test_functools_partial(self):
|
|
x = torch.randn(3)
|
|
y = torch.randn(2, 3)
|
|
result = vmap(functools.partial(torch.mul, x))(y)
|
|
self.assertEqual(result, x * y)
|
|
|
|
def test_nn_module(self):
|
|
tensor = torch.randn(2, 3)
|
|
model = torch.nn.Linear(3, 3, bias=False)
|
|
result = vmap(model)(tensor)
|
|
self.assertEqual(result, model(tensor))
|
|
|
|
def test_fallback_with_undefined_grad(self):
|
|
B0 = 7
|
|
x = torch.randn(2, 3, 4, 5, requires_grad=True)
|
|
weight = torch.randn(3, 3, 1, 1)
|
|
v = torch.randn(B0, 2, 3, 4, 5)
|
|
|
|
def get_vjp(v):
|
|
result = torch.nn.functional.conv2d(x, weight)
|
|
grad_x, = torch.autograd.grad(result, x, v)
|
|
return grad_x
|
|
|
|
# Runs vmap(get_vjp)(v), which should not error out.
|
|
# The backward formula for convolution returns an undefined
|
|
# Tensor for grad_bias because the original bias does not exist.
|
|
#
|
|
# In the future we'll probably add a batching rule for convolution
|
|
# backward. When this happens, we should modify this test to use a
|
|
# different op (and/or create and use a dummy operator) to avoid bitrot.
|
|
self._assert_uses_vmap_fallback([get_vjp], [v])
|
|
|
|
def test_reshape_dim_into(self):
|
|
x = torch.randn(2, 3, 5, 7)
|
|
|
|
y = reshape_dim_into(0, 0, x)
|
|
self.assertEqual(y, x.reshape(6, 5, 7))
|
|
|
|
y = reshape_dim_into(0, 1, x)
|
|
self.assertEqual(y, x.movedim(0, 1).reshape(3, 2 * 5, 7))
|
|
|
|
y = reshape_dim_into(0, 2, x)
|
|
self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7))
|
|
|
|
y = reshape_dim_into(1, 2, x)
|
|
self.assertEqual(y, x.movedim(1, 2).reshape(2, 5, 3 * 7))
|
|
|
|
y = reshape_dim_into(0, -2, x)
|
|
self.assertEqual(y, x.movedim(0, 1).reshape(3, 2 * 5, 7))
|
|
|
|
y = reshape_dim_into(0, -1, x)
|
|
self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7))
|
|
|
|
y = reshape_dim_into(-4, -1, x)
|
|
self.assertEqual(y, x.movedim(0, 2).reshape(3, 5, 2 * 7))
|
|
|
|
def test_reshape_dim_outof(self):
|
|
x = torch.randn(12, 12, 12).permute(2, 1, 0)
|
|
|
|
y = reshape_dim_outof(0, 2, x)
|
|
self.assertEqual(y, x.reshape(2, 6, 12, 12))
|
|
|
|
y = reshape_dim_outof(1, 4, x)
|
|
self.assertEqual(y, x.reshape(12, 4, 3, 12))
|
|
|
|
y = reshape_dim_outof(2, 6, x)
|
|
self.assertEqual(y, x.reshape(12, 12, 6, 2))
|
|
|
|
y = reshape_dim_outof(-1, 6, x)
|
|
self.assertEqual(y, x.reshape(12, 12, 6, 2))
|
|
|
|
def test_batch_rule_does_not_need_to_handle_no_batched_input(self):
|
|
def f(x, y):
|
|
res = torch.dot(y, torch.ones(2))
|
|
return x + res
|
|
|
|
x = torch.randn(7, 5)
|
|
y = torch.randn(3, 2)
|
|
out = vmap(vmap(f, in_dims=(0, None)), in_dims=(None, 0))(x, y)
|
|
expected = torch.mv(y, torch.ones(2)).view(3, 1, 1) + x
|
|
self.assertEqual(out, expected)
|
|
|
|
def _test_vmap_autocast(self, device):
|
|
|
|
if torch.device(device).type == "cpu":
|
|
amp_dtype = torch.bfloat16
|
|
else:
|
|
amp_dtype = torch.float16
|
|
|
|
a_float32 = torch.rand(4, 2, 3, device=device)
|
|
b_float32 = torch.rand(4, 3, 2, device=device)
|
|
c_float32 = torch.rand(4, 2, 2, device=device)
|
|
d_float32 = torch.rand(4, 3, 2, device=device)
|
|
|
|
# Case 1, autocast inside vmapped function
|
|
def func1(x, y, z, w):
|
|
with torch.autocast(dtype=amp_dtype, device_type=device):
|
|
e_float16 = torch.matmul(x, y)
|
|
assert e_float16.dtype == amp_dtype, e_float16.dtype
|
|
f_float16 = torch.matmul(z, e_float16)
|
|
assert f_float16.dtype == amp_dtype, f_float16.dtype
|
|
return torch.matmul(w, f_float16.float())
|
|
|
|
expected = func1(a_float32, b_float32, c_float32, d_float32)
|
|
out = vmap(func1)(a_float32, b_float32, c_float32, d_float32)
|
|
assert expected.allclose(out)
|
|
|
|
# Case 2, autocast decorator inside vmapped function
|
|
@torch.autocast(dtype=amp_dtype, device_type=device)
|
|
def func2(x, y, z, w):
|
|
e_float16 = torch.matmul(x, y)
|
|
assert e_float16.dtype == amp_dtype, e_float16.dtype
|
|
f_float16 = torch.matmul(z, e_float16)
|
|
assert f_float16.dtype == amp_dtype, f_float16.dtype
|
|
return torch.matmul(w, f_float16)
|
|
|
|
expected = func2(a_float32, b_float32, c_float32, d_float32)
|
|
out = vmap(func2)(a_float32, b_float32, c_float32, d_float32)
|
|
assert expected.allclose(out)
|
|
|
|
# Case 3, autocast is outside vmapped function
|
|
def func3(x, y, z, w):
|
|
e_float16 = torch.matmul(x, y)
|
|
assert e_float16.dtype == amp_dtype, e_float16.dtype
|
|
f_float16 = torch.matmul(z, e_float16)
|
|
assert f_float16.dtype == amp_dtype, f_float16.dtype
|
|
return torch.matmul(w, f_float16)
|
|
|
|
with torch.autocast(dtype=amp_dtype, device_type=device):
|
|
expected = func3(a_float32, b_float32, c_float32, d_float32)
|
|
out = vmap(func3)(a_float32, b_float32, c_float32, d_float32)
|
|
|
|
assert expected.allclose(out)
|
|
|
|
@unittest.skip("Somehow, vmap and autocast do not work on CPU")
|
|
def test_vmap_autocast_cpu(self):
|
|
self._test_vmap_autocast("cpu")
|
|
|
|
@skipIf(not torch.cuda.is_available(), "CUDA is unavailable")
|
|
def test_vmap_autocast_cuda(self):
|
|
self._test_vmap_autocast("cuda")
|
|
|
|
|
|
def slice_inputs(inputs, bdims, i):
|
|
result = []
|
|
for inp, bdim in zip(inputs, bdims):
|
|
if bdim is None:
|
|
result.append(inp)
|
|
else:
|
|
result.append(inp.select(bdim, i))
|
|
return tuple(result)
|
|
|
|
|
|
def reference_vmap(op, inputs, in_dims=0, out_dims=0):
|
|
if isinstance(in_dims, int):
|
|
in_dims = (in_dims,) * len(inputs)
|
|
bdim_sizes = [inp.size(dim) for inp, dim in zip(inputs, in_dims) if dim is not None]
|
|
assert all(bdim_size == bdim_sizes[0] for bdim_size in bdim_sizes)
|
|
bdim_size = bdim_sizes[0]
|
|
results = tuple(op(*slice_inputs(inputs, in_dims, i)) for i in range(bdim_size))
|
|
|
|
assert len(results) > 0
|
|
op_has_single_return = not isinstance(results[0], tuple)
|
|
if op_has_single_return:
|
|
assert all(isinstance(result, torch.Tensor) for result in results)
|
|
if isinstance(out_dims, int):
|
|
out_dims = (out_dims,) * 1
|
|
return torch.stack(results, dim=out_dims[0])
|
|
|
|
assert all(isinstance(result, tuple) for result in results)
|
|
num_returns = len(results[0])
|
|
assert all(len(result) == num_returns for result in results)
|
|
if isinstance(out_dims, int):
|
|
out_dims = (out_dims,) * num_returns
|
|
return tuple(torch.stack(result_shards, out_dim)
|
|
for result_shards, out_dim in zip(zip(*results), out_dims))
|
|
|
|
|
|
class TensorFactory:
|
|
@staticmethod
|
|
def rand(size, device='cpu', dtype=torch.float):
|
|
return torch.rand(size, device=device, dtype=dtype)
|
|
|
|
@staticmethod
|
|
def randn(size, device='cpu', dtype=torch.float):
|
|
return torch.randn(size, device=device, dtype=dtype)
|
|
|
|
@staticmethod
|
|
def randp1(size, device='cpu', dtype=torch.float):
|
|
return torch.rand(size, device=device, dtype=dtype) + 1
|
|
|
|
# Tests vmap(op, in_dims, out_dims)(*inputs) by comparing the output to a
|
|
# (slow) sequential map+stack fallback.
|
|
#
|
|
# check_view: Test if the first returned output is a view of the first input
|
|
# check_propagates_grad: Test if the operation propagates gradients.
|
|
|
|
|
|
def _vmap_test(self, op, inputs, in_dims=0, out_dims=0,
|
|
check_view=False, check_propagates_grad=True):
|
|
result = vmap(op, in_dims, out_dims)(*inputs)
|
|
reference_result = reference_vmap(op, inputs, in_dims, out_dims)
|
|
self.assertEqual(result, reference_result)
|
|
op_has_single_return = not isinstance(result, tuple)
|
|
|
|
if check_view:
|
|
result_as_tuple = (result,) if op_has_single_return else result
|
|
for output in result_as_tuple:
|
|
input0_base = inputs[0] if inputs[0]._base is None else inputs[0]._base
|
|
self.assertTrue(output._base is input0_base,
|
|
msg="result was not a view of the first input!")
|
|
|
|
if not check_propagates_grad:
|
|
return
|
|
# Assuming input[0] is a floating-point tensor. Check if the vmap
|
|
# operation propagates the requires_grad flag to the zeroth output.
|
|
# Some vmap operators are implemented in a way that assumes that
|
|
# they are composite with respect to autograd. If the operator ever is
|
|
# changed to not be composite with respect to autograd, then the
|
|
# following check should fail.
|
|
inputs_clone = list(inputs)
|
|
inputs_clone[0] = inputs[0].clone().requires_grad_()
|
|
result = vmap(op, in_dims, out_dims)(*inputs_clone)
|
|
result_as_tuple = (result,) if op_has_single_return else result
|
|
self.assertTrue(result[0].requires_grad)
|
|
|
|
|
|
def should_allow_vmap_fallback_usage(fn):
|
|
return getattr(fn, '_allow_vmap_fallback_usage', False)
|
|
|
|
|
|
def allowVmapFallbackUsage(fn):
|
|
fn._allow_vmap_fallback_usage = True
|
|
return fn
|
|
|
|
# All tests of TestVmapBase check that the slow vmap fallback is never invoked.
|
|
# This is so that we can incrementally add batching rules for operators to
|
|
# replace the slow vmap fallback path for said operators. To skip this check,
|
|
# please use the allowVmapFallbackUsage decorator.
|
|
#
|
|
# NB: Don't add tests to TestVmapBase directly, unless you want them to run
|
|
# on every subclass of TestVmapBase. Add them to e.g. TestVmapOperators.
|
|
#
|
|
# NB: TestVmapBase is a nested class. This prevents test runners from picking
|
|
# it up and running it.
|
|
|
|
|
|
class Namespace:
|
|
class TestVmapBase(TestCase):
|
|
def __init__(self, method_name='runTest'):
|
|
super().__init__(method_name)
|
|
|
|
test_method = getattr(self, method_name, None)
|
|
if test_method is None:
|
|
return
|
|
|
|
if not should_allow_vmap_fallback_usage(test_method):
|
|
setattr(self, method_name,
|
|
self._wrap_method_with_vmap_fallback_check(test_method))
|
|
|
|
def _wrap_method_with_vmap_fallback_check(self, method):
|
|
# msg = (
|
|
# 'Expected the test to not invoke the vmap fallback path, i.e., '
|
|
# 'all of the operators being tested in this test should have batching '
|
|
# 'rules implemented. If you are intentionally testing something to '
|
|
# 'do with the fallback path, use allowVmapFallbackUsage. Otherwise, '
|
|
# 'please make sure that batching rules are implemented for the '
|
|
# 'operator(s) being tested.'
|
|
# )
|
|
|
|
@functools.wraps(method)
|
|
def wrapper(self, *args, **kwargs):
|
|
with warnings.catch_warnings(record=True):
|
|
warnings.simplefilter('always')
|
|
with EnableVmapFallbackWarnings():
|
|
method(*args, **kwargs)
|
|
# for captured_warning in wa:
|
|
# self.assertNotRegex(str(captured_warning.message), FALLBACK_REGEX, msg)
|
|
return types.MethodType(wrapper, self)
|
|
|
|
@allowVmapFallbackUsage
|
|
def test_vmap_fallback_check_ok(self):
|
|
# One day we'll implement a batching rule for torch.var_mean.
|
|
# When that happens, please change the example to use an
|
|
# operator that doesn't have a batching rule implemented.
|
|
op_using_fallback = torch.var_mean
|
|
vmap(op_using_fallback)(torch.rand(3))
|
|
|
|
@unittest.expectedFailure
|
|
def test_vmap_fallback_check(self):
|
|
@self._wrap_method_with_vmap_fallback_check
|
|
def no_fallback(self):
|
|
pass
|
|
|
|
# One day we'll implement a batching rule for torch.var_mean.
|
|
# When that happens, please change the example to use an
|
|
# operator that doesn't have a batching rule implemented.
|
|
op_using_fallback = torch.var_mean
|
|
|
|
@self._wrap_method_with_vmap_fallback_check
|
|
def uses_fallback(self):
|
|
vmap(op_using_fallback)(torch.rand(3))
|
|
|
|
no_fallback(self)
|
|
|
|
with self.assertRaises(AssertionError):
|
|
uses_fallback(self)
|
|
|
|
|
|
def _make_case(op, input_getter=TensorFactory.randn):
|
|
return (op, input_getter)
|
|
|
|
|
|
class TestVmapOperators(Namespace.TestVmapBase):
|
|
def _vmap_test(self, *args, **kwargs):
|
|
return _vmap_test(self, *args, **kwargs)
|
|
|
|
def _vmap_view_test(self, *args, **kwargs):
|
|
self._vmap_test(*args, **kwargs, check_view=True)
|
|
|
|
def _test_unary(self, op, getter, device, *args, **kwargs):
|
|
test = functools.partial(self._vmap_test, *args, **kwargs)
|
|
B0, B1 = 7, 11
|
|
|
|
# Single vmap, various in_dims / out_dims
|
|
test(op, [getter([B0, 3], device)])
|
|
test(op, [getter([2, 5, B0, 3], device)], in_dims=2)
|
|
test(op, [getter([2, 5, B0, 3], device)], in_dims=2, out_dims=2)
|
|
|
|
# Doubly nested vmap
|
|
test(vmap(op), [getter([B0, B1], device)])
|
|
test(vmap(op), [getter([B1, 2, 5, B0, 3], device)], in_dims=2)
|
|
test(vmap(op, in_dims=2), [getter([2, 5, B0, B1, 3], device)],
|
|
in_dims=2, out_dims=2)
|
|
|
|
@parametrize("case", [
|
|
(torch.abs, TensorFactory.randn),
|
|
(torch.acos, TensorFactory.rand),
|
|
(torch.asin, TensorFactory.rand),
|
|
(torch.atan, TensorFactory.rand),
|
|
(torch.ceil, TensorFactory.randn),
|
|
(torch.cos, TensorFactory.rand),
|
|
(torch.cosh, TensorFactory.rand),
|
|
(torch.digamma, TensorFactory.rand),
|
|
(torch.exp, TensorFactory.randn),
|
|
(torch.expm1, TensorFactory.randn),
|
|
(torch.floor, TensorFactory.randn),
|
|
(torch.frac, TensorFactory.randn),
|
|
(torch.lgamma, TensorFactory.rand),
|
|
(torch.log, TensorFactory.randp1),
|
|
(torch.log10, TensorFactory.randp1),
|
|
(torch.log1p, TensorFactory.randp1),
|
|
(torch.log2, TensorFactory.randp1),
|
|
(torch.neg, TensorFactory.randn),
|
|
(torch.reciprocal, TensorFactory.randp1),
|
|
(torch.relu, TensorFactory.randn),
|
|
(torch.round, TensorFactory.randn),
|
|
(torch.rsqrt, TensorFactory.randp1),
|
|
(torch.sigmoid, TensorFactory.randn),
|
|
(torch.sign, TensorFactory.randn),
|
|
(torch.sin, TensorFactory.rand),
|
|
(torch.sinh, TensorFactory.rand),
|
|
(torch.sqrt, TensorFactory.rand),
|
|
(torch.tan, TensorFactory.rand),
|
|
(torch.tanh, TensorFactory.rand),
|
|
(torch.trunc, TensorFactory.randn),
|
|
], name_fn=lambda x: x[0].__name__)
|
|
def test_unary_pointwise(self, case):
|
|
op, getter = case
|
|
self._test_unary(op, getter, 'cpu')
|
|
|
|
# test in-place
|
|
method = getattr(Tensor, f'{op.__name__ + "_"}')
|
|
self._test_unary(method, getter, 'cpu', check_propagates_grad=False)
|
|
|
|
def test_clone(self):
|
|
# Some basic tests
|
|
self._test_unary(lambda x: x.clone(), TensorFactory.randn, 'cpu')
|
|
self._test_unary(lambda x: x.clone(memory_format=torch.preserve_format),
|
|
TensorFactory.randn, 'cpu')
|
|
self._test_unary(lambda x: x.clone(memory_format=torch.contiguous_format),
|
|
TensorFactory.randn, 'cpu')
|
|
|
|
# Test that the per-examples are contiguous when using torch.contiguous_format
|
|
def clone_contiguous(x):
|
|
return x.clone(memory_format=torch.contiguous_format)
|
|
|
|
B0, B1 = 3, 5
|
|
x = torch.randn(2, B0, 7)
|
|
y = vmap(clone_contiguous, in_dims=1, out_dims=1)(x)
|
|
self.assertTrue(y.movedim(1, 0).is_contiguous())
|
|
self.assertTrue(y[:, 0, :].is_contiguous())
|
|
|
|
x = torch.randn(2, B0, 7, B1)
|
|
y = vmap(vmap(clone_contiguous, in_dims=2), in_dims=1)(x)
|
|
self.assertTrue(y.is_contiguous())
|
|
self.assertTrue(y[0][0].is_contiguous())
|
|
|
|
msg = r'only supported with memory_format torch.preserve_format or torch.contiguous_format'
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(lambda x: x.clone(memory_format=torch.channels_last))(torch.randn(B0))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(lambda x: x.clone(memory_format=torch.channels_last_3d))(torch.randn(B0))
|
|
|
|
def test_weird_matmul_case(self):
|
|
# Check that this doesn't crash.
|
|
# https://github.com/pytorch/functorch/issues/417
|
|
x = torch.randn(5, 2, 2, 2)
|
|
y = torch.randn(5, 7, 2)
|
|
|
|
vmap(vmap(torch.matmul, in_dims=(None, 0)))(x, y)
|
|
|
|
@parametrize("case",
|
|
(
|
|
(torch.clamp_min_, TensorFactory.randn),
|
|
(torch.clamp_max_, TensorFactory.randn),
|
|
), name_fn=lambda x: x[0].__name__)
|
|
def test_clamp_inplace_variant(self, case):
|
|
test = self._vmap_test
|
|
|
|
def get_number(getter):
|
|
return getter([]).item()
|
|
|
|
op, getter = case
|
|
device = 'cpu'
|
|
B0, B1 = 7, 11
|
|
|
|
# Single vmap: op(Tensor, Tensor)
|
|
test(op, (getter([B0, 3], device), getter([B0, 3], device)), check_propagates_grad=False)
|
|
test(op, (getter([B0], device), getter([B0], device)), check_propagates_grad=False)
|
|
test(op, (getter([2, B0, 3], device), getter([2, B0, 3], device)), in_dims=(1, 1), check_propagates_grad=False)
|
|
test(op, (getter([B0, 2, 3], device), getter([2, B0, 3], device)),
|
|
in_dims=(0, 1), out_dims=1, check_propagates_grad=False)
|
|
test(op, (getter([B0, 2, 3], device), getter([1, 1], device)), in_dims=(0, None), check_propagates_grad=False)
|
|
test(op, (getter([B0, 3], device), getter([B0, 3], device)), in_dims=(0, 0), check_propagates_grad=False)
|
|
|
|
# Nested vmap: op(Tensor, Tensor)
|
|
test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 1, 3], device)), check_propagates_grad=False)
|
|
|
|
# Python number overload: op(Tensor, Number)
|
|
number = get_number(getter)
|
|
self._test_unary(lambda t: op(t, number), getter, device, check_propagates_grad=False)
|
|
|
|
@parametrize('case', [
|
|
subtest(_make_case(torch.clamp_min), name='clamp_min'),
|
|
subtest(_make_case(torch.clamp_max), name='clamp_max'),
|
|
])
|
|
def test_clamp_variant(self, case):
|
|
test = self._vmap_test
|
|
|
|
def get_number(getter):
|
|
return getter([]).item()
|
|
|
|
op, getter = case
|
|
device = 'cpu'
|
|
B0, B1 = 7, 11
|
|
|
|
# Single vmap: op(Tensor, Tensor)
|
|
test(op, (getter([B0, 3], device), getter([B0, 3], device)))
|
|
test(op, (getter([B0], device), getter([B0, 2, 3], device)))
|
|
test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
|
|
test(op, (getter([B0], device), getter([2, B0, 3], device)),
|
|
in_dims=(0, 1), out_dims=1)
|
|
test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
|
|
test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(None, 0))
|
|
|
|
# Nested vmap: op(Tensor, Tensor)
|
|
test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
|
|
test(vmap(op, in_dims=(None, 0)),
|
|
(getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
|
|
|
|
# Python number overload: op(Tensor, Number)
|
|
number = get_number(getter)
|
|
self._test_unary(lambda t: op(t, number), getter, device)
|
|
|
|
def test_copy_(self):
|
|
x = torch.randn(3)
|
|
y = torch.randn(3)
|
|
vmap(Tensor.copy_)(x, y)
|
|
self.assertEqual(x, y)
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(3, 2)
|
|
vmap(Tensor.copy_, in_dims=(1, None))(y, x)
|
|
self.assertEqual(y, x.expand(2, 3).t())
|
|
|
|
x = torch.randn(3)
|
|
y = torch.randn(2, 3)
|
|
with self.assertRaisesRegex(RuntimeError, 'inplace'):
|
|
vmap(Tensor.copy_, in_dims=(None, 0))(x, y)
|
|
|
|
def test_silu_backward(self):
|
|
test = self._vmap_test
|
|
device = 'cpu'
|
|
getter = TensorFactory.randp1
|
|
B0 = 7
|
|
op = torch.ops.aten.silu_backward
|
|
|
|
# Single vmap: op(Tensor, Tensor)
|
|
test(op, (getter([B0, 3], device), getter([B0, 3], device)))
|
|
test(op, (getter([], device), getter([B0], device)), in_dims=(None, 0))
|
|
test(op, (getter([2, B0], device), getter([2], device)), in_dims=(1, None))
|
|
|
|
@parametrize('case', [
|
|
subtest(_make_case(torch.add), name='add'),
|
|
subtest(_make_case(lambda x, y: x + y), name='add_dunder'),
|
|
subtest(_make_case(torch.sub), name='sub'),
|
|
subtest(_make_case(lambda x, y: x - y), name='sub_dunder'),
|
|
subtest(_make_case(torch.mul), name='mul'),
|
|
subtest(_make_case(lambda x, y: x * y), name='mul_dunder'),
|
|
subtest(_make_case(torch.div, input_getter=TensorFactory.randp1), name='div'),
|
|
subtest(_make_case(lambda x, y: x / y, input_getter=TensorFactory.randp1), name='div_dunder'),
|
|
subtest(_make_case(torch.pow, input_getter=TensorFactory.randp1), name='pow'),
|
|
subtest(_make_case(lambda x, y: x ** y, input_getter=TensorFactory.randp1), name='pow_dunder'),
|
|
])
|
|
def test_arithmetic(self, case):
|
|
test = self._vmap_test
|
|
|
|
def get_number(getter):
|
|
return getter([]).item()
|
|
|
|
op, getter = case
|
|
device = 'cpu'
|
|
B0, B1 = 7, 11
|
|
|
|
# Single vmap: op(Tensor, Tensor)
|
|
test(op, (getter([B0, 3], device), getter([B0, 3], device)))
|
|
test(op, (getter([B0], device), getter([B0, 2, 3], device)))
|
|
test(op, (getter([B0], device), getter([2, B0, 3], device)), in_dims=(0, 1))
|
|
test(op, (getter([B0], device), getter([2, B0, 3], device)),
|
|
in_dims=(0, 1), out_dims=1)
|
|
test(op, (getter([B0], device), getter([2, 3], device)), in_dims=(0, None))
|
|
test(op, (getter([2, 3], device), getter([B0, 3], device)), in_dims=(0, None))
|
|
|
|
# Nested vmap: op(Tensor, Tensor)
|
|
test(vmap(op), (getter([B0, B1, 2, 3], device), getter([B0, B1, 3], device)))
|
|
test(vmap(op, in_dims=(None, 0)),
|
|
(getter([B0, 2, 3], device), getter([B1, 3], device)), in_dims=(0, None))
|
|
|
|
# Python number overload: op(Tensor, Number) (and vice-versa)
|
|
number = get_number(getter)
|
|
self._test_unary(lambda t: op(t, number), getter, device)
|
|
number = get_number(getter)
|
|
self._test_unary(lambda t: op(number, t), getter, device)
|
|
|
|
# Type promotion: op(Logical Scalar Tensor, Logical Scalar Tensor)
|
|
test(op, (getter([B0], device), getter([B0], device, dtype=torch.double)))
|
|
test(op, (getter([B0], device, dtype=torch.double), getter([B0], device)))
|
|
test(op, (getter([B0], device), getter([B0], device)))
|
|
|
|
# Type promotion: op(Tensor, Logical Scalar Tensor) (and vice-versa)
|
|
test(op, (getter([B0, 2], device), getter([B0], device, torch.double)))
|
|
test(op, (getter([B0], device, torch.double), getter([B0, 2], device)))
|
|
|
|
if not torch.cuda.is_available():
|
|
return
|
|
|
|
# TODO(rzou): fix the following
|
|
# # Test cross-device scalars
|
|
# number = get_number(getter)
|
|
# self._test_unary(lambda t: op(t, number), getter, device='cuda')
|
|
# self._test_unary(lambda t: op(number, t), getter, device='cuda')
|
|
# self._test_unary(lambda t: op(t, torch.tensor(number)), getter, device='cuda')
|
|
|
|
def test_as_strided(self):
|
|
def _test(sizes, strides, offset, tensor, lambd):
|
|
# bdim at dim 0 test
|
|
result = vmap(lambda t: t.as_strided(sizes, strides, offset))(tensor)
|
|
expected = vmap(lambd)(tensor)
|
|
self.assertTrue(result._base is expected._base)
|
|
self.assertEqual(result, expected)
|
|
|
|
# bdim at dim -1 test
|
|
tensor = tensor.movedim(0, -1)
|
|
result = vmap(lambda t: t.as_strided(sizes, strides, offset), -1)(tensor)
|
|
expected = vmap(lambd, -1)(tensor)
|
|
self.assertTrue(result._base is expected._base)
|
|
self.assertEqual(result, expected)
|
|
|
|
# single vmap test
|
|
B0 = 5
|
|
# Each Tensor has shape [B0, 2, 3]; the expressions below
|
|
# are just to get tensors of different strides that have shape [B0, 2, 3]
|
|
tensors = [
|
|
# contiguous
|
|
torch.randn(B0, 2, 3),
|
|
# non-contiguous
|
|
torch.randn(B0, 3, 2).transpose(1, 2),
|
|
torch.randn(3, 2, B0).movedim(-1, 0).transpose(1, 2),
|
|
# non-zero storage offset
|
|
torch.randn(2, B0, 2, 3)[1],
|
|
torch.randn(2, 2, B0, 3)[1].movedim(1, 0),
|
|
# non-contiguous strides, zero storage offset
|
|
torch.randn(B0, 2, 4, 3, 7)[:, :, 0, :, 0],
|
|
torch.randn(2, 4, B0, 3, 7).movedim(2, 0)[:, :, 0, :, 0],
|
|
# non-contiguous strides, non-zero storage offset
|
|
torch.randn(B0, 2, 4, 3, 7)[:, :, 2, :, 1],
|
|
torch.randn(2, 4, 3, 7, B0).movedim(-1, 0)[:, :, 2, :, 1],
|
|
]
|
|
|
|
for x in tensors:
|
|
S0, S1 = x.stride()[1:]
|
|
offset = x.storage_offset()
|
|
|
|
# Broadcast
|
|
_test([5, 5, 2, 3], [0, 0, S0, S1], offset, x, lambda x: x.expand(5, 5, 2, 3))
|
|
# transpose
|
|
_test([3, 2], [S1, S0], offset, x, lambda x: x.transpose(0, 1))
|
|
# select
|
|
_test([2], [S0], offset + S1, x, lambda x: x[:, 1])
|
|
# diagonal
|
|
_test([2], [S0 + S1], offset, x, lambda x: x.diagonal())
|
|
# strided slice
|
|
_test([2], [S1 * 2], offset, x, lambda x: x[0, ::2])
|
|
|
|
# Nested vmap test
|
|
B1 = 7
|
|
x = torch.randn(B1, B0, 2, 3)
|
|
S0, S1 = x.stride()[2:]
|
|
result = vmap(vmap(lambda t: t.as_strided([5, 5, 2, 3], [0, 0, S0, S1])), in_dims=1)(x)
|
|
expected = vmap(vmap(lambda t: t.expand(5, 5, 2, 3)), in_dims=1)(x)
|
|
self.assertTrue(result._base is expected._base)
|
|
self.assertEqual(result, expected)
|
|
|
|
# Check that mal-formatted size/strides doesn't crash
|
|
with self.assertRaisesRegex(RuntimeError, 'size and stride must have the same length'):
|
|
x = torch.randn(B0, 2, 3).transpose(0, 1)
|
|
vmap(lambda x: x.as_strided([1, 1, 1], [1, 1]))(x)
|
|
|
|
# All the Sanity check #1{a,b,c} cases check that
|
|
# xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
|
|
# doesn't index memory that is out of bounds of xs[i]. This condition
|
|
# is important to the correctness of the as_strided batching rule
|
|
# (see NOTE: [When will the as_strided_batching_rule fail?])
|
|
|
|
# Sanity check #1a: The maximum indexable location of
|
|
# xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
|
|
# is less than or equal to the maximum indexable location of xs[i].
|
|
msg = 'This is not supported inside of vmap'
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
x = torch.randn(B0, 3)
|
|
vmap(lambda x: x.as_strided([3], [1], 1))(x)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
x = torch.randn(B0, 3, 5)
|
|
vmap(lambda x: x.as_strided([4, 4], [4, 1], 0))(x)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
x = torch.randn(B0, B1, 3, 5)
|
|
vmap(vmap(lambda x: x.as_strided([4, 4], [4, 1], 0)))(x)
|
|
|
|
# Sanity check #1b: The min indexable location of
|
|
# xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
|
|
# is greater than or equal to the min indexable location of xs[i].
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
x = torch.randn(2, B0, 3)[1]
|
|
vmap(lambda x: x.as_strided([3], [1], B0 * 3 - 1))(x)
|
|
|
|
# Sanity check #1c:
|
|
# xs[i] is a zero-dim tensor, but
|
|
# xs[i].as_strided(sizes, strides, offset + xs[i].offset() - xs.offset())
|
|
# is not
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
x = torch.randn(B0, 0, 3)
|
|
vmap(lambda x: x.as_strided([3], [1]))(x)
|
|
|
|
def test_nll_loss(self):
|
|
test = self._vmap_test
|
|
op = F.nll_loss
|
|
B = 3
|
|
|
|
y = torch.randn(B, 2, 5)
|
|
t = torch.randint(0, 5, (B, 2))
|
|
test(op, (y, t))
|
|
test(functools.partial(op, reduction='sum'), (y, t))
|
|
test(functools.partial(op, reduction='none'), (y, t))
|
|
|
|
y = torch.randn(B, 2, 5)
|
|
t = torch.randint(0, 5, (2,))
|
|
test(op, (y, t), in_dims=(0, None))
|
|
test(functools.partial(op, reduction='sum'), (y, t), in_dims=(0, None))
|
|
test(functools.partial(op, reduction='none'), (y, t), in_dims=(0, None))
|
|
|
|
def test_adaptive_avg_pool2d(self):
|
|
test = self._vmap_test
|
|
op = functools.partial(F.adaptive_avg_pool2d, output_size=(3, 3))
|
|
|
|
x = torch.randn(3, 5, 7, 9, 11)
|
|
test(op, (x,))
|
|
test(op, (x,), in_dims=(1,))
|
|
test(op, (x,), in_dims=(4,))
|
|
|
|
def test_bmm(self):
|
|
op = torch.bmm
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
# shape mismatch
|
|
msg = ""
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(0, None))(torch.randn(B0, 3, 3, 2), torch.randn(2, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
|
|
|
|
# left arg is vmapped
|
|
test(op, (torch.rand(B0, 2, 3, 5), torch.rand(2, 5, 3)), in_dims=(0, None))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 3, 5), torch.rand(2, 5, 3)),
|
|
in_dims=(1, None))
|
|
|
|
# right arg is vmapped
|
|
test(op, (torch.rand(2, 5, 3), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
|
|
test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5, 3), torch.rand(B1, B0, 2, 3, 5)),
|
|
in_dims=(None, 1))
|
|
|
|
# both args are vmapped
|
|
test(op, (torch.rand(B0, 2, 3, 5), torch.rand(B0, 2, 5, 3)))
|
|
test(vmap(op), (torch.rand(B1, B0, 2, 3, 5), torch.rand(B0, B1, 2, 5, 3)), in_dims=(1, 0))
|
|
test(vmap(op, in_dims=(0, None)),
|
|
(torch.rand(B1, 2, 3, 5), torch.rand(B0, 2, 5, 3)), in_dims=(None, 0))
|
|
|
|
def test_cat(self):
|
|
test = self._vmap_test
|
|
B0, B1 = 5, 7
|
|
|
|
# Quick hack b/c vmap can't accept a list of tensors as an argument
|
|
def get_op(dim):
|
|
def op(*tensors):
|
|
return torch.cat(tensors, dim=dim)
|
|
return op
|
|
|
|
test(get_op(0), (torch.rand(B0, 2), torch.rand(B0, 3)))
|
|
test(get_op(0), (torch.rand(B0, 0), torch.rand(B0, 0)))
|
|
test(get_op(0), (torch.rand(2), torch.rand(B0, 0)), in_dims=(None, 0))
|
|
test(get_op(1), (torch.rand(2, 5), torch.rand(B0, 0), torch.rand(2, 3)), in_dims=(None, 0, None))
|
|
test(get_op(1), (torch.rand(B0, 2, 3), torch.rand(B0, 0)))
|
|
test(get_op(1), (torch.rand(B0, 2, 3, 4), torch.rand(0)), in_dims=(0, None))
|
|
test(get_op(0), (torch.rand(0), torch.rand(B0, 2), torch.rand(B0, 0)), in_dims=(None, 0, 0))
|
|
test(get_op(0), (torch.rand(2), torch.rand(B0, 3)), in_dims=(None, 0))
|
|
test(get_op(0), (torch.rand(2, 17), torch.rand(3, 17, B0)), in_dims=(None, 2))
|
|
test(get_op(-1), (torch.rand(17, 2), torch.rand(17, 3, B0)), in_dims=(None, 2))
|
|
test(vmap(get_op(0), in_dims=(0, None)),
|
|
(torch.rand(B1, 2), torch.rand(B0, 3)), in_dims=(None, 0))
|
|
test(vmap(get_op(0), in_dims=(0, 0)),
|
|
(torch.rand(B1, 2), torch.rand(B0, B1, 3)), in_dims=(None, 0))
|
|
|
|
def test_unsafe_view(self):
|
|
# Unsafe view isn't exposed, so we get at it via
|
|
# vmap(grad(matmul))
|
|
test = functools.partial(self._vmap_test, check_propagates_grad=False)
|
|
B = 2
|
|
x = torch.randn(B, 2, 3, 3)
|
|
y = torch.randn(B, 3, 3)
|
|
|
|
def baz(x, y):
|
|
return (x @ y).sum()
|
|
|
|
test(functorch.grad(baz), (x, y))
|
|
|
|
def test_conj(self):
|
|
op = torch.conj
|
|
|
|
def run_test(dtype):
|
|
def get(shape):
|
|
return torch.randn(shape, dtype=dtype)
|
|
B0, B1 = 7, 11
|
|
test = self._vmap_test
|
|
|
|
# Single vmap, various in_dims / out_dims
|
|
test(op, [get([B0, 3])])
|
|
test(op, [get([2, 5, B0, 3])], in_dims=2)
|
|
test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
|
|
|
|
# Doubly nested vmap
|
|
test(vmap(op), [get([B0, B1])])
|
|
test(vmap(op), [get([B1, 2, 5, B0, 3])], in_dims=2)
|
|
test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])],
|
|
in_dims=2, out_dims=2)
|
|
|
|
# correctness tests
|
|
run_test(torch.float)
|
|
run_test(torch.cfloat)
|
|
|
|
# check that torch.conj on a non-complex tensor returns the same tensor
|
|
real_tensor = torch.randn(3)
|
|
result = vmap(op)(real_tensor)
|
|
self.assertEqual(result.data_ptr(), real_tensor.data_ptr())
|
|
|
|
def test_contiguous(self):
|
|
op = Tensor.contiguous
|
|
|
|
self._test_unary(op, TensorFactory.randn, 'cpu')
|
|
|
|
# check that contiguous returns the original tensor if the per-examples
|
|
# are already contiguous
|
|
B0 = 3
|
|
x = torch.randn(B0, 2, 5, 7)
|
|
x = x.movedim(0, 2)
|
|
result = vmap(Tensor.contiguous, in_dims=2, out_dims=2)(x)
|
|
self.assertTrue(result is x)
|
|
|
|
msg = 'NYI: querying is_contiguous inside of vmap for memory_format'
|
|
tensor = torch.randn(B0, 3)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(functools.partial(op, memory_format=torch.channels_last))(tensor)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(functools.partial(op, memory_format=torch.channels_last_3d))(tensor)
|
|
|
|
def test_stride(self):
|
|
B0 = 3
|
|
|
|
x = torch.randn(B0, 2, 5, 7)
|
|
|
|
def foo(x):
|
|
assert x.stride() == (7 * 5, 7, 1)
|
|
return x
|
|
|
|
vmap(foo)(x)
|
|
|
|
x = torch.randn(2, B0, 5, 7).movedim(1, 0)
|
|
|
|
def bar(x):
|
|
assert x.stride() == (7 * 5 * B0, 7, 1)
|
|
return x
|
|
|
|
vmap(bar)(x)
|
|
|
|
def test_chunk(self):
|
|
test = self._vmap_view_test
|
|
op = torch.chunk
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
# tests for torch.split(self, split_size: int, dim)
|
|
test(op, (torch.rand(B0, 2, 1024), 15, -1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 1024), 9, 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 4, 0),
|
|
in_dims=(2, None, None))
|
|
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
|
|
|
|
def test_clamp(self):
|
|
clamp_cases = (
|
|
(lambda t: t.clamp(min=-0.5), TensorFactory.randn),
|
|
(lambda t: t.clamp(max=0.5), TensorFactory.randn),
|
|
(lambda t: t.clamp(min=-0.5, max=0.5), TensorFactory.randn),
|
|
(lambda t: t.clamp_min(min=-0.5), TensorFactory.randn),
|
|
(lambda t: t.clamp_max(max=0.5), TensorFactory.randn),
|
|
)
|
|
for op, getter in clamp_cases:
|
|
self._test_unary(op, getter, 'cpu')
|
|
|
|
def test_comparison_ops(self):
|
|
test = functools.partial(self._vmap_test, check_propagates_grad=False)
|
|
|
|
getter = TensorFactory.randn
|
|
B0, B1 = 7, 11
|
|
|
|
ops = (
|
|
torch.eq, lambda x, y: x == y,
|
|
torch.gt, lambda x, y: x > y,
|
|
torch.ge, lambda x, y: x >= y,
|
|
torch.le, lambda x, y: x <= y,
|
|
torch.lt, lambda x, y: x < y,
|
|
torch.ne, lambda x, y: x != y,
|
|
)
|
|
|
|
for op in ops:
|
|
# Single vmap: op(Tensor, Tensor)
|
|
test(op, (getter([B0, 3]), getter([B0, 3])))
|
|
test(op, (getter([B0]), getter([B0, 2, 3])))
|
|
test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1))
|
|
test(op, (getter([B0]), getter([2, B0, 3])), in_dims=(0, 1), out_dims=1)
|
|
test(op, (getter([B0]), getter([2, 3])), in_dims=(0, None))
|
|
test(op, (getter([2, 3]), getter([B0, 3])), in_dims=(0, None))
|
|
|
|
# Nested vmap: op(Tensor, Tensor)
|
|
test(vmap(op), (getter([B0, B1, 2, 3]), getter([B0, B1, 3])))
|
|
test(vmap(op, in_dims=(None, 0)),
|
|
(getter([B0, 2, 3]), getter([B1, 3])), in_dims=(0, None))
|
|
|
|
# test number as inputs
|
|
number = getter([]).item()
|
|
self._test_unary(lambda t: op(t, number), getter, 'cpu', check_propagates_grad=False)
|
|
|
|
def test_cross_batch_size_three(self):
|
|
# Let's test corner case when batch_size is 3 and cross' dim argument is not specified
|
|
# According to the cross API, dim will be assigned to the first dim with value 3
|
|
# In this test we ensure that found dim is not batch dim.
|
|
op = torch.cross
|
|
test = self._vmap_test
|
|
B0 = B1 = 3
|
|
test(op, (torch.rand(B0, 2, 3), torch.rand(B0, 2, 3)))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B0, B1, 2, 3), torch.rand(B0, B1, 2, 3)),
|
|
in_dims=(None, 1))
|
|
|
|
def test_diagonal(self):
|
|
tensor = torch.randn(3, 5, 7, 11, 13)
|
|
test = self._vmap_view_test
|
|
op = torch.diagonal
|
|
test(op, (tensor, 1, 0, 1), in_dims=(0, None, None, None))
|
|
test(op, (tensor, 0, 2, -1), in_dims=(0, None, None, None))
|
|
test(op, (tensor, 2, 1, 2), in_dims=(1, None, None, None))
|
|
test(op, (tensor, 0, -2, -1), in_dims=(1, None, None, None), out_dims=1)
|
|
test(vmap(lambda t: op(t, 0, 0, -1)), (tensor,), in_dims=1, out_dims=1)
|
|
test(vmap(vmap(lambda t: op(t, 0, 0, 1), in_dims=1), in_dims=3),
|
|
(tensor,), in_dims=1, out_dims=1)
|
|
|
|
def test_dot(self):
|
|
op = torch.dot
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
# shape mismatch
|
|
msg = ""
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2))
|
|
|
|
# left arg is vmapped
|
|
test(op, (torch.rand(B0, 5), torch.rand(5)), in_dims=(0, None))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 5), torch.rand(5)),
|
|
in_dims=(1, None))
|
|
|
|
# right arg is vmapped
|
|
test(op, (torch.rand(5), torch.rand(B0, 5)), in_dims=(None, 0))
|
|
test(vmap(op, in_dims=(None, 0)), (torch.rand(5), torch.rand(B1, B0, 5)),
|
|
in_dims=(None, 1))
|
|
|
|
# both args are vmapped
|
|
test(op, (torch.rand(B0, 5), torch.rand(B0, 5)))
|
|
test(vmap(op), (torch.rand(B1, B0, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
|
|
test(vmap(op, in_dims=(0, None)),
|
|
(torch.rand(B1, 5), torch.rand(B0, 5)), in_dims=(None, 0))
|
|
|
|
def test_expand_as(self):
|
|
op = torch.Tensor.expand_as
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(op, (torch.rand(B0, 1, 5), torch.rand(B0, 2, 3, 5)))
|
|
test(op, (torch.rand(B0, 1, 5), torch.rand(2, 3, 5)), in_dims=(0, None))
|
|
test(op, (torch.rand(1, 5), torch.rand(B0, 2, 3, 5)), in_dims=(None, 0))
|
|
test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B0, B1, 2, 3, 5)))
|
|
test(vmap(op), (torch.rand(B0, B1, 1, 5), torch.rand(B1, B0, 2, 3, 5)), in_dims=(0, 1))
|
|
test(vmap(op), (torch.rand(B0, B1), torch.rand(B1, 2, 3, 5)), in_dims=(0, None))
|
|
test(vmap(vmap(op)), (torch.rand(B0, B1, B2), torch.rand(B0, B1, B2, 2, 3, 5)))
|
|
|
|
def test_fill_and_zero_inplace(self):
|
|
test = functools.partial(self._vmap_test, check_propagates_grad=False)
|
|
B0, B1 = 7, 11
|
|
ops = (
|
|
lambda t: t.fill_(0.1),
|
|
lambda t: t.fill_(torch.tensor(0.2)),
|
|
lambda t: t.zero_(),
|
|
)
|
|
|
|
for op in ops:
|
|
# Single vmap, various in_dims / out_dims
|
|
test(op, [TensorFactory.randn([B0, 3])])
|
|
test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2)
|
|
test(op, [TensorFactory.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
|
|
|
|
# Doubly nested vmap
|
|
test(vmap(op), [TensorFactory.randn([B0, B1])])
|
|
test(vmap(op), [TensorFactory.randn([B1, 2, 5, B0, 3])], in_dims=2)
|
|
test(vmap(op, in_dims=2), [TensorFactory.randn([2, 5, B0, B1, 3])],
|
|
in_dims=2, out_dims=2)
|
|
|
|
# test when value is a batched tensor for fill_ operator
|
|
B0, B1 = 3, 5
|
|
test(Tensor.fill_, [TensorFactory.randn([B0, B1]), TensorFactory.randn(B0)])
|
|
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
""):
|
|
# Runtime Error is thrown when the tensor being written to isn't being vmapped over
|
|
vmap(Tensor.fill_, (None, 0))(TensorFactory.randn([B0, B1]),
|
|
TensorFactory.randn([B0]))
|
|
|
|
def _test_complex_views(self, op, dtypes):
|
|
test = self._vmap_view_test
|
|
|
|
def run_test(op, dtype):
|
|
def get(shape):
|
|
return torch.randn(shape, dtype=dtype)
|
|
|
|
B0, B1 = 7, 11
|
|
|
|
# Single vmap, various in_dims / out_dims
|
|
test(op, [get([B0, 3])])
|
|
test(op, [get([3, B0])], in_dims=1)
|
|
test(op, [get([2, 5, B0, 3])], in_dims=2)
|
|
test(op, [get([2, 5, B0, 3])], in_dims=2, out_dims=2)
|
|
|
|
# Doubly nested vmap
|
|
test(vmap(op), [get([B0, B1])])
|
|
test(vmap(op), [get([B1, 2, 5, 3, B0])], in_dims=4)
|
|
test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3])],
|
|
in_dims=2, out_dims=2)
|
|
|
|
for dtype in dtypes:
|
|
run_test(op, dtype)
|
|
|
|
def test_real(self):
|
|
self._test_complex_views(torch.real, dtypes=[torch.cfloat, torch.cdouble])
|
|
|
|
def test_imag(self):
|
|
self._test_complex_views(torch.imag, dtypes=[torch.cfloat, torch.cdouble])
|
|
|
|
def test_view_as_real(self):
|
|
self._test_complex_views(torch.view_as_real, dtypes=[torch.cfloat, torch.cdouble])
|
|
|
|
def test_view_as_complex(self):
|
|
def run_test(dtype):
|
|
def get(shape):
|
|
return torch.randn(shape, dtype=dtype)
|
|
|
|
op = torch.view_as_complex
|
|
test = self._vmap_view_test
|
|
B0, B1 = 7, 11
|
|
|
|
# Single vmap, various in_dims / out_dims
|
|
test(op, [get([B0, 3, 2])])
|
|
test(op, [get([2, 5, B0, 3, 2])], in_dims=2)
|
|
test(op, [get([2, 5, B0, 3, 2])], in_dims=2, out_dims=2)
|
|
|
|
# Doubly nested vmap
|
|
test(vmap(op), [get([B0, B1, 2])])
|
|
test(vmap(op), [get([B1, 2, 5, B0, 3, 2])], in_dims=2)
|
|
test(vmap(op, in_dims=2), [get([2, 5, B0, B1, 3, 2])],
|
|
in_dims=2, out_dims=2)
|
|
|
|
# Interesting case #1: Batch dim directly before dim of size 2
|
|
test(op, [get([3, B0, 2])], in_dims=1)
|
|
test(vmap(op, in_dims=1), [get([3, B1, B0, 2])], in_dims=2)
|
|
|
|
# Interesting case #2: Batch dim at end of tensor, success cases
|
|
# view_as_complex requires that the dim with size 2 have stride 1
|
|
# in order for the view to function propertly
|
|
test(op, [get([B0, 2]).transpose(0, 1)], in_dims=1)
|
|
test(vmap(op, in_dims=1), [get([B0, B1, 2]).movedim(1, 2)])
|
|
test(vmap(op, in_dims=2), [get([B0, 3, B1, 2]).movedim(2, 3)])
|
|
|
|
# Interesting case #3: Batch dim at end of tensor, failure cases
|
|
msg = "Tensor must have a last dimension with stride 1"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=1)(get([2, B0]))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(vmap(op, in_dims=1), in_dims=1)(get([2, B0, B1]))
|
|
|
|
# Invalid input: no dimension of size 2
|
|
msg = 'Input tensor must have one or more dimensions'
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(get([B0]))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(vmap(op))(get([B0, B1]))
|
|
|
|
# Invalid input: Batch dim has size 2, but the logical last dim does
|
|
# not have size 2
|
|
msg = 'Tensor must have a last dimension of size 2'
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=1)(get([3, 2]))
|
|
|
|
for dtype in [torch.float, torch.double]:
|
|
run_test(dtype)
|
|
|
|
def test_is_complex(self):
|
|
ctensor = torch.randn(3, dtype=torch.cfloat)
|
|
tensor = torch.randn(3)
|
|
|
|
def foo(x):
|
|
if x.is_complex():
|
|
return torch.tensor(1)
|
|
else:
|
|
return torch.tensor(0)
|
|
|
|
self.assertEqual(vmap(foo)(ctensor), torch.tensor([1, 1, 1]))
|
|
self.assertEqual(vmap(foo)(tensor), torch.tensor([0, 0, 0]))
|
|
|
|
def test_is_floating_point(self):
|
|
float_tensor = torch.tensor([1., 2., 3.])
|
|
long_tensor = torch.tensor([1, 2, 3])
|
|
|
|
def foo(x):
|
|
if x.is_floating_point():
|
|
return torch.tensor(1)
|
|
else:
|
|
return torch.tensor(0)
|
|
|
|
self.assertEqual(vmap(foo)(float_tensor), torch.tensor([1, 1, 1]))
|
|
self.assertEqual(vmap(foo)(long_tensor), torch.tensor([0, 0, 0]))
|
|
|
|
def test_is_contiguous(self):
|
|
def foo(x):
|
|
if x.is_contiguous():
|
|
return torch.tensor(1.)
|
|
else:
|
|
return torch.tensor(0.)
|
|
|
|
B0, B1 = 3, 5
|
|
|
|
# Single batch dim
|
|
contig = torch.randn(B0, 2, 7)
|
|
self.assertEqual(vmap(foo)(contig), torch.ones(B0))
|
|
|
|
noncontig = torch.randn(2, B0, 7)
|
|
self.assertEqual(vmap(foo, in_dims=1)(noncontig), torch.zeros(B0))
|
|
|
|
noncontig = torch.randn(2, B0, 7).movedim(1, 0)
|
|
self.assertEqual(vmap(foo)(noncontig), torch.zeros(B0))
|
|
|
|
noncontig = torch.randn(2, 7, B0)
|
|
self.assertEqual(vmap(foo, in_dims=2)(noncontig), torch.zeros(B0))
|
|
|
|
# Multiple batch dims
|
|
contig = torch.randn(B0, B1, 3)
|
|
self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
|
|
|
|
contig = torch.randn(B1, B0, 3)
|
|
self.assertEqual(vmap(vmap(foo), in_dims=1)(contig), torch.ones(B0, B1))
|
|
|
|
contig = torch.randn(B1, B0, 3).movedim(0, 1)
|
|
self.assertEqual(vmap(vmap(foo))(contig), torch.ones(B0, B1))
|
|
|
|
noncontig = torch.randn(B0, 3, B1)
|
|
self.assertEqual(vmap(vmap(foo, in_dims=1))(noncontig), torch.zeros(B0, B1))
|
|
|
|
# is_contiguous on empty tensor is True
|
|
def bar(x):
|
|
assert x.is_contiguous()
|
|
return x
|
|
|
|
vmap(bar)(torch.randn(B0, 0, 3))
|
|
vmap(bar, in_dims=1)(torch.randn(0, B0, 3))
|
|
vmap(bar)(torch.randn(B0, 0, 3).transpose(-1, -2))
|
|
|
|
# is_contiguous with other memory formats
|
|
def baz(x, memory_format):
|
|
x.is_contiguous(memory_format=memory_format)
|
|
return x
|
|
|
|
msg = 'NYI: querying is_contiguous inside of vmap for memory_format'
|
|
tensor = torch.randn(B0, 2, 7, 3)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(functools.partial(baz, memory_format=torch.channels_last))(tensor)
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(functools.partial(baz, memory_format=torch.channels_last_3d))(tensor)
|
|
|
|
def test_unsqueeze(self):
|
|
op = torch.unsqueeze
|
|
test = self._vmap_view_test
|
|
B0, B1 = 7, 11
|
|
|
|
# unsqueeze dim 0
|
|
test(op, (torch.rand(B0, 2, 5), 0), in_dims=(0, None))
|
|
test(op, (torch.rand(2, B0, 5), 0), in_dims=(1, None))
|
|
|
|
# unsqueeze last dim (positive)
|
|
test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
|
|
test(op, (torch.rand(2, B0, 5), 2), in_dims=(1, None))
|
|
|
|
# unsqueeze last dim (negative)
|
|
test(op, (torch.rand(B0, 2, 5), -1), in_dims=(0, None))
|
|
test(op, (torch.rand(2, B0, 5), -1), in_dims=(1, None))
|
|
|
|
# nested vmaps
|
|
def unsqueeze_0(x):
|
|
return torch.unsqueeze(x, 0)
|
|
|
|
def unsqueeze_last(x):
|
|
return torch.unsqueeze(x, -1)
|
|
|
|
# bdims in canonical order
|
|
test(vmap(unsqueeze_0), (torch.rand(B0, B1, 2), ))
|
|
test(vmap(unsqueeze_last), (torch.rand(B0, B1, 2),))
|
|
|
|
# wild bdims
|
|
test(vmap(unsqueeze_0), (torch.rand(B1, 2, B0),), in_dims=2)
|
|
test(vmap(unsqueeze_0, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
|
|
test(vmap(unsqueeze_last), (torch.rand(B1, 2, B0),), in_dims=2)
|
|
test(vmap(unsqueeze_last, in_dims=1), (torch.rand(2, B1, B0),), in_dims=2)
|
|
|
|
def test_movedim(self):
|
|
op = torch.movedim
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
# movedim(tensor, int, int) variant
|
|
test(op, (torch.rand(B0, 2, 5), 0, 1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 5), 0, 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 2, B0, 5), 0, 1), in_dims=(2, None, None))
|
|
test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
|
|
(torch.rand(B1, 2, B0, 5, B2), 0, 1), in_dims=(2, None, None))
|
|
|
|
# movedim(tensor, intlist, intlist) variant
|
|
test(op, (torch.rand(B0, 2, 3, 5), [1, 0], [0, 2]), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, 3, B0, 5), [1, 0], [0, 2]), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)),
|
|
(torch.rand(B1, 2, B0, 5), [0, 1], [1, 0]), in_dims=(2, None, None))
|
|
test(vmap(vmap(op, in_dims=(2, None, None)), in_dims=(0, None, None)),
|
|
(torch.rand(B1, 2, B0, 5, B2), [0, 1], [1, 0]), in_dims=(2, None, None))
|
|
|
|
def test_mm(self):
|
|
op = torch.mm
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
# shape mismatch
|
|
msg = "Shape mismatch"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(0, None))(torch.randn(B0, 2), torch.randn(2, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2, 2))
|
|
|
|
# left arg is vmapped
|
|
test(op, (torch.rand(B0, 2, 5), torch.rand(5, 2)), in_dims=(0, None))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5, 2)),
|
|
in_dims=(1, None))
|
|
|
|
# right arg is vmapped
|
|
test(op, (torch.rand(2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
|
|
test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5, 2)),
|
|
in_dims=(None, 1))
|
|
|
|
# both args are vmapped
|
|
test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5, 2)))
|
|
test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5, 2)), in_dims=(1, 0))
|
|
test(vmap(op, in_dims=(0, None)),
|
|
(torch.rand(B1, 2, 5), torch.rand(B0, 5, 2)), in_dims=(None, 0))
|
|
|
|
def test_mv(self):
|
|
op = torch.mv
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
# shape mismatch
|
|
msg = ""
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op)(torch.randn(B0, 2, 2, 2), torch.randn(B0, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(0, None))(torch.randn(B0, 2, 2), torch.randn(2, 2))
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vmap(op, in_dims=(None, 0))(torch.randn(2, 2), torch.randn(B0, 2, 2))
|
|
|
|
# left arg is vmapped
|
|
test(op, (torch.rand(B0, 2, 5), torch.rand(5)), in_dims=(0, None))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, B0, 2, 5), torch.rand(5)),
|
|
in_dims=(1, None))
|
|
|
|
# right arg is vmapped
|
|
test(op, (torch.rand(2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
|
|
test(vmap(op, in_dims=(None, 0)), (torch.rand(2, 5), torch.rand(B1, B0, 5)),
|
|
in_dims=(None, 1))
|
|
|
|
# both args are vmapped
|
|
test(op, (torch.rand(B0, 2, 5), torch.rand(B0, 5)))
|
|
test(vmap(op), (torch.rand(B1, B0, 2, 5), torch.rand(B0, B1, 5)), in_dims=(1, 0))
|
|
test(vmap(op, in_dims=(0, None)),
|
|
(torch.rand(B1, 2, 5), torch.rand(B0, 5)), in_dims=(None, 0))
|
|
|
|
def test_narrow(self):
|
|
op = torch.narrow
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
test(op, (torch.rand(B0, 2, 5), -1, 1, 3), in_dims=(0, None, None, None))
|
|
test(op, (torch.rand(2, B0, 5), 1, 1, 3), in_dims=(1, None, None, None))
|
|
test(vmap(op, in_dims=(0, None, None, None)),
|
|
(torch.rand(B1, 2, B0, 5), 1, 0, 0), in_dims=(2, None, None, None))
|
|
test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
|
|
(torch.rand(B1, 2, B0, 5, B2), -1, 2, 3), in_dims=(2, None, None, None))
|
|
|
|
def test_new_empty(self):
|
|
# Empty is non-deterministic so we just check that the shape of the
|
|
# output tensor is what we expect and that the vmap fallback isn't used.
|
|
op = Tensor.new_empty
|
|
|
|
B0, B1 = 7, 11
|
|
|
|
result = vmap(lambda x: op(x, [2, 3]))(torch.randn(B0))
|
|
self.assertEqual(result.shape, [B0, 2, 3])
|
|
|
|
result = vmap(lambda x: op(x, []))(torch.randn(B0))
|
|
self.assertEqual(result.shape, [B0])
|
|
|
|
result = vmap(vmap(lambda x: op(x, [2, 3])))(torch.randn(B0, B1))
|
|
self.assertEqual(result.shape, [B0, B1, 2, 3])
|
|
|
|
def test_new_empty_strided(self):
|
|
# Empty is non-deterministic so we just check that the size and shape
|
|
# of the output are what we expect and that the vmap fallback isn't used
|
|
B0, B1 = 7, 11
|
|
|
|
def _test_single_vmap(size, stride, B0):
|
|
x = torch.randn(B0)
|
|
result = vmap(lambda x: x.new_empty_strided(size, stride))(x)
|
|
S = torch.empty_strided(size, stride).storage().size()
|
|
self.assertEqual(result.shape, [B0] + size)
|
|
self.assertEqual(result.stride(), [S] + stride)
|
|
|
|
def _test_double_vmap(size, stride, B0, B1):
|
|
x = torch.randn(B0, B1)
|
|
result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)))(x)
|
|
S = torch.empty_strided(size, stride).storage().size()
|
|
self.assertEqual(result.shape, [B0, B1] + size)
|
|
self.assertEqual(result.stride(), [B1 * S, S] + stride)
|
|
|
|
x = torch.randn(B1, B0)
|
|
result = vmap(vmap(lambda x: x.new_empty_strided(size, stride)), in_dims=1)(x)
|
|
S = x.new_empty_strided(size, stride).storage().size()
|
|
self.assertEqual(result.shape, [B0, B1] + size)
|
|
self.assertEqual(result.stride(), [B1 * S, S] + stride)
|
|
|
|
# contiguous case
|
|
_test_single_vmap([2, 3, 5], [3 * 5, 5, 1], B0)
|
|
_test_double_vmap([2, 3, 5], [3 * 5, 5, 1], B0, B1)
|
|
|
|
# expanded
|
|
_test_single_vmap([2, 3, 5], [0, 5, 1], B0)
|
|
_test_double_vmap([2, 3, 5], [0, 5, 1], B0, B1)
|
|
|
|
# some of these cases are pretty strange, just verifying that if
|
|
# empty_strided allows them then BatchedTensor.new_empty_strided
|
|
# can as well
|
|
for shape in [[2, 3, 4], [0, 2, 0]]:
|
|
for strides in [[12, 4, 1], [2, 4, 6], [0, 0, 0]]:
|
|
_test_single_vmap(shape, strides, B0)
|
|
_test_double_vmap(shape, strides, B0, B1)
|
|
|
|
def test_new_zeros(self):
|
|
op = Tensor.new_zeros
|
|
test = functools.partial(self._vmap_test, check_propagates_grad=False)
|
|
B0, B1 = 7, 11
|
|
|
|
test(lambda x: op(x, 2, 3), (torch.rand(B0),))
|
|
test(lambda x: op(x, []), (torch.rand(B0),))
|
|
test(vmap(lambda x: op(x, 3, 5)), (torch.rand(B0, B1),))
|
|
|
|
def test_select(self):
|
|
op = torch.select
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(op, (torch.rand(B0, 2, 5), 0, 0), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 5), 1, 1), in_dims=(1, None, None))
|
|
test(vmap(lambda t: op(t, 1, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
|
|
test(vmap(vmap(lambda t: op(t, 1, 1), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)
|
|
|
|
def test_roll_no_dims(self):
|
|
op = torch.roll
|
|
test = self._vmap_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(op, (torch.rand(B0, 2, 5), 2), in_dims=(0, None))
|
|
test(op, (torch.rand(2, B0, 5), 3), in_dims=(1, None))
|
|
test(vmap(lambda t: op(t, 3)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
|
|
test(vmap(vmap(lambda t: op(t, 3), in_dims=1)), (torch.rand(B1, 2, B0, B2, 5),), in_dims=2)
|
|
|
|
def test_stack(self):
|
|
test = self._vmap_test
|
|
B0, B1 = 5, 7
|
|
|
|
# Quick hack b/c vmap can't accept a list of tensors as an argument
|
|
def get_op(dim):
|
|
def op(*tensors):
|
|
return torch.stack(tensors, dim=dim)
|
|
return op
|
|
|
|
test(get_op(0), (torch.rand(B0, 3), torch.rand(B0, 3)))
|
|
test(get_op(0), (torch.rand(3), torch.rand(B0, 3)), in_dims=(None, 0))
|
|
test(get_op(0), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
|
|
test(get_op(-1), (torch.rand(2, 17), torch.rand(2, 17, B0)), in_dims=(None, 2))
|
|
test(vmap(get_op(0), in_dims=(0, None)),
|
|
(torch.rand(B1, 2), torch.rand(B0, 2)), in_dims=(None, 0))
|
|
test(vmap(get_op(0), in_dims=(0, 0)),
|
|
(torch.rand(B1, 2), torch.rand(B0, B1, 2)), in_dims=(None, 0))
|
|
|
|
def test_slice(self):
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(lambda t: t[0:1], (torch.rand(B0, 3, 5),))
|
|
test(lambda t: t[:, 1:3], (torch.rand(3, 5, B0),), in_dims=2)
|
|
test(vmap(lambda t: t[:, 0:1], in_dims=2), (torch.rand(3, 5, B0, B1),), in_dims=2)
|
|
test(vmap(vmap(lambda t: t[0:1], in_dims=2), in_dims=2),
|
|
(torch.rand(3, 5, B0, B1, B2),), in_dims=2)
|
|
|
|
def test_squeeze(self):
|
|
def verify_behavior(op, min_ndim=1):
|
|
test = self._vmap_view_test
|
|
B0, B1 = 1, 11
|
|
# These tests cannot be used with an operator that requires more
|
|
# than 1 dimension after batching.
|
|
if min_ndim <= 1:
|
|
test(op, (torch.rand(B0),))
|
|
test(op, (torch.rand(B1),))
|
|
test(vmap(op), (torch.rand(B0, B1, 1),))
|
|
test(vmap(op), (torch.rand(B1, 1, B0),), in_dims=2)
|
|
test(op, (torch.rand(B0, 3, 5),))
|
|
test(op, (torch.rand(1, B0, 5),), in_dims=1)
|
|
test(op, (torch.rand(B0, 0, 1, 5, 1),))
|
|
test(op, (torch.rand(B0, 1, 1, 1, 1),))
|
|
test(vmap(op), (torch.rand(B0, B1, 1, 3, 4),))
|
|
test(vmap(op), (torch.rand(B1, 1, B0, 4, 5),), in_dims=2)
|
|
|
|
verify_behavior(torch.squeeze)
|
|
verify_behavior(lambda x: torch.squeeze(x, dim=0), min_ndim=1)
|
|
verify_behavior(lambda x: torch.squeeze(x, dim=1), min_ndim=2)
|
|
verify_behavior(lambda x: torch.squeeze(x, dim=-1), min_ndim=2)
|
|
verify_behavior(lambda x: torch.squeeze(x, dim=-2), min_ndim=3)
|
|
|
|
msg = ""
|
|
try:
|
|
torch.squeeze(torch.rand(10), dim=1)
|
|
except IndexError as err:
|
|
msg = str(err)
|
|
with self.assertRaises(RuntimeError, msg=msg):
|
|
vmap(lambda x: torch.squeeze(x, dim=1))(torch.rand(10))
|
|
|
|
def _test_mean_sum_dim(self, op):
|
|
test = self._vmap_test
|
|
B0, B1 = 5, 7
|
|
|
|
# Single vmap, various in_dims / out_dims
|
|
test(lambda x: op(x, 0), [torch.randn([B0])])
|
|
test(lambda x: op(x, -1), [torch.randn([B0])])
|
|
test(lambda x: op(x, 0), [torch.randn([B0, 3])])
|
|
test(lambda x: op(x, -1), [torch.randn([2, 5, B0, 3])], in_dims=2)
|
|
test(lambda x: op(x, 2), [torch.randn([2, 5, B0, 3])], in_dims=2, out_dims=2)
|
|
|
|
# Doubly nested vmap
|
|
test(vmap(lambda x: op(x, 0)), [torch.randn([B0, B1])])
|
|
test(vmap(lambda x: op(x, -1)), [torch.randn([B0, B1])])
|
|
test(vmap(lambda x: op(x, -2)), [torch.randn([B1, 2, 5, B0, 3])], in_dims=2)
|
|
test(vmap(lambda x: op(x, 2), in_dims=2), [torch.randn([2, 5, B0, B1, 3])],
|
|
in_dims=2, out_dims=2)
|
|
|
|
def test_sum_dim(self):
|
|
self._test_mean_sum_dim(torch.sum)
|
|
|
|
def test_mean_dim(self):
|
|
self._test_mean_sum_dim(torch.mean)
|
|
|
|
def test_argmax_dim(self):
|
|
def test(f, args):
|
|
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}):
|
|
self.assertEqual(loop_out, batched_out)
|
|
B0 = 5
|
|
test(lambda x: torch.argmax(x), [torch.randn(B0)])
|
|
test(lambda x: torch.argmax(x), [torch.randn(B0, 2, 3)])
|
|
test(lambda x: torch.argmax(x, 0), [torch.randn(B0, 2, 3)])
|
|
test(lambda x: torch.argmax(x, -1), [torch.randn(B0, 2, 3)])
|
|
test(lambda x: torch.argmax(x, 2), [torch.randn(B0, 2, 3)])
|
|
|
|
def _test_sum_mean(self, op):
|
|
test = self._vmap_test
|
|
B0, B1 = 5, 7
|
|
|
|
# Single vmap, various in_dims / out_dims
|
|
test(op, [torch.randn([B0])])
|
|
test(op, [torch.randn([B0, 3])])
|
|
test(op, [torch.randn([2, 5, B0, 3])], in_dims=2)
|
|
test(op, [torch.randn([2, 5, B0, 3])], in_dims=2)
|
|
|
|
# Doubly nested vmap
|
|
test(vmap(op), [torch.randn([B0, B1])])
|
|
test(vmap(op), [torch.randn([B1, 2, 5, B0, 3])])
|
|
test(vmap(op), [torch.randn([2, 5, B0, B1, 3])], in_dims=2)
|
|
|
|
def test_sum(self):
|
|
self._test_sum_mean(torch.sum)
|
|
|
|
def test_mean(self):
|
|
self._test_sum_mean(torch.mean)
|
|
|
|
def test_repeat(self):
|
|
test = self._vmap_test
|
|
B0 = 7
|
|
op = Tensor.repeat
|
|
test(lambda x: op(x, (2, 3)), (torch.rand(B0, 1, 1),))
|
|
test(lambda x: op(x, (2, 3)), (torch.rand(1, B0, 1),), in_dims=1)
|
|
|
|
def test_slogdet(self):
|
|
test = functools.partial(self._vmap_test, check_propagates_grad=False)
|
|
B0 = 7
|
|
op = torch.linalg.slogdet
|
|
test(op, (torch.rand(B0, 1, 1),))
|
|
test(op, (torch.rand(B0, 2, 2),))
|
|
test(op, (torch.rand(B0, 3, 2, 2),))
|
|
test(op, (torch.rand(3, 2, 2, B0),), in_dims=3)
|
|
|
|
def test_reshape(self):
|
|
test = self._vmap_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
op = torch.reshape
|
|
test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None), check_view=True)
|
|
test(op, (torch.rand(2, B0, 5), [1, 1, 10]), in_dims=(1, None), check_view=False)
|
|
test(vmap(lambda t: t.reshape([-1])), (torch.rand(B0, B1, 2, 5),), check_view=True)
|
|
test(vmap(vmap(lambda t: t.reshape([-1]), in_dims=2), in_dims=1),
|
|
(torch.rand(3, B1, 2, B2, 5, B0),), in_dims=5, check_view=False)
|
|
|
|
def test_reshape_as(self):
|
|
test = self._vmap_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
op = torch.Tensor.reshape_as
|
|
test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)), check_view=True)
|
|
test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0), check_view=True)
|
|
test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None), check_view=True)
|
|
|
|
test(op, (torch.rand(2, B0, 5), torch.rand(1, 1, 10)), in_dims=(1, None), check_view=False)
|
|
|
|
test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)), check_view=True)
|
|
test(vmap(vmap(op, in_dims=(2, None)), in_dims=(1, None)),
|
|
(torch.rand(3, B1, 2, B2, 5, B0), torch.rand(B0, 3 * 2 * 5)),
|
|
in_dims=(5, 0), check_view=False)
|
|
|
|
def test_result_type(self):
|
|
def scalar_tensor_with_dtype(op):
|
|
def wrapped(*args, **kwargs):
|
|
dtype = op(*args, **kwargs)
|
|
return torch.ones([], dtype=dtype)
|
|
return wrapped
|
|
|
|
test = self._vmap_test
|
|
op = scalar_tensor_with_dtype(torch.result_type)
|
|
|
|
B0 = 2
|
|
|
|
test(op, (torch.randn(B0), torch.randn(B0, dtype=torch.float64)),
|
|
check_propagates_grad=False)
|
|
test(op, (torch.randn(B0), torch.randint(10, [B0], dtype=torch.int64)),
|
|
check_propagates_grad=False)
|
|
|
|
test(lambda x: op(x, 1), (torch.randn(B0),), check_propagates_grad=False)
|
|
test(lambda x: op(x, 1.6), (torch.randn(B0),), check_propagates_grad=False)
|
|
|
|
test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0),),
|
|
check_propagates_grad=False)
|
|
test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
|
|
(torch.randn(B0),), check_propagates_grad=False)
|
|
|
|
test(op, (torch.randn(B0, 2), torch.randn(B0, 2, dtype=torch.float64)),
|
|
check_propagates_grad=False)
|
|
test(op, (torch.randn(B0, 2), torch.randint(10, [B0, 2], dtype=torch.int64)),
|
|
check_propagates_grad=False)
|
|
|
|
test(lambda x: op(x, 1), (torch.randn(B0, 2),), check_propagates_grad=False)
|
|
test(lambda x: op(x, 1.6), (torch.randn(B0, 2),), check_propagates_grad=False)
|
|
|
|
test(lambda x: op(x, torch.tensor(1)), (torch.randn(B0, 2),),
|
|
check_propagates_grad=False)
|
|
test(lambda x: op(x, torch.tensor(1.6, dtype=torch.double)),
|
|
(torch.randn(B0, 2),), check_propagates_grad=False)
|
|
|
|
test(op, (torch.randn(B0, 2), torch.randn(B0, dtype=torch.float64)),
|
|
check_propagates_grad=False)
|
|
test(op, (torch.randn(B0, 2), torch.randint(10, [B0], dtype=torch.int64)),
|
|
check_propagates_grad=False)
|
|
|
|
def test_tensor_split(self):
|
|
test = self._vmap_view_test
|
|
op = torch.tensor_split
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
# tests for torch.tensor_split(self, indices_or_sections: int, dim)
|
|
test(op, (torch.rand(B0, 2, 1024), 5, -1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 1024), 150, 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
|
|
in_dims=(2, None, None))
|
|
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
|
|
|
|
# tests for torch.tensor_split(self, indices_or_sections: List[int], dim)
|
|
test(op, (torch.rand(B0, 2, 1024), [50, 100, 378, 890], -1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 1024), [50, 100, 212, 345, 0, 378, 890], 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [50, 100, 212, 345, 0, 378, 890], 0),
|
|
in_dims=(2, None, None))
|
|
test(vmap(vmap(lambda t: op(t, [4, 8, 9, 34, 29], 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
|
|
|
|
def test_split(self):
|
|
test = self._vmap_view_test
|
|
op = torch.split
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
# tests for torch.split(self, split_size: int, dim)
|
|
test(op, (torch.rand(B0, 2, 1024), 101, -1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 1024), 130, 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), 256, 0),
|
|
in_dims=(2, None, None))
|
|
test(vmap(vmap(lambda t: op(t, 4, 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
|
|
|
|
# tests for torch.split(self, split_size: List[int], dim)
|
|
test(op, (torch.rand(B0, 2, 1024), [1, 1020, 3], -1), in_dims=(0, None, None))
|
|
test(op, (torch.rand(2, B0, 1024), [100] * 10 + [24], 1), in_dims=(1, None, None))
|
|
test(vmap(op, in_dims=(0, None, None)), (torch.rand(B1, 1023, B0, 5), [256] * 3 + [255], 0),
|
|
in_dims=(2, None, None))
|
|
test(vmap(vmap(lambda t: op(t, [4] * 8 + [8] * 4, 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 64, B2),), in_dims=2)
|
|
|
|
def test_trace(self):
|
|
op = torch.trace
|
|
test = self._vmap_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(op, (torch.rand(B0, 2, 5),))
|
|
test(op, (torch.rand(2, B0, 5),), in_dims=1)
|
|
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
|
|
test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
|
|
|
|
def test_transpose(self):
|
|
op = torch.transpose
|
|
test = self._vmap_view_test
|
|
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(lambda x: op(x, 0, 1), (torch.rand(B0, 2, 5),))
|
|
test(lambda x: op(x, -1, -2), (torch.rand(B0, 2, 5),))
|
|
test(lambda x: op(x, 3, 1), (torch.rand(B0, 2, 5, 4, 6),))
|
|
test(lambda x: op(x, 1, 0), (torch.rand(2, B0, 5),), in_dims=1)
|
|
test(vmap(lambda x: op(x, 0, 1)), (torch.rand(B1, 2, B0, 5),), in_dims=2)
|
|
test(vmap(vmap(lambda x: op(x, 0, 1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
|
|
|
|
# Special case: scalar tensor
|
|
for dim1, dim2 in itertools.product([0, -1], [0, -1]):
|
|
x = torch.rand(B0)
|
|
result = vmap(lambda x: op(x, dim1, dim2))(x)
|
|
self.assertTrue(result is x)
|
|
|
|
def test_t(self):
|
|
op = torch.t
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(op, (torch.rand(B0, 2, 5),))
|
|
test(op, (torch.rand(2, B0, 5),), in_dims=1)
|
|
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
|
|
test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 5, B2),), in_dims=2)
|
|
|
|
def test_T_numpy(self):
|
|
def op(t):
|
|
return t.T
|
|
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
test(op, (torch.rand(B0, 2, 3, 5),))
|
|
test(op, (torch.rand(B0),))
|
|
test(op, (torch.rand(2, B0, 3, 5),), in_dims=1)
|
|
test(vmap(op), (torch.rand(B1, 2, B0, 5),), in_dims=2)
|
|
test(vmap(op), (torch.rand(B1, 2, B0, 3, 5),), in_dims=2)
|
|
test(vmap(vmap(op, in_dims=2)), (torch.rand(B1, 2, B0, 3, B2, 5),), in_dims=2)
|
|
|
|
def test_to(self):
|
|
test = self._vmap_test
|
|
B0, B1 = 7, 11
|
|
|
|
test(lambda t: t.to('cpu'), (torch.rand(B0),))
|
|
test(lambda t: t.to(torch.double), (torch.rand(B0),))
|
|
test(lambda t, o: t.to(o), (torch.rand(B0), torch.randn(B0, dtype=torch.float64)))
|
|
test(lambda t, o: t.to(o),
|
|
(torch.rand(B0), torch.randn(B0, dtype=torch.float64)),
|
|
in_dims=(0, None))
|
|
test(vmap(lambda t: t.to(torch.double)), (torch.rand(B0, B1, 3),))
|
|
|
|
# also test some casting methods
|
|
test(lambda t: t.double(), (torch.rand(B0),))
|
|
test(lambda t: t.float(), (torch.rand(B0),))
|
|
test(lambda t: t.int(), (torch.rand(B0),), check_propagates_grad=False)
|
|
test(lambda t: t.long(), (torch.rand(B0),), check_propagates_grad=False)
|
|
|
|
def test_unfold(self):
|
|
op = torch.Tensor.unfold
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 3, 2, 5
|
|
|
|
test(op, (torch.rand(B0, 7, 11), 0, 2, 1), in_dims=(0, None, None, None))
|
|
test(op, (torch.rand(7, B0, 11), 1, 4, 2), in_dims=(1, None, None, None))
|
|
test(vmap(op, in_dims=(0, None, None, None)),
|
|
(torch.rand(B1, 7, B0, 11), 1, 5, 1), in_dims=(2, None, None, None))
|
|
test(vmap(vmap(op, in_dims=(2, None, None, None)), in_dims=(0, None, None, None)),
|
|
(torch.rand(B1, 7, B0, 11, B2), -1, 2, 4), in_dims=(2, None, None, None))
|
|
|
|
def test_unbind(self):
|
|
test = self._vmap_view_test
|
|
op = torch.unbind
|
|
B0, B1, B2 = 7, 11, 13
|
|
|
|
test(op, (torch.rand(B0, 2, 1024), -1), in_dims=(0, None))
|
|
test(op, (torch.rand(B0, 2, 0),))
|
|
test(op, (torch.rand(2, B0, 7), 0), in_dims=(1, None))
|
|
test(vmap(op, in_dims=(0, None)), (torch.rand(B1, 1023, B0, 5), 1),
|
|
in_dims=(2, None))
|
|
test(vmap(vmap(lambda t: op(t, dim=1), in_dims=2)),
|
|
(torch.rand(B1, 2, B0, 32, B2),), in_dims=2)
|
|
|
|
def test_view(self):
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
op = torch.Tensor.view
|
|
|
|
# We should error out if the view would produce an incorrect result
|
|
with self.assertRaises(RuntimeError):
|
|
vmap(op, in_dims=(1, None))(torch.rand(2, B0, 5), [10])
|
|
|
|
test(op, (torch.rand(B0, 2 * 5), [2, 5]), in_dims=(0, None))
|
|
test(op, (torch.rand(B0, 4, 5), [1, 2, 1, 10]), in_dims=(0, None))
|
|
test(vmap(lambda t: t.view([-1])), (torch.rand(B0, B1, 2, 5, 3),))
|
|
test(vmap(vmap(lambda t: t.reshape([-1])), in_dims=1),
|
|
(torch.rand(B2, B0, B1, 3, 2, 5),), in_dims=1)
|
|
|
|
def test_view_as(self):
|
|
test = self._vmap_view_test
|
|
B0, B1, B2 = 7, 11, 13
|
|
op = torch.Tensor.view_as
|
|
|
|
# We should error out if the view would produce an incorrect result
|
|
with self.assertRaises(RuntimeError):
|
|
vmap(op, in_dims=(1, 0))(torch.rand(2, B0, 5), torch.rand(B0, 10))
|
|
|
|
test(op, (torch.rand(B0, 2 * 5), torch.rand(B0, 2, 5)))
|
|
test(op, (torch.rand(2 * 5), torch.rand(B0, 2, 5)), in_dims=(None, 0))
|
|
test(op, (torch.rand(B0, 2 * 5), torch.rand(2, 5)), in_dims=(0, None))
|
|
|
|
test(op, (torch.rand(B0, 4, 5), torch.rand(2, 1, 1, 10)), in_dims=(0, None))
|
|
|
|
test(vmap(op), (torch.rand(B0, B1, 2, 5), torch.randn(B0, B1, 10)))
|
|
test(vmap(vmap(op, in_dims=(0, None)), in_dims=(0, None)),
|
|
(torch.rand(B1, B2, B0, 3, 2, 5), torch.rand(B0, 3 * 2 * 5)),
|
|
in_dims=(2, 0))
|
|
|
|
def test_conv2d(self):
|
|
conv_setups = [
|
|
(torch.nn.Conv1d, torch.conv1d, [2, 4, 15]),
|
|
(torch.nn.Conv2d, torch.conv2d, [2, 4, 15, 20]),
|
|
(torch.nn.Conv3d, torch.conv3d, [2, 4, 15, 20, 25]),
|
|
# (torch.nn.ConvTranspose2d, torch.conv_transpose2d, [2, 4, 15, 20])
|
|
]
|
|
for conv_mod, conv_fn, inp_shape in conv_setups:
|
|
mod = conv_mod(4, 8, kernel_size=3)
|
|
arg_values = [torch.randn(inp_shape), mod.weight, mod.bias]
|
|
kwarg_values = {}
|
|
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values):
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
arg_values = [torch.randn(inp_shape), mod.weight, None]
|
|
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values):
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
mod2 = conv_mod(4, 8, kernel_size=3, groups=2, stride=3, padding=1, dilation=2)
|
|
arg_values = [torch.randn(inp_shape), mod2.weight, mod2.bias]
|
|
kwarg_values = dict(groups=2, stride=3, padding=1, dilation=2)
|
|
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values):
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
arg_values = [torch.randn(inp_shape), mod2.weight, None]
|
|
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(conv_fn, arg_values, kwarg_values):
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
def test_one_hot(self):
|
|
sample_inputs = [
|
|
(torch.randint(0, 3, []), 3),
|
|
(torch.randint(0, 3, [2, 3, 4]), 4),
|
|
]
|
|
for args in sample_inputs:
|
|
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(F.one_hot, args, {}):
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
def test_conj_bit(self):
|
|
x = torch.tensor([1 + 1j, 2 + 1j])
|
|
|
|
def foo(x):
|
|
assert not x.is_conj()
|
|
y = x.conj()
|
|
assert y.is_conj()
|
|
return y
|
|
res = vmap(foo)(x)
|
|
self.assertEqual(res, x.conj())
|
|
|
|
def test_mode_key(self):
|
|
def vmap_f(x):
|
|
return x + torch.randn(())
|
|
|
|
def naive_f(x, shape):
|
|
return x + torch.randn(shape)
|
|
|
|
torch.manual_seed(0)
|
|
out1 = vmap(vmap(vmap_f, randomness='different'), randomness='different')(torch.ones(2, 3))
|
|
|
|
torch.manual_seed(0)
|
|
out2 = naive_f(torch.ones(2, 3), (2, 3))
|
|
self.assertEqual(out1, out2)
|
|
|
|
torch.manual_seed(0)
|
|
out1 = vmap(vmap(vmap_f, randomness='different'), randomness='different')(torch.ones(2, 3, 4))
|
|
|
|
torch.manual_seed(0)
|
|
out2 = naive_f(torch.ones(2, 3, 4), (2, 3, 1))
|
|
self.assertEqual(out1, out2)
|
|
|
|
self.assertTrue(torch.randn(()).dim() == 0)
|
|
|
|
@parametrize('in_dim', [0, 1, 2])
|
|
@parametrize('out_dim', [0, 1, 2])
|
|
@parametrize('randomness', ['error', 'same'])
|
|
def test_chunk_vmap(self, in_dim, out_dim, randomness):
|
|
|
|
x = torch.randn(4, 5, 6)
|
|
|
|
def f(x):
|
|
y = x.sin()
|
|
if randomness != "error":
|
|
y = y + torch.rand_like(x)
|
|
return y
|
|
|
|
rs = torch.get_rng_state()
|
|
expected = vmap(f, in_dims=in_dim, out_dims=out_dim, randomness=randomness)(x)
|
|
|
|
for chunks in [1, 2, 3, 4, 7, 10, 16]:
|
|
torch.set_rng_state(rs)
|
|
output = chunk_vmap(
|
|
f, in_dims=in_dim, out_dims=out_dim, randomness=randomness, chunks=chunks
|
|
)(x)
|
|
self.assertEqual(output, expected)
|
|
|
|
|
|
instantiate_parametrized_tests(TestVmapOperators)
|
|
|
|
|
|
def construct_v(output, batch_size, contig=False):
|
|
if contig:
|
|
return torch.randn(batch_size, *output.shape,
|
|
dtype=output.dtype, device=output.device)
|
|
result = torch.randn(*output.shape, batch_size,
|
|
dtype=output.dtype, device=output.device)
|
|
return result.movedim(-1, 0)
|
|
|
|
def as_tuple(x):
|
|
if isinstance(x, tuple):
|
|
return x
|
|
elif isinstance(x, list):
|
|
return tuple(x)
|
|
else:
|
|
return x,
|
|
|
|
|
|
def differentiable(args):
|
|
return tuple(arg for arg in as_tuple(args)
|
|
if isinstance(arg, torch.Tensor) and arg.requires_grad)
|
|
|
|
|
|
def _get_rand_no_zeros(*args, **kwargs):
|
|
requires_grad = kwargs.get('requires_grad', False)
|
|
kwargs_without_requires_grad = kwargs.copy()
|
|
kwargs_without_requires_grad['requires_grad'] = False
|
|
result = torch.rand(*args, **kwargs_without_requires_grad)
|
|
return result.clamp_min_(0.1).requires_grad_(requires_grad)
|
|
|
|
|
|
class TestVmapBatchedGradient(Namespace.TestVmapBase):
|
|
def _vmap_test(self, *args, **kwargs):
|
|
return _vmap_test(self, *args, **kwargs)
|
|
|
|
# Tests batched gradient computation of outputs = op(*args, **kwargs)
|
|
# by comparing it to a sequential map+stack fallback.
|
|
#
|
|
# output_process_fn: a function that maps the outputs to the part
|
|
# that should be differentiated.
|
|
# batch_size: the batch dim size for the batched grad
|
|
def _batched_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
outputs = op(*args, **kwargs)
|
|
outputs = differentiable(output_process_fn(outputs))
|
|
for contig in [True, False]:
|
|
batched_vectors = tuple(construct_v(out, batch_size, contig)
|
|
for out in outputs)
|
|
|
|
def vector_jacobian_product(*vectors):
|
|
return torch.autograd.grad(outputs, differentiable(args), vectors,
|
|
retain_graph=True)
|
|
self._vmap_test(vector_jacobian_product, batched_vectors,
|
|
check_propagates_grad=False)
|
|
|
|
# Tests batched second grad computation of outputs = op(*args, **kwargs).
|
|
# by comparing it to a sequential map+stack fallback.
|
|
#
|
|
# output_process_fn: a function that maps the outputs to the part
|
|
# that should be differentiated.
|
|
# batch_size: the batch dim size for the batched grad
|
|
#
|
|
# NB: we only test computing batched gradients in the second gradient
|
|
# computation. One specific use case that does this is computing the hessian
|
|
# matrix of a scalar-valued function; this is useful in Bayesian Logistic
|
|
# Regression.
|
|
# It might be useful to have a test that computes batched first gradients and
|
|
# then uses those to compute batched second gradients in the future.
|
|
def _batched_grad_grad_test(self, op, args, kwargs=None, output_process_fn=lambda x: x, batch_size=3):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
outputs = op(*args, **kwargs)
|
|
outputs = differentiable(output_process_fn(outputs))
|
|
ones = tuple(torch.ones_like(out) for out in outputs)
|
|
# Same thing as summing together all of the outputs and calling .backward()
|
|
first_grads = torch.autograd.grad(outputs, differentiable(args), ones,
|
|
create_graph=True)
|
|
first_grads = differentiable(first_grads)
|
|
self.assertNotEqual(
|
|
len(first_grads), 0, "None of the first grads depend on the input!")
|
|
|
|
for contig in [True, False]:
|
|
batched_vectors = tuple(construct_v(grad, batch_size, contig)
|
|
for grad in first_grads)
|
|
|
|
def vector_hessian_product(*vectors):
|
|
outputs = torch.autograd.grad(first_grads, differentiable(args), vectors,
|
|
retain_graph=True, allow_unused=True)
|
|
outputs = tuple(out for out in outputs if out is not None)
|
|
assert len(outputs) > 0
|
|
return outputs
|
|
|
|
self._vmap_test(vector_hessian_product, batched_vectors,
|
|
check_propagates_grad=False)
|
|
|
|
def _test_arithmetic(self, op, device, test_grad_grad=True):
|
|
x = torch.randn(2, 3, requires_grad=True, device=device)
|
|
y = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
|
|
scalar = 3.14
|
|
self._batched_grad_test(op, (x, y))
|
|
self._batched_grad_test(op, (scalar, y))
|
|
self._batched_grad_test(op, (x, scalar))
|
|
|
|
if test_grad_grad:
|
|
self._batched_grad_grad_test(op, (x, y))
|
|
|
|
def test_add(self, device):
|
|
self._test_arithmetic(torch.add, device, test_grad_grad=False)
|
|
self._test_arithmetic(lambda x, y: x + y, device, test_grad_grad=False)
|
|
|
|
def test_sub(self, device):
|
|
self._test_arithmetic(torch.sub, device, test_grad_grad=False)
|
|
self._test_arithmetic(lambda x, y: x - y, device, test_grad_grad=False)
|
|
|
|
def test_mul(self, device):
|
|
self._test_arithmetic(torch.mul, device)
|
|
self._test_arithmetic(lambda x, y: x * y, device)
|
|
|
|
def test_div(self, device):
|
|
self._test_arithmetic(torch.div, device)
|
|
self._test_arithmetic(lambda x, y: x / y, device)
|
|
|
|
def test_binary_cross_entropy(self, device):
|
|
x = F.sigmoid(torch.randn(3, 2, device=device, requires_grad=True))
|
|
target = torch.rand(3, 2, device=device)
|
|
|
|
op = functools.partial(F.binary_cross_entropy, target=target)
|
|
|
|
self._batched_grad_test(op, (x,), {})
|
|
self._batched_grad_grad_test(op, (x,), {})
|
|
|
|
def test_log_softmax(self, device):
|
|
op = functools.partial(torch.log_softmax, dim=-1)
|
|
x = torch.randn(3, 2, device=device, requires_grad=True)
|
|
|
|
self._batched_grad_test(op, (x,), {})
|
|
self._batched_grad_grad_test(op, (x,), {})
|
|
|
|
def test_expand(self, device):
|
|
x = torch.randn(2, 3, device=device, requires_grad=True)
|
|
|
|
def op(x):
|
|
return x.expand(5, 5, 2, 3)
|
|
self._batched_grad_test(op, (x,))
|
|
|
|
@allowVmapFallbackUsage
|
|
def test_index(self, device):
|
|
x = torch.randn(2, 3, requires_grad=True, device=device)
|
|
index = torch.tensor([[0, 0], [1, 1]], device=device)
|
|
|
|
def op(x):
|
|
y = x * x
|
|
return y[index]
|
|
|
|
self._batched_grad_test(op, (x,))
|
|
self._batched_grad_grad_test(op, (x,))
|
|
|
|
def test_lgamma(self, device):
|
|
x = torch.randn(2, 3, requires_grad=True, device=device)
|
|
self._batched_grad_test(Tensor.lgamma, (x,))
|
|
self._batched_grad_grad_test(Tensor.lgamma, (x,))
|
|
|
|
def test_log(self, device):
|
|
x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
|
|
self._batched_grad_test(torch.log, (x,))
|
|
self._batched_grad_grad_test(torch.log, (x,))
|
|
|
|
def test_logsumexp(self, device):
|
|
x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
|
|
|
|
def op(x):
|
|
return torch.logsumexp(x, -1)
|
|
|
|
self._batched_grad_test(op, (x,))
|
|
self._batched_grad_grad_test(op, (x,))
|
|
|
|
def test_log1p(self, device):
|
|
x = _get_rand_no_zeros(2, 3, device=device, requires_grad=True)
|
|
self._batched_grad_test(torch.log1p, (x,))
|
|
self._batched_grad_grad_test(torch.log1p, (x,))
|
|
|
|
@allowVmapFallbackUsage
|
|
def test_max(self, device):
|
|
x = torch.randn(2, 3, requires_grad=True, device=device)
|
|
self._batched_grad_test(torch.max, (x,))
|
|
|
|
@allowVmapFallbackUsage
|
|
def test_median(self, device):
|
|
x = torch.randn(2, 3, requires_grad=True, device=device)
|
|
self._batched_grad_test(torch.median, (x,))
|
|
|
|
@allowVmapFallbackUsage
|
|
def test_min(self, device):
|
|
x = torch.randn(2, 3, requires_grad=True, device=device)
|
|
self._batched_grad_test(torch.min, (x,))
|
|
|
|
def test_permute(self, device):
|
|
x = torch.randn(2, 3, 5, requires_grad=True, device=device)
|
|
|
|
def op(x):
|
|
return x.permute(2, 0, 1)
|
|
|
|
self._batched_grad_test(op, (x,))
|
|
|
|
def test_reshape(self, device):
|
|
x = torch.randn(2, 3, 5, requires_grad=True, device=device)
|
|
|
|
def op(x):
|
|
return x.reshape([2 * 3, 5])
|
|
|
|
self._batched_grad_test(op, (x,))
|
|
|
|
def test_sigmoid(self, device):
|
|
x = torch.randn(2, 3, requires_grad=True, device=device)
|
|
self._batched_grad_test(Tensor.sigmoid, (x,))
|
|
self._batched_grad_grad_test(Tensor.sigmoid, (x,))
|
|
|
|
def test_stack(self, device):
|
|
x = torch.randn(2, 3, device=device, requires_grad=True)
|
|
y = torch.randn(2, 3, device=device, requires_grad=True)
|
|
|
|
def op(x, y):
|
|
return torch.stack([x, y])
|
|
self._batched_grad_test(op, (x, y))
|
|
|
|
def test_select(self, device):
|
|
x = torch.randn(2, 3, device=device, requires_grad=True)
|
|
self._batched_grad_test(lambda x: x[1], (x,))
|
|
self._batched_grad_test(lambda x: x.select(1, 2), (x,))
|
|
self._batched_grad_test(lambda x: x.select(-1, 0), (x,))
|
|
|
|
def test_slice(self, device):
|
|
x = torch.randn(2, 3, 5, device=device, requires_grad=True)
|
|
self._batched_grad_test(lambda x: x[0:1], (x,))
|
|
self._batched_grad_test(lambda x: x[:, 1:3], (x,))
|
|
self._batched_grad_test(lambda x: x[..., 1:3], (x,))
|
|
|
|
def test_trace(self, device):
|
|
x = torch.randn(2, 3, device=device, requires_grad=True)
|
|
self._batched_grad_test(Tensor.trace, (x,))
|
|
|
|
x = torch.randn(3, 2, 2, device=device)
|
|
|
|
def sum_grad_trace(x):
|
|
return grad(torch.trace)(x).sum()
|
|
|
|
output = vmap(grad(sum_grad_trace))(x)
|
|
self.assertEqual(output, torch.zeros_like(output))
|
|
|
|
def test_where(self, device):
|
|
x = torch.randn(3, 2, device=device)
|
|
y = torch.ones(3, 2, device=device)
|
|
|
|
def f(x, y):
|
|
return torch.where(x > 0, x, y)
|
|
|
|
# Check that there is no runtime error, exactness tests are done with opinfo
|
|
vmap(f)(x, y)
|
|
|
|
x = torch.randint(0, 2, size=(4, 3), dtype=torch.float)
|
|
|
|
def f(t):
|
|
return torch.where(t)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"Attempted to vmap over aten::where"):
|
|
vmap(f)(x)
|
|
|
|
@skipCUDAIfNoMagma
|
|
@allowVmapFallbackUsage
|
|
def test_symeig(self, device):
|
|
def op(x):
|
|
return torch.symeig(x, eigenvectors=True)[0]
|
|
|
|
x = torch.randn(3, 3, device=device, requires_grad=True)
|
|
self._batched_grad_test(op, (x,), {})
|
|
self._batched_grad_grad_test(op, (x,), {})
|
|
|
|
def test_threshold(self, device):
|
|
x = torch.randn(2, 3, device=device, requires_grad=True)
|
|
self._batched_grad_test(lambda x: F.threshold(x, 0.5, 0.0), (x,))
|
|
|
|
@allowVmapFallbackUsage
|
|
def test_inplace_view(self, device):
|
|
leaf = torch.randn(4, 5, requires_grad=True)
|
|
|
|
def func(leaf):
|
|
# Make sure the function is non-trivially twice differentiable
|
|
base = leaf * leaf
|
|
view = base[0]
|
|
view.cos_()
|
|
return view
|
|
|
|
self._batched_grad_test(func, (leaf,), {})
|
|
self._batched_grad_grad_test(func, (leaf,), {})
|
|
|
|
@allowVmapFallbackUsage
|
|
def test_inplace_manyview(self, device):
|
|
leaf = torch.randn(4, 4, 5, requires_grad=True)
|
|
|
|
def func(leaf):
|
|
# Make sure the function is non-trivially twice differentiable
|
|
base = leaf * leaf
|
|
view = base.transpose(0, 2)
|
|
view = view[1]
|
|
view = view.diagonal()
|
|
view = view[::2]
|
|
view.cos_()
|
|
return view
|
|
|
|
self._batched_grad_test(func, (leaf,), {})
|
|
self._batched_grad_grad_test(func, (leaf,), {})
|
|
|
|
def test_diagonal(self, device):
|
|
x = torch.randn(4, 5, device=device, requires_grad=True)
|
|
self._batched_grad_test(lambda x: x.diagonal(1, 0, 1), (x,))
|
|
|
|
x = torch.randn(3, 4, 5, device=device, requires_grad=True)
|
|
self._batched_grad_test(lambda x: x.diagonal(0, -1, -2), (x,))
|
|
|
|
@allowVmapFallbackUsage
|
|
def test_unrelated_output(self, device):
|
|
B0 = 3
|
|
x = torch.randn([], requires_grad=True)
|
|
y = torch.randn([], requires_grad=True)
|
|
gy = torch.randn(B0, requires_grad=True)
|
|
|
|
def vjp(v):
|
|
res, = torch.autograd.grad(y, x, v, allow_unused=True)
|
|
return torch.zeros_like(x) if res is None else res
|
|
|
|
result = vmap(vjp)(gy)
|
|
self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
|
|
|
|
@allowVmapFallbackUsage
|
|
def test_unrelated_output_multiple_grad(self, device):
|
|
B0 = 3
|
|
x = torch.randn([], requires_grad=True)
|
|
y = torch.randn([], requires_grad=True)
|
|
gy = torch.randn(B0, requires_grad=True)
|
|
|
|
def vjp(v):
|
|
res, = torch.autograd.grad(y, x, v, allow_unused=True)
|
|
return torch.zeros_like(x) if res is None else res
|
|
|
|
_ = vjp(gy[0])
|
|
result = vmap(vjp)(gy)
|
|
self.assertEqual(result, torch.zeros(B0, *x.shape, device=device))
|
|
|
|
|
|
def discover_variants(opinfo):
|
|
aliases = []
|
|
inplace_variants = []
|
|
|
|
if opinfo.inplace_variant:
|
|
inplace_variants.append(opinfo.inplace_variant)
|
|
|
|
aliases.append(opinfo.op)
|
|
for alias in opinfo.aliases:
|
|
aliases.append(alias.op)
|
|
if alias.inplace_variant:
|
|
inplace_variants.append(alias.inplace_variant)
|
|
return aliases, inplace_variants
|
|
|
|
|
|
class TestVmapOperatorsOpInfo(TestCase):
|
|
|
|
def vmap_outplace_test(self, func, args, kwargs, in_dims, check_shape_only=False,
|
|
postprocess_fn=None):
|
|
for loop_out, vmap_out in compute_quantities_for_vmap_test(func, args, kwargs, in_dims):
|
|
if postprocess_fn is not None:
|
|
loop_out = postprocess_fn(loop_out)
|
|
vmap_out = postprocess_fn(vmap_out)
|
|
if check_shape_only:
|
|
self.assertEqual(vmap_out.shape, loop_out.shape)
|
|
continue
|
|
self.assertEqual(vmap_out, loop_out)
|
|
|
|
def vmap_inplace_test(self, func, args, kwargs, in_dims, postprocess_fn=None):
|
|
# NB: This test assumes that the first argument is being modified.
|
|
# This is OK because it's what every other OpInfo-based test assumes,
|
|
# but it is going to need a more robust solution eventually.
|
|
if in_dims[0] is None:
|
|
# Check that we correctly raise an error when vmap is impossible
|
|
# on the in-place operation
|
|
with self.assertRaises(RuntimeError):
|
|
for _ in compute_quantities_for_vmap_test(
|
|
func, args, kwargs, in_dims, compute_loop_out=False, clone_inputs=True):
|
|
pass
|
|
return
|
|
for loop_out, vmap_out in compute_quantities_for_vmap_test(
|
|
func, args, kwargs, in_dims, clone_inputs=True):
|
|
if postprocess_fn is not None:
|
|
loop_out = postprocess_fn(loop_out)
|
|
vmap_out = postprocess_fn(vmap_out)
|
|
self.assertEqual(vmap_out, loop_out)
|
|
|
|
def opinfo_vmap_test(self, device, dtype, op, check_has_batch_rule,
|
|
skip_inplace=(), postprocess_fn=None):
|
|
def test():
|
|
# Error inputs check
|
|
if op.error_inputs_func is not None:
|
|
error_inputs = op.error_inputs(device)
|
|
for error_input in error_inputs:
|
|
sample_input = error_input.sample_input
|
|
args = (sample_input.input,) + tuple(sample_input.args)
|
|
kwargs = sample_input.kwargs
|
|
for args, in_dims, _ in generate_vmap_inputs(args, {}):
|
|
with self.assertRaises(Exception):
|
|
vmap(op, in_dims)(*args, **kwargs)
|
|
|
|
# Sample inputs check
|
|
sample_inputs_itr = op.sample_inputs(device, dtype, requires_grad=False)
|
|
aliases, inplace_aliases = discover_variants(op)
|
|
check_shape_only = op.name in ('empty_like', 'new_empty')
|
|
for sample_input in sample_inputs_itr:
|
|
args = (sample_input.input,) + sample_input.args
|
|
kwargs = sample_input.kwargs
|
|
is_batch_norm_and_training = is_batch_norm_training(op.name, kwargs)
|
|
for args, in_dims, _ in generate_vmap_inputs(
|
|
args, {}, is_batch_norm_and_training=is_batch_norm_and_training):
|
|
for func in aliases:
|
|
self.vmap_outplace_test(func, args, kwargs, in_dims, check_shape_only, postprocess_fn)
|
|
if op.name in skip_inplace:
|
|
continue
|
|
if not is_valid_inplace_sample_input(sample_input, op, op.inplace_variant):
|
|
continue
|
|
for func in inplace_aliases:
|
|
self.vmap_inplace_test(func, args, kwargs, in_dims, postprocess_fn)
|
|
|
|
if check_has_batch_rule:
|
|
check_vmap_fallback(self, test, op)
|
|
else:
|
|
test()
|
|
|
|
vmap_fail = {
|
|
# -------------------- ALLOWED FAILURES --------------------------------
|
|
# These are things that we either cannot fix or are not actually problems
|
|
xfail('resize_'),
|
|
xfail('resize_as_'),
|
|
xfail('to_sparse'),
|
|
xfail('__getitem__'), # dynamic mask
|
|
xfail('index_put'), # dynamic mask
|
|
xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency
|
|
xfail('nn.functional._scaled_dot_product_attention'), # randomness
|
|
xfail('masked_select'), # dynamic op
|
|
xfail('nonzero'), # dynamic op
|
|
xfail('unique', ''), # dynamic op
|
|
xfail('unique_consecutive', ''), # dynamic op
|
|
xfail('allclose'), # returns a boolean
|
|
xfail('uniform'), # randomness is tested separately
|
|
xfail('rand_like'), # randomness is tested separately
|
|
xfail('randint_like'), # randomness is tested separately
|
|
xfail('randn_like'), # randomness is tested separately
|
|
xfail('bernoulli', ''), # randomness is tested separately
|
|
xfail('normal', ''), # randomness is tested separately
|
|
xfail('normal', 'number_mean'), # randomness is tested separately
|
|
xfail('multinomial', ''), # randomness
|
|
xfail('nn.functional.embedding', ''), # we only support some cases
|
|
xfail('nn.functional.rrelu'), # randomness
|
|
xfail('nn.functional.dropout2d', ''), # randomness
|
|
xfail('nn.functional.dropout3d', ''), # randomness
|
|
xfail('nn.functional.feature_alpha_dropout', 'with_train'), # randomness
|
|
xfail('as_strided'), # Our test runner can't handle this; manual test exists
|
|
skip('new_empty_strided'), # empty tensor data is garbage so it's hard to make comparisons with it
|
|
xfail('nn.functional.fractional_max_pool3d'), # randomness
|
|
xfail('nn.functional.fractional_max_pool2d'), # randomness
|
|
xfail('pca_lowrank', ''), # random operation
|
|
xfail('svd_lowrank', ''), # random operation
|
|
xfail('linspace', ''), # test runner can't handle factory functions
|
|
xfail('arange', ''), # test runner can't handle factory functions
|
|
xfail('logspace', ''), # test runner can't handle factory functions
|
|
xfail('empty', ''), # test runner can't handle factory functions
|
|
xfail('ones', ''), # test runner can't handle factory functions
|
|
xfail('zeros', ''), # test runner can't handle factory functions
|
|
xfail('eye', ''), # non-tensor input
|
|
xfail('broadcast_shapes', ''), # test runner can't handle non-Tensor ops
|
|
xfail('sparse.sampled_addmm'), # sparse
|
|
xfail('cross'), # The default value of dim in op is *very* weird. No wonder it doesn't work
|
|
skip('linalg.eigh', ''), # not unique, see test_linalg_eigh for manual test
|
|
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
|
# ----------------------------------------------------------------------
|
|
|
|
# ---------------------------- BUGS ------------------------------------
|
|
# entries in here don't work and need to be fixed.
|
|
# Each one of these is a bug
|
|
xfail('clamp_min', ''), # Exception not raised on error input
|
|
xfail('clamp_max', ''), # Exception not raised on error input
|
|
|
|
xfail('view_as_complex'), # RuntimeError: Tensor must have a last dimension with stride 1
|
|
xfail('tensor_split'), # data_ptr
|
|
xfail('histogramdd'), # expected Tensor as element 0 in argument 0, but got tuple
|
|
xfail('nn.functional.gaussian_nll_loss'), # data-dependent control flow error
|
|
xfail('nn.functional.embedding_bag'), # embedding renorm vmap inplace incompatible
|
|
xfail('__rpow__'), # https://github.com/pytorch/functorch/issues/617
|
|
xfail('column_stack', ''), # Batching rule not implemented for aten::column_stack
|
|
xfail('narrow'), # Batching rule not implemented for aten::narrow.Tensor
|
|
|
|
# required rank 4 tensor to use channels_last format
|
|
xfail('bfloat16'),
|
|
xfail('bool'),
|
|
xfail('byte'),
|
|
xfail('char'),
|
|
xfail('double'),
|
|
xfail('float'),
|
|
xfail('half'),
|
|
xfail('int'),
|
|
xfail('long'),
|
|
xfail('short'),
|
|
xfail('cdouble'),
|
|
xfail('cfloat'),
|
|
|
|
xfail('jiterator_binary', device_type='cuda'), # NYI: querying is_contiguous inside of vmap
|
|
xfail('jiterator_binary_return_by_ref', device_type='cuda'), # NYI: querying is_contiguous inside of vmap
|
|
xfail('jiterator_4inputs_with_extra_args', device_type='cuda'), # NYI: querying is_contiguous inside of vmap
|
|
xfail('equal', ''), # TypeError: object of type 'bool' has no len(); likely testrunner problem
|
|
xfail('jiterator_unary', device_type='cuda'), # NYI: querying is_contiguous inside of vmap
|
|
xfail('jiterator_2inputs_2outputs', device_type='cuda'), # NYI: querying is_contiguous inside of vmap
|
|
# ---------------------------------------------------------------------
|
|
}
|
|
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
|
|
@opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', (
|
|
tol1('linalg.det',
|
|
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'),
|
|
# The following is often flaky, but just on windows.
|
|
# We should investigate if it's actually a problem or not.
|
|
tol1('nn.functional.conv_transpose3d',
|
|
{torch.float32: tol(atol=1e-04, rtol=1e-02)}, device_type='cuda'),
|
|
))
|
|
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
|
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
|
|
xfail('native_batch_norm'),
|
|
# The error inputs are vectors, that pass when batched as they are treated as a matrix
|
|
xfail('trace'),
|
|
}))
|
|
def test_vmap_exhaustive(self, device, dtype, op):
|
|
# needs to be fixed
|
|
inplace_failure_list = (
|
|
)
|
|
self.opinfo_vmap_test(device, dtype, op, check_has_batch_rule=False,
|
|
skip_inplace=inplace_failure_list)
|
|
|
|
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
|
|
@opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', (
|
|
tol1('linalg.det',
|
|
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'),
|
|
))
|
|
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
|
@skipOps('TestVmapOperatorsOpInfo', 'test_op_has_batch_rule', vmap_fail.union({
|
|
skip('to'), # RuntimeError: required rank 4 tensor to use channels_last format
|
|
xfail('complex'),
|
|
xfail('copysign'),
|
|
xfail('native_batch_norm'),
|
|
xfail('histogram'),
|
|
xfail('index_fill'),
|
|
xfail('nansum'),
|
|
xfail('nanmean'),
|
|
xfail('scatter_reduce', 'sum'),
|
|
xfail('scatter_reduce', 'mean'),
|
|
xfail('scatter_reduce', 'amax'),
|
|
xfail('scatter_reduce', 'amin'),
|
|
# `index_put` OpInfo in pytorch/pytorch has
|
|
# masked index as input which is not supported
|
|
xfail('index_put', ''),
|
|
xfail('isin'),
|
|
xfail('lu_unpack'),
|
|
xfail('masked_fill'),
|
|
xfail('masked_scatter'),
|
|
xfail('masked_select'),
|
|
xfail('nanquantile'),
|
|
xfail('ormqr'),
|
|
xfail('put'),
|
|
xfail('quantile'),
|
|
xfail('renorm'),
|
|
xfail('resize_as_'),
|
|
xfail('take'),
|
|
xfail('tensor_split'),
|
|
xfail('to_sparse'),
|
|
xfail('vdot'),
|
|
xfail('__getitem__', ''),
|
|
xfail('all'),
|
|
xfail('any'),
|
|
xfail('count_nonzero'),
|
|
xfail('nanmean'),
|
|
xfail('nn.functional.dropout'), # works, can't check against for loop because of randomness inconsistency
|
|
xfail('nn.functional._scaled_dot_product_attention'), # randomness
|
|
xfail('resize_'),
|
|
xfail('view_as_complex'),
|
|
xfail('matrix_exp'),
|
|
xfail('trace'), # Does not support batched tensors
|
|
xfail('bucketize'),
|
|
xfail('fft.ihfft2'),
|
|
xfail('fft.ihfftn'),
|
|
xfail('allclose'),
|
|
xfail('argwhere'),
|
|
xfail('unique_consecutive'),
|
|
xfail('unique'),
|
|
xfail('nn.functional.ctc_loss'),
|
|
xfail('nn.functional.gaussian_nll_loss'),
|
|
xfail('nn.functional.huber_loss'),
|
|
# We can get this to work on CUDA through decomposition,
|
|
# but fails on CPU due to max_pool1d_cpu not having a derivative
|
|
xfail('nn.functional.max_pool1d'),
|
|
xfail('nn.functional.max_pool3d'),
|
|
xfail('histc'),
|
|
xfail('as_strided'),
|
|
xfail('istft'),
|
|
xfail('nonzero'),
|
|
xfail('nn.functional.fractional_max_pool2d'),
|
|
xfail('stft'),
|
|
xfail('isclose'),
|
|
xfail('nn.functional.fractional_max_pool3d'),
|
|
xfail('nn.functional.bilinear'),
|
|
xfail('nn.functional.embedding_bag'),
|
|
xfail('linalg.tensorsolve'),
|
|
xfail('bernoulli', ''),
|
|
xfail('linalg.lu_factor', ''),
|
|
xfail('nn.functional.feature_alpha_dropout', 'with_train'),
|
|
xfail('nn.functional.kl_div', ''),
|
|
xfail('multinomial', ''),
|
|
xfail('column_stack', ''),
|
|
xfail('pca_lowrank', ''),
|
|
xfail('normal', ''),
|
|
xfail('nn.functional.dropout2d', ''),
|
|
xfail('normal', 'number_mean'),
|
|
xfail('svd_lowrank', ''),
|
|
xfail('diagflat', ''),
|
|
xfail('special.log_ndtr'),
|
|
xfail('narrow'), # Batching rule not implemented for aten::narrow.Tensor
|
|
xfail('nn.functional.triplet_margin_loss', ''),
|
|
xfail('nn.functional.pdist', ''),
|
|
xfail('scatter_reduce', 'sum'),
|
|
xfail('nn.functional.smooth_l1_loss', ''),
|
|
xfail('scatter_reduce', 'amax'),
|
|
xfail('nn.functional.max_unpool1d', 'grad'),
|
|
xfail('nn.functional.multi_margin_loss', ''),
|
|
xfail('scatter_reduce', 'prod'),
|
|
xfail('nn.functional.multilabel_margin_loss', ''),
|
|
xfail('scatter_reduce', 'amin'),
|
|
xfail('nn.functional.max_unpool3d', 'grad'),
|
|
xfail('nn.functional.max_unpool2d', ''),
|
|
xfail('nn.functional.max_unpool2d', 'grad'),
|
|
xfail('nn.functional.margin_ranking_loss', ''),
|
|
xfail('nn.functional.max_unpool1d', ''),
|
|
xfail('nn.functional.soft_margin_loss', ''),
|
|
xfail('scatter_reduce', 'mean'),
|
|
xfail('nn.functional.max_unpool3d', ''),
|
|
xfail('linalg.ldl_solve', '', device_type='cpu'),
|
|
xfail('chalf', ''),
|
|
xfail('arange', ''),
|
|
xfail('clamp_max', ''),
|
|
xfail('jiterator_binary_return_by_ref', device_type='cuda'),
|
|
xfail('special.spherical_bessel_j0'),
|
|
xfail('jiterator_unary', device_type='cuda'),
|
|
xfail('jiterator_2inputs_2outputs', device_type='cuda'),
|
|
xfail('special.airy_ai'),
|
|
xfail('clamp_min', ''),
|
|
xfail('special.bessel_j0'),
|
|
xfail('sparse.sampled_addmm'),
|
|
xfail('special.bessel_y0'),
|
|
xfail('special.chebyshev_polynomial_u'),
|
|
xfail('special.modified_bessel_k1'),
|
|
xfail('segment_reduce', 'offsets'),
|
|
xfail('special.bessel_j1'),
|
|
xfail('logspace', ''),
|
|
xfail('empty', ''),
|
|
xfail('index_reduce', ''),
|
|
xfail('linspace', ''),
|
|
xfail('special.laguerre_polynomial_l'),
|
|
xfail('special.hermite_polynomial_h'),
|
|
xfail('jiterator_binary', device_type='cuda'),
|
|
xfail('special.modified_bessel_i0'),
|
|
xfail('jiterator_4inputs_with_extra_args', device_type='cuda'),
|
|
xfail('linalg.vander', ''),
|
|
xfail('segment_reduce', 'lengths'),
|
|
xfail('lu_solve', ''),
|
|
xfail('special.bessel_y1'),
|
|
xfail('special.hermite_polynomial_he'),
|
|
xfail('special.scaled_modified_bessel_k0'),
|
|
xfail('nn.functional.dropout3d', ''),
|
|
xfail('special.scaled_modified_bessel_k1'),
|
|
xfail('broadcast_shapes', ''),
|
|
xfail('special.modified_bessel_k0'),
|
|
xfail('linalg.vecdot', ''),
|
|
xfail('linalg.ldl_factor', ''),
|
|
xfail('special.modified_bessel_i1'),
|
|
xfail('special.chebyshev_polynomial_t'),
|
|
xfail('as_strided_scatter', ''),
|
|
xfail('equal', ''),
|
|
xfail('linalg.lu', ''),
|
|
skip('linalg.ldl_solve', ''),
|
|
}))
|
|
def test_op_has_batch_rule(self, device, dtype, op):
|
|
# needs to be fixed
|
|
inplace_failures = (
|
|
'abs',
|
|
'acos',
|
|
'acosh',
|
|
'addbmm',
|
|
'addcdiv',
|
|
'addcmul',
|
|
'addmm',
|
|
'addmv',
|
|
'addr',
|
|
'asin',
|
|
'asinh',
|
|
'atan2',
|
|
'atan',
|
|
'atanh',
|
|
'baddbmm',
|
|
'clamp',
|
|
'conj_physical',
|
|
'cumprod',
|
|
'cumsum',
|
|
'div',
|
|
'div',
|
|
'floor_divide',
|
|
'fmod',
|
|
'heaviside',
|
|
'hypot',
|
|
'igamma',
|
|
'igammac',
|
|
'index_add',
|
|
'index_copy',
|
|
'ldexp',
|
|
'lerp',
|
|
'neg',
|
|
'nextafter',
|
|
'polygamma',
|
|
'pow',
|
|
'remainder',
|
|
'scatter_add',
|
|
'scatter',
|
|
'square',
|
|
'sub',
|
|
'tril',
|
|
'triu',
|
|
'trunc',
|
|
'xlogy',
|
|
)
|
|
self.opinfo_vmap_test(device, dtype, op, check_has_batch_rule=True,
|
|
skip_inplace=inplace_failures)
|
|
|
|
def test_linalg_svd(self, device):
|
|
# linalg_svd returns a tuple of three tensors, (U, S, Vh).
|
|
# Given the same input, it may return different tensors,
|
|
# because svd isn't unique. To test that the svd is correct, we multiply
|
|
# U @ diag(S) @ Vh and check that that the output from vmap matches the
|
|
# output from a for-loop.
|
|
def compute_A(out):
|
|
U, S, Vh = out
|
|
m = U.shape[-1]
|
|
n = Vh.shape[-2]
|
|
diag_S = S.new_zeros(*S.shape[:-1], m, n)
|
|
diag_S.diagonal(offset=0, dim1=-2, dim2=-1).copy_(S)
|
|
return U @ diag_S @ Vh
|
|
|
|
opinfos = [op for op in op_db if op.name == 'linalg.svd']
|
|
assert len(opinfos) > 0
|
|
|
|
for op in opinfos:
|
|
self.opinfo_vmap_test(device, torch.float, op, check_has_batch_rule=True,
|
|
postprocess_fn=compute_A)
|
|
|
|
def test_linalg_eigh(self, device):
|
|
# linalg_svd returns two tensors, (Q, L).
|
|
# Given the same input, it may return different tensors,
|
|
# because the eig decomposition isn't unique.
|
|
# To test that eigh is correct, we multiply
|
|
# Q @ diag(L) @ Qh and check that that the output from vmap matches the
|
|
# output from a for-loop.
|
|
def compute_A(out):
|
|
L, Q = out
|
|
n = Q.shape[-1]
|
|
diag_L = L.new_zeros(*L.shape[:-1], n, n)
|
|
diag_L.diagonal(offset=0, dim1=-2, dim2=-1).copy_(L)
|
|
Qh = Q.transpose(-2, -1).conj()
|
|
return Q @ diag_L @ Qh
|
|
|
|
opinfos = [op for op in op_db if op.name == 'linalg.eigh']
|
|
assert len(opinfos) > 0
|
|
|
|
for op in opinfos:
|
|
self.opinfo_vmap_test(device, torch.float, op, check_has_batch_rule=False,
|
|
postprocess_fn=compute_A)
|
|
|
|
def test_slogdet(self, device):
|
|
# There's no OpInfo for this
|
|
def test():
|
|
B = 2
|
|
x = torch.randn(2, 5, 5, device=device)
|
|
self.vmap_outplace_test(torch.slogdet, (x,), {}, (0,))
|
|
|
|
check_vmap_fallback(self, test, torch.slogdet)
|
|
|
|
def test_fill__Tensor(self, device):
|
|
# There's no OpInfo for fill_.Tensor, so here's an extra test for it.
|
|
def test():
|
|
B = 2
|
|
args = (torch.randn(B, 3, device=device), torch.randn(B))
|
|
self.vmap_inplace_test(Tensor.fill_, args, {}, (0, 0))
|
|
|
|
args = (torch.randn(3, B, device=device), torch.randn(B))
|
|
self.vmap_inplace_test(Tensor.fill_, args, {}, (-1, 0))
|
|
|
|
args = (torch.randn(3, device=device), torch.randn(B))
|
|
self.vmap_inplace_test(Tensor.fill_, args, {}, (None, 0))
|
|
|
|
args = (torch.randn(3, B, device=device), torch.randn([]))
|
|
self.vmap_inplace_test(Tensor.fill_, args, {}, (1, None))
|
|
|
|
check_vmap_fallback(self, test, Tensor.fill_)
|
|
|
|
def test_conv_double_backward(self, device):
|
|
images = torch.randn(2, 1, 5, 5, device=device)
|
|
weight = torch.randn(2, 1, 2, 2, device=device)
|
|
bias = torch.randn(2, device=device)
|
|
ggI = torch.randn_like(images)
|
|
ggW = torch.randn_like(weight)
|
|
ggb = torch.randn_like(bias)
|
|
stride = (1, 1)
|
|
padding = (0, 0)
|
|
dilation = (1, 1)
|
|
transposed = False
|
|
output_padding = (0, 0)
|
|
groups = 1
|
|
output_mask = (True, True, True)
|
|
gO = torch.randn_like(F.conv2d(images, weight, bias, stride, padding, dilation, groups))
|
|
|
|
args = (
|
|
ggI, ggW, ggb, gO, weight, images, stride, padding, dilation,
|
|
transposed, output_padding, groups, output_mask,
|
|
)
|
|
op = torch.ops.aten._convolution_double_backward
|
|
|
|
generator = get_fallback_and_vmap_exhaustive(op, args, {})
|
|
is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(0) == (8, 6)
|
|
atol, rtol = (1e-3, 1e-3) if is_cuda_sm86 else (1e-4, 1e-4)
|
|
|
|
def test():
|
|
for loop_out, batched_out in generator:
|
|
self.assertEqual(loop_out, batched_out, atol=atol, rtol=rtol)
|
|
|
|
check_vmap_fallback(self, test, op)
|
|
|
|
def test_isnan(self, device):
|
|
test = functools.partial(_vmap_test, check_propagates_grad=False)
|
|
|
|
B, N, C, H, W = 2, 3, 24, 5, 7
|
|
op = torch.isnan
|
|
|
|
x = torch.randn(B, N, C, H, W)
|
|
x[x > 0] = float('nan')
|
|
test(self, op, (x,), in_dims=(0))
|
|
|
|
def test_isinf(self, device):
|
|
test = functools.partial(_vmap_test, check_propagates_grad=False)
|
|
|
|
B, N, C, H, W = 2, 3, 24, 5, 7
|
|
op = torch.isinf
|
|
|
|
x = torch.randn(B, N, C, H, W)
|
|
x[x > 0] = float('inf')
|
|
test(self, op, (x,), in_dims=(0))
|
|
|
|
def test_foo_like(self, device):
|
|
# vfdev-5: Probably, we can remove this line. Flake8 reported as unused
|
|
# test = functools.partial(_vmap_test, check_propagates_grad=False)
|
|
|
|
B, N, C, H, W = 2, 3, 24, 5, 7
|
|
for op in [torch.ones_like, torch.zeros_like]:
|
|
x = torch.randn(B, N, C, H, W)
|
|
# todo(chilli): test these better
|
|
# Not testing correctness, just that they run
|
|
vmap(op, in_dims=(0,))(x,)
|
|
|
|
def test_flatten(self, device):
|
|
test = functools.partial(_vmap_test, check_propagates_grad=False)
|
|
|
|
op = torch.flatten
|
|
|
|
x = torch.randn(2, 3, 4, 5)
|
|
test(self, op, (x, 1, 2), in_dims=(0, None, None))
|
|
|
|
def test_group_norm(self, device):
|
|
test = functools.partial(_vmap_test, check_propagates_grad=False)
|
|
|
|
B, N, C, H, W = 2, 3, 24, 5, 7
|
|
op = F.group_norm
|
|
|
|
x = torch.randn(B, N, C, H, W)
|
|
weight = torch.randn(C)
|
|
bias = torch.randn(C)
|
|
test(self, op, (x, 3, weight, bias), in_dims=(0, None, None, None))
|
|
|
|
x = torch.randn(B, N, C, H, W)
|
|
weight = torch.randn(B, C)
|
|
bias = torch.randn(B, C)
|
|
test(self, op, (x, 4, weight, bias), in_dims=(0, None, 0, 0))
|
|
|
|
def test_index_put(self, device):
|
|
def test(f, t, idx, values):
|
|
base = f(t[0], idx[0], values[0])
|
|
self.assertEqual(vmap(f, in_dims=(0, 0, 0))(t, idx, values)[0], base)
|
|
self.assertEqual(vmap(f, in_dims=(0, None, None))(t, idx[0], values[0])[0], base)
|
|
self.assertEqual(vmap(f, in_dims=(0, None, 0))(t, idx[0], values)[0], base)
|
|
self.assertEqual(vmap(f, in_dims=(0, 0, None))(t, idx, values[0])[0], base)
|
|
|
|
def f(x, y, z):
|
|
x[y] = z
|
|
return x
|
|
|
|
x = torch.randn(3, 4, 5, device=device)
|
|
y = torch.zeros((3, 2), device=device).long()
|
|
z = torch.randn(3, 2, 5, device=device)
|
|
test(f, x, y, z)
|
|
|
|
# indexing innermost dim
|
|
def f(t, idx, values):
|
|
t[:, idx] = values
|
|
return t
|
|
|
|
t = torch.zeros((3, 2, 3))
|
|
values = torch.ones((3, 1, 2))
|
|
idx = torch.tensor([[1, 2]]).expand((3, 2))
|
|
test(f, t, idx, values)
|
|
|
|
# indexing middle dim
|
|
def f(t, idx, values):
|
|
t[:, idx, :] = values
|
|
return t
|
|
|
|
t = torch.zeros((3, 2, 3, 3))
|
|
values = torch.ones((3, 1, 2, 3))
|
|
idx = torch.tensor([[0, 2]]).expand((3, 2))
|
|
test(f, t, idx, values)
|
|
|
|
# indexing with slices
|
|
def f(t, values):
|
|
t[:, :2, :] = values
|
|
return t
|
|
|
|
base = f(t[0], values[0])
|
|
self.assertEqual(vmap(f, in_dims=(0, 0))(t, values)[0], base)
|
|
self.assertEqual(vmap(f, in_dims=(0, None))(t, values[0])[0], base)
|
|
|
|
# index_put_
|
|
tensor = torch.zeros(3, 3, 4)
|
|
value = torch.ones(3, 2)
|
|
idxs = (torch.tensor([[0], [1], [2]]), torch.tensor([[0]]), torch.tensor([1, 2]))
|
|
expected = torch.index_put_(tensor.clone(), idxs, value)
|
|
|
|
def f(t, idx, v):
|
|
torch.index_put_(t, idx, v)
|
|
return t
|
|
|
|
self.assertEqual(vmap(f, in_dims=(0, (None, None), 0))(tensor, idxs[1:], value), expected)
|
|
self.assertEqual(vmap(f, in_dims=(0, (None, None), None))(tensor, idxs[1:], value[0]), expected)
|
|
|
|
@parametrize('training', [True, False])
|
|
@parametrize('track_running_stats', [True, False])
|
|
@parametrize('affine', [True, False])
|
|
def test_batch_norm(self, device, affine, track_running_stats, training):
|
|
if not track_running_stats and not training:
|
|
return
|
|
|
|
test = functools.partial(_vmap_test, check_propagates_grad=False)
|
|
BN = torch.nn.BatchNorm2d
|
|
ensemble_size = 10
|
|
hidden_dim = 3
|
|
|
|
weights, buffers, _, _, _ = \
|
|
functional_init_with_buffers(BN, [ensemble_size])(
|
|
hidden_dim, affine=affine, track_running_stats=track_running_stats)
|
|
|
|
inputs = [torch.randn(ensemble_size, 32, hidden_dim, 16, 16, device=device)]
|
|
in_dims = [0]
|
|
|
|
def append(inp, in_dim):
|
|
inputs.append(inp)
|
|
in_dims.append(in_dim)
|
|
|
|
if track_running_stats:
|
|
running_mean, running_var, _ = buffers
|
|
append(running_mean.to(device), 0)
|
|
append(running_var.to(device), 0)
|
|
else:
|
|
append(None, None)
|
|
append(None, None)
|
|
|
|
if affine:
|
|
weight, bias = weights
|
|
append(weight.to(device), 0)
|
|
append(bias.to(device), 0)
|
|
else:
|
|
append(None, None)
|
|
append(None, None)
|
|
|
|
append(training, None)
|
|
|
|
def op(inp, running_mean, running_var, weight, bias, training):
|
|
res = F.batch_norm(inp, running_mean, running_var, weight, bias, training)
|
|
if track_running_stats:
|
|
return res, running_mean, running_var
|
|
return res
|
|
|
|
test(self, op, tuple(inputs), in_dims=tuple(in_dims))
|
|
|
|
def test_torch_return_types_returns(self, device):
|
|
t = torch.randn(3, 2, 2, device=device)
|
|
self.assertTrue(isinstance(vmap(torch.min, (0, None))(t, 0), torch.return_types.min))
|
|
self.assertTrue(isinstance(vmap(torch.max, (0, None))(t, 0), torch.return_types.max))
|
|
self.assertTrue(isinstance(vmap(torch.topk, (0, None, None))(t, 1, 0), torch.return_types.topk))
|
|
self.assertTrue(isinstance(vmap(torch.linalg.eig, (0))(t), torch.return_types.linalg_eig))
|
|
|
|
def test_namedtuple_returns(self, device):
|
|
Point = namedtuple('Point', ['x', 'y'])
|
|
|
|
def f(x, y):
|
|
return Point(x=x, y=y)
|
|
|
|
x = torch.randn(2, 5, device=device)
|
|
y = torch.randn(2, 3, device=device)
|
|
self.assertTrue(isinstance(vmap(f)(x, y), Point))
|
|
|
|
def test_inplace_on_view(self, device):
|
|
def func(leaf):
|
|
base = leaf * leaf
|
|
view = base.transpose(0, 1)
|
|
view[2:4, 2:4] *= 2
|
|
view[0:2, 0:2].diagonal().sin_()
|
|
view = view[1:3, 1:3]
|
|
view.cos_()
|
|
return view
|
|
|
|
def push_vjp(leaf, gout):
|
|
_, vjp_fn = vjp(func, leaf)
|
|
result, = vjp_fn(gout)
|
|
return result
|
|
|
|
leaf = torch.randn(4, 4, device=device)
|
|
gout = torch.randn(2, 2, device=device)
|
|
args = (leaf, gout)
|
|
|
|
for args, in_dims, _, in generate_vmap_inputs(args, {}):
|
|
if in_dims[1] is None:
|
|
# triggers some composite compliance problem
|
|
continue
|
|
self.vmap_outplace_test(push_vjp, args, {}, in_dims)
|
|
|
|
def test_advanced_indexing(self, device):
|
|
def test(f, args):
|
|
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(f, args, {}):
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
def f(x, idx):
|
|
return x[:, idx]
|
|
|
|
def f2(x, idx):
|
|
return x[idx, :]
|
|
|
|
def f3(x, idx):
|
|
return x[:, :, idx]
|
|
|
|
inps = (torch.randn(5, 5, 5, device=device),
|
|
torch.randn(5, 5, 5, 5, device=device),
|
|
torch.randn(5, 5, 5, 5, 5, device=device))
|
|
idxes = (torch.tensor([0, 1, 2], device=device),
|
|
torch.tensor([0, 1, 2], device=device).reshape(3, 1),
|
|
torch.tensor([0, 1, 2], device=device).reshape(3, 1, 1))
|
|
for (inp, idx) in itertools.product(inps, idxes):
|
|
test(f, (inp, idx))
|
|
test(f2, (inp, idx))
|
|
test(f3, (inp, idx))
|
|
|
|
def test_nested_advanced_indexing(self, device):
|
|
e = torch.rand(7, 4, device=device)
|
|
idx = torch.tensor([0, 1], device=device).view(2, 1)
|
|
|
|
# simple reference implementation for comparison
|
|
def _fake_vmap(f, in_dims=0, out_dims=0):
|
|
def w(input):
|
|
r = [f(input.select(in_dims, i)) for i in range(input.size(in_dims))]
|
|
return torch.stack(r, out_dims)
|
|
|
|
return w
|
|
|
|
def with_vmap(_vmap):
|
|
def g(idx_):
|
|
def f(e_):
|
|
return e_[idx_]
|
|
|
|
return _vmap(f, in_dims=1)(e)
|
|
|
|
r = _vmap(g)(idx)
|
|
return r
|
|
|
|
a = with_vmap(vmap)
|
|
b = with_vmap(_fake_vmap)
|
|
self.assertEqual(a, b)
|
|
|
|
@ops(filter(lambda op: "linalg" in op.name, op_db + additional_op_db), allowed_dtypes=(torch.float,))
|
|
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_linalg_failure_1D_input', {
|
|
xfail('linalg.vector_norm'), # can accept vector inputs
|
|
xfail('linalg.norm'), # can accept vector inputs
|
|
xfail('linalg.norm', 'subgradients_at_zero'), # can accept vector inputs
|
|
skip('linalg.multi_dot'), # accepts list of tensor inputs, has its own special test
|
|
xfail('linalg.vander'),
|
|
xfail('linalg.vecdot'),
|
|
skip('linalg.ldl_solve', ''),
|
|
})
|
|
def test_vmap_linalg_failure_1D_input(self, device, dtype, op):
|
|
for sample in op.sample_inputs(device, dtype, requires_grad=False):
|
|
if sample.input.dim() != 2 or sample.input.shape[0] == 0:
|
|
continue
|
|
test_input = sample.input[0] # using the sample input avoids numerical inconsistency issues
|
|
with self.assertRaisesRegex(RuntimeError, "dimension"):
|
|
op(test_input, *sample.args, **sample.kwargs)
|
|
|
|
def op_wrapper(inp):
|
|
return op(inp, *sample.args, **sample.kwargs)
|
|
|
|
# square inputs are more likely to pass linalg checks
|
|
test_input = test_input.expand(test_input.shape[0], test_input.shape[0])
|
|
with self.assertRaisesRegex(RuntimeError, "dimension"):
|
|
return vmap(op_wrapper)(test_input)
|
|
|
|
def test_vmap_multi_dot_failure_1D_input(self):
|
|
# special exception for first and last tensors so making giving 3 items avoids special cases
|
|
inputs = (torch.randn(3, 3), torch.randn(3), torch.randn(3, 3))
|
|
with self.assertRaisesRegex(RuntimeError, "tensor 1 must be 2D but got 1D"):
|
|
torch.linalg.multi_dot(inputs)
|
|
|
|
# square inputs are more likely to pass linalg checks
|
|
inputs = tuple(i.expand(i.shape[0], i.shape[0]) for i in inputs)
|
|
with self.assertRaisesRegex(RuntimeError, "tensor 1 must be 2D but got 1D"):
|
|
return vmap(torch.linalg.multi_dot)(inputs)
|
|
|
|
|
|
class TestRandomness(TestCase):
|
|
def _reset_random(self, generator, orig_state, use_generator, seed):
|
|
return generator.set_state(orig_state) if use_generator else torch.manual_seed(seed)
|
|
|
|
def _get_image(self, batched_input, batch_size, device):
|
|
if batched_input == "first":
|
|
return torch.ones([batch_size, 3, 3, 14, 14], device=device)
|
|
if batched_input == "last":
|
|
return torch.ones([3, 3, 14, 14, batch_size], device=device)
|
|
assert batched_input == "none"
|
|
return torch.ones([3, 3, 14, 14], device=device)
|
|
|
|
def _assert_all_slices_equal(self, tensor):
|
|
expected = tensor[0]
|
|
self.assertTrue((tensor == expected).all())
|
|
|
|
def _assert_all_slices_unique(self, tensor):
|
|
B0 = tensor.shape[0]
|
|
slices_equal = vmap(vmap(lambda x, y: (x == y).all(), (0, None)), (None, 0))(tensor, tensor)
|
|
assert slices_equal.shape == (B0, B0)
|
|
slices_equal.diagonal().zero_()
|
|
self.assertEqual(slices_equal, torch.zeros_like(slices_equal))
|
|
|
|
def _assert_throws_in_error_mode(self, fn, args, in_dims):
|
|
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
|
|
vmap(fn, in_dims=in_dims, randomness="error")(*args)
|
|
|
|
def _assert_throws_in_different_mode_inplace(self, fn, args, in_dims):
|
|
with self.assertRaisesRegex(RuntimeError, r"different inplace randomness on an unbatched tensor"):
|
|
vmap(fn, in_dims=in_dims, randomness="different")(*args)
|
|
|
|
def _assert_throws_in_same_mode_batched(self, fn, args, in_dims):
|
|
with self.assertRaisesRegex(RuntimeError,
|
|
r"Vmap does not currently support same randomness with a batched tensor input"):
|
|
vmap(fn, in_dims=in_dims, randomness="same")(*args)
|
|
|
|
def _in_dims(self, *batched_strings):
|
|
|
|
def get_in_dim(batched_string):
|
|
if batched_string == "first":
|
|
return 0
|
|
if batched_string == "last":
|
|
return -1
|
|
assert batched_string == "none"
|
|
return None
|
|
|
|
batched_strings = batched_strings + ("first",) # for the always batched as first dim dummy argument
|
|
return tuple(get_in_dim(batched_string) for batched_string in batched_strings)
|
|
|
|
@parametrize('randomness', ['same', 'different', 'error'])
|
|
@parametrize('use_generator', [True, False])
|
|
def test_factory_ops(self, device, randomness, use_generator):
|
|
generator = torch.Generator(device=device)
|
|
orig_state = generator.get_state()
|
|
kwargs = {'device': device, 'generator': generator} if use_generator else {'device': device}
|
|
ops = [
|
|
lambda _, shape: torch.randn(shape, **kwargs),
|
|
lambda _, shape: torch.rand(shape, **kwargs),
|
|
lambda _, shape: torch.randint(100, shape, **kwargs),
|
|
lambda _, shape: torch.randint(5, 100, shape, **kwargs),
|
|
lambda _, shape: torch.normal(0., 1., shape, **kwargs),
|
|
]
|
|
B0 = 4
|
|
shape = (3, 3)
|
|
seed = 1234567
|
|
|
|
for op in ops:
|
|
passed = torch.randn(B0, device=device)
|
|
if randomness == 'error':
|
|
self._assert_throws_in_error_mode(op, (passed, shape), in_dims=(0, None))
|
|
return
|
|
|
|
generator = self._reset_random(generator, orig_state, use_generator, seed)
|
|
vmap_result = vmap(op, in_dims=(0, None), randomness=randomness)(passed, shape)
|
|
|
|
generator = self._reset_random(generator, orig_state, use_generator, seed)
|
|
if randomness == "different":
|
|
expected = op(passed, [B0, *shape])
|
|
self._assert_all_slices_unique(vmap_result)
|
|
self.assertEqual(vmap_result, expected)
|
|
else:
|
|
expected = op(passed, shape)
|
|
self._assert_all_slices_equal(vmap_result)
|
|
for i in range(B0):
|
|
self.assertEqual(vmap_result[i], expected)
|
|
|
|
@parametrize('randomness', ['same', 'different', 'error'])
|
|
@parametrize('use_generator', [True, False])
|
|
def test_randperm(self, device, randomness, use_generator):
|
|
# needs a special case because randperm doesn't take a batch size
|
|
B0 = 4
|
|
seed = 1234567
|
|
passed = torch.randn(B0, device=device)
|
|
|
|
torch.manual_seed(seed)
|
|
generator = torch.Generator(device=device)
|
|
orig_state = generator.get_state()
|
|
|
|
kwargs = {'device': device, 'generator': generator} if use_generator else {'device': device}
|
|
|
|
if randomness == 'error':
|
|
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
|
|
vmap(lambda _: torch.randperm(10, **kwargs), randomness=randomness)(passed)
|
|
return
|
|
|
|
vmap_result = vmap(lambda _: torch.randperm(10, **kwargs), randomness=randomness)(passed)
|
|
generator = generator.set_state(orig_state)
|
|
torch.manual_seed(seed)
|
|
if randomness == 'different':
|
|
for i in range(B0):
|
|
expected = torch.randperm(10, **kwargs)
|
|
self.assertEqual(vmap_result[i], expected)
|
|
else:
|
|
expected = torch.randperm(10, **kwargs)
|
|
for i in range(B0):
|
|
self.assertEqual(vmap_result[i], expected)
|
|
|
|
@parametrize('randomness', ['error', 'same', 'different'])
|
|
@parametrize('batched_input', ["first", "last", "none"])
|
|
def test_dropout(self, device, randomness, batched_input):
|
|
def op(t, ignored):
|
|
return torch.nn.functional.dropout(torch.ones_like(t), training=True)
|
|
|
|
B0 = 4
|
|
always_batched = torch.randn((B0,))
|
|
passed = self._get_image(batched_input, B0, device)
|
|
in_dims = self._in_dims(batched_input)
|
|
|
|
if randomness == 'error':
|
|
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
|
|
vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
|
|
return
|
|
|
|
vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
|
|
|
|
# Check that the randomness is within bounds...
|
|
# ideally this is close to 0.5
|
|
p_estimate = vmap_result.mean() / 2
|
|
self.assertTrue(p_estimate < 0.75)
|
|
self.assertTrue(p_estimate > 0.25)
|
|
|
|
if randomness == 'different':
|
|
self._assert_all_slices_unique(vmap_result)
|
|
return
|
|
|
|
assert randomness == 'same'
|
|
self._assert_all_slices_equal(vmap_result)
|
|
|
|
@parametrize('randomness', ['error', 'same', 'different'])
|
|
@parametrize('batched_input', ["first", "last", "none"])
|
|
def test_alpha_dropout(self, device, randomness, batched_input):
|
|
def op(t, ignored):
|
|
return torch.nn.functional.alpha_dropout(torch.ones_like(t), training=True)
|
|
|
|
B0 = 4
|
|
always_batched = torch.randn((B0,))
|
|
passed = self._get_image(batched_input, B0, device)
|
|
in_dims = self._in_dims(batched_input)
|
|
|
|
if randomness == 'error':
|
|
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
|
|
vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
|
|
return
|
|
|
|
# I have no clue how to actually test corectness of alpha dropout because the docs
|
|
# seem wrong: https://github.com/pytorch/pytorch/issues/74004
|
|
vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
|
|
if randomness == 'different':
|
|
self._assert_all_slices_unique(vmap_result)
|
|
return
|
|
|
|
assert randomness == 'same'
|
|
self._assert_all_slices_equal(vmap_result)
|
|
|
|
@parametrize('randomness', ['error', 'same', 'different'])
|
|
@parametrize('batched_input', ["first", "last", "none"])
|
|
@parametrize('dim', [2, 3])
|
|
def test_feature_dropout(self, device, randomness, batched_input, dim):
|
|
def op(t, ignored):
|
|
f = torch.nn.functional.dropout2d if dim == 2 else torch.nn.functional.dropout3d
|
|
return f(torch.ones_like(t), training=True)
|
|
|
|
B0 = 4
|
|
always_batched = torch.randn((B0,))
|
|
passed = self._get_image(batched_input, B0, device)
|
|
if dim == 3:
|
|
unsqueeze_dim = -2 if batched_input == "last" else -1
|
|
passed = passed.unsqueeze(unsqueeze_dim)
|
|
in_dims = self._in_dims(batched_input)
|
|
|
|
if randomness == 'error':
|
|
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
|
|
vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
|
|
return
|
|
|
|
vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
|
|
|
|
# Check that the randomness is within bounds...
|
|
# ideally this is close to 0.5
|
|
p_estimate = vmap_result.mean() / 2
|
|
self.assertTrue(p_estimate < 0.75)
|
|
self.assertTrue(p_estimate > 0.25)
|
|
|
|
# Check the "feature" pattern
|
|
dims = [-1, -2] if dim == 2 else [-1, -2, -3]
|
|
planes_numel = 2 * vmap_result.numel() / (vmap_result.shape[0] * vmap_result.shape[1] * vmap_result.shape[2])
|
|
planes = vmap_result.sum(dims)
|
|
result = (planes == 0) ^ (planes == planes_numel)
|
|
self.assertEqual(result, torch.ones_like(result, dtype=torch.bool))
|
|
|
|
if randomness == 'different':
|
|
self._assert_all_slices_unique(vmap_result)
|
|
return
|
|
|
|
assert randomness == 'same'
|
|
self._assert_all_slices_equal(vmap_result)
|
|
|
|
@parametrize('randomness', ['error', 'same', 'different'])
|
|
@parametrize('batched_input', ["first", "last", "none"])
|
|
def test_feature_alpha_dropout(self, device, randomness, batched_input):
|
|
def op(t, ignored):
|
|
return torch.nn.functional.feature_alpha_dropout(torch.ones_like(t), training=True)
|
|
|
|
B0 = 4
|
|
always_batched = torch.randn((B0,))
|
|
passed = self._get_image(batched_input, B0, device)
|
|
unsqueeze_dim = -2 if batched_input == "last" else -1
|
|
passed = passed.unsqueeze(unsqueeze_dim)
|
|
in_dims = self._in_dims(batched_input)
|
|
|
|
if randomness == 'error':
|
|
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
|
|
vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
|
|
return
|
|
|
|
vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
|
|
|
|
# I have no clue how to actually test corectness of alpha dropout because the docs
|
|
# seem wrong: https://github.com/pytorch/pytorch/issues/74004
|
|
|
|
# Check the "feature" pattern
|
|
dims = [-1, -2, -3]
|
|
planes = vmap_result.sum(dims)
|
|
max_elt = planes.max()
|
|
min_elt = planes.min()
|
|
result = (planes == min_elt) ^ (planes == max_elt)
|
|
self.assertEqual(result, torch.ones_like(result, dtype=torch.bool))
|
|
|
|
if randomness == 'different':
|
|
self._assert_all_slices_unique(vmap_result)
|
|
return
|
|
|
|
assert randomness == 'same'
|
|
self._assert_all_slices_equal(vmap_result)
|
|
|
|
@parametrize('randomness', ['error', 'same', 'different'])
|
|
@parametrize('batched_input', ["first", "last", "none"])
|
|
def test_like_functions(self, device, randomness, batched_input):
|
|
seed = 1234567
|
|
supported_ops = [
|
|
lambda t, _: torch.randint_like(t, 20),
|
|
lambda t, _: torch.randint_like(t, 0, 20),
|
|
lambda t, _: torch.rand_like(t),
|
|
lambda t, _: torch.randn_like(t),
|
|
]
|
|
B0 = 4
|
|
|
|
for op in supported_ops:
|
|
always_batched = torch.randn(B0)
|
|
passed = self._get_image(batched_input, B0, device)
|
|
in_dims = self._in_dims(batched_input)
|
|
|
|
if randomness == 'error':
|
|
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
|
|
vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched)
|
|
return
|
|
|
|
torch.manual_seed(seed)
|
|
vmap_result = vmap(op, randomness=randomness, in_dims=in_dims)(passed, always_batched)
|
|
|
|
torch.manual_seed(seed)
|
|
|
|
if batched_input == "last":
|
|
passed = passed.movedim(-1, 0)
|
|
if randomness == 'different':
|
|
if batched_input == "none":
|
|
passed = passed.expand(B0, *passed.shape)
|
|
expected = op(passed, 0)
|
|
|
|
self._assert_all_slices_unique(vmap_result)
|
|
self.assertEqual(expected, vmap_result)
|
|
return
|
|
|
|
assert randomness == 'same'
|
|
if batched_input != "none":
|
|
passed = passed[0]
|
|
expected = op(passed, 0)
|
|
self._assert_all_slices_equal(vmap_result)
|
|
for i in range(B0):
|
|
self.assertEqual(expected, vmap_result[i])
|
|
|
|
@parametrize('use_generator', [True, False])
|
|
@parametrize('randomness', ['error', 'same', 'different'])
|
|
@parametrize('batched_input', ["first", "last", "none"])
|
|
def test_random_unary_inplace(self, device, use_generator, randomness, batched_input):
|
|
generator = torch.Generator(device=device)
|
|
orig_state = generator.get_state()
|
|
kwargs = {'generator': generator} if use_generator else {}
|
|
ops = [
|
|
lambda t, _: t.random_(**kwargs),
|
|
lambda t, _: t.random_(100, **kwargs),
|
|
lambda t, _: t.random_(-5, 100, **kwargs),
|
|
lambda t, _: t.normal_(**kwargs),
|
|
lambda t, _: t.bernoulli_(**kwargs),
|
|
lambda t, _: t.cauchy_(**kwargs),
|
|
lambda t, _: t.exponential_(**kwargs),
|
|
lambda t, _: t.geometric_(0.5, **kwargs),
|
|
lambda t, _: t.log_normal_(**kwargs),
|
|
lambda t, _: t.uniform_(**kwargs),
|
|
]
|
|
B0 = 4
|
|
seed = 1234567
|
|
in_dims = self._in_dims(batched_input)
|
|
|
|
for op in ops:
|
|
# because of in place updates, clone inputs
|
|
always_batched = torch.randn(B0, device=device)
|
|
passed = self._get_image(batched_input, B0, device)
|
|
passed_expected = passed.clone()
|
|
|
|
if randomness == 'error':
|
|
self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims)
|
|
return
|
|
if randomness == 'different' and batched_input == "none":
|
|
self._assert_throws_in_different_mode_inplace(op, (passed, always_batched), in_dims=in_dims)
|
|
return
|
|
|
|
generator = self._reset_random(generator, orig_state, use_generator, seed)
|
|
vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched)
|
|
|
|
if batched_input == "last":
|
|
passed_expected = passed_expected.movedim(-1, 0)
|
|
generator = self._reset_random(generator, orig_state, use_generator, seed)
|
|
if randomness == "different":
|
|
expected = op(passed_expected, always_batched)
|
|
self._assert_all_slices_unique(vmap_result)
|
|
self.assertEqual(vmap_result, expected)
|
|
else:
|
|
if batched_input != "none":
|
|
passed_expected = passed_expected[0].clone() # bug in pytorch, normal_ on views doesn't work
|
|
expected = op(passed_expected, always_batched)
|
|
self._assert_all_slices_equal(vmap_result)
|
|
for i in range(B0):
|
|
self.assertEqual(vmap_result[i], expected)
|
|
|
|
@parametrize('use_generator', [True, False])
|
|
@parametrize('randomness', ['error', 'same', 'different'])
|
|
@parametrize('batched_input', ["first", "last", "none"])
|
|
@parametrize('batched_probability', ["first", "last", "none"])
|
|
def test_bernoulli_in_place(self, device, use_generator, randomness, batched_input, batched_probability):
|
|
B0 = 4
|
|
seed = 1234567
|
|
generator = torch.Generator(device=device)
|
|
orig_state = generator.get_state()
|
|
kwargs = {'generator': generator} if use_generator else {}
|
|
in_dims = self._in_dims(batched_input, batched_probability)
|
|
|
|
def op(t, p, ignored):
|
|
return t.bernoulli_(p, **kwargs)
|
|
|
|
# because of in place updates, clone inputs
|
|
always_batched = torch.randn(B0, device=device)
|
|
input = self._get_image(batched_input, B0, device)
|
|
input_expected = input.clone()
|
|
probability = self._get_image(batched_probability, B0, device) - 0.5
|
|
|
|
if randomness == 'error':
|
|
self._assert_throws_in_error_mode(op, (input, probability, always_batched), in_dims=in_dims)
|
|
return
|
|
if randomness == 'same' and batched_probability != "none":
|
|
self._assert_throws_in_same_mode_batched(op, (input, probability, always_batched), in_dims=in_dims)
|
|
return
|
|
if batched_input == "none" and batched_probability != "none":
|
|
regex = r"there exists a Tensor `other` in extra_args that has more elements than `self`"
|
|
with self.assertRaisesRegex(RuntimeError, regex):
|
|
vmap(op, in_dims=in_dims, randomness=randomness)(input, probability, always_batched)
|
|
return
|
|
if randomness == 'different' and batched_input == "none":
|
|
self._assert_throws_in_different_mode_inplace(op, (input, probability, always_batched), in_dims=in_dims)
|
|
return
|
|
|
|
self._reset_random(generator, orig_state, use_generator, seed)
|
|
vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(input, probability, always_batched)
|
|
|
|
self._reset_random(generator, orig_state, use_generator, seed)
|
|
if batched_input == "last":
|
|
input_expected = input_expected.movedim(-1, 0)
|
|
if batched_probability == "last":
|
|
probability = probability.movedim(-1, 0)
|
|
if randomness == "different":
|
|
expected = op(input_expected, probability, always_batched)
|
|
self._assert_all_slices_unique(vmap_result)
|
|
self.assertEqual(vmap_result, expected)
|
|
else:
|
|
if batched_input != "none":
|
|
input_expected = input_expected[0]
|
|
expected = op(input_expected, probability, always_batched)
|
|
self._assert_all_slices_equal(vmap_result)
|
|
for i in range(B0):
|
|
self.assertEqual(vmap_result[i], expected)
|
|
|
|
@parametrize('use_generator', [True, False])
|
|
@parametrize('randomness', ['error', 'same', 'different'])
|
|
@parametrize('batched_input', ["first", "last", "none"])
|
|
@parametrize('batched_other', ["first", "last", "none"])
|
|
def test_random_binary_out_of_place(self, device, use_generator, randomness, batched_input, batched_other):
|
|
generator = torch.Generator(device=device)
|
|
orig_state = generator.get_state()
|
|
kwargs = {'generator': generator} if use_generator else {}
|
|
ops = [
|
|
lambda t, o, _: torch.normal(t, o, **kwargs),
|
|
lambda t, o, _: torch.binomial(t, (o - 0.5), **kwargs),
|
|
]
|
|
|
|
B0 = 4
|
|
seed = 1234567
|
|
in_dims = self._in_dims(batched_input, batched_other)
|
|
|
|
for op in ops:
|
|
always_batched = torch.randn(B0, device=device)
|
|
input = self._get_image(batched_input, B0, device)
|
|
other = self._get_image(batched_other, B0, device)
|
|
|
|
if randomness == 'error':
|
|
self._assert_throws_in_error_mode(op, (input, other, always_batched), in_dims=in_dims)
|
|
return
|
|
if randomness == 'same' and (batched_input != "none" or batched_other != "none"):
|
|
self._assert_throws_in_same_mode_batched(op, (input, other, always_batched), in_dims=in_dims)
|
|
return
|
|
|
|
generator = self._reset_random(generator, orig_state, use_generator, seed)
|
|
vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(input, other, always_batched)
|
|
|
|
if batched_input == "last":
|
|
input = input.movedim(-1, 0)
|
|
if batched_other == "last":
|
|
other = other.movedim(-1, 0)
|
|
|
|
generator = self._reset_random(generator, orig_state, use_generator, seed)
|
|
if randomness == "different":
|
|
if batched_input == "none":
|
|
input = input.expand(B0, *input.shape)
|
|
expected = op(input, other, always_batched)
|
|
self._assert_all_slices_unique(vmap_result)
|
|
self.assertEqual(vmap_result, expected)
|
|
else:
|
|
assert batched_input == "none" and batched_other == "none"
|
|
expected = op(input, other, always_batched)
|
|
self._assert_all_slices_equal(vmap_result)
|
|
for i in range(B0):
|
|
self.assertEqual(vmap_result[i], expected)
|
|
|
|
@parametrize('use_generator', [True, False])
|
|
@parametrize('randomness', ['error', 'same', 'different'])
|
|
@parametrize('batched_input', ["first", "last", "none"])
|
|
def test_random_unary_out_of_place(self, device, use_generator, randomness, batched_input):
|
|
generator = torch.Generator(device=device)
|
|
orig_state = generator.get_state()
|
|
kwargs = {'generator': generator} if use_generator else {}
|
|
ops = [
|
|
lambda t, _: torch.normal(0., torch.abs(t), **kwargs),
|
|
lambda t, _: torch.normal(t, 1., **kwargs),
|
|
lambda t, _: torch.bernoulli(t - 0.5, **kwargs),
|
|
lambda t, _: torch.bernoulli(t, 0.5, **kwargs),
|
|
lambda t, _: torch._standard_gamma(t, **kwargs),
|
|
lambda t, _: torch._sample_dirichlet(t, **kwargs),
|
|
lambda t, _: torch.poisson(t, **kwargs),
|
|
]
|
|
|
|
B0 = 4
|
|
seed = 1234567
|
|
in_dims = self._in_dims(batched_input)
|
|
|
|
for op in ops:
|
|
always_batched = torch.randn(B0, device=device)
|
|
passed = self._get_image(batched_input, B0, device)
|
|
if randomness == 'error':
|
|
self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims)
|
|
return
|
|
if randomness == 'same' and batched_input != "none":
|
|
self._assert_throws_in_same_mode_batched(op, (passed, always_batched), in_dims=in_dims)
|
|
return
|
|
|
|
generator = self._reset_random(generator, orig_state, use_generator, seed)
|
|
vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched)
|
|
|
|
generator = self._reset_random(generator, orig_state, use_generator, seed)
|
|
if randomness == "different":
|
|
if batched_input == "none":
|
|
passed = passed.expand(B0, *passed.shape)
|
|
if batched_input == "last":
|
|
passed = passed.movedim(-1, 0)
|
|
expected = op(passed, always_batched)
|
|
self._assert_all_slices_unique(vmap_result)
|
|
self.assertEqual(vmap_result, expected)
|
|
else:
|
|
expected = op(passed, always_batched)
|
|
self._assert_all_slices_equal(vmap_result)
|
|
for i in range(B0):
|
|
self.assertEqual(vmap_result[i], expected)
|
|
|
|
@parametrize('use_generator', [True, False])
|
|
@parametrize('randomness', ['error', 'same', 'different'])
|
|
@parametrize('batched_call', [True, False])
|
|
@parametrize('batched_input', ["first", "last", "none"])
|
|
def test_multinomial(self, device, use_generator, randomness, batched_call, batched_input):
|
|
def flatten_input(input, batch_call, batch_location):
|
|
if batch_call and batch_location != "none":
|
|
final_size = 3 # [B0, B, N]
|
|
elif not batch_call and batch_location == "none":
|
|
final_size = 1 # [N]
|
|
else:
|
|
final_size = 2 # [B0, N] or [B, N]
|
|
|
|
start_idx = final_size - 1
|
|
end_idx = -1
|
|
if batch_location == "last":
|
|
start_idx -= 1
|
|
end_idx -= 1 # gets to correct final size because using negative indices
|
|
|
|
ret = input.flatten(start_idx, end_idx)
|
|
assert ret.dim() == final_size
|
|
return ret
|
|
|
|
def op(input, _):
|
|
return torch.multinomial(input, 10, **kwargs)
|
|
|
|
generator = torch.Generator(device=device)
|
|
orig_state = generator.get_state()
|
|
kwargs = {'generator': generator} if use_generator else {}
|
|
|
|
B0 = 4
|
|
seed = 1234567
|
|
in_dims = self._in_dims(batched_input)
|
|
|
|
always_batched = torch.randn(B0, device=device)
|
|
passed = self._get_image(batched_input, B0, device)
|
|
passed = flatten_input(passed, batched_call, batched_input)
|
|
if randomness == 'error':
|
|
self._assert_throws_in_error_mode(op, (passed, always_batched), in_dims=in_dims)
|
|
return
|
|
if randomness == 'same' and batched_input != "none":
|
|
self._assert_throws_in_same_mode_batched(op, (passed, always_batched), in_dims=in_dims)
|
|
return
|
|
|
|
generator = self._reset_random(generator, orig_state, use_generator, seed)
|
|
vmap_result = vmap(op, in_dims=in_dims, randomness=randomness)(passed, always_batched)
|
|
|
|
generator = self._reset_random(generator, orig_state, use_generator, seed)
|
|
|
|
if randomness == "different":
|
|
if batched_input == "none":
|
|
passed = passed.expand(B0, *passed.shape)
|
|
if batched_input == "last":
|
|
passed = passed.movedim(-1, 0)
|
|
orig_passed_size = passed.shape[:2] if batched_call else passed.shape[:1]
|
|
passed = passed.flatten(0, 1) if batched_call else passed
|
|
expected = op(passed, always_batched)
|
|
expected = expected.reshape(*orig_passed_size, 10)
|
|
self._assert_all_slices_unique(vmap_result)
|
|
self.assertEqual(vmap_result, expected)
|
|
else:
|
|
expected = op(passed, always_batched)
|
|
self._assert_all_slices_equal(vmap_result)
|
|
for i in range(B0):
|
|
self.assertEqual(vmap_result[i], expected)
|
|
|
|
def test_unsupported_random(self, device):
|
|
x = torch.randn(3, device=device)
|
|
y = x.abs()
|
|
z = x.abs()
|
|
with self.assertRaisesRegex(RuntimeError, "calling out variants"):
|
|
def f(x):
|
|
return torch.randn(3, device=device, out=y)
|
|
vmap(f, randomness='same')(x)
|
|
with self.assertRaisesRegex(RuntimeError, "calling out variants"):
|
|
def f(x0, x1):
|
|
return torch.normal(x, y, out=x)
|
|
vmap(f, randomness='same')(z, z)
|
|
with self.assertRaisesRegex(RuntimeError, "do not yet support"):
|
|
def f(z):
|
|
return torch.rrelu(x)
|
|
vmap(f, randomness='same')(z)
|
|
|
|
@parametrize('in_dim', [0, 1, 2])
|
|
@parametrize('out_dim', [0, 1, 2])
|
|
def test_chunk_vmap(self, in_dim, out_dim):
|
|
|
|
randomness = "different"
|
|
|
|
x = torch.randn(4, 5, 6)
|
|
|
|
def f(x):
|
|
y = x.sin() + torch.rand_like(x)
|
|
return y
|
|
|
|
for chunks in [1, 2, 3, 4, 7, 10, 16]:
|
|
output = chunk_vmap(
|
|
f, in_dims=in_dim, out_dims=out_dim, randomness=randomness, chunks=chunks
|
|
)(x)
|
|
self._assert_all_slices_unique(output)
|
|
|
|
|
|
def test_jacfwd_with_random(self):
|
|
# checks on behavior are above, this just checks that jacfwd respects
|
|
# the randomness param
|
|
|
|
x = torch.rand(3, 4)
|
|
with self.assertRaisesRegex(RuntimeError, r"called random operation while in randomness error mode"):
|
|
jacfwd(torch.bernoulli)(x)
|
|
|
|
# x isn't batched so use bernoulli since it doesn't do inplace randomness
|
|
jacfwd(torch.bernoulli, randomness="same")(x)
|
|
jacfwd(torch.bernoulli, randomness="different")(x)
|
|
|
|
|
|
class TestTransformFailure(TestCase):
|
|
@parametrize('transform', ['vmap', 'grad', 'grad_and_value', 'vjp', 'jvp', 'jacrev', 'jacfwd'])
|
|
def test_fails_with_autograd_function(self, device, transform):
|
|
class Test(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(_, input):
|
|
return input
|
|
|
|
@staticmethod
|
|
def backward(_, grad_input):
|
|
return grad_input
|
|
|
|
transform = getattr(functorch, transform)
|
|
|
|
def f(x):
|
|
return Test.apply(x)
|
|
|
|
if transform == grad or transform == grad_and_value:
|
|
input = torch.tensor(4.)
|
|
else:
|
|
input = torch.randn(5)
|
|
|
|
if transform == vjp:
|
|
transform = functools.partial(transform, f)
|
|
elif transform == jvp:
|
|
input = (input,)
|
|
transform = functools.partial(transform, f, input)
|
|
else:
|
|
transform = transform(f)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "autograd.Function"):
|
|
transform(input)
|
|
|
|
only_for = ("cpu", "cuda")
|
|
instantiate_device_type_tests(TestVmapOperatorsOpInfo, globals(), only_for=only_for)
|
|
|
|
instantiate_device_type_tests(
|
|
TestVmapBatchedGradient,
|
|
globals(),
|
|
only_for=only_for,
|
|
)
|
|
instantiate_device_type_tests(TestTransformFailure, globals(), only_for=only_for)
|
|
instantiate_device_type_tests(TestRandomness, globals(), only_for=only_for)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|