mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/53749 Split up tests into cases that cover specific functionality. Goals: 1. Avoid the omnibus test file mess (see: test_jit.py) by imposing early structure and deliberately avoiding a generic TestPackage test case. 2. Encourage testing of individual APIs and components by example. 3. Hide the fake modules we created for these tests in their own folder. You can either run the test files individually, or still use test/test_package.py like before. Also this isort + black formats all the tests. Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D26958535 Pulled By: suo fbshipit-source-id: 8a63048b95ca71f4f1aa94e53c48442686076034
232 lines
7.7 KiB
Python
232 lines
7.7 KiB
Python
from sys import version_info
|
|
from textwrap import dedent
|
|
from unittest import skipIf
|
|
|
|
from torch.package import (
|
|
DeniedModuleError,
|
|
EmptyMatchError,
|
|
PackageExporter,
|
|
PackageImporter,
|
|
)
|
|
from torch.testing._internal.common_utils import run_tests
|
|
|
|
try:
|
|
from .common import PackageTestCase
|
|
except ImportError:
|
|
# Support the case where we run this file directly.
|
|
from common import PackageTestCase # type: ignore
|
|
|
|
|
|
class TestDependencyAPI(PackageTestCase):
|
|
"""Dependency management API tests.
|
|
- mock()
|
|
- extern()
|
|
- deny()
|
|
"""
|
|
|
|
def test_extern(self):
|
|
filename = self.temp()
|
|
with PackageExporter(filename, verbose=False) as he:
|
|
he.extern(["package_a.subpackage", "module_a"])
|
|
he.require_module("package_a.subpackage")
|
|
he.require_module("module_a")
|
|
he.save_module("package_a")
|
|
hi = PackageImporter(filename)
|
|
import module_a
|
|
import package_a.subpackage
|
|
|
|
module_a_im = hi.import_module("module_a")
|
|
hi.import_module("package_a.subpackage")
|
|
package_a_im = hi.import_module("package_a")
|
|
|
|
self.assertIs(module_a, module_a_im)
|
|
self.assertIsNot(package_a, package_a_im)
|
|
self.assertIs(package_a.subpackage, package_a_im.subpackage)
|
|
|
|
def test_extern_glob(self):
|
|
filename = self.temp()
|
|
with PackageExporter(filename, verbose=False) as he:
|
|
he.extern(["package_a.*", "module_*"])
|
|
he.save_module("package_a")
|
|
he.save_source_string(
|
|
"test_module",
|
|
dedent(
|
|
"""\
|
|
import package_a.subpackage
|
|
import module_a
|
|
"""
|
|
),
|
|
)
|
|
hi = PackageImporter(filename)
|
|
import module_a
|
|
import package_a.subpackage
|
|
|
|
module_a_im = hi.import_module("module_a")
|
|
hi.import_module("package_a.subpackage")
|
|
package_a_im = hi.import_module("package_a")
|
|
|
|
self.assertIs(module_a, module_a_im)
|
|
self.assertIsNot(package_a, package_a_im)
|
|
self.assertIs(package_a.subpackage, package_a_im.subpackage)
|
|
|
|
def test_extern_glob_allow_empty(self):
|
|
"""
|
|
Test that an error is thrown when a extern glob is specified with allow_empty=True
|
|
and no matching module is required during packaging.
|
|
"""
|
|
filename = self.temp()
|
|
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
|
|
with PackageExporter(filename, verbose=False) as exporter:
|
|
exporter.extern(include=["package_a.*"], allow_empty=False)
|
|
exporter.save_module("package_b.subpackage")
|
|
|
|
def test_deny(self):
|
|
"""
|
|
Test marking packages as "deny" during export.
|
|
"""
|
|
filename = self.temp()
|
|
|
|
with self.assertRaisesRegex(
|
|
DeniedModuleError,
|
|
"required during packaging but has been explicitly blocklisted",
|
|
):
|
|
with PackageExporter(filename, verbose=False) as exporter:
|
|
exporter.deny(["package_a.subpackage", "module_a"])
|
|
exporter.require_module("package_a.subpackage")
|
|
|
|
def test_deny_glob(self):
|
|
"""
|
|
Test marking packages as "deny" using globs instead of package names.
|
|
"""
|
|
filename = self.temp()
|
|
with self.assertRaisesRegex(
|
|
DeniedModuleError,
|
|
"required during packaging but has been explicitly blocklisted",
|
|
):
|
|
with PackageExporter(filename, verbose=False) as exporter:
|
|
exporter.deny(["package_a.*", "module_*"])
|
|
exporter.save_source_string(
|
|
"test_module",
|
|
dedent(
|
|
"""\
|
|
import package_a.subpackage
|
|
import module_a
|
|
"""
|
|
),
|
|
)
|
|
|
|
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
|
def test_mock(self):
|
|
filename = self.temp()
|
|
with PackageExporter(filename, verbose=False) as he:
|
|
he.mock(["package_a.subpackage", "module_a"])
|
|
he.save_module("package_a")
|
|
he.require_module("package_a.subpackage")
|
|
he.require_module("module_a")
|
|
hi = PackageImporter(filename)
|
|
import package_a.subpackage
|
|
|
|
_ = package_a.subpackage
|
|
import module_a
|
|
|
|
_ = module_a
|
|
|
|
m = hi.import_module("package_a.subpackage")
|
|
r = m.result
|
|
with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
|
|
r()
|
|
|
|
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
|
def test_mock_glob(self):
|
|
filename = self.temp()
|
|
with PackageExporter(filename, verbose=False) as he:
|
|
he.mock(["package_a.*", "module*"])
|
|
he.save_module("package_a")
|
|
he.save_source_string(
|
|
"test_module",
|
|
dedent(
|
|
"""\
|
|
import package_a.subpackage
|
|
import module_a
|
|
"""
|
|
),
|
|
)
|
|
hi = PackageImporter(filename)
|
|
import package_a.subpackage
|
|
|
|
_ = package_a.subpackage
|
|
import module_a
|
|
|
|
_ = module_a
|
|
|
|
m = hi.import_module("package_a.subpackage")
|
|
r = m.result
|
|
with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
|
|
r()
|
|
|
|
def test_mock_glob_allow_empty(self):
|
|
"""
|
|
Test that an error is thrown when a mock glob is specified with allow_empty=True
|
|
and no matching module is required during packaging.
|
|
"""
|
|
filename = self.temp()
|
|
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
|
|
with PackageExporter(filename, verbose=False) as exporter:
|
|
exporter.mock(include=["package_a.*"], allow_empty=False)
|
|
exporter.save_module("package_b.subpackage")
|
|
|
|
def test_module_glob(self):
|
|
from torch.package.package_exporter import _GlobGroup
|
|
|
|
def check(include, exclude, should_match, should_not_match):
|
|
x = _GlobGroup(include, exclude)
|
|
for e in should_match:
|
|
self.assertTrue(x.matches(e))
|
|
for e in should_not_match:
|
|
self.assertFalse(x.matches(e))
|
|
|
|
check(
|
|
"torch.*",
|
|
[],
|
|
["torch.foo", "torch.bar"],
|
|
["tor.foo", "torch.foo.bar", "torch"],
|
|
)
|
|
check(
|
|
"torch.**",
|
|
[],
|
|
["torch.foo", "torch.bar", "torch.foo.bar", "torch"],
|
|
["what.torch", "torchvision"],
|
|
)
|
|
check("torch.*.foo", [], ["torch.w.foo"], ["torch.hi.bar.baz"])
|
|
check(
|
|
"torch.**.foo", [], ["torch.w.foo", "torch.hi.bar.foo"], ["torch.f.foo.z"]
|
|
)
|
|
check("torch*", [], ["torch", "torchvision"], ["torch.f"])
|
|
check(
|
|
"torch.**",
|
|
["torch.**.foo"],
|
|
["torch", "torch.bar", "torch.barfoo"],
|
|
["torch.foo", "torch.some.foo"],
|
|
)
|
|
check("**.torch", [], ["torch", "bar.torch"], ["visiontorch"])
|
|
|
|
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
|
|
def test_pickle_mocked(self):
|
|
import package_a.subpackage
|
|
|
|
obj = package_a.subpackage.PackageASubpackageObject()
|
|
obj2 = package_a.PackageAObject(obj)
|
|
|
|
filename = self.temp()
|
|
with PackageExporter(filename, verbose=False) as he:
|
|
he.mock(include="package_a.subpackage")
|
|
he.save_pickle("obj", "obj.pkl", obj2)
|
|
|
|
hi = PackageImporter(filename)
|
|
with self.assertRaises(NotImplementedError):
|
|
hi.load_pickle("obj", "obj.pkl")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|