mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Clean up unused Pyrefly suppressions (#166178)
Cleaning up ignores that are no longer needed in the repo and adding select suppressions so the main branch is clean. test plan: `lintrunner -a` Pull Request resolved: https://github.com/pytorch/pytorch/pull/166178 Approved by: https://github.com/oulgen
This commit is contained in:
parent
7924e3aacf
commit
eb83c3ca23
|
|
@ -2103,11 +2103,10 @@ def export(
|
||||||
)
|
)
|
||||||
and not trace_rules.check(call_to_inspect)
|
and not trace_rules.check(call_to_inspect)
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
dim_constraints.solve()
|
dim_constraints.solve()
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
forced_specializations = dim_constraints.forced_specializations()
|
forced_specializations = dim_constraints.forced_specializations()
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
msg = dim_constraints.prettify_results(
|
msg = dim_constraints.prettify_results(
|
||||||
original_signature,
|
original_signature,
|
||||||
dynamic_shapes,
|
dynamic_shapes,
|
||||||
|
|
@ -2128,11 +2127,10 @@ def export(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Error if we have any constraints on static values
|
# Error if we have any constraints on static values
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
for k in shape_env.var_to_range.keys():
|
for k in shape_env.var_to_range.keys():
|
||||||
if isinstance(k, sympy.Integer):
|
if isinstance(k, sympy.Integer):
|
||||||
constraint_violation_error = ConstraintViolationError(
|
constraint_violation_error = ConstraintViolationError(
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||||
"It appears that you're trying to set a constraint on a "
|
"It appears that you're trying to set a constraint on a "
|
||||||
f"value which we evaluated to have a static value of {k}. "
|
f"value which we evaluated to have a static value of {k}. "
|
||||||
|
|
|
||||||
|
|
@ -408,11 +408,10 @@ def _suggest_or_raise_constraint_violation(
|
||||||
torch._ops.OpOverloadPacket | torch._ops.OpOverload,
|
torch._ops.OpOverloadPacket | torch._ops.OpOverload,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
dim_constraints.solve()
|
dim_constraints.solve()
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
forced_specializations = dim_constraints.forced_specializations()
|
forced_specializations = dim_constraints.forced_specializations()
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
msg = dim_constraints.prettify_results(
|
msg = dim_constraints.prettify_results(
|
||||||
inspect.signature(orig_callable), # type: ignore[attr-defined]
|
inspect.signature(orig_callable), # type: ignore[attr-defined]
|
||||||
dynamic_shapes,
|
dynamic_shapes,
|
||||||
|
|
@ -433,11 +432,10 @@ def _suggest_or_raise_constraint_violation(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Error if we have any constraints on static values
|
# Error if we have any constraints on static values
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
for k in shape_env.var_to_range.keys():
|
for k in shape_env.var_to_range.keys():
|
||||||
if isinstance(k, sympy.Integer):
|
if isinstance(k, sympy.Integer):
|
||||||
constraint_violation_error = ConstraintViolationError(
|
constraint_violation_error = ConstraintViolationError(
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
f"{''.join(traceback.format_list(shape_env.var_to_stack[k]))}\n"
|
||||||
"It appears that you're trying to set a constraint on a "
|
"It appears that you're trying to set a constraint on a "
|
||||||
f"value which we evaluated to have a static value of {k}. "
|
f"value which we evaluated to have a static value of {k}. "
|
||||||
|
|
|
||||||
|
|
@ -456,8 +456,10 @@ def _add_mutation_dependencies(
|
||||||
for user in mutated_arg.users:
|
for user in mutated_arg.users:
|
||||||
if user is node:
|
if user is node:
|
||||||
continue
|
continue
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
elif user < node:
|
elif user < node:
|
||||||
node_to_additional_deps[node].add(user)
|
node_to_additional_deps[node].add(user)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
elif user > node:
|
elif user > node:
|
||||||
node_to_additional_deps[user].add(node)
|
node_to_additional_deps[user].add(node)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4100,13 +4100,12 @@ class CheckFunctionManager:
|
||||||
and (cache_entry := self.guard_manager.cache_entry) is not None
|
and (cache_entry := self.guard_manager.cache_entry) is not None
|
||||||
and (extra_state := self.guard_manager.extra_state) is not None
|
and (extra_state := self.guard_manager.extra_state) is not None
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
assert isinstance(cache_entry, CacheEntry)
|
assert isinstance(cache_entry, CacheEntry)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
assert isinstance(extra_state, ExtraState)
|
assert isinstance(extra_state, ExtraState)
|
||||||
reason = f"Cache line invalidated because {obj_str} got deallocated"
|
reason = f"Cache line invalidated because {obj_str} got deallocated"
|
||||||
deleted_guard_manager = DeletedGuardManagerWrapper(reason)
|
deleted_guard_manager = DeletedGuardManagerWrapper(reason)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
extra_state.invalidate(cache_entry, deleted_guard_manager)
|
extra_state.invalidate(cache_entry, deleted_guard_manager)
|
||||||
self.guard_manager = deleted_guard_manager
|
self.guard_manager = deleted_guard_manager
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2048,9 +2048,8 @@ class OutputGraph(OutputGraphCommon):
|
||||||
tx = self.root_tx
|
tx = self.root_tx
|
||||||
assert tx is not None
|
assert tx is not None
|
||||||
if (ds := tx.distributed_state) is not None and ds.all_states is None:
|
if (ds := tx.distributed_state) is not None and ds.all_states is None:
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
compile_pg = ds.compile_pg
|
compile_pg = ds.compile_pg
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
log.info("compiler_collective %s", ds.local_state)
|
log.info("compiler_collective %s", ds.local_state)
|
||||||
torch._logging.trace_structured(
|
torch._logging.trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
|
|
@ -2058,7 +2057,6 @@ class OutputGraph(OutputGraphCommon):
|
||||||
"name": "compiler_collective",
|
"name": "compiler_collective",
|
||||||
"encoding": "string",
|
"encoding": "string",
|
||||||
},
|
},
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
payload_fn=lambda: ds.local_state.render(),
|
payload_fn=lambda: ds.local_state.render(),
|
||||||
)
|
)
|
||||||
device_types = compile_pg._device_types
|
device_types = compile_pg._device_types
|
||||||
|
|
@ -2072,9 +2070,9 @@ class OutputGraph(OutputGraphCommon):
|
||||||
dynamo_timed("compiler_collective", log_pt2_compile_event=True),
|
dynamo_timed("compiler_collective", log_pt2_compile_event=True),
|
||||||
):
|
):
|
||||||
all_states: list[Any] = [None] * compile_pg.size()
|
all_states: list[Any] = [None] * compile_pg.size()
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
|
dist.all_gather_object(all_states, ds.local_state, group=compile_pg)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
ds.all_states = all_states
|
ds.all_states = all_states
|
||||||
# Clear speculation log, because are tracing may diverge due to
|
# Clear speculation log, because are tracing may diverge due to
|
||||||
# this information from the compiler collective
|
# this information from the compiler collective
|
||||||
|
|
@ -2468,7 +2466,6 @@ class OutputGraph(OutputGraphCommon):
|
||||||
isinstance(b, torch.SymBool)
|
isinstance(b, torch.SymBool)
|
||||||
and (r := b.node.maybe_as_bool()) is not None
|
and (r := b.node.maybe_as_bool()) is not None
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
return r
|
return r
|
||||||
# TODO: We can also technically remove all cases when the input
|
# TODO: We can also technically remove all cases when the input
|
||||||
# doesn't have unbacked inputs, since it's all in the ShapeEnv
|
# doesn't have unbacked inputs, since it's all in the ShapeEnv
|
||||||
|
|
|
||||||
|
|
@ -876,7 +876,6 @@ def get_code_state() -> defaultdict[CodeId, CodeState]:
|
||||||
not _CODE_STATE
|
not _CODE_STATE
|
||||||
and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
|
and (sticky_read := torch.compiler.config.pgo_extra_read_key) is not None
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
extra_read_key = get_extra_cache_key(sticky_read)
|
extra_read_key = get_extra_cache_key(sticky_read)
|
||||||
if extra_read_key is not None:
|
if extra_read_key is not None:
|
||||||
get_extra_remote_code_state(extra_read_key)
|
get_extra_remote_code_state(extra_read_key)
|
||||||
|
|
|
||||||
|
|
@ -4410,7 +4410,6 @@ class InstructionTranslator(InstructionTranslatorBase):
|
||||||
and isinstance(tos, LocalGeneratorObjectVariable)
|
and isinstance(tos, LocalGeneratorObjectVariable)
|
||||||
):
|
):
|
||||||
self.stack[-1] = ListIteratorVariable(
|
self.stack[-1] = ListIteratorVariable(
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
tos.force_unpack_var_sequence(self),
|
tos.force_unpack_var_sequence(self),
|
||||||
mutation_type=ValueMutationNew(),
|
mutation_type=ValueMutationNew(),
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -4214,7 +4214,6 @@ def _extract_anchors_from_expr(segment: str) -> Optional[_Anchors]:
|
||||||
# (x) + (y)
|
# (x) + (y)
|
||||||
# ~~^~~~~~~
|
# ~~^~~~~~~
|
||||||
while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
|
while (ch := lines[cur_lineno][cur_col]).isspace() or ch in ")\\#":
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
if ch in "\\#":
|
if ch in "\\#":
|
||||||
cur_lineno, cur_col = nextline(cur_lineno, cur_col)
|
cur_lineno, cur_col = nextline(cur_lineno, cur_col)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -317,7 +317,6 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
||||||
)
|
)
|
||||||
res_proxy.node.meta.update(meta.data)
|
res_proxy.node.meta.update(meta.data)
|
||||||
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
|
if self.fake_tensor_mode and (shape_env := self.fake_tensor_mode.shape_env):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
|
if symbol_to_path := compute_unbacked_bindings(shape_env, res_data):
|
||||||
res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
|
res_proxy.node.meta["unbacked_bindings"] = symbol_to_path
|
||||||
self.tracer.set_metadata(res_proxy.node, res_data)
|
self.tracer.set_metadata(res_proxy.node, res_data)
|
||||||
|
|
|
||||||
|
|
@ -2183,7 +2183,6 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
simplify=True,
|
simplify=True,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
node.meta["unbacked_bindings"] = unbacked_bindings
|
node.meta["unbacked_bindings"] = unbacked_bindings
|
||||||
|
|
||||||
assert len(self.unbacked_symbols) == 0
|
assert len(self.unbacked_symbols) == 0
|
||||||
|
|
|
||||||
|
|
@ -204,7 +204,6 @@ def run_functionalized_fw_and_collect_metadata(
|
||||||
suppress_pending = contextlib.nullcontext()
|
suppress_pending = contextlib.nullcontext()
|
||||||
fake_mode = detect_fake_mode()
|
fake_mode = detect_fake_mode()
|
||||||
if fake_mode and (shape_env := fake_mode.shape_env):
|
if fake_mode and (shape_env := fake_mode.shape_env):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
|
suppress_pending = shape_env.ignore_fresh_unbacked_symbols()
|
||||||
with disable_above, mode, suppress_pending:
|
with disable_above, mode, suppress_pending:
|
||||||
# precondition: The passed in function already handles unflattening inputs + flattening outputs
|
# precondition: The passed in function already handles unflattening inputs + flattening outputs
|
||||||
|
|
|
||||||
|
|
@ -746,7 +746,6 @@ class WhileLoopAutogradOp(torch.autograd.Function):
|
||||||
and (shape_env := loop_count.node.shape_env)
|
and (shape_env := loop_count.node.shape_env)
|
||||||
and loop_count in shape_env.pending_fresh_unbacked_symbols
|
and loop_count in shape_env.pending_fresh_unbacked_symbols
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
shape_env.pending_fresh_unbacked_symbols.remove(loop_count)
|
shape_env.pending_fresh_unbacked_symbols.remove(loop_count)
|
||||||
|
|
||||||
# Even when body function is not executed, we clone and unsqueeze the input
|
# Even when body function is not executed, we clone and unsqueeze the input
|
||||||
|
|
|
||||||
|
|
@ -132,7 +132,6 @@ def aoti_compile_and_package(
|
||||||
)
|
)
|
||||||
or (
|
or (
|
||||||
isinstance(package_path, (str, os.PathLike))
|
isinstance(package_path, (str, os.PathLike))
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
and os.fspath(package_path).endswith(".pt2")
|
and os.fspath(package_path).endswith(".pt2")
|
||||||
)
|
)
|
||||||
), (
|
), (
|
||||||
|
|
|
||||||
|
|
@ -557,7 +557,6 @@ class GPUDeviceBenchmarkMixin:
|
||||||
res = benchmarker.benchmark_gpu(fn)
|
res = benchmarker.benchmark_gpu(fn)
|
||||||
device_interface.synchronize() # shake out any CUDA errors
|
device_interface.synchronize() # shake out any CUDA errors
|
||||||
|
|
||||||
# pyrefly: ignore # bad-return
|
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1737,7 +1737,6 @@ class KernelArgs:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for outer, inner in chain(
|
for outer, inner in chain(
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
self.input_buffers.items(),
|
self.input_buffers.items(),
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.output_buffers.items(),
|
self.output_buffers.items(),
|
||||||
|
|
|
||||||
|
|
@ -1478,7 +1478,6 @@ class CppGemmTemplate(CppTemplate):
|
||||||
assert isinstance(template_buffer, ir.IRNode)
|
assert isinstance(template_buffer, ir.IRNode)
|
||||||
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
|
gemm_output_name = f"{template_buffer.get_name()}_GemmOut"
|
||||||
gemm_output_buffer = ir.Buffer(
|
gemm_output_buffer = ir.Buffer(
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
name=gemm_output_name,
|
name=gemm_output_name,
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
layout=template_buffer.layout,
|
layout=template_buffer.layout,
|
||||||
|
|
@ -1502,7 +1501,6 @@ class CppGemmTemplate(CppTemplate):
|
||||||
reindexers.append(None)
|
reindexers.append(None)
|
||||||
if i < len(epilogue_creators) - 1:
|
if i < len(epilogue_creators) - 1:
|
||||||
current_input_buffer = ir.Buffer(
|
current_input_buffer = ir.Buffer(
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
name=buffer_name,
|
name=buffer_name,
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
layout=template_buffer.layout,
|
layout=template_buffer.layout,
|
||||||
|
|
|
||||||
|
|
@ -822,7 +822,6 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||||
|
|
||||||
if triton:
|
if triton:
|
||||||
call_args, arg_types = self.prepare_triton_wrapper_args(
|
call_args, arg_types = self.prepare_triton_wrapper_args(
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
call_args,
|
call_args,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
arg_types,
|
arg_types,
|
||||||
|
|
|
||||||
|
|
@ -680,7 +680,6 @@ class MetalKernel(SIMDKernel):
|
||||||
)
|
)
|
||||||
idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment]
|
idx_val = self._new_idxvar(dtype, default_value=0, is_threadgroup=False) # type: ignore[assignment]
|
||||||
idx_var = next(
|
idx_var = next(
|
||||||
# pyrefly: ignore # missing-argument
|
|
||||||
t
|
t
|
||||||
for t in self.range_tree_nodes.values()
|
for t in self.range_tree_nodes.values()
|
||||||
# pyrefly: ignore # missing-argument
|
# pyrefly: ignore # missing-argument
|
||||||
|
|
@ -863,7 +862,6 @@ class MetalKernel(SIMDKernel):
|
||||||
|
|
||||||
if self.inside_reduction:
|
if self.inside_reduction:
|
||||||
total_reduction_size = math.prod(
|
total_reduction_size = math.prod(
|
||||||
# pyrefly: ignore # missing-argument
|
|
||||||
t.numel
|
t.numel
|
||||||
for t in self.range_trees
|
for t in self.range_trees
|
||||||
# pyrefly: ignore # missing-argument
|
# pyrefly: ignore # missing-argument
|
||||||
|
|
|
||||||
|
|
@ -965,7 +965,6 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||||
|
|
||||||
def active_range_trees(self) -> list[IterationRangesRoot]:
|
def active_range_trees(self) -> list[IterationRangesRoot]:
|
||||||
return [
|
return [
|
||||||
# pyrefly: ignore # missing-argument
|
|
||||||
t
|
t
|
||||||
for t in self.range_trees
|
for t in self.range_trees
|
||||||
# pyrefly: ignore # missing-argument
|
# pyrefly: ignore # missing-argument
|
||||||
|
|
|
||||||
|
|
@ -1036,7 +1036,6 @@ class FxConverter:
|
||||||
# Add constants stored as Triton metadata, in signature order.
|
# Add constants stored as Triton metadata, in signature order.
|
||||||
call_kwargs |= constants
|
call_kwargs |= constants
|
||||||
new_call_args = [
|
new_call_args = [
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
call_kwargs[key]
|
call_kwargs[key]
|
||||||
for key in signature
|
for key in signature
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
|
|
|
||||||
|
|
@ -826,11 +826,10 @@ def _schedule_for_comm(
|
||||||
collective_cost > 0
|
collective_cost > 0
|
||||||
and (candidate := get_overlapping_candidate()) is not None
|
and (candidate := get_overlapping_candidate()) is not None
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
ready.remove(candidate)
|
ready.remove(candidate)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
schedule(candidate.snode)
|
schedule(candidate.snode)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
collective_cost -= snode_to_cost[candidate.snode]
|
collective_cost -= snode_to_cost[candidate.snode]
|
||||||
heapq.heapify(ready)
|
heapq.heapify(ready)
|
||||||
|
|
||||||
|
|
@ -1098,7 +1097,7 @@ def _sink_waits_iterative_internal(
|
||||||
info.grouped_info = _group_names(gns)
|
info.grouped_info = _group_names(gns)
|
||||||
candidate = _next[candidate]
|
candidate = _next[candidate]
|
||||||
continue
|
continue
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
elif (data_dep is None) and both_contain_comms:
|
elif (data_dep is None) and both_contain_comms:
|
||||||
info.limiting_factor = (
|
info.limiting_factor = (
|
||||||
f"collective ordering {_group_names(gns)}"
|
f"collective ordering {_group_names(gns)}"
|
||||||
|
|
|
||||||
|
|
@ -271,7 +271,6 @@ def record_original_output_strides(gm: GraphModule) -> None:
|
||||||
and (val := output.meta.get("val")) is not None
|
and (val := output.meta.get("val")) is not None
|
||||||
and isinstance(val, torch.Tensor)
|
and isinstance(val, torch.Tensor)
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
output_strides.append(val.stride())
|
output_strides.append(val.stride())
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
|
|
||||||
|
|
@ -620,7 +620,6 @@ class _OutOfProcessFxCompile(_SerializedFxCompile):
|
||||||
|
|
||||||
if output.warning_replay:
|
if output.warning_replay:
|
||||||
for w in output.warning_replay:
|
for w in output.warning_replay:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
warnings.warn_explicit(
|
warnings.warn_explicit(
|
||||||
message=w.message,
|
message=w.message,
|
||||||
category=w.category,
|
category=w.category,
|
||||||
|
|
|
||||||
|
|
@ -544,7 +544,6 @@ def amax(
|
||||||
keepdim: bool = False,
|
keepdim: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.dtype == torch.bool:
|
if self.dtype == torch.bool:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return torch.any(self, dim=dim, keepdim=keepdim)
|
return torch.any(self, dim=dim, keepdim=keepdim)
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
|
|
@ -556,7 +555,6 @@ def amin(
|
||||||
keepdim: bool = False,
|
keepdim: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
if self.dtype == torch.bool:
|
if self.dtype == torch.bool:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return torch.all(self, dim=dim, keepdim=keepdim)
|
return torch.all(self, dim=dim, keepdim=keepdim)
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -238,7 +238,7 @@ class FakeTensorUpdater:
|
||||||
symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
|
symbol_to_path := compute_unbacked_bindings(shape_env, new_fake_tensor)
|
||||||
):
|
):
|
||||||
# Refresh the bindings to the new symbols
|
# Refresh the bindings to the new symbols
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
node.meta["unbacked_bindings"] = symbol_to_path
|
node.meta["unbacked_bindings"] = symbol_to_path
|
||||||
|
|
||||||
existing_storages[get_node_storage(node)] += 1
|
existing_storages[get_node_storage(node)] += 1
|
||||||
|
|
|
||||||
|
|
@ -6500,12 +6500,10 @@ def div_prim(a, b):
|
||||||
# see https://github.com/pytorch/pytorch/issues/157959
|
# see https://github.com/pytorch/pytorch/issues/157959
|
||||||
if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu":
|
if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu":
|
||||||
# Replace divide by constant with multiply by reciprocal
|
# Replace divide by constant with multiply by reciprocal
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
if divisor.value == 0:
|
if divisor.value == 0:
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
reciprocal = math.copysign(float("inf"), divisor.value)
|
reciprocal = math.copysign(float("inf"), divisor.value)
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
reciprocal = 1.0 / divisor.value
|
reciprocal = 1.0 / divisor.value
|
||||||
return mul(a, reciprocal)
|
return mul(a, reciprocal)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -131,7 +131,6 @@ def load_package(
|
||||||
)
|
)
|
||||||
return AOTICompiledModel(loader)
|
return AOTICompiledModel(loader)
|
||||||
|
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
|
path = os.fspath(path) # AOTIModelPackageLoader expects (str, str)
|
||||||
loader = torch._C._aoti.AOTIModelPackageLoader(
|
loader = torch._C._aoti.AOTIModelPackageLoader(
|
||||||
path, model_name, run_single_threaded, num_runners, device_index
|
path, model_name, run_single_threaded, num_runners, device_index
|
||||||
|
|
|
||||||
|
|
@ -2676,7 +2676,6 @@ class Scheduler:
|
||||||
and (dep := next(iter(node.read_writes.writes)))
|
and (dep := next(iter(node.read_writes.writes)))
|
||||||
and isinstance(dep, MemoryDep)
|
and isinstance(dep, MemoryDep)
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
node_mode = dep.mode
|
node_mode = dep.mode
|
||||||
else:
|
else:
|
||||||
node_mode = None
|
node_mode = None
|
||||||
|
|
@ -4360,7 +4359,6 @@ class Scheduler:
|
||||||
if config.expand_dimension_for_pointwise_nodes and (
|
if config.expand_dimension_for_pointwise_nodes and (
|
||||||
expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2)
|
expand_analysis := self.get_expand_dim_for_pointwise_nodes(node1, node2)
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
(expand_dim, smaller_node, expand_size) = expand_analysis
|
(expand_dim, smaller_node, expand_size) = expand_analysis
|
||||||
smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size)
|
smaller_node.expand_dimension_for_pointwise_node(expand_dim, expand_size)
|
||||||
shared_data_score = self.score_fusion_memory(node1, node2)
|
shared_data_score = self.score_fusion_memory(node1, node2)
|
||||||
|
|
@ -4669,7 +4667,6 @@ class Scheduler:
|
||||||
device.type == "cuda"
|
device.type == "cuda"
|
||||||
and (device_props := torch.cuda.get_device_properties(device)).major < 7
|
and (device_props := torch.cuda.get_device_properties(device)).major < 7
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
raise GPUTooOldForTriton(device_props, inspect.currentframe())
|
raise GPUTooOldForTriton(device_props, inspect.currentframe())
|
||||||
elif is_gpu(device.type) and not device.type == "mps":
|
elif is_gpu(device.type) and not device.type == "mps":
|
||||||
raise TritonMissing(inspect.currentframe())
|
raise TritonMissing(inspect.currentframe())
|
||||||
|
|
@ -4967,7 +4964,6 @@ class Scheduler:
|
||||||
if isinstance(buf.node, ir.MutationOutput) and (
|
if isinstance(buf.node, ir.MutationOutput) and (
|
||||||
real_name := self.mutation_real_name.get(buf_name, None)
|
real_name := self.mutation_real_name.get(buf_name, None)
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
return is_none_layout(real_name)
|
return is_none_layout(real_name)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -3681,8 +3681,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||||
),
|
),
|
||||||
node.get_device(),
|
node.get_device(),
|
||||||
node.get_dtype(),
|
node.get_dtype(),
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
V.graph.sizevars.atomically_apply_size_hint(
|
V.graph.sizevars.atomically_apply_size_hint(
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
node.layout.offset,
|
node.layout.offset,
|
||||||
fallback=config.unbacked_symint_fallback,
|
fallback=config.unbacked_symint_fallback,
|
||||||
hint_override=hint_override,
|
hint_override=hint_override,
|
||||||
|
|
|
||||||
|
|
@ -1652,7 +1652,7 @@ class MMTemplateConfigMixin(GemmMaxAutotuneTemplateConfigHeuristics):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Build options dict
|
# Build options dict
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
options_dict = dict(
|
options_dict = dict(
|
||||||
EVEN_K=even_k_symbolic,
|
EVEN_K=even_k_symbolic,
|
||||||
USE_FAST_ACCUM=False, # Option for _scaled_mm
|
USE_FAST_ACCUM=False, # Option for _scaled_mm
|
||||||
|
|
|
||||||
|
|
@ -3764,7 +3764,6 @@ def maybe_log_cudagraph_partition(
|
||||||
and (fx_node := ir_node.get_origin_node())
|
and (fx_node := ir_node.get_origin_node())
|
||||||
and (stack_trace := fx_node.meta.get("stack_trace", None))
|
and (stack_trace := fx_node.meta.get("stack_trace", None))
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
warning_msg = f"{warning_msg}. Found from : \n {stack_trace}"
|
warning_msg = f"{warning_msg}. Found from : \n {stack_trace}"
|
||||||
|
|
||||||
perf_hint_log.warning(warning_msg)
|
perf_hint_log.warning(warning_msg)
|
||||||
|
|
|
||||||
|
|
@ -144,7 +144,6 @@ def benchmark_all_kernels(
|
||||||
launcher = triton_kernel.launchers[0]
|
launcher = triton_kernel.launchers[0]
|
||||||
print(
|
print(
|
||||||
get_info_str(
|
get_info_str(
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
ms,
|
ms,
|
||||||
launcher.n_regs,
|
launcher.n_regs,
|
||||||
launcher.n_spills,
|
launcher.n_spills,
|
||||||
|
|
|
||||||
|
|
@ -246,7 +246,6 @@ def save_op_profiles(op_profiles: dict[str, set[OpProfile]], f: FileLike) -> Non
|
||||||
yaml_str = generate_yaml_from_profiles(op_profiles)
|
yaml_str = generate_yaml_from_profiles(op_profiles)
|
||||||
|
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
f = os.fspath(f)
|
f = os.fspath(f)
|
||||||
|
|
||||||
with open(f, "w") as file:
|
with open(f, "w") as file:
|
||||||
|
|
@ -312,7 +311,6 @@ def load_op_profiles(f: FileLike) -> dict[str, set[OpProfile]]:
|
||||||
Loads the saved operator profiles from `save_op_profiles`.
|
Loads the saved operator profiles from `save_op_profiles`.
|
||||||
"""
|
"""
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
f = os.fspath(f)
|
f = os.fspath(f)
|
||||||
|
|
||||||
with open(f) as file:
|
with open(f) as file:
|
||||||
|
|
|
||||||
|
|
@ -173,7 +173,7 @@ def register_opaque_type(cls: Any, name: Optional[str] = None) -> None:
|
||||||
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
|
f"Unable to accept name, {name}, for this opaque type as it contains a '.'"
|
||||||
)
|
)
|
||||||
_OPAQUE_TYPES[cls] = name
|
_OPAQUE_TYPES[cls] = name
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
torch._C._register_opaque_type(name)
|
torch._C._register_opaque_type(name)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -183,5 +183,5 @@ def is_opaque_type(cls: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
if cls not in _OPAQUE_TYPES:
|
if cls not in _OPAQUE_TYPES:
|
||||||
return False
|
return False
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])
|
return torch._C._is_opaque_type_registered(_OPAQUE_TYPES[cls])
|
||||||
|
|
|
||||||
|
|
@ -914,7 +914,6 @@ class TorchLogsFormatter(logging.Formatter):
|
||||||
and (trace_id := torch._guards.CompileContext.current_trace_id())
|
and (trace_id := torch._guards.CompileContext.current_trace_id())
|
||||||
is not None
|
is not None
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
record.traceid = f" [{trace_id}]"
|
record.traceid = f" [{trace_id}]"
|
||||||
|
|
||||||
glog_level_to_abbr = {
|
glog_level_to_abbr = {
|
||||||
|
|
|
||||||
|
|
@ -1336,9 +1336,9 @@ def float_power(
|
||||||
|
|
||||||
# Float power has the following contiguous cast behavior to be
|
# Float power has the following contiguous cast behavior to be
|
||||||
# consistent with its C++ impl
|
# consistent with its C++ impl
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
a = _maybe_convert_to_dtype(a, dtype)
|
a = _maybe_convert_to_dtype(a, dtype)
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
b = _maybe_convert_to_dtype(b, dtype)
|
b = _maybe_convert_to_dtype(b, dtype)
|
||||||
|
|
||||||
a, b = _maybe_broadcast(a, b)
|
a, b = _maybe_broadcast(a, b)
|
||||||
|
|
@ -2348,7 +2348,6 @@ def all(
|
||||||
dim: Optional[DimsType] = None,
|
dim: Optional[DimsType] = None,
|
||||||
keepdim: bool = False,
|
keepdim: bool = False,
|
||||||
) -> TensorLikeType:
|
) -> TensorLikeType:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim))
|
result = torch.logical_not(torch.any(torch.logical_not(a), dim, keepdim=keepdim))
|
||||||
|
|
||||||
if a.dtype == torch.uint8:
|
if a.dtype == torch.uint8:
|
||||||
|
|
@ -3245,7 +3244,7 @@ def _normalize(
|
||||||
mean (Tensor): mean of the tensor along norm_dims.
|
mean (Tensor): mean of the tensor along norm_dims.
|
||||||
rstd (Tensor): 1/std of the tensor along norm_dims.
|
rstd (Tensor): 1/std of the tensor along norm_dims.
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
|
norm_dims = utils.canonicalize_dims(a.ndim, norm_dims)
|
||||||
computation_dtype = utils.get_computation_dtype(a.dtype)
|
computation_dtype = utils.get_computation_dtype(a.dtype)
|
||||||
a_acc = _maybe_convert_to_dtype(a, computation_dtype)
|
a_acc = _maybe_convert_to_dtype(a, computation_dtype)
|
||||||
|
|
@ -3975,7 +3974,7 @@ def reshape_as(self: TensorLikeType, other: TensorLikeType) -> TensorLikeType:
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType:
|
def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLikeType:
|
||||||
"""Reference implementation of :func:`torch.roll`."""
|
"""Reference implementation of :func:`torch.roll`."""
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
dims = utils.canonicalize_dims(a.ndim, dims)
|
dims = utils.canonicalize_dims(a.ndim, dims)
|
||||||
# ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
|
# ATen specifies int[1] type for shifts and dims which expands integers to tuples of length 1
|
||||||
if not isinstance(shifts, Iterable):
|
if not isinstance(shifts, Iterable):
|
||||||
|
|
@ -4286,7 +4285,7 @@ def squeeze(a: TensorLikeType, dim: Optional[DimsType] = None) -> TensorLikeType
|
||||||
return prims.squeeze(a, dims) if dims else prims.view_of(a)
|
return prims.squeeze(a, dims) if dims else prims.view_of(a)
|
||||||
|
|
||||||
ndim = a.ndim
|
ndim = a.ndim
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
dim = utils.canonicalize_dims(ndim, dim)
|
dim = utils.canonicalize_dims(ndim, dim)
|
||||||
dims = (dim,) if isinstance(dim, Dim) else dim
|
dims = (dim,) if isinstance(dim, Dim) else dim
|
||||||
# Short-circuits if the tensor has no dimensions
|
# Short-circuits if the tensor has no dimensions
|
||||||
|
|
|
||||||
|
|
@ -216,7 +216,7 @@ def matrix_norm(
|
||||||
# shape
|
# shape
|
||||||
check_is_matrix(A, "linalg.matrix_norm")
|
check_is_matrix(A, "linalg.matrix_norm")
|
||||||
# dim
|
# dim
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
dim = utils.canonicalize_dims(A.ndim, dim)
|
dim = utils.canonicalize_dims(A.ndim, dim)
|
||||||
if isinstance(dim, Dim):
|
if isinstance(dim, Dim):
|
||||||
dim = (dim,) # type: ignore[assignment]
|
dim = (dim,) # type: ignore[assignment]
|
||||||
|
|
|
||||||
|
|
@ -2620,7 +2620,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
and s.rhs == 1
|
and s.rhs == 1
|
||||||
):
|
):
|
||||||
assert self.shape_env is not None
|
assert self.shape_env is not None
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
self.shape_env.set_unbacked_var_to_val(s, int(real_t))
|
self.shape_env.set_unbacked_var_to_val(s, int(real_t))
|
||||||
|
|
||||||
if real_out is not nil:
|
if real_out is not nil:
|
||||||
|
|
|
||||||
|
|
@ -1110,7 +1110,6 @@ class Tensor(torch._C.TensorBase):
|
||||||
|
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||||
def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
def __rsub__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return _C._VariableFunctions.rsub(self, other)
|
return _C._VariableFunctions.rsub(self, other)
|
||||||
|
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||||
|
|
@ -1137,7 +1136,7 @@ class Tensor(torch._C.TensorBase):
|
||||||
|
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||||
def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
def __rmod__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||||
return torch.remainder(other, self) # pyrefly: ignore # no-matching-overload
|
return torch.remainder(other, self)
|
||||||
|
|
||||||
def __format__(self, format_spec):
|
def __format__(self, format_spec):
|
||||||
if has_torch_function_unary(self):
|
if has_torch_function_unary(self):
|
||||||
|
|
@ -1150,7 +1149,7 @@ class Tensor(torch._C.TensorBase):
|
||||||
|
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||||
def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
def __rpow__(self, other: Union["Tensor", int, float, bool, complex]) -> "Tensor":
|
||||||
return torch.pow(other, self) # pyrefly: ignore # no-matching-overload
|
return torch.pow(other, self)
|
||||||
|
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||||
def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
|
def __floordiv__(self, other: Union["Tensor", int, float, bool]) -> "Tensor": # type: ignore[override]
|
||||||
|
|
@ -1166,14 +1165,12 @@ class Tensor(torch._C.TensorBase):
|
||||||
def __rlshift__(
|
def __rlshift__(
|
||||||
self, other: Union["Tensor", int, float, bool, complex]
|
self, other: Union["Tensor", int, float, bool, complex]
|
||||||
) -> "Tensor":
|
) -> "Tensor":
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return torch.bitwise_left_shift(other, self)
|
return torch.bitwise_left_shift(other, self)
|
||||||
|
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||||
def __rrshift__(
|
def __rrshift__(
|
||||||
self, other: Union["Tensor", int, float, bool, complex]
|
self, other: Union["Tensor", int, float, bool, complex]
|
||||||
) -> "Tensor":
|
) -> "Tensor":
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return torch.bitwise_right_shift(other, self)
|
return torch.bitwise_right_shift(other, self)
|
||||||
|
|
||||||
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
@_handle_torch_function_and_wrap_type_error_to_not_implemented
|
||||||
|
|
|
||||||
|
|
@ -744,10 +744,7 @@ class ExceptionWrapper:
|
||||||
if exc_info is None:
|
if exc_info is None:
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
self.exc_type = exc_info[0]
|
self.exc_type = exc_info[0]
|
||||||
self.exc_msg = "".join(
|
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
traceback.format_exception(*exc_info)
|
|
||||||
)
|
|
||||||
self.where = where
|
self.where = where
|
||||||
|
|
||||||
def reraise(self):
|
def reraise(self):
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,6 @@ def compile_time_strobelight_meta(
|
||||||
skip := kwargs["skip"],
|
skip := kwargs["skip"],
|
||||||
int,
|
int,
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
kwargs["skip"] = skip + 1
|
kwargs["skip"] = skip + 1
|
||||||
|
|
||||||
# This is not needed but we have it here to avoid having profile_compile_time
|
# This is not needed but we have it here to avoid having profile_compile_time
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,7 @@ class Conv1d(_ConvNd, nn.Conv1d):
|
||||||
and the backend should be able to fuse the ops with `*` into a quantized conv1d
|
and the backend should be able to fuse the ops with `*` into a quantized conv1d
|
||||||
"""
|
"""
|
||||||
weight_quant_dequant = self.get_weight()
|
weight_quant_dequant = self.get_weight()
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
result = F.conv1d(
|
result = F.conv1d(
|
||||||
x,
|
x,
|
||||||
weight_quant_dequant,
|
weight_quant_dequant,
|
||||||
|
|
@ -160,7 +160,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
|
||||||
and the backend should be able to fuse the ops with `*` into a quantized conv2d
|
and the backend should be able to fuse the ops with `*` into a quantized conv2d
|
||||||
"""
|
"""
|
||||||
weight_quant_dequant = self.get_weight()
|
weight_quant_dequant = self.get_weight()
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
result = F.conv2d(
|
result = F.conv2d(
|
||||||
x,
|
x,
|
||||||
weight_quant_dequant,
|
weight_quant_dequant,
|
||||||
|
|
@ -225,7 +225,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
|
||||||
and the backend should be able to fuse the ops with `*` into a quantized conv3d
|
and the backend should be able to fuse the ops with `*` into a quantized conv3d
|
||||||
"""
|
"""
|
||||||
weight_quant_dequant = self.get_weight()
|
weight_quant_dequant = self.get_weight()
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
result = F.conv3d(
|
result = F.conv3d(
|
||||||
x,
|
x,
|
||||||
weight_quant_dequant,
|
weight_quant_dequant,
|
||||||
|
|
|
||||||
|
|
@ -1095,6 +1095,7 @@ def create_a_shadows_b(
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore # unbound-name
|
||||||
if not isinstance(input_logger, list):
|
if not isinstance(input_logger, list):
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
|
# pyrefly: ignore # unbound-name
|
||||||
f"Expected list, got {type(input_logger)}"
|
f"Expected list, got {type(input_logger)}"
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore # unbound-name
|
||||||
|
|
|
||||||
|
|
@ -127,7 +127,6 @@ class AdaptiveRoundingOptimizer:
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def feed_forward(self, x, weight, module):
|
def feed_forward(self, x, weight, module):
|
||||||
if isinstance(module, torch.nn.Conv1d):
|
if isinstance(module, torch.nn.Conv1d):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
out = torch.nn.functional.conv1d(
|
out = torch.nn.functional.conv1d(
|
||||||
x,
|
x,
|
||||||
weight,
|
weight,
|
||||||
|
|
|
||||||
|
|
@ -90,7 +90,6 @@ def _find_q_dq_node_for_user(
|
||||||
and arg.op == "call_function"
|
and arg.op == "call_function"
|
||||||
and arg.target in _QUANTIZE_OPS
|
and arg.target in _QUANTIZE_OPS
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
q_node = arg
|
q_node = arg
|
||||||
return (q_node, dq_node)
|
return (q_node, dq_node)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ def _make_grads(
|
||||||
is_grads_batched: bool,
|
is_grads_batched: bool,
|
||||||
) -> tuple[_OptionalTensor, ...]:
|
) -> tuple[_OptionalTensor, ...]:
|
||||||
new_grads: list[_OptionalTensor] = []
|
new_grads: list[_OptionalTensor] = []
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
for out, grad in zip(outputs, grads):
|
for out, grad in zip(outputs, grads):
|
||||||
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
|
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
|
||||||
out_size = None
|
out_size = None
|
||||||
|
|
|
||||||
|
|
@ -155,25 +155,21 @@ class cuBLASModule:
|
||||||
if name == "allow_tf32":
|
if name == "allow_tf32":
|
||||||
return torch._C._get_cublas_allow_tf32()
|
return torch._C._get_cublas_allow_tf32()
|
||||||
elif name == "allow_fp16_reduced_precision_reduction":
|
elif name == "allow_fp16_reduced_precision_reduction":
|
||||||
# pyrefly: ignore # not-iterable
|
|
||||||
allow_reduced_precision, _ = (
|
allow_reduced_precision, _ = (
|
||||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
||||||
)
|
)
|
||||||
return allow_reduced_precision
|
return allow_reduced_precision
|
||||||
elif name == "allow_fp16_reduced_precision_reduction_split_k":
|
elif name == "allow_fp16_reduced_precision_reduction_split_k":
|
||||||
# pyrefly: ignore # not-iterable
|
|
||||||
_, allow_splitk = (
|
_, allow_splitk = (
|
||||||
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
torch._C._get_cublas_allow_fp16_reduced_precision_reduction()
|
||||||
)
|
)
|
||||||
return allow_splitk
|
return allow_splitk
|
||||||
elif name == "allow_bf16_reduced_precision_reduction":
|
elif name == "allow_bf16_reduced_precision_reduction":
|
||||||
# pyrefly: ignore # not-iterable
|
|
||||||
allow_reduced_precision, _ = (
|
allow_reduced_precision, _ = (
|
||||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
||||||
)
|
)
|
||||||
return allow_reduced_precision
|
return allow_reduced_precision
|
||||||
elif name == "allow_bf16_reduced_precision_reduction_split_k":
|
elif name == "allow_bf16_reduced_precision_reduction_split_k":
|
||||||
# pyrefly: ignore # not-iterable
|
|
||||||
_, allow_splitk = (
|
_, allow_splitk = (
|
||||||
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
torch._C._get_cublas_allow_bf16_reduced_precision_reduction()
|
||||||
)
|
)
|
||||||
|
|
@ -193,7 +189,6 @@ class cuBLASModule:
|
||||||
)
|
)
|
||||||
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(
|
return torch._C._set_cublas_allow_fp16_reduced_precision_reduction(
|
||||||
allow_reduced_precision,
|
allow_reduced_precision,
|
||||||
# pyrefly: ignore # bad-argument-count
|
|
||||||
allow_splitk,
|
allow_splitk,
|
||||||
)
|
)
|
||||||
elif name == "allow_bf16_reduced_precision_reduction":
|
elif name == "allow_bf16_reduced_precision_reduction":
|
||||||
|
|
@ -202,7 +197,6 @@ class cuBLASModule:
|
||||||
)
|
)
|
||||||
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(
|
return torch._C._set_cublas_allow_bf16_reduced_precision_reduction(
|
||||||
allow_reduced_precision,
|
allow_reduced_precision,
|
||||||
# pyrefly: ignore # bad-argument-count
|
|
||||||
allow_splitk,
|
allow_splitk,
|
||||||
)
|
)
|
||||||
elif name == "allow_fp16_accumulation":
|
elif name == "allow_fp16_accumulation":
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ if hasattr(torch._C, "_CUDAGreenContext"):
|
||||||
|
|
||||||
|
|
||||||
# Python shim helps Sphinx process docstrings more reliably.
|
# Python shim helps Sphinx process docstrings more reliably.
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
class GreenContext(_GreenContext):
|
class GreenContext(_GreenContext):
|
||||||
r"""Wrapper around a CUDA green context.
|
r"""Wrapper around a CUDA green context.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,6 @@ def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
|
||||||
|
|
||||||
checkpoint_id = kwargs.get("checkpoint_id")
|
checkpoint_id = kwargs.get("checkpoint_id")
|
||||||
if not checkpoint_id and (serializer := storage_writer or storage_reader):
|
if not checkpoint_id and (serializer := storage_writer or storage_reader):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
checkpoint_id = getattr(serializer, "checkpoint_id", None)
|
checkpoint_id = getattr(serializer, "checkpoint_id", None)
|
||||||
|
|
||||||
msg_dict["checkpoint_id"] = (
|
msg_dict["checkpoint_id"] = (
|
||||||
|
|
|
||||||
|
|
@ -1227,7 +1227,6 @@ def _unflatten_model_state_dict(
|
||||||
if not state_dict:
|
if not state_dict:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
if isinstance(next(iter(state_dict.keys())), nn.Module):
|
if isinstance(next(iter(state_dict.keys())), nn.Module):
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``"
|
"Passing model_state_dict as a ``Dict[nn.Module, Dict[str, Any]]``"
|
||||||
|
|
|
||||||
|
|
@ -393,7 +393,7 @@ class BackendConfig:
|
||||||
# e.g. "nccl", "gloo", "ucc", "mpi"
|
# e.g. "nccl", "gloo", "ucc", "mpi"
|
||||||
supported_devices = Backend.backend_capability[backend.lower()]
|
supported_devices = Backend.backend_capability[backend.lower()]
|
||||||
backend_val = Backend(backend)
|
backend_val = Backend(backend)
|
||||||
# pyrefly: ignore # bad-assignment
|
|
||||||
self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
|
self.device_backend_map = dict.fromkeys(supported_devices, backend_val)
|
||||||
elif ":" in backend.lower():
|
elif ":" in backend.lower():
|
||||||
# Backend specified in "device:backend" format
|
# Backend specified in "device:backend" format
|
||||||
|
|
|
||||||
|
|
@ -290,7 +290,6 @@ def _shard_dict_of_args(
|
||||||
f"Unsupported chunk spec: {spec} and value: {v} combination."
|
f"Unsupported chunk spec: {spec} and value: {v} combination."
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
for _flat_split_result, _v_split in zip(
|
for _flat_split_result, _v_split in zip(
|
||||||
flat_split_results, v_splits, strict=True
|
flat_split_results, v_splits, strict=True
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -1327,7 +1327,6 @@ class _ContextParallel(ParallelStyle):
|
||||||
placement = [Shard(self.seq_dim)]
|
placement = [Shard(self.seq_dim)]
|
||||||
all_args = []
|
all_args = []
|
||||||
|
|
||||||
# pyrefly: ignore # bad-assignment, bad-argument-type
|
|
||||||
for arg in itertools.chain(args, kwargs.values()):
|
for arg in itertools.chain(args, kwargs.values()):
|
||||||
if isinstance(arg, torch.Tensor):
|
if isinstance(arg, torch.Tensor):
|
||||||
if isinstance(arg, DTensor):
|
if isinstance(arg, DTensor):
|
||||||
|
|
|
||||||
|
|
@ -548,7 +548,7 @@ class PrepareModuleInput(ParallelStyle):
|
||||||
assert self.desired_input_layouts is not None, (
|
assert self.desired_input_layouts is not None, (
|
||||||
"desired module inputs should not be None!"
|
"desired module inputs should not be None!"
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
for inp, input_layout, desired_layout in zip(
|
for inp, input_layout, desired_layout in zip(
|
||||||
inputs, self.input_layouts, self.desired_input_layouts
|
inputs, self.input_layouts, self.desired_input_layouts
|
||||||
):
|
):
|
||||||
|
|
@ -664,7 +664,7 @@ class PrepareModuleOutput(ParallelStyle):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"module outputs and output_layouts should have same length!"
|
"module outputs and output_layouts should have same length!"
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
for out, out_layout, desired_out_layout in zip(
|
for out, out_layout, desired_out_layout in zip(
|
||||||
outputs, self.output_layouts, self.desired_output_layouts
|
outputs, self.output_layouts, self.desired_output_layouts
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ def _cast_forward_inputs(
|
||||||
def cast_fn(x: torch.Tensor) -> torch.Tensor:
|
def cast_fn(x: torch.Tensor) -> torch.Tensor:
|
||||||
if not torch.is_floating_point(x) or x.dtype == dtype:
|
if not torch.is_floating_point(x) or x.dtype == dtype:
|
||||||
return x
|
return x
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return x.to(dtype)
|
return x.to(dtype)
|
||||||
|
|
||||||
return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
|
return (_apply_to_tensors(cast_fn, args), _apply_to_tensors(cast_fn, kwargs))
|
||||||
|
|
|
||||||
|
|
@ -436,7 +436,6 @@ def load(
|
||||||
print(ep(torch.randn(5)))
|
print(ep(torch.randn(5)))
|
||||||
"""
|
"""
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
f = os.fspath(f)
|
f = os.fspath(f)
|
||||||
|
|
||||||
extra_files = extra_files or {}
|
extra_files = extra_files or {}
|
||||||
|
|
|
||||||
|
|
@ -514,7 +514,6 @@ def _replace_unbacked_bindings(gm: torch.fx.GraphModule) -> None:
|
||||||
simplify=True,
|
simplify=True,
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
node.meta["unbacked_bindings"] = unbacked_bindings
|
node.meta["unbacked_bindings"] = unbacked_bindings
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -683,7 +683,6 @@ def package_pt2(
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
(isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable())
|
(isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable())
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
|
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
|
||||||
or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2"))
|
or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2"))
|
||||||
):
|
):
|
||||||
|
|
@ -695,7 +694,6 @@ def package_pt2(
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
f = os.fspath(f)
|
f = os.fspath(f)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
|
@ -1086,7 +1084,6 @@ def load_pt2(
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
(isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable())
|
(isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable())
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
|
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
|
||||||
):
|
):
|
||||||
# TODO: turn this into an error in 2.9
|
# TODO: turn this into an error in 2.9
|
||||||
|
|
@ -1097,7 +1094,6 @@ def load_pt2(
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
f = os.fspath(f)
|
f = os.fspath(f)
|
||||||
|
|
||||||
weights = {}
|
weights = {}
|
||||||
|
|
@ -1167,7 +1163,6 @@ def load_pt2(
|
||||||
else:
|
else:
|
||||||
aoti_runners = {
|
aoti_runners = {
|
||||||
model_name: _load_aoti(
|
model_name: _load_aoti(
|
||||||
# pyrefly: ignore # bad-argument-type
|
|
||||||
f,
|
f,
|
||||||
model_name,
|
model_name,
|
||||||
run_single_threaded,
|
run_single_threaded,
|
||||||
|
|
|
||||||
|
|
@ -916,7 +916,6 @@ def fetch_object_proxy(
|
||||||
def fetch_object_proxy(
|
def fetch_object_proxy(
|
||||||
tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType]
|
tracer: _ProxyTracer, t: Union[Tensor, _AnyScriptObjectType, PySymType]
|
||||||
) -> object:
|
) -> object:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return get_proxy_slot(t, tracer, t)
|
return get_proxy_slot(t, tracer, t)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -965,7 +964,6 @@ def _fetch_proxies_and_all_constant_flag(
|
||||||
"""
|
"""
|
||||||
f_flat_args_kwargs = [
|
f_flat_args_kwargs = [
|
||||||
(
|
(
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
fetch_object_proxy(tracer, x)
|
fetch_object_proxy(tracer, x)
|
||||||
if isinstance(x, (Tensor, _AnyScriptObject))
|
if isinstance(x, (Tensor, _AnyScriptObject))
|
||||||
else x
|
else x
|
||||||
|
|
@ -2497,7 +2495,6 @@ class _MakefxTracer:
|
||||||
):
|
):
|
||||||
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
from torch.fx.passes.runtime_assert import insert_deferred_runtime_asserts
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx")
|
insert_deferred_runtime_asserts(t, fake_mode.shape_env, "reenter_make_fx")
|
||||||
t.recompile()
|
t.recompile()
|
||||||
# TODO: kind of a bad way to do it, should maybe figure out a better way
|
# TODO: kind of a bad way to do it, should maybe figure out a better way
|
||||||
|
|
|
||||||
|
|
@ -621,13 +621,12 @@ def rebind_unbacked(
|
||||||
):
|
):
|
||||||
# This is what the pattern match above is testing
|
# This is what the pattern match above is testing
|
||||||
repacked = _sympy_cast_symbool_to_symint_guardless(
|
repacked = _sympy_cast_symbool_to_symint_guardless(
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
sympy.Eq(new_raw_u1, 1)
|
sympy.Eq(new_raw_u1, 1)
|
||||||
)
|
)
|
||||||
assert repacked == raw_u1, f"{repacked} != {raw_u1}"
|
assert repacked == raw_u1, f"{repacked} != {raw_u1}"
|
||||||
# Cancel the to_int(to_bool(x)). This is sound because x in
|
# Cancel the to_int(to_bool(x)). This is sound because x in
|
||||||
# [0, 1]
|
# [0, 1]
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
raw_u1 = new_raw_u1
|
raw_u1 = new_raw_u1
|
||||||
|
|
||||||
if not isinstance(raw_u1, sympy.Symbol):
|
if not isinstance(raw_u1, sympy.Symbol):
|
||||||
|
|
@ -1055,7 +1054,6 @@ def find_symbol_binding_fx_nodes(
|
||||||
# NB: Prefer first occurrence of symbol
|
# NB: Prefer first occurrence of symbol
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
if (s := is_symbol_binding_fx_node(node)) is not None and s not in r:
|
if (s := is_symbol_binding_fx_node(node)) is not None and s not in r:
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
r[s] = node
|
r[s] = node
|
||||||
return r
|
return r
|
||||||
|
|
||||||
|
|
@ -1226,13 +1224,12 @@ def _free_unbacked_symbols_with_path(
|
||||||
and isinstance(s := expr(a), sympy.Symbol)
|
and isinstance(s := expr(a), sympy.Symbol)
|
||||||
and s in pending
|
and s in pending
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
r[s] = path
|
r[s] = path
|
||||||
if shape_env and real is not None:
|
if shape_env and real is not None:
|
||||||
assert isinstance(real, (int, float))
|
assert isinstance(real, (int, float))
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
shape_env.set_unbacked_var_to_val(s, real)
|
shape_env.set_unbacked_var_to_val(s, real)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
pending.remove(s)
|
pending.remove(s)
|
||||||
# When an unbacked SymInt is perfectly divisible by an integer
|
# When an unbacked SymInt is perfectly divisible by an integer
|
||||||
# constant, we replace it with the integer constant to improve
|
# constant, we replace it with the integer constant to improve
|
||||||
|
|
@ -1262,14 +1259,10 @@ def _free_unbacked_symbols_with_path(
|
||||||
source=shape_env.var_to_sources.get(s, [None])[0], # type: ignore[union-attr]
|
source=shape_env.var_to_sources.get(s, [None])[0], # type: ignore[union-attr]
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
unbacked = lhs if lhs in pending else rhs
|
unbacked = lhs if lhs in pending else rhs
|
||||||
divisor: IntLikeType = (
|
divisor: IntLikeType = (
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
int(coeff)
|
int(coeff)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
if shape_env and isinstance(coeff, sympy.Integer)
|
if shape_env and isinstance(coeff, sympy.Integer)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
else _symint_wrap(coeff)
|
else _symint_wrap(coeff)
|
||||||
)
|
)
|
||||||
# TODO: DivideByKey needs to test divisibility at runtime!
|
# TODO: DivideByKey needs to test divisibility at runtime!
|
||||||
|
|
@ -1278,11 +1271,8 @@ def _free_unbacked_symbols_with_path(
|
||||||
if real is not None:
|
if real is not None:
|
||||||
assert isinstance(real, int)
|
assert isinstance(real, int)
|
||||||
val = (
|
val = (
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
real // int(coeff)
|
real // int(coeff)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
if isinstance(coeff, sympy.Integer)
|
if isinstance(coeff, sympy.Integer)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
else CleanDiv(real, coeff)
|
else CleanDiv(real, coeff)
|
||||||
)
|
)
|
||||||
if shape_env:
|
if shape_env:
|
||||||
|
|
@ -1299,14 +1289,12 @@ def _free_unbacked_symbols_with_path(
|
||||||
and s.rhs == 1
|
and s.rhs == 1
|
||||||
and s.lhs in pending
|
and s.lhs in pending
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unsupported-operation
|
|
||||||
r[s.lhs] = path + (ConvertIntKey(),)
|
r[s.lhs] = path + (ConvertIntKey(),)
|
||||||
if real is not None:
|
if real is not None:
|
||||||
assert type(real) is bool
|
assert type(real) is bool
|
||||||
if shape_env:
|
if shape_env:
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
shape_env.set_unbacked_var_to_val(s, int(real))
|
shape_env.set_unbacked_var_to_val(s, int(real))
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
pending.remove(s.lhs)
|
pending.remove(s.lhs)
|
||||||
|
|
||||||
return r
|
return r
|
||||||
|
|
@ -1382,7 +1370,6 @@ def compute_unbacked_bindings(
|
||||||
):
|
):
|
||||||
if (
|
if (
|
||||||
isinstance(old_sym, SymTypes)
|
isinstance(old_sym, SymTypes)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
and (old_s := old_sym.node.expr) != new_s
|
and (old_s := old_sym.node.expr) != new_s
|
||||||
):
|
):
|
||||||
# If old_s is not an unbacked_symbol,
|
# If old_s is not an unbacked_symbol,
|
||||||
|
|
@ -1392,15 +1379,12 @@ def compute_unbacked_bindings(
|
||||||
# and the original symbol gets replaced by the backed symbol.
|
# and the original symbol gets replaced by the backed symbol.
|
||||||
# When this happens we just replace new_s by the old_s
|
# When this happens we just replace new_s by the old_s
|
||||||
# because we know the value is the same.
|
# because we know the value is the same.
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s):
|
if isinstance(old_s, sympy.Symbol) and free_unbacked_symbols(old_s):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
shape_env._rename_unbacked_to(new_s, old_s)
|
shape_env._rename_unbacked_to(new_s, old_s)
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
shape_env._eliminate_unbacked(new_s, old_s)
|
shape_env._eliminate_unbacked(new_s, old_s)
|
||||||
elif not isinstance(old_sym, SymTypes):
|
elif not isinstance(old_sym, SymTypes):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym))
|
shape_env._eliminate_unbacked(new_s, sympy.sympify(old_sym))
|
||||||
|
|
||||||
return symbol_to_path
|
return symbol_to_path
|
||||||
|
|
@ -3365,7 +3349,7 @@ class DimConstraints:
|
||||||
and str(symbol := next(iter(c["eq"].free_symbols))) == old_root
|
and str(symbol := next(iter(c["eq"].free_symbols))) == old_root
|
||||||
): # derived dim with root = old_root
|
): # derived dim with root = old_root
|
||||||
new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1
|
new_root_expr = results[str(old_root)]["eq"] # dx=3*_dx+1
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1
|
new_expr = c["eq"].subs({symbol: new_root_expr}) # dy=(3*_dx+1)+1
|
||||||
c["eq"] = new_expr
|
c["eq"] = new_expr
|
||||||
|
|
||||||
|
|
@ -7630,10 +7614,9 @@ class ShapeEnv:
|
||||||
log.info(
|
log.info(
|
||||||
"oblivious_size %s -> %s (passed counterfactual)",
|
"oblivious_size %s -> %s (passed counterfactual)",
|
||||||
orig_expr,
|
orig_expr,
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
correct_hint,
|
correct_hint,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
concrete_val = correct_hint
|
concrete_val = correct_hint
|
||||||
# NB: do NOT transmute into runtime assert
|
# NB: do NOT transmute into runtime assert
|
||||||
ok = True
|
ok = True
|
||||||
|
|
@ -7650,10 +7633,9 @@ class ShapeEnv:
|
||||||
).xreplace(self.var_to_val)
|
).xreplace(self.var_to_val)
|
||||||
).free_symbols
|
).free_symbols
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
self._log_real_tensor_propagation(orig_expr, unsound_result)
|
self._log_real_tensor_propagation(orig_expr, unsound_result)
|
||||||
transmute_into_runtime_assert = True
|
transmute_into_runtime_assert = True
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
concrete_val = unsound_result
|
concrete_val = unsound_result
|
||||||
ok = True
|
ok = True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1314,6 +1314,7 @@ class Graph:
|
||||||
f(to_erase)
|
f(to_erase)
|
||||||
|
|
||||||
self._find_nodes_lookup_table.remove(to_erase)
|
self._find_nodes_lookup_table.remove(to_erase)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
to_erase._remove_from_list()
|
to_erase._remove_from_list()
|
||||||
to_erase._erased = True # iterators may retain handles to erased nodes
|
to_erase._erased = True # iterators may retain handles to erased nodes
|
||||||
self._len -= 1
|
self._len -= 1
|
||||||
|
|
|
||||||
|
|
@ -385,6 +385,7 @@ class Node(_NodeBase):
|
||||||
Args:
|
Args:
|
||||||
x (Node): The node to put before this node. Must be a member of the same graph.
|
x (Node): The node to put before this node. Must be a member of the same graph.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
self._prepend(x)
|
self._prepend(x)
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
|
|
@ -396,6 +397,7 @@ class Node(_NodeBase):
|
||||||
Args:
|
Args:
|
||||||
x (Node): The node to put after this node. Must be a member of the same graph.
|
x (Node): The node to put after this node. Must be a member of the same graph.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
self._next._prepend(x)
|
self._next._prepend(x)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
|
||||||
|
|
@ -276,7 +276,6 @@ def tensorify_python_scalars(
|
||||||
):
|
):
|
||||||
transform = True
|
transform = True
|
||||||
try:
|
try:
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
proxy = _sympy_interp(zf.node.expr)
|
proxy = _sympy_interp(zf.node.expr)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
transform = False
|
transform = False
|
||||||
|
|
@ -303,7 +302,6 @@ def tensorify_python_scalars(
|
||||||
args.append(a)
|
args.append(a)
|
||||||
|
|
||||||
if transform:
|
if transform:
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
replacement_proxy = replacement_op(*args)
|
replacement_proxy = replacement_op(*args)
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,6 @@ class FakeTensorProp(torch.fx.Interpreter):
|
||||||
if (shape_env := self._mode.shape_env) and (
|
if (shape_env := self._mode.shape_env) and (
|
||||||
symbol_to_path := compute_unbacked_bindings(shape_env, result)
|
symbol_to_path := compute_unbacked_bindings(shape_env, result)
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
n.meta["unbacked_bindings"] = symbol_to_path
|
n.meta["unbacked_bindings"] = symbol_to_path
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
|
||||||
|
|
@ -298,14 +298,12 @@ def insert_deferred_runtime_asserts(
|
||||||
and s not in expr_to_proxy
|
and s not in expr_to_proxy
|
||||||
):
|
):
|
||||||
with _set_node_metadata_hook(gm, _node_metadata_hook):
|
with _set_node_metadata_hook(gm, _node_metadata_hook):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
|
expr_to_proxy[s] = fx.Proxy(cb(), tracer=tracer)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
|
log.debug("expr_to_proxy[%s] = %s", s, expr_to_proxy[s])
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
match_symbol(example_value, lambda: node)
|
match_symbol(example_value, lambda: node)
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
if isinstance(t := example_value, torch.Tensor):
|
if isinstance(t := example_value, torch.Tensor):
|
||||||
for i, s in enumerate(t.size()):
|
for i, s in enumerate(t.size()):
|
||||||
match_symbol(
|
match_symbol(
|
||||||
|
|
@ -386,7 +384,6 @@ def insert_deferred_runtime_asserts(
|
||||||
|
|
||||||
# maybe re-reify expression, replace current node
|
# maybe re-reify expression, replace current node
|
||||||
if (
|
if (
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
sym_expr in expr_to_proxy
|
sym_expr in expr_to_proxy
|
||||||
or ( # example value is redundant
|
or ( # example value is redundant
|
||||||
_is_intermediate_tensor_sym_call(node)
|
_is_intermediate_tensor_sym_call(node)
|
||||||
|
|
@ -405,10 +402,8 @@ def insert_deferred_runtime_asserts(
|
||||||
nn_module_stack=node.meta.get("nn_module_stack"),
|
nn_module_stack=node.meta.get("nn_module_stack"),
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
expr_to_proxy[sym_expr] = _sympy_interp(
|
expr_to_proxy[sym_expr] = _sympy_interp(
|
||||||
expr_to_proxy,
|
expr_to_proxy,
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
sym_expr,
|
sym_expr,
|
||||||
) # type: ignore[arg-type]
|
) # type: ignore[arg-type]
|
||||||
# won't try DCE-ing tensor compute here
|
# won't try DCE-ing tensor compute here
|
||||||
|
|
@ -419,14 +414,12 @@ def insert_deferred_runtime_asserts(
|
||||||
"CSE node %s -> %s for expr %s",
|
"CSE node %s -> %s for expr %s",
|
||||||
node,
|
node,
|
||||||
hash_node,
|
hash_node,
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
sym_expr,
|
sym_expr,
|
||||||
)
|
)
|
||||||
|
|
||||||
# store node in hash cons, don't delete/replace
|
# store node in hash cons, don't delete/replace
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
elif sym_expr not in expr_to_proxy and not isinstance(
|
elif sym_expr not in expr_to_proxy and not isinstance(
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
sym_expr,
|
sym_expr,
|
||||||
(sympy.Number, sympy.logic.boolalg.BooleanAtom),
|
(sympy.Number, sympy.logic.boolalg.BooleanAtom),
|
||||||
): # don't hash cons primitives
|
): # don't hash cons primitives
|
||||||
|
|
|
||||||
|
|
@ -318,7 +318,6 @@ def split_module(
|
||||||
and isinstance(s0 := val.node.expr, sympy.Symbol)
|
and isinstance(s0 := val.node.expr, sympy.Symbol)
|
||||||
and s0 not in symbol_to_node
|
and s0 not in symbol_to_node
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
symbol_to_node[val.node.expr] = node
|
symbol_to_node[val.node.expr] = node
|
||||||
|
|
||||||
if node.op in ["placeholder", "get_attr", "output"]:
|
if node.op in ["placeholder", "get_attr", "output"]:
|
||||||
|
|
|
||||||
|
|
@ -85,7 +85,6 @@ def get_source_partitions(
|
||||||
if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and (
|
if (source_fn_st := node.meta.get("source_fn_stack", None)) is None and (
|
||||||
torch_fn := node.meta.get("torch_fn", None)
|
torch_fn := node.meta.get("torch_fn", None)
|
||||||
) is not None:
|
) is not None:
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
node_fqn, source_fn = torch_fn
|
node_fqn, source_fn = torch_fn
|
||||||
source_fn_name = source_fn.split(".")[1]
|
source_fn_name = source_fn.split(".")[1]
|
||||||
if source_fn_name in wanted_sources:
|
if source_fn_name in wanted_sources:
|
||||||
|
|
|
||||||
|
|
@ -421,7 +421,7 @@ def set_dir(d: Union[str, os.PathLike]) -> None:
|
||||||
d (str): path to a local folder to save downloaded models & weights.
|
d (str): path to a local folder to save downloaded models & weights.
|
||||||
"""
|
"""
|
||||||
global _hub_dir
|
global _hub_dir
|
||||||
_hub_dir = os.path.expanduser(d) # pyrefly: ignore # no-matching-overload
|
_hub_dir = os.path.expanduser(d)
|
||||||
|
|
||||||
|
|
||||||
def list(
|
def list(
|
||||||
|
|
|
||||||
|
|
@ -167,7 +167,6 @@ def load(f, map_location=None, _extra_files=None, _restore_shapes=False):
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
cpp_module = torch._C.import_ir_module(
|
cpp_module = torch._C.import_ir_module(
|
||||||
cu,
|
cu,
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
os.fspath(f),
|
os.fspath(f),
|
||||||
map_location,
|
map_location,
|
||||||
_extra_files,
|
_extra_files,
|
||||||
|
|
@ -208,7 +207,6 @@ def validate_map_location(map_location=None):
|
||||||
|
|
||||||
def jit_module_from_flatbuffer(f):
|
def jit_module_from_flatbuffer(f):
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
f = os.fspath(f)
|
f = os.fspath(f)
|
||||||
return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
|
return wrap_cpp_module(torch._C._load_jit_module_from_file(f))
|
||||||
else:
|
else:
|
||||||
|
|
@ -258,7 +256,6 @@ def save_jit_module_to_flatbuffer(m, f, _extra_files=None):
|
||||||
extra_files = {}
|
extra_files = {}
|
||||||
|
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
f = os.fspath(f)
|
f = os.fspath(f)
|
||||||
torch._C._save_jit_module(m._c, f, extra_files)
|
torch._C._save_jit_module(m._c, f, extra_files)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,6 @@ def _load_for_lite_interpreter(f, map_location=None):
|
||||||
map_location = validate_map_location(map_location)
|
map_location = validate_map_location(map_location)
|
||||||
|
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
|
cpp_module = torch._C._load_for_lite_interpreter(os.fspath(f), map_location)
|
||||||
else:
|
else:
|
||||||
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
|
cpp_module = torch._C._load_for_lite_interpreter_from_buffer(
|
||||||
|
|
@ -106,7 +105,6 @@ def _get_model_bytecode_version(f_input) -> int:
|
||||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||||
|
|
||||||
if isinstance(f_input, (str, os.PathLike)):
|
if isinstance(f_input, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return torch._C._get_model_bytecode_version(os.fspath(f_input))
|
return torch._C._get_model_bytecode_version(os.fspath(f_input))
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
|
|
@ -140,7 +138,6 @@ def _get_mobile_model_contained_types(f_input) -> int:
|
||||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||||
|
|
||||||
if isinstance(f_input, (str, os.PathLike)):
|
if isinstance(f_input, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
|
return torch._C._get_mobile_model_contained_types(os.fspath(f_input))
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
|
|
@ -168,9 +165,7 @@ def _backport_for_mobile(f_input, f_output, to_version):
|
||||||
isinstance(f_output, (str, os.PathLike))
|
isinstance(f_output, (str, os.PathLike))
|
||||||
):
|
):
|
||||||
return torch._C._backport_for_mobile(
|
return torch._C._backport_for_mobile(
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
os.fspath(f_input),
|
os.fspath(f_input),
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
os.fspath(f_output),
|
os.fspath(f_output),
|
||||||
to_version,
|
to_version,
|
||||||
)
|
)
|
||||||
|
|
@ -198,7 +193,6 @@ def _backport_for_mobile_to_buffer(f_input, to_version):
|
||||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||||
|
|
||||||
if isinstance(f_input, (str, os.PathLike)):
|
if isinstance(f_input, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
|
return torch._C._backport_for_mobile_to_buffer(os.fspath(f_input), to_version)
|
||||||
else:
|
else:
|
||||||
return torch._C._backport_for_mobile_from_buffer_to_buffer(
|
return torch._C._backport_for_mobile_from_buffer_to_buffer(
|
||||||
|
|
@ -244,7 +238,6 @@ def _get_model_ops_and_info(f_input):
|
||||||
raise ValueError(f"The provided filename {f_input} is a directory")
|
raise ValueError(f"The provided filename {f_input} is a directory")
|
||||||
|
|
||||||
if isinstance(f_input, (str, os.PathLike)):
|
if isinstance(f_input, (str, os.PathLike)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return torch._C._get_model_ops_and_info(os.fspath(f_input))
|
return torch._C._get_model_ops_and_info(os.fspath(f_input))
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore # missing-attribute
|
||||||
|
|
|
||||||
|
|
@ -644,7 +644,7 @@ def impl(
|
||||||
>>> y2 = torch.sin(x) + 1
|
>>> y2 = torch.sin(x) + 1
|
||||||
>>> assert torch.allclose(y1, y2)
|
>>> assert torch.allclose(y1, y2)
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
|
return _impl(qualname, types, func, lib=lib, disable_dynamo=False)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -831,7 +831,6 @@ def register_kernel(
|
||||||
if device_types is None:
|
if device_types is None:
|
||||||
device_types = "CompositeExplicitAutograd"
|
device_types = "CompositeExplicitAutograd"
|
||||||
|
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
|
return _impl(op, device_types, func, lib=lib, disable_dynamo=True)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -642,7 +642,6 @@ def _sparse_coo_scatter_reduction_helper(
|
||||||
|
|
||||||
# promote dtype if specified
|
# promote dtype if specified
|
||||||
if values.dtype != output_dtype:
|
if values.dtype != output_dtype:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
values = values.to(output_dtype)
|
values = values.to(output_dtype)
|
||||||
|
|
||||||
if keepdim:
|
if keepdim:
|
||||||
|
|
@ -767,7 +766,6 @@ def _sparse_csr_segment_reduction_helper(
|
||||||
|
|
||||||
# promote dtype if specified
|
# promote dtype if specified
|
||||||
if values.dtype != output_dtype:
|
if values.dtype != output_dtype:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
values = values.to(output_dtype)
|
values = values.to(output_dtype)
|
||||||
|
|
||||||
if len(dims) == 0:
|
if len(dims) == 0:
|
||||||
|
|
|
||||||
|
|
@ -473,7 +473,6 @@ class ModuleList(Module):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def pop(self, key: Union[int, slice]) -> Module:
|
def pop(self, key: Union[int, slice]) -> Module:
|
||||||
# pyrefly: ignore # index-error
|
|
||||||
v = self[key]
|
v = self[key]
|
||||||
del self[key]
|
del self[key]
|
||||||
return v
|
return v
|
||||||
|
|
|
||||||
|
|
@ -363,7 +363,7 @@ class Conv1d(_ConvNd):
|
||||||
self.dilation,
|
self.dilation,
|
||||||
self.groups,
|
self.groups,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return F.conv1d(
|
return F.conv1d(
|
||||||
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
||||||
)
|
)
|
||||||
|
|
@ -541,7 +541,7 @@ class Conv2d(_ConvNd):
|
||||||
self.dilation,
|
self.dilation,
|
||||||
self.groups,
|
self.groups,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return F.conv2d(
|
return F.conv2d(
|
||||||
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
||||||
)
|
)
|
||||||
|
|
@ -711,7 +711,7 @@ class Conv3d(_ConvNd):
|
||||||
self.dilation,
|
self.dilation,
|
||||||
self.groups,
|
self.groups,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return F.conv3d(
|
return F.conv3d(
|
||||||
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -364,7 +364,6 @@ def parse_args(
|
||||||
fn_name = None
|
fn_name = None
|
||||||
args = [
|
args = [
|
||||||
_parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign]
|
_parse_arg(arg, arg_desc, arg_name, fn_name) # type: ignore[method-assign]
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)
|
for arg, arg_desc, arg_name in zip(args, arg_descriptors, arg_names)
|
||||||
]
|
]
|
||||||
# only support _outputs in kwargs
|
# only support _outputs in kwargs
|
||||||
|
|
|
||||||
|
|
@ -453,7 +453,7 @@ def _single_tensor_adam(
|
||||||
device_beta1 = beta1
|
device_beta1 = beta1
|
||||||
|
|
||||||
# Decay the first and second moment running average coefficient
|
# Decay the first and second moment running average coefficient
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
exp_avg.lerp_(grad, 1 - device_beta1)
|
exp_avg.lerp_(grad, 1 - device_beta1)
|
||||||
|
|
||||||
# Nested if is necessary to bypass jitscript rules
|
# Nested if is necessary to bypass jitscript rules
|
||||||
|
|
|
||||||
|
|
@ -398,7 +398,6 @@ class Optimizer:
|
||||||
self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict)
|
self.state: defaultdict[torch.Tensor, Any] = defaultdict(dict)
|
||||||
self.param_groups: list[dict[str, Any]] = []
|
self.param_groups: list[dict[str, Any]] = []
|
||||||
|
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
param_groups = list(params)
|
param_groups = list(params)
|
||||||
if len(param_groups) == 0:
|
if len(param_groups) == 0:
|
||||||
raise ValueError("optimizer got an empty parameter list")
|
raise ValueError("optimizer got an empty parameter list")
|
||||||
|
|
|
||||||
|
|
@ -219,7 +219,7 @@ class PackageExporter:
|
||||||
torch._C._log_api_usage_once("torch.package.PackageExporter")
|
torch._C._log_api_usage_once("torch.package.PackageExporter")
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
f = os.fspath(f) # pyrefly: ignore # no-matching-overload
|
f = os.fspath(f)
|
||||||
self.buffer: Optional[IO[bytes]] = None
|
self.buffer: Optional[IO[bytes]] = None
|
||||||
else: # is a byte buffer
|
else: # is a byte buffer
|
||||||
self.buffer = f
|
self.buffer = f
|
||||||
|
|
|
||||||
|
|
@ -108,7 +108,6 @@ class PackageImporter(Importer):
|
||||||
self.filename = "<pytorch_file_reader>"
|
self.filename = "<pytorch_file_reader>"
|
||||||
self.zip_reader = file_or_buffer
|
self.zip_reader = file_or_buffer
|
||||||
elif isinstance(file_or_buffer, (os.PathLike, str)):
|
elif isinstance(file_or_buffer, (os.PathLike, str)):
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
self.filename = os.fspath(file_or_buffer)
|
self.filename = os.fspath(file_or_buffer)
|
||||||
if not os.path.isdir(self.filename):
|
if not os.path.isdir(self.filename):
|
||||||
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
|
self.zip_reader = torch._C.PyTorchFileReader(self.filename)
|
||||||
|
|
|
||||||
|
|
@ -774,10 +774,7 @@ def _open_file_like(name_or_buffer: FileLike, mode: str) -> _opener[IO[bytes]]:
|
||||||
|
|
||||||
class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]):
|
class _open_zipfile_reader(_opener[torch._C.PyTorchFileReader]):
|
||||||
def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None:
|
def __init__(self, name_or_buffer: Union[str, IO[bytes]]) -> None:
|
||||||
super().__init__(
|
super().__init__(torch._C.PyTorchFileReader(name_or_buffer))
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
torch._C.PyTorchFileReader(name_or_buffer)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
|
class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
|
||||||
|
|
@ -970,7 +967,7 @@ def save(
|
||||||
_check_save_filelike(f)
|
_check_save_filelike(f)
|
||||||
|
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
f = os.fspath(f) # pyrefly: ignore # no-matching-overload
|
f = os.fspath(f)
|
||||||
|
|
||||||
if _use_new_zipfile_serialization:
|
if _use_new_zipfile_serialization:
|
||||||
with _open_zipfile_writer(f) as opened_zipfile:
|
with _open_zipfile_writer(f) as opened_zipfile:
|
||||||
|
|
@ -1524,7 +1521,6 @@ def load(
|
||||||
else:
|
else:
|
||||||
shared = False
|
shared = False
|
||||||
overall_storage = torch.UntypedStorage.from_file(
|
overall_storage = torch.UntypedStorage.from_file(
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
os.fspath(f),
|
os.fspath(f),
|
||||||
shared,
|
shared,
|
||||||
size,
|
size,
|
||||||
|
|
|
||||||
|
|
@ -701,7 +701,6 @@ def tree_map_only(
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -762,7 +761,6 @@ def tree_map_only_(
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1555,7 +1555,6 @@ def tree_map_only(
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
return tree_map(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1616,7 +1615,6 @@ def tree_map_only_(
|
||||||
tree: PyTree,
|
tree: PyTree,
|
||||||
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
is_leaf: Optional[Callable[[PyTree], bool]] = None,
|
||||||
) -> PyTree:
|
) -> PyTree:
|
||||||
# pyrefly: ignore # no-matching-overload
|
|
||||||
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
return tree_map_(map_only(type_or_types_or_pred)(func), tree, is_leaf=is_leaf)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1531,7 +1531,7 @@ def include_paths(device_type: str = "cpu", torch_include_dirs=True) -> list[str
|
||||||
# Support CUDA_INC_PATH env variable supported by CMake files
|
# Support CUDA_INC_PATH env variable supported by CMake files
|
||||||
if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \
|
if (cuda_inc_path := os.environ.get("CUDA_INC_PATH", None)) and \
|
||||||
cuda_inc_path != '/usr/include':
|
cuda_inc_path != '/usr/include':
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
paths.append(cuda_inc_path)
|
paths.append(cuda_inc_path)
|
||||||
if CUDNN_HOME is not None:
|
if CUDNN_HOME is not None:
|
||||||
paths.append(os.path.join(CUDNN_HOME, 'include'))
|
paths.append(os.path.join(CUDNN_HOME, 'include'))
|
||||||
|
|
|
||||||
|
|
@ -678,7 +678,6 @@ class _BaseDataLoaderIter:
|
||||||
|
|
||||||
# Set pin memory device based on the current accelerator.
|
# Set pin memory device based on the current accelerator.
|
||||||
self._pin_memory_device = (
|
self._pin_memory_device = (
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
acc.type
|
acc.type
|
||||||
if self._pin_memory
|
if self._pin_memory
|
||||||
and (acc := torch.accelerator.current_accelerator()) is not None
|
and (acc := torch.accelerator.current_accelerator()) is not None
|
||||||
|
|
|
||||||
|
|
@ -251,7 +251,6 @@ def get_device_capability(device: Optional[_device_t] = None) -> dict[str, Any]:
|
||||||
# Only keep attributes that are safe for dictionary serialization.
|
# Only keep attributes that are safe for dictionary serialization.
|
||||||
serializable_types = (int, float, bool, str, type(None), list, tuple, dict)
|
serializable_types = (int, float, bool, str, type(None), list, tuple, dict)
|
||||||
return {
|
return {
|
||||||
# pyrefly: ignore # unbound-name
|
|
||||||
key: value
|
key: value
|
||||||
for key in dir(props)
|
for key in dir(props)
|
||||||
if not key.startswith("__")
|
if not key.startswith("__")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user