Fix syntax for pyrefly errors (#166496)

Last one! This ensures all existing suppressions match the syntax expected and will silence only one error code

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166496
Approved by: https://github.com/Skylion007, https://github.com/mlazos
This commit is contained in:
Maggie Moss 2025-10-29 20:00:21 +00:00 committed by PyTorch MergeBot
parent fa560e1158
commit d1a6e006e0
47 changed files with 203 additions and 203 deletions

View File

@ -841,7 +841,7 @@ def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr:
factor = functools.reduce(math.gcd, map(integer_coefficient, atoms)) factor = functools.reduce(math.gcd, map(integer_coefficient, atoms))
if factor == 1: if factor == 1:
return expr return expr
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
atoms = [div_by_factor(x, factor) for x in atoms] atoms = [div_by_factor(x, factor) for x in atoms]
return _sympy_from_args( return _sympy_from_args(
sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative
@ -2207,7 +2207,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext):
def __post_init__(self) -> None: def __post_init__(self) -> None:
super().__post_init__() super().__post_init__()
if self.inner_contexts is None: if self.inner_contexts is None:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.inner_contexts = {} self.inner_contexts = {}
@ -2296,12 +2296,12 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
# only re-create the objects if any of the args changed to avoid expensive # only re-create the objects if any of the args changed to avoid expensive
# checks when re-creating objects. # checks when re-creating objects.
new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type] new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type]
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)): if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
return _fast_expand(expr.func(*new_args)) return _fast_expand(expr.func(*new_args))
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if expr.is_Pow: if expr.is_Pow:
base: sympy.Expr base: sympy.Expr
exp: sympy.Expr exp: sympy.Expr
@ -2311,11 +2311,11 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
return sympy.expand_multinomial(expr, deep=False) return sympy.expand_multinomial(expr, deep=False)
elif exp < 0: elif exp < 0:
return S.One / sympy.expand_multinomial(S.One / expr, deep=False) return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
elif expr.is_Mul: elif expr.is_Mul:
num: list[sympy.Expr] = [] num: list[sympy.Expr] = []
den: list[sympy.Expr] = [] den: list[sympy.Expr] = []
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
for arg in expr.args: for arg in expr.args:
if arg.is_Pow and arg.args[1] == -1: if arg.is_Pow and arg.args[1] == -1:
den.append(S.One / arg) # type: ignore[operator, arg-type] den.append(S.One / arg) # type: ignore[operator, arg-type]
@ -2437,7 +2437,7 @@ def _maybe_evaluate_static_worker(
# TODO: remove this try catch (esp for unbacked_only) # TODO: remove this try catch (esp for unbacked_only)
try: try:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
new_expr = expr.xreplace(new_shape_env) new_expr = expr.xreplace(new_shape_env)
except RecursionError: except RecursionError:
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env) log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
@ -2975,19 +2975,19 @@ class DimConstraints:
# is_integer tests though haha # is_integer tests though haha
return (base - mod_reduced) / divisor return (base - mod_reduced) / divisor
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if expr.has(Mod): if expr.has(Mod):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
expr = expr.replace(Mod, mod_handler) expr = expr.replace(Mod, mod_handler)
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
# arguments should be OK. # arguments should be OK.
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if expr.has(PythonMod): if expr.has(PythonMod):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
expr = expr.replace(PythonMod, mod_handler) expr = expr.replace(PythonMod, mod_handler)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if expr.has(FloorDiv): if expr.has(FloorDiv):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
expr = expr.replace(FloorDiv, floor_div_handler) expr = expr.replace(FloorDiv, floor_div_handler)
return expr return expr
@ -5106,7 +5106,7 @@ class ShapeEnv:
if duck: if duck:
# Make sure to reuse this symbol for subsequent duck shaping # Make sure to reuse this symbol for subsequent duck shaping
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
self.val_to_var[val] = sympy_expr self.val_to_var[val] = sympy_expr
if isinstance(val, int): if isinstance(val, int):
@ -5338,9 +5338,9 @@ class ShapeEnv:
# Expand optional inputs, or verify invariants are upheld # Expand optional inputs, or verify invariants are upheld
if input_contexts is None: if input_contexts is None:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
input_contexts = [ input_contexts = [
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
_create_no_constraints_context(t) if isinstance(t, Tensorlike) else None _create_no_constraints_context(t) if isinstance(t, Tensorlike) else None
for t in placeholders for t in placeholders
] ]
@ -5350,7 +5350,7 @@ class ShapeEnv:
for i, (t, context) in enumerate(zip(placeholders, input_contexts)): for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
if isinstance(t, Tensorlike): if isinstance(t, Tensorlike):
if context is None: if context is None:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
input_contexts[i] = _create_no_constraints_context(t) input_contexts[i] = _create_no_constraints_context(t)
else: else:
assert isinstance(t, (SymInt, int, SymFloat, float)) assert isinstance(t, (SymInt, int, SymFloat, float))
@ -5636,7 +5636,7 @@ class ShapeEnv:
s = sympy.Float(val) s = sympy.Float(val)
input_guards.append((source, s)) input_guards.append((source, s))
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
for t, source, context in zip(placeholders, sources, input_contexts): for t, source, context in zip(placeholders, sources, input_contexts):
if isinstance(source, str): if isinstance(source, str):
from torch._dynamo.source import LocalSource from torch._dynamo.source import LocalSource
@ -5999,7 +5999,7 @@ class ShapeEnv:
else: else:
str_msg = f" - {msg_cb()}" str_msg = f" - {msg_cb()}"
error_msgs.append(str_msg) error_msgs.append(str_msg)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
debug_names.add(debug_name) debug_names.add(debug_name)
if len(error_msgs) > 0: if len(error_msgs) > 0:
debug_names_str = ", ".join(sorted(debug_names)) debug_names_str = ", ".join(sorted(debug_names))
@ -6133,7 +6133,7 @@ class ShapeEnv:
Get a list of guards, but pruned so it only provides guards that Get a list of guards, but pruned so it only provides guards that
reference symints from the passed in input reference symints from the passed in input
""" """
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
symints = { symints = {
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol) s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
} }
@ -6396,7 +6396,7 @@ class ShapeEnv:
Apply symbol replacements to any symbols in the given expression. Apply symbol replacements to any symbols in the given expression.
""" """
replacements = {} replacements = {}
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
for s in expr.free_symbols: for s in expr.free_symbols:
r = self._find(s) r = self._find(s)
@ -6406,7 +6406,7 @@ class ShapeEnv:
if not r.is_Symbol or r != s: if not r.is_Symbol or r != s:
replacements[s] = r replacements[s] = r
if replacements: if replacements:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
return safe_expand(expr.xreplace(replacements)) return safe_expand(expr.xreplace(replacements))
else: else:
return expr return expr
@ -7181,7 +7181,7 @@ class ShapeEnv:
instructions = list(dis.Bytecode(frame.f_code)) instructions = list(dis.Bytecode(frame.f_code))
co_lines, offset = inspect.getsourcelines(frame.f_code) co_lines, offset = inspect.getsourcelines(frame.f_code)
start, end, cur = None, None, None start, end, cur = None, None, None
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
for i, instr in enumerate(instructions): for i, instr in enumerate(instructions):
if instr.starts_line is not None: if instr.starts_line is not None:
cur = instr.starts_line cur = instr.starts_line

View File

@ -238,7 +238,7 @@ class Dispatcher:
"To use a variadic union type place the desired types " "To use a variadic union type place the desired types "
"inside of a tuple, e.g., [(int, str)]" "inside of a tuple, e.g., [(int, str)]"
) )
# pyrefly: ignore # bad-specialization # pyrefly: ignore [bad-specialization]
new_signature.append(Variadic[typ[0]]) new_signature.append(Variadic[typ[0]])
else: else:
new_signature.append(typ) new_signature.append(typ)
@ -407,7 +407,7 @@ class MethodDispatcher(Dispatcher):
Dispatcher Dispatcher
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
__slots__ = ("obj", "cls") __slots__ = ("obj", "cls")
@classmethod @classmethod

