mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix syntax for pyrefly errors (#166496)
Last one! This ensures all existing suppressions match the syntax expected and will silence only one error code pyrefly check lintrunner Pull Request resolved: https://github.com/pytorch/pytorch/pull/166496 Approved by: https://github.com/Skylion007, https://github.com/mlazos
This commit is contained in:
parent
fa560e1158
commit
d1a6e006e0
|
|
@ -841,7 +841,7 @@ def _reduce_to_lowest_terms(expr: sympy.Expr) -> sympy.Expr:
|
||||||
factor = functools.reduce(math.gcd, map(integer_coefficient, atoms))
|
factor = functools.reduce(math.gcd, map(integer_coefficient, atoms))
|
||||||
if factor == 1:
|
if factor == 1:
|
||||||
return expr
|
return expr
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
atoms = [div_by_factor(x, factor) for x in atoms]
|
atoms = [div_by_factor(x, factor) for x in atoms]
|
||||||
return _sympy_from_args(
|
return _sympy_from_args(
|
||||||
sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative
|
sympy.Add, atoms, sort=True, is_commutative=expr.is_commutative
|
||||||
|
|
@ -2207,7 +2207,7 @@ class SubclassSymbolicContext(StatefulSymbolicContext):
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
if self.inner_contexts is None:
|
if self.inner_contexts is None:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.inner_contexts = {}
|
self.inner_contexts = {}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2296,12 +2296,12 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
|
||||||
# only re-create the objects if any of the args changed to avoid expensive
|
# only re-create the objects if any of the args changed to avoid expensive
|
||||||
# checks when re-creating objects.
|
# checks when re-creating objects.
|
||||||
new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type]
|
new_args = [_fast_expand(arg) for arg in expr.args] # type: ignore[arg-type]
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)):
|
if any(arg is not new_arg for arg, new_arg in zip(expr.args, new_args)):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
return _fast_expand(expr.func(*new_args))
|
return _fast_expand(expr.func(*new_args))
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if expr.is_Pow:
|
if expr.is_Pow:
|
||||||
base: sympy.Expr
|
base: sympy.Expr
|
||||||
exp: sympy.Expr
|
exp: sympy.Expr
|
||||||
|
|
@ -2311,11 +2311,11 @@ def _fast_expand(expr: _SympyT) -> _SympyT:
|
||||||
return sympy.expand_multinomial(expr, deep=False)
|
return sympy.expand_multinomial(expr, deep=False)
|
||||||
elif exp < 0:
|
elif exp < 0:
|
||||||
return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
|
return S.One / sympy.expand_multinomial(S.One / expr, deep=False)
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
elif expr.is_Mul:
|
elif expr.is_Mul:
|
||||||
num: list[sympy.Expr] = []
|
num: list[sympy.Expr] = []
|
||||||
den: list[sympy.Expr] = []
|
den: list[sympy.Expr] = []
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
for arg in expr.args:
|
for arg in expr.args:
|
||||||
if arg.is_Pow and arg.args[1] == -1:
|
if arg.is_Pow and arg.args[1] == -1:
|
||||||
den.append(S.One / arg) # type: ignore[operator, arg-type]
|
den.append(S.One / arg) # type: ignore[operator, arg-type]
|
||||||
|
|
@ -2437,7 +2437,7 @@ def _maybe_evaluate_static_worker(
|
||||||
|
|
||||||
# TODO: remove this try catch (esp for unbacked_only)
|
# TODO: remove this try catch (esp for unbacked_only)
|
||||||
try:
|
try:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
new_expr = expr.xreplace(new_shape_env)
|
new_expr = expr.xreplace(new_shape_env)
|
||||||
except RecursionError:
|
except RecursionError:
|
||||||
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
|
log.warning("RecursionError in sympy.xreplace(%s, %s)", expr, new_shape_env)
|
||||||
|
|
@ -2975,19 +2975,19 @@ class DimConstraints:
|
||||||
# is_integer tests though haha
|
# is_integer tests though haha
|
||||||
return (base - mod_reduced) / divisor
|
return (base - mod_reduced) / divisor
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if expr.has(Mod):
|
if expr.has(Mod):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
expr = expr.replace(Mod, mod_handler)
|
expr = expr.replace(Mod, mod_handler)
|
||||||
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
|
# 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative
|
||||||
# arguments should be OK.
|
# arguments should be OK.
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if expr.has(PythonMod):
|
if expr.has(PythonMod):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
expr = expr.replace(PythonMod, mod_handler)
|
expr = expr.replace(PythonMod, mod_handler)
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if expr.has(FloorDiv):
|
if expr.has(FloorDiv):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
expr = expr.replace(FloorDiv, floor_div_handler)
|
expr = expr.replace(FloorDiv, floor_div_handler)
|
||||||
return expr
|
return expr
|
||||||
|
|
||||||
|
|
@ -5106,7 +5106,7 @@ class ShapeEnv:
|
||||||
|
|
||||||
if duck:
|
if duck:
|
||||||
# Make sure to reuse this symbol for subsequent duck shaping
|
# Make sure to reuse this symbol for subsequent duck shaping
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
self.val_to_var[val] = sympy_expr
|
self.val_to_var[val] = sympy_expr
|
||||||
|
|
||||||
if isinstance(val, int):
|
if isinstance(val, int):
|
||||||
|
|
@ -5338,9 +5338,9 @@ class ShapeEnv:
|
||||||
|
|
||||||
# Expand optional inputs, or verify invariants are upheld
|
# Expand optional inputs, or verify invariants are upheld
|
||||||
if input_contexts is None:
|
if input_contexts is None:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
input_contexts = [
|
input_contexts = [
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
_create_no_constraints_context(t) if isinstance(t, Tensorlike) else None
|
_create_no_constraints_context(t) if isinstance(t, Tensorlike) else None
|
||||||
for t in placeholders
|
for t in placeholders
|
||||||
]
|
]
|
||||||
|
|
@ -5350,7 +5350,7 @@ class ShapeEnv:
|
||||||
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
|
for i, (t, context) in enumerate(zip(placeholders, input_contexts)):
|
||||||
if isinstance(t, Tensorlike):
|
if isinstance(t, Tensorlike):
|
||||||
if context is None:
|
if context is None:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
input_contexts[i] = _create_no_constraints_context(t)
|
input_contexts[i] = _create_no_constraints_context(t)
|
||||||
else:
|
else:
|
||||||
assert isinstance(t, (SymInt, int, SymFloat, float))
|
assert isinstance(t, (SymInt, int, SymFloat, float))
|
||||||
|
|
@ -5636,7 +5636,7 @@ class ShapeEnv:
|
||||||
s = sympy.Float(val)
|
s = sympy.Float(val)
|
||||||
input_guards.append((source, s))
|
input_guards.append((source, s))
|
||||||
|
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
for t, source, context in zip(placeholders, sources, input_contexts):
|
for t, source, context in zip(placeholders, sources, input_contexts):
|
||||||
if isinstance(source, str):
|
if isinstance(source, str):
|
||||||
from torch._dynamo.source import LocalSource
|
from torch._dynamo.source import LocalSource
|
||||||
|
|
@ -5999,7 +5999,7 @@ class ShapeEnv:
|
||||||
else:
|
else:
|
||||||
str_msg = f" - {msg_cb()}"
|
str_msg = f" - {msg_cb()}"
|
||||||
error_msgs.append(str_msg)
|
error_msgs.append(str_msg)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
debug_names.add(debug_name)
|
debug_names.add(debug_name)
|
||||||
if len(error_msgs) > 0:
|
if len(error_msgs) > 0:
|
||||||
debug_names_str = ", ".join(sorted(debug_names))
|
debug_names_str = ", ".join(sorted(debug_names))
|
||||||
|
|
@ -6133,7 +6133,7 @@ class ShapeEnv:
|
||||||
Get a list of guards, but pruned so it only provides guards that
|
Get a list of guards, but pruned so it only provides guards that
|
||||||
reference symints from the passed in input
|
reference symints from the passed in input
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
symints = {
|
symints = {
|
||||||
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
|
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
|
||||||
}
|
}
|
||||||
|
|
@ -6396,7 +6396,7 @@ class ShapeEnv:
|
||||||
Apply symbol replacements to any symbols in the given expression.
|
Apply symbol replacements to any symbols in the given expression.
|
||||||
"""
|
"""
|
||||||
replacements = {}
|
replacements = {}
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
for s in expr.free_symbols:
|
for s in expr.free_symbols:
|
||||||
r = self._find(s)
|
r = self._find(s)
|
||||||
|
|
||||||
|
|
@ -6406,7 +6406,7 @@ class ShapeEnv:
|
||||||
if not r.is_Symbol or r != s:
|
if not r.is_Symbol or r != s:
|
||||||
replacements[s] = r
|
replacements[s] = r
|
||||||
if replacements:
|
if replacements:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
return safe_expand(expr.xreplace(replacements))
|
return safe_expand(expr.xreplace(replacements))
|
||||||
else:
|
else:
|
||||||
return expr
|
return expr
|
||||||
|
|
@ -7181,7 +7181,7 @@ class ShapeEnv:
|
||||||
instructions = list(dis.Bytecode(frame.f_code))
|
instructions = list(dis.Bytecode(frame.f_code))
|
||||||
co_lines, offset = inspect.getsourcelines(frame.f_code)
|
co_lines, offset = inspect.getsourcelines(frame.f_code)
|
||||||
start, end, cur = None, None, None
|
start, end, cur = None, None, None
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
for i, instr in enumerate(instructions):
|
for i, instr in enumerate(instructions):
|
||||||
if instr.starts_line is not None:
|
if instr.starts_line is not None:
|
||||||
cur = instr.starts_line
|
cur = instr.starts_line
|
||||||
|
|
|
||||||
|
|
@ -238,7 +238,7 @@ class Dispatcher:
|
||||||
"To use a variadic union type place the desired types "
|
"To use a variadic union type place the desired types "
|
||||||
"inside of a tuple, e.g., [(int, str)]"
|
"inside of a tuple, e.g., [(int, str)]"
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-specialization
|
# pyrefly: ignore [bad-specialization]
|
||||||
new_signature.append(Variadic[typ[0]])
|
new_signature.append(Variadic[typ[0]])
|
||||||
else:
|
else:
|
||||||
new_signature.append(typ)
|
new_signature.append(typ)
|
||||||
|
|
@ -407,7 +407,7 @@ class MethodDispatcher(Dispatcher):
|
||||||
Dispatcher
|
Dispatcher
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
__slots__ = ("obj", "cls")
|
__slots__ = ("obj", "cls")
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
||||||
|
|
@ -298,7 +298,7 @@ def update_in(d, keys, func, default=None, factory=dict):
|
||||||
rv = inner = factory()
|
rv = inner = factory()
|
||||||
rv.update(d)
|
rv.update(d)
|
||||||
|
|
||||||
# pyrefly: ignore # not-iterable
|
# pyrefly: ignore [not-iterable]
|
||||||
for key in ks:
|
for key in ks:
|
||||||
if k in d:
|
if k in d:
|
||||||
d = d[k]
|
d = d[k]
|
||||||
|
|
|
||||||
|
|
@ -1380,7 +1380,7 @@ class Graph:
|
||||||
f(to_erase)
|
f(to_erase)
|
||||||
|
|
||||||
self._find_nodes_lookup_table.remove(to_erase)
|
self._find_nodes_lookup_table.remove(to_erase)
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
to_erase._remove_from_list()
|
to_erase._remove_from_list()
|
||||||
to_erase._erased = True # iterators may retain handles to erased nodes
|
to_erase._erased = True # iterators may retain handles to erased nodes
|
||||||
self._len -= 1
|
self._len -= 1
|
||||||
|
|
@ -1941,7 +1941,7 @@ class Graph:
|
||||||
"a str is expected"
|
"a str is expected"
|
||||||
)
|
)
|
||||||
if node.op in ["get_attr", "call_module"]:
|
if node.op in ["get_attr", "call_module"]:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
target_atoms = node.target.split(".")
|
target_atoms = node.target.split(".")
|
||||||
m_itr = self.owning_module
|
m_itr = self.owning_module
|
||||||
for i, atom in enumerate(target_atoms):
|
for i, atom in enumerate(target_atoms):
|
||||||
|
|
|
||||||
|
|
@ -535,7 +535,7 @@ class GraphModule(torch.nn.Module):
|
||||||
self.graph._tracer_cls
|
self.graph._tracer_cls
|
||||||
and "<locals>" not in self.graph._tracer_cls.__qualname__
|
and "<locals>" not in self.graph._tracer_cls.__qualname__
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._tracer_cls = self.graph._tracer_cls
|
self._tracer_cls = self.graph._tracer_cls
|
||||||
|
|
||||||
self._tracer_extras = {}
|
self._tracer_extras = {}
|
||||||
|
|
|
||||||
|
|
@ -165,12 +165,12 @@ def tensorify_python_scalars(
|
||||||
|
|
||||||
node = graph.call_function(
|
node = graph.call_function(
|
||||||
torch.ops.aten.scalar_tensor.default,
|
torch.ops.aten.scalar_tensor.default,
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
(c,),
|
(c,),
|
||||||
{"dtype": dtype},
|
{"dtype": dtype},
|
||||||
)
|
)
|
||||||
with fake_mode:
|
with fake_mode:
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype)
|
node.meta["val"] = torch.ops.aten.scalar_tensor.default(c, dtype=dtype)
|
||||||
expr_to_tensor_proxy[expr] = MetaProxy(
|
expr_to_tensor_proxy[expr] = MetaProxy(
|
||||||
node,
|
node,
|
||||||
|
|
@ -223,13 +223,13 @@ def tensorify_python_scalars(
|
||||||
expr_to_sym_proxy[s] = MetaProxy(
|
expr_to_sym_proxy[s] = MetaProxy(
|
||||||
node, tracer=tracer, fake_mode=fake_mode
|
node, tracer=tracer, fake_mode=fake_mode
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
elif (sym_expr := _get_sym_val(node)) is not None:
|
elif (sym_expr := _get_sym_val(node)) is not None:
|
||||||
if sym_expr not in expr_to_sym_proxy and not isinstance(
|
if sym_expr not in expr_to_sym_proxy and not isinstance(
|
||||||
sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
|
sym_expr, (sympy.Number, sympy.logic.boolalg.BooleanAtom)
|
||||||
):
|
):
|
||||||
expr_to_sym_proxy[sym_expr] = MetaProxy(
|
expr_to_sym_proxy[sym_expr] = MetaProxy(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
node,
|
node,
|
||||||
tracer=tracer,
|
tracer=tracer,
|
||||||
fake_mode=fake_mode,
|
fake_mode=fake_mode,
|
||||||
|
|
@ -238,7 +238,7 @@ def tensorify_python_scalars(
|
||||||
# Specialize all dimensions that contain symfloats. Here's
|
# Specialize all dimensions that contain symfloats. Here's
|
||||||
# an example test that requires this:
|
# an example test that requires this:
|
||||||
# PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950
|
# PYTORCH_OPINFO_SAMPLE_INPUT_INDEX=4 python test/inductor/test_torchinductor_opinfo.py TestInductorOpInfoCUDA.test_comprehensive_nn_functional_interpolate_bicubic_cuda_float32 # noqa: B950
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
val = node.meta.get("val")
|
val = node.meta.get("val")
|
||||||
if isinstance(val, FakeTensor):
|
if isinstance(val, FakeTensor):
|
||||||
for dim in val.shape:
|
for dim in val.shape:
|
||||||
|
|
@ -257,17 +257,17 @@ def tensorify_python_scalars(
|
||||||
should_restart = True
|
should_restart = True
|
||||||
|
|
||||||
# Look for functions to convert
|
# Look for functions to convert
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if node.op == "call_function" and (
|
if node.op == "call_function" and (
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
replacement_op := SUPPORTED_OPS.get(node.target)
|
replacement_op := SUPPORTED_OPS.get(node.target)
|
||||||
):
|
):
|
||||||
args: list[Any] = []
|
args: list[Any] = []
|
||||||
transform = False
|
transform = False
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
compute_dtype = get_computation_dtype(node.meta["val"].dtype)
|
compute_dtype = get_computation_dtype(node.meta["val"].dtype)
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
for a in node.args:
|
for a in node.args:
|
||||||
if (
|
if (
|
||||||
isinstance(a, fx.Node)
|
isinstance(a, fx.Node)
|
||||||
|
|
@ -304,7 +304,7 @@ def tensorify_python_scalars(
|
||||||
if transform:
|
if transform:
|
||||||
replacement_proxy = replacement_op(*args)
|
replacement_proxy = replacement_op(*args)
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if compute_dtype != node.meta["val"].dtype:
|
if compute_dtype != node.meta["val"].dtype:
|
||||||
replacement_proxy = (
|
replacement_proxy = (
|
||||||
torch.ops.prims.convert_element_type.default(
|
torch.ops.prims.convert_element_type.default(
|
||||||
|
|
@ -313,9 +313,9 @@ def tensorify_python_scalars(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
node.replace_all_uses_with(replacement_proxy.node)
|
node.replace_all_uses_with(replacement_proxy.node)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
graph.erase_node(node)
|
graph.erase_node(node)
|
||||||
|
|
||||||
metrics_context = get_metrics_context()
|
metrics_context = get_metrics_context()
|
||||||
|
|
@ -324,16 +324,16 @@ def tensorify_python_scalars(
|
||||||
"tensorify_float_success", True, overwrite=True
|
"tensorify_float_success", True, overwrite=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
for a in node.args:
|
for a in node.args:
|
||||||
if (
|
if (
|
||||||
isinstance(a, fx.Node)
|
isinstance(a, fx.Node)
|
||||||
and "val" in a.meta
|
and "val" in a.meta
|
||||||
and isinstance(zf := a.meta["val"], torch.SymFloat)
|
and isinstance(zf := a.meta["val"], torch.SymFloat)
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
failed_tensorify_ops.update(str(node.target))
|
failed_tensorify_ops.update(str(node.target))
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
log.info("Failed to tensorify %s", str(node.target))
|
log.info("Failed to tensorify %s", str(node.target))
|
||||||
|
|
||||||
# Now do one more pass that specializes all symfloats we didn't manage
|
# Now do one more pass that specializes all symfloats we didn't manage
|
||||||
|
|
|
||||||
|
|
@ -437,13 +437,13 @@ if HAS_PYDOT:
|
||||||
)
|
)
|
||||||
current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment]
|
current_graph = buf_name_to_subgraph.get(buf_name) # type: ignore[assignment]
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
current_graph.add_node(dot_node)
|
current_graph.add_node(dot_node)
|
||||||
|
|
||||||
def get_module_params_or_buffers():
|
def get_module_params_or_buffers():
|
||||||
for pname, ptensor in chain(
|
for pname, ptensor in chain(
|
||||||
leaf_module.named_parameters(),
|
leaf_module.named_parameters(),
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
leaf_module.named_buffers(),
|
leaf_module.named_buffers(),
|
||||||
):
|
):
|
||||||
pname1 = node.name + "." + pname
|
pname1 = node.name + "." + pname
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,7 @@ __all__ = ["PassResult", "PassBase"]
|
||||||
|
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
# pyrefly: ignore # invalid-inheritance
|
# pyrefly: ignore [invalid-inheritance]
|
||||||
class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
|
class PassResult(namedtuple("PassResult", ["graph_module", "modified"])):
|
||||||
"""
|
"""
|
||||||
Result of a pass:
|
Result of a pass:
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ def pass_result_wrapper(fn: Callable) -> Callable:
|
||||||
wrapped_fn (Callable[Module, PassResult])
|
wrapped_fn (Callable[Module, PassResult])
|
||||||
"""
|
"""
|
||||||
if fn is None:
|
if fn is None:
|
||||||
# pyrefly: ignore # bad-return
|
# pyrefly: ignore [bad-return]
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@wraps(fn)
|
@wraps(fn)
|
||||||
|
|
|
||||||
|
|
@ -396,25 +396,25 @@ class _MinimizerBase:
|
||||||
report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
|
report.append(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
|
||||||
if self.module_exporter:
|
if self.module_exporter:
|
||||||
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
|
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
result_key = result_key[-1]
|
result_key = result_key[-1]
|
||||||
# If the result is still a tuple (happens in non-sequential mode),
|
# If the result is still a tuple (happens in non-sequential mode),
|
||||||
# we only use the first element as name.
|
# we only use the first element as name.
|
||||||
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
|
if isinstance(result_key, tuple): # type: ignore[possibly-undefined]
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
result_key = str(result_key[0])
|
result_key = str(result_key[0])
|
||||||
# pyre-ignore[29]: not a function
|
# pyre-ignore[29]: not a function
|
||||||
self.module_exporter(
|
self.module_exporter(
|
||||||
a_input,
|
a_input,
|
||||||
submodule,
|
submodule,
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
result_key + "_cpu",
|
result_key + "_cpu",
|
||||||
)
|
)
|
||||||
# pyre-ignore[29]: not a function
|
# pyre-ignore[29]: not a function
|
||||||
self.module_exporter(
|
self.module_exporter(
|
||||||
b_input,
|
b_input,
|
||||||
submodule,
|
submodule,
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
result_key + "_acc",
|
result_key + "_acc",
|
||||||
)
|
)
|
||||||
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
|
raise FxNetMinimizerResultMismatchError(f"Result mismatch for {result_key}") # type: ignore[possibly-undefined]
|
||||||
|
|
|
||||||
|
|
@ -360,7 +360,7 @@ def insert_deferred_runtime_asserts(
|
||||||
):
|
):
|
||||||
# this guards against deleting calls like item() that produce new untracked symbols
|
# this guards against deleting calls like item() that produce new untracked symbols
|
||||||
def has_new_untracked_symbols():
|
def has_new_untracked_symbols():
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
for symbol in sym_expr.free_symbols:
|
for symbol in sym_expr.free_symbols:
|
||||||
if symbol not in expr_to_proxy:
|
if symbol not in expr_to_proxy:
|
||||||
return True
|
return True
|
||||||
|
|
@ -376,7 +376,7 @@ def insert_deferred_runtime_asserts(
|
||||||
assert resolved_unbacked_bindings is not None
|
assert resolved_unbacked_bindings is not None
|
||||||
|
|
||||||
def has_new_unbacked_bindings():
|
def has_new_unbacked_bindings():
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
for key in resolved_unbacked_bindings.keys():
|
for key in resolved_unbacked_bindings.keys():
|
||||||
if key not in expr_to_proxy:
|
if key not in expr_to_proxy:
|
||||||
return True
|
return True
|
||||||
|
|
|
||||||
|
|
@ -351,9 +351,9 @@ def split_module(
|
||||||
|
|
||||||
assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
|
assert all(v is not None for v in autocast_exits.values()), "autocast must exit"
|
||||||
|
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
|
autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()}
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
|
grad_regions = {k: sorted(v) for k, v in grad_regions.items()}
|
||||||
|
|
||||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||||
|
|
@ -418,9 +418,9 @@ def split_module(
|
||||||
for regions_mapping in [autocast_regions, grad_regions]:
|
for regions_mapping in [autocast_regions, grad_regions]:
|
||||||
for node, regions in regions_mapping.items():
|
for node, regions in regions_mapping.items():
|
||||||
assert len(regions) > 0
|
assert len(regions) > 0
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
partitions[str(regions[0])].environment[node] = node
|
partitions[str(regions[0])].environment[node] = node
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
for r in regions[1:]:
|
for r in regions[1:]:
|
||||||
partition = partitions[str(r)]
|
partition = partitions[str(r)]
|
||||||
new_node = partition.graph.create_node(
|
new_node = partition.graph.create_node(
|
||||||
|
|
@ -520,7 +520,7 @@ def split_module(
|
||||||
for node in reversed(regions_mapping):
|
for node in reversed(regions_mapping):
|
||||||
regions = regions_mapping[node]
|
regions = regions_mapping[node]
|
||||||
assert len(regions) > 0
|
assert len(regions) > 0
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
for r in regions[:-1]:
|
for r in regions[:-1]:
|
||||||
partition = partitions[str(r)]
|
partition = partitions[str(r)]
|
||||||
exit_node = autocast_exits[node]
|
exit_node = autocast_exits[node]
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ def lift_subgraph_as_module(
|
||||||
|
|
||||||
for name in target_name_parts[:-1]:
|
for name in target_name_parts[:-1]:
|
||||||
if not hasattr(curr, name):
|
if not hasattr(curr, name):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
curr.add_module(name, HolderModule({}))
|
curr.add_module(name, HolderModule({}))
|
||||||
|
|
||||||
curr = getattr(curr, name)
|
curr = getattr(curr, name)
|
||||||
|
|
|
||||||
|
|
@ -242,7 +242,7 @@ class Library:
|
||||||
|
|
||||||
if dispatch_key == "":
|
if dispatch_key == "":
|
||||||
dispatch_key = self.dispatch_key
|
dispatch_key = self.dispatch_key
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
|
assert torch.DispatchKeySet(dispatch_key).has(torch._C.DispatchKey.Dense)
|
||||||
|
|
||||||
if isinstance(op_name, str):
|
if isinstance(op_name, str):
|
||||||
|
|
|
||||||
|
|
@ -484,7 +484,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
|
||||||
raise IndexError(
|
raise IndexError(
|
||||||
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
|
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
dims.append(d % ndim)
|
dims.append(d % ndim)
|
||||||
return tuple(sorted(dims))
|
return tuple(sorted(dims))
|
||||||
|
|
||||||
|
|
@ -1017,7 +1017,7 @@ def _combine_input_and_mask(
|
||||||
|
|
||||||
class Combine(torch.autograd.Function):
|
class Combine(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def forward(ctx, input, mask):
|
def forward(ctx, input, mask):
|
||||||
"""Return input with masked-out elements eliminated for the given operations."""
|
"""Return input with masked-out elements eliminated for the given operations."""
|
||||||
ctx.save_for_backward(mask)
|
ctx.save_for_backward(mask)
|
||||||
|
|
@ -1028,7 +1028,7 @@ def _combine_input_and_mask(
|
||||||
return helper(input, mask)
|
return helper(input, mask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
(mask,) = ctx.saved_tensors
|
(mask,) = ctx.saved_tensors
|
||||||
grad_data = (
|
grad_data = (
|
||||||
|
|
@ -1403,18 +1403,18 @@ elements, have ``nan`` values.
|
||||||
if input.layout == torch.strided:
|
if input.layout == torch.strided:
|
||||||
if mask is None:
|
if mask is None:
|
||||||
# TODO: compute count analytically
|
# TODO: compute count analytically
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
count = sum(
|
count = sum(
|
||||||
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
||||||
dim,
|
dim,
|
||||||
keepdim=keepdim,
|
keepdim=keepdim,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
total = sum(input, dim, keepdim=keepdim, dtype=dtype)
|
total = sum(input, dim, keepdim=keepdim, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
inmask = _input_mask(input, mask=mask)
|
inmask = _input_mask(input, mask=mask)
|
||||||
count = inmask.sum(dim=dim, keepdim=bool(keepdim))
|
count = inmask.sum(dim=dim, keepdim=bool(keepdim))
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
|
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
|
||||||
return total / count
|
return total / count
|
||||||
elif input.layout == torch.sparse_csr:
|
elif input.layout == torch.sparse_csr:
|
||||||
|
|
@ -1625,18 +1625,18 @@ def _std_var(
|
||||||
if input.layout == torch.strided:
|
if input.layout == torch.strided:
|
||||||
if mask is None:
|
if mask is None:
|
||||||
# TODO: compute count analytically
|
# TODO: compute count analytically
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
count = sum(
|
count = sum(
|
||||||
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
||||||
dim,
|
dim,
|
||||||
keepdim=True,
|
keepdim=True,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
sample_total = sum(input, dim, keepdim=True, dtype=dtype)
|
sample_total = sum(input, dim, keepdim=True, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
inmask = _input_mask(input, mask=mask)
|
inmask = _input_mask(input, mask=mask)
|
||||||
count = inmask.sum(dim=dim, keepdim=True)
|
count = inmask.sum(dim=dim, keepdim=True)
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
|
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
|
||||||
# TODO: replace torch.subtract/divide/square/maximum with
|
# TODO: replace torch.subtract/divide/square/maximum with
|
||||||
# masked subtract/divide/square/maximum when these will be
|
# masked subtract/divide/square/maximum when these will be
|
||||||
|
|
@ -1644,7 +1644,7 @@ def _std_var(
|
||||||
sample_mean = torch.divide(sample_total, count)
|
sample_mean = torch.divide(sample_total, count)
|
||||||
x = torch.subtract(input, sample_mean)
|
x = torch.subtract(input, sample_mean)
|
||||||
if mask is None:
|
if mask is None:
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
|
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
|
||||||
else:
|
else:
|
||||||
total = sum(
|
total = sum(
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ def _check_args_kwargs_length(
|
||||||
|
|
||||||
class _MaskedContiguous(torch.autograd.Function):
|
class _MaskedContiguous(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
if not is_masked_tensor(input):
|
if not is_masked_tensor(input):
|
||||||
raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
|
raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
|
||||||
|
|
@ -61,14 +61,14 @@ class _MaskedContiguous(torch.autograd.Function):
|
||||||
return MaskedTensor(data.contiguous(), mask.contiguous())
|
return MaskedTensor(data.contiguous(), mask.contiguous())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return grad_output
|
return grad_output
|
||||||
|
|
||||||
|
|
||||||
class _MaskedToDense(torch.autograd.Function):
|
class _MaskedToDense(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
if not is_masked_tensor(input):
|
if not is_masked_tensor(input):
|
||||||
raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
|
raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
|
||||||
|
|
@ -83,7 +83,7 @@ class _MaskedToDense(torch.autograd.Function):
|
||||||
return MaskedTensor(data.to_dense(), mask.to_dense())
|
return MaskedTensor(data.to_dense(), mask.to_dense())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
layout = ctx.layout
|
layout = ctx.layout
|
||||||
|
|
||||||
|
|
@ -98,7 +98,7 @@ class _MaskedToDense(torch.autograd.Function):
|
||||||
|
|
||||||
class _MaskedToSparse(torch.autograd.Function):
|
class _MaskedToSparse(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
if not is_masked_tensor(input):
|
if not is_masked_tensor(input):
|
||||||
raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
|
raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
|
||||||
|
|
@ -115,14 +115,14 @@ class _MaskedToSparse(torch.autograd.Function):
|
||||||
return MaskedTensor(sparse_data, sparse_mask)
|
return MaskedTensor(sparse_data, sparse_mask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return grad_output.to_dense()
|
return grad_output.to_dense()
|
||||||
|
|
||||||
|
|
||||||
class _MaskedToSparseCsr(torch.autograd.Function):
|
class _MaskedToSparseCsr(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
if not is_masked_tensor(input):
|
if not is_masked_tensor(input):
|
||||||
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
|
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
|
||||||
|
|
@ -143,21 +143,21 @@ class _MaskedToSparseCsr(torch.autograd.Function):
|
||||||
return MaskedTensor(sparse_data, sparse_mask)
|
return MaskedTensor(sparse_data, sparse_mask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return grad_output.to_dense()
|
return grad_output.to_dense()
|
||||||
|
|
||||||
|
|
||||||
class _MaskedWhere(torch.autograd.Function):
|
class _MaskedWhere(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def forward(ctx, cond, self, other):
|
def forward(ctx, cond, self, other):
|
||||||
ctx.mark_non_differentiable(cond)
|
ctx.mark_non_differentiable(cond)
|
||||||
ctx.save_for_backward(cond)
|
ctx.save_for_backward(cond)
|
||||||
return torch.ops.aten.where(cond, self, other)
|
return torch.ops.aten.where(cond, self, other)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
(cond,) = ctx.saved_tensors
|
(cond,) = ctx.saved_tensors
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -174,7 +174,7 @@ class MaskedTensor(torch.Tensor):
|
||||||
UserWarning,
|
UserWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
|
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
|
||||||
|
|
||||||
def _preprocess_data(self, data, mask):
|
def _preprocess_data(self, data, mask):
|
||||||
|
|
@ -244,12 +244,12 @@ class MaskedTensor(torch.Tensor):
|
||||||
|
|
||||||
class Constructor(torch.autograd.Function):
|
class Constructor(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def forward(ctx, data, mask):
|
def forward(ctx, data, mask):
|
||||||
return MaskedTensor(data, mask)
|
return MaskedTensor(data, mask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return grad_output, None
|
return grad_output, None
|
||||||
|
|
||||||
|
|
@ -336,12 +336,12 @@ class MaskedTensor(torch.Tensor):
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
class GetData(torch.autograd.Function):
|
class GetData(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def forward(ctx, self):
|
def forward(ctx, self):
|
||||||
return self._masked_data.detach()
|
return self._masked_data.detach()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
if is_masked_tensor(grad_output):
|
if is_masked_tensor(grad_output):
|
||||||
return grad_output
|
return grad_output
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ class ProcessContext:
|
||||||
"""Attempt to join all processes with a shared timeout."""
|
"""Attempt to join all processes with a shared timeout."""
|
||||||
end = time.monotonic() + timeout
|
end = time.monotonic() + timeout
|
||||||
for process in self.processes:
|
for process in self.processes:
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
time_to_wait = max(0, end - time.monotonic())
|
time_to_wait = max(0, end - time.monotonic())
|
||||||
process.join(time_to_wait)
|
process.join(time_to_wait)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
|
||||||
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
|
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
|
||||||
# For other dims, subtract 1 to convert to inner space.
|
# For other dims, subtract 1 to convert to inner space.
|
||||||
return (
|
return (
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
ragged_dim - 1 if dim == 0 else dim - 1
|
ragged_dim - 1 if dim == 0 else dim - 1
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2008,7 +2008,7 @@ def index_put_(func, *args, **kwargs):
|
||||||
else:
|
else:
|
||||||
lengths = inp.lengths()
|
lengths = inp.lengths()
|
||||||
torch._assert_async(
|
torch._assert_async(
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
torch.all(indices[inp._ragged_idx] < lengths),
|
torch.all(indices[inp._ragged_idx] < lengths),
|
||||||
"Some indices in the ragged dimension are out of bounds!",
|
"Some indices in the ragged dimension are out of bounds!",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -668,5 +668,5 @@ def _get_numa_node_indices_for_socket_index(*, socket_index: int) -> set[int]:
|
||||||
|
|
||||||
def _get_allowed_cpu_indices_for_current_thread() -> set[int]:
|
def _get_allowed_cpu_indices_for_current_thread() -> set[int]:
|
||||||
# 0 denotes current thread
|
# 0 denotes current thread
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
return os.sched_getaffinity(0) # type:ignore[attr-defined]
|
return os.sched_getaffinity(0) # type:ignore[attr-defined]
|
||||||
|
|
|
||||||
|
|
@ -251,7 +251,7 @@ def _compare_onnx_pytorch_outputs_in_np(
|
||||||
# pyrefly: ignore [missing-attribute]
|
# pyrefly: ignore [missing-attribute]
|
||||||
if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8:
|
if ort_out.dtype == np.uint8 or ort_out.dtype == np.int8:
|
||||||
warnings.warn("ONNX output is quantized", stacklevel=2)
|
warnings.warn("ONNX output is quantized", stacklevel=2)
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8:
|
if pt_out.dtype == np.uint8 or pt_out.dtype == np.int8:
|
||||||
warnings.warn("PyTorch output is quantized", stacklevel=2)
|
warnings.warn("PyTorch output is quantized", stacklevel=2)
|
||||||
raise
|
raise
|
||||||
|
|
|
||||||
|
|
@ -78,7 +78,7 @@ def _adjust_lr(
|
||||||
A, B = param_shape[:2]
|
A, B = param_shape[:2]
|
||||||
|
|
||||||
if adjust_lr_fn is None or adjust_lr_fn == "original":
|
if adjust_lr_fn is None or adjust_lr_fn == "original":
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
adjusted_ratio = math.sqrt(max(1, A / B))
|
adjusted_ratio = math.sqrt(max(1, A / B))
|
||||||
elif adjust_lr_fn == "match_rms_adamw":
|
elif adjust_lr_fn == "match_rms_adamw":
|
||||||
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
|
adjusted_ratio = 0.2 * math.sqrt(max(A, B))
|
||||||
|
|
|
||||||
|
|
@ -423,7 +423,7 @@ def _single_tensor_adam(
|
||||||
if weight_decay.requires_grad:
|
if weight_decay.requires_grad:
|
||||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
grad = grad.add(param, alpha=weight_decay)
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
else:
|
else:
|
||||||
grad = grad.add(param, alpha=weight_decay)
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
|
|
|
||||||
|
|
@ -264,7 +264,7 @@ def _single_tensor_asgd(
|
||||||
ax.copy_(param)
|
ax.copy_(param)
|
||||||
|
|
||||||
if capturable:
|
if capturable:
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
|
eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
|
||||||
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
|
mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -113,16 +113,16 @@ def _strong_wolfe(
|
||||||
|
|
||||||
# compute new trial value
|
# compute new trial value
|
||||||
t = _cubic_interpolate(
|
t = _cubic_interpolate(
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
bracket[0],
|
bracket[0],
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
bracket_f[0],
|
bracket_f[0],
|
||||||
bracket_gtd[0], # type: ignore[possibly-undefined]
|
bracket_gtd[0], # type: ignore[possibly-undefined]
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
bracket[1],
|
bracket[1],
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
bracket_f[1],
|
bracket_f[1],
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
bracket_gtd[1],
|
bracket_gtd[1],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -133,20 +133,20 @@ def _strong_wolfe(
|
||||||
# + `t` is at one of the boundary,
|
# + `t` is at one of the boundary,
|
||||||
# we will move `t` to a position which is `0.1 * len(bracket)`
|
# we will move `t` to a position which is `0.1 * len(bracket)`
|
||||||
# away from the nearest boundary point.
|
# away from the nearest boundary point.
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
eps = 0.1 * (max(bracket) - min(bracket))
|
eps = 0.1 * (max(bracket) - min(bracket))
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
if min(max(bracket) - t, t - min(bracket)) < eps:
|
if min(max(bracket) - t, t - min(bracket)) < eps:
|
||||||
# interpolation close to boundary
|
# interpolation close to boundary
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
if insuf_progress or t >= max(bracket) or t <= min(bracket):
|
if insuf_progress or t >= max(bracket) or t <= min(bracket):
|
||||||
# evaluate at 0.1 away from boundary
|
# evaluate at 0.1 away from boundary
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
if abs(t - max(bracket)) < abs(t - min(bracket)):
|
if abs(t - max(bracket)) < abs(t - min(bracket)):
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
t = max(bracket) - eps
|
t = max(bracket) - eps
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
t = min(bracket) + eps
|
t = min(bracket) + eps
|
||||||
insuf_progress = False
|
insuf_progress = False
|
||||||
else:
|
else:
|
||||||
|
|
@ -160,45 +160,45 @@ def _strong_wolfe(
|
||||||
gtd_new = g_new.dot(d)
|
gtd_new = g_new.dot(d)
|
||||||
ls_iter += 1
|
ls_iter += 1
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
|
if f_new > (f + c1 * t * gtd) or f_new >= bracket_f[low_pos]:
|
||||||
# Armijo condition not satisfied or not lower than lowest point
|
# Armijo condition not satisfied or not lower than lowest point
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
bracket[high_pos] = t
|
bracket[high_pos] = t
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
bracket_f[high_pos] = f_new
|
bracket_f[high_pos] = f_new
|
||||||
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
bracket_g[high_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
bracket_gtd[high_pos] = gtd_new
|
bracket_gtd[high_pos] = gtd_new
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
|
low_pos, high_pos = (0, 1) if bracket_f[0] <= bracket_f[1] else (1, 0)
|
||||||
else:
|
else:
|
||||||
if abs(gtd_new) <= -c2 * gtd:
|
if abs(gtd_new) <= -c2 * gtd:
|
||||||
# Wolfe conditions satisfied
|
# Wolfe conditions satisfied
|
||||||
done = True
|
done = True
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
|
elif gtd_new * (bracket[high_pos] - bracket[low_pos]) >= 0:
|
||||||
# old high becomes new low
|
# old high becomes new low
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
bracket[high_pos] = bracket[low_pos]
|
bracket[high_pos] = bracket[low_pos]
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
bracket_f[high_pos] = bracket_f[low_pos]
|
bracket_f[high_pos] = bracket_f[low_pos]
|
||||||
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
|
bracket_g[high_pos] = bracket_g[low_pos] # type: ignore[possibly-undefined]
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
bracket_gtd[high_pos] = bracket_gtd[low_pos]
|
bracket_gtd[high_pos] = bracket_gtd[low_pos]
|
||||||
|
|
||||||
# new point becomes new low
|
# new point becomes new low
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
bracket[low_pos] = t
|
bracket[low_pos] = t
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
bracket_f[low_pos] = f_new
|
bracket_f[low_pos] = f_new
|
||||||
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
bracket_g[low_pos] = g_new.clone(memory_format=torch.contiguous_format) # type: ignore[possibly-undefined]
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
bracket_gtd[low_pos] = gtd_new
|
bracket_gtd[low_pos] = gtd_new
|
||||||
|
|
||||||
# return stuff
|
# return stuff
|
||||||
t = bracket[low_pos] # type: ignore[possibly-undefined]
|
t = bracket[low_pos] # type: ignore[possibly-undefined]
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
f_new = bracket_f[low_pos]
|
f_new = bracket_f[low_pos]
|
||||||
g_new = bracket_g[low_pos] # type: ignore[possibly-undefined]
|
g_new = bracket_g[low_pos] # type: ignore[possibly-undefined]
|
||||||
return f_new, g_new, t, ls_func_evals
|
return f_new, g_new, t, ls_func_evals
|
||||||
|
|
@ -276,7 +276,7 @@ class LBFGS(Optimizer):
|
||||||
|
|
||||||
def _numel(self):
|
def _numel(self):
|
||||||
if self._numel_cache is None:
|
if self._numel_cache is None:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._numel_cache = sum(
|
self._numel_cache = sum(
|
||||||
2 * p.numel() if torch.is_complex(p) else p.numel()
|
2 * p.numel() if torch.is_complex(p) else p.numel()
|
||||||
for p in self._params
|
for p in self._params
|
||||||
|
|
|
||||||
|
|
@ -422,7 +422,7 @@ class LambdaLR(LRScheduler):
|
||||||
|
|
||||||
for idx, fn in enumerate(self.lr_lambdas):
|
for idx, fn in enumerate(self.lr_lambdas):
|
||||||
if not isinstance(fn, types.FunctionType):
|
if not isinstance(fn, types.FunctionType):
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
|
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
@ -542,7 +542,7 @@ class MultiplicativeLR(LRScheduler):
|
||||||
|
|
||||||
for idx, fn in enumerate(self.lr_lambdas):
|
for idx, fn in enumerate(self.lr_lambdas):
|
||||||
if not isinstance(fn, types.FunctionType):
|
if not isinstance(fn, types.FunctionType):
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
|
state_dict["lr_lambdas"][idx] = fn.__dict__.copy()
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
@ -1219,7 +1219,7 @@ class SequentialLR(LRScheduler):
|
||||||
state_dict["_schedulers"] = [None] * len(self._schedulers)
|
state_dict["_schedulers"] = [None] * len(self._schedulers)
|
||||||
|
|
||||||
for idx, s in enumerate(self._schedulers):
|
for idx, s in enumerate(self._schedulers):
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
state_dict["_schedulers"][idx] = s.state_dict()
|
state_dict["_schedulers"][idx] = s.state_dict()
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
@ -1562,7 +1562,7 @@ class ChainedScheduler(LRScheduler):
|
||||||
state_dict["_schedulers"] = [None] * len(self._schedulers)
|
state_dict["_schedulers"] = [None] * len(self._schedulers)
|
||||||
|
|
||||||
for idx, s in enumerate(self._schedulers):
|
for idx, s in enumerate(self._schedulers):
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
state_dict["_schedulers"][idx] = s.state_dict()
|
state_dict["_schedulers"][idx] = s.state_dict()
|
||||||
|
|
||||||
return state_dict
|
return state_dict
|
||||||
|
|
@ -1671,7 +1671,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||||
self.default_min_lr = None
|
self.default_min_lr = None
|
||||||
self.min_lrs = list(min_lr)
|
self.min_lrs = list(min_lr)
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.default_min_lr = min_lr
|
self.default_min_lr = min_lr
|
||||||
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
self.min_lrs = [min_lr] * len(optimizer.param_groups)
|
||||||
|
|
||||||
|
|
@ -1731,7 +1731,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||||
"of the `optimizer` param groups."
|
"of the `optimizer` param groups."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
|
self.min_lrs = [self.default_min_lr] * len(self.optimizer.param_groups)
|
||||||
|
|
||||||
for i, param_group in enumerate(self.optimizer.param_groups):
|
for i, param_group in enumerate(self.optimizer.param_groups):
|
||||||
|
|
@ -1911,13 +1911,13 @@ class CyclicLR(LRScheduler):
|
||||||
|
|
||||||
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
|
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
step_size_up = float(step_size_up)
|
step_size_up = float(step_size_up)
|
||||||
step_size_down = (
|
step_size_down = (
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
float(step_size_down) if step_size_down is not None else step_size_up
|
float(step_size_down) if step_size_down is not None else step_size_up
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
self.total_size = step_size_up + step_size_down
|
self.total_size = step_size_up + step_size_down
|
||||||
self.step_ratio = step_size_up / self.total_size
|
self.step_ratio = step_size_up / self.total_size
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -62,7 +62,7 @@ def _use_grad_for_differentiable(func: Callable[_P, _T]) -> Callable[_P, _T]:
|
||||||
def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
def _use_grad(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||||
import torch._dynamo
|
import torch._dynamo
|
||||||
|
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
self = cast(Optimizer, args[0]) # assume first positional arg is `self`
|
self = cast(Optimizer, args[0]) # assume first positional arg is `self`
|
||||||
prev_grad = torch.is_grad_enabled()
|
prev_grad = torch.is_grad_enabled()
|
||||||
try:
|
try:
|
||||||
|
|
@ -136,13 +136,13 @@ def _disable_dynamo_if_unsupported(
|
||||||
if torch.compiler.is_compiling() and (
|
if torch.compiler.is_compiling() and (
|
||||||
not kwargs.get("capturable", False)
|
not kwargs.get("capturable", False)
|
||||||
and has_state_steps
|
and has_state_steps
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
and (arg := args[state_steps_ind])
|
and (arg := args[state_steps_ind])
|
||||||
and isinstance(arg, Sequence)
|
and isinstance(arg, Sequence)
|
||||||
and arg[0].is_cuda
|
and arg[0].is_cuda
|
||||||
or (
|
or (
|
||||||
"state_steps" in kwargs
|
"state_steps" in kwargs
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
and (kwarg := kwargs["state_steps"])
|
and (kwarg := kwargs["state_steps"])
|
||||||
and isinstance(kwarg, Sequence)
|
and isinstance(kwarg, Sequence)
|
||||||
and kwarg[0].is_cuda
|
and kwarg[0].is_cuda
|
||||||
|
|
@ -362,18 +362,18 @@ class Optimizer:
|
||||||
|
|
||||||
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
|
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
|
||||||
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
|
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
|
||||||
# pyrefly: ignore # not-a-type
|
# pyrefly: ignore [not-a-type]
|
||||||
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||||
_optimizer_state_dict_post_hooks: (
|
_optimizer_state_dict_post_hooks: (
|
||||||
# pyrefly: ignore # not-a-type
|
# pyrefly: ignore [not-a-type]
|
||||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||||
)
|
)
|
||||||
_optimizer_load_state_dict_pre_hooks: (
|
_optimizer_load_state_dict_pre_hooks: (
|
||||||
# pyrefly: ignore # not-a-type
|
# pyrefly: ignore [not-a-type]
|
||||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||||
)
|
)
|
||||||
_optimizer_load_state_dict_post_hooks: (
|
_optimizer_load_state_dict_post_hooks: (
|
||||||
# pyrefly: ignore # not-a-type
|
# pyrefly: ignore [not-a-type]
|
||||||
'OrderedDict[int, Callable[["Optimizer"], None]]'
|
'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -522,7 +522,7 @@ class Optimizer:
|
||||||
f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
|
f"{func} must return None or a tuple of (new_args, new_kwargs), but got {result}."
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyrefly: ignore # invalid-param-spec
|
# pyrefly: ignore [invalid-param-spec]
|
||||||
out = func(*args, **kwargs)
|
out = func(*args, **kwargs)
|
||||||
self._optimizer_step_code()
|
self._optimizer_step_code()
|
||||||
|
|
||||||
|
|
@ -961,9 +961,9 @@ class Optimizer:
|
||||||
return Optimizer._process_value_according_to_param_policy(
|
return Optimizer._process_value_according_to_param_policy(
|
||||||
param,
|
param,
|
||||||
value,
|
value,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
param_id,
|
param_id,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
param_groups,
|
param_groups,
|
||||||
key,
|
key,
|
||||||
)
|
)
|
||||||
|
|
@ -976,7 +976,7 @@ class Optimizer:
|
||||||
}
|
}
|
||||||
elif isinstance(value, Iterable):
|
elif isinstance(value, Iterable):
|
||||||
return type(value)(
|
return type(value)(
|
||||||
# pyrefly: ignore # bad-argument-count
|
# pyrefly: ignore [bad-argument-count]
|
||||||
_cast(param, v, param_id=param_id, param_groups=param_groups)
|
_cast(param, v, param_id=param_id, param_groups=param_groups)
|
||||||
for v in value
|
for v in value
|
||||||
) # type: ignore[call-arg]
|
) # type: ignore[call-arg]
|
||||||
|
|
|
||||||
|
|
@ -323,7 +323,7 @@ def _single_tensor_radam(
|
||||||
rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
|
rho_t = rho_inf - 2 * step * (beta2**step) / bias_correction2
|
||||||
|
|
||||||
def _compute_rect():
|
def _compute_rect():
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
return (
|
return (
|
||||||
(rho_t - 4)
|
(rho_t - 4)
|
||||||
* (rho_t - 2)
|
* (rho_t - 2)
|
||||||
|
|
@ -338,7 +338,7 @@ def _single_tensor_radam(
|
||||||
else:
|
else:
|
||||||
exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
|
exp_avg_sq_sqrt = exp_avg_sq_sqrt.add_(eps)
|
||||||
|
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
return (bias_correction2**0.5) / exp_avg_sq_sqrt
|
return (bias_correction2**0.5) / exp_avg_sq_sqrt
|
||||||
|
|
||||||
# Compute the variance rectification term and update parameters accordingly
|
# Compute the variance rectification term and update parameters accordingly
|
||||||
|
|
|
||||||
|
|
@ -348,7 +348,7 @@ def _single_tensor_sgd(
|
||||||
# usually this is the differentiable path, which is why the param.clone() is needed
|
# usually this is the differentiable path, which is why the param.clone() is needed
|
||||||
grad = grad.addcmul_(param.clone(), weight_decay)
|
grad = grad.addcmul_(param.clone(), weight_decay)
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
grad = grad.add(param, alpha=weight_decay)
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
else:
|
else:
|
||||||
grad = grad.add(param, alpha=weight_decay)
|
grad = grad.add(param, alpha=weight_decay)
|
||||||
|
|
@ -372,7 +372,7 @@ def _single_tensor_sgd(
|
||||||
if lr.requires_grad:
|
if lr.requires_grad:
|
||||||
param.addcmul_(grad, lr, value=-1)
|
param.addcmul_(grad, lr, value=-1)
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
param.add_(grad, alpha=-lr)
|
param.add_(grad, alpha=-lr)
|
||||||
else:
|
else:
|
||||||
param.add_(grad, alpha=-lr)
|
param.add_(grad, alpha=-lr)
|
||||||
|
|
|
||||||
|
|
@ -250,13 +250,13 @@ class AveragedModel(Module):
|
||||||
def update_parameters(self, model: Module):
|
def update_parameters(self, model: Module):
|
||||||
"""Update model parameters."""
|
"""Update model parameters."""
|
||||||
self_param = (
|
self_param = (
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
itertools.chain(self.module.parameters(), self.module.buffers())
|
itertools.chain(self.module.parameters(), self.module.buffers())
|
||||||
if self.use_buffers
|
if self.use_buffers
|
||||||
else self.parameters()
|
else self.parameters()
|
||||||
)
|
)
|
||||||
model_param = (
|
model_param = (
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
itertools.chain(model.parameters(), model.buffers())
|
itertools.chain(model.parameters(), model.buffers())
|
||||||
if self.use_buffers
|
if self.use_buffers
|
||||||
else model.parameters()
|
else model.parameters()
|
||||||
|
|
@ -298,17 +298,17 @@ class AveragedModel(Module):
|
||||||
avg_fn = get_swa_avg_fn()
|
avg_fn = get_swa_avg_fn()
|
||||||
n_averaged = self.n_averaged.to(device)
|
n_averaged = self.n_averaged.to(device)
|
||||||
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
|
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
|
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
|
||||||
else:
|
else:
|
||||||
for p_averaged, p_model in zip( # type: ignore[assignment]
|
for p_averaged, p_model in zip( # type: ignore[assignment]
|
||||||
self_param_detached, model_param_detached
|
self_param_detached, model_param_detached
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
n_averaged = self.n_averaged.to(p_averaged.device)
|
n_averaged = self.n_averaged.to(p_averaged.device)
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
p_averaged.detach().copy_(
|
p_averaged.detach().copy_(
|
||||||
# pyrefly: ignore # missing-attribute, bad-argument-type
|
# pyrefly: ignore [missing-attribute, bad-argument-type]
|
||||||
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
|
self.avg_fn(p_averaged.detach(), p_model, n_averaged)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -497,14 +497,14 @@ class SWALR(LRScheduler):
|
||||||
step = self._step_count - 1
|
step = self._step_count - 1
|
||||||
if self.anneal_epochs == 0:
|
if self.anneal_epochs == 0:
|
||||||
step = max(1, step)
|
step = max(1, step)
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
|
prev_t = max(0, min(1, (step - 1) / max(1, self.anneal_epochs)))
|
||||||
prev_alpha = self.anneal_func(prev_t)
|
prev_alpha = self.anneal_func(prev_t)
|
||||||
prev_lrs = [
|
prev_lrs = [
|
||||||
self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
|
self._get_initial_lr(group["lr"], group["swa_lr"], prev_alpha)
|
||||||
for group in self.optimizer.param_groups
|
for group in self.optimizer.param_groups
|
||||||
]
|
]
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
t = max(0, min(1, step / max(1, self.anneal_epochs)))
|
t = max(0, min(1, step / max(1, self.anneal_epochs)))
|
||||||
alpha = self.anneal_func(t)
|
alpha = self.anneal_func(t)
|
||||||
return [
|
return [
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
# pyrefly: ignore # missing-module-attribute
|
# pyrefly: ignore [missing-module-attribute]
|
||||||
from pickle import ( # type: ignore[attr-defined]
|
from pickle import ( # type: ignore[attr-defined]
|
||||||
_compat_pickle,
|
_compat_pickle,
|
||||||
_extension_registry,
|
_extension_registry,
|
||||||
|
|
|
||||||
|
|
@ -3,7 +3,7 @@ import importlib
|
||||||
import logging
|
import logging
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
# pyrefly: ignore # missing-module-attribute
|
# pyrefly: ignore [missing-module-attribute]
|
||||||
from pickle import ( # type: ignore[attr-defined]
|
from pickle import ( # type: ignore[attr-defined]
|
||||||
_getattribute,
|
_getattribute,
|
||||||
_Pickler,
|
_Pickler,
|
||||||
|
|
|
||||||
|
|
@ -652,7 +652,7 @@ class PackageExporter:
|
||||||
memo: defaultdict[int, str] = defaultdict(None)
|
memo: defaultdict[int, str] = defaultdict(None)
|
||||||
memo_count = 0
|
memo_count = 0
|
||||||
# pickletools.dis(data_value)
|
# pickletools.dis(data_value)
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
for opcode, arg, _pos in pickletools.genops(data_value):
|
for opcode, arg, _pos in pickletools.genops(data_value):
|
||||||
if pickle_protocol == 4:
|
if pickle_protocol == 4:
|
||||||
if (
|
if (
|
||||||
|
|
|
||||||
|
|
@ -230,7 +230,7 @@ class SchemaMatcher:
|
||||||
for schema in cls.match_schemas(t):
|
for schema in cls.match_schemas(t):
|
||||||
mutable = mutable or [False for _ in schema.arguments]
|
mutable = mutable or [False for _ in schema.arguments]
|
||||||
for i, arg in enumerate(schema.arguments):
|
for i, arg in enumerate(schema.arguments):
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
mutable[i] |= getattr(arg.alias_info, "is_write", False)
|
mutable[i] |= getattr(arg.alias_info, "is_write", False)
|
||||||
|
|
||||||
return tuple(mutable or (None for _ in t.inputs))
|
return tuple(mutable or (None for _ in t.inputs))
|
||||||
|
|
@ -1084,7 +1084,7 @@ class MemoryProfileTimeline:
|
||||||
|
|
||||||
if action in (Action.PREEXISTING, Action.CREATE):
|
if action in (Action.PREEXISTING, Action.CREATE):
|
||||||
raw_events.append(
|
raw_events.append(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
(
|
(
|
||||||
t,
|
t,
|
||||||
_ACTION_TO_INDEX[action],
|
_ACTION_TO_INDEX[action],
|
||||||
|
|
@ -1095,7 +1095,7 @@ class MemoryProfileTimeline:
|
||||||
|
|
||||||
elif action == Action.INCREMENT_VERSION:
|
elif action == Action.INCREMENT_VERSION:
|
||||||
raw_events.append(
|
raw_events.append(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
(
|
(
|
||||||
t,
|
t,
|
||||||
_ACTION_TO_INDEX[action],
|
_ACTION_TO_INDEX[action],
|
||||||
|
|
@ -1104,7 +1104,7 @@ class MemoryProfileTimeline:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
raw_events.append(
|
raw_events.append(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
(
|
(
|
||||||
t,
|
t,
|
||||||
_ACTION_TO_INDEX[action],
|
_ACTION_TO_INDEX[action],
|
||||||
|
|
@ -1115,7 +1115,7 @@ class MemoryProfileTimeline:
|
||||||
|
|
||||||
elif action == Action.DESTROY:
|
elif action == Action.DESTROY:
|
||||||
raw_events.append(
|
raw_events.append(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
(
|
(
|
||||||
t,
|
t,
|
||||||
_ACTION_TO_INDEX[action],
|
_ACTION_TO_INDEX[action],
|
||||||
|
|
|
||||||
|
|
@ -211,7 +211,7 @@ class BasicEvaluation:
|
||||||
# Find latest cuda kernel event
|
# Find latest cuda kernel event
|
||||||
if hasattr(event, "start_us"):
|
if hasattr(event, "start_us"):
|
||||||
start_time = event.start_us() * 1000
|
start_time = event.start_us() * 1000
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
end_time = (event.start_us() + event.duration_us()) * 1000
|
end_time = (event.start_us() + event.duration_us()) * 1000
|
||||||
# Find current spawned cuda kernel event
|
# Find current spawned cuda kernel event
|
||||||
if event in kernel_mapping and kernel_mapping[event] is not None:
|
if event in kernel_mapping and kernel_mapping[event] is not None:
|
||||||
|
|
|
||||||
|
|
@ -161,19 +161,19 @@ class _KinetoProfile:
|
||||||
self.mem_tl: Optional[MemoryProfileTimeline] = None
|
self.mem_tl: Optional[MemoryProfileTimeline] = None
|
||||||
self.use_device = None
|
self.use_device = None
|
||||||
if ProfilerActivity.CUDA in self.activities:
|
if ProfilerActivity.CUDA in self.activities:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.use_device = "cuda"
|
self.use_device = "cuda"
|
||||||
elif ProfilerActivity.XPU in self.activities:
|
elif ProfilerActivity.XPU in self.activities:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.use_device = "xpu"
|
self.use_device = "xpu"
|
||||||
elif ProfilerActivity.MTIA in self.activities:
|
elif ProfilerActivity.MTIA in self.activities:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.use_device = "mtia"
|
self.use_device = "mtia"
|
||||||
elif ProfilerActivity.HPU in self.activities:
|
elif ProfilerActivity.HPU in self.activities:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.use_device = "hpu"
|
self.use_device = "hpu"
|
||||||
elif ProfilerActivity.PrivateUse1 in self.activities:
|
elif ProfilerActivity.PrivateUse1 in self.activities:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.use_device = _get_privateuse1_backend_name()
|
self.use_device = _get_privateuse1_backend_name()
|
||||||
|
|
||||||
# user-defined metadata to be amended to the trace
|
# user-defined metadata to be amended to the trace
|
||||||
|
|
@ -385,7 +385,7 @@ class _KinetoProfile:
|
||||||
}
|
}
|
||||||
if backend == "nccl":
|
if backend == "nccl":
|
||||||
nccl_version = torch.cuda.nccl.version()
|
nccl_version = torch.cuda.nccl.version()
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
|
dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
|
||||||
return dist_info
|
return dist_info
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,7 +71,7 @@ def quantized_weight_reorder_for_mixed_dtypes_linear_cutlass(
|
||||||
nrows // 16, 16
|
nrows // 16, 16
|
||||||
)
|
)
|
||||||
).view(-1)
|
).view(-1)
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
outp = outp.index_copy(1, cols_permuted, outp)
|
outp = outp.index_copy(1, cols_permuted, outp)
|
||||||
|
|
||||||
# interleave_column_major_tensor
|
# interleave_column_major_tensor
|
||||||
|
|
|
||||||
|
|
@ -790,7 +790,7 @@ class _open_zipfile_writer_file(_opener[torch._C.PyTorchFileWriter]):
|
||||||
# PyTorchFileWriter only supports ascii filename.
|
# PyTorchFileWriter only supports ascii filename.
|
||||||
# For filenames with non-ascii characters, we rely on Python
|
# For filenames with non-ascii characters, we rely on Python
|
||||||
# for writing out the file.
|
# for writing out the file.
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.file_stream = io.FileIO(self.name, mode="w")
|
self.file_stream = io.FileIO(self.name, mode="w")
|
||||||
super().__init__(
|
super().__init__(
|
||||||
torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload
|
torch._C.PyTorchFileWriter( # pyrefly: ignore # no-matching-overload
|
||||||
|
|
|
||||||
|
|
@ -397,15 +397,15 @@ def kaiser(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Avoid NaNs by casting `beta` to the appropriate dtype.
|
# Avoid NaNs by casting `beta` to the appropriate dtype.
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
beta = torch.tensor(beta, dtype=dtype, device=device)
|
beta = torch.tensor(beta, dtype=dtype, device=device)
|
||||||
|
|
||||||
start = -beta
|
start = -beta
|
||||||
constant = 2.0 * beta / (M if not sym else M - 1)
|
constant = 2.0 * beta / (M if not sym else M - 1)
|
||||||
end = torch.minimum(
|
end = torch.minimum(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
beta,
|
beta,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
start + (M - 1) * constant,
|
start + (M - 1) * constant,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -420,7 +420,7 @@ def kaiser(
|
||||||
)
|
)
|
||||||
|
|
||||||
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(
|
return torch.i0(torch.sqrt(beta * beta - torch.pow(k, 2))) / torch.i0(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
beta
|
beta
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -623,20 +623,20 @@ def as_sparse_gradcheck(gradcheck):
|
||||||
)
|
)
|
||||||
obj = obj.to_dense().sparse_mask(full_mask)
|
obj = obj.to_dense().sparse_mask(full_mask)
|
||||||
if obj.layout is torch.sparse_coo:
|
if obj.layout is torch.sparse_coo:
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
d.update(
|
d.update(
|
||||||
indices=obj._indices(), is_coalesced=obj.is_coalesced()
|
indices=obj._indices(), is_coalesced=obj.is_coalesced()
|
||||||
)
|
)
|
||||||
values = obj._values()
|
values = obj._values()
|
||||||
elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
d.update(
|
d.update(
|
||||||
compressed_indices=obj.crow_indices(),
|
compressed_indices=obj.crow_indices(),
|
||||||
plain_indices=obj.col_indices(),
|
plain_indices=obj.col_indices(),
|
||||||
)
|
)
|
||||||
values = obj.values()
|
values = obj.values()
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
d.update(
|
d.update(
|
||||||
compressed_indices=obj.ccol_indices(),
|
compressed_indices=obj.ccol_indices(),
|
||||||
plain_indices=obj.row_indices(),
|
plain_indices=obj.row_indices(),
|
||||||
|
|
|
||||||
|
|
@ -140,7 +140,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
||||||
|
|
||||||
if dense.dtype != torch.float:
|
if dense.dtype != torch.float:
|
||||||
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
|
sparse0 = dense_4.gather(-1, idxs0.unsqueeze(-1)) # type: ignore[possibly-undefined]
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
sparse1 = dense_4.gather(-1, idxs1.unsqueeze(-1))
|
||||||
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
sparse = torch.stack((sparse0, sparse1), dim=-1).view(m, k // 2)
|
||||||
else:
|
else:
|
||||||
|
|
@ -173,7 +173,7 @@ def sparse_semi_structured_from_dense_cutlass(dense):
|
||||||
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
meta_offsets = _calculate_meta_reordering_scatter_offsets(
|
||||||
m, meta_ncols, meta_dtype, device
|
m, meta_ncols, meta_dtype, device
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
meta_reordered.scatter_(0, meta_offsets, meta.view(-1))
|
||||||
|
|
||||||
return (sparse, meta_reordered.view(m, meta_ncols))
|
return (sparse, meta_reordered.view(m, meta_ncols))
|
||||||
|
|
|
||||||
|
|
@ -67,7 +67,7 @@ def semi_sparse_t(func, types, args=(), kwargs=None) -> torch.Tensor:
|
||||||
# Because we cannot go from the compressed representation back to the dense representation currently,
|
# Because we cannot go from the compressed representation back to the dense representation currently,
|
||||||
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
# we just keep track of how many times we have been transposed. Depending on whether the sparse matrix
|
||||||
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
# is the first or second argument, we expect an even / odd number of calls to transpose respectively.
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
return self.__class__(
|
return self.__class__(
|
||||||
torch.Size([self.shape[-1], self.shape[0]]),
|
torch.Size([self.shape[-1], self.shape[0]]),
|
||||||
packed=self.packed_t,
|
packed=self.packed_t,
|
||||||
|
|
|
||||||
|
|
@ -1297,7 +1297,7 @@ def bsr_dense_addmm(
|
||||||
assert alpha != 0
|
assert alpha != 0
|
||||||
|
|
||||||
def kernel(grid, *sliced_tensors):
|
def kernel(grid, *sliced_tensors):
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
_bsr_strided_addmm_kernel[grid](
|
_bsr_strided_addmm_kernel[grid](
|
||||||
*ptr_stride_extractor(*sliced_tensors),
|
*ptr_stride_extractor(*sliced_tensors),
|
||||||
beta,
|
beta,
|
||||||
|
|
@ -1427,7 +1427,7 @@ if has_triton():
|
||||||
|
|
||||||
mat1_block = tl.load(
|
mat1_block = tl.load(
|
||||||
mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :],
|
mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :],
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
mask=mask_k[None, :],
|
mask=mask_k[None, :],
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
|
|
@ -1436,7 +1436,7 @@ if has_triton():
|
||||||
mat2_block_ptrs
|
mat2_block_ptrs
|
||||||
+ mat2_tiled_col_stride * col_block
|
+ mat2_tiled_col_stride * col_block
|
||||||
+ mat2_row_block_stride * k_offsets[:, None],
|
+ mat2_row_block_stride * k_offsets[:, None],
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
mask=mask_k[:, None],
|
mask=mask_k[:, None],
|
||||||
other=0.0,
|
other=0.0,
|
||||||
)
|
)
|
||||||
|
|
@ -1974,7 +1974,7 @@ if has_triton():
|
||||||
if attn_mask.dtype is not torch.bool:
|
if attn_mask.dtype is not torch.bool:
|
||||||
check_dtype(f_name, attn_mask, query.dtype)
|
check_dtype(f_name, attn_mask, query.dtype)
|
||||||
|
|
||||||
# pyrefly: ignore # not-callable
|
# pyrefly: ignore [not-callable]
|
||||||
sdpa = sampled_addmm(
|
sdpa = sampled_addmm(
|
||||||
attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False
|
attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False
|
||||||
)
|
)
|
||||||
|
|
@ -1986,10 +1986,10 @@ if has_triton():
|
||||||
)
|
)
|
||||||
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
||||||
sdpa.values().mul_(scale_factor)
|
sdpa.values().mul_(scale_factor)
|
||||||
# pyrefly: ignore # not-callable
|
# pyrefly: ignore [not-callable]
|
||||||
sdpa = bsr_softmax(sdpa)
|
sdpa = bsr_softmax(sdpa)
|
||||||
torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
|
torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
|
||||||
# pyrefly: ignore # not-callable
|
# pyrefly: ignore [not-callable]
|
||||||
sdpa = bsr_dense_mm(sdpa, value)
|
sdpa = bsr_dense_mm(sdpa, value)
|
||||||
return sdpa
|
return sdpa
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -234,10 +234,10 @@ def dump():
|
||||||
part2 = current_content[end_data_index:]
|
part2 = current_content[end_data_index:]
|
||||||
data_part = []
|
data_part = []
|
||||||
for op_key in sorted(_operation_device_version_data, key=sort_key):
|
for op_key in sorted(_operation_device_version_data, key=sort_key):
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
data_part.append(" " + repr(op_key).replace("'", '"') + ": {")
|
data_part.append(" " + repr(op_key).replace("'", '"') + ": {")
|
||||||
op_data = _operation_device_version_data[op_key]
|
op_data = _operation_device_version_data[op_key]
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
data_part.extend(f" {key}: {op_data[key]}," for key in sorted(op_data))
|
data_part.extend(f" {key}: {op_data[key]}," for key in sorted(op_data))
|
||||||
data_part.append(" },")
|
data_part.append(" },")
|
||||||
new_content = part1 + "\n".join(data_part) + "\n" + part2
|
new_content = part1 + "\n".join(data_part) + "\n" + part2
|
||||||
|
|
@ -371,7 +371,7 @@ def minimize(
|
||||||
if next_target < minimal_target:
|
if next_target < minimal_target:
|
||||||
minimal_target = next_target
|
minimal_target = next_target
|
||||||
parameters = next_parameters
|
parameters = next_parameters
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
pbar.total += i + 1
|
pbar.total += i + 1
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -185,7 +185,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
|
||||||
outer_stride,
|
outer_stride,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
|
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
return cls(
|
return cls(
|
||||||
shape=shape,
|
shape=shape,
|
||||||
packed=inner_tensors.get("packed", None),
|
packed=inner_tensors.get("packed", None),
|
||||||
|
|
@ -415,7 +415,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||||
sparse_tensor_cutlass,
|
sparse_tensor_cutlass,
|
||||||
meta_tensor_cutlass,
|
meta_tensor_cutlass,
|
||||||
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
|
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
return cls(
|
return cls(
|
||||||
original_tensor.shape,
|
original_tensor.shape,
|
||||||
packed=sparse_tensor_cutlass,
|
packed=sparse_tensor_cutlass,
|
||||||
|
|
@ -502,7 +502,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
|
||||||
original_tensor, algorithm=algorithm, use_cutlass=True
|
original_tensor, algorithm=algorithm, use_cutlass=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
return cls(
|
return cls(
|
||||||
original_tensor.shape,
|
original_tensor.shape,
|
||||||
packed=packed,
|
packed=packed,
|
||||||
|
|
@ -564,7 +564,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||||
cls, original_tensor: torch.Tensor
|
cls, original_tensor: torch.Tensor
|
||||||
) -> "SparseSemiStructuredTensorCUSPARSELT":
|
) -> "SparseSemiStructuredTensorCUSPARSELT":
|
||||||
cls._validate_device_dim_dtype_shape(original_tensor)
|
cls._validate_device_dim_dtype_shape(original_tensor)
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
return cls(
|
return cls(
|
||||||
shape=original_tensor.shape,
|
shape=original_tensor.shape,
|
||||||
packed=torch._cslt_compress(original_tensor),
|
packed=torch._cslt_compress(original_tensor),
|
||||||
|
|
@ -631,7 +631,7 @@ class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
|
||||||
packed = packed.view(original_tensor.shape[0], -1)
|
packed = packed.view(original_tensor.shape[0], -1)
|
||||||
packed_t = packed_t.view(original_tensor.shape[1], -1)
|
packed_t = packed_t.view(original_tensor.shape[1], -1)
|
||||||
|
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
return cls(
|
return cls(
|
||||||
original_tensor.shape,
|
original_tensor.shape,
|
||||||
packed=packed,
|
packed=packed,
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,6 @@ from torch._C import FileCheck as FileCheck
|
||||||
|
|
||||||
from . import _utils
|
from . import _utils
|
||||||
|
|
||||||
# pyrefly: ignore # deprecated
|
# pyrefly: ignore [deprecated]
|
||||||
from ._comparison import assert_allclose, assert_close as assert_close
|
from ._comparison import assert_allclose, assert_close as assert_close
|
||||||
from ._creation import make_tensor as make_tensor
|
from ._creation import make_tensor as make_tensor
|
||||||
|
|
|
||||||
|
|
@ -243,7 +243,7 @@ def make_scalar_mismatch_msg(
|
||||||
Defaults to "Scalars".
|
Defaults to "Scalars".
|
||||||
"""
|
"""
|
||||||
abs_diff = abs(actual - expected)
|
abs_diff = abs(actual - expected)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected)
|
rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected)
|
||||||
return _make_mismatch_msg(
|
return _make_mismatch_msg(
|
||||||
default_identifier="Scalars",
|
default_identifier="Scalars",
|
||||||
|
|
@ -487,7 +487,7 @@ class BooleanPair(Pair):
|
||||||
def _supported_types(self) -> tuple[type, ...]:
|
def _supported_types(self) -> tuple[type, ...]:
|
||||||
cls: list[type] = [bool]
|
cls: list[type] = [bool]
|
||||||
if HAS_NUMPY:
|
if HAS_NUMPY:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
cls.append(np.bool_)
|
cls.append(np.bool_)
|
||||||
return tuple(cls)
|
return tuple(cls)
|
||||||
|
|
||||||
|
|
@ -503,7 +503,7 @@ class BooleanPair(Pair):
|
||||||
def _to_bool(self, bool_like: Any, *, id: tuple[Any, ...]) -> bool:
|
def _to_bool(self, bool_like: Any, *, id: tuple[Any, ...]) -> bool:
|
||||||
if isinstance(bool_like, bool):
|
if isinstance(bool_like, bool):
|
||||||
return bool_like
|
return bool_like
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
elif isinstance(bool_like, np.bool_):
|
elif isinstance(bool_like, np.bool_):
|
||||||
return bool_like.item()
|
return bool_like.item()
|
||||||
else:
|
else:
|
||||||
|
|
@ -583,7 +583,7 @@ class NumberPair(Pair):
|
||||||
def _supported_types(self) -> tuple[type, ...]:
|
def _supported_types(self) -> tuple[type, ...]:
|
||||||
cls = list(self._NUMBER_TYPES)
|
cls = list(self._NUMBER_TYPES)
|
||||||
if HAS_NUMPY:
|
if HAS_NUMPY:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
cls.append(np.number)
|
cls.append(np.number)
|
||||||
return tuple(cls)
|
return tuple(cls)
|
||||||
|
|
||||||
|
|
@ -599,7 +599,7 @@ class NumberPair(Pair):
|
||||||
def _to_number(
|
def _to_number(
|
||||||
self, number_like: Any, *, id: tuple[Any, ...]
|
self, number_like: Any, *, id: tuple[Any, ...]
|
||||||
) -> Union[int, float, complex]:
|
) -> Union[int, float, complex]:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if HAS_NUMPY and isinstance(number_like, np.number):
|
if HAS_NUMPY and isinstance(number_like, np.number):
|
||||||
return number_like.item()
|
return number_like.item()
|
||||||
elif isinstance(number_like, self._NUMBER_TYPES):
|
elif isinstance(number_like, self._NUMBER_TYPES):
|
||||||
|
|
@ -1122,7 +1122,7 @@ def originate_pairs(
|
||||||
mapping_types: tuple[type, ...] = (collections.abc.Mapping,),
|
mapping_types: tuple[type, ...] = (collections.abc.Mapping,),
|
||||||
id: tuple[Any, ...] = (),
|
id: tuple[Any, ...] = (),
|
||||||
**options: Any,
|
**options: Any,
|
||||||
# pyrefly: ignore # bad-return
|
# pyrefly: ignore [bad-return]
|
||||||
) -> list[Pair]:
|
) -> list[Pair]:
|
||||||
"""Originates pairs from the individual inputs.
|
"""Originates pairs from the individual inputs.
|
||||||
|
|
||||||
|
|
@ -1221,7 +1221,7 @@ def originate_pairs(
|
||||||
else:
|
else:
|
||||||
for pair_type in pair_types:
|
for pair_type in pair_types:
|
||||||
try:
|
try:
|
||||||
# pyrefly: ignore # bad-instantiation
|
# pyrefly: ignore [bad-instantiation]
|
||||||
return [pair_type(actual, expected, id=id, **options)]
|
return [pair_type(actual, expected, id=id, **options)]
|
||||||
# Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the
|
# Raising an `UnsupportedInputs` during origination indicates that the pair type is not able to handle the
|
||||||
# inputs. Thus, we try the next pair type.
|
# inputs. Thus, we try the next pair type.
|
||||||
|
|
@ -1319,9 +1319,9 @@ def not_close_error_metas(
|
||||||
# would not get freed until cycle collection, leaking cuda memory in tests.
|
# would not get freed until cycle collection, leaking cuda memory in tests.
|
||||||
# We break the cycle by removing the reference to the error_meta objects
|
# We break the cycle by removing the reference to the error_meta objects
|
||||||
# from this frame as it returns.
|
# from this frame as it returns.
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
error_metas = [error_metas]
|
error_metas = [error_metas]
|
||||||
# pyrefly: ignore # bad-return
|
# pyrefly: ignore [bad-return]
|
||||||
return error_metas.pop()
|
return error_metas.pop()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user