Support independent builds for cpp extension tests + apply to libtorch_agnostic tests (#153264)

Related: #148920

This PR:
* Provides a helper `install_cpp_extension(extension_root)` for building C++ extensions. This is intended to be used in `TestMyCppExtension.setUpClass()`
    * Updates libtorch_agnostic tests to use this
* Deletes preexisting libtorch_agnostic tests from `test/test_cpp_extensions_aot.py`
    * Fixes `run_test.py` to actually run tests in `test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py` to avoid losing coverage. This wasn't being run due to logic excluding tests that start with "cpp"; this is fixed now

After this PR, it is now possible to run:
```
python test/cpp_extensions/libtorch_agnostic_extension/test/test_libtorch_agnostic.py
```

and the test will build the `libtorch_agnostic` extension before running the tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/153264
Approved by: https://github.com/janeyx99
This commit is contained in:
Joel Schlosser 2025-05-20 11:49:06 -04:00 committed by PyTorch MergeBot
parent f1f54c197d
commit 3ecd444004
5 changed files with 160 additions and 174 deletions

View File

@ -1,6 +1,6 @@
# Owner(s): ["module: cpp"]
import libtorch_agnostic # noqa: F401
from pathlib import Path
import torch
from torch.testing._internal.common_device_type import (
@ -8,114 +8,144 @@ from torch.testing._internal.common_device_type import (
onlyCPU,
onlyCUDA,
)
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.common_utils import (
install_cpp_extension,
IS_WINDOWS,
run_tests,
TestCase,
xfailIfTorchDynamo,
)
class TestLibtorchAgnostic(TestCase):
@onlyCPU
def test_slow_sgd(self, device):
param = torch.rand(5, device=device)
grad = torch.rand_like(param)
weight_decay = 0.01
lr = 0.001
maximize = False
# TODO: Fix this error in Windows:
# LINK : error LNK2001: unresolved external symbol PyInit__C
if not IS_WINDOWS:
new_param = libtorch_agnostic.ops.sgd_out_of_place(
param, grad, weight_decay, lr, maximize
)
torch._fused_sgd_(
(param,),
(grad,),
(),
weight_decay=weight_decay,
momentum=0.0,
lr=lr,
dampening=0.0,
nesterov=False,
maximize=maximize,
is_first_step=False,
)
self.assertEqual(new_param, param)
class TestLibtorchAgnostic(TestCase):
@classmethod
def setUpClass(cls):
install_cpp_extension(extension_root=Path(__file__).parent.parent)
@onlyCUDA
def test_identity_does_not_hog_memory(self, device):
def _run_identity(prior_mem):
t = torch.rand(32, 32, device=device)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
identi_t = libtorch_agnostic.ops.identity(t)
assert identi_t is t
@onlyCPU
def test_slow_sgd(self, device):
import libtorch_agnostic
init_mem = torch.cuda.memory_allocated(device)
param = torch.rand(5, device=device)
grad = torch.rand_like(param)
weight_decay = 0.01
lr = 0.001
maximize = False
for _ in range(3):
_run_identity(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
new_param = libtorch_agnostic.ops.sgd_out_of_place(
param, grad, weight_decay, lr, maximize
)
torch._fused_sgd_(
(param,),
(grad,),
(),
weight_decay=weight_decay,
momentum=0.0,
lr=lr,
dampening=0.0,
nesterov=False,
maximize=maximize,
is_first_step=False,
)
self.assertEqual(new_param, param)
def test_exp_neg_is_leaf(self, device):
t1 = torch.rand(2, 3, device=device)
t2 = torch.rand(3, 2, device=device)
t3 = torch.rand(2, device=device)
@onlyCUDA
def test_identity_does_not_hog_memory(self, device):
import libtorch_agnostic
exp, neg, is_leaf = libtorch_agnostic.ops.exp_neg_is_leaf(t1, t2, t3)
self.assertEqual(exp, torch.exp(t1))
self.assertEqual(neg, torch.neg(t2))
self.assertEqual(is_leaf, t3.is_leaf)
def _run_identity(prior_mem):
t = torch.rand(32, 32, device=device)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
identi_t = libtorch_agnostic.ops.identity(t)
assert identi_t is t
def test_my_abs(self, device):
t = torch.rand(32, 16, device=device) - 0.5
cpu_t = libtorch_agnostic.ops.my_abs(t)
self.assertEqual(cpu_t, torch.abs(t))
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_abs(t)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.abs(t))
if t.is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
_run_identity(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
def test_my_ones_like(self, device):
t = torch.rand(3, 1, device=device) - 0.5
cpu_t = libtorch_agnostic.ops.my_ones_like(t, "cpu")
self.assertEqual(cpu_t, torch.ones_like(t, device="cpu"))
def test_exp_neg_is_leaf(self, device):
import libtorch_agnostic
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_ones_like(t, device)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.ones_like(t, device=device))
t1 = torch.rand(2, 3, device=device)
t2 = torch.rand(3, 2, device=device)
t3 = torch.rand(2, device=device)
if t.is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
exp, neg, is_leaf = libtorch_agnostic.ops.exp_neg_is_leaf(t1, t2, t3)
self.assertEqual(exp, torch.exp(t1))
self.assertEqual(neg, torch.neg(t2))
self.assertEqual(is_leaf, t3.is_leaf)
@onlyCUDA
def test_z_delete_torch_lib(self, device):
# Why the z + CUDA? THIS TEST MUST BE RUN LAST
# We are testing that unloading the library properly deletes the registrations, so running this test
# earlier will cause all other tests in this file to fail
lib = libtorch_agnostic.loaded_lib
def test_my_abs(self, device):
import libtorch_agnostic
# code for unloading a library inspired from
# https://stackoverflow.com/questions/19547084/can-i-explicitly-close-a-ctypes-cdll
lib_handle = lib._handle
lib.dlclose(lib_handle)
t = torch.rand(32, 16, device=device) - 0.5
cpu_t = libtorch_agnostic.ops.my_abs(t)
self.assertEqual(cpu_t, torch.abs(t))
t = torch.tensor([-2.0, 0.5])
with self.assertRaises(RuntimeError):
libtorch_agnostic.ops.identity(
t
) # errors as identity shouldn't be registered anymore
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_abs(t)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.abs(t))
if t.is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
# TODO: Debug this:
# torch._dynamo.exc.TorchRuntimeError: Dynamo failed to run FX node with fake tensors:
# call_function libtorch_agnostic.my_ones_like.default(*(FakeTensor(..., size=(3, 1)), 'cpu'),
# **{}): got AssertionError("tensor's device must be `meta`, got cpu instead")
@xfailIfTorchDynamo
def test_my_ones_like(self, device):
import libtorch_agnostic
t = torch.rand(3, 1, device=device) - 0.5
cpu_t = libtorch_agnostic.ops.my_ones_like(t, "cpu")
self.assertEqual(cpu_t, torch.ones_like(t, device="cpu"))
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_ones_like(t, device)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.ones_like(t, device=device))
if t.is_cuda:
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
@onlyCUDA
def test_z_delete_torch_lib(self, device):
import libtorch_agnostic
# Why the z + CUDA? THIS TEST MUST BE RUN LAST
# We are testing that unloading the library properly deletes the registrations, so running this test
# earlier will cause all other tests in this file to fail
lib = libtorch_agnostic.loaded_lib
# code for unloading a library inspired from
# https://stackoverflow.com/questions/19547084/can-i-explicitly-close-a-ctypes-cdll
lib_handle = lib._handle
lib.dlclose(lib_handle)
t = torch.tensor([-2.0, 0.5])
with self.assertRaises(RuntimeError):
libtorch_agnostic.ops.identity(
t
) # errors as identity shouldn't be registered anymore
instantiate_device_type_tests(TestLibtorchAgnostic, globals(), except_for=None)
if __name__ == "__main__":
run_tests()

