[CI][CUDA][Distributed][FSDP] Remove hardcoded world size of 2 (#145195)

as these unit tests would fail if run

on a single GPU (i.e**. skip_if_lt_x_gpu(2)) seems to view world size as 2 even on platforms with 1 GPU.**

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145195
Approved by: https://github.com/Skylion007, https://github.com/atalman
This commit is contained in:
Wei Wang 2025-01-21 20:25:49 +00:00 committed by PyTorch MergeBot
parent 505ade7471
commit df67ac4c86
3 changed files with 13 additions and 0 deletions

View File

@ -40,6 +40,10 @@ _DISTRIBUTED_STATE_DICT_IMPLS = {
class TestDistributedCheckpoint(FSDPTest):
@property
def world_size(self):
if torch.cuda.is_available():
gpu_cnt = torch.cuda.device_count()
if gpu_cnt < 2:
return gpu_cnt
return 2
@skip_if_lt_x_gpu(2)

View File

@ -36,6 +36,10 @@ device_type = torch.device(get_devtype())
class TestApply(FSDPTest):
@property
def world_size(self):
if torch.cuda.is_available():
gpu_cnt = torch.cuda.device_count()
if gpu_cnt < 2:
return gpu_cnt
return 2
@torch.no_grad()

View File

@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
import sys
import torch
from torch import distributed as dist
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_device_type import instantiate_device_type_tests
@ -28,6 +29,10 @@ if TEST_WITH_DEV_DBG_ASAN:
class TestTraversal(FSDPTest):
@property
def world_size(self):
if torch.cuda.is_available():
gpu_cnt = torch.cuda.device_count()
if gpu_cnt < 2:
return gpu_cnt
return 2
@skip_if_lt_x_gpu(2)