Clean up unused Pyrefly suppressions (#166178)

Cleaning up ignores that are no longer needed in the repo and adding select suppressions so the main branch is clean.

test plan:
`lintrunner -a`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166178
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss 2025-10-25 05:32:21 +00:00 committed by PyTorch MergeBot
parent 7924e3aacf
commit eb83c3ca23
85 changed files with 66 additions and 188 deletions

View File

@ -2103,11 +2103,10 @@ def export(
) )
and not trace_rules.check(call_to_inspect) and not trace_rules.check(call_to_inspect)
): ):
# pyrefly: ignore # unbound-name
dim_constraints.solve() dim_constraints.solve()
# pyrefly: ignore # unbound-name
forced_specializations = dim_constraints.forced_specializations() forced_specializations = dim_constraints.forced_specializations()
# pyrefly: ignore # unbound-name
msg = dim_constraints.prettify_results( msg = dim_constraints.prettify_results(
original_signature, original_signature,
dynamic_shapes, dynamic_shapes,
@ -2128,11 +2127,10 @@ def export(
) )
# Error if we have any constraints on static values # Error if we have any constraints on static values
# pyrefly: ignore # unbound-name
for k in shape_env.var_to_range.keys(): for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer): if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError( constraint_violation_error = ConstraintViolationError(
# pyrefly: ignore # unbound-name
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a " "It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. " f"value which we evaluated to have a static value of {k}. "

View File

@ -408,11 +408,10 @@ def _suggest_or_raise_constraint_violation(
torch._ops.OpOverloadPacket | torch._ops.OpOverload, torch._ops.OpOverloadPacket | torch._ops.OpOverload,
) )
): ):
# pyrefly: ignore # unbound-name
dim_constraints.solve() dim_constraints.solve()
# pyrefly: ignore # unbound-name
forced_specializations = dim_constraints.forced_specializations() forced_specializations = dim_constraints.forced_specializations()
# pyrefly: ignore # unbound-name
msg = dim_constraints.prettify_results( msg = dim_constraints.prettify_results(
inspect.signature(orig_callable), # type: ignore[attr-defined] inspect.signature(orig_callable), # type: ignore[attr-defined]
dynamic_shapes, dynamic_shapes,
@ -433,11 +432,10 @@ def _suggest_or_raise_constraint_violation(
) )
# Error if we have any constraints on static values # Error if we have any constraints on static values
# pyrefly: ignore # unbound-name
for k in shape_env.var_to_range.keys(): for k in shape_env.var_to_range.keys():
if isinstance(k, sympy.Integer): if isinstance(k, sympy.Integer):
constraint_violation_error = ConstraintViolationError( constraint_violation_error = ConstraintViolationError(
# pyrefly: ignore # unbound-name
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n" f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
"It appears that you're trying to set a constraint on a " "It appears that you're trying to set a constraint on a "
f"value which we evaluated to have a static value of {k}. " f"value which we evaluated to have a static value of {k}. "

View File

@ -456,8 +456,10 @@ def _add_mutation_dependencies(
for user in mutated_arg.users: for user in mutated_arg.users:
if user is node: if user is node:
continue continue
# pyrefly: ignore # unsupported-operation
elif user < node: elif user < node:
node_to_additional_deps[node].add(user) node_to_additional_deps[node].add(user)
# pyrefly: ignore # unsupported-operation
elif user > node: elif user > node:
node_to_additional_deps[user].add(node) node_to_additional_deps[user].add(node)

View File

@ -4100,13 +4100,12 @@ class CheckFunctionManager:
and (cache_entry := self.guard_manager.cache_entry) is not None and (cache_entry := self.guard_manager.cache_entry) is not None
and (extra_state := self.guard_manager.extra_state) is not None and (extra_state := self.guard_manager.extra_state) is not None
): ):
# pyrefly: ignore # unbound-name
assert isinstance(cache_entry, CacheEntry) assert isinstance(cache_entry, CacheEntry)
# pyrefly: ignore # unbound-name
assert isinstance(extra_state, ExtraState) assert isinstance(extra_state, ExtraState)
reason = f"Cache line invalidated because {obj_str} got deallocated" reason = f"Cache line invalidated because {obj_str} got deallocated"
deleted_guard_manager = DeletedGuardManagerWrapper(reason) deleted_guard_manager = DeletedGuardManagerWrapper(reason)
# pyrefly: ignore # unbound-name
extra_state.invalidate(cache_entry, deleted_guard_manager) extra_state.invalidate(cache_entry, deleted_guard_manager)
self.guard_manager = deleted_guard_manager self.guard_manager = deleted_guard_manager

View File

@ -2048,9 +2048,8 @@ class OutputGraph(OutputGraphCommon):
tx = self.root_tx tx = self.root_tx
assert tx is not None assert tx is not None
if (ds := tx.distributed_state) is not None and ds.all_states is None: if (ds := tx.distributed_state) is not None and ds.all_states is None:
# pyrefly: ignore # unbound-name
compile_pg = ds.compile_pg compile_pg = ds.compile_pg
# pyrefly: ignore # unbound-name
log.info("compiler_collective %s", ds.local_state) log.info("compiler_collective %s", ds.local_state)
torch._logging.trace_structured( torch._logging.trace_structured(
"artifact", "artifact",
@ -2058,7 +2057,6 @@ class OutputGraph(OutputGraphCommon):
"name": "compiler_collective", "name": "compiler_collective",
"encoding": "string", "encoding": "string",
}, },
# pyrefly: ignore # unbound-name
payload_fn=lambda: ds.local_state.render(), payload_fn=lambda: ds.local_state.render(),
) )
device_types = compile_pg._device_types device_types = compile_pg._device_types
@ -2072,9 +2070,9 @@ class OutputGraph(OutputGraphCommon):
dynamo_timed("compiler_collective", log_pt2_compile_event=True), dynamo_timed("compiler_collective", log_pt2_compile_event=True),
): ):
all_states: list[Any] = [None] * compile_pg.size() all_states: list[Any] = [None] * compile_pg.size()
# pyrefly: ignore # unbound-name
dist.all_gather_object(all_states, ds.local_state, group=compile_pg) dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
# pyrefly: ignore # unbound-name
ds.all_states = all_states ds.all_states = all_states
# Clear speculation log, because are tracing may diverge due to # Clear speculation log, because are tracing may diverge due to
# this information from the compiler collective # this information from the compiler collective
@ -2468,7 +2466,6 @@ class OutputGraph(OutputGraphCommon):
isinstance(b, torch.SymBool) isinstance(b, torch.SymBool)
and (r := b.node.maybe_as_bool()) is not None and (r := b.node.maybe_as_bool()) is not None
): ):
# pyrefly: ignore # unbound-name
return r return r
# TODO: We can also technically remove all cases when the input # TODO: We can also technically remove all cases when the input
# doesn't have unbacked inputs, since it's all in the ShapeEnv # doesn't have unbacked inputs, since it's all in the ShapeEnv

View File

@ -876,7 +876,6 @@ def get_code_state() -> defaultdict[CodeId, CodeState]:
not _CODE_STATE not _CODE_STATE
and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
): ):
# pyrefly: ignore # unbound-name
extra_read_key = get_extra_cache_key(sticky_read) extra_read_key = get_extra_cache_key(sticky_read)
if extra_read_key is not None: if extra_read_key is not None:
get_extra_remote_code_state(extra_read_key) get_extra_remote_code_state(extra_read_key)

View File

@ -4410,7 +4410,6 @@ class InstructionTranslator(InstructionTranslatorBase):
and isinstance(tos, LocalGeneratorObjectVariable) and isinstance(tos, LocalGeneratorObjectVariable)
): ):
self.stack[-1] = ListIteratorVariable( self.stack[-1] = ListIteratorVariable(
# pyrefly: ignore # unbound-name
tos.force_unpack_var_sequence(self), tos.force_unpack_var_sequence(self),
mutation_type=ValueMutationNew(), mutation_type=ValueMutationNew(),
) )

View File

@ -4214,7 +4214,6 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]:
# (x) + (y) # (x) + (y)
# ~~^~~~~~~ # ~~^~~~~~~
while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#": while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
# pyrefly: ignore # unbound-name
if ch in "\\#": if ch in "\\#":
cur_lineno, cur_col = nextline(cur_lineno, cur_col) cur_lineno, cur_col = nextline(cur_lineno, cur_col)
else: else:

View File

@ -317,7 +317,6 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
) )
res_proxy.node.meta.update(meta.data) res_proxy.node.meta.update(meta.data)
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env): if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
# pyrefly: ignore # unbound-name
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data): if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
res_proxy.node.meta["unbacked_bindings"] = symbol_to_path res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
self.tracer.set_metadata(res_proxy.node, res_data) self.tracer.set_metadata(res_proxy.node, res_data)

