pytorch/test/export/test_tools.py
Sherlock Huang 0d5ba547ec Tool for scouting exportability in one shot (#126471)
Summary:
Tool for scouting exportability issues in one shot.

- Collect sample inputs for all submodules by running eager inference with forward_pre_hook.
- Start from root module, recursively try exporting child modules, if current module export fails.

Limitations:
- only works for nn.module that contains tree-like submodules structure. this doesn't work for flatten GraphModule.

TODO: support dynamic_dims

Sample output: https://docs.google.com/spreadsheets/d/1jnixrqBTYbWO_y6AaKA13XqOZmeB1MQAMuWL30dGoOg/edit?usp=sharing

```
exportability_report =
        {
            '': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
            'submod_1': UnsupportedOperatorException(func=<OpOverload(op='testlib.op_missing_meta', overload='default')>),
            'submod_2': None
        }
```

Test Plan: buck2 run mode/dev-nosan fbcode//caffe2/test:test_export -- -r TestExportTools

Differential Revision: D57466486

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126471
Approved by: https://github.com/zhxchen17
2024-05-18 00:10:46 +00:00

68 lines
1.8 KiB
Python

# Owner(s): ["oncall: export"]
import torch
from torch._dynamo.test_case import TestCase
from torch._export.tools import report_exportability
from torch.testing._internal.common_utils import run_tests
torch.library.define(
"testlib::op_missing_meta",
"(Tensor(a!) x, Tensor(b!) z) -> Tensor",
tags=torch.Tag.pt2_compliant_tag,
)
@torch.library.impl("testlib::op_missing_meta", "cpu")
@torch._dynamo.disable
def op_missing_meta(x, z):
x.add_(5)
z.add_(5)
return x + z
class TestExportTools(TestCase):
def test_report_exportability_basic(self):
class Module(torch.nn.Module):
def forward(self, x, y):
return x[0] + y
f = Module()
inp = ([torch.ones(1, 3)], torch.ones(1, 3))
report = report_exportability(f, inp)
self.assertTrue(len(report) == 1)
self.assertTrue(report[""] is None)
def test_report_exportability_with_issues(self):
class Unsupported(torch.nn.Module):
def forward(self, x):
return torch.ops.testlib.op_missing_meta(x, x.cos())
class Supported(torch.nn.Module):
def forward(self, x):
return x.sin()
class Module(torch.nn.Module):
def __init__(self):
super().__init__()
self.unsupported = Unsupported()
self.supported = Supported()
def forward(self, x):
y = torch.nonzero(x)
return self.unsupported(y) + self.supported(y)
f = Module()
inp = (torch.ones(4, 4),)
report = report_exportability(f, inp, strict=False, pre_dispatch=True)
self.assertTrue(report[""] is not None)
self.assertTrue(report["unsupported"] is not None)
self.assertTrue(report["supported"] is None)
if __name__ == "__main__":
run_tests()