View File

@ -396,7 +396,15 @@ AOT_DISPATCH_TESTS = [
]
FUNCTORCH_TESTS = [test for test in TESTS if test.startswith("functorch")]
ONNX_TESTS = [test for test in TESTS if test.startswith("onnx")]
CPP_TESTS = [test for test in TESTS if test.startswith(CPP_TEST_PREFIX)]
def _is_cpp_test(test):
# Note: tests underneath cpp_extensions are different from other cpp tests
# in that they utilize the usual python test infrastructure.
return test.startswith(CPP_TEST_PREFIX) and not test.startswith("cpp_extensions")
CPP_TESTS = [test for test in TESTS if _is_cpp_test(test)]
TESTS_REQUIRING_LAPACK = [
"distributions/test_constraints",
@ -467,7 +475,7 @@ def run_test(
stepcurrent_key = test_file
is_distributed_test = test_file.startswith(DISTRIBUTED_TEST_PREFIX)
is_cpp_test = test_file.startswith(CPP_TEST_PREFIX)
is_cpp_test = _is_cpp_test(test_file)
# NB: Rerun disabled tests depends on pytest-flakefinder and it doesn't work with
# pytest-cpp atm. We also don't have support to disable C++ test yet, so it's ok
# to just return successfully here
@ -543,7 +551,7 @@ def run_test(
# case such as coverage for C++ test. So just returning ok makes sense
return 0
if test_file.startswith(CPP_TEST_PREFIX):
if is_cpp_test:
# C++ tests are not the regular test directory
if CPP_TESTS_DIR:
cpp_test = os.path.join(

View File

@ -215,88 +215,6 @@ class TestCppExtensionAOT(common.TestCase):
missing_symbols = subprocess.check_output(["nm", "-u", so_file]).decode("utf-8")
self.assertFalse("Py" in missing_symbols)
@unittest.skipIf(not TEST_CUDA, "some aspects of this test require CUDA")
def test_libtorch_agnostic(self):
import libtorch_agnostic
# (1) first test that SGD CPU kernel works
param = torch.rand(5, device="cpu")
grad = torch.rand_like(param)
weight_decay = 0.01
lr = 0.001
maximize = False
new_param = libtorch_agnostic.ops.sgd_out_of_place(
param, grad, weight_decay, lr, maximize
)
torch._fused_sgd_(
(param,),
(grad,),
(),
weight_decay=weight_decay,
momentum=0.0,
lr=lr,
dampening=0.0,
nesterov=False,
maximize=maximize,
is_first_step=False,
)
self.assertEqual(new_param, param)
# (2) then test that we don't hog unnecessary memory
def _run_identity(prior_mem, device):
t = torch.rand(32, 32, device=device)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
identi_t = libtorch_agnostic.ops.identity(t)
assert identi_t is t
device = torch.cuda.current_device()
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_run_identity(init_mem, device)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
# (3a) test calling our dispatcher on easy API like abs
t = torch.rand(32, 16, device=device) - 0.5
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_abs(t)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.abs(t))
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
# (3b) and on factory API like ones_like
cpu_t = libtorch_agnostic.ops.my_ones_like(t, "cpu")
self.assertEqual(cpu_t, torch.ones_like(t, device="cpu"))
def _make_cuda_tensors(prior_mem):
cuda_t = libtorch_agnostic.ops.my_ones_like(t, t.device)
self.assertGreater(torch.cuda.memory_allocated(device), prior_mem)
self.assertEqual(cuda_t, torch.ones_like(t, device=t.device))
init_mem = torch.cuda.memory_allocated(device)
for _ in range(3):
_make_cuda_tensors(init_mem)
curr_mem = torch.cuda.memory_allocated(device)
self.assertEqual(curr_mem, init_mem)
# (4) test multiple returns
t1 = torch.rand(2, 3, device="cuda")
t2 = torch.rand(3, 2, device="cpu")
t3 = torch.rand(2, device="cpu")
exp, neg, is_leaf = libtorch_agnostic.ops.exp_neg_is_leaf(t1, t2, t3)
self.assertEqual(exp, torch.exp(t1))
self.assertEqual(neg, torch.neg(t2))
self.assertEqual(is_leaf, t3.is_leaf)
@torch.testing._internal.common_utils.markDynamoStrictTest
class TestPybindTypeCasters(common.TestCase):

View File

@ -5578,7 +5578,7 @@ def check_leaked_tensors(limit=1, matched_type=torch.Tensor):
f"{num_garbage_objs} tensors were found in the garbage. Did you introduce a reference cycle?"
)
try:
import objgraph # type: ignore[import-not-found]
import objgraph # type: ignore[import-not-found,import-untyped]
warnings.warn(
f"Dumping first {limit} objgraphs of leaked {matched_type}s rendered to png"
)
@ -5604,6 +5604,36 @@ def remove_cpp_extensions_build_root():
else:
shutil.rmtree(default_build_root, ignore_errors=True)
def install_cpp_extension(extension_root):
# Wipe the build / install dirs if they exist
build_dir = os.path.join(extension_root, "build")
install_dir = os.path.join(extension_root, "install")
for d in (build_dir, install_dir):
if os.path.exists(d):
shutil.rmtree(d)
# Build the extension
setup_py_path = os.path.join(extension_root, "setup.py")
cmd = [sys.executable, setup_py_path, "install", "--root", install_dir]
return_code = shell(cmd, cwd=extension_root, env=os.environ)
if return_code != 0:
raise RuntimeError(f"build failed for cpp extension at {extension_root}")
mod_install_dir = None
# install directory is the one that is named site-packages
for root, directories, _ in os.walk(install_dir):
for directory in directories:
if "-packages" in directory:
mod_install_dir = os.path.join(root, directory)
if mod_install_dir is None:
raise RuntimeError(f"installation failed for cpp extension at {extension_root}")
if mod_install_dir not in sys.path:
sys.path.insert(0, mod_install_dir)
# Decorator to provide a helper to load inline extensions to a temp directory
def scoped_load_inline(func):