View File

@ -2183,7 +2183,6 @@ class GraphModuleDeserializer(metaclass=Final):
simplify=True, simplify=True,
) )
): ):
# pyrefly: ignore # unbound-name
node.meta["unbacked_bindings"] = unbacked_bindings node.meta["unbacked_bindings"] = unbacked_bindings
assert len(self.unbacked_symbols) == 0 assert len(self.unbacked_symbols) == 0

View File

@ -204,7 +204,6 @@ def run_functionalized_fw_and_collect_metadata(
suppress_pending = contextlib.nullcontext() suppress_pending = contextlib.nullcontext()
fake_mode = detect_fake_mode() fake_mode = detect_fake_mode()
if fake_mode and (shape_env := fake_mode.shape_env): if fake_mode and (shape_env := fake_mode.shape_env):
# pyrefly: ignore # unbound-name
suppress_pending = shape_env.ignore_fresh_unbacked_symbols() suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
with disable_above, mode, suppress_pending: with disable_above, mode, suppress_pending:
# precondition: The passed in function already handles unflattening inputs + flattening outputs # precondition: The passed in function already handles unflattening inputs + flattening outputs

View File

@ -746,7 +746,6 @@ class WhileLoopAutogradOp(torch.autograd.Function):
and (shape_env := loop_count.node.shape_env) and (shape_env := loop_count.node.shape_env)
and loop_count in shape_env.pending_fresh_unbacked_symbols and loop_count in shape_env.pending_fresh_unbacked_symbols
): ):
# pyrefly: ignore # unbound-name
shape_env.pending_fresh_unbacked_symbols.remove(loop_count) shape_env.pending_fresh_unbacked_symbols.remove(loop_count)
# Even when body function is not executed, we clone and unsqueeze the input # Even when body function is not executed, we clone and unsqueeze the input

View File

