# Owner(s): ["module: ci"] import subprocess import sys import unittest import pathlib from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX, IS_CI REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent try: # Just in case PyTorch was not built in 'develop' mode sys.path.append(str(REPO_ROOT)) from tools.stats.scribe import rds_write, register_rds_schema except ImportError: register_rds_schema = None rds_write = None # these tests could eventually be changed to fail if the import/init # time is greater than a certain threshold, but for now we just use them # as a way to track the duration of `import torch` in our ossci-metrics # S3 bucket (see tools/stats/print_test_stats.py) class TestImportTime(TestCase): def test_time_import_torch(self): TestCase.runWithPytorchAPIUsageStderr("import torch") def test_time_cuda_device_count(self): TestCase.runWithPytorchAPIUsageStderr( "import torch; torch.cuda.device_count()", ) @unittest.skipIf(not IS_LINUX, "Memory test is only implemented for Linux") @unittest.skipIf(not IS_CI, "Memory test only runs in CI") @unittest.skipIf(rds_write is None, "Cannot import rds_write from tools.stats.scribe") def test_peak_memory(self): def profile(module, name): command = f"import {module}; import resource; print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)" result = subprocess.run( [sys.executable, "-c", command], stdout=subprocess.PIPE, ) max_rss = int(result.stdout.decode().strip()) return { "test_name": name, "peak_memory_bytes": max_rss, } data = profile("torch", "pytorch") baseline = profile("sys", "baseline") try: rds_write("import_stats", [data, baseline]) except Exception as e: raise unittest.SkipTest(f"Failed to record import_stats: {e}") if __name__ == "__main__": if register_rds_schema and IS_CI: try: register_rds_schema( "import_stats", { "test_name": "string", "peak_memory_bytes": "int", "time_ms": "int", }, ) except Exception as e: print(f"Failed to register RDS schema: {e}") run_tests()