mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[c10d] Fix extra CUDA context created by barrier (#149144)"
This reverts commit 457fa820ad.
Reverted https://github.com/pytorch/pytorch/pull/149144 on behalf of https://github.com/huydhn due to Internal failure looks legit ([comment](https://github.com/pytorch/pytorch/pull/149144#issuecomment-2852564660))
This commit is contained in:
parent
2ce6d169fc
commit
cc954848d4
|
|
@ -3516,6 +3516,17 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase):
|
|||
|
||||
c10d.barrier(device_ids=[self.rank])
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_nccl_barrier_device_ids_function_argument(self):
|
||||
store = c10d.FileStore(self.file_name, self.world_size)
|
||||
c10d.init_process_group(
|
||||
backend="nccl", rank=self.rank, world_size=self.world_size, store=store
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(TypeError, "Invalid function argument"):
|
||||
c10d.barrier(device_ids=self.rank)
|
||||
|
||||
@requires_nccl()
|
||||
@skip_if_lt_x_gpu(2)
|
||||
def test_unwaited(self) -> None:
|
||||
|
|
|
|||
|
|
@ -4730,7 +4730,7 @@ def barrier(
|
|||
group (ProcessGroup, optional): The process group to work on. If None,
|
||||
the default process group will be used.
|
||||
async_op (bool, optional): Whether this op should be an async op
|
||||
device_ids ([int], optional): List of device/GPU ids. Only one id is expected.
|
||||
device_ids ([int], optional): List of device/GPU ids.
|
||||
|
||||
Returns:
|
||||
Async work handle, if async_op is set to True.
|
||||
|
|
@ -4738,35 +4738,22 @@ def barrier(
|
|||
|
||||
.. note:: `ProcessGroupNCCL` now blocks the cpu thread till the completion of the barrier collective.
|
||||
"""
|
||||
group = group or _get_default_group()
|
||||
|
||||
if _rank_not_in_group(group):
|
||||
_warn_not_in_group("barrier")
|
||||
return
|
||||
|
||||
opts = BarrierOptions()
|
||||
opts.device = torch.device(_get_object_coll_device(group))
|
||||
opts.asyncOp = async_op
|
||||
# Detect the accelerator on the machine. If no accelerator is available, it
|
||||
# returns CPU.
|
||||
device = torch._C._get_accelerator()
|
||||
if isinstance(device_ids, list):
|
||||
opts.device_ids = device_ids
|
||||
# use only the first device id
|
||||
opts.device = torch.device(device.type, device_ids[0])
|
||||
elif getattr(group, "bound_device_id", None) is not None:
|
||||
# Use device id from `init_process_group(device_id=...)`
|
||||
opts.device = group.bound_device_id # type: ignore[assignment]
|
||||
elif device.type == "cpu" or _get_object_coll_device(group) == "cpu":
|
||||
opts.device = torch.device("cpu")
|
||||
else:
|
||||
# Use the current device set by the user. If user did not set any, this
|
||||
# may use default device 0, causing issues like hang or all processes
|
||||
# creating context on device 0.
|
||||
opts.device = device
|
||||
warnings.warn( # warn only once
|
||||
"No device id is provided via `init_process_group` or `barrier `. Using the current device set by the user. "
|
||||
)
|
||||
if device_ids is not None:
|
||||
if isinstance(device_ids, list):
|
||||
opts.device_ids = device_ids
|
||||
else:
|
||||
raise TypeError(
|
||||
"Invalid function argument: device_ids type should be List[int]"
|
||||
)
|
||||
|
||||
group = group or _get_default_group()
|
||||
work = group.barrier(opts=opts)
|
||||
|
||||
if async_op:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user