@ -132,7 +132,6 @@ def aoti_compile_and_package(
) )
or ( or (
isinstance(package_path, (str, os.PathLike)) isinstance(package_path, (str, os.PathLike))
# pyrefly: ignore # no-matching-overload
and os.fspath(package_path).endswith(".pt2") and os.fspath(package_path).endswith(".pt2")
) )
), ( ), (

View File

@ -557,7 +557,6 @@ class GPUDeviceBenchmarkMixin:
res = benchmarker.benchmark_gpu(fn) res = benchmarker.benchmark_gpu(fn)
device_interface.synchronize() # shake out any CUDA errors device_interface.synchronize() # shake out any CUDA errors
# pyrefly: ignore # bad-return
return res return res

View File

@ -1737,7 +1737,6 @@ class KernelArgs:
) )
) )
for outer, inner in chain( for outer, inner in chain(
# pyrefly: ignore # bad-argument-type
self.input_buffers.items(), self.input_buffers.items(),
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
self.output_buffers.items(), self.output_buffers.items(),

View File

@ -1478,7 +1478,6 @@ class CppGemmTemplate(CppTemplate):
assert isinstance(template_buffer, ir.IRNode) assert isinstance(template_buffer, ir.IRNode)
gemm_output_name = f"{template_buffer.get_name()}_GemmOut" gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
gemm_output_buffer = ir.Buffer( gemm_output_buffer = ir.Buffer(
# pyrefly: ignore # missing-attribute
name=gemm_output_name, name=gemm_output_name,
# pyrefly: ignore # missing-attribute # pyrefly: ignore # missing-attribute
layout=template_buffer.layout, layout=template_buffer.layout,
@ -1502,7 +1501,6 @@ class CppGemmTemplate(CppTemplate):
reindexers.append(None) reindexers.append(None)
if i < len(epilogue_creators) - 1: if i < len(epilogue_creators) - 1:
current_input_buffer = ir.Buffer( current_input_buffer = ir.Buffer(
# pyrefly: ignore # missing-attribute
name=buffer_name, name=buffer_name,
# pyrefly: ignore # missing-attribute # pyrefly: ignore # missing-attribute
layout=template_buffer.layout, layout=template_buffer.layout,

View File

@ -822,7 +822,6 @@ class CppWrapperGpu(CppWrapperCpu):
if triton: if triton:
call_args, arg_types = self.prepare_triton_wrapper_args( call_args, arg_types = self.prepare_triton_wrapper_args(
# pyrefly: ignore # bad-argument-type
call_args, call_args,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
arg_types, arg_types,

View File

@ -680,7 +680,6 @@ class MetalKernel(SIMDKernel):
) )
idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment] idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment]
idx_var = next( idx_var = next(
# pyrefly: ignore # missing-argument
t t
for t in self.range_tree_nodes.values() for t in self.range_tree_nodes.values()
# pyrefly: ignore # missing-argument # pyrefly: ignore # missing-argument
@ -863,7 +862,6 @@ class MetalKernel(SIMDKernel):
if self.inside_reduction: if self.inside_reduction:
total_reduction_size = math.prod( total_reduction_size = math.prod(
# pyrefly: ignore # missing-argument
t.numel t.numel
for t in self.range_trees for t in self.range_trees
# pyrefly: ignore # missing-argument # pyrefly: ignore # missing-argument

View File

@ -965,7 +965,6 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
def active_range_trees(self) -> list[IterationRangesRoot]: def active_range_trees(self) -> list[IterationRangesRoot]:
return [ return [
# pyrefly: ignore # missing-argument
t t
for t in self.range_trees for t in self.range_trees
# pyrefly: ignore # missing-argument # pyrefly: ignore # missing-argument

View File

@ -1036,7 +1036,6 @@ class FxConverter:
# Add constants stored as Triton metadata, in signature order. # Add constants stored as Triton metadata, in signature order.
call_kwargs |= constants call_kwargs |= constants
new_call_args = [ new_call_args = [
# pyrefly: ignore # missing-attribute
call_kwargs[key] call_kwargs[key]
for key in signature for key in signature
# pyrefly: ignore # missing-attribute # pyrefly: ignore # missing-attribute

View File

@ -826,11 +826,10 @@ def _schedule_for_comm(
collective_cost > 0 collective_cost > 0
and (candidate := get_overlapping_candidate()) is not None and (candidate := get_overlapping_candidate()) is not None
): ):
# pyrefly: ignore # unbound-name
ready.remove(candidate) ready.remove(candidate)
# pyrefly: ignore # unbound-name
schedule(candidate.snode) schedule(candidate.snode)
# pyrefly: ignore # unbound-name
collective_cost -= snode_to_cost[candidate.snode] collective_cost -= snode_to_cost[candidate.snode]
heapq.heapify(ready) heapq.heapify(ready)
@ -1098,7 +1097,7 @@ def _sink_waits_iterative_internal(
info.grouped_info = _group_names(gns) info.grouped_info = _group_names(gns)
candidate = _next[candidate] candidate = _next[candidate]
continue continue
# pyrefly: ignore # unbound-name
elif (data_dep is None) and both_contain_comms: elif (data_dep is None) and both_contain_comms:
info.limiting_factor = ( info.limiting_factor = (
f"collective ordering {_group_names(gns)}" f"collective ordering {_group_names(gns)}"

View File

@ -271,7 +271,6 @@ def record_original_output_strides(gm: GraphModule) -> None:
and (val := output.meta.get("val")) is not None and (val := output.meta.get("val")) is not None
and isinstance(val, torch.Tensor) and isinstance(val, torch.Tensor)
): ):
# pyrefly: ignore # unbound-name
output_strides.append(val.stride()) output_strides.append(val.stride())
else: else:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type

View File

@ -620,7 +620,6 @@ class _OutOfProcessFxCompile(_SerializedFxCompile):
if output.warning_replay: if output.warning_replay:
for w in output.warning_replay: for w in output.warning_replay:
# pyrefly: ignore # no-matching-overload
warnings.warn_explicit( warnings.warn_explicit(
message=w.message, message=w.message,
category=w.category, category=w.category,

View File

@ -544,7 +544,6 @@ def amax(
keepdim: bool = False, keepdim: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if self.dtype == torch.bool: if self.dtype == torch.bool:
# pyrefly: ignore # no-matching-overload
return torch.any(self, dim=dim, keepdim=keepdim) return torch.any(self, dim=dim, keepdim=keepdim)
return NotImplemented return NotImplemented
@ -556,7 +555,6 @@ def amin(
keepdim: bool = False, keepdim: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
if self.dtype == torch.bool: if self.dtype == torch.bool:
# pyrefly: ignore # no-matching-overload
return torch.all(self, dim=dim, keepdim=keepdim) return torch.all(self, dim=dim, keepdim=keepdim)
return NotImplemented return NotImplemented

View File

@ -238,7 +238,7 @@ class FakeTensorUpdater:
symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor) symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
): ):
# Refresh the bindings to the new symbols # Refresh the bindings to the new symbols
# pyrefly: ignore # unbound-name
node.meta["unbacked_bindings"] = symbol_to_path node.meta["unbacked_bindings"] = symbol_to_path
existing_storages[get_node_storage(node)] += 1 existing_storages[get_node_storage(node)] += 1

View File

@ -6500,12 +6500,10 @@ def div_prim(a, b):
# see https://github.com/pytorch/pytorch/issues/157959 # see https://github.com/pytorch/pytorch/issues/157959
if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu": if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu":
# Replace divide by constant with multiply by reciprocal # Replace divide by constant with multiply by reciprocal
# pyrefly: ignore # unbound-name
if divisor.value == 0: if divisor.value == 0:
# pyrefly: ignore # unbound-name
reciprocal = math.copysign(float("inf"), divisor.value) reciprocal = math.copysign(float("inf"), divisor.value)
else: else:
# pyrefly: ignore # unbound-name
reciprocal = 1.0 / divisor.value reciprocal = 1.0 / divisor.value
return mul(a, reciprocal) return mul(a, reciprocal)

View File

@ -131,7 +131,6 @@ def load_package(
) )
return AOTICompiledModel(loader) return AOTICompiledModel(loader)
# pyrefly: ignore # no-matching-overload
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str) path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
loader = torch._C._aoti.AOTIModelPackageLoader( loader = torch._C._aoti.AOTIModelPackageLoader(
path, model_name, run_single_threaded, num_runners, device_index path, model_name, run_single_threaded, num_runners, device_index

View File

@ -2676,7 +2676,6 @@ class Scheduler:
and (dep := next(iter(node.read_writes.writes))) and (dep := next(iter(node.read_writes.writes)))
and isinstance(dep, MemoryDep) and isinstance(dep, MemoryDep)
): ):
# pyrefly: ignore # unbound-name
node_mode = dep.mode node_mode = dep.mode
else: else:
node_mode = None node_mode = None
@ -4360,7 +4359,6 @@ class Scheduler:
if config.expand_dimension_for_pointwise_nodes and ( if config.expand_dimension_for_pointwise_nodes and (
expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2) expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2)
): ):
# pyrefly: ignore # unbound-name
(expand_dim, smaller_node, expand_size) = expand_analysis (expand_dim, smaller_node, expand_size) = expand_analysis
smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size) smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size)
shared_data_score = self.score_fusion_memory(node1, node2) shared_data_score = self.score_fusion_memory(node1, node2)
@ -4669,7 +4667,6 @@ class Scheduler:
device.type == "cuda" device.type == "cuda"
and (device_props := torch.cuda.get_device_properties(device)).major < 7 and (device_props := torch.cuda.get_device_properties(device)).major < 7
): ):
# pyrefly: ignore # unbound-name
raise GPUTooOldForTriton(device_props, inspect.currentframe()) raise GPUTooOldForTriton(device_props, inspect.currentframe())
elif is_gpu(device.type) and not device.type == "mps": elif is_gpu(device.type) and not device.type == "mps":
raise TritonMissing(inspect.currentframe()) raise TritonMissing(inspect.currentframe())
@ -4967,7 +4964,6 @@ class Scheduler:
if isinstance(buf.node, ir.MutationOutput) and ( if isinstance(buf.node, ir.MutationOutput) and (
real_name := self.mutation_real_name.get(buf_name, None) real_name := self.mutation_real_name.get(buf_name, None)
): ):
# pyrefly: ignore # unbound-name
return is_none_layout(real_name) return is_none_layout(real_name)
return True return True

View File

@ -3681,8 +3681,8 @@ class AlgorithmSelectorCache(PersistentCache):
), ),
node.get_device(), node.get_device(),
node.get_dtype(), node.get_dtype(),
# pyrefly: ignore # missing-attribute
V.graph.sizevars.atomically_apply_size_hint( V.graph.sizevars.atomically_apply_size_hint(
# pyrefly: ignore # missing-attribute
node.layout.offset, node.layout.offset,
fallback=config.unbacked_symint_fallback, fallback=config.unbacked_symint_fallback,
hint_override=hint_override, hint_override=hint_override,

View File

@ -1652,7 +1652,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
) )
# Build options dict # Build options dict
# pyrefly: ignore # no-matching-overload
options_dict = dict( options_dict = dict(
EVEN_K=even_k_symbolic, EVEN_K=even_k_symbolic,
USE_FAST_ACCUM=False, # Option for _scaled_mm USE_FAST_ACCUM=False, # Option for _scaled_mm

View File

@ -3764,7 +3764,6 @@ def maybe_log_cudagraph_partition(
and (fx_node := ir_node.get_origin_node()) and (fx_node := ir_node.get_origin_node())
and (stack_trace := fx_node.meta.get("stack_trace", None)) and (stack_trace := fx_node.meta.get("stack_trace", None))
): ):
# pyrefly: ignore # unbound-name
warning_msg = f"{warning_msg}. Found from : \n {stack_trace}" warning_msg = f"{warning_msg}. Found from : \n {stack_trace}"
perf_hint_log.warning(warning_msg) perf_hint_log.warning(warning_msg)

View File

@ -144,7 +144,6 @@ def benchmark_all_kernels(
launcher = triton_kernel.launchers[0] launcher = triton_kernel.launchers[0]
print( print(
get_info_str( get_info_str(
# pyrefly: ignore # bad-argument-type
ms, ms,
launcher.n_regs, launcher.n_regs,
launcher.n_spills, launcher.n_spills,

View File

@ -246,7 +246,6 @@ def save_op_profiles(op_profiles: dict[str, set[OpProfile]], f: FileLike) -> Non
yaml_str = generate_yaml_from_profiles(op_profiles) yaml_str = generate_yaml_from_profiles(op_profiles)
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f) f = os.fspath(f)
with open(f, "w") as file: with open(f, "w") as file:
@ -312,7 +311,6 @@ def load_op_profiles(f: FileLike) -> dict[str, set[OpProfile]]:
Loads the saved operator profiles from `save_op_profiles`. Loads the saved operator profiles from `save_op_profiles`.
""" """
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f) f = os.fspath(f)
with open(f) as file: with open(f) as file:

View File

@ -173,7 +173,7 @@ def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
f"Unable to accept name, {name}, for this opaque type as it contains a '.'" f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
) )
_OPAQUE_TYPES[cls] = name _OPAQUE_TYPES[cls] = name
# pyrefly: ignore # missing-attribute
torch._C._register_opaque_type(name) torch._C._register_opaque_type(name)
@ -183,5 +183,5 @@ def is_opaque_type(cls: Any) -> bool:
""" """
if cls not in _OPAQUE_TYPES: if cls not in _OPAQUE_TYPES:
return False return False
# pyrefly: ignore # missing-attribute
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls]) return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])

View File

@ -914,7 +914,6 @@ class TorchLogsFormatter(logging.Formatter):
and (trace_id := torch._guards.CompileContext.current_trace_id()) and (trace_id := torch._guards.CompileContext.current_trace_id())
is not None is not None
): ):
# pyrefly: ignore # unbound-name
record.traceid = f" [{trace_id}]" record.traceid = f" [{trace_id}]"
glog_level_to_abbr = { glog_level_to_abbr = {

View File

@ -1336,9 +1336,9 @@ def float_power(
# Float power has the following contiguous cast behavior to be # Float power has the following contiguous cast behavior to be
# consistent with its C++ impl # consistent with its C++ impl
# pyrefly: ignore # no-matching-overload
a = _maybe_convert_to_dtype(a, dtype) a = _maybe_convert_to_dtype(a, dtype)
# pyrefly: ignore # no-matching-overload
b = _maybe_convert_to_dtype(b, dtype) b = _maybe_convert_to_dtype(b, dtype)
a, b = _maybe_broadcast(a, b) a, b = _maybe_broadcast(a, b)
@ -2348,7 +2348,6 @@ def all(
dim: Optional[DimsType] = None, dim: Optional[DimsType] = None,
keepdim: bool = False, keepdim: bool = False,
) -> TensorLikeType: ) -> TensorLikeType:
# pyrefly: ignore # no-matching-overload
result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim)) result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim))
if a.dtype == torch.uint8: if a.dtype == torch.uint8:
@ -3245,7 +3244,7 @@ def _normalize(
mean (Tensor): mean of the tensor along norm_dims. mean (Tensor): mean of the tensor along norm_dims.
rstd (Tensor): 1/std of the tensor along norm_dims. rstd (Tensor): 1/std of the tensor along norm_dims.
""" """
# pyrefly: ignore # no-matching-overload
norm_dims = utils.canonicalize_dims(a.ndim, norm_dims) norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
computation_dtype = utils.get_computation_dtype(a.dtype) computation_dtype = utils.get_computation_dtype(a.dtype)
a_acc = _maybe_convert_to_dtype(a, computation_dtype) a_acc = _maybe_convert_to_dtype(a, computation_dtype)
@ -3975,7 +3974,7 @@ def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
@out_wrapper() @out_wrapper()
def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType: def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType:
"""Reference implementation of :func:`torch.roll`.""" """Reference implementation of :func:`torch.roll`."""
# pyrefly: ignore # no-matching-overload
dims = utils.canonicalize_dims(a.ndim, dims) dims = utils.canonicalize_dims(a.ndim, dims)
# ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1 # ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
if not isinstance(shifts, Iterable): if not isinstance(shifts, Iterable):
@ -4286,7 +4285,7 @@ def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType
return prims.squeeze(a, dims) if dims else prims.view_of(a) return prims.squeeze(a, dims) if dims else prims.view_of(a)
ndim = a.ndim ndim = a.ndim
# pyrefly: ignore # no-matching-overload
dim = utils.canonicalize_dims(ndim, dim) dim = utils.canonicalize_dims(ndim, dim)
dims = (dim,) if isinstance(dim, Dim) else dim dims = (dim,) if isinstance(dim, Dim) else dim
# Short-circuits if the tensor has no dimensions # Short-circuits if the tensor has no dimensions

View File

@ -216,7 +216,7 @@ def matrix_norm(
# shape # shape
check_is_matrix(A, "linalg.matrix_norm") check_is_matrix(A, "linalg.matrix_norm")
# dim # dim
# pyrefly: ignore # no-matching-overload
dim = utils.canonicalize_dims(A.ndim, dim) dim = utils.canonicalize_dims(A.ndim, dim)
if isinstance(dim, Dim): if isinstance(dim, Dim):
dim = (dim,) # type: ignore[assignment] dim = (dim,) # type: ignore[assignment]

View File

@ -2620,7 +2620,7 @@ class FakeTensorMode(TorchDispatchMode):
and s.rhs == 1 and s.rhs == 1
): ):
assert self.shape_env is not None assert self.shape_env is not None
# pyrefly: ignore # unbound-name
self.shape_env.set_unbacked_var_to_val(s, int(real_t)) self.shape_env.set_unbacked_var_to_val(s, int(real_t))
if real_out is not nil: if real_out is not nil:

