Fix pyrefly ignores 1/n (#166239)

First diff adjusting the syntax for pyrefly: ignore suppressions so they only hide one class of type error.

Test:
lintrunner
pyrefly check

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166239
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss 2025-10-26 00:44:07 +00:00 committed by PyTorch MergeBot
parent 621ba05107
commit c7eee49525
55 changed files with 282 additions and 184 deletions

View File

@ -130,6 +130,7 @@ errors.bad-param-name-override = false
# Mypy doesn't require that imports are explicitly imported, so be compatible with that. # Mypy doesn't require that imports are explicitly imported, so be compatible with that.
# Might be a good idea to turn this on in future. # Might be a good idea to turn this on in future.
errors.implicit-import = false errors.implicit-import = false
errors.deprecated = false # re-enable after we've fix import formatting
permissive-ignores = true permissive-ignores = true
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"] replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
search-path = ["tools/experimental"] search-path = ["tools/experimental"]

View File

@ -2,7 +2,8 @@ from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from dataclasses_json import DataClassJsonMixin # pyrefly: ignore [missing-import]
from dataclasses_json import DataClassJsonMixin # type: ignore[import-not-found]
_DATA_MODEL_VERSION = 1.5 _DATA_MODEL_VERSION = 1.5
@ -17,7 +18,7 @@ class UtilizationStats:
@dataclass @dataclass
class UtilizationMetadata(DataClassJsonMixin): class UtilizationMetadata(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
level: str level: str
workflow_id: str workflow_id: str
job_id: str job_id: str
@ -33,7 +34,7 @@ class UtilizationMetadata(DataClassJsonMixin):
@dataclass @dataclass
class GpuUsage(DataClassJsonMixin): class GpuUsage(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
uuid: Optional[str] = None uuid: Optional[str] = None
util_percent: Optional[UtilizationStats] = None util_percent: Optional[UtilizationStats] = None
mem_util_percent: Optional[UtilizationStats] = None mem_util_percent: Optional[UtilizationStats] = None
@ -43,14 +44,14 @@ class GpuUsage(DataClassJsonMixin):
@dataclass @dataclass
class RecordData(DataClassJsonMixin): class RecordData(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
cpu: Optional[UtilizationStats] = None cpu: Optional[UtilizationStats] = None
memory: Optional[UtilizationStats] = None memory: Optional[UtilizationStats] = None
gpu_usage: Optional[list[GpuUsage]] = None gpu_usage: Optional[list[GpuUsage]] = None
@dataclass @dataclass
class UtilizationRecord(DataClassJsonMixin): class UtilizationRecord(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
level: str level: str
timestamp: int timestamp: int
data: Optional[RecordData] = None data: Optional[RecordData] = None
@ -63,7 +64,7 @@ class UtilizationRecord(DataClassJsonMixin):
# the db schema related to this is: # the db schema related to this is:
# https://github.com/pytorch/test-infra/blob/main/clickhouse_db_schema/oss_ci_utilization/oss_ci_utilization_metadata_schema.sql # https://github.com/pytorch/test-infra/blob/main/clickhouse_db_schema/oss_ci_utilization/oss_ci_utilization_metadata_schema.sql
@dataclass @dataclass
class OssCiSegmentV1(DataClassJsonMixin): class OssCiSegmentV1(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
level: str level: str
name: str name: str
start_at: int start_at: int

View File

@ -1703,7 +1703,7 @@ def _check(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error an object that has a ``__str__()`` method to be used as the error
message. Default: ``None`` message. Default: ``None``
""" """
_check_with(RuntimeError, cond, message) # pyrefly: ignore # bad-argument-type _check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
# TODO add deprecation annotation # TODO add deprecation annotation
@ -1753,7 +1753,7 @@ def _check_index(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error an object that has a ``__str__()`` method to be used as the error
message. Default: ``None`` message. Default: ``None``
""" """
_check_with(IndexError, cond, message) # pyrefly: ignore # bad-argument-type _check_with(IndexError, cond, message) # pyrefly: ignore [bad-argument-type]
def _check_value(cond, message=None): # noqa: F811 def _check_value(cond, message=None): # noqa: F811
@ -1771,7 +1771,7 @@ def _check_value(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error an object that has a ``__str__()`` method to be used as the error
message. Default: ``None`` message. Default: ``None``
""" """
_check_with(ValueError, cond, message) # pyrefly: ignore # bad-argument-type _check_with(ValueError, cond, message) # pyrefly: ignore [bad-argument-type]
def _check_type(cond, message=None): # noqa: F811 def _check_type(cond, message=None): # noqa: F811
@ -1789,7 +1789,7 @@ def _check_type(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error an object that has a ``__str__()`` method to be used as the error
message. Default: ``None`` message. Default: ``None``
""" """
_check_with(TypeError, cond, message) # pyrefly: ignore # bad-argument-type _check_with(TypeError, cond, message) # pyrefly: ignore [bad-argument-type]
def _check_not_implemented(cond, message=None): # noqa: F811 def _check_not_implemented(cond, message=None): # noqa: F811

View File

@ -101,7 +101,7 @@ def custom_op(
lib, ns, function_schema, name, ophandle, _private_access=True lib, ns, function_schema, name, ophandle, _private_access=True
) )
result.__name__ = func.__name__ # pyrefly: ignore # bad-assignment result.__name__ = func.__name__ # pyrefly: ignore [bad-assignment]
result.__module__ = func.__module__ result.__module__ = func.__module__
result.__doc__ = func.__doc__ result.__doc__ = func.__doc__

View File

@ -154,7 +154,7 @@ def make_crossref_functionalize(
maybe_detach, (f_args, f_kwargs) maybe_detach, (f_args, f_kwargs)
) )
with fake_mode: with fake_mode:
f_r = op(*f_args, **f_kwargs) # pyrefly: ignore # invalid-param-spec f_r = op(*f_args, **f_kwargs) # pyrefly: ignore [invalid-param-spec]
r = op._op_dk(final_key, *args, **kwargs) r = op._op_dk(final_key, *args, **kwargs)
def desc(): def desc():

View File

@ -1029,7 +1029,7 @@ class BuiltinVariable(VariableTracker):
def call_self_handler(tx: "InstructionTranslator", args, kwargs): def call_self_handler(tx: "InstructionTranslator", args, kwargs):
try: try:
# pyrefly: ignore # not-callable # pyrefly: ignore [not-callable]
result = self_handler(tx, *args, **kwargs) result = self_handler(tx, *args, **kwargs)
if result is not None: if result is not None:
return result return result
@ -1037,7 +1037,7 @@ class BuiltinVariable(VariableTracker):
# Check if binding is bad. inspect signature bind is expensive. # Check if binding is bad. inspect signature bind is expensive.
# So check only when handler call fails. # So check only when handler call fails.
try: try:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
inspect.signature(self_handler).bind(tx, *args, **kwargs) inspect.signature(self_handler).bind(tx, *args, **kwargs)
except TypeError as e: except TypeError as e:
has_constant_handler = obj.has_constant_handler(args, kwargs) has_constant_handler = obj.has_constant_handler(args, kwargs)
@ -1090,7 +1090,7 @@ class BuiltinVariable(VariableTracker):
hints=[*graph_break_hints.DYNAMO_BUG], hints=[*graph_break_hints.DYNAMO_BUG],
from_exc=exc, from_exc=exc,
) )
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
return VariableTracker.build(tx, res) return VariableTracker.build(tx, res)
else: else:
@ -1119,7 +1119,7 @@ class BuiltinVariable(VariableTracker):
tx, tx,
args=list(map(ConstantVariable.create, exc.args)), args=list(map(ConstantVariable.create, exc.args)),
) )
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
return VariableTracker.build(tx, res) return VariableTracker.build(tx, res)
handlers.append(constant_fold_handler) handlers.append(constant_fold_handler)
@ -1442,7 +1442,7 @@ class BuiltinVariable(VariableTracker):
resolved_fn = getattr(self.fn, name) resolved_fn = getattr(self.fn, name)
if resolved_fn in dict_methods: if resolved_fn in dict_methods:
if isinstance(args[0], variables.UserDefinedDictVariable): if isinstance(args[0], variables.UserDefinedDictVariable):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs) return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.ConstDictVariable): elif isinstance(args[0], variables.ConstDictVariable):
return args[0].call_method(tx, name, args[1:], kwargs) return args[0].call_method(tx, name, args[1:], kwargs)
@ -1451,7 +1451,7 @@ class BuiltinVariable(VariableTracker):
resolved_fn = getattr(self.fn, name) resolved_fn = getattr(self.fn, name)
if resolved_fn in set_methods: if resolved_fn in set_methods:
if isinstance(args[0], variables.UserDefinedSetVariable): if isinstance(args[0], variables.UserDefinedSetVariable):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
return args[0]._set_vt.call_method(tx, name, args[1:], kwargs) return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.SetVariable): elif isinstance(args[0], variables.SetVariable):
return args[0].call_method(tx, name, args[1:], kwargs) return args[0].call_method(tx, name, args[1:], kwargs)
@ -1540,12 +1540,12 @@ class BuiltinVariable(VariableTracker):
if type(arg.value).__str__ is object.__str__: if type(arg.value).__str__ is object.__str__:
# Rely on the object str method # Rely on the object str method
try: try:
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
return variables.ConstantVariable.create(value=str_method()) return variables.ConstantVariable.create(value=str_method())
except AttributeError: except AttributeError:
# Graph break # Graph break
return return
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
elif is_wrapper_or_member_descriptor(str_method): elif is_wrapper_or_member_descriptor(str_method):
unimplemented_v2( unimplemented_v2(
gb_type="Attempted to a str() method implemented in C/C++", gb_type="Attempted to a str() method implemented in C/C++",
@ -1662,10 +1662,10 @@ class BuiltinVariable(VariableTracker):
else: else:
raw_b = b.raw_value raw_b = b.raw_value
if self.fn is max: if self.fn is max:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
raw_res = max(a.raw_value, raw_b) raw_res = max(a.raw_value, raw_b)
else: else:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
raw_res = min(a.raw_value, raw_b) raw_res = min(a.raw_value, raw_b)
need_unwrap = any( need_unwrap = any(
@ -1980,12 +1980,16 @@ class BuiltinVariable(VariableTracker):
if isinstance(arg, dict): if isinstance(arg, dict):
arg = [ConstantVariable.create(k) for k in arg.keys()] arg = [ConstantVariable.create(k) for k in arg.keys()]
return DictVariableType( return DictVariableType(
dict.fromkeys(arg, value), user_cls, mutation_type=ValueMutationNew() # pyrefly: ignore [bad-argument-type]
dict.fromkeys(arg, value),
user_cls,
mutation_type=ValueMutationNew(),
) )
elif arg.has_force_unpack_var_sequence(tx): elif arg.has_force_unpack_var_sequence(tx):
keys = arg.force_unpack_var_sequence(tx) keys = arg.force_unpack_var_sequence(tx)
if all(is_hashable(v) for v in keys): if all(is_hashable(v) for v in keys):
return DictVariableType( return DictVariableType(
# pyrefly: ignore [bad-argument-type]
dict.fromkeys(keys, value), dict.fromkeys(keys, value),
user_cls, user_cls,
mutation_type=ValueMutationNew(), mutation_type=ValueMutationNew(),
@ -2152,7 +2156,7 @@ class BuiltinVariable(VariableTracker):
) )
if isinstance(arg, variables.UserDefinedExceptionClassVariable): if isinstance(arg, variables.UserDefinedExceptionClassVariable):
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
return ConstantVariable.create(isinstance(arg_type, isinstance_type)) return ConstantVariable.create(isinstance(arg_type, isinstance_type))
isinstance_type_tuple: tuple[type, ...] isinstance_type_tuple: tuple[type, ...]
@ -2185,10 +2189,10 @@ class BuiltinVariable(VariableTracker):
# through it. This is a limitation of the current implementation. # through it. This is a limitation of the current implementation.
# Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it # Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
# might not be a big issue and we trade off it for performance. # might not be a big issue and we trade off it for performance.
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
val = issubclass(arg_type, isinstance_type_tuple) val = issubclass(arg_type, isinstance_type_tuple)
except TypeError: except TypeError:
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
val = arg_type in isinstance_type_tuple val = arg_type in isinstance_type_tuple
return variables.ConstantVariable.create(val) return variables.ConstantVariable.create(val)
@ -2210,7 +2214,7 @@ class BuiltinVariable(VariableTracker):
# WARNING: This might run arbitrary user code `__subclasscheck__`. # WARNING: This might run arbitrary user code `__subclasscheck__`.
# See the comment in call_isinstance above. # See the comment in call_isinstance above.
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py)) return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
def call_super(self, tx: "InstructionTranslator", a, b): def call_super(self, tx: "InstructionTranslator", a, b):
@ -2256,9 +2260,9 @@ class BuiltinVariable(VariableTracker):
value = getattr(self.fn, name) value = getattr(self.fn, name)
except AttributeError: except AttributeError:
raise_observed_exception(AttributeError, tx) raise_observed_exception(AttributeError, tx)
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
if not callable(value): if not callable(value):
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
return VariableTracker.build(tx, value, source) return VariableTracker.build(tx, value, source)
return variables.GetAttrVariable(self, name, source=source) return variables.GetAttrVariable(self, name, source=source)

View File

@ -34,6 +34,7 @@ class LazyCache:
self.vt = builder.VariableBuilder(tx, self.source)(self.value) self.vt = builder.VariableBuilder(tx, self.source)(self.value)
if self.name_hint is not None: if self.name_hint is not None:
# pyrefly: ignore [missing-attribute]
self.vt.set_name_hint(self.name_hint) self.vt.set_name_hint(self.name_hint)
del self.value del self.value

View File

@ -1138,11 +1138,13 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())): for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
if isinstance(m, FakeTensorMode): if isinstance(m, FakeTensorMode):
# pyrefly: ignore [bad-argument-type]
fake_modes.append((m, "active fake mode", i)) fake_modes.append((m, "active fake mode", i))
flat_inputs = pytree.tree_leaves(inputs) flat_inputs = pytree.tree_leaves(inputs)
for i, flat_input in enumerate(flat_inputs): for i, flat_input in enumerate(flat_inputs):
if isinstance(flat_input, FakeTensor): if isinstance(flat_input, FakeTensor):
# pyrefly: ignore [bad-argument-type]
fake_modes.append((flat_input.fake_mode, "fake tensor input", i)) fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
if is_traceable_wrapper_subclass(flat_input): if is_traceable_wrapper_subclass(flat_input):
out: list[Union[torch.Tensor, int, torch.SymInt]] = [] out: list[Union[torch.Tensor, int, torch.SymInt]] = []
@ -1151,6 +1153,7 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
x for x in out if isinstance(x, FakeTensor) x for x in out if isinstance(x, FakeTensor)
] ]
fake_modes.extend( fake_modes.extend(
# pyrefly: ignore [bad-argument-type]
[ [
(tensor.fake_mode, f"subclass input {i}", ix) (tensor.fake_mode, f"subclass input {i}", ix)
for ix, tensor in enumerate(fake_tensors) for ix, tensor in enumerate(fake_tensors)
@ -1162,9 +1165,12 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
for m, desc2, i2 in fake_modes[1:]: for m, desc2, i2 in fake_modes[1:]:
assert fake_mode is m, ( assert fake_mode is m, (
f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n" f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
# pyrefly: ignore [missing-attribute]
f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n" f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
# pyrefly: ignore [missing-attribute]
f"fake mode from {desc2} {i2} allocated at:\n{m.stack}" f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
) )
# pyrefly: ignore [bad-return]
return fake_mode return fake_mode
else: else:
return None return None

View File

@ -114,6 +114,7 @@ def _cancel_all_tasks(
for task in to_cancel: for task in to_cancel:
task.cancel() task.cancel()
# pyrefly: ignore [bad-argument-type]
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True)) loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
for task in to_cancel: for task in to_cancel:
@ -149,7 +150,7 @@ def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[
task_factory = task_factories[0] task_factory = task_factories[0]
if task_factory is None: if task_factory is None:
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
task = asyncio.Task(coro, loop=loop, context=context) task = asyncio.Task(coro, loop=loop, context=context)
else: else:
task = asyncio.Task(coro, loop=loop) task = asyncio.Task(coro, loop=loop)

View File

@ -590,11 +590,11 @@ class CKGemmTemplate(CKTemplate):
arg = f"/* {field_name} */ Tuple<{tuple_elements}>" arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
else: # tile shape else: # tile shape
arg = f"/* {field_name} */ S<{tuple_elements}>" arg = f"/* {field_name} */ S<{tuple_elements}>"
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
template_params.append(arg) template_params.append(arg)
else: else:
if field_value is not None: if field_value is not None:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
template_params.append(f"/* {field_name} */ {field_value}") template_params.append(f"/* {field_name} */ {field_value}")
operation_name = op.name().replace("(", "").replace(",", "").replace(")", "") operation_name = op.name().replace("(", "").replace(",", "").replace(")", "")
return self._template_from_string(template_definition).render( return self._template_from_string(template_definition).render(
@ -939,6 +939,7 @@ class CKGemmTemplate(CKTemplate):
for o in rops: for o in rops:
kBatches = self._get_kBatch(o) kBatches = self._get_kBatch(o)
for kBatch in kBatches: for kBatch in kBatches:
# pyrefly: ignore [bad-argument-type]
ops.append(InductorROCmOp(op=o, kBatch=kBatch)) ops.append(InductorROCmOp(op=o, kBatch=kBatch))
filtered_instances = list(filter(lambda op: self.filter_op(op), ops)) filtered_instances = list(filter(lambda op: self.filter_op(op), ops))

View File

@ -273,7 +273,7 @@ def record_original_output_strides(gm: GraphModule) -> None:
): ):
output_strides.append(val.stride()) output_strides.append(val.stride())
else: else:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
output_strides.append(None) output_strides.append(None)
output_node.meta["original_output_strides"] = output_strides output_node.meta["original_output_strides"] = output_strides
@ -1110,6 +1110,7 @@ def _compile_fx_inner(
) )
log.info("-" * 130) log.info("-" * 130)
for row in mm_table_data: for row in mm_table_data:
# pyrefly: ignore [not-iterable]
log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001 log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001
log.info("-" * 130) log.info("-" * 130)
@ -1551,7 +1552,7 @@ class _InProcessFxCompile(FxCompile):
node_runtimes = None node_runtimes = None
if inductor_metrics_log.isEnabledFor(logging.INFO): if inductor_metrics_log.isEnabledFor(logging.INFO):
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes() num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
metrics.num_bytes_accessed += num_bytes metrics.num_bytes_accessed += num_bytes
metrics.node_runtimes += node_runtimes metrics.node_runtimes += node_runtimes
metrics.nodes_num_elem += nodes_num_elem metrics.nodes_num_elem += nodes_num_elem
@ -1595,10 +1596,10 @@ class _InProcessFxCompile(FxCompile):
disable = f"{disable} Found from {stack_trace}\n" disable = f"{disable} Found from {stack_trace}\n"
else: else:
disable = f"{disable}\n" disable = f"{disable}\n"
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
V.graph.disable_cudagraphs_reason = disable V.graph.disable_cudagraphs_reason = disable
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
if cudagraphs and not V.graph.disable_cudagraphs_reason: if cudagraphs and not V.graph.disable_cudagraphs_reason:
maybe_incompat_node = get_first_incompatible_cudagraph_node(gm) maybe_incompat_node = get_first_incompatible_cudagraph_node(gm)
if maybe_incompat_node: if maybe_incompat_node:
@ -1607,29 +1608,29 @@ class _InProcessFxCompile(FxCompile):
"stack_trace", None "stack_trace", None
): ):
disable = f"{disable} Found from {stack_trace}\n" disable = f"{disable} Found from {stack_trace}\n"
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
V.graph.disable_cudagraphs_reason = disable V.graph.disable_cudagraphs_reason = disable
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
if V.aot_compilation: if V.aot_compilation:
assert isinstance( assert isinstance(
compiled_fn, compiled_fn,
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
(str, list, torch.fx.GraphModule), (str, list, torch.fx.GraphModule),
), type(compiled_fn) ), type(compiled_fn)
return CompiledAOTI(compiled_fn) return CompiledAOTI(compiled_fn)
# TODO: Hoist this above V.aot_compilation # TODO: Hoist this above V.aot_compilation
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
if cudagraphs and not V.graph.disable_cudagraphs_reason: if cudagraphs and not V.graph.disable_cudagraphs_reason:
from torch._inductor.cudagraph_utils import ( from torch._inductor.cudagraph_utils import (
check_lowering_disable_cudagraph, check_lowering_disable_cudagraph,
) )
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
V.graph.disable_cudagraphs_reason = ( V.graph.disable_cudagraphs_reason = (
check_lowering_disable_cudagraph( check_lowering_disable_cudagraph(
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
V.graph.device_node_mapping V.graph.device_node_mapping
) )
) )
@ -1637,29 +1638,29 @@ class _InProcessFxCompile(FxCompile):
self._compile_stats[type(self)].codegen_and_compile += 1 self._compile_stats[type(self)].codegen_and_compile += 1
if ( if (
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
torch._inductor.debug.RECORD_GRAPH_EXECUTION torch._inductor.debug.RECORD_GRAPH_EXECUTION
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
and torch._inductor.debug.GRAPH_COMPILE_IDS is not None and torch._inductor.debug.GRAPH_COMPILE_IDS is not None
): ):
compile_id = str( compile_id = str(
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
torch._guards.CompileContext.current_compile_id() torch._guards.CompileContext.current_compile_id()
) )
graph_id = graph_kwargs.get("graph_id") graph_id = graph_kwargs.get("graph_id")
if graph_id is not None: if graph_id is not None:
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
torch._inductor.debug.GRAPH_COMPILE_IDS[graph_id] = ( torch._inductor.debug.GRAPH_COMPILE_IDS[graph_id] = (
compile_id compile_id
) )
return CompiledFxGraph( return CompiledFxGraph(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
compiled_fn, compiled_fn,
graph, graph,
gm, gm,
output_strides, output_strides,
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
V.graph.disable_cudagraphs_reason, V.graph.disable_cudagraphs_reason,
metrics_helper.get_deltas(), metrics_helper.get_deltas(),
counters["inductor"] - inductor_counters, counters["inductor"] - inductor_counters,
@ -1701,18 +1702,18 @@ def fx_codegen_and_compile(
from .compile_fx_async import _AsyncFxCompile from .compile_fx_async import _AsyncFxCompile
from .compile_fx_ext import _OutOfProcessFxCompile from .compile_fx_ext import _OutOfProcessFxCompile
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
assert isinstance(scheme, _OutOfProcessFxCompile), ( assert isinstance(scheme, _OutOfProcessFxCompile), (
"async is only valid with an out-of-process compile mode" "async is only valid with an out-of-process compile mode"
) )
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
scheme = _AsyncFxCompile(scheme) scheme = _AsyncFxCompile(scheme)
if fx_compile_progressive: if fx_compile_progressive:
from .compile_fx_async import _ProgressiveFxCompile from .compile_fx_async import _ProgressiveFxCompile
from .compile_fx_ext import _OutOfProcessFxCompile from .compile_fx_ext import _OutOfProcessFxCompile
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
assert isinstance(scheme, _OutOfProcessFxCompile), ( assert isinstance(scheme, _OutOfProcessFxCompile), (
"progressive is only valid with an out-of-process compile mode" "progressive is only valid with an out-of-process compile mode"
) )
@ -1722,10 +1723,10 @@ def fx_codegen_and_compile(
# Use in-process compile for the fast version # Use in-process compile for the fast version
fast_scheme = _InProcessFxCompile() fast_scheme = _InProcessFxCompile()
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs) scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs)
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs) return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
@ -1835,7 +1836,7 @@ def cudagraphify_impl(
Assumes inputs[static_input_idxs[i]] are always the same memory address Assumes inputs[static_input_idxs[i]] are always the same memory address
""" """
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type] check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type]
# pyrefly: ignore # annotation-mismatch # pyrefly: ignore [annotation-mismatch]
static_input_idxs: OrderedSet[int] = OrderedSet( static_input_idxs: OrderedSet[int] = OrderedSet(
remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type] remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type]
) )
@ -1902,7 +1903,7 @@ def cudagraphify_impl(
index_expanded_dims_and_copy_(dst, src, expanded_dims) index_expanded_dims_and_copy_(dst, src, expanded_dims)
new_inputs.clear() new_inputs.clear()
graph.replay() graph.replay()
# pyrefly: ignore # bad-return # pyrefly: ignore [bad-return]
return static_outputs return static_outputs
else: else:
@ -1918,7 +1919,7 @@ def cudagraphify_impl(
index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims) index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims)
new_inputs.clear() new_inputs.clear()
graph.replay() graph.replay()
# pyrefly: ignore # bad-return # pyrefly: ignore [bad-return]
return static_outputs return static_outputs
return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet()) return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet())
@ -1935,7 +1936,7 @@ def compile_fx_aot(
# [See NOTE] Unwrapping subclasses AOT # [See NOTE] Unwrapping subclasses AOT
unwrap_tensor_subclass_parameters(model_) unwrap_tensor_subclass_parameters(model_)
# pyrefly: ignore # annotation-mismatch # pyrefly: ignore [annotation-mismatch]
config_patches: dict[str, Any] = copy.deepcopy(config_patches or {}) config_patches: dict[str, Any] = copy.deepcopy(config_patches or {})
if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper): if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper):
@ -2878,7 +2879,7 @@ def _aoti_flatten_inputs(
Flatten the inputs to the graph module and return the flat inputs and options. Flatten the inputs to the graph module and return the flat inputs and options.
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options. Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.
""" """
# pyrefly: ignore # missing-module-attribute # pyrefly: ignore [missing-module-attribute]
from .compile_fx import graph_returns_tuple from .compile_fx import graph_returns_tuple
assert graph_returns_tuple(gm), ( assert graph_returns_tuple(gm), (

View File

@ -291,7 +291,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs):
log.debug("example value absent for node: %s", input) log.debug("example value absent for node: %s", input)
return return
ndim = input.meta["example_value"].ndim ndim = input.meta["example_value"].ndim
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
if dim < 0: # Normalize unbind dim if dim < 0: # Normalize unbind dim
dim += ndim dim += ndim
with graph.inserting_after(node): with graph.inserting_after(node):
@ -341,7 +341,7 @@ def normalize_cat_default(match: Match, *args, **kwargs):
ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors
) )
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
if cat_dim < 0: # Normalize cat dim if cat_dim < 0: # Normalize cat dim
cat_dim += ndim cat_dim += ndim
@ -949,7 +949,7 @@ class SplitCatSimplifier:
if isinstance(user_input, tuple): if isinstance(user_input, tuple):
# Find the correct new getitem (present in split_items) # Find the correct new getitem (present in split_items)
new_user_inputs.append( new_user_inputs.append(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
split_items[ split_items[
split_ranges.index( split_ranges.index(
( (
@ -1000,7 +1000,7 @@ class SplitCatSimplifier:
for user_input_new, transform_param in zip( for user_input_new, transform_param in zip(
user_inputs_new, transform_params user_inputs_new, transform_params
): ):
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
if not is_node_meta_valid(user_input_new): if not is_node_meta_valid(user_input_new):
log.debug("example value absent for node: %s", user_input_new) log.debug("example value absent for node: %s", user_input_new)
return return
@ -1015,7 +1015,7 @@ class SplitCatSimplifier:
stack_dim is None or stack_dim == unsqueeze_params[0] stack_dim is None or stack_dim == unsqueeze_params[0]
): ):
to_stack.append(user_input_new) to_stack.append(user_input_new)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
to_stack_meta.append(user_input_new.meta["example_value"]) to_stack_meta.append(user_input_new.meta["example_value"])
stack_dim = unsqueeze_params[0] stack_dim = unsqueeze_params[0]
continue continue
@ -1036,12 +1036,12 @@ class SplitCatSimplifier:
if unsqueeze_params: if unsqueeze_params:
to_stack.append(user_input_new) to_stack.append(user_input_new)
stack_dim = unsqueeze_params[0] stack_dim = unsqueeze_params[0]
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
to_stack_meta.append(user_input_new.meta["example_value"]) to_stack_meta.append(user_input_new.meta["example_value"])
continue continue
if unflatten_params: if unflatten_params:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
user_input_new_meta = user_input_new.meta["example_value"] user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function( user_input_new = graph.call_function(
torch.unflatten, args=(user_input_new, *unflatten_params) torch.unflatten, args=(user_input_new, *unflatten_params)
@ -1051,7 +1051,7 @@ class SplitCatSimplifier:
*unflatten_params, # type: ignore[arg-type] *unflatten_params, # type: ignore[arg-type]
) )
if movedim_params: if movedim_params:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
user_input_new_meta = user_input_new.meta["example_value"] user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function( user_input_new = graph.call_function(
torch.movedim, args=(user_input_new, *movedim_params) torch.movedim, args=(user_input_new, *movedim_params)
@ -1061,7 +1061,7 @@ class SplitCatSimplifier:
*movedim_params, # type: ignore[arg-type] *movedim_params, # type: ignore[arg-type]
) )
if flatten_params: if flatten_params:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
user_input_new_meta = user_input_new.meta["example_value"] user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function( user_input_new = graph.call_function(
torch.flatten, args=(user_input_new, *flatten_params) torch.flatten, args=(user_input_new, *flatten_params)
@ -1072,7 +1072,7 @@ class SplitCatSimplifier:
) )
user_inputs_new_transformed.append(user_input_new) user_inputs_new_transformed.append(user_input_new)
user_inputs_new_transformed_meta.append( user_inputs_new_transformed_meta.append(
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
user_input_new.meta["example_value"] user_input_new.meta["example_value"]
) )
if to_stack: if to_stack:
@ -1432,7 +1432,7 @@ def simplify_split_cat(match: Match, split_sections: list[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split if not isinstance(split_sections, (list, tuple)): # Unnormalized split
return return
split_node = next(node for node in match.nodes if node.target == torch.split) split_node = next(node for node in match.nodes if node.target == torch.split)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
SplitCatSimplifier().simplify(match.graph, split_node, split_sections) SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
@ -1501,7 +1501,7 @@ def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: list[int]) -
for i in range(len(split_node.args[1])): # type: ignore[arg-type] for i in range(len(split_node.args[1])): # type: ignore[arg-type]
if i in indices: if i in indices:
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
# pyrefly: ignore # bad-return # pyrefly: ignore [bad-return]
return fused_tensor_size return fused_tensor_size
@ -1978,7 +1978,7 @@ def normalize_cat_default_aten(match: Match, *args, **kwargs):
assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors) assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors)
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
if cat_dim < 0: # Normalize cat dim if cat_dim < 0: # Normalize cat dim
cat_dim += ndim cat_dim += ndim
@ -2512,7 +2512,8 @@ def reshape_cat_node_to_stack(
args=(cat_node, tuple(reshape_list)), args=(cat_node, tuple(reshape_list)),
) )
reshape_node.meta["example_value"] = torch.reshape( reshape_node.meta["example_value"] = torch.reshape(
cat_node.meta["example_value"], tuple(reshape_list) cat_node.meta["example_value"],
tuple(reshape_list), # pyrefly: ignore [bad-argument-type]
) )
permute_list = list(range(len(stack_shape))) permute_list = list(range(len(stack_shape)))
permute_list[stack_dim], permute_list[split_or_unbind_dim] = ( permute_list[stack_dim], permute_list[split_or_unbind_dim] = (
@ -3044,6 +3045,6 @@ def replace_einsum_to_pointwise(match: Match, *args, **kwargs):
einsum_node = match.nodes[0] einsum_node = match.nodes[0]
input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2) input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2)
if should_replace_einsum(einsum_node): if should_replace_einsum(einsum_node):
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
match.replace_by_example(repl, [input, weights]) match.replace_by_example(repl, [input, weights])
counters[backend]["einsum_to_pointwise_pass"] += 1 counters[backend]["einsum_to_pointwise_pass"] += 1

View File

@ -147,7 +147,7 @@ def _qualified_name(obj, mangle_name=True) -> str:
# If the module is actually a torchbind module, then we should short circuit # If the module is actually a torchbind module, then we should short circuit
if module_name == "torch._classes": if module_name == "torch._classes":
return obj.qualified_name # pyrefly: ignore # missing-attribute return obj.qualified_name # pyrefly: ignore [missing-attribute]
# The Python docs are very clear that `__module__` can be None, but I can't # The Python docs are very clear that `__module__` can be None, but I can't
# figure out when it actually would be. # figure out when it actually would be.
@ -759,7 +759,7 @@ def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
) )
return prop # pyrefly: ignore # bad-return return prop # pyrefly: ignore [bad-return]
fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined] fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
return fn return fn

View File

@ -65,8 +65,8 @@ class AsyncClosureHandler(ClosureHandler):
self._closure_event_loop = threading.Thread( self._closure_event_loop = threading.Thread(
target=event_loop target=event_loop
) # pyrefly: ignore # bad-assignment ) # pyrefly: ignore [bad-assignment]
self._closure_event_loop.start() # pyrefly: ignore # missing-attribute self._closure_event_loop.start() # pyrefly: ignore [missing-attribute]
def run(self, closure): def run(self, closure):
with self._closure_lock: with self._closure_lock:

View File

@ -515,8 +515,11 @@ def meta_copy_(self, src, non_blocking=False):
def inferUnsqueezeGeometry(tensor, dim): def inferUnsqueezeGeometry(tensor, dim):
result_sizes = list(tensor.size()) result_sizes = list(tensor.size())
result_strides = list(tensor.stride()) result_strides = list(tensor.stride())
# pyrefly: ignore [unsupported-operation]
new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim] new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
# pyrefly: ignore [bad-argument-type]
result_sizes.insert(dim, 1) result_sizes.insert(dim, 1)
# pyrefly: ignore [bad-argument-type]
result_strides.insert(dim, new_stride) result_strides.insert(dim, new_stride)
return result_sizes, result_strides return result_sizes, result_strides
@ -2341,19 +2344,19 @@ def calc_conv_nd_return_shape(
ret_shape = [input_tensor.shape[0], out_channels] ret_shape = [input_tensor.shape[0], out_channels]
if isinstance(stride, IntLike): if isinstance(stride, IntLike):
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
stride = [stride] * len(dims) stride = [stride] * len(dims)
elif len(stride) == 1: elif len(stride) == 1:
stride = [stride[0]] * len(dims) stride = [stride[0]] * len(dims)
if isinstance(padding, IntLike): if isinstance(padding, IntLike):
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
padding = [padding] * len(dims) padding = [padding] * len(dims)
elif len(padding) == 1: elif len(padding) == 1:
padding = [padding[0]] * len(dims) padding = [padding[0]] * len(dims)
if isinstance(dilation, IntLike): if isinstance(dilation, IntLike):
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
dilation = [dilation] * len(dims) dilation = [dilation] * len(dims)
elif len(dilation) == 1: elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims) dilation = [dilation[0]] * len(dims)
@ -2361,7 +2364,7 @@ def calc_conv_nd_return_shape(
output_padding_list: Optional[list[int]] = None output_padding_list: Optional[list[int]] = None
if output_padding: if output_padding:
if isinstance(output_padding, IntLike): if isinstance(output_padding, IntLike):
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
output_padding_list = [output_padding] * len(dims) output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1: elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims) output_padding_list = [output_padding[0]] * len(dims)
@ -2374,19 +2377,19 @@ def calc_conv_nd_return_shape(
ret_shape.append( ret_shape.append(
_formula_transposed( _formula_transposed(
dims[i], dims[i],
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
padding[i], padding[i],
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
dilation[i], dilation[i],
kernel_size[i], kernel_size[i],
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
stride[i], stride[i],
output_padding_list[i], output_padding_list[i],
) )
) )
else: else:
ret_shape.append( ret_shape.append(
# pyrefly: ignore # index-error # pyrefly: ignore [index-error]
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i]) _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
) )
from torch.fx.experimental.symbolic_shapes import sym_or from torch.fx.experimental.symbolic_shapes import sym_or
@ -3454,7 +3457,7 @@ def meta_index_Tensor(self, indices):
""" """
shape = before_shape + replacement_shape + after_shape shape = before_shape + replacement_shape + after_shape
strides = list(self.stride()) strides = list(self.stride())
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len( strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
replacement_shape replacement_shape
) )
@ -5311,6 +5314,7 @@ def full(size, fill_value, *args, **kwargs):
if not dtype: if not dtype:
dtype = utils.get_dtype(fill_value) dtype = utils.get_dtype(fill_value)
kwargs["dtype"] = dtype kwargs["dtype"] = dtype
# pyrefly: ignore [not-iterable]
return torch.empty(size, *args, **kwargs) return torch.empty(size, *args, **kwargs)
@ -6668,7 +6672,7 @@ def rnn_cell_checkSizes(
) )
torch._check( torch._check(
all( all(
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
x.device == input_gates.device x.device == input_gates.device
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden] for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
), ),

View File

@ -880,7 +880,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk): elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
return self._op_dk(dk, *args, **kwargs) return self._op_dk(dk, *args, **kwargs)
else: else:
return NotImplemented # pyrefly: ignore # bad-return return NotImplemented # pyrefly: ignore [bad-return]
# Remove a dispatch key from the dispatch cache. This will force it to get # Remove a dispatch key from the dispatch cache. This will force it to get
# recomputed the next time. Does nothing # recomputed the next time. Does nothing
@ -985,9 +985,9 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
r = self.py_kernels.get(final_key, final_key) r = self.py_kernels.get(final_key, final_key)
if cache_result: if cache_result:
self._dispatch_cache[key] = r # pyrefly: ignore # unsupported-operation self._dispatch_cache[key] = r # pyrefly: ignore [unsupported-operation]
add_cached_op(self) add_cached_op(self)
return r # pyrefly: ignore # bad-return return r # pyrefly: ignore [bad-return]
def name(self): def name(self):
return self._name return self._name
@ -1117,7 +1117,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
) )
assert isinstance(handler, Callable) # type: ignore[arg-type] assert isinstance(handler, Callable) # type: ignore[arg-type]
return handler(*args, **kwargs) # pyrefly: ignore # bad-return return handler(*args, **kwargs) # pyrefly: ignore [bad-return]
def _must_dispatch_in_python(args, kwargs): def _must_dispatch_in_python(args, kwargs):

