pytorch/torch/distributed/rpc/_utils.py
Aaron Gokaslan 4f4ecc583e [BE]: Enable RUFF TRY400 rule - log.exception (#153473)
Change logging.error to logging.exception to log additional information when relevant.  A few places have slipped in logging.errors in try except since I last did a clean up here and the rule is stabilized so I am enabling it codebase wide. I have NOQA'd much of our custom exception stack trace handling for RPC calls and distributed and tried to a fix a few errors based on whether we immediately reraised it or if we didn't print any exception handling where it could be useful.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/153473
Approved by: https://github.com/albanD, https://github.com/cyyever
2025-05-15 13:36:59 +00:00

48 lines
1.6 KiB
Python

# mypy: allow-untyped-defs
import logging
from contextlib import contextmanager
from typing import cast
from . import api, TensorPipeAgent
logger = logging.getLogger(__name__)
@contextmanager
def _group_membership_management(store, name, is_join):
token_key = "RpcGroupManagementToken"
join_or_leave = "join" if is_join else "leave"
my_token = f"Token_for_{name}_{join_or_leave}"
while True:
# Retrieve token from store to signal start of rank join/leave critical section
returned = store.compare_set(token_key, "", my_token).decode()
if returned == my_token:
# Yield to the function this context manager wraps
yield
# Finished, now exit and release token
# Update from store to signal end of rank join/leave critical section
store.set(token_key, "")
# Other will wait for this token to be set before they execute
store.set(my_token, "Done")
break
else:
# Store will wait for the token to be released
try:
store.wait([returned])
except RuntimeError:
logger.error( # noqa: TRY400
"Group membership token %s timed out waiting for %s to be released.",
my_token,
returned,
)
raise
def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
ret = agent._update_group_membership(
worker_info, my_devices, reverse_device_map, is_join
)
return ret