pytorch/test/quantization/ao_migration/common.py
Vasiliy Kuznetsov 216f88d084 ao migration: remove package test as this behavior is tested by other things (#94422)
Summary:

We have tests testing package level migration correctness for torch AO migration.
After reading the code, I noticed that these tests are not testing anything
additional on top of the function level tests we already have.

An upcoming user warning PR will break this test, and it doesn't seem worth fixing.
As long as the function level tests pass, 100% of user functionality will
be tested.  Removing this in a separate PR to keep PRs small.

Test plan:

```
python test/test_quantization.py -k AOMigration
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94422
Approved by: https://github.com/jcaip
2023-02-13 16:33:40 +00:00

43 lines
2.1 KiB
Python

from torch.testing._internal.common_utils import TestCase
import importlib
from typing import List, Optional
class AOMigrationTestCase(TestCase):
def _test_function_import(self, package_name: str, function_list: List[str],
base: Optional[str] = None, new_package_name: Optional[str] = None):
r"""Tests individual function list import by comparing the functions
and their hashes."""
if base is None:
base = 'quantization'
old_base = 'torch.' + base
new_base = 'torch.ao.' + base
if new_package_name is None:
new_package_name = package_name
old_location = importlib.import_module(f'{old_base}.{package_name}')
new_location = importlib.import_module(f'{new_base}.{new_package_name}')
for fn_name in function_list:
old_function = getattr(old_location, fn_name)
new_function = getattr(new_location, fn_name)
assert old_function == new_function, f"Functions don't match: {fn_name}"
assert hash(old_function) == hash(new_function), \
f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
f"{new_function}({hash(new_function)})"
def _test_dict_import(self, package_name: str, dict_list: List[str],
base: Optional[str] = None):
r"""Tests individual function list import by comparing the functions
and their hashes."""
if base is None:
base = 'quantization'
old_base = 'torch.' + base
new_base = 'torch.ao.' + base
old_location = importlib.import_module(f'{old_base}.{package_name}')
new_location = importlib.import_module(f'{new_base}.{package_name}')
for dict_name in dict_list:
old_dict = getattr(old_location, dict_name)
new_dict = getattr(new_location, dict_name)
assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
for key in new_dict.keys():
assert old_dict[key] == new_dict[key], f"Dicts don't match: {dict_name} for key {key}"