View File

@ -1110,7 +1110,6 @@ class Tensor(torch._C.TensorBase):
@_handle_torch_function_and_wrap_type_error_to_not_implemented @_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
# pyrefly: ignore # no-matching-overload
return _C._VariableFunctions.rsub(self, other) return _C._VariableFunctions.rsub(self, other)
@_handle_torch_function_and_wrap_type_error_to_not_implemented @_handle_torch_function_and_wrap_type_error_to_not_implemented
@ -1137,7 +1136,7 @@ class Tensor(torch._C.TensorBase):
@_handle_torch_function_and_wrap_type_error_to_not_implemented @_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
return torch.remainder(other, self) # pyrefly: ignore # no-matching-overload return torch.remainder(other, self)
def __format__(self, format_spec): def __format__(self, format_spec):
if has_torch_function_unary(self): if has_torch_function_unary(self):
@ -1150,7 +1149,7 @@ class Tensor(torch._C.TensorBase):
@_handle_torch_function_and_wrap_type_error_to_not_implemented @_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor": def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
return torch.pow(other, self) # pyrefly: ignore # no-matching-overload return torch.pow(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented @_handle_torch_function_and_wrap_type_error_to_not_implemented
def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override] def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
@ -1166,14 +1165,12 @@ class Tensor(torch._C.TensorBase):
def __rlshift__( def __rlshift__(
self, other: Union["Tensor", int, float, bool, complex] self, other: Union["Tensor", int, float, bool, complex]
) -> "Tensor": ) -> "Tensor":
# pyrefly: ignore # no-matching-overload
return torch.bitwise_left_shift(other, self) return torch.bitwise_left_shift(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented @_handle_torch_function_and_wrap_type_error_to_not_implemented
def __rrshift__( def __rrshift__(
self, other: Union["Tensor", int, float, bool, complex] self, other: Union["Tensor", int, float, bool, complex]
) -> "Tensor": ) -> "Tensor":
# pyrefly: ignore # no-matching-overload
return torch.bitwise_right_shift(other, self) return torch.bitwise_right_shift(other, self)
@_handle_torch_function_and_wrap_type_error_to_not_implemented @_handle_torch_function_and_wrap_type_error_to_not_implemented

View File

@ -744,10 +744,7 @@ class ExceptionWrapper:
if exc_info is None: if exc_info is None:
exc_info = sys.exc_info() exc_info = sys.exc_info()
self.exc_type = exc_info[0] self.exc_type = exc_info[0]
self.exc_msg = "".join( self.exc_msg = "".join(traceback.format_exception(*exc_info))
# pyrefly: ignore # no-matching-overload
traceback.format_exception(*exc_info)
)
self.where = where self.where = where
def reraise(self): def reraise(self):

View File

@ -89,7 +89,6 @@ def compile_time_strobelight_meta(
skip := kwargs["skip"], skip := kwargs["skip"],
int, int,
): ):
# pyrefly: ignore # unbound-name
kwargs["skip"] = skip + 1 kwargs["skip"] = skip + 1
# This is not needed but we have it here to avoid having profile_compile_time # This is not needed but we have it here to avoid having profile_compile_time

View File