View File

@ -267,6 +267,7 @@ class FunctionalTensor(torch.Tensor):
device=self.device, device=self.device,
layout=self.layout, layout=self.layout,
) )
# pyrefly: ignore [not-iterable]
return super().to(*args, **kwargs) return super().to(*args, **kwargs)
def cuda(self, device=None, *args, **kwargs): def cuda(self, device=None, *args, **kwargs):

View File

@ -551,6 +551,7 @@ class Tensor(torch._C.TensorBase):
raise RuntimeError("__setstate__ can be only called on leaf Tensors") raise RuntimeError("__setstate__ can be only called on leaf Tensors")
if len(state) == 4: if len(state) == 4:
# legacy serialization of Tensor # legacy serialization of Tensor
# pyrefly: ignore [not-iterable]
self.set_(*state) self.set_(*state)
return return
elif len(state) == 5: elif len(state) == 5:
@ -758,7 +759,7 @@ class Tensor(torch._C.TensorBase):
) )
if self._post_accumulate_grad_hooks is None: if self._post_accumulate_grad_hooks is None:
self._post_accumulate_grad_hooks: dict[Any, Any] = ( self._post_accumulate_grad_hooks: dict[Any, Any] = (
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
OrderedDict() OrderedDict()
) )
@ -1062,7 +1063,7 @@ class Tensor(torch._C.TensorBase):
else: else:
return torch._VF.split_with_sizes( return torch._VF.split_with_sizes(
self, self,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
split_size, split_size,
dim, dim,
) )
@ -1119,7 +1120,7 @@ class Tensor(torch._C.TensorBase):
__rtruediv__ = __rdiv__ __rtruediv__ = __rdiv__
__itruediv__ = _C.TensorBase.__idiv__ __itruediv__ = _C.TensorBase.__idiv__
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
__pow__ = cast( __pow__ = cast(
Callable[ Callable[
["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]], ["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]],

View File

@ -686,8 +686,8 @@ def _take_tensors(tensors, size_limit):
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0: if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
yield buf_and_size[0] yield buf_and_size[0]
buf_and_size = buf_dict[t] = [[], 0] buf_and_size = buf_dict[t] = [[], 0]
buf_and_size[0].append(tensor) # pyrefly: ignore # missing-attribute buf_and_size[0].append(tensor) # pyrefly: ignore [missing-attribute]
buf_and_size[1] += size # pyrefly: ignore # unsupported-operation buf_and_size[1] += size # pyrefly: ignore [unsupported-operation]
for buf, _ in buf_dict.values(): for buf, _ in buf_dict.values():
if len(buf) > 0: if len(buf) > 0:
yield buf yield buf
@ -744,6 +744,7 @@ class ExceptionWrapper:
if exc_info is None: if exc_info is None:
exc_info = sys.exc_info() exc_info = sys.exc_info()
self.exc_type = exc_info[0] self.exc_type = exc_info[0]
# pyrefly: ignore [not-iterable]
self.exc_msg = "".join(traceback.format_exception(*exc_info)) self.exc_msg = "".join(traceback.format_exception(*exc_info))
self.where = where self.where = where
@ -751,7 +752,7 @@ class ExceptionWrapper:
r"""Reraises the wrapped exception in the current thread""" r"""Reraises the wrapped exception in the current thread"""
# Format a message such as: "Caught ValueError in DataLoader worker # Format a message such as: "Caught ValueError in DataLoader worker
# process 2. Original Traceback:", followed by the traceback. # process 2. Original Traceback:", followed by the traceback.
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore [missing-attribute]
if self.exc_type is KeyError: if self.exc_type is KeyError:
# KeyError calls repr() on its argument (usually a dict key). This # KeyError calls repr() on its argument (usually a dict key). This
# makes stack traces unreadable. It will not be changed in Python # makes stack traces unreadable. It will not be changed in Python
@ -760,13 +761,13 @@ class ExceptionWrapper:
elif getattr(self.exc_type, "message", None): elif getattr(self.exc_type, "message", None):
# Some exceptions have first argument as non-str but explicitly # Some exceptions have first argument as non-str but explicitly
# have message field # have message field
# pyrefly: ignore # not-callable # pyrefly: ignore [not-callable]
raise self.exc_type( raise self.exc_type(
# pyrefly: ignore # unexpected-keyword # pyrefly: ignore [unexpected-keyword]
message=msg message=msg
) )
try: try:
exception = self.exc_type(msg) # pyrefly: ignore # not-callable exception = self.exc_type(msg) # pyrefly: ignore [not-callable]
except Exception: except Exception:
# If the exception takes multiple arguments or otherwise can't # If the exception takes multiple arguments or otherwise can't
# be constructed, don't try to instantiate since we don't know how to # be constructed, don't try to instantiate since we don't know how to
@ -1018,12 +1019,12 @@ class _LazySeedTracker:
self.call_order = [] self.call_order = []
def queue_seed_all(self, cb, traceback): def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore # bad-assignment self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore [bad-assignment]
# update seed_all to be latest # update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb] self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
def queue_seed(self, cb, traceback): def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback) # pyrefly: ignore # bad-assignment self.manual_seed_cb = (cb, traceback) # pyrefly: ignore [bad-assignment]
# update seed to be latest # update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb] self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]

View File

@ -419,6 +419,7 @@ class Unpickler:
inst = self.stack[-1] inst = self.stack[-1]
if type(inst) is torch.Tensor: if type(inst) is torch.Tensor:
# Legacy unpickling # Legacy unpickling
# pyrefly: ignore [not-iterable]
inst.set_(*state) inst.set_(*state)
elif type(inst) is torch.nn.Parameter: elif type(inst) is torch.nn.Parameter:
inst.__setstate__(state) inst.__setstate__(state)

View File

@ -104,6 +104,7 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
flatten("", stats) flatten("", stats)
flat_stats.sort() flat_stats.sort()
# pyrefly: ignore [no-matching-overload]
return OrderedDict(flat_stats) return OrderedDict(flat_stats)

View File

@ -525,7 +525,7 @@ def custom_fwd(
args[0]._dtype = torch.get_autocast_dtype(device_type) args[0]._dtype = torch.get_autocast_dtype(device_type)
if cast_inputs is None: if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type) args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable return fwd(*args, **kwargs) # pyrefly: ignore [not-callable]
else: else:
autocast_context = torch.is_autocast_enabled(device_type) autocast_context = torch.is_autocast_enabled(device_type)
args[0]._fwd_used_autocast = False args[0]._fwd_used_autocast = False
@ -536,7 +536,7 @@ def custom_fwd(
**_cast(kwargs, device_type, cast_inputs), **_cast(kwargs, device_type, cast_inputs),
) )
else: else:
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable return fwd(*args, **kwargs) # pyrefly: ignore [not-callable]
return decorate_fwd return decorate_fwd
@ -571,6 +571,6 @@ def custom_bwd(bwd=None, *, device_type: str):
enabled=args[0]._fwd_used_autocast, enabled=args[0]._fwd_used_autocast,
dtype=args[0]._dtype, dtype=args[0]._dtype,
): ):
return bwd(*args, **kwargs) # pyrefly: ignore # not-callable return bwd(*args, **kwargs) # pyrefly: ignore [not-callable]
return decorate_bwd return decorate_bwd

