mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
# UPDATE: This is take 3 of https://github.com/pytorch/pytorch/pull/131863 which was landed via co dev but not applying correclty # Summary Changes the stance of SDPA on what to do for fully masked out rows ## Current Behavior Several PyTorch users have expressed frustration over this issue: - https://github.com/pytorch/pytorch/issues/41508 - https://github.com/pytorch/pytorch/issues/103749 - https://github.com/pytorch/pytorch/issues/103963 These are significant issues with extensive discussion but no satisfactory resolution. The PyTorch team's consensus, as stated here: https://github.com/pytorch/pytorch/issues/24816#issuecomment-524415617 Can be paraphrased as follows: When passing in fully masked out rows, attention becomes ambiguous. We have two main options: 1. Uniformly attend to all values: ```python scores[masked_out_rows] = 1 / len(row) out[masked_out_rows] = 1 / len(row) * value ``` 2. Decide that attention between no queries (masked) and no keys (masked) is meaningless: ```python output[fully_masked_rows] = NaN ``` We went with option 2. Partially because it was easier to implement, but also people argued that users can slice the output to remove the NaNs: ``` Python >fill_value = -float("inf") >row0 = torch.randn(4) >row1 = torch.tensor([(fill_value for _ in range(4)]) >matrix = torch.stack([row0, row1]).requires_grad_(True) >out = torch.softmax(matrix, 1) >out = out[0] >print(out) tensor([0.5377, 0.2729, 0.0692, 0.1201]) ``` Cool, problem solved. But what happends when you call backwards.. ```Python >out.backward(torch.ones_like(out)) >print(matrix.grad) tensor([[3.0957e-08, 1.4157e-08, 7.7802e-10, 1.3713e-08], [ nan, nan, nan, nan]]) ``` Those pesky NaNs are back! ## Why do we see NaNs today? The core of the problem revolves around using softmax function in sdpa: ```python > row = torch.tensor([(-float("inf")) for _ in range(4)]) > torch.softmax(row, 0) tensor([nan, nan, nan, nan]) ``` ## Quick Aside: Masking in Attention Attention itself doesn't have a concept of masking. The `sdpa` function has an argument called `attn_mask`, which would be more accurately named `attn_bias`. This is because we don't actually "mask" entries when computing attention. Instead, due to implementation details([performance](https://github.com/pytorch/pytorch/issues/25110#issuecomment-524519087)), we add a value to the masked-out query/key pairs. We use a large negative number (typically -inf) to decrease the attention weight, as softmax assigns more weight to larger values. ## Alternative Approaches If we use a very large negative number instead of -inf: ```python > row = torch.tensor([(-1e6) for _ in range(4)]) > torch.softmax(row, 0) tensor([0.2500, 0.2500, 0.2500, 0.2500]) ``` However if users always remembered to "slice" out their outputs i.e.: ```Python >fill_value = -1e6 >... >out.backward(torch.ones_like(out)) >print(matrix.grad) tensor([[-0.0563, -0.0564, 0.1613, -0.0486], [ 0.0000, 0.0000, 0.0000, 0.0000]]) ``` This would bring us back into a better state. ## A Third Option We don't necessarily need to alter the behavior of softmax for -inf or very large negative numbers. The fundamental goal is to exclude certain query/key pairs from attention, regardless of the underlying implementation. This PR implements the new semantic for masking w/ attention in fully masked-out rows: ```python out[masked_out_rows] = 0 ``` **Important Note**: This idea isn't entirely new. The [MaskedTensor](https://pytorch.org/tutorials/prototype/maskedtensor_overview#safe-softmax) prototype, a tensor subclass, was designed to handle such cases. However, it remains a prototype feature and hasn't gained widespread adoption. ## Details This PR stack does 3 things: 1. Adds a PRIVATE _safe_softmax op 2. Updates semantic for flash_cpu fused kernel 3. Updates semantic for efficient_cuda fused kernel _safe_softmax is not supposed to be used generically and is only meant to be used within the context of SDPA. Due to this fact instead of decomposing softmax and checking for -inf rows we instead "cheat" and use nan_to_num. Why I think this is okay? (please find a counter point if avail) There are multiple ways NaNs can emerge. For the fully masked out rows case nan_to_num works. But what if there were other NaNs, wouldn't this silently remove them? The only case that this can happen is if the input itself had a NaN or an Inf For example: ```Python a = torch.ones([4], requires_grad=False, dtype=torch.float16) a[1] = torch.finfo(torch.float16).max print(a.softmax(-1)) ``` Will return `tensor([0., 1., 0., 0.], dtype=torch.float16)` Where ```Python a = torch.ones([4], requires_grad=False, dtype=torch.float16) a[1] = float("inf") a.softmax(-1) ``` returns: `tensor([nan, nan, nan, nan], dtype=torch.float16)` If we dont want to even allow for the possibility of "inf" or "NaN" attention scores to be converted to 0 then we can implemented it something like this ```Python max = torch.max(a, dim=-1, keepdim=True) exp = torch.exp(a - max.values) denom = torch.sum(exp, dim=-1, keepdim=True) softmax = exp / denom softmax = torch.where(max.values == float('-inf'), 0.0, softmax) ``` however we would be paying for this in math performance. ## Why Now I think one point that has substantially changed where PyTorch should lie on this argument is the fact that we have fused implementations for SDPA now. And these fused implementations allow us to easily and performantly support this new semantic. Differential Revision: [D61418679](https://our.internmc.facebook.com/intern/diff/D61418679) Pull Request resolved: https://github.com/pytorch/pytorch/pull/133882 Approved by: https://github.com/soulitzer
2997 lines
123 KiB
Python
2997 lines
123 KiB
Python
# Owner(s): ["module: functorch"]
|
|
|
|
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import functools
|
|
import itertools
|
|
import unittest
|
|
|
|
from common_utils import (
|
|
check_vmap_fallback,
|
|
decorate,
|
|
expectedFailureIf,
|
|
generate_vmap_inputs,
|
|
get_fallback_and_vmap_exhaustive,
|
|
is_batch_norm_training,
|
|
is_valid_inplace_sample_input,
|
|
loop,
|
|
loop2,
|
|
opsToleranceOverride,
|
|
skip,
|
|
skipOps,
|
|
tol1,
|
|
tol2,
|
|
xfail,
|
|
)
|
|
from functorch_additional_op_db import additional_op_db
|
|
|
|
import torch
|
|
import torch.autograd.forward_ad as fwAD
|
|
from functorch import grad, jacfwd, jacrev, vjp, vmap
|
|
from torch import Tensor
|
|
from torch._functorch.eager_transforms import _as_tuple, jvp
|
|
from torch.testing._internal.autograd_function_db import autograd_function_db
|
|
from torch.testing._internal.common_cuda import with_tf32_off
|
|
from torch.testing._internal.common_device_type import (
|
|
instantiate_device_type_tests,
|
|
ops,
|
|
tol,
|
|
toleranceOverride,
|
|
)
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
from torch.testing._internal.common_utils import (
|
|
is_iterable_of_tensors,
|
|
IS_MACOS,
|
|
IS_X86,
|
|
noncontiguous_like,
|
|
parametrize,
|
|
run_tests,
|
|
runOnRocm,
|
|
skipIfRocm,
|
|
TEST_WITH_ASAN,
|
|
TEST_WITH_ROCM,
|
|
TestCase,
|
|
unMarkDynamoStrictTest,
|
|
)
|
|
from torch.testing._internal.opinfo.core import SampleInput
|
|
from torch.utils import _pytree as pytree
|
|
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
|
|
# Version of autograd.grad with some differences:
|
|
# - pytree inputs is allowed (but leaves of the pytree have to all
|
|
# be tensors)
|
|
# - if an input is not used as part of derivatives, we will return a
|
|
# zero-filled tensor for the result
|
|
def _autograd_grad(
|
|
outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True
|
|
):
|
|
inputs, inputs_spec = tree_flatten(inputs)
|
|
diff_inputs = tuple(inp for inp in inputs if inp.requires_grad)
|
|
if grad_outputs is None:
|
|
diff_outputs = tuple(out for out in outputs if out.requires_grad)
|
|
else:
|
|
diff_grad_outputs = [
|
|
(out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad
|
|
]
|
|
if len(diff_grad_outputs) == 0:
|
|
diff_outputs, grad_outputs = (), ()
|
|
else:
|
|
diff_outputs, grad_outputs = zip(*diff_grad_outputs)
|
|
grad_inputs = torch.autograd.grad(
|
|
diff_outputs,
|
|
diff_inputs,
|
|
grad_outputs,
|
|
retain_graph=retain_graph,
|
|
create_graph=create_graph,
|
|
allow_unused=True,
|
|
)
|
|
result = []
|
|
grad_inputs_iter = iter(grad_inputs)
|
|
for inp in inputs:
|
|
if inp.requires_grad:
|
|
grad_input = next(grad_inputs_iter)
|
|
if grad_input is None:
|
|
result.append(torch.zeros_like(inp))
|
|
else:
|
|
result.append(grad_input)
|
|
else:
|
|
result.append(torch.zeros_like(inp))
|
|
return tree_unflatten(result, inputs_spec)
|
|
|
|
|
|
def diff_arg(arg, requires_grad=True):
|
|
def is_differentiable_arg(arg):
|
|
if requires_grad:
|
|
return arg.requires_grad
|
|
else:
|
|
return arg.is_floating_point() or arg.is_complex()
|
|
|
|
if is_iterable_of_tensors(arg):
|
|
if all(is_differentiable_arg(a) for a in arg):
|
|
return True
|
|
if all(not is_differentiable_arg(a) for a in arg):
|
|
return False
|
|
raise RuntimeError("NYI: The test runner can't handle this")
|
|
return isinstance(arg, Tensor) and is_differentiable_arg(arg)
|
|
|
|
|
|
# Given f, returns an f' such that:
|
|
# - f' takes only positional arguments
|
|
# - All arguments to f' are floating-point Tensors
|
|
# - All outputs of f' are floating-point Tensors
|
|
def normalize_op_input_output2(
|
|
f, args, kwargs, output_process_fn_grad=None, requires_grad=True
|
|
):
|
|
flat_args, args_spec = tree_flatten(args)
|
|
diff_argnums = tuple(
|
|
i
|
|
for i, arg in enumerate(flat_args)
|
|
if diff_arg(arg, requires_grad=requires_grad)
|
|
)
|
|
assert len(diff_argnums) > 0
|
|
primals = tuple(flat_args[i] for i in diff_argnums)
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*primals):
|
|
_args = list(flat_args)
|
|
for num, arg in zip(diff_argnums, primals):
|
|
_args[num] = arg
|
|
_args = tree_unflatten(_args, args_spec)
|
|
result = f(*_args, **kwargs)
|
|
if output_process_fn_grad is not None:
|
|
result = output_process_fn_grad(result)
|
|
if isinstance(result, tuple):
|
|
result = tuple(r for r in result if torch.is_floating_point(r))
|
|
assert len(result) > 0
|
|
return result
|
|
|
|
return wrapped, primals
|
|
|
|
|
|
# TODO: consolidate with normalize_op_input_output2
|
|
def normalize_op_input_output3(
|
|
f, args, kwargs, sample_args, output_process_fn_grad=None
|
|
):
|
|
flat_args, args_spec = tree_flatten(args)
|
|
flat_sample_args = pytree.tree_leaves(sample_args)
|
|
diff_argnums = tuple(
|
|
i
|
|
for i, (arg, sample) in enumerate(zip(flat_args, flat_sample_args))
|
|
if diff_arg(sample, requires_grad=True)
|
|
)
|
|
assert len(diff_argnums) > 0
|
|
primals = tuple(flat_args[i] for i in diff_argnums)
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*primals):
|
|
_args = list(flat_args)
|
|
for num, arg in zip(diff_argnums, primals):
|
|
_args[num] = arg
|
|
_args = tree_unflatten(_args, args_spec)
|
|
result = f(*_args, **kwargs)
|
|
if output_process_fn_grad is not None:
|
|
result = output_process_fn_grad(result)
|
|
if isinstance(result, tuple):
|
|
result = tuple(r for r in result if torch.is_floating_point(r))
|
|
assert len(result) > 0
|
|
return result
|
|
|
|
return wrapped, primals
|
|
|
|
|
|
def normalize_op_input_output(f, sample, requires_grad=True):
|
|
args = tuple([sample.input] + list(sample.args))
|
|
return normalize_op_input_output2(
|
|
f,
|
|
args,
|
|
sample.kwargs,
|
|
sample.output_process_fn_grad,
|
|
requires_grad=requires_grad,
|
|
)
|
|
|
|
|
|
def ref_vjp(f, *primals):
|
|
result = f(*primals)
|
|
|
|
def wrapped(cotangents):
|
|
return _autograd_grad(_as_tuple(result), primals, _as_tuple(cotangents))
|
|
|
|
return result, wrapped
|
|
|
|
|
|
def simulate_jvp(f, primals, tangents):
|
|
primals_out, tangents_out = torch.autograd.functional.jvp(f, primals, tangents)
|
|
return primals_out, tangents_out
|
|
|
|
|
|
def ref_jvp(f, primals, tangents):
|
|
with fwAD.dual_level():
|
|
duals = tuple(fwAD.make_dual(p, t) for p, t in zip(primals, tangents))
|
|
result_duals = f(*duals)
|
|
result_duals, spec = tree_flatten(result_duals)
|
|
primals_out, tangents_out = zip(*(fwAD.unpack_dual(d) for d in result_duals))
|
|
return tree_unflatten(primals_out, spec), tree_unflatten(tangents_out, spec)
|
|
|
|
|
|
def get_sample_cotangents(f, sample):
|
|
fn, primals = normalize_op_input_output(f, sample)
|
|
output = fn(*primals)
|
|
return tree_map(torch.randn_like, output)
|
|
|
|
|
|
# returns a new function g(*args, *cotangents)
|
|
# that computes vjps and (*args, cotangents)
|
|
def get_vjp_fn_and_args_with_cotangents(f, sample, cotangents):
|
|
args = tuple([sample.input] + list(sample.args))
|
|
kwargs = sample.kwargs
|
|
flat_args, args_spec = tree_flatten(args)
|
|
flat_cotangents, cotangents_spec = tree_flatten(cotangents)
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*args):
|
|
assert len(args) == len(flat_args) + len(flat_cotangents)
|
|
actual_args = args[: len(flat_args)]
|
|
cotangents = args[len(flat_args) :]
|
|
actual_args = tree_unflatten(actual_args, args_spec)
|
|
cotangents = tree_unflatten(cotangents, cotangents_spec)
|
|
|
|
fn, primals = normalize_op_input_output3(
|
|
f, actual_args, kwargs, flat_args, sample.output_process_fn_grad
|
|
)
|
|
_, vjp_fn = vjp(fn, *primals)
|
|
return vjp_fn(cotangents)
|
|
|
|
return wrapped, tuple(flat_args + flat_cotangents)
|
|
|
|
|
|
# Returns a new function g(*args, *cotangents) that computes vjps and
|
|
# sample (*args, *cotangents)
|
|
def get_vjpfull_variant(f, sample):
|
|
fn, primals = normalize_op_input_output(f, sample)
|
|
return _get_vjpfull_variant(fn, primals)
|
|
|
|
|
|
def get_vjpfull_variant2(f, args, kwargs):
|
|
fn, primals = normalize_op_input_output2(f, args, kwargs)
|
|
return _get_vjpfull_variant(fn, primals)
|
|
|
|
|
|
def _get_vjpfull_variant(fn, primals):
|
|
result = fn(*primals)
|
|
cotangents = _as_tuple(
|
|
tree_map(lambda x: torch.randn_like(x, requires_grad=True), result)
|
|
)
|
|
num_primals = len(primals)
|
|
args = (*primals, *cotangents)
|
|
|
|
@functools.wraps(fn)
|
|
def wrapped(*args):
|
|
primals = args[:num_primals]
|
|
cotangents = args[num_primals:]
|
|
result, vjp_fn = vjp(fn, *primals)
|
|
if isinstance(result, torch.Tensor):
|
|
assert len(cotangents) == 1
|
|
cotangents = cotangents[0]
|
|
return vjp_fn(cotangents)
|
|
|
|
return wrapped, args
|
|
|
|
|
|
def get_jvp_variant(f, sample):
|
|
# We want this higher-order variant of jvp, so that it can
|
|
# be used to wrap vmap
|
|
fn, primals = normalize_op_input_output(f, sample, requires_grad=False)
|
|
tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals))
|
|
|
|
@functools.wraps(f)
|
|
def wrapped(*args):
|
|
tangents = args
|
|
primals_out, tangents_out = jvp(fn, primals, tangents)
|
|
|
|
if isinstance(primals_out, torch.Tensor):
|
|
return (primals_out, tangents_out)
|
|
else:
|
|
flat_primals_out = pytree.tree_leaves(primals_out)
|
|
flat_tangents_out = pytree.tree_leaves(tangents_out)
|
|
return tuple(flat_primals_out + flat_tangents_out)
|
|
|
|
return wrapped, tangents
|
|
|
|
|
|
def get_jvp_variant_primals_tangents2(
|
|
f, args, kwargs, output_process_fn_grad=None, requires_grad=False
|
|
):
|
|
fn, primals = normalize_op_input_output2(
|
|
f, args, kwargs, output_process_fn_grad, requires_grad
|
|
)
|
|
tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals))
|
|
return _get_jvp_variant(fn, primals, tangents)
|
|
|
|
|
|
def get_jvp_variant_primals_tangents(f, sample):
|
|
# We want this higher-order variant of jvp, so that it can
|
|
# be used to wrap vmap
|
|
fn, primals = normalize_op_input_output(f, sample, requires_grad=False)
|
|
tangents = _as_tuple(tree_map(lambda x: torch.randn_like(x), primals))
|
|
return _get_jvp_variant(fn, primals, tangents)
|
|
|
|
|
|
def _get_jvp_variant(fn, primals, tangents):
|
|
@functools.wraps(fn)
|
|
def wrapped(*args):
|
|
primals_in = args[: len(primals)]
|
|
tangents_in = args[len(primals) :]
|
|
primals_out, tangents_out = jvp(fn, primals_in, tangents_in)
|
|
|
|
if isinstance(primals_out, torch.Tensor):
|
|
return (primals_out, tangents_out)
|
|
else:
|
|
flat_primals_out = pytree.tree_leaves(primals_out)
|
|
flat_tangents_out = pytree.tree_leaves(tangents_out)
|
|
return tuple(flat_primals_out + flat_tangents_out)
|
|
|
|
return wrapped, primals + tangents
|
|
|
|
|
|
def is_inplace(op, variant):
|
|
if hasattr(variant, "__wrapped__"):
|
|
return variant.__wrapped__ is op.get_inplace()
|
|
return variant is op.get_inplace()
|
|
|
|
|
|
vjp_fail = {
|
|
xfail("tensor_split"), # data_ptr composite compliance
|
|
# Very minor accuracy issue on ROCm
|
|
decorate("nn.functional.scaled_dot_product_attention", decorator=skipIfRocm),
|
|
}
|
|
|
|
aliasing_ops = {
|
|
"T",
|
|
"broadcast_to",
|
|
"conj",
|
|
"contiguous",
|
|
"diagonal", # linalg.diagonal is an alias
|
|
"expand",
|
|
"flatten",
|
|
"imag",
|
|
"mH", # adjoint is an alias
|
|
"mT",
|
|
"movedim", # moveaxis is an alias
|
|
"narrow",
|
|
"permute",
|
|
"positive",
|
|
# 'ravel', is composite implicit autograd and may call clone
|
|
"real",
|
|
"reshape",
|
|
"resolve_conj",
|
|
"resolve_neg",
|
|
"select",
|
|
"squeeze",
|
|
"transpose", # swapdims and swapaxes are aliases
|
|
"unflatten",
|
|
"unfold",
|
|
"unsqueeze",
|
|
"view",
|
|
"view_as",
|
|
"view_as_complex",
|
|
"view_as_real",
|
|
}
|
|
|
|
aliasing_ops_list_return = {
|
|
"chunks",
|
|
"dsplit",
|
|
"hsplit",
|
|
"split",
|
|
"unbind",
|
|
"vsplit",
|
|
# 'tensor_split' not composite compliant, see vjp_fail
|
|
}
|
|
|
|
skip_noncontig = {
|
|
"_batch_norm_with_update",
|
|
"as_strided_copy",
|
|
}
|
|
|
|
|
|
@unittest.skipIf(TEST_WITH_ASAN, "tests time out with asan, are probably redundant")
|
|
@unMarkDynamoStrictTest
|
|
class TestOperators(TestCase):
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_grad",
|
|
vjp_fail.union(
|
|
{
|
|
xfail(
|
|
"chalf", "", device_type="cpu"
|
|
), # RuntimeError: "sum_cpu" not implemented for 'ComplexHalf'
|
|
xfail(
|
|
"sparse.sampled_addmm", ""
|
|
), # RuntimeError: Sparse CSR tensors do not have strides
|
|
xfail(
|
|
"sparse.mm", "reduce"
|
|
), # RuntimeError: Sparse CSR tensors do not have strides
|
|
# Non-contiguous Bugs
|
|
#
|
|
# AssertionError: Tensor-likes are not close!
|
|
xfail("_softmax_backward_data", device_type="cpu"),
|
|
xfail("as_strided"),
|
|
xfail("as_strided", "partial_views"),
|
|
# RuntimeError: !self.requires_grad() || self.is_contiguous()
|
|
xfail("as_strided_scatter"),
|
|
# RuntimeError: Tensor must have a last dimension with stride 1
|
|
xfail("view_as_complex"),
|
|
# query: last dimension must be contiguous
|
|
# Fused attention kernels require last dim to be contiguous
|
|
decorate(
|
|
"nn.functional.scaled_dot_product_attention",
|
|
decorator=expectedFailureIf(not TEST_WITH_ROCM),
|
|
), # Works on ROCm
|
|
xfail("torch.ops.aten._flash_attention_forward"),
|
|
xfail("torch.ops.aten._efficient_attention_forward"),
|
|
# RuntimeError: Expected contiguous tensor, but got
|
|
# non-contiguous tensor for argument #2 'grad_output'
|
|
decorate(
|
|
"_batch_norm_with_update",
|
|
decorator=expectedFailureIf(TEST_WITH_ROCM),
|
|
device_type="cuda",
|
|
),
|
|
}
|
|
),
|
|
)
|
|
@opsToleranceOverride(
|
|
"TestOperators",
|
|
"test_grad",
|
|
(
|
|
tol1(
|
|
"nn.functional.binary_cross_entropy_with_logits",
|
|
{torch.float32: tol(atol=1e-04, rtol=1e-04)},
|
|
),
|
|
tol1("masked.cumprod", {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
|
|
tol1("svd_lowrank", {torch.float32: tol(atol=3e-04, rtol=3e-04)}),
|
|
tol1(
|
|
"linalg.multi_dot",
|
|
{torch.float32: tol(atol=1e-05, rtol=8e-04)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"linalg.tensorsolve",
|
|
{torch.float32: tol(atol=3e-04, rtol=3e-04)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"nn.functional.multi_head_attention_forward",
|
|
{torch.float32: tol(atol=8e-04, rtol=1e-03)},
|
|
),
|
|
tol1(
|
|
"__rmatmul__",
|
|
{torch.float32: tol(atol=3e-04, rtol=3e-04)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"matmul",
|
|
{torch.float32: tol(atol=3e-04, rtol=3e-04)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"pca_lowrank",
|
|
{torch.float32: tol(atol=3e-05, rtol=4e-06)},
|
|
device_type="cpu",
|
|
),
|
|
),
|
|
)
|
|
def test_grad(self, device, dtype, op):
|
|
if op.name in vjp_fail:
|
|
self.skipTest("Skipped; Expected failures")
|
|
return
|
|
|
|
if not op.supports_autograd:
|
|
self.skipTest("Skipped! Autograd not supported.")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
if is_inplace(op, op.get_op()):
|
|
self.skipTest("Skipped for redundancy. test_vjp handles in-place testing.")
|
|
return
|
|
|
|
for sample in samples:
|
|
args = [sample.input] + list(sample.args)
|
|
kwargs = sample.kwargs
|
|
|
|
if op.name not in skip_noncontig:
|
|
noncontig_sample = sample.noncontiguous()
|
|
noncontig_args = [noncontig_sample.input] + list(noncontig_sample.args)
|
|
noncontig_kwargs = noncontig_sample.kwargs
|
|
|
|
diff_argnums = tuple(i for i, arg in enumerate(args) if diff_arg(arg))
|
|
assert len(diff_argnums) > 0
|
|
diff_args = tuple(args[i] for i in diff_argnums)
|
|
|
|
def wrapped_fn(*args, **kwargs):
|
|
result = op(*args, **kwargs)
|
|
if sample.output_process_fn_grad is not None:
|
|
result = sample.output_process_fn_grad(result)
|
|
|
|
def abs_if_complex(t):
|
|
if t.dtype.is_complex:
|
|
return t.abs()
|
|
return t
|
|
|
|
# Reduce into single value for grad
|
|
if isinstance(result, torch.Tensor):
|
|
return abs_if_complex(result.sum())
|
|
result = sum(abs_if_complex(res.sum()) for res in result)
|
|
return result
|
|
|
|
result = grad(wrapped_fn, diff_argnums)(*args, **kwargs)
|
|
expected = _autograd_grad(_as_tuple(wrapped_fn(*args, **kwargs)), diff_args)
|
|
self.assertEqual(result, expected)
|
|
|
|
if op.name not in skip_noncontig:
|
|
result_noncontig = grad(wrapped_fn, diff_argnums)(
|
|
*noncontig_args, **noncontig_kwargs
|
|
)
|
|
self.assertEqual(result_noncontig, expected)
|
|
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_jvp",
|
|
set(
|
|
{
|
|
# Composite ops that do bad things. Need to be fixed in PyTorch core.
|
|
# RuntimeError: Cannot access data pointer of Tensor that doesn't have storage
|
|
xfail("tensor_split"),
|
|
# BUG: silent incorrectness: runs and produces numerical differences
|
|
skip("nn.functional.max_unpool1d"), # fails everywhere except on mac
|
|
skip(
|
|
"nn.functional.max_unpool2d"
|
|
), # fails everywhere except on windows
|
|
skip("nn.functional.max_unpool3d"), # fails everywhere except on mac
|
|
xfail(
|
|
"native_batch_norm"
|
|
), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
|
xfail(
|
|
"_native_batch_norm_legit"
|
|
), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
|
xfail(
|
|
"_batch_norm_with_update"
|
|
), # TODO: fails comparing None to tensor of 0s for saved_mean/var tangents
|
|
xfail("nn.functional.scaled_dot_product_attention"),
|
|
xfail("torch.ops.aten._flash_attention_forward"),
|
|
xfail("torch.ops.aten._efficient_attention_forward"),
|
|
xfail(
|
|
"nn.functional.rrelu"
|
|
), # in-place test errors out with no formula implemented
|
|
xfail(
|
|
"NumpyExpMarkDirtyAutogradFunction"
|
|
), # TODO: https://github.com/pytorch/pytorch/issues/91280
|
|
# --- Non-Contiguous Failures! ---
|
|
# This is expected to fail as the operator
|
|
# expects last dim to have stride=1
|
|
xfail("view_as_complex"),
|
|
# BUG
|
|
# AssertionError: Tensor-likes are not close!
|
|
xfail("as_strided"),
|
|
xfail("as_strided", "partial_views"),
|
|
xfail("as_strided_scatter"),
|
|
decorate(
|
|
"linalg.det",
|
|
"singular",
|
|
decorator=expectedFailureIf(IS_MACOS and IS_X86),
|
|
),
|
|
}
|
|
),
|
|
)
|
|
@opsToleranceOverride(
|
|
"TestOperators",
|
|
"test_jvp",
|
|
(
|
|
tol1(
|
|
"nn.functional.conv_transpose3d",
|
|
{torch.float32: tol(atol=1e-04, rtol=1.3e-06)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"linalg.tensorsolve",
|
|
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"masked.prod",
|
|
{torch.float32: tol(atol=1e-05, rtol=1.3e-05)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"nn.functional.binary_cross_entropy_with_logits",
|
|
{torch.float32: tol(atol=4e-04, rtol=4e-04)},
|
|
),
|
|
tol1(
|
|
"nn.functional.batch_norm", {torch.float32: tol(atol=4e-05, rtol=5e-05)}
|
|
),
|
|
tol1("nn.functional.conv2d", {torch.float32: tol(atol=4e-05, rtol=5e-05)}),
|
|
tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
|
|
tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
|
|
tol1(
|
|
"nn.functional.multi_head_attention_forward",
|
|
{torch.float32: tol(atol=6e-05, rtol=2e-05)},
|
|
),
|
|
tol2(
|
|
"linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-5, rtol=2e-5)}
|
|
),
|
|
),
|
|
)
|
|
def test_jvp(self, device, dtype, op):
|
|
# TODO: get rid of vjp_decomp when we add decomposition support to
|
|
# PyTorch's forward-mode ad. Currently the decomposition support only
|
|
# works for functorch.jvp
|
|
VJP_DECOMP = {
|
|
"nn.functional.logsigmoid",
|
|
}
|
|
if op.name in VJP_DECOMP:
|
|
fixme_ref_jvp_local = simulate_jvp
|
|
else:
|
|
fixme_ref_jvp_local = ref_jvp
|
|
|
|
if not op.supports_forward_ad and op.name not in VJP_DECOMP:
|
|
self.skipTest("Skipped! Forward AD not supported.")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
outplace_variant = op if not is_inplace(op, op.get_op()) else None
|
|
inplace_variant = op.inplace_variant if op.supports_inplace_autograd else None
|
|
|
|
for sample in samples:
|
|
if outplace_variant:
|
|
self.jvp_opinfo_test(
|
|
outplace_variant,
|
|
sample,
|
|
sample.output_process_fn_grad,
|
|
clone_inputs=False,
|
|
fixme_ref_jvp_local=fixme_ref_jvp_local,
|
|
test_noncontig=op.name not in skip_noncontig,
|
|
)
|
|
if is_valid_inplace_sample_input(sample, op, inplace_variant):
|
|
self.jvp_opinfo_test(
|
|
inplace_variant,
|
|
sample,
|
|
sample.output_process_fn_grad,
|
|
clone_inputs=True,
|
|
fixme_ref_jvp_local=fixme_ref_jvp_local,
|
|
test_noncontig=op.name not in skip_noncontig,
|
|
)
|
|
|
|
def jvp_opinfo_test(
|
|
self,
|
|
fn,
|
|
sample,
|
|
output_process_fn,
|
|
clone_inputs,
|
|
fixme_ref_jvp_local,
|
|
test_noncontig,
|
|
):
|
|
# NB: we used requires_grad=True to determine where the primals are,
|
|
# but don't need that information otherwise
|
|
args = (sample.input,) + sample.args
|
|
kwargs = sample.kwargs
|
|
contig_fn, primals = normalize_op_input_output2(
|
|
fn, args, kwargs, output_process_fn, requires_grad=True
|
|
)
|
|
orig_primals = tree_map(lambda x: x.detach(), primals)
|
|
orig_tangents = tree_map(lambda x: torch.randn_like(x), primals)
|
|
|
|
def maybe_clone_inputs():
|
|
if clone_inputs:
|
|
primals = tree_map(torch.clone, orig_primals)
|
|
tangents = tree_map(torch.clone, orig_tangents)
|
|
return primals, tangents
|
|
return orig_primals, orig_tangents
|
|
|
|
primals, tangents = maybe_clone_inputs()
|
|
expected_primal_outs, expected_tangent_outs = fixme_ref_jvp_local(
|
|
contig_fn, primals, tangents
|
|
)
|
|
|
|
primals, tangents = maybe_clone_inputs()
|
|
primal_outs, tangent_outs = jvp(contig_fn, primals, tangents)
|
|
|
|
self.assertEqual(primal_outs, expected_primal_outs)
|
|
self.assertEqual(tangent_outs, expected_tangent_outs)
|
|
|
|
if test_noncontig:
|
|
noncontig_sample = sample.noncontiguous()
|
|
noncontig_args = (noncontig_sample.input,) + noncontig_sample.args
|
|
noncontig_kwargs = sample.kwargs
|
|
noncontig_fn, primals = normalize_op_input_output2(
|
|
fn,
|
|
noncontig_args,
|
|
noncontig_kwargs,
|
|
output_process_fn,
|
|
requires_grad=True,
|
|
)
|
|
noncontig_primals = tree_map(lambda x: x.detach(), primals)
|
|
noncontig_tangents = tree_map(
|
|
lambda x: noncontiguous_like(x), orig_tangents
|
|
)
|
|
noncontig_primal_outs, noncontig_tangent_outs = jvp(
|
|
noncontig_fn, noncontig_primals, noncontig_tangents
|
|
)
|
|
|
|
self.assertEqual(noncontig_primal_outs, expected_primal_outs)
|
|
self.assertEqual(noncontig_tangent_outs, expected_tangent_outs)
|
|
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vjp",
|
|
vjp_fail.union(
|
|
{
|
|
xfail("sparse.sampled_addmm", ""),
|
|
xfail("sparse.mm", "reduce"),
|
|
# ---- Non-Contiguous Failures ----
|
|
# This is expected to fail as the operator
|
|
# expects last dim to have stride=1
|
|
xfail("view_as_complex"),
|
|
# RuntimeError: query: last dimension must be contiguous
|
|
# The fused attention kernels require the last dim to be contiguous
|
|
decorate(
|
|
"nn.functional.scaled_dot_product_attention",
|
|
decorator=expectedFailureIf(not TEST_WITH_ROCM),
|
|
), # Works on ROCm
|
|
xfail("torch.ops.aten._flash_attention_forward"),
|
|
xfail("torch.ops.aten._efficient_attention_forward"),
|
|
# BUG
|
|
# AssertionError: Tensor-likes are not close!
|
|
xfail("as_strided"),
|
|
xfail("as_strided_scatter"),
|
|
xfail("_softmax_backward_data", device_type="cpu"),
|
|
xfail("as_strided", "partial_views"),
|
|
}
|
|
),
|
|
)
|
|
@opsToleranceOverride(
|
|
"TestOperators",
|
|
"test_vjp",
|
|
(
|
|
tol1(
|
|
"nn.functional.conv_transpose3d",
|
|
{torch.float32: tol(atol=5e-05, rtol=9e-05)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"nn.functional.binary_cross_entropy_with_logits",
|
|
{torch.float32: tol(atol=1e-04, rtol=1e-04)},
|
|
),
|
|
tol1(
|
|
"nn.functional.multi_head_attention_forward",
|
|
{torch.float32: tol(atol=2e-03, rtol=2e-04)},
|
|
),
|
|
tol1("__rmatmul__", {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
|
|
tol1("matmul", {torch.float32: tol(atol=1e-05, rtol=1e-05)}),
|
|
tol2(
|
|
"linalg.pinv", "hermitian", {torch.float32: tol(atol=1e-05, rtol=1e-05)}
|
|
),
|
|
tol1("linalg.tensorsolve", {torch.float32: tol(atol=9e-03, rtol=2e-04)}),
|
|
tol1("linalg.multi_dot", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
|
|
tol1("svd_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
|
|
tol1("pca_lowrank", {torch.float32: tol(atol=1e-04, rtol=1e-04)}),
|
|
),
|
|
)
|
|
def test_vjp(self, device, dtype, op):
|
|
if not op.supports_autograd:
|
|
self.skipTest("Skipped! Autograd not supported.")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
def _test(_op, inplace=False):
|
|
for sample in samples:
|
|
if inplace and not is_valid_inplace_sample_input(
|
|
sample, op, op.inplace_variant
|
|
):
|
|
continue
|
|
fn, primals = normalize_op_input_output(_op, sample)
|
|
result = fn(*primals)
|
|
cotangents = tree_map(lambda x: torch.randn_like(x), result)
|
|
|
|
out, vjp_fn = vjp(fn, *primals)
|
|
self.assertEqual(out, result)
|
|
result_vjps = vjp_fn(cotangents)
|
|
|
|
_, vjp_fn = ref_vjp(fn, *primals)
|
|
expected_vjps = vjp_fn(cotangents)
|
|
|
|
self.assertEqual(result_vjps, expected_vjps)
|
|
|
|
if op.name not in skip_noncontig:
|
|
noncontig_fn, noncontig_primals = normalize_op_input_output(
|
|
_op, sample.noncontiguous()
|
|
)
|
|
noncontig_cotangents = tree_map(
|
|
lambda x: noncontiguous_like(x), cotangents
|
|
)
|
|
out_noncontig, vjp_fn = vjp(noncontig_fn, *noncontig_primals)
|
|
self.assertEqual(out_noncontig, result)
|
|
noncontig_result_vjps = vjp_fn(noncontig_cotangents)
|
|
self.assertEqual(noncontig_result_vjps, expected_vjps)
|
|
|
|
_test(op)
|
|
for a_op in op.aliases:
|
|
_test(a_op)
|
|
if op.inplace_variant:
|
|
|
|
def f(inp, *args, **kwargs):
|
|
return op.inplace_variant(inp.clone(), *args, **kwargs)
|
|
|
|
_test(f, inplace=True)
|
|
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vjpvjp",
|
|
vjp_fail.union(
|
|
{
|
|
skip("nn.functional.max_unpool1d"), # silent incorrectness; Flaky
|
|
skip("nn.functional.max_unpool2d"), # silent incorrectness; Flaky
|
|
xfail("nn.functional.ctc_loss"), # Not Implemented
|
|
xfail(
|
|
"native_layer_norm", ""
|
|
), # Expected a proper Tensor but got None for argument #1 'other'
|
|
xfail("sparse.sampled_addmm", ""), # sparse tensors have no strides
|
|
xfail("sparse.mm", "reduce"), # sparse tensors have no strides
|
|
skip("nn.functional.scaled_dot_product_attention"),
|
|
xfail("torch.ops.aten._flash_attention_forward"),
|
|
xfail("torch.ops.aten._efficient_attention_forward"),
|
|
# AssertionError: Tensor-likes are not close!
|
|
# Mismatched elements: 1 / 15 (6.7%)
|
|
# Greatest absolute difference: 24.0 at index (2, 4) (up to 1e-05 allowed)
|
|
# Greatest relative difference: 1.7933241714393998e-06 at index (2, 4) (up to 1.3e-06 allowed)
|
|
# The failure occurred for item [0]
|
|
xfail("masked.prod"),
|
|
}
|
|
),
|
|
)
|
|
@opsToleranceOverride(
|
|
"TestOperators",
|
|
"test_vjpvjp",
|
|
(
|
|
tol1(
|
|
"nn.functional.conv_transpose3d",
|
|
{torch.float32: tol(atol=5e-05, rtol=9e-05)},
|
|
device_type="cuda",
|
|
),
|
|
tol1("prod", {torch.float32: tol(atol=2e-05, rtol=1e-04)}),
|
|
tol1("masked.cumprod", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
|
|
tol1("cumprod", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
|
|
tol1("linalg.vander", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
|
|
tol2(
|
|
"linalg.det", "singular", {torch.float32: tol(atol=2e-05, rtol=2e-05)}
|
|
),
|
|
),
|
|
)
|
|
def test_vjpvjp(self, device, dtype, op):
|
|
if not op.supports_autograd:
|
|
self.skipTest("Skipped! Autograd not supported.")
|
|
return
|
|
if not op.supports_gradgrad:
|
|
self.skipTest("Skipped! Operation does not support gradgrad")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
def test(_op, inplace=False):
|
|
for sample in samples:
|
|
if inplace and not is_valid_inplace_sample_input(
|
|
sample, op, op.inplace_variant
|
|
):
|
|
continue
|
|
fn, args = get_vjpfull_variant(_op, sample)
|
|
result = fn(*args)
|
|
cotangents = tree_map(lambda x: torch.randn_like(x), result)
|
|
|
|
# Compute vjp of vjp
|
|
_, vjp_fn = vjp(fn, *args)
|
|
result_vjps = vjp_fn(cotangents)
|
|
|
|
# Compute ref_vjp of vjp. We could have done ref_vjp of ref_vjp,
|
|
# but since we're confident that vjp works by itself, this is
|
|
# an equivalent way to test that.
|
|
_, vjp_fn = ref_vjp(fn, *args)
|
|
expected_vjps = vjp_fn(cotangents)
|
|
|
|
self.assertEqual(result_vjps, expected_vjps)
|
|
|
|
test(op)
|
|
if op.inplace_variant:
|
|
|
|
def fn(inp, *args, **kwargs):
|
|
return op.inplace_variant(inp.clone(), *args, **kwargs)
|
|
|
|
test(fn, inplace=True)
|
|
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vmapvjpvjp",
|
|
vjp_fail.union(
|
|
{
|
|
skip("atleast_1d"), # Takes too long
|
|
skip("atleast_2d"), # Takes too long
|
|
skip("atleast_3d"), # Takes too long
|
|
skip("ormqr"), # Takes too long
|
|
xfail("as_strided"), # incorrect output
|
|
xfail("as_strided", "partial_views"), # incorrect output
|
|
xfail("as_strided_scatter"), # incorrect output
|
|
skip("bernoulli"), # calls random op
|
|
xfail("bfloat16"), # rank 4 tensor for channels_last
|
|
xfail("cdouble"), # rank 4 tensor for channels_last
|
|
xfail("cfloat"), # rank 4 tensor for channels_last
|
|
xfail("chalf"), # rank 4 tensor for channels_last
|
|
xfail("double"), # rank 4 tensor for channels_last
|
|
xfail("float"), # rank 4 tensor for channels_last
|
|
xfail("half"), # rank 4 tensor for channels_last
|
|
xfail(
|
|
"NumpyCubeNotComposableAutogradFunction"
|
|
), # Not composable autograd.Function
|
|
# It looks like you're either (1) calling .item() on a Tensor or
|
|
# (2) attempting to use a Tensor in some data-dependent control flow or
|
|
# (3) encountering this error in PyTorch internals.
|
|
xfail("index_reduce", "prod"),
|
|
decorate(
|
|
"linalg.householder_product", decorator=runOnRocm
|
|
), # works on ROCm
|
|
xfail(
|
|
# nans
|
|
"masked.softmax",
|
|
device_type="cpu",
|
|
),
|
|
xfail(
|
|
"nanquantile", device_type="cpu"
|
|
), # vmap not implemented for at::equal.
|
|
xfail("native_layer_norm"), # vmap: inplace into a regular tensor
|
|
# got a batched tensor as input while the running_mean or running_var,
|
|
# which will be updated in place, were not batched.
|
|
xfail("nn.functional.batch_norm"),
|
|
xfail(
|
|
"nn.functional.binary_cross_entropy"
|
|
), # vmap: inplace into a regular tensor
|
|
xfail(
|
|
"nn.functional.ctc_loss"
|
|
), # derivate not implemented for _ctc_loss_backward
|
|
# flaky on ROCM needs investigation
|
|
decorate("nn.functional.conv_transpose2d", decorator=skipIfRocm),
|
|
skip("nn.functional.dropout"), # calls random op
|
|
skip("nn.functional.dropout2d"), # calls random op
|
|
skip("nn.functional.dropout3d"), # calls random op
|
|
skip("nn.functional.alpha_dropout"), # calls random op
|
|
skip(
|
|
"nn.functional.feature_alpha_dropout", "with_train"
|
|
), # calls random op
|
|
skip("nn.functional.fractional_max_pool2d"), # calls random op
|
|
skip("nn.functional.fractional_max_pool3d"), # calls random op
|
|
xfail("nn.functional.scaled_dot_product_attention"), # randomness
|
|
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
|
xfail("nn.functional.multi_head_attention_forward"), # randomness
|
|
# It looks like you're either (1) calling .item() on a Tensor or
|
|
# (2) attempting to use a Tensor in some data-dependent control flow or
|
|
# (3) encountering this error in PyTorch internals.
|
|
xfail("nn.functional.gaussian_nll_loss"),
|
|
# got a batched tensor as input while the running_mean or running_var,
|
|
# which will be updated in place, were not batched.
|
|
xfail("nn.functional.instance_norm"),
|
|
xfail(
|
|
"nn.functional.layer_norm"
|
|
), # vmap: inplace into a regular tensor
|
|
# RuntimeError: NYI: querying is_contiguous inside of vmap
|
|
# for memory_format other than torch.contiguous_formats
|
|
xfail("nn.functional.max_pool2d"),
|
|
# RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only
|
|
# supported with memory_format torch.preserve_format or
|
|
# torch.contiguous_format (got ChannelsLast)
|
|
xfail("nn.functional.max_unpool2d"),
|
|
# RuntimeError: NYI: Tensor.clone(memory_format) inside vmap is only
|
|
# supported with memory_format torch.preserve_format
|
|
# or torch.contiguous_format (got ChannelsLast)s
|
|
xfail("nn.functional.max_unpool2d", "grad"),
|
|
xfail(
|
|
"nn.functional.rrelu"
|
|
), # RuntimeError: vmap: we do not yet support aten::rrelu_with_noise.
|
|
xfail("normal"), # calls random op
|
|
xfail("normal", "number_mean"), # calls random op
|
|
xfail("pca_lowrank"), # calls random op
|
|
xfail(
|
|
"quantile", device_type="cpu"
|
|
), # Batching rule not implemented for `at::equal`
|
|
xfail(
|
|
"scatter_reduce", "prod"
|
|
), # vmap (looks like you are calling item/data-dependent)
|
|
xfail(
|
|
"sparse.sampled_addmm"
|
|
), # RuntimeError: Sparse CSR tensors do not have strides
|
|
xfail(
|
|
"sparse.mm", "reduce"
|
|
), # RuntimeError: Sparse CSR tensors do not have strides
|
|
xfail("svd_lowrank"), # calls random op
|
|
xfail("to"), # rank 4 tensor for channels_last
|
|
xfail(
|
|
"view_as_complex"
|
|
), # RuntimeError: Tensor must have a last dimension with stride 1
|
|
# got a batched tensor as input while the running_mean or running_var,
|
|
# which will be updated in place, were not batched.
|
|
xfail("nn.functional.batch_norm", "without_cudnn"),
|
|
# view doesn't work on sparse
|
|
xfail("to_sparse"),
|
|
xfail("native_batch_norm"),
|
|
xfail("_native_batch_norm_legit"),
|
|
# TODO: implement batching rule
|
|
xfail("_batch_norm_with_update"),
|
|
}
|
|
),
|
|
)
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
|
@opsToleranceOverride(
|
|
"TestOperators",
|
|
"test_vmapvjpvjp",
|
|
(
|
|
tol1("linalg.svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
|
|
tol1("linalg.lu", {torch.float32: tol(atol=5e-04, rtol=7e-04)}),
|
|
tol1("linalg.lu_factor", {torch.float32: tol(atol=2e-03, rtol=2e-02)}),
|
|
tol1("linalg.multi_dot", {torch.float32: tol(atol=2e-03, rtol=2e-04)}),
|
|
tol1("svd", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
|
|
tol1("matrix_exp", {torch.float32: tol(atol=1e-03, rtol=5e-04)}),
|
|
tol1("masked.prod", {torch.float32: tol(atol=2e-03, rtol=2e-04)}),
|
|
),
|
|
)
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vmapvjpvjp",
|
|
{
|
|
xfail("as_strided", "partial_views"),
|
|
xfail("as_strided_copy"),
|
|
},
|
|
)
|
|
def test_vmapvjpvjp(self, device, dtype, op):
|
|
# Since, we test `vjpvjp` independently,
|
|
# for this test, we just verify that vmap
|
|
# of `vjpvjp` is correct.
|
|
if not op.supports_autograd:
|
|
self.skipTest("Skipped! Autograd not supported.")
|
|
return
|
|
if not op.supports_gradgrad:
|
|
self.skipTest("Skipped! Operation does not support gradgrad")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
# TODO: test in-place
|
|
if is_inplace(op, op.get_op()):
|
|
self.skipTest("Skipped! NYI: inplace-testing not supported.")
|
|
return
|
|
|
|
for sample in samples:
|
|
fn, args = get_vjpfull_variant(op, sample)
|
|
result = fn(*args)
|
|
cotangents = tree_map(lambda x: torch.randn_like(x), result)
|
|
cotangents = pytree.tree_leaves(cotangents)
|
|
num_args = len(args)
|
|
|
|
args_and_cotangents = tuple(args) + tuple(cotangents)
|
|
|
|
def vjp_of_vjp(*args_and_cotangents):
|
|
args = args_and_cotangents[:num_args]
|
|
cotangents = args_and_cotangents[num_args:]
|
|
result, vjp_fn = vjp(fn, *args)
|
|
result_vjps = vjp_fn(cotangents)
|
|
result = pytree.tree_leaves(result)
|
|
result_vjps = pytree.tree_leaves(result_vjps)
|
|
return (*result, *result_vjps)
|
|
|
|
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
|
|
generator = get_fallback_and_vmap_exhaustive(
|
|
vjp_of_vjp,
|
|
args_and_cotangents,
|
|
{},
|
|
is_batch_norm_and_training=is_batch_norm_and_training,
|
|
)
|
|
for loop_out, batched_out in generator:
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
vmapvjp_fail = vjp_fail.union(
|
|
{
|
|
# -------------------- ALLOWED FAILURES --------------------------------
|
|
# The following are not bugs and are expected behavior
|
|
xfail("masked_select"), # Not possible due to dynamic shapes
|
|
skip("bernoulli"), # randomness
|
|
skip("normal", ""), # randomness
|
|
skip("normal", "number_mean"), # randomness
|
|
skip("nn.functional.rrelu"), # randomness
|
|
skip("nn.functional.feature_alpha_dropout", "with_train"), # randomness
|
|
skip("nn.functional.feature_alpha_dropout", "without_train"), # randomness
|
|
skip("nn.functional.dropout"), # randomness
|
|
skip("nn.functional.dropout2d"), # randomness
|
|
skip("nn.functional.dropout3d", ""), # randomness
|
|
skip("nn.functional.alpha_dropout"), # randomness
|
|
skip("nn.functional.scaled_dot_product_attention"), # randomness
|
|
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
|
skip("nn.functional.multi_head_attention_forward"), # randomness
|
|
xfail(
|
|
"index_put", ""
|
|
), # not possible due to dynamic shapes; we support a subset
|
|
xfail("nn.functional.fractional_max_pool2d"), # random
|
|
xfail("nn.functional.fractional_max_pool3d"), # random
|
|
xfail("pca_lowrank", ""), # randomness
|
|
xfail("svd_lowrank", ""), # randomness
|
|
xfail("to_sparse", ""), # non-dense output
|
|
skip(
|
|
"to"
|
|
), # RuntimeError: required rank 4 tensor to use channels_last format
|
|
xfail("as_strided", "partial_views"),
|
|
xfail(
|
|
"NumpyCubeNotComposableAutogradFunction"
|
|
), # Not composable autograd.Function
|
|
# ----------------------------------------------------------------------
|
|
# ---------------------------- BUGS ------------------------------------
|
|
# All of the following are bugs and need to be fixed
|
|
skip(
|
|
"linalg.svdvals"
|
|
), # # really annoying thing where it passes correctness check but not has_batch_rule
|
|
skip("native_batch_norm"),
|
|
skip("_native_batch_norm_legit"),
|
|
# TODO: implement batching rule
|
|
skip("_batch_norm_with_update"),
|
|
xfail("__getitem__", ""), # dynamic error
|
|
xfail("nanquantile", device_type="cpu"), # checks q via a .item() call
|
|
xfail("nn.functional.gaussian_nll_loss"), # checks var for if any value < 0
|
|
xfail("narrow"), # .item() call
|
|
xfail("quantile", device_type="cpu"), # checks q via a .item() call
|
|
xfail("view_as_complex"), # Tensor must have a last dimension with stride 1
|
|
# required rank 4 tensor to use channels_last format
|
|
xfail("bfloat16"),
|
|
xfail("double"),
|
|
xfail("float"),
|
|
xfail("half"),
|
|
xfail("cdouble", ""),
|
|
xfail("cfloat", ""),
|
|
xfail("chalf", ""),
|
|
xfail("scatter_reduce", "prod"), # item call
|
|
# Batching rule not implemented for aten::_use_cudnn_ctc_loss.Tensor
|
|
xfail("nn.functional.ctc_loss", device_type="cuda"),
|
|
# NYI: querying is_contiguous inside of vmap for memory_format other than torch.contiguous_format
|
|
xfail("nn.functional.max_unpool2d"),
|
|
xfail("nn.functional.max_unpool2d", "grad"),
|
|
xfail("sparse.sampled_addmm", ""),
|
|
xfail("sparse.mm", "reduce"),
|
|
xfail("as_strided_scatter", ""), # calls as_strided
|
|
xfail("index_reduce", "prod"), # .item() call
|
|
# ---------------------------------------------------------------------
|
|
}
|
|
)
|
|
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
|
@opsToleranceOverride(
|
|
"TestOperators",
|
|
"test_vmapvjp",
|
|
(
|
|
tol1(
|
|
"linalg.svd",
|
|
{torch.float32: tol(atol=5e-04, rtol=1e-04)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"svd", {torch.float32: tol(atol=5e-04, rtol=1e-04)}, device_type="cuda"
|
|
),
|
|
tol1(
|
|
"linalg.householder_product",
|
|
{torch.float32: tol(atol=3e-04, rtol=9e-04)},
|
|
),
|
|
tol1(
|
|
"matrix_exp",
|
|
{torch.float32: tol(atol=5e-04, rtol=1e-04)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"nn.functional.layer_norm",
|
|
{torch.float32: tol(atol=3e-4, rtol=1e-4)},
|
|
device_type="cpu",
|
|
),
|
|
tol1(
|
|
"native_layer_norm",
|
|
{torch.float32: tol(atol=3e-4, rtol=1e-4)},
|
|
device_type="cpu",
|
|
),
|
|
),
|
|
)
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vmapvjp",
|
|
vmapvjp_fail.union(
|
|
{
|
|
xfail("as_strided"),
|
|
xfail("as_strided_copy"),
|
|
xfail("as_strided", "partial_views"),
|
|
}
|
|
),
|
|
)
|
|
def test_vmapvjp(self, device, dtype, op):
|
|
if not op.supports_autograd:
|
|
self.skipTest("Skipped! Autograd not supported.")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
# TODO: test in-place
|
|
if is_inplace(op, op.get_op()):
|
|
self.skipTest("Skipped! NYI: inplace-testing not supported.")
|
|
return
|
|
for sample in samples:
|
|
cotangents = get_sample_cotangents(op, sample)
|
|
fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
|
|
is_batch_norm_and_training = is_batch_norm_training(op.name, sample.kwargs)
|
|
generator = get_fallback_and_vmap_exhaustive(
|
|
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training
|
|
)
|
|
for loop_out, batched_out in generator:
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
vmapjvpall_fail = {
|
|
# -------------------- ALLOWED FAILURES --------------------------------
|
|
# The following are expected (not a bug)
|
|
skip("bernoulli", ""), # randomness
|
|
skip("nn.functional.dropout"), # randomness
|
|
skip("nn.functional.rrelu"), # randomness
|
|
skip("nn.functional.dropout2d", ""),
|
|
skip("nn.functional.dropout3d", ""),
|
|
skip("nn.functional.scaled_dot_product_attention"), # randomness
|
|
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
|
skip("nn.functional.multi_head_attention_forward"), # randomness
|
|
skip("nn.functional.alpha_dropout"), # randomness
|
|
skip("nn.functional.feature_alpha_dropout", "without_train"),
|
|
skip("nn.functional.feature_alpha_dropout", "with_train"),
|
|
xfail(
|
|
"nn.functional.fractional_max_pool2d"
|
|
), # Cannot access data pointer of Tensor that doesn't have storage
|
|
xfail(
|
|
"nn.functional.fractional_max_pool3d"
|
|
), # Cannot access data pointer of Tensor that doesn't have storage
|
|
# Not actually a problem: embedding with max_norm mutates the weight
|
|
# and causes different runs to produce different results.
|
|
# skip because this is flaky depending on what the max_norm is!
|
|
skip("nn.functional.embedding", ""),
|
|
skip("to"), # RuntimeError: required rank 4 tensor to use channels_last format
|
|
xfail(
|
|
"NumpyExpMarkDirtyAutogradFunction"
|
|
), # vmap: inplace into a regular tensor
|
|
# ----------------------------------------------------------------------
|
|
# ---------------------------- BUGS ------------------------------------
|
|
# The following are bugs that we should fix
|
|
xfail("masked.mean"), # silent incorrectness (nan difference)
|
|
xfail("as_strided", "partial_views"), # Tensor-likes are not close!
|
|
xfail(
|
|
"nn.functional.soft_margin_loss", ""
|
|
), # soft_margin_loss_backward does not support forward-ad
|
|
xfail("tensor_split"), # data_ptr composite compliance
|
|
xfail("quantile"), # at::equal batching rule (cpu), also, in-place vmap (cuda)
|
|
skip("as_strided"), # Test runner cannot handle this
|
|
# requires special handling, and does not yet have a batching rule. Feel free to file a github issue!
|
|
xfail("as_strided_scatter"),
|
|
xfail(
|
|
"nn.functional.gaussian_nll_loss"
|
|
), # .item or data-dependent control flow
|
|
xfail("scatter"), # forward-mode AD does not support at::scatter
|
|
xfail(
|
|
"nanquantile"
|
|
), # at::equal batching rule (cpu), also, in-place vmap (cuda)
|
|
xfail("view_as_complex"), # Tensor must have a last dimension with stride 1
|
|
skip("pca_lowrank", ""), # randomness
|
|
skip("svd_lowrank", ""), # randomness
|
|
xfail("double"), # required rank 4 tensor to use channels_last format
|
|
xfail("cdouble"), # required rank 4 tensor to use channels_last format
|
|
# potential silent incorrectness
|
|
skip(
|
|
"nn.functional.max_unpool1d"
|
|
), # Flaky, seems to sometimes his max_unpool2d
|
|
skip("nn.functional.max_unpool2d"), # fails everywhere except on mac
|
|
skip("nn.functional.max_unpool3d"), # fails everywhere except on mac
|
|
# erroring because running_mean and running_var aren't differentiable
|
|
xfail("nn.functional.batch_norm"),
|
|
xfail("nn.functional.batch_norm", "without_cudnn"),
|
|
xfail("native_batch_norm"),
|
|
xfail("_native_batch_norm_legit"),
|
|
# TODO: implement batching rule
|
|
xfail("_batch_norm_with_update"),
|
|
# ----------------------------------------------------------------------
|
|
}
|
|
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
|
@opsToleranceOverride(
|
|
"TestOperators",
|
|
"test_vmapjvpall",
|
|
(
|
|
tol1(
|
|
"nn.functional.conv_transpose3d",
|
|
{torch.float32: tol(atol=2e-04, rtol=9e-3)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"linalg.householder_product",
|
|
{torch.float32: tol(atol=2e-04, rtol=9e-3)},
|
|
),
|
|
),
|
|
)
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vmapjvpall",
|
|
vmapjvpall_fail.union(
|
|
{
|
|
xfail("as_strided_copy"),
|
|
decorate(
|
|
"linalg.det",
|
|
"singular",
|
|
decorator=expectedFailureIf(IS_MACOS and IS_X86),
|
|
),
|
|
}
|
|
),
|
|
)
|
|
# This is technically a superset of test_vmapjvp. We should either delete test_vmapjvp
|
|
# or figure out if we can split vmapjvpall. It's useful to keep test_vmapjvp intact
|
|
# because that corresponds to "batched forward-mode AD" testing in PyTorch core
|
|
def test_vmapjvpall(self, device, dtype, op):
|
|
if is_inplace(op, op.get_op()):
|
|
# TODO: test in-place
|
|
self.skipTest("Skipped! NYI: inplace-testing not supported.")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
|
|
|
if not op.supports_forward_ad:
|
|
self.skipTest("Skipped! Forward AD not supported.")
|
|
return
|
|
|
|
for sample in samples:
|
|
arg_values = [sample.input] + list(sample.args)
|
|
kwarg_values = sample.kwargs
|
|
args = tuple(arg_values) + tuple(kwarg_values)
|
|
fn, args = get_jvp_variant_primals_tangents(op, sample)
|
|
is_batch_norm_and_training = is_batch_norm_training(op.name, kwarg_values)
|
|
generator = get_fallback_and_vmap_exhaustive(
|
|
fn, args, {}, is_batch_norm_and_training=is_batch_norm_and_training
|
|
)
|
|
for loop_out, batched_out in generator:
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vmapjvpall_has_batch_rule",
|
|
vmapjvpall_fail.union(
|
|
{
|
|
skip(
|
|
"to"
|
|
), # RuntimeError: required rank 4 tensor to use channels_last format
|
|
xfail(
|
|
"cdouble"
|
|
), # RuntimeError: required rank 4 tensor to use channels_last format
|
|
xfail("cumprod"),
|
|
xfail("masked_fill"),
|
|
xfail("fill"),
|
|
skip("masked.mean"), # ???
|
|
xfail("masked_scatter"),
|
|
xfail("put"),
|
|
xfail("take"),
|
|
xfail("nn.functional.feature_alpha_dropout", "without_train"),
|
|
xfail("nn.functional.dropout2d", ""),
|
|
xfail("pca_lowrank", ""),
|
|
xfail("svd_lowrank", ""),
|
|
xfail("nn.functional.feature_alpha_dropout", "with_train"),
|
|
xfail("special.log_ndtr", ""),
|
|
xfail("fft.ihfft2"), # conj_physical fallback
|
|
xfail("fft.ihfftn"), # conj_physical fallback
|
|
xfail("nn.functional.max_unpool3d", "grad"),
|
|
xfail("nn.functional.max_unpool2d", "grad"),
|
|
xfail("nn.functional.soft_margin_loss", ""),
|
|
xfail("nn.functional.max_unpool1d", "grad"),
|
|
xfail("nn.functional.embedding", ""),
|
|
xfail(
|
|
"scatter_reduce", "sum"
|
|
), # aten::scatter_reduce.two hit the vmap fallback
|
|
xfail(
|
|
"scatter_reduce", "mean"
|
|
), # aten::scatter_reduce.two hit the vmap fallback
|
|
xfail(
|
|
"scatter_reduce", "amin"
|
|
), # aten::scatter_reduce.two hit the vmap fallback
|
|
xfail(
|
|
"scatter_reduce", "amax"
|
|
), # aten::scatter_reduce.two hit the vmap fallback
|
|
xfail("nn.functional.glu"),
|
|
xfail("nn.functional.bilinear"), # trilinear doesn't have batching rule
|
|
xfail("linalg.lu", ""),
|
|
xfail("nn.functional.dropout3d", ""),
|
|
xfail("as_strided_scatter", ""),
|
|
xfail("masked.cumprod", ""),
|
|
xfail("renorm"), # hit vmap fallback, which is disabled
|
|
xfail("t_copy"),
|
|
xfail("unsqueeze_copy"),
|
|
}
|
|
),
|
|
)
|
|
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
|
def test_vmapjvpall_has_batch_rule(self, device, dtype, op):
|
|
if is_inplace(op, op.get_op()):
|
|
# TODO: test in-place
|
|
self.skipTest("Skipped! NYI: inplace-testing not supported.")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
|
|
|
if not op.supports_forward_ad:
|
|
self.skipTest("Skipped! Forward AD not supported.")
|
|
return
|
|
|
|
def test():
|
|
for sample in samples:
|
|
arg_values = [sample.input] + list(sample.args)
|
|
kwarg_values = sample.kwargs
|
|
args = tuple(arg_values) + tuple(kwarg_values)
|
|
fn, args = get_jvp_variant_primals_tangents(op, sample)
|
|
is_batch_norm_and_training = is_batch_norm_training(
|
|
op.name, kwarg_values
|
|
)
|
|
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
|
|
fn,
|
|
args,
|
|
{},
|
|
is_batch_norm_and_training=is_batch_norm_and_training,
|
|
compute_loop_out=False,
|
|
):
|
|
pass
|
|
|
|
check_vmap_fallback(self, test, op, dry_run=False)
|
|
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vmapvjp_has_batch_rule",
|
|
vmapvjp_fail.union(
|
|
{
|
|
skip(
|
|
"to"
|
|
), # RuntimeError: required rank 4 tensor to use channels_last format
|
|
xfail("view_as_complex"),
|
|
xfail("cummax"),
|
|
xfail("cummin"),
|
|
xfail("fill"),
|
|
xfail(
|
|
"narrow"
|
|
), # Batching rule not implemented for `narrow.Tensor` (and view op)
|
|
xfail("special.log_ndtr"),
|
|
xfail("linalg.householder_product"),
|
|
xfail("masked_fill"),
|
|
xfail("masked_scatter"),
|
|
xfail("masked_select"),
|
|
xfail("nanquantile"),
|
|
xfail("ormqr"),
|
|
xfail("put"),
|
|
xfail(
|
|
"scatter_reduce", "sum"
|
|
), # aten::scatter_reduce.two hit the vmap fallback
|
|
xfail(
|
|
"scatter_reduce", "mean"
|
|
), # aten::scatter_reduce.two hit the vmap fallback
|
|
xfail(
|
|
"scatter_reduce", "amin"
|
|
), # aten::scatter_reduce.two hit the vmap fallback
|
|
xfail(
|
|
"scatter_reduce", "amax"
|
|
), # aten::scatter_reduce.two hit the vmap fallback
|
|
xfail("quantile"),
|
|
xfail("renorm"),
|
|
xfail("take"),
|
|
xfail("tensor_split"),
|
|
xfail("to_sparse"),
|
|
xfail("unfold"),
|
|
xfail("unfold_copy"),
|
|
xfail("nn.functional.dropout"),
|
|
xfail("fft.ihfft2"),
|
|
xfail("fft.ihfftn"),
|
|
xfail("nn.functional.gaussian_nll_loss"),
|
|
xfail("nn.functional.bilinear"),
|
|
xfail("nn.functional.fractional_max_pool3d"),
|
|
xfail("nn.functional.ctc_loss"),
|
|
xfail("nn.functional.rrelu"),
|
|
xfail("nn.functional.embedding_bag"),
|
|
xfail("nn.functional.fractional_max_pool2d"),
|
|
xfail("nn.functional.feature_alpha_dropout", "with_train"),
|
|
xfail("pca_lowrank", ""),
|
|
xfail("nn.functional.dropout2d", ""),
|
|
xfail("nn.functional.feature_alpha_dropout", "without_train"),
|
|
xfail("svd_lowrank", ""),
|
|
xfail("nn.functional.max_unpool2d", ""),
|
|
xfail("nn.functional.multi_margin_loss", ""),
|
|
xfail("nn.functional.multilabel_margin_loss", ""),
|
|
xfail("nn.functional.pdist", ""),
|
|
xfail("scatter_reduce", "prod"),
|
|
xfail("nn.functional.max_unpool1d", ""),
|
|
xfail("nn.functional.max_unpool3d", ""),
|
|
xfail("nn.functional.max_unpool3d", "grad"),
|
|
xfail("nn.functional.soft_margin_loss", ""),
|
|
xfail("nn.functional.max_unpool1d", "grad"),
|
|
xfail("nn.functional.max_unpool2d", "grad"),
|
|
xfail("linalg.lu", ""),
|
|
xfail("cdouble", ""),
|
|
xfail("cfloat", ""),
|
|
xfail("chalf", ""),
|
|
xfail(
|
|
"index_reduce", "prod"
|
|
), # aten::index_reduce hit the vmap fallback which is currently disabled
|
|
xfail(
|
|
"index_reduce", "mean"
|
|
), # aten::index_reduce hit the vmap fallback which is currently disabled
|
|
xfail(
|
|
"index_reduce", "amax"
|
|
), # aten::index_reduce hit the vmap fallback which is currently disabled
|
|
xfail(
|
|
"index_reduce", "amin"
|
|
), # aten::index_reduce hit the vmap fallback which is currently disabled
|
|
xfail("nn.functional.dropout3d", ""),
|
|
xfail("as_strided_scatter", ""),
|
|
xfail("_segment_reduce", "offsets"),
|
|
xfail("_segment_reduce", "lengths"),
|
|
xfail("sparse.sampled_addmm", ""),
|
|
xfail("sparse.mm", "reduce"),
|
|
xfail("native_batch_norm"),
|
|
xfail("_native_batch_norm_legit"),
|
|
# TODO: implement batching rule
|
|
xfail("_batch_norm_with_update"),
|
|
xfail("native_dropout_backward"),
|
|
xfail(
|
|
"index_fill"
|
|
), # aten::_unique hit the vmap fallback which is currently disabled
|
|
xfail("t_copy"),
|
|
xfail("unsqueeze_copy"),
|
|
}
|
|
),
|
|
)
|
|
def test_vmapvjp_has_batch_rule(self, device, dtype, op):
|
|
if not op.supports_autograd:
|
|
self.skipTest("Skipped! Autograd not supported.")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
# TODO: test in-place
|
|
if is_inplace(op, op.get_op()):
|
|
self.skipTest("Skipped! NYI: inplace-testing not supported.")
|
|
return
|
|
|
|
def test():
|
|
for sample in samples:
|
|
cotangents = get_sample_cotangents(op, sample)
|
|
fn, args = get_vjp_fn_and_args_with_cotangents(op, sample, cotangents)
|
|
is_batch_norm_and_training = is_batch_norm_training(
|
|
op.name, sample.kwargs
|
|
)
|
|
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
|
|
fn,
|
|
args,
|
|
{},
|
|
is_batch_norm_and_training=is_batch_norm_and_training,
|
|
compute_loop_out=False,
|
|
):
|
|
pass
|
|
for a_op in op.aliases:
|
|
fn, args = get_vjp_fn_and_args_with_cotangents(
|
|
a_op, sample, cotangents
|
|
)
|
|
for loop_out, batched_out in get_fallback_and_vmap_exhaustive(
|
|
fn,
|
|
args,
|
|
{},
|
|
is_batch_norm_and_training=is_batch_norm_and_training,
|
|
compute_loop_out=False,
|
|
):
|
|
pass
|
|
|
|
check_vmap_fallback(self, test, op, dry_run=False)
|
|
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vjpvmap",
|
|
vjp_fail.union(
|
|
{
|
|
skip("bernoulli", ""), # vjpvmap testing can't handle randomness
|
|
skip("normal", ""), # vjpvmap testing can't handle randomness
|
|
skip(
|
|
"normal", "number_mean"
|
|
), # vjpvmap testing can't handle randomness
|
|
skip("nn.functional.rrelu"), # randomness
|
|
skip("nn.functional.feature_alpha_dropout", "with_train"), # randomness
|
|
skip(
|
|
"nn.functional.feature_alpha_dropout", "without_train"
|
|
), # randomness
|
|
skip("nn.functional.scaled_dot_product_attention"),
|
|
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
|
skip("nn.functional.multi_head_attention_forward"), # randomness
|
|
skip("nn.functional.alpha_dropout"), # randomness
|
|
skip(
|
|
"to"
|
|
), # RuntimeError: required rank 4 tensor to use channels_last format
|
|
skip("to_sparse", ""), # non-dense output
|
|
skip("ormqr", ""), # takes too long
|
|
xfail(
|
|
"NumpyCubeNotComposableAutogradFunction"
|
|
), # Not composable autograd.Function
|
|
# fallback path doesn't work
|
|
# All of the following are bugs and need to be fixed
|
|
xfail("__getitem__", ""),
|
|
xfail("index_put", ""),
|
|
xfail("view_as_complex"),
|
|
xfail("nn.functional.gaussian_nll_loss"),
|
|
xfail("masked_select"),
|
|
xfail(
|
|
"narrow"
|
|
), # Batching rule not implemented for `narrow.Tensor` (and view op)
|
|
skip(
|
|
"nn.functional.fractional_max_pool3d"
|
|
), # generator works on cpu, fails on cuda
|
|
skip(
|
|
"nn.functional.fractional_max_pool2d"
|
|
), # generator works on cpu, fails on cuda
|
|
xfail("column_stack", ""),
|
|
xfail("nn.functional.dropout2d", ""),
|
|
xfail("svd_lowrank", ""),
|
|
xfail("pca_lowrank", ""),
|
|
xfail("clamp"),
|
|
# something weird happening with channels_last
|
|
xfail("bfloat16"),
|
|
xfail("double"),
|
|
xfail("float"),
|
|
xfail("half"),
|
|
xfail("cdouble"),
|
|
xfail("cfloat"),
|
|
xfail("nn.functional.dropout3d", ""),
|
|
xfail("as_strided_scatter", ""),
|
|
xfail("sparse.sampled_addmm", ""),
|
|
xfail("sparse.mm", "reduce"),
|
|
xfail("native_batch_norm"),
|
|
xfail("_native_batch_norm_legit"),
|
|
# TODO: implement batching rule
|
|
xfail("_batch_norm_with_update"),
|
|
xfail("as_strided", "partial_views"),
|
|
}
|
|
),
|
|
)
|
|
def test_vjpvmap(self, device, dtype, op):
|
|
# NB: there is no vjpvmap_has_batch_rule test because that is almost
|
|
# certainly redundant with the vmap_has_batch_rule test in test_vmap.py
|
|
|
|
# one-off skip
|
|
if op.name == "nn.functional.dropout":
|
|
self.skipTest("Skipped!")
|
|
|
|
if not op.supports_autograd:
|
|
# If the op doesn't support autograd, vmap(op) won't either
|
|
self.skipTest("Skipped! Autograd not supported.")
|
|
return
|
|
|
|
# TODO: test in-place
|
|
if is_inplace(op, op.get_op()):
|
|
self.skipTest("Skipped! NYI: inplace-testing not supported.")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
batch_norm_fns = (
|
|
"nn.functional.batch_norm",
|
|
"nn.functional.instance_norm",
|
|
) # instance norm calls batch norm
|
|
is_batch_norm = op.name in batch_norm_fns
|
|
|
|
for sample in samples:
|
|
args = [sample.input] + list(sample.args)
|
|
kwargs = sample.kwargs
|
|
|
|
is_batch_norm_and_training = is_batch_norm and is_batch_norm_training(
|
|
op.name, kwargs
|
|
)
|
|
generator = generate_vmap_inputs(
|
|
args, kwargs, is_batch_norm_and_training=is_batch_norm_and_training
|
|
)
|
|
|
|
for batched_args, in_dims, kwargs in generator:
|
|
vmapped_op = vmap(op, in_dims)
|
|
fn, primals = normalize_op_input_output2(
|
|
vmapped_op, batched_args, kwargs, sample.output_process_fn_grad
|
|
)
|
|
result = fn(*primals)
|
|
cotangents = tree_map(lambda x: torch.randn_like(x), result)
|
|
|
|
_, vjp_fn = vjp(fn, *primals)
|
|
result_vjps = vjp_fn(cotangents)
|
|
|
|
_, vjp_fn = ref_vjp(fn, *primals)
|
|
expected_vjps = vjp_fn(cotangents)
|
|
|
|
self.assertEqual(result_vjps, expected_vjps)
|
|
|
|
def _compare_jacobians_of_vjp(
|
|
self, fn, cotangents_and_primals, argnums=None, atol_rtol=None
|
|
):
|
|
if argnums is None:
|
|
argnums = tuple(range(len(cotangents_and_primals)))
|
|
|
|
def get_vjp(cotangents, *primals):
|
|
_, vjp_fn = vjp(fn, *primals)
|
|
return vjp_fn(cotangents)
|
|
|
|
jacobian_jvp = jacfwd(get_vjp, argnums)(*cotangents_and_primals)
|
|
jacobian_vjp = jacrev(get_vjp, argnums)(*cotangents_and_primals)
|
|
|
|
# For dtype changing operations, the jacobians have different dtype.
|
|
jacobian_jvp = tree_map(lambda x: x.to(torch.float), jacobian_jvp)
|
|
jacobian_vjp = tree_map(lambda x: x.to(torch.float), jacobian_vjp)
|
|
|
|
if atol_rtol is not None:
|
|
(atol, rtol) = atol_rtol
|
|
self.assertEqual(jacobian_jvp, jacobian_vjp, atol=atol, rtol=rtol)
|
|
else:
|
|
self.assertEqual(jacobian_jvp, jacobian_vjp)
|
|
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_jvpvjp",
|
|
vjp_fail.union(
|
|
{
|
|
xfail("to_sparse", ""), # NYI
|
|
# RuntimeError: Trying to set a forward gradient that has a different size than that of the original Tensor,
|
|
# this is not supported. Tensor is of size [5, 2, 3] while the given forward gradient is of size [1, 2, 3].
|
|
xfail("normal", ""),
|
|
xfail("cdist", ""), # NYI: forward-AD for _cdist_forward
|
|
xfail("cholesky", ""), # NYI: forward-AD for cholesky
|
|
xfail(
|
|
"nn.functional.embedding_bag", ""
|
|
), # NYI: forward-AD for _embedding_bag
|
|
xfail(
|
|
"nn.functional.grid_sample", ""
|
|
), # NYI: forward AD for grid_sampler_2d
|
|
xfail("grid_sampler_2d", ""), # NYI: forward AD for grid_sampler_2d
|
|
xfail(
|
|
"nn.functional.hardsigmoid", ""
|
|
), # NYI: forward AD for hardsigmoid_backward
|
|
xfail(
|
|
"nn.functional.huber_loss", ""
|
|
), # NYI: forward AD for huber_loss_backward
|
|
xfail("NumpyCubeNotComposableAutogradFunction"), # not composable
|
|
xfail("ormqr", ""), # NYI: forward AD for ormqr
|
|
xfail(
|
|
"nn.functional.multilabel_margin_loss", ""
|
|
), # NYI: multilabel_margin_loss_forward
|
|
xfail(
|
|
"nn.functional.soft_margin_loss", ""
|
|
), # NYI: forward-AD for soft_margin_loss_backward
|
|
xfail("nn.functional.ctc_loss", ""), # NYI: forward-AD for _ctc_loss
|
|
xfail("nn.functional.pdist", ""), # NYI: forward-AD with _pdist_forward
|
|
skip("nn.functional.scaled_dot_product_attention"),
|
|
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
|
xfail(
|
|
"nn.functional.multi_margin_loss", ""
|
|
), # NYI: forward AD with multi_margin_loss
|
|
skip(
|
|
"linalg.householder_product", "", device_type="cuda"
|
|
), # flaky, I'm not sure why
|
|
xfail("sparse.sampled_addmm", ""), # Sparse tensors have no strides
|
|
xfail(
|
|
"_segment_reduce", "offsets"
|
|
), # NYI: forward-AD for _segment_reduce
|
|
xfail("sparse.mm", "reduce"), # Sparse tensors have no strides
|
|
xfail("index_reduce", "prod"), # NYI: forward-AD for index_reduce
|
|
xfail("index_reduce", "mean"), # NYI: forward-AD for index_reduce
|
|
xfail("index_reduce", "amax"), # NYI: forward-AD for index_reduce
|
|
xfail("index_reduce", "amin"), # NYI: forward-AD for index_reduce
|
|
xfail(
|
|
"_segment_reduce", "lengths"
|
|
), # NYI: forward-AD for _segment_reduce
|
|
xfail("native_dropout_backward"), # NYI
|
|
}
|
|
),
|
|
)
|
|
@opsToleranceOverride(
|
|
"TestOperators",
|
|
"test_jvpvjp",
|
|
(
|
|
tol1("masked.prod", {torch.float32: tol(atol=1e-04, rtol=1.3e-05)}),
|
|
tol1("masked.cumprod", {torch.float32: tol(atol=1e-04, rtol=5e-04)}),
|
|
tol1(
|
|
"cumprod",
|
|
{torch.float32: tol(atol=1e-03, rtol=5e-04)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"linalg.det",
|
|
{torch.float32: tol(atol=3e-05, rtol=5e-06)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"linalg.vander",
|
|
{torch.float32: tol(atol=1e-04, rtol=1.3e-05)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"nn.functional.group_norm", {torch.float32: tol(atol=1e-03, rtol=1e-03)}
|
|
),
|
|
tol2(
|
|
"linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-03, rtol=5e-03)}
|
|
),
|
|
),
|
|
)
|
|
def test_jvpvjp(self, device, dtype, op):
|
|
if not op.supports_autograd:
|
|
self.skipTest("Skipped! Autograd not supported.")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
# TODO: test in-place
|
|
if is_inplace(op, op.get_op()):
|
|
self.skipTest("Skipped! NYI: inplace-testing not supported.")
|
|
return
|
|
|
|
for sample in samples:
|
|
fn, primals = normalize_op_input_output(op, sample)
|
|
result = fn(*primals)
|
|
cotangents = tree_map(lambda x: torch.randn_like(x), result)
|
|
|
|
primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
|
|
cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)
|
|
|
|
def push_vjp(primals, cotangents):
|
|
_, vjp_fn = vjp(fn, *primals)
|
|
return vjp_fn(cotangents)
|
|
|
|
result = jvp(
|
|
push_vjp, (primals, cotangents), (primals_tangents, cotangents_tangents)
|
|
)
|
|
self.assertEqual(len(result), 2)
|
|
|
|
def tree_map2(fn, first, second):
|
|
flat_first, spec_first = tree_flatten(first)
|
|
flat_second, spec_second = tree_flatten(second)
|
|
assert spec_first == spec_second
|
|
flat_result = [fn(f, s) for f, s in zip(flat_first, flat_second)]
|
|
return tree_unflatten(flat_result, spec_first)
|
|
|
|
def reference(primals, cotangents, primals_tangents, cotangents_tangents):
|
|
with fwAD.dual_level():
|
|
primal_duals = tree_map2(fwAD.make_dual, primals, primals_tangents)
|
|
_, vjp_fn = ref_vjp(fn, *primal_duals)
|
|
|
|
cotangent_duals = tree_map2(
|
|
fwAD.make_dual, cotangents, cotangents_tangents
|
|
)
|
|
result = vjp_fn(cotangent_duals)
|
|
|
|
flat_result, spec = tree_flatten(result)
|
|
primals_out, tangents_out = zip(
|
|
*[fwAD.unpack_dual(r) for r in flat_result]
|
|
)
|
|
tangents_out = [
|
|
t if t is not None else torch.zeros_like(p)
|
|
for p, t in zip(primals_out, tangents_out)
|
|
]
|
|
expected = (
|
|
tree_unflatten(primals_out, spec),
|
|
tree_unflatten(tangents_out, spec),
|
|
)
|
|
return expected
|
|
|
|
expected = reference(
|
|
primals, cotangents, primals_tangents, cotangents_tangents
|
|
)
|
|
self.assertEqual(result, expected)
|
|
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vmapjvpvjp",
|
|
vjp_fail.union(
|
|
{
|
|
# Following operators take too long, hence skipped
|
|
skip("atleast_1d"),
|
|
skip("atleast_2d"),
|
|
skip("atleast_3d"),
|
|
skip("meshgrid", "list_of_tensors"),
|
|
skip("meshgrid", "variadic_tensors"),
|
|
skip("broadcast_tensors"),
|
|
skip("linalg.lstsq"),
|
|
skip("nn.functional.bilinear"),
|
|
skip("native_layer_norm"),
|
|
skip("ormqr"),
|
|
# Not actually a problem
|
|
xfail("NumpyCubeNotComposableAutogradFunction"), # not composable
|
|
xfail(
|
|
"NumpyExpMarkDirtyAutogradFunction"
|
|
), # vmap: inplace into a regular tensor
|
|
# Potential bugs/errors
|
|
xfail("as_strided"), # AssertionError: Tensor-likes are not close!
|
|
xfail(
|
|
"as_strided", "partial_views"
|
|
), # AssertionError: Tensor-likes are not close!
|
|
xfail("as_strided_copy"), # AssertionError: Tensor-likes are not close!
|
|
xfail(
|
|
"as_strided_scatter"
|
|
), # AssertionError: Tensor-likes are not close!
|
|
xfail("bernoulli"), # calls random op
|
|
xfail("bfloat16"), # required rank 4 tensor to use channels_last format
|
|
xfail("cdist"), # Forward AD not implemented and no decomposition
|
|
xfail("cdouble"), # required rank 4 tensor to use channels_last format
|
|
xfail("cfloat"), # required rank 4 tensor to use channels_last format
|
|
xfail("chalf"), # required rank 4 tensor to use channels_last format
|
|
xfail("cholesky"), # Forward AD not implemented and no decomposition
|
|
xfail("ormqr"), # Forward AD not implemented and no decomposition
|
|
xfail("double"), # required rank 4 tensor to use channels_last format
|
|
xfail("float"), # required rank 4 tensor to use channels_last format
|
|
xfail("half"), # required rank 4 tensor to use channels_last format
|
|
xfail("index_reduce", "prod"), # NYI: forward AD for index_reduce
|
|
xfail("index_reduce", "mean"), # NYI: forward AD for index_reduce
|
|
xfail("index_reduce", "amax"), # NYI: forward AD for index_reduce
|
|
xfail("index_reduce", "amin"), # NYI: forward AD for index_reduce
|
|
xfail(
|
|
"mvlgamma", "mvlgamma_p_1"
|
|
), # vmap: inplace into a regular tensor
|
|
xfail(
|
|
"mvlgamma", "mvlgamma_p_3"
|
|
), # vmap: inplace into a regular tensor
|
|
xfail(
|
|
"mvlgamma", "mvlgamma_p_5"
|
|
), # vmap: inplace into a regular tensor
|
|
xfail("nanquantile"), # Batching rule not implemented for aten::equal
|
|
# RuntimeError: Batch norm got a batched tensor as input while the
|
|
# running_mean or running_var, which will be updated in place,
|
|
# were not batched.
|
|
xfail("nn.functional.batch_norm"),
|
|
xfail("nn.functional.batch_norm", "without_cudnn"),
|
|
xfail(
|
|
"nn.functional.ctc_loss"
|
|
), # ForwardAD not implemented and no decomposition
|
|
xfail("nn.functional.dropout2d"), # calls random op
|
|
xfail("nn.functional.dropout3d"), # calls random op
|
|
xfail("nn.functional.dropout"), # calls random op
|
|
xfail("nn.functional.scaled_dot_product_attention"), # randomness
|
|
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
|
xfail("nn.functional.multi_head_attention_forward"), # randomness
|
|
xfail(
|
|
"nn.functional.embedding_bag"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail("nn.functional.alpha_dropout"), # calls randomn op
|
|
xfail(
|
|
"nn.functional.feature_alpha_dropout", "with_train"
|
|
), # calls random op
|
|
xfail("nn.functional.fractional_max_pool2d"), # calls random op
|
|
xfail("nn.functional.fractional_max_pool3d"), # calls random op
|
|
xfail("nn.functional.gaussian_nll_loss"), # data depenedant flow
|
|
xfail(
|
|
"nn.functional.grid_sample"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail(
|
|
"grid_sampler_2d"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail(
|
|
"nn.functional.hardsigmoid"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail(
|
|
"nn.functional.hinge_embedding_loss"
|
|
), # vmap: inplace into a regular tensor
|
|
xfail(
|
|
"nn.functional.huber_loss"
|
|
), # Forward AD not implemented and no decomposition
|
|
# RuntimeError: Batch norm got a batched tensor as input while the
|
|
# running_mean or running_var, which will be updated in place,
|
|
# were not batched.
|
|
xfail("nn.functional.instance_norm"),
|
|
# NYI: Tensor.clone(memory_format) inside vmap is only supported with
|
|
# memory_format torch.preserve_format or torch.contiguous_format (got ChannelsLast)
|
|
xfail("nn.functional.max_unpool2d"),
|
|
xfail("nn.functional.max_unpool2d", "grad"),
|
|
xfail(
|
|
"nn.functional.multi_margin_loss"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail(
|
|
"nn.functional.multilabel_margin_loss"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail(
|
|
"nn.functional.pdist"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail(
|
|
"nn.functional.rrelu"
|
|
), # vmap: we do not yet support aten::rrelu_with_noise.
|
|
xfail(
|
|
"nn.functional.soft_margin_loss"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail("normal"), # calls random op
|
|
xfail("normal", "number_mean"), # calls random op
|
|
xfail("pca_lowrank"), # calls random op
|
|
xfail("quantile"), # Batching rule not implemented for aten::equal
|
|
xfail(
|
|
"scatter_reduce", "prod"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail(
|
|
"_segment_reduce", "lengths"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail(
|
|
"_segment_reduce", "offsets"
|
|
), # Forward AD not implemented and no decomposition
|
|
xfail(
|
|
"sparse.sampled_addmm"
|
|
), # RuntimeError: Sparse CSR tensors do not have strides
|
|
xfail(
|
|
"sparse.mm", "reduce"
|
|
), # RuntimeError: Sparse CSR tensors do not have strides
|
|
xfail("svd_lowrank"), # calls random op
|
|
xfail(
|
|
"to"
|
|
), # RuntimeError: required rank 4 tensor to use channels_last format
|
|
xfail("to_sparse"), # Forward AD not implemented and no decomposition
|
|
xfail(
|
|
"view_as_complex"
|
|
), # RuntimeError: Tensor must have a last dimension with stride 1
|
|
# RuntimeError: Batch norm got a batched tensor as
|
|
# input while the running_mean or running_var, which will be updated in
|
|
# place, were not batched.
|
|
xfail("native_batch_norm"),
|
|
xfail("_native_batch_norm_legit"),
|
|
# TODO: implement batching rule
|
|
xfail("_batch_norm_with_update"),
|
|
xfail("native_dropout_backward"),
|
|
}
|
|
),
|
|
)
|
|
@ops(op_db + additional_op_db + autograd_function_db, allowed_dtypes=(torch.float,))
|
|
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
|
|
@opsToleranceOverride(
|
|
"TestOperators",
|
|
"test_vmapjvpvjp",
|
|
(
|
|
tol1("linalg.svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
|
|
tol1(
|
|
"linalg.householder_product",
|
|
{torch.float32: tol(atol=5e-03, rtol=5e-03)},
|
|
),
|
|
tol1("linalg.multi_dot", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
|
|
tol2(
|
|
"linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-04, rtol=5e-04)}
|
|
),
|
|
tol1(
|
|
"nn.functional.conv_transpose2d",
|
|
{torch.float32: tol(atol=5e-04, rtol=5e-04)},
|
|
),
|
|
tol1("svd", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
|
|
tol1("matrix_exp", {torch.float32: tol(atol=5e-04, rtol=5e-04)}),
|
|
),
|
|
)
|
|
def test_vmapjvpvjp(self, device, dtype, op):
|
|
# Since we test `jvpvjp` separately,
|
|
# in this we just check that vmap of `jvpvjp`
|
|
# is correct.
|
|
if not op.supports_autograd:
|
|
self.skipTest("Skipped! Autograd not supported.")
|
|
return
|
|
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
# TODO: test in-place
|
|
if is_inplace(op, op.get_op()):
|
|
self.skipTest("Skipped! NYI: inplace-testing not supported.")
|
|
return
|
|
|
|
for sample in samples:
|
|
fn, primals = normalize_op_input_output(op, sample)
|
|
result = fn(*primals)
|
|
cotangents = tree_map(lambda x: torch.randn_like(x), result)
|
|
|
|
primals_tangents = tree_map(lambda x: torch.randn_like(x), primals)
|
|
cotangents_tangents = tree_map(lambda x: torch.randn_like(x), cotangents)
|
|
|
|
def push_vjp(primals, cotangents):
|
|
_, vjp_fn = vjp(fn, *primals)
|
|
return vjp_fn(cotangents)
|
|
|
|
args, spec = tree_flatten(
|
|
((primals, cotangents), (primals_tangents, cotangents_tangents))
|
|
)
|
|
|
|
def jvp_of_vjp(*args):
|
|
(primals, tangents) = tree_unflatten(args, spec)
|
|
primals_out, tangents_out = jvp(push_vjp, primals, tangents)
|
|
|
|
flat_primals_out = pytree.tree_leaves(primals_out)
|
|
flat_tangents_out = pytree.tree_leaves(tangents_out)
|
|
return tuple(flat_primals_out + flat_tangents_out)
|
|
|
|
is_batch_norm_and_training = is_batch_norm_training(op, sample.kwargs)
|
|
generator = get_fallback_and_vmap_exhaustive(
|
|
jvp_of_vjp,
|
|
args,
|
|
{},
|
|
is_batch_norm_and_training=is_batch_norm_and_training,
|
|
)
|
|
for loop_out, batched_out in generator:
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
def _make_extremal_inputs(self, shape, device):
|
|
if shape is None:
|
|
return (None,)
|
|
return (
|
|
torch.full(shape, -1000.0, device=device),
|
|
torch.zeros(shape, device=device),
|
|
torch.full(shape, 1000.0, device=device),
|
|
)
|
|
|
|
def _arg_and_kwarg_options(self, args_options, kwargs_options):
|
|
return itertools.product(*args_options, kwargs_options)
|
|
|
|
def test_extremal_numerics_nll_loss(self, device):
|
|
N, C = 3, 4
|
|
d1, d2, d3 = 5, 6, 7
|
|
shapes = (
|
|
((N, C), (N,), (C,)),
|
|
((N, C), (N,), None),
|
|
((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)),
|
|
((N, C, d1, d2, d3), (N, d1, d2, d3), None),
|
|
)
|
|
kwargs_options = (
|
|
{"ignore_index": 0, "reduction": "mean"},
|
|
{"reduction": "sum"},
|
|
{"reduction": "none"},
|
|
{},
|
|
)
|
|
for input_shape, target_shape, weight_shape in shapes:
|
|
input_options = self._make_extremal_inputs(input_shape, device)
|
|
for input, kwargs in self._arg_and_kwarg_options(
|
|
(input_options,), kwargs_options
|
|
):
|
|
if weight_shape is None:
|
|
weight = None
|
|
else:
|
|
weight = torch.randn(weight_shape, device=device)
|
|
target = torch.randint(0, C, target_shape, device=device)
|
|
target[
|
|
0
|
|
] = 1 # since we're ignoring index 0, at least one element must be non-zero
|
|
|
|
fn = functools.partial(
|
|
torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs
|
|
)
|
|
result = fn(input)
|
|
cotangents = torch.randn_like(result, device=device)
|
|
self._compare_jacobians_of_vjp(fn, (cotangents, input))
|
|
|
|
def test_extremal_numerics_l1_loss(self, device):
|
|
N, C, H, W = 3, 4, 5, 6
|
|
shapes = ((N, C), (N, C, H), (N, C, H, W))
|
|
kwargs_options = ({"reduction": "sum"}, {"reduction": "none"}, {})
|
|
for shape in shapes:
|
|
input_options = self._make_extremal_inputs(shape, device)
|
|
target_options = self._make_extremal_inputs(shape, device)
|
|
for input, target, kwargs in self._arg_and_kwarg_options(
|
|
(input_options, target_options), kwargs_options
|
|
):
|
|
result = torch.nn.functional.l1_loss(input, target)
|
|
cotangents = torch.randn_like(result, device=device)
|
|
self._compare_jacobians_of_vjp(
|
|
torch.nn.functional.l1_loss, (cotangents, input, target)
|
|
)
|
|
|
|
def test_extremal_numerics_mse_loss(self, device):
|
|
N, C, H, W = 3, 4, 5, 6
|
|
shapes = ((N, C), (N, C, H), (N, C, H, W))
|
|
kwargs_options = ({"reduction": "sum"}, {"reduction": "none"}, {})
|
|
for shape in shapes:
|
|
input_options = self._make_extremal_inputs(shape, device)
|
|
target_options = self._make_extremal_inputs(shape, device)
|
|
for input, target, kwargs in self._arg_and_kwarg_options(
|
|
(input_options, target_options), kwargs_options
|
|
):
|
|
result = torch.nn.functional.mse_loss(input, target)
|
|
cotangents = torch.randn_like(result, device=device)
|
|
self._compare_jacobians_of_vjp(
|
|
torch.nn.functional.mse_loss, (cotangents, input, target)
|
|
)
|
|
|
|
def test_extremal_numerics_softmax(self, device):
|
|
N, C, H, W = 3, 4, 5, 6
|
|
shapes = ((N, C), (N, C, H), (N, C, H, W))
|
|
kwargs_options = ({"dim": 1}, {})
|
|
for shape in shapes:
|
|
input_options = self._make_extremal_inputs(shape, device)
|
|
for input, kwargs in self._arg_and_kwarg_options(
|
|
(input_options,), kwargs_options
|
|
):
|
|
result = torch.nn.functional.softmax(input)
|
|
cotangents = torch.randn_like(result, device=device)
|
|
self._compare_jacobians_of_vjp(
|
|
torch.nn.functional.softmax, (cotangents, input)
|
|
)
|
|
|
|
def test_extremal_numerics_log_softmax(self, device):
|
|
N, C, H, W = 3, 4, 5, 6
|
|
shapes = ((N, C), (N, C, H), (N, C, H, W))
|
|
kwargs_options = ({"dim": 1}, {})
|
|
for shape in shapes:
|
|
input_options = self._make_extremal_inputs(shape, device)
|
|
for input, kwargs in self._arg_and_kwarg_options(
|
|
(input_options,), kwargs_options
|
|
):
|
|
result = torch.nn.functional.log_softmax(input)
|
|
cotangents = torch.randn_like(result, device=device)
|
|
self._compare_jacobians_of_vjp(
|
|
torch.nn.functional.log_softmax, (cotangents, input)
|
|
)
|
|
|
|
def test_extremal_numerics_cross_entropy(self, device):
|
|
N, C = 3, 4
|
|
d1, d2, d3 = 5, 6, 7
|
|
shapes = (
|
|
((N, C), (N,), (C,)),
|
|
((N, C), (N,), None),
|
|
((N, C), (N, C), (C,)),
|
|
((N, C), (N, C), None),
|
|
((C,), (), (C,)),
|
|
((C,), (), None),
|
|
((C,), (C,), (C,)),
|
|
((C,), (C,), None),
|
|
((N, C, d1, d2, d3), (N, d1, d2, d3), (C,)),
|
|
((N, C, d1, d2, d3), (N, d1, d2, d3), None),
|
|
((N, C, d1, d2, d3), (N, C, d1, d2, d3), (C,)),
|
|
((N, C, d1, d2, d3), (N, C, d1, d2, d3), None),
|
|
)
|
|
for input_shape, target_shape, weight_shape in shapes:
|
|
input_options = self._make_extremal_inputs(input_shape, device)
|
|
kwargs_options = [{"reduction": "sum"}, {"reduction": "none"}, {}]
|
|
if input_shape != target_shape:
|
|
kwargs_options.append({"ignore_index": 0, "reduction": "mean"})
|
|
|
|
for input, kwargs in self._arg_and_kwarg_options(
|
|
(input_options,), kwargs_options
|
|
):
|
|
if weight_shape is None:
|
|
weight = None
|
|
else:
|
|
weight = torch.randn(weight_shape, device=device)
|
|
|
|
if input_shape == target_shape:
|
|
target = torch.rand(target_shape, device=device)
|
|
elif len(target_shape) == 0:
|
|
target = torch.tensor(
|
|
1, device=device
|
|
) # must be non-zero since ignore_index may be 0
|
|
else:
|
|
target = torch.randint(0, C, target_shape, device=device)
|
|
|
|
fn = functools.partial(
|
|
torch.nn.functional.cross_entropy,
|
|
target=target,
|
|
weight=weight,
|
|
**kwargs,
|
|
)
|
|
result = fn(input)
|
|
cotangents = torch.randn_like(result, device=device)
|
|
self._compare_jacobians_of_vjp(
|
|
fn, (cotangents, input), atol_rtol=(1e-4, 1e-5)
|
|
)
|
|
|
|
def test_extremal_numerics_binary_cross_entropy(self, device):
|
|
N, C, H, W = 3, 4, 5, 6
|
|
shapes = ((N, C), (N, C, H), (N, C, H, W))
|
|
for shape in shapes:
|
|
weight_options = self._make_extremal_inputs(shape, device)
|
|
kwargs_options = [{"reduction": "sum"}, {"reduction": "none"}, {}]
|
|
|
|
for weight, kwargs in self._arg_and_kwarg_options(
|
|
(weight_options,), kwargs_options
|
|
):
|
|
input = torch.rand(shape, device=device)
|
|
target = torch.rand(shape, device=device)
|
|
fn = functools.partial(
|
|
torch.nn.functional.binary_cross_entropy,
|
|
target=target,
|
|
weight=weight,
|
|
**kwargs,
|
|
)
|
|
result = fn(input)
|
|
cotangents = torch.randn_like(result, device=device)
|
|
self._compare_jacobians_of_vjp(
|
|
fn, (cotangents, input), atol_rtol=(1e-4, 2e-5)
|
|
)
|
|
|
|
def test_extremal_numerics_layer_norm(self, device):
|
|
N, C, H, W = 3, 4, 5, 6
|
|
shapes = ((N, C), (N, C, H), (N, C, H, W))
|
|
for shape in shapes:
|
|
input_options = self._make_extremal_inputs(shape, device)
|
|
normalized_shape = shape[1:]
|
|
weight_options = self._make_extremal_inputs(normalized_shape, device)
|
|
bias_options = self._make_extremal_inputs(normalized_shape, device)
|
|
|
|
for input, bias, weight in self._arg_and_kwarg_options(
|
|
(input_options, bias_options, weight_options), ()
|
|
):
|
|
|
|
def fn(input, weight, bias):
|
|
return torch.nn.functional.layer_norm(
|
|
input, normalized_shape, weight=weight, bias=bias
|
|
)
|
|
|
|
result = fn(input, weight, bias)
|
|
cotangents = torch.randn_like(result, device=device)
|
|
self._compare_jacobians_of_vjp(fn, (cotangents, input, weight, bias))
|
|
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
@ops(
|
|
op_db + additional_op_db + autograd_function_db,
|
|
allowed_dtypes=(torch.float32, torch.double),
|
|
)
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vmap_autograd_grad",
|
|
{
|
|
# The size of tensor a (4) must match the size of tensor b (10) at non-singleton dimension 0
|
|
xfail("masked_select"),
|
|
xfail("nn.functional.max_unpool2d", "grad"), # contiguous call
|
|
xfail("nn.functional.max_unpool2d"), # contiguous call
|
|
xfail("to_sparse"), # dispatch key issue
|
|
xfail("torch.ops.aten._efficient_attention_forward"), # outputs ints
|
|
# https://github.com/pytorch/pytorch/issues/96560#issuecomment-2151063723
|
|
# ** minor accuracy issue for float32 on ROCm
|
|
decorate("xlogy", decorator=skipIfRocm),
|
|
# numerical inconsistencies, look like bugs
|
|
skip(
|
|
"matrix_exp", dtypes=(torch.float32,), device_type="cuda"
|
|
), # fails on linux, passes on windows
|
|
skip(
|
|
"ldexp", dtypes=(torch.float32,), device_type="cpu"
|
|
), # fails on all but mac
|
|
skip("__rmatmul__"), # flaky needs investigation
|
|
skip("matmul"), # flaky needs investigation
|
|
skip("nn.functional.conv_transpose3d"), # flaky needs investigation
|
|
skip("nn.functional.conv_transpose2d"), # flaky needs investigation
|
|
skip("nn.functional.conv_transpose1d"), # flaky needs investigation
|
|
skip(
|
|
"nn.functional.layer_norm", dtypes=(torch.float32,), device_type="cpu"
|
|
), # fails on windows
|
|
skip(
|
|
"linalg.lu_factor", dtypes=(torch.float32,), device_type="cuda"
|
|
), # fails on all but windows
|
|
skip(
|
|
"linalg.lu_factor_ex", dtypes=(torch.float32,), device_type="cuda"
|
|
), # fails on all but windows
|
|
skip("linalg.multi_dot", "", device_type="cpu"),
|
|
skip("sparse.sampled_addmm", ""),
|
|
skip("sparse.mm", "reduce"),
|
|
skip("native_layer_norm", "", device_type="cpu"),
|
|
# RuntimeError: Expected contiguous tensor, but got
|
|
# non-contiguous tensor for argument #2 'grad_output'
|
|
decorate(
|
|
"_batch_norm_with_update",
|
|
decorator=expectedFailureIf(TEST_WITH_ROCM),
|
|
device_type="cuda",
|
|
),
|
|
},
|
|
)
|
|
@opsToleranceOverride(
|
|
"TestOperators",
|
|
"test_vmap_autograd_grad",
|
|
(
|
|
tol1(
|
|
"ldexp",
|
|
{torch.float32: tol(atol=3e-04, rtol=1.6e-06)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"linalg.householder_product",
|
|
{torch.float32: tol(atol=5e-04, rtol=9e-03)},
|
|
device_type="cuda",
|
|
),
|
|
tol1(
|
|
"linalg.householder_product",
|
|
{torch.float32: tol(atol=6e-03, rtol=1e-03)},
|
|
device_type="cpu",
|
|
),
|
|
tol1(
|
|
"linalg.multi_dot",
|
|
{torch.float32: tol(atol=2e-04, rtol=1e-04)},
|
|
device_type="cuda",
|
|
),
|
|
tol2(
|
|
"linalg.pinv", "hermitian", {torch.float32: tol(atol=5e-06, rtol=5e-06)}
|
|
),
|
|
tol1("nn.functional.conv3d", {torch.float32: tol(atol=5e-04, rtol=9e-03)}),
|
|
tol1(
|
|
"nn.functional.conv2d",
|
|
{torch.float32: tol(atol=3e-05, rtol=5e-06)},
|
|
device_type="cuda",
|
|
),
|
|
tol1("svd_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
|
|
tol1("pca_lowrank", {torch.float32: tol(atol=5e-05, rtol=5e-05)}),
|
|
),
|
|
)
|
|
def test_vmap_autograd_grad(self, device, dtype, op):
|
|
def is_differentiable(inp):
|
|
return isinstance(inp, Tensor) and (
|
|
inp.grad_fn is not None or inp.requires_grad
|
|
)
|
|
|
|
def get_flat_differentiable(tree):
|
|
flattened = pytree.tree_leaves(tree)
|
|
return tuple(i for i in flattened if is_differentiable(i))
|
|
|
|
def get_differentiable_linked(list1, list2):
|
|
paired_list = zip(list1, list2)
|
|
paired_list = tuple(
|
|
(first, second)
|
|
for (first, second) in paired_list
|
|
if is_differentiable(first)
|
|
)
|
|
return zip(*paired_list)
|
|
|
|
def filter_none(out):
|
|
flattened = pytree.tree_leaves(out)
|
|
return tuple(o for o in flattened if o is not None)
|
|
|
|
if not op.supports_autograd:
|
|
self.skipTest("Skipped! Autograd not supported.")
|
|
return
|
|
|
|
sample_inputs = op.sample_inputs(device, dtype, requires_grad=True)
|
|
|
|
for sample_input in sample_inputs:
|
|
fn, primals = normalize_op_input_output(op, sample_input)
|
|
out = fn(*primals)
|
|
cotangents = tree_map(torch.randn_like, out)
|
|
|
|
def compute_grad(cotangents):
|
|
out_flattened = out
|
|
cotangents_flattened = cotangents
|
|
if not isinstance(out_flattened, torch.Tensor):
|
|
out_flattened = pytree.tree_leaves(out)
|
|
cotangents_flattened = pytree.tree_leaves(cotangents)
|
|
out_flattened, cotangents_flattened = get_differentiable_linked(
|
|
out_flattened, cotangents_flattened
|
|
)
|
|
|
|
return filter_none(
|
|
torch.autograd.grad(
|
|
out_flattened,
|
|
get_flat_differentiable(primals),
|
|
cotangents_flattened,
|
|
retain_graph=True,
|
|
allow_unused=True,
|
|
)
|
|
)
|
|
|
|
is_batch_norm_and_training = is_batch_norm_training(op, sample_input.kwargs)
|
|
generator = get_fallback_and_vmap_exhaustive(
|
|
compute_grad,
|
|
(cotangents,),
|
|
{},
|
|
is_batch_norm_and_training=is_batch_norm_and_training,
|
|
)
|
|
for loop_out, batched_out in generator:
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
def test_vmapvmapjvp_linalg_solve(self):
|
|
ops = [op for op in op_db if op.name == "linalg.solve"]
|
|
assert len(ops) > 0
|
|
|
|
# this specializes a lot of code from the get_fallback_and_vmap_exhaustive test. If we need this more
|
|
# generally, this could go for a refactor
|
|
|
|
B0 = 2
|
|
B1 = 3
|
|
|
|
# we want to check the case where A will be seen as contiguous by jvp but during the vmap calls will become
|
|
# non-contiguous because vmap will expand. This will happen during both levels of vmap
|
|
A = torch.randn(4, 4)
|
|
k = torch.randn(4, 5, B1, B0)
|
|
fn, args = get_jvp_variant_primals_tangents(
|
|
torch.linalg.solve, SampleInput(A, args=(k,))
|
|
)
|
|
|
|
in_dims_all = (None, -1, None, -1)
|
|
batched_out = vmap(vmap(fn, in_dims=in_dims_all), in_dims=in_dims_all)(*args)
|
|
loop_out = loop2(fn, in_dims_all, in_dims_all, 0, 0, B0, B1, *args)
|
|
self.assertEqual(loop_out, batched_out)
|
|
|
|
@ops(
|
|
filter(lambda op: op.name in aliasing_ops, op_db + additional_op_db),
|
|
allowed_dtypes=(torch.float,),
|
|
)
|
|
@parametrize("grad_op", ["jvp", "vjp"])
|
|
def test_view_then_inplace(self, device, dtype, op, grad_op):
|
|
for sample_input in op.sample_inputs(device, dtype):
|
|
|
|
def f(x):
|
|
op(sample_input.input, *sample_input.args, **sample_input.kwargs).copy_(
|
|
x
|
|
)
|
|
return x
|
|
|
|
without_grad = op(
|
|
sample_input.input, *sample_input.args, **sample_input.kwargs
|
|
)
|
|
if grad_op == "jvp":
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"During a grad .* attempted to call in-place operation",
|
|
):
|
|
jvp(
|
|
f,
|
|
(torch.randn_like(without_grad),),
|
|
(torch.randn_like(without_grad),),
|
|
)
|
|
else:
|
|
assert grad_op == "vjp"
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"During a grad .* attempted to call in-place operation",
|
|
):
|
|
vjp(f, torch.randn_like(without_grad))
|
|
|
|
@ops(
|
|
filter(
|
|
lambda op: op.name in aliasing_ops_list_return, op_db + additional_op_db
|
|
),
|
|
allowed_dtypes=(torch.float,),
|
|
)
|
|
@parametrize("grad_op", ["jvp", "vjp"])
|
|
def test_view_then_inplace_list_return(self, device, dtype, op, grad_op):
|
|
for sample_input in op.sample_inputs(device, dtype):
|
|
|
|
def f(x):
|
|
op(sample_input.input, *sample_input.args, **sample_input.kwargs)[
|
|
0
|
|
].copy_(x)
|
|
return x
|
|
|
|
without_grad = op(
|
|
sample_input.input, *sample_input.args, **sample_input.kwargs
|
|
)[0]
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "During a grad .* attempted to call in-place operation"
|
|
):
|
|
if grad_op == "jvp":
|
|
jvp(
|
|
f,
|
|
(torch.randn_like(without_grad),),
|
|
(torch.randn_like(without_grad),),
|
|
)
|
|
else:
|
|
assert grad_op == "vjp"
|
|
vjp(f, torch.randn_like(without_grad))
|
|
|
|
@parametrize("grad_op", ["jvp", "vjp"])
|
|
def test_view_then_inplace_special(self, grad_op):
|
|
# some things in __getitem__ use at::index, which doesn't alias, so this tests a subset of them that do alias
|
|
ops = [
|
|
lambda x: x[0],
|
|
lambda x: x[0, 0, 0],
|
|
lambda x: x[:1],
|
|
lambda x: x[:, :1],
|
|
lambda x: x[:, :1, :],
|
|
]
|
|
|
|
for op in ops:
|
|
|
|
def f(x):
|
|
op(captured).copy_(x)
|
|
return x
|
|
|
|
captured = torch.randn(4, 3, 3)
|
|
without_grad = op(captured)
|
|
if grad_op == "jvp":
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"During a grad .* attempted to call in-place operation",
|
|
):
|
|
jvp(
|
|
f,
|
|
(torch.randn_like(without_grad),),
|
|
(torch.randn_like(without_grad),),
|
|
)
|
|
else:
|
|
assert grad_op == "vjp"
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"During a grad .* attempted to call in-place operation",
|
|
):
|
|
vjp(f, torch.randn_like(without_grad))
|
|
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
# NOTE: [three-transform testing]
|
|
# We only test the autograd_function_db tests here.
|
|
#
|
|
# Usually testing the composition of two transforms is sufficient to convince
|
|
# ourselves that an operator is correctly implemented. For the following cases,
|
|
# we want to be extra sure, so we send those through some three-transform tests:
|
|
# - autograd.Function. The mechanism is via PyDispatcher/HigherOrderOperator, not the
|
|
# regular PyTorch dispatcher, so it's good to exercise more caution.
|
|
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vmapvjpvmap",
|
|
{
|
|
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
|
|
},
|
|
)
|
|
def test_vmapvjpvmap(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
B = 2
|
|
for sample in samples:
|
|
args = [sample.input] + list(sample.args)
|
|
kwargs = sample.kwargs
|
|
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
|
|
for batched_args, in_dims, kwargs in generator:
|
|
inner_vmapped_op = vmap(op, in_dims)
|
|
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
|
|
|
|
inner_vmapped_fn, primals = normalize_op_input_output2(
|
|
inner_vmapped_op,
|
|
batched_args,
|
|
kwargs,
|
|
sample.output_process_fn_grad,
|
|
)
|
|
inner_mapped_fn, _ = normalize_op_input_output2(
|
|
inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
|
|
)
|
|
result = inner_mapped_fn(*primals)
|
|
cotangents = tree_map(lambda x: torch.rand_like(x), result)
|
|
|
|
def apply_vjp(fn):
|
|
def inner(primals, cotangents):
|
|
_, vjp_fn = vjp(fn, *primals)
|
|
return vjp_fn(cotangents)
|
|
|
|
return inner
|
|
|
|
vjpvmap_fn = apply_vjp(inner_vmapped_fn)
|
|
vjpmap_fn = apply_vjp(inner_mapped_fn)
|
|
batched_args = (primals, cotangents)
|
|
generator = generate_vmap_inputs(batched_args, {})
|
|
|
|
for batched_args, in_dims, _ in generator:
|
|
# strategy: compare vmap(vjp(vmap(op)) vs map(vjp(map(op))
|
|
vmapvjpvmap_fn = vmap(vjpvmap_fn, in_dims)
|
|
mapvjpmap_fn = functools.partial(loop, vjpmap_fn, in_dims, 0, B)
|
|
|
|
result = vmapvjpvmap_fn(*batched_args)
|
|
expected = mapvjpmap_fn(*batched_args)
|
|
self.assertEqual(result, expected)
|
|
|
|
# See NOTE: [three-transform testing]
|
|
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vjpvmapvmap",
|
|
{
|
|
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
|
|
},
|
|
)
|
|
def test_vjpvmapvmap(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
B = 2
|
|
for sample in samples:
|
|
args = [sample.input] + list(sample.args)
|
|
kwargs = sample.kwargs
|
|
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
|
|
for batched_args, inner_in_dims, kwargs in generator:
|
|
inner_vmapped_op = vmap(op, inner_in_dims)
|
|
inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B)
|
|
generator = generate_vmap_inputs(batched_args, kwargs)
|
|
for batched_args, in_dims, kwargs in generator:
|
|
# strategy: compare vjp(vmap(vmap(op)) vs vjp(map(map(op))
|
|
vmapped_op = vmap(inner_vmapped_op, in_dims)
|
|
mapped_op = functools.partial(loop, inner_mapped_op, in_dims, 0, B)
|
|
|
|
vmapped_fn, primals = normalize_op_input_output2(
|
|
vmapped_op, batched_args, kwargs, sample.output_process_fn_grad
|
|
)
|
|
mapped_fn, _ = normalize_op_input_output2(
|
|
mapped_op, batched_args, kwargs, sample.output_process_fn_grad
|
|
)
|
|
|
|
result = mapped_fn(*primals)
|
|
cotangents = tree_map(lambda x: torch.rand_like(x), result)
|
|
|
|
_, vjp_fn = vjp(mapped_fn, *primals)
|
|
expected_vjps = vjp_fn(cotangents)
|
|
|
|
_, vjp_fn = vjp(vmapped_fn, *primals)
|
|
result_vjps = vjp_fn(cotangents)
|
|
|
|
self.assertEqual(result_vjps, expected_vjps)
|
|
|
|
# See NOTE: [three-transform testing]
|
|
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vjpvjpvmap",
|
|
{
|
|
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
|
|
},
|
|
)
|
|
def test_vjpvjpvmap(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
B = 2
|
|
for sample in samples:
|
|
args = [sample.input] + list(sample.args)
|
|
kwargs = sample.kwargs
|
|
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
|
|
for batched_args, in_dims, kwargs in generator:
|
|
inner_vmapped_op = vmap(op, in_dims)
|
|
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
|
|
|
|
vjpmap_fn, args = get_vjpfull_variant2(
|
|
inner_mapped_op, batched_args, kwargs
|
|
)
|
|
vjpvmap_fn, _ = get_vjpfull_variant2(
|
|
inner_vmapped_op, batched_args, kwargs
|
|
)
|
|
|
|
vjpvjpvmap_fn, new_args = get_vjpfull_variant2(vjpvmap_fn, args, {})
|
|
vjpvjpmap_fn, _ = get_vjpfull_variant2(vjpmap_fn, args, {})
|
|
|
|
expected = vjpvjpmap_fn(*new_args)
|
|
result = vjpvjpvmap_fn(*new_args)
|
|
self.assertEqual(result, expected)
|
|
|
|
# We're generally convinced that jvp x vmap works (vmap turns an operator
|
|
# into another operator and we test jvp support for operators). So
|
|
# we only test it on the things we're not sure about:
|
|
# - the autograd.Function <> functorch interaction
|
|
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_jvpvmap",
|
|
{
|
|
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
|
|
},
|
|
)
|
|
def test_jvpvmap(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
B = 2
|
|
for sample in samples:
|
|
args = [sample.input] + list(sample.args)
|
|
kwargs = sample.kwargs
|
|
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
|
|
for batched_args, in_dims, kwargs in generator:
|
|
inner_vmapped_op = vmap(op, in_dims)
|
|
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
|
|
|
|
jvpvmap_op, primals = get_jvp_variant_primals_tangents2(
|
|
inner_vmapped_op,
|
|
batched_args,
|
|
kwargs,
|
|
sample.output_process_fn_grad,
|
|
)
|
|
jvpmap_op, _ = get_jvp_variant_primals_tangents2(
|
|
inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
|
|
)
|
|
|
|
expected = jvpmap_op(*primals)
|
|
result = jvpvmap_op(*primals)
|
|
self.assertEqual(result, expected)
|
|
|
|
# See NOTE: [three-transform testing]
|
|
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_jvpvmapvmap",
|
|
{
|
|
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
|
|
},
|
|
)
|
|
def test_jvpvmapvmap(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
B = 2
|
|
for sample in samples:
|
|
args = [sample.input] + list(sample.args)
|
|
kwargs = sample.kwargs
|
|
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
|
|
for batched_args, inner_in_dims, kwargs in generator:
|
|
inner_vmapped_op = vmap(op, inner_in_dims)
|
|
inner_mapped_op = functools.partial(loop, op, inner_in_dims, 0, B)
|
|
generator = generate_vmap_inputs(batched_args, kwargs)
|
|
for batched_args, in_dims, kwargs in generator:
|
|
# strategy: compare jvp(vmap(vmap(op)) vs jvp(map(map(op))
|
|
vmapped_op = vmap(inner_vmapped_op, in_dims)
|
|
mapped_op = functools.partial(loop, inner_mapped_op, in_dims, 0, B)
|
|
|
|
jvpvmapvmap_fn, primals = get_jvp_variant_primals_tangents2(
|
|
vmapped_op, batched_args, kwargs, sample.output_process_fn_grad
|
|
)
|
|
jvpmapmap_fn, _ = get_jvp_variant_primals_tangents2(
|
|
mapped_op, batched_args, kwargs, sample.output_process_fn_grad
|
|
)
|
|
|
|
expected = jvpmapmap_fn(*primals)
|
|
result = jvpvmapvmap_fn(*primals)
|
|
self.assertEqual(result, expected)
|
|
|
|
# See NOTE: [three-transform testing]
|
|
@with_tf32_off # https://github.com/pytorch/pytorch/issues/86798
|
|
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_vmapjvpvmap",
|
|
{
|
|
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
|
|
},
|
|
)
|
|
def test_vmapjvpvmap(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
B = 2
|
|
for sample in samples:
|
|
args = [sample.input] + list(sample.args)
|
|
kwargs = sample.kwargs
|
|
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
|
|
for batched_args, in_dims, kwargs in generator:
|
|
inner_vmapped_op = vmap(op, in_dims)
|
|
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
|
|
|
|
jvpvmap_fn, primals = get_jvp_variant_primals_tangents2(
|
|
inner_vmapped_op,
|
|
batched_args,
|
|
kwargs,
|
|
sample.output_process_fn_grad,
|
|
)
|
|
jvpmap_fn, _ = get_jvp_variant_primals_tangents2(
|
|
inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
|
|
)
|
|
|
|
generator = generate_vmap_inputs(primals, {})
|
|
|
|
for batched_args, in_dims, _ in generator:
|
|
# strategy: compare vmap(jvp(vmap(op)) vs map(jvp(map(op))
|
|
vmapjvpvmap_fn = vmap(jvpvmap_fn, in_dims)
|
|
mapjvpmap_fn = functools.partial(loop, jvpmap_fn, in_dims, 0, B)
|
|
|
|
result = vmapjvpvmap_fn(*batched_args)
|
|
expected = mapjvpmap_fn(*batched_args)
|
|
self.assertEqual(result, expected)
|
|
|
|
# See NOTE: [three-transform testing]
|
|
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_jvpjvpvmap",
|
|
{
|
|
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
|
|
},
|
|
)
|
|
def test_jvpjvpvmap(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
B = 2
|
|
for sample in samples:
|
|
args = [sample.input] + list(sample.args)
|
|
kwargs = sample.kwargs
|
|
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
|
|
for batched_args, in_dims, kwargs in generator:
|
|
inner_vmapped_op = vmap(op, in_dims)
|
|
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
|
|
|
|
jvpmap_fn, args = get_jvp_variant_primals_tangents2(
|
|
inner_mapped_op, batched_args, kwargs, sample.output_process_fn_grad
|
|
)
|
|
jvpvmap_fn, _ = get_jvp_variant_primals_tangents2(
|
|
inner_vmapped_op,
|
|
batched_args,
|
|
kwargs,
|
|
sample.output_process_fn_grad,
|
|
)
|
|
|
|
jvpjvpvmap_fn, new_args = get_jvp_variant_primals_tangents2(
|
|
jvpvmap_fn, args, {}
|
|
)
|
|
jvpjvpmap_fn, _ = get_jvp_variant_primals_tangents2(jvpmap_fn, args, {})
|
|
|
|
expected = jvpjvpmap_fn(*new_args)
|
|
result = jvpjvpvmap_fn(*new_args)
|
|
self.assertEqual(result, expected)
|
|
|
|
# See NOTE: [three-transform testing]
|
|
@ops(autograd_function_db, allowed_dtypes=(torch.float32,))
|
|
@skipOps(
|
|
"TestOperators",
|
|
"test_jvpvjpvmap",
|
|
{
|
|
xfail("NumpyCubeNotComposableAutogradFunction"), # Not composable
|
|
},
|
|
)
|
|
def test_jvpvjpvmap(self, device, dtype, op):
|
|
samples = op.sample_inputs(device, dtype, requires_grad=True)
|
|
B = 2
|
|
for sample in samples:
|
|
args = [sample.input] + list(sample.args)
|
|
kwargs = sample.kwargs
|
|
generator = generate_vmap_inputs(args, kwargs, batch_size=B)
|
|
for batched_args, in_dims, kwargs in generator:
|
|
inner_vmapped_op = vmap(op, in_dims)
|
|
inner_mapped_op = functools.partial(loop, op, in_dims, 0, B)
|
|
|
|
vjpmap_fn, args = get_vjpfull_variant2(
|
|
inner_mapped_op, batched_args, kwargs
|
|
)
|
|
vjpvmap_fn, _ = get_vjpfull_variant2(
|
|
inner_vmapped_op, batched_args, kwargs
|
|
)
|
|
|
|
jvpvjpvmap_fn, new_args = get_jvp_variant_primals_tangents2(
|
|
vjpvmap_fn, args, {}
|
|
)
|
|
jvpvjpmap_fn, _ = get_jvp_variant_primals_tangents2(vjpmap_fn, args, {})
|
|
|
|
expected = jvpvjpmap_fn(*new_args)
|
|
result = jvpvjpvmap_fn(*new_args)
|
|
self.assertEqual(result, expected)
|
|
|
|
def test_data_write_errors_under_transform(self, device):
|
|
t = torch.randn(3, 3, device=device)
|
|
|
|
def fn(t):
|
|
t.data = torch.randn(3, 3)
|
|
return t.sum()
|
|
|
|
msg = "mutating directly with `.data` inside functorch transform"
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
grad(fn)(t)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
vjp(fn, t)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, msg):
|
|
jvp(fn, (t,), (torch.randn_like(t),))
|
|
|
|
def test_tensor_with_scalar_list(self, device):
|
|
x = torch.randn((), device=device)
|
|
|
|
def func_list_of_scalar(x):
|
|
return torch.tensor([x], device=device)
|
|
|
|
def func(x):
|
|
return torch.tensor(x, device=device).view(1)
|
|
|
|
actual_o, actual_fn = vjp(func_list_of_scalar, x)
|
|
expected_o, expected_fn = vjp(func, x)
|
|
|
|
self.assertEqual(actual_o, expected_o)
|
|
self.assertEqual(
|
|
expected_fn(torch.ones_like(expected_o)),
|
|
actual_fn(torch.ones_like(actual_o)),
|
|
)
|
|
|
|
|
|
only_for = ("cpu", "cuda")
|
|
instantiate_device_type_tests(TestOperators, globals(), only_for=only_for)
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|