mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
As per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/79311 Approved by: https://github.com/ezyang
883 lines
39 KiB
Python
883 lines
39 KiB
Python
# Owner(s): ["module: primTorch"]
|
|
|
|
import torch
|
|
import os
|
|
from enum import Enum
|
|
from torch.overrides import resolve_name
|
|
from torch.utils._pytree import tree_map, tree_flatten
|
|
from torch._subclasses.meta_utils import MetaConverter
|
|
import torch.utils._python_dispatch
|
|
from torch.testing._internal.common_utils import (
|
|
TestCase,
|
|
skipIfCrossRef,
|
|
suppress_warnings,
|
|
TEST_WITH_ASAN,
|
|
run_tests,
|
|
)
|
|
from torch.testing._internal.common_device_type import (
|
|
ops,
|
|
instantiate_device_type_tests,
|
|
onlyCUDA,
|
|
)
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
from torchgen.utils import YamlLoader
|
|
from torchgen.model import OperatorName
|
|
|
|
import sys
|
|
import yaml
|
|
import atexit
|
|
import re
|
|
from collections import defaultdict
|
|
import unittest
|
|
import warnings
|
|
|
|
bf16 = torch.bfloat16
|
|
f64 = torch.float64
|
|
f32 = torch.float32
|
|
f16 = torch.float16
|
|
c32 = torch.complex32
|
|
c64 = torch.complex64
|
|
c128 = torch.complex128
|
|
i8 = torch.int8
|
|
i16 = torch.int16
|
|
i32 = torch.int32
|
|
i64 = torch.int64
|
|
b8 = torch.bool
|
|
u8 = torch.uint8
|
|
|
|
dtype_abbrs = {
|
|
torch.bfloat16: 'bf16',
|
|
torch.float64: 'f64',
|
|
torch.float32: 'f32',
|
|
torch.float16: 'f16',
|
|
torch.complex32: 'c32',
|
|
torch.complex64: 'c64',
|
|
torch.complex128: 'c128',
|
|
torch.int8: 'i8',
|
|
torch.int16: 'i16',
|
|
torch.int32: 'i32',
|
|
torch.int64: 'i64',
|
|
torch.bool: 'b8',
|
|
torch.uint8: 'u8',
|
|
}
|
|
|
|
|
|
class TestMetaConverter(TestCase):
|
|
def assertSameVersionCounter(self, m1, m2):
|
|
# Cannot easily test m1 and m2 have same storage due to
|
|
# lack of Storage bindings. Use version counter.
|
|
vc = m1._version
|
|
self.assertEqual(m2._version, vc)
|
|
# Doing it this way ensures that we get VC bump even with leaves
|
|
with torch.no_grad():
|
|
m1._base.add_(3)
|
|
self.assertNotEqual(m1._version, vc)
|
|
self.assertEqual(m2._version, m1._version)
|
|
|
|
def test_view_of_non_leaf(self):
|
|
x = torch.randn(4, requires_grad=True)
|
|
y = x.neg()
|
|
z1 = y[:]
|
|
z2 = y[:]
|
|
to_meta = MetaConverter()
|
|
m1 = to_meta(z1)
|
|
m2 = to_meta(z2)
|
|
self.assertEqual(m1.shape, z1.shape)
|
|
self.assertTrue(m1._is_view())
|
|
self.assertFalse(m1._base.is_leaf)
|
|
self.assertSameVersionCounter(m1, m2)
|
|
|
|
def test_view_of_leaf(self):
|
|
x = torch.randn(4, requires_grad=True)
|
|
z1 = x[:]
|
|
z2 = x[:]
|
|
to_meta = MetaConverter()
|
|
m1 = to_meta(z1)
|
|
m2 = to_meta(z2)
|
|
self.assertEqual(m1.shape, z1.shape)
|
|
self.assertTrue(m1._is_view())
|
|
self.assertTrue(m1._base.is_leaf)
|
|
self.assertSameVersionCounter(m1, m2)
|
|
|
|
def test_leaf(self):
|
|
x = torch.randn(4, requires_grad=True)
|
|
to_meta = MetaConverter()
|
|
m = to_meta(x)
|
|
self.assertEqual(m.shape, x.shape)
|
|
self.assertTrue(m.is_leaf)
|
|
self.assertTrue(m.requires_grad)
|
|
|
|
def test_non_leaf(self):
|
|
x = torch.randn(4, requires_grad=True)
|
|
y = x.neg()
|
|
to_meta = MetaConverter()
|
|
m = to_meta(y)
|
|
self.assertEqual(m.shape, y.shape)
|
|
self.assertFalse(m.is_leaf)
|
|
self.assertTrue(m.requires_grad)
|
|
|
|
def test_requires_grad_false(self):
|
|
x = torch.randn(4, requires_grad=False)
|
|
to_meta = MetaConverter()
|
|
m = to_meta(x)
|
|
self.assertEqual(m.shape, x.shape)
|
|
self.assertFalse(m.requires_grad)
|
|
|
|
# NB: complex stuff is not actually exercised right now because
|
|
# we have a blanket exclusion for complex conversion
|
|
|
|
def test_view_as_real(self):
|
|
x = torch.randn(4, dtype=torch.complex64)
|
|
y = torch.view_as_real(x)
|
|
m = MetaConverter()(y)
|
|
self.assertEqual(m.shape, y.shape)
|
|
self.assertEqual(m.stride(), y.stride())
|
|
self.assertEqual(m.dtype, y.dtype)
|
|
|
|
def test_complex_noncontiguous_bug(self):
|
|
x = torch.randn((2, 2, 4, 9), dtype=torch.complex32)[:, 0, :, :]
|
|
m = MetaConverter()(x)
|
|
self.assertEqual(m.shape, x.shape)
|
|
self.assertEqual(m.stride(), x.stride())
|
|
self.assertEqual(m.dtype, x.dtype)
|
|
|
|
def test_view_as_complex(self):
|
|
x = torch.randn((4, 2), dtype=torch.float32)
|
|
y = torch.view_as_complex(x)
|
|
m = MetaConverter()(y)
|
|
self.assertEqual(m.shape, y.shape)
|
|
self.assertEqual(m.stride(), y.stride())
|
|
self.assertEqual(m.dtype, y.dtype)
|
|
|
|
def test_view_dtype(self):
|
|
x = torch.randn(4, dtype=torch.float32)
|
|
y = x.view(dtype=torch.int32)
|
|
m = MetaConverter()(y)
|
|
self.assertEqual(m.shape, y.shape)
|
|
self.assertEqual(m.stride(), y.stride())
|
|
self.assertEqual(m.dtype, y.dtype)
|
|
|
|
def test_imag(self):
|
|
x = torch.randn(4, dtype=torch.complex64)
|
|
y = x.imag
|
|
m = MetaConverter()(y)
|
|
self.assertEqual(m.shape, y.shape)
|
|
self.assertEqual(m.dtype, y.dtype)
|
|
self.assertEqual(m.stride(), y.stride())
|
|
self.assertEqual(m.storage_offset(), y.storage_offset())
|
|
|
|
|
|
def assert_ref_meta_equal(test_case, meta_rs, rs, msg_callable):
|
|
flat_meta_rs, _ = tree_flatten(meta_rs)
|
|
flat_rs, _ = tree_flatten(rs)
|
|
test_case.assertEqual(len(flat_meta_rs), len(flat_rs))
|
|
for i, meta_r, r in zip(range(len(flat_rs)), flat_meta_rs, flat_rs):
|
|
def test_assert(cond, msg):
|
|
if not cond:
|
|
raise RuntimeError(f"output {i}: {msg_callable(msg)}")
|
|
if not isinstance(r, torch.Tensor):
|
|
continue
|
|
test_assert(isinstance(meta_r, torch.Tensor), f"but real {i}th result is Tensor")
|
|
test_assert(meta_r.dtype == r.dtype, f"but real dtype was {r.dtype}")
|
|
test_assert(meta_r.shape == r.shape, f"but real shape was {r.shape}")
|
|
# NOTE: stride checking is currently disabled
|
|
# See https://github.com/pytorch/pytorch/issues/78050
|
|
# same_strides, _ = prims.utils.check_significant_strides(meta_r, r)
|
|
# test_assert(same_strides, f"but real stride was {r.stride()}")
|
|
test_assert(
|
|
meta_r.storage_offset() == r.storage_offset(),
|
|
f"but real storage_offset was {r.storage_offset()}")
|
|
test_assert(meta_r.requires_grad == r.requires_grad, f"but real requires_grad was {r.requires_grad}")
|
|
test_assert(meta_r.is_conj() == r.is_conj(), f"but real is_conj was {r.is_conj()}")
|
|
test_assert(meta_r.is_neg() == r.is_neg(), f"but real is_neg was {r.is_neg()}")
|
|
|
|
|
|
# This environment variable controls whether or not we print expected failure
|
|
# lists at the end of a test suite run. The intended usage looks like this:
|
|
#
|
|
# 1. Run `PYTORCH_COLLECT_EXPECT=1 python test/test_meta.py` on a CUDA build
|
|
# of PyTorch that has LAPACK/MAGMA installed. You can filter `-k test_meta`
|
|
# or `-k test_dispatch_meta` to only focus on one or another list
|
|
# 2. Given the printed skip/xfail list, add them to the corresponding lists;
|
|
# torch.* entries go in meta_function and aten.* entries go in meta_dispatch.
|
|
# If there are preexisting entries, you need to merge in the entries.
|
|
#
|
|
# This is somewhat manual but typically you shouldn't need to do this, unless
|
|
# you've made a major change (e.g., added a new dtype to PyTorch) and need to
|
|
# refresh the lists. If you want to do it from scratch, just clear out the
|
|
# preexisting lists before running.
|
|
#
|
|
# WARNING: Python dict literals will silently ignore duplicate keys
|
|
COLLECT_EXPECT = os.getenv('PYTORCH_COLLECT_EXPECT', '0') == '1'
|
|
|
|
seen_succeeded = {}
|
|
seen_failed = {}
|
|
failed_reasons = defaultdict(set)
|
|
def print_seen():
|
|
expected_failures = []
|
|
skips = []
|
|
|
|
def fmt_dtypes(dtypes):
|
|
r = ', '.join(sorted(dtype_abbrs[d] for d in dtypes))
|
|
return '{' + r + '}'
|
|
|
|
for op, failed_dtypes in seen_failed.items():
|
|
ops = resolve_name(op)
|
|
succeeded_dtypes = seen_succeeded.get(op, set())
|
|
expected_failures_dtypes = failed_dtypes - succeeded_dtypes
|
|
skips_dtypes = failed_dtypes & succeeded_dtypes
|
|
reasons = ""
|
|
if failed_reasons[op]:
|
|
reasons = " # " + ", ".join(sorted(failed_reasons[op]))
|
|
if expected_failures_dtypes:
|
|
expected_failures.append(f" {ops}: {fmt_dtypes(expected_failures_dtypes)},{reasons}")
|
|
if skips_dtypes:
|
|
skips.append(f" {ops}: {fmt_dtypes(skips_dtypes)},")
|
|
expected_failures.sort()
|
|
skips.sort()
|
|
nl = '\n'
|
|
print(f"""\
|
|
expected_failures = {{
|
|
{nl.join(expected_failures)}
|
|
}}
|
|
|
|
skips = {{
|
|
{nl.join(skips)}
|
|
}}
|
|
""")
|
|
if COLLECT_EXPECT:
|
|
atexit.register(print_seen)
|
|
|
|
# Success forces pass; failure forces fail; skip unconditionally skips testing
|
|
TestExpect = Enum("TestExpect", ("SUCCESS", "XFAILURE", "SKIP"))
|
|
|
|
# unlike print produce strides
|
|
def verbose_print(e):
|
|
class Lit:
|
|
def __init__(self, s):
|
|
self.s = s
|
|
|
|
def __repr__(self):
|
|
return self.s
|
|
|
|
def go(t):
|
|
if isinstance(t, torch.Tensor):
|
|
return Lit(f"{t} stride={t.stride()}")
|
|
else:
|
|
return t
|
|
|
|
return repr(tree_map(go, e))
|
|
|
|
def run_meta_crossref(
|
|
test_case,
|
|
test_expect,
|
|
func,
|
|
args,
|
|
kwargs,
|
|
*,
|
|
dtype,
|
|
device_type,
|
|
):
|
|
to_meta = MetaConverter()
|
|
do_meta = test_expect is not TestExpect.SKIP
|
|
|
|
if do_meta:
|
|
try:
|
|
meta_args = tree_map(to_meta, args)
|
|
meta_kwargs = tree_map(to_meta, kwargs)
|
|
except Exception as e:
|
|
raise RuntimeError(
|
|
f"failed to convert args to meta; "
|
|
f"originally (*{args}, **{kwargs})") from e
|
|
|
|
rs = func(*args, **kwargs)
|
|
|
|
# TODO: also handle cases where func raise an exception
|
|
|
|
# For now, only attempt if we managed to convert all tensor types
|
|
# (if any of them failed, we're in a mixed device situation and
|
|
# this isn't well supported)
|
|
if do_meta and to_meta.successful():
|
|
# Special cases
|
|
if func is torch.tensor_split:
|
|
# Use original indices_or_sections, this argument is data dependent
|
|
meta_args = (meta_args[0], args[1]) + meta_args[2:]
|
|
elif func is torch.ops.aten.repeat_interleave.Tensor:
|
|
if kwargs.get("output_size", None) is None:
|
|
meta_args = args
|
|
elif func is torch.ops.aten.index.Tensor:
|
|
# Don't convert boolean tensors to meta as they will have nonzero
|
|
# called on them
|
|
indices = []
|
|
for meta_index, real_index in zip(meta_args[1], args[1]):
|
|
if meta_index is not None and meta_index.dtype in [torch.int8, torch.bool]:
|
|
indices.append(real_index)
|
|
else:
|
|
indices.append(meta_index)
|
|
meta_args = (meta_args[0], indices)
|
|
try:
|
|
# Suppress warnings, this doesn't matter for test_meta.py
|
|
# but it does matter if you want to use this decorator
|
|
# for cross-ref testing, as some tests may be looking at
|
|
# errors
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter("ignore")
|
|
meta_rs = func(*meta_args, **meta_kwargs)
|
|
except Exception as e:
|
|
if test_expect is TestExpect.XFAILURE:
|
|
return rs
|
|
seen_failed.setdefault(func, set()).add(dtype)
|
|
if isinstance(e, NotImplementedError):
|
|
m = RE_NOT_IMPLEMENTED_MSG.search(e.args[0])
|
|
if m:
|
|
failed_reasons[func].add(m.group(1))
|
|
if COLLECT_EXPECT:
|
|
return rs
|
|
raise RuntimeError(f"""\
|
|
failed to run: {resolve_name(func)}(
|
|
*{verbose_print(meta_args)},
|
|
**{verbose_print(meta_kwargs)}
|
|
)""") from e
|
|
else:
|
|
try:
|
|
delim = ',\n '
|
|
assert_ref_meta_equal(test_case, meta_rs, rs, lambda msg: f"""\
|
|
meta disagrees with real impl:
|
|
{resolve_name(func)}(
|
|
{delim.join(map(verbose_print, meta_args))},
|
|
{delim.join(k + ": " + verbose_print(v) for k, v in meta_kwargs.items())}
|
|
) = (
|
|
{verbose_print(meta_rs)}
|
|
)
|
|
{msg}
|
|
""")
|
|
except Exception:
|
|
if test_expect is TestExpect.XFAILURE:
|
|
return rs
|
|
seen_failed.setdefault(func, set()).add(dtype)
|
|
if COLLECT_EXPECT:
|
|
return rs
|
|
raise
|
|
else:
|
|
seen_succeeded.setdefault(func, set()).add(dtype)
|
|
if test_expect is TestExpect.XFAILURE and not COLLECT_EXPECT:
|
|
raise RuntimeError(f"unexpected success {resolve_name(func)}")
|
|
|
|
return rs
|
|
|
|
|
|
|
|
RE_NOT_IMPLEMENTED_MSG = re.compile(r"Could not run '([^']+)' with arguments ")
|
|
|
|
meta_function_expected_failures = {
|
|
torch.Tensor.item: {b8, bf16, c128, c64, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense
|
|
torch.Tensor.to_sparse: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::to_sparse, aten::to_sparse.sparse_dim
|
|
torch.allclose: {bf16, f16, f32, f64}, # aten::_local_scalar_dense
|
|
torch.argwhere: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::nonzero
|
|
torch.bincount: {i16, i32, i64, i8, u8}, # aten::bincount
|
|
torch.bucketize: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::bucketize.Tensor, aten::bucketize.Tensor_out
|
|
torch.combinations: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::masked_select
|
|
torch.complex: {f16, f32, f64}, # aten::complex.out
|
|
torch.corrcoef: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense
|
|
torch.count_nonzero: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::count_nonzero.dim_IntList
|
|
torch.cov: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense
|
|
torch.fft.hfft2: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c
|
|
torch.fft.hfft: {b8, f32, f64, i16, i32, i64, i8, u8},
|
|
torch.fft.hfftn: {b8, f32, f64, i16, i32, i64, i8, u8}, # aten::_fft_c2c
|
|
torch.floor_divide: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::floor_divide, aten::floor_divide.out
|
|
torch.frexp: {bf16, f16, f32, f64}, # aten::frexp.Tensor_out
|
|
torch.functional.istft: {f32, f64}, # aten::view_as_complex
|
|
torch.functional.unique: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_unique2, aten::unique_dim
|
|
torch.functional.unique_consecutive: {b8, bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::unique_consecutive
|
|
torch.histc: {bf16, f32, f64}, # aten::histc, aten::histc.out
|
|
torch.histogram: {f32, f64}, # aten::histogram.bin_ct, aten::histogram.bins_tensor
|
|
torch.histogramdd: {f32, f64}, # aten::_histogramdd_bin_edges, aten::_histogramdd_from_bin_tensors
|
|
torch.kthvalue: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::kthvalue.values
|
|
torch.logcumsumexp: {bf16, f32, f64}, # aten::_logcumsumexp, aten::_logcumsumexp.out
|
|
torch.masked_select: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::masked_select, aten::masked_select.out
|
|
torch.matrix_exp: {bf16, f32, f64}, # aten::linalg_matrix_exp
|
|
torch.median: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::median, aten::median.dim_values
|
|
torch.mode: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::mode
|
|
torch.multinomial: {bf16, f32, f64}, # aten::multinomial, aten::multinomial.out
|
|
torch.mvlgamma: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::_local_scalar_dense, aten::mvlgamma.out
|
|
torch.nanmean: {bf16, f16, f32, f64},
|
|
torch.nanmedian: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::nanmedian, aten::nanmedian.dim_values
|
|
torch.nanquantile: {f32, f64},
|
|
torch.nansum: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::nansum, aten::nansum.out
|
|
torch.nn.functional.conv1d: {bf16, f32, f64, i64},
|
|
torch.nn.functional.conv2d: {bf16, f32, f64, i64},
|
|
torch.nn.functional.conv_transpose1d: {f32, f64, i64},
|
|
torch.nn.functional.conv_transpose2d: {f32, f64, i64},
|
|
torch.nn.functional.conv_transpose3d: {f32, f64, i64},
|
|
torch.nn.functional.ctc_loss: {f32, f64},
|
|
torch.nn.functional.gaussian_nll_loss: {bf16, f32, f64}, # aten::_local_scalar_dense
|
|
torch.nn.functional.grid_sample: {f32, f64}, # aten::grid_sampler_2d, aten::grid_sampler_3d
|
|
torch.nn.functional.max_pool3d: {f32, f64}, # aten::max_pool3d_with_indices
|
|
torch.nn.functional.max_pool3d_with_indices: {f32, f64}, # aten::max_pool3d_with_indices
|
|
torch.nn.functional.max_unpool1d: {f32, f64}, # aten::max_unpool2d
|
|
torch.nn.functional.max_unpool2d: {f32, f64}, # aten::max_unpool2d
|
|
torch.nn.functional.max_unpool3d: {f32, f64}, # aten::max_unpool3d
|
|
torch.nn.functional.multi_margin_loss: {f32, f64}, # aten::multi_margin_loss
|
|
torch.nn.functional.multilabel_margin_loss: {f32, f64}, # aten::multilabel_margin_loss_forward
|
|
torch.nn.functional.one_hot: {i64}, # aten::_local_scalar_dense
|
|
torch.nn.functional.pdist: {f32, f64}, # aten::_pdist_forward
|
|
torch.nn.functional.prelu: {bf16, f32, f64}, # aten::prelu
|
|
torch.nn.functional.rrelu: {bf16, f32, f64}, # aten::rrelu_with_noise
|
|
torch.nn.functional.unfold: {bf16, f16, f32, f64}, # aten::im2col
|
|
torch.nonzero: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::nonzero, aten::nonzero.out
|
|
torch.polar: {f32, f64}, # aten::polar.out
|
|
torch.repeat_interleave: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::repeat_interleave.Tensor
|
|
torch.segment_reduce: {bf16, f16, f32, f64}, # aten::segment_reduce
|
|
torch.searchsorted: {bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::searchsorted.Tensor, aten::searchsorted.Tensor_out
|
|
torch.symeig: {f32, f64},
|
|
torch.take: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8}, # aten::take, aten::take.out
|
|
torch.vdot: {bf16, f32, f64, i16, i32, i64, i8, u8}, # aten::vdot
|
|
torch.ormqr: {f32, f64},
|
|
torch.cholesky: {f32, f64}, # aten::cholesky, aten::cholesky.out
|
|
torch.cholesky_inverse: {f32, f64}, # aten::cholesky_inverse, aten::cholesky_inverse.out
|
|
torch.cholesky_solve: {f32, f64}, # aten::_cholesky_solve_helper
|
|
torch.eig: {f32, f64}, # aten::_local_scalar_dense
|
|
torch.geqrf: {f32, f64}, # aten::geqrf
|
|
torch.linalg.det: {f32, f64}, # aten::_det_lu_based_helper
|
|
torch.linalg.eig: {f32, f64}, # aten::linalg_eig
|
|
torch.linalg.eigh: {f32, f64},
|
|
torch.linalg.eigvals: {f32, f64},
|
|
torch.linalg.eigvalsh: {f32, f64}, # aten::linalg_eigvalsh.out
|
|
torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product
|
|
torch.linalg.lstsq: {f32, f64}, # aten::linalg_lstsq.out
|
|
torch.linalg.slogdet: {f32, f64}, # aten::linalg_slogdet
|
|
torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular
|
|
torch.logdet: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero
|
|
}
|
|
|
|
"""
|
|
# This is some sample code for how we could dump these dicts into YAML
|
|
# file for easier reading/writing
|
|
import yaml
|
|
print(yaml.dump(
|
|
{resolve_name(k): [dtype_abbrs[d] for d in v]
|
|
for k, v in meta_function_expected_failures.items()}, default_flow_style=None))
|
|
import sys
|
|
sys.exit()
|
|
"""
|
|
|
|
meta_function_skips = {
|
|
torch.aminmax: {b8, f32, f64, i16, i32, i64, i8, u8},
|
|
torch.cummax: {b8, bf16, f32, f64, i16, i32, i64, i8, u8},
|
|
torch.cummin: {b8, bf16, f32, f64, i16, i32, i64, i8, u8},
|
|
torch.diff: {b8},
|
|
torch.functional.cdist: {f32, f64},
|
|
torch.functional.tensordot: {bf16, f32, f64, i16, i32, i64, i8, u8},
|
|
torch.inner: {bf16, f32, f64, i16, i32, i64, i8, u8},
|
|
torch.logical_not: {b8, bf16, f16, f32, f64, i16, i32, i64, i8, u8},
|
|
torch.nn.functional.cross_entropy: {bf16, f32, f64},
|
|
torch.nn.functional.interpolate: {bf16, f32, f64, u8},
|
|
torch.nn.functional.nll_loss: {bf16, f32, f64}, # TODO
|
|
torch.linalg.pinv: {f32, f64},
|
|
torch.empty: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
|
|
}
|
|
|
|
meta_function_device_expected_failures = defaultdict(dict)
|
|
meta_function_device_skips = defaultdict(dict)
|
|
|
|
meta_function_device_expected_failures['cpu'] = {
|
|
}
|
|
|
|
meta_function_device_expected_failures['cuda'] = {
|
|
torch.corrcoef: {bf16, f16}, # aten::_local_scalar_dense
|
|
torch.cov: {f16}, # aten::_local_scalar_dense
|
|
torch.fft.fft2: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out
|
|
torch.fft.fft: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out
|
|
torch.fft.fftn: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out
|
|
torch.fft.hfft2: {c32, f16}, # aten::_fft_c2c
|
|
torch.fft.hfft: {c32, f16},
|
|
torch.fft.hfftn: {c32, f16}, # aten::_fft_c2c
|
|
torch.fft.ifft2: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out
|
|
torch.fft.ifft: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out
|
|
torch.fft.ifftn: {c32, f16}, # aten::_fft_c2c, aten::_fft_c2c.out
|
|
torch.fft.ihfft2: {f16},
|
|
torch.fft.ihfft: {f16},
|
|
torch.fft.ihfftn: {f16},
|
|
torch.fft.irfft2: {c32, f16}, # aten::_fft_c2r, aten::_fft_c2r.out
|
|
torch.fft.irfft: {c32, f16}, # aten::_fft_c2r, aten::_fft_c2r.out
|
|
torch.fft.irfftn: {c32, f16}, # aten::_fft_c2r, aten::_fft_c2r.out
|
|
torch.fft.rfft2: {f16},
|
|
torch.fft.rfft: {f16},
|
|
torch.fft.rfftn: {f16},
|
|
torch.functional.unique: {f16}, # aten::_unique2, aten::unique_dim
|
|
torch.functional.unique_consecutive: {f16}, # aten::unique_consecutive
|
|
torch.geqrf: {f32, f64}, # aten::geqrf
|
|
torch.histc: {i16, i32, i64, i8}, # aten::histc, aten::histc.out
|
|
torch.kthvalue: {f16}, # aten::kthvalue.values
|
|
torch.linalg.householder_product: {f32, f64}, # aten::linalg_householder_product, aten::linalg_householder_product.out
|
|
torch.linalg.solve_triangular: {f32, f64}, # aten::linalg_solve_triangular, aten::linalg_solve_triangular.out
|
|
torch.logcumsumexp: {bf16, f16}, # aten::_logcumsumexp, aten::_logcumsumexp.out
|
|
torch.matrix_exp: {f16}, # aten::linalg_matrix_exp
|
|
torch.median: {f16}, # aten::median, aten::median.dim_values
|
|
torch.multinomial: {f16}, # aten::multinomial, aten::multinomial.out
|
|
torch.mvlgamma: {f16}, # aten::_local_scalar_dense, aten::mvlgamma.out
|
|
torch.nanmedian: {f16}, # aten::nanmedian, aten::nanmedian.dim_values
|
|
torch.nn.functional.conv1d: {f16, c32},
|
|
torch.nn.functional.conv2d: {f16, c32},
|
|
torch.nn.functional.conv_transpose1d: {bf16, f16},
|
|
torch.nn.functional.conv_transpose2d: {bf16, f16},
|
|
torch.nn.functional.conv_transpose3d: {bf16, f16},
|
|
torch.nn.functional.gaussian_nll_loss: {f16}, # aten::_local_scalar_dense
|
|
torch.nn.functional.grid_sample: {f16}, # aten::grid_sampler_2d, aten::grid_sampler_3d
|
|
torch.nn.functional.max_pool3d: {bf16, f16}, # aten::max_pool3d_with_indices
|
|
torch.nn.functional.max_pool3d_with_indices: {bf16, f16}, # aten::max_pool3d_with_indices
|
|
torch.nn.functional.max_unpool1d: {f16}, # aten::max_unpool2d
|
|
torch.nn.functional.max_unpool2d: {f16}, # aten::max_unpool2d
|
|
torch.nn.functional.max_unpool3d: {f16}, # aten::max_unpool3d
|
|
torch.nn.functional.multi_margin_loss: {bf16, f16}, # aten::multi_margin_loss
|
|
torch.nn.functional.multilabel_margin_loss: {bf16, f16}, # aten::multilabel_margin_loss_forward
|
|
torch.nn.functional.prelu: {f16}, # aten::prelu
|
|
torch.nn.functional.rrelu: {f16}, # aten::rrelu_with_noise
|
|
torch.ormqr: {f32, f64}, # aten::ormqr, aten::ormqr.out
|
|
torch.vdot: {f16}, # aten::vdot
|
|
}
|
|
|
|
meta_function_device_skips['cuda'] = {
|
|
torch.cummax: {f16},
|
|
torch.cummin: {f16},
|
|
torch.functional.tensordot: {f16},
|
|
torch.inner: {f16},
|
|
torch.linalg.matrix_power: {f32, f64},
|
|
torch.linalg.matrix_rank: {f32, f64},
|
|
torch.linalg.svd: {f32, f64},
|
|
torch.nn.functional.cross_entropy: {f16},
|
|
torch.nn.functional.interpolate: {f16},
|
|
torch.nn.functional.nll_loss: {f16},
|
|
torch.svd: {f32, f64},
|
|
}
|
|
|
|
# This is a __torch_function__ mode that, when enabled, interposes every
|
|
# Torch API call and runs the operator as normal, and then reruns it
|
|
# with meta inputs, and then checks that everything about the output agrees.
|
|
# Most of the logic deals with faithfully replicating the original tensor
|
|
# as a meta tensor, which is nontrivial because there are a lot of subsystems
|
|
# that may potentially be exercised.
|
|
#
|
|
# That being said, this class is a little overkill for what it is doing in
|
|
# this test file (since I could have just inlined __torch_function__ on the
|
|
# OpInfo call, and OpInfos generally have very regular inputs), but it will be
|
|
# useful for more comprehensive testing e.g., as seen in
|
|
# https://github.com/pytorch/pytorch/pull/75994 The big benefit is it is
|
|
# A LOT more efficient that torch dispatch mode (at the cost of less coverage)
|
|
class MetaCrossRefFunctionMode(torch.overrides.TorchFunctionMode):
|
|
test_case: TestCase
|
|
device_type: str
|
|
dtype: torch.dtype
|
|
|
|
def __init__(self, test_case, *, device, dtype):
|
|
self.test_case = test_case
|
|
self.device_type = torch.device(device).type
|
|
self.dtype = dtype
|
|
|
|
def __torch_function__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
|
|
if torch.jit.is_tracing() or isinstance(func, torch.ScriptMethod):
|
|
return func(*args, **kwargs)
|
|
|
|
if self.dtype in meta_function_skips.get(func, set()):
|
|
test_expect = TestExpect.SKIP
|
|
elif self.dtype in meta_function_device_skips[self.device_type].get(func, set()):
|
|
test_expect = TestExpect.SKIP
|
|
elif self.dtype in meta_function_expected_failures.get(func, set()):
|
|
test_expect = TestExpect.XFAILURE
|
|
elif self.dtype in meta_function_device_expected_failures[self.device_type].get(func, set()):
|
|
test_expect = TestExpect.XFAILURE
|
|
else:
|
|
test_expect = TestExpect.SUCCESS
|
|
|
|
return run_meta_crossref(
|
|
self.test_case, test_expect, func, args,
|
|
kwargs, dtype=self.dtype, device_type=self.device_type
|
|
)
|
|
|
|
aten = torch.ops.aten
|
|
|
|
# these always fail
|
|
meta_dispatch_expected_failures = {
|
|
aten._convolution.default: {c64, i64, f64, c128, bf16, f32},
|
|
aten._ctc_loss.default: {f64, f32},
|
|
aten._histogramdd_bin_edges.default: {f64, f32},
|
|
aten._histogramdd_from_bin_cts.default: {f64, f32},
|
|
aten._histogramdd_from_bin_tensors.default: {f64, f32},
|
|
aten._local_scalar_dense.default: {c64, i64, c128, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten._pdist_forward.default: {f64, f32},
|
|
aten._unique2.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.bincount.default: {i8, i64, i16, u8, i32},
|
|
aten.bucketize.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
|
aten.bucketize.Tensor_out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
|
aten.col2im.default: {c64, f32, f64, c128},
|
|
aten.complex.default: {c64, f64, c128, f16, f32},
|
|
aten.complex.out: {f16},
|
|
aten.convolution.default: {c64, i64, f64, c128, bf16, f32},
|
|
aten.count_nonzero.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.count_nonzero.dim_IntList: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.floor_divide.default: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
|
aten.floor_divide.out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
|
aten.frexp.Tensor: {bf16, f16, f64, f32},
|
|
aten.grid_sampler_2d.default: {f64, f32},
|
|
aten.grid_sampler_3d.default: {f64, f32},
|
|
aten.histc.default: {bf16, f64, f32},
|
|
aten.histc.out: {bf16, f64, f32},
|
|
aten.histogram.bin_ct: {f64, f32},
|
|
aten.histogram.bins_tensor: {f64, f32},
|
|
aten.im2col.default: {bf16, f16, f64, f32},
|
|
aten.kthvalue.default: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
|
aten.linalg_matrix_exp.default: {bf16, f64, f32},
|
|
aten.log_sigmoid_forward.output: {bf16, f64, f32},
|
|
aten.logcumsumexp.default: {bf16, f64, f32},
|
|
aten.logcumsumexp.out: {bf16, f64, f32},
|
|
aten.logical_not.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.logical_not_.default: {bf16, f16, f64, f32},
|
|
aten.masked_select.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.masked_select.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.max_pool3d_with_indices.default: {f64, f32},
|
|
aten.max_unpool2d.default: {f64, f32},
|
|
aten.max_unpool3d.default: {f64, f32},
|
|
aten.median.default: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
|
aten.median.dim: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
|
aten.mode.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.multi_margin_loss.default: {f64, f32},
|
|
aten.multilabel_margin_loss_forward.default: {f64, f32},
|
|
aten.multinomial.default: {bf16, f64, f32},
|
|
aten.multinomial.out: {bf16, f64, f32},
|
|
aten.mvlgamma.default: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
|
aten.mvlgamma.out: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
|
aten.nanmedian.default: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
|
aten.nanmedian.dim: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
|
aten.nansum.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.nansum.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.nll_loss2d_forward.default: {bf16, f64, f32},
|
|
aten.nonzero.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.nonzero.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.polar.default: {f64, f32},
|
|
aten.prelu.default: {bf16, f64, f32},
|
|
aten.rrelu_with_noise.default: {bf16, f64, f32},
|
|
aten.searchsorted.Tensor: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
|
aten.searchsorted.Tensor_out: {i64, bf16, f16, u8, f32, i8, f64, i16, i32},
|
|
aten.segment_reduce.default: {bf16, f16, f32, f64},
|
|
aten.take.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.take.out: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.tensordot.out: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
|
aten.to_sparse.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.to_sparse.sparse_dim: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.unique_consecutive.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.unique_dim.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.upsample_nearest3d.vec: {bf16, u8, f64, f32},
|
|
aten.vdot.default: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
|
aten.vdot.out: {i64, bf16, u8, f32, i8, f64, i16, i32},
|
|
aten._det_lu_based_helper.default: {f32, f64}, # aten::_det_lu_based_helper
|
|
aten.cholesky.default: {f32, f64}, # aten::cholesky
|
|
aten.cholesky.out: {f32, f64}, # aten::cholesky.out
|
|
aten.cholesky_inverse.default: {f32, f64}, # aten::cholesky_inverse
|
|
aten.cholesky_inverse.out: {f32, f64}, # aten::cholesky_inverse.out
|
|
aten.cholesky_solve.default: {f32, f64}, # aten::_cholesky_solve_helper
|
|
aten.cholesky_solve.out: {f32, f64}, # aten::_cholesky_solve_helper
|
|
aten.eig.default: {f32, f64}, # aten::_local_scalar_dense
|
|
aten.geqrf.default: {f32, f64}, # aten::geqrf
|
|
aten.linalg_eig.default: {f32, f64}, # aten::linalg_eig
|
|
aten.linalg_eigh.default: {f32, f64},
|
|
aten.linalg_eigvalsh.out: {f32, f64}, # aten::linalg_eigvalsh.out
|
|
aten.linalg_householder_product.default: {f32, f64}, # aten::linalg_householder_product
|
|
aten.linalg_householder_product.out: {f32, f64}, # aten::linalg_householder_product.out
|
|
aten.linalg_lstsq.default: {f32, f64}, # aten::linalg_lstsq.out
|
|
aten.linalg_slogdet.default: {f32, f64}, # aten::linalg_slogdet
|
|
aten.linalg_solve_triangular.default: {f32, f64}, # aten::linalg_solve_triangular
|
|
aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out
|
|
aten.logdet.default: {f32, f64}, # aten::_local_scalar_dense, aten::nonzero
|
|
aten.ormqr.default: {f32, f64}, # aten::ormqr
|
|
aten.ormqr.out: {f32, f64}, # aten::ormqr.out
|
|
aten.symeig.default: {f32, f64}, # aten::_symeig_helper
|
|
}
|
|
|
|
# these sometimes pass and sometimes fail
|
|
meta_dispatch_skips = {
|
|
aten.index.Tensor: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32, c32}, # at::nonzero doesn't have a Meta function
|
|
aten._to_copy.default: {i64, bf16, f16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.aminmax.default: {i64, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.cummax.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.cummin.default: {i64, bf16, u8, b8, f32, i8, f64, i16, i32},
|
|
aten.linalg_pinv.atol_rtol_tensor: {f32, f64},
|
|
aten.linalg_pinv.atol_rtol_tensor_out: {f32, f64},
|
|
aten.empty.memory_format: {b8, bf16, c128, c64, c32, f16, f32, f64, i16, i32, i64, i8, u8},
|
|
}
|
|
|
|
meta_dispatch_device_expected_failures = defaultdict(dict)
|
|
meta_dispatch_device_skips = defaultdict(dict)
|
|
|
|
meta_dispatch_device_expected_failures['cuda'] = {
|
|
aten._convolution.default: {f16, c32},
|
|
aten._unique2.default: {f16}, # aten::_unique2
|
|
aten._use_cudnn_ctc_loss.default: {f32, f64}, # aten::_use_cudnn_ctc_loss
|
|
aten.convolution.default: {f16, c32},
|
|
aten.cudnn_grid_sampler.default: {f16, f32, f64}, # aten::cudnn_grid_sampler
|
|
aten.geqrf.default: {f32, f64}, # aten::geqrf
|
|
aten.grid_sampler_2d.default: {f16}, # aten::grid_sampler_2d
|
|
aten.grid_sampler_3d.default: {f16}, # aten::grid_sampler_3d
|
|
aten.histc.default: {i16, i32, i64, i8}, # aten::histc
|
|
aten.histc.out: {i16, i32, i64, i8}, # aten::histc.out
|
|
aten.kthvalue.default: {f16}, # aten::kthvalue.values
|
|
aten.linalg_eigvalsh.out: {f32, f64}, # aten::linalg_eigvalsh.out
|
|
aten.linalg_householder_product.default: {f32, f64}, # aten::linalg_householder_product
|
|
aten.linalg_householder_product.out: {f32, f64}, # aten::linalg_householder_product.out
|
|
aten.linalg_matrix_exp.default: {f16}, # aten::linalg_matrix_exp
|
|
aten.linalg_solve_triangular.default: {f32, f64}, # aten::linalg_solve_triangular
|
|
aten.linalg_solve_triangular.out: {f32, f64}, # aten::linalg_solve_triangular.out
|
|
aten.log_sigmoid_forward.default: {bf16, f16, f64, f32},
|
|
aten.log_sigmoid_forward.output: {f16}, # aten::log_sigmoid_forward.output
|
|
aten.logcumsumexp.default: {bf16, f16}, # aten::_logcumsumexp
|
|
aten.logcumsumexp.out: {bf16, f16}, # aten::_logcumsumexp.out
|
|
aten.max_pool3d_with_indices.default: {bf16, f16}, # aten::max_pool3d_with_indices
|
|
aten.max_unpool2d.default: {f16}, # aten::max_unpool2d
|
|
aten.max_unpool3d.default: {f16}, # aten::max_unpool3d
|
|
aten.median.default: {f16}, # aten::median
|
|
aten.median.dim: {f16}, # aten::median.dim_values
|
|
aten.multi_margin_loss.default: {bf16, f16}, # aten::multi_margin_loss
|
|
aten.multilabel_margin_loss_forward.default: {bf16, f16}, # aten::multilabel_margin_loss_forward
|
|
aten.multinomial.default: {f16}, # aten::multinomial
|
|
aten.multinomial.out: {f16}, # aten::multinomial.out
|
|
aten.mvlgamma.default: {f16}, # aten::_local_scalar_dense
|
|
aten.mvlgamma.out: {f16}, # aten::mvlgamma.out
|
|
aten.nanmedian.default: {f16}, # aten::nanmedian
|
|
aten.nanmedian.dim: {f16}, # aten::nanmedian.dim_values
|
|
aten.native_group_norm.default: {bf16, f16},
|
|
aten.nll_loss2d_forward.default: {f16}, # aten::nll_loss2d_forward
|
|
aten.ormqr.default: {f32, f64}, # aten::ormqr
|
|
aten.ormqr.out: {f32, f64}, # aten::ormqr.out
|
|
aten.prelu.default: {f16}, # aten::prelu
|
|
aten.rrelu_with_noise.default: {f16}, # aten::rrelu_with_noise
|
|
aten.tensordot.out: {f16}, # aten::tensordot.out
|
|
aten.unique_consecutive.default: {f16}, # aten::unique_consecutive
|
|
aten.unique_dim.default: {f16}, # aten::unique_dim
|
|
aten.upsample_nearest3d.vec: {f16}, # aten::upsample_nearest3d.vec
|
|
aten.vdot.default: {f16}, # aten::vdot
|
|
aten.vdot.out: {f16}, # aten::vdot
|
|
}
|
|
|
|
meta_dispatch_device_skips['cuda'] = {
|
|
aten._conj.default: {c32, f16},
|
|
aten.cudnn_batch_norm.default: {f32, f64},
|
|
aten.cummax.default: {f16},
|
|
aten.cummin.default: {f16},
|
|
# ROCm stuff; technically this should be expected failure but it's
|
|
# not worth it; these should get unified anyway
|
|
aten.miopen_batch_norm.default: {f32},
|
|
}
|
|
|
|
class MetaCrossRefDispatchMode(torch.utils._python_dispatch.TorchDispatchMode):
|
|
test_case: TestCase
|
|
device: torch.device
|
|
dtype: torch.dtype
|
|
|
|
def __init__(self, test_case, *, device, dtype):
|
|
self.test_case = test_case
|
|
# save TLS
|
|
self.precision = test_case.precision
|
|
self.rel_tol = test_case.rel_tol
|
|
self.device_type = torch.device(device).type
|
|
self.dtype = dtype
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
kwargs = kwargs or {}
|
|
|
|
self.test_case.precision = self.precision
|
|
self.test_case.rel_tol = self.rel_tol
|
|
|
|
if self.dtype in meta_dispatch_skips.get(func, set()):
|
|
test_expect = TestExpect.SKIP
|
|
elif self.dtype in meta_dispatch_device_skips[self.device_type].get(func, set()):
|
|
test_expect = TestExpect.SKIP
|
|
elif self.dtype in meta_dispatch_expected_failures.get(func, set()):
|
|
test_expect = TestExpect.XFAILURE
|
|
elif self.dtype in meta_dispatch_device_expected_failures[self.device_type].get(func, set()):
|
|
test_expect = TestExpect.XFAILURE
|
|
else:
|
|
test_expect = TestExpect.SUCCESS
|
|
|
|
return run_meta_crossref(
|
|
self.test_case,
|
|
test_expect,
|
|
func,
|
|
args,
|
|
kwargs,
|
|
dtype=self.dtype,
|
|
device_type=self.device_type,
|
|
)
|
|
|
|
|
|
# NB: we're running these tests only on CUDA because there are some
|
|
# inconsistencies between CUDA and CPU, and running on CUDA makes it easier
|
|
# to ignore the CPU case when inconsistencies arise. Ideally we deal
|
|
# with the inconsistencies but this takes time.
|
|
class TestMeta(TestCase):
|
|
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
|
@onlyCUDA
|
|
@skipIfCrossRef
|
|
@suppress_warnings
|
|
@ops(op_db)
|
|
def test_meta(self, device, dtype, op):
|
|
# run the OpInfo sample inputs, cross-referencing them with the
|
|
# meta implementation and check the results are the same. All
|
|
# the heavy lifting happens in MetaCrossRefFunctionMode
|
|
func = op.get_op()
|
|
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
|
for sample_input in samples:
|
|
args = [sample_input.input] + list(sample_input.args)
|
|
kwargs = sample_input.kwargs
|
|
with MetaCrossRefFunctionMode.push(self, dtype=dtype, device=device):
|
|
expected = func(*args, **kwargs)
|
|
if isinstance(expected, torch.Tensor) and op.supports_out:
|
|
func(*args, **kwargs, out=expected)
|
|
|
|
@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
|
|
@onlyCUDA
|
|
@skipIfCrossRef
|
|
@suppress_warnings
|
|
@ops(op_db)
|
|
def test_dispatch_meta(self, device, dtype, op):
|
|
func = op.get_op()
|
|
samples = op.sample_inputs(device, dtype, requires_grad=False)
|
|
for sample_input in samples:
|
|
args = [sample_input.input] + list(sample_input.args)
|
|
kwargs = sample_input.kwargs
|
|
with MetaCrossRefDispatchMode.push(self, dtype=dtype, device=device):
|
|
expected = func(*args, **kwargs)
|
|
if isinstance(expected, torch.Tensor) and op.supports_out:
|
|
func(*args, **kwargs, out=expected)
|
|
|
|
instantiate_device_type_tests(TestMeta, globals())
|
|
|
|
def print_op_str_if_not_supported(op_str):
|
|
op = OperatorName.parse(op_str)
|
|
packet = getattr(torch.ops.aten, str(op.name))
|
|
overload = getattr(packet, op.overload_name if op.overload_name else "default")
|
|
if any(overload in d for d in [meta_dispatch_skips, meta_dispatch_device_skips['cuda']]):
|
|
print(f"{overload} # SKIP")
|
|
if any(overload in d for d in [meta_dispatch_expected_failures, meta_dispatch_device_expected_failures['cuda']]):
|
|
print(overload)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
COMPARE_XLA = os.getenv('PYTORCH_COMPARE_XLA', None)
|
|
if COMPARE_XLA is not None:
|
|
with open(COMPARE_XLA, "r") as f:
|
|
d = yaml.load(f, Loader=YamlLoader)
|
|
ops = d.get("full_codegen", []) + d.get("supported", []) + d.get("autograd", [])
|
|
for op_str in ops:
|
|
print_op_str_if_not_supported(op_str)
|
|
sys.exit(0)
|
|
|
|
COMPARE_TEXT = os.getenv('PYTORCH_COMPARE_TEXT', None)
|
|
if COMPARE_TEXT is not None:
|
|
with open(COMPARE_TEXT, "r") as f:
|
|
for op_str in f:
|
|
print_op_str_if_not_supported(op_str.strip())
|
|
sys.exit(0)
|
|
|
|
run_tests()
|