View File

@ -84,7 +84,7 @@ class _NSGraphMatchableSubgraphsIterator:
if is_match: if is_match:
# navigate to the base node # navigate to the base node
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1): for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.seen_nodes.add(cur_start_node) self.seen_nodes.add(cur_start_node)
# for now, assume that there are no other nodes # for now, assume that there are no other nodes
# which need to be added to the stack # which need to be added to the stack
@ -95,10 +95,10 @@ class _NSGraphMatchableSubgraphsIterator:
cur_base_op_node = cur_start_node cur_base_op_node = cur_start_node
break break
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.seen_nodes.add(cur_start_node) self.seen_nodes.add(cur_start_node)
# add args of previous nodes to stack # add args of previous nodes to stack
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
for arg in cur_start_node.all_input_nodes: for arg in cur_start_node.all_input_nodes:
self._recursively_add_node_arg_to_stack(arg) self._recursively_add_node_arg_to_stack(arg)
@ -106,7 +106,7 @@ class _NSGraphMatchableSubgraphsIterator:
# note: this check is done on the start_node, i.e. # note: this check is done on the start_node, i.e.
# if we are matching linear-relu in reverse, this would do the matchable # if we are matching linear-relu in reverse, this would do the matchable
# check on the linear # check on the linear
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
if not self._is_matchable(cur_base_op_node): if not self._is_matchable(cur_base_op_node):
continue continue
@ -120,10 +120,10 @@ class _NSGraphMatchableSubgraphsIterator:
continue continue
return NSSubgraph( return NSSubgraph(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
start_node=cur_start_node, start_node=cur_start_node,
end_node=cur_end_node, end_node=cur_end_node,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
base_op_node=cur_base_op_node, base_op_node=cur_base_op_node,
) )
@ -481,4 +481,5 @@ of subgraphs."""
# subgraphs in their order of execution. # subgraphs in their order of execution.
results = collections.OrderedDict(reversed(results.items())) results = collections.OrderedDict(reversed(results.items()))
# pyrefly: ignore [bad-return]
return results return results

View File

@ -30,6 +30,7 @@ class EventList(list):
use_device = kwargs.pop("use_device", None) use_device = kwargs.pop("use_device", None)
profile_memory = kwargs.pop("profile_memory", False) profile_memory = kwargs.pop("profile_memory", False)
with_flops = kwargs.pop("with_flops", False) with_flops = kwargs.pop("with_flops", False)
# pyrefly: ignore [not-iterable]
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._use_device = use_device self._use_device = use_device
self._profile_memory = profile_memory self._profile_memory = profile_memory
@ -505,9 +506,9 @@ class FunctionEvent(FormattedTimesMixin):
self.id: int = id self.id: int = id
self.node_id: int = node_id self.node_id: int = node_id
self.name: str = name self.name: str = name
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.overload_name: str = overload_name self.overload_name: str = overload_name
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.trace_name: str = trace_name self.trace_name: str = trace_name
self.time_range: Interval = Interval(start_us, end_us) self.time_range: Interval = Interval(start_us, end_us)
self.thread: int = thread self.thread: int = thread
@ -516,13 +517,13 @@ class FunctionEvent(FormattedTimesMixin):
self.count: int = 1 self.count: int = 1
self.cpu_children: list[FunctionEvent] = [] self.cpu_children: list[FunctionEvent] = []
self.cpu_parent: Optional[FunctionEvent] = None self.cpu_parent: Optional[FunctionEvent] = None
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.input_shapes: tuple[int, ...] = input_shapes self.input_shapes: tuple[int, ...] = input_shapes
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.concrete_inputs: list[Any] = concrete_inputs self.concrete_inputs: list[Any] = concrete_inputs
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.kwinputs: dict[str, Any] = kwinputs self.kwinputs: dict[str, Any] = kwinputs
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.stack: list = stack self.stack: list = stack
self.scope: int = scope self.scope: int = scope
self.use_device: Optional[str] = use_device self.use_device: Optional[str] = use_device
@ -766,7 +767,7 @@ class FunctionEventAvg(FormattedTimesMixin):
self.self_device_memory_usage += other.self_device_memory_usage self.self_device_memory_usage += other.self_device_memory_usage
self.count += other.count self.count += other.count
if self.flops is None: if self.flops is None:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.flops = other.flops self.flops = other.flops
elif other.flops is not None: elif other.flops is not None:
self.flops += other.flops self.flops += other.flops
@ -1003,7 +1004,7 @@ def _build_table(
] ]
if flops <= 0: if flops <= 0:
raise AssertionError(f"Expected flops to be positive, but got {flops}") raise AssertionError(f"Expected flops to be positive, but got {flops}")
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1))) log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
if not (log_flops >= 0 and log_flops < len(flop_headers)): if not (log_flops >= 0 and log_flops < len(flop_headers)):
raise AssertionError( raise AssertionError(

View File

@ -50,6 +50,7 @@ def compile(*args, **kwargs):
""" """
See :func:`torch.compile` for details on the arguments for this function. See :func:`torch.compile` for details on the arguments for this function.
""" """
# pyrefly: ignore [not-iterable]
return torch.compile(*args, **kwargs) return torch.compile(*args, **kwargs)

View File

@ -198,6 +198,7 @@ def _for_each_rank_run_func(
rr_val = flat_rank_rets[rr_key] rr_val = flat_rank_rets[rr_key]
if isinstance(rr_val, Tensor): if isinstance(rr_val, Tensor):
# pyrefly: ignore [bad-argument-type, bad-argument-count]
ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)}) ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)})
elif isinstance(rr_val, (list, tuple)): elif isinstance(rr_val, (list, tuple)):
ret_list = [] ret_list = []
@ -206,6 +207,7 @@ def _for_each_rank_run_func(
v_it = iter(rets.values()) v_it = iter(rets.values())
v = next(v_it) v = next(v_it)
if isinstance(v, Tensor): if isinstance(v, Tensor):
# pyrefly: ignore [bad-argument-type, bad-argument-count]
ret_list.append(LocalTensor(rets)) ret_list.append(LocalTensor(rets))
elif isinstance(v, int) and not all(v == v2 for v2 in v_it): elif isinstance(v, int) and not all(v == v2 for v2 in v_it):
ret_list.append(torch.SymInt(LocalIntNode(rets))) ret_list.append(torch.SymInt(LocalIntNode(rets)))
@ -468,7 +470,7 @@ class LocalTensor(torch.Tensor):
def __repr__(self) -> str: # type: ignore[override] def __repr__(self) -> str: # type: ignore[override]
parts = [] parts = []
for k, v in self._local_tensors.items(): for k, v in self._local_tensors.items():
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
parts.append(f" {k}: {v}") parts.append(f" {k}: {v}")
tensors_str = ",\n".join(parts) tensors_str = ",\n".join(parts)
return f"LocalTensor(\n{tensors_str}\n)" return f"LocalTensor(\n{tensors_str}\n)"
@ -491,6 +493,7 @@ class LocalTensor(torch.Tensor):
"Expecting spec to be not None from `__tensor_flatten__` return value!" "Expecting spec to be not None from `__tensor_flatten__` return value!"
) )
local_tensors = inner_tensors["_local_tensors"] local_tensors = inner_tensors["_local_tensors"]
# pyrefly: ignore [bad-argument-type, bad-argument-count]
return LocalTensor(local_tensors) return LocalTensor(local_tensors)
@classmethod @classmethod
@ -751,6 +754,7 @@ class LocalTensorMode(TorchDispatchMode):
""" """
with self.disable(): with self.disable():
# pyrefly: ignore [bad-argument-type, bad-argument-count]
return LocalTensor({r: cb(r) for r in self.ranks}) return LocalTensor({r: cb(r) for r in self.ranks})
def _patch_device_mesh(self) -> None: def _patch_device_mesh(self) -> None:
@ -761,7 +765,7 @@ class LocalTensorMode(TorchDispatchMode):
def _unpatch_device_mesh(self) -> None: def _unpatch_device_mesh(self) -> None:
assert self._old_get_coordinate is not None assert self._old_get_coordinate is not None
DeviceMesh.get_coordinate = self._old_get_coordinate DeviceMesh.get_coordinate = self._old_get_coordinate
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self._old_get_coordinate = None self._old_get_coordinate = None

View File

@ -79,6 +79,7 @@ def _flatten_tensor_size(size) -> torch.Size:
Checks if tensor size is valid, then flatten/return a torch.Size object. Checks if tensor size is valid, then flatten/return a torch.Size object.
""" """
if len(size) == 1 and isinstance(size[0], collections.abc.Sequence): if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
# pyrefly: ignore [not-iterable]
dims = list(*size) dims = list(*size)
else: else:
dims = list(size) dims = list(size)
@ -208,7 +209,7 @@ def build_global_metadata(
global_sharded_tensor_metadata = None global_sharded_tensor_metadata = None
global_metadata_rank = 0 global_metadata_rank = 0
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
for rank, rank_metadata in enumerate(gathered_metadatas): for rank, rank_metadata in enumerate(gathered_metadatas):
if rank_metadata is None: if rank_metadata is None:
continue continue

View File

@ -227,7 +227,7 @@ class PGTransport:
self._work: list[Work] = [] self._work: list[Work] = []
self._pg = pg self._pg = pg
self._timeout = timeout self._timeout = timeout
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self._device = device self._device = device
self._state_dict = state_dict self._state_dict = state_dict
@ -345,6 +345,7 @@ class PGTransport:
values.append(recv(path, v)) values.append(recv(path, v))
elif isinstance(v, _DTensorMeta): elif isinstance(v, _DTensorMeta):
tensor = recv(path, v.local) tensor = recv(path, v.local)
# pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword]
values.append(DTensor(tensor, v.spec, requires_grad=False)) values.append(DTensor(tensor, v.spec, requires_grad=False))
elif isinstance(v, _ShardedTensorMeta): elif isinstance(v, _ShardedTensorMeta):
# Receive all local shards that were sent to us # Receive all local shards that were sent to us

View File

@ -565,7 +565,7 @@ class FlatParamHandle:
# Only align addresses for `use_orig_params=True` (for now) # Only align addresses for `use_orig_params=True` (for now)
align_addresses = use_orig_params align_addresses = use_orig_params
self._init_get_unflat_views_fn(align_addresses) self._init_get_unflat_views_fn(align_addresses)
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.device = device self.device = device
self._device_handle = _FSDPDeviceHandle.from_device(self.device) self._device_handle = _FSDPDeviceHandle.from_device(self.device)
self.process_group = process_group self.process_group = process_group
@ -2495,6 +2495,7 @@ class FlatParamHandle:
########### ###########
def flat_param_to(self, *args, **kwargs): def flat_param_to(self, *args, **kwargs):
"""Wrap an in-place call to ``.to()`` for ``self.flat_param``.""" """Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
# pyrefly: ignore [not-iterable]
self.flat_param.data = self.flat_param.to(*args, **kwargs) self.flat_param.data = self.flat_param.to(*args, **kwargs)
if self._use_orig_params: if self._use_orig_params:
# Refresh the views because their storage may have changed # Refresh the views because their storage may have changed

