Enable UFMT on test/test_public_bindings.py (#128389)

Part of: https://github.com/pytorch/pytorch/issues/123062

Ran lintrunner on:
> test/test_public_bindings.py

Detail:
```
$ lintrunner -a --take UFMT --all-files
ok No lint issues.
Successfully applied all patches.
```

Co-authored-by: Edward Z. Yang <ezyang@fb.com>
Co-authored-by: Xuehai Pan <XuehaiPan@pku.edu.cn>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128389
Approved by: https://github.com/ezyang
This commit is contained in:
dilililiwhy 2024-06-30 08:49:51 +00:00 committed by PyTorch MergeBot
parent 4ee1cb9b95
commit fe5424d0f8
2 changed files with 114 additions and 62 deletions

View File

@ -1101,7 +1101,6 @@ exclude_patterns = [
'test/test_prims.py', 'test/test_prims.py',
'test/test_proxy_tensor.py', 'test/test_proxy_tensor.py',
'test/test_pruning_op.py', 'test/test_pruning_op.py',
'test/test_public_bindings.py',
'test/test_quantization.py', 'test/test_quantization.py',
'test/test_reductions.py', 'test/test_reductions.py',
'test/test_scatter_gather_ops.py', 'test/test_scatter_gather_ops.py',

View File

@ -30,8 +30,7 @@ def _find_all_importables(pkg):
return sorted( return sorted(
set( set(
chain.from_iterable( chain.from_iterable(
_discover_path_importables(Path(p), pkg.__name__) _discover_path_importables(Path(p), pkg.__name__) for p in pkg.__path__
for p in pkg.__path__
), ),
), ),
) )
@ -46,18 +45,17 @@ def _discover_path_importables(pkg_pth, pkg_name):
for dir_path, _d, file_names in os.walk(pkg_pth): for dir_path, _d, file_names in os.walk(pkg_pth):
pkg_dir_path = Path(dir_path) pkg_dir_path = Path(dir_path)
if pkg_dir_path.parts[-1] == '__pycache__': if pkg_dir_path.parts[-1] == "__pycache__":
continue continue
if all(Path(_).suffix != ".py" for _ in file_names):
if all(Path(_).suffix != '.py' for _ in file_names):
continue continue
rel_pt = pkg_dir_path.relative_to(pkg_pth) rel_pt = pkg_dir_path.relative_to(pkg_pth)
pkg_pref = '.'.join((pkg_name, ) + rel_pt.parts) pkg_pref = ".".join((pkg_name,) + rel_pt.parts)
yield from ( yield from (
pkg_path pkg_path
for _, pkg_path, _ in pkgutil.walk_packages( for _, pkg_path, _ in pkgutil.walk_packages(
(str(pkg_dir_path), ), prefix=f'{pkg_pref}.', (str(pkg_dir_path),),
prefix=f"{pkg_pref}.",
) )
) )
@ -72,9 +70,11 @@ class TestPublicBindings(TestCase):
reexported_callables = sorted( reexported_callables = sorted(
k k
for k, v in vars(torch).items() for k, v in vars(torch).items()
if callable(v) and not v.__module__.startswith('torch') if callable(v) and not v.__module__.startswith("torch")
)
self.assertTrue(
all(k.startswith("_") for k in reexported_callables), reexported_callables
) )
self.assertTrue(all(k.startswith('_') for k in reexported_callables), reexported_callables)
def test_no_new_bindings(self): def test_no_new_bindings(self):
""" """
@ -89,11 +89,11 @@ class TestPublicBindings(TestCase):
If you have removed a binding, remove it from the allowlist as well. If you have removed a binding, remove it from the allowlist as well.
""" """
# This allowlist contains every binding in torch._C that is copied into torch at # This allowlist contains every binding in torch._C that is copied into torch at
# the time of writing. It was generated with # the time of writing. It was generated with
# #
# {elem for elem in dir(torch._C) if not elem.startswith("_")} # {elem for elem in dir(torch._C) if not elem.startswith("_")}
#
torch_C_allowlist_superset = { torch_C_allowlist_superset = {
"AggregationType", "AggregationType",
"AliasDb", "AliasDb",
@ -264,8 +264,8 @@ class TestPublicBindings(TestCase):
"UnionType", "UnionType",
"Use", "Use",
"Value", "Value",
'set_autocast_gpu_dtype', "set_autocast_gpu_dtype",
'get_autocast_gpu_dtype', "get_autocast_gpu_dtype",
"vitals_enabled", "vitals_enabled",
"wait", "wait",
"Tag", "Tag",
@ -274,12 +274,14 @@ class TestPublicBindings(TestCase):
"get_autocast_xla_dtype", "get_autocast_xla_dtype",
"is_autocast_xla_enabled", "is_autocast_xla_enabled",
} }
torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")} torch_C_bindings = {elem for elem in dir(torch._C) if not elem.startswith("_")}
# torch.TensorBase is explicitly removed in torch/__init__.py, so included here (#109940) # torch.TensorBase is explicitly removed in torch/__init__.py, so included here (#109940)
explicitly_removed_torch_C_bindings = { explicitly_removed_torch_C_bindings = {
"TensorBase", "TensorBase",
} }
torch_C_bindings = torch_C_bindings - explicitly_removed_torch_C_bindings torch_C_bindings = torch_C_bindings - explicitly_removed_torch_C_bindings
# Check that the torch._C bindings are all in the allowlist. Since # Check that the torch._C bindings are all in the allowlist. Since
@ -292,13 +294,16 @@ class TestPublicBindings(TestCase):
@staticmethod @staticmethod
def _is_mod_public(modname): def _is_mod_public(modname):
split_strs = modname.split('.') split_strs = modname.split(".")
for elem in split_strs: for elem in split_strs:
if elem.startswith("_"): if elem.startswith("_"):
return False return False
return True return True
@unittest.skipIf(IS_WINDOWS or IS_MACOS, "Inductor/Distributed modules hard fail on windows and macos") @unittest.skipIf(
IS_WINDOWS or IS_MACOS,
"Inductor/Distributed modules hard fail on windows and macos",
)
@skipIfTorchDynamo("Broken and not relevant for now") @skipIfTorchDynamo("Broken and not relevant for now")
def test_modules_can_be_imported(self): def test_modules_can_be_imported(self):
failures = [] failures = []
@ -311,6 +316,7 @@ class TestPublicBindings(TestCase):
importlib.import_module(modname) importlib.import_module(modname)
except Exception as e: except Exception as e:
# Some current failures are not ImportError # Some current failures are not ImportError
failures.append((modname, type(e))) failures.append((modname, type(e)))
# It is ok to add new entries here but please be careful that these modules # It is ok to add new entries here but please be careful that these modules
@ -452,20 +458,18 @@ class TestPublicBindings(TestCase):
for mod, excep_type in failures: for mod, excep_type in failures:
if mod in public_allowlist: if mod in public_allowlist:
# TODO: Ensure this is the right error type # TODO: Ensure this is the right error type
continue
continue
if mod in private_allowlist: if mod in private_allowlist:
continue continue
errors.append(f"{mod} failed to import with error {excep_type}") errors.append(f"{mod} failed to import with error {excep_type}")
self.assertEqual("", "\n".join(errors)) self.assertEqual("", "\n".join(errors))
# AttributeError: module 'torch.distributed' has no attribute '_shard' # AttributeError: module 'torch.distributed' has no attribute '_shard'
@unittest.skipIf(IS_WINDOWS or IS_JETSON or IS_MACOS, "Distributed Attribute Error") @unittest.skipIf(IS_WINDOWS or IS_JETSON or IS_MACOS, "Distributed Attribute Error")
@skipIfTorchDynamo("Broken and not relevant for now") @skipIfTorchDynamo("Broken and not relevant for now")
def test_correct_module_names(self): def test_correct_module_names(self):
''' """
An API is considered public, if its `__module__` starts with `torch.` An API is considered public, if its `__module__` starts with `torch.`
and there is no name in `__module__` or the object itself that starts with "_". and there is no name in `__module__` or the object itself that starts with "_".
Each public package should either: Each public package should either:
@ -474,18 +478,25 @@ class TestPublicBindings(TestCase):
NOT have their `__module__` start with the current submodule. NOT have their `__module__` start with the current submodule.
- (for simple python-only modules) Not define `__all__` and all the elements in `dir(submod)` must have their - (for simple python-only modules) Not define `__all__` and all the elements in `dir(submod)` must have their
`__module__` that start with the current submodule. `__module__` that start with the current submodule.
''' """
failure_list = [] failure_list = []
with open(get_file_path_2(os.path.dirname(__file__), 'allowlist_for_publicAPI.json')) as json_file: with open(
get_file_path_2(os.path.dirname(__file__), "allowlist_for_publicAPI.json")
) as json_file:
# no new entries should be added to this allow_dict. # no new entries should be added to this allow_dict.
# New APIs must follow the public API guidelines. # New APIs must follow the public API guidelines.
allow_dict = json.load(json_file) allow_dict = json.load(json_file)
# Because we want minimal modifications to the `allowlist_for_publicAPI.json`, # Because we want minimal modifications to the `allowlist_for_publicAPI.json`,
# we are adding the entries for the migrated modules here from the original # we are adding the entries for the migrated modules here from the original
# locations. # locations.
for modname in allow_dict["being_migrated"]: for modname in allow_dict["being_migrated"]:
if modname in allow_dict: if modname in allow_dict:
allow_dict[allow_dict["being_migrated"][modname]] = allow_dict[modname] allow_dict[allow_dict["being_migrated"][modname]] = allow_dict[
modname
]
def test_module(modname): def test_module(modname):
try: try:
@ -495,96 +506,137 @@ class TestPublicBindings(TestCase):
except Exception: except Exception:
# It is ok to ignore here as we have a test above that ensures # It is ok to ignore here as we have a test above that ensures
# this should never happen # this should never happen
return
return
if not self._is_mod_public(modname): if not self._is_mod_public(modname):
return return
# verifies that each public API has the correct module name and naming semantics # verifies that each public API has the correct module name and naming semantics
def check_one_element(elem, modname, mod, *, is_public, is_all): def check_one_element(elem, modname, mod, *, is_public, is_all):
obj = getattr(mod, elem) obj = getattr(mod, elem)
# torch.dtype is not a class nor callable, so we need to check for it separately # torch.dtype is not a class nor callable, so we need to check for it separately
if not (isinstance(obj, (Callable, torch.dtype)) or inspect.isclass(obj)): if not (
isinstance(obj, (Callable, torch.dtype)) or inspect.isclass(obj)
):
return return
elem_module = getattr(obj, '__module__', None) elem_module = getattr(obj, "__module__", None)
# Only used for nice error message below # Only used for nice error message below
why_not_looks_public = "" why_not_looks_public = ""
if elem_module is None: if elem_module is None:
why_not_looks_public = "because it does not have a `__module__` attribute" why_not_looks_public = (
"because it does not have a `__module__` attribute"
)
# If a module is being migrated from foo.a to bar.a (that is entry {"foo": "bar"}), # If a module is being migrated from foo.a to bar.a (that is entry {"foo": "bar"}),
# the module's starting package would be referred to as the new location even # the module's starting package would be referred to as the new location even
# if there is a "from foo import a" inside the "bar.py". # if there is a "from foo import a" inside the "bar.py".
modname = allow_dict["being_migrated"].get(modname, modname) modname = allow_dict["being_migrated"].get(modname, modname)
elem_modname_starts_with_mod = elem_module is not None and \ elem_modname_starts_with_mod = (
elem_module.startswith(modname) and \ elem_module is not None
'._' not in elem_module and elem_module.startswith(modname)
and "._" not in elem_module
)
if not why_not_looks_public and not elem_modname_starts_with_mod: if not why_not_looks_public and not elem_modname_starts_with_mod:
why_not_looks_public = f"because its `__module__` attribute (`{elem_module}`) is not within the " \ why_not_looks_public = (
f"because its `__module__` attribute (`{elem_module}`) is not within the "
f"torch library or does not start with the submodule where it is defined (`{modname}`)" f"torch library or does not start with the submodule where it is defined (`{modname}`)"
)
# elem's name must NOT begin with an `_` and it's module name # elem's name must NOT begin with an `_` and it's module name
# SHOULD start with it's current module since it's a public API # SHOULD start with it's current module since it's a public API
looks_public = not elem.startswith('_') and elem_modname_starts_with_mod looks_public = not elem.startswith("_") and elem_modname_starts_with_mod
if not why_not_looks_public and not looks_public: if not why_not_looks_public and not looks_public:
why_not_looks_public = f"because it starts with `_` (`{elem}`)" why_not_looks_public = f"because it starts with `_` (`{elem}`)"
if is_public != looks_public: if is_public != looks_public:
if modname in allow_dict and elem in allow_dict[modname]: if modname in allow_dict and elem in allow_dict[modname]:
return return
if is_public: if is_public:
why_is_public = f"it is inside the module's (`{modname}`) `__all__`" if is_all else \ why_is_public = (
"it is an attribute that does not start with `_` on a module that " \ f"it is inside the module's (`{modname}`) `__all__`"
if is_all
else "it is an attribute that does not start with `_` on a module that "
"does not have `__all__` defined" "does not have `__all__` defined"
fix_is_public = f"remove it from the modules's (`{modname}`) `__all__`" if is_all else \ )
f"either define a `__all__` for `{modname}` or add a `_` at the beginning of the name" fix_is_public = (
f"remove it from the modules's (`{modname}`) `__all__`"
if is_all
else f"either define a `__all__` for `{modname}` or add a `_` at the beginning of the name"
)
else: else:
assert is_all assert is_all
why_is_public = f"it is not inside the module's (`{modname}`) `__all__`" why_is_public = (
fix_is_public = f"add it from the modules's (`{modname}`) `__all__`" f"it is not inside the module's (`{modname}`) `__all__`"
)
fix_is_public = (
f"add it from the modules's (`{modname}`) `__all__`"
)
if looks_public: if looks_public:
why_looks_public = "it does look public because it follows the rules from the doc above " \ why_looks_public = (
"it does look public because it follows the rules from the doc above "
"(does not start with `_` and has a proper `__module__`)." "(does not start with `_` and has a proper `__module__`)."
)
fix_looks_public = "make its name start with `_`" fix_looks_public = "make its name start with `_`"
else: else:
why_looks_public = why_not_looks_public why_looks_public = why_not_looks_public
if not elem_modname_starts_with_mod: if not elem_modname_starts_with_mod:
fix_looks_public = "make sure the `__module__` is properly set and points to a submodule "\ fix_looks_public = (
"make sure the `__module__` is properly set and points to a submodule "
f"of `{modname}`" f"of `{modname}`"
)
else: else:
fix_looks_public = "remove the `_` at the beginning of the name" fix_looks_public = (
"remove the `_` at the beginning of the name"
)
failure_list.append(f"# {modname}.{elem}:") failure_list.append(f"# {modname}.{elem}:")
is_public_str = "" if is_public else " NOT" is_public_str = "" if is_public else " NOT"
failure_list.append(f" - Is{is_public_str} public: {why_is_public}") failure_list.append(
f" - Is{is_public_str} public: {why_is_public}"
)
looks_public_str = "" if looks_public else " NOT" looks_public_str = "" if looks_public else " NOT"
failure_list.append(f" - Does{looks_public_str} look public: {why_looks_public}") failure_list.append(
f" - Does{looks_public_str} look public: {why_looks_public}"
)
# Swap the str below to avoid having to create the NOT again # Swap the str below to avoid having to create the NOT again
failure_list.append(" - You can do either of these two things to fix this problem:") failure_list.append(
failure_list.append(f" - To make it{looks_public_str} public: {fix_is_public}") " - You can do either of these two things to fix this problem:"
failure_list.append(f" - To make it{is_public_str} look public: {fix_looks_public}") )
failure_list.append(
f" - To make it{looks_public_str} public: {fix_is_public}"
)
failure_list.append(
f" - To make it{is_public_str} look public: {fix_looks_public}"
)
if hasattr(mod, '__all__'): if hasattr(mod, "__all__"):
public_api = mod.__all__ public_api = mod.__all__
all_api = dir(mod) all_api = dir(mod)
for elem in all_api: for elem in all_api:
check_one_element(elem, modname, mod, is_public=elem in public_api, is_all=True) check_one_element(
elem, modname, mod, is_public=elem in public_api, is_all=True
)
else: else:
all_api = dir(mod) all_api = dir(mod)
for elem in all_api: for elem in all_api:
if not elem.startswith('_'): if not elem.startswith("_"):
check_one_element(elem, modname, mod, is_public=True, is_all=False) check_one_element(
elem, modname, mod, is_public=True, is_all=False
)
for modname in _find_all_importables(torch): for modname in _find_all_importables(torch):
test_module(modname) test_module(modname)
test_module("torch")
test_module('torch') msg = (
"All the APIs below do not meet our guidelines for public API from "
msg = "All the APIs below do not meet our guidelines for public API from " \
"https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.\n" "https://github.com/pytorch/pytorch/wiki/Public-API-definition-and-documentation.\n"
msg += "Make sure that everything that is public is expected (in particular that the module " \ )
"has a properly populated `__all__` attribute) and that everything that is supposed to be public " \ msg += (
"Make sure that everything that is public is expected (in particular that the module "
"has a properly populated `__all__` attribute) and that everything that is supposed to be public "
"does look public (it does not start with `_` and has a `__module__` that is properly populated)." "does look public (it does not start with `_` and has a `__module__` that is properly populated)."
)
msg += "\n\nFull list:\n" msg += "\n\nFull list:\n"
msg += "\n".join(map(str, failure_list)) msg += "\n".join(map(str, failure_list))
@ -592,5 +644,6 @@ class TestPublicBindings(TestCase):
# empty lists are considered false in python # empty lists are considered false in python
self.assertTrue(not failure_list, msg) self.assertTrue(not failure_list, msg)
if __name__ == '__main__':
if __name__ == "__main__":
run_tests() run_tests()