mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
The conventional env var to set is CI. Both circle and GHA set it, so IN_CI is unnecessary Pull Request resolved: https://github.com/pytorch/pytorch/pull/79229 Approved by: https://github.com/janeyx99
75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
# Owner(s): ["module: ci"]
|
|
|
|
import subprocess
|
|
import sys
|
|
import unittest
|
|
import pathlib
|
|
|
|
from torch.testing._internal.common_utils import TestCase, run_tests, IS_LINUX, IS_CI
|
|
|
|
|
|
REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
|
|
|
try:
|
|
# Just in case PyTorch was not built in 'develop' mode
|
|
sys.path.append(str(REPO_ROOT))
|
|
from tools.stats.scribe import rds_write, register_rds_schema
|
|
except ImportError:
|
|
register_rds_schema = None
|
|
rds_write = None
|
|
|
|
|
|
# these tests could eventually be changed to fail if the import/init
|
|
# time is greater than a certain threshold, but for now we just use them
|
|
# as a way to track the duration of `import torch` in our ossci-metrics
|
|
# S3 bucket (see tools/stats/print_test_stats.py)
|
|
class TestImportTime(TestCase):
|
|
def test_time_import_torch(self):
|
|
TestCase.runWithPytorchAPIUsageStderr("import torch")
|
|
|
|
def test_time_cuda_device_count(self):
|
|
TestCase.runWithPytorchAPIUsageStderr(
|
|
"import torch; torch.cuda.device_count()",
|
|
)
|
|
|
|
@unittest.skipIf(not IS_LINUX, "Memory test is only implemented for Linux")
|
|
@unittest.skipIf(not IS_CI, "Memory test only runs in CI")
|
|
@unittest.skipIf(rds_write is None, "Cannot import rds_write from tools.stats.scribe")
|
|
def test_peak_memory(self):
|
|
def profile(module, name):
|
|
command = f"import {module}; import resource; print(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss)"
|
|
result = subprocess.run(
|
|
[sys.executable, "-c", command],
|
|
stdout=subprocess.PIPE,
|
|
)
|
|
max_rss = int(result.stdout.decode().strip())
|
|
|
|
return {
|
|
"test_name": name,
|
|
"peak_memory_bytes": max_rss,
|
|
}
|
|
|
|
data = profile("torch", "pytorch")
|
|
baseline = profile("sys", "baseline")
|
|
try:
|
|
rds_write("import_stats", [data, baseline])
|
|
except Exception as e:
|
|
raise unittest.SkipTest(f"Failed to record import_stats: {e}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
if register_rds_schema and IS_CI:
|
|
try:
|
|
register_rds_schema(
|
|
"import_stats",
|
|
{
|
|
"test_name": "string",
|
|
"peak_memory_bytes": "int",
|
|
"time_ms": "int",
|
|
},
|
|
)
|
|
except Exception as e:
|
|
print(f"Failed to register RDS schema: {e}")
|
|
|
|
run_tests()
|