Initial NJT testing over dim type / views (#140161)

This PR introduces `ExtraOpData`, a structure that contains op metadata regarding whether the op is a view and the dim-related args it accepts. It also populates a huge database for dim-wise / view ops with this info.

Test logic (sample input generation, references) have been updated to utilize this data. It allows for a fairly generic set of sample inputs & a reference for the class of ops that accept a single NJT and operate dim-wise (AKA "unary dimwise ops").

Testing is added over the following ops:
* `chunk()`
* `narrow()`
* `select()`
* `split()`
* `split_with_sizes()`
* `squeeze()`
* `unflatten()`
* `unsqueeze()`

Most of the above do not operate on the ragged / batch dims or on non-contiguous NJTs, so the proper xfails are added as needed.

I also slipped in a couple minor fixes (sorry):
1. The `_wrap_jagged_dim()` helper now avoids assuming the `nt._ragged_idx == 1` and allows for a batch dim to be a valid input, disambiguating the converted inner dim as necessary through an additional `operating_on_batch` return value (i.e. both dim=0 and dim=1 map to dim=0 on the inner values tensor, since that dim represents a packed ragged dim for all batch items)
2. Padded dense -> NJT conversion requires shape gymnastics to operate with the restrictive FBGEMM kernel. The gymnastics were slightly wrong for the transposed NJT case, and this PR fixes that
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140161
Approved by: https://github.com/Skylion007, https://github.com/cpuhrsch
ghstack dependencies: #141500, #140736
This commit is contained in:
Joel Schlosser 2024-11-26 13:31:56 -05:00 committed by PyTorch MergeBot
parent 7671dd436e
commit 9ee5d6f83c
3 changed files with 748 additions and 79 deletions

View File

@ -71,7 +71,7 @@ from torch.testing._internal.opinfo.core import (
XFailRule,
)
from torch.testing._internal.opinfo.definitions.nested import njt_op_db
from torch.utils._pytree import tree_flatten
from torch.utils._pytree import tree_flatten, tree_map_only
from torch.utils.checkpoint import checkpoint, create_selective_checkpoint_contexts
@ -3969,7 +3969,7 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
with self.assertRaisesRegex(
RuntimeError,
r"split\(\): not supported for NestedTensor on dim=1",
r"split\(\): not supported for NestedTensor on ragged dim",
):
torch.split(nt, 2, 1)
@ -3995,7 +3995,7 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
)
with self.assertRaisesRegex(
RuntimeError,
r"split_with_sizes\(\): not supported for NestedTensor on dim=1",
r"split_with_sizes\(\): not supported for NestedTensor on ragged dim",
):
torch.split(nt, [1, 2], 1)
@ -4201,7 +4201,7 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
# chunk on ragged dim not supported
with self.assertRaisesRegex(
RuntimeError, "chunk.* not supported for NestedTensor on dim=1"
RuntimeError, "chunk.* not supported for NestedTensor on ragged dim"
):
nt.chunk(2, dim=1)
@ -4238,7 +4238,7 @@ class TestNestedTensorSubclass(NestedTensorTestCase):
# squeeze on ragged dim not supported
with self.assertRaisesRegex(
RuntimeError, "squeeze.* not supported for NestedTensor on dim=1"
RuntimeError, "squeeze.* not supported for NestedTensor on ragged dim"
):
nt.squeeze(1)
@ -7964,14 +7964,21 @@ FORWARD_SKIPS_AND_XFAILS = [
isinstance(op, ReductionOpInfo) or "reduction_with_dim" in op.full_name
),
sample_match_fn=lambda device, sample: (
sample.name
== (
"3D_noncontig_transposed_with_seqlen_cache: "
"normal dim reduction with keepdim=False"
)
"noncontig_transposed" in sample.name
and "normal dim reduction with keepdim=False" in sample.name
),
name="transposed_reduction_bug",
),
# likely related to previous: similar error when operating on select() with dim=0
XFailRule(
error_type=IndexError,
error_msg="tuple index out of range",
op_match_fn=lambda device, op: (op.full_name == "select"),
sample_match_fn=lambda device, sample: (
"noncontig_transposed" in sample.name and "normal_dim" in sample.name
),
name="select_batch_dim_bug",
),
# nanmean sometimes hits an unimplemented nansum() path and other times hits an
# unimplemented sum() path
XFailRule(
@ -8035,13 +8042,111 @@ FORWARD_SKIPS_AND_XFAILS = [
),
name="index_put_noncontig_holes_no_ragged_dim_indices",
),
# expected: masked_select() doesn't work on non-contiguous NJTs
# select() only supports dim=0 for non-contiguous with holes NJTs for now
XFailRule(
op_match_fn=lambda device, op: (op.full_name == "select"),
sample_match_fn=lambda device, sample: (
sample.kwargs["dim"] != 0 and "noncontig_holes" in sample.name
),
name="unsupported_select_on_non_batch_dim_with_noncontig_holes",
),
# these don't work on non-contiguous NJTs yet
XFailRule(
error_type=ValueError,
error_msg="expected self to be a contiguous jagged layout NestedTensor",
op_match_fn=lambda device, op: (op.full_name == "masked_select"),
sample_match_fn=lambda device, sample: (not sample.input.is_contiguous()),
name="masked_select_noncontig",
op_match_fn=lambda device, op: (
op.full_name
in {
"chunk",
"masked_select",
"narrow",
"split",
"split_with_sizes",
"squeeze",
}
),
sample_match_fn=lambda device, sample: (
sample.input._lengths is not None or sample.input._ragged_idx != 1
),
name="missing_noncontig_support",
),
# these don't work on the ragged dim yet
XFailRule(
error_type=RuntimeError,
error_msg="not supported for NestedTensor on ragged dim",
op_match_fn=lambda device, op: (
op.full_name
in {
"chunk",
"narrow",
"select",
"split",
"unsqueeze",
}
),
sample_match_fn=lambda device, sample: "ragged_dim" in sample.name,
name="ragged_dim_unsupported",
),
# Bug: unsqueeze at the end is wrong
XFailRule(
error_type=AssertionError,
error_msg="The values for attribute 'shape' do not match",
op_match_fn=lambda device, op: (op.full_name == "unsqueeze"),
sample_match_fn=lambda device, sample: "add dim to the end" in sample.name,
name="unsqueeze_end_dim_unsupported",
),
# Bug: inserting a dim before the ragged dim should update ragged_idx!
XFailRule(
op_match_fn=lambda device, op: (op.full_name == "unsqueeze"),
sample_match_fn=lambda device, sample: (
sample.kwargs["dim"] <= sample.input._ragged_idx
),
name="unsqueeze_dim_before_ragged_bug",
),
XFailRule(
error_type=RuntimeError,
# error comes from usage of view() in the decomp
error_msg="does not support ragged_idx != 1 except when",
op_match_fn=lambda device, op: (op.full_name == "unflatten"),
sample_match_fn=lambda device, sample: "noncontig_transposed" in sample.name,
name="unflatten_ragged_dim_unsupported",
),
# these don't work on the batch dim yet
XFailRule(
error_type=RuntimeError,
error_msg="not supported for NestedTensor on dim=0",
op_match_fn=lambda device, op: (
op.full_name
in {
"narrow",
"split",
"split_with_sizes",
"unsqueeze",
}
),
sample_match_fn=lambda device, sample: "batch_dim" in sample.name,
name="batch_dim_unsupported",
),
XFailRule(
error_type=RuntimeError,
# error comes from usage of view() in the decomp
error_msg="cannot view shape",
op_match_fn=lambda device, op: (op.full_name == "unflatten"),
sample_match_fn=lambda device, sample: "batch_dim" in sample.name,
name="unflatten_batch_dim_unsupported",
),
# Bug: chunk calculation on batch dim is completely wrong for NJT. It should
# match what is done for dense tensors wrt chunk size calculation, which can
# be unintuitive.
XFailRule(
op_match_fn=lambda device, op: op.full_name == "chunk",
sample_match_fn=lambda device, sample: (
"batch_dim" in sample.name
and
# this specific case works lol
not (sample.input.size(0) == 3 and sample.kwargs["chunks"] == 2)
),
name="batch_dim_chunk_bug1",
),
# expected: bmm / matmul sometimes use a to_padded_tensor() fallback which isn't
# supported for non-contig NJTs with holes
@ -8175,7 +8280,8 @@ BACKWARD_SKIPS_AND_XFAILS = [
]
COMPILE_FORWARD_SKIPS_AND_XFAILS = [
*FORWARD_SKIPS_AND_XFAILS,
# The unsqueeze bugs don't affect eager vs. compile consistency so they don't fail here
*(x for x in FORWARD_SKIPS_AND_XFAILS if not x.name.startswith("unsqueeze")),
# Bug: cross-device conversions with to() result in new nested ints within compile only
XFailRule(
error_type=AssertionError,
@ -8196,18 +8302,22 @@ COMPILE_FORWARD_SKIPS_AND_XFAILS = [
),
name="clone_unbind_data_dependency",
),
# select on dim=0 currently uses unbind(), leading to data-dependent error in torch.compile
XFailRule(
error_type=torch._dynamo.exc.Unsupported,
error_msg="data dependent operator: aten._local_scalar_dense.default",
op_match_fn=lambda device, op: (op.full_name == "select"),
sample_match_fn=lambda device, sample: (sample.kwargs["dim"] == 0),
name="select_unbind_data_dependency",
),
# Bug: no idea what's going on here; needs investigation within AOTAutograd
XFailRule(
error_type=ValueError,
error_msg="has length 1 but the spec refers to a pytree that holds 3 items",
op_match_fn=lambda device, op: (op.full_name == "nan_to_num"),
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="crazy_aot_autograd_bug1",
),
# Bug: also no idea what's going on here: needs investigation within AOTAutograd
XFailRule(
error_type=AssertionError,
error_msg="Expected 5 == 4",
op_match_fn=lambda device, op: (op.full_name == "isreal"),
sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name),
name="crazy_aot_autograd_bug2",
@ -8293,7 +8403,8 @@ COMPILE_BACKWARD_SKIPS_AND_XFAILS = [
name="clone_unbind_data_dependency_backward",
),
*COMPILE_FORWARD_SKIPS_AND_XFAILS,
*BACKWARD_SKIPS_AND_XFAILS,
# The unsqueeze bugs don't affect eager vs. compile consistency so they don't fail here
*(x for x in BACKWARD_SKIPS_AND_XFAILS if not x.name.startswith("unsqueeze")),
]
COMPARE_TENSOR_COMPONENT_EQUALITY = {
@ -8330,6 +8441,10 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
out = op.op(sample.input, *sample.args, **sample.kwargs)
out_ref = op.ref(op, sample)
self.assertEqualIgnoringNestedInts(out, out_ref)
if op._extra_op_data.is_view:
tree_map_only(
NestedTensor, lambda x: self.assertTrue(x._is_view()), out
)
# TODO: Revisit once https://github.com/pytorch/pytorch/pull/138369 lands
# TODO: Add xfails for other inplace ops instead of hardcoding
@ -8351,6 +8466,10 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
out = op.op(sample.input, *sample.args, **sample.kwargs)
out_ref = op.ref(op, sample)
self.assertEqualIgnoringNestedInts(out, out_ref)
if op._extra_op_data.is_view:
tree_map_only(
NestedTensor, lambda x: self.assertTrue(x._is_view()), out
)
inps, _ = tree_flatten((sample.input, sample.args, sample.kwargs))
g_inps = [
@ -8395,6 +8514,10 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
out_ref = f(sample.input, *sample.args, **sample.kwargs)
out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
if op._extra_op_data.is_view:
tree_map_only(
NestedTensor, lambda x: self.assertTrue(x._is_view()), out_ref
)
if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
self.assertEqualIgnoringNestedInts(out_compile, out_ref)
@ -8452,6 +8575,10 @@ class TestNestedTensorOpInfo(NestedTensorTestCase):
out_ref = f(sample.input, *sample.args, **sample.kwargs)
out_compile = compiled_f(sample.input, *sample.args, **sample.kwargs)
if op._extra_op_data.is_view:
tree_map_only(
NestedTensor, lambda x: self.assertTrue(x._is_view()), out_ref
)
if op.full_name in COMPARE_TENSOR_COMPONENT_EQUALITY:
self.assertEqualIgnoringNestedInts(out_compile, out_ref)

View File

@ -34,16 +34,28 @@ def _outer_to_inner_dim(ndim, dim, canonicalize=False):
def _wrap_jagged_dim(
ndim, dim, op_name, convert_to_inner_dim=True, allow_batch_dim=False
ndim,
dim,
ragged_dim,
op_name,
convert_to_inner_dim=True,
allow_ragged_dim=False,
allow_batch_dim=False,
):
from torch._prims_common import canonicalize_dims
wrapped = canonicalize_dims(ndim, dim)
if wrapped == 1:
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=1")
if wrapped == ragged_dim and not allow_ragged_dim:
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on ragged dim")
elif wrapped == 0 and not allow_batch_dim:
raise RuntimeError(f"{op_name}(): not supported for NestedTensor on dim=0")
return _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped
ret = _outer_to_inner_dim(ndim, wrapped) if convert_to_inner_dim else wrapped
if allow_batch_dim:
# Need to disambiguate whether we're operating on the batch dim or not.
# Operating on dim=1 -> dim=0 after the inner dim conversion.
operating_on_batch = wrapped == 0
return (ret, operating_on_batch)
return ret
def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
@ -372,10 +384,18 @@ def jagged_torch_function(func, *args, **kwargs):
# NB: stay in outer dim space because we're going to redispatch on a NT input
start_dim = _wrap_jagged_dim(
inp.dim(), new_kwargs["start_dim"], "flatten", convert_to_inner_dim=False
inp.dim(),
new_kwargs["start_dim"],
inp._ragged_idx,
"flatten",
convert_to_inner_dim=False,
)
end_dim = _wrap_jagged_dim(
inp.dim(), new_kwargs["end_dim"], "flatten", convert_to_inner_dim=False
inp.dim(),
new_kwargs["end_dim"],
inp._ragged_idx,
"flatten",
convert_to_inner_dim=False,
)
if start_dim == end_dim:
@ -823,7 +843,9 @@ def split_tensor(func, *args, **kwargs):
inp = new_kwargs.pop("input")
new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "split")
new_kwargs["dim"] = _wrap_jagged_dim(
inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split"
)
return tuple(
NestedTensor(values=x, **extract_kwargs(inp))
@ -842,7 +864,7 @@ def split_with_sizes_default(func, *args, **kwargs):
inp = new_kwargs.pop("input")
new_kwargs["dim"] = _wrap_jagged_dim(
inp.dim(), new_kwargs["dim"], "split_with_sizes"
inp.dim(), new_kwargs["dim"], inp._ragged_idx, "split_with_sizes"
)
return [
@ -860,7 +882,7 @@ def narrow(func, *args, **kwargs):
)
inp = new_kwargs.pop("input")
dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "narrow")
dim = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], inp._ragged_idx, "narrow")
values = func(
inp._values,
dim=dim,
@ -878,11 +900,11 @@ def chunk_default(func, *args, **kwargs):
inp = new_kwargs.pop("input")
new_kwargs["dim"] = _wrap_jagged_dim(
inp.dim(), new_kwargs["dim"], "chunk", allow_batch_dim=True
new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
inp.dim(), new_kwargs["dim"], inp._ragged_idx, "chunk", allow_batch_dim=True
)
if new_kwargs["dim"] == 0:
if operating_on_batch:
chunks = new_kwargs["chunks"]
dim0_size = inp._size[0]
chunk_size = math.ceil(dim0_size / chunks)
@ -956,7 +978,9 @@ def squeeze_dim(func, *args, **kwargs):
inp = new_kwargs.pop("input")
values = inp._values
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size), new_kwargs["dim"], "squeeze")
new_kwargs["dim"] = _wrap_jagged_dim(
len(inp._size), new_kwargs["dim"], inp._ragged_idx, "squeeze"
)
return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
@ -971,7 +995,9 @@ def unsqueeze_default(func, *args, **kwargs):
# Account for collapsed jagged dim
dim = new_kwargs["dim"]
new_kwargs["dim"] = _wrap_jagged_dim(len(inp._size) + 1, dim, "unsqueeze")
new_kwargs["dim"] = _wrap_jagged_dim(
len(inp._size) + 1, dim, inp._ragged_idx, "unsqueeze"
)
return NestedTensor(func(values, **new_kwargs), **extract_kwargs(inp))
@ -991,7 +1017,9 @@ def cat_default(func, *args, **kwargs):
# Account for collapsed jagged dim
dim = new_kwargs["dim"]
new_kwargs["dim"] = _wrap_jagged_dim(len(first.shape), dim, "cat")
new_kwargs["dim"] = _wrap_jagged_dim(
len(first.shape), dim, first._ragged_idx, "cat"
)
return NestedTensor(
func([t._values for t in tensors], **new_kwargs), **extract_kwargs(tensors[0])
@ -1376,8 +1404,12 @@ def transpose_int(func, *args, **kwargs):
**inp_kwargs,
)
new_kwargs["dim0"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim0"], "transpose")
new_kwargs["dim1"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim1"], "transpose")
new_kwargs["dim0"] = _wrap_jagged_dim(
inp.dim(), new_kwargs["dim0"], inp._ragged_idx, "transpose"
)
new_kwargs["dim1"] = _wrap_jagged_dim(
inp.dim(), new_kwargs["dim1"], inp._ragged_idx, "transpose"
)
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
@ -1596,13 +1628,13 @@ def select_int(func, *args, **kwargs):
)
inp = new_kwargs.pop("input")
new_kwargs["dim"] = _wrap_jagged_dim(
inp.dim(), new_kwargs["dim"], "select", allow_batch_dim=True
new_kwargs["dim"], operating_on_batch = _wrap_jagged_dim(
inp.dim(), new_kwargs["dim"], inp._ragged_idx, "select", allow_batch_dim=True
)
# handle batch dim slicing via unbind() for now
# TODO: make this more efficient
if new_kwargs["dim"] == 0:
if operating_on_batch:
return inp.unbind()[new_kwargs["index"]]
if inp._lengths is not None:
@ -1623,7 +1655,9 @@ def slice_tensor(func, *args, **kwargs):
)
inp = new_kwargs.pop("input")
new_kwargs["dim"] = _wrap_jagged_dim(inp.dim(), new_kwargs["dim"], "slice")
new_kwargs["dim"] = _wrap_jagged_dim(
inp.dim(), new_kwargs["dim"], inp._ragged_idx, "slice"
)
return NestedTensor(func(inp._values, **new_kwargs), **extract_kwargs(inp))
@ -1929,7 +1963,7 @@ def stack_default(func, *args, **kwargs):
)
new_kwargs["dim"] = _wrap_jagged_dim(
tensors[0].dim() + 1, new_kwargs["dim"], "stack"
tensors[0].dim() + 1, new_kwargs["dim"], tensors[0]._ragged_idx, "stack"
)
return NestedTensor(
@ -2076,9 +2110,9 @@ def _nested_from_padded_tensor_default(func, *args, **kwargs):
# only 3D padded with ragged packed dim=0 is supported by the underlying FBGEMM
# kernel so do shape gymnastics
padded_shape = padded.shape
if ragged_idx > 1:
padded = padded.transpose(ragged_idx, 1)
padded_ragged_dim1_shape = padded.shape
if padded.dim() > 3:
padded = padded.flatten(start_dim=2)
elif padded.dim() < 3:
@ -2096,9 +2130,9 @@ def _nested_from_padded_tensor_default(func, *args, **kwargs):
values = values.to(torch.bool)
# shape gymnastics part 2
if len(padded_shape) > 3:
values = values.unflatten(-1, padded_shape[2:])
elif len(padded_shape) < 3:
if len(padded_ragged_dim1_shape) > 3:
values = values.unflatten(-1, padded_ragged_dim1_shape[2:])
elif len(padded_ragged_dim1_shape) < 3:
values = values.squeeze(-1)
if ragged_idx > 1:
values = values.transpose(ragged_idx - 1, 0)

View File

@ -1,7 +1,10 @@
# mypy: ignore-errors
import math
from copy import copy
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple
import torch
from torch.fx.experimental.symbolic_shapes import is_nested_int
@ -12,7 +15,197 @@ from torch.testing._internal.opinfo.core import (
SampleInput,
UnaryUfuncInfo,
)
from torch.utils._pytree import tree_map
from torch.utils._pytree import tree_flatten, tree_map
@dataclass
class ExtraOpData:
"""
Contains info on top of the typical OpInfo data that is useful for NJT test generation.
The process that converts the standard op_db -> an NJT-compatible op_db will attach this
data onto each associated OpInfo entry.
"""
# Indicates whether the associated op is a view op
is_view: bool = False
# Specifies the names of any dim-related args that the op takes in. This is useful
# for NJT tests because there is often asymmetry across the supported set of dims for
# an op; it may make sense to operate over the batch dim but not the ragged dim, for
# example. The length of this list should match the number of relevant overloads.
# Each list item of the outer list should specify dim argnames. Ellipses should be used
# to indicate multi-dim support for a given overload.
#
# For example, squeeze() has both a dim and multi-dim overload, where the argname for
# each is simply "dim". Its entry should be: [["dim"], ["dim..."]].
#
# If no overload of the op accepts dim-related args, this should be None.
dim_args: List[List[str]] = None
# Helper function to extract names of dim-related args.
# Returns: tuple of (single dim argname if available, dim list argname if available)
# If the op doesn't support dim-related args at all OR this op only has overloads
# with multiple dim args (e.g. transpose()), then this returns (None, None).
def get_dim_argnames(self) -> Tuple[Optional[str], Optional[str]]:
if self.dim_args is None:
return (None, None)
# name for the dim arg that supports a single dim
single_dim_argname = None
# name for the dim arg that supports a list of dims
dimlist_argname = None
for overload in self.dim_args:
# only consider overloads with a single dim-related arg
if len(overload) != 1:
continue
if overload[0].endswith("..."):
dimlist_argname = overload[0].replace("...", "")
if single_dim_argname is None:
single_dim_argname = dimlist_argname
else:
single_dim_argname = overload[0]
return (single_dim_argname, dimlist_argname)
# Mapping of OpInfo full names -> extra data to tack onto the OpInfo entry for use
# in test generation.
extra_op_data = {
"_segment_reduce.lengths": ExtraOpData(dim_args=[["axis0"]]),
"_segment_reduce.offsets": ExtraOpData(dim_args=[["axis0"]]),
"all": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
"argmax": ExtraOpData(dim_args=[["dim"]]),
"argmin": ExtraOpData(dim_args=[["dim"]]),
"amax": ExtraOpData(dim_args=[["dim..."]]),
"amin": ExtraOpData(dim_args=[["dim..."]]),
"any": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
"argsort": ExtraOpData(dim_args=[["dim"]]),
"broadcast_to": ExtraOpData(is_view=True),
"cat": ExtraOpData(dim_args=[["dim"]]),
"chunk": ExtraOpData(is_view=True, dim_args=[["dim"]]),
"conj": ExtraOpData(is_view=True),
"contiguous": ExtraOpData(is_view=True),
"count_nonzero": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
"cummax": ExtraOpData(dim_args=[["dim"]]),
"cummin": ExtraOpData(dim_args=[["dim"]]),
"cumprod": ExtraOpData(dim_args=[["dim"]]),
"cumsum": ExtraOpData(dim_args=[["dim"]]),
"cumulative_trapezoid": ExtraOpData(dim_args=[["dim"]]),
"diag_embed": ExtraOpData(dim_args=[["dim1", "dim2"]]),
"diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]),
"diagonal_copy": ExtraOpData(dim_args=[["dim1", "dim2"]]),
"diagonal_scatter": ExtraOpData(dim_args=[["dim1", "dim2"]]),
"diff": ExtraOpData(dim_args=[["dim"]]),
"expand": ExtraOpData(is_view=True),
"expand_as": ExtraOpData(is_view=True),
"fft.fft": ExtraOpData(dim_args=[["dim"]]),
"fft.hfft": ExtraOpData(dim_args=[["dim"]]),
"fft.ifft": ExtraOpData(dim_args=[["dim"]]),
"fft.ihfft": ExtraOpData(dim_args=[["dim"]]),
"fft.irfft": ExtraOpData(dim_args=[["dim"]]),
"fft.rfft": ExtraOpData(dim_args=[["dim"]]),
"flatten": ExtraOpData(is_view=True, dim_args=[["start_dim", "end_dim"]]),
"flip": ExtraOpData(dim_args=[["dims..."]]),
"gather": ExtraOpData(dim_args=[["dim"]]),
"imag": ExtraOpData(is_view=True),
"index_add": ExtraOpData(dim_args=[["dim"]]),
"index_copy": ExtraOpData(dim_args=[["dim"]]),
"index_fill": ExtraOpData(dim_args=[["dim"]]),
"index_reduce.amax": ExtraOpData(dim_args=[["dim"]]),
"index_reduce.amin": ExtraOpData(dim_args=[["dim"]]),
"index_reduce.mean": ExtraOpData(dim_args=[["dim"]]),
"index_reduce.prod": ExtraOpData(dim_args=[["dim"]]),
"index_select": ExtraOpData(dim_args=[["dim"]]),
"kthvalue": ExtraOpData(dim_args=[["dim"]]),
"linalg.cross": ExtraOpData(dim_args=[["dim"]]),
"linalg.diagonal": ExtraOpData(is_view=True, dim_args=[["dim1", "dim2"]]),
"linalg.tensorsolve": ExtraOpData(dim_args=[["dims..."]]),
"linalg.vecdot": ExtraOpData(dim_args=[["dim"]]),
"linalg.vector_norm": ExtraOpData(dim_args=[["dim..."]]),
"log_softmax": ExtraOpData(dim_args=[["dim"]]),
"logcumsumexp": ExtraOpData(dim_args=[["dim"]]),
"masked.amax": ExtraOpData(dim_args=[["dim"]]),
"masked.amin": ExtraOpData(dim_args=[["dim"]]),
"masked.argmax": ExtraOpData(dim_args=[["dim"]]),
"masked.argmin": ExtraOpData(dim_args=[["dim"]]),
"masked.logsumexp": ExtraOpData(dim_args=[["dim"]]),
"masked.mean": ExtraOpData(dim_args=[["dim"]]),
"masked.norm": ExtraOpData(dim_args=[["dim"]]),
"masked.prod": ExtraOpData(dim_args=[["dim"]]),
"masked.std": ExtraOpData(dim_args=[["dim"]]),
"masked.sum": ExtraOpData(dim_args=[["dim"]]),
"masked.var": ExtraOpData(dim_args=[["dim"]]),
"max.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]),
"median": ExtraOpData(dim_args=[["dim"]]),
"mean": ExtraOpData(dim_args=[["dim..."]]),
"min.reduction_with_dim": ExtraOpData(dim_args=[["dim"]]),
"mode": ExtraOpData(dim_args=[["dim"]]),
"movedim": ExtraOpData(
dim_args=[["source", "destination"], ["source...", "destination..."]]
),
"nanmean": ExtraOpData(dim_args=[["dim..."]]),
"nanmedian": ExtraOpData(dim_args=[["dim"]]),
"nansum": ExtraOpData(dim_args=[["dim..."]]),
"narrow": ExtraOpData(is_view=True, dim_args=[["dim"]]),
"narrow_copy": ExtraOpData(dim_args=[["dim"]]),
"nn.functional.cosine_similarity": ExtraOpData(dim_args=[["dim"]]),
"nn.functional.glu": ExtraOpData(dim_args=[["dim"]]),
"permute": ExtraOpData(is_view=True, dim_args=[["dims..."]]),
"positive": ExtraOpData(is_view=True),
"prod": ExtraOpData(dim_args=[["dim"]]),
"ravel": ExtraOpData(is_view=True),
"real": ExtraOpData(is_view=True),
"renorm": ExtraOpData(dim_args=[["dim"]]),
"reshape": ExtraOpData(is_view=True),
"reshape_as": ExtraOpData(is_view=True),
"roll": ExtraOpData(dim_args=[["dims..."]]),
"rot90": ExtraOpData(dim_args=[["dims..."]]),
"scatter": ExtraOpData(dim_args=[["dim"]]),
"scatter_add": ExtraOpData(dim_args=[["dim"]]),
"scatter_reduce.amax": ExtraOpData(dim_args=[["dim"]]),
"scatter_reduce.amin": ExtraOpData(dim_args=[["dim"]]),
"scatter_reduce.mean": ExtraOpData(dim_args=[["dim"]]),
"scatter_reduce.prod": ExtraOpData(dim_args=[["dim"]]),
"scatter_reduce.sum": ExtraOpData(dim_args=[["dim"]]),
"select": ExtraOpData(is_view=True, dim_args=[["dim"]]),
"select_scatter": ExtraOpData(dim_args=[["dim"]]),
"slice": ExtraOpData(is_view=True, dim_args=[["dim"]]),
"slice_scatter": ExtraOpData(dim_args=[["dim"]]),
"softmax": ExtraOpData(dim_args=[["dim"]]),
"sort": ExtraOpData(dim_args=[["dim"]]),
"split": ExtraOpData(is_view=True, dim_args=[["dim"]]),
"split_with_sizes": ExtraOpData(is_view=True, dim_args=[["dim"]]),
"split_with_sizes_copy": ExtraOpData(dim_args=[["dim"]]),
"squeeze": ExtraOpData(is_view=True, dim_args=[["dim"], ["dim..."]]),
"squeeze_copy": ExtraOpData(dim_args=[["dim"], ["dim..."]]),
"stack": ExtraOpData(dim_args=[["dim"]]),
"std": ExtraOpData(dim_args=[["dim..."]]),
"std.unbiased": ExtraOpData(dim_args=[["dim..."]]),
"sum": ExtraOpData(dim_args=[["dim..."]]),
"t": ExtraOpData(is_view=True),
"tensor_split": ExtraOpData(is_view=True, dim_args=[["dim"]]),
"tensordot": ExtraOpData(dim_args=[["dims..."]]),
"tile": ExtraOpData(dim_args=[["dims..."]]),
"topk": ExtraOpData(dim_args=[["dim"]]),
"transpose": ExtraOpData(is_view=True, dim_args=[["dim0", "dim1"]]),
"transpose_copy": ExtraOpData(dim_args=[["dim0", "dim1"]]),
"trapezoid": ExtraOpData(dim_args=[["dim"]]),
"trapz": ExtraOpData(dim_args=[["dim"]]),
"unbind": ExtraOpData(is_view=True, dim_args=[["dim"]]),
"unflatten": ExtraOpData(is_view=True, dim_args=[["dim"]]),
"unfold": ExtraOpData(is_view=True, dim_args=[["dimension"]]),
"unfold_copy": ExtraOpData(dim_args=[["dimension"]]),
"unsafe_chunk": ExtraOpData(dim_args=[["dim"]]),
"unsafe_split": ExtraOpData(dim_args=[["dim"]]),
"unsqueeze": ExtraOpData(is_view=True, dim_args=[["dim"]]),
"unsqueeze_copy": ExtraOpData(dim_args=[["dim"]]),
"var": ExtraOpData(dim_args=[["dim..."]]),
"var.unbiased": ExtraOpData(dim_args=[["dim..."]]),
"view": ExtraOpData(is_view=True),
"view_as": ExtraOpData(is_view=True),
"view_as_complex": ExtraOpData(is_view=True),
"view_as_real": ExtraOpData(is_view=True),
}
# random integer used for sizes
@ -29,6 +222,19 @@ def _raggedness_matches(nt1, nt2):
)
# Helper function to update a sample with new kwargs / name
def _update_sample(sample, new_kwargs):
all_kwargs = dict(sample.kwargs)
all_kwargs.update(new_kwargs)
full_name = ", ".join([sample.name, *(f"{k}={v}" for (k, v) in new_kwargs.items())])
return SampleInput(
sample.input.clone().detach(),
args=sample.args,
kwargs=all_kwargs,
name=full_name,
)
# Generates a random NT.
# dims should be something like [5, None, 10], with None indicating that a
# random ragged structure should be used
@ -61,6 +267,15 @@ def _describe_njt(njt) -> str:
return f"{njt.dim()}D{contig_type}{cached_data}"
# Helper function to get a reasonable string representation of a given dim wrt an NJT.
def _describe_dim(njt, dim):
if dim == 0:
return "batch_dim"
elif dim == njt._ragged_idx:
return "ragged_dim"
return "normal_dim"
# Helper function for generating a comprehensive set of NJT sample inputs.
def _sample_njts(device, dtype, requires_grad=False, dims=None):
if dims is None:
@ -88,7 +303,7 @@ def _sample_njts(device, dtype, requires_grad=False, dims=None):
# non-contiguous transposed NJT (not possible for 2D)
if dim > 2:
yield nt.transpose(-2, -1)
yield nt.transpose(-1, nt._ragged_idx)
# non-contiguous with holes NJT
values = nt.values().clone().detach()
@ -161,14 +376,19 @@ def unbind_reference(op, sample, wrap_output_as_njt=True):
from torch.nested._internal.ops import _outer_to_inner_dim
# Need to adjust dim to apply on NJT component
if "dim" in kwargs:
kwargs["dim"] = _outer_to_inner_dim(
nt_inp.dim(), kwargs["dim"], canonicalize=True
# Need to adjust dims to apply on NJT component
if op._extra_op_data.dim_args is not None:
# get all possible dim-related argnames that could be encountered for this op
argnames = tree_map(
lambda a: a.replace("...", ""),
tree_flatten(op._extra_op_data.dim_args)[0],
)
# for all dim-related args present, convert from outer -> inner dim space
for argname in {a for a in argnames if a in kwargs}:
kwargs[argname] = _outer_to_inner_dim(
nt_inp.dim(), kwargs[argname], canonicalize=True
)
# TODO: handle this
assert "dims" not in kwargs
out_ref_component = op.op(inp, *args, **kwargs)
out_ref_components.append(out_ref_component)
@ -193,14 +413,36 @@ def unbind_reference(op, sample, wrap_output_as_njt=True):
return out_ref_components
# Computes the reference value for a non-reduction unary op with dim-wise application.
def unary_dimwise_reference(op, sample, batchwise_reference=None):
# extract info about the dim args this op supports
assert op._extra_op_data.dim_args is not None
single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames()
# only support a single non-list dim arg for now
assert dimlist_argname is None
assert single_dim_argname is not None
if sample.kwargs[single_dim_argname] == 0:
# unbind reference won't work for batch-wise operation; handle this case here
assert batchwise_reference is not None
return batchwise_reference(op, sample)
return unbind_reference(op, sample)
# Computes the reference value for a reduction op.
def reduction_reference(op, sample):
assert sample.input.is_nested
dim = sample.kwargs.get("dim", None)
keepdim = sample.kwargs.get("keepdim", False)
assert dim != 0, "reductions over the batch dim are not supported"
assert "dims" not in sample.kwargs
# extract info about the dim args this op supports
assert op._extra_op_data.dim_args is not None
single_dim_argname, dimlist_argname = op._extra_op_data.get_dim_argnames()
assert single_dim_argname is not None
supports_dimlist = dimlist_argname is not None
dim = sample.kwargs.get(
dimlist_argname, sample.kwargs.get(single_dim_argname, None)
)
keepdim = sample.kwargs.get("keepdim", False)
assert dim != 0, "reductions over just the batch dim are not supported"
if isinstance(dim, (tuple, list)):
reduce_on_ragged = sample.input._ragged_idx in dim
reduce_on_batch = 0 in dim
@ -217,7 +459,8 @@ def reduction_reference(op, sample):
from torch.nested._internal.ops import _outer_to_inner_dim
ref_kwargs = dict(sample.kwargs)
ref_kwargs["dim"] = _outer_to_inner_dim(
assert dimlist_argname is not None
ref_kwargs[dimlist_argname] = _outer_to_inner_dim(
sample.input.dim(), dim, canonicalize=True
)
out = op.op(sample.input.values(), *sample.args, **ref_kwargs)
@ -413,7 +656,6 @@ def sample_inputs_njt_reduction(
device,
dtype,
requires_grad,
supports_dimlist=True,
supports_keepdim=True,
op_kwargs=None,
**kwargs,
@ -421,6 +663,15 @@ def sample_inputs_njt_reduction(
if not op_kwargs:
op_kwargs = {}
# extract info about the dim args this op supports
assert op_info._extra_op_data.dim_args is not None
(
single_dim_argname,
dimlist_argname,
) = op_info._extra_op_data.get_dim_argnames()
assert single_dim_argname is not None
supports_dimlist = dimlist_argname is not None
for njt in _sample_njts(
device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
):
@ -437,7 +688,7 @@ def sample_inputs_njt_reduction(
njt.detach().clone(),
kwargs={
**op_kwargs,
"dim": dim,
single_dim_argname: dim,
**({"keepdim": keepdim} if supports_keepdim else {}),
},
name=f"{njt_desc}: {dim_desc} dim reduction{keepdim_suffix}",
@ -449,7 +700,7 @@ def sample_inputs_njt_reduction(
njt.detach().clone(),
kwargs={
**op_kwargs,
"dim": [0, njt._ragged_idx],
dimlist_argname: [0, njt._ragged_idx],
**({"keepdim": keepdim} if supports_keepdim else {}),
},
name=f"{njt_desc}: batch+ragged reduction{keepdim_suffix}",
@ -461,7 +712,7 @@ def sample_inputs_njt_reduction(
njt.detach().clone(),
kwargs={
**op_kwargs,
"dim": [0, njt._ragged_idx, other_dim],
dimlist_argname: [0, njt._ragged_idx, other_dim],
**({"keepdim": keepdim} if supports_keepdim else {}),
},
name=(
@ -476,7 +727,7 @@ def sample_inputs_njt_reduction(
njt.detach().clone(),
kwargs={
**op_kwargs,
"dim": [njt.dim() - 2, njt.dim() - 1],
dimlist_argname: [njt.dim() - 2, njt.dim() - 1],
**({"keepdim": keepdim} if supports_keepdim else {}),
},
name=f"{njt_desc}: two normal dim reduction{keepdim_suffix}",
@ -487,7 +738,7 @@ def sample_inputs_njt_reduction(
njt.detach().clone(),
kwargs={
**op_kwargs,
"dim": list(range(njt.dim())),
dimlist_argname: list(range(njt.dim())),
**({"keepdim": keepdim} if supports_keepdim else {}),
},
name=f"{njt_desc}: all dim reduction{keepdim_suffix}",
@ -524,7 +775,85 @@ def unsupported_reference(op_name):
return _f
# === BEGIN OP-SPECIFIC SAMPLE INPUTS FUNCS ===
# === BEGIN OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES ===
def sample_inputs_unary_dimwise(
op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
):
if op_kwargs is None:
op_kwargs = {}
# only support a single non-list dim arg for now
assert op_info._extra_op_data is not None
single_dim_argname, dimlist_argname = op_info._extra_op_data.get_dim_argnames()
assert single_dim_argname is not None
assert dimlist_argname is None
for njt in _sample_njts(
device=device, dtype=dtype, requires_grad=requires_grad, dims=[2, 3, 4]
):
for dim in range(njt.dim()):
kwargs = {single_dim_argname: dim}
kwargs.update(op_kwargs)
yield SampleInput(
njt.clone().detach(),
kwargs=kwargs,
name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}",
)
def batchwise_reference_chunk(op, sample):
# reference for chunk() over dim=0
kwargs = sample.kwargs
B = sample.input.size(0)
num_chunks = sample.kwargs["chunks"]
chunk_size = math.ceil(B / num_chunks)
num_full_chunks = B // chunk_size
chunk_sizes = [chunk_size for _ in range(num_full_chunks)]
if B % chunk_size != 0:
# final chunk contains the leftovers
chunk_sizes.append(B % chunk_size)
# split unbound components into chunks according to calculated sizes
components = list(sample.input.unbind())
start = 0
chunks = []
for chunk_size in chunk_sizes:
chunks.append(components[start : start + chunk_size])
start += chunk_size
# rejoin into NJT outputs
return [torch.nested.nested_tensor(lst, layout=torch.jagged) for lst in chunks]
def batchwise_reference_narrow(op, sample):
# TODO: write this!
raise NotImplementedError
def batchwise_reference_select(op, sample):
# reference for select() over dim=0
return sample.input.unbind()[sample.kwargs["index"]]
def batchwise_reference_split(op, sample):
# TODO: write this!
raise NotImplementedError
def batchwise_reference_split_with_sizes(op, sample):
# TODO: write this!
raise NotImplementedError
def batchwise_reference_unflatten(op, sample):
# TODO: write this!
raise NotImplementedError
def batchwise_reference_unsqueeze(op, sample):
raise ValueError("unsqueeze() is not intended to operate on the batch dim")
def sample_inputs_clone(op_info, device, dtype, requires_grad, **kwargs):
# non-contiguous NJTs
for njt in _sample_njts(
@ -622,6 +951,20 @@ def reference_bmm(op, sample):
return unbind_reference(matmul_op, modified_sample)
def sample_inputs_chunk(op_info, device, dtype, requires_grad, **kwargs):
for sample_input in sample_inputs_unary_dimwise(
op_info, device, dtype, requires_grad, **kwargs
):
# ragged dim chunking: test a single chunks value
if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
yield _update_sample(sample_input, {"chunks": 3})
# other dim chunking: test different chunks values
else:
D = sample_input.input.size(sample_input.kwargs["dim"])
for chunks in [1, D // 2, D - 1, D]:
yield _update_sample(sample_input, {"chunks": chunks})
def sample_inputs_matmul(
op_info, device, dtype, requires_grad, op_kwargs=None, **kwargs
):
@ -680,6 +1023,20 @@ def sample_inputs_masked_select(
)
def sample_inputs_narrow(op_info, device, dtype, requires_grad, **kwargs):
for sample_input in sample_inputs_unary_dimwise(
op_info, device, dtype, requires_grad, **kwargs
):
# ragged dim narrowing: test a single start, length value
if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
yield _update_sample(sample_input, {"start": 1, "length": 2})
# other dim narrowing: test different start, length values
else:
D = sample_input.input.size(sample_input.kwargs["dim"])
for start, length in [(0, D), (0, D - 1), (1, D - 1), (D - 1, 1)]:
yield _update_sample(sample_input, {"start": start, "length": length})
def sample_inputs_nn_functional_embedding(
op_info, device, dtype, requires_grad, **kwargs
):
@ -885,6 +1242,132 @@ sample_inputs_nn_functional_threshold = partial(
)
def sample_inputs_select(op_info, device, dtype, requires_grad, **kwargs):
for sample_input in sample_inputs_unary_dimwise(
op_info, device, dtype, requires_grad, **kwargs
):
# ragged dim chunking: test a single index
if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
yield _update_sample(sample_input, {"index": 0})
# other dim chunking: test different indices
else:
D = sample_input.input.size(sample_input.kwargs["dim"])
for index in [0, D // 2, D - 1]:
yield _update_sample(sample_input, {"index": index})
def sample_inputs_split(op_info, device, dtype, requires_grad, **kwargs):
for sample_input in sample_inputs_unary_dimwise(
op_info, device, dtype, requires_grad, **kwargs
):
# ragged dim chunking: test a single split size
if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
yield _update_sample(sample_input, {"split_size_or_sections": 3})
# other dim chunking: test different split sizes
else:
D = sample_input.input.size(sample_input.kwargs["dim"])
for split_size in [1, D // 2, D - 1, D]:
yield _update_sample(
sample_input, {"split_size_or_sections": split_size}
)
def sample_inputs_split_with_sizes(op_info, device, dtype, requires_grad, **kwargs):
for sample_input in sample_inputs_unary_dimwise(
op_info, device, dtype, requires_grad, **kwargs
):
# It will never make sense to operate on the ragged dim.
# TODO: Handle this with error_inputs
if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
continue
D = sample_input.input.size(sample_input.kwargs["dim"])
# splits should add up to D
split1 = torch.randint(0, D - 1, size=()).item()
split2 = D - split1
yield _update_sample(sample_input, {"split_sizes": [split1, split2]})
def sample_inputs_squeeze(op_info, device, dtype, requires_grad, **kwargs):
# squeeze-specific NJT generator (need to ensure there are some 1s in the shape)
def _get_njts():
njt = random_nt_from_dims(
(4, None, 1, 3, 1),
device=device,
dtype=dtype,
requires_grad=requires_grad,
layout=torch.jagged,
)
yield njt
# without min / max seqlen cached
values = njt.values().detach().clone()
offsets = njt.offsets().detach().clone()
yield torch.nested.nested_tensor_from_jagged(values, offsets)
# non-contiguous transposed
yield njt.transpose(1, 3)
# non-contiguous with holes
values = njt.values().clone().detach()
offsets = njt.offsets().clone().detach()
# subtract 1 to cause holes
lengths = (offsets.diff() - 1).clone().detach()
yield torch.nested.nested_tensor_from_jagged(
values=values,
offsets=offsets,
lengths=lengths,
)
for njt in _get_njts():
# single dim operation
for dim in range(njt.dim()):
# Operation on batch / ragged dim is never expected to work.
# TODO: Handle these via error_inputs.
if dim == 0 or dim == njt._ragged_idx:
continue
yield SampleInput(
njt.clone().detach(),
kwargs={"dim": dim},
name=f"{_describe_njt(njt)}: {_describe_dim(njt, dim)}",
)
# multiple dim operation (pass no args)
yield SampleInput(
njt.clone().detach(),
kwargs={"dim": dim},
name=f"{_describe_njt(njt)}: multiple dims",
)
def sample_inputs_unflatten(op_info, device, dtype, requires_grad, **kwargs):
for sample_input in sample_inputs_unary_dimwise(
op_info, device, dtype, requires_grad, **kwargs
):
# It will never make sense to operate on the ragged dim.
# TODO: Handle this with error_inputs
if sample_input.kwargs["dim"] == sample_input.input._ragged_idx:
continue
D = sample_input.input.size(sample_input.kwargs["dim"])
# sizes should multiply to be D
yield _update_sample(sample_input, {"sizes": [D, 1]})
yield _update_sample(sample_input, {"sizes": [1, D]})
if D % 2 == 0:
yield _update_sample(sample_input, {"sizes": [D // 2, 2]})
yield _update_sample(sample_input, {"sizes": [2, D // 2]})
def sample_inputs_unsqueeze(op_info, device, dtype, requires_grad, **kwargs):
for sample_input in sample_inputs_unary_dimwise(
op_info, device, dtype, requires_grad, **kwargs
):
yield sample_input
last_dim_sample = _update_sample(sample_input, {"dim": -1})
last_dim_sample.name = (
f"{_describe_njt(last_dim_sample.input)}: add dim to the end"
)
yield last_dim_sample
def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
for sample in sample_inputs_elementwise_njt_binary(
op_info, device, dtype, requires_grad, **kwargs
@ -897,7 +1380,7 @@ def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
yield sample
# === END OP-SPECIFIC SAMPLE INPUTS FUNCS ===
# === END OP-SPECIFIC SAMPLE INPUTS FUNCS / REFERENCES ===
# Mapping of OpInfo full names -> sample_inputs_funcs, which define the set of sample inputs
@ -906,9 +1389,8 @@ def sample_inputs_where(op_info, device, dtype, requires_grad, **kwargs):
# to specify if they cannot be auto-generated for some reason. Try to keep these sorted
# in alphabetical order!
njt_sample_inputs = {
"argmax": partial(sample_inputs_njt_reduction, supports_dimlist=False),
"argmin": partial(sample_inputs_njt_reduction, supports_dimlist=False),
"bmm": sample_inputs_bmm,
"chunk": sample_inputs_chunk,
"clone": sample_inputs_clone,
"count_nonzero": partial(sample_inputs_njt_reduction, supports_keepdim=False),
**{f"mvlgamma.mvlgamma_p_{p}": sample_inputs_mvl_gamma(p=1) for p in (1, 3, 5)},
@ -922,26 +1404,50 @@ njt_sample_inputs = {
"to": sample_inputs_to,
"matmul": sample_inputs_matmul,
"masked_select": sample_inputs_masked_select,
"narrow": sample_inputs_narrow,
"index_put": sample_inputs_index_put,
"max.reduction_with_dim": partial(
sample_inputs_njt_reduction, supports_dimlist=False
),
"min.reduction_with_dim": partial(
sample_inputs_njt_reduction, supports_dimlist=False
),
"prod": partial(sample_inputs_njt_reduction, supports_dimlist=False),
# these two don't have ReductionOpInfo entries
"max.reduction_with_dim": sample_inputs_njt_reduction,
"min.reduction_with_dim": sample_inputs_njt_reduction,
"select": sample_inputs_select,
"split": sample_inputs_split,
"split_with_sizes": sample_inputs_split_with_sizes,
"squeeze": sample_inputs_squeeze,
"unflatten": sample_inputs_unflatten,
"unsqueeze": sample_inputs_unsqueeze,
"where": sample_inputs_where,
}
njt_references = {
"argmax": reduction_reference,
"argmin": reduction_reference,
"bmm": reference_bmm,
"chunk": partial(
unary_dimwise_reference, batchwise_reference=batchwise_reference_chunk
),
"count_nonzero": reduction_reference,
# these two don't have ReductionOpInfo entries
"max.reduction_with_dim": reduction_reference,
"min.reduction_with_dim": reduction_reference,
"narrow": partial(
unary_dimwise_reference, batchwise_reference=batchwise_reference_narrow
),
"select": partial(
unary_dimwise_reference, batchwise_reference=batchwise_reference_select
),
"split": partial(
unary_dimwise_reference, batchwise_reference=batchwise_reference_split
),
"split_with_sizes": partial(
unary_dimwise_reference,
batchwise_reference=batchwise_reference_split_with_sizes,
),
"squeeze": unbind_reference,
"nn.functional.embedding_bag": reference_nn_functional_embedding_bag,
"prod": reduction_reference,
"unflatten": partial(
unary_dimwise_reference, batchwise_reference=batchwise_reference_unflatten
),
"unsqueeze": partial(
unary_dimwise_reference, batchwise_reference=batchwise_reference_unsqueeze
),
}
@ -949,6 +1455,8 @@ njt_references = {
def translate_opinfo(op):
new_op = copy(op)
new_op.supports_njt = True
# add some extra info for use in generating tests on the right subset of ops
new_op._extra_op_data = extra_op_data.get(op.full_name, ExtraOpData())
if op.full_name in njt_sample_inputs:
new_op.sample_inputs_func = njt_sample_inputs[op.full_name]