mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This reverts commit b0c7dd47cd.
Reverted https://github.com/pytorch/pytorch/pull/115402 on behalf of https://github.com/atalman due to OSSCI oncall, broke CI tests ([comment](https://github.com/pytorch/pytorch/pull/115402#issuecomment-1853864075))
44 lines
1.4 KiB
Python
44 lines
1.4 KiB
Python
import functools
|
|
import unittest
|
|
from unittest.mock import patch
|
|
|
|
|
|
def make_test_cls_with_mocked_export(
|
|
cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None
|
|
):
|
|
MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
|
|
MockedTestClass.__qualname__ = MockedTestClass.__name__
|
|
|
|
for name in dir(cls):
|
|
if name.startswith("test_"):
|
|
fn = getattr(cls, name)
|
|
if not callable(fn):
|
|
setattr(MockedTestClass, name, getattr(cls, name))
|
|
continue
|
|
new_name = f"{name}{fn_suffix}"
|
|
new_fn = _make_fn_with_mocked_export(fn, mocked_export_fn)
|
|
new_fn.__name__ = new_name
|
|
if xfail_prop is not None and hasattr(fn, xfail_prop):
|
|
new_fn = unittest.expectedFailure(new_fn)
|
|
setattr(MockedTestClass, new_name, new_fn)
|
|
# NB: Doesn't handle slots correctly, but whatever
|
|
elif not hasattr(MockedTestClass, name):
|
|
setattr(MockedTestClass, name, getattr(cls, name))
|
|
|
|
return MockedTestClass
|
|
|
|
|
|
def _make_fn_with_mocked_export(fn, mocked_export_fn):
|
|
@functools.wraps(fn)
|
|
def _fn(*args, **kwargs):
|
|
with patch("test_export.export", mocked_export_fn):
|
|
return fn(*args, **kwargs)
|
|
|
|
return _fn
|
|
|
|
|
|
# Controls tests generated in test/export/test_export_nonstrict.py
|
|
def expectedFailureNonStrict(fn):
|
|
fn._expected_failure_non_strict = True
|
|
return fn
|