View File

@ -298,7 +298,7 @@ def update_in(d, keys, func, default=None, factory=dict):
rv = inner = factory() rv = inner = factory()
rv.update(d) rv.update(d)
# pyrefly: ignore # not-iterable # pyrefly: ignore [not-iterable]
for key in ks: for key in ks:
if k in d: if k in d:
d = d[k] d = d[k]

View File

@ -1380,7 +1380,7 @@ class Graph:
f(to_erase) f(to_erase)
self._find_nodes_lookup_table.remove(to_erase) self._find_nodes_lookup_table.remove(to_erase)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
to_erase._remove_from_list() to_erase._remove_from_list()
to_erase._erased = True # iterators may retain handles to erased nodes to_erase._erased = True # iterators may retain handles to erased nodes
self._len -= 1 self._len -= 1
@ -1941,7 +1941,7 @@ class Graph:
"a str is expected" "a str is expected"
) )
if node.op in ["get_attr", "call_module"]: if node.op in ["get_attr", "call_module"]:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
target_atoms = node.target.split(".") target_atoms = node.target.split(".")
m_itr = self.owning_module m_itr = self.owning_module
for i, atom in enumerate(target_atoms): for i, atom in enumerate(target_atoms):

View File

@ -535,7 +535,7 @@ class GraphModule(torch.nn.Module):
self.graph._tracer_cls self.graph._tracer_cls
and "<locals>" not in self.graph._tracer_cls.__qualname__ and "<locals>" not in self.graph._tracer_cls.__qualname__
): ):
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self._tracer_cls = self.graph._tracer_cls self._tracer_cls = self.graph._tracer_cls
self._tracer_extras = {} self._tracer_extras = {}

View File