View File

@ -139,11 +139,14 @@ def _from_local_no_grad(
""" """
if not compiled_autograd_enabled(): if not compiled_autograd_enabled():
# pyrefly: ignore [bad-argument-type]
return DTensor( return DTensor(
# Use the local tensor directly instead of constructing a new tensor # Use the local tensor directly instead of constructing a new tensor
# variable, e.g. with `view_as()`, since this is not differentiable # variable, e.g. with `view_as()`, since this is not differentiable
# pyrefly: ignore [bad-argument-count]
local_tensor, local_tensor,
sharding_spec, sharding_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=local_tensor.requires_grad, requires_grad=local_tensor.requires_grad,
) )
else: else:

View File

@ -107,9 +107,12 @@ class _ToTorchTensor(torch.autograd.Function):
) )
return ( return (
# pyrefly: ignore [bad-argument-type]
DTensor( DTensor(
# pyrefly: ignore [bad-argument-count]
grad_output, grad_output,
grad_spec, grad_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=grad_output.requires_grad, requires_grad=grad_output.requires_grad,
), ),
None, None,
@ -175,11 +178,14 @@ class _FromTorchTensor(torch.autograd.Function):
) )
# We want a fresh Tensor object that shares memory with the input tensor # We want a fresh Tensor object that shares memory with the input tensor
# pyrefly: ignore [bad-argument-type]
dist_tensor = DTensor( dist_tensor = DTensor(
# pyrefly: ignore [bad-argument-count]
input.view_as(input), input.view_as(input),
dist_spec, dist_spec,
# requires_grad of the dist tensor depends on if input # requires_grad of the dist tensor depends on if input
# requires_grad or not # requires_grad or not
# pyrefly: ignore [unexpected-keyword]
requires_grad=input.requires_grad, requires_grad=input.requires_grad,
) )
return dist_tensor return dist_tensor
@ -304,9 +310,12 @@ class DTensor(torch.Tensor):
spec.placements, spec.placements,
tensor_meta=unflatten_tensor_meta, tensor_meta=unflatten_tensor_meta,
) )
# pyrefly: ignore [bad-argument-type]
return DTensor( return DTensor(
# pyrefly: ignore [bad-argument-count]
local_tensor, local_tensor,
unflatten_spec, unflatten_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=requires_grad, requires_grad=requires_grad,
) )
@ -820,9 +829,12 @@ def distribute_tensor(
dtype=tensor.dtype, dtype=tensor.dtype,
), ),
) )
# pyrefly: ignore [bad-argument-type]
return DTensor( return DTensor(
# pyrefly: ignore [bad-argument-count]
local_tensor.requires_grad_(tensor.requires_grad), local_tensor.requires_grad_(tensor.requires_grad),
spec, spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
) )
@ -1077,9 +1089,12 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
), ),
) )
# pyrefly: ignore [bad-argument-type]
return DTensor( return DTensor(
# pyrefly: ignore [bad-argument-count]
local_tensor, local_tensor,
spec, spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=kwargs["requires_grad"], requires_grad=kwargs["requires_grad"],
) )