@ -95,7 +95,7 @@ class Conv1d(_ConvNd, nn.Conv1d):
and the backend should be able to fuse the ops with `*` into a quantized conv1d and the backend should be able to fuse the ops with `*` into a quantized conv1d
""" """
weight_quant_dequant = self.get_weight() weight_quant_dequant = self.get_weight()
# pyrefly: ignore # no-matching-overload
result = F.conv1d( result = F.conv1d(
x, x,
weight_quant_dequant, weight_quant_dequant,
@ -160,7 +160,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
and the backend should be able to fuse the ops with `*` into a quantized conv2d and the backend should be able to fuse the ops with `*` into a quantized conv2d
""" """
weight_quant_dequant = self.get_weight() weight_quant_dequant = self.get_weight()
# pyrefly: ignore # no-matching-overload
result = F.conv2d( result = F.conv2d(
x, x,
weight_quant_dequant, weight_quant_dequant,
@ -225,7 +225,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
and the backend should be able to fuse the ops with `*` into a quantized conv3d and the backend should be able to fuse the ops with `*` into a quantized conv3d
""" """
weight_quant_dequant = self.get_weight() weight_quant_dequant = self.get_weight()
# pyrefly: ignore # no-matching-overload
result = F.conv3d( result = F.conv3d(
x, x,
weight_quant_dequant, weight_quant_dequant,

View File

@ -1095,6 +1095,7 @@ def create_a_shadows_b(
# pyrefly: ignore # unbound-name # pyrefly: ignore # unbound-name
if not isinstance(input_logger, list): if not isinstance(input_logger, list):
raise AssertionError( raise AssertionError(
# pyrefly: ignore # unbound-name
f"Expected list, got {type(input_logger)}" f"Expected list, got {type(input_logger)}"
) )
# pyrefly: ignore # unbound-name # pyrefly: ignore # unbound-name

View File

@ -127,7 +127,6 @@ class AdaptiveRoundingOptimizer:
@torch.no_grad() @torch.no_grad()
def feed_forward(self, x, weight, module): def feed_forward(self, x, weight, module):
if isinstance(module, torch.nn.Conv1d): if isinstance(module, torch.nn.Conv1d):
# pyrefly: ignore # no-matching-overload
out = torch.nn.functional.conv1d( out = torch.nn.functional.conv1d(
x, x,
weight, weight,

View File

@ -90,7 +90,6 @@ def _find_q_dq_node_for_user(
and arg.op == "call_function" and arg.op == "call_function"
and arg.target in _QUANTIZE_OPS and arg.target in _QUANTIZE_OPS
): ):
# pyrefly: ignore # unbound-name
q_node = arg q_node = arg
return (q_node, dq_node) return (q_node, dq_node)

View File

@ -92,7 +92,7 @@ def _make_grads(
is_grads_batched: bool, is_grads_batched: bool,
) -> tuple[_OptionalTensor, ...]: ) -> tuple[_OptionalTensor, ...]:
new_grads: list[_OptionalTensor] = [] new_grads: list[_OptionalTensor] = []
# pyrefly: ignore # no-matching-overload
for out, grad in zip(outputs, grads): for out, grad in zip(outputs, grads):
out = cast(Union[torch.Tensor, graph.GradientEdge], out) out = cast(Union[torch.Tensor, graph.GradientEdge], out)
out_size = None out_size = None

View File

@ -155,25 +155,21 @@ class cuBLASModule:
if name == "allow_tf32": if name == "allow_tf32":
return torch._C._get_cublas_allow_tf32() return torch._C._get_cublas_allow_tf32()
elif name == "allow_fp16_reduced_precision_reduction": elif name == "allow_fp16_reduced_precision_reduction":
# pyrefly: ignore # not-iterable
allow_reduced_precision, _ = ( allow_reduced_precision, _ = (
torch._C._get_cublas_allow_fp16_reduced_precision_reduction() torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
) )
return allow_reduced_precision return allow_reduced_precision
elif name == "allow_fp16_reduced_precision_reduction_split_k": elif name == "allow_fp16_reduced_precision_reduction_split_k":
# pyrefly: ignore # not-iterable
_, allow_splitk = ( _, allow_splitk = (
torch._C._get_cublas_allow_fp16_reduced_precision_reduction() torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
) )
return allow_splitk return allow_splitk
elif name == "allow_bf16_reduced_precision_reduction": elif name == "allow_bf16_reduced_precision_reduction":
# pyrefly: ignore # not-iterable
allow_reduced_precision, _ = ( allow_reduced_precision, _ = (
torch._C._get_cublas_allow_bf16_reduced_precision_reduction() torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
) )
return allow_reduced_precision return allow_reduced_precision
elif name == "allow_bf16_reduced_precision_reduction_split_k": elif name == "allow_bf16_reduced_precision_reduction_split_k":
# pyrefly: ignore # not-iterable
_, allow_splitk = ( _, allow_splitk = (
torch._C._get_cublas_allow_bf16_reduced_precision_reduction() torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
) )
@ -193,7 +189,6 @@ class cuBLASModule:
) )
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction( return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(
allow_reduced_precision, allow_reduced_precision,
# pyrefly: ignore # bad-argument-count
allow_splitk, allow_splitk,
) )
elif name == "allow_bf16_reduced_precision_reduction": elif name == "allow_bf16_reduced_precision_reduction":
@ -202,7 +197,6 @@ class cuBLASModule:
) )
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction( return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(
allow_reduced_precision, allow_reduced_precision,
# pyrefly: ignore # bad-argument-count
allow_splitk, allow_splitk,
) )
elif name == "allow_fp16_accumulation": elif name == "allow_fp16_accumulation":

View File

@ -10,6 +10,7 @@ if hasattr(torch._C, "_CUDAGreenContext"):
# Python shim helps Sphinx process docstrings more reliably. # Python shim helps Sphinx process docstrings more reliably.
# pyrefly: ignore # invalid-inheritance
class GreenContext(_GreenContext): class GreenContext(_GreenContext):
r"""Wrapper around a CUDA green context. r"""Wrapper around a CUDA green context.

View File

@ -37,7 +37,6 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
checkpoint_id = kwargs.get("checkpoint_id") checkpoint_id = kwargs.get("checkpoint_id")
if not checkpoint_id and (serializer := storage_writer or storage_reader): if not checkpoint_id and (serializer := storage_writer or storage_reader):
# pyrefly: ignore # unbound-name
checkpoint_id = getattr(serializer, "checkpoint_id", None) checkpoint_id = getattr(serializer, "checkpoint_id", None)
msg_dict["checkpoint_id"] = ( msg_dict["checkpoint_id"] = (

View File

@ -1227,7 +1227,6 @@ def _unflatten_model_state_dict(
if not state_dict: if not state_dict:
return {} return {}
# pyrefly: ignore # no-matching-overload
if isinstance(next(iter(state_dict.keys())), nn.Module): if isinstance(next(iter(state_dict.keys())), nn.Module):
warnings.warn( warnings.warn(
"Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``" "Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``"

View File

@ -393,7 +393,7 @@ class BackendConfig:
# e.g. "nccl", "gloo", "ucc", "mpi" # e.g. "nccl", "gloo", "ucc", "mpi"
supported_devices = Backend.backend_capability[backend.lower()] supported_devices = Backend.backend_capability[backend.lower()]
backend_val = Backend(backend) backend_val = Backend(backend)
# pyrefly: ignore # bad-assignment
self.device_backend_map = dict.fromkeys(supported_devices, backend_val) self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
elif ":" in backend.lower(): elif ":" in backend.lower():
# Backend specified in "device:backend" format # Backend specified in "device:backend" format

View File

@ -290,7 +290,6 @@ def _shard_dict_of_args(
f"Unsupported chunk spec: {spec} and value: {v} combination." f"Unsupported chunk spec: {spec} and value: {v} combination."
) )
# pyrefly: ignore # no-matching-overload
for _flat_split_result, _v_split in zip( for _flat_split_result, _v_split in zip(
flat_split_results, v_splits, strict=True flat_split_results, v_splits, strict=True
): ):

View File

@ -1327,7 +1327,6 @@ class _ContextParallel(ParallelStyle):
placement = [Shard(self.seq_dim)] placement = [Shard(self.seq_dim)]
all_args = [] all_args = []
# pyrefly: ignore # bad-assignment, bad-argument-type
for arg in itertools.chain(args, kwargs.values()): for arg in itertools.chain(args, kwargs.values()):
if isinstance(arg, torch.Tensor): if isinstance(arg, torch.Tensor):
if isinstance(arg, DTensor): if isinstance(arg, DTensor):

View File

@ -548,7 +548,7 @@ class PrepareModuleInput(ParallelStyle):
assert self.desired_input_layouts is not None, ( assert self.desired_input_layouts is not None, (
"desired module inputs should not be None!" "desired module inputs should not be None!"
) )
# pyrefly: ignore # no-matching-overload
for inp, input_layout, desired_layout in zip( for inp, input_layout, desired_layout in zip(
inputs, self.input_layouts, self.desired_input_layouts inputs, self.input_layouts, self.desired_input_layouts
): ):
@ -664,7 +664,7 @@ class PrepareModuleOutput(ParallelStyle):
raise ValueError( raise ValueError(
"module outputs and output_layouts should have same length!" "module outputs and output_layouts should have same length!"
) )
# pyrefly: ignore # no-matching-overload
for out, out_layout, desired_out_layout in zip( for out, out_layout, desired_out_layout in zip(
outputs, self.output_layouts, self.desired_output_layouts outputs, self.output_layouts, self.desired_output_layouts
): ):

View File

@ -59,7 +59,7 @@ def _cast_forward_inputs(
def cast_fn(x: torch.Tensor) -> torch.Tensor: def cast_fn(x: torch.Tensor) -> torch.Tensor:
if not torch.is_floating_point(x) or x.dtype == dtype: if not torch.is_floating_point(x) or x.dtype == dtype:
return x return x
# pyrefly: ignore # no-matching-overload
return x.to(dtype) return x.to(dtype)
return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs)) return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))

View File

@ -436,7 +436,6 @@ def load(
print(ep(torch.randn(5))) print(ep(torch.randn(5)))
""" """
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f) f = os.fspath(f)
extra_files = extra_files or {} extra_files = extra_files or {}

View File

@ -514,7 +514,6 @@ def _replace_unbacked_bindings(gm: torch.fx.GraphModule) -> None:
simplify=True, simplify=True,
) )
): ):
# pyrefly: ignore # unbound-name
node.meta["unbacked_bindings"] = unbacked_bindings node.meta["unbacked_bindings"] = unbacked_bindings

View File

