mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: We have tests testing package level migration correctness for torch AO migration. After reading the code, I noticed that these tests are not testing anything additional on top of the function level tests we already have. An upcoming user warning PR will break this test, and it doesn't seem worth fixing. As long as the function level tests pass, 100% of user functionality will be tested. Removing this in a separate PR to keep PRs small. Test plan: ``` python test/test_quantization.py -k AOMigration ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/94422 Approved by: https://github.com/jcaip
43 lines
2.1 KiB
Python
43 lines
2.1 KiB
Python
from torch.testing._internal.common_utils import TestCase
|
|
|
|
import importlib
|
|
from typing import List, Optional
|
|
|
|
class AOMigrationTestCase(TestCase):
|
|
def _test_function_import(self, package_name: str, function_list: List[str],
|
|
base: Optional[str] = None, new_package_name: Optional[str] = None):
|
|
r"""Tests individual function list import by comparing the functions
|
|
and their hashes."""
|
|
if base is None:
|
|
base = 'quantization'
|
|
old_base = 'torch.' + base
|
|
new_base = 'torch.ao.' + base
|
|
if new_package_name is None:
|
|
new_package_name = package_name
|
|
old_location = importlib.import_module(f'{old_base}.{package_name}')
|
|
new_location = importlib.import_module(f'{new_base}.{new_package_name}')
|
|
for fn_name in function_list:
|
|
old_function = getattr(old_location, fn_name)
|
|
new_function = getattr(new_location, fn_name)
|
|
assert old_function == new_function, f"Functions don't match: {fn_name}"
|
|
assert hash(old_function) == hash(new_function), \
|
|
f"Hashes don't match: {old_function}({hash(old_function)}) vs. " \
|
|
f"{new_function}({hash(new_function)})"
|
|
|
|
def _test_dict_import(self, package_name: str, dict_list: List[str],
|
|
base: Optional[str] = None):
|
|
r"""Tests individual function list import by comparing the functions
|
|
and their hashes."""
|
|
if base is None:
|
|
base = 'quantization'
|
|
old_base = 'torch.' + base
|
|
new_base = 'torch.ao.' + base
|
|
old_location = importlib.import_module(f'{old_base}.{package_name}')
|
|
new_location = importlib.import_module(f'{new_base}.{package_name}')
|
|
for dict_name in dict_list:
|
|
old_dict = getattr(old_location, dict_name)
|
|
new_dict = getattr(new_location, dict_name)
|
|
assert old_dict == new_dict, f"Dicts don't match: {dict_name}"
|
|
for key in new_dict.keys():
|
|
assert old_dict[key] == new_dict[key], f"Dicts don't match: {dict_name} for key {key}"
|