mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
reland "[sigmoid] Test OSS model runner with test_export.py" (#147535)
Summary: There are ~260 tests for all the corner cases of export from test_export.py. utitlizing to test sigmoid in the OSS setting. Test Plan: buck test mode/opt caffe2/test:test_export -- -r _sigmoid Differential Revision: D69937387 Pull Request resolved: https://github.com/pytorch/pytorch/pull/147535 Approved by: https://github.com/yiming0416
This commit is contained in:
parent
87e6e2924e
commit
fdb1305ace
|
|
@ -5090,22 +5090,22 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
return x0, x1, x2
|
return x0, x1, x2
|
||||||
|
|
||||||
inps = (
|
inps = (
|
||||||
[
|
(
|
||||||
{"data": torch.randn(4, 4)},
|
{"data": torch.randn(4, 4)},
|
||||||
torch.randn(4, 4),
|
torch.randn(4, 4),
|
||||||
torch.randn(6, 4),
|
torch.randn(6, 4),
|
||||||
],
|
),
|
||||||
{
|
{
|
||||||
"a": torch.randn(8, 4),
|
"a": torch.randn(8, 4),
|
||||||
"b": torch.randn(9, 6),
|
"b": torch.randn(9, 6),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
dynamic_shapes = {
|
dynamic_shapes = {
|
||||||
"x": [
|
"x": (
|
||||||
{"data": (Dim("dx00"), Dim("dx01"))},
|
{"data": (Dim("dx00"), Dim("dx01"))},
|
||||||
(Dim("dx10"), Dim("dx11")),
|
(Dim("dx10"), Dim("dx11")),
|
||||||
(Dim("dx20"), Dim("dx21")),
|
(Dim("dx20"), Dim("dx21")),
|
||||||
],
|
),
|
||||||
"y": {
|
"y": {
|
||||||
"a": (Dim("dya0"), Dim("dya1")),
|
"a": (Dim("dya0"), Dim("dya1")),
|
||||||
"b": (Dim("dyb0"), Dim("dyb1")),
|
"b": (Dim("dyb0"), Dim("dyb1")),
|
||||||
|
|
@ -5984,7 +5984,7 @@ def forward(self, x):
|
||||||
a = x.item()
|
a = x.item()
|
||||||
torch._check(a >= 4)
|
torch._check(a >= 4)
|
||||||
torch._check(a <= 7)
|
torch._check(a <= 7)
|
||||||
return torch.empty((a, 4))
|
return torch.randn((a, 4))
|
||||||
|
|
||||||
f = Module()
|
f = Module()
|
||||||
ep = export(f, (torch.tensor([5]),))
|
ep = export(f, (torch.tensor([5]),))
|
||||||
|
|
@ -6012,9 +6012,9 @@ def forward(self, x):
|
||||||
a = x.item()
|
a = x.item()
|
||||||
torch._check(a >= 4)
|
torch._check(a >= 4)
|
||||||
torch._check(a <= 7)
|
torch._check(a <= 7)
|
||||||
empty = torch.empty((a, 4))
|
randn = torch.randn((a, 4))
|
||||||
|
|
||||||
return torch.cat((empty.transpose(0, 1), torch.zeros(6, a)), 0)
|
return torch.cat((randn.transpose(0, 1), torch.zeros(6, a)), 0)
|
||||||
|
|
||||||
f = Module()
|
f = Module()
|
||||||
ep = export(f, (torch.tensor([6]),))
|
ep = export(f, (torch.tensor([6]),))
|
||||||
|
|
@ -6397,6 +6397,7 @@ def forward(self, b_a_buffer, x):
|
||||||
)
|
)
|
||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
|
@testing.expectedFailureCppRuntime
|
||||||
def test_export_associative_scan_symbol_dim(self):
|
def test_export_associative_scan_symbol_dim(self):
|
||||||
dim1 = torch.export.Dim("dim0", min=5, max=15)
|
dim1 = torch.export.Dim("dim0", min=5, max=15)
|
||||||
xs = torch.ones(3, 10, 2, device=torch.device("cuda"))
|
xs = torch.ones(3, 10, 2, device=torch.device("cuda"))
|
||||||
|
|
@ -6415,6 +6416,7 @@ def forward(self, b_a_buffer, x):
|
||||||
self.assertTrue(torch.allclose(ep.module()(xs), Foo()(xs)))
|
self.assertTrue(torch.allclose(ep.module()(xs), Foo()(xs)))
|
||||||
|
|
||||||
@requires_cuda
|
@requires_cuda
|
||||||
|
@testing.expectedFailureCppRuntime
|
||||||
def test_export_associative_scan_symbol_scandim(self):
|
def test_export_associative_scan_symbol_scandim(self):
|
||||||
dim1 = torch.export.Dim("dim0", min=5, max=15)
|
dim1 = torch.export.Dim("dim0", min=5, max=15)
|
||||||
xs = torch.ones(3, 10, 2, device=torch.device("cuda"))
|
xs = torch.ones(3, 10, 2, device=torch.device("cuda"))
|
||||||
|
|
@ -12016,6 +12018,7 @@ class GraphModule(torch.nn.Module):
|
||||||
self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp)))
|
self.assertTrue(torch.allclose(ep.module()(*inp), m(*inp)))
|
||||||
|
|
||||||
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
|
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
|
||||||
|
@testing.expectedFailureCppRuntime
|
||||||
def test_distributed_all_to_all_single(self):
|
def test_distributed_all_to_all_single(self):
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
@ -12033,6 +12036,7 @@ class GraphModule(torch.nn.Module):
|
||||||
self.assertEqual(len(nodes), 1)
|
self.assertEqual(len(nodes), 1)
|
||||||
|
|
||||||
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
|
@unittest.skipIf(IS_MACOS, "Distributed not packaged in macos")
|
||||||
|
@testing.expectedFailureCppRuntime
|
||||||
def test_distributed_reduce_scatter_tensor(self):
|
def test_distributed_reduce_scatter_tensor(self):
|
||||||
class Foo(torch.nn.Module):
|
class Foo(torch.nn.Module):
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
|
|
|
||||||
|
|
@ -196,7 +196,12 @@ _COMPOSITE_OPS_THAT_CAN_BE_PRESERVED_TESTING_ONLY = [
|
||||||
|
|
||||||
|
|
||||||
def make_test_cls_with_mocked_export(
|
def make_test_cls_with_mocked_export(
|
||||||
cls, cls_prefix, fn_suffix, mocked_export_fn, xfail_prop=None
|
cls,
|
||||||
|
cls_prefix,
|
||||||
|
fn_suffix,
|
||||||
|
mocked_export_fn,
|
||||||
|
xfail_prop=None,
|
||||||
|
test_only_if_no_xfail=False,
|
||||||
):
|
):
|
||||||
MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
|
MockedTestClass = type(f"{cls_prefix}{cls.__name__}", cls.__bases__, {})
|
||||||
MockedTestClass.__qualname__ = MockedTestClass.__name__
|
MockedTestClass.__qualname__ = MockedTestClass.__name__
|
||||||
|
|
@ -212,6 +217,12 @@ def make_test_cls_with_mocked_export(
|
||||||
new_fn.__name__ = new_name
|
new_fn.__name__ = new_name
|
||||||
if xfail_prop is not None and hasattr(fn, xfail_prop):
|
if xfail_prop is not None and hasattr(fn, xfail_prop):
|
||||||
new_fn = unittest.expectedFailure(new_fn)
|
new_fn = unittest.expectedFailure(new_fn)
|
||||||
|
elif test_only_if_no_xfail and any(
|
||||||
|
x.startswith("_expected_failure") for x in dir(fn)
|
||||||
|
):
|
||||||
|
new_fn = unittest.skip(
|
||||||
|
"Will only be tested if no other tests are failing"
|
||||||
|
)(new_fn)
|
||||||
setattr(MockedTestClass, new_name, new_fn)
|
setattr(MockedTestClass, new_name, new_fn)
|
||||||
# NB: Doesn't handle slots correctly, but whatever
|
# NB: Doesn't handle slots correctly, but whatever
|
||||||
elif not hasattr(MockedTestClass, name):
|
elif not hasattr(MockedTestClass, name):
|
||||||
|
|
@ -291,6 +302,11 @@ def expectedFailureCppSerDes(fn):
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def expectedFailureCppRuntime(fn):
|
||||||
|
fn._expected_failure_cpp_runtime = True
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
# Controls tests generated in test/export/test_export_legacy.py
|
# Controls tests generated in test/export/test_export_legacy.py
|
||||||
def expectedFailureLegacyExportStrict(fn):
|
def expectedFailureLegacyExportStrict(fn):
|
||||||
fn._expected_failure_legacy_export = True
|
fn._expected_failure_legacy_export = True
|
||||||
|
|
|
||||||
|
|
@ -1530,7 +1530,13 @@ class ExportedProgram:
|
||||||
|
|
||||||
# TODO(zhxchen17) Formalize this.
|
# TODO(zhxchen17) Formalize this.
|
||||||
def _update(
|
def _update(
|
||||||
self, graph_module, graph_signature, *, state_dict=None, verifiers=None
|
self,
|
||||||
|
graph_module,
|
||||||
|
graph_signature,
|
||||||
|
*,
|
||||||
|
state_dict=None,
|
||||||
|
constants=None,
|
||||||
|
verifiers=None,
|
||||||
) -> "ExportedProgram":
|
) -> "ExportedProgram":
|
||||||
return ExportedProgram(
|
return ExportedProgram(
|
||||||
root=graph_module,
|
root=graph_module,
|
||||||
|
|
@ -1540,7 +1546,7 @@ class ExportedProgram:
|
||||||
range_constraints=copy.deepcopy(self.range_constraints),
|
range_constraints=copy.deepcopy(self.range_constraints),
|
||||||
module_call_graph=copy.deepcopy(self._module_call_graph),
|
module_call_graph=copy.deepcopy(self._module_call_graph),
|
||||||
example_inputs=self.example_inputs,
|
example_inputs=self.example_inputs,
|
||||||
constants=self.constants,
|
constants=constants if constants is not None else self.constants,
|
||||||
verifiers=verifiers if verifiers is not None else self.verifiers,
|
verifiers=verifiers if verifiers is not None else self.verifiers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user