@ -683,7 +683,6 @@ def package_pt2(
if not ( if not (
(isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable()) (isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable())
# pyrefly: ignore # no-matching-overload
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2")) or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2"))
): ):
@ -695,7 +694,6 @@ def package_pt2(
) )
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f) f = os.fspath(f)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore # bad-argument-type
@ -1086,7 +1084,6 @@ def load_pt2(
if not ( if not (
(isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable()) (isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable())
# pyrefly: ignore # no-matching-overload
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2")) or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
): ):
# TODO: turn this into an error in 2.9 # TODO: turn this into an error in 2.9
@ -1097,7 +1094,6 @@ def load_pt2(
) )
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f) f = os.fspath(f)
weights = {} weights = {}
@ -1167,7 +1163,6 @@ def load_pt2(
else: else:
aoti_runners = { aoti_runners = {
model_name: _load_aoti( model_name: _load_aoti(
# pyrefly: ignore # bad-argument-type
f, f,
model_name, model_name,
run_single_threaded, run_single_threaded,

View File

@ -916,7 +916,6 @@ def fetch_object_proxy(
def fetch_object_proxy( def fetch_object_proxy(
tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType] tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType]
) -> object: ) -> object:
# pyrefly: ignore # no-matching-overload
return get_proxy_slot(t, tracer, t) return get_proxy_slot(t, tracer, t)
@ -965,7 +964,6 @@ def _fetch_proxies_and_all_constant_flag(
""" """
f_flat_args_kwargs = [ f_flat_args_kwargs = [
( (
# pyrefly: ignore # no-matching-overload
fetch_object_proxy(tracer, x) fetch_object_proxy(tracer, x)
if isinstance(x, (Tensor, _AnyScriptObject)) if isinstance(x, (Tensor, _AnyScriptObject))
else x else x
@ -2497,7 +2495,6 @@ class _MakefxTracer:
): ):
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
# pyrefly: ignore # unbound-name
insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx") insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx")
t.recompile() t.recompile()
# TODO: kind of a bad way to do it, should maybe figure out a better way # TODO: kind of a bad way to do it, should maybe figure out a better way

View File

@ -621,13 +621,12 @@ def rebind_unbacked(
): ):
# This is what the pattern match above is testing # This is what the pattern match above is testing
repacked = _sympy_cast_symbool_to_symint_guardless( repacked = _sympy_cast_symbool_to_symint_guardless(
# pyrefly: ignore # unbound-name
sympy.Eq(new_raw_u1, 1) sympy.Eq(new_raw_u1, 1)
) )
assert repacked == raw_u1, f"{repacked} != {raw_u1}" assert repacked == raw_u1, f"{repacked} != {raw_u1}"
# Cancel the to_int(to_bool(x)). This is sound because x in # Cancel the to_int(to_bool(x)). This is sound because x in
# [0, 1] # [0, 1]
# pyrefly: ignore # unbound-name
raw_u1 = new_raw_u1 raw_u1 = new_raw_u1
if not isinstance(raw_u1, sympy.Symbol): if not isinstance(raw_u1, sympy.Symbol):
@ -1055,7 +1054,6 @@ def find_symbol_binding_fx_nodes(
# NB: Prefer first occurrence of symbol # NB: Prefer first occurrence of symbol
for node in graph.nodes: for node in graph.nodes:
if (s := is_symbol_binding_fx_node(node)) is not None and s not in r: if (s := is_symbol_binding_fx_node(node)) is not None and s not in r:
# pyrefly: ignore # unbound-name
r[s] = node r[s] = node
return r return r
@ -1226,13 +1224,12 @@ def _free_unbacked_symbols_with_path(
and isinstance(s := expr(a), sympy.Symbol) and isinstance(s := expr(a), sympy.Symbol)
and s in pending and s in pending
): ):
# pyrefly: ignore # unbound-name
r[s] = path r[s] = path
if shape_env and real is not None: if shape_env and real is not None:
assert isinstance(real, (int, float)) assert isinstance(real, (int, float))
# pyrefly: ignore # unbound-name
shape_env.set_unbacked_var_to_val(s, real) shape_env.set_unbacked_var_to_val(s, real)
# pyrefly: ignore # unbound-name
pending.remove(s) pending.remove(s)
# When an unbacked SymInt is perfectly divisible by an integer # When an unbacked SymInt is perfectly divisible by an integer
# constant, we replace it with the integer constant to improve # constant, we replace it with the integer constant to improve
@ -1262,14 +1259,10 @@ def _free_unbacked_symbols_with_path(
source=shape_env.var_to_sources.get(s, [None])[0], # type: ignore[union-attr] source=shape_env.var_to_sources.get(s, [None])[0], # type: ignore[union-attr]
) )
# pyrefly: ignore # unbound-name
unbacked = lhs if lhs in pending else rhs unbacked = lhs if lhs in pending else rhs
divisor: IntLikeType = ( divisor: IntLikeType = (
# pyrefly: ignore # unbound-name
int(coeff) int(coeff)
# pyrefly: ignore # unbound-name
if shape_env and isinstance(coeff, sympy.Integer) if shape_env and isinstance(coeff, sympy.Integer)
# pyrefly: ignore # unbound-name
else _symint_wrap(coeff) else _symint_wrap(coeff)
) )
# TODO: DivideByKey needs to test divisibility at runtime! # TODO: DivideByKey needs to test divisibility at runtime!
@ -1278,11 +1271,8 @@ def _free_unbacked_symbols_with_path(
if real is not None: if real is not None:
assert isinstance(real, int) assert isinstance(real, int)
val = ( val = (
# pyrefly: ignore # unbound-name
real // int(coeff) real // int(coeff)
# pyrefly: ignore # unbound-name
if isinstance(coeff, sympy.Integer) if isinstance(coeff, sympy.Integer)
# pyrefly: ignore # unbound-name
else CleanDiv(real, coeff) else CleanDiv(real, coeff)
) )
if shape_env: if shape_env:
@ -1299,14 +1289,12 @@ def _free_unbacked_symbols_with_path(
and s.rhs == 1 and s.rhs == 1
and s.lhs in pending and s.lhs in pending
): ):
# pyrefly: ignore # unsupported-operation
r[s.lhs] = path + (ConvertIntKey(),) r[s.lhs] = path + (ConvertIntKey(),)
if real is not None: if real is not None:
assert type(real) is bool assert type(real) is bool
if shape_env: if shape_env:
# pyrefly: ignore # unbound-name
shape_env.set_unbacked_var_to_val(s, int(real)) shape_env.set_unbacked_var_to_val(s, int(real))
# pyrefly: ignore # unbound-name
pending.remove(s.lhs) pending.remove(s.lhs)
return r return r
@ -1382,7 +1370,6 @@ def compute_unbacked_bindings(
): ):
if ( if (
isinstance(old_sym, SymTypes) isinstance(old_sym, SymTypes)
# pyrefly: ignore # unbound-name
and (old_s := old_sym.node.expr) != new_s and (old_s := old_sym.node.expr) != new_s
): ):
# If old_s is not an unbacked_symbol, # If old_s is not an unbacked_symbol,
@ -1392,15 +1379,12 @@ def compute_unbacked_bindings(
# and the original symbol gets replaced by the backed symbol. # and the original symbol gets replaced by the backed symbol.
# When this happens we just replace new_s by the old_s # When this happens we just replace new_s by the old_s
# because we know the value is the same. # because we know the value is the same.
# pyrefly: ignore # unbound-name
if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s): if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s):
# pyrefly: ignore # unbound-name
shape_env._rename_unbacked_to(new_s, old_s) shape_env._rename_unbacked_to(new_s, old_s)
else: else:
# pyrefly: ignore # unbound-name
shape_env._eliminate_unbacked(new_s, old_s) shape_env._eliminate_unbacked(new_s, old_s)
elif not isinstance(old_sym, SymTypes): elif not isinstance(old_sym, SymTypes):
# pyrefly: ignore # unbound-name
shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym)) shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym))
return symbol_to_path return symbol_to_path
@ -3365,7 +3349,7 @@ class DimConstraints:
and str(symbol := next(iter(c["eq"].free_symbols))) == old_root and str(symbol := next(iter(c["eq"].free_symbols))) == old_root
): # derived dim with root = old_root ): # derived dim with root = old_root
new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1 new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1
# pyrefly: ignore # unbound-name
new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1 new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1
c["eq"] = new_expr c["eq"] = new_expr
@ -7630,10 +7614,9 @@ class ShapeEnv:
log.info( log.info(
"oblivious_size %s -> %s (passed counterfactual)", "oblivious_size %s -> %s (passed counterfactual)",
orig_expr, orig_expr,
# pyrefly: ignore # unbound-name
correct_hint, correct_hint,
) )
# pyrefly: ignore # unbound-name
concrete_val = correct_hint concrete_val = correct_hint
# NB: do NOT transmute into runtime assert # NB: do NOT transmute into runtime assert
ok = True ok = True
@ -7650,10 +7633,9 @@ class ShapeEnv:
).xreplace(self.var_to_val) ).xreplace(self.var_to_val)
).free_symbols ).free_symbols
): ):
# pyrefly: ignore # unbound-name
self._log_real_tensor_propagation(orig_expr, unsound_result) self._log_real_tensor_propagation(orig_expr, unsound_result)
transmute_into_runtime_assert = True transmute_into_runtime_assert = True
# pyrefly: ignore # unbound-name
concrete_val = unsound_result concrete_val = unsound_result
ok = True ok = True

View File

@ -1314,6 +1314,7 @@ class Graph:
f(to_erase) f(to_erase)
self._find_nodes_lookup_table.remove(to_erase) self._find_nodes_lookup_table.remove(to_erase)
# pyrefly: ignore # missing-attribute
to_erase._remove_from_list() to_erase._remove_from_list()
to_erase._erased = True # iterators may retain handles to erased nodes to_erase._erased = True # iterators may retain handles to erased nodes
self._len -= 1 self._len -= 1

View File

