mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Differential Revision: [D65362160](https://our.internmc.facebook.com/intern/diff/D65362160) State after this IR: 1. For the tests that require inference IR, they are replaced with ep.run_decomp({}) so export_for_training_run_decomp is sort of redundant but i guess it is still nice that multiple round of retracing still working. In general, we need some auditing to reduce our redundant testing coverages. 2. After this PR landed and not get reverted for a week or so, i will replace the export_for_training calls with export as they are the same thing now. 3. Added more tests to also cover now "deprecated" old IR by patching export to use old export. For reviewers, please look at the internal version. Pull Request resolved: https://github.com/pytorch/pytorch/pull/139511 Approved by: https://github.com/ydwu4, https://github.com/angelayi, https://github.com/avikchaudhuri
303 lines
8.3 KiB
Python
303 lines
8.3 KiB
Python
import functools
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
|
|
aten = torch.ops.aten
|
|
|
|
# This list is not meant to be comprehensive
|
|
_COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [
|
|
aten.arctan2.default,
|
|
aten.divide.Tensor,
|
|
aten.divide.Scalar,
|
|
aten.divide.Tensor_mode,
|
|
aten.divide.Scalar_mode,
|
|
aten.multiply.Tensor,
|
|
aten.multiply.Scalar,
|
|
aten.subtract.Tensor,
|
|
aten.subtract.Scalar,
|
|
aten.true_divide.Tensor,
|
|
aten.true_divide.Scalar,
|
|
aten.greater.Tensor,
|
|
aten.greater.Scalar,
|
|
aten.greater_equal.Tensor,
|
|
aten.greater_equal.Scalar,
|
|
aten.less_equal.Tensor,
|
|
aten.less_equal.Scalar,
|
|
aten.less.Tensor,
|
|
aten.less.Scalar,
|
|
aten.not_equal.Tensor,
|
|
aten.not_equal.Scalar,
|
|
aten.cat.names,
|
|
aten.sum.dim_DimnameList,
|
|
aten.mean.names_dim,
|
|
aten.prod.dim_Dimname,
|
|
aten.all.dimname,
|
|
aten.norm.names_ScalarOpt_dim,
|
|
aten.norm.names_ScalarOpt_dim_dtype,
|
|
aten.var.default,
|
|
aten.var.dim,
|
|
aten.var.names_dim,
|
|
aten.var.correction_names,
|
|
aten.std.default,
|
|
aten.std.dim,
|
|
aten.std.names_dim,
|
|
aten.std.correction_names,
|
|
aten.absolute.default,
|
|
aten.arccos.default,
|
|
aten.arccosh.default,
|
|
aten.arcsin.default,
|
|
aten.arcsinh.default,
|
|
aten.arctan.default,
|
|
aten.arctanh.default,
|
|
aten.clip.default,
|
|
aten.clip.Tensor,
|
|
aten.fix.default,
|
|
aten.negative.default,
|
|
aten.square.default,
|
|
aten.size.int,
|
|
aten.size.Dimname,
|
|
aten.stride.int,
|
|
aten.stride.Dimname,
|
|
aten.repeat_interleave.self_Tensor,
|
|
aten.repeat_interleave.self_int,
|
|
aten.sym_size.int,
|
|
aten.sym_stride.int,
|
|
aten.atleast_1d.Sequence,
|
|
aten.atleast_2d.Sequence,
|
|
aten.atleast_3d.Sequence,
|
|
aten.linear.default,
|
|
aten.conv2d.default,
|
|
aten.conv2d.padding,
|
|
aten.mish_backward.default,
|
|
aten.silu_backward.default,
|
|
aten.index_add.dimname,
|
|
aten.pad_sequence.default,
|
|
aten.index_copy.dimname,
|
|
aten.upsample_nearest1d.vec,
|
|
aten.upsample_nearest2d.vec,
|
|
aten.upsample_nearest3d.vec,
|
|
aten._upsample_nearest_exact1d.vec,
|
|
aten._upsample_nearest_exact2d.vec,
|
|
aten._upsample_nearest_exact3d.vec,
|
|
aten.rnn_tanh.input,
|
|
aten.rnn_tanh.data,
|
|
aten.rnn_relu.input,
|
|
aten.rnn_relu.data,
|
|
aten.lstm.input,
|
|
aten.lstm.data,
|
|
aten.gru.input,
|
|
aten.gru.data,
|
|
aten._upsample_bilinear2d_aa.vec,
|
|
aten._upsample_bicubic2d_aa.vec,
|
|
aten.upsample_bilinear2d.vec,
|
|
aten.upsample_trilinear3d.vec,
|
|
aten.upsample_linear1d.vec,
|
|
aten.matmul.default,
|
|
aten.upsample_bicubic2d.vec,
|
|
aten.__and__.Scalar,
|
|
aten.__and__.Tensor,
|
|
aten.__or__.Tensor,
|
|
aten.__or__.Scalar,
|
|
aten.__xor__.Tensor,
|
|
aten.__xor__.Scalar,
|
|
aten.scatter.dimname_src,
|
|
aten.scatter.dimname_value,
|
|
aten.scatter_add.dimname,
|
|
aten.is_complex.default,
|
|
aten.logsumexp.names,
|
|
aten.where.ScalarOther,
|
|
aten.where.ScalarSelf,
|
|
aten.where.Scalar,
|
|
aten.where.default,
|
|
aten.item.default,
|
|
aten.any.dimname,
|
|
aten.std_mean.default,
|
|
aten.std_mean.dim,
|
|
aten.std_mean.names_dim,
|
|
aten.std_mean.correction_names,
|
|
aten.var_mean.default,
|
|
aten.var_mean.dim,
|
|
aten.var_mean.names_dim,
|
|
aten.var_mean.correction_names,
|
|
aten.broadcast_tensors.default,
|
|
aten.stft.default,
|
|
aten.stft.center,
|
|
aten.istft.default,
|
|
aten.index_fill.Dimname_Scalar,
|
|
aten.index_fill.Dimname_Tensor,
|
|
aten.index_select.dimname,
|
|
aten.diag.default,
|
|
aten.cumsum.dimname,
|
|
aten.cumprod.dimname,
|
|
aten.meshgrid.default,
|
|
aten.meshgrid.indexing,
|
|
aten.fft_fft.default,
|
|
aten.fft_ifft.default,
|
|
aten.fft_rfft.default,
|
|
aten.fft_irfft.default,
|
|
aten.fft_hfft.default,
|
|
aten.fft_ihfft.default,
|
|
aten.fft_fftn.default,
|
|
aten.fft_ifftn.default,
|
|
aten.fft_rfftn.default,
|
|
aten.fft_ihfftn.default,
|
|
aten.fft_irfftn.default,
|
|
aten.fft_hfftn.default,
|
|
aten.fft_fft2.default,
|
|
aten.fft_ifft2.default,
|
|
aten.fft_rfft2.default,
|
|
aten.fft_irfft2.default,
|
|
aten.fft_hfft2.default,
|
|
aten.fft_ihfft2.default,
|
|
aten.fft_fftshift.default,
|
|
aten.fft_ifftshift.default,
|
|
aten.selu.default,
|
|
aten.margin_ranking_loss.default,
|
|
aten.hinge_embedding_loss.default,
|
|
aten.nll_loss.default,
|
|
aten.prelu.default,
|
|
aten.relu6.default,
|
|
aten.pairwise_distance.default,
|
|
aten.pdist.default,
|
|
aten.special_ndtr.default,
|
|
aten.cummax.dimname,
|
|
aten.cummin.dimname,
|
|
aten.logcumsumexp.dimname,
|
|
aten.max.other,
|
|
aten.max.names_dim,
|
|
aten.min.other,
|
|
aten.min.names_dim,
|
|
aten.linalg_eigvals.default,
|
|
aten.median.names_dim,
|
|
aten.nanmedian.names_dim,
|
|
aten.mode.dimname,
|
|
aten.gather.dimname,
|
|
aten.sort.dimname,
|
|
aten.sort.dimname_stable,
|
|
aten.argsort.default,
|
|
aten.argsort.dimname,
|
|
aten.rrelu.default,
|
|
aten.conv_transpose1d.default,
|
|
aten.conv_transpose2d.input,
|
|
aten.conv_transpose3d.input,
|
|
aten.conv1d.default,
|
|
aten.conv1d.padding,
|
|
aten.conv3d.default,
|
|
aten.conv3d.padding,
|
|
aten.float_power.Tensor_Tensor,
|
|
aten.float_power.Tensor_Scalar,
|
|
aten.float_power.Scalar,
|
|
aten.ldexp.Tensor,
|
|
aten._version.default,
|
|
]
|
|
|
|
|
|
def make_test_cls_with_mocked_export(
|
|
cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None
|
|
):
|
|
MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
|
|
MockedTestClass.__qualname__ = MockedTestClass.__name__
|
|
|
|
for name in dir(cls):
|
|
if name.startswith("test_"):
|
|
fn = getattr(cls, name)
|
|
if not callable(fn):
|
|
setattr(MockedTestClass, name, getattr(cls, name))
|
|
continue
|
|
new_name = f"{name}{fn_suffix}"
|
|
new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn)
|
|
new_fn.__name__ = new_name
|
|
if xfail_prop is not None and hasattr(fn, xfail_prop):
|
|
new_fn = unittest.expectedFailure(new_fn)
|
|
setattr(MockedTestClass, new_name, new_fn)
|
|
# NB: Doesn't handle slots correctly, but whatever
|
|
elif not hasattr(MockedTestClass, name):
|
|
setattr(MockedTestClass, name, getattr(cls, name))
|
|
|
|
return MockedTestClass
|
|
|
|
|
|
def _make_fn_with_mocked_export(fn, mocked_export_fn):
|
|
@functools.wraps(fn)
|
|
def _fn(*args, **kwargs):
|
|
try:
|
|
from . import test_export
|
|
except ImportError:
|
|
import test_export # @manual=fbcode//caffe2/test:test_export-library
|
|
|
|
with patch(f"{test_export.__name__}.export", mocked_export_fn):
|
|
return fn(*args, **kwargs)
|
|
|
|
return _fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
|
|
def expectedFailureTrainingIRToRunDecomp(fn):
|
|
fn._expected_failure_training_ir_to_run_decomp = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_training_ir_to_run_decomp.py
|
|
def expectedFailureTrainingIRToRunDecompNonStrict(fn):
|
|
fn._expected_failure_training_ir_to_run_decomp_non_strict = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_nonstrict.py
|
|
def expectedFailureNonStrict(fn):
|
|
fn._expected_failure_non_strict = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_retraceability.py
|
|
def expectedFailureRetraceability(fn):
|
|
fn._expected_failure_retrace = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_retraceability.py
|
|
def expectedFailureRetraceabilityNonStrict(fn):
|
|
fn._expected_failure_retrace_non_strict = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_serdes.py
|
|
def expectedFailureSerDer(fn):
|
|
fn._expected_failure_serdes = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_serdes.py
|
|
def expectedFailureSerDerNonStrict(fn):
|
|
fn._expected_failure_serdes_non_strict = True
|
|
return fn
|
|
|
|
|
|
def expectedFailureSerDerPreDispatch(fn):
|
|
fn._expected_failure_serdes_pre_dispatch = True
|
|
return fn
|
|
|
|
|
|
def expectedFailurePreDispatchRunDecomp(fn):
|
|
fn._expected_failure_pre_dispatch = True
|
|
return fn
|
|
|
|
|
|
def expectedFailureCppSerDes(fn):
|
|
fn._expected_failure_cpp_serdes = True
|
|
return fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_legacy.py
|
|
def expectedFailureLegacyExportStrict(fn):
|
|
fn._expected_failure_legacy_export = True
|
|
return fn
|
|
|
|
|
|
def expectedFailureLegacyExportNonStrict(fn):
|
|
fn._expected_failure_legacy_export_non_strict = True
|
|
return fn
|