mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix syntax for pyrefly errors (#166496)
Last one! This ensures all existing suppressions match the syntax expected and will silence only one error code pyrefly check lintrunner Pull Request resolved: https://github.com/pytorch/pytorch/pull/166496 Approved by: https://github.com/Skylion007, https://github.com/mlazos
This commit is contained in:
parent
fa560e1158
commit
d1a6e006e0
|
|
@ -841,7 +841,7 @@ def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr:
|
|||
factor = functools.reduce(math.gcd, map(integer_coefficient, atoms))
|
||||
if factor == 1:
|
||||
return expr
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
atoms = [div_by_factor(x, factor) for x in atoms]
|
||||
return _sympy_from_args(
|
||||
sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative
|
||||
|
|
@ -2207,7 +2207,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext):
|
|||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
if self.inner_contexts is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.inner_contexts = {}
|
||||
|
||||
|
||||
|
|
@ -2296,12 +2296,12 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
|
|||
# only re-create the objects if any of the args changed to avoid expensive
|
||||
# checks when re-creating objects.
|
||||
new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return _fast_expand(expr.func(*new_args))
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if expr.is_Pow:
|
||||
base: sympy.Expr
|
||||
exp: sympy.Expr
|
||||
|
|
@ -2311,11 +2311,11 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
|
|||
return sympy.expand_multinomial(expr, deep=False)
|
||||
elif exp < 0:
|
||||
return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
elif expr.is_Mul:
|
||||
num: list[sympy.Expr] = []
|
||||
den: list[sympy.Expr] = []
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for arg in expr.args:
|
||||
if arg.is_Pow and arg.args[1] == -1:
|
||||
den.append(S.One / arg) # type: ignore[operator, arg-type]
|
||||
|
|
@ -2437,7 +2437,7 @@ def _maybe_evaluate_static_worker(
|
|||
|
||||
# TODO: remove this try catch (esp for unbacked_only)
|
||||
try:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
new_expr = expr.xreplace(new_shape_env)
|
||||
except RecursionError:
|
||||
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
|
||||
|
|
@ -2975,19 +2975,19 @@ class DimConstraints:
|
|||
# is_integer tests though haha
|
||||
return (base - mod_reduced) / divisor
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if expr.has(Mod):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
expr = expr.replace(Mod, mod_handler)
|
||||
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
|
||||
# arguments should be OK.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if expr.has(PythonMod):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
expr = expr.replace(PythonMod, mod_handler)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if expr.has(FloorDiv):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
expr = expr.replace(FloorDiv, floor_div_handler)
|
||||
return expr
|
||||
|
||||
|
|
@ -5106,7 +5106,7 @@ class ShapeEnv:
|
|||
|
||||
if duck:
|
||||
# Make sure to reuse this symbol for subsequent duck shaping
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
self.val_to_var[val] = sympy_expr
|
||||
|
||||
if isinstance(val, int):
|
||||
|
|
@ -5338,9 +5338,9 @@ class ShapeEnv:
|
|||
|
||||
# Expand optional inputs, or verify invariants are upheld
|
||||
if input_contexts is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
input_contexts = [
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_create_no_constraints_context(t) if isinstance(t, Tensorlike) else None
|
||||
for t in placeholders
|
||||
]
|
||||
|
|
@ -5350,7 +5350,7 @@ class ShapeEnv:
|
|||
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
|
||||
if isinstance(t, Tensorlike):
|
||||
if context is None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
input_contexts[i] = _create_no_constraints_context(t)
|
||||
else:
|
||||
assert isinstance(t, (SymInt, int, SymFloat, float))
|
||||
|
|
@ -5636,7 +5636,7 @@ class ShapeEnv:
|
|||
s = sympy.Float(val)
|
||||
input_guards.append((source, s))
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
for t, source, context in zip(placeholders, sources, input_contexts):
|
||||
if isinstance(source, str):
|
||||
from torch._dynamo.source import LocalSource
|
||||
|
|
@ -5999,7 +5999,7 @@ class ShapeEnv:
|
|||
else:
|
||||
str_msg = f" - {msg_cb()}"
|
||||
error_msgs.append(str_msg)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
debug_names.add(debug_name)
|
||||
if len(error_msgs) > 0:
|
||||
debug_names_str = ", ".join(sorted(debug_names))
|
||||
|
|
@ -6133,7 +6133,7 @@ class ShapeEnv:
|
|||
Get a list of guards, but pruned so it only provides guards that
|
||||
reference symints from the passed in input
|
||||
"""
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
symints = {
|
||||
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
|
||||
}
|
||||
|
|
@ -6396,7 +6396,7 @@ class ShapeEnv:
|
|||
Apply symbol replacements to any symbols in the given expression.
|
||||
"""
|
||||
replacements = {}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for s in expr.free_symbols:
|
||||
r = self._find(s)
|
||||
|
||||
|
|
@ -6406,7 +6406,7 @@ class ShapeEnv:
|
|||
if not r.is_Symbol or r != s:
|
||||
replacements[s] = r
|
||||
if replacements:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return safe_expand(expr.xreplace(replacements))
|
||||
else:
|
||||
return expr
|
||||
|
|
@ -7181,7 +7181,7 @@ class ShapeEnv:
|
|||
instructions = list(dis.Bytecode(frame.f_code))
|
||||
co_lines, offset = inspect.getsourcelines(frame.f_code)
|
||||
start, end, cur = None, None, None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for i, instr in enumerate(instructions):
|
||||
if instr.starts_line is not None:
|
||||
cur = instr.starts_line
|
||||
|
|
|
|||
|
|
@ -238,7 +238,7 @@ class Dispatcher:
|
|||
"To use a variadic union type place the desired types "
|
||||
"inside of a tuple, e.g., [(int, str)]"
|
||||
)
|
||||
# pyrefly: ignore # bad-specialization
|
||||
# pyrefly: ignore [bad-specialization]
|
||||
new_signature.append(Variadic[typ[0]])
|
||||
else:
|
||||
new_signature.append(typ)
|
||||
|
|
@ -407,7 +407,7 @@ class MethodDispatcher(Dispatcher):
|
|||
Dispatcher
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
__slots__ = ("obj", "cls")
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -298,7 +298,7 @@ def update_in(d, keys, func, default=None, factory=dict):
|
|||
rv = inner = factory()
|
||||
rv.update(d)
|
||||
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
for key in ks:
|
||||
if k in d:
|
||||
d = d[k]
|
||||
|
|
|
|||
|
|
@ -1380,7 +1380,7 @@ class Graph:
|
|||
f(to_erase)
|
||||
|
||||
self._find_nodes_lookup_table.remove(to_erase)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
to_erase._remove_from_list()
|
||||
to_erase._erased = True # iterators may retain handles to erased nodes
|
||||
self._len -= 1
|
||||
|
|
@ -1941,7 +1941,7 @@ class Graph:
|
|||
"a str is expected"
|
||||
)
|
||||
if node.op in ["get_attr", "call_module"]:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
target_atoms = node.target.split(".")
|
||||
m_itr = self.owning_module
|
||||
for i, atom in enumerate(target_atoms):
|
||||
|
|
|
|||
|
|
@ -535,7 +535,7 @@ class GraphModule(torch.nn.Module):
|
|||
self.graph._tracer_cls
|
||||
and "<locals>" not in self.graph._tracer_cls.__qualname__
|
||||
):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self._tracer_cls = self.graph._tracer_cls
|
||||
|
||||
self._tracer_extras = {}
|
||||
|
|
|
|||
|
|
@ -165,12 +165,12 @@ def tensorify_python_scalars(
|
|||
|
||||
node = graph.call_function(
|
||||
torch.ops.aten.scalar_tensor.default,
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
(c,),
|
||||
{"dtype": dtype},
|
||||
)
|
||||
with fake_mode:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype)
|
||||
expr_to_tensor_proxy[expr] = MetaProxy(
|
||||
node,
|
||||
|
|
@ -223,13 +223,13 @@ def tensorify_python_scalars(
|
|||
expr_to_sym_proxy[s] = MetaProxy(
|
||||
node, tracer=tracer, fake_mode=fake_mode
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
elif (sym_expr := _get_sym_val(node)) is not None:
|
||||
if sym_expr not in expr_to_sym_proxy and not isinstance(
|
||||
sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
|
||||
):
|
||||
expr_to_sym_proxy[sym_expr] = MetaProxy(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
node,
|
||||
tracer=tracer,
|
||||
fake_mode=fake_mode,
|
||||
|
|
@ -238,7 +238,7 @@ def tensorify_python_scalars(
|
|||
# Specialize all dimensions that contain symfloats. Here's
|
||||
# an example test that requires this:
|
||||
# PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
val = node.meta.get("val")
|
||||
if isinstance(val, FakeTensor):
|
||||
for dim in val.shape:
|
||||
|
|
@ -257,17 +257,17 @@ def tensorify_python_scalars(
|
|||
should_restart = True
|
||||
|
||||
# Look for functions to convert
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if node.op == "call_function" and (
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
replacement_op := SUPPORTED_OPS.get(node.target)
|
||||
):
|
||||
args: list[Any] = []
|
||||
transform = False
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
compute_dtype = get_computation_dtype(node.meta["val"].dtype)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for a in node.args:
|
||||
if (
|
||||
isinstance(a, fx.Node)
|
||||
|
|
@ -304,7 +304,7 @@ def tensorify_python_scalars(
|
|||
if transform:
|
||||
replacement_proxy = replacement_op(*args)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if compute_dtype != node.meta["val"].dtype:
|
||||
replacement_proxy = (
|
||||
torch.ops.prims.convert_element_type.default(
|
||||
|
|
@ -313,9 +313,9 @@ def tensorify_python_scalars(
|
|||
)
|
||||
)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
node.replace_all_uses_with(replacement_proxy.node)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
graph.erase_node(node)
|
||||
|
||||
metrics_context = get_metrics_context()
|
||||
|
|
@ -324,16 +324,16 @@ def tensorify_python_scalars(
|
|||
"tensorify_float_success", True, overwrite=True
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for a in node.args:
|
||||
if (
|
||||
isinstance(a, fx.Node)
|
||||
and "val" in a.meta
|
||||
and isinstance(zf := a.meta["val"], torch.SymFloat)
|
||||
):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
failed_tensorify_ops.update(str(node.target))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
log.info("Failed to tensorify %s", str(node.target))
|
||||
|
||||
# Now do one more pass that specializes all symfloats we didn't manage
|
||||
|
|
|
|||
|
|
@ -437,13 +437,13 @@ if HAS_PYDOT:
|
|||
)
|
||||
current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment]
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
current_graph.add_node(dot_node)
|
||||
|
||||
def get_module_params_or_buffers():
|
||||
for pname, ptensor in chain(
|
||||
leaf_module.named_parameters(),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
leaf_module.named_buffers(),
|
||||
):
|
||||
pname1 = node.name + "." + pname
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ __all__ = ["PassResult", "PassBase"]
|
|||
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
|
||||
"""
|
||||
Result of a pass:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def pass_result_wrapper(fn: Callable) -> Callable:
|
|||
wrapped_fn (Callable[Module, PassResult])
|
||||
"""
|
||||
if fn is None:
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return None
|
||||
|
||||
@wraps(fn)
|
||||
|
|
|
|||
|
|
@ -396,25 +396,25 @@ class _MinimizerBase:
|
|||
report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
|
||||
if self.module_exporter:
|
||||
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
result_key = result_key[-1]
|
||||
# If the result is still a tuple (happens in non-sequential mode),
|
||||
# we only use the first element as name.
|
||||
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
result_key = str(result_key[0])
|
||||
# pyre-ignore[29]: not a function
|
||||
self.module_exporter(
|
||||
a_input,
|
||||
submodule,
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
result_key + "_cpu",
|
||||
)
|
||||
# pyre-ignore[29]: not a function
|
||||
self.module_exporter(
|
||||
b_input,
|
||||
submodule,
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
result_key + "_acc",
|
||||
)
|
||||
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
|
||||
|
|
|
|||
|
|
@ -360,7 +360,7 @@ def insert_deferred_runtime_asserts(
|
|||
):
|
||||
# this guards against deleting calls like item() that produce new untracked symbols
|
||||
def has_new_untracked_symbols():
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for symbol in sym_expr.free_symbols:
|
||||
if symbol not in expr_to_proxy:
|
||||
return True
|
||||
|
|
@ -376,7 +376,7 @@ def insert_deferred_runtime_asserts(
|
|||
assert resolved_unbacked_bindings is not None
|
||||
|
||||
def has_new_unbacked_bindings():
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for key in resolved_unbacked_bindings.keys():
|
||||
if key not in expr_to_proxy:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -351,9 +351,9 @@ def split_module(
|
|||
|
||||
assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
|
|
@ -418,9 +418,9 @@ def split_module(
|
|||
for regions_mapping in [autocast_regions, grad_regions]:
|
||||
for node, regions in regions_mapping.items():
|
||||
assert len(regions) > 0
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
partitions[str(regions[0])].environment[node] = node
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
for r in regions[1:]:
|
||||
partition = partitions[str(r)]
|
||||
new_node = partition.graph.create_node(
|
||||
|
|
@ -520,7 +520,7 @@ def split_module(
|
|||
for node in reversed(regions_mapping):
|
||||
regions = regions_mapping[node]
|
||||
assert len(regions) > 0
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
for r in regions[:-1]:
|
||||
partition = partitions[str(r)]
|
||||
exit_node = autocast_exits[node]
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ def lift_subgraph_as_module(
|
|||
|
||||
for name in target_name_parts[:-1]:
|
||||
if not hasattr(curr, name):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
curr.add_module(name, HolderModule({}))
|
||||
|
||||
curr = getattr(curr, name)
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ class Library:
|
|||
|
||||
if dispatch_key == "":
|
||||
dispatch_key = self.dispatch_key
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
|
||||
|
||||
if isinstance(op_name, str):
|
||||
|
|
|
|||
|
|
@ -484,7 +484,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
|
|||
raise IndexError(
|
||||
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
dims.append(d % ndim)
|
||||
return tuple(sorted(dims))
|
||||
|
||||
|
|
@ -1017,7 +1017,7 @@ def _combine_input_and_mask(
|
|||
|
||||
class Combine(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, input, mask):
|
||||
"""Return input with masked-out elements eliminated for the given operations."""
|
||||
ctx.save_for_backward(mask)
|
||||
|
|
@ -1028,7 +1028,7 @@ def _combine_input_and_mask(
|
|||
return helper(input, mask)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
(mask,) = ctx.saved_tensors
|
||||
grad_data = (
|
||||
|
|
@ -1403,18 +1403,18 @@ elements, have ``nan`` values.
|
|||
if input.layout == torch.strided:
|
||||
if mask is None:
|
||||
# TODO: compute count analytically
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
count = sum(
|
||||
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
||||
dim,
|
||||
keepdim=keepdim,
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
total = sum(input, dim, keepdim=keepdim, dtype=dtype)
|
||||
else:
|
||||
inmask = _input_mask(input, mask=mask)
|
||||
count = inmask.sum(dim=dim, keepdim=bool(keepdim))
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
|
||||
return total / count
|
||||
elif input.layout == torch.sparse_csr:
|
||||
|
|
@ -1625,18 +1625,18 @@ def _std_var(
|
|||
if input.layout == torch.strided:
|
||||
if mask is None:
|
||||
# TODO: compute count analytically
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
count = sum(
|
||||
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
||||
dim,
|
||||
keepdim=True,
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
sample_total = sum(input, dim, keepdim=True, dtype=dtype)
|
||||
else:
|
||||
inmask = _input_mask(input, mask=mask)
|
||||
count = inmask.sum(dim=dim, keepdim=True)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
|
||||
# TODO: replace torch.subtract/divide/square/maximum with
|
||||
# masked subtract/divide/square/maximum when these will be
|
||||
|
|
@ -1644,7 +1644,7 @@ def _std_var(
|
|||
sample_mean = torch.divide(sample_total, count)
|
||||
x = torch.subtract(input, sample_mean)
|
||||
if mask is None:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
|
||||
else:
|
||||
total = sum(
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ def _check_args_kwargs_length(
|
|||
|
||||
class _MaskedContiguous(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, input):
|
||||
if not is_masked_tensor(input):
|
||||
raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
|
||||
|
|
@ -61,14 +61,14 @@ class _MaskedContiguous(torch.autograd.Function):
|
|||
return MaskedTensor(data.contiguous(), mask.contiguous())
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output
|
||||
|
||||
|
||||
class _MaskedToDense(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, input):
|
||||
if not is_masked_tensor(input):
|
||||
raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
|
||||
|
|
@ -83,7 +83,7 @@ class _MaskedToDense(torch.autograd.Function):
|
|||
return MaskedTensor(data.to_dense(), mask.to_dense())
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
layout = ctx.layout
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ class _MaskedToDense(torch.autograd.Function):
|
|||
|
||||
class _MaskedToSparse(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, input):
|
||||
if not is_masked_tensor(input):
|
||||
raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
|
||||
|
|
@ -115,14 +115,14 @@ class _MaskedToSparse(torch.autograd.Function):
|
|||
return MaskedTensor(sparse_data, sparse_mask)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.to_dense()
|
||||
|
||||
|
||||
class _MaskedToSparseCsr(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, input):
|
||||
if not is_masked_tensor(input):
|
||||
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
|
||||
|
|
@ -143,21 +143,21 @@ class _MaskedToSparseCsr(torch.autograd.Function):
|
|||
return MaskedTensor(sparse_data, sparse_mask)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output.to_dense()
|
||||
|
||||
|
||||
class _MaskedWhere(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, cond, self, other):
|
||||
ctx.mark_non_differentiable(cond)
|
||||
ctx.save_for_backward(cond)
|
||||
return torch.ops.aten.where(cond, self, other)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
(cond,) = ctx.saved_tensors
|
||||
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ class MaskedTensor(torch.Tensor):
|
|||
UserWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
|
||||
|
||||
def _preprocess_data(self, data, mask):
|
||||
|
|
@ -244,12 +244,12 @@ class MaskedTensor(torch.Tensor):
|
|||
|
||||
class Constructor(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, data, mask):
|
||||
return MaskedTensor(data, mask)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output, None
|
||||
|
||||
|
|
@ -336,12 +336,12 @@ class MaskedTensor(torch.Tensor):
|
|||
def get_data(self):
|
||||
class GetData(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, self):
|
||||
return self._masked_data.detach()
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
if is_masked_tensor(grad_output):
|
||||
return grad_output
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class ProcessContext:
|
|||
"""Attempt to join all processes with a shared timeout."""
|
||||
end = time.monotonic() + timeout
|
||||
for process in self.processes:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
time_to_wait = max(0, end - time.monotonic())
|
||||
process.join(time_to_wait)
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
|
|||
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
|
||||
# For other dims, subtract 1 to convert to inner space.
|
||||
return (
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
ragged_dim - 1 if dim == 0 else dim - 1
|
||||
)
|
||||
|
||||
|
|
@ -2008,7 +2008,7 @@ def index_put_(func, *args, **kwargs):
|
|||
else:
|
||||
lengths = inp.lengths()
|
||||
torch._assert_async(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
torch.all(indices[inp._ragged_idx] < lengths),
|
||||
"Some indices in the ragged dimension are out of bounds!",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -668,5 +668,5 @@ def _get_numa_node_indices_for_socket_index(*, socket_index: int) -> set[int]:
|
|||
|
||||
def _get_allowed_cpu_indices_for_current_thread() -> set[int]:
|
||||
# 0 denotes current thread
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return os.sched_getaffinity(0) # type:ignore[attr-defined]
|
||||
|
|
|
|||
|
|
@ -251,7 +251,7 @@ def _compare_onnx_pytorch_outputs_in_np(
|
|||
# pyrefly: ignore [missing-attribute]
|
||||
if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8:
|
||||
warnings.warn("ONNX output is quantized", stacklevel=2)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8:
|
||||
warnings.warn("PyTorch output is quantized", stacklevel=2)
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ def _adjust_lr(
|
|||
A, B = param_shape[:2]
|
||||
|
||||
if adjust_lr_fn is None or adjust_lr_fn == "original":
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
adjusted_ratio = math.sqrt(max(1, A / B))
|
||||
elif adjust_lr_fn == "match_rms_adamw":
|
||||
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
|
||||
|
|
|
|||
|
|
@ -423,7 +423,7 @@ def _single_tensor_adam(
|
|||
if weight_decay.requires_grad:
|
||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
|
|
|||
|
|
@ -264,7 +264,7 @@ def _single_tensor_asgd(
|
|||
ax.copy_(param)
|
||||
|
||||
if capturable:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
|
||||
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -113,16 +113,16 @@ def _strong_wolfe(
|
|||
|
||||
# compute new trial value
|
||||
t = _cubic_interpolate(
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
bracket[0],
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
bracket_f[0],
|
||||
bracket_gtd[0], # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
bracket[1],
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
bracket_f[1],
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
bracket_gtd[1],
|
||||
)
|
||||
|
||||
|
|
@ -133,20 +133,20 @@ def _strong_wolfe(
|
|||
# + `t` is at one of the boundary,
|
||||
# we will move `t` to a position which is `0.1 * len(bracket)`
|
||||
# away from the nearest boundary point.
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
eps = 0.1 * (max(bracket) - min(bracket))
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if min(max(bracket) - t, t - min(bracket)) < eps:
|
||||
# interpolation close to boundary
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if insuf_progress or t >= max(bracket) or t <= min(bracket):
|
||||
# evaluate at 0.1 away from boundary
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if abs(t - max(bracket)) < abs(t - min(bracket)):
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
t = max(bracket) - eps
|
||||
else:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
t = min(bracket) + eps
|
||||
insuf_progress = False
|
||||
else:
|
||||
|
|
@ -160,45 +160,45 @@ def _strong_wolfe(
|
|||
gtd_new = g_new.dot(d)
|
||||
ls_iter += 1
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
|
||||
# Armijo condition not satisfied or not lower than lowest point
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
bracket[high_pos] = t
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
bracket_f[high_pos] = f_new
|
||||
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
bracket_gtd[high_pos] = gtd_new
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
|
||||
else:
|
||||
if abs(gtd_new) <= -c2 * gtd:
|
||||
# Wolfe conditions satisfied
|
||||
done = True
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
|
||||
# old high becomes new low
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
bracket[high_pos] = bracket[low_pos]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
bracket_f[high_pos] = bracket_f[low_pos]
|
||||
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
bracket_gtd[high_pos] = bracket_gtd[low_pos]
|
||||
|
||||
# new point becomes new low
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
bracket[low_pos] = t
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
bracket_f[low_pos] = f_new
|
||||
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
bracket_gtd[low_pos] = gtd_new
|
||||
|
||||
# return stuff
|
||||
t = bracket[low_pos] # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
f_new = bracket_f[low_pos]
|
||||
g_new = bracket_g[low_pos] # type: ignore[possibly-undefined]
|
||||
return f_new, g_new, t, ls_func_evals
|
||||
|
|
@ -276,7 +276,7 @@ class LBFGS(Optimizer):
|
|||
|
||||
def _numel(self):
|
||||
if self._numel_cache is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self._numel_cache = sum(
|
||||
2 * p.numel() if torch.is_complex(p) else p.numel()
|
||||
for p in self._params
|
||||
|
|
|
|||
|
|
@ -422,7 +422,7 @@ class LambdaLR(LRScheduler):
|
|||
|
||||
for idx, fn in enumerate(self.lr_lambdas):
|
||||
if not isinstance(fn, types.FunctionType):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
|
||||
|
||||
return state_dict
|
||||
|
|
@ -542,7 +542,7 @@ class MultiplicativeLR(LRScheduler):
|
|||
|
||||
for idx, fn in enumerate(self.lr_lambdas):
|
||||
if not isinstance(fn, types.FunctionType):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
|
||||
|
||||
return state_dict
|
||||
|
|
@ -1219,7 +1219,7 @@ class SequentialLR(LRScheduler):
|
|||
state_dict["_schedulers"] = [None] * len(self._schedulers)
|
||||
|
||||
for idx, s in enumerate(self._schedulers):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
state_dict["_schedulers"][idx] = s.state_dict()
|
||||
|
||||
return state_dict
|
||||
|
|
@ -1562,7 +1562,7 @@ class ChainedScheduler(LRScheduler):
|
|||
state_dict["_schedulers"] = [None] * len(self._schedulers)
|
||||
|
||||
for idx, s in enumerate(self._schedulers):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
state_dict["_schedulers"][idx] = s.state_dict()
|
||||
|
||||
return state_dict
|
||||
|
|
@ -1671,7 +1671,7 @@ class ReduceLROnPlateau(LRScheduler):
|
|||
self.default_min_lr = None
|
||||
self.min_lrs = list(min_lr)
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.default_min_lr = min_lr
|
||||
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
||||
|
||||
|
|
@ -1731,7 +1731,7 @@ class ReduceLROnPlateau(LRScheduler):
|
|||
"of the `optimizer` param groups."
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
|
||||
|
||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||
|
|
@ -1911,13 +1911,13 @@ class CyclicLR(LRScheduler):
|
|||
|
||||
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
step_size_up = float(step_size_up)
|
||||
step_size_down = (
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
float(step_size_down) if step_size_down is not None else step_size_up
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
self.total_size = step_size_up + step_size_down
|
||||
self.step_ratio = step_size_up / self.total_size
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
|||
def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
import torch._dynamo
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
self = cast(Optimizer, args[0]) # assume first positional arg is `self`
|
||||
prev_grad = torch.is_grad_enabled()
|
||||
try:
|
||||
|
|
@ -136,13 +136,13 @@ def _disable_dynamo_if_unsupported(
|
|||
if torch.compiler.is_compiling() and (
|
||||
not kwargs.get("capturable", False)
|
||||
and has_state_steps
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
and (arg := args[state_steps_ind])
|
||||
and isinstance(arg, Sequence)
|
||||
and arg[0].is_cuda
|
||||
or (
|
||||
"state_steps" in kwargs
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
and (kwarg := kwargs["state_steps"])
|
||||
and isinstance(kwarg, Sequence)
|
||||
and kwarg[0].is_cuda
|
||||
|
|
@ -362,18 +362,18 @@ class Optimizer:
|
|||
|
||||
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
|
||||
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
|
||||
# pyrefly: ignore # not-a-type
|
||||
# pyrefly: ignore [not-a-type]
|
||||
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
_optimizer_state_dict_post_hooks: (
|
||||
# pyrefly: ignore # not-a-type
|
||||
# pyrefly: ignore [not-a-type]
|
||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
)
|
||||
_optimizer_load_state_dict_pre_hooks: (
|
||||
# pyrefly: ignore # not-a-type
|
||||
# pyrefly: ignore [not-a-type]
|
||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
)
|
||||
_optimizer_load_state_dict_post_hooks: (
|
||||
# pyrefly: ignore # not-a-type
|
||||
# pyrefly: ignore [not-a-type]
|
||||
'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
)
|
||||
|
||||
|
|
@ -522,7 +522,7 @@ class Optimizer:
|
|||
f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
|
||||
)
|
||||
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
# pyrefly: ignore [invalid-param-spec]
|
||||
out = func(*args, **kwargs)
|
||||
self._optimizer_step_code()
|
||||
|
||||
|
|
@ -961,9 +961,9 @@ class Optimizer:
|
|||
return Optimizer._process_value_according_to_param_policy(
|
||||
param,
|
||||
value,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
param_id,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
param_groups,
|
||||
key,
|
||||
)
|
||||
|
|
@ -976,7 +976,7 @@ class Optimizer:
|
|||
}
|
||||
elif isinstance(value, Iterable):
|
||||
return type(value)(
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
_cast(param, v, param_id=param_id, param_groups=param_groups)
|
||||
for v in value
|
||||
) # type: ignore[call-arg]
|
||||
|
|
|
|||
|
|
@ -323,7 +323,7 @@ def _single_tensor_radam(
|
|||
rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
|
||||
|
||||
def _compute_rect():
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
return (
|
||||
(rho_t - 4)
|
||||
* (rho_t - 2)
|
||||
|
|
@ -338,7 +338,7 @@ def _single_tensor_radam(
|
|||
else:
|
||||
exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
return (bias_correction2**0.5) / exp_avg_sq_sqrt
|
||||
|
||||
# Compute the variance rectification term and update parameters accordingly
|
||||
|
|
|
|||
|
|
@ -348,7 +348,7 @@ def _single_tensor_sgd(
|
|||
# usually this is the differentiable path, which is why the param.clone() is needed
|
||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
else:
|
||||
grad = grad.add(param, alpha=weight_decay)
|
||||
|
|
@ -372,7 +372,7 @@ def _single_tensor_sgd(
|
|||
if lr.requires_grad:
|
||||
param.addcmul_(grad, lr, value=-1)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
param.add_(grad, alpha=-lr)
|
||||
else:
|
||||
param.add_(grad, alpha=-lr)
|
||||
|
|
|
|||
|
|
@ -250,13 +250,13 @@ class AveragedModel(Module):
|
|||
def update_parameters(self, model: Module):
|
||||
"""Update model parameters."""
|
||||
self_param = (
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
itertools.chain(self.module.parameters(), self.module.buffers())
|
||||
if self.use_buffers
|
||||
else self.parameters()
|
||||
)
|
||||
model_param = (
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
itertools.chain(model.parameters(), model.buffers())
|
||||
if self.use_buffers
|
||||
else model.parameters()
|
||||
|
|
@ -298,17 +298,17 @@ class AveragedModel(Module):
|
|||
avg_fn = get_swa_avg_fn()
|
||||
n_averaged = self.n_averaged.to(device)
|
||||
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
|
||||
else:
|
||||
for p_averaged, p_model in zip( # type: ignore[assignment]
|
||||
self_param_detached, model_param_detached
|
||||
):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
n_averaged = self.n_averaged.to(p_averaged.device)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
p_averaged.detach().copy_(
|
||||
# pyrefly: ignore # missing-attribute, bad-argument-type
|
||||
# pyrefly: ignore [missing-attribute, bad-argument-type]
|
||||
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
|
||||
)
|
||||
|
||||
|
|
@ -497,14 +497,14 @@ class SWALR(LRScheduler):
|
|||
step = self._step_count - 1
|
||||
if self.anneal_epochs == 0:
|
||||
step = max(1, step)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
|
||||
prev_alpha = self.anneal_func(prev_t)
|
||||
prev_lrs = [
|
||||
self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
|
||||
for group in self.optimizer.param_groups
|
||||
]
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
t = max(0, min(1, step / max(1, self.anneal_epochs)))
|
||||
alpha = self.anneal_func(t)
|
||||
return [
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
# pyrefly: ignore # missing-module-attribute
|
||||
# pyrefly: ignore [missing-module-attribute]
|
||||
from pickle import ( # type: ignore[attr-defined]
|
||||
_compat_pickle,
|
||||
_extension_registry,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import importlib
|
|||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
# pyrefly: ignore # missing-module-attribute
|
||||
# pyrefly: ignore [missing-module-attribute]
|
||||
from pickle import ( # type: ignore[attr-defined]
|
||||
_getattribute,
|
||||
_Pickler,
|
||||
|
|
|
|||
|
|
@ -652,7 +652,7 @@ class PackageExporter:
|
|||
memo: defaultdict[int, str] = defaultdict(None)
|
||||
memo_count = 0
|
||||
# pickletools.dis(data_value)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for opcode, arg, _pos in pickletools.genops(data_value):
|
||||
if pickle_protocol == 4:
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ class SchemaMatcher:
|
|||
for schema in cls.match_schemas(t):
|
||||
mutable = mutable or [False for _ in schema.arguments]
|
||||
for i, arg in enumerate(schema.arguments):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
mutable[i] |= getattr(arg.alias_info, "is_write", False)
|
||||
|
||||
return tuple(mutable or (None for _ in t.inputs))
|
||||
|
|
@ -1084,7 +1084,7 @@ class MemoryProfileTimeline:
|
|||
|
||||
if action in (Action.PREEXISTING, Action.CREATE):
|
||||
raw_events.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
(
|
||||
t,
|
||||
_ACTION_TO_INDEX[action],
|
||||
|
|
@ -1095,7 +1095,7 @@ class MemoryProfileTimeline:
|
|||
|
||||
elif action == Action.INCREMENT_VERSION:
|
||||
raw_events.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
(
|
||||
t,
|
||||
_ACTION_TO_INDEX[action],
|
||||
|
|
@ -1104,7 +1104,7 @@ class MemoryProfileTimeline:
|
|||
)
|
||||
)
|
||||
raw_events.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
(
|
||||
t,
|
||||
_ACTION_TO_INDEX[action],
|
||||
|
|
@ -1115,7 +1115,7 @@ class MemoryProfileTimeline:
|
|||
|
||||
elif action == Action.DESTROY:
|
||||
raw_events.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
(
|
||||
t,
|
||||
_ACTION_TO_INDEX[action],
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ class BasicEvaluation:
|
|||
# Find latest cuda kernel event
|
||||
if hasattr(event, "start_us"):
|
||||
start_time = event.start_us() * 1000
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
end_time = (event.start_us() + event.duration_us()) * 1000
|
||||
# Find current spawned cuda kernel event
|
||||
if event in kernel_mapping and kernel_mapping[event] is not None:
|
||||
|
|
|
|||
|
|
@ -161,19 +161,19 @@ class _KinetoProfile:
|
|||
self.mem_tl: Optional[MemoryProfileTimeline] = None
|
||||
self.use_device = None
|
||||
if ProfilerActivity.CUDA in self.activities:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.use_device = "cuda"
|
||||
elif ProfilerActivity.XPU in self.activities:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.use_device = "xpu"
|
||||
elif ProfilerActivity.MTIA in self.activities:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.use_device = "mtia"
|
||||
elif ProfilerActivity.HPU in self.activities:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.use_device = "hpu"
|
||||
elif ProfilerActivity.PrivateUse1 in self.activities:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.use_device = _get_privateuse1_backend_name()
|
||||
|
||||
# user-defined metadata to be amended to the trace
|
||||
|
|
@ -385,7 +385,7 @@ class _KinetoProfile:
|
|||
}
|
||||
if backend == "nccl":
|
||||
nccl_version = torch.cuda.nccl.version()
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
|
||||
return dist_info
|
||||
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
|
|||
nrows // 16, 16
|
||||
)
|
||||
).view(-1)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
outp = outp.index_copy(1, cols_permuted, outp)
|
||||
|
||||
# interleave_column_major_tensor
|
||||
|
|
|
|||
|
|
@ -790,7 +790,7 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
|
|||
# PyTorchFileWriter only supports ascii filename.
|
||||
# For filenames with non-ascii characters, we rely on Python
|
||||
# for writing out the file.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.file_stream = io.FileIO(self.name, mode="w")
|
||||
super().__init__(
|
||||
torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload
|
||||
|
|
|
|||
|
|
@ -397,15 +397,15 @@ def kaiser(
|
|||
)
|
||||
|
||||
# Avoid NaNs by casting `beta` to the appropriate dtype.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
beta = torch.tensor(beta, dtype=dtype, device=device)
|
||||
|
||||
start = -beta
|
||||
constant = 2.0 * beta / (M if not sym else M - 1)
|
||||
end = torch.minimum(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
beta,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
start + (M - 1) * constant,
|
||||
)
|
||||
|
||||
|
|
@ -420,7 +420,7 @@ def kaiser(
|
|||
)
|
||||
|
||||
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
beta
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -623,20 +623,20 @@ def as_sparse_gradcheck(gradcheck):
|
|||
)
|
||||
obj = obj.to_dense().sparse_mask(full_mask)
|
||||
if obj.layout is torch.sparse_coo:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
d.update(
|
||||
indices=obj._indices(), is_coalesced=obj.is_coalesced()
|
||||
)
|
||||
values = obj._values()
|
||||
elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
d.update(
|
||||
compressed_indices=obj.crow_indices(),
|
||||
plain_indices=obj.col_indices(),
|
||||
)
|
||||
values = obj.values()
|
||||
else:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
d.update(
|
||||
compressed_indices=obj.ccol_indices(),
|
||||
plain_indices=obj.row_indices(),
|
||||
|
|
|
|||
|
|
@ -140,7 +140,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
|||
|
||||
if dense.dtype != torch.float:
|
||||
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
||||
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
||||
else:
|
||||
|
|
@ -173,7 +173,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
|||
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
||||
m, meta_ncols, meta_dtype, device
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
||||
|
||||
return (sparse, meta_reordered.view(m, meta_ncols))
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
|
|||
# Because we cannot go from the compressed representation back to the dense representation currently,
|
||||
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
||||
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return self.__class__(
|
||||
torch.Size([self.shape[-1], self.shape[0]]),
|
||||
packed=self.packed_t,
|
||||
|
|
|
|||
|
|
@ -1297,7 +1297,7 @@ def bsr_dense_addmm(
|
|||
assert alpha != 0
|
||||
|
||||
def kernel(grid, *sliced_tensors):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
_bsr_strided_addmm_kernel[grid](
|
||||
*ptr_stride_extractor(*sliced_tensors),
|
||||
beta,
|
||||
|
|
@ -1427,7 +1427,7 @@ if has_triton():
|
|||
|
||||
mat1_block = tl.load(
|
||||
mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :],
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
mask=mask_k[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
|
|
@ -1436,7 +1436,7 @@ if has_triton():
|
|||
mat2_block_ptrs
|
||||
+ mat2_tiled_col_stride * col_block
|
||||
+ mat2_row_block_stride * k_offsets[:, None],
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
mask=mask_k[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
|
@ -1974,7 +1974,7 @@ if has_triton():
|
|||
if attn_mask.dtype is not torch.bool:
|
||||
check_dtype(f_name, attn_mask, query.dtype)
|
||||
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
sdpa = sampled_addmm(
|
||||
attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False
|
||||
)
|
||||
|
|
@ -1986,10 +1986,10 @@ if has_triton():
|
|||
)
|
||||
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
||||
sdpa.values().mul_(scale_factor)
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
sdpa = bsr_softmax(sdpa)
|
||||
torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
sdpa = bsr_dense_mm(sdpa, value)
|
||||
return sdpa
|
||||
|
||||
|
|
|
|||
|
|
@ -234,10 +234,10 @@ def dump():
|
|||
part2 = current_content[end_data_index:]
|
||||
data_part = []
|
||||
for op_key in sorted(_operation_device_version_data, key=sort_key):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
data_part.append(" " + repr(op_key).replace("'", '"') + ": {")
|
||||
op_data = _operation_device_version_data[op_key]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
data_part.extend(f" {key}: {op_data[key]}," for key in sorted(op_data))
|
||||
data_part.append(" },")
|
||||
new_content = part1 + "\n".join(data_part) + "\n" + part2
|
||||
|
|
@ -371,7 +371,7 @@ def minimize(
|
|||
if next_target < minimal_target:
|
||||
minimal_target = next_target
|
||||
parameters = next_parameters
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
pbar.total += i + 1
|
||||
break
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -185,7 +185,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
|||
outer_stride,
|
||||
) -> torch.Tensor:
|
||||
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return cls(
|
||||
shape=shape,
|
||||
packed=inner_tensors.get("packed", None),
|
||||
|
|
@ -415,7 +415,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
|||
sparse_tensor_cutlass,
|
||||
meta_tensor_cutlass,
|
||||
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=sparse_tensor_cutlass,
|
||||
|
|
@ -502,7 +502,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
|||
original_tensor, algorithm=algorithm, use_cutlass=True
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=packed,
|
||||
|
|
@ -564,7 +564,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
|||
cls, original_tensor: torch.Tensor
|
||||
) -> "SparseSemiStructuredTensorCUSPARSELT":
|
||||
cls._validate_device_dim_dtype_shape(original_tensor)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return cls(
|
||||
shape=original_tensor.shape,
|
||||
packed=torch._cslt_compress(original_tensor),
|
||||
|
|
@ -631,7 +631,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
|||
packed = packed.view(original_tensor.shape[0], -1)
|
||||
packed_t = packed_t.view(original_tensor.shape[1], -1)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return cls(
|
||||
original_tensor.shape,
|
||||
packed=packed,
|
||||
|
|
|
|||
|
|
@ -2,6 +2,6 @@ from torch._C import FileCheck as FileCheck
|
|||
|
||||
from . import _utils
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
# pyrefly: ignore [deprecated]
|
||||
from ._comparison import assert_allclose, assert_close as assert_close
|
||||
from ._creation import make_tensor as make_tensor
|
||||
|
|
|
|||
|
|
@ -243,7 +243,7 @@ def make_scalar_mismatch_msg(
|
|||
Defaults to "Scalars".
|
||||
"""
|
||||
abs_diff = abs(actual - expected)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected)
|
||||
return _make_mismatch_msg(
|
||||
default_identifier="Scalars",
|
||||
|
|
@ -487,7 +487,7 @@ class BooleanPair(Pair):
|
|||
def _supported_types(self) -> tuple[type, ...]:
|
||||
cls: list[type] = [bool]
|
||||
if HAS_NUMPY:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
cls.append(np.bool_)
|
||||
return tuple(cls)
|
||||
|
||||
|
|
@ -503,7 +503,7 @@ class BooleanPair(Pair):
|
|||
def _to_bool(self, bool_like: Any, *, id: tuple[Any, ...]) -> bool:
|
||||
if isinstance(bool_like, bool):
|
||||
return bool_like
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
elif isinstance(bool_like, np.bool_):
|
||||
return bool_like.item()
|
||||
else:
|
||||
|
|
@ -583,7 +583,7 @@ class NumberPair(Pair):
|
|||
def _supported_types(self) -> tuple[type, ...]:
|
||||
cls = list(self._NUMBER_TYPES)
|
||||
if HAS_NUMPY:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
cls.append(np.number)
|
||||
return tuple(cls)
|
||||
|
||||
|
|
@ -599,7 +599,7 @@ class NumberPair(Pair):
|
|||
def _to_number(
|
||||
self, number_like: Any, *, id: tuple[Any, ...]
|
||||
) -> Union[int, float, complex]:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if HAS_NUMPY and isinstance(number_like, np.number):
|
||||
return number_like.item()
|
||||
elif isinstance(number_like, self._NUMBER_TYPES):
|
||||
|
|
@ -1122,7 +1122,7 @@ def originate_pairs(
|
|||
mapping_types: tuple[type, ...] = (collections.abc.Mapping,),
|
||||
id: tuple[Any, ...] = (),
|
||||
**options: Any,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> list[Pair]:
|
||||
"""Originates pairs from the individual inputs.
|
||||
|
||||
|
|
@ -1221,7 +1221,7 @@ def originate_pairs(
|
|||
else:
|
||||
for pair_type in pair_types:
|
||||
try:
|
||||
# pyrefly: ignore # bad-instantiation
|
||||
# pyrefly: ignore [bad-instantiation]
|
||||
return [pair_type(actual, expected, id=id, **options)]
|
||||
# Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the
|
||||
# inputs. Thus, we try the next pair type.
|
||||
|
|
@ -1319,9 +1319,9 @@ def not_close_error_metas(
|
|||
# would not get freed until cycle collection, leaking cuda memory in tests.
|
||||
# We break the cycle by removing the reference to the error_meta objects
|
||||
# from this frame as it returns.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
error_metas = [error_metas]
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return error_metas.pop()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user