@ -385,6 +385,7 @@ class Node(_NodeBase):
Args: Args:
x (Node): The node to put before this node. Must be a member of the same graph. x (Node): The node to put before this node. Must be a member of the same graph.
""" """
# pyrefly: ignore # missing-attribute
self._prepend(x) self._prepend(x)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
@ -396,6 +397,7 @@ class Node(_NodeBase):
Args: Args:
x (Node): The node to put after this node. Must be a member of the same graph. x (Node): The node to put after this node. Must be a member of the same graph.
""" """
# pyrefly: ignore # missing-attribute
self._next._prepend(x) self._next._prepend(x)
@property @property

View File

@ -276,7 +276,6 @@ def tensorify_python_scalars(
): ):
transform = True transform = True
try: try:
# pyrefly: ignore # unbound-name
proxy = _sympy_interp(zf.node.expr) proxy = _sympy_interp(zf.node.expr)
except NotImplementedError: except NotImplementedError:
transform = False transform = False
@ -303,7 +302,6 @@ def tensorify_python_scalars(
args.append(a) args.append(a)
if transform: if transform:
# pyrefly: ignore # unbound-name
replacement_proxy = replacement_op(*args) replacement_proxy = replacement_op(*args)
# pyrefly: ignore # missing-attribute # pyrefly: ignore # missing-attribute

View File

@ -93,7 +93,6 @@ class FakeTensorProp(torch.fx.Interpreter):
if (shape_env := self._mode.shape_env) and ( if (shape_env := self._mode.shape_env) and (
symbol_to_path := compute_unbacked_bindings(shape_env, result) symbol_to_path := compute_unbacked_bindings(shape_env, result)
): ):
# pyrefly: ignore # unbound-name
n.meta["unbacked_bindings"] = symbol_to_path n.meta["unbacked_bindings"] = symbol_to_path
return result return result

View File

@ -298,14 +298,12 @@ def insert_deferred_runtime_asserts(
and s not in expr_to_proxy and s not in expr_to_proxy
): ):
with _set_node_metadata_hook(gm, _node_metadata_hook): with _set_node_metadata_hook(gm, _node_metadata_hook):
# pyrefly: ignore # unbound-name
expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer) expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
# pyrefly: ignore # unbound-name
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s]) log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
# pyrefly: ignore # unbound-name
match_symbol(example_value, lambda: node) match_symbol(example_value, lambda: node)
# pyrefly: ignore # unbound-name
if isinstance(t := example_value, torch.Tensor): if isinstance(t := example_value, torch.Tensor):
for i, s in enumerate(t.size()): for i, s in enumerate(t.size()):
match_symbol( match_symbol(
@ -386,7 +384,6 @@ def insert_deferred_runtime_asserts(
# maybe re-reify expression, replace current node # maybe re-reify expression, replace current node
if ( if (
# pyrefly: ignore # unbound-name
sym_expr in expr_to_proxy sym_expr in expr_to_proxy
or ( # example value is redundant or ( # example value is redundant
_is_intermediate_tensor_sym_call(node) _is_intermediate_tensor_sym_call(node)
@ -405,10 +402,8 @@ def insert_deferred_runtime_asserts(
nn_module_stack=node.meta.get("nn_module_stack"), nn_module_stack=node.meta.get("nn_module_stack"),
), ),
): ):
# pyrefly: ignore # unbound-name
expr_to_proxy[sym_expr] = _sympy_interp( expr_to_proxy[sym_expr] = _sympy_interp(
expr_to_proxy, expr_to_proxy,
# pyrefly: ignore # unbound-name
sym_expr, sym_expr,
) # type: ignore[arg-type] ) # type: ignore[arg-type]
# won't try DCE-ing tensor compute here # won't try DCE-ing tensor compute here
@ -419,14 +414,12 @@ def insert_deferred_runtime_asserts(
"CSE node %s -> %s for expr %s", "CSE node %s -> %s for expr %s",
node, node,
hash_node, hash_node,
# pyrefly: ignore # unbound-name
sym_expr, sym_expr,
) )
# store node in hash cons, don't delete/replace # store node in hash cons, don't delete/replace
# pyrefly: ignore # unbound-name
elif sym_expr not in expr_to_proxy and not isinstance( elif sym_expr not in expr_to_proxy and not isinstance(
# pyrefly: ignore # unbound-name
sym_expr, sym_expr,
(sympy.Number, sympy.logic.boolalg.BooleanAtom), (sympy.Number, sympy.logic.boolalg.BooleanAtom),
): # don't hash cons primitives ): # don't hash cons primitives

View File

@ -318,7 +318,6 @@ def split_module(
and isinstance(s0 := val.node.expr, sympy.Symbol) and isinstance(s0 := val.node.expr, sympy.Symbol)
and s0 not in symbol_to_node and s0 not in symbol_to_node
): ):
# pyrefly: ignore # unbound-name
symbol_to_node[val.node.expr] = node symbol_to_node[val.node.expr] = node
if node.op in ["placeholder", "get_attr", "output"]: if node.op in ["placeholder", "get_attr", "output"]:

View File

@ -85,7 +85,6 @@ def get_source_partitions(
if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and ( if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and (
torch_fn := node.meta.get("torch_fn", None) torch_fn := node.meta.get("torch_fn", None)
) is not None: ) is not None:
# pyrefly: ignore # unbound-name
node_fqn, source_fn = torch_fn node_fqn, source_fn = torch_fn
source_fn_name = source_fn.split(".")[1] source_fn_name = source_fn.split(".")[1]
if source_fn_name in wanted_sources: if source_fn_name in wanted_sources:

View File

@ -421,7 +421,7 @@ def set_dir(d: Union[str, os.PathLike]) -> None:
d (str): path to a local folder to save downloaded models & weights. d (str): path to a local folder to save downloaded models & weights.
""" """
global _hub_dir global _hub_dir
_hub_dir = os.path.expanduser(d) # pyrefly: ignore # no-matching-overload _hub_dir = os.path.expanduser(d)
def list( def list(

View File

@ -167,7 +167,6 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
cpp_module = torch._C.import_ir_module( cpp_module = torch._C.import_ir_module(
cu, cu,
# pyrefly: ignore # no-matching-overload
os.fspath(f), os.fspath(f),
map_location, map_location,
_extra_files, _extra_files,
@ -208,7 +207,6 @@ def validate_map_location(map_location=None):
def jit_module_from_flatbuffer(f): def jit_module_from_flatbuffer(f):
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f) f = os.fspath(f)
return wrap_cpp_module(torch._C._load_jit_module_from_file(f)) return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
else: else:
@ -258,7 +256,6 @@ def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
extra_files = {} extra_files = {}
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f) f = os.fspath(f)
torch._C._save_jit_module(m._c, f, extra_files) torch._C._save_jit_module(m._c, f, extra_files)
else: else:

View File

@ -44,7 +44,6 @@ def _load_for_lite_interpreter(f, map_location=None):
map_location = validate_map_location(map_location) map_location = validate_map_location(map_location)
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location) cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
else: else:
cpp_module = torch._C._load_for_lite_interpreter_from_buffer( cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
@ -106,7 +105,6 @@ def _get_model_bytecode_version(f_input) -> int:
raise ValueError(f"The provided filename {f_input} is a directory") raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)): if isinstance(f_input, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
return torch._C._get_model_bytecode_version(os.fspath(f_input)) return torch._C._get_model_bytecode_version(os.fspath(f_input))
else: else:
# pyrefly: ignore # missing-attribute # pyrefly: ignore # missing-attribute
@ -140,7 +138,6 @@ def _get_mobile_model_contained_types(f_input) -> int:
raise ValueError(f"The provided filename {f_input} is a directory") raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)): if isinstance(f_input, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
return torch._C._get_mobile_model_contained_types(os.fspath(f_input)) return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
else: else:
# pyrefly: ignore # missing-attribute # pyrefly: ignore # missing-attribute
@ -168,9 +165,7 @@ def _backport_for_mobile(f_input, f_output, to_version):
isinstance(f_output, (str, os.PathLike)) isinstance(f_output, (str, os.PathLike))
): ):
return torch._C._backport_for_mobile( return torch._C._backport_for_mobile(
# pyrefly: ignore # no-matching-overload
os.fspath(f_input), os.fspath(f_input),
# pyrefly: ignore # no-matching-overload
os.fspath(f_output), os.fspath(f_output),
to_version, to_version,
) )
@ -198,7 +193,6 @@ def _backport_for_mobile_to_buffer(f_input, to_version):
raise ValueError(f"The provided filename {f_input} is a directory") raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)): if isinstance(f_input, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version) return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
else: else:
return torch._C._backport_for_mobile_from_buffer_to_buffer( return torch._C._backport_for_mobile_from_buffer_to_buffer(
@ -244,7 +238,6 @@ def _get_model_ops_and_info(f_input):
raise ValueError(f"The provided filename {f_input} is a directory") raise ValueError(f"The provided filename {f_input} is a directory")
if isinstance(f_input, (str, os.PathLike)): if isinstance(f_input, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
return torch._C._get_model_ops_and_info(os.fspath(f_input)) return torch._C._get_model_ops_and_info(os.fspath(f_input))
else: else:
# pyrefly: ignore # missing-attribute # pyrefly: ignore # missing-attribute

View File

@ -644,7 +644,7 @@ def impl(
>>> y2 = torch.sin(x) + 1 >>> y2 = torch.sin(x) + 1
>>> assert torch.allclose(y1, y2) >>> assert torch.allclose(y1, y2)
""" """
# pyrefly: ignore # no-matching-overload
return _impl(qualname, types, func, lib=lib, disable_dynamo=False) return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
@ -831,7 +831,6 @@ def register_kernel(
if device_types is None: if device_types is None:
device_types = "CompositeExplicitAutograd" device_types = "CompositeExplicitAutograd"
# pyrefly: ignore # no-matching-overload
return _impl(op, device_types, func, lib=lib, disable_dynamo=True) return _impl(op, device_types, func, lib=lib, disable_dynamo=True)

View File

@ -642,7 +642,6 @@ def _sparse_coo_scatter_reduction_helper(
# promote dtype if specified # promote dtype if specified
if values.dtype != output_dtype: if values.dtype != output_dtype:
# pyrefly: ignore # no-matching-overload
values = values.to(output_dtype) values = values.to(output_dtype)
if keepdim: if keepdim:
@ -767,7 +766,6 @@ def _sparse_csr_segment_reduction_helper(
# promote dtype if specified # promote dtype if specified
if values.dtype != output_dtype: if values.dtype != output_dtype:
# pyrefly: ignore # no-matching-overload
values = values.to(output_dtype) values = values.to(output_dtype)
if len(dims) == 0: if len(dims) == 0:

View File

@ -473,7 +473,6 @@ class ModuleList(Module):
return self return self
def pop(self, key: Union[int, slice]) -> Module: def pop(self, key: Union[int, slice]) -> Module:
# pyrefly: ignore # index-error
v = self[key] v = self[key]
del self[key] del self[key]
return v return v

View File

@ -363,7 +363,7 @@ class Conv1d(_ConvNd):
self.dilation, self.dilation,
self.groups, self.groups,
) )
# pyrefly: ignore # no-matching-overload
return F.conv1d( return F.conv1d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups input, weight, bias, self.stride, self.padding, self.dilation, self.groups
) )
@ -541,7 +541,7 @@ class Conv2d(_ConvNd):
self.dilation, self.dilation,
self.groups, self.groups,
) )
# pyrefly: ignore # no-matching-overload
return F.conv2d( return F.conv2d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups input, weight, bias, self.stride, self.padding, self.dilation, self.groups
) )
@ -711,7 +711,7 @@ class Conv3d(_ConvNd):
self.dilation, self.dilation,
self.groups, self.groups,
) )
# pyrefly: ignore # no-matching-overload
return F.conv3d( return F.conv3d(
input, weight, bias, self.stride, self.padding, self.dilation, self.groups input, weight, bias, self.stride, self.padding, self.dilation, self.groups
) )

