pytorch/torch/distributed/_composable_state.py
Maggie Moss 7457d139c5 Add pyrefly suppressions to torch/distributed (7/n) (#165002)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

One more PR after this one.

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165002
Approved by: https://github.com/oulgen
2025-10-09 04:08:25 +00:00

46 lines
1.4 KiB
Python

import weakref
from typing import cast, Optional
import torch.nn as nn
class _State:
pass
_module_state_mapping: weakref.WeakKeyDictionary[
nn.Module, weakref.ReferenceType[_State]
] = weakref.WeakKeyDictionary()
def _insert_module_state(module: nn.Module, state: _State) -> None:
global _module_state_mapping
assert module not in _module_state_mapping, f"Inserting {module} more than once."
_module_state_mapping[module] = weakref.ref(state)
def _get_module_state(module: nn.Module) -> Optional[_State]:
"""
Return the ``_State`` in ``model``.
Given a ``module``, this API finds out if the module is also a ``_State``
instance or if the module is managed by a composable API. If the module
is also a ``_State``, ``module`` will be casted to ``_State` and returned.
If it is managed by a composable API, the corresponding ``_State`` will
be returned.
"""
global _module_state_mapping
if isinstance(module, _State):
# pyrefly: ignore # redundant-cast
return cast(_State, module)
else:
# https://github.com/pytorch/pytorch/issues/107054
if module in _module_state_mapping:
state_ref = _module_state_mapping[module]
state = state_ref()
if state is None:
raise AssertionError("State has already been garbage collected")
return state
else:
return None