pytorch/test/functorch/test_ops.py
drisspg fb26b84390 Update fused kernels and call _safe_softmax from SDPA (#133882)
# UPDATE:
This is  take 3 of https://github.com/pytorch/pytorch/pull/131863 which was landed via co dev but not applying correclty

# Summary
Changes the stance of SDPA on what to do for fully masked out rows

## Current Behavior
Several PyTorch users have expressed frustration over this issue:
- https://github.com/pytorch/pytorch/issues/41508
- https://github.com/pytorch/pytorch/issues/103749
- https://github.com/pytorch/pytorch/issues/103963

These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here:
https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617

Can be paraphrased as follows:

When passing in fully masked out rows, attention becomes ambiguous. We have two main options:

1. Uniformly attend to all values:
   ```python
   scores[masked_out_rows] = 1 / len(row)
   out[masked_out_rows] = 1 / len(row) * value
   ```

2. Decide that attention between no queries (masked) and no keys (masked) is meaningless:
   ```python
   output[fully_masked_rows] = NaN
   ```

We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs:
``` Python
>fill_value = -float("inf")
>row0 = torch.randn(4)
>row1 = torch.tensor([(fill_value for _ in range(4)])
>matrix = torch.stack([row0, row1]).requires_grad_(True)
>out = torch.softmax(matrix, 1)
>out = out[0]
>print(out)
tensor([0.5377, 0.2729, 0.0692, 0.1201])
```
Cool, problem solved. But what happends when you call backwards..
```Python
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08],
        [       nan,        nan,        nan,        nan]])
```
Those pesky NaNs are back!

## Why do we see NaNs today?

The core of the problem revolves around using softmax function in sdpa:

```python
> row = torch.tensor([(-float("inf")) for _ in range(4)])
> torch.softmax(row, 0)
tensor([nan, nan, nan, nan])
```

## Quick Aside: Masking in Attention

Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs.

We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values.

## Alternative Approaches

If we use a very large negative number instead of -inf:

```python
> row = torch.tensor([(-1e6) for _ in range(4)])
> torch.softmax(row, 0)
tensor([0.2500, 0.2500, 0.2500, 0.2500])
```
However if users always remembered to "slice" out their outputs i.e.:
```Python
>fill_value = -1e6
>...
>out.backward(torch.ones_like(out))
>print(matrix.grad)
tensor([[-0.0563, -0.0564,  0.1613, -0.0486],
        [ 0.0000,  0.0000,  0.0000,  0.0000]])
```
This would bring us back into a better state.

## A Third Option

We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation.

This PR implements the new semantic for masking w/ attention in fully masked-out rows:
```python
out[masked_out_rows] = 0
```

**Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption.

## Details
This PR stack does 3 things:
1. Adds a PRIVATE _safe_softmax op
2. Updates semantic for flash_cpu fused kernel
3. Updates semantic for efficient_cuda fused kernel

_safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num.

Why I think this is okay? (please find a counter point if avail)
There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them?

The only case that this can happen is if the input itself had a NaN or an Inf
For example:
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = torch.finfo(torch.float16).max
print(a.softmax(-1))
```
Will return
`tensor([0., 1., 0., 0.], dtype=torch.float16)`

Where
```Python
a = torch.ones([4], requires_grad=False, dtype=torch.float16)
a[1] = float("inf")
a.softmax(-1)
```
returns:
`tensor([nan, nan, nan, nan], dtype=torch.float16)`

If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this

```Python
max = torch.max(a, dim=-1, keepdim=True)
exp = torch.exp(a - max.values)
denom = torch.sum(exp, dim=-1, keepdim=True)
softmax = exp / denom
softmax = torch.where(max.values == float('-inf'), 0.0, softmax)
```
however we would be paying for this in math performance.

## Why Now
I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic.

Differential Revision: [D61418679](https://our.internmc.facebook.com/intern/diff/D61418679)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133882
Approved by: https://github.com/soulitzer
2024-08-19 18:53:11 +00:00

2997 lines
123 KiB
Python

# Owner(s): ["module: functorch"]
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import functools
import itertools
import unittest
from common_utils import (
check_vmap_fallback,
decorate,
expectedFailureIf,
generate_vmap_inputs,
get_fallback_and_vmap_exhaustive,
is_batch_norm_training,
is_valid_inplace_sample_input,
loop,
loop2,
opsToleranceOverride,
skip,
skipOps,
tol1,
tol2,
xfail,
)
from functorch_additional_op_db import additional_op_db
import torch
import torch.autograd.forward_ad as fwAD
from functorch import grad, jacfwd, jacrev, vjp, vmap
from torch import Tensor
from torch._functorch.eager_transforms import _as_tuple, jvp
from torch.testing._internal.autograd_function_db import autograd_function_db
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests,
ops,
tol,
toleranceOverride,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_utils import (
is_iterable_of_tensors,
IS_MACOS,
IS_X86,
noncontiguous_like,
parametrize,
run_tests,
runOnRocm,
skipIfRocm,
TEST_WITH_ASAN,
TEST_WITH_ROCM,
TestCase,
unMarkDynamoStrictTest,
)
from torch.testing._internal.opinfo.core import SampleInput
from torch.utils import _pytree as pytree
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
aten = torch.ops.aten
# Version of autograd.grad with some differences:
# - pytree inputs is allowed (but leaves of the pytree have to all
# be tensors)
# - if an input is not used as part of derivatives, we will return a
# zero-filled tensor for the result
def _autograd_grad(
outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
):
inputs, inputs_spec = tree_flatten(inputs)
diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
if grad_outputs is None:
diff_outputs = tuple(out for out in outputs if out.requires_grad)
else:
diff_grad_outputs = [
(out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
]
if len(diff_grad_outputs) == 0:
diff_outputs, grad_outputs = (), ()
else:
diff_outputs, grad_outputs = zip(*diff_grad_outputs)
grad_inputs = torch.autograd.grad(
diff_outputs,
diff_inputs,
grad_outputs,
retain_graph=retain_graph,
create_graph=create_graph,
allow_unused=True,
)
result = []
grad_inputs_iter = iter(grad_inputs)
for inp in inputs:
if inp.requires_grad:
grad_input = next(grad_inputs_iter)
if grad_input is None:
result.append(torch.zeros_like(inp))
else:
result.append(grad_input)
else:
result.append(torch.zeros_like(inp))
return tree_unflatten(result, inputs_spec)
def diff_arg(arg, requires_grad=True):
def is_differentiable_arg(arg):
if requires_grad:
return arg.requires_grad
else:
return arg.is_floating_point() or arg.is_complex()
if is_iterable_of_tensors(arg):
if all(is_differentiable_arg(a) for a in arg):
return True
if all(not is_differentiable_arg(a) for a in arg):
return False
raise RuntimeError("NYI: The test runner can't handle this")
return isinstance(arg, Tensor) and is_differentiable_arg(arg)
# Given f, returns an f' such that:
# - f' takes only positional arguments
# - All arguments to f' are floating-point Tensors
# - All outputs of f' are floating-point Tensors
def normalize_op_input_output2(
f, args, kwargs, output_process_fn_grad=None, requires_grad=True
):
flat_args, args_spec = tree_flatten(args)
diff_argnums = tuple(
i
for i, arg in enumerate(flat_args)
if diff_arg(arg, requires_grad=requires_grad)
)
assert len(diff_argnums) > 0
primals = tuple(flat_args[i] for i in diff_argnums)
@functools.wraps(f)
def wrapped(*primals):
_args = list(flat_args)
for num, arg in zip(diff_argnums, primals):
_args[num] = arg
_args = tree_unflatten(_args, args_spec)
result = f(*_args, **kwargs)
if output_process_fn_grad is not None:
result = output_process_fn_grad(result)
if isinstance(result, tuple):
result = tuple(r for r in result if torch.is_floating_point(r))
assert len(result) > 0
return result
return wrapped, primals
# TODO: consolidate with normalize_op_input_output2
def normalize_op_input_output3(
f, args, kwargs, sample_args, output_process_fn_grad=None
):
flat_args, args_spec = tree_flatten(args)
flat_sample_args = pytree.tree_leaves(sample_args)
diff_argnums = tuple(
i
for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args))
if diff_arg(sample, requires_grad=True)
)
assert len(diff_argnums) > 0
primals = tuple(flat_args[i] for i in diff_argnums)
@functools.wraps(f)
def wrapped(*primals):
_args = list(flat_args)
for num, arg in zip(diff_argnums, primals):
_args[num] = arg
_args = tree_unflatten(_args, args_spec)
result = f(*_args, **kwargs)
if output_process_fn_grad is not None:
result = output_process_fn_grad(result)
if isinstance(result, tuple):
result = tuple(r for r in result if torch.is_floating_point(r))
assert len(result) > 0
return result
return wrapped, primals
def normalize_op_input_output(f, sample, requires_grad=True):
args = tuple([sample.input] + list(sample.args))
return normalize_op_input_output2(
f,
args,
sample.kwargs,
sample.output_process_fn_grad,
requires_grad=requires_grad,
)
def ref_vjp(f, *primals):
result = f(*primals)
def wrapped(cotangents):
return _autograd_grad(_as_tuple(result), primals, _as_tuple(cotangents))
return result, wrapped
def simulate_jvp(f, primals, tangents):
primals_out, tangents_out = torch.autograd.functional.jvp(f, primals, tangents)
return primals_out, tangents_out
def ref_jvp(f, primals, tangents):
with fwAD.dual_level():
duals = tuple(fwAD.make_dual(p, t) for p, t in zip(primals, tangents))
result_duals = f(*duals)
result_duals, spec = tree_flatten(result_duals)
primals_out, tangents_out = zip(*(fwAD.unpack_dual(d) for d in result_duals))
return tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec)
def get_sample_cotangents(f, sample):
fn, primals = normalize_op_input_output(f, sample)
output = fn(*primals)
return tree_map(torch.randn_like, output)
# returns a new function g(*args, *cotangents)
# that computes vjps and (*args, cotangents)
def get_vjp_fn_and_args_with_cotangents(f, sample, cotangents):
args = tuple([sample.input] + list(sample.args))
kwargs = sample.kwargs
flat_args, args_spec = tree_flatten(args)
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
@functools.wraps(f)
def wrapped(*args):
assert len(args) == len(flat_args) + len(flat_cotangents)
actual_args = args[: len(flat_args)]
cotangents = args[len(flat_args) :]
actual_args = tree_unflatten(actual_args, args_spec)
cotangents = tree_unflatten(cotangents, cotangents_spec)
fn, primals = normalize_op_input_output3(
f, actual_args, kwargs, flat_args, sample.output_process_fn_grad
)
_, vjp_fn = vjp(fn, *primals)
return vjp_fn(cotangents)
return wrapped, tuple(flat_args + flat_cotangents)
# Returns a new function g(*args, *cotangents) that computes vjps and
# sample (*args, *cotangents)
def get_vjpfull_variant(f, sample):
fn, primals = normalize_op_input_output(f, sample)
return _get_vjpfull_variant(fn, primals)
def get_vjpfull_variant2(f, args, kwargs):
fn, primals = normalize_op_input_output2(f, args, kwargs)
return _get_vjpfull_variant(fn, primals)
def _get_vjpfull_variant(fn, primals):
result = fn(*primals)
cotangents = _as_tuple(
tree_map(lambda x: torch.randn_like(x, requires_grad=True), result)
)
num_primals = len(primals)
args = (*primals, *cotangents)
@functools.wraps(fn)
def wrapped(*args):
primals = args[:num_primals]
cotangents = args[num_primals:]
result, vjp_fn = vjp(fn, *primals)
if isinstance(result, torch.Tensor):
assert len(cotangents) == 1
cotangents = cotangents[0]
return vjp_fn(cotangents)
return wrapped, args
def get_jvp_variant(f, sample):
# We want this higher-order variant of jvp, so that it can
# be used to wrap vmap
fn, primals = normalize_op_input_output(f, sample, requires_grad=False)
tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals))
@functools.wraps(f)
def wrapped(*args):
tangents = args
primals_out, tangents_out = jvp(fn, primals, tangents)
if isinstance(primals_out, torch.Tensor):
return (primals_out, tangents_out)
else:
flat_primals_out = pytree.tree_leaves(primals_out)
flat_tangents_out = pytree.tree_leaves(tangents_out)
return tuple(flat_primals_out + flat_tangents_out)
return wrapped, tangents
def get_jvp_variant_primals_tangents2(
f, args, kwargs, output_process_fn_grad=None, requires_grad=False
):
fn, primals = normalize_op_input_output2(
f, args, kwargs, output_process_fn_grad, requires_grad
)
tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals))
return _get_jvp_variant(fn, primals, tangents)
def get_jvp_variant_primals_tangents(f, sample):
# We want this higher-order variant of jvp, so that it can
# be used to wrap vmap
fn, primals = normalize_op_input_output(f, sample, requires_grad=False)
tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals))
return _get_jvp_variant(fn, primals, tangents)
def _get_jvp_variant(fn, primals, tangents):
@functools.wraps(fn)
def wrapped(*args):
primals_in = args[: len(primals)]
tangents_in = args[len(primals) :]
primals_out, tangents_out = jvp(fn, primals_in, tangents_in)
if isinstance(primals_out, torch.Tensor):
return (primals_out, tangents_out)
else:
flat_primals_out = pytree.tree_leaves(primals_out)
flat_tangents_out = pytree.tree_leaves(tangents_out)
return tuple(flat_primals_out + flat_tangents_out)
return wrapped, primals + tangents
def is_inplace(op, variant):
if hasattr(variant, "__wrapped__"):
return variant.__wrapped__ is op.get_inplace()
return variant is op.get_inplace()
vjp_fail = {
xfail("tensor_split"), # data_ptr composite compliance
# Very minor accuracy issue on ROCm
decorate("nn.functional.scaled_dot_product_attention", decorator=skipIfRocm),
}
aliasing_ops = {
"T",
"broadcast_to",
"conj",
"contiguous",
"diagonal", # linalg.diagonal is an alias
"expand",
"flatten",
"imag",
"mH", # adjoint is an alias
"mT",
"movedim", # moveaxis is an alias
"narrow",
"permute",
"positive",
# 'ravel', is composite implicit autograd and may call clone
"real",
"reshape",
"resolve_conj",
"resolve_neg",
"select",
"squeeze",
"transpose", # swapdims and swapaxes are aliases
"unflatten",
"unfold",
"unsqueeze",
"view",
"view_as",
"view_as_complex",
"view_as_real",
}
aliasing_ops_list_return = {
"chunks",
"dsplit",
"hsplit",
"split",
"unbind",
"vsplit",
# 'tensor_split' not composite compliant, see vjp_fail
}
skip_noncontig = {
"_batch_norm_with_update",
"as_strided_copy",
}
@unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant")
@unMarkDynamoStrictTest
class TestOperators(TestCase):
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@skipOps(
"TestOperators",
"test_grad",
vjp_fail.union(
{
xfail(
"chalf", "", device_type="cpu"
), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
xfail(
"sparse.sampled_addmm", ""
), # RuntimeError: Sparse CSR tensors do not have strides
xfail(
"sparse.mm", "reduce"
), # RuntimeError: Sparse CSR tensors do not have strides
# Non-contiguous Bugs
#
# AssertionError: Tensor-likes are not close!
xfail("_softmax_backward_data", device_type="cpu"),
xfail("as_strided"),
xfail("as_strided", "partial_views"),
# RuntimeError: !self.requires_grad() || self.is_contiguous()
xfail("as_strided_scatter"),
# RuntimeError: Tensor must have a last dimension with stride 1
xfail("view_as_complex"),
# query: last dimension must be contiguous
# Fused attention kernels require last dim to be contiguous
decorate(
"nn.functional.scaled_dot_product_attention",
decorator=expectedFailureIf(not TEST_WITH_ROCM),
), # Works on ROCm
xfail("torch.ops.aten._flash_attention_forward"),
xfail("torch.ops.aten._efficient_attention_forward"),
# RuntimeError: Expected contiguous tensor, but got
# non-contiguous tensor for argument #2 'grad_output'
decorate(
"_batch_norm_with_update",
decorator=expectedFailureIf(TEST_WITH_ROCM),
device_type="cuda",
),
}
),
)
@opsToleranceOverride(
"TestOperators",
"test_grad",
(
tol1(
"nn.functional.binary_cross_entropy_with_logits",
{torch.float32: tol(atol=1e-04, rtol=1e-04)},
),
tol1("masked.cumprod", {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
tol1("svd_lowrank", {torch.float32: tol(atol=3e-04, rtol=3e-04)}),
tol1(
"linalg.multi_dot",
{torch.float32: tol(atol=1e-05, rtol=8e-04)},
device_type="cuda",
),
tol1(
"linalg.tensorsolve",
{torch.float32: tol(atol=3e-04, rtol=3e-04)},
device_type="cuda",
),
tol1(
"nn.functional.multi_head_attention_forward",
{torch.float32: tol(atol=8e-04, rtol=1e-03)},
),
tol1(
"__rmatmul__",
{torch.float32: tol(atol=3e-04, rtol=3e-04)},
device_type="cuda",
),
tol1(
"matmul",
{torch.float32: tol(atol=3e-04, rtol=3e-04)},
device_type="cuda",
),
tol1(
"pca_lowrank",
{torch.float32: tol(atol=3e-05, rtol=4e-06)},
device_type="cpu",
),
),
)
def test_grad(self, device, dtype, op):
if op.name in vjp_fail:
self.skipTest("Skipped; Expected failures")
return
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
return
samples = op.sample_inputs(device, dtype, requires_grad=True)
if is_inplace(op, op.get_op()):
self.skipTest("Skipped for redundancy. test_vjp handles in-place testing.")
return
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
if op.name not in skip_noncontig:
noncontig_sample = sample.noncontiguous()
noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args)
noncontig_kwargs = noncontig_sample.kwargs
diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg))
assert len(diff_argnums) > 0
diff_args = tuple(args[i] for i in diff_argnums)
def wrapped_fn(*args, **kwargs):
result = op(*args, **kwargs)
if sample.output_process_fn_grad is not None:
result = sample.output_process_fn_grad(result)
def abs_if_complex(t):
if t.dtype.is_complex:
return t.abs()
return t
# Reduce into single value for grad
if isinstance(result, torch.Tensor):
return abs_if_complex(result.sum())
result = sum(abs_if_complex(res.sum()) for res in result)
return result
result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args)
self.assertEqual(result, expected)
if op.name not in skip_noncontig:
result_noncontig = grad(wrapped_fn, diff_argnums)(
*noncontig_args, **noncontig_kwargs
)
self.assertEqual(result_noncontig, expected)
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@skipOps(
"TestOperators",
"test_jvp",
set(
{
# Composite ops that do bad things. Need to be fixed in PyTorch core.
# RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
xfail("tensor_split"),
# BUG: silent incorrectness: runs and produces numerical differences
skip("nn.functional.max_unpool1d"), # fails everywhere except on mac
skip(
"nn.functional.max_unpool2d"
), # fails everywhere except on windows
skip("nn.functional.max_unpool3d"), # fails everywhere except on mac
xfail(
"native_batch_norm"
), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
xfail(
"_native_batch_norm_legit"
), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
xfail(
"_batch_norm_with_update"
), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
xfail("nn.functional.scaled_dot_product_attention"),
xfail("torch.ops.aten._flash_attention_forward"),
xfail("torch.ops.aten._efficient_attention_forward"),
xfail(
"nn.functional.rrelu"
), # in-place test errors out with no formula implemented
xfail(
"NumpyExpMarkDirtyAutogradFunction"
), # TODO: https://github.com/pytorch/pytorch/issues/91280
# --- Non-Contiguous Failures! ---
# This is expected to fail as the operator
# expects last dim to have stride=1
xfail("view_as_complex"),
# BUG
# AssertionError: Tensor-likes are not close!
xfail("as_strided"),
xfail("as_strided", "partial_views"),
xfail("as_strided_scatter"),
decorate(
"linalg.det",
"singular",
decorator=expectedFailureIf(IS_MACOS and IS_X86),
),
}
),
)
@opsToleranceOverride(
"TestOperators",
"test_jvp",
(
tol1(
"nn.functional.conv_transpose3d",
{torch.float32: tol(atol=1e-04, rtol=1.3e-06)},
device_type="cuda",
),
tol1(
"linalg.tensorsolve",
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)},
device_type="cuda",
),
tol1(
"masked.prod",
{torch.float32: tol(atol=1e-05, rtol=1.3e-05)},
device_type="cuda",
),
tol1(
"nn.functional.binary_cross_entropy_with_logits",
{torch.float32: tol(atol=4e-04, rtol=4e-04)},
),
tol1(
"nn.functional.batch_norm", {torch.float32: tol(atol=4e-05, rtol=5e-05)}
),
tol1("nn.functional.conv2d", {torch.float32: tol(atol=4e-05, rtol=5e-05)}),
tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
tol1(
"nn.functional.multi_head_attention_forward",
{torch.float32: tol(atol=6e-05, rtol=2e-05)},
),
tol2(
"linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-5, rtol=2e-5)}
),
),
)
def test_jvp(self, device, dtype, op):
# TODO: get rid of vjp_decomp when we add decomposition support to
# PyTorch's forward-mode ad. Currently the decomposition support only
# works for functorch.jvp
VJP_DECOMP = {
"nn.functional.logsigmoid",
}
if op.name in VJP_DECOMP:
fixme_ref_jvp_local = simulate_jvp
else:
fixme_ref_jvp_local = ref_jvp
if not op.supports_forward_ad and op.name not in VJP_DECOMP:
self.skipTest("Skipped! Forward AD not supported.")
return
samples = op.sample_inputs(device, dtype, requires_grad=True)
outplace_variant = op if not is_inplace(op, op.get_op()) else None
inplace_variant = op.inplace_variant if op.supports_inplace_autograd else None
for sample in samples:
if outplace_variant:
self.jvp_opinfo_test(
outplace_variant,
sample,
sample.output_process_fn_grad,
clone_inputs=False,
fixme_ref_jvp_local=fixme_ref_jvp_local,
test_noncontig=op.name not in skip_noncontig,
)
if is_valid_inplace_sample_input(sample, op, inplace_variant):
self.jvp_opinfo_test(
inplace_variant,
sample,
sample.output_process_fn_grad,
clone_inputs=True,
fixme_ref_jvp_local=fixme_ref_jvp_local,
test_noncontig=op.name not in skip_noncontig,
)
def jvp_opinfo_test(
self,
fn,
sample,
output_process_fn,
clone_inputs,
fixme_ref_jvp_local,
test_noncontig,
):
# NB: we used requires_grad=True to determine where the primals are,
# but don't need that information otherwise
args = (sample.input,) + sample.args
kwargs = sample.kwargs
contig_fn, primals = normalize_op_input_output2(
fn, args, kwargs, output_process_fn, requires_grad=True
)
orig_primals = tree_map(lambda x: x.detach(), primals)
orig_tangents = tree_map(lambda x: torch.randn_like(x), primals)
def maybe_clone_inputs():
if clone_inputs:
primals = tree_map(torch.clone, orig_primals)
tangents = tree_map(torch.clone, orig_tangents)
return primals, tangents
return orig_primals, orig_tangents
primals, tangents = maybe_clone_inputs()
expected_primal_outs, expected_tangent_outs = fixme_ref_jvp_local(
contig_fn, primals, tangents
)
primals, tangents = maybe_clone_inputs()
primal_outs, tangent_outs = jvp(contig_fn, primals, tangents)
self.assertEqual(primal_outs, expected_primal_outs)
self.assertEqual(tangent_outs, expected_tangent_outs)
if test_noncontig:
noncontig_sample = sample.noncontiguous()
noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
noncontig_kwargs = sample.kwargs
noncontig_fn, primals = normalize_op_input_output2(
fn,
noncontig_args,
noncontig_kwargs,
output_process_fn,
requires_grad=True,
)
noncontig_primals = tree_map(lambda x: x.detach(), primals)
noncontig_tangents = tree_map(
lambda x: noncontiguous_like(x), orig_tangents
)
noncontig_primal_outs, noncontig_tangent_outs = jvp(
noncontig_fn, noncontig_primals, noncontig_tangents
)
self.assertEqual(noncontig_primal_outs, expected_primal_outs)
self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@skipOps(
"TestOperators",
"test_vjp",
vjp_fail.union(
{
xfail("sparse.sampled_addmm", ""),
xfail("sparse.mm", "reduce"),
# ---- Non-Contiguous Failures ----
# This is expected to fail as the operator
# expects last dim to have stride=1
xfail("view_as_complex"),
# RuntimeError: query: last dimension must be contiguous
# The fused attention kernels require the last dim to be contiguous
decorate(
"nn.functional.scaled_dot_product_attention",
decorator=expectedFailureIf(not TEST_WITH_ROCM),
), # Works on ROCm
xfail("torch.ops.aten._flash_attention_forward"),
xfail("torch.ops.aten._efficient_attention_forward"),
# BUG
# AssertionError: Tensor-likes are not close!
xfail("as_strided"),
xfail("as_strided_scatter"),
xfail("_softmax_backward_data", device_type="cpu"),
xfail("as_strided", "partial_views"),
}
),
)
@opsToleranceOverride(
"TestOperators",
"test_vjp",
(
tol1(
"nn.functional.conv_transpose3d",
{torch.float32: tol(atol=5e-05, rtol=9e-05)},
device_type="cuda",
),
tol1(
"nn.functional.binary_cross_entropy_with_logits",
{torch.float32: tol(atol=1e-04, rtol=1e-04)},
),
tol1(
"nn.functional.multi_head_attention_forward",
{torch.float32: tol(atol=2e-03, rtol=2e-04)},
),
tol1("__rmatmul__", {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
tol1("matmul", {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
tol2(
"linalg.pinv", "hermitian", {torch.float32: tol(atol=1e-05, rtol=1e-05)}
),
tol1("linalg.tensorsolve", {torch.float32: tol(atol=9e-03, rtol=2e-04)}),
tol1("linalg.multi_dot", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
tol1("svd_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
tol1("pca_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
),
)
def test_vjp(self, device, dtype, op):
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
return
samples = op.sample_inputs(device, dtype, requires_grad=True)
def _test(_op, inplace=False):
for sample in samples:
if inplace and not is_valid_inplace_sample_input(
sample, op, op.inplace_variant
):
continue
fn, primals = normalize_op_input_output(_op, sample)
result = fn(*primals)
cotangents = tree_map(lambda x: torch.randn_like(x), result)
out, vjp_fn = vjp(fn, *primals)
self.assertEqual(out, result)
result_vjps = vjp_fn(cotangents)
_, vjp_fn = ref_vjp(fn, *primals)
expected_vjps = vjp_fn(cotangents)
self.assertEqual(result_vjps, expected_vjps)
if op.name not in skip_noncontig:
noncontig_fn, noncontig_primals = normalize_op_input_output(
_op, sample.noncontiguous()
)
noncontig_cotangents = tree_map(
lambda x: noncontiguous_like(x), cotangents
)
out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
self.assertEqual(out_noncontig, result)
noncontig_result_vjps = vjp_fn(noncontig_cotangents)
self.assertEqual(noncontig_result_vjps, expected_vjps)
_test(op)
for a_op in op.aliases:
_test(a_op)
if op.inplace_variant:
def f(inp, *args, **kwargs):
return op.inplace_variant(inp.clone(), *args, **kwargs)
_test(f, inplace=True)
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@skipOps(
"TestOperators",
"test_vjpvjp",
vjp_fail.union(
{
skip("nn.functional.max_unpool1d"), # silent incorrectness; Flaky
skip("nn.functional.max_unpool2d"), # silent incorrectness; Flaky
xfail("nn.functional.ctc_loss"), # Not Implemented
xfail(
"native_layer_norm", ""
), # Expected a proper Tensor but got None for argument #1 'other'
xfail("sparse.sampled_addmm", ""), # sparse tensors have no strides
xfail("sparse.mm", "reduce"), # sparse tensors have no strides
skip("nn.functional.scaled_dot_product_attention"),
xfail("torch.ops.aten._flash_attention_forward"),
xfail("torch.ops.aten._efficient_attention_forward"),
# AssertionError: Tensor-likes are not close!
# Mismatched elements: 1 / 15 (6.7%)
# Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
# Greatest relative difference: 1.7933241714393998e-06 at index (2, 4) (up to 1.3e-06 allowed)
# The failure occurred for item [0]
xfail("masked.prod"),
}
),
)
@opsToleranceOverride(
"TestOperators",
"test_vjpvjp",
(
tol1(
"nn.functional.conv_transpose3d",
{torch.float32: tol(atol=5e-05, rtol=9e-05)},
device_type="cuda",
),
tol1("prod", {torch.float32: tol(atol=2e-05, rtol=1e-04)}),
tol1("masked.cumprod", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1("cumprod", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1("linalg.vander", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol2(
"linalg.det", "singular", {torch.float32: tol(atol=2e-05, rtol=2e-05)}
),
),
)
def test_vjpvjp(self, device, dtype, op):
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
return
if not op.supports_gradgrad:
self.skipTest("Skipped! Operation does not support gradgrad")
return
samples = op.sample_inputs(device, dtype, requires_grad=True)
def test(_op, inplace=False):
for sample in samples:
if inplace and not is_valid_inplace_sample_input(
sample, op, op.inplace_variant
):
continue
fn, args = get_vjpfull_variant(_op, sample)
result = fn(*args)
cotangents = tree_map(lambda x: torch.randn_like(x), result)
# Compute vjp of vjp
_, vjp_fn = vjp(fn, *args)
result_vjps = vjp_fn(cotangents)
# Compute ref_vjp of vjp. We could have done ref_vjp of ref_vjp,
# but since we're confident that vjp works by itself, this is
# an equivalent way to test that.
_, vjp_fn = ref_vjp(fn, *args)
expected_vjps = vjp_fn(cotangents)
self.assertEqual(result_vjps, expected_vjps)
test(op)
if op.inplace_variant:
def fn(inp, *args, **kwargs):
return op.inplace_variant(inp.clone(), *args, **kwargs)
test(fn, inplace=True)
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@skipOps(
"TestOperators",
"test_vmapvjpvjp",
vjp_fail.union(
{
skip("atleast_1d"), # Takes too long
skip("atleast_2d"), # Takes too long
skip("atleast_3d"), # Takes too long
skip("ormqr"), # Takes too long
xfail("as_strided"), # incorrect output
xfail("as_strided", "partial_views"), # incorrect output
xfail("as_strided_scatter"), # incorrect output
skip("bernoulli"), # calls random op
xfail("bfloat16"), # rank 4 tensor for channels_last
xfail("cdouble"), # rank 4 tensor for channels_last
xfail("cfloat"), # rank 4 tensor for channels_last
xfail("chalf"), # rank 4 tensor for channels_last
xfail("double"), # rank 4 tensor for channels_last
xfail("float"), # rank 4 tensor for channels_last
xfail("half"), # rank 4 tensor for channels_last
xfail(
"NumpyCubeNotComposableAutogradFunction"
), # Not composable autograd.Function
# It looks like you're either (1) calling .item() on a Tensor or
# (2) attempting to use a Tensor in some data-dependent control flow or
# (3) encountering this error in PyTorch internals.
xfail("index_reduce", "prod"),
decorate(
"linalg.householder_product", decorator=runOnRocm
), # works on ROCm
xfail(
# nans
"masked.softmax",
device_type="cpu",
),
xfail(
"nanquantile", device_type="cpu"
), # vmap not implemented for at::equal.
xfail("native_layer_norm"), # vmap: inplace into a regular tensor
# got a batched tensor as input while the running_mean or running_var,
# which will be updated in place, were not batched.
xfail("nn.functional.batch_norm"),
xfail(
"nn.functional.binary_cross_entropy"
), # vmap: inplace into a regular tensor
xfail(
"nn.functional.ctc_loss"
), # derivate not implemented for _ctc_loss_backward
# flaky on ROCM needs investigation
decorate("nn.functional.conv_transpose2d", decorator=skipIfRocm),
skip("nn.functional.dropout"), # calls random op
skip("nn.functional.dropout2d"), # calls random op
skip("nn.functional.dropout3d"), # calls random op
skip("nn.functional.alpha_dropout"), # calls random op
skip(
"nn.functional.feature_alpha_dropout", "with_train"
), # calls random op
skip("nn.functional.fractional_max_pool2d"), # calls random op
skip("nn.functional.fractional_max_pool3d"), # calls random op
xfail("nn.functional.scaled_dot_product_attention"), # randomness
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
xfail("nn.functional.multi_head_attention_forward"), # randomness
# It looks like you're either (1) calling .item() on a Tensor or
# (2) attempting to use a Tensor in some data-dependent control flow or
# (3) encountering this error in PyTorch internals.
xfail("nn.functional.gaussian_nll_loss"),
# got a batched tensor as input while the running_mean or running_var,
# which will be updated in place, were not batched.
xfail("nn.functional.instance_norm"),
xfail(
"nn.functional.layer_norm"
), # vmap: inplace into a regular tensor
# RuntimeError: NYI: querying is_contiguous inside of vmap
# for memory_format other than torch.contiguous_formats
xfail("nn.functional.max_pool2d"),
# RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only
# supported with memory_format torch.preserve_format or
# torch.contiguous_format (got ChannelsLast)
xfail("nn.functional.max_unpool2d"),
# RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only
# supported with memory_format torch.preserve_format
# or torch.contiguous_format (got ChannelsLast)s
xfail("nn.functional.max_unpool2d", "grad"),
xfail(
"nn.functional.rrelu"
), # RuntimeError: vmap: we do not yet support aten::rrelu_with_noise.
xfail("normal"), # calls random op
xfail("normal", "number_mean"), # calls random op
xfail("pca_lowrank"), # calls random op
xfail(
"quantile", device_type="cpu"
), # Batching rule not implemented for `at::equal`
xfail(
"scatter_reduce", "prod"
), # vmap (looks like you are calling item/data-dependent)
xfail(
"sparse.sampled_addmm"
), # RuntimeError: Sparse CSR tensors do not have strides
xfail(
"sparse.mm", "reduce"
), # RuntimeError: Sparse CSR tensors do not have strides
xfail("svd_lowrank"), # calls random op
xfail("to"), # rank 4 tensor for channels_last
xfail(
"view_as_complex"
), # RuntimeError: Tensor must have a last dimension with stride 1
# got a batched tensor as input while the running_mean or running_var,
# which will be updated in place, were not batched.
xfail("nn.functional.batch_norm", "without_cudnn"),
# view doesn't work on sparse
xfail("to_sparse"),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
# TODO: implement batching rule
xfail("_batch_norm_with_update"),
}
),
)
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@opsToleranceOverride(
"TestOperators",
"test_vmapvjpvjp",
(
tol1("linalg.svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
tol1("linalg.lu", {torch.float32: tol(atol=5e-04, rtol=7e-04)}),
tol1("linalg.lu_factor", {torch.float32: tol(atol=2e-03, rtol=2e-02)}),
tol1("linalg.multi_dot", {torch.float32: tol(atol=2e-03, rtol=2e-04)}),
tol1("svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
tol1("matrix_exp", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
tol1("masked.prod", {torch.float32: tol(atol=2e-03, rtol=2e-04)}),
),
)
@skipOps(
"TestOperators",
"test_vmapvjpvjp",
{
xfail("as_strided", "partial_views"),
xfail("as_strided_copy"),
},
)
def test_vmapvjpvjp(self, device, dtype, op):
# Since, we test `vjpvjp` independently,
# for this test, we just verify that vmap
# of `vjpvjp` is correct.
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
return
if not op.supports_gradgrad:
self.skipTest("Skipped! Operation does not support gradgrad")
return
samples = op.sample_inputs(device, dtype, requires_grad=True)
# TODO: test in-place
if is_inplace(op, op.get_op()):
self.skipTest("Skipped! NYI: inplace-testing not supported.")
return
for sample in samples:
fn, args = get_vjpfull_variant(op, sample)
result = fn(*args)
cotangents = tree_map(lambda x: torch.randn_like(x), result)
cotangents = pytree.tree_leaves(cotangents)
num_args = len(args)
args_and_cotangents = tuple(args) + tuple(cotangents)
def vjp_of_vjp(*args_and_cotangents):
args = args_and_cotangents[:num_args]
cotangents = args_and_cotangents[num_args:]
result, vjp_fn = vjp(fn, *args)
result_vjps = vjp_fn(cotangents)
result = pytree.tree_leaves(result)
result_vjps = pytree.tree_leaves(result_vjps)
return (*result, *result_vjps)
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
generator = get_fallback_and_vmap_exhaustive(
vjp_of_vjp,
args_and_cotangents,
{},
is_batch_norm_and_training=is_batch_norm_and_training,
)
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)
vmapvjp_fail = vjp_fail.union(
{
# -------------------- ALLOWED FAILURES --------------------------------
# The following are not bugs and are expected behavior
xfail("masked_select"), # Not possible due to dynamic shapes
skip("bernoulli"), # randomness
skip("normal", ""), # randomness
skip("normal", "number_mean"), # randomness
skip("nn.functional.rrelu"), # randomness
skip("nn.functional.feature_alpha_dropout", "with_train"), # randomness
skip("nn.functional.feature_alpha_dropout", "without_train"), # randomness
skip("nn.functional.dropout"), # randomness
skip("nn.functional.dropout2d"), # randomness
skip("nn.functional.dropout3d", ""), # randomness
skip("nn.functional.alpha_dropout"), # randomness
skip("nn.functional.scaled_dot_product_attention"), # randomness
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
skip("nn.functional.multi_head_attention_forward"), # randomness
xfail(
"index_put", ""
), # not possible due to dynamic shapes; we support a subset
xfail("nn.functional.fractional_max_pool2d"), # random
xfail("nn.functional.fractional_max_pool3d"), # random
xfail("pca_lowrank", ""), # randomness
xfail("svd_lowrank", ""), # randomness
xfail("to_sparse", ""), # non-dense output
skip(
"to"
), # RuntimeError: required rank 4 tensor to use channels_last format
xfail("as_strided", "partial_views"),
xfail(
"NumpyCubeNotComposableAutogradFunction"
), # Not composable autograd.Function
# ----------------------------------------------------------------------
# ---------------------------- BUGS ------------------------------------
# All of the following are bugs and need to be fixed
skip(
"linalg.svdvals"
), # # really annoying thing where it passes correctness check but not has_batch_rule
skip("native_batch_norm"),
skip("_native_batch_norm_legit"),
# TODO: implement batching rule
skip("_batch_norm_with_update"),
xfail("__getitem__", ""), # dynamic error
xfail("nanquantile", device_type="cpu"), # checks q via a .item() call
xfail("nn.functional.gaussian_nll_loss"), # checks var for if any value < 0
xfail("narrow"), # .item() call
xfail("quantile", device_type="cpu"), # checks q via a .item() call
xfail("view_as_complex"), # Tensor must have a last dimension with stride 1
# required rank 4 tensor to use channels_last format
xfail("bfloat16"),
xfail("double"),
xfail("float"),
xfail("half"),
xfail("cdouble", ""),
xfail("cfloat", ""),
xfail("chalf", ""),
xfail("scatter_reduce", "prod"), # item call
# Batching rule not implemented for aten::_use_cudnn_ctc_loss.Tensor
xfail("nn.functional.ctc_loss", device_type="cuda"),
# NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format
xfail("nn.functional.max_unpool2d"),
xfail("nn.functional.max_unpool2d", "grad"),
xfail("sparse.sampled_addmm", ""),
xfail("sparse.mm", "reduce"),
xfail("as_strided_scatter", ""), # calls as_strided
xfail("index_reduce", "prod"), # .item() call
# ---------------------------------------------------------------------
}
)
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@opsToleranceOverride(
"TestOperators",
"test_vmapvjp",
(
tol1(
"linalg.svd",
{torch.float32: tol(atol=5e-04, rtol=1e-04)},
device_type="cuda",
),
tol1(
"svd", {torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda"
),
tol1(
"linalg.householder_product",
{torch.float32: tol(atol=3e-04, rtol=9e-04)},
),
tol1(
"matrix_exp",
{torch.float32: tol(atol=5e-04, rtol=1e-04)},
device_type="cuda",
),
tol1(
"nn.functional.layer_norm",
{torch.float32: tol(atol=3e-4, rtol=1e-4)},
device_type="cpu",
),
tol1(
"native_layer_norm",
{torch.float32: tol(atol=3e-4, rtol=1e-4)},
device_type="cpu",
),
),
)
@skipOps(
"TestOperators",
"test_vmapvjp",
vmapvjp_fail.union(
{
xfail("as_strided"),
xfail("as_strided_copy"),
xfail("as_strided", "partial_views"),
}
),
)
def test_vmapvjp(self, device, dtype, op):
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
return
samples = op.sample_inputs(device, dtype, requires_grad=True)
# TODO: test in-place
if is_inplace(op, op.get_op()):
self.skipTest("Skipped! NYI: inplace-testing not supported.")
return
for sample in samples:
cotangents = get_sample_cotangents(op, sample)
fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
generator = get_fallback_and_vmap_exhaustive(
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training
)
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)
vmapjvpall_fail = {
# -------------------- ALLOWED FAILURES --------------------------------
# The following are expected (not a bug)
skip("bernoulli", ""), # randomness
skip("nn.functional.dropout"), # randomness
skip("nn.functional.rrelu"), # randomness
skip("nn.functional.dropout2d", ""),
skip("nn.functional.dropout3d", ""),
skip("nn.functional.scaled_dot_product_attention"), # randomness
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
skip("nn.functional.multi_head_attention_forward"), # randomness
skip("nn.functional.alpha_dropout"), # randomness
skip("nn.functional.feature_alpha_dropout", "without_train"),
skip("nn.functional.feature_alpha_dropout", "with_train"),
xfail(
"nn.functional.fractional_max_pool2d"
), # Cannot access data pointer of Tensor that doesn't have storage
xfail(
"nn.functional.fractional_max_pool3d"
), # Cannot access data pointer of Tensor that doesn't have storage
# Not actually a problem: embedding with max_norm mutates the weight
# and causes different runs to produce different results.
# skip because this is flaky depending on what the max_norm is!
skip("nn.functional.embedding", ""),
skip("to"), # RuntimeError: required rank 4 tensor to use channels_last format
xfail(
"NumpyExpMarkDirtyAutogradFunction"
), # vmap: inplace into a regular tensor
# ----------------------------------------------------------------------
# ---------------------------- BUGS ------------------------------------
# The following are bugs that we should fix
xfail("masked.mean"), # silent incorrectness (nan difference)
xfail("as_strided", "partial_views"), # Tensor-likes are not close!
xfail(
"nn.functional.soft_margin_loss", ""
), # soft_margin_loss_backward does not support forward-ad
xfail("tensor_split"), # data_ptr composite compliance
xfail("quantile"), # at::equal batching rule (cpu), also, in-place vmap (cuda)
skip("as_strided"), # Test runner cannot handle this
# requires special handling, and does not yet have a batching rule. Feel free to file a github issue!
xfail("as_strided_scatter"),
xfail(
"nn.functional.gaussian_nll_loss"
), # .item or data-dependent control flow
xfail("scatter"), # forward-mode AD does not support at::scatter
xfail(
"nanquantile"
), # at::equal batching rule (cpu), also, in-place vmap (cuda)
xfail("view_as_complex"), # Tensor must have a last dimension with stride 1
skip("pca_lowrank", ""), # randomness
skip("svd_lowrank", ""), # randomness
xfail("double"), # required rank 4 tensor to use channels_last format
xfail("cdouble"), # required rank 4 tensor to use channels_last format
# potential silent incorrectness
skip(
"nn.functional.max_unpool1d"
), # Flaky, seems to sometimes his max_unpool2d
skip("nn.functional.max_unpool2d"), # fails everywhere except on mac
skip("nn.functional.max_unpool3d"), # fails everywhere except on mac
# erroring because running_mean and running_var aren't differentiable
xfail("nn.functional.batch_norm"),
xfail("nn.functional.batch_norm", "without_cudnn"),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
# TODO: implement batching rule
xfail("_batch_norm_with_update"),
# ----------------------------------------------------------------------
}
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@opsToleranceOverride(
"TestOperators",
"test_vmapjvpall",
(
tol1(
"nn.functional.conv_transpose3d",
{torch.float32: tol(atol=2e-04, rtol=9e-3)},
device_type="cuda",
),
tol1(
"linalg.householder_product",
{torch.float32: tol(atol=2e-04, rtol=9e-3)},
),
),
)
@skipOps(
"TestOperators",
"test_vmapjvpall",
vmapjvpall_fail.union(
{
xfail("as_strided_copy"),
decorate(
"linalg.det",
"singular",
decorator=expectedFailureIf(IS_MACOS and IS_X86),
),
}
),
)
# This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp
# or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact
# because that corresponds to "batched forward-mode AD" testing in PyTorch core
def test_vmapjvpall(self, device, dtype, op):
if is_inplace(op, op.get_op()):
# TODO: test in-place
self.skipTest("Skipped! NYI: inplace-testing not supported.")
return
samples = op.sample_inputs(device, dtype, requires_grad=False)
if not op.supports_forward_ad:
self.skipTest("Skipped! Forward AD not supported.")
return
for sample in samples:
arg_values = [sample.input] + list(sample.args)
kwarg_values = sample.kwargs
args = tuple(arg_values) + tuple(kwarg_values)
fn, args = get_jvp_variant_primals_tangents(op, sample)
is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
generator = get_fallback_and_vmap_exhaustive(
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training
)
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@skipOps(
"TestOperators",
"test_vmapjvpall_has_batch_rule",
vmapjvpall_fail.union(
{
skip(
"to"
), # RuntimeError: required rank 4 tensor to use channels_last format
xfail(
"cdouble"
), # RuntimeError: required rank 4 tensor to use channels_last format
xfail("cumprod"),
xfail("masked_fill"),
xfail("fill"),
skip("masked.mean"), # ???
xfail("masked_scatter"),
xfail("put"),
xfail("take"),
xfail("nn.functional.feature_alpha_dropout", "without_train"),
xfail("nn.functional.dropout2d", ""),
xfail("pca_lowrank", ""),
xfail("svd_lowrank", ""),
xfail("nn.functional.feature_alpha_dropout", "with_train"),
xfail("special.log_ndtr", ""),
xfail("fft.ihfft2"), # conj_physical fallback
xfail("fft.ihfftn"), # conj_physical fallback
xfail("nn.functional.max_unpool3d", "grad"),
xfail("nn.functional.max_unpool2d", "grad"),
xfail("nn.functional.soft_margin_loss", ""),
xfail("nn.functional.max_unpool1d", "grad"),
xfail("nn.functional.embedding", ""),
xfail(
"scatter_reduce", "sum"
), # aten::scatter_reduce.two hit the vmap fallback
xfail(
"scatter_reduce", "mean"
), # aten::scatter_reduce.two hit the vmap fallback
xfail(
"scatter_reduce", "amin"
), # aten::scatter_reduce.two hit the vmap fallback
xfail(
"scatter_reduce", "amax"
), # aten::scatter_reduce.two hit the vmap fallback
xfail("nn.functional.glu"),
xfail("nn.functional.bilinear"), # trilinear doesn't have batching rule
xfail("linalg.lu", ""),
xfail("nn.functional.dropout3d", ""),
xfail("as_strided_scatter", ""),
xfail("masked.cumprod", ""),
xfail("renorm"), # hit vmap fallback, which is disabled
xfail("t_copy"),
xfail("unsqueeze_copy"),
}
),
)
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
if is_inplace(op, op.get_op()):
# TODO: test in-place
self.skipTest("Skipped! NYI: inplace-testing not supported.")
return
samples = op.sample_inputs(device, dtype, requires_grad=False)
if not op.supports_forward_ad:
self.skipTest("Skipped! Forward AD not supported.")
return
def test():
for sample in samples:
arg_values = [sample.input] + list(sample.args)
kwarg_values = sample.kwargs
args = tuple(arg_values) + tuple(kwarg_values)
fn, args = get_jvp_variant_primals_tangents(op, sample)
is_batch_norm_and_training = is_batch_norm_training(
op.name, kwarg_values
)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
fn,
args,
{},
is_batch_norm_and_training=is_batch_norm_and_training,
compute_loop_out=False,
):
pass
check_vmap_fallback(self, test, op, dry_run=False)
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps(
"TestOperators",
"test_vmapvjp_has_batch_rule",
vmapvjp_fail.union(
{
skip(
"to"
), # RuntimeError: required rank 4 tensor to use channels_last format
xfail("view_as_complex"),
xfail("cummax"),
xfail("cummin"),
xfail("fill"),
xfail(
"narrow"
), # Batching rule not implemented for `narrow.Tensor` (and view op)
xfail("special.log_ndtr"),
xfail("linalg.householder_product"),
xfail("masked_fill"),
xfail("masked_scatter"),
xfail("masked_select"),
xfail("nanquantile"),
xfail("ormqr"),
xfail("put"),
xfail(
"scatter_reduce", "sum"
), # aten::scatter_reduce.two hit the vmap fallback
xfail(
"scatter_reduce", "mean"
), # aten::scatter_reduce.two hit the vmap fallback
xfail(
"scatter_reduce", "amin"
), # aten::scatter_reduce.two hit the vmap fallback
xfail(
"scatter_reduce", "amax"
), # aten::scatter_reduce.two hit the vmap fallback
xfail("quantile"),
xfail("renorm"),
xfail("take"),
xfail("tensor_split"),
xfail("to_sparse"),
xfail("unfold"),
xfail("unfold_copy"),
xfail("nn.functional.dropout"),
xfail("fft.ihfft2"),
xfail("fft.ihfftn"),
xfail("nn.functional.gaussian_nll_loss"),
xfail("nn.functional.bilinear"),
xfail("nn.functional.fractional_max_pool3d"),
xfail("nn.functional.ctc_loss"),
xfail("nn.functional.rrelu"),
xfail("nn.functional.embedding_bag"),
xfail("nn.functional.fractional_max_pool2d"),
xfail("nn.functional.feature_alpha_dropout", "with_train"),
xfail("pca_lowrank", ""),
xfail("nn.functional.dropout2d", ""),
xfail("nn.functional.feature_alpha_dropout", "without_train"),
xfail("svd_lowrank", ""),
xfail("nn.functional.max_unpool2d", ""),
xfail("nn.functional.multi_margin_loss", ""),
xfail("nn.functional.multilabel_margin_loss", ""),
xfail("nn.functional.pdist", ""),
xfail("scatter_reduce", "prod"),
xfail("nn.functional.max_unpool1d", ""),
xfail("nn.functional.max_unpool3d", ""),
xfail("nn.functional.max_unpool3d", "grad"),
xfail("nn.functional.soft_margin_loss", ""),
xfail("nn.functional.max_unpool1d", "grad"),
xfail("nn.functional.max_unpool2d", "grad"),
xfail("linalg.lu", ""),
xfail("cdouble", ""),
xfail("cfloat", ""),
xfail("chalf", ""),
xfail(
"index_reduce", "prod"
), # aten::index_reduce hit the vmap fallback which is currently disabled
xfail(
"index_reduce", "mean"
), # aten::index_reduce hit the vmap fallback which is currently disabled
xfail(
"index_reduce", "amax"
), # aten::index_reduce hit the vmap fallback which is currently disabled
xfail(
"index_reduce", "amin"
), # aten::index_reduce hit the vmap fallback which is currently disabled
xfail("nn.functional.dropout3d", ""),
xfail("as_strided_scatter", ""),
xfail("_segment_reduce", "offsets"),
xfail("_segment_reduce", "lengths"),
xfail("sparse.sampled_addmm", ""),
xfail("sparse.mm", "reduce"),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
# TODO: implement batching rule
xfail("_batch_norm_with_update"),
xfail("native_dropout_backward"),
xfail(
"index_fill"
), # aten::_unique hit the vmap fallback which is currently disabled
xfail("t_copy"),
xfail("unsqueeze_copy"),
}
),
)
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
return
samples = op.sample_inputs(device, dtype, requires_grad=True)
# TODO: test in-place
if is_inplace(op, op.get_op()):
self.skipTest("Skipped! NYI: inplace-testing not supported.")
return
def test():
for sample in samples:
cotangents = get_sample_cotangents(op, sample)
fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
is_batch_norm_and_training = is_batch_norm_training(
op.name, sample.kwargs
)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
fn,
args,
{},
is_batch_norm_and_training=is_batch_norm_and_training,
compute_loop_out=False,
):
pass
for a_op in op.aliases:
fn, args = get_vjp_fn_and_args_with_cotangents(
a_op, sample, cotangents
)
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
fn,
args,
{},
is_batch_norm_and_training=is_batch_norm_and_training,
compute_loop_out=False,
):
pass
check_vmap_fallback(self, test, op, dry_run=False)
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@skipOps(
"TestOperators",
"test_vjpvmap",
vjp_fail.union(
{
skip("bernoulli", ""), # vjpvmap testing can't handle randomness
skip("normal", ""), # vjpvmap testing can't handle randomness
skip(
"normal", "number_mean"
), # vjpvmap testing can't handle randomness
skip("nn.functional.rrelu"), # randomness
skip("nn.functional.feature_alpha_dropout", "with_train"), # randomness
skip(
"nn.functional.feature_alpha_dropout", "without_train"
), # randomness
skip("nn.functional.scaled_dot_product_attention"),
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
skip("nn.functional.multi_head_attention_forward"), # randomness
skip("nn.functional.alpha_dropout"), # randomness
skip(
"to"
), # RuntimeError: required rank 4 tensor to use channels_last format
skip("to_sparse", ""), # non-dense output
skip("ormqr", ""), # takes too long
xfail(
"NumpyCubeNotComposableAutogradFunction"
), # Not composable autograd.Function
# fallback path doesn't work
# All of the following are bugs and need to be fixed
xfail("__getitem__", ""),
xfail("index_put", ""),
xfail("view_as_complex"),
xfail("nn.functional.gaussian_nll_loss"),
xfail("masked_select"),
xfail(
"narrow"
), # Batching rule not implemented for `narrow.Tensor` (and view op)
skip(
"nn.functional.fractional_max_pool3d"
), # generator works on cpu, fails on cuda
skip(
"nn.functional.fractional_max_pool2d"
), # generator works on cpu, fails on cuda
xfail("column_stack", ""),
xfail("nn.functional.dropout2d", ""),
xfail("svd_lowrank", ""),
xfail("pca_lowrank", ""),
xfail("clamp"),
# something weird happening with channels_last
xfail("bfloat16"),
xfail("double"),
xfail("float"),
xfail("half"),
xfail("cdouble"),
xfail("cfloat"),
xfail("nn.functional.dropout3d", ""),
xfail("as_strided_scatter", ""),
xfail("sparse.sampled_addmm", ""),
xfail("sparse.mm", "reduce"),
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
# TODO: implement batching rule
xfail("_batch_norm_with_update"),
xfail("as_strided", "partial_views"),
}
),
)
def test_vjpvmap(self, device, dtype, op):
# NB: there is no vjpvmap_has_batch_rule test because that is almost
# certainly redundant with the vmap_has_batch_rule test in test_vmap.py
# one-off skip
if op.name == "nn.functional.dropout":
self.skipTest("Skipped!")
if not op.supports_autograd:
# If the op doesn't support autograd, vmap(op) won't either
self.skipTest("Skipped! Autograd not supported.")
return
# TODO: test in-place
if is_inplace(op, op.get_op()):
self.skipTest("Skipped! NYI: inplace-testing not supported.")
return
samples = op.sample_inputs(device, dtype, requires_grad=True)
batch_norm_fns = (
"nn.functional.batch_norm",
"nn.functional.instance_norm",
) # instance norm calls batch norm
is_batch_norm = op.name in batch_norm_fns
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
is_batch_norm_and_training = is_batch_norm and is_batch_norm_training(
op.name, kwargs
)
generator = generate_vmap_inputs(
args, kwargs, is_batch_norm_and_training=is_batch_norm_and_training
)
for batched_args, in_dims, kwargs in generator:
vmapped_op = vmap(op, in_dims)
fn, primals = normalize_op_input_output2(
vmapped_op, batched_args, kwargs, sample.output_process_fn_grad
)
result = fn(*primals)
cotangents = tree_map(lambda x: torch.randn_like(x), result)
_, vjp_fn = vjp(fn, *primals)
result_vjps = vjp_fn(cotangents)
_, vjp_fn = ref_vjp(fn, *primals)
expected_vjps = vjp_fn(cotangents)
self.assertEqual(result_vjps, expected_vjps)
def _compare_jacobians_of_vjp(
self, fn, cotangents_and_primals, argnums=None, atol_rtol=None
):
if argnums is None:
argnums = tuple(range(len(cotangents_and_primals)))
def get_vjp(cotangents, *primals):
_, vjp_fn = vjp(fn, *primals)
return vjp_fn(cotangents)
jacobian_jvp = jacfwd(get_vjp, argnums)(*cotangents_and_primals)
jacobian_vjp = jacrev(get_vjp, argnums)(*cotangents_and_primals)
# For dtype changing operations, the jacobians have different dtype.
jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp)
jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp)
if atol_rtol is not None:
(atol, rtol) = atol_rtol
self.assertEqual(jacobian_jvp, jacobian_vjp, atol=atol, rtol=rtol)
else:
self.assertEqual(jacobian_jvp, jacobian_vjp)
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@skipOps(
"TestOperators",
"test_jvpvjp",
vjp_fail.union(
{
xfail("to_sparse", ""), # NYI
# RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
# this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
xfail("normal", ""),
xfail("cdist", ""), # NYI: forward-AD for _cdist_forward
xfail("cholesky", ""), # NYI: forward-AD for cholesky
xfail(
"nn.functional.embedding_bag", ""
), # NYI: forward-AD for _embedding_bag
xfail(
"nn.functional.grid_sample", ""
), # NYI: forward AD for grid_sampler_2d
xfail("grid_sampler_2d", ""), # NYI: forward AD for grid_sampler_2d
xfail(
"nn.functional.hardsigmoid", ""
), # NYI: forward AD for hardsigmoid_backward
xfail(
"nn.functional.huber_loss", ""
), # NYI: forward AD for huber_loss_backward
xfail("NumpyCubeNotComposableAutogradFunction"), # not composable
xfail("ormqr", ""), # NYI: forward AD for ormqr
xfail(
"nn.functional.multilabel_margin_loss", ""
), # NYI: multilabel_margin_loss_forward
xfail(
"nn.functional.soft_margin_loss", ""
), # NYI: forward-AD for soft_margin_loss_backward
xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss
xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward
skip("nn.functional.scaled_dot_product_attention"),
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
xfail(
"nn.functional.multi_margin_loss", ""
), # NYI: forward AD with multi_margin_loss
skip(
"linalg.householder_product", "", device_type="cuda"
), # flaky, I'm not sure why
xfail("sparse.sampled_addmm", ""), # Sparse tensors have no strides
xfail(
"_segment_reduce", "offsets"
), # NYI: forward-AD for _segment_reduce
xfail("sparse.mm", "reduce"), # Sparse tensors have no strides
xfail("index_reduce", "prod"), # NYI: forward-AD for index_reduce
xfail("index_reduce", "mean"), # NYI: forward-AD for index_reduce
xfail("index_reduce", "amax"), # NYI: forward-AD for index_reduce
xfail("index_reduce", "amin"), # NYI: forward-AD for index_reduce
xfail(
"_segment_reduce", "lengths"
), # NYI: forward-AD for _segment_reduce
xfail("native_dropout_backward"), # NYI
}
),
)
@opsToleranceOverride(
"TestOperators",
"test_jvpvjp",
(
tol1("masked.prod", {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}),
tol1("masked.cumprod", {torch.float32: tol(atol=1e-04, rtol=5e-04)}),
tol1(
"cumprod",
{torch.float32: tol(atol=1e-03, rtol=5e-04)},
device_type="cuda",
),
tol1(
"linalg.det",
{torch.float32: tol(atol=3e-05, rtol=5e-06)},
device_type="cuda",
),
tol1(
"linalg.vander",
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)},
device_type="cuda",
),
tol1(
"nn.functional.group_norm", {torch.float32: tol(atol=1e-03, rtol=1e-03)}
),
tol2(
"linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-03, rtol=5e-03)}
),
),
)
def test_jvpvjp(self, device, dtype, op):
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
return
samples = op.sample_inputs(device, dtype, requires_grad=True)
# TODO: test in-place
if is_inplace(op, op.get_op()):
self.skipTest("Skipped! NYI: inplace-testing not supported.")
return
for sample in samples:
fn, primals = normalize_op_input_output(op, sample)
result = fn(*primals)
cotangents = tree_map(lambda x: torch.randn_like(x), result)
primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)
def push_vjp(primals, cotangents):
_, vjp_fn = vjp(fn, *primals)
return vjp_fn(cotangents)
result = jvp(
push_vjp, (primals, cotangents), (primals_tangents, cotangents_tangents)
)
self.assertEqual(len(result), 2)
def tree_map2(fn, first, second):
flat_first, spec_first = tree_flatten(first)
flat_second, spec_second = tree_flatten(second)
assert spec_first == spec_second
flat_result = [fn(f, s) for f, s in zip(flat_first, flat_second)]
return tree_unflatten(flat_result, spec_first)
def reference(primals, cotangents, primals_tangents, cotangents_tangents):
with fwAD.dual_level():
primal_duals = tree_map2(fwAD.make_dual, primals, primals_tangents)
_, vjp_fn = ref_vjp(fn, *primal_duals)
cotangent_duals = tree_map2(
fwAD.make_dual, cotangents, cotangents_tangents
)
result = vjp_fn(cotangent_duals)
flat_result, spec = tree_flatten(result)
primals_out, tangents_out = zip(
*[fwAD.unpack_dual(r) for r in flat_result]
)
tangents_out = [
t if t is not None else torch.zeros_like(p)
for p, t in zip(primals_out, tangents_out)
]
expected = (
tree_unflatten(primals_out, spec),
tree_unflatten(tangents_out, spec),
)
return expected
expected = reference(
primals, cotangents, primals_tangents, cotangents_tangents
)
self.assertEqual(result, expected)
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@skipOps(
"TestOperators",
"test_vmapjvpvjp",
vjp_fail.union(
{
# Following operators take too long, hence skipped
skip("atleast_1d"),
skip("atleast_2d"),
skip("atleast_3d"),
skip("meshgrid", "list_of_tensors"),
skip("meshgrid", "variadic_tensors"),
skip("broadcast_tensors"),
skip("linalg.lstsq"),
skip("nn.functional.bilinear"),
skip("native_layer_norm"),
skip("ormqr"),
# Not actually a problem
xfail("NumpyCubeNotComposableAutogradFunction"), # not composable
xfail(
"NumpyExpMarkDirtyAutogradFunction"
), # vmap: inplace into a regular tensor
# Potential bugs/errors
xfail("as_strided"), # AssertionError: Tensor-likes are not close!
xfail(
"as_strided", "partial_views"
), # AssertionError: Tensor-likes are not close!
xfail("as_strided_copy"), # AssertionError: Tensor-likes are not close!
xfail(
"as_strided_scatter"
), # AssertionError: Tensor-likes are not close!
xfail("bernoulli"), # calls random op
xfail("bfloat16"), # required rank 4 tensor to use channels_last format
xfail("cdist"), # Forward AD not implemented and no decomposition
xfail("cdouble"), # required rank 4 tensor to use channels_last format
xfail("cfloat"), # required rank 4 tensor to use channels_last format
xfail("chalf"), # required rank 4 tensor to use channels_last format
xfail("cholesky"), # Forward AD not implemented and no decomposition
xfail("ormqr"), # Forward AD not implemented and no decomposition
xfail("double"), # required rank 4 tensor to use channels_last format
xfail("float"), # required rank 4 tensor to use channels_last format
xfail("half"), # required rank 4 tensor to use channels_last format
xfail("index_reduce", "prod"), # NYI: forward AD for index_reduce
xfail("index_reduce", "mean"), # NYI: forward AD for index_reduce
xfail("index_reduce", "amax"), # NYI: forward AD for index_reduce
xfail("index_reduce", "amin"), # NYI: forward AD for index_reduce
xfail(
"mvlgamma", "mvlgamma_p_1"
), # vmap: inplace into a regular tensor
xfail(
"mvlgamma", "mvlgamma_p_3"
), # vmap: inplace into a regular tensor
xfail(
"mvlgamma", "mvlgamma_p_5"
), # vmap: inplace into a regular tensor
xfail("nanquantile"), # Batching rule not implemented for aten::equal
# RuntimeError: Batch norm got a batched tensor as input while the
# running_mean or running_var, which will be updated in place,
# were not batched.
xfail("nn.functional.batch_norm"),
xfail("nn.functional.batch_norm", "without_cudnn"),
xfail(
"nn.functional.ctc_loss"
), # ForwardAD not implemented and no decomposition
xfail("nn.functional.dropout2d"), # calls random op
xfail("nn.functional.dropout3d"), # calls random op
xfail("nn.functional.dropout"), # calls random op
xfail("nn.functional.scaled_dot_product_attention"), # randomness
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
xfail("nn.functional.multi_head_attention_forward"), # randomness
xfail(
"nn.functional.embedding_bag"
), # Forward AD not implemented and no decomposition
xfail("nn.functional.alpha_dropout"), # calls randomn op
xfail(
"nn.functional.feature_alpha_dropout", "with_train"
), # calls random op
xfail("nn.functional.fractional_max_pool2d"), # calls random op
xfail("nn.functional.fractional_max_pool3d"), # calls random op
xfail("nn.functional.gaussian_nll_loss"), # data depenedant flow
xfail(
"nn.functional.grid_sample"
), # Forward AD not implemented and no decomposition
xfail(
"grid_sampler_2d"
), # Forward AD not implemented and no decomposition
xfail(
"nn.functional.hardsigmoid"
), # Forward AD not implemented and no decomposition
xfail(
"nn.functional.hinge_embedding_loss"
), # vmap: inplace into a regular tensor
xfail(
"nn.functional.huber_loss"
), # Forward AD not implemented and no decomposition
# RuntimeError: Batch norm got a batched tensor as input while the
# running_mean or running_var, which will be updated in place,
# were not batched.
xfail("nn.functional.instance_norm"),
# NYI: Tensor.clone(memory_format) inside vmap is only supported with
# memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast)
xfail("nn.functional.max_unpool2d"),
xfail("nn.functional.max_unpool2d", "grad"),
xfail(
"nn.functional.multi_margin_loss"
), # Forward AD not implemented and no decomposition
xfail(
"nn.functional.multilabel_margin_loss"
), # Forward AD not implemented and no decomposition
xfail(
"nn.functional.pdist"
), # Forward AD not implemented and no decomposition
xfail(
"nn.functional.rrelu"
), # vmap: we do not yet support aten::rrelu_with_noise.
xfail(
"nn.functional.soft_margin_loss"
), # Forward AD not implemented and no decomposition
xfail("normal"), # calls random op
xfail("normal", "number_mean"), # calls random op
xfail("pca_lowrank"), # calls random op
xfail("quantile"), # Batching rule not implemented for aten::equal
xfail(
"scatter_reduce", "prod"
), # Forward AD not implemented and no decomposition
xfail(
"_segment_reduce", "lengths"
), # Forward AD not implemented and no decomposition
xfail(
"_segment_reduce", "offsets"
), # Forward AD not implemented and no decomposition
xfail(
"sparse.sampled_addmm"
), # RuntimeError: Sparse CSR tensors do not have strides
xfail(
"sparse.mm", "reduce"
), # RuntimeError: Sparse CSR tensors do not have strides
xfail("svd_lowrank"), # calls random op
xfail(
"to"
), # RuntimeError: required rank 4 tensor to use channels_last format
xfail("to_sparse"), # Forward AD not implemented and no decomposition
xfail(
"view_as_complex"
), # RuntimeError: Tensor must have a last dimension with stride 1
# RuntimeError: Batch norm got a batched tensor as
# input while the running_mean or running_var, which will be updated in
# place, were not batched.
xfail("native_batch_norm"),
xfail("_native_batch_norm_legit"),
# TODO: implement batching rule
xfail("_batch_norm_with_update"),
xfail("native_dropout_backward"),
}
),
)
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@opsToleranceOverride(
"TestOperators",
"test_vmapjvpvjp",
(
tol1("linalg.svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1(
"linalg.householder_product",
{torch.float32: tol(atol=5e-03, rtol=5e-03)},
),
tol1("linalg.multi_dot", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol2(
"linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-04, rtol=5e-04)}
),
tol1(
"nn.functional.conv_transpose2d",
{torch.float32: tol(atol=5e-04, rtol=5e-04)},
),
tol1("svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
tol1("matrix_exp", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
),
)
def test_vmapjvpvjp(self, device, dtype, op):
# Since we test `jvpvjp` separately,
# in this we just check that vmap of `jvpvjp`
# is correct.
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
return
samples = op.sample_inputs(device, dtype, requires_grad=True)
# TODO: test in-place
if is_inplace(op, op.get_op()):
self.skipTest("Skipped! NYI: inplace-testing not supported.")
return
for sample in samples:
fn, primals = normalize_op_input_output(op, sample)
result = fn(*primals)
cotangents = tree_map(lambda x: torch.randn_like(x), result)
primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)
def push_vjp(primals, cotangents):
_, vjp_fn = vjp(fn, *primals)
return vjp_fn(cotangents)
args, spec = tree_flatten(
((primals, cotangents), (primals_tangents, cotangents_tangents))
)
def jvp_of_vjp(*args):
(primals, tangents) = tree_unflatten(args, spec)
primals_out, tangents_out = jvp(push_vjp, primals, tangents)
flat_primals_out = pytree.tree_leaves(primals_out)
flat_tangents_out = pytree.tree_leaves(tangents_out)
return tuple(flat_primals_out + flat_tangents_out)
is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs)
generator = get_fallback_and_vmap_exhaustive(
jvp_of_vjp,
args,
{},
is_batch_norm_and_training=is_batch_norm_and_training,
)
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)
def _make_extremal_inputs(self, shape, device):
if shape is None:
return (None,)
return (
torch.full(shape, -1000.0, device=device),
torch.zeros(shape, device=device),
torch.full(shape, 1000.0, device=device),
)
def _arg_and_kwarg_options(self, args_options, kwargs_options):
return itertools.product(*args_options, kwargs_options)
def test_extremal_numerics_nll_loss(self, device):
N, C = 3, 4
d1, d2, d3 = 5, 6, 7
shapes = (
((N, C), (N,), (C,)),
((N, C), (N,), None),
((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)),
((N, C, d1, d2, d3), (N, d1, d2, d3), None),
)
kwargs_options = (
{"ignore_index": 0, "reduction": "mean"},
{"reduction": "sum"},
{"reduction": "none"},
{},
)
for input_shape, target_shape, weight_shape in shapes:
input_options = self._make_extremal_inputs(input_shape, device)
for input, kwargs in self._arg_and_kwarg_options(
(input_options,), kwargs_options
):
if weight_shape is None:
weight = None
else:
weight = torch.randn(weight_shape, device=device)
target = torch.randint(0, C, target_shape, device=device)
target[
0
] = 1 # since we're ignoring index 0, at least one element must be non-zero
fn = functools.partial(
torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs
)
result = fn(input)
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(fn, (cotangents, input))
def test_extremal_numerics_l1_loss(self, device):
N, C, H, W = 3, 4, 5, 6
shapes = ((N, C), (N, C, H), (N, C, H, W))
kwargs_options = ({"reduction": "sum"}, {"reduction": "none"}, {})
for shape in shapes:
input_options = self._make_extremal_inputs(shape, device)
target_options = self._make_extremal_inputs(shape, device)
for input, target, kwargs in self._arg_and_kwarg_options(
(input_options, target_options), kwargs_options
):
result = torch.nn.functional.l1_loss(input, target)
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(
torch.nn.functional.l1_loss, (cotangents, input, target)
)
def test_extremal_numerics_mse_loss(self, device):
N, C, H, W = 3, 4, 5, 6
shapes = ((N, C), (N, C, H), (N, C, H, W))
kwargs_options = ({"reduction": "sum"}, {"reduction": "none"}, {})
for shape in shapes:
input_options = self._make_extremal_inputs(shape, device)
target_options = self._make_extremal_inputs(shape, device)
for input, target, kwargs in self._arg_and_kwarg_options(
(input_options, target_options), kwargs_options
):
result = torch.nn.functional.mse_loss(input, target)
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(
torch.nn.functional.mse_loss, (cotangents, input, target)
)
def test_extremal_numerics_softmax(self, device):
N, C, H, W = 3, 4, 5, 6
shapes = ((N, C), (N, C, H), (N, C, H, W))
kwargs_options = ({"dim": 1}, {})
for shape in shapes:
input_options = self._make_extremal_inputs(shape, device)
for input, kwargs in self._arg_and_kwarg_options(
(input_options,), kwargs_options
):
result = torch.nn.functional.softmax(input)
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(
torch.nn.functional.softmax, (cotangents, input)
)
def test_extremal_numerics_log_softmax(self, device):
N, C, H, W = 3, 4, 5, 6
shapes = ((N, C), (N, C, H), (N, C, H, W))
kwargs_options = ({"dim": 1}, {})
for shape in shapes:
input_options = self._make_extremal_inputs(shape, device)
for input, kwargs in self._arg_and_kwarg_options(
(input_options,), kwargs_options
):
result = torch.nn.functional.log_softmax(input)
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(
torch.nn.functional.log_softmax, (cotangents, input)
)
def test_extremal_numerics_cross_entropy(self, device):
N, C = 3, 4
d1, d2, d3 = 5, 6, 7
shapes = (
((N, C), (N,), (C,)),
((N, C), (N,), None),
((N, C), (N, C), (C,)),
((N, C), (N, C), None),
((C,), (), (C,)),
((C,), (), None),
((C,), (C,), (C,)),
((C,), (C,), None),
((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)),
((N, C, d1, d2, d3), (N, d1, d2, d3), None),
((N, C, d1, d2, d3), (N, C, d1, d2, d3), (C,)),
((N, C, d1, d2, d3), (N, C, d1, d2, d3), None),
)
for input_shape, target_shape, weight_shape in shapes:
input_options = self._make_extremal_inputs(input_shape, device)
kwargs_options = [{"reduction": "sum"}, {"reduction": "none"}, {}]
if input_shape != target_shape:
kwargs_options.append({"ignore_index": 0, "reduction": "mean"})
for input, kwargs in self._arg_and_kwarg_options(
(input_options,), kwargs_options
):
if weight_shape is None:
weight = None
else:
weight = torch.randn(weight_shape, device=device)
if input_shape == target_shape:
target = torch.rand(target_shape, device=device)
elif len(target_shape) == 0:
target = torch.tensor(
1, device=device
) # must be non-zero since ignore_index may be 0
else:
target = torch.randint(0, C, target_shape, device=device)
fn = functools.partial(
torch.nn.functional.cross_entropy,
target=target,
weight=weight,
**kwargs,
)
result = fn(input)
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(
fn, (cotangents, input), atol_rtol=(1e-4, 1e-5)
)
def test_extremal_numerics_binary_cross_entropy(self, device):
N, C, H, W = 3, 4, 5, 6
shapes = ((N, C), (N, C, H), (N, C, H, W))
for shape in shapes:
weight_options = self._make_extremal_inputs(shape, device)
kwargs_options = [{"reduction": "sum"}, {"reduction": "none"}, {}]
for weight, kwargs in self._arg_and_kwarg_options(
(weight_options,), kwargs_options
):
input = torch.rand(shape, device=device)
target = torch.rand(shape, device=device)
fn = functools.partial(
torch.nn.functional.binary_cross_entropy,
target=target,
weight=weight,
**kwargs,
)
result = fn(input)
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(
fn, (cotangents, input), atol_rtol=(1e-4, 2e-5)
)
def test_extremal_numerics_layer_norm(self, device):
N, C, H, W = 3, 4, 5, 6
shapes = ((N, C), (N, C, H), (N, C, H, W))
for shape in shapes:
input_options = self._make_extremal_inputs(shape, device)
normalized_shape = shape[1:]
weight_options = self._make_extremal_inputs(normalized_shape, device)
bias_options = self._make_extremal_inputs(normalized_shape, device)
for input, bias, weight in self._arg_and_kwarg_options(
(input_options, bias_options, weight_options), ()
):
def fn(input, weight, bias):
return torch.nn.functional.layer_norm(
input, normalized_shape, weight=weight, bias=bias
)
result = fn(input, weight, bias)
cotangents = torch.randn_like(result, device=device)
self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias))
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(
op_db + additional_op_db + autograd_function_db,
allowed_dtypes=(torch.float32, torch.double),
)
@skipOps(
"TestOperators",
"test_vmap_autograd_grad",
{
# The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0
xfail("masked_select"),
xfail("nn.functional.max_unpool2d", "grad"), # contiguous call
xfail("nn.functional.max_unpool2d"), # contiguous call
xfail("to_sparse"), # dispatch key issue
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
# https://github.com/pytorch/pytorch/issues/96560#issuecomment-2151063723
# ** minor accuracy issue for float32 on ROCm
decorate("xlogy", decorator=skipIfRocm),
# numerical inconsistencies, look like bugs
skip(
"matrix_exp", dtypes=(torch.float32,), device_type="cuda"
), # fails on linux, passes on windows
skip(
"ldexp", dtypes=(torch.float32,), device_type="cpu"
), # fails on all but mac
skip("__rmatmul__"), # flaky needs investigation
skip("matmul"), # flaky needs investigation
skip("nn.functional.conv_transpose3d"), # flaky needs investigation
skip("nn.functional.conv_transpose2d"), # flaky needs investigation
skip("nn.functional.conv_transpose1d"), # flaky needs investigation
skip(
"nn.functional.layer_norm", dtypes=(torch.float32,), device_type="cpu"
), # fails on windows
skip(
"linalg.lu_factor", dtypes=(torch.float32,), device_type="cuda"
), # fails on all but windows
skip(
"linalg.lu_factor_ex", dtypes=(torch.float32,), device_type="cuda"
), # fails on all but windows
skip("linalg.multi_dot", "", device_type="cpu"),
skip("sparse.sampled_addmm", ""),
skip("sparse.mm", "reduce"),
skip("native_layer_norm", "", device_type="cpu"),
# RuntimeError: Expected contiguous tensor, but got
# non-contiguous tensor for argument #2 'grad_output'
decorate(
"_batch_norm_with_update",
decorator=expectedFailureIf(TEST_WITH_ROCM),
device_type="cuda",
),
},
)
@opsToleranceOverride(
"TestOperators",
"test_vmap_autograd_grad",
(
tol1(
"ldexp",
{torch.float32: tol(atol=3e-04, rtol=1.6e-06)},
device_type="cuda",
),
tol1(
"linalg.householder_product",
{torch.float32: tol(atol=5e-04, rtol=9e-03)},
device_type="cuda",
),
tol1(
"linalg.householder_product",
{torch.float32: tol(atol=6e-03, rtol=1e-03)},
device_type="cpu",
),
tol1(
"linalg.multi_dot",
{torch.float32: tol(atol=2e-04, rtol=1e-04)},
device_type="cuda",
),
tol2(
"linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-06, rtol=5e-06)}
),
tol1("nn.functional.conv3d", {torch.float32: tol(atol=5e-04, rtol=9e-03)}),
tol1(
"nn.functional.conv2d",
{torch.float32: tol(atol=3e-05, rtol=5e-06)},
device_type="cuda",
),
tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
),
)
def test_vmap_autograd_grad(self, device, dtype, op):
def is_differentiable(inp):
return isinstance(inp, Tensor) and (
inp.grad_fn is not None or inp.requires_grad
)
def get_flat_differentiable(tree):
flattened = pytree.tree_leaves(tree)
return tuple(i for i in flattened if is_differentiable(i))
def get_differentiable_linked(list1, list2):
paired_list = zip(list1, list2)
paired_list = tuple(
(first, second)
for (first, second) in paired_list
if is_differentiable(first)
)
return zip(*paired_list)
def filter_none(out):
flattened = pytree.tree_leaves(out)
return tuple(o for o in flattened if o is not None)
if not op.supports_autograd:
self.skipTest("Skipped! Autograd not supported.")
return
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
for sample_input in sample_inputs:
fn, primals = normalize_op_input_output(op, sample_input)
out = fn(*primals)
cotangents = tree_map(torch.randn_like, out)
def compute_grad(cotangents):
out_flattened = out
cotangents_flattened = cotangents
if not isinstance(out_flattened, torch.Tensor):
out_flattened = pytree.tree_leaves(out)
cotangents_flattened = pytree.tree_leaves(cotangents)
out_flattened, cotangents_flattened = get_differentiable_linked(
out_flattened, cotangents_flattened
)
return filter_none(
torch.autograd.grad(
out_flattened,
get_flat_differentiable(primals),
cotangents_flattened,
retain_graph=True,
allow_unused=True,
)
)
is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs)
generator = get_fallback_and_vmap_exhaustive(
compute_grad,
(cotangents,),
{},
is_batch_norm_and_training=is_batch_norm_and_training,
)
for loop_out, batched_out in generator:
self.assertEqual(loop_out, batched_out)
def test_vmapvmapjvp_linalg_solve(self):
ops = [op for op in op_db if op.name == "linalg.solve"]
assert len(ops) > 0
# this specializes a lot of code from the get_fallback_and_vmap_exhaustive test. If we need this more
# generally, this could go for a refactor
B0 = 2
B1 = 3
# we want to check the case where A will be seen as contiguous by jvp but during the vmap calls will become
# non-contiguous because vmap will expand. This will happen during both levels of vmap
A = torch.randn(4, 4)
k = torch.randn(4, 5, B1, B0)
fn, args = get_jvp_variant_primals_tangents(
torch.linalg.solve, SampleInput(A, args=(k,))
)
in_dims_all = (None, -1, None, -1)
batched_out = vmap(vmap(fn, in_dims=in_dims_all), in_dims=in_dims_all)(*args)
loop_out = loop2(fn, in_dims_all, in_dims_all, 0, 0, B0, B1, *args)
self.assertEqual(loop_out, batched_out)
@ops(
filter(lambda op: op.name in aliasing_ops, op_db + additional_op_db),
allowed_dtypes=(torch.float,),
)
@parametrize("grad_op", ["jvp", "vjp"])
def test_view_then_inplace(self, device, dtype, op, grad_op):
for sample_input in op.sample_inputs(device, dtype):
def f(x):
op(sample_input.input, *sample_input.args, **sample_input.kwargs).copy_(
x
)
return x
without_grad = op(
sample_input.input, *sample_input.args, **sample_input.kwargs
)
if grad_op == "jvp":
with self.assertRaisesRegex(
RuntimeError,
"During a grad .* attempted to call in-place operation",
):
jvp(
f,
(torch.randn_like(without_grad),),
(torch.randn_like(without_grad),),
)
else:
assert grad_op == "vjp"
with self.assertRaisesRegex(
RuntimeError,
"During a grad .* attempted to call in-place operation",
):
vjp(f, torch.randn_like(without_grad))
@ops(
filter(
lambda op: op.name in aliasing_ops_list_return, op_db + additional_op_db
),
allowed_dtypes=(torch.float,),
)
@parametrize("grad_op", ["jvp", "vjp"])
def test_view_then_inplace_list_return(self, device, dtype, op, grad_op):
for sample_input in op.sample_inputs(device, dtype):
def f(x):
op(sample_input.input, *sample_input.args, **sample_input.kwargs)[
0
].copy_(x)
return x
without_grad = op(
sample_input.input, *sample_input.args, **sample_input.kwargs
)[0]
with self.assertRaisesRegex(
RuntimeError, "During a grad .* attempted to call in-place operation"
):
if grad_op == "jvp":
jvp(
f,
(torch.randn_like(without_grad),),
(torch.randn_like(without_grad),),
)
else:
assert grad_op == "vjp"
vjp(f, torch.randn_like(without_grad))
@parametrize("grad_op", ["jvp", "vjp"])
def test_view_then_inplace_special(self, grad_op):
# some things in __getitem__ use at::index, which doesn't alias, so this tests a subset of them that do alias
ops = [
lambda x: x[0],
lambda x: x[0, 0, 0],
lambda x: x[:1],
lambda x: x[:, :1],
lambda x: x[:, :1, :],
]
for op in ops:
def f(x):
op(captured).copy_(x)
return x
captured = torch.randn(4, 3, 3)
without_grad = op(captured)
if grad_op == "jvp":
with self.assertRaisesRegex(
RuntimeError,
"During a grad .* attempted to call in-place operation",
):
jvp(
f,
(torch.randn_like(without_grad),),
(torch.randn_like(without_grad),),
)
else:
assert grad_op == "vjp"
with self.assertRaisesRegex(
RuntimeError,
"During a grad .* attempted to call in-place operation",
):
vjp(f, torch.randn_like(without_grad))
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
# NOTE: [three-transform testing]
# We only test the autograd_function_db tests here.
#
# Usually testing the composition of two transforms is sufficient to convince
# ourselves that an operator is correctly implemented. For the following cases,
# we want to be extra sure, so we send those through some three-transform tests:
# - autograd.Function. The mechanism is via PyDispatcher/HigherOrderOperator, not the
# regular PyTorch dispatcher, so it's good to exercise more caution.
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
@skipOps(
"TestOperators",
"test_vmapvjpvmap",
{
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
},
)
def test_vmapvjpvmap(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
B = 2
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
for batched_args, in_dims, kwargs in generator:
inner_vmapped_op = vmap(op, in_dims)
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
inner_vmapped_fn, primals = normalize_op_input_output2(
inner_vmapped_op,
batched_args,
kwargs,
sample.output_process_fn_grad,
)
inner_mapped_fn, _ = normalize_op_input_output2(
inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
)
result = inner_mapped_fn(*primals)
cotangents = tree_map(lambda x: torch.rand_like(x), result)
def apply_vjp(fn):
def inner(primals, cotangents):
_, vjp_fn = vjp(fn, *primals)
return vjp_fn(cotangents)
return inner
vjpvmap_fn = apply_vjp(inner_vmapped_fn)
vjpmap_fn = apply_vjp(inner_mapped_fn)
batched_args = (primals, cotangents)
generator = generate_vmap_inputs(batched_args, {})
for batched_args, in_dims, _ in generator:
# strategy: compare vmap(vjp(vmap(op)) vs map(vjp(map(op))
vmapvjpvmap_fn = vmap(vjpvmap_fn, in_dims)
mapvjpmap_fn = functools.partial(loop, vjpmap_fn, in_dims, 0, B)
result = vmapvjpvmap_fn(*batched_args)
expected = mapvjpmap_fn(*batched_args)
self.assertEqual(result, expected)
# See NOTE: [three-transform testing]
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
@skipOps(
"TestOperators",
"test_vjpvmapvmap",
{
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
},
)
def test_vjpvmapvmap(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
B = 2
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
for batched_args, inner_in_dims, kwargs in generator:
inner_vmapped_op = vmap(op, inner_in_dims)
inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B)
generator = generate_vmap_inputs(batched_args, kwargs)
for batched_args, in_dims, kwargs in generator:
# strategy: compare vjp(vmap(vmap(op)) vs vjp(map(map(op))
vmapped_op = vmap(inner_vmapped_op, in_dims)
mapped_op = functools.partial(loop, inner_mapped_op, in_dims, 0, B)
vmapped_fn, primals = normalize_op_input_output2(
vmapped_op, batched_args, kwargs, sample.output_process_fn_grad
)
mapped_fn, _ = normalize_op_input_output2(
mapped_op, batched_args, kwargs, sample.output_process_fn_grad
)
result = mapped_fn(*primals)
cotangents = tree_map(lambda x: torch.rand_like(x), result)
_, vjp_fn = vjp(mapped_fn, *primals)
expected_vjps = vjp_fn(cotangents)
_, vjp_fn = vjp(vmapped_fn, *primals)
result_vjps = vjp_fn(cotangents)
self.assertEqual(result_vjps, expected_vjps)
# See NOTE: [three-transform testing]
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
@skipOps(
"TestOperators",
"test_vjpvjpvmap",
{
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
},
)
def test_vjpvjpvmap(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
B = 2
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
for batched_args, in_dims, kwargs in generator:
inner_vmapped_op = vmap(op, in_dims)
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
vjpmap_fn, args = get_vjpfull_variant2(
inner_mapped_op, batched_args, kwargs
)
vjpvmap_fn, _ = get_vjpfull_variant2(
inner_vmapped_op, batched_args, kwargs
)
vjpvjpvmap_fn, new_args = get_vjpfull_variant2(vjpvmap_fn, args, {})
vjpvjpmap_fn, _ = get_vjpfull_variant2(vjpmap_fn, args, {})
expected = vjpvjpmap_fn(*new_args)
result = vjpvjpvmap_fn(*new_args)
self.assertEqual(result, expected)
# We're generally convinced that jvp x vmap works (vmap turns an operator
# into another operator and we test jvp support for operators). So
# we only test it on the things we're not sure about:
# - the autograd.Function <> functorch interaction
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
@skipOps(
"TestOperators",
"test_jvpvmap",
{
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
},
)
def test_jvpvmap(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
B = 2
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
for batched_args, in_dims, kwargs in generator:
inner_vmapped_op = vmap(op, in_dims)
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
jvpvmap_op, primals = get_jvp_variant_primals_tangents2(
inner_vmapped_op,
batched_args,
kwargs,
sample.output_process_fn_grad,
)
jvpmap_op, _ = get_jvp_variant_primals_tangents2(
inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
)
expected = jvpmap_op(*primals)
result = jvpvmap_op(*primals)
self.assertEqual(result, expected)
# See NOTE: [three-transform testing]
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
@skipOps(
"TestOperators",
"test_jvpvmapvmap",
{
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
},
)
def test_jvpvmapvmap(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
B = 2
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
for batched_args, inner_in_dims, kwargs in generator:
inner_vmapped_op = vmap(op, inner_in_dims)
inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B)
generator = generate_vmap_inputs(batched_args, kwargs)
for batched_args, in_dims, kwargs in generator:
# strategy: compare jvp(vmap(vmap(op)) vs jvp(map(map(op))
vmapped_op = vmap(inner_vmapped_op, in_dims)
mapped_op = functools.partial(loop, inner_mapped_op, in_dims, 0, B)
jvpvmapvmap_fn, primals = get_jvp_variant_primals_tangents2(
vmapped_op, batched_args, kwargs, sample.output_process_fn_grad
)
jvpmapmap_fn, _ = get_jvp_variant_primals_tangents2(
mapped_op, batched_args, kwargs, sample.output_process_fn_grad
)
expected = jvpmapmap_fn(*primals)
result = jvpvmapvmap_fn(*primals)
self.assertEqual(result, expected)
# See NOTE: [three-transform testing]
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
@skipOps(
"TestOperators",
"test_vmapjvpvmap",
{
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
},
)
def test_vmapjvpvmap(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
B = 2
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
for batched_args, in_dims, kwargs in generator:
inner_vmapped_op = vmap(op, in_dims)
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
jvpvmap_fn, primals = get_jvp_variant_primals_tangents2(
inner_vmapped_op,
batched_args,
kwargs,
sample.output_process_fn_grad,
)
jvpmap_fn, _ = get_jvp_variant_primals_tangents2(
inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
)
generator = generate_vmap_inputs(primals, {})
for batched_args, in_dims, _ in generator:
# strategy: compare vmap(jvp(vmap(op)) vs map(jvp(map(op))
vmapjvpvmap_fn = vmap(jvpvmap_fn, in_dims)
mapjvpmap_fn = functools.partial(loop, jvpmap_fn, in_dims, 0, B)
result = vmapjvpvmap_fn(*batched_args)
expected = mapjvpmap_fn(*batched_args)
self.assertEqual(result, expected)
# See NOTE: [three-transform testing]
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
@skipOps(
"TestOperators",
"test_jvpjvpvmap",
{
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
},
)
def test_jvpjvpvmap(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
B = 2
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
for batched_args, in_dims, kwargs in generator:
inner_vmapped_op = vmap(op, in_dims)
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
jvpmap_fn, args = get_jvp_variant_primals_tangents2(
inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
)
jvpvmap_fn, _ = get_jvp_variant_primals_tangents2(
inner_vmapped_op,
batched_args,
kwargs,
sample.output_process_fn_grad,
)
jvpjvpvmap_fn, new_args = get_jvp_variant_primals_tangents2(
jvpvmap_fn, args, {}
)
jvpjvpmap_fn, _ = get_jvp_variant_primals_tangents2(jvpmap_fn, args, {})
expected = jvpjvpmap_fn(*new_args)
result = jvpjvpvmap_fn(*new_args)
self.assertEqual(result, expected)
# See NOTE: [three-transform testing]
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
@skipOps(
"TestOperators",
"test_jvpvjpvmap",
{
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
},
)
def test_jvpvjpvmap(self, device, dtype, op):
samples = op.sample_inputs(device, dtype, requires_grad=True)
B = 2
for sample in samples:
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
for batched_args, in_dims, kwargs in generator:
inner_vmapped_op = vmap(op, in_dims)
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
vjpmap_fn, args = get_vjpfull_variant2(
inner_mapped_op, batched_args, kwargs
)
vjpvmap_fn, _ = get_vjpfull_variant2(
inner_vmapped_op, batched_args, kwargs
)
jvpvjpvmap_fn, new_args = get_jvp_variant_primals_tangents2(
vjpvmap_fn, args, {}
)
jvpvjpmap_fn, _ = get_jvp_variant_primals_tangents2(vjpmap_fn, args, {})
expected = jvpvjpmap_fn(*new_args)
result = jvpvjpvmap_fn(*new_args)
self.assertEqual(result, expected)
def test_data_write_errors_under_transform(self, device):
t = torch.randn(3, 3, device=device)
def fn(t):
t.data = torch.randn(3, 3)
return t.sum()
msg = "mutating directly with `.data` inside functorch transform"
with self.assertRaisesRegex(RuntimeError, msg):
grad(fn)(t)
with self.assertRaisesRegex(RuntimeError, msg):
vjp(fn, t)
with self.assertRaisesRegex(RuntimeError, msg):
jvp(fn, (t,), (torch.randn_like(t),))
def test_tensor_with_scalar_list(self, device):
x = torch.randn((), device=device)
def func_list_of_scalar(x):
return torch.tensor([x], device=device)
def func(x):
return torch.tensor(x, device=device).view(1)
actual_o, actual_fn = vjp(func_list_of_scalar, x)
expected_o, expected_fn = vjp(func, x)
self.assertEqual(actual_o, expected_o)
self.assertEqual(
expected_fn(torch.ones_like(expected_o)),
actual_fn(torch.ones_like(actual_o)),
)
only_for = ("cpu", "cuda")
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
if __name__ == "__main__":
run_tests()