View File

@ -78,8 +78,11 @@ def found_inf_reduce_handler(
dtype=target_tensor.dtype, dtype=target_tensor.dtype,
), ),
) )
# pyrefly: ignore [bad-argument-type]
found_inf_dtensor = dtensor.DTensor( found_inf_dtensor = dtensor.DTensor(
local_tensor=target_tensor, spec=spec, requires_grad=False local_tensor=target_tensor, # pyrefly: ignore [unexpected-keyword]
spec=spec, # pyrefly: ignore [unexpected-keyword]
requires_grad=False, # pyrefly: ignore [unexpected-keyword]
) )
found_inf = found_inf_dtensor.full_tensor() found_inf = found_inf_dtensor.full_tensor()
target_tensor.copy_(found_inf) target_tensor.copy_(found_inf)
@ -189,7 +192,7 @@ class OpDispatcher:
local_tensor_args = ( local_tensor_args = (
pytree.tree_unflatten( pytree.tree_unflatten(
cast(list[object], op_info.local_args), cast(list[object], op_info.local_args),
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
op_info.args_tree_spec, op_info.args_tree_spec,
) )
if op_info.args_tree_spec if op_info.args_tree_spec
@ -366,7 +369,7 @@ class OpDispatcher:
resharded_local_tensor = redistribute_local_tensor( resharded_local_tensor = redistribute_local_tensor(
local_tensor, local_tensor,
arg_spec, arg_spec,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
reshard_arg_spec, reshard_arg_spec,
) )
new_local_args.append(resharded_local_tensor) new_local_args.append(resharded_local_tensor)
@ -439,7 +442,7 @@ class OpDispatcher:
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor( kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
op_call, op_call,
v, v,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
compute_mesh, compute_mesh,
) )
local_kwargs[k] = v local_kwargs[k] = v
@ -456,7 +459,7 @@ class OpDispatcher:
OpSchema( OpSchema(
op_call, op_call,
( (
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
pytree.tree_unflatten(args_schema, args_spec) pytree.tree_unflatten(args_schema, args_spec)
if args_spec if args_spec
else tuple(args_schema) else tuple(args_schema)
@ -478,6 +481,7 @@ class OpDispatcher:
assert isinstance(spec, DTensorSpec), ( assert isinstance(spec, DTensorSpec), (
f"output spec does not match with output! Expected DTensorSpec, got {spec}." f"output spec does not match with output! Expected DTensorSpec, got {spec}."
) )
# pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword]
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad) return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
else: else:
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor # if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor

View File

@ -883,9 +883,12 @@ class Redistribute(torch.autograd.Function):
output = local_tensor output = local_tensor
target_spec = current_spec target_spec = current_spec
# pyrefly: ignore [bad-argument-type]
return dtensor.DTensor( return dtensor.DTensor(
# pyrefly: ignore [bad-argument-count]
output, output,
target_spec, target_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=input.requires_grad, requires_grad=input.requires_grad,
) )
@ -944,9 +947,12 @@ class Redistribute(torch.autograd.Function):
dtype=output.dtype, dtype=output.dtype,
), ),
) )
# pyrefly: ignore [bad-argument-type]
output_dtensor = dtensor.DTensor( output_dtensor = dtensor.DTensor(
# pyrefly: ignore [bad-argument-count]
output, output,
spec, spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=grad_output.requires_grad, requires_grad=grad_output.requires_grad,
) )

