From c7eee495259a5ce2f2f5e8830bcec3b6eca84b31 Mon Sep 17 00:00:00 2001 From: Maggie Moss Date: Sun, 26 Oct 2025 00:44:07 +0000 Subject: [PATCH] Fix pyrefly ignores 1/n (#166239) First diff adjusting the syntax for pyrefly: ignore suppressions so they only hide one class of type error. Test: lintrunner pyrefly check Pull Request resolved: https://github.com/pytorch/pytorch/pull/166239 Approved by: https://github.com/oulgen --- pyrefly.toml | 1 + tools/stats/utilization_stats_lib.py | 13 ++--- torch/__init__.py | 8 +-- torch/_custom_op/impl.py | 2 +- torch/_dispatch/python.py | 2 +- torch/_dynamo/variables/builtin.py | 38 +++++++------ torch/_dynamo/variables/lazy.py | 1 + torch/_guards.py | 6 +++ torch/_inductor/await_utils.py | 3 +- .../rocm/ck_universal_gemm_template.py | 5 +- torch/_inductor/compile_fx.py | 53 ++++++++++--------- torch/_inductor/fx_passes/split_cat.py | 31 +++++------ torch/_jit_internal.py | 4 +- torch/_lazy/closure.py | 4 +- torch/_meta_registrations.py | 24 +++++---- torch/_ops.py | 8 +-- torch/_subclasses/functional_tensor.py | 1 + torch/_tensor.py | 7 +-- torch/_utils.py | 17 +++--- torch/_weights_only_unpickler.py | 1 + torch/accelerator/memory.py | 1 + torch/amp/autocast_mode.py | 6 +-- torch/ao/ns/fx/graph_matcher.py | 13 ++--- torch/autograd/profiler_util.py | 17 +++--- torch/compiler/__init__.py | 1 + torch/distributed/_local_tensor/__init__.py | 8 ++- .../_shard/sharded_tensor/utils.py | 3 +- torch/distributed/checkpoint/_pg_transport.py | 3 +- torch/distributed/fsdp/_flat_param.py | 3 +- .../fsdp/_fully_shard/_fsdp_common.py | 3 ++ torch/distributed/tensor/_api.py | 15 ++++++ torch/distributed/tensor/_dispatch.py | 14 +++-- torch/distributed/tensor/_redistribute.py | 6 +++ torch/distributed/tensor/parallel/loss.py | 21 +++++--- torch/export/unflatten.py | 6 ++- torch/functional.py | 2 +- torch/fx/experimental/unification/more.py | 1 + torch/fx/node.py | 12 +++-- torch/mtia/__init__.py | 2 +- torch/nested/_internal/ops.py | 2 +- torch/nn/modules/batchnorm.py | 21 ++++---- torch/nn/modules/module.py | 20 +++---- .../_internal/exporter/_capture_strategies.py | 3 +- torch/onnx/_internal/exporter/_core.py | 15 +++--- .../torchscript_exporter/verification.py | 2 +- torch/optim/_multi_tensor/__init__.py | 1 + torch/signal/windows/windows.py | 2 +- torch/storage.py | 2 +- torch/utils/_ordered_set.py | 3 +- torch/utils/backend_registration.py | 8 ++- torch/utils/hipify/cuda_to_hip_mappings.py | 9 +++- torch/utils/hipify/hipify_python.py | 5 +- torch/utils/weak.py | 3 +- torch/xpu/__init__.py | 2 +- torch/xpu/streams.py | 2 +- 55 files changed, 282 insertions(+), 184 deletions(-) diff --git a/pyrefly.toml b/pyrefly.toml index 3dd62368184..cca6f5eb78c 100644 --- a/pyrefly.toml +++ b/pyrefly.toml @@ -130,6 +130,7 @@ errors.bad-param-name-override = false # Mypy doesn't require that imports are explicitly imported, so be compatible with that. # Might be a good idea to turn this on in future. errors.implicit-import = false +errors.deprecated = false # re-enable after we've fix import formatting permissive-ignores = true replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"] search-path = ["tools/experimental"] diff --git a/tools/stats/utilization_stats_lib.py b/tools/stats/utilization_stats_lib.py index 33551fd55de..306cd7fe9e1 100644 --- a/tools/stats/utilization_stats_lib.py +++ b/tools/stats/utilization_stats_lib.py @@ -2,7 +2,8 @@ from dataclasses import dataclass, field from datetime import datetime from typing import Optional -from dataclasses_json import DataClassJsonMixin +# pyrefly: ignore [missing-import] +from dataclasses_json import DataClassJsonMixin # type: ignore[import-not-found] _DATA_MODEL_VERSION = 1.5 @@ -17,7 +18,7 @@ class UtilizationStats: @dataclass -class UtilizationMetadata(DataClassJsonMixin): +class UtilizationMetadata(DataClassJsonMixin): # type: ignore[misc, no-any-unimported] level: str workflow_id: str job_id: str @@ -33,7 +34,7 @@ class UtilizationMetadata(DataClassJsonMixin): @dataclass -class GpuUsage(DataClassJsonMixin): +class GpuUsage(DataClassJsonMixin): # type: ignore[misc, no-any-unimported] uuid: Optional[str] = None util_percent: Optional[UtilizationStats] = None mem_util_percent: Optional[UtilizationStats] = None @@ -43,14 +44,14 @@ class GpuUsage(DataClassJsonMixin): @dataclass -class RecordData(DataClassJsonMixin): +class RecordData(DataClassJsonMixin): # type: ignore[misc, no-any-unimported] cpu: Optional[UtilizationStats] = None memory: Optional[UtilizationStats] = None gpu_usage: Optional[list[GpuUsage]] = None @dataclass -class UtilizationRecord(DataClassJsonMixin): +class UtilizationRecord(DataClassJsonMixin): # type: ignore[misc, no-any-unimported] level: str timestamp: int data: Optional[RecordData] = None @@ -63,7 +64,7 @@ class UtilizationRecord(DataClassJsonMixin): # the db schema related to this is: # https://github.com/pytorch/test-infra/blob/main/clickhouse_db_schema/oss_ci_utilization/oss_ci_utilization_metadata_schema.sql @dataclass -class OssCiSegmentV1(DataClassJsonMixin): +class OssCiSegmentV1(DataClassJsonMixin): # type: ignore[misc, no-any-unimported] level: str name: str start_at: int diff --git a/torch/__init__.py b/torch/__init__.py index 18c56ebe46f..95f55ae5878 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -1703,7 +1703,7 @@ def _check(cond, message=None): # noqa: F811 an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """ - _check_with(RuntimeError, cond, message) # pyrefly: ignore # bad-argument-type + _check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type] # TODO add deprecation annotation @@ -1753,7 +1753,7 @@ def _check_index(cond, message=None): # noqa: F811 an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """ - _check_with(IndexError, cond, message) # pyrefly: ignore # bad-argument-type + _check_with(IndexError, cond, message) # pyrefly: ignore [bad-argument-type] def _check_value(cond, message=None): # noqa: F811 @@ -1771,7 +1771,7 @@ def _check_value(cond, message=None): # noqa: F811 an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """ - _check_with(ValueError, cond, message) # pyrefly: ignore # bad-argument-type + _check_with(ValueError, cond, message) # pyrefly: ignore [bad-argument-type] def _check_type(cond, message=None): # noqa: F811 @@ -1789,7 +1789,7 @@ def _check_type(cond, message=None): # noqa: F811 an object that has a ``__str__()`` method to be used as the error message. Default: ``None`` """ - _check_with(TypeError, cond, message) # pyrefly: ignore # bad-argument-type + _check_with(TypeError, cond, message) # pyrefly: ignore [bad-argument-type] def _check_not_implemented(cond, message=None): # noqa: F811 diff --git a/torch/_custom_op/impl.py b/torch/_custom_op/impl.py index b445907b5d1..bcc0193fb88 100644 --- a/torch/_custom_op/impl.py +++ b/torch/_custom_op/impl.py @@ -101,7 +101,7 @@ def custom_op( lib, ns, function_schema, name, ophandle, _private_access=True ) - result.__name__ = func.__name__ # pyrefly: ignore # bad-assignment + result.__name__ = func.__name__ # pyrefly: ignore [bad-assignment] result.__module__ = func.__module__ result.__doc__ = func.__doc__ diff --git a/torch/_dispatch/python.py b/torch/_dispatch/python.py index 4cf1d1b5cff..e6b3f09c22f 100644 --- a/torch/_dispatch/python.py +++ b/torch/_dispatch/python.py @@ -154,7 +154,7 @@ def make_crossref_functionalize( maybe_detach, (f_args, f_kwargs) ) with fake_mode: - f_r = op(*f_args, **f_kwargs) # pyrefly: ignore # invalid-param-spec + f_r = op(*f_args, **f_kwargs) # pyrefly: ignore [invalid-param-spec] r = op._op_dk(final_key, *args, **kwargs) def desc(): diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 5be6952195c..aefa2949470 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1029,7 +1029,7 @@ class BuiltinVariable(VariableTracker): def call_self_handler(tx: "InstructionTranslator", args, kwargs): try: - # pyrefly: ignore # not-callable + # pyrefly: ignore [not-callable] result = self_handler(tx, *args, **kwargs) if result is not None: return result @@ -1037,7 +1037,7 @@ class BuiltinVariable(VariableTracker): # Check if binding is bad. inspect signature bind is expensive. # So check only when handler call fails. try: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] inspect.signature(self_handler).bind(tx, *args, **kwargs) except TypeError as e: has_constant_handler = obj.has_constant_handler(args, kwargs) @@ -1090,7 +1090,7 @@ class BuiltinVariable(VariableTracker): hints=[*graph_break_hints.DYNAMO_BUG], from_exc=exc, ) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] return VariableTracker.build(tx, res) else: @@ -1119,7 +1119,7 @@ class BuiltinVariable(VariableTracker): tx, args=list(map(ConstantVariable.create, exc.args)), ) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] return VariableTracker.build(tx, res) handlers.append(constant_fold_handler) @@ -1442,7 +1442,7 @@ class BuiltinVariable(VariableTracker): resolved_fn = getattr(self.fn, name) if resolved_fn in dict_methods: if isinstance(args[0], variables.UserDefinedDictVariable): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs) elif isinstance(args[0], variables.ConstDictVariable): return args[0].call_method(tx, name, args[1:], kwargs) @@ -1451,7 +1451,7 @@ class BuiltinVariable(VariableTracker): resolved_fn = getattr(self.fn, name) if resolved_fn in set_methods: if isinstance(args[0], variables.UserDefinedSetVariable): - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return args[0]._set_vt.call_method(tx, name, args[1:], kwargs) elif isinstance(args[0], variables.SetVariable): return args[0].call_method(tx, name, args[1:], kwargs) @@ -1540,12 +1540,12 @@ class BuiltinVariable(VariableTracker): if type(arg.value).__str__ is object.__str__: # Rely on the object str method try: - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] return variables.ConstantVariable.create(value=str_method()) except AttributeError: # Graph break return - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] elif is_wrapper_or_member_descriptor(str_method): unimplemented_v2( gb_type="Attempted to a str() method implemented in C/C++", @@ -1662,10 +1662,10 @@ class BuiltinVariable(VariableTracker): else: raw_b = b.raw_value if self.fn is max: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] raw_res = max(a.raw_value, raw_b) else: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] raw_res = min(a.raw_value, raw_b) need_unwrap = any( @@ -1980,12 +1980,16 @@ class BuiltinVariable(VariableTracker): if isinstance(arg, dict): arg = [ConstantVariable.create(k) for k in arg.keys()] return DictVariableType( - dict.fromkeys(arg, value), user_cls, mutation_type=ValueMutationNew() + # pyrefly: ignore [bad-argument-type] + dict.fromkeys(arg, value), + user_cls, + mutation_type=ValueMutationNew(), ) elif arg.has_force_unpack_var_sequence(tx): keys = arg.force_unpack_var_sequence(tx) if all(is_hashable(v) for v in keys): return DictVariableType( + # pyrefly: ignore [bad-argument-type] dict.fromkeys(keys, value), user_cls, mutation_type=ValueMutationNew(), @@ -2152,7 +2156,7 @@ class BuiltinVariable(VariableTracker): ) if isinstance(arg, variables.UserDefinedExceptionClassVariable): - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] return ConstantVariable.create(isinstance(arg_type, isinstance_type)) isinstance_type_tuple: tuple[type, ...] @@ -2185,10 +2189,10 @@ class BuiltinVariable(VariableTracker): # through it. This is a limitation of the current implementation. # Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it # might not be a big issue and we trade off it for performance. - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] val = issubclass(arg_type, isinstance_type_tuple) except TypeError: - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] val = arg_type in isinstance_type_tuple return variables.ConstantVariable.create(val) @@ -2210,7 +2214,7 @@ class BuiltinVariable(VariableTracker): # WARNING: This might run arbitrary user code `__subclasscheck__`. # See the comment in call_isinstance above. - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py)) def call_super(self, tx: "InstructionTranslator", a, b): @@ -2256,9 +2260,9 @@ class BuiltinVariable(VariableTracker): value = getattr(self.fn, name) except AttributeError: raise_observed_exception(AttributeError, tx) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] if not callable(value): - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] return VariableTracker.build(tx, value, source) return variables.GetAttrVariable(self, name, source=source) diff --git a/torch/_dynamo/variables/lazy.py b/torch/_dynamo/variables/lazy.py index 326a507d1ef..80c630f86ec 100644 --- a/torch/_dynamo/variables/lazy.py +++ b/torch/_dynamo/variables/lazy.py @@ -34,6 +34,7 @@ class LazyCache: self.vt = builder.VariableBuilder(tx, self.source)(self.value) if self.name_hint is not None: + # pyrefly: ignore [missing-attribute] self.vt.set_name_hint(self.name_hint) del self.value diff --git a/torch/_guards.py b/torch/_guards.py index e3d20c9fc51..fc8f88f237c 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -1138,11 +1138,13 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())): if isinstance(m, FakeTensorMode): + # pyrefly: ignore [bad-argument-type] fake_modes.append((m, "active fake mode", i)) flat_inputs = pytree.tree_leaves(inputs) for i, flat_input in enumerate(flat_inputs): if isinstance(flat_input, FakeTensor): + # pyrefly: ignore [bad-argument-type] fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) if is_traceable_wrapper_subclass(flat_input): out: list[Union[torch.Tensor, int, torch.SymInt]] = [] @@ -1151,6 +1153,7 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: x for x in out if isinstance(x, FakeTensor) ] fake_modes.extend( + # pyrefly: ignore [bad-argument-type] [ (tensor.fake_mode, f"subclass input {i}", ix) for ix, tensor in enumerate(fake_tensors) @@ -1162,9 +1165,12 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]: for m, desc2, i2 in fake_modes[1:]: assert fake_mode is m, ( f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n" + # pyrefly: ignore [missing-attribute] f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n" + # pyrefly: ignore [missing-attribute] f"fake mode from {desc2} {i2} allocated at:\n{m.stack}" ) + # pyrefly: ignore [bad-return] return fake_mode else: return None diff --git a/torch/_inductor/await_utils.py b/torch/_inductor/await_utils.py index 036c7e3457d..9d84776b06c 100644 --- a/torch/_inductor/await_utils.py +++ b/torch/_inductor/await_utils.py @@ -114,6 +114,7 @@ def _cancel_all_tasks( for task in to_cancel: task.cancel() + # pyrefly: ignore [bad-argument-type] loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) for task in to_cancel: @@ -149,7 +150,7 @@ def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[ task_factory = task_factories[0] if task_factory is None: if sys.version_info >= (3, 11): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] task = asyncio.Task(coro, loop=loop, context=context) else: task = asyncio.Task(coro, loop=loop) diff --git a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py index 8357e9fba77..e9f8ff54f9f 100644 --- a/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py +++ b/torch/_inductor/codegen/rocm/ck_universal_gemm_template.py @@ -590,11 +590,11 @@ class CKGemmTemplate(CKTemplate): arg = f"/* {field_name} */ Tuple<{tuple_elements}>" else: # tile shape arg = f"/* {field_name} */ S<{tuple_elements}>" - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] template_params.append(arg) else: if field_value is not None: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] template_params.append(f"/* {field_name} */ {field_value}") operation_name = op.name().replace("(", "").replace(",", "").replace(")", "") return self._template_from_string(template_definition).render( @@ -939,6 +939,7 @@ class CKGemmTemplate(CKTemplate): for o in rops: kBatches = self._get_kBatch(o) for kBatch in kBatches: + # pyrefly: ignore [bad-argument-type] ops.append(InductorROCmOp(op=o, kBatch=kBatch)) filtered_instances = list(filter(lambda op: self.filter_op(op), ops)) diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index 85c7f2884eb..6c4b8270dd7 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -273,7 +273,7 @@ def record_original_output_strides(gm: GraphModule) -> None: ): output_strides.append(val.stride()) else: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] output_strides.append(None) output_node.meta["original_output_strides"] = output_strides @@ -1110,6 +1110,7 @@ def _compile_fx_inner( ) log.info("-" * 130) for row in mm_table_data: + # pyrefly: ignore [not-iterable] log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001 log.info("-" * 130) @@ -1551,7 +1552,7 @@ class _InProcessFxCompile(FxCompile): node_runtimes = None if inductor_metrics_log.isEnabledFor(logging.INFO): num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] metrics.num_bytes_accessed += num_bytes metrics.node_runtimes += node_runtimes metrics.nodes_num_elem += nodes_num_elem @@ -1595,10 +1596,10 @@ class _InProcessFxCompile(FxCompile): disable = f"{disable} Found from {stack_trace}\n" else: disable = f"{disable}\n" - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] V.graph.disable_cudagraphs_reason = disable - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] if cudagraphs and not V.graph.disable_cudagraphs_reason: maybe_incompat_node = get_first_incompatible_cudagraph_node(gm) if maybe_incompat_node: @@ -1607,29 +1608,29 @@ class _InProcessFxCompile(FxCompile): "stack_trace", None ): disable = f"{disable} Found from {stack_trace}\n" - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] V.graph.disable_cudagraphs_reason = disable - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] if V.aot_compilation: assert isinstance( compiled_fn, - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] (str, list, torch.fx.GraphModule), ), type(compiled_fn) return CompiledAOTI(compiled_fn) # TODO: Hoist this above V.aot_compilation - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] if cudagraphs and not V.graph.disable_cudagraphs_reason: from torch._inductor.cudagraph_utils import ( check_lowering_disable_cudagraph, ) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] V.graph.disable_cudagraphs_reason = ( check_lowering_disable_cudagraph( - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] V.graph.device_node_mapping ) ) @@ -1637,29 +1638,29 @@ class _InProcessFxCompile(FxCompile): self._compile_stats[type(self)].codegen_and_compile += 1 if ( - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] torch._inductor.debug.RECORD_GRAPH_EXECUTION - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] and torch._inductor.debug.GRAPH_COMPILE_IDS is not None ): compile_id = str( - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] torch._guards.CompileContext.current_compile_id() ) graph_id = graph_kwargs.get("graph_id") if graph_id is not None: - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] torch._inductor.debug.GRAPH_COMPILE_IDS[graph_id] = ( compile_id ) return CompiledFxGraph( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] compiled_fn, graph, gm, output_strides, - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] V.graph.disable_cudagraphs_reason, metrics_helper.get_deltas(), counters["inductor"] - inductor_counters, @@ -1701,18 +1702,18 @@ def fx_codegen_and_compile( from .compile_fx_async import _AsyncFxCompile from .compile_fx_ext import _OutOfProcessFxCompile - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] assert isinstance(scheme, _OutOfProcessFxCompile), ( "async is only valid with an out-of-process compile mode" ) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] scheme = _AsyncFxCompile(scheme) if fx_compile_progressive: from .compile_fx_async import _ProgressiveFxCompile from .compile_fx_ext import _OutOfProcessFxCompile - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] assert isinstance(scheme, _OutOfProcessFxCompile), ( "progressive is only valid with an out-of-process compile mode" ) @@ -1722,10 +1723,10 @@ def fx_codegen_and_compile( # Use in-process compile for the fast version fast_scheme = _InProcessFxCompile() - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs) - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) @@ -1835,7 +1836,7 @@ def cudagraphify_impl( Assumes inputs[static_input_idxs[i]] are always the same memory address """ check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] - # pyrefly: ignore # annotation-mismatch + # pyrefly: ignore [annotation-mismatch] static_input_idxs: OrderedSet[int] = OrderedSet( remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] ) @@ -1902,7 +1903,7 @@ def cudagraphify_impl( index_expanded_dims_and_copy_(dst, src, expanded_dims) new_inputs.clear() graph.replay() - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return static_outputs else: @@ -1918,7 +1919,7 @@ def cudagraphify_impl( index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims) new_inputs.clear() graph.replay() - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return static_outputs return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet()) @@ -1935,7 +1936,7 @@ def compile_fx_aot( # [See NOTE] Unwrapping subclasses AOT unwrap_tensor_subclass_parameters(model_) - # pyrefly: ignore # annotation-mismatch + # pyrefly: ignore [annotation-mismatch] config_patches: dict[str, Any] = copy.deepcopy(config_patches or {}) if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper): @@ -2878,7 +2879,7 @@ def _aoti_flatten_inputs( Flatten the inputs to the graph module and return the flat inputs and options. Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options. """ - # pyrefly: ignore # missing-module-attribute + # pyrefly: ignore [missing-module-attribute] from .compile_fx import graph_returns_tuple assert graph_returns_tuple(gm), ( diff --git a/torch/_inductor/fx_passes/split_cat.py b/torch/_inductor/fx_passes/split_cat.py index b6be29506fe..15ea6867dba 100644 --- a/torch/_inductor/fx_passes/split_cat.py +++ b/torch/_inductor/fx_passes/split_cat.py @@ -291,7 +291,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs): log.debug("example value absent for node: %s", input) return ndim = input.meta["example_value"].ndim - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] if dim < 0: # Normalize unbind dim dim += ndim with graph.inserting_after(node): @@ -341,7 +341,7 @@ def normalize_cat_default(match: Match, *args, **kwargs): ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors ) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] if cat_dim < 0: # Normalize cat dim cat_dim += ndim @@ -949,7 +949,7 @@ class SplitCatSimplifier: if isinstance(user_input, tuple): # Find the correct new getitem (present in split_items) new_user_inputs.append( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] split_items[ split_ranges.index( ( @@ -1000,7 +1000,7 @@ class SplitCatSimplifier: for user_input_new, transform_param in zip( user_inputs_new, transform_params ): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] if not is_node_meta_valid(user_input_new): log.debug("example value absent for node: %s", user_input_new) return @@ -1015,7 +1015,7 @@ class SplitCatSimplifier: stack_dim is None or stack_dim == unsqueeze_params[0] ): to_stack.append(user_input_new) - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] to_stack_meta.append(user_input_new.meta["example_value"]) stack_dim = unsqueeze_params[0] continue @@ -1036,12 +1036,12 @@ class SplitCatSimplifier: if unsqueeze_params: to_stack.append(user_input_new) stack_dim = unsqueeze_params[0] - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] to_stack_meta.append(user_input_new.meta["example_value"]) continue if unflatten_params: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] user_input_new_meta = user_input_new.meta["example_value"] user_input_new = graph.call_function( torch.unflatten, args=(user_input_new, *unflatten_params) @@ -1051,7 +1051,7 @@ class SplitCatSimplifier: *unflatten_params, # type: ignore[arg-type] ) if movedim_params: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] user_input_new_meta = user_input_new.meta["example_value"] user_input_new = graph.call_function( torch.movedim, args=(user_input_new, *movedim_params) @@ -1061,7 +1061,7 @@ class SplitCatSimplifier: *movedim_params, # type: ignore[arg-type] ) if flatten_params: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] user_input_new_meta = user_input_new.meta["example_value"] user_input_new = graph.call_function( torch.flatten, args=(user_input_new, *flatten_params) @@ -1072,7 +1072,7 @@ class SplitCatSimplifier: ) user_inputs_new_transformed.append(user_input_new) user_inputs_new_transformed_meta.append( - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] user_input_new.meta["example_value"] ) if to_stack: @@ -1432,7 +1432,7 @@ def simplify_split_cat(match: Match, split_sections: list[int], dim: int): if not isinstance(split_sections, (list, tuple)): # Unnormalized split return split_node = next(node for node in match.nodes if node.target == torch.split) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] SplitCatSimplifier().simplify(match.graph, split_node, split_sections) @@ -1501,7 +1501,7 @@ def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: list[int]) - for i in range(len(split_node.args[1])): # type: ignore[arg-type] if i in indices: fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] - # pyrefly: ignore # bad-return + # pyrefly: ignore [bad-return] return fused_tensor_size @@ -1978,7 +1978,7 @@ def normalize_cat_default_aten(match: Match, *args, **kwargs): assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] if cat_dim < 0: # Normalize cat dim cat_dim += ndim @@ -2512,7 +2512,8 @@ def reshape_cat_node_to_stack( args=(cat_node, tuple(reshape_list)), ) reshape_node.meta["example_value"] = torch.reshape( - cat_node.meta["example_value"], tuple(reshape_list) + cat_node.meta["example_value"], + tuple(reshape_list), # pyrefly: ignore [bad-argument-type] ) permute_list = list(range(len(stack_shape))) permute_list[stack_dim], permute_list[split_or_unbind_dim] = ( @@ -3044,6 +3045,6 @@ def replace_einsum_to_pointwise(match: Match, *args, **kwargs): einsum_node = match.nodes[0] input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2) if should_replace_einsum(einsum_node): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] match.replace_by_example(repl, [input, weights]) counters[backend]["einsum_to_pointwise_pass"] += 1 diff --git a/torch/_jit_internal.py b/torch/_jit_internal.py index 56622079c3b..192f969e5c6 100644 --- a/torch/_jit_internal.py +++ b/torch/_jit_internal.py @@ -147,7 +147,7 @@ def _qualified_name(obj, mangle_name=True) -> str: # If the module is actually a torchbind module, then we should short circuit if module_name == "torch._classes": - return obj.qualified_name # pyrefly: ignore # missing-attribute + return obj.qualified_name # pyrefly: ignore [missing-attribute] # The Python docs are very clear that `__module__` can be None, but I can't # figure out when it actually would be. @@ -759,7 +759,7 @@ def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]: prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED ) - return prop # pyrefly: ignore # bad-return + return prop # pyrefly: ignore [bad-return] fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined] return fn diff --git a/torch/_lazy/closure.py b/torch/_lazy/closure.py index 864591f84b5..bbe1a43bb12 100644 --- a/torch/_lazy/closure.py +++ b/torch/_lazy/closure.py @@ -65,8 +65,8 @@ class AsyncClosureHandler(ClosureHandler): self._closure_event_loop = threading.Thread( target=event_loop - ) # pyrefly: ignore # bad-assignment - self._closure_event_loop.start() # pyrefly: ignore # missing-attribute + ) # pyrefly: ignore [bad-assignment] + self._closure_event_loop.start() # pyrefly: ignore [missing-attribute] def run(self, closure): with self._closure_lock: diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 1ad443ff387..cb3a0c5ae0c 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -515,8 +515,11 @@ def meta_copy_(self, src, non_blocking=False): def inferUnsqueezeGeometry(tensor, dim): result_sizes = list(tensor.size()) result_strides = list(tensor.stride()) + # pyrefly: ignore [unsupported-operation] new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim] + # pyrefly: ignore [bad-argument-type] result_sizes.insert(dim, 1) + # pyrefly: ignore [bad-argument-type] result_strides.insert(dim, new_stride) return result_sizes, result_strides @@ -2341,19 +2344,19 @@ def calc_conv_nd_return_shape( ret_shape = [input_tensor.shape[0], out_channels] if isinstance(stride, IntLike): - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] stride = [stride] * len(dims) elif len(stride) == 1: stride = [stride[0]] * len(dims) if isinstance(padding, IntLike): - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] padding = [padding] * len(dims) elif len(padding) == 1: padding = [padding[0]] * len(dims) if isinstance(dilation, IntLike): - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] dilation = [dilation] * len(dims) elif len(dilation) == 1: dilation = [dilation[0]] * len(dims) @@ -2361,7 +2364,7 @@ def calc_conv_nd_return_shape( output_padding_list: Optional[list[int]] = None if output_padding: if isinstance(output_padding, IntLike): - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] output_padding_list = [output_padding] * len(dims) elif len(output_padding) == 1: output_padding_list = [output_padding[0]] * len(dims) @@ -2374,19 +2377,19 @@ def calc_conv_nd_return_shape( ret_shape.append( _formula_transposed( dims[i], - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] padding[i], - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] dilation[i], kernel_size[i], - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] stride[i], output_padding_list[i], ) ) else: ret_shape.append( - # pyrefly: ignore # index-error + # pyrefly: ignore [index-error] _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]) ) from torch.fx.experimental.symbolic_shapes import sym_or @@ -3454,7 +3457,7 @@ def meta_index_Tensor(self, indices): """ shape = before_shape + replacement_shape + after_shape strides = list(self.stride()) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len( replacement_shape ) @@ -5311,6 +5314,7 @@ def full(size, fill_value, *args, **kwargs): if not dtype: dtype = utils.get_dtype(fill_value) kwargs["dtype"] = dtype + # pyrefly: ignore [not-iterable] return torch.empty(size, *args, **kwargs) @@ -6668,7 +6672,7 @@ def rnn_cell_checkSizes( ) torch._check( all( - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] x.device == input_gates.device for x in [hidden_gates, input_bias, hidden_bias, prev_hidden] ), diff --git a/torch/_ops.py b/torch/_ops.py index dae34fff729..a0e8060eea0 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -880,7 +880,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]): elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk): return self._op_dk(dk, *args, **kwargs) else: - return NotImplemented # pyrefly: ignore # bad-return + return NotImplemented # pyrefly: ignore [bad-return] # Remove a dispatch key from the dispatch cache. This will force it to get # recomputed the next time. Does nothing @@ -985,9 +985,9 @@ class OpOverload(OperatorBase, Generic[_P, _T]): r = self.py_kernels.get(final_key, final_key) if cache_result: - self._dispatch_cache[key] = r # pyrefly: ignore # unsupported-operation + self._dispatch_cache[key] = r # pyrefly: ignore [unsupported-operation] add_cached_op(self) - return r # pyrefly: ignore # bad-return + return r # pyrefly: ignore [bad-return] def name(self): return self._name @@ -1117,7 +1117,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]): ) assert isinstance(handler, Callable) # type: ignore[arg-type] - return handler(*args, **kwargs) # pyrefly: ignore # bad-return + return handler(*args, **kwargs) # pyrefly: ignore [bad-return] def _must_dispatch_in_python(args, kwargs): diff --git a/torch/_subclasses/functional_tensor.py b/torch/_subclasses/functional_tensor.py index 15ed56ddca3..83d0afb837b 100644 --- a/torch/_subclasses/functional_tensor.py +++ b/torch/_subclasses/functional_tensor.py @@ -267,6 +267,7 @@ class FunctionalTensor(torch.Tensor): device=self.device, layout=self.layout, ) + # pyrefly: ignore [not-iterable] return super().to(*args, **kwargs) def cuda(self, device=None, *args, **kwargs): diff --git a/torch/_tensor.py b/torch/_tensor.py index 23195f720c5..165fd6ba7e1 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -551,6 +551,7 @@ class Tensor(torch._C.TensorBase): raise RuntimeError("__setstate__ can be only called on leaf Tensors") if len(state) == 4: # legacy serialization of Tensor + # pyrefly: ignore [not-iterable] self.set_(*state) return elif len(state) == 5: @@ -758,7 +759,7 @@ class Tensor(torch._C.TensorBase): ) if self._post_accumulate_grad_hooks is None: self._post_accumulate_grad_hooks: dict[Any, Any] = ( - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] OrderedDict() ) @@ -1062,7 +1063,7 @@ class Tensor(torch._C.TensorBase): else: return torch._VF.split_with_sizes( self, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] split_size, dim, ) @@ -1119,7 +1120,7 @@ class Tensor(torch._C.TensorBase): __rtruediv__ = __rdiv__ __itruediv__ = _C.TensorBase.__idiv__ - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] __pow__ = cast( Callable[ ["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]], diff --git a/torch/_utils.py b/torch/_utils.py index 095f256aac3..991e543e7a5 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -686,8 +686,8 @@ def _take_tensors(tensors, size_limit): if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: yield buf_and_size[0] buf_and_size = buf_dict[t] = [[], 0] - buf_and_size[0].append(tensor) # pyrefly: ignore # missing-attribute - buf_and_size[1] += size # pyrefly: ignore # unsupported-operation + buf_and_size[0].append(tensor) # pyrefly: ignore [missing-attribute] + buf_and_size[1] += size # pyrefly: ignore [unsupported-operation] for buf, _ in buf_dict.values(): if len(buf) > 0: yield buf @@ -744,6 +744,7 @@ class ExceptionWrapper: if exc_info is None: exc_info = sys.exc_info() self.exc_type = exc_info[0] + # pyrefly: ignore [not-iterable] self.exc_msg = "".join(traceback.format_exception(*exc_info)) self.where = where @@ -751,7 +752,7 @@ class ExceptionWrapper: r"""Reraises the wrapped exception in the current thread""" # Format a message such as: "Caught ValueError in DataLoader worker # process 2. Original Traceback:", followed by the traceback. - msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute + msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore [missing-attribute] if self.exc_type is KeyError: # KeyError calls repr() on its argument (usually a dict key). This # makes stack traces unreadable. It will not be changed in Python @@ -760,13 +761,13 @@ class ExceptionWrapper: elif getattr(self.exc_type, "message", None): # Some exceptions have first argument as non-str but explicitly # have message field - # pyrefly: ignore # not-callable + # pyrefly: ignore [not-callable] raise self.exc_type( - # pyrefly: ignore # unexpected-keyword + # pyrefly: ignore [unexpected-keyword] message=msg ) try: - exception = self.exc_type(msg) # pyrefly: ignore # not-callable + exception = self.exc_type(msg) # pyrefly: ignore [not-callable] except Exception: # If the exception takes multiple arguments or otherwise can't # be constructed, don't try to instantiate since we don't know how to @@ -1018,12 +1019,12 @@ class _LazySeedTracker: self.call_order = [] def queue_seed_all(self, cb, traceback): - self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore # bad-assignment + self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore [bad-assignment] # update seed_all to be latest self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] def queue_seed(self, cb, traceback): - self.manual_seed_cb = (cb, traceback) # pyrefly: ignore # bad-assignment + self.manual_seed_cb = (cb, traceback) # pyrefly: ignore [bad-assignment] # update seed to be latest self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] diff --git a/torch/_weights_only_unpickler.py b/torch/_weights_only_unpickler.py index 5cc8e523406..d33c10ed384 100644 --- a/torch/_weights_only_unpickler.py +++ b/torch/_weights_only_unpickler.py @@ -419,6 +419,7 @@ class Unpickler: inst = self.stack[-1] if type(inst) is torch.Tensor: # Legacy unpickling + # pyrefly: ignore [not-iterable] inst.set_(*state) elif type(inst) is torch.nn.Parameter: inst.__setstate__(state) diff --git a/torch/accelerator/memory.py b/torch/accelerator/memory.py index d34a11a3a02..d98be363211 100644 --- a/torch/accelerator/memory.py +++ b/torch/accelerator/memory.py @@ -104,6 +104,7 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]: flatten("", stats) flat_stats.sort() + # pyrefly: ignore [no-matching-overload] return OrderedDict(flat_stats) diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index 9196cb2de69..c23058dc336 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -525,7 +525,7 @@ def custom_fwd( args[0]._dtype = torch.get_autocast_dtype(device_type) if cast_inputs is None: args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type) - return fwd(*args, **kwargs) # pyrefly: ignore # not-callable + return fwd(*args, **kwargs) # pyrefly: ignore [not-callable] else: autocast_context = torch.is_autocast_enabled(device_type) args[0]._fwd_used_autocast = False @@ -536,7 +536,7 @@ def custom_fwd( **_cast(kwargs, device_type, cast_inputs), ) else: - return fwd(*args, **kwargs) # pyrefly: ignore # not-callable + return fwd(*args, **kwargs) # pyrefly: ignore [not-callable] return decorate_fwd @@ -571,6 +571,6 @@ def custom_bwd(bwd=None, *, device_type: str): enabled=args[0]._fwd_used_autocast, dtype=args[0]._dtype, ): - return bwd(*args, **kwargs) # pyrefly: ignore # not-callable + return bwd(*args, **kwargs) # pyrefly: ignore [not-callable] return decorate_bwd diff --git a/torch/ao/ns/fx/graph_matcher.py b/torch/ao/ns/fx/graph_matcher.py index 91d8042d5f6..fd7f5cbe552 100644 --- a/torch/ao/ns/fx/graph_matcher.py +++ b/torch/ao/ns/fx/graph_matcher.py @@ -84,7 +84,7 @@ class _NSGraphMatchableSubgraphsIterator: if is_match: # navigate to the base node for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.seen_nodes.add(cur_start_node) # for now, assume that there are no other nodes # which need to be added to the stack @@ -95,10 +95,10 @@ class _NSGraphMatchableSubgraphsIterator: cur_base_op_node = cur_start_node break - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.seen_nodes.add(cur_start_node) # add args of previous nodes to stack - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] for arg in cur_start_node.all_input_nodes: self._recursively_add_node_arg_to_stack(arg) @@ -106,7 +106,7 @@ class _NSGraphMatchableSubgraphsIterator: # note: this check is done on the start_node, i.e. # if we are matching linear-relu in reverse, this would do the matchable # check on the linear - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] if not self._is_matchable(cur_base_op_node): continue @@ -120,10 +120,10 @@ class _NSGraphMatchableSubgraphsIterator: continue return NSSubgraph( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] start_node=cur_start_node, end_node=cur_end_node, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] base_op_node=cur_base_op_node, ) @@ -481,4 +481,5 @@ of subgraphs.""" # subgraphs in their order of execution. results = collections.OrderedDict(reversed(results.items())) + # pyrefly: ignore [bad-return] return results diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 530937928b8..b2d6530049e 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -30,6 +30,7 @@ class EventList(list): use_device = kwargs.pop("use_device", None) profile_memory = kwargs.pop("profile_memory", False) with_flops = kwargs.pop("with_flops", False) + # pyrefly: ignore [not-iterable] super().__init__(*args, **kwargs) self._use_device = use_device self._profile_memory = profile_memory @@ -505,9 +506,9 @@ class FunctionEvent(FormattedTimesMixin): self.id: int = id self.node_id: int = node_id self.name: str = name - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.overload_name: str = overload_name - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.trace_name: str = trace_name self.time_range: Interval = Interval(start_us, end_us) self.thread: int = thread @@ -516,13 +517,13 @@ class FunctionEvent(FormattedTimesMixin): self.count: int = 1 self.cpu_children: list[FunctionEvent] = [] self.cpu_parent: Optional[FunctionEvent] = None - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.input_shapes: tuple[int, ...] = input_shapes - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.concrete_inputs: list[Any] = concrete_inputs - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.kwinputs: dict[str, Any] = kwinputs - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.stack: list = stack self.scope: int = scope self.use_device: Optional[str] = use_device @@ -766,7 +767,7 @@ class FunctionEventAvg(FormattedTimesMixin): self.self_device_memory_usage += other.self_device_memory_usage self.count += other.count if self.flops is None: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self.flops = other.flops elif other.flops is not None: self.flops += other.flops @@ -1003,7 +1004,7 @@ def _build_table( ] if flops <= 0: raise AssertionError(f"Expected flops to be positive, but got {flops}") - # pyrefly: ignore # no-matching-overload + # pyrefly: ignore [no-matching-overload] log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1))) if not (log_flops >= 0 and log_flops < len(flop_headers)): raise AssertionError( diff --git a/torch/compiler/__init__.py b/torch/compiler/__init__.py index 52d2645c4b7..1e744f54362 100644 --- a/torch/compiler/__init__.py +++ b/torch/compiler/__init__.py @@ -50,6 +50,7 @@ def compile(*args, **kwargs): """ See :func:`torch.compile` for details on the arguments for this function. """ + # pyrefly: ignore [not-iterable] return torch.compile(*args, **kwargs) diff --git a/torch/distributed/_local_tensor/__init__.py b/torch/distributed/_local_tensor/__init__.py index 4c15ffc0195..2415396d87a 100644 --- a/torch/distributed/_local_tensor/__init__.py +++ b/torch/distributed/_local_tensor/__init__.py @@ -198,6 +198,7 @@ def _for_each_rank_run_func( rr_val = flat_rank_rets[rr_key] if isinstance(rr_val, Tensor): + # pyrefly: ignore [bad-argument-type, bad-argument-count] ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)}) elif isinstance(rr_val, (list, tuple)): ret_list = [] @@ -206,6 +207,7 @@ def _for_each_rank_run_func( v_it = iter(rets.values()) v = next(v_it) if isinstance(v, Tensor): + # pyrefly: ignore [bad-argument-type, bad-argument-count] ret_list.append(LocalTensor(rets)) elif isinstance(v, int) and not all(v == v2 for v2 in v_it): ret_list.append(torch.SymInt(LocalIntNode(rets))) @@ -468,7 +470,7 @@ class LocalTensor(torch.Tensor): def __repr__(self) -> str: # type: ignore[override] parts = [] for k, v in self._local_tensors.items(): - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] parts.append(f" {k}: {v}") tensors_str = ",\n".join(parts) return f"LocalTensor(\n{tensors_str}\n)" @@ -491,6 +493,7 @@ class LocalTensor(torch.Tensor): "Expecting spec to be not None from `__tensor_flatten__` return value!" ) local_tensors = inner_tensors["_local_tensors"] + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor(local_tensors) @classmethod @@ -751,6 +754,7 @@ class LocalTensorMode(TorchDispatchMode): """ with self.disable(): + # pyrefly: ignore [bad-argument-type, bad-argument-count] return LocalTensor({r: cb(r) for r in self.ranks}) def _patch_device_mesh(self) -> None: @@ -761,7 +765,7 @@ class LocalTensorMode(TorchDispatchMode): def _unpatch_device_mesh(self) -> None: assert self._old_get_coordinate is not None DeviceMesh.get_coordinate = self._old_get_coordinate - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] self._old_get_coordinate = None diff --git a/torch/distributed/_shard/sharded_tensor/utils.py b/torch/distributed/_shard/sharded_tensor/utils.py index ed65991aeb0..c32fe5f7195 100644 --- a/torch/distributed/_shard/sharded_tensor/utils.py +++ b/torch/distributed/_shard/sharded_tensor/utils.py @@ -79,6 +79,7 @@ def _flatten_tensor_size(size) -> torch.Size: Checks if tensor size is valid, then flatten/return a torch.Size object. """ if len(size) == 1 and isinstance(size[0], collections.abc.Sequence): + # pyrefly: ignore [not-iterable] dims = list(*size) else: dims = list(size) @@ -208,7 +209,7 @@ def build_global_metadata( global_sharded_tensor_metadata = None global_metadata_rank = 0 - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for rank, rank_metadata in enumerate(gathered_metadatas): if rank_metadata is None: continue diff --git a/torch/distributed/checkpoint/_pg_transport.py b/torch/distributed/checkpoint/_pg_transport.py index 6a327afd445..b258517bdce 100644 --- a/torch/distributed/checkpoint/_pg_transport.py +++ b/torch/distributed/checkpoint/_pg_transport.py @@ -227,7 +227,7 @@ class PGTransport: self._work: list[Work] = [] self._pg = pg self._timeout = timeout - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self._device = device self._state_dict = state_dict @@ -345,6 +345,7 @@ class PGTransport: values.append(recv(path, v)) elif isinstance(v, _DTensorMeta): tensor = recv(path, v.local) + # pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword] values.append(DTensor(tensor, v.spec, requires_grad=False)) elif isinstance(v, _ShardedTensorMeta): # Receive all local shards that were sent to us diff --git a/torch/distributed/fsdp/_flat_param.py b/torch/distributed/fsdp/_flat_param.py index ce5d29dc166..2d742c30302 100644 --- a/torch/distributed/fsdp/_flat_param.py +++ b/torch/distributed/fsdp/_flat_param.py @@ -565,7 +565,7 @@ class FlatParamHandle: # Only align addresses for `use_orig_params=True` (for now) align_addresses = use_orig_params self._init_get_unflat_views_fn(align_addresses) - # pyrefly: ignore # read-only + # pyrefly: ignore [read-only] self.device = device self._device_handle = _FSDPDeviceHandle.from_device(self.device) self.process_group = process_group @@ -2495,6 +2495,7 @@ class FlatParamHandle: ########### def flat_param_to(self, *args, **kwargs): """Wrap an in-place call to ``.to()`` for ``self.flat_param``.""" + # pyrefly: ignore [not-iterable] self.flat_param.data = self.flat_param.to(*args, **kwargs) if self._use_orig_params: # Refresh the views because their storage may have changed diff --git a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py index 5013ce62cb3..a72c09fd80f 100644 --- a/torch/distributed/fsdp/_fully_shard/_fsdp_common.py +++ b/torch/distributed/fsdp/_fully_shard/_fsdp_common.py @@ -139,11 +139,14 @@ def _from_local_no_grad( """ if not compiled_autograd_enabled(): + # pyrefly: ignore [bad-argument-type] return DTensor( # Use the local tensor directly instead of constructing a new tensor # variable, e.g. with `view_as()`, since this is not differentiable + # pyrefly: ignore [bad-argument-count] local_tensor, sharding_spec, + # pyrefly: ignore [unexpected-keyword] requires_grad=local_tensor.requires_grad, ) else: diff --git a/torch/distributed/tensor/_api.py b/torch/distributed/tensor/_api.py index 5fd66b2c5f8..865de11dacc 100644 --- a/torch/distributed/tensor/_api.py +++ b/torch/distributed/tensor/_api.py @@ -107,9 +107,12 @@ class _ToTorchTensor(torch.autograd.Function): ) return ( + # pyrefly: ignore [bad-argument-type] DTensor( + # pyrefly: ignore [bad-argument-count] grad_output, grad_spec, + # pyrefly: ignore [unexpected-keyword] requires_grad=grad_output.requires_grad, ), None, @@ -175,11 +178,14 @@ class _FromTorchTensor(torch.autograd.Function): ) # We want a fresh Tensor object that shares memory with the input tensor + # pyrefly: ignore [bad-argument-type] dist_tensor = DTensor( + # pyrefly: ignore [bad-argument-count] input.view_as(input), dist_spec, # requires_grad of the dist tensor depends on if input # requires_grad or not + # pyrefly: ignore [unexpected-keyword] requires_grad=input.requires_grad, ) return dist_tensor @@ -304,9 +310,12 @@ class DTensor(torch.Tensor): spec.placements, tensor_meta=unflatten_tensor_meta, ) + # pyrefly: ignore [bad-argument-type] return DTensor( + # pyrefly: ignore [bad-argument-count] local_tensor, unflatten_spec, + # pyrefly: ignore [unexpected-keyword] requires_grad=requires_grad, ) @@ -820,9 +829,12 @@ def distribute_tensor( dtype=tensor.dtype, ), ) + # pyrefly: ignore [bad-argument-type] return DTensor( + # pyrefly: ignore [bad-argument-count] local_tensor.requires_grad_(tensor.requires_grad), spec, + # pyrefly: ignore [unexpected-keyword] requires_grad=tensor.requires_grad, ) @@ -1077,9 +1089,12 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def] ), ) + # pyrefly: ignore [bad-argument-type] return DTensor( + # pyrefly: ignore [bad-argument-count] local_tensor, spec, + # pyrefly: ignore [unexpected-keyword] requires_grad=kwargs["requires_grad"], ) diff --git a/torch/distributed/tensor/_dispatch.py b/torch/distributed/tensor/_dispatch.py index 8a293aaaea2..4f91e3444b0 100644 --- a/torch/distributed/tensor/_dispatch.py +++ b/torch/distributed/tensor/_dispatch.py @@ -78,8 +78,11 @@ def found_inf_reduce_handler( dtype=target_tensor.dtype, ), ) + # pyrefly: ignore [bad-argument-type] found_inf_dtensor = dtensor.DTensor( - local_tensor=target_tensor, spec=spec, requires_grad=False + local_tensor=target_tensor, # pyrefly: ignore [unexpected-keyword] + spec=spec, # pyrefly: ignore [unexpected-keyword] + requires_grad=False, # pyrefly: ignore [unexpected-keyword] ) found_inf = found_inf_dtensor.full_tensor() target_tensor.copy_(found_inf) @@ -189,7 +192,7 @@ class OpDispatcher: local_tensor_args = ( pytree.tree_unflatten( cast(list[object], op_info.local_args), - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] op_info.args_tree_spec, ) if op_info.args_tree_spec @@ -366,7 +369,7 @@ class OpDispatcher: resharded_local_tensor = redistribute_local_tensor( local_tensor, arg_spec, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] reshard_arg_spec, ) new_local_args.append(resharded_local_tensor) @@ -439,7 +442,7 @@ class OpDispatcher: kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor( op_call, v, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] compute_mesh, ) local_kwargs[k] = v @@ -456,7 +459,7 @@ class OpDispatcher: OpSchema( op_call, ( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] pytree.tree_unflatten(args_schema, args_spec) if args_spec else tuple(args_schema) @@ -478,6 +481,7 @@ class OpDispatcher: assert isinstance(spec, DTensorSpec), ( f"output spec does not match with output! Expected DTensorSpec, got {spec}." ) + # pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword] return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) else: # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor diff --git a/torch/distributed/tensor/_redistribute.py b/torch/distributed/tensor/_redistribute.py index 9dc5f5041ab..463c34c8fb4 100644 --- a/torch/distributed/tensor/_redistribute.py +++ b/torch/distributed/tensor/_redistribute.py @@ -883,9 +883,12 @@ class Redistribute(torch.autograd.Function): output = local_tensor target_spec = current_spec + # pyrefly: ignore [bad-argument-type] return dtensor.DTensor( + # pyrefly: ignore [bad-argument-count] output, target_spec, + # pyrefly: ignore [unexpected-keyword] requires_grad=input.requires_grad, ) @@ -944,9 +947,12 @@ class Redistribute(torch.autograd.Function): dtype=output.dtype, ), ) + # pyrefly: ignore [bad-argument-type] output_dtensor = dtensor.DTensor( + # pyrefly: ignore [bad-argument-count] output, spec, + # pyrefly: ignore [unexpected-keyword] requires_grad=grad_output.requires_grad, ) diff --git a/torch/distributed/tensor/parallel/loss.py b/torch/distributed/tensor/parallel/loss.py index bc9c5486298..addcc0a898b 100644 --- a/torch/distributed/tensor/parallel/loss.py +++ b/torch/distributed/tensor/parallel/loss.py @@ -174,9 +174,12 @@ def _log_softmax_handler( tensor_meta=output_tensor_meta, ) + # pyrefly: ignore [bad-argument-type] return DTensor( + # pyrefly: ignore [bad-argument-count] res, res_spec, + # pyrefly: ignore [unexpected-keyword] requires_grad=res.requires_grad, ) @@ -251,7 +254,7 @@ def _nll_loss_forward( if weight is not None: new_shape = list(x.shape) new_shape[channel_dim] = -1 - # pyrefly: ignore # unbound-name + # pyrefly: ignore [unbound-name] w = w.expand(new_shape) wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) wsum = torch.where(target != ignore_index, wsum, 0) @@ -309,9 +312,9 @@ def _nll_loss_forward_handler( output_placements = all_replicate_placements # tensor inputs to _propagate_tensor_meta need to be DTensors - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] args = list(args) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] args[1], args[2] = target, weight output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) @@ -330,9 +333,12 @@ def _nll_loss_forward_handler( out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta) return ( + # pyrefly: ignore [bad-argument-type] DTensor( + # pyrefly: ignore [bad-argument-count] result, out_spec, + # pyrefly: ignore [unexpected-keyword] requires_grad=result.requires_grad, ), total_weight, @@ -442,11 +448,11 @@ def _nll_loss_backward_handler( weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) # tensor inputs to _propagate_tensor_meta need to be DTensors - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] args = list(args) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] args[2], args[3] = target, weight - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh) output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) @@ -470,9 +476,12 @@ def _nll_loss_backward_handler( tensor_meta=output_tensor_meta, ) + # pyrefly: ignore [bad-argument-type] return DTensor( + # pyrefly: ignore [bad-argument-count] result, out_spec, + # pyrefly: ignore [unexpected-keyword] requires_grad=result.requires_grad, ) diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index de72c8f505d..1e4931f4a19 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -949,7 +949,7 @@ def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module): for key, value in pytree.tree_map(arg_dump, node.kwargs).items() ] target = node.target if node.op in ("call_function", "get_attr") else "" - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") nodes_idx[id(node)] = i return "\n".join(ret) @@ -1206,6 +1206,7 @@ class _ModuleFrame: for k in kwargs_spec.context } assert self.parent_call_module is not None + # pyrefly: ignore [bad-assignment] self.parent_call_module.args = tuple(arg_nodes) self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment] @@ -1393,6 +1394,7 @@ class _ModuleFrame: def print(self, *args, **kwargs): if self.verbose: + # pyrefly: ignore [not-iterable] print(*args, **kwargs) def run_from(self, node_idx): @@ -1486,7 +1488,7 @@ class _ModuleFrame: self.seen_attrs[self.child_fqn].add(node.target) self.copy_node(node) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] node_idx += 1 diff --git a/torch/functional.py b/torch/functional.py index 47e147f8508..3054f54b7cd 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -1952,7 +1952,7 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor: ) if isinstance(shape, (int, torch.SymInt)): - shape = torch.Size([shape]) # pyrefly: ignore # bad-argument-type + shape = torch.Size([shape]) # pyrefly: ignore [bad-argument-type] else: for dim in shape: torch._check_type( diff --git a/torch/fx/experimental/unification/more.py b/torch/fx/experimental/unification/more.py index f1df562a2dc..42074a46a42 100644 --- a/torch/fx/experimental/unification/more.py +++ b/torch/fx/experimental/unification/more.py @@ -87,6 +87,7 @@ def _reify_object_slots(o, s): @dispatch(slice, dict) def _reify(o, s): """Reify a Python ``slice`` object""" + # pyrefly: ignore [not-iterable] return slice(*reify((o.start, o.stop, o.step), s)) diff --git a/torch/fx/node.py b/torch/fx/node.py index ad848c80970..48f57d58863 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -59,7 +59,7 @@ Argument = Optional[ BaseArgumentTypes, ] ] -# pyrefly: ignore # invalid-annotation +# pyrefly: ignore [invalid-annotation] ArgumentT = TypeVar("ArgumentT", bound=Argument) _P = ParamSpec("_P") _R = TypeVar("_R") @@ -385,7 +385,7 @@ class Node(_NodeBase): Args: x (Node): The node to put before this node. Must be a member of the same graph. """ - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self._prepend(x) @compatibility(is_backward_compatible=True) @@ -397,7 +397,7 @@ class Node(_NodeBase): Args: x (Node): The node to put after this node. Must be a member of the same graph. """ - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self._next._prepend(x) @property @@ -698,7 +698,8 @@ class Node(_NodeBase): if replace_hooks: for replace_hook in replace_hooks: replace_hook(old=self, new=replace_with.name, user=use_node) - use_node._replace_input_with(self, replace_with) + # pyrefly: ignore [missing-attribute] + use_node._replace_input_with(self, replace_with) # type: ignore[attr-defined] return result @compatibility(is_backward_compatible=False) @@ -835,7 +836,8 @@ class Node(_NodeBase): for replace_hook in m._replace_hooks: replace_hook(old=old_input, new=new_input.name, user=self) - self._replace_input_with(old_input, new_input) + # pyrefly: ignore [missing-attribute] + self._replace_input_with(old_input, new_input) # type: ignore[attr-defined] def _rename(self, candidate: str) -> None: if candidate == self.name: diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index 092e9f2cc5c..c381d99747c 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -303,7 +303,7 @@ class StreamContext: self.idx = _get_device_index(None, True) if not torch.jit.is_scripting(): if self.idx is None: - self.idx = -1 # pyrefly: ignore # bad-assignment + self.idx = -1 # pyrefly: ignore [bad-assignment] self.src_prev_stream = ( None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) diff --git a/torch/nested/_internal/ops.py b/torch/nested/_internal/ops.py index bdca74c13b1..756dc643baf 100644 --- a/torch/nested/_internal/ops.py +++ b/torch/nested/_internal/ops.py @@ -46,7 +46,7 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False): if canonicalize: dim = canonicalize_dims(ndim, dim) - assert dim >= 0 and dim < ndim # pyrefly: ignore # unsupported-operation + assert dim >= 0 and dim < ndim # pyrefly: ignore [unsupported-operation] # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1. # For other dims, subtract 1 to convert to inner space. diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index dbc32e0ff96..6a78aba2ad7 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -72,7 +72,7 @@ class _NormBase(Module): torch.tensor( 0, dtype=torch.long, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, ), ) @@ -222,7 +222,7 @@ class _LazyNormBase(LazyModuleMixin, _NormBase): dtype=None, ) -> None: factory_kwargs = {"device": device, "dtype": dtype} - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] super().__init__( # affine and track_running_stats are hardcoded to False to # avoid creating tensors that will soon be overwritten. @@ -236,29 +236,29 @@ class _LazyNormBase(LazyModuleMixin, _NormBase): self.affine = affine self.track_running_stats = track_running_stats if self.affine: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.weight = UninitializedParameter(**factory_kwargs) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.bias = UninitializedParameter(**factory_kwargs) if self.track_running_stats: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.running_mean = UninitializedBuffer(**factory_kwargs) - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] self.running_var = UninitializedBuffer(**factory_kwargs) self.num_batches_tracked = torch.tensor( 0, dtype=torch.long, - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] **{k: v for k, v in factory_kwargs.items() if k != "dtype"}, ) def reset_parameters(self) -> None: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] if not self.has_uninitialized_params() and self.num_features != 0: super().reset_parameters() def initialize_parameters(self, input) -> None: # type: ignore[override] - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] if self.has_uninitialized_params(): self.num_features = input.shape[1] if self.affine: @@ -352,6 +352,7 @@ class BatchNorm1d(_BatchNorm): raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") +# pyrefly: ignore [inconsistent-inheritance] class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization. @@ -463,6 +464,7 @@ class BatchNorm2d(_BatchNorm): raise ValueError(f"expected 4D input (got {input.dim()}D input)") +# pyrefly: ignore [inconsistent-inheritance] class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization. @@ -574,6 +576,7 @@ class BatchNorm3d(_BatchNorm): raise ValueError(f"expected 5D input (got {input.dim()}D input)") +# pyrefly: ignore [inconsistent-inheritance] class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization. diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index 084e9821781..194e68046e8 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -38,13 +38,13 @@ T = TypeVar("T", bound="Module") class _IncompatibleKeys( - # pyrefly: ignore # invalid-inheritance + # pyrefly: ignore [invalid-inheritance] namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), ): __slots__ = () def __repr__(self) -> str: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if not self.missing_keys and not self.unexpected_keys: return "" return super().__repr__() @@ -93,7 +93,7 @@ class _WrappedHook: def __getstate__(self) -> dict: result = {"hook": self.hook, "with_module": self.with_module} if self.with_module: - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] result["module"] = self.module() return result @@ -979,7 +979,7 @@ class Module: # Decrement use count of the gradient by setting to None param.grad = None param_applied = torch.nn.Parameter( - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] param_applied, requires_grad=param.requires_grad, ) @@ -992,13 +992,13 @@ class Module: ) from e out_param = param elif p_should_use_set_data: - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] param.data = param_applied out_param = param else: assert isinstance(param, Parameter) assert param.is_leaf - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] out_param = Parameter(param_applied, param.requires_grad) self._parameters[key] = out_param @@ -1337,7 +1337,9 @@ class Module: """ device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( - *args, **kwargs + # pyrefly: ignore [not-iterable] + *args, + **kwargs, ) if dtype is not None: @@ -2256,7 +2258,7 @@ class Module: if destination is None: destination = OrderedDict() - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] destination._metadata = OrderedDict() local_metadata = dict(version=self._version) @@ -2407,7 +2409,7 @@ class Module: } local_name_params = itertools.chain( self._parameters.items(), - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] persistent_buffers.items(), ) local_state = {k: v for k, v in local_name_params if v is not None} diff --git a/torch/onnx/_internal/exporter/_capture_strategies.py b/torch/onnx/_internal/exporter/_capture_strategies.py index 88524b45209..63421ff5bb9 100644 --- a/torch/onnx/_internal/exporter/_capture_strategies.py +++ b/torch/onnx/_internal/exporter/_capture_strategies.py @@ -27,6 +27,7 @@ def _verbose_printer(verbose: bool | None) -> Callable[..., None]: """Prints messages based on `verbose`.""" if verbose is False: return lambda *_, **__: None + # pyrefly: ignore [not-iterable] return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs) @@ -47,7 +48,7 @@ def _patch_dynamo_unsupported_functions(): # Replace torch.jit.isinstance with isinstance jit_isinstance = torch.jit.isinstance - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] torch.jit.isinstance = isinstance logger.info("Replaced torch.jit.isinstance with isinstance to allow dynamo tracing") try: diff --git a/torch/onnx/_internal/exporter/_core.py b/torch/onnx/_internal/exporter/_core.py index 5f7872b6749..4458e00d767 100644 --- a/torch/onnx/_internal/exporter/_core.py +++ b/torch/onnx/_internal/exporter/_core.py @@ -132,10 +132,10 @@ class TorchTensor(ir.Tensor): # view the tensor as that dtype so that it is convertible to NumPy, # and then view it back to the proper dtype (using ml_dtypes obtained by # calling dtype.numpy()). - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] if self.dtype == ir.DataType.BFLOAT16: return ( - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) ) if self.dtype in { @@ -144,11 +144,11 @@ class TorchTensor(ir.Tensor): ir.DataType.FLOAT8E5M2, ir.DataType.FLOAT8E5M2FNUZ, }: - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) if self.dtype == ir.DataType.FLOAT4E2M1: return _type_casting.unpack_float4x2_as_uint8(self.raw).view( - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] self.dtype.numpy() ) @@ -170,7 +170,7 @@ class TorchTensor(ir.Tensor): if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): raise TypeError( - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " "with a tensor backed by real data using ONNXProgram.apply_weights() " "or save the model without initializers by setting include_initializers=False." @@ -251,7 +251,7 @@ def _set_shape_type( if isinstance(dim, int): dims.append(dim) else: - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] dims.append(str(dim.node)) # If the dtype is set already (e.g. by the onnx_symbolic ops), @@ -1232,7 +1232,7 @@ def _exported_program_to_onnx_program( # so we need to get them from the name_* apis. for name, torch_tensor in itertools.chain( exported_program.named_parameters(), - # pyrefly: ignore # bad-argument-type + # pyrefly: ignore [bad-argument-type] exported_program.named_buffers(), exported_program.constants.items(), ): @@ -1265,6 +1265,7 @@ def _verbose_printer(verbose: bool | None) -> Callable[..., None]: """Prints messages based on `verbose`.""" if verbose is False: return lambda *_, **__: None + # pyrefly: ignore [not-iterable] return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs) diff --git a/torch/onnx/_internal/torchscript_exporter/verification.py b/torch/onnx/_internal/torchscript_exporter/verification.py index f8e2d37ba73..c3cb967c14c 100644 --- a/torch/onnx/_internal/torchscript_exporter/verification.py +++ b/torch/onnx/_internal/torchscript_exporter/verification.py @@ -239,7 +239,7 @@ def _compare_onnx_pytorch_outputs_in_np( if acceptable_error_percentage: error_percentage = 1 - np.sum( np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) - ) / np.prod(ort_out.shape) # pyrefly: ignore # missing-attribute + ) / np.prod(ort_out.shape) # pyrefly: ignore [missing-attribute] if error_percentage <= acceptable_error_percentage: warnings.warn( f"Suppressed AssertionError:\n{e}.\n" diff --git a/torch/optim/_multi_tensor/__init__.py b/torch/optim/_multi_tensor/__init__.py index b6818e5a50f..face68d0bc5 100644 --- a/torch/optim/_multi_tensor/__init__.py +++ b/torch/optim/_multi_tensor/__init__.py @@ -13,6 +13,7 @@ from torch import optim def partialclass(cls, *args, **kwargs): # noqa: D103 class NewCls(cls): + # pyrefly: ignore [not-iterable] __init__ = partialmethod(cls.__init__, *args, **kwargs) return NewCls diff --git a/torch/signal/windows/windows.py b/torch/signal/windows/windows.py index ed240bda816..96fe6932de8 100644 --- a/torch/signal/windows/windows.py +++ b/torch/signal/windows/windows.py @@ -326,7 +326,7 @@ def gaussian( requires_grad=requires_grad, ) - return torch.exp(-(k**2)) # pyrefly: ignore # unsupported-operation + return torch.exp(-(k**2)) # pyrefly: ignore [unsupported-operation] @_add_docstr( diff --git a/torch/storage.py b/torch/storage.py index fbe75b549f2..1b9023121dd 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -618,7 +618,7 @@ def _get_storage_from_sequence(sequence, dtype, device): def _isint(x): if HAS_NUMPY: - return isinstance(x, (int, np.integer)) # pyrefly: ignore # missing-attribute + return isinstance(x, (int, np.integer)) # pyrefly: ignore [missing-attribute] else: return isinstance(x, int) diff --git a/torch/utils/_ordered_set.py b/torch/utils/_ordered_set.py index cea8ea684d3..eea73102223 100644 --- a/torch/utils/_ordered_set.py +++ b/torch/utils/_ordered_set.py @@ -77,6 +77,7 @@ class OrderedSet(MutableSet[T], Reversible[T]): def pop(self) -> T: if not self: raise KeyError("pop from an empty set") + # pyrefly: ignore [bad-return] return self._dict.popitem()[0] def copy(self) -> OrderedSet[T]: @@ -158,7 +159,7 @@ class OrderedSet(MutableSet[T], Reversible[T]): def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]: # MutableSet impl will iterate over other, iter over smaller of two sets if isinstance(other, OrderedSet) and len(self) < len(other): - # pyrefly: ignore # unsupported-operation, bad-return + # pyrefly: ignore [unsupported-operation, bad-return] return other & self return cast(OrderedSet[T], super().__and__(other)) diff --git a/torch/utils/backend_registration.py b/torch/utils/backend_registration.py index d034f22b1e6..3d04d9cd839 100644 --- a/torch/utils/backend_registration.py +++ b/torch/utils/backend_registration.py @@ -202,7 +202,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) - Args: device (int, optional): if specified, all parameters will be copied to that device """ - # pyrefly: ignore # missing-attribute + # pyrefly: ignore [missing-attribute] return self._apply(lambda t: getattr(t, custom_backend_name)(device)) _check_register_once(torch.nn.Module, custom_backend_name) @@ -252,11 +252,15 @@ def _generate_packed_sequence_methods_for_privateuse1_backend( device (int, optional): if specified, all parameters will be copied to that device """ ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( - *args, **kwargs + # pyrefly: ignore [not-iterable] + *args, + **kwargs, ) if ex.device.type == custom_backend_name: + # pyrefly: ignore [not-iterable] return self.to(*args, **kwargs) kwargs.update({"device": custom_backend_name}) + # pyrefly: ignore [not-iterable] return self.to(*args, **kwargs) _check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name) diff --git a/torch/utils/hipify/cuda_to_hip_mappings.py b/torch/utils/hipify/cuda_to_hip_mappings.py index d1d9a08c71c..82547c8e285 100644 --- a/torch/utils/hipify/cuda_to_hip_mappings.py +++ b/torch/utils/hipify/cuda_to_hip_mappings.py @@ -48,7 +48,7 @@ MATH_TRANSPILATIONS = collections.OrderedDict( ] ) -# pyrefly: ignore # no-matching-overload +# pyrefly: ignore [no-matching-overload] CUDA_TYPE_NAME_MAP = collections.OrderedDict( [ ("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)), @@ -587,6 +587,7 @@ CUDA_TYPE_NAME_MAP = collections.OrderedDict( ] ) +# pyrefly: ignore [no-matching-overload] CUDA_INCLUDE_MAP = collections.OrderedDict( [ # since pytorch uses "\b{pattern}\b" as the actual re pattern, @@ -676,7 +677,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict( ] ) -# pyrefly: ignore # no-matching-overload +# pyrefly: ignore [no-matching-overload] CUDA_IDENTIFIER_MAP = collections.OrderedDict( [ ("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)), @@ -8370,6 +8371,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict( ] ) +# pyrefly: ignore [no-matching-overload] CUDA_SPECIAL_MAP = collections.OrderedDict( [ # SPARSE @@ -8852,6 +8854,7 @@ CUDA_SPECIAL_MAP = collections.OrderedDict( ] ) +# pyrefly: ignore [no-matching-overload] PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict( [ ("USE_CUDA", ("USE_ROCM", API_PYTORCH)), @@ -9316,6 +9319,7 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict( ] ) +# pyrefly: ignore [no-matching-overload] CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict( [ ("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)), @@ -9401,6 +9405,7 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict( # # NB: if you want a transformation to ONLY apply to the c10/ directory, # put it as API_CAFFE2 +# pyrefly: ignore [no-matching-overload] C10_MAPPINGS = collections.OrderedDict( [ ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)), diff --git a/torch/utils/hipify/hipify_python.py b/torch/utils/hipify/hipify_python.py index 7e245262ea7..93ce3c50dfc 100755 --- a/torch/utils/hipify/hipify_python.py +++ b/torch/utils/hipify/hipify_python.py @@ -120,6 +120,7 @@ class GeneratedFileCleaner: def open(self, fn, *args, **kwargs): if not os.path.exists(fn): self.files_to_clean.add(os.path.abspath(fn)) + # pyrefly: ignore [not-iterable] return open(fn, *args, **kwargs) def makedirs(self, dn, exist_ok=False): @@ -669,7 +670,7 @@ def is_caffe2_gpu_file(rel_filepath): return True filename = os.path.basename(rel_filepath) _, ext = os.path.splitext(filename) - # pyrefly: ignore # unsupported-operation + # pyrefly: ignore [unsupported-operation] return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename) class TrieNode: @@ -1145,7 +1146,7 @@ def hipify( out_of_place_only=out_of_place_only, is_pytorch_extension=is_pytorch_extension)) all_files_set = set(all_files) - # pyrefly: ignore # bad-assignment + # pyrefly: ignore [bad-assignment] for f in extra_files: if not os.path.isabs(f): f = os.path.join(output_directory, f) diff --git a/torch/utils/weak.py b/torch/utils/weak.py index ed311cd0595..26b7b8c1122 100644 --- a/torch/utils/weak.py +++ b/torch/utils/weak.py @@ -292,9 +292,10 @@ class WeakIdKeyDictionary(MutableMapping): if o is not None: return o, value - # pyrefly: ignore # bad-override + # pyrefly: ignore [bad-override] def pop(self, key, *args): self._dirty_len = True + # pyrefly: ignore [not-iterable] return self.data.pop(self.ref_type(key), *args) # CHANGED def setdefault(self, key, default=None): diff --git a/torch/xpu/__init__.py b/torch/xpu/__init__.py index e95b7015f33..72285b46bcc 100644 --- a/torch/xpu/__init__.py +++ b/torch/xpu/__init__.py @@ -328,7 +328,7 @@ class StreamContext: self.stream = stream self.idx = _get_device_index(None, True) if self.idx is None: - self.idx = -1 # pyrefly: ignore # bad-assignment + self.idx = -1 # pyrefly: ignore [bad-assignment] def __enter__(self): cur_stream = self.stream diff --git a/torch/xpu/streams.py b/torch/xpu/streams.py index 378e71074c1..a1d78305f0a 100644 --- a/torch/xpu/streams.py +++ b/torch/xpu/streams.py @@ -126,7 +126,7 @@ class Event(torch._C._XpuEventBase): """ if stream is None: stream = torch.xpu.current_stream() - super().record(stream) # pyrefly: ignore # bad-argument-type + super().record(stream) # pyrefly: ignore [bad-argument-type] def wait(self, stream=None) -> None: r"""Make all future work submitted to the given stream wait for this event.