View File

@ -364,7 +364,6 @@ def parse_args(
fn_name = None fn_name = None
args = [ args = [
_parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign] _parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign]
# pyrefly: ignore # no-matching-overload
for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names) for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)
] ]
# only support _outputs in kwargs # only support _outputs in kwargs

View File

@ -453,7 +453,7 @@ def _single_tensor_adam(
device_beta1 = beta1 device_beta1 = beta1
# Decay the first and second moment running average coefficient # Decay the first and second moment running average coefficient
# pyrefly: ignore # no-matching-overload
exp_avg.lerp_(grad, 1 - device_beta1) exp_avg.lerp_(grad, 1 - device_beta1)
# Nested if is necessary to bypass jitscript rules # Nested if is necessary to bypass jitscript rules

View File

@ -398,7 +398,6 @@ class Optimizer:
self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict) self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict)
self.param_groups: list[dict[str, Any]] = [] self.param_groups: list[dict[str, Any]] = []
# pyrefly: ignore # no-matching-overload
param_groups = list(params) param_groups = list(params)
if len(param_groups) == 0: if len(param_groups) == 0:
raise ValueError("optimizer got an empty parameter list") raise ValueError("optimizer got an empty parameter list")

View File

@ -219,7 +219,7 @@ class PackageExporter:
torch._C._log_api_usage_once("torch.package.PackageExporter") torch._C._log_api_usage_once("torch.package.PackageExporter")
self.debug = debug self.debug = debug
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
f = os.fspath(f) # pyrefly: ignore # no-matching-overload f = os.fspath(f)
self.buffer: Optional[IO[bytes]] = None self.buffer: Optional[IO[bytes]] = None
else: # is a byte buffer else: # is a byte buffer
self.buffer = f self.buffer = f

View File

@ -108,7 +108,6 @@ class PackageImporter(Importer):
self.filename = "<pytorch_file_reader>" self.filename = "<pytorch_file_reader>"
self.zip_reader = file_or_buffer self.zip_reader = file_or_buffer
elif isinstance(file_or_buffer, (os.PathLike, str)): elif isinstance(file_or_buffer, (os.PathLike, str)):
# pyrefly: ignore # no-matching-overload
self.filename = os.fspath(file_or_buffer) self.filename = os.fspath(file_or_buffer)
if not os.path.isdir(self.filename): if not os.path.isdir(self.filename):
self.zip_reader = torch._C.PyTorchFileReader(self.filename) self.zip_reader = torch._C.PyTorchFileReader(self.filename)

View File

@ -774,10 +774,7 @@ def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]): class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]):
def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None: def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None:
super().__init__( super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
# pyrefly: ignore # no-matching-overload
torch._C.PyTorchFileReader(name_or_buffer)
)
class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]): class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
@ -970,7 +967,7 @@ def save(
_check_save_filelike(f) _check_save_filelike(f)
if isinstance(f, (str, os.PathLike)): if isinstance(f, (str, os.PathLike)):
f = os.fspath(f) # pyrefly: ignore # no-matching-overload f = os.fspath(f)
if _use_new_zipfile_serialization: if _use_new_zipfile_serialization:
with _open_zipfile_writer(f) as opened_zipfile: with _open_zipfile_writer(f) as opened_zipfile:
@ -1524,7 +1521,6 @@ def load(
else: else:
shared = False shared = False
overall_storage = torch.UntypedStorage.from_file( overall_storage = torch.UntypedStorage.from_file(
# pyrefly: ignore # no-matching-overload
os.fspath(f), os.fspath(f),
shared, shared,
size, size,

View File

@ -701,7 +701,6 @@ def tree_map_only(
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -762,7 +761,6 @@ def tree_map_only_(
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)

View File

@ -1555,7 +1555,6 @@ def tree_map_only(
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
@ -1616,7 +1615,6 @@ def tree_map_only_(
tree: PyTree, tree: PyTree,
is_leaf: Optional[Callable[[PyTree], bool]] = None, is_leaf: Optional[Callable[[PyTree], bool]] = None,
) -> PyTree: ) -> PyTree:
# pyrefly: ignore # no-matching-overload
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf) return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)

View File

@ -1531,7 +1531,7 @@ def include_paths(device_type: str = "cpu", torch_include_dirs=True) -> list[str
# Support CUDA_INC_PATH env variable supported by CMake files # Support CUDA_INC_PATH env variable supported by CMake files
if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \ if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \
cuda_inc_path != '/usr/include': cuda_inc_path != '/usr/include':
# pyrefly: ignore # unbound-name
paths.append(cuda_inc_path) paths.append(cuda_inc_path)
if CUDNN_HOME is not None: if CUDNN_HOME is not None:
paths.append(os.path.join(CUDNN_HOME, 'include')) paths.append(os.path.join(CUDNN_HOME, 'include'))

View File

@ -678,7 +678,6 @@ class _BaseDataLoaderIter:
# Set pin memory device based on the current accelerator. # Set pin memory device based on the current accelerator.
self._pin_memory_device = ( self._pin_memory_device = (
# pyrefly: ignore # unbound-name
acc.type acc.type
if self._pin_memory if self._pin_memory
and (acc := torch.accelerator.current_accelerator()) is not None and (acc := torch.accelerator.current_accelerator()) is not None

View File

@ -251,7 +251,6 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]:
# Only keep attributes that are safe for dictionary serialization. # Only keep attributes that are safe for dictionary serialization.
serializable_types = (int, float, bool, str, type(None), list, tuple, dict) serializable_types = (int, float, bool, str, type(None), list, tuple, dict)
return { return {
# pyrefly: ignore # unbound-name
key: value key: value
for key in dir(props) for key in dir(props)
if not key.startswith("__") if not key.startswith("__")