Use correct pyrefly syntax in suppressions distributed/... (#166241)

Updates the pyrefy-ignores in the torch/distributed directory to use the correct syntax. No functional changes.

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166241
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss 2025-10-26 04:16:38 +00:00 committed by PyTorch MergeBot
parent cdb60e44eb
commit 8f80892359
91 changed files with 261 additions and 261 deletions

View File

@ -133,7 +133,7 @@ if is_available():
# Variables prefixed with underscore are not auto imported
# See the comment in `distributed_c10d.py` above `_backend` on why we expose
# this.
# pyrefly: ignore # deprecated
# pyrefly: ignore [deprecated]
from .distributed_c10d import * # noqa: F403
from .distributed_c10d import ( # pyrefly: ignore # deprecated
_all_gather_base,

View File

@ -107,7 +107,7 @@ def contract(
# If the user passes a sequence of modules, then we assume that
# we only need to insert the state object on the root modules
# (i.e. those without a parent) among the passed-in modules.
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
modules = _get_root_modules(list(module))
state = state_cls() # shared across all modules
registry_item = RegistryItem() # shared across all modules
@ -119,7 +119,7 @@ def contract(
all_orig_named_buffers: list[dict[str, torch.Tensor]] = []
all_orig_named_modules: list[dict[str, nn.Module]] = []
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for module in modules:
default_all_state: dict[Callable, _State] = OrderedDict()
default_registry: dict[str, RegistryItem] = OrderedDict()
@ -146,11 +146,11 @@ def contract(
all_state.setdefault(func, state)
registry.setdefault(func.__name__, registry_item)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
all_orig_named_params.append(OrderedDict(module.named_parameters()))
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
all_orig_named_buffers.append(OrderedDict(module.named_buffers()))
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
all_orig_named_modules.append(OrderedDict(module.named_modules()))
updated = func(inp_module, *args, **kwargs)
@ -165,13 +165,13 @@ def contract(
all_new_named_params: list[dict[str, nn.Parameter]] = []
all_new_named_buffers: list[dict[str, torch.Tensor]] = []
all_new_named_modules: list[dict[str, nn.Module]] = []
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for module in updated_modules:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
all_new_named_params.append(OrderedDict(module.named_parameters()))
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
all_new_named_buffers.append(OrderedDict(module.named_buffers()))
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
all_new_named_modules.append(OrderedDict(module.named_modules()))
num_orig_modules = len(all_orig_named_modules)
@ -234,7 +234,7 @@ def contract(
# TODO: verify that installed distributed paradigms are compatible with
# each other.
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return updated
def get_state(module: nn.Module) -> _State:

View File

@ -100,7 +100,7 @@ class _ReplicateState(FSDPState):
for module in modules:
_insert_module_state(module, self)
self._modules = modules
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self._device = device
self._device_handle = _get_device_handle(device.type)
self._mp_policy = mp_policy
@ -151,7 +151,7 @@ class _ReplicateState(FSDPState):
)
state._is_root = False
self._state_ctx.all_states.append(state)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
visited_states.add(state)
if self._fsdp_param_group and self._auto_reshard_after_forward:
# For the root, do not reshard after forward since for training,

View File

@ -31,7 +31,7 @@ def _get_module_state(module: nn.Module) -> Optional[_State]:
"""
global _module_state_mapping
if isinstance(module, _State):
# pyrefly: ignore # redundant-cast
# pyrefly: ignore [redundant-cast]
return cast(_State, module)
else:
# https://github.com/pytorch/pytorch/issues/107054

View File

@ -633,7 +633,7 @@ class AsyncCollectiveTensor(torch.Tensor):
if func == torch.ops.aten.view.default:
# Fast handle aten.view as a lot of view related op goes to aten.view
# eventually, this avoids pytree slowdown
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
res = func(args[0].elem, args[1])
wrapper_res = AsyncCollectiveTensor(res)
return wrapper_res
@ -787,7 +787,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str:
FutureWarning,
stacklevel=3,
)
# pyrefly: ignore # redundant-cast
# pyrefly: ignore [redundant-cast]
return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag)
else:
raise ValueError(f"Unsupported group type: {type(group)}, {group}")
@ -1166,10 +1166,10 @@ def all_gather_inplace(
for t in tensor_list:
is_scalar = t.dim() == 0
t_offset = 1 if is_scalar else t.size(0)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
out = output[offset] if is_scalar else output[offset : offset + t_offset]
output_splits.append(out)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
offset += t_offset
for dst, src in zip(tensor_list, output_splits):
dst.copy_(src)

View File

@ -316,7 +316,7 @@ def _local_all_gather_(
assert len(input_tensors) == 1
input_tensor = input_tensors[0]
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
output_tensors = output_tensors[0]
ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so)
@ -337,12 +337,12 @@ def _local_all_gather_(
source_tensor = input_tensor
if isinstance(input_tensor, LocalTensor):
source_tensor = input_tensor._local_tensors[rank_i]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
output_tensors[i].copy_(source_tensor)
work = FakeWork()
work_so = Work.boxed(work)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return ([output_tensors], work_so)
@ -429,7 +429,7 @@ def _local_scatter_(
assert len(output_tensors) == 1
assert len(input_tensors) == 1
output_tensor = output_tensors[0]
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
input_tensors = input_tensors[0]
ranks, group_offsets, offset = _prepare_collective_groups(process_group_so)

View File

@ -39,9 +39,9 @@ class _MeshLayout(Layout):
different from that of PyCute's.
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
shape: IntTuple
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
stride: IntTuple
def __post_init__(self) -> None:

View File

@ -43,7 +43,7 @@ def _sharded_op_common(op, early_stop_func, extra_check):
def wrapper(types, args=(), kwargs=None, pg=None):
_basic_validation(op, args, kwargs)
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
st = args[0]
if kwargs is None:
kwargs = {}
@ -93,7 +93,7 @@ def _register_sharded_op_on_local_shards(
@_sharded_op_impl(op)
@_sharded_op_common(op, early_stop_func, extra_check)
def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
st = args[0]
st_metadata = st.metadata()
local_shards = st.local_shards()

View File

@ -20,13 +20,13 @@ def uniform_(types, args=(), kwargs=None, pg=None):
b: the upper bound of the uniform distribution
"""
validate_param(kwargs, "kwargs")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
a = kwargs["a"]
validate_param(a, "a")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
b = kwargs["b"]
validate_param(b, "b")
@ -46,13 +46,13 @@ def normal_(types, args=(), kwargs=None, pg=None):
std: the standard deviation of the normal distribution
"""
validate_param(kwargs, "kwargs")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
mean = kwargs["mean"]
validate_param(mean, "mean")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
std = kwargs["std"]
validate_param(std, "std")
@ -84,16 +84,16 @@ def kaiming_uniform_(types, args=(), kwargs=None, pg=None):
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
validate_param(kwargs, "kwargs")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
a = kwargs["a"]
validate_param(a, "a")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
mode = kwargs["mode"]
validate_param(mode, "mode")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
nonlinearity = kwargs["nonlinearity"]
validate_param(nonlinearity, "nonlinearity")
@ -113,10 +113,10 @@ def constant_(types, args=(), kwargs=None, pg=None):
val: the value to fill the tensor with
"""
validate_param(kwargs, "kwargs")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
sharded_tensor = kwargs["tensor"]
validate_param(sharded_tensor, "tensor")
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
val = kwargs["val"]
validate_param(val, "val")
for shard in sharded_tensor.local_shards():
@ -149,7 +149,7 @@ def register_tensor_creation_op(op):
if kwargs is None:
kwargs = {}
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
st = args[0]
new_st = creation_op(st.sharding_spec(), st.size(), *args[1:], **kwargs) # type: ignore[operator]

View File

@ -40,7 +40,7 @@ _register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ig
# the device property on each rank
@_sharded_op_impl(torch.Tensor.device.__get__)
def tensor_device(types, args=(), kwargs=None, pg=None):
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
self_st = args[0]
# Validate types
if not isinstance(self_st, ShardedTensor):
@ -57,7 +57,7 @@ def tensor_device(types, args=(), kwargs=None, pg=None):
@_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined]
def st_is_meta(types, args=(), kwargs=None, pg=None):
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
return args[0].local_tensor().is_meta
@ -198,7 +198,7 @@ _register_sharded_op_on_local_shards(
@_sharded_op_impl(torch.Tensor.requires_grad_)
def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None):
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
self_st = args[0]
# Validate types
if not isinstance(self_st, ShardedTensor):

View File

@ -299,9 +299,9 @@ class ShardedTensor(ShardedTensorBase):
if self._init_rrefs:
with _sharded_tensor_lock:
global _sharded_tensor_current_id, _sharded_tensor_map
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._sharded_tensor_id = _sharded_tensor_current_id
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
_sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self)
_sharded_tensor_current_id += 1

View File

@ -167,7 +167,7 @@ class ChunkShardingSpec(ShardingSpec):
)
tensors_to_scatter[
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
dist.get_group_rank(process_group, remote_global_rank)
] = tensor_to_scatter

View File

@ -58,7 +58,7 @@ def _register_sharded_op_on_local_tensor(
@custom_sharding_spec_op(ChunkShardingSpec, op)
@_sharded_op_common(op, early_stop_func, extra_check)
def sharded_tensor_op_on_local_tensor(types, args=(), kwargs=None, pg=None):
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
st = args[0]
sharding_spec = st.sharding_spec()
if len(st.local_shards()) != 1:

View File

@ -425,9 +425,9 @@ def _handle_row_wise_sharding(
else:
split_sizes = torch.cat(
(
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
offsets[1 : offsets.size(0)] - offsets[0:-1],
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
(input.size(0) - offsets[-1]).unsqueeze(0),
),
dim=-1,

View File

@ -195,13 +195,13 @@ def _iterate_state_dict(
ret.local_shards()[idx].tensor, non_blocking=non_blocking
)
else:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
companion_obj.copy_(ret, non_blocking=non_blocking)
ret = companion_obj
else:
ret = {} if isinstance(ret, dict) else None
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return ret
@ -799,7 +799,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
)
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
extend_list(cur_container, prev_key)
if cur_container[prev_key] is None:
cur_container[prev_key] = def_val

View File

@ -1848,7 +1848,7 @@ def empty(
@overload
# pyrefly: ignore # inconsistent-overload
# pyrefly: ignore [inconsistent-overload]
def empty(
size: Sequence[_int],
*,

View File

@ -231,7 +231,7 @@ class FSDPMemTracker(MemTracker):
" or file a github issue if you need this feature."
)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
args, kwargs = orig_fsdp_state_pre_fw(*args, **kwargs)
fsdp_state = fsdp_mod._get_fsdp_state()
@ -365,7 +365,7 @@ class FSDPMemTracker(MemTracker):
# `FSDPParamGroup.post_forward` because during AC these won't be called.
# TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786)
# lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`.
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for module in self._root_mod.modules():
if isinstance(module, FSDPModule):
fsdp_state = module._get_fsdp_state()
@ -374,7 +374,7 @@ class FSDPMemTracker(MemTracker):
fsdp_state._pre_forward_hook_handle.remove()
fsdp_state._post_forward_hook_handle.remove()
fsdp_state._pre_forward_hook_handle = (
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
module.register_forward_pre_hook(
self._fsdp_state_pre_forward(
module, fsdp_state._pre_forward
@ -383,7 +383,7 @@ class FSDPMemTracker(MemTracker):
with_kwargs=True,
)
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
fsdp_state._post_forward_hook_handle = module.register_forward_hook(
self._fsdp_state_post_forward(module, fsdp_state._post_forward),
prepend=False,
@ -402,7 +402,7 @@ class FSDPMemTracker(MemTracker):
)
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for buffer in self._root_mod.buffers():
self._update_and_maybe_create_winfos(
buffer,
@ -512,7 +512,7 @@ class FSDPMemTracker(MemTracker):
):
# N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns
# a new tensor which does not happen in eager mode, when a wait_tensor is called.
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
res = args[0]
else:
res = func(*args, **kwargs or {})
@ -529,7 +529,7 @@ class FSDPMemTracker(MemTracker):
_FSDPState.PRE_FW,
_FSDPState.PRE_BW,
]:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
output_tensor = args[0]
self._update_and_maybe_create_winfos(
output_tensor,
@ -540,7 +540,7 @@ class FSDPMemTracker(MemTracker):
func == c10d._reduce_scatter_base_.default
and self._fsdp_state == _FSDPState.POST_BW
):
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
input_tensor = args[1]
self._update_and_maybe_create_winfos(
input_tensor,

View File

@ -143,7 +143,7 @@ class _WeakRefInfo:
self.size = size
self.element_size = element_size
self.reftype = reftype
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.device = device
self.mem_consumed = self._calculate_mem_consumed()
@ -405,7 +405,7 @@ class MemTracker(TorchDispatchMode):
# Initialize a flag to track if the total memory might drop to zero after updates.
maybe_zero = False
# Ensure the device entry exists in the current memory snapshot, initializing if necessary.
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
dev_snap = self._curr_mem_snap.setdefault(
winfo.device, dict.fromkeys(self._ref_class, 0)
)
@ -917,7 +917,7 @@ class MemTracker(TorchDispatchMode):
self._depth += 1
return self
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def __exit__(self, *args: Any) -> None:
self._depth -= 1
if self._depth == 0:
@ -935,7 +935,7 @@ class MemTracker(TorchDispatchMode):
):
# N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns
# a new tensor which does not happen in eager mode, when a wait_tensor is called.
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
res = args[0]
else:
res = func(*args, **kwargs or {})

View File

@ -232,9 +232,9 @@ class MemoryTracker:
def _pre_forward_hook(module: nn.Module, inputs: Any) -> None:
self._cur_module_name = f"{name}.forward"
if (
# pyrefly: ignore # invalid-argument
# pyrefly: ignore [invalid-argument]
hasattr(module, "_memory_tracker_is_root")
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
and module._memory_tracker_is_root
):
self._add_marker("fw_start")
@ -250,9 +250,9 @@ class MemoryTracker:
outputs: Sequence[torch.Tensor],
) -> None:
if (
# pyrefly: ignore # invalid-argument
# pyrefly: ignore [invalid-argument]
hasattr(module, "_memory_tracker_is_root")
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
and module._memory_tracker_is_root
):
self._add_marker("fw_bw_boundary")

View File

@ -178,7 +178,7 @@ class ModTracker:
def custom_formatwarning(msg, category, filename, lineno, line=None):
return f"{filename}:{lineno}: {category.__name__}: {msg} \n"
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
warnings.formatwarning = custom_formatwarning
warnings.warn(
"The module hierarchy tracking maybe be messed up."

View File

@ -519,7 +519,7 @@ class RuntimeEstimator(TorchDispatchMode):
super().__enter__()
return self
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def __exit__(self, *args: Any) -> None:
print(
f"Estimated ({self._estimate_mode_type})"

View File

@ -429,7 +429,7 @@ class SACEstimator(TorchDispatchMode):
# sdpa has non-deterministic seed, but might be deterministic
# if no dropout is applied
if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention":
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
is_rand_op = kwargs.get("dropout_p", 0) != 0
# 5. Create metadata information per active non-leaf module
for mod_fqn in self._mod_tracker.parents:

View File

@ -65,7 +65,7 @@ def _dequantize_tensor(tensor, qtype, quant_loss=None):
elif tensor.dtype == torch.float16 and quant_loss is None:
return tensor.float()
else:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
return tensor.float() / quant_loss
elif qtype == DQuantType.BFP16:
if tensor.dtype != torch.float16:

View File

@ -22,7 +22,7 @@ def _allreduce_fut(
group_to_use = process_group if process_group is not None else dist.group.WORLD
# Apply the division first to avoid overflow, especially for FP16.
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
tensor.div_(group_to_use.size())
return (
@ -60,7 +60,7 @@ def _compress_hook(
bucket: dist.GradBucket,
) -> torch.futures.Future[torch.Tensor]:
group_to_use = process_group if process_group is not None else dist.group.WORLD
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
world_size = group_to_use.size()
buffer = (
@ -82,7 +82,7 @@ def _compress_hook(
grad = dist._functional_collectives.all_reduce(
compressed_tensor,
"sum",
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
group_to_use,
)
return decompress(grad)

View File

@ -66,7 +66,7 @@ def quantization_pertensor_hook(
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
world_size = group_to_use.size()
tensor = bucket.buffer()
@ -148,7 +148,7 @@ def quantization_perchannel_hook(
"""
group_to_use = process_group if process_group is not None else dist.group.WORLD
rank = process_group.rank() if process_group is not None else dist.get_rank()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
world_size = group_to_use.size()
tensor = bucket.buffer()

View File

@ -210,7 +210,7 @@ class Join:
"""
process_group = None
device = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for joinable in self._joinables:
if process_group is None:
process_group = joinable.join_process_group

View File

@ -74,7 +74,7 @@ class HybridModel(torch.nn.Module):
assert NUM_PS * EMBEDDING_DIM >= 512
dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512)
emb_lookups_reshaped = emb_lookups_cat.reshape( # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
[emb_lookups_cat.shape[0] * dim_normalizer, 512]
)

View File

@ -45,7 +45,7 @@ def _get_logging_handler(
return (log_handler, log_handler_name)
# pyrefly: ignore # unknown-name
# pyrefly: ignore [unknown-name]
global _c10d_logger
_c10d_logger = _get_or_create_logger()

View File

@ -13,9 +13,9 @@ from .optimizer import load_sharded_optimizer_state_dict
from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
from .quantized_hf_storage import QuantizedHuggingFaceStorageReader
# pyrefly: ignore # deprecated
# pyrefly: ignore [deprecated]
from .state_dict_loader import load, load_state_dict
# pyrefly: ignore # deprecated
# pyrefly: ignore [deprecated]
from .state_dict_saver import async_save, save, save_state_dict
from .storage import StorageReader, StorageWriter

View File

@ -313,7 +313,7 @@ class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor):
@_dcp_method_logger(**ckpt_kwargs)
def create_checkpoint_daemon_process() -> None:
global _CHECKPOINT_PROCESS
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_CHECKPOINT_PROCESS = _AsyncCheckpointProcess(pg_init_info=pg_init_info)
create_checkpoint_daemon_process()

View File

@ -322,7 +322,7 @@ class CheckpointProcess:
subprocess_pid = self.process.processes[0].pid
# send graceful termination to sub process
try:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self._parent_end.send(
WorkerRequest(
request_type=RequestType.TERMINATE_PROCESS,

View File

@ -176,7 +176,7 @@ class CheckpointReader:
# create a new map with all the keys present in source_value
target_value = dict.fromkeys(source_value.keys())
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for key in list(target_value.keys()):
current_path = f"{key_path}.{key}" if key_path else key
if key in source_value:

View File

@ -147,14 +147,14 @@ class DefaultStager(CheckpointStager):
self._staging_stream = None
if self._config.use_async_staging:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._staging_executor = ThreadPoolExecutor(max_workers=1)
if torch.accelerator.is_available():
# Note: stream needs to be initialized on the main thread after default cuda
# stream is setup/used to avoid the risk of accidentally reusing the main
# compute stream or in other cases kernels actually launching from the
# main thread.
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._staging_stream = torch.Stream()
if self._config.use_non_blocking_copy:

View File

@ -94,7 +94,7 @@ class ZStandard(StreamTransformExtension):
return zstandard is not None or pyzstd is not None
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def from_descriptor(version: str) -> "ZStandard":
if version.partition(".")[0] != "1":
raise ValueError(f"Unknown extension {version=}")
@ -217,7 +217,7 @@ class ExtensionRegistry:
ext = self.extensions.get(name)
if not ext:
raise ValueError(f"Unknown extension {name=}")
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return ext.from_descriptor(version)
return [from_descriptor(desc) for desc in descriptors]

View File

@ -128,7 +128,7 @@ def set_element(
CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
)
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
extend_list(cur_container, prev_key)
if cur_container[prev_key] is None:
cur_container[prev_key] = def_val
@ -155,7 +155,7 @@ def get_element(
elif not isinstance(cur_value, Mapping) or part not in cur_value:
return default_value
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
cur_value = cast(CONTAINER_TYPE, cur_value[part])
return cast(Optional[T], cur_value)

View File

@ -60,7 +60,7 @@ def _init_model(rank, world_size):
optim = torch.optim.Adam(model.parameters(), lr=0.0001)
_patch_model_state_dict(model)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_patch_optimizer_state_dict(model, optimizers=optim)
return model, optim
@ -93,7 +93,7 @@ def run(rank, world_size):
loss_calc = torch.nn.BCELoss()
f = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for epoch in range(NUM_EPOCHS):
try:
torch.manual_seed(epoch)

View File

@ -64,7 +64,7 @@ class BroadcastingTorchSaveReader(StorageReader):
self.checkpoint_id = checkpoint_id
self.coordinator_rank = coordinator_rank
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def read_metadata(self) -> Metadata:
"""Extends the default StorageReader to support building the metadata file"""
# Metadata is built in planner.set_up_planner, since we are not actually reading metadata from
@ -104,7 +104,7 @@ class BroadcastingTorchSaveReader(StorageReader):
# Broadcast the tensor from the coordinator rank
if self.is_coordinator:
pg_device = dist.distributed_c10d._get_pg_default_device()
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
tensor = torch_state_dict[req.storage_index.fqn].to(pg_device)
else:
tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn])
@ -125,7 +125,7 @@ class BroadcastingTorchSaveReader(StorageReader):
fut.set_result(None)
return fut
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
"""Implementation of the StorageReader method"""
self.is_coordinator = is_coordinator

View File

@ -309,7 +309,7 @@ class HuggingFaceStorageReader(FileSystemReader):
fut.set_result(None)
return fut
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def read_metadata(self) -> Metadata:
from safetensors import safe_open # type: ignore[import]
from safetensors.torch import _getdtype # type: ignore[import]

View File

@ -16,7 +16,7 @@ logger = logging.getLogger()
__all__: list[str] = []
# pyrefly: ignore # unknown-name
# pyrefly: ignore [unknown-name]
global _dcp_logger
_dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
@ -40,7 +40,7 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
checkpoint_id = getattr(serializer, "checkpoint_id", None)
msg_dict["checkpoint_id"] = (
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
str(checkpoint_id) if checkpoint_id is not None else checkpoint_id
)

View File

@ -30,7 +30,7 @@ from torch.distributed.checkpoint.planner_helpers import (
create_read_items_for_chunk_list,
)
# pyrefly: ignore # deprecated
# pyrefly: ignore [deprecated]
from torch.distributed.checkpoint.state_dict_loader import load_state_dict
from torch.distributed.checkpoint.storage import StorageReader
from torch.distributed.checkpoint.utils import (
@ -157,7 +157,7 @@ def _get_state_dict_2d_layout(
class _ReaderWithOffset(DefaultLoadPlanner):
translation: dict[MetadataIndex, MetadataIndex]
state_dict: STATE_DICT_TYPE
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
metadata: Metadata
def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None:

View File

@ -182,14 +182,14 @@ class DefaultStager(AsyncStager):
self._staging_executor = None
self._staging_stream = None
if self._config.use_async_staging:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._staging_executor = ThreadPoolExecutor(max_workers=1)
if torch.accelerator.is_available():
# Note: stream needs to be initialized on the main thread after default cuda
# stream is setup/used to avoid the risk of accidentally reusing the main
# compute stream or in other cases kernels actually launching from the
# main thread.
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._staging_stream = torch.Stream()
if self._config.use_non_blocking_copy:
@ -355,7 +355,7 @@ class _ReplicationStager(AsyncStager):
):
self._pg = pg
self._timeout = timeout
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self._device = device
self._transport = PGTransport(pg, timeout, device, None)

View File

@ -200,7 +200,7 @@ def _get_fqns(
return {f"{prefix}{fqn}" for fqn in flat_param._fqns}
curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE)
if curr_obj_name != FSDP_WRAPPED_MODULE:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
fqn_obj_names.append(curr_obj_name)
curr_obj = getattr(curr_obj, curr_obj_name)
elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule):
@ -218,7 +218,7 @@ def _get_fqns(
):
if hasattr(curr_obj, removed_fqn):
curr_obj = getattr(curr_obj, removed_fqn)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
fqn_obj_names.append(curr_obj_name)
if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX:
if i != len(obj_names) - 1:

View File

@ -328,7 +328,7 @@ def async_save(
upload_future: Future = upload_executor.execute_save(
staging_future_or_state_dict,
checkpoint_id=checkpoint_id,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
storage_writer=storage_writer,
planner=planner,
process_group=process_group,

View File

@ -257,7 +257,7 @@ class _DistWrapper:
if len(node_failures) > 0:
result = CheckpointException(step, node_failures)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
final_result = self.broadcast_object(result)
if isinstance(final_result, CheckpointException):
raise final_result
@ -306,7 +306,7 @@ class _DistWrapper:
result = map_fun()
except BaseException as e: # noqa: B036
result = CheckpointException(step, {self.rank: _wrap_exception(e)})
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
final_result = self.broadcast_object(result)
if isinstance(final_result, CheckpointException):
raise final_result

View File

@ -114,7 +114,7 @@ def broadcast(
error_msg += f": stage {sync_obj.stage_name}"
if sync_obj.exception is not None:
error_msg += f": exception {sync_obj.exception}"
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
raise RuntimeError(error_msg) from sync_obj.exception
return cast(T, sync_obj.payload)
@ -186,14 +186,14 @@ def all_gather(
raise RuntimeError( # type: ignore[misc]
error_msg,
exception_list,
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
) from exception_list[0]
return ret_list
else:
if not sync_obj.success:
raise RuntimeError(
f"all_gather failed with exception {sync_obj.exception}",
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
) from sync_obj.exception
return [sync_obj.payload] # type: ignore[list-item]
@ -270,13 +270,13 @@ def _summarize_ranks(ranks: Iterable[int]) -> str:
result = []
for r in ranges:
if len(r) == 1:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
result.append(f"{r.start}")
elif r.step == 1:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
result.append(f"{r.start}:{r.stop}")
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
result.append(f"{r.start}:{r.stop}:{r.step}")
return ",".join(result)

View File

@ -253,7 +253,7 @@ else:
)
if is_initialized() and get_backend() == "threaded":
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._thread_id = threading.get_ident()
if _rank is None:
@ -443,7 +443,7 @@ else:
# We temporarily revert the reuse subgroup, since it breaks two internal tests.
# Temporarily reverting to resolve test timeout while root-causing.
# TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists.
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
if bound_device_id is None or not has_split_group:
dim_group = new_group(
ranks=subgroup_ranks,

View File

@ -372,7 +372,7 @@ class BackendConfig:
def __init__(self, backend: Backend):
"""Init."""
self.device_backend_map: dict[str, Backend] = {}
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
backend = str(backend)
if backend == Backend.UNDEFINED:
@ -412,7 +412,7 @@ class BackendConfig:
f"Invalid device:backend pairing: \
{device_backend_pair_str}. {backend_str_error_message}"
)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
device, backend = device_backend_pair
if device in self.device_backend_map:
raise ValueError(
@ -1185,7 +1185,7 @@ def _as_iterable(obj) -> collections.abc.Iterable:
def _ensure_all_tensors_same_dtype(*tensors) -> None:
last_dtype = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)):
tensor_dtype = tensor.dtype
# Mixing complex and its element type is allowed
@ -1858,7 +1858,7 @@ def _get_split_source(pg):
split_from = pg._get_backend(pg.bound_device_id)
elif pg is _world.default_pg:
try:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
split_from = pg._get_backend(torch.device("cuda"))
except RuntimeError:
# no cuda device associated with this backend
@ -2022,7 +2022,7 @@ def _new_process_group_helper(
backend_prefix_store,
group_rank,
group_size,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
timeout=timeout,
)
backend_class.options.global_ranks_in_group = global_ranks_in_group
@ -2044,7 +2044,7 @@ def _new_process_group_helper(
# default backend_options for NCCL
backend_options = ProcessGroupNCCL.Options()
backend_options.is_high_priority_stream = False
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
backend_options._timeout = timeout
if split_from:
@ -2067,7 +2067,7 @@ def _new_process_group_helper(
backend_prefix_store,
group_rank,
group_size,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
timeout=timeout,
)
backend_type = ProcessGroup.BackendType.UCC
@ -2077,7 +2077,7 @@ def _new_process_group_helper(
backend_options = ProcessGroupXCCL.Options()
backend_options.global_ranks_in_group = global_ranks_in_group
backend_options.group_name = group_name
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
backend_options._timeout = timeout
backend_class = ProcessGroupXCCL(
backend_prefix_store, group_rank, group_size, backend_options
@ -2102,7 +2102,7 @@ def _new_process_group_helper(
dist_backend_opts.store = backend_prefix_store
dist_backend_opts.group_rank = group_rank
dist_backend_opts.group_size = group_size
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
dist_backend_opts.timeout = timeout
dist_backend_opts.group_id = group_name
dist_backend_opts.global_ranks_in_group = global_ranks_in_group
@ -2146,7 +2146,7 @@ def _new_process_group_helper(
store=backend_prefix_store,
rank=group_rank,
world_size=group_size,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
timeout=timeout,
)
@ -3356,7 +3356,7 @@ def gather_object(
return
assert object_gather_list is not None, "Must provide object_gather_list on dst rank"
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
for i, tensor in enumerate(output_tensors):
tensor = tensor.type(torch.uint8)
tensor_size = object_size_list[i]
@ -3733,10 +3733,10 @@ def broadcast_object_list(
# has only one element, we can skip the copy.
if my_group_rank == group_src:
if len(tensor_list) == 1: # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
object_tensor = tensor_list[0]
else:
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
object_tensor = torch.cat(tensor_list)
else:
object_tensor = torch.empty( # type: ignore[call-overload]
@ -3865,7 +3865,7 @@ def scatter_object_list(
broadcast(max_tensor_size, group_src=group_src, group=group)
# Scatter actual serialized objects
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
output_tensor = torch.empty(
max_tensor_size.item(), dtype=torch.uint8, device=pg_device
)
@ -4902,19 +4902,19 @@ def barrier(
if isinstance(device_ids, list):
opts.device_ids = device_ids
# use only the first device id
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
opts.device = torch.device(device.type, device_ids[0])
elif getattr(group, "bound_device_id", None) is not None:
# Use device id from `init_process_group(device_id=...)`
opts.device = group.bound_device_id # type: ignore[assignment]
elif device.type == "cpu" or _get_object_coll_device(group) == "cpu":
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
opts.device = torch.device("cpu")
else:
# Use the current device set by the user. If user did not set any, this
# may use default device 0, causing issues like hang or all processes
# creating context on device 0.
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
opts.device = device
if group.rank() == 0:
warnings.warn( # warn only once
@ -5045,7 +5045,7 @@ def _hash_ranks_to_str(ranks: list[int]) -> str:
# Takes a list of ranks and computes an integer color
def _process_group_color(ranks: list[int]) -> int:
# Convert list to tuple to make it hashable
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
ranks = tuple(ranks)
hash_value = hash(ranks)
# Split color must be:

View File

@ -333,10 +333,10 @@ class LocalElasticAgent(SimpleElasticAgent):
rank=worker.global_rank,
local_rank=local_rank,
)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
log_line_prefixes[local_rank] = log_line_prefix
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
envs[local_rank] = worker_env
worker_args = list(spec.args)
worker_args = macros.substitute(worker_args, str(local_rank))

View File

@ -54,7 +54,7 @@ class Event:
if isinstance(data, str):
data_dict = json.loads(data)
data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return Event(**data_dict)
def serialize(self) -> str:
@ -109,7 +109,7 @@ class RdzvEvent:
if isinstance(data, str):
data_dict = json.loads(data)
data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return RdzvEvent(**data_dict)
def serialize(self) -> str:

View File

@ -168,15 +168,15 @@ def profile(group=None):
try:
start_time = time.time()
result = func(*args, **kwargs)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
publish_metric(group, f"{func.__name__}.success", 1)
except Exception:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
publish_metric(group, f"{func.__name__}.failure", 1)
raise
finally:
publish_metric(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
group,
f"{func.__name__}.duration.ms",
get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined]

View File

@ -105,7 +105,7 @@ class TailLog:
n = len(log_files)
self._threadpool = None
if n > 0:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._threadpool = ThreadPoolExecutor(
max_workers=n,
thread_name_prefix=f"{self.__class__.__qualname__}_{name}",

View File

@ -126,7 +126,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
return tmp
def _decode_state(self, result: etcd.EtcdResult) -> tuple[bytes, Token]:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
base64_state = result.value.encode()
try:
@ -136,7 +136,7 @@ class EtcdRendezvousBackend(RendezvousBackend):
"The state object is corrupt. See inner exception for details."
) from exc
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return state, result.modifiedIndex

View File

@ -53,7 +53,7 @@ class ElasticDistributedSampler(DistributedSampler[T]):
raise TypeError("Dataset must be an instance of collections.abc.Sized")
# Cast to Sized for mypy
# pyrefly: ignore # redundant-cast
# pyrefly: ignore [redundant-cast]
sized_dataset = cast(Sized, dataset)
if start_index >= len(sized_dataset):

View File

@ -65,7 +65,7 @@ class _FSDPDeviceHandle:
if backend is None:
try:
self.__backend = getattr(torch, device.type)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.__device = device
except AttributeError as exc:
raise AttributeError(

View File

@ -220,7 +220,7 @@ def _move_states_to_device(
the future.
"""
# Follow the logic in `nn.Module._apply`
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
for tensor in itertools.chain(params, buffers):
if tensor.device == device or tensor.device.type == "meta":
# Keep meta-device tensors on meta device for deferred init

View File

@ -232,7 +232,7 @@ class FSDPParam:
self._module_info: ParamModuleInfo = module_info
self.mesh_info = mesh_info
self.post_forward_mesh_info = post_forward_mesh_info
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.device = device
self.mp_policy = mp_policy
self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy)
@ -579,7 +579,7 @@ class FSDPParam:
f"world size ({shard_world_size})"
)
shard_rank = self.post_forward_mesh_info.shard_mesh_rank
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
sharded_numel = numel // shard_world_size
self._sharded_post_forward_param_data = (
self.all_gather_outputs[0].narrow(
@ -713,7 +713,7 @@ class FSDPParam:
self.device, non_blocking=True
)
pre_all_gather_signature = inspect.signature(
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
sharded_local_tensor.fsdp_pre_all_gather
)
num_fn_params = len(pre_all_gather_signature.parameters)
@ -729,7 +729,7 @@ class FSDPParam:
(
all_gather_inputs,
self._extensions_data.all_gather_metadata,
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
) = sharded_local_tensor.fsdp_pre_all_gather(
self.shard_mesh_from_root
)
@ -737,7 +737,7 @@ class FSDPParam:
(
all_gather_inputs,
self._extensions_data.all_gather_metadata,
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
) = sharded_local_tensor.fsdp_pre_all_gather(
self.shard_mesh_from_root,
self._orig_size,
@ -865,7 +865,7 @@ class FSDPParam:
f"instead of {self.sharded_param}"
)
self.sharded_param = new_param
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
local_tensor = new_param._local_tensor
if local_tensor.is_meta:
return

View File

@ -151,7 +151,7 @@ class FSDPParamGroup:
]
self.mesh_info = mesh_info
self.post_forward_mesh_info = post_forward_mesh_info
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.device = device
self.device_handle = _get_device_handle(device.type)
self.mp_policy = mp_policy
@ -621,7 +621,7 @@ class FSDPParamGroup:
# Prefetch naively using the reverse post-forward order, which may
# have mistargeted prefetches if not all modules used in forward
# are used in this backward
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index]
self._prefetch_unshard(target_fsdp_param_group, "backward")
@ -868,7 +868,7 @@ compile the forward part if you want to use Traceable FSDP2."""
raise RuntimeError(msg)
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor):
# All tensors in `inputs` should require gradient
RegisterPostBackwardFunction._assert_not_tracing_fsdp()

View File

@ -96,7 +96,7 @@ class FSDPState(_State):
for module in modules:
_insert_module_state(module, self)
self._modules = modules
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self._device = device
self._device_handle = _get_device_handle(device.type)
self._mp_policy = mp_policy

View File

@ -51,7 +51,7 @@ def get_cls_to_fsdp_cls() -> dict[type, type]:
@overload
# pyrefly: ignore # inconsistent-overload
# pyrefly: ignore [inconsistent-overload]
def fully_shard(
module: nn.Module,
*,
@ -65,7 +65,7 @@ def fully_shard(
@overload
# pyrefly: ignore # inconsistent-overload
# pyrefly: ignore [inconsistent-overload]
def fully_shard(
module: list[nn.Module],
*,

View File

@ -509,7 +509,7 @@ def _init_prefetching_state(
@no_type_check
# pyrefly: ignore # bad-function-definition
# pyrefly: ignore [bad-function-definition]
def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState:
# TODO: we need to add additional check once we support FSDP + PiPPy.
# This check is currently sufficient, since we only support FSDP + TP.
@ -918,7 +918,7 @@ def _materialize_meta_module(
# the module has directly managed parameters/buffers
module_state_iter = itertools.chain(
module.parameters(recurse=False),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
module.buffers(recurse=False),
)
has_module_states = len(list(module_state_iter)) > 0

View File

@ -612,7 +612,7 @@ def _flatten_optim_state(
]
# Check that the unflattened parameters have the same state names
state_names = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for unflat_param_state in unflat_param_states:
if unflat_param_state is None:
continue
@ -936,7 +936,7 @@ def _rekey_sharded_optim_state_dict(
flat_param_key = unflat_param_names_to_flat_param_key.get(
key.unflat_param_names, key.unflat_param_names
)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
rekeyed_osd_state[flat_param_key] = param_state
# Only process param_groups if it exists in sharded_osd
@ -999,7 +999,7 @@ def _get_param_id_to_param_from_optim_input(
if optim_input is None:
return dict(enumerate(model.parameters()))
try:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
params = cast(list[nn.Parameter], list(optim_input))
except TypeError as e:
raise TypeError(

View File

@ -356,7 +356,7 @@ class _RemoteModule(nn.Module):
def register_backward_hook( # type: ignore[return]
self,
hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]],
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> RemovableHandle:
_raise_not_supported(self.register_backward_hook.__name__)
@ -371,7 +371,7 @@ class _RemoteModule(nn.Module):
],
prepend: bool = False,
with_kwargs: bool = False,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> RemovableHandle:
_raise_not_supported(self.register_forward_pre_hook.__name__)
@ -383,7 +383,7 @@ class _RemoteModule(nn.Module):
],
prepend: bool = False,
with_kwargs: bool = False,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> RemovableHandle:
_raise_not_supported(self.register_forward_hook.__name__)
@ -408,7 +408,7 @@ class _RemoteModule(nn.Module):
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Iterator[tuple[str, Parameter]]:
_raise_not_supported(self.named_parameters.__name__)
@ -420,7 +420,7 @@ class _RemoteModule(nn.Module):
prefix: str = "",
recurse: bool = True,
remove_duplicate: bool = True,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Iterator[tuple[str, Tensor]]:
_raise_not_supported(self.named_buffers.__name__)
@ -584,31 +584,31 @@ class _RemoteModule(nn.Module):
remote_module = object.__new__(RemoteModule)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device)
if _module_interface_cls is not None:
# Users reply on this field to know if this generated RemoteModule is TorchScript-able.
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
remote_module.is_scriptable = True
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
remote_module._init_template(
_module_interface_cls, enable_moving_cpu_tensors_to_cuda
)
else:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
remote_module.is_scriptable = False
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
remote_module.generated_methods = (
_NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
remote_module.module_rref = module_rref
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
remote_module._install_generated_methods()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
remote_module._check_attribute_picklability()
return remote_module
@ -711,11 +711,11 @@ def _remote_module_receiver(
m.__dict__.update(serialized_remote_module._asdict())
# Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method.
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
m.module_rref = rpc.PyRRef._deserialize(m.module_rref)
# Install generated methods when unpickled.
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for method in m.generated_methods:
method_name = method.__name__
method = torch.jit.export(method)

View File

@ -225,7 +225,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD):
class _Broadcast(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, src, group, tensor):
ctx.src = src
ctx.group = group
@ -237,7 +237,7 @@ class _Broadcast(Function):
return tensor
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output)
if ctx.src != ctx.rank:
@ -247,7 +247,7 @@ class _Broadcast(Function):
class _Gather(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, dst, group, tensor):
ctx.dst = dst
ctx.group = group
@ -273,7 +273,7 @@ class _Gather(Function):
class _Scatter(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, src, group, *tensors):
ctx.src = src
ctx.group = group
@ -286,14 +286,14 @@ class _Scatter(Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output)
class _Reduce(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, src, op, group, tensor):
ctx.src = src
ctx.group = group
@ -302,14 +302,14 @@ class _Reduce(Function):
return tensor
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),)
class _Reduce_Scatter(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, op, group, tensor, *input_tensor_list):
ctx.group = group
# Need contiguous tensors for collectives.
@ -319,14 +319,14 @@ class _Reduce_Scatter(Function):
return tensor
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
return (None, None, None) + _AllGather.apply(ctx.group, grad_output)
class _AllGather(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, group, tensor):
# Need contiguous tensors for collectives.
tensor = tensor.contiguous()
@ -356,14 +356,14 @@ class _AllGather(Function):
class _AllGatherBase(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, output_tensor, input_tensor, group):
ctx.group = group
dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group)
return output_tensor
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
if dist.get_backend(group=ctx.group) is dist.Backend.NCCL:
world_size = dist.get_world_size(group=ctx.group)
@ -385,7 +385,7 @@ class _AllGatherBase(Function):
class _AlltoAll(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, group, out_tensor_list, *tensors):
ctx.group = group
ctx.input_tensor_size_list = [
@ -421,7 +421,7 @@ class _AlltoAll(Function):
class _AlltoAllSingle(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, group, output, output_split_sizes, input_split_sizes, input):
ctx.group = group
ctx.input_size = input.size()
@ -437,7 +437,7 @@ class _AlltoAllSingle(Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
tensor = torch.empty(
ctx.input_size, device=grad_output.device, dtype=grad_output.dtype
@ -455,7 +455,7 @@ class _AlltoAllSingle(Function):
class _AllReduce(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, op, group, tensor):
ctx.group = group
ctx.op = op
@ -464,6 +464,6 @@ class _AllReduce(Function):
return tensor
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),)

View File

@ -100,19 +100,19 @@ def _broadcast_object(
data = bytearray(buffer.getbuffer())
length_tensor = torch.LongTensor([len(data)]).to(device)
data_send_tensor = torch.ByteTensor(data).to(device)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False)
else:
# Receive the object
length_tensor = torch.LongTensor([0]).to(device)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False)
data_recv_tensor = torch.empty(
[int(length_tensor.item())], dtype=torch.uint8, device=device
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False)
buffer = io.BytesIO(data_recv_tensor.cpu().numpy())
obj = torch.load(buffer, map_location=device, weights_only=False)
@ -171,7 +171,7 @@ class _DDPBucketAssignment:
if len(self.parameters) == 0:
raise ValueError("Empty bucket assignment")
# DDP guarantees all parameters in the bucket have the same device
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.device: torch.device = self.parameters[0].device
self.tensor: Optional[torch.Tensor] = None
@ -420,7 +420,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self.world_size: int = dist.get_world_size(self.process_group)
self.rank: int = dist.get_rank(self.process_group)
self.global_rank: int = dist.distributed_c10d.get_global_rank(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.process_group,
self.rank,
)
@ -542,7 +542,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
self._all_state_dicts = []
for rank in range(self.world_size):
global_rank = dist.distributed_c10d.get_global_rank(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.process_group,
rank,
)
@ -776,7 +776,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
for dev_i_buckets in self._buckets:
bucket = dev_i_buckets[rank]
global_rank = dist.distributed_c10d.get_global_rank(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.process_group,
rank,
)
@ -791,7 +791,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
else:
param_groups = self._partition_parameters()[rank]
global_rank = dist.distributed_c10d.get_global_rank(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.process_group,
rank,
)
@ -992,12 +992,12 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
for param_index, param in enumerate(bucket_params):
param_numel = param.numel()
if (
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
assignment_size + param_numel >= threshold
and param_index > bucket_offset
):
assigned_rank = self._get_min_index(
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
size_per_rank,
assigned_ranks_per_bucket[bucket_index],
)
@ -1010,7 +1010,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
assigned_rank,
assigned_ranks_per_bucket,
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
size_per_rank[assigned_rank] += assignment_size
bucket_offset = param_index
assignment_size = 0
@ -1018,7 +1018,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
# Assign the remainder of the bucket so that no assignment
# spans across two buckets
assigned_rank = self._get_min_index(
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
size_per_rank,
assigned_ranks_per_bucket[bucket_index],
)
@ -1029,7 +1029,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
assigned_rank,
assigned_ranks_per_bucket,
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
size_per_rank[assigned_rank] += assignment_size
return self._bucket_assignments_per_rank_cache
@ -1108,7 +1108,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
return loss
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def step(
self,
closure: Optional[Callable[[], float]] = None,

View File

@ -282,7 +282,7 @@ class LossWrapper(torch.nn.Module):
class TrivialLossWrapper(LossWrapper):
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(self, x, targets):
model_out = self.module(x)
return self.loss_fn(model_out, targets)

View File

@ -245,7 +245,7 @@ def stage_backward_weight(
if non_none_grads:
summed_grad = sum(non_none_grads)
valid_edges.append(GradientEdge(intermediate, 0))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
valid_grad_outputs.append(summed_grad)
# Break a reference cycle caused inside stage_backward_input->get_hook->hook

View File

@ -81,7 +81,7 @@ def get_schedule_ops(
raise ValueError(f"Invalid schedule: {schedule_class}")
# Instantiate the schedule class
# pyrefly: ignore # bad-instantiation, bad-argument-type
# pyrefly: ignore [bad-instantiation, bad-argument-type]
schedule_instance = schedule_class(stages, num_microbatches)
assert schedule_instance.pipeline_order is not None

View File

@ -246,7 +246,7 @@ def _format_pipeline_order(
pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
]
# Transpose the list of lists (rows to columns)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
# Generate column labels for ranks
num_ranks = len(pipeline_order)

View File

@ -155,7 +155,7 @@ class _PipelineStageBase(ABC):
self.submod = submodule
self.stage_index = stage_index
self.num_stages = num_stages
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.device = device
self.group = group

View File

@ -36,14 +36,14 @@ class _remote_device:
elif isinstance(remote_device, str):
fields = remote_device.split("/")
if len(fields) == 2:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._worker_name, self._device = fields
elif len(fields) == 1:
# Check if this is a valid device.
if _remote_device._is_valid_local_device(fields[0]):
self._device = fields[0]
else:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._worker_name = fields[0]
self._device = "cpu"
else:
@ -65,7 +65,7 @@ class _remote_device:
# rank:<rank>/device format, extract rank
if fields[0] == "rank" and fields[1].isdigit():
self._rank = int(fields[1]) # type: ignore[assignment]
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._worker_name = None
else:
raise ValueError(PARSE_ERROR)

View File

@ -93,7 +93,7 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa
result = result._replace(
query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}"
)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
url = urlunparse(result)
if result.scheme not in _rendezvous_handlers:
@ -111,7 +111,7 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
if not isinstance(world_size, numbers.Integral):
raise RuntimeError(f"`world_size` must be an integer. {world_size}")
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return _rendezvous_helper(url, rank, world_size, **kwargs)

View File

@ -471,7 +471,7 @@ def _rref_typeof_on_user(
T = TypeVar("T")
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
GenericWithOneTypeVar = Generic[T]
@ -718,7 +718,7 @@ def _invoke_rpc(
is_async_exec = hasattr(func, "_wrapped_async_rpc_function")
if is_async_exec:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
wrapped = func._wrapped_async_rpc_function
if isinstance(wrapped, torch.jit.ScriptFunction):
func = wrapped

View File

@ -95,7 +95,7 @@ def register_backend(
BackendType.__repr__ = _backend_type_repr # type: ignore[assignment]
if BackendType.__doc__:
BackendType.__doc__ = _backend_type_doc
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
return BackendType[backend_name]

View File

@ -48,7 +48,7 @@ else:
_TensorPipeRpcBackendOptionsBase = object # type: ignore[assignment, misc]
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase):
r"""
The backend options for

View File

@ -5,7 +5,7 @@ import itertools
import torch
# pyrefly: ignore # deprecated
# pyrefly: ignore [deprecated]
from torch.autograd.profiler_legacy import profile
from . import (
@ -176,13 +176,13 @@ class _server_process_global_profile(profile):
flattened_function_events = list(
itertools.chain.from_iterable(process_global_function_events)
)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.function_events = torch.autograd.profiler_util.EventList(
flattened_function_events,
use_device="cuda" if self.use_cuda else None,
profile_memory=self.profile_memory,
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.function_events._build_tree()
self.process_global_function_events = process_global_function_events

View File

@ -869,7 +869,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str
) from e
logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs)
# pyrefly: ignore # bad-instantiation
# pyrefly: ignore [bad-instantiation]
logs_specs = logs_specs_cls(
log_dir=args.log_dir,
redirects=Std.from_str(args.redirects),

View File

@ -90,7 +90,7 @@ class DTensorSpec:
if not isinstance(self.placements, tuple):
self.placements = tuple(self.placements)
if self.shard_order is None:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.shard_order = DTensorSpec.compute_default_shard_order(self.placements)
self._hash: int | None = None

View File

@ -171,7 +171,7 @@ def einop_rule(
global_shape, input_spec.mesh, input_spec.placements
)
cost += prod(local_shape) * input_spec.mesh.size(mesh_dim)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
costs.append(cost)
d_to_keep_sharding = dims[costs.index(max(costs))]
for d in dims:

View File

@ -138,7 +138,7 @@ class _NormPartial(Partial):
f"Expected int or float, got {type(self.norm_type)}"
)
if self.norm_type != 0 and self.norm_type != 1:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
return tensor**self.norm_type
return tensor
@ -149,7 +149,7 @@ class _NormPartial(Partial):
f"Expected int or float, got {type(self.norm_type)}"
)
if self.norm_type != 0 and self.norm_type != 1:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
return tensor ** (1.0 / self.norm_type)
return tensor

View File

@ -1094,9 +1094,9 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy:
)
return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
mat1_meta = local_meta(mat1_strategy.strategies[0], input_specs[0].placements)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
mat2_meta = local_meta(mat2_strategy.strategies[0], input_specs[1].placements)
def check_valid_strides(meta: TensorMeta) -> bool:

View File

@ -336,7 +336,7 @@ def expand_to_full_mesh_op_strategy(
for specs in zip(*strategy_comb):
if specs[0] is not None:
# TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
spec_list.append(DTensorSpec(mesh, specs))
else:
spec_list.append(None)

View File

@ -150,7 +150,7 @@ class _RNGStateTracker:
"""
def __init__(self, device: torch.device):
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self._device = device
self._device_handle = _get_device_handle(self._device.type)
if not (self._device_handle and self._device_handle.is_available()):

View File

@ -256,10 +256,10 @@ def convolution_backward_handler(
kwargs: dict[str, object],
) -> object:
# Redistribute grad_output tensor to the same placement as input tensor
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
args = list(args)
assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements)
args = tuple(args)

View File

@ -594,7 +594,7 @@ class CommDebugMode(TorchDispatchMode):
self.advanced_module_tracker.__enter__()
return self
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def __exit__(self, *args):
self.advanced_module_tracker.__exit__()
super().__exit__(*args)

View File

@ -90,7 +90,7 @@ def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=Fals
op_infos.sort(key=itemgetter(count_idx), reverse=True)
headers = ["Operator", "Schema", "Total Count", "Supported"]
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
print(tabulate(op_infos, headers=headers))
if output_csv:
@ -102,5 +102,5 @@ def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=Fals
csv_writer.writerow(headers)
# Write each table row to the CSV file
for row in op_infos:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
csv_writer.writerow(row)

View File

@ -91,10 +91,10 @@ class LocalShardsWrapper(torch.Tensor):
if func in supported_ops:
res_shards_list = [
func(shard, *args[1:], **kwargs)
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
for shard in args[0].shards
]
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
return LocalShardsWrapper(res_shards_list, args[0].shard_offsets)
else:
raise NotImplementedError(
@ -144,7 +144,7 @@ def run_torchrec_row_wise_even_sharding_example(rank, world_size):
local_tensor = torch.randn(local_shard_shape, device=device)
# row-wise sharding: one shard per rank
# create the local shards wrapper
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
local_shards_wrapper = LocalShardsWrapper(
local_shards=[local_tensor],
offsets=[local_shard_offset],
@ -223,7 +223,7 @@ def run_torchrec_row_wise_uneven_sharding_example(rank, world_size):
# local shards
# row-wise sharding: one shard per rank
# create the local shards wrapper
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
local_shards_wrapper = LocalShardsWrapper(
local_shards=[local_tensor],
offsets=[local_shard_offset],
@ -302,7 +302,7 @@ def run_torchrec_table_wise_sharding_example(rank, world_size):
local_shard_offset = torch.Size((0, 0))
# wrap local shards into a wrapper
local_shards_wrapper = (
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
LocalShardsWrapper(
local_shards=[local_tensor],
offsets=[local_shard_offset],

View File

@ -465,7 +465,7 @@ def _templated_ring_attention(
)
sdpa_merger.step(out, logsumexp, partial)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return *sdpa_merger.results(), *rest
@ -632,7 +632,7 @@ def _templated_ring_attention_backward(
grad_query,
grad_key,
grad_value,
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
*rest,
)

View File

@ -238,7 +238,7 @@ def _local_map_wrapped(
flat_local_args.append(arg)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
local_args = pytree.tree_unflatten(flat_local_args, args_spec)
out = func(*local_args, **kwargs)
@ -272,7 +272,7 @@ def _local_map_wrapped(
flat_dist_out.append(out)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return pytree.tree_unflatten(flat_dist_out, out_spec)
else:
return out

View File

@ -237,11 +237,11 @@ def _mark_sharding(
op_schema,
)
placement_strategies[node] = OpSpec(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_specs=_get_output_spec_from_output_sharding(output_sharding),
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
input_specs=output_sharding.redistribute_schema.args_spec
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if output_sharding.redistribute_schema is not None
else _get_input_node_specs(node, placement_strategies),
)

View File

@ -134,11 +134,11 @@ def _rewrite_spec_if_needed(
break
if rewrite:
spec = copy.deepcopy(spec)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for i, placement in enumerate(spec.placements):
placement = cast(_remote_device, placement)
if placement.rank() == rank and placement.device() != tensor.device:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}")
return spec

View File

@ -134,16 +134,16 @@ def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies):
from torch.nn.parallel.scatter_gather import _is_namedtuple
if _is_namedtuple(obj):
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return [type(obj)(*args) for args in zip(*map(to_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return list(zip(*map(to_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return [list(i) for i in zip(*map(to_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return [type(obj)(i) for i in zip(*map(to_map, obj.items()))]
return [obj]