pytorch/test/package/test_dependency_api.py
Michael Suo 741d0f41d6 [package] split tests (#53749)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53749

Split up tests into cases that cover specific functionality. Goals:
1. Avoid the omnibus test file mess (see: test_jit.py) by imposing early
   structure and deliberately avoiding a generic TestPackage test case.
2. Encourage testing of individual APIs and components by example.
3. Hide the fake modules we created for these tests in their own folder.

You can either run the test files individually, or still use
test/test_package.py like before.

Also this isort + black formats all the tests.

Test Plan: Imported from OSS

Reviewed By: SplitInfinity

Differential Revision: D26958535

Pulled By: suo

fbshipit-source-id: 8a63048b95ca71f4f1aa94e53c48442686076034
2021-03-10 16:07:36 -08:00

232 lines
7.7 KiB
Python

from sys import version_info
from textwrap import dedent
from unittest import skipIf
from torch.package import (
DeniedModuleError,
EmptyMatchError,
PackageExporter,
PackageImporter,
)
from torch.testing._internal.common_utils import run_tests
try:
from .common import PackageTestCase
except ImportError:
# Support the case where we run this file directly.
from common import PackageTestCase # type: ignore
class TestDependencyAPI(PackageTestCase):
"""Dependency management API tests.
- mock()
- extern()
- deny()
"""
def test_extern(self):
filename = self.temp()
with PackageExporter(filename, verbose=False) as he:
he.extern(["package_a.subpackage", "module_a"])
he.require_module("package_a.subpackage")
he.require_module("module_a")
he.save_module("package_a")
hi = PackageImporter(filename)
import module_a
import package_a.subpackage
module_a_im = hi.import_module("module_a")
hi.import_module("package_a.subpackage")
package_a_im = hi.import_module("package_a")
self.assertIs(module_a, module_a_im)
self.assertIsNot(package_a, package_a_im)
self.assertIs(package_a.subpackage, package_a_im.subpackage)
def test_extern_glob(self):
filename = self.temp()
with PackageExporter(filename, verbose=False) as he:
he.extern(["package_a.*", "module_*"])
he.save_module("package_a")
he.save_source_string(
"test_module",
dedent(
"""\
import package_a.subpackage
import module_a
"""
),
)
hi = PackageImporter(filename)
import module_a
import package_a.subpackage
module_a_im = hi.import_module("module_a")
hi.import_module("package_a.subpackage")
package_a_im = hi.import_module("package_a")
self.assertIs(module_a, module_a_im)
self.assertIsNot(package_a, package_a_im)
self.assertIs(package_a.subpackage, package_a_im.subpackage)
def test_extern_glob_allow_empty(self):
"""
Test that an error is thrown when a extern glob is specified with allow_empty=True
and no matching module is required during packaging.
"""
filename = self.temp()
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
with PackageExporter(filename, verbose=False) as exporter:
exporter.extern(include=["package_a.*"], allow_empty=False)
exporter.save_module("package_b.subpackage")
def test_deny(self):
"""
Test marking packages as "deny" during export.
"""
filename = self.temp()
with self.assertRaisesRegex(
DeniedModuleError,
"required during packaging but has been explicitly blocklisted",
):
with PackageExporter(filename, verbose=False) as exporter:
exporter.deny(["package_a.subpackage", "module_a"])
exporter.require_module("package_a.subpackage")
def test_deny_glob(self):
"""
Test marking packages as "deny" using globs instead of package names.
"""
filename = self.temp()
with self.assertRaisesRegex(
DeniedModuleError,
"required during packaging but has been explicitly blocklisted",
):
with PackageExporter(filename, verbose=False) as exporter:
exporter.deny(["package_a.*", "module_*"])
exporter.save_source_string(
"test_module",
dedent(
"""\
import package_a.subpackage
import module_a
"""
),
)
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_mock(self):
filename = self.temp()
with PackageExporter(filename, verbose=False) as he:
he.mock(["package_a.subpackage", "module_a"])
he.save_module("package_a")
he.require_module("package_a.subpackage")
he.require_module("module_a")
hi = PackageImporter(filename)
import package_a.subpackage
_ = package_a.subpackage
import module_a
_ = module_a
m = hi.import_module("package_a.subpackage")
r = m.result
with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
r()
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_mock_glob(self):
filename = self.temp()
with PackageExporter(filename, verbose=False) as he:
he.mock(["package_a.*", "module*"])
he.save_module("package_a")
he.save_source_string(
"test_module",
dedent(
"""\
import package_a.subpackage
import module_a
"""
),
)
hi = PackageImporter(filename)
import package_a.subpackage
_ = package_a.subpackage
import module_a
_ = module_a
m = hi.import_module("package_a.subpackage")
r = m.result
with self.assertRaisesRegex(NotImplementedError, "was mocked out"):
r()
def test_mock_glob_allow_empty(self):
"""
Test that an error is thrown when a mock glob is specified with allow_empty=True
and no matching module is required during packaging.
"""
filename = self.temp()
with self.assertRaisesRegex(EmptyMatchError, r"did not match any modules"):
with PackageExporter(filename, verbose=False) as exporter:
exporter.mock(include=["package_a.*"], allow_empty=False)
exporter.save_module("package_b.subpackage")
def test_module_glob(self):
from torch.package.package_exporter import _GlobGroup
def check(include, exclude, should_match, should_not_match):
x = _GlobGroup(include, exclude)
for e in should_match:
self.assertTrue(x.matches(e))
for e in should_not_match:
self.assertFalse(x.matches(e))
check(
"torch.*",
[],
["torch.foo", "torch.bar"],
["tor.foo", "torch.foo.bar", "torch"],
)
check(
"torch.**",
[],
["torch.foo", "torch.bar", "torch.foo.bar", "torch"],
["what.torch", "torchvision"],
)
check("torch.*.foo", [], ["torch.w.foo"], ["torch.hi.bar.baz"])
check(
"torch.**.foo", [], ["torch.w.foo", "torch.hi.bar.foo"], ["torch.f.foo.z"]
)
check("torch*", [], ["torch", "torchvision"], ["torch.f"])
check(
"torch.**",
["torch.**.foo"],
["torch", "torch.bar", "torch.barfoo"],
["torch.foo", "torch.some.foo"],
)
check("**.torch", [], ["torch", "bar.torch"], ["visiontorch"])
@skipIf(version_info < (3, 7), "mock uses __getattr__ a 3.7 feature")
def test_pickle_mocked(self):
import package_a.subpackage
obj = package_a.subpackage.PackageASubpackageObject()
obj2 = package_a.PackageAObject(obj)
filename = self.temp()
with PackageExporter(filename, verbose=False) as he:
he.mock(include="package_a.subpackage")
he.save_pickle("obj", "obj.pkl", obj2)
hi = PackageImporter(filename)
with self.assertRaises(NotImplementedError):
hi.load_pickle("obj", "obj.pkl")
if __name__ == "__main__":
run_tests()