#!/usr/bin/env python3 import sys import io import unittest import torch import torch.utils.model_dump import torch.utils.mobile_optimizer from torch.testing._internal.common_utils import TestCase, run_tests from torch.testing._internal.common_quantized import supported_qengines class SimpleModel(torch.nn.Module): def __init__(self): super().__init__() self.layer1 = torch.nn.Linear(16, 64) self.relu1 = torch.nn.ReLU() self.layer2 = torch.nn.Linear(64, 8) self.relu2 = torch.nn.ReLU() def forward(self, features): act = features act = self.layer1(act) act = self.relu1(act) act = self.layer2(act) act = self.relu2(act) return act class QuantModel(torch.nn.Module): def __init__(self): super().__init__() self.quant = torch.quantization.QuantStub() self.dequant = torch.quantization.DeQuantStub() self.core = SimpleModel() def forward(self, x): x = self.quant(x) x = self.core(x) x = self.dequant(x) return x class ModelWithLists(torch.nn.Module): def __init__(self): super().__init__() self.rt = [torch.zeros(1)] self.ot = [torch.zeros(1), None] def forward(self, arg): arg = arg + self.rt[0] o = self.ot[0] if o is not None: arg = arg + o return arg class TestModelDump(TestCase): @unittest.skipIf(sys.version_info < (3, 7), "importlib.resources was new in 3.7") def test_inline_skeleton(self): skel = torch.utils.model_dump.get_inline_skeleton() assert "unpkg.org" not in skel assert "src=" not in skel def do_dump_model(self, model, extra_files=None): # Just check that we're able to run successfully. buf = io.BytesIO() torch.jit.save(model, buf, _extra_files=extra_files) info = torch.utils.model_dump.get_model_info(buf) assert info is not None def test_scripted_model(self): model = torch.jit.script(SimpleModel()) self.do_dump_model(model) def test_traced_model(self): model = torch.jit.trace(SimpleModel(), torch.zeros(2, 16)) self.do_dump_model(model) def get_quant_model(self): fmodel = QuantModel().eval() fmodel = torch.quantization.fuse_modules(fmodel, [ ["core.layer1", "core.relu1"], ["core.layer2", "core.relu2"], ]) fmodel.qconfig = torch.quantization.get_default_qconfig("qnnpack") prepped = torch.quantization.prepare(fmodel) prepped(torch.randn(2, 16)) qmodel = torch.quantization.convert(prepped) return qmodel @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available") def test_quantized_model(self): qmodel = self.get_quant_model() self.do_dump_model(torch.jit.script(qmodel)) @unittest.skipUnless("qnnpack" in supported_qengines, "QNNPACK not available") def test_optimized_quantized_model(self): qmodel = self.get_quant_model() smodel = torch.jit.trace(qmodel, torch.zeros(2, 16)) omodel = torch.utils.mobile_optimizer.optimize_for_mobile(smodel) self.do_dump_model(omodel) def test_model_with_lists(self): model = torch.jit.script(ModelWithLists()) self.do_dump_model(model) def test_invalid_json(self): model = torch.jit.script(SimpleModel()) self.do_dump_model(model, extra_files={"foo.json": "{"}) if __name__ == '__main__': run_tests()