@ -165,12 +165,12 @@ def tensorify_python_scalars(
node = graph.call_function( node = graph.call_function(
torch.ops.aten.scalar_tensor.default, torch.ops.aten.scalar_tensor.default,
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
(c,), (c,),
{"dtype": dtype}, {"dtype": dtype},
) )
with fake_mode: with fake_mode:
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype) node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype)
expr_to_tensor_proxy[expr] = MetaProxy( expr_to_tensor_proxy[expr] = MetaProxy(
node, node,
@ -223,13 +223,13 @@ def tensorify_python_scalars(
expr_to_sym_proxy[s] = MetaProxy( expr_to_sym_proxy[s] = MetaProxy(
node, tracer=tracer, fake_mode=fake_mode node, tracer=tracer, fake_mode=fake_mode
) )
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
elif (sym_expr := _get_sym_val(node)) is not None: elif (sym_expr := _get_sym_val(node)) is not None:
if sym_expr not in expr_to_sym_proxy and not isinstance( if sym_expr not in expr_to_sym_proxy and not isinstance(
sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom) sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
): ):
expr_to_sym_proxy[sym_expr] = MetaProxy( expr_to_sym_proxy[sym_expr] = MetaProxy(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
node, node,
tracer=tracer, tracer=tracer,
fake_mode=fake_mode, fake_mode=fake_mode,
@ -238,7 +238,7 @@ def tensorify_python_scalars(
# Specialize all dimensions that contain symfloats. Here's # Specialize all dimensions that contain symfloats. Here's
# an example test that requires this: # an example test that requires this:
# PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950 # PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
val = node.meta.get("val") val = node.meta.get("val")
if isinstance(val, FakeTensor): if isinstance(val, FakeTensor):
for dim in val.shape: for dim in val.shape:
@ -257,17 +257,17 @@ def tensorify_python_scalars(
should_restart = True should_restart = True
# Look for functions to convert # Look for functions to convert
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if node.op == "call_function" and ( if node.op == "call_function" and (
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
replacement_op := SUPPORTED_OPS.get(node.target) replacement_op := SUPPORTED_OPS.get(node.target)
): ):
args: list[Any] = [] args: list[Any] = []
transform = False transform = False
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
compute_dtype = get_computation_dtype(node.meta["val"].dtype) compute_dtype = get_computation_dtype(node.meta["val"].dtype)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
for a in node.args: for a in node.args:
if ( if (
isinstance(a, fx.Node) isinstance(a, fx.Node)
@ -304,7 +304,7 @@ def tensorify_python_scalars(
if transform: if transform:
replacement_proxy = replacement_op(*args) replacement_proxy = replacement_op(*args)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if compute_dtype != node.meta["val"].dtype: if compute_dtype != node.meta["val"].dtype:
replacement_proxy = ( replacement_proxy = (
torch.ops.prims.convert_element_type.default( torch.ops.prims.convert_element_type.default(
@ -313,9 +313,9 @@ def tensorify_python_scalars(
) )
) )
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
node.replace_all_uses_with(replacement_proxy.node) node.replace_all_uses_with(replacement_proxy.node)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
graph.erase_node(node) graph.erase_node(node)
metrics_context = get_metrics_context() metrics_context = get_metrics_context()
@ -324,16 +324,16 @@ def tensorify_python_scalars(
"tensorify_float_success", True, overwrite=True "tensorify_float_success", True, overwrite=True
) )
else: else:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
for a in node.args: for a in node.args:
if ( if (
isinstance(a, fx.Node) isinstance(a, fx.Node)
and "val" in a.meta and "val" in a.meta
and isinstance(zf := a.meta["val"], torch.SymFloat) and isinstance(zf := a.meta["val"], torch.SymFloat)
): ):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
failed_tensorify_ops.update(str(node.target)) failed_tensorify_ops.update(str(node.target))
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
log.info("Failed to tensorify %s", str(node.target)) log.info("Failed to tensorify %s", str(node.target))
# Now do one more pass that specializes all symfloats we didn't manage # Now do one more pass that specializes all symfloats we didn't manage

View File

@ -437,13 +437,13 @@ if HAS_PYDOT:
) )
current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment] current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment]
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
current_graph.add_node(dot_node) current_graph.add_node(dot_node)
def get_module_params_or_buffers(): def get_module_params_or_buffers():
for pname, ptensor in chain( for pname, ptensor in chain(
leaf_module.named_parameters(), leaf_module.named_parameters(),
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
leaf_module.named_buffers(), leaf_module.named_buffers(),
): ):
pname1 = node.name + "." + pname pname1 = node.name + "." + pname

View File

@ -11,7 +11,7 @@ __all__ = ["PassResult", "PassBase"]
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
# pyrefly: ignore # invalid-inheritance # pyrefly: ignore [invalid-inheritance]
class PassResult(namedtuple("PassResult", ["graph_module", "modified"])): class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
""" """
Result of a pass: Result of a pass:

View File

@ -31,7 +31,7 @@ def pass_result_wrapper(fn: Callable) -> Callable:
wrapped_fn (Callable[Module, PassResult]) wrapped_fn (Callable[Module, PassResult])
""" """
if fn is None: if fn is None:
# pyrefly: ignore # bad-return # pyrefly: ignore [bad-return]
return None return None
@wraps(fn) @wraps(fn)

View File

@ -396,25 +396,25 @@ class _MinimizerBase:
report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
if self.module_exporter: if self.module_exporter:
if isinstance(result_key, tuple): # type: ignore[possibly-undefined] if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
result_key = result_key[-1] result_key = result_key[-1]
# If the result is still a tuple (happens in non-sequential mode), # If the result is still a tuple (happens in non-sequential mode),
# we only use the first element as name. # we only use the first element as name.
if isinstance(result_key, tuple): # type: ignore[possibly-undefined] if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
result_key = str(result_key[0]) result_key = str(result_key[0])
# pyre-ignore[29]: not a function # pyre-ignore[29]: not a function
self.module_exporter( self.module_exporter(
a_input, a_input,
submodule, submodule,
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
result_key + "_cpu", result_key + "_cpu",
) )
# pyre-ignore[29]: not a function # pyre-ignore[29]: not a function
self.module_exporter( self.module_exporter(
b_input, b_input,
submodule, submodule,
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
result_key + "_acc", result_key + "_acc",
) )
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined] raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]

View File

@ -360,7 +360,7 @@ def insert_deferred_runtime_asserts(
): ):
# this guards against deleting calls like item() that produce new untracked symbols # this guards against deleting calls like item() that produce new untracked symbols
def has_new_untracked_symbols(): def has_new_untracked_symbols():
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
for symbol in sym_expr.free_symbols: for symbol in sym_expr.free_symbols:
if symbol not in expr_to_proxy: if symbol not in expr_to_proxy:
return True return True
@ -376,7 +376,7 @@ def insert_deferred_runtime_asserts(
assert resolved_unbacked_bindings is not None assert resolved_unbacked_bindings is not None
def has_new_unbacked_bindings(): def has_new_unbacked_bindings():
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
for key in resolved_unbacked_bindings.keys(): for key in resolved_unbacked_bindings.keys():
if key not in expr_to_proxy: if key not in expr_to_proxy:
return True return True

View File

@ -351,9 +351,9 @@ def split_module(
assert all(v is not None for v in autocast_exits.values()), "autocast must exit" assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()} autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
grad_regions = {k: sorted(v) for k, v in grad_regions.items()} grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
if _LOGGER.isEnabledFor(logging.DEBUG): if _LOGGER.isEnabledFor(logging.DEBUG):
@ -418,9 +418,9 @@ def split_module(
for regions_mapping in [autocast_regions, grad_regions]: for regions_mapping in [autocast_regions, grad_regions]:
for node, regions in regions_mapping.items(): for node, regions in regions_mapping.items():
assert len(regions) > 0 assert len(regions) > 0
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
partitions[str(regions[0])].environment[node] = node partitions[str(regions[0])].environment[node] = node
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
for r in regions[1:]: for r in regions[1:]:
partition = partitions[str(r)] partition = partitions[str(r)]
new_node = partition.graph.create_node( new_node = partition.graph.create_node(
@ -520,7 +520,7 @@ def split_module(
for node in reversed(regions_mapping): for node in reversed(regions_mapping):
regions = regions_mapping[node] regions = regions_mapping[node]
assert len(regions) > 0 assert len(regions) > 0
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
for r in regions[:-1]: for r in regions[:-1]:
partition = partitions[str(r)] partition = partitions[str(r)]
exit_node = autocast_exits[node] exit_node = autocast_exits[node]

View File

@ -64,7 +64,7 @@ def lift_subgraph_as_module(
for name in target_name_parts[:-1]: for name in target_name_parts[:-1]:
if not hasattr(curr, name): if not hasattr(curr, name):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
curr.add_module(name, HolderModule({})) curr.add_module(name, HolderModule({}))
curr = getattr(curr, name) curr = getattr(curr, name)

View File

@ -242,7 +242,7 @@ class Library:
if dispatch_key == "": if dispatch_key == "":
dispatch_key = self.dispatch_key dispatch_key = self.dispatch_key
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense) assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
if isinstance(op_name, str): if isinstance(op_name, str):

View File

@ -484,7 +484,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
raise IndexError( raise IndexError(
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})" f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
) )
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
dims.append(d % ndim) dims.append(d % ndim)
return tuple(sorted(dims)) return tuple(sorted(dims))
@ -1017,7 +1017,7 @@ def _combine_input_and_mask(
class Combine(torch.autograd.Function): class Combine(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def forward(ctx, input, mask): def forward(ctx, input, mask):
"""Return input with masked-out elements eliminated for the given operations.""" """Return input with masked-out elements eliminated for the given operations."""
ctx.save_for_backward(mask) ctx.save_for_backward(mask)
@ -1028,7 +1028,7 @@ def _combine_input_and_mask(
return helper(input, mask) return helper(input, mask)
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def backward(ctx, grad_output): def backward(ctx, grad_output):
(mask,) = ctx.saved_tensors (mask,) = ctx.saved_tensors
grad_data = ( grad_data = (
@ -1403,18 +1403,18 @@ elements, have ``nan`` values.
if input.layout == torch.strided: if input.layout == torch.strided:
if mask is None: if mask is None:
# TODO: compute count analytically # TODO: compute count analytically
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
count = sum( count = sum(
torch.ones(input.shape, dtype=torch.int64, device=input.device), torch.ones(input.shape, dtype=torch.int64, device=input.device),
dim, dim,
keepdim=keepdim, keepdim=keepdim,
) )
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
total = sum(input, dim, keepdim=keepdim, dtype=dtype) total = sum(input, dim, keepdim=keepdim, dtype=dtype)
else: else:
inmask = _input_mask(input, mask=mask) inmask = _input_mask(input, mask=mask)
count = inmask.sum(dim=dim, keepdim=bool(keepdim)) count = inmask.sum(dim=dim, keepdim=bool(keepdim))
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask) total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
return total / count return total / count
elif input.layout == torch.sparse_csr: elif input.layout == torch.sparse_csr:
@ -1625,18 +1625,18 @@ def _std_var(
if input.layout == torch.strided: if input.layout == torch.strided:
if mask is None: if mask is None:
# TODO: compute count analytically # TODO: compute count analytically
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
count = sum( count = sum(
torch.ones(input.shape, dtype=torch.int64, device=input.device), torch.ones(input.shape, dtype=torch.int64, device=input.device),
dim, dim,
keepdim=True, keepdim=True,
) )
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
sample_total = sum(input, dim, keepdim=True, dtype=dtype) sample_total = sum(input, dim, keepdim=True, dtype=dtype)
else: else:
inmask = _input_mask(input, mask=mask) inmask = _input_mask(input, mask=mask)
count = inmask.sum(dim=dim, keepdim=True) count = inmask.sum(dim=dim, keepdim=True)
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask) sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
# TODO: replace torch.subtract/divide/square/maximum with # TODO: replace torch.subtract/divide/square/maximum with
# masked subtract/divide/square/maximum when these will be # masked subtract/divide/square/maximum when these will be
@ -1644,7 +1644,7 @@ def _std_var(
sample_mean = torch.divide(sample_total, count) sample_mean = torch.divide(sample_total, count)
x = torch.subtract(input, sample_mean) x = torch.subtract(input, sample_mean)
if mask is None: if mask is None:
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
else: else:
total = sum( total = sum(

View File

@ -47,7 +47,7 @@ def _check_args_kwargs_length(
class _MaskedContiguous(torch.autograd.Function): class _MaskedContiguous(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def forward(ctx, input): def forward(ctx, input):
if not is_masked_tensor(input): if not is_masked_tensor(input):
raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.") raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
@ -61,14 +61,14 @@ class _MaskedContiguous(torch.autograd.Function):
return MaskedTensor(data.contiguous(), mask.contiguous()) return MaskedTensor(data.contiguous(), mask.contiguous())
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output return grad_output
class _MaskedToDense(torch.autograd.Function): class _MaskedToDense(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def forward(ctx, input): def forward(ctx, input):
if not is_masked_tensor(input): if not is_masked_tensor(input):
raise ValueError("MaskedToDense forward: input must be a MaskedTensor.") raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
@ -83,7 +83,7 @@ class _MaskedToDense(torch.autograd.Function):
return MaskedTensor(data.to_dense(), mask.to_dense()) return MaskedTensor(data.to_dense(), mask.to_dense())
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def backward(ctx, grad_output): def backward(ctx, grad_output):
layout = ctx.layout layout = ctx.layout
@ -98,7 +98,7 @@ class _MaskedToDense(torch.autograd.Function):
class _MaskedToSparse(torch.autograd.Function): class _MaskedToSparse(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def forward(ctx, input): def forward(ctx, input):
if not is_masked_tensor(input): if not is_masked_tensor(input):
raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.") raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
@ -115,14 +115,14 @@ class _MaskedToSparse(torch.autograd.Function):
return MaskedTensor(sparse_data, sparse_mask) return MaskedTensor(sparse_data, sparse_mask)
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output.to_dense() return grad_output.to_dense()
class _MaskedToSparseCsr(torch.autograd.Function): class _MaskedToSparseCsr(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def forward(ctx, input): def forward(ctx, input):
if not is_masked_tensor(input): if not is_masked_tensor(input):
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.") raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
@ -143,21 +143,21 @@ class _MaskedToSparseCsr(torch.autograd.Function):
return MaskedTensor(sparse_data, sparse_mask) return MaskedTensor(sparse_data, sparse_mask)
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output.to_dense() return grad_output.to_dense()
class _MaskedWhere(torch.autograd.Function): class _MaskedWhere(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def forward(ctx, cond, self, other): def forward(ctx, cond, self, other):
ctx.mark_non_differentiable(cond) ctx.mark_non_differentiable(cond)
ctx.save_for_backward(cond) ctx.save_for_backward(cond)
return torch.ops.aten.where(cond, self, other) return torch.ops.aten.where(cond, self, other)
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def backward(ctx, grad_output): def backward(ctx, grad_output):
(cond,) = ctx.saved_tensors (cond,) = ctx.saved_tensors

View File

@ -174,7 +174,7 @@ class MaskedTensor(torch.Tensor):
UserWarning, UserWarning,
stacklevel=2, stacklevel=2,
) )
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs) return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
def _preprocess_data(self, data, mask): def _preprocess_data(self, data, mask):
@ -244,12 +244,12 @@ class MaskedTensor(torch.Tensor):
class Constructor(torch.autograd.Function): class Constructor(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def forward(ctx, data, mask): def forward(ctx, data, mask):
return MaskedTensor(data, mask) return MaskedTensor(data, mask)
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def backward(ctx, grad_output): def backward(ctx, grad_output):
return grad_output, None return grad_output, None
@ -336,12 +336,12 @@ class MaskedTensor(torch.Tensor):
def get_data(self): def get_data(self):
class GetData(torch.autograd.Function): class GetData(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def forward(ctx, self): def forward(ctx, self):
return self._masked_data.detach() return self._masked_data.detach()
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def backward(ctx, grad_output): def backward(ctx, grad_output):
if is_masked_tensor(grad_output): if is_masked_tensor(grad_output):
return grad_output return grad_output

View File

@ -114,7 +114,7 @@ class ProcessContext:
"""Attempt to join all processes with a shared timeout.""" """Attempt to join all processes with a shared timeout."""
end = time.monotonic() + timeout end = time.monotonic() + timeout
for process in self.processes: for process in self.processes:
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
time_to_wait = max(0, end - time.monotonic()) time_to_wait = max(0, end - time.monotonic())
process.join(time_to_wait) process.join(time_to_wait)

View File

@ -51,7 +51,7 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1. # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
# For other dims, subtract 1 to convert to inner space. # For other dims, subtract 1 to convert to inner space.
return ( return (
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
ragged_dim - 1 if dim == 0 else dim - 1 ragged_dim - 1 if dim == 0 else dim - 1
) )
@ -2008,7 +2008,7 @@ def index_put_(func, *args, **kwargs):
else: else:
lengths = inp.lengths() lengths = inp.lengths()
torch._assert_async( torch._assert_async(
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
torch.all(indices[inp._ragged_idx] < lengths), torch.all(indices[inp._ragged_idx] < lengths),
"Some indices in the ragged dimension are out of bounds!", "Some indices in the ragged dimension are out of bounds!",
) )

View File

@ -668,5 +668,5 @@ def _get_numa_node_indices_for_socket_index(*, socket_index: int) -> set[int]:
def _get_allowed_cpu_indices_for_current_thread() -> set[int]: def _get_allowed_cpu_indices_for_current_thread() -> set[int]:
# 0 denotes current thread # 0 denotes current thread
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
return os.sched_getaffinity(0) # type:ignore[attr-defined] return os.sched_getaffinity(0) # type:ignore[attr-defined]

View File

@ -251,7 +251,7 @@ def _compare_onnx_pytorch_outputs_in_np(
# pyrefly: ignore [missing-attribute] # pyrefly: ignore [missing-attribute]
if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8: if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8:
warnings.warn("ONNX output is quantized", stacklevel=2) warnings.warn("ONNX output is quantized", stacklevel=2)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8: if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8:
warnings.warn("PyTorch output is quantized", stacklevel=2) warnings.warn("PyTorch output is quantized", stacklevel=2)
raise raise

View File

@ -78,7 +78,7 @@ def _adjust_lr(
A, B = param_shape[:2] A, B = param_shape[:2]
if adjust_lr_fn is None or adjust_lr_fn == "original": if adjust_lr_fn is None or adjust_lr_fn == "original":
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
adjusted_ratio = math.sqrt(max(1, A / B)) adjusted_ratio = math.sqrt(max(1, A / B))
elif adjust_lr_fn == "match_rms_adamw": elif adjust_lr_fn == "match_rms_adamw":
adjusted_ratio = 0.2 * math.sqrt(max(A, B)) adjusted_ratio = 0.2 * math.sqrt(max(A, B))

View File

@ -423,7 +423,7 @@ def _single_tensor_adam(
if weight_decay.requires_grad: if weight_decay.requires_grad:
grad = grad.addcmul_(param.clone(), weight_decay) grad = grad.addcmul_(param.clone(), weight_decay)
else: else:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
grad = grad.add(param, alpha=weight_decay) grad = grad.add(param, alpha=weight_decay)
else: else:
grad = grad.add(param, alpha=weight_decay) grad = grad.add(param, alpha=weight_decay)

View File

@ -264,7 +264,7 @@ def _single_tensor_asgd(
ax.copy_(param) ax.copy_(param)
if capturable: if capturable:
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha)) eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t))) mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
else: else:

View File

@ -113,16 +113,16 @@ def _strong_wolfe(
# compute new trial value # compute new trial value
t = _cubic_interpolate( t = _cubic_interpolate(
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
bracket[0], bracket[0],
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
bracket_f[0], bracket_f[0],
bracket_gtd[0], # type: ignore[possibly-undefined] bracket_gtd[0], # type: ignore[possibly-undefined]
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
bracket[1], bracket[1],
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
bracket_f[1], bracket_f[1],
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
bracket_gtd[1], bracket_gtd[1],
) )
@ -133,20 +133,20 @@ def _strong_wolfe(
# + `t` is at one of the boundary, # + `t` is at one of the boundary,
# we will move `t` to a position which is `0.1 * len(bracket)` # we will move `t` to a position which is `0.1 * len(bracket)`
# away from the nearest boundary point. # away from the nearest boundary point.
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
eps = 0.1 * (max(bracket) - min(bracket)) eps = 0.1 * (max(bracket) - min(bracket))
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
if min(max(bracket) - t, t - min(bracket)) < eps: if min(max(bracket) - t, t - min(bracket)) < eps:
# interpolation close to boundary # interpolation close to boundary
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
if insuf_progress or t >= max(bracket) or t <= min(bracket): if insuf_progress or t >= max(bracket) or t <= min(bracket):
# evaluate at 0.1 away from boundary # evaluate at 0.1 away from boundary
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
if abs(t - max(bracket)) < abs(t - min(bracket)): if abs(t - max(bracket)) < abs(t - min(bracket)):
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
t = max(bracket) - eps t = max(bracket) - eps
else: else:
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
t = min(bracket) + eps t = min(bracket) + eps
insuf_progress = False insuf_progress = False
else: else:
@ -160,45 +160,45 @@ def _strong_wolfe(
gtd_new = g_new.dot(d) gtd_new = g_new.dot(d)
ls_iter += 1 ls_iter += 1
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]: if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
# Armijo condition not satisfied or not lower than lowest point # Armijo condition not satisfied or not lower than lowest point
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
bracket[high_pos] = t bracket[high_pos] = t
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
bracket_f[high_pos] = f_new bracket_f[high_pos] = f_new
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
bracket_gtd[high_pos] = gtd_new bracket_gtd[high_pos] = gtd_new
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0) low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
else: else:
if abs(gtd_new) <= -c2 * gtd: if abs(gtd_new) <= -c2 * gtd:
# Wolfe conditions satisfied # Wolfe conditions satisfied
done = True done = True
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0: elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
# old high becomes new low # old high becomes new low
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
bracket[high_pos] = bracket[low_pos] bracket[high_pos] = bracket[low_pos]
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
bracket_f[high_pos] = bracket_f[low_pos] bracket_f[high_pos] = bracket_f[low_pos]
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined] bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
bracket_gtd[high_pos] = bracket_gtd[low_pos] bracket_gtd[high_pos] = bracket_gtd[low_pos]
# new point becomes new low # new point becomes new low
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
bracket[low_pos] = t bracket[low_pos] = t
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
bracket_f[low_pos] = f_new bracket_f[low_pos] = f_new
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined] bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
bracket_gtd[low_pos] = gtd_new bracket_gtd[low_pos] = gtd_new
# return stuff # return stuff
t = bracket[low_pos] # type: ignore[possibly-undefined] t = bracket[low_pos] # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
f_new = bracket_f[low_pos] f_new = bracket_f[low_pos]
g_new = bracket_g[low_pos] # type: ignore[possibly-undefined] g_new = bracket_g[low_pos] # type: ignore[possibly-undefined]
return f_new, g_new, t, ls_func_evals return f_new, g_new, t, ls_func_evals
@ -276,7 +276,7 @@ class LBFGS(Optimizer):
def _numel(self): def _numel(self):
if self._numel_cache is None: if self._numel_cache is None:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self._numel_cache = sum( self._numel_cache = sum(
2 * p.numel() if torch.is_complex(p) else p.numel() 2 * p.numel() if torch.is_complex(p) else p.numel()
for p in self._params for p in self._params

View File

@ -422,7 +422,7 @@ class LambdaLR(LRScheduler):
for idx, fn in enumerate(self.lr_lambdas): for idx, fn in enumerate(self.lr_lambdas):
if not isinstance(fn, types.FunctionType): if not isinstance(fn, types.FunctionType):
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
state_dict["lr_lambdas"][idx] = fn.__dict__.copy() state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
return state_dict return state_dict
@ -542,7 +542,7 @@ class MultiplicativeLR(LRScheduler):
for idx, fn in enumerate(self.lr_lambdas): for idx, fn in enumerate(self.lr_lambdas):
if not isinstance(fn, types.FunctionType): if not isinstance(fn, types.FunctionType):
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
state_dict["lr_lambdas"][idx] = fn.__dict__.copy() state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
return state_dict return state_dict
@ -1219,7 +1219,7 @@ class SequentialLR(LRScheduler):
state_dict["_schedulers"] = [None] * len(self._schedulers) state_dict["_schedulers"] = [None] * len(self._schedulers)
for idx, s in enumerate(self._schedulers): for idx, s in enumerate(self._schedulers):
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
state_dict["_schedulers"][idx] = s.state_dict() state_dict["_schedulers"][idx] = s.state_dict()
return state_dict return state_dict
@ -1562,7 +1562,7 @@ class ChainedScheduler(LRScheduler):
state_dict["_schedulers"] = [None] * len(self._schedulers) state_dict["_schedulers"] = [None] * len(self._schedulers)
for idx, s in enumerate(self._schedulers): for idx, s in enumerate(self._schedulers):
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
state_dict["_schedulers"][idx] = s.state_dict() state_dict["_schedulers"][idx] = s.state_dict()
return state_dict return state_dict
@ -1671,7 +1671,7 @@ class ReduceLROnPlateau(LRScheduler):
self.default_min_lr = None self.default_min_lr = None
self.min_lrs = list(min_lr) self.min_lrs = list(min_lr)
else: else:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.default_min_lr = min_lr self.default_min_lr = min_lr
self.min_lrs = [min_lr] * len(optimizer.param_groups) self.min_lrs = [min_lr] * len(optimizer.param_groups)
@ -1731,7 +1731,7 @@ class ReduceLROnPlateau(LRScheduler):
"of the `optimizer` param groups." "of the `optimizer` param groups."
) )
else: else:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups) self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
for i, param_group in enumerate(self.optimizer.param_groups): for i, param_group in enumerate(self.optimizer.param_groups):
@ -1911,13 +1911,13 @@ class CyclicLR(LRScheduler):
self.max_lrs = _format_param("max_lr", optimizer, max_lr) self.max_lrs = _format_param("max_lr", optimizer, max_lr)
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
step_size_up = float(step_size_up) step_size_up = float(step_size_up)
step_size_down = ( step_size_down = (
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
float(step_size_down) if step_size_down is not None else step_size_up float(step_size_down) if step_size_down is not None else step_size_up
) )
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
self.total_size = step_size_up + step_size_down self.total_size = step_size_up + step_size_down
self.step_ratio = step_size_up / self.total_size self.step_ratio = step_size_up / self.total_size

View File

@ -62,7 +62,7 @@ def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T: def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
import torch._dynamo import torch._dynamo
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
self = cast(Optimizer, args[0]) # assume first positional arg is `self` self = cast(Optimizer, args[0]) # assume first positional arg is `self`
prev_grad = torch.is_grad_enabled() prev_grad = torch.is_grad_enabled()
try: try:
@ -136,13 +136,13 @@ def _disable_dynamo_if_unsupported(
if torch.compiler.is_compiling() and ( if torch.compiler.is_compiling() and (
not kwargs.get("capturable", False) not kwargs.get("capturable", False)
and has_state_steps and has_state_steps
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
and (arg := args[state_steps_ind]) and (arg := args[state_steps_ind])
and isinstance(arg, Sequence) and isinstance(arg, Sequence)
and arg[0].is_cuda and arg[0].is_cuda
or ( or (
"state_steps" in kwargs "state_steps" in kwargs
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
and (kwarg := kwargs["state_steps"]) and (kwarg := kwargs["state_steps"])
and isinstance(kwarg, Sequence) and isinstance(kwarg, Sequence)
and kwarg[0].is_cuda and kwarg[0].is_cuda
@ -362,18 +362,18 @@ class Optimizer:
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook] _optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: dict[int, OptimizerPostHook] _optimizer_step_post_hooks: dict[int, OptimizerPostHook]
# pyrefly: ignore # not-a-type # pyrefly: ignore [not-a-type]
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]' _optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
_optimizer_state_dict_post_hooks: ( _optimizer_state_dict_post_hooks: (
# pyrefly: ignore # not-a-type # pyrefly: ignore [not-a-type]
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
) )
_optimizer_load_state_dict_pre_hooks: ( _optimizer_load_state_dict_pre_hooks: (
# pyrefly: ignore # not-a-type # pyrefly: ignore [not-a-type]
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]' 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
) )
_optimizer_load_state_dict_post_hooks: ( _optimizer_load_state_dict_post_hooks: (
# pyrefly: ignore # not-a-type # pyrefly: ignore [not-a-type]
'OrderedDict[int, Callable[["Optimizer"], None]]' 'OrderedDict[int, Callable[["Optimizer"], None]]'
) )
@ -522,7 +522,7 @@ class Optimizer:
f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}." f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
) )
# pyrefly: ignore # invalid-param-spec # pyrefly: ignore [invalid-param-spec]
out = func(*args, **kwargs) out = func(*args, **kwargs)
self._optimizer_step_code() self._optimizer_step_code()
@ -961,9 +961,9 @@ class Optimizer:
return Optimizer._process_value_according_to_param_policy( return Optimizer._process_value_according_to_param_policy(
param, param,
value, value,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
param_id, param_id,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
param_groups, param_groups,
key, key,
) )
@ -976,7 +976,7 @@ class Optimizer:
} }
elif isinstance(value, Iterable): elif isinstance(value, Iterable):
return type(value)( return type(value)(
# pyrefly: ignore # bad-argument-count # pyrefly: ignore [bad-argument-count]
_cast(param, v, param_id=param_id, param_groups=param_groups) _cast(param, v, param_id=param_id, param_groups=param_groups)
for v in value for v in value
) # type: ignore[call-arg] ) # type: ignore[call-arg]

View File

@ -323,7 +323,7 @@ def _single_tensor_radam(
rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2 rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
def _compute_rect(): def _compute_rect():
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
return ( return (
(rho_t - 4) (rho_t - 4)
* (rho_t - 2) * (rho_t - 2)
@ -338,7 +338,7 @@ def _single_tensor_radam(
else: else:
exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps) exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
return (bias_correction2**0.5) / exp_avg_sq_sqrt return (bias_correction2**0.5) / exp_avg_sq_sqrt
# Compute the variance rectification term and update parameters accordingly # Compute the variance rectification term and update parameters accordingly

View File

@ -348,7 +348,7 @@ def _single_tensor_sgd(
# usually this is the differentiable path, which is why the param.clone() is needed # usually this is the differentiable path, which is why the param.clone() is needed
grad = grad.addcmul_(param.clone(), weight_decay) grad = grad.addcmul_(param.clone(), weight_decay)
else: else:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
grad = grad.add(param, alpha=weight_decay) grad = grad.add(param, alpha=weight_decay)
else: else:
grad = grad.add(param, alpha=weight_decay) grad = grad.add(param, alpha=weight_decay)
@ -372,7 +372,7 @@ def _single_tensor_sgd(
if lr.requires_grad: if lr.requires_grad:
param.addcmul_(grad, lr, value=-1) param.addcmul_(grad, lr, value=-1)
else: else:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
param.add_(grad, alpha=-lr) param.add_(grad, alpha=-lr)
else: else:
param.add_(grad, alpha=-lr) param.add_(grad, alpha=-lr)

View File

@ -250,13 +250,13 @@ class AveragedModel(Module):
def update_parameters(self, model: Module): def update_parameters(self, model: Module):
"""Update model parameters.""" """Update model parameters."""
self_param = ( self_param = (
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
itertools.chain(self.module.parameters(), self.module.buffers()) itertools.chain(self.module.parameters(), self.module.buffers())
if self.use_buffers if self.use_buffers
else self.parameters() else self.parameters()
) )
model_param = ( model_param = (
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
itertools.chain(model.parameters(), model.buffers()) itertools.chain(model.parameters(), model.buffers())
if self.use_buffers if self.use_buffers
else model.parameters() else model.parameters()
@ -298,17 +298,17 @@ class AveragedModel(Module):
avg_fn = get_swa_avg_fn() avg_fn = get_swa_avg_fn()
n_averaged = self.n_averaged.to(device) n_averaged = self.n_averaged.to(device)
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment] for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged)) p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
else: else:
for p_averaged, p_model in zip( # type: ignore[assignment] for p_averaged, p_model in zip( # type: ignore[assignment]
self_param_detached, model_param_detached self_param_detached, model_param_detached
): ):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
n_averaged = self.n_averaged.to(p_averaged.device) n_averaged = self.n_averaged.to(p_averaged.device)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
p_averaged.detach().copy_( p_averaged.detach().copy_(
# pyrefly: ignore # missing-attribute, bad-argument-type # pyrefly: ignore [missing-attribute, bad-argument-type]
self.avg_fn(p_averaged.detach(), p_model, n_averaged) self.avg_fn(p_averaged.detach(), p_model, n_averaged)
) )
@ -497,14 +497,14 @@ class SWALR(LRScheduler):
step = self._step_count - 1 step = self._step_count - 1
if self.anneal_epochs == 0: if self.anneal_epochs == 0:
step = max(1, step) step = max(1, step)
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs))) prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
prev_alpha = self.anneal_func(prev_t) prev_alpha = self.anneal_func(prev_t)
prev_lrs = [ prev_lrs = [
self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha) self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
for group in self.optimizer.param_groups for group in self.optimizer.param_groups
] ]
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
t = max(0, min(1, step / max(1, self.anneal_epochs))) t = max(0, min(1, step / max(1, self.anneal_epochs)))
alpha = self.anneal_func(t) alpha = self.anneal_func(t)
return [ return [

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
# pyrefly: ignore # missing-module-attribute # pyrefly: ignore [missing-module-attribute]
from pickle import ( # type: ignore[attr-defined] from pickle import ( # type: ignore[attr-defined]
_compat_pickle, _compat_pickle,
_extension_registry, _extension_registry,

View File

@ -3,7 +3,7 @@ import importlib
import logging import logging
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
# pyrefly: ignore # missing-module-attribute # pyrefly: ignore [missing-module-attribute]
from pickle import ( # type: ignore[attr-defined] from pickle import ( # type: ignore[attr-defined]
_getattribute, _getattribute,
_Pickler, _Pickler,

View File

@ -652,7 +652,7 @@ class PackageExporter:
memo: defaultdict[int, str] = defaultdict(None) memo: defaultdict[int, str] = defaultdict(None)
memo_count = 0 memo_count = 0
# pickletools.dis(data_value) # pickletools.dis(data_value)
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
for opcode, arg, _pos in pickletools.genops(data_value): for opcode, arg, _pos in pickletools.genops(data_value):
if pickle_protocol == 4: if pickle_protocol == 4:
if ( if (

View File

@ -230,7 +230,7 @@ class SchemaMatcher:
for schema in cls.match_schemas(t): for schema in cls.match_schemas(t):
mutable = mutable or [False for _ in schema.arguments] mutable = mutable or [False for _ in schema.arguments]
for i, arg in enumerate(schema.arguments): for i, arg in enumerate(schema.arguments):
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
mutable[i] |= getattr(arg.alias_info, "is_write", False) mutable[i] |= getattr(arg.alias_info, "is_write", False)
return tuple(mutable or (None for _ in t.inputs)) return tuple(mutable or (None for _ in t.inputs))
@ -1084,7 +1084,7 @@ class MemoryProfileTimeline:
if action in (Action.PREEXISTING, Action.CREATE): if action in (Action.PREEXISTING, Action.CREATE):
raw_events.append( raw_events.append(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
( (
t, t,
_ACTION_TO_INDEX[action], _ACTION_TO_INDEX[action],
@ -1095,7 +1095,7 @@ class MemoryProfileTimeline:
elif action == Action.INCREMENT_VERSION: elif action == Action.INCREMENT_VERSION:
raw_events.append( raw_events.append(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
( (
t, t,
_ACTION_TO_INDEX[action], _ACTION_TO_INDEX[action],
@ -1104,7 +1104,7 @@ class MemoryProfileTimeline:
) )
) )
raw_events.append( raw_events.append(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
( (
t, t,
_ACTION_TO_INDEX[action], _ACTION_TO_INDEX[action],
@ -1115,7 +1115,7 @@ class MemoryProfileTimeline:
elif action == Action.DESTROY: elif action == Action.DESTROY:
raw_events.append( raw_events.append(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
( (
t, t,
_ACTION_TO_INDEX[action], _ACTION_TO_INDEX[action],

View File

@ -211,7 +211,7 @@ class BasicEvaluation:
# Find latest cuda kernel event # Find latest cuda kernel event
if hasattr(event, "start_us"): if hasattr(event, "start_us"):
start_time = event.start_us() * 1000 start_time = event.start_us() * 1000
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
end_time = (event.start_us() + event.duration_us()) * 1000 end_time = (event.start_us() + event.duration_us()) * 1000
# Find current spawned cuda kernel event # Find current spawned cuda kernel event
if event in kernel_mapping and kernel_mapping[event] is not None: if event in kernel_mapping and kernel_mapping[event] is not None:

View File

@ -161,19 +161,19 @@ class _KinetoProfile:
self.mem_tl: Optional[MemoryProfileTimeline] = None self.mem_tl: Optional[MemoryProfileTimeline] = None
self.use_device = None self.use_device = None
if ProfilerActivity.CUDA in self.activities: if ProfilerActivity.CUDA in self.activities:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.use_device = "cuda" self.use_device = "cuda"
elif ProfilerActivity.XPU in self.activities: elif ProfilerActivity.XPU in self.activities:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.use_device = "xpu" self.use_device = "xpu"
elif ProfilerActivity.MTIA in self.activities: elif ProfilerActivity.MTIA in self.activities:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.use_device = "mtia" self.use_device = "mtia"
elif ProfilerActivity.HPU in self.activities: elif ProfilerActivity.HPU in self.activities:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.use_device = "hpu" self.use_device = "hpu"
elif ProfilerActivity.PrivateUse1 in self.activities: elif ProfilerActivity.PrivateUse1 in self.activities:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.use_device = _get_privateuse1_backend_name() self.use_device = _get_privateuse1_backend_name()
# user-defined metadata to be amended to the trace # user-defined metadata to be amended to the trace
@ -385,7 +385,7 @@ class _KinetoProfile:
} }
if backend == "nccl": if backend == "nccl":
nccl_version = torch.cuda.nccl.version() nccl_version = torch.cuda.nccl.version()
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version) dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
return dist_info return dist_info

View File

@ -71,7 +71,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
nrows // 16, 16 nrows // 16, 16
) )
).view(-1) ).view(-1)
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
outp = outp.index_copy(1, cols_permuted, outp) outp = outp.index_copy(1, cols_permuted, outp)
# interleave_column_major_tensor # interleave_column_major_tensor

View File

@ -790,7 +790,7 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
# PyTorchFileWriter only supports ascii filename. # PyTorchFileWriter only supports ascii filename.
# For filenames with non-ascii characters, we rely on Python # For filenames with non-ascii characters, we rely on Python
# for writing out the file. # for writing out the file.
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.file_stream = io.FileIO(self.name, mode="w") self.file_stream = io.FileIO(self.name, mode="w")
super().__init__( super().__init__(
torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload

View File

@ -397,15 +397,15 @@ def kaiser(
) )
# Avoid NaNs by casting `beta` to the appropriate dtype. # Avoid NaNs by casting `beta` to the appropriate dtype.
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
beta = torch.tensor(beta, dtype=dtype, device=device) beta = torch.tensor(beta, dtype=dtype, device=device)
start = -beta start = -beta
constant = 2.0 * beta / (M if not sym else M - 1) constant = 2.0 * beta / (M if not sym else M - 1)
end = torch.minimum( end = torch.minimum(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
beta, beta,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
start + (M - 1) * constant, start + (M - 1) * constant,
) )
@ -420,7 +420,7 @@ def kaiser(
) )
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0( return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
beta beta
) )

View File

@ -623,20 +623,20 @@ def as_sparse_gradcheck(gradcheck):
) )
obj = obj.to_dense().sparse_mask(full_mask) obj = obj.to_dense().sparse_mask(full_mask)
if obj.layout is torch.sparse_coo: if obj.layout is torch.sparse_coo:
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
d.update( d.update(
indices=obj._indices(), is_coalesced=obj.is_coalesced() indices=obj._indices(), is_coalesced=obj.is_coalesced()
) )
values = obj._values() values = obj._values()
elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}: elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
d.update( d.update(
compressed_indices=obj.crow_indices(), compressed_indices=obj.crow_indices(),
plain_indices=obj.col_indices(), plain_indices=obj.col_indices(),
) )
values = obj.values() values = obj.values()
else: else:
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
d.update( d.update(
compressed_indices=obj.ccol_indices(), compressed_indices=obj.ccol_indices(),
plain_indices=obj.row_indices(), plain_indices=obj.row_indices(),

View File

@ -140,7 +140,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
if dense.dtype != torch.float: if dense.dtype != torch.float:
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined] sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1)) sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2) sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
else: else:
@ -173,7 +173,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
meta_offsets = _calculate_meta_reordering_scatter_offsets( meta_offsets = _calculate_meta_reordering_scatter_offsets(
m, meta_ncols, meta_dtype, device m, meta_ncols, meta_dtype, device
) )
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
meta_reordered.scatter_(0, meta_offsets, meta.view(-1)) meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
return (sparse, meta_reordered.view(m, meta_ncols)) return (sparse, meta_reordered.view(m, meta_ncols))

View File

@ -67,7 +67,7 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
# Because we cannot go from the compressed representation back to the dense representation currently, # Because we cannot go from the compressed representation back to the dense representation currently,
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix # we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
# is the first or second argument, we expect an even / odd number of calls to transpose respectively. # is the first or second argument, we expect an even / odd number of calls to transpose respectively.
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
return self.__class__( return self.__class__(
torch.Size([self.shape[-1], self.shape[0]]), torch.Size([self.shape[-1], self.shape[0]]),
packed=self.packed_t, packed=self.packed_t,

View File

@ -1297,7 +1297,7 @@ def bsr_dense_addmm(
assert alpha != 0 assert alpha != 0
def kernel(grid, *sliced_tensors): def kernel(grid, *sliced_tensors):
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
_bsr_strided_addmm_kernel[grid]( _bsr_strided_addmm_kernel[grid](
*ptr_stride_extractor(*sliced_tensors), *ptr_stride_extractor(*sliced_tensors),
beta, beta,
@ -1427,7 +1427,7 @@ if has_triton():
mat1_block = tl.load( mat1_block = tl.load(
mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :], mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :],
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
mask=mask_k[None, :], mask=mask_k[None, :],
other=0.0, other=0.0,
) )
@ -1436,7 +1436,7 @@ if has_triton():
mat2_block_ptrs mat2_block_ptrs
+ mat2_tiled_col_stride * col_block + mat2_tiled_col_stride * col_block
+ mat2_row_block_stride * k_offsets[:, None], + mat2_row_block_stride * k_offsets[:, None],
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
mask=mask_k[:, None], mask=mask_k[:, None],
other=0.0, other=0.0,
) )
@ -1974,7 +1974,7 @@ if has_triton():
if attn_mask.dtype is not torch.bool: if attn_mask.dtype is not torch.bool:
check_dtype(f_name, attn_mask, query.dtype) check_dtype(f_name, attn_mask, query.dtype)
# pyrefly: ignore # not-callable # pyrefly: ignore [not-callable]
sdpa = sampled_addmm( sdpa = sampled_addmm(
attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False
) )
@ -1986,10 +1986,10 @@ if has_triton():
) )
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
sdpa.values().mul_(scale_factor) sdpa.values().mul_(scale_factor)
# pyrefly: ignore # not-callable # pyrefly: ignore [not-callable]
sdpa = bsr_softmax(sdpa) sdpa = bsr_softmax(sdpa)
torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True) torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
# pyrefly: ignore # not-callable # pyrefly: ignore [not-callable]
sdpa = bsr_dense_mm(sdpa, value) sdpa = bsr_dense_mm(sdpa, value)
return sdpa return sdpa

View File

@ -234,10 +234,10 @@ def dump():
part2 = current_content[end_data_index:] part2 = current_content[end_data_index:]
data_part = [] data_part = []
for op_key in sorted(_operation_device_version_data, key=sort_key): for op_key in sorted(_operation_device_version_data, key=sort_key):
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
data_part.append(" " + repr(op_key).replace("'", '"') + ": {") data_part.append(" " + repr(op_key).replace("'", '"') + ": {")
op_data = _operation_device_version_data[op_key] op_data = _operation_device_version_data[op_key]
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
data_part.extend(f" {key}: {op_data[key]}," for key in sorted(op_data)) data_part.extend(f" {key}: {op_data[key]}," for key in sorted(op_data))
data_part.append(" },") data_part.append(" },")
new_content = part1 + "\n".join(data_part) + "\n" + part2 new_content = part1 + "\n".join(data_part) + "\n" + part2
@ -371,7 +371,7 @@ def minimize(
if next_target < minimal_target: if next_target < minimal_target:
minimal_target = next_target minimal_target = next_target
parameters = next_parameters parameters = next_parameters
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
pbar.total += i + 1 pbar.total += i + 1
break break
else: else:

View File

@ -185,7 +185,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
outer_stride, outer_stride,
) -> torch.Tensor: ) -> torch.Tensor:
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
return cls( return cls(
shape=shape, shape=shape,
packed=inner_tensors.get("packed", None), packed=inner_tensors.get("packed", None),
@ -415,7 +415,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
sparse_tensor_cutlass, sparse_tensor_cutlass,
meta_tensor_cutlass, meta_tensor_cutlass,
) = sparse_semi_structured_from_dense_cutlass(original_tensor) ) = sparse_semi_structured_from_dense_cutlass(original_tensor)
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
return cls( return cls(
original_tensor.shape, original_tensor.shape,
packed=sparse_tensor_cutlass, packed=sparse_tensor_cutlass,
@ -502,7 +502,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
original_tensor, algorithm=algorithm, use_cutlass=True original_tensor, algorithm=algorithm, use_cutlass=True
) )
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
return cls( return cls(
original_tensor.shape, original_tensor.shape,
packed=packed, packed=packed,
@ -564,7 +564,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
cls, original_tensor: torch.Tensor cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUSPARSELT": ) -> "SparseSemiStructuredTensorCUSPARSELT":
cls._validate_device_dim_dtype_shape(original_tensor) cls._validate_device_dim_dtype_shape(original_tensor)
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
return cls( return cls(
shape=original_tensor.shape, shape=original_tensor.shape,
packed=torch._cslt_compress(original_tensor), packed=torch._cslt_compress(original_tensor),
@ -631,7 +631,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
packed = packed.view(original_tensor.shape[0], -1) packed = packed.view(original_tensor.shape[0], -1)
packed_t = packed_t.view(original_tensor.shape[1], -1) packed_t = packed_t.view(original_tensor.shape[1], -1)
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
return cls( return cls(
original_tensor.shape, original_tensor.shape,
packed=packed, packed=packed,

View File

@ -2,6 +2,6 @@ from torch._C import FileCheck as FileCheck
from . import _utils from . import _utils
# pyrefly: ignore # deprecated # pyrefly: ignore [deprecated]
from ._comparison import assert_allclose, assert_close as assert_close from ._comparison import assert_allclose, assert_close as assert_close
from ._creation import make_tensor as make_tensor from ._creation import make_tensor as make_tensor

View File

@ -243,7 +243,7 @@ def make_scalar_mismatch_msg(
Defaults to "Scalars". Defaults to "Scalars".
""" """
abs_diff = abs(actual - expected) abs_diff = abs(actual - expected)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected) rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected)
return _make_mismatch_msg( return _make_mismatch_msg(
default_identifier="Scalars", default_identifier="Scalars",
@ -487,7 +487,7 @@ class BooleanPair(Pair):
def _supported_types(self) -> tuple[type, ...]: def _supported_types(self) -> tuple[type, ...]:
cls: list[type] = [bool] cls: list[type] = [bool]
if HAS_NUMPY: if HAS_NUMPY:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
cls.append(np.bool_) cls.append(np.bool_)
return tuple(cls) return tuple(cls)
@ -503,7 +503,7 @@ class BooleanPair(Pair):
def _to_bool(self, bool_like: Any, *, id: tuple[Any, ...]) -> bool: def _to_bool(self, bool_like: Any, *, id: tuple[Any, ...]) -> bool:
if isinstance(bool_like, bool): if isinstance(bool_like, bool):
return bool_like return bool_like
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
elif isinstance(bool_like, np.bool_): elif isinstance(bool_like, np.bool_):
return bool_like.item() return bool_like.item()
else: else:
@ -583,7 +583,7 @@ class NumberPair(Pair):
def _supported_types(self) -> tuple[type, ...]: def _supported_types(self) -> tuple[type, ...]:
cls = list(self._NUMBER_TYPES) cls = list(self._NUMBER_TYPES)
if HAS_NUMPY: if HAS_NUMPY:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
cls.append(np.number) cls.append(np.number)
return tuple(cls) return tuple(cls)
@ -599,7 +599,7 @@ class NumberPair(Pair):
def _to_number( def _to_number(
self, number_like: Any, *, id: tuple[Any, ...] self, number_like: Any, *, id: tuple[Any, ...]
) -> Union[int, float, complex]: ) -> Union[int, float, complex]:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if HAS_NUMPY and isinstance(number_like, np.number): if HAS_NUMPY and isinstance(number_like, np.number):
return number_like.item() return number_like.item()
elif isinstance(number_like, self._NUMBER_TYPES): elif isinstance(number_like, self._NUMBER_TYPES):
@ -1122,7 +1122,7 @@ def originate_pairs(
mapping_types: tuple[type, ...] = (collections.abc.Mapping,), mapping_types: tuple[type, ...] = (collections.abc.Mapping,),
id: tuple[Any, ...] = (), id: tuple[Any, ...] = (),
**options: Any, **options: Any,
# pyrefly: ignore # bad-return # pyrefly: ignore [bad-return]
) -> list[Pair]: ) -> list[Pair]:
"""Originates pairs from the individual inputs. """Originates pairs from the individual inputs.
@ -1221,7 +1221,7 @@ def originate_pairs(
else: else:
for pair_type in pair_types: for pair_type in pair_types:
try: try:
# pyrefly: ignore # bad-instantiation # pyrefly: ignore [bad-instantiation]
return [pair_type(actual, expected, id=id, **options)] return [pair_type(actual, expected, id=id, **options)]
# Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the # Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the
# inputs. Thus, we try the next pair type. # inputs. Thus, we try the next pair type.
@ -1319,9 +1319,9 @@ def not_close_error_metas(
# would not get freed until cycle collection, leaking cuda memory in tests. # would not get freed until cycle collection, leaking cuda memory in tests.
# We break the cycle by removing the reference to the error_meta objects # We break the cycle by removing the reference to the error_meta objects
# from this frame as it returns. # from this frame as it returns.
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
error_metas = [error_metas] error_metas = [error_metas]
# pyrefly: ignore # bad-return # pyrefly: ignore [bad-return]
return error_metas.pop() return error_metas.pop()