Refactor cudnn version check in smoke test for Windows (#150015)

After https://github.com/pytorch/pytorch/pull/149885

I see failures on Window smoke test:
https://github.com/pytorch/test-infra/actions/runs/14069923716/job/39401550854

Due to fact that pypi packages such as cudnn and nccl are installed only on Linux. Hence this should resolve issue on Windows platform.
On windows cudnn is shipped with PyTorch as opposed to installed dynamically.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150015
Approved by: https://github.com/ZainRizvi
This commit is contained in:
atalman 2025-03-26 15:15:46 +00:00 committed by PyTorch MergeBot
parent 8a40fca9a1
commit 7336b76bcc

View File

@ -259,20 +259,21 @@ def smoke_test_cuda(
) )
print(f"torch cuda: {torch.version.cuda}") print(f"torch cuda: {torch.version.cuda}")
print(f"cuDNN enabled? {torch.backends.cudnn.enabled}")
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
compare_pypi_to_torch_versions(
"cudnn", find_pypi_package_version("nvidia-cudnn"), torch_cudnn_version
)
torch.cuda.init() torch.cuda.init()
print("CUDA initialized successfully") print("CUDA initialized successfully")
print(f"Number of CUDA devices: {torch.cuda.device_count()}") print(f"Number of CUDA devices: {torch.cuda.device_count()}")
for i in range(torch.cuda.device_count()): for i in range(torch.cuda.device_count()):
print(f"Device {i}: {torch.cuda.get_device_name(i)}") print(f"Device {i}: {torch.cuda.get_device_name(i)}")
# nccl is availbale only on Linux print(f"cuDNN enabled? {torch.backends.cudnn.enabled}")
torch_cudnn_version = cudnn_to_version_str(torch.backends.cudnn.version())
print(f"Torch cuDNN version: {torch_cudnn_version}")
# Pypi dependencies are installed on linux ony and nccl is availbale only on Linux.
if sys.platform in ["linux", "linux2"]: if sys.platform in ["linux", "linux2"]:
compare_pypi_to_torch_versions(
"cudnn", find_pypi_package_version("nvidia-cudnn"), torch_cudnn_version
)
torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version()) torch_nccl_version = ".".join(str(v) for v in torch.cuda.nccl.version())
compare_pypi_to_torch_versions( compare_pypi_to_torch_versions(
"nccl", find_pypi_package_version("nvidia-nccl"), torch_nccl_version "nccl", find_pypi_package_version("nvidia-nccl"), torch_nccl_version