mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[CI][CUDA][Distributed][FSDP] Remove hardcoded world size of 2 (#145195)
as these unit tests would fail if run on a single GPU (i.e**. skip_if_lt_x_gpu(2)) seems to view world size as 2 even on platforms with 1 GPU.** Pull Request resolved: https://github.com/pytorch/pytorch/pull/145195 Approved by: https://github.com/Skylion007, https://github.com/atalman
This commit is contained in:
parent
505ade7471
commit
df67ac4c86
|
|
@ -40,6 +40,10 @@ _DISTRIBUTED_STATE_DICT_IMPLS = {
|
|||
class TestDistributedCheckpoint(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
if torch.cuda.is_available():
|
||||
gpu_cnt = torch.cuda.device_count()
|
||||
if gpu_cnt < 2:
|
||||
return gpu_cnt
|
||||
return 2
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
|
|||
|
|
@ -36,6 +36,10 @@ device_type = torch.device(get_devtype())
|
|||
class TestApply(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
if torch.cuda.is_available():
|
||||
gpu_cnt = torch.cuda.device_count()
|
||||
if gpu_cnt < 2:
|
||||
return gpu_cnt
|
||||
return 2
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# Owner(s): ["oncall: distributed"]
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch import distributed as dist
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
|
|
@ -28,6 +29,10 @@ if TEST_WITH_DEV_DBG_ASAN:
|
|||
class TestTraversal(FSDPTest):
|
||||
@property
|
||||
def world_size(self):
|
||||
if torch.cuda.is_available():
|
||||
gpu_cnt = torch.cuda.device_count()
|
||||
if gpu_cnt < 2:
|
||||
return gpu_cnt
|
||||
return 2
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user