View File

@ -174,9 +174,12 @@ def _log_softmax_handler(
tensor_meta=output_tensor_meta, tensor_meta=output_tensor_meta,
) )
# pyrefly: ignore [bad-argument-type]
return DTensor( return DTensor(
# pyrefly: ignore [bad-argument-count]
res, res,
res_spec, res_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=res.requires_grad, requires_grad=res.requires_grad,
) )
@ -251,7 +254,7 @@ def _nll_loss_forward(
if weight is not None: if weight is not None:
new_shape = list(x.shape) new_shape = list(x.shape)
new_shape[channel_dim] = -1 new_shape[channel_dim] = -1
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
w = w.expand(new_shape) w = w.expand(new_shape)
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim) wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
wsum = torch.where(target != ignore_index, wsum, 0) wsum = torch.where(target != ignore_index, wsum, 0)
@ -309,9 +312,9 @@ def _nll_loss_forward_handler(
output_placements = all_replicate_placements output_placements = all_replicate_placements
# tensor inputs to _propagate_tensor_meta need to be DTensors # tensor inputs to _propagate_tensor_meta need to be DTensors
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
args = list(args) args = list(args)
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
args[1], args[2] = target, weight args[1], args[2] = target, weight
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
@ -330,9 +333,12 @@ def _nll_loss_forward_handler(
out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta) out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta)
return ( return (
# pyrefly: ignore [bad-argument-type]
DTensor( DTensor(
# pyrefly: ignore [bad-argument-count]
result, result,
out_spec, out_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=result.requires_grad, requires_grad=result.requires_grad,
), ),
total_weight, total_weight,
@ -442,11 +448,11 @@ def _nll_loss_backward_handler(
weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh) weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
# tensor inputs to _propagate_tensor_meta need to be DTensors # tensor inputs to _propagate_tensor_meta need to be DTensors
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
args = list(args) args = list(args)
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
args[2], args[3] = target, weight args[2], args[3] = target, weight
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh) args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs) output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
@ -470,9 +476,12 @@ def _nll_loss_backward_handler(
tensor_meta=output_tensor_meta, tensor_meta=output_tensor_meta,
) )
# pyrefly: ignore [bad-argument-type]
return DTensor( return DTensor(
# pyrefly: ignore [bad-argument-count]
result, result,
out_spec, out_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=result.requires_grad, requires_grad=result.requires_grad,
) )

View File

@ -949,7 +949,7 @@ def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
for key, value in pytree.tree_map(arg_dump, node.kwargs).items() for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
] ]
target = node.target if node.op in ("call_function", "get_attr") else "" target = node.target if node.op in ("call_function", "get_attr") else ""
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})") ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
nodes_idx[id(node)] = i nodes_idx[id(node)] = i
return "\n".join(ret) return "\n".join(ret)
@ -1206,6 +1206,7 @@ class _ModuleFrame:
for k in kwargs_spec.context for k in kwargs_spec.context
} }
assert self.parent_call_module is not None assert self.parent_call_module is not None
# pyrefly: ignore [bad-assignment]
self.parent_call_module.args = tuple(arg_nodes) self.parent_call_module.args = tuple(arg_nodes)
self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment] self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment]
@ -1393,6 +1394,7 @@ class _ModuleFrame:
def print(self, *args, **kwargs): def print(self, *args, **kwargs):
if self.verbose: if self.verbose:
# pyrefly: ignore [not-iterable]
print(*args, **kwargs) print(*args, **kwargs)
def run_from(self, node_idx): def run_from(self, node_idx):
@ -1486,7 +1488,7 @@ class _ModuleFrame:
self.seen_attrs[self.child_fqn].add(node.target) self.seen_attrs[self.child_fqn].add(node.target)
self.copy_node(node) self.copy_node(node)
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
node_idx += 1 node_idx += 1

View File

@ -1952,7 +1952,7 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
) )
if isinstance(shape, (int, torch.SymInt)): if isinstance(shape, (int, torch.SymInt)):
shape = torch.Size([shape]) # pyrefly: ignore # bad-argument-type shape = torch.Size([shape]) # pyrefly: ignore [bad-argument-type]
else: else:
for dim in shape: for dim in shape:
torch._check_type( torch._check_type(

View File

@ -87,6 +87,7 @@ def _reify_object_slots(o, s):
@dispatch(slice, dict) @dispatch(slice, dict)
def _reify(o, s): def _reify(o, s):
"""Reify a Python ``slice`` object""" """Reify a Python ``slice`` object"""
# pyrefly: ignore [not-iterable]
return slice(*reify((o.start, o.stop, o.step), s)) return slice(*reify((o.start, o.stop, o.step), s))

View File

@ -59,7 +59,7 @@ Argument = Optional[
BaseArgumentTypes, BaseArgumentTypes,
] ]
] ]
# pyrefly: ignore # invalid-annotation # pyrefly: ignore [invalid-annotation]
ArgumentT = TypeVar("ArgumentT", bound=Argument) ArgumentT = TypeVar("ArgumentT", bound=Argument)
_P = ParamSpec("_P") _P = ParamSpec("_P")
_R = TypeVar("_R") _R = TypeVar("_R")
@ -385,7 +385,7 @@ class Node(_NodeBase):
Args: Args:
x (Node): The node to put before this node. Must be a member of the same graph. x (Node): The node to put before this node. Must be a member of the same graph.
""" """
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
self._prepend(x) self._prepend(x)
@compatibility(is_backward_compatible=True) @compatibility(is_backward_compatible=True)
@ -397,7 +397,7 @@ class Node(_NodeBase):
Args: Args:
x (Node): The node to put after this node. Must be a member of the same graph. x (Node): The node to put after this node. Must be a member of the same graph.
""" """
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
self._next._prepend(x) self._next._prepend(x)
@property @property
@ -698,7 +698,8 @@ class Node(_NodeBase):
if replace_hooks: if replace_hooks:
for replace_hook in replace_hooks: for replace_hook in replace_hooks:
replace_hook(old=self, new=replace_with.name, user=use_node) replace_hook(old=self, new=replace_with.name, user=use_node)
use_node._replace_input_with(self, replace_with) # pyrefly: ignore [missing-attribute]
use_node._replace_input_with(self, replace_with) # type: ignore[attr-defined]
return result return result
@compatibility(is_backward_compatible=False) @compatibility(is_backward_compatible=False)
@ -835,7 +836,8 @@ class Node(_NodeBase):
for replace_hook in m._replace_hooks: for replace_hook in m._replace_hooks:
replace_hook(old=old_input, new=new_input.name, user=self) replace_hook(old=old_input, new=new_input.name, user=self)
self._replace_input_with(old_input, new_input) # pyrefly: ignore [missing-attribute]
self._replace_input_with(old_input, new_input) # type: ignore[attr-defined]
def _rename(self, candidate: str) -> None: def _rename(self, candidate: str) -> None:
if candidate == self.name: if candidate == self.name:

View File

@ -303,7 +303,7 @@ class StreamContext:
self.idx = _get_device_index(None, True) self.idx = _get_device_index(None, True)
if not torch.jit.is_scripting(): if not torch.jit.is_scripting():
if self.idx is None: if self.idx is None:
self.idx = -1 # pyrefly: ignore # bad-assignment self.idx = -1 # pyrefly: ignore [bad-assignment]
self.src_prev_stream = ( self.src_prev_stream = (
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None) None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)

View File

@ -46,7 +46,7 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
if canonicalize: if canonicalize:
dim = canonicalize_dims(ndim, dim) dim = canonicalize_dims(ndim, dim)
assert dim >= 0 and dim < ndim # pyrefly: ignore # unsupported-operation assert dim >= 0 and dim < ndim # pyrefly: ignore [unsupported-operation]
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1. # Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
# For other dims, subtract 1 to convert to inner space. # For other dims, subtract 1 to convert to inner space.

View File

@ -72,7 +72,7 @@ class _NormBase(Module):
torch.tensor( torch.tensor(
0, 0,
dtype=torch.long, dtype=torch.long,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
**{k: v for k, v in factory_kwargs.items() if k != "dtype"}, **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
), ),
) )
@ -222,7 +222,7 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
dtype=None, dtype=None,
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
super().__init__( super().__init__(
# affine and track_running_stats are hardcoded to False to # affine and track_running_stats are hardcoded to False to
# avoid creating tensors that will soon be overwritten. # avoid creating tensors that will soon be overwritten.
@ -236,29 +236,29 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
self.affine = affine self.affine = affine
self.track_running_stats = track_running_stats self.track_running_stats = track_running_stats
if self.affine: if self.affine:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.weight = UninitializedParameter(**factory_kwargs) self.weight = UninitializedParameter(**factory_kwargs)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.bias = UninitializedParameter(**factory_kwargs) self.bias = UninitializedParameter(**factory_kwargs)
if self.track_running_stats: if self.track_running_stats:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.running_mean = UninitializedBuffer(**factory_kwargs) self.running_mean = UninitializedBuffer(**factory_kwargs)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.running_var = UninitializedBuffer(**factory_kwargs) self.running_var = UninitializedBuffer(**factory_kwargs)
self.num_batches_tracked = torch.tensor( self.num_batches_tracked = torch.tensor(
0, 0,
dtype=torch.long, dtype=torch.long,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
**{k: v for k, v in factory_kwargs.items() if k != "dtype"}, **{k: v for k, v in factory_kwargs.items() if k != "dtype"},
) )
def reset_parameters(self) -> None: def reset_parameters(self) -> None:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
if not self.has_uninitialized_params() and self.num_features != 0: if not self.has_uninitialized_params() and self.num_features != 0:
super().reset_parameters() super().reset_parameters()
def initialize_parameters(self, input) -> None: # type: ignore[override] def initialize_parameters(self, input) -> None: # type: ignore[override]
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
if self.has_uninitialized_params(): if self.has_uninitialized_params():
self.num_features = input.shape[1] self.num_features = input.shape[1]
if self.affine: if self.affine:
@ -352,6 +352,7 @@ class BatchNorm1d(_BatchNorm):
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)") raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
# pyrefly: ignore [inconsistent-inheritance]
class LazyBatchNorm1d(_LazyNormBase, _BatchNorm): class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization. r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
@ -463,6 +464,7 @@ class BatchNorm2d(_BatchNorm):
raise ValueError(f"expected 4D input (got {input.dim()}D input)") raise ValueError(f"expected 4D input (got {input.dim()}D input)")
# pyrefly: ignore [inconsistent-inheritance]
class LazyBatchNorm2d(_LazyNormBase, _BatchNorm): class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization. r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
@ -574,6 +576,7 @@ class BatchNorm3d(_BatchNorm):
raise ValueError(f"expected 5D input (got {input.dim()}D input)") raise ValueError(f"expected 5D input (got {input.dim()}D input)")
# pyrefly: ignore [inconsistent-inheritance]
class LazyBatchNorm3d(_LazyNormBase, _BatchNorm): class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization. r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.

View File

@ -38,13 +38,13 @@ T = TypeVar("T", bound="Module")
class _IncompatibleKeys( class _IncompatibleKeys(
# pyrefly: ignore # invalid-inheritance # pyrefly: ignore [invalid-inheritance]
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]), namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
): ):
__slots__ = () __slots__ = ()
def __repr__(self) -> str: def __repr__(self) -> str:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if not self.missing_keys and not self.unexpected_keys: if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>" return "<All keys matched successfully>"
return super().__repr__() return super().__repr__()
@ -93,7 +93,7 @@ class _WrappedHook:
def __getstate__(self) -> dict: def __getstate__(self) -> dict:
result = {"hook": self.hook, "with_module": self.with_module} result = {"hook": self.hook, "with_module": self.with_module}
if self.with_module: if self.with_module:
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
result["module"] = self.module() result["module"] = self.module()
return result return result
@ -979,7 +979,7 @@ class Module:
# Decrement use count of the gradient by setting to None # Decrement use count of the gradient by setting to None
param.grad = None param.grad = None
param_applied = torch.nn.Parameter( param_applied = torch.nn.Parameter(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
param_applied, param_applied,
requires_grad=param.requires_grad, requires_grad=param.requires_grad,
) )
@ -992,13 +992,13 @@ class Module:
) from e ) from e
out_param = param out_param = param
elif p_should_use_set_data: elif p_should_use_set_data:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
param.data = param_applied param.data = param_applied
out_param = param out_param = param
else: else:
assert isinstance(param, Parameter) assert isinstance(param, Parameter)
assert param.is_leaf assert param.is_leaf
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
out_param = Parameter(param_applied, param.requires_grad) out_param = Parameter(param_applied, param.requires_grad)
self._parameters[key] = out_param self._parameters[key] = out_param
@ -1337,7 +1337,9 @@ class Module:
""" """
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to( device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs # pyrefly: ignore [not-iterable]
*args,
**kwargs,
) )
if dtype is not None: if dtype is not None:
@ -2256,7 +2258,7 @@ class Module:
if destination is None: if destination is None:
destination = OrderedDict() destination = OrderedDict()
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
destination._metadata = OrderedDict() destination._metadata = OrderedDict()
local_metadata = dict(version=self._version) local_metadata = dict(version=self._version)
@ -2407,7 +2409,7 @@ class Module:
} }
local_name_params = itertools.chain( local_name_params = itertools.chain(
self._parameters.items(), self._parameters.items(),
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
persistent_buffers.items(), persistent_buffers.items(),
) )
local_state = {k: v for k, v in local_name_params if v is not None} local_state = {k: v for k, v in local_name_params if v is not None}

View File

@ -27,6 +27,7 @@ def _verbose_printer(verbose: bool | None) -> Callable[..., None]:
"""Prints messages based on `verbose`.""" """Prints messages based on `verbose`."""
if verbose is False: if verbose is False:
return lambda *_, **__: None return lambda *_, **__: None
# pyrefly: ignore [not-iterable]
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs) return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)
@ -47,7 +48,7 @@ def _patch_dynamo_unsupported_functions():
# Replace torch.jit.isinstance with isinstance # Replace torch.jit.isinstance with isinstance
jit_isinstance = torch.jit.isinstance jit_isinstance = torch.jit.isinstance
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
torch.jit.isinstance = isinstance torch.jit.isinstance = isinstance
logger.info("Replaced torch.jit.isinstance with isinstance to allow dynamo tracing") logger.info("Replaced torch.jit.isinstance with isinstance to allow dynamo tracing")
try: try:

View File

@ -132,10 +132,10 @@ class TorchTensor(ir.Tensor):
# view the tensor as that dtype so that it is convertible to NumPy, # view the tensor as that dtype so that it is convertible to NumPy,
# and then view it back to the proper dtype (using ml_dtypes obtained by # and then view it back to the proper dtype (using ml_dtypes obtained by
# calling dtype.numpy()). # calling dtype.numpy()).
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if self.dtype == ir.DataType.BFLOAT16: if self.dtype == ir.DataType.BFLOAT16:
return ( return (
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy()) self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
) )
if self.dtype in { if self.dtype in {
@ -144,11 +144,11 @@ class TorchTensor(ir.Tensor):
ir.DataType.FLOAT8E5M2, ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E5M2FNUZ, ir.DataType.FLOAT8E5M2FNUZ,
}: }:
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy()) return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
if self.dtype == ir.DataType.FLOAT4E2M1: if self.dtype == ir.DataType.FLOAT4E2M1:
return _type_casting.unpack_float4x2_as_uint8(self.raw).view( return _type_casting.unpack_float4x2_as_uint8(self.raw).view(
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
self.dtype.numpy() self.dtype.numpy()
) )
@ -170,7 +170,7 @@ class TorchTensor(ir.Tensor):
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor): if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
raise TypeError( raise TypeError(
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor " f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
"with a tensor backed by real data using ONNXProgram.apply_weights() " "with a tensor backed by real data using ONNXProgram.apply_weights() "
"or save the model without initializers by setting include_initializers=False." "or save the model without initializers by setting include_initializers=False."
@ -251,7 +251,7 @@ def _set_shape_type(
if isinstance(dim, int): if isinstance(dim, int):
dims.append(dim) dims.append(dim)
else: else:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
dims.append(str(dim.node)) dims.append(str(dim.node))
# If the dtype is set already (e.g. by the onnx_symbolic ops), # If the dtype is set already (e.g. by the onnx_symbolic ops),
@ -1232,7 +1232,7 @@ def _exported_program_to_onnx_program(
# so we need to get them from the name_* apis. # so we need to get them from the name_* apis.
for name, torch_tensor in itertools.chain( for name, torch_tensor in itertools.chain(
exported_program.named_parameters(), exported_program.named_parameters(),
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
exported_program.named_buffers(), exported_program.named_buffers(),
exported_program.constants.items(), exported_program.constants.items(),
): ):
@ -1265,6 +1265,7 @@ def _verbose_printer(verbose: bool | None) -> Callable[..., None]:
"""Prints messages based on `verbose`.""" """Prints messages based on `verbose`."""
if verbose is False: if verbose is False:
return lambda *_, **__: None return lambda *_, **__: None
# pyrefly: ignore [not-iterable]
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs) return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)

View File

@ -239,7 +239,7 @@ def _compare_onnx_pytorch_outputs_in_np(
if acceptable_error_percentage: if acceptable_error_percentage:
error_percentage = 1 - np.sum( error_percentage = 1 - np.sum(
np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol) np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol)
) / np.prod(ort_out.shape) # pyrefly: ignore # missing-attribute ) / np.prod(ort_out.shape) # pyrefly: ignore [missing-attribute]
if error_percentage <= acceptable_error_percentage: if error_percentage <= acceptable_error_percentage:
warnings.warn( warnings.warn(
f"Suppressed AssertionError:\n{e}.\n" f"Suppressed AssertionError:\n{e}.\n"

View File

@ -13,6 +13,7 @@ from torch import optim
def partialclass(cls, *args, **kwargs): # noqa: D103 def partialclass(cls, *args, **kwargs): # noqa: D103
class NewCls(cls): class NewCls(cls):
# pyrefly: ignore [not-iterable]
__init__ = partialmethod(cls.__init__, *args, **kwargs) __init__ = partialmethod(cls.__init__, *args, **kwargs)
return NewCls return NewCls

View File

@ -326,7 +326,7 @@ def gaussian(
requires_grad=requires_grad, requires_grad=requires_grad,
) )
return torch.exp(-(k**2)) # pyrefly: ignore # unsupported-operation return torch.exp(-(k**2)) # pyrefly: ignore [unsupported-operation]
@_add_docstr( @_add_docstr(

View File

@ -618,7 +618,7 @@ def _get_storage_from_sequence(sequence, dtype, device):
def _isint(x): def _isint(x):
if HAS_NUMPY: if HAS_NUMPY:
return isinstance(x, (int, np.integer)) # pyrefly: ignore # missing-attribute return isinstance(x, (int, np.integer)) # pyrefly: ignore [missing-attribute]
else: else:
return isinstance(x, int) return isinstance(x, int)

View File

@ -77,6 +77,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
def pop(self) -> T: def pop(self) -> T:
if not self: if not self:
raise KeyError("pop from an empty set") raise KeyError("pop from an empty set")
# pyrefly: ignore [bad-return]
return self._dict.popitem()[0] return self._dict.popitem()[0]
def copy(self) -> OrderedSet[T]: def copy(self) -> OrderedSet[T]:
@ -158,7 +159,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]: def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
# MutableSet impl will iterate over other, iter over smaller of two sets # MutableSet impl will iterate over other, iter over smaller of two sets
if isinstance(other, OrderedSet) and len(self) < len(other): if isinstance(other, OrderedSet) and len(self) < len(other):
# pyrefly: ignore # unsupported-operation, bad-return # pyrefly: ignore [unsupported-operation, bad-return]
return other & self return other & self
return cast(OrderedSet[T], super().__and__(other)) return cast(OrderedSet[T], super().__and__(other))

View File

@ -202,7 +202,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
Args: Args:
device (int, optional): if specified, all parameters will be copied to that device device (int, optional): if specified, all parameters will be copied to that device
""" """
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
return self._apply(lambda t: getattr(t, custom_backend_name)(device)) return self._apply(lambda t: getattr(t, custom_backend_name)(device))
_check_register_once(torch.nn.Module, custom_backend_name) _check_register_once(torch.nn.Module, custom_backend_name)
@ -252,11 +252,15 @@ def _generate_packed_sequence_methods_for_privateuse1_backend(
device (int, optional): if specified, all parameters will be copied to that device device (int, optional): if specified, all parameters will be copied to that device
""" """
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to( ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
*args, **kwargs # pyrefly: ignore [not-iterable]
*args,
**kwargs,
) )
if ex.device.type == custom_backend_name: if ex.device.type == custom_backend_name:
# pyrefly: ignore [not-iterable]
return self.to(*args, **kwargs) return self.to(*args, **kwargs)
kwargs.update({"device": custom_backend_name}) kwargs.update({"device": custom_backend_name})
# pyrefly: ignore [not-iterable]
return self.to(*args, **kwargs) return self.to(*args, **kwargs)
_check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name) _check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name)

View File

@ -48,7 +48,7 @@ MATH_TRANSPILATIONS = collections.OrderedDict(
] ]
) )
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
CUDA_TYPE_NAME_MAP = collections.OrderedDict( CUDA_TYPE_NAME_MAP = collections.OrderedDict(
[ [
("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)), ("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)),
@ -587,6 +587,7 @@ CUDA_TYPE_NAME_MAP = collections.OrderedDict(
] ]
) )
# pyrefly: ignore [no-matching-overload]
CUDA_INCLUDE_MAP = collections.OrderedDict( CUDA_INCLUDE_MAP = collections.OrderedDict(
[ [
# since pytorch uses "\b{pattern}\b" as the actual re pattern, # since pytorch uses "\b{pattern}\b" as the actual re pattern,
@ -676,7 +677,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
] ]
) )
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
CUDA_IDENTIFIER_MAP = collections.OrderedDict( CUDA_IDENTIFIER_MAP = collections.OrderedDict(
[ [
("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)), ("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)),
@ -8370,6 +8371,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
] ]
) )
# pyrefly: ignore [no-matching-overload]
CUDA_SPECIAL_MAP = collections.OrderedDict( CUDA_SPECIAL_MAP = collections.OrderedDict(
[ [
# SPARSE # SPARSE
@ -8852,6 +8854,7 @@ CUDA_SPECIAL_MAP = collections.OrderedDict(
] ]
) )
# pyrefly: ignore [no-matching-overload]
PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict( PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
[ [
("USE_CUDA", ("USE_ROCM", API_PYTORCH)), ("USE_CUDA", ("USE_ROCM", API_PYTORCH)),
@ -9316,6 +9319,7 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
] ]
) )
# pyrefly: ignore [no-matching-overload]
CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict( CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
[ [
("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)), ("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)),
@ -9401,6 +9405,7 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
# #
# NB: if you want a transformation to ONLY apply to the c10/ directory, # NB: if you want a transformation to ONLY apply to the c10/ directory,
# put it as API_CAFFE2 # put it as API_CAFFE2
# pyrefly: ignore [no-matching-overload]
C10_MAPPINGS = collections.OrderedDict( C10_MAPPINGS = collections.OrderedDict(
[ [
("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)), ("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)),

View File

@ -120,6 +120,7 @@ class GeneratedFileCleaner:
def open(self, fn, *args, **kwargs): def open(self, fn, *args, **kwargs):
if not os.path.exists(fn): if not os.path.exists(fn):
self.files_to_clean.add(os.path.abspath(fn)) self.files_to_clean.add(os.path.abspath(fn))
# pyrefly: ignore [not-iterable]
return open(fn, *args, **kwargs) return open(fn, *args, **kwargs)
def makedirs(self, dn, exist_ok=False): def makedirs(self, dn, exist_ok=False):
@ -669,7 +670,7 @@ def is_caffe2_gpu_file(rel_filepath):
return True return True
filename = os.path.basename(rel_filepath) filename = os.path.basename(rel_filepath)
_, ext = os.path.splitext(filename) _, ext = os.path.splitext(filename)
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename) return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
class TrieNode: class TrieNode:
@ -1145,7 +1146,7 @@ def hipify(
out_of_place_only=out_of_place_only, out_of_place_only=out_of_place_only,
is_pytorch_extension=is_pytorch_extension)) is_pytorch_extension=is_pytorch_extension))
all_files_set = set(all_files) all_files_set = set(all_files)
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
for f in extra_files: for f in extra_files:
if not os.path.isabs(f): if not os.path.isabs(f):
f = os.path.join(output_directory, f) f = os.path.join(output_directory, f)

View File

@ -292,9 +292,10 @@ class WeakIdKeyDictionary(MutableMapping):
if o is not None: if o is not None:
return o, value return o, value
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def pop(self, key, *args): def pop(self, key, *args):
self._dirty_len = True self._dirty_len = True
# pyrefly: ignore [not-iterable]
return self.data.pop(self.ref_type(key), *args) # CHANGED return self.data.pop(self.ref_type(key), *args) # CHANGED
def setdefault(self, key, default=None): def setdefault(self, key, default=None):

View File

@ -328,7 +328,7 @@ class StreamContext:
self.stream = stream self.stream = stream
self.idx = _get_device_index(None, True) self.idx = _get_device_index(None, True)
if self.idx is None: if self.idx is None:
self.idx = -1 # pyrefly: ignore # bad-assignment self.idx = -1 # pyrefly: ignore [bad-assignment]
def __enter__(self): def __enter__(self):
cur_stream = self.stream cur_stream = self.stream

View File

@ -126,7 +126,7 @@ class Event(torch._C._XpuEventBase):
""" """
if stream is None: if stream is None:
stream = torch.xpu.current_stream() stream = torch.xpu.current_stream()
super().record(stream) # pyrefly: ignore # bad-argument-type super().record(stream) # pyrefly: ignore [bad-argument-type]
def wait(self, stream=None) -> None: def wait(self, stream=None) -> None:
r"""Make all future work submitted to the given stream wait for this event. r"""Make all future work submitted to the given stream wait for this event.