diff --git a/torch/distributed/__init__.py b/torch/distributed/__init__.py index f8b5a7a75b2..f0beed8f4d4 100644 --- a/torch/distributed/__init__.py +++ b/torch/distributed/__init__.py @@ -133,7 +133,7 @@ if is_available(): # Variables prefixed with underscore are not auto imported # See the comment in `distributed_c10d.py` above `_backend` on why we expose # this. - # pyrefly: ignore # deprecated + # pyrefly: ignore [deprecated] from .distributed_c10d import * # noqa: F403 from .distributed_c10d import ( # pyrefly: ignore # deprecated _all_gather_base, diff --git a/torch/distributed/_composable/contract.py b/torch/distributed/_composable/contract.py index 507edafff18..6dedc5d4600 100644 --- a/torch/distributed/_composable/contract.py +++ b/torch/distributed/_composable/contract.py @@ -107,7 +107,7 @@ def contract( # If the user passes a sequence of modules, then we assume that # we only need to insert the state object on the root modules # (i.e. those without a parent) among the passed-in modules. - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] modules = _get_root_modules(list(module)) state = state_cls() # shared across all modules registry_item = RegistryItem() # shared across all modules @@ -119,7 +119,7 @@ def contract( all_orig_named_buffers: list[dict[str, torch.Tensor]] = [] all_orig_named_modules: list[dict[str, nn.Module]] = [] - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for module in modules: default_all_state: dict[Callable, _State] = OrderedDict() default_registry: dict[str, RegistryItem] = OrderedDict() @@ -146,11 +146,11 @@ def contract( all_state.setdefault(func, state) registry.setdefault(func.__name__, registry_item) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] all_orig_named_params.append(OrderedDict(module.named_parameters())) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] all_orig_named_buffers.append(OrderedDict(module.named_buffers())) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] all_orig_named_modules.append(OrderedDict(module.named_modules())) updated = func(inp_module, *args, **kwargs) @@ -165,13 +165,13 @@ def contract( all_new_named_params: list[dict[str, nn.Parameter]] = [] all_new_named_buffers: list[dict[str, torch.Tensor]] = [] all_new_named_modules: list[dict[str, nn.Module]] = [] - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for module in updated_modules: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] all_new_named_params.append(OrderedDict(module.named_parameters())) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] all_new_named_buffers.append(OrderedDict(module.named_buffers())) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] all_new_named_modules.append(OrderedDict(module.named_modules())) num_orig_modules = len(all_orig_named_modules) @@ -234,7 +234,7 @@ def contract( # TODO: verify that installed distributed paradigms are compatible with # each other. - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return updated def get_state(module: nn.Module) -> _State: diff --git a/torch/distributed/_composable/replicate_with_fsdp.py b/torch/distributed/_composable/replicate_with_fsdp.py index 405e3381145..36802680117 100644 --- a/torch/distributed/_composable/replicate_with_fsdp.py +++ b/torch/distributed/_composable/replicate_with_fsdp.py @@ -100,7 +100,7 @@ class _ReplicateState(FSDPState): for module in modules: _insert_module_state(module, self) self._modules = modules - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self._device = device self._device_handle = _get_device_handle(device.type) self._mp_policy = mp_policy @@ -151,7 +151,7 @@ class _ReplicateState(FSDPState): ) state._is_root = False self._state_ctx.all_states.append(state) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] visited_states.add(state) if self._fsdp_param_group and self._auto_reshard_after_forward: # For the root, do not reshard after forward since for training, diff --git a/torch/distributed/_composable_state.py b/torch/distributed/_composable_state.py index 507db1bf7fc..4f2808b5452 100644 --- a/torch/distributed/_composable_state.py +++ b/torch/distributed/_composable_state.py @@ -31,7 +31,7 @@ def _get_module_state(module: nn.Module) -> Optional[_State]: """ global _module_state_mapping if isinstance(module, _State): - # pyrefly: ignore # redundant-cast + # pyrefly: ignore [redundant-cast] return cast(_State, module) else: # https://github.com/pytorch/pytorch/issues/107054 diff --git a/torch/distributed/_functional_collectives.py b/torch/distributed/_functional_collectives.py index 70dc50f1591..e760a1a0744 100644 --- a/torch/distributed/_functional_collectives.py +++ b/torch/distributed/_functional_collectives.py @@ -633,7 +633,7 @@ class AsyncCollectiveTensor(torch.Tensor): if func == torch.ops.aten.view.default: # Fast handle aten.view as a lot of view related op goes to aten.view # eventually, this avoids pytree slowdown - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] res = func(args[0].elem, args[1]) wrapper_res = AsyncCollectiveTensor(res) return wrapper_res @@ -787,7 +787,7 @@ def _resolve_group_name(group: RANK_TYPES, tag: str = "") -> str: FutureWarning, stacklevel=3, ) - # pyrefly: ignore # redundant-cast + # pyrefly: ignore [redundant-cast] return c10d._resolve_group_name_by_ranks_and_tag(cast(list[int], group), tag) else: raise ValueError(f"Unsupported group type: {type(group)}, {group}") @@ -1166,10 +1166,10 @@ def all_gather_inplace( for t in tensor_list: is_scalar = t.dim() == 0 t_offset = 1 if is_scalar else t.size(0) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] out = output[offset] if is_scalar else output[offset : offset + t_offset] output_splits.append(out) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] offset += t_offset for dst, src in zip(tensor_list, output_splits): dst.copy_(src) diff --git a/torch/distributed/_local_tensor/_c10d.py b/torch/distributed/_local_tensor/_c10d.py index 53b17a9d450..6c70e2e230f 100644 --- a/torch/distributed/_local_tensor/_c10d.py +++ b/torch/distributed/_local_tensor/_c10d.py @@ -316,7 +316,7 @@ def _local_all_gather_( assert len(input_tensors) == 1 input_tensor = input_tensors[0] - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] output_tensors = output_tensors[0] ranks, group_offsets, _offset = _prepare_collective_groups(process_group_so) @@ -337,12 +337,12 @@ def _local_all_gather_( source_tensor = input_tensor if isinstance(input_tensor, LocalTensor): source_tensor = input_tensor._local_tensors[rank_i] - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] output_tensors[i].copy_(source_tensor) work = FakeWork() work_so = Work.boxed(work) - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return ([output_tensors], work_so) @@ -429,7 +429,7 @@ def _local_scatter_( assert len(output_tensors) == 1 assert len(input_tensors) == 1 output_tensor = output_tensors[0] - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] input_tensors = input_tensors[0] ranks, group_offsets, offset = _prepare_collective_groups(process_group_so) diff --git a/torch/distributed/_mesh_layout.py b/torch/distributed/_mesh_layout.py index 2a8355fb26c..38026b7d3d5 100644 --- a/torch/distributed/_mesh_layout.py +++ b/torch/distributed/_mesh_layout.py @@ -39,9 +39,9 @@ class _MeshLayout(Layout): different from that of PyCute's. """ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] shape: IntTuple - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] stride: IntTuple def __post_init__(self) -> None: diff --git a/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/torch/distributed/_shard/sharded_tensor/_ops/_common.py index cc965a2ab71..0a356e524a4 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/_common.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/_common.py @@ -43,7 +43,7 @@ def _sharded_op_common(op, early_stop_func, extra_check): def wrapper(types, args=(), kwargs=None, pg=None): _basic_validation(op, args, kwargs) - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] st = args[0] if kwargs is None: kwargs = {} @@ -93,7 +93,7 @@ def _register_sharded_op_on_local_shards( @_sharded_op_impl(op) @_sharded_op_common(op, early_stop_func, extra_check) def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None): - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] st = args[0] st_metadata = st.metadata() local_shards = st.local_shards() diff --git a/torch/distributed/_shard/sharded_tensor/_ops/init.py b/torch/distributed/_shard/sharded_tensor/_ops/init.py index 6c7255bb7c6..d0e576b45eb 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/init.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/init.py @@ -20,13 +20,13 @@ def uniform_(types, args=(), kwargs=None, pg=None): b: the upper bound of the uniform distribution """ validate_param(kwargs, "kwargs") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] a = kwargs["a"] validate_param(a, "a") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] b = kwargs["b"] validate_param(b, "b") @@ -46,13 +46,13 @@ def normal_(types, args=(), kwargs=None, pg=None): std: the standard deviation of the normal distribution """ validate_param(kwargs, "kwargs") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] mean = kwargs["mean"] validate_param(mean, "mean") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] std = kwargs["std"] validate_param(std, "std") @@ -84,16 +84,16 @@ def kaiming_uniform_(types, args=(), kwargs=None, pg=None): recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). """ validate_param(kwargs, "kwargs") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] a = kwargs["a"] validate_param(a, "a") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] mode = kwargs["mode"] validate_param(mode, "mode") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] nonlinearity = kwargs["nonlinearity"] validate_param(nonlinearity, "nonlinearity") @@ -113,10 +113,10 @@ def constant_(types, args=(), kwargs=None, pg=None): val: the value to fill the tensor with """ validate_param(kwargs, "kwargs") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] sharded_tensor = kwargs["tensor"] validate_param(sharded_tensor, "tensor") - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] val = kwargs["val"] validate_param(val, "val") for shard in sharded_tensor.local_shards(): @@ -149,7 +149,7 @@ def register_tensor_creation_op(op): if kwargs is None: kwargs = {} - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] st = args[0] new_st = creation_op(st.sharding_spec(), st.size(), *args[1:], **kwargs) # type: ignore[operator] diff --git a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py index a41a0bf9b15..d5b7ad7c77b 100644 --- a/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py +++ b/torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py @@ -40,7 +40,7 @@ _register_default_op(torch.Tensor.is_leaf.__get__, _sharded_op_impl) # type: ig # the device property on each rank @_sharded_op_impl(torch.Tensor.device.__get__) def tensor_device(types, args=(), kwargs=None, pg=None): - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] self_st = args[0] # Validate types if not isinstance(self_st, ShardedTensor): @@ -57,7 +57,7 @@ def tensor_device(types, args=(), kwargs=None, pg=None): @_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined] def st_is_meta(types, args=(), kwargs=None, pg=None): - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] return args[0].local_tensor().is_meta @@ -198,7 +198,7 @@ _register_sharded_op_on_local_shards( @_sharded_op_impl(torch.Tensor.requires_grad_) def tensor_requires_grad_set(types, args=(), kwargs=None, pg=None): - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] self_st = args[0] # Validate types if not isinstance(self_st, ShardedTensor): diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py index 9e2b8a5712b..7b709a2965c 100644 --- a/torch/distributed/_shard/sharded_tensor/api.py +++ b/torch/distributed/_shard/sharded_tensor/api.py @@ -299,9 +299,9 @@ class ShardedTensor(ShardedTensorBase): if self._init_rrefs: with _sharded_tensor_lock: global _sharded_tensor_current_id, _sharded_tensor_map - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._sharded_tensor_id = _sharded_tensor_current_id - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] _sharded_tensor_map[self._sharded_tensor_id] = weakref.ref(self) _sharded_tensor_current_id += 1 diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py index a8d8e422d1f..d4cd5728b2a 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py @@ -167,7 +167,7 @@ class ChunkShardingSpec(ShardingSpec): ) tensors_to_scatter[ - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] dist.get_group_rank(process_group, remote_global_rank) ] = tensor_to_scatter diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py index 3083a61163a..af4f4f890e9 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py @@ -58,7 +58,7 @@ def _register_sharded_op_on_local_tensor( @custom_sharding_spec_op(ChunkShardingSpec, op) @_sharded_op_common(op, early_stop_func, extra_check) def sharded_tensor_op_on_local_tensor(types, args=(), kwargs=None, pg=None): - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] st = args[0] sharding_spec = st.sharding_spec() if len(st.local_shards()) != 1: diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py index dfdefa9373f..f1581575f5f 100644 --- a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py +++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py @@ -425,9 +425,9 @@ def _handle_row_wise_sharding( else: split_sizes = torch.cat( ( - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] offsets[1 : offsets.size(0)] - offsets[0:-1], - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] (input.size(0) - offsets[-1]).unsqueeze(0), ), dim=-1, diff --git a/torch/distributed/_state_dict_utils.py b/torch/distributed/_state_dict_utils.py index 06aa9db81e9..bcf56748334 100644 --- a/torch/distributed/_state_dict_utils.py +++ b/torch/distributed/_state_dict_utils.py @@ -195,13 +195,13 @@ def _iterate_state_dict( ret.local_shards()[idx].tensor, non_blocking=non_blocking ) else: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] companion_obj.copy_(ret, non_blocking=non_blocking) ret = companion_obj else: ret = {} if isinstance(ret, dict) else None - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return ret @@ -799,7 +799,7 @@ def _set_element(root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: Any) -> None CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) ) else: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] extend_list(cur_container, prev_key) if cur_container[prev_key] is None: cur_container[prev_key] = def_val diff --git a/torch/distributed/_symmetric_memory/__init__.py b/torch/distributed/_symmetric_memory/__init__.py index 8fb1677345a..6aa4584b981 100644 --- a/torch/distributed/_symmetric_memory/__init__.py +++ b/torch/distributed/_symmetric_memory/__init__.py @@ -1848,7 +1848,7 @@ def empty( @overload -# pyrefly: ignore # inconsistent-overload +# pyrefly: ignore [inconsistent-overload] def empty( size: Sequence[_int], *, diff --git a/torch/distributed/_tools/fsdp2_mem_tracker.py b/torch/distributed/_tools/fsdp2_mem_tracker.py index a6d20c69ecf..60ff77d0d49 100644 --- a/torch/distributed/_tools/fsdp2_mem_tracker.py +++ b/torch/distributed/_tools/fsdp2_mem_tracker.py @@ -231,7 +231,7 @@ class FSDPMemTracker(MemTracker): " or file a github issue if you need this feature." ) - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] args, kwargs = orig_fsdp_state_pre_fw(*args, **kwargs) fsdp_state = fsdp_mod._get_fsdp_state() @@ -365,7 +365,7 @@ class FSDPMemTracker(MemTracker): # `FSDPParamGroup.post_forward` because during AC these won't be called. # TODO(@sanketpurandare): This will need to be modified after this PR (https://github.com/pytorch/pytorch/pull/127786) # lands. For backward we monkey-patch the `FSDPParamGroup.pre_backward` and `FSDPParamGroup.post_backward`. - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for module in self._root_mod.modules(): if isinstance(module, FSDPModule): fsdp_state = module._get_fsdp_state() @@ -374,7 +374,7 @@ class FSDPMemTracker(MemTracker): fsdp_state._pre_forward_hook_handle.remove() fsdp_state._post_forward_hook_handle.remove() fsdp_state._pre_forward_hook_handle = ( - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] module.register_forward_pre_hook( self._fsdp_state_pre_forward( module, fsdp_state._pre_forward @@ -383,7 +383,7 @@ class FSDPMemTracker(MemTracker): with_kwargs=True, ) ) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] fsdp_state._post_forward_hook_handle = module.register_forward_hook( self._fsdp_state_post_forward(module, fsdp_state._post_forward), prepend=False, @@ -402,7 +402,7 @@ class FSDPMemTracker(MemTracker): ) ) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for buffer in self._root_mod.buffers(): self._update_and_maybe_create_winfos( buffer, @@ -512,7 +512,7 @@ class FSDPMemTracker(MemTracker): ): # N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns # a new tensor which does not happen in eager mode, when a wait_tensor is called. - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] res = args[0] else: res = func(*args, **kwargs or {}) @@ -529,7 +529,7 @@ class FSDPMemTracker(MemTracker): _FSDPState.PRE_FW, _FSDPState.PRE_BW, ]: - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] output_tensor = args[0] self._update_and_maybe_create_winfos( output_tensor, @@ -540,7 +540,7 @@ class FSDPMemTracker(MemTracker): func == c10d._reduce_scatter_base_.default and self._fsdp_state == _FSDPState.POST_BW ): - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] input_tensor = args[1] self._update_and_maybe_create_winfos( input_tensor, diff --git a/torch/distributed/_tools/mem_tracker.py b/torch/distributed/_tools/mem_tracker.py index 2736ca0a2f3..04f5482d7d1 100644 --- a/torch/distributed/_tools/mem_tracker.py +++ b/torch/distributed/_tools/mem_tracker.py @@ -143,7 +143,7 @@ class _WeakRefInfo: self.size = size self.element_size = element_size self.reftype = reftype - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self.device = device self.mem_consumed = self._calculate_mem_consumed() @@ -405,7 +405,7 @@ class MemTracker(TorchDispatchMode): # Initialize a flag to track if the total memory might drop to zero after updates. maybe_zero = False # Ensure the device entry exists in the current memory snapshot, initializing if necessary. - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] dev_snap = self._curr_mem_snap.setdefault( winfo.device, dict.fromkeys(self._ref_class, 0) ) @@ -917,7 +917,7 @@ class MemTracker(TorchDispatchMode): self._depth += 1 return self - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def __exit__(self, *args: Any) -> None: self._depth -= 1 if self._depth == 0: @@ -935,7 +935,7 @@ class MemTracker(TorchDispatchMode): ): # N.B: This is a hacky way to override the Meta IMPL of wait_tensor. The original impl returns # a new tensor which does not happen in eager mode, when a wait_tensor is called. - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] res = args[0] else: res = func(*args, **kwargs or {}) diff --git a/torch/distributed/_tools/memory_tracker.py b/torch/distributed/_tools/memory_tracker.py index 1ee9817c95a..1dc01f62d94 100644 --- a/torch/distributed/_tools/memory_tracker.py +++ b/torch/distributed/_tools/memory_tracker.py @@ -232,9 +232,9 @@ class MemoryTracker: def _pre_forward_hook(module: nn.Module, inputs: Any) -> None: self._cur_module_name = f"{name}.forward" if ( - # pyrefly: ignore # invalid-argument + # pyrefly: ignore [invalid-argument] hasattr(module, "_memory_tracker_is_root") - # pyrefly: ignore # not-callable + # pyrefly: ignore [not-callable] and module._memory_tracker_is_root ): self._add_marker("fw_start") @@ -250,9 +250,9 @@ class MemoryTracker: outputs: Sequence[torch.Tensor], ) -> None: if ( - # pyrefly: ignore # invalid-argument + # pyrefly: ignore [invalid-argument] hasattr(module, "_memory_tracker_is_root") - # pyrefly: ignore # not-callable + # pyrefly: ignore [not-callable] and module._memory_tracker_is_root ): self._add_marker("fw_bw_boundary") diff --git a/torch/distributed/_tools/mod_tracker.py b/torch/distributed/_tools/mod_tracker.py index 32a76062ec5..3d5c1783d8a 100644 --- a/torch/distributed/_tools/mod_tracker.py +++ b/torch/distributed/_tools/mod_tracker.py @@ -178,7 +178,7 @@ class ModTracker: def custom_formatwarning(msg, category, filename, lineno, line=None): return f"{filename}:{lineno}: {category.__name__}: {msg} \n" - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] warnings.formatwarning = custom_formatwarning warnings.warn( "The module hierarchy tracking maybe be messed up." diff --git a/torch/distributed/_tools/runtime_estimator.py b/torch/distributed/_tools/runtime_estimator.py index d739d789f4a..b897e51cac9 100644 --- a/torch/distributed/_tools/runtime_estimator.py +++ b/torch/distributed/_tools/runtime_estimator.py @@ -519,7 +519,7 @@ class RuntimeEstimator(TorchDispatchMode): super().__enter__() return self - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def __exit__(self, *args: Any) -> None: print( f"Estimated ({self._estimate_mode_type})" diff --git a/torch/distributed/_tools/sac_estimator.py b/torch/distributed/_tools/sac_estimator.py index d14d8c9ae92..eaad4d26aa3 100644 --- a/torch/distributed/_tools/sac_estimator.py +++ b/torch/distributed/_tools/sac_estimator.py @@ -429,7 +429,7 @@ class SACEstimator(TorchDispatchMode): # sdpa has non-deterministic seed, but might be deterministic # if no dropout is applied if func.overloadpacket.__name__ == "_scaled_dot_product_flash_attention": - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] is_rand_op = kwargs.get("dropout_p", 0) != 0 # 5. Create metadata information per active non-leaf module for mod_fqn in self._mod_tracker.parents: diff --git a/torch/distributed/algorithms/_quantization/quantization.py b/torch/distributed/algorithms/_quantization/quantization.py index 23c08e63331..a1fa1fd64c0 100644 --- a/torch/distributed/algorithms/_quantization/quantization.py +++ b/torch/distributed/algorithms/_quantization/quantization.py @@ -65,7 +65,7 @@ def _dequantize_tensor(tensor, qtype, quant_loss=None): elif tensor.dtype == torch.float16 and quant_loss is None: return tensor.float() else: - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] return tensor.float() / quant_loss elif qtype == DQuantType.BFP16: if tensor.dtype != torch.float16: diff --git a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py index a1febff0a6f..20a0de7ef31 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/default_hooks.py @@ -22,7 +22,7 @@ def _allreduce_fut( group_to_use = process_group if process_group is not None else dist.group.WORLD # Apply the division first to avoid overflow, especially for FP16. - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] tensor.div_(group_to_use.size()) return ( @@ -60,7 +60,7 @@ def _compress_hook( bucket: dist.GradBucket, ) -> torch.futures.Future[torch.Tensor]: group_to_use = process_group if process_group is not None else dist.group.WORLD - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] world_size = group_to_use.size() buffer = ( @@ -82,7 +82,7 @@ def _compress_hook( grad = dist._functional_collectives.all_reduce( compressed_tensor, "sum", - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] group_to_use, ) return decompress(grad) diff --git a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py index 5224decc5ee..886155908e1 100644 --- a/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py +++ b/torch/distributed/algorithms/ddp_comm_hooks/quantization_hooks.py @@ -66,7 +66,7 @@ def quantization_pertensor_hook( """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank() if process_group is not None else dist.get_rank() - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] world_size = group_to_use.size() tensor = bucket.buffer() @@ -148,7 +148,7 @@ def quantization_perchannel_hook( """ group_to_use = process_group if process_group is not None else dist.group.WORLD rank = process_group.rank() if process_group is not None else dist.get_rank() - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] world_size = group_to_use.size() tensor = bucket.buffer() diff --git a/torch/distributed/algorithms/join.py b/torch/distributed/algorithms/join.py index 3b6ea6e40e3..ee07c75f7ee 100644 --- a/torch/distributed/algorithms/join.py +++ b/torch/distributed/algorithms/join.py @@ -210,7 +210,7 @@ class Join: """ process_group = None device = None - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for joinable in self._joinables: if process_group is None: process_group = joinable.join_process_group diff --git a/torch/distributed/benchmarks/benchmark_ddp_rpc.py b/torch/distributed/benchmarks/benchmark_ddp_rpc.py index 2ad8df5834f..e0279966695 100644 --- a/torch/distributed/benchmarks/benchmark_ddp_rpc.py +++ b/torch/distributed/benchmarks/benchmark_ddp_rpc.py @@ -74,7 +74,7 @@ class HybridModel(torch.nn.Module): assert NUM_PS * EMBEDDING_DIM >= 512 dim_normalizer = int(NUM_PS * EMBEDDING_DIM / 512) emb_lookups_reshaped = emb_lookups_cat.reshape( # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] [emb_lookups_cat.shape[0] * dim_normalizer, 512] ) diff --git a/torch/distributed/c10d_logger.py b/torch/distributed/c10d_logger.py index 446682b804f..1dfae5b9296 100644 --- a/torch/distributed/c10d_logger.py +++ b/torch/distributed/c10d_logger.py @@ -45,7 +45,7 @@ def _get_logging_handler( return (log_handler, log_handler_name) -# pyrefly: ignore # unknown-name +# pyrefly: ignore [unknown-name] global _c10d_logger _c10d_logger = _get_or_create_logger() diff --git a/torch/distributed/checkpoint/__init__.py b/torch/distributed/checkpoint/__init__.py index 6f67d225998..8104a8df99f 100644 --- a/torch/distributed/checkpoint/__init__.py +++ b/torch/distributed/checkpoint/__init__.py @@ -13,9 +13,9 @@ from .optimizer import load_sharded_optimizer_state_dict from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem from .quantized_hf_storage import QuantizedHuggingFaceStorageReader -# pyrefly: ignore # deprecated +# pyrefly: ignore [deprecated] from .state_dict_loader import load, load_state_dict -# pyrefly: ignore # deprecated +# pyrefly: ignore [deprecated] from .state_dict_saver import async_save, save, save_state_dict from .storage import StorageReader, StorageWriter diff --git a/torch/distributed/checkpoint/_async_process_executor.py b/torch/distributed/checkpoint/_async_process_executor.py index 7c8aa6b6398..f7c045cdd27 100644 --- a/torch/distributed/checkpoint/_async_process_executor.py +++ b/torch/distributed/checkpoint/_async_process_executor.py @@ -313,7 +313,7 @@ class _ProcessBasedAsyncCheckpointExecutor(_AsyncCheckpointExecutor): @_dcp_method_logger(**ckpt_kwargs) def create_checkpoint_daemon_process() -> None: global _CHECKPOINT_PROCESS - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _CHECKPOINT_PROCESS = _AsyncCheckpointProcess(pg_init_info=pg_init_info) create_checkpoint_daemon_process() diff --git a/torch/distributed/checkpoint/_experimental/checkpoint_process.py b/torch/distributed/checkpoint/_experimental/checkpoint_process.py index 5fde55053ee..c71210aaa54 100644 --- a/torch/distributed/checkpoint/_experimental/checkpoint_process.py +++ b/torch/distributed/checkpoint/_experimental/checkpoint_process.py @@ -322,7 +322,7 @@ class CheckpointProcess: subprocess_pid = self.process.processes[0].pid # send graceful termination to sub process try: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self._parent_end.send( WorkerRequest( request_type=RequestType.TERMINATE_PROCESS, diff --git a/torch/distributed/checkpoint/_experimental/checkpoint_reader.py b/torch/distributed/checkpoint/_experimental/checkpoint_reader.py index fb1bcf46198..7be55938cfd 100644 --- a/torch/distributed/checkpoint/_experimental/checkpoint_reader.py +++ b/torch/distributed/checkpoint/_experimental/checkpoint_reader.py @@ -176,7 +176,7 @@ class CheckpointReader: # create a new map with all the keys present in source_value target_value = dict.fromkeys(source_value.keys()) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for key in list(target_value.keys()): current_path = f"{key_path}.{key}" if key_path else key if key in source_value: diff --git a/torch/distributed/checkpoint/_experimental/staging.py b/torch/distributed/checkpoint/_experimental/staging.py index 199532e2d11..2d83278e131 100644 --- a/torch/distributed/checkpoint/_experimental/staging.py +++ b/torch/distributed/checkpoint/_experimental/staging.py @@ -147,14 +147,14 @@ class DefaultStager(CheckpointStager): self._staging_stream = None if self._config.use_async_staging: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._staging_executor = ThreadPoolExecutor(max_workers=1) if torch.accelerator.is_available(): # Note: stream needs to be initialized on the main thread after default cuda # stream is setup/used to avoid the risk of accidentally reusing the main # compute stream or in other cases kernels actually launching from the # main thread. - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._staging_stream = torch.Stream() if self._config.use_non_blocking_copy: diff --git a/torch/distributed/checkpoint/_extension.py b/torch/distributed/checkpoint/_extension.py index 2bde1cfb10b..663caa8a857 100644 --- a/torch/distributed/checkpoint/_extension.py +++ b/torch/distributed/checkpoint/_extension.py @@ -94,7 +94,7 @@ class ZStandard(StreamTransformExtension): return zstandard is not None or pyzstd is not None @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def from_descriptor(version: str) -> "ZStandard": if version.partition(".")[0] != "1": raise ValueError(f"Unknown extension {version=}") @@ -217,7 +217,7 @@ class ExtensionRegistry: ext = self.extensions.get(name) if not ext: raise ValueError(f"Unknown extension {name=}") - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return ext.from_descriptor(version) return [from_descriptor(desc) for desc in descriptors] diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py index cfd605a2bfb..48eb67b4f76 100644 --- a/torch/distributed/checkpoint/_traverse.py +++ b/torch/distributed/checkpoint/_traverse.py @@ -128,7 +128,7 @@ def set_element( CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val) ) else: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] extend_list(cur_container, prev_key) if cur_container[prev_key] is None: cur_container[prev_key] = def_val @@ -155,7 +155,7 @@ def get_element( elif not isinstance(cur_value, Mapping) or part not in cur_value: return default_value - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] cur_value = cast(CONTAINER_TYPE, cur_value[part]) return cast(Optional[T], cur_value) diff --git a/torch/distributed/checkpoint/examples/async_checkpointing_example.py b/torch/distributed/checkpoint/examples/async_checkpointing_example.py index c3375c37543..41bc9a28812 100644 --- a/torch/distributed/checkpoint/examples/async_checkpointing_example.py +++ b/torch/distributed/checkpoint/examples/async_checkpointing_example.py @@ -60,7 +60,7 @@ def _init_model(rank, world_size): optim = torch.optim.Adam(model.parameters(), lr=0.0001) _patch_model_state_dict(model) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] _patch_optimizer_state_dict(model, optimizers=optim) return model, optim @@ -93,7 +93,7 @@ def run(rank, world_size): loss_calc = torch.nn.BCELoss() f = None - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for epoch in range(NUM_EPOCHS): try: torch.manual_seed(epoch) diff --git a/torch/distributed/checkpoint/format_utils.py b/torch/distributed/checkpoint/format_utils.py index b61474f675d..129b7cf570c 100644 --- a/torch/distributed/checkpoint/format_utils.py +++ b/torch/distributed/checkpoint/format_utils.py @@ -64,7 +64,7 @@ class BroadcastingTorchSaveReader(StorageReader): self.checkpoint_id = checkpoint_id self.coordinator_rank = coordinator_rank - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def read_metadata(self) -> Metadata: """Extends the default StorageReader to support building the metadata file""" # Metadata is built in planner.set_up_planner, since we are not actually reading metadata from @@ -104,7 +104,7 @@ class BroadcastingTorchSaveReader(StorageReader): # Broadcast the tensor from the coordinator rank if self.is_coordinator: pg_device = dist.distributed_c10d._get_pg_default_device() - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] tensor = torch_state_dict[req.storage_index.fqn].to(pg_device) else: tensor = torch.empty_like(planner.state_dict[req.storage_index.fqn]) @@ -125,7 +125,7 @@ class BroadcastingTorchSaveReader(StorageReader): fut.set_result(None) return fut - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None: """Implementation of the StorageReader method""" self.is_coordinator = is_coordinator diff --git a/torch/distributed/checkpoint/hf_storage.py b/torch/distributed/checkpoint/hf_storage.py index c769565229b..52f9209da0e 100644 --- a/torch/distributed/checkpoint/hf_storage.py +++ b/torch/distributed/checkpoint/hf_storage.py @@ -309,7 +309,7 @@ class HuggingFaceStorageReader(FileSystemReader): fut.set_result(None) return fut - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def read_metadata(self) -> Metadata: from safetensors import safe_open # type: ignore[import] from safetensors.torch import _getdtype # type: ignore[import] diff --git a/torch/distributed/checkpoint/logger.py b/torch/distributed/checkpoint/logger.py index 68ad0009c44..677cac0339c 100644 --- a/torch/distributed/checkpoint/logger.py +++ b/torch/distributed/checkpoint/logger.py @@ -16,7 +16,7 @@ logger = logging.getLogger() __all__: list[str] = [] -# pyrefly: ignore # unknown-name +# pyrefly: ignore [unknown-name] global _dcp_logger _dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME) @@ -40,7 +40,7 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]: checkpoint_id = getattr(serializer, "checkpoint_id", None) msg_dict["checkpoint_id"] = ( - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] str(checkpoint_id) if checkpoint_id is not None else checkpoint_id ) diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py index 7d72633b6a9..343497da0aa 100644 --- a/torch/distributed/checkpoint/optimizer.py +++ b/torch/distributed/checkpoint/optimizer.py @@ -30,7 +30,7 @@ from torch.distributed.checkpoint.planner_helpers import ( create_read_items_for_chunk_list, ) -# pyrefly: ignore # deprecated +# pyrefly: ignore [deprecated] from torch.distributed.checkpoint.state_dict_loader import load_state_dict from torch.distributed.checkpoint.storage import StorageReader from torch.distributed.checkpoint.utils import ( @@ -157,7 +157,7 @@ def _get_state_dict_2d_layout( class _ReaderWithOffset(DefaultLoadPlanner): translation: dict[MetadataIndex, MetadataIndex] state_dict: STATE_DICT_TYPE - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] metadata: Metadata def __init__(self, fqn_to_offset: dict[str, Sequence[int]]) -> None: diff --git a/torch/distributed/checkpoint/staging.py b/torch/distributed/checkpoint/staging.py index d3ea5334d68..4bbacc66aaa 100644 --- a/torch/distributed/checkpoint/staging.py +++ b/torch/distributed/checkpoint/staging.py @@ -182,14 +182,14 @@ class DefaultStager(AsyncStager): self._staging_executor = None self._staging_stream = None if self._config.use_async_staging: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._staging_executor = ThreadPoolExecutor(max_workers=1) if torch.accelerator.is_available(): # Note: stream needs to be initialized on the main thread after default cuda # stream is setup/used to avoid the risk of accidentally reusing the main # compute stream or in other cases kernels actually launching from the # main thread. - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._staging_stream = torch.Stream() if self._config.use_non_blocking_copy: @@ -355,7 +355,7 @@ class _ReplicationStager(AsyncStager): ): self._pg = pg self._timeout = timeout - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self._device = device self._transport = PGTransport(pg, timeout, device, None) diff --git a/torch/distributed/checkpoint/state_dict.py b/torch/distributed/checkpoint/state_dict.py index f023eb949ce..f50a0ee8e60 100644 --- a/torch/distributed/checkpoint/state_dict.py +++ b/torch/distributed/checkpoint/state_dict.py @@ -200,7 +200,7 @@ def _get_fqns( return {f"{prefix}{fqn}" for fqn in flat_param._fqns} curr_obj = getattr(curr_obj, FSDP_WRAPPED_MODULE) if curr_obj_name != FSDP_WRAPPED_MODULE: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] fqn_obj_names.append(curr_obj_name) curr_obj = getattr(curr_obj, curr_obj_name) elif isinstance(curr_obj, torch._dynamo.eval_frame.OptimizedModule): @@ -218,7 +218,7 @@ def _get_fqns( ): if hasattr(curr_obj, removed_fqn): curr_obj = getattr(curr_obj, removed_fqn) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] fqn_obj_names.append(curr_obj_name) if curr_obj_name == nn.modules.module._EXTRA_STATE_KEY_SUFFIX: if i != len(obj_names) - 1: diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py index 58a4bd0e85e..ef0be9f9309 100644 --- a/torch/distributed/checkpoint/state_dict_saver.py +++ b/torch/distributed/checkpoint/state_dict_saver.py @@ -328,7 +328,7 @@ def async_save( upload_future: Future = upload_executor.execute_save( staging_future_or_state_dict, checkpoint_id=checkpoint_id, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] storage_writer=storage_writer, planner=planner, process_group=process_group, diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py index c06c5022383..94844812b52 100644 --- a/torch/distributed/checkpoint/utils.py +++ b/torch/distributed/checkpoint/utils.py @@ -257,7 +257,7 @@ class _DistWrapper: if len(node_failures) > 0: result = CheckpointException(step, node_failures) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] final_result = self.broadcast_object(result) if isinstance(final_result, CheckpointException): raise final_result @@ -306,7 +306,7 @@ class _DistWrapper: result = map_fun() except BaseException as e: # noqa: B036 result = CheckpointException(step, {self.rank: _wrap_exception(e)}) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] final_result = self.broadcast_object(result) if isinstance(final_result, CheckpointException): raise final_result diff --git a/torch/distributed/collective_utils.py b/torch/distributed/collective_utils.py index b61155274bc..b45f0b5cbb4 100644 --- a/torch/distributed/collective_utils.py +++ b/torch/distributed/collective_utils.py @@ -114,7 +114,7 @@ def broadcast( error_msg += f": stage {sync_obj.stage_name}" if sync_obj.exception is not None: error_msg += f": exception {sync_obj.exception}" - # pyrefly: ignore # invalid-inheritance + # pyrefly: ignore [invalid-inheritance] raise RuntimeError(error_msg) from sync_obj.exception return cast(T, sync_obj.payload) @@ -186,14 +186,14 @@ def all_gather( raise RuntimeError( # type: ignore[misc] error_msg, exception_list, - # pyrefly: ignore # invalid-inheritance + # pyrefly: ignore [invalid-inheritance] ) from exception_list[0] return ret_list else: if not sync_obj.success: raise RuntimeError( f"all_gather failed with exception {sync_obj.exception}", - # pyrefly: ignore # invalid-inheritance + # pyrefly: ignore [invalid-inheritance] ) from sync_obj.exception return [sync_obj.payload] # type: ignore[list-item] @@ -270,13 +270,13 @@ def _summarize_ranks(ranks: Iterable[int]) -> str: result = [] for r in ranges: if len(r) == 1: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] result.append(f"{r.start}") elif r.step == 1: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] result.append(f"{r.start}:{r.stop}") else: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] result.append(f"{r.start}:{r.stop}:{r.step}") return ",".join(result) diff --git a/torch/distributed/device_mesh.py b/torch/distributed/device_mesh.py index 5c8969091d6..052b74ba479 100644 --- a/torch/distributed/device_mesh.py +++ b/torch/distributed/device_mesh.py @@ -253,7 +253,7 @@ else: ) if is_initialized() and get_backend() == "threaded": - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._thread_id = threading.get_ident() if _rank is None: @@ -443,7 +443,7 @@ else: # We temporarily revert the reuse subgroup, since it breaks two internal tests. # Temporarily reverting to resolve test timeout while root-causing. # TODO: Add two tests to cover internal tests scenarios and re-enable reuse subgroup if exists. - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] if bound_device_id is None or not has_split_group: dim_group = new_group( ranks=subgroup_ranks, diff --git a/torch/distributed/distributed_c10d.py b/torch/distributed/distributed_c10d.py index 96b2eeb7ef2..52370a4545f 100644 --- a/torch/distributed/distributed_c10d.py +++ b/torch/distributed/distributed_c10d.py @@ -372,7 +372,7 @@ class BackendConfig: def __init__(self, backend: Backend): """Init.""" self.device_backend_map: dict[str, Backend] = {} - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] backend = str(backend) if backend == Backend.UNDEFINED: @@ -412,7 +412,7 @@ class BackendConfig: f"Invalid device:backend pairing: \ {device_backend_pair_str}. {backend_str_error_message}" ) - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] device, backend = device_backend_pair if device in self.device_backend_map: raise ValueError( @@ -1185,7 +1185,7 @@ def _as_iterable(obj) -> collections.abc.Iterable: def _ensure_all_tensors_same_dtype(*tensors) -> None: last_dtype = None - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for tensor in itertools.chain.from_iterable(map(_as_iterable, tensors)): tensor_dtype = tensor.dtype # Mixing complex and its element type is allowed @@ -1858,7 +1858,7 @@ def _get_split_source(pg): split_from = pg._get_backend(pg.bound_device_id) elif pg is _world.default_pg: try: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] split_from = pg._get_backend(torch.device("cuda")) except RuntimeError: # no cuda device associated with this backend @@ -2022,7 +2022,7 @@ def _new_process_group_helper( backend_prefix_store, group_rank, group_size, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] timeout=timeout, ) backend_class.options.global_ranks_in_group = global_ranks_in_group @@ -2044,7 +2044,7 @@ def _new_process_group_helper( # default backend_options for NCCL backend_options = ProcessGroupNCCL.Options() backend_options.is_high_priority_stream = False - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] backend_options._timeout = timeout if split_from: @@ -2067,7 +2067,7 @@ def _new_process_group_helper( backend_prefix_store, group_rank, group_size, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] timeout=timeout, ) backend_type = ProcessGroup.BackendType.UCC @@ -2077,7 +2077,7 @@ def _new_process_group_helper( backend_options = ProcessGroupXCCL.Options() backend_options.global_ranks_in_group = global_ranks_in_group backend_options.group_name = group_name - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] backend_options._timeout = timeout backend_class = ProcessGroupXCCL( backend_prefix_store, group_rank, group_size, backend_options @@ -2102,7 +2102,7 @@ def _new_process_group_helper( dist_backend_opts.store = backend_prefix_store dist_backend_opts.group_rank = group_rank dist_backend_opts.group_size = group_size - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] dist_backend_opts.timeout = timeout dist_backend_opts.group_id = group_name dist_backend_opts.global_ranks_in_group = global_ranks_in_group @@ -2146,7 +2146,7 @@ def _new_process_group_helper( store=backend_prefix_store, rank=group_rank, world_size=group_size, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] timeout=timeout, ) @@ -3356,7 +3356,7 @@ def gather_object( return assert object_gather_list is not None, "Must provide object_gather_list on dst rank" - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] for i, tensor in enumerate(output_tensors): tensor = tensor.type(torch.uint8) tensor_size = object_size_list[i] @@ -3733,10 +3733,10 @@ def broadcast_object_list( # has only one element, we can skip the copy. if my_group_rank == group_src: if len(tensor_list) == 1: # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] object_tensor = tensor_list[0] else: - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] object_tensor = torch.cat(tensor_list) else: object_tensor = torch.empty( # type: ignore[call-overload] @@ -3865,7 +3865,7 @@ def scatter_object_list( broadcast(max_tensor_size, group_src=group_src, group=group) # Scatter actual serialized objects - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] output_tensor = torch.empty( max_tensor_size.item(), dtype=torch.uint8, device=pg_device ) @@ -4902,19 +4902,19 @@ def barrier( if isinstance(device_ids, list): opts.device_ids = device_ids # use only the first device id - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] opts.device = torch.device(device.type, device_ids[0]) elif getattr(group, "bound_device_id", None) is not None: # Use device id from `init_process_group(device_id=...)` opts.device = group.bound_device_id # type: ignore[assignment] elif device.type == "cpu" or _get_object_coll_device(group) == "cpu": - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] opts.device = torch.device("cpu") else: # Use the current device set by the user. If user did not set any, this # may use default device 0, causing issues like hang or all processes # creating context on device 0. - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] opts.device = device if group.rank() == 0: warnings.warn( # warn only once @@ -5045,7 +5045,7 @@ def _hash_ranks_to_str(ranks: list[int]) -> str: # Takes a list of ranks and computes an integer color def _process_group_color(ranks: list[int]) -> int: # Convert list to tuple to make it hashable - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] ranks = tuple(ranks) hash_value = hash(ranks) # Split color must be: diff --git a/torch/distributed/elastic/agent/server/local_elastic_agent.py b/torch/distributed/elastic/agent/server/local_elastic_agent.py index 42d46a23fac..f643de5f9b2 100644 --- a/torch/distributed/elastic/agent/server/local_elastic_agent.py +++ b/torch/distributed/elastic/agent/server/local_elastic_agent.py @@ -333,10 +333,10 @@ class LocalElasticAgent(SimpleElasticAgent): rank=worker.global_rank, local_rank=local_rank, ) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] log_line_prefixes[local_rank] = log_line_prefix - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] envs[local_rank] = worker_env worker_args = list(spec.args) worker_args = macros.substitute(worker_args, str(local_rank)) diff --git a/torch/distributed/elastic/events/api.py b/torch/distributed/elastic/events/api.py index 2e340c47afa..939ab0793f6 100644 --- a/torch/distributed/elastic/events/api.py +++ b/torch/distributed/elastic/events/api.py @@ -54,7 +54,7 @@ class Event: if isinstance(data, str): data_dict = json.loads(data) data_dict["source"] = EventSource[data_dict["source"]] # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] return Event(**data_dict) def serialize(self) -> str: @@ -109,7 +109,7 @@ class RdzvEvent: if isinstance(data, str): data_dict = json.loads(data) data_dict["node_state"] = NodeState[data_dict["node_state"]] # type: ignore[possibly-undefined] - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] return RdzvEvent(**data_dict) def serialize(self) -> str: diff --git a/torch/distributed/elastic/metrics/api.py b/torch/distributed/elastic/metrics/api.py index 0bfa255174d..07d0f9fc43c 100644 --- a/torch/distributed/elastic/metrics/api.py +++ b/torch/distributed/elastic/metrics/api.py @@ -168,15 +168,15 @@ def profile(group=None): try: start_time = time.time() result = func(*args, **kwargs) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] publish_metric(group, f"{func.__name__}.success", 1) except Exception: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] publish_metric(group, f"{func.__name__}.failure", 1) raise finally: publish_metric( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] group, f"{func.__name__}.duration.ms", get_elapsed_time_ms(start_time), # type: ignore[possibly-undefined] diff --git a/torch/distributed/elastic/multiprocessing/tail_log.py b/torch/distributed/elastic/multiprocessing/tail_log.py index 7aaf115625f..7ad35115cd3 100644 --- a/torch/distributed/elastic/multiprocessing/tail_log.py +++ b/torch/distributed/elastic/multiprocessing/tail_log.py @@ -105,7 +105,7 @@ class TailLog: n = len(log_files) self._threadpool = None if n > 0: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._threadpool = ThreadPoolExecutor( max_workers=n, thread_name_prefix=f"{self.__class__.__qualname__}_{name}", diff --git a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py index 9a69dff151a..a0012607ce3 100644 --- a/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py +++ b/torch/distributed/elastic/rendezvous/etcd_rendezvous_backend.py @@ -126,7 +126,7 @@ class EtcdRendezvousBackend(RendezvousBackend): return tmp def _decode_state(self, result: etcd.EtcdResult) -> tuple[bytes, Token]: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] base64_state = result.value.encode() try: @@ -136,7 +136,7 @@ class EtcdRendezvousBackend(RendezvousBackend): "The state object is corrupt. See inner exception for details." ) from exc - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return state, result.modifiedIndex diff --git a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py index 8cf489cd18f..a10d49ae489 100644 --- a/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py +++ b/torch/distributed/elastic/utils/data/elastic_distributed_sampler.py @@ -53,7 +53,7 @@ class ElasticDistributedSampler(DistributedSampler[T]): raise TypeError("Dataset must be an instance of collections.abc.Sized") # Cast to Sized for mypy - # pyrefly: ignore # redundant-cast + # pyrefly: ignore [redundant-cast] sized_dataset = cast(Sized, dataset) if start_index >= len(sized_dataset): diff --git a/torch/distributed/fsdp/_common_utils.py b/torch/distributed/fsdp/_common_utils.py index 8e63d881838..a995e567bba 100644 --- a/torch/distributed/fsdp/_common_utils.py +++ b/torch/distributed/fsdp/_common_utils.py @@ -65,7 +65,7 @@ class _FSDPDeviceHandle: if backend is None: try: self.__backend = getattr(torch, device.type) - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self.__device = device except AttributeError as exc: raise AttributeError( diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_init.py b/torch/distributed/fsdp/_fully_shard/_fsdp_init.py index 5239c1add11..01d196795c3 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_init.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_init.py @@ -220,7 +220,7 @@ def _move_states_to_device( the future. """ # Follow the logic in `nn.Module._apply` - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] for tensor in itertools.chain(params, buffers): if tensor.device == device or tensor.device.type == "meta": # Keep meta-device tensors on meta device for deferred init diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py index 376898d519f..3fd2ec5373d 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param.py @@ -232,7 +232,7 @@ class FSDPParam: self._module_info: ParamModuleInfo = module_info self.mesh_info = mesh_info self.post_forward_mesh_info = post_forward_mesh_info - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self.device = device self.mp_policy = mp_policy self.offload_to_cpu: bool = isinstance(offload_policy, CPUOffloadPolicy) @@ -579,7 +579,7 @@ class FSDPParam: f"world size ({shard_world_size})" ) shard_rank = self.post_forward_mesh_info.shard_mesh_rank - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] sharded_numel = numel // shard_world_size self._sharded_post_forward_param_data = ( self.all_gather_outputs[0].narrow( @@ -713,7 +713,7 @@ class FSDPParam: self.device, non_blocking=True ) pre_all_gather_signature = inspect.signature( - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] sharded_local_tensor.fsdp_pre_all_gather ) num_fn_params = len(pre_all_gather_signature.parameters) @@ -729,7 +729,7 @@ class FSDPParam: ( all_gather_inputs, self._extensions_data.all_gather_metadata, - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] ) = sharded_local_tensor.fsdp_pre_all_gather( self.shard_mesh_from_root ) @@ -737,7 +737,7 @@ class FSDPParam: ( all_gather_inputs, self._extensions_data.all_gather_metadata, - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] ) = sharded_local_tensor.fsdp_pre_all_gather( self.shard_mesh_from_root, self._orig_size, @@ -865,7 +865,7 @@ class FSDPParam: f"instead of {self.sharded_param}" ) self.sharded_param = new_param - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] local_tensor = new_param._local_tensor if local_tensor.is_meta: return diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py index 32939a55450..f2eac802bb6 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_param_group.py @@ -151,7 +151,7 @@ class FSDPParamGroup: ] self.mesh_info = mesh_info self.post_forward_mesh_info = post_forward_mesh_info - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self.device = device self.device_handle = _get_device_handle(device.type) self.mp_policy = mp_policy @@ -621,7 +621,7 @@ class FSDPParamGroup: # Prefetch naively using the reverse post-forward order, which may # have mistargeted prefetches if not all modules used in forward # are used in this backward - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] target_fsdp_param_group = self.comm_ctx.post_forward_order[target_index] self._prefetch_unshard(target_fsdp_param_group, "backward") @@ -868,7 +868,7 @@ compile the forward part if you want to use Traceable FSDP2.""" raise RuntimeError(msg) @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, param_group: FSDPParamGroup, *inputs: torch.Tensor): # All tensors in `inputs` should require gradient RegisterPostBackwardFunction._assert_not_tracing_fsdp() diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py index 6484c94d3ca..d68dfbf2ddc 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_state.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_state.py @@ -96,7 +96,7 @@ class FSDPState(_State): for module in modules: _insert_module_state(module, self) self._modules = modules - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self._device = device self._device_handle = _get_device_handle(device.type) self._mp_policy = mp_policy diff --git a/torch/distributed/fsdp/_fully_shard/_fully_shard.py b/torch/distributed/fsdp/_fully_shard/_fully_shard.py index 54541656206..998a33746f9 100644 --- a/torch/distributed/fsdp/_fully_shard/_fully_shard.py +++ b/torch/distributed/fsdp/_fully_shard/_fully_shard.py @@ -51,7 +51,7 @@ def get_cls_to_fsdp_cls() -> dict[type, type]: @overload -# pyrefly: ignore # inconsistent-overload +# pyrefly: ignore [inconsistent-overload] def fully_shard( module: nn.Module, *, @@ -65,7 +65,7 @@ def fully_shard( @overload -# pyrefly: ignore # inconsistent-overload +# pyrefly: ignore [inconsistent-overload] def fully_shard( module: list[nn.Module], *, diff --git a/torch/distributed/fsdp/_init_utils.py b/torch/distributed/fsdp/_init_utils.py index f157bbd565a..74cc12dc889 100644 --- a/torch/distributed/fsdp/_init_utils.py +++ b/torch/distributed/fsdp/_init_utils.py @@ -509,7 +509,7 @@ def _init_prefetching_state( @no_type_check -# pyrefly: ignore # bad-function-definition +# pyrefly: ignore [bad-function-definition] def _init_extension(state: _FSDPState, device_mesh: DeviceMesh = None) -> _FSDPState: # TODO: we need to add additional check once we support FSDP + PiPPy. # This check is currently sufficient, since we only support FSDP + TP. @@ -918,7 +918,7 @@ def _materialize_meta_module( # the module has directly managed parameters/buffers module_state_iter = itertools.chain( module.parameters(recurse=False), - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] module.buffers(recurse=False), ) has_module_states = len(list(module_state_iter)) > 0 diff --git a/torch/distributed/fsdp/_optim_utils.py b/torch/distributed/fsdp/_optim_utils.py index 300be17b6ab..3c64bfbf2f6 100644 --- a/torch/distributed/fsdp/_optim_utils.py +++ b/torch/distributed/fsdp/_optim_utils.py @@ -612,7 +612,7 @@ def _flatten_optim_state( ] # Check that the unflattened parameters have the same state names state_names = None - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for unflat_param_state in unflat_param_states: if unflat_param_state is None: continue @@ -936,7 +936,7 @@ def _rekey_sharded_optim_state_dict( flat_param_key = unflat_param_names_to_flat_param_key.get( key.unflat_param_names, key.unflat_param_names ) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] rekeyed_osd_state[flat_param_key] = param_state # Only process param_groups if it exists in sharded_osd @@ -999,7 +999,7 @@ def _get_param_id_to_param_from_optim_input( if optim_input is None: return dict(enumerate(model.parameters())) try: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] params = cast(list[nn.Parameter], list(optim_input)) except TypeError as e: raise TypeError( diff --git a/torch/distributed/nn/api/remote_module.py b/torch/distributed/nn/api/remote_module.py index f4313d4cbee..066197fad24 100644 --- a/torch/distributed/nn/api/remote_module.py +++ b/torch/distributed/nn/api/remote_module.py @@ -356,7 +356,7 @@ class _RemoteModule(nn.Module): def register_backward_hook( # type: ignore[return] self, hook: Callable[[Module, _grad_t, _grad_t], Union[None, _grad_t]], - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> RemovableHandle: _raise_not_supported(self.register_backward_hook.__name__) @@ -371,7 +371,7 @@ class _RemoteModule(nn.Module): ], prepend: bool = False, with_kwargs: bool = False, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> RemovableHandle: _raise_not_supported(self.register_forward_pre_hook.__name__) @@ -383,7 +383,7 @@ class _RemoteModule(nn.Module): ], prepend: bool = False, with_kwargs: bool = False, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> RemovableHandle: _raise_not_supported(self.register_forward_hook.__name__) @@ -408,7 +408,7 @@ class _RemoteModule(nn.Module): prefix: str = "", recurse: bool = True, remove_duplicate: bool = True, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Iterator[tuple[str, Parameter]]: _raise_not_supported(self.named_parameters.__name__) @@ -420,7 +420,7 @@ class _RemoteModule(nn.Module): prefix: str = "", recurse: bool = True, remove_duplicate: bool = True, - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] ) -> Iterator[tuple[str, Tensor]]: _raise_not_supported(self.named_buffers.__name__) @@ -584,31 +584,31 @@ class _RemoteModule(nn.Module): remote_module = object.__new__(RemoteModule) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] enable_moving_cpu_tensors_to_cuda = remote_module._prepare_init(remote_device) if _module_interface_cls is not None: # Users reply on this field to know if this generated RemoteModule is TorchScript-able. - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] remote_module.is_scriptable = True - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] remote_module._init_template( _module_interface_cls, enable_moving_cpu_tensors_to_cuda ) else: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] remote_module.is_scriptable = False - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] remote_module.generated_methods = ( _NON_SCRIPTABLE_REMOTE_MODULE_MODULE._generated_methods ) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] remote_module.module_rref = module_rref - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] remote_module._install_generated_methods() - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] remote_module._check_attribute_picklability() return remote_module @@ -711,11 +711,11 @@ def _remote_module_receiver( m.__dict__.update(serialized_remote_module._asdict()) # Unpickling the attribute `module_rref` must invoke RRef's `_deserialize()` method. - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] m.module_rref = rpc.PyRRef._deserialize(m.module_rref) # Install generated methods when unpickled. - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for method in m.generated_methods: method_name = method.__name__ method = torch.jit.export(method) diff --git a/torch/distributed/nn/functional.py b/torch/distributed/nn/functional.py index ff2c776348e..287775be924 100644 --- a/torch/distributed/nn/functional.py +++ b/torch/distributed/nn/functional.py @@ -225,7 +225,7 @@ def all_reduce(tensor, op=ReduceOp.SUM, group=group.WORLD): class _Broadcast(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, src, group, tensor): ctx.src = src ctx.group = group @@ -237,7 +237,7 @@ class _Broadcast(Function): return tensor @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): gx = _Reduce.apply(ctx.src, ReduceOp.SUM, ctx.group, grad_output) if ctx.src != ctx.rank: @@ -247,7 +247,7 @@ class _Broadcast(Function): class _Gather(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, dst, group, tensor): ctx.dst = dst ctx.group = group @@ -273,7 +273,7 @@ class _Gather(Function): class _Scatter(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, src, group, *tensors): ctx.src = src ctx.group = group @@ -286,14 +286,14 @@ class _Scatter(Function): return output @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): return (None, None) + _Gather.apply(ctx.src, ctx.group, grad_output) class _Reduce(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, src, op, group, tensor): ctx.src = src ctx.group = group @@ -302,14 +302,14 @@ class _Reduce(Function): return tensor @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): return (None, None, None) + (_Broadcast.apply(ctx.src, ctx.group, grad_output),) class _Reduce_Scatter(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, op, group, tensor, *input_tensor_list): ctx.group = group # Need contiguous tensors for collectives. @@ -319,14 +319,14 @@ class _Reduce_Scatter(Function): return tensor @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): return (None, None, None) + _AllGather.apply(ctx.group, grad_output) class _AllGather(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, group, tensor): # Need contiguous tensors for collectives. tensor = tensor.contiguous() @@ -356,14 +356,14 @@ class _AllGather(Function): class _AllGatherBase(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, output_tensor, input_tensor, group): ctx.group = group dist._all_gather_base(output_tensor, input_tensor.contiguous(), group=group) return output_tensor @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): if dist.get_backend(group=ctx.group) is dist.Backend.NCCL: world_size = dist.get_world_size(group=ctx.group) @@ -385,7 +385,7 @@ class _AllGatherBase(Function): class _AlltoAll(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, group, out_tensor_list, *tensors): ctx.group = group ctx.input_tensor_size_list = [ @@ -421,7 +421,7 @@ class _AlltoAll(Function): class _AlltoAllSingle(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, group, output, output_split_sizes, input_split_sizes, input): ctx.group = group ctx.input_size = input.size() @@ -437,7 +437,7 @@ class _AlltoAllSingle(Function): return output @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): tensor = torch.empty( ctx.input_size, device=grad_output.device, dtype=grad_output.dtype @@ -455,7 +455,7 @@ class _AlltoAllSingle(Function): class _AllReduce(Function): @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(ctx, op, group, tensor): ctx.group = group ctx.op = op @@ -464,6 +464,6 @@ class _AllReduce(Function): return tensor @staticmethod - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def backward(ctx, grad_output): return (None, None) + (_AllReduce.apply(ctx.op, ctx.group, grad_output),) diff --git a/torch/distributed/optim/zero_redundancy_optimizer.py b/torch/distributed/optim/zero_redundancy_optimizer.py index 18f31ade189..8c82b53eff7 100644 --- a/torch/distributed/optim/zero_redundancy_optimizer.py +++ b/torch/distributed/optim/zero_redundancy_optimizer.py @@ -100,19 +100,19 @@ def _broadcast_object( data = bytearray(buffer.getbuffer()) length_tensor = torch.LongTensor([len(data)]).to(device) data_send_tensor = torch.ByteTensor(data).to(device) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] dist.broadcast(data_send_tensor, src=src_rank, group=group, async_op=False) else: # Receive the object length_tensor = torch.LongTensor([0]).to(device) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] dist.broadcast(length_tensor, src=src_rank, group=group, async_op=False) data_recv_tensor = torch.empty( [int(length_tensor.item())], dtype=torch.uint8, device=device ) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] dist.broadcast(data_recv_tensor, src=src_rank, group=group, async_op=False) buffer = io.BytesIO(data_recv_tensor.cpu().numpy()) obj = torch.load(buffer, map_location=device, weights_only=False) @@ -171,7 +171,7 @@ class _DDPBucketAssignment: if len(self.parameters) == 0: raise ValueError("Empty bucket assignment") # DDP guarantees all parameters in the bucket have the same device - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self.device: torch.device = self.parameters[0].device self.tensor: Optional[torch.Tensor] = None @@ -420,7 +420,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): self.world_size: int = dist.get_world_size(self.process_group) self.rank: int = dist.get_rank(self.process_group) self.global_rank: int = dist.distributed_c10d.get_global_rank( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.process_group, self.rank, ) @@ -542,7 +542,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): self._all_state_dicts = [] for rank in range(self.world_size): global_rank = dist.distributed_c10d.get_global_rank( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.process_group, rank, ) @@ -776,7 +776,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): for dev_i_buckets in self._buckets: bucket = dev_i_buckets[rank] global_rank = dist.distributed_c10d.get_global_rank( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.process_group, rank, ) @@ -791,7 +791,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): else: param_groups = self._partition_parameters()[rank] global_rank = dist.distributed_c10d.get_global_rank( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.process_group, rank, ) @@ -992,12 +992,12 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): for param_index, param in enumerate(bucket_params): param_numel = param.numel() if ( - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] assignment_size + param_numel >= threshold and param_index > bucket_offset ): assigned_rank = self._get_min_index( - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] size_per_rank, assigned_ranks_per_bucket[bucket_index], ) @@ -1010,7 +1010,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): assigned_rank, assigned_ranks_per_bucket, ) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] size_per_rank[assigned_rank] += assignment_size bucket_offset = param_index assignment_size = 0 @@ -1018,7 +1018,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): # Assign the remainder of the bucket so that no assignment # spans across two buckets assigned_rank = self._get_min_index( - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] size_per_rank, assigned_ranks_per_bucket[bucket_index], ) @@ -1029,7 +1029,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): assigned_rank, assigned_ranks_per_bucket, ) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] size_per_rank[assigned_rank] += assignment_size return self._bucket_assignments_per_rank_cache @@ -1108,7 +1108,7 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable): return loss - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def step( self, closure: Optional[Callable[[], float]] = None, diff --git a/torch/distributed/pipelining/_IR.py b/torch/distributed/pipelining/_IR.py index 45e90c4f3aa..52e56dd3f95 100644 --- a/torch/distributed/pipelining/_IR.py +++ b/torch/distributed/pipelining/_IR.py @@ -282,7 +282,7 @@ class LossWrapper(torch.nn.Module): class TrivialLossWrapper(LossWrapper): - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def forward(self, x, targets): model_out = self.module(x) return self.loss_fn(model_out, targets) diff --git a/torch/distributed/pipelining/_backward.py b/torch/distributed/pipelining/_backward.py index 5410c9b9448..38d30c793e8 100644 --- a/torch/distributed/pipelining/_backward.py +++ b/torch/distributed/pipelining/_backward.py @@ -245,7 +245,7 @@ def stage_backward_weight( if non_none_grads: summed_grad = sum(non_none_grads) valid_edges.append(GradientEdge(intermediate, 0)) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] valid_grad_outputs.append(summed_grad) # Break a reference cycle caused inside stage_backward_input->get_hook->hook diff --git a/torch/distributed/pipelining/_schedule_visualizer.py b/torch/distributed/pipelining/_schedule_visualizer.py index 365cdd246b3..e5891c775a6 100644 --- a/torch/distributed/pipelining/_schedule_visualizer.py +++ b/torch/distributed/pipelining/_schedule_visualizer.py @@ -81,7 +81,7 @@ def get_schedule_ops( raise ValueError(f"Invalid schedule: {schedule_class}") # Instantiate the schedule class - # pyrefly: ignore # bad-instantiation, bad-argument-type + # pyrefly: ignore [bad-instantiation, bad-argument-type] schedule_instance = schedule_class(stages, num_microbatches) assert schedule_instance.pipeline_order is not None diff --git a/torch/distributed/pipelining/schedules.py b/torch/distributed/pipelining/schedules.py index ff8e19d4f7e..39da483fe00 100644 --- a/torch/distributed/pipelining/schedules.py +++ b/torch/distributed/pipelining/schedules.py @@ -246,7 +246,7 @@ def _format_pipeline_order( pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order) ] # Transpose the list of lists (rows to columns) - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue="")) # Generate column labels for ranks num_ranks = len(pipeline_order) diff --git a/torch/distributed/pipelining/stage.py b/torch/distributed/pipelining/stage.py index 3e9aabedd0a..c18c4d6f678 100644 --- a/torch/distributed/pipelining/stage.py +++ b/torch/distributed/pipelining/stage.py @@ -155,7 +155,7 @@ class _PipelineStageBase(ABC): self.submod = submodule self.stage_index = stage_index self.num_stages = num_stages - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self.device = device self.group = group diff --git a/torch/distributed/remote_device.py b/torch/distributed/remote_device.py index fbd14faa4de..a71e15c9c34 100644 --- a/torch/distributed/remote_device.py +++ b/torch/distributed/remote_device.py @@ -36,14 +36,14 @@ class _remote_device: elif isinstance(remote_device, str): fields = remote_device.split("/") if len(fields) == 2: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._worker_name, self._device = fields elif len(fields) == 1: # Check if this is a valid device. if _remote_device._is_valid_local_device(fields[0]): self._device = fields[0] else: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._worker_name = fields[0] self._device = "cpu" else: @@ -65,7 +65,7 @@ class _remote_device: # rank:/device format, extract rank if fields[0] == "rank" and fields[1].isdigit(): self._rank = int(fields[1]) # type: ignore[assignment] - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._worker_name = None else: raise ValueError(PARSE_ERROR) diff --git a/torch/distributed/rendezvous.py b/torch/distributed/rendezvous.py index 4d5e5877816..0c9ebc468be 100644 --- a/torch/distributed/rendezvous.py +++ b/torch/distributed/rendezvous.py @@ -93,7 +93,7 @@ def _rendezvous_helper(url: str, rank: int, world_size_opt: Optional[int], **kwa result = result._replace( query=f"{'&'.join([f'{k}={v}' for k, v in query_dict.items()])}" ) - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] url = urlunparse(result) if result.scheme not in _rendezvous_handlers: @@ -111,7 +111,7 @@ def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs): if not isinstance(world_size, numbers.Integral): raise RuntimeError(f"`world_size` must be an integer. {world_size}") - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return _rendezvous_helper(url, rank, world_size, **kwargs) diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 883b6b324f9..845ce0b7faf 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -471,7 +471,7 @@ def _rref_typeof_on_user( T = TypeVar("T") -# pyrefly: ignore # invalid-annotation +# pyrefly: ignore [invalid-annotation] GenericWithOneTypeVar = Generic[T] @@ -718,7 +718,7 @@ def _invoke_rpc( is_async_exec = hasattr(func, "_wrapped_async_rpc_function") if is_async_exec: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] wrapped = func._wrapped_async_rpc_function if isinstance(wrapped, torch.jit.ScriptFunction): func = wrapped diff --git a/torch/distributed/rpc/backend_registry.py b/torch/distributed/rpc/backend_registry.py index 2eea49a0803..16299404c6b 100644 --- a/torch/distributed/rpc/backend_registry.py +++ b/torch/distributed/rpc/backend_registry.py @@ -95,7 +95,7 @@ def register_backend( BackendType.__repr__ = _backend_type_repr # type: ignore[assignment] if BackendType.__doc__: BackendType.__doc__ = _backend_type_doc - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] return BackendType[backend_name] diff --git a/torch/distributed/rpc/options.py b/torch/distributed/rpc/options.py index 506c7bfd6ad..7c1e3d4b5a0 100644 --- a/torch/distributed/rpc/options.py +++ b/torch/distributed/rpc/options.py @@ -48,7 +48,7 @@ else: _TensorPipeRpcBackendOptionsBase = object # type: ignore[assignment, misc] -# pyrefly: ignore # invalid-inheritance +# pyrefly: ignore [invalid-inheritance] class TensorPipeRpcBackendOptions(_TensorPipeRpcBackendOptionsBase): r""" The backend options for diff --git a/torch/distributed/rpc/server_process_global_profiler.py b/torch/distributed/rpc/server_process_global_profiler.py index dfe8c02aef2..29a916772d3 100644 --- a/torch/distributed/rpc/server_process_global_profiler.py +++ b/torch/distributed/rpc/server_process_global_profiler.py @@ -5,7 +5,7 @@ import itertools import torch -# pyrefly: ignore # deprecated +# pyrefly: ignore [deprecated] from torch.autograd.profiler_legacy import profile from . import ( @@ -176,13 +176,13 @@ class _server_process_global_profile(profile): flattened_function_events = list( itertools.chain.from_iterable(process_global_function_events) ) - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.function_events = torch.autograd.profiler_util.EventList( flattened_function_events, use_device="cuda" if self.use_cuda else None, profile_memory=self.profile_memory, ) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self.function_events._build_tree() self.process_global_function_events = process_global_function_events diff --git a/torch/distributed/run.py b/torch/distributed/run.py index fbd234eb0ee..a076c8d5798 100644 --- a/torch/distributed/run.py +++ b/torch/distributed/run.py @@ -869,7 +869,7 @@ def config_from_args(args) -> tuple[LaunchConfig, Union[Callable, str], list[str ) from e logs_specs_cls: type[LogsSpecs] = _get_logs_specs_class(args.logs_specs) - # pyrefly: ignore # bad-instantiation + # pyrefly: ignore [bad-instantiation] logs_specs = logs_specs_cls( log_dir=args.log_dir, redirects=Std.from_str(args.redirects), diff --git a/torch/distributed/tensor/_dtensor_spec.py b/torch/distributed/tensor/_dtensor_spec.py index 42cb7fcd7c3..5e7d7b3c842 100644 --- a/torch/distributed/tensor/_dtensor_spec.py +++ b/torch/distributed/tensor/_dtensor_spec.py @@ -90,7 +90,7 @@ class DTensorSpec: if not isinstance(self.placements, tuple): self.placements = tuple(self.placements) if self.shard_order is None: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.shard_order = DTensorSpec.compute_default_shard_order(self.placements) self._hash: int | None = None diff --git a/torch/distributed/tensor/_ops/_common_rules.py b/torch/distributed/tensor/_ops/_common_rules.py index 65d72c09e7a..1e7ff648f7f 100644 --- a/torch/distributed/tensor/_ops/_common_rules.py +++ b/torch/distributed/tensor/_ops/_common_rules.py @@ -171,7 +171,7 @@ def einop_rule( global_shape, input_spec.mesh, input_spec.placements ) cost += prod(local_shape) * input_spec.mesh.size(mesh_dim) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] costs.append(cost) d_to_keep_sharding = dims[costs.index(max(costs))] for d in dims: diff --git a/torch/distributed/tensor/_ops/_math_ops.py b/torch/distributed/tensor/_ops/_math_ops.py index 0d2d68c9923..91f6f7d0265 100644 --- a/torch/distributed/tensor/_ops/_math_ops.py +++ b/torch/distributed/tensor/_ops/_math_ops.py @@ -138,7 +138,7 @@ class _NormPartial(Partial): f"Expected int or float, got {type(self.norm_type)}" ) if self.norm_type != 0 and self.norm_type != 1: - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] return tensor**self.norm_type return tensor @@ -149,7 +149,7 @@ class _NormPartial(Partial): f"Expected int or float, got {type(self.norm_type)}" ) if self.norm_type != 0 and self.norm_type != 1: - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] return tensor ** (1.0 / self.norm_type) return tensor diff --git a/torch/distributed/tensor/_ops/_matrix_ops.py b/torch/distributed/tensor/_ops/_matrix_ops.py index 0005acf0cd7..cd6ba48d983 100644 --- a/torch/distributed/tensor/_ops/_matrix_ops.py +++ b/torch/distributed/tensor/_ops/_matrix_ops.py @@ -1094,9 +1094,9 @@ def grouped_mm_strategy(op_schema: OpSchema) -> OpStrategy: ) return TensorMeta(torch.Size(local_shape), local_stride, meta.dtype) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] mat1_meta = local_meta(mat1_strategy.strategies[0], input_specs[0].placements) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] mat2_meta = local_meta(mat2_strategy.strategies[0], input_specs[1].placements) def check_valid_strides(meta: TensorMeta) -> bool: diff --git a/torch/distributed/tensor/_ops/utils.py b/torch/distributed/tensor/_ops/utils.py index 2b6f30bded9..9a4ce12ed82 100644 --- a/torch/distributed/tensor/_ops/utils.py +++ b/torch/distributed/tensor/_ops/utils.py @@ -336,7 +336,7 @@ def expand_to_full_mesh_op_strategy( for specs in zip(*strategy_comb): if specs[0] is not None: # TODO: we should fill in tensor_meta here. If nothing else, it helps the filter strategy callback - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] spec_list.append(DTensorSpec(mesh, specs)) else: spec_list.append(None) diff --git a/torch/distributed/tensor/_random.py b/torch/distributed/tensor/_random.py index 19e69cff218..d81f58520aa 100644 --- a/torch/distributed/tensor/_random.py +++ b/torch/distributed/tensor/_random.py @@ -150,7 +150,7 @@ class _RNGStateTracker: """ def __init__(self, device: torch.device): - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self._device = device self._device_handle = _get_device_handle(self._device.type) if not (self._device_handle and self._device_handle.is_available()): diff --git a/torch/distributed/tensor/_tp_conv.py b/torch/distributed/tensor/_tp_conv.py index 2fa1848d399..2b3f126c7e5 100644 --- a/torch/distributed/tensor/_tp_conv.py +++ b/torch/distributed/tensor/_tp_conv.py @@ -256,10 +256,10 @@ def convolution_backward_handler( kwargs: dict[str, object], ) -> object: # Redistribute grad_output tensor to the same placement as input tensor - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] args = list(args) assert isinstance(args[0], dtensor.DTensor) and isinstance(args[1], dtensor.DTensor) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] args[0] = args[0].redistribute(args[1].device_mesh, args[1].placements) args = tuple(args) diff --git a/torch/distributed/tensor/debug/_comm_mode.py b/torch/distributed/tensor/debug/_comm_mode.py index 31f091fe31b..e494b07d96b 100644 --- a/torch/distributed/tensor/debug/_comm_mode.py +++ b/torch/distributed/tensor/debug/_comm_mode.py @@ -594,7 +594,7 @@ class CommDebugMode(TorchDispatchMode): self.advanced_module_tracker.__enter__() return self - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def __exit__(self, *args): self.advanced_module_tracker.__exit__() super().__exit__(*args) diff --git a/torch/distributed/tensor/debug/_op_coverage.py b/torch/distributed/tensor/debug/_op_coverage.py index fa17430bd94..7315d64d697 100644 --- a/torch/distributed/tensor/debug/_op_coverage.py +++ b/torch/distributed/tensor/debug/_op_coverage.py @@ -90,7 +90,7 @@ def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=Fals op_infos.sort(key=itemgetter(count_idx), reverse=True) headers = ["Operator", "Schema", "Total Count", "Supported"] - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] print(tabulate(op_infos, headers=headers)) if output_csv: @@ -102,5 +102,5 @@ def print_op_coverage_summary(model: nn.Module, args, kwargs, *, output_csv=Fals csv_writer.writerow(headers) # Write each table row to the CSV file for row in op_infos: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] csv_writer.writerow(row) diff --git a/torch/distributed/tensor/examples/torchrec_sharding_example.py b/torch/distributed/tensor/examples/torchrec_sharding_example.py index 9647b4bb93e..713dba994e7 100644 --- a/torch/distributed/tensor/examples/torchrec_sharding_example.py +++ b/torch/distributed/tensor/examples/torchrec_sharding_example.py @@ -91,10 +91,10 @@ class LocalShardsWrapper(torch.Tensor): if func in supported_ops: res_shards_list = [ func(shard, *args[1:], **kwargs) - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] for shard in args[0].shards ] - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] return LocalShardsWrapper(res_shards_list, args[0].shard_offsets) else: raise NotImplementedError( @@ -144,7 +144,7 @@ def run_torchrec_row_wise_even_sharding_example(rank, world_size): local_tensor = torch.randn(local_shard_shape, device=device) # row-wise sharding: one shard per rank # create the local shards wrapper - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] local_shards_wrapper = LocalShardsWrapper( local_shards=[local_tensor], offsets=[local_shard_offset], @@ -223,7 +223,7 @@ def run_torchrec_row_wise_uneven_sharding_example(rank, world_size): # local shards # row-wise sharding: one shard per rank # create the local shards wrapper - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] local_shards_wrapper = LocalShardsWrapper( local_shards=[local_tensor], offsets=[local_shard_offset], @@ -302,7 +302,7 @@ def run_torchrec_table_wise_sharding_example(rank, world_size): local_shard_offset = torch.Size((0, 0)) # wrap local shards into a wrapper local_shards_wrapper = ( - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] LocalShardsWrapper( local_shards=[local_tensor], offsets=[local_shard_offset], diff --git a/torch/distributed/tensor/experimental/_attention.py b/torch/distributed/tensor/experimental/_attention.py index 5939a247c2f..566390b8a03 100644 --- a/torch/distributed/tensor/experimental/_attention.py +++ b/torch/distributed/tensor/experimental/_attention.py @@ -465,7 +465,7 @@ def _templated_ring_attention( ) sdpa_merger.step(out, logsumexp, partial) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] return *sdpa_merger.results(), *rest @@ -632,7 +632,7 @@ def _templated_ring_attention_backward( grad_query, grad_key, grad_value, - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] *rest, ) diff --git a/torch/distributed/tensor/experimental/_func_map.py b/torch/distributed/tensor/experimental/_func_map.py index 922d5238cab..cf0e9df1ab3 100644 --- a/torch/distributed/tensor/experimental/_func_map.py +++ b/torch/distributed/tensor/experimental/_func_map.py @@ -238,7 +238,7 @@ def _local_map_wrapped( flat_local_args.append(arg) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] local_args = pytree.tree_unflatten(flat_local_args, args_spec) out = func(*local_args, **kwargs) @@ -272,7 +272,7 @@ def _local_map_wrapped( flat_dist_out.append(out) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] return pytree.tree_unflatten(flat_dist_out, out_spec) else: return out diff --git a/torch/distributed/tensor/experimental/_tp_transform.py b/torch/distributed/tensor/experimental/_tp_transform.py index f84c6a10139..f66ab2b2e39 100644 --- a/torch/distributed/tensor/experimental/_tp_transform.py +++ b/torch/distributed/tensor/experimental/_tp_transform.py @@ -237,11 +237,11 @@ def _mark_sharding( op_schema, ) placement_strategies[node] = OpSpec( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] output_specs=_get_output_spec_from_output_sharding(output_sharding), - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] input_specs=output_sharding.redistribute_schema.args_spec - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if output_sharding.redistribute_schema is not None else _get_input_node_specs(node, placement_strategies), ) diff --git a/torch/distributed/tensor/parallel/fsdp.py b/torch/distributed/tensor/parallel/fsdp.py index dd62d5fc171..f491624b5aa 100644 --- a/torch/distributed/tensor/parallel/fsdp.py +++ b/torch/distributed/tensor/parallel/fsdp.py @@ -134,11 +134,11 @@ def _rewrite_spec_if_needed( break if rewrite: spec = copy.deepcopy(spec) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for i, placement in enumerate(spec.placements): placement = cast(_remote_device, placement) if placement.rank() == rank and placement.device() != tensor.device: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] spec.placements[i] = _remote_device(f"rank:{rank}/{tensor.device}") return spec diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py index aae098056bb..2b1a88f1a12 100644 --- a/torch/distributed/utils.py +++ b/torch/distributed/utils.py @@ -134,16 +134,16 @@ def _recursive_to(inputs, target_device, use_side_stream_for_tensor_copies): from torch.nn.parallel.scatter_gather import _is_namedtuple if _is_namedtuple(obj): - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return [type(obj)(*args) for args in zip(*map(to_map, obj))] if isinstance(obj, tuple) and len(obj) > 0: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return list(zip(*map(to_map, obj))) if isinstance(obj, list) and len(obj) > 0: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return [list(i) for i in zip(*map(to_map, obj))] if isinstance(obj, dict) and len(obj) > 0: - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] return [type(obj)(i) for i in zip(*map(to_map, obj.items()))] return [obj]