mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Make public binding test only consider files that are packaged in the wheels (#130497)
In particular, when creating the PyTorch wheel, we use setuptools find_packages 551b3c6dca/setup.py (L1055) which explicitly skips packages without `__init__.py` files (namespace packages) https://setuptools.pypa.io/en/latest/userguide/package_discovery.html#finding-simple-packages.
So this PR is reverting the change to stop skipping these namespace packages as, even though they are in the codebase, they are not in the published binaries and so we're ok relaxing the public API and importability rules for them.
A manual diff of the two traversal methods:
```
torch._inductor.kernel.bmm
torch._inductor.kernel.conv
torch._inductor.kernel.flex_attention
torch._inductor.kernel.mm
torch._inductor.kernel.mm_common
torch._inductor.kernel.mm_plus_mm
torch._inductor.kernel.unpack_mixed_mm
torch._strobelight.examples.cli_function_profiler_example
torch._strobelight.examples.compile_time_profile_example
torch.ao.pruning._experimental.data_sparsifier.benchmarks.dlrm_utils
torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_disk_savings
torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_forward_time
torch.ao.pruning._experimental.data_sparsifier.benchmarks.evaluate_model_metrics
torch.ao.pruning._experimental.data_sparsifier.lightning.tests.test_callbacks
torch.ao.quantization.experimental.APoT_tensor
torch.ao.quantization.experimental.adaround_fake_quantize
torch.ao.quantization.experimental.adaround_loss
torch.ao.quantization.experimental.adaround_optimization
torch.ao.quantization.experimental.apot_utils
torch.ao.quantization.experimental.fake_quantize
torch.ao.quantization.experimental.fake_quantize_function
torch.ao.quantization.experimental.linear
torch.ao.quantization.experimental.observer
torch.ao.quantization.experimental.qconfig
torch.ao.quantization.experimental.quantizer
torch.csrc.jit.tensorexpr.codegen_external
torch.csrc.jit.tensorexpr.scripts.bisect
torch.csrc.lazy.test_mnist
torch.distributed._tensor.examples.checkpoint_example
torch.distributed._tensor.examples.comm_mode_features_example
torch.distributed._tensor.examples.comm_mode_features_example_argparser
torch.distributed._tensor.examples.convnext_example
torch.distributed._tensor.examples.torchrec_sharding_example
torch.distributed._tensor.examples.visualize_sharding_example
torch.distributed.benchmarks.benchmark_ddp_rpc
torch.distributed.checkpoint.examples.async_checkpointing_example
torch.distributed.checkpoint.examples.fsdp_checkpoint_example
torch.distributed.checkpoint.examples.stateful_example
torch.distributed.examples.memory_tracker_example
torch.fx.experimental.shape_inference.infer_shape
torch.fx.experimental.shape_inference.infer_symbol_values
torch.include.fp16.avx
torch.include.fp16.avx2
torch.onnx._internal.fx.analysis.unsupported_nodes
torch.onnx._internal.fx.passes._utils
torch.onnx._internal.fx.passes.decomp
torch.onnx._internal.fx.passes.functionalization
torch.onnx._internal.fx.passes.modularization
torch.onnx._internal.fx.passes.readability
torch.onnx._internal.fx.passes.type_promotion
torch.onnx._internal.fx.passes.virtualization
torch.utils._strobelight.examples.cli_function_profiler_example
torch.utils.benchmark.examples.sparse.compare
torch.utils.benchmark.examples.sparse.fuzzer
torch.utils.benchmark.examples.sparse.op_benchmark
torch.utils.tensorboard._convert_np
torch.utils.tensorboard._embedding
torch.utils.tensorboard._onnx_graph
torch.utils.tensorboard._proto_graph
torch.utils.tensorboard._pytorch_graph
torch.utils.tensorboard._utils
torch.utils.tensorboard.summary
torch.utils.tensorboard.writer
```
These are all either namespace packages (which we want to remove) or package that are not importable (and tagged as such in the test).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130497
Approved by: https://github.com/aorenste
This commit is contained in:
parent
215013daad
commit
354edb232a
|
|
@ -6,8 +6,6 @@ import json
|
|||
import os
|
||||
import pkgutil
|
||||
import unittest
|
||||
from itertools import chain
|
||||
from pathlib import Path
|
||||
from typing import Callable
|
||||
|
||||
import torch
|
||||
|
|
@ -22,44 +20,6 @@ from torch.testing._internal.common_utils import (
|
|||
)
|
||||
|
||||
|
||||
def _find_all_importables(pkg):
|
||||
"""Find all importables in the project.
|
||||
|
||||
Return them in order.
|
||||
"""
|
||||
return sorted(
|
||||
set(
|
||||
chain.from_iterable(
|
||||
_discover_path_importables(Path(p), pkg.__name__) for p in pkg.__path__
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _discover_path_importables(pkg_pth, pkg_name):
|
||||
"""Yield all importables under a given path and package.
|
||||
|
||||
This is like pkgutil.walk_packages, but does *not* skip over namespace
|
||||
packages. Taken from https://stackoverflow.com/questions/41203765/init-py-required-for-pkgutil-walk-packages-in-python3
|
||||
"""
|
||||
for dir_path, _d, file_names in os.walk(pkg_pth):
|
||||
pkg_dir_path = Path(dir_path)
|
||||
|
||||
if pkg_dir_path.parts[-1] == "__pycache__":
|
||||
continue
|
||||
if all(Path(_).suffix != ".py" for _ in file_names):
|
||||
continue
|
||||
rel_pt = pkg_dir_path.relative_to(pkg_pth)
|
||||
pkg_pref = ".".join((pkg_name,) + rel_pt.parts)
|
||||
yield from (
|
||||
pkg_path
|
||||
for _, pkg_path, _ in pkgutil.walk_packages(
|
||||
(str(pkg_dir_path),),
|
||||
prefix=f"{pkg_pref}.",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestPublicBindings(TestCase):
|
||||
def test_no_new_reexport_callables(self):
|
||||
"""
|
||||
|
|
@ -307,7 +267,12 @@ class TestPublicBindings(TestCase):
|
|||
@skipIfTorchDynamo("Broken and not relevant for now")
|
||||
def test_modules_can_be_imported(self):
|
||||
failures = []
|
||||
for modname in _find_all_importables(torch):
|
||||
|
||||
def onerror(modname):
|
||||
failures.append((modname, ImportError))
|
||||
|
||||
for mod in pkgutil.walk_packages(torch.__path__, "torch.", onerror=onerror):
|
||||
modname = mod.name
|
||||
try:
|
||||
# TODO: fix "torch/utils/model_dump/__main__.py"
|
||||
# which calls sys.exit() when we try to import it
|
||||
|
|
@ -369,6 +334,10 @@ class TestPublicBindings(TestCase):
|
|||
"torch.testing._internal.distributed.rpc_utils",
|
||||
"torch._inductor.codegen.cuda.cuda_template",
|
||||
"torch._inductor.codegen.cuda.gemm_template",
|
||||
"torch._inductor.codegen.cpp_template",
|
||||
"torch._inductor.codegen.cpp_gemm_template",
|
||||
"torch._inductor.codegen.cpp_micro_gemm",
|
||||
"torch._inductor.codegen.cpp_template_kernel",
|
||||
"torch._inductor.runtime.triton_helpers",
|
||||
"torch.ao.pruning._experimental.data_sparsifier.lightning.callbacks.data_sparsity",
|
||||
"torch.backends._coreml.preprocess",
|
||||
|
|
@ -624,7 +593,8 @@ class TestPublicBindings(TestCase):
|
|||
elem, modname, mod, is_public=True, is_all=False
|
||||
)
|
||||
|
||||
for modname in _find_all_importables(torch):
|
||||
for mod in pkgutil.walk_packages(torch.__path__, "torch."):
|
||||
mod = mod.name
|
||||
test_module(modname)
|
||||
test_module("torch")
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user