Fix pyrefly ignores 1/n (#166239)

First diff adjusting the syntax for pyrefly: ignore suppressions so they only hide one class of type error.

Test:
lintrunner
pyrefly check

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166239
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss 2025-10-26 00:44:07 +00:00 committed by PyTorch MergeBot
parent 621ba05107
commit c7eee49525
55 changed files with 282 additions and 184 deletions

View File

@ -130,6 +130,7 @@ errors.bad-param-name-override = false
# Mypy doesn't require that imports are explicitly imported, so be compatible with that.
# Might be a good idea to turn this on in future.
errors.implicit-import = false
errors.deprecated = false # re-enable after we've fix import formatting
permissive-ignores = true
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
search-path = ["tools/experimental"]

View File

@ -2,7 +2,8 @@ from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional
from dataclasses_json import DataClassJsonMixin
# pyrefly: ignore [missing-import]
from dataclasses_json import DataClassJsonMixin # type: ignore[import-not-found]
_DATA_MODEL_VERSION = 1.5
@ -17,7 +18,7 @@ class UtilizationStats:
@dataclass
class UtilizationMetadata(DataClassJsonMixin):
class UtilizationMetadata(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
level: str
workflow_id: str
job_id: str
@ -33,7 +34,7 @@ class UtilizationMetadata(DataClassJsonMixin):
@dataclass
class GpuUsage(DataClassJsonMixin):
class GpuUsage(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
uuid: Optional[str] = None
util_percent: Optional[UtilizationStats] = None
mem_util_percent: Optional[UtilizationStats] = None
@ -43,14 +44,14 @@ class GpuUsage(DataClassJsonMixin):
@dataclass
class RecordData(DataClassJsonMixin):
class RecordData(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
cpu: Optional[UtilizationStats] = None
memory: Optional[UtilizationStats] = None
gpu_usage: Optional[list[GpuUsage]] = None
@dataclass
class UtilizationRecord(DataClassJsonMixin):
class UtilizationRecord(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
level: str
timestamp: int
data: Optional[RecordData] = None
@ -63,7 +64,7 @@ class UtilizationRecord(DataClassJsonMixin):
# the db schema related to this is:
# https://github.com/pytorch/test-infra/blob/main/clickhouse_db_schema/oss_ci_utilization/oss_ci_utilization_metadata_schema.sql
@dataclass
class OssCiSegmentV1(DataClassJsonMixin):
class OssCiSegmentV1(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
level: str
name: str
start_at: int

View File

@ -1703,7 +1703,7 @@ def _check(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error
message. Default: ``None``
"""
_check_with(RuntimeError, cond, message) # pyrefly: ignore # bad-argument-type
_check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
# TODO add deprecation annotation
@ -1753,7 +1753,7 @@ def _check_index(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error
message. Default: ``None``
"""
_check_with(IndexError, cond, message) # pyrefly: ignore # bad-argument-type
_check_with(IndexError, cond, message) # pyrefly: ignore [bad-argument-type]
def _check_value(cond, message=None): # noqa: F811
@ -1771,7 +1771,7 @@ def _check_value(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error
message. Default: ``None``
"""
_check_with(ValueError, cond, message) # pyrefly: ignore # bad-argument-type
_check_with(ValueError, cond, message) # pyrefly: ignore [bad-argument-type]
def _check_type(cond, message=None): # noqa: F811
@ -1789,7 +1789,7 @@ def _check_type(cond, message=None): # noqa: F811
an object that has a ``__str__()`` method to be used as the error
message. Default: ``None``
"""
_check_with(TypeError, cond, message) # pyrefly: ignore # bad-argument-type
_check_with(TypeError, cond, message) # pyrefly: ignore [bad-argument-type]
def _check_not_implemented(cond, message=None): # noqa: F811

View File

@ -101,7 +101,7 @@ def custom_op(
lib, ns, function_schema, name, ophandle, _private_access=True
)
result.__name__ = func.__name__ # pyrefly: ignore # bad-assignment
result.__name__ = func.__name__ # pyrefly: ignore [bad-assignment]
result.__module__ = func.__module__
result.__doc__ = func.__doc__

View File

@ -154,7 +154,7 @@ def make_crossref_functionalize(
maybe_detach, (f_args, f_kwargs)
)
with fake_mode:
f_r = op(*f_args, **f_kwargs) # pyrefly: ignore # invalid-param-spec
f_r = op(*f_args, **f_kwargs) # pyrefly: ignore [invalid-param-spec]
r = op._op_dk(final_key, *args, **kwargs)
def desc():

View File

@ -1029,7 +1029,7 @@ class BuiltinVariable(VariableTracker):
def call_self_handler(tx: "InstructionTranslator", args, kwargs):
try:
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
result = self_handler(tx, *args, **kwargs)
if result is not None:
return result
@ -1037,7 +1037,7 @@ class BuiltinVariable(VariableTracker):
# Check if binding is bad. inspect signature bind is expensive.
# So check only when handler call fails.
try:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
inspect.signature(self_handler).bind(tx, *args, **kwargs)
except TypeError as e:
has_constant_handler = obj.has_constant_handler(args, kwargs)
@ -1090,7 +1090,7 @@ class BuiltinVariable(VariableTracker):
hints=[*graph_break_hints.DYNAMO_BUG],
from_exc=exc,
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return VariableTracker.build(tx, res)
else:
@ -1119,7 +1119,7 @@ class BuiltinVariable(VariableTracker):
tx,
args=list(map(ConstantVariable.create, exc.args)),
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return VariableTracker.build(tx, res)
handlers.append(constant_fold_handler)
@ -1442,7 +1442,7 @@ class BuiltinVariable(VariableTracker):
resolved_fn = getattr(self.fn, name)
if resolved_fn in dict_methods:
if isinstance(args[0], variables.UserDefinedDictVariable):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.ConstDictVariable):
return args[0].call_method(tx, name, args[1:], kwargs)
@ -1451,7 +1451,7 @@ class BuiltinVariable(VariableTracker):
resolved_fn = getattr(self.fn, name)
if resolved_fn in set_methods:
if isinstance(args[0], variables.UserDefinedSetVariable):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
elif isinstance(args[0], variables.SetVariable):
return args[0].call_method(tx, name, args[1:], kwargs)
@ -1540,12 +1540,12 @@ class BuiltinVariable(VariableTracker):
if type(arg.value).__str__ is object.__str__:
# Rely on the object str method
try:
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return variables.ConstantVariable.create(value=str_method())
except AttributeError:
# Graph break
return
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
elif is_wrapper_or_member_descriptor(str_method):
unimplemented_v2(
gb_type="Attempted to a str() method implemented in C/C++",
@ -1662,10 +1662,10 @@ class BuiltinVariable(VariableTracker):
else:
raw_b = b.raw_value
if self.fn is max:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
raw_res = max(a.raw_value, raw_b)
else:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
raw_res = min(a.raw_value, raw_b)
need_unwrap = any(
@ -1980,12 +1980,16 @@ class BuiltinVariable(VariableTracker):
if isinstance(arg, dict):
arg = [ConstantVariable.create(k) for k in arg.keys()]
return DictVariableType(
dict.fromkeys(arg, value), user_cls, mutation_type=ValueMutationNew()
# pyrefly: ignore [bad-argument-type]
dict.fromkeys(arg, value),
user_cls,
mutation_type=ValueMutationNew(),
)
elif arg.has_force_unpack_var_sequence(tx):
keys = arg.force_unpack_var_sequence(tx)
if all(is_hashable(v) for v in keys):
return DictVariableType(
# pyrefly: ignore [bad-argument-type]
dict.fromkeys(keys, value),
user_cls,
mutation_type=ValueMutationNew(),
@ -2152,7 +2156,7 @@ class BuiltinVariable(VariableTracker):
)
if isinstance(arg, variables.UserDefinedExceptionClassVariable):
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return ConstantVariable.create(isinstance(arg_type, isinstance_type))
isinstance_type_tuple: tuple[type, ...]
@ -2185,10 +2189,10 @@ class BuiltinVariable(VariableTracker):
# through it. This is a limitation of the current implementation.
# Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
# might not be a big issue and we trade off it for performance.
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
val = issubclass(arg_type, isinstance_type_tuple)
except TypeError:
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
val = arg_type in isinstance_type_tuple
return variables.ConstantVariable.create(val)
@ -2210,7 +2214,7 @@ class BuiltinVariable(VariableTracker):
# WARNING: This might run arbitrary user code `__subclasscheck__`.
# See the comment in call_isinstance above.
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
def call_super(self, tx: "InstructionTranslator", a, b):
@ -2256,9 +2260,9 @@ class BuiltinVariable(VariableTracker):
value = getattr(self.fn, name)
except AttributeError:
raise_observed_exception(AttributeError, tx)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
if not callable(value):
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return VariableTracker.build(tx, value, source)
return variables.GetAttrVariable(self, name, source=source)

View File

@ -34,6 +34,7 @@ class LazyCache:
self.vt = builder.VariableBuilder(tx, self.source)(self.value)
if self.name_hint is not None:
# pyrefly: ignore [missing-attribute]
self.vt.set_name_hint(self.name_hint)
del self.value

View File

@ -1138,11 +1138,13 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
if isinstance(m, FakeTensorMode):
# pyrefly: ignore [bad-argument-type]
fake_modes.append((m, "active fake mode", i))
flat_inputs = pytree.tree_leaves(inputs)
for i, flat_input in enumerate(flat_inputs):
if isinstance(flat_input, FakeTensor):
# pyrefly: ignore [bad-argument-type]
fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
if is_traceable_wrapper_subclass(flat_input):
out: list[Union[torch.Tensor, int, torch.SymInt]] = []
@ -1151,6 +1153,7 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
x for x in out if isinstance(x, FakeTensor)
]
fake_modes.extend(
# pyrefly: ignore [bad-argument-type]
[
(tensor.fake_mode, f"subclass input {i}", ix)
for ix, tensor in enumerate(fake_tensors)
@ -1162,9 +1165,12 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
for m, desc2, i2 in fake_modes[1:]:
assert fake_mode is m, (
f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
# pyrefly: ignore [missing-attribute]
f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
# pyrefly: ignore [missing-attribute]
f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
)
# pyrefly: ignore [bad-return]
return fake_mode
else:
return None

View File

@ -114,6 +114,7 @@ def _cancel_all_tasks(
for task in to_cancel:
task.cancel()
# pyrefly: ignore [bad-argument-type]
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
for task in to_cancel:
@ -149,7 +150,7 @@ def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[
task_factory = task_factories[0]
if task_factory is None:
if sys.version_info >= (3, 11):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
task = asyncio.Task(coro, loop=loop, context=context)
else:
task = asyncio.Task(coro, loop=loop)

View File

@ -590,11 +590,11 @@ class CKGemmTemplate(CKTemplate):
arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
else: # tile shape
arg = f"/* {field_name} */ S<{tuple_elements}>"
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
template_params.append(arg)
else:
if field_value is not None:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
template_params.append(f"/* {field_name} */ {field_value}")
operation_name = op.name().replace("(", "").replace(",", "").replace(")", "")
return self._template_from_string(template_definition).render(
@ -939,6 +939,7 @@ class CKGemmTemplate(CKTemplate):
for o in rops:
kBatches = self._get_kBatch(o)
for kBatch in kBatches:
# pyrefly: ignore [bad-argument-type]
ops.append(InductorROCmOp(op=o, kBatch=kBatch))
filtered_instances = list(filter(lambda op: self.filter_op(op), ops))

View File

@ -273,7 +273,7 @@ def record_original_output_strides(gm: GraphModule) -> None:
):
output_strides.append(val.stride())
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_strides.append(None)
output_node.meta["original_output_strides"] = output_strides
@ -1110,6 +1110,7 @@ def _compile_fx_inner(
)
log.info("-" * 130)
for row in mm_table_data:
# pyrefly: ignore [not-iterable]
log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001
log.info("-" * 130)
@ -1551,7 +1552,7 @@ class _InProcessFxCompile(FxCompile):
node_runtimes = None
if inductor_metrics_log.isEnabledFor(logging.INFO):
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
metrics.num_bytes_accessed += num_bytes
metrics.node_runtimes += node_runtimes
metrics.nodes_num_elem += nodes_num_elem
@ -1595,10 +1596,10 @@ class _InProcessFxCompile(FxCompile):
disable = f"{disable} Found from {stack_trace}\n"
else:
disable = f"{disable}\n"
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
V.graph.disable_cudagraphs_reason = disable
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
if cudagraphs and not V.graph.disable_cudagraphs_reason:
maybe_incompat_node = get_first_incompatible_cudagraph_node(gm)
if maybe_incompat_node:
@ -1607,29 +1608,29 @@ class _InProcessFxCompile(FxCompile):
"stack_trace", None
):
disable = f"{disable} Found from {stack_trace}\n"
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
V.graph.disable_cudagraphs_reason = disable
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
if V.aot_compilation:
assert isinstance(
compiled_fn,
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
(str, list, torch.fx.GraphModule),
), type(compiled_fn)
return CompiledAOTI(compiled_fn)
# TODO: Hoist this above V.aot_compilation
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
if cudagraphs and not V.graph.disable_cudagraphs_reason:
from torch._inductor.cudagraph_utils import (
check_lowering_disable_cudagraph,
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
V.graph.disable_cudagraphs_reason = (
check_lowering_disable_cudagraph(
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
V.graph.device_node_mapping
)
)
@ -1637,29 +1638,29 @@ class _InProcessFxCompile(FxCompile):
self._compile_stats[type(self)].codegen_and_compile += 1
if (
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
torch._inductor.debug.RECORD_GRAPH_EXECUTION
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
and torch._inductor.debug.GRAPH_COMPILE_IDS is not None
):
compile_id = str(
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
torch._guards.CompileContext.current_compile_id()
)
graph_id = graph_kwargs.get("graph_id")
if graph_id is not None:
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
torch._inductor.debug.GRAPH_COMPILE_IDS[graph_id] = (
compile_id
)
return CompiledFxGraph(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
compiled_fn,
graph,
gm,
output_strides,
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
V.graph.disable_cudagraphs_reason,
metrics_helper.get_deltas(),
counters["inductor"] - inductor_counters,
@ -1701,18 +1702,18 @@ def fx_codegen_and_compile(
from .compile_fx_async import _AsyncFxCompile
from .compile_fx_ext import _OutOfProcessFxCompile
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
assert isinstance(scheme, _OutOfProcessFxCompile), (
"async is only valid with an out-of-process compile mode"
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
scheme = _AsyncFxCompile(scheme)
if fx_compile_progressive:
from .compile_fx_async import _ProgressiveFxCompile
from .compile_fx_ext import _OutOfProcessFxCompile
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
assert isinstance(scheme, _OutOfProcessFxCompile), (
"progressive is only valid with an out-of-process compile mode"
)
@ -1722,10 +1723,10 @@ def fx_codegen_and_compile(
# Use in-process compile for the fast version
fast_scheme = _InProcessFxCompile()
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
@ -1835,7 +1836,7 @@ def cudagraphify_impl(
Assumes inputs[static_input_idxs[i]] are always the same memory address
"""
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type]
# pyrefly: ignore # annotation-mismatch
# pyrefly: ignore [annotation-mismatch]
static_input_idxs: OrderedSet[int] = OrderedSet(
remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type]
)
@ -1902,7 +1903,7 @@ def cudagraphify_impl(
index_expanded_dims_and_copy_(dst, src, expanded_dims)
new_inputs.clear()
graph.replay()
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return static_outputs
else:
@ -1918,7 +1919,7 @@ def cudagraphify_impl(
index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims)
new_inputs.clear()
graph.replay()
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return static_outputs
return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet())
@ -1935,7 +1936,7 @@ def compile_fx_aot(
# [See NOTE] Unwrapping subclasses AOT
unwrap_tensor_subclass_parameters(model_)
# pyrefly: ignore # annotation-mismatch
# pyrefly: ignore [annotation-mismatch]
config_patches: dict[str, Any] = copy.deepcopy(config_patches or {})
if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper):
@ -2878,7 +2879,7 @@ def _aoti_flatten_inputs(
Flatten the inputs to the graph module and return the flat inputs and options.
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.
"""
# pyrefly: ignore # missing-module-attribute
# pyrefly: ignore [missing-module-attribute]
from .compile_fx import graph_returns_tuple
assert graph_returns_tuple(gm), (

View File

@ -291,7 +291,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs):
log.debug("example value absent for node: %s", input)
return
ndim = input.meta["example_value"].ndim
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
if dim < 0: # Normalize unbind dim
dim += ndim
with graph.inserting_after(node):
@ -341,7 +341,7 @@ def normalize_cat_default(match: Match, *args, **kwargs):
ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors
)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
if cat_dim < 0: # Normalize cat dim
cat_dim += ndim
@ -949,7 +949,7 @@ class SplitCatSimplifier:
if isinstance(user_input, tuple):
# Find the correct new getitem (present in split_items)
new_user_inputs.append(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
split_items[
split_ranges.index(
(
@ -1000,7 +1000,7 @@ class SplitCatSimplifier:
for user_input_new, transform_param in zip(
user_inputs_new, transform_params
):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if not is_node_meta_valid(user_input_new):
log.debug("example value absent for node: %s", user_input_new)
return
@ -1015,7 +1015,7 @@ class SplitCatSimplifier:
stack_dim is None or stack_dim == unsqueeze_params[0]
):
to_stack.append(user_input_new)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
to_stack_meta.append(user_input_new.meta["example_value"])
stack_dim = unsqueeze_params[0]
continue
@ -1036,12 +1036,12 @@ class SplitCatSimplifier:
if unsqueeze_params:
to_stack.append(user_input_new)
stack_dim = unsqueeze_params[0]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
to_stack_meta.append(user_input_new.meta["example_value"])
continue
if unflatten_params:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function(
torch.unflatten, args=(user_input_new, *unflatten_params)
@ -1051,7 +1051,7 @@ class SplitCatSimplifier:
*unflatten_params, # type: ignore[arg-type]
)
if movedim_params:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function(
torch.movedim, args=(user_input_new, *movedim_params)
@ -1061,7 +1061,7 @@ class SplitCatSimplifier:
*movedim_params, # type: ignore[arg-type]
)
if flatten_params:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function(
torch.flatten, args=(user_input_new, *flatten_params)
@ -1072,7 +1072,7 @@ class SplitCatSimplifier:
)
user_inputs_new_transformed.append(user_input_new)
user_inputs_new_transformed_meta.append(
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
user_input_new.meta["example_value"]
)
if to_stack:
@ -1432,7 +1432,7 @@ def simplify_split_cat(match: Match, split_sections: list[int], dim: int):
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
return
split_node = next(node for node in match.nodes if node.target == torch.split)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
@ -1501,7 +1501,7 @@ def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: list[int]) -
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
if i in indices:
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return fused_tensor_size
@ -1978,7 +1978,7 @@ def normalize_cat_default_aten(match: Match, *args, **kwargs):
assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
if cat_dim < 0: # Normalize cat dim
cat_dim += ndim
@ -2512,7 +2512,8 @@ def reshape_cat_node_to_stack(
args=(cat_node, tuple(reshape_list)),
)
reshape_node.meta["example_value"] = torch.reshape(
cat_node.meta["example_value"], tuple(reshape_list)
cat_node.meta["example_value"],
tuple(reshape_list), # pyrefly: ignore [bad-argument-type]
)
permute_list = list(range(len(stack_shape)))
permute_list[stack_dim], permute_list[split_or_unbind_dim] = (
@ -3044,6 +3045,6 @@ def replace_einsum_to_pointwise(match: Match, *args, **kwargs):
einsum_node = match.nodes[0]
input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2)
if should_replace_einsum(einsum_node):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
match.replace_by_example(repl, [input, weights])
counters[backend]["einsum_to_pointwise_pass"] += 1

View File

@ -147,7 +147,7 @@ def _qualified_name(obj, mangle_name=True) -> str:
# If the module is actually a torchbind module, then we should short circuit
if module_name == "torch._classes":
return obj.qualified_name # pyrefly: ignore # missing-attribute
return obj.qualified_name # pyrefly: ignore [missing-attribute]
# The Python docs are very clear that `__module__` can be None, but I can't
# figure out when it actually would be.
@ -759,7 +759,7 @@ def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
)
return prop # pyrefly: ignore # bad-return
return prop # pyrefly: ignore [bad-return]
fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
return fn

View File

@ -65,8 +65,8 @@ class AsyncClosureHandler(ClosureHandler):
self._closure_event_loop = threading.Thread(
target=event_loop
) # pyrefly: ignore # bad-assignment
self._closure_event_loop.start() # pyrefly: ignore # missing-attribute
) # pyrefly: ignore [bad-assignment]
self._closure_event_loop.start() # pyrefly: ignore [missing-attribute]
def run(self, closure):
with self._closure_lock:

View File

@ -515,8 +515,11 @@ def meta_copy_(self, src, non_blocking=False):
def inferUnsqueezeGeometry(tensor, dim):
result_sizes = list(tensor.size())
result_strides = list(tensor.stride())
# pyrefly: ignore [unsupported-operation]
new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
# pyrefly: ignore [bad-argument-type]
result_sizes.insert(dim, 1)
# pyrefly: ignore [bad-argument-type]
result_strides.insert(dim, new_stride)
return result_sizes, result_strides
@ -2341,19 +2344,19 @@ def calc_conv_nd_return_shape(
ret_shape = [input_tensor.shape[0], out_channels]
if isinstance(stride, IntLike):
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
stride = [stride] * len(dims)
elif len(stride) == 1:
stride = [stride[0]] * len(dims)
if isinstance(padding, IntLike):
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
padding = [padding] * len(dims)
elif len(padding) == 1:
padding = [padding[0]] * len(dims)
if isinstance(dilation, IntLike):
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
dilation = [dilation] * len(dims)
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)
@ -2361,7 +2364,7 @@ def calc_conv_nd_return_shape(
output_padding_list: Optional[list[int]] = None
if output_padding:
if isinstance(output_padding, IntLike):
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims)
@ -2374,19 +2377,19 @@ def calc_conv_nd_return_shape(
ret_shape.append(
_formula_transposed(
dims[i],
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
padding[i],
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
dilation[i],
kernel_size[i],
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
stride[i],
output_padding_list[i],
)
)
else:
ret_shape.append(
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
)
from torch.fx.experimental.symbolic_shapes import sym_or
@ -3454,7 +3457,7 @@ def meta_index_Tensor(self, indices):
"""
shape = before_shape + replacement_shape + after_shape
strides = list(self.stride())
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
replacement_shape
)
@ -5311,6 +5314,7 @@ def full(size, fill_value, *args, **kwargs):
if not dtype:
dtype = utils.get_dtype(fill_value)
kwargs["dtype"] = dtype
# pyrefly: ignore [not-iterable]
return torch.empty(size, *args, **kwargs)
@ -6668,7 +6672,7 @@ def rnn_cell_checkSizes(
)
torch._check(
all(
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
x.device == input_gates.device
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
),

View File

@ -880,7 +880,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
return self._op_dk(dk, *args, **kwargs)
else:
return NotImplemented # pyrefly: ignore # bad-return
return NotImplemented # pyrefly: ignore [bad-return]
# Remove a dispatch key from the dispatch cache. This will force it to get
# recomputed the next time. Does nothing
@ -985,9 +985,9 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
r = self.py_kernels.get(final_key, final_key)
if cache_result:
self._dispatch_cache[key] = r # pyrefly: ignore # unsupported-operation
self._dispatch_cache[key] = r # pyrefly: ignore [unsupported-operation]
add_cached_op(self)
return r # pyrefly: ignore # bad-return
return r # pyrefly: ignore [bad-return]
def name(self):
return self._name
@ -1117,7 +1117,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
)
assert isinstance(handler, Callable) # type: ignore[arg-type]
return handler(*args, **kwargs) # pyrefly: ignore # bad-return
return handler(*args, **kwargs) # pyrefly: ignore [bad-return]
def _must_dispatch_in_python(args, kwargs):

View File

@ -267,6 +267,7 @@ class FunctionalTensor(torch.Tensor):
device=self.device,
layout=self.layout,
)
# pyrefly: ignore [not-iterable]
return super().to(*args, **kwargs)
def cuda(self, device=None, *args, **kwargs):

View File

@ -551,6 +551,7 @@ class Tensor(torch._C.TensorBase):
raise RuntimeError("__setstate__ can be only called on leaf Tensors")
if len(state) == 4:
# legacy serialization of Tensor
# pyrefly: ignore [not-iterable]
self.set_(*state)
return
elif len(state) == 5:
@ -758,7 +759,7 @@ class Tensor(torch._C.TensorBase):
)
if self._post_accumulate_grad_hooks is None:
self._post_accumulate_grad_hooks: dict[Any, Any] = (
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
OrderedDict()
)
@ -1062,7 +1063,7 @@ class Tensor(torch._C.TensorBase):
else:
return torch._VF.split_with_sizes(
self,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
split_size,
dim,
)
@ -1119,7 +1120,7 @@ class Tensor(torch._C.TensorBase):
__rtruediv__ = __rdiv__
__itruediv__ = _C.TensorBase.__idiv__
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
__pow__ = cast(
Callable[
["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]],

View File

@ -686,8 +686,8 @@ def _take_tensors(tensors, size_limit):
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
yield buf_and_size[0]
buf_and_size = buf_dict[t] = [[], 0]
buf_and_size[0].append(tensor) # pyrefly: ignore # missing-attribute
buf_and_size[1] += size # pyrefly: ignore # unsupported-operation
buf_and_size[0].append(tensor) # pyrefly: ignore [missing-attribute]
buf_and_size[1] += size # pyrefly: ignore [unsupported-operation]
for buf, _ in buf_dict.values():
if len(buf) > 0:
yield buf
@ -744,6 +744,7 @@ class ExceptionWrapper:
if exc_info is None:
exc_info = sys.exc_info()
self.exc_type = exc_info[0]
# pyrefly: ignore [not-iterable]
self.exc_msg = "".join(traceback.format_exception(*exc_info))
self.where = where
@ -751,7 +752,7 @@ class ExceptionWrapper:
r"""Reraises the wrapped exception in the current thread"""
# Format a message such as: "Caught ValueError in DataLoader worker
# process 2. Original Traceback:", followed by the traceback.
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore [missing-attribute]
if self.exc_type is KeyError:
# KeyError calls repr() on its argument (usually a dict key). This
# makes stack traces unreadable. It will not be changed in Python
@ -760,13 +761,13 @@ class ExceptionWrapper:
elif getattr(self.exc_type, "message", None):
# Some exceptions have first argument as non-str but explicitly
# have message field
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
raise self.exc_type(
# pyrefly: ignore # unexpected-keyword
# pyrefly: ignore [unexpected-keyword]
message=msg
)
try:
exception = self.exc_type(msg) # pyrefly: ignore # not-callable
exception = self.exc_type(msg) # pyrefly: ignore [not-callable]
except Exception:
# If the exception takes multiple arguments or otherwise can't
# be constructed, don't try to instantiate since we don't know how to
@ -1018,12 +1019,12 @@ class _LazySeedTracker:
self.call_order = []
def queue_seed_all(self, cb, traceback):
self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore # bad-assignment
self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore [bad-assignment]
# update seed_all to be latest
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
def queue_seed(self, cb, traceback):
self.manual_seed_cb = (cb, traceback) # pyrefly: ignore # bad-assignment
self.manual_seed_cb = (cb, traceback) # pyrefly: ignore [bad-assignment]
# update seed to be latest
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]

View File

@ -419,6 +419,7 @@ class Unpickler:
inst = self.stack[-1]
if type(inst) is torch.Tensor:
# Legacy unpickling
# pyrefly: ignore [not-iterable]
inst.set_(*state)
elif type(inst) is torch.nn.Parameter:
inst.__setstate__(state)

View File

@ -104,6 +104,7 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
flatten("", stats)
flat_stats.sort()
# pyrefly: ignore [no-matching-overload]
return OrderedDict(flat_stats)

View File

@ -525,7 +525,7 @@ def custom_fwd(
args[0]._dtype = torch.get_autocast_dtype(device_type)
if cast_inputs is None:
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable
return fwd(*args, **kwargs) # pyrefly: ignore [not-callable]
else:
autocast_context = torch.is_autocast_enabled(device_type)
args[0]._fwd_used_autocast = False
@ -536,7 +536,7 @@ def custom_fwd(
**_cast(kwargs, device_type, cast_inputs),
)
else:
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable
return fwd(*args, **kwargs) # pyrefly: ignore [not-callable]
return decorate_fwd
@ -571,6 +571,6 @@ def custom_bwd(bwd=None, *, device_type: str):
enabled=args[0]._fwd_used_autocast,
dtype=args[0]._dtype,
):
return bwd(*args, **kwargs) # pyrefly: ignore # not-callable
return bwd(*args, **kwargs) # pyrefly: ignore [not-callable]
return decorate_bwd

View File

@ -84,7 +84,7 @@ class _NSGraphMatchableSubgraphsIterator:
if is_match:
# navigate to the base node
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.seen_nodes.add(cur_start_node)
# for now, assume that there are no other nodes
# which need to be added to the stack
@ -95,10 +95,10 @@ class _NSGraphMatchableSubgraphsIterator:
cur_base_op_node = cur_start_node
break
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.seen_nodes.add(cur_start_node)
# add args of previous nodes to stack
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for arg in cur_start_node.all_input_nodes:
self._recursively_add_node_arg_to_stack(arg)
@ -106,7 +106,7 @@ class _NSGraphMatchableSubgraphsIterator:
# note: this check is done on the start_node, i.e.
# if we are matching linear-relu in reverse, this would do the matchable
# check on the linear
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if not self._is_matchable(cur_base_op_node):
continue
@ -120,10 +120,10 @@ class _NSGraphMatchableSubgraphsIterator:
continue
return NSSubgraph(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
start_node=cur_start_node,
end_node=cur_end_node,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
base_op_node=cur_base_op_node,
)
@ -481,4 +481,5 @@ of subgraphs."""
# subgraphs in their order of execution.
results = collections.OrderedDict(reversed(results.items()))
# pyrefly: ignore [bad-return]
return results

View File

@ -30,6 +30,7 @@ class EventList(list):
use_device = kwargs.pop("use_device", None)
profile_memory = kwargs.pop("profile_memory", False)
with_flops = kwargs.pop("with_flops", False)
# pyrefly: ignore [not-iterable]
super().__init__(*args, **kwargs)
self._use_device = use_device
self._profile_memory = profile_memory
@ -505,9 +506,9 @@ class FunctionEvent(FormattedTimesMixin):
self.id: int = id
self.node_id: int = node_id
self.name: str = name
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.overload_name: str = overload_name
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.trace_name: str = trace_name
self.time_range: Interval = Interval(start_us, end_us)
self.thread: int = thread
@ -516,13 +517,13 @@ class FunctionEvent(FormattedTimesMixin):
self.count: int = 1
self.cpu_children: list[FunctionEvent] = []
self.cpu_parent: Optional[FunctionEvent] = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.input_shapes: tuple[int, ...] = input_shapes
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.concrete_inputs: list[Any] = concrete_inputs
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.kwinputs: dict[str, Any] = kwinputs
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.stack: list = stack
self.scope: int = scope
self.use_device: Optional[str] = use_device
@ -766,7 +767,7 @@ class FunctionEventAvg(FormattedTimesMixin):
self.self_device_memory_usage += other.self_device_memory_usage
self.count += other.count
if self.flops is None:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.flops = other.flops
elif other.flops is not None:
self.flops += other.flops
@ -1003,7 +1004,7 @@ def _build_table(
]
if flops <= 0:
raise AssertionError(f"Expected flops to be positive, but got {flops}")
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
if not (log_flops >= 0 and log_flops < len(flop_headers)):
raise AssertionError(

View File

@ -50,6 +50,7 @@ def compile(*args, **kwargs):
"""
See :func:`torch.compile` for details on the arguments for this function.
"""
# pyrefly: ignore [not-iterable]
return torch.compile(*args, **kwargs)

View File

@ -198,6 +198,7 @@ def _for_each_rank_run_func(
rr_val = flat_rank_rets[rr_key]
if isinstance(rr_val, Tensor):
# pyrefly: ignore [bad-argument-type, bad-argument-count]
ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)})
elif isinstance(rr_val, (list, tuple)):
ret_list = []
@ -206,6 +207,7 @@ def _for_each_rank_run_func(
v_it = iter(rets.values())
v = next(v_it)
if isinstance(v, Tensor):
# pyrefly: ignore [bad-argument-type, bad-argument-count]
ret_list.append(LocalTensor(rets))
elif isinstance(v, int) and not all(v == v2 for v2 in v_it):
ret_list.append(torch.SymInt(LocalIntNode(rets)))
@ -468,7 +470,7 @@ class LocalTensor(torch.Tensor):
def __repr__(self) -> str: # type: ignore[override]
parts = []
for k, v in self._local_tensors.items():
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
parts.append(f" {k}: {v}")
tensors_str = ",\n".join(parts)
return f"LocalTensor(\n{tensors_str}\n)"
@ -491,6 +493,7 @@ class LocalTensor(torch.Tensor):
"Expecting spec to be not None from `__tensor_flatten__` return value!"
)
local_tensors = inner_tensors["_local_tensors"]
# pyrefly: ignore [bad-argument-type, bad-argument-count]
return LocalTensor(local_tensors)
@classmethod
@ -751,6 +754,7 @@ class LocalTensorMode(TorchDispatchMode):
"""
with self.disable():
# pyrefly: ignore [bad-argument-type, bad-argument-count]
return LocalTensor({r: cb(r) for r in self.ranks})
def _patch_device_mesh(self) -> None:
@ -761,7 +765,7 @@ class LocalTensorMode(TorchDispatchMode):
def _unpatch_device_mesh(self) -> None:
assert self._old_get_coordinate is not None
DeviceMesh.get_coordinate = self._old_get_coordinate
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self._old_get_coordinate = None

View File

@ -79,6 +79,7 @@ def _flatten_tensor_size(size) -> torch.Size:
Checks if tensor size is valid, then flatten/return a torch.Size object.
"""
if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
# pyrefly: ignore [not-iterable]
dims = list(*size)
else:
dims = list(size)
@ -208,7 +209,7 @@ def build_global_metadata(
global_sharded_tensor_metadata = None
global_metadata_rank = 0
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for rank, rank_metadata in enumerate(gathered_metadatas):
if rank_metadata is None:
continue

View File

@ -227,7 +227,7 @@ class PGTransport:
self._work: list[Work] = []
self._pg = pg
self._timeout = timeout
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self._device = device
self._state_dict = state_dict
@ -345,6 +345,7 @@ class PGTransport:
values.append(recv(path, v))
elif isinstance(v, _DTensorMeta):
tensor = recv(path, v.local)
# pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword]
values.append(DTensor(tensor, v.spec, requires_grad=False))
elif isinstance(v, _ShardedTensorMeta):
# Receive all local shards that were sent to us

View File

@ -565,7 +565,7 @@ class FlatParamHandle:
# Only align addresses for `use_orig_params=True` (for now)
align_addresses = use_orig_params
self._init_get_unflat_views_fn(align_addresses)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.device = device
self._device_handle = _FSDPDeviceHandle.from_device(self.device)
self.process_group = process_group
@ -2495,6 +2495,7 @@ class FlatParamHandle:
###########
def flat_param_to(self, *args, **kwargs):
"""Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
# pyrefly: ignore [not-iterable]
self.flat_param.data = self.flat_param.to(*args, **kwargs)
if self._use_orig_params:
# Refresh the views because their storage may have changed

View File

@ -139,11 +139,14 @@ def _from_local_no_grad(
"""
if not compiled_autograd_enabled():
# pyrefly: ignore [bad-argument-type]
return DTensor(
# Use the local tensor directly instead of constructing a new tensor
# variable, e.g. with `view_as()`, since this is not differentiable
# pyrefly: ignore [bad-argument-count]
local_tensor,
sharding_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=local_tensor.requires_grad,
)
else:

View File

@ -107,9 +107,12 @@ class _ToTorchTensor(torch.autograd.Function):
)
return (
# pyrefly: ignore [bad-argument-type]
DTensor(
# pyrefly: ignore [bad-argument-count]
grad_output,
grad_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=grad_output.requires_grad,
),
None,
@ -175,11 +178,14 @@ class _FromTorchTensor(torch.autograd.Function):
)
# We want a fresh Tensor object that shares memory with the input tensor
# pyrefly: ignore [bad-argument-type]
dist_tensor = DTensor(
# pyrefly: ignore [bad-argument-count]
input.view_as(input),
dist_spec,
# requires_grad of the dist tensor depends on if input
# requires_grad or not
# pyrefly: ignore [unexpected-keyword]
requires_grad=input.requires_grad,
)
return dist_tensor
@ -304,9 +310,12 @@ class DTensor(torch.Tensor):
spec.placements,
tensor_meta=unflatten_tensor_meta,
)
# pyrefly: ignore [bad-argument-type]
return DTensor(
# pyrefly: ignore [bad-argument-count]
local_tensor,
unflatten_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=requires_grad,
)
@ -820,9 +829,12 @@ def distribute_tensor(
dtype=tensor.dtype,
),
)
# pyrefly: ignore [bad-argument-type]
return DTensor(
# pyrefly: ignore [bad-argument-count]
local_tensor.requires_grad_(tensor.requires_grad),
spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=tensor.requires_grad,
)
@ -1077,9 +1089,12 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
),
)
# pyrefly: ignore [bad-argument-type]
return DTensor(
# pyrefly: ignore [bad-argument-count]
local_tensor,
spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=kwargs["requires_grad"],
)

View File

@ -78,8 +78,11 @@ def found_inf_reduce_handler(
dtype=target_tensor.dtype,
),
)
# pyrefly: ignore [bad-argument-type]
found_inf_dtensor = dtensor.DTensor(
local_tensor=target_tensor, spec=spec, requires_grad=False
local_tensor=target_tensor, # pyrefly: ignore [unexpected-keyword]
spec=spec, # pyrefly: ignore [unexpected-keyword]
requires_grad=False, # pyrefly: ignore [unexpected-keyword]
)
found_inf = found_inf_dtensor.full_tensor()
target_tensor.copy_(found_inf)
@ -189,7 +192,7 @@ class OpDispatcher:
local_tensor_args = (
pytree.tree_unflatten(
cast(list[object], op_info.local_args),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
op_info.args_tree_spec,
)
if op_info.args_tree_spec
@ -366,7 +369,7 @@ class OpDispatcher:
resharded_local_tensor = redistribute_local_tensor(
local_tensor,
arg_spec,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
reshard_arg_spec,
)
new_local_args.append(resharded_local_tensor)
@ -439,7 +442,7 @@ class OpDispatcher:
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
op_call,
v,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
compute_mesh,
)
local_kwargs[k] = v
@ -456,7 +459,7 @@ class OpDispatcher:
OpSchema(
op_call,
(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
pytree.tree_unflatten(args_schema, args_spec)
if args_spec
else tuple(args_schema)
@ -478,6 +481,7 @@ class OpDispatcher:
assert isinstance(spec, DTensorSpec), (
f"output spec does not match with output! Expected DTensorSpec, got {spec}."
)
# pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword]
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
else:
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor

View File

@ -883,9 +883,12 @@ class Redistribute(torch.autograd.Function):
output = local_tensor
target_spec = current_spec
# pyrefly: ignore [bad-argument-type]
return dtensor.DTensor(
# pyrefly: ignore [bad-argument-count]
output,
target_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=input.requires_grad,
)
@ -944,9 +947,12 @@ class Redistribute(torch.autograd.Function):
dtype=output.dtype,
),
)
# pyrefly: ignore [bad-argument-type]
output_dtensor = dtensor.DTensor(
# pyrefly: ignore [bad-argument-count]
output,
spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=grad_output.requires_grad,
)

View File

@ -174,9 +174,12 @@ def _log_softmax_handler(
tensor_meta=output_tensor_meta,
)
# pyrefly: ignore [bad-argument-type]
return DTensor(
# pyrefly: ignore [bad-argument-count]
res,
res_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=res.requires_grad,
)
@ -251,7 +254,7 @@ def _nll_loss_forward(
if weight is not None:
new_shape = list(x.shape)
new_shape[channel_dim] = -1
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
w = w.expand(new_shape)
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
wsum = torch.where(target != ignore_index, wsum, 0)
@ -309,9 +312,9 @@ def _nll_loss_forward_handler(
output_placements = all_replicate_placements
# tensor inputs to _propagate_tensor_meta need to be DTensors
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
args = list(args)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
args[1], args[2] = target, weight
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
@ -330,9 +333,12 @@ def _nll_loss_forward_handler(
out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta)
return (
# pyrefly: ignore [bad-argument-type]
DTensor(
# pyrefly: ignore [bad-argument-count]
result,
out_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=result.requires_grad,
),
total_weight,
@ -442,11 +448,11 @@ def _nll_loss_backward_handler(
weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
# tensor inputs to _propagate_tensor_meta need to be DTensors
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
args = list(args)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
args[2], args[3] = target, weight
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
@ -470,9 +476,12 @@ def _nll_loss_backward_handler(
tensor_meta=output_tensor_meta,
)
# pyrefly: ignore [bad-argument-type]
return DTensor(
# pyrefly: ignore [bad-argument-count]
result,
out_spec,
# pyrefly: ignore [unexpected-keyword]
requires_grad=result.requires_grad,
)

View File

@ -949,7 +949,7 @@ def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
]
target = node.target if node.op in ("call_function", "get_attr") else ""
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
nodes_idx[id(node)] = i
return "\n".join(ret)
@ -1206,6 +1206,7 @@ class _ModuleFrame:
for k in kwargs_spec.context
}
assert self.parent_call_module is not None
# pyrefly: ignore [bad-assignment]
self.parent_call_module.args = tuple(arg_nodes)
self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment]
@ -1393,6 +1394,7 @@ class _ModuleFrame:
def print(self, *args, **kwargs):
if self.verbose:
# pyrefly: ignore [not-iterable]
print(*args, **kwargs)
def run_from(self, node_idx):
@ -1486,7 +1488,7 @@ class _ModuleFrame:
self.seen_attrs[self.child_fqn].add(node.target)
self.copy_node(node)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
node_idx += 1

View File

@ -1952,7 +1952,7 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
)
if isinstance(shape, (int, torch.SymInt)):
shape = torch.Size([shape]) # pyrefly: ignore # bad-argument-type
shape = torch.Size([shape]) # pyrefly: ignore [bad-argument-type]
else:
for dim in shape:
torch._check_type(

View File

@ -87,6 +87,7 @@ def _reify_object_slots(o, s):
@dispatch(slice, dict)
def _reify(o, s):
"""Reify a Python ``slice`` object"""
# pyrefly: ignore [not-iterable]
return slice(*reify((o.start, o.stop, o.step), s))

View File

@ -59,7 +59,7 @@ Argument = Optional[
BaseArgumentTypes,
]
]
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
ArgumentT = TypeVar("ArgumentT", bound=Argument)
_P = ParamSpec("_P")
_R = TypeVar("_R")
@ -385,7 +385,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put before this node. Must be a member of the same graph.
"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self._prepend(x)
@compatibility(is_backward_compatible=True)
@ -397,7 +397,7 @@ class Node(_NodeBase):
Args:
x (Node): The node to put after this node. Must be a member of the same graph.
"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self._next._prepend(x)
@property
@ -698,7 +698,8 @@ class Node(_NodeBase):
if replace_hooks:
for replace_hook in replace_hooks:
replace_hook(old=self, new=replace_with.name, user=use_node)
use_node._replace_input_with(self, replace_with)
# pyrefly: ignore [missing-attribute]
use_node._replace_input_with(self, replace_with) # type: ignore[attr-defined]
return result
@compatibility(is_backward_compatible=False)
@ -835,7 +836,8 @@ class Node(_NodeBase):
for replace_hook in m._replace_hooks:
replace_hook(old=old_input, new=new_input.name, user=self)
self._replace_input_with(old_input, new_input)
# pyrefly: ignore [missing-attribute]
self._replace_input_with(old_input, new_input) # type: ignore[attr-defined]
def _rename(self, candidate: str) -> None:
if candidate == self.name:

View File

@ -303,7 +303,7 @@ class StreamContext:
self.idx = _get_device_index(None, True)
if not torch.jit.is_scripting():
if self.idx is None:
self.idx = -1 # pyrefly: ignore # bad-assignment
self.idx = -1 # pyrefly: ignore [bad-assignment]
self.src_prev_stream = (
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)

View File

@ -46,7 +46,7 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
if canonicalize:
dim = canonicalize_dims(ndim, dim)
assert dim >= 0 and dim < ndim # pyrefly: ignore # unsupported-operation
assert dim >= 0 and dim < ndim # pyrefly: ignore [unsupported-operation]
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
# For other dims, subtract 1 to convert to inner space.

View File

@ -72,7 +72,7 @@ class _NormBase(Module):
torch.tensor(
0,
dtype=torch.long,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
),
)
@ -222,7 +222,7 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(
# affine and track_running_stats are hardcoded to False to
# avoid creating tensors that will soon be overwritten.
@ -236,29 +236,29 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
self.affine = affine
self.track_running_stats = track_running_stats
if self.affine:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.weight = UninitializedParameter(**factory_kwargs)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.bias = UninitializedParameter(**factory_kwargs)
if self.track_running_stats:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.running_mean = UninitializedBuffer(**factory_kwargs)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.running_var = UninitializedBuffer(**factory_kwargs)
self.num_batches_tracked = torch.tensor(
0,
dtype=torch.long,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
)
def reset_parameters(self) -> None:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if not self.has_uninitialized_params() and self.num_features != 0:
super().reset_parameters()
def initialize_parameters(self, input) -> None: # type: ignore[override]
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if self.has_uninitialized_params():
self.num_features = input.shape[1]
if self.affine:
@ -352,6 +352,7 @@ class BatchNorm1d(_BatchNorm):
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
# pyrefly: ignore [inconsistent-inheritance]
class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
@ -463,6 +464,7 @@ class BatchNorm2d(_BatchNorm):
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
# pyrefly: ignore [inconsistent-inheritance]
class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
@ -574,6 +576,7 @@ class BatchNorm3d(_BatchNorm):
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
# pyrefly: ignore [inconsistent-inheritance]
class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.

View File

@ -38,13 +38,13 @@ T = TypeVar("T", bound="Module")
class _IncompatibleKeys(
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
):
__slots__ = ()
def __repr__(self) -> str:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if not self.missing_keys and not self.unexpected_keys:
return "<All keys matched successfully>"
return super().__repr__()
@ -93,7 +93,7 @@ class _WrappedHook:
def __getstate__(self) -> dict:
result = {"hook": self.hook, "with_module": self.with_module}
if self.with_module:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
result["module"] = self.module()
return result
@ -979,7 +979,7 @@ class Module:
# Decrement use count of the gradient by setting to None
param.grad = None
param_applied = torch.nn.Parameter(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
param_applied,
requires_grad=param.requires_grad,
)
@ -992,13 +992,13 @@ class Module:
) from e
out_param = param
elif p_should_use_set_data:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
param.data = param_applied
out_param = param
else:
assert isinstance(param, Parameter)
assert param.is_leaf
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
out_param = Parameter(param_applied, param.requires_grad)
self._parameters[key] = out_param
@ -1337,7 +1337,9 @@ class Module:
"""
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
*args, **kwargs
# pyrefly: ignore [not-iterable]
*args,
**kwargs,
)
if dtype is not None:
@ -2256,7 +2258,7 @@ class Module:
if destination is None:
destination = OrderedDict()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
destination._metadata = OrderedDict()
local_metadata = dict(version=self._version)
@ -2407,7 +2409,7 @@ class Module:
}
local_name_params = itertools.chain(
self._parameters.items(),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
persistent_buffers.items(),
)
local_state = {k: v for k, v in local_name_params if v is not None}

View File

@ -27,6 +27,7 @@ def _verbose_printer(verbose: bool | None) -> Callable[..., None]:
"""Prints messages based on `verbose`."""
if verbose is False:
return lambda *_, **__: None
# pyrefly: ignore [not-iterable]
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)
@ -47,7 +48,7 @@ def _patch_dynamo_unsupported_functions():
# Replace torch.jit.isinstance with isinstance
jit_isinstance = torch.jit.isinstance
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
torch.jit.isinstance = isinstance
logger.info("Replaced torch.jit.isinstance with isinstance to allow dynamo tracing")
try:

View File

@ -132,10 +132,10 @@ class TorchTensor(ir.Tensor):
# view the tensor as that dtype so that it is convertible to NumPy,
# and then view it back to the proper dtype (using ml_dtypes obtained by
# calling dtype.numpy()).
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if self.dtype == ir.DataType.BFLOAT16:
return (
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
)
if self.dtype in {
@ -144,11 +144,11 @@ class TorchTensor(ir.Tensor):
ir.DataType.FLOAT8E5M2,
ir.DataType.FLOAT8E5M2FNUZ,
}:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
if self.dtype == ir.DataType.FLOAT4E2M1:
return _type_casting.unpack_float4x2_as_uint8(self.raw).view(
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.dtype.numpy()
)
@ -170,7 +170,7 @@ class TorchTensor(ir.Tensor):
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
raise TypeError(
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
"with a tensor backed by real data using ONNXProgram.apply_weights() "
"or save the model without initializers by setting include_initializers=False."
@ -251,7 +251,7 @@ def _set_shape_type(
if isinstance(dim, int):
dims.append(dim)
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
dims.append(str(dim.node))
# If the dtype is set already (e.g. by the onnx_symbolic ops),
@ -1232,7 +1232,7 @@ def _exported_program_to_onnx_program(
# so we need to get them from the name_* apis.
for name, torch_tensor in itertools.chain(
exported_program.named_parameters(),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
exported_program.named_buffers(),
exported_program.constants.items(),
):
@ -1265,6 +1265,7 @@ def _verbose_printer(verbose: bool | None) -> Callable[..., None]:
"""Prints messages based on `verbose`."""
if verbose is False:
return lambda *_, **__: None
# pyrefly: ignore [not-iterable]
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)

View File

@ -239,7 +239,7 @@ def _compare_onnx_pytorch_outputs_in_np(
if acceptable_error_percentage:
error_percentage = 1 - np.sum(
np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol)
) / np.prod(ort_out.shape) # pyrefly: ignore # missing-attribute
) / np.prod(ort_out.shape) # pyrefly: ignore [missing-attribute]
if error_percentage <= acceptable_error_percentage:
warnings.warn(
f"Suppressed AssertionError:\n{e}.\n"

View File

@ -13,6 +13,7 @@ from torch import optim
def partialclass(cls, *args, **kwargs): # noqa: D103
class NewCls(cls):
# pyrefly: ignore [not-iterable]
__init__ = partialmethod(cls.__init__, *args, **kwargs)
return NewCls

View File

@ -326,7 +326,7 @@ def gaussian(
requires_grad=requires_grad,
)
return torch.exp(-(k**2)) # pyrefly: ignore # unsupported-operation
return torch.exp(-(k**2)) # pyrefly: ignore [unsupported-operation]
@_add_docstr(

View File

@ -618,7 +618,7 @@ def _get_storage_from_sequence(sequence, dtype, device):
def _isint(x):
if HAS_NUMPY:
return isinstance(x, (int, np.integer)) # pyrefly: ignore # missing-attribute
return isinstance(x, (int, np.integer)) # pyrefly: ignore [missing-attribute]
else:
return isinstance(x, int)

View File

@ -77,6 +77,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
def pop(self) -> T:
if not self:
raise KeyError("pop from an empty set")
# pyrefly: ignore [bad-return]
return self._dict.popitem()[0]
def copy(self) -> OrderedSet[T]:
@ -158,7 +159,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
# MutableSet impl will iterate over other, iter over smaller of two sets
if isinstance(other, OrderedSet) and len(self) < len(other):
# pyrefly: ignore # unsupported-operation, bad-return
# pyrefly: ignore [unsupported-operation, bad-return]
return other & self
return cast(OrderedSet[T], super().__and__(other))

View File

@ -202,7 +202,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
Args:
device (int, optional): if specified, all parameters will be copied to that device
"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return self._apply(lambda t: getattr(t, custom_backend_name)(device))
_check_register_once(torch.nn.Module, custom_backend_name)
@ -252,11 +252,15 @@ def _generate_packed_sequence_methods_for_privateuse1_backend(
device (int, optional): if specified, all parameters will be copied to that device
"""
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
*args, **kwargs
# pyrefly: ignore [not-iterable]
*args,
**kwargs,
)
if ex.device.type == custom_backend_name:
# pyrefly: ignore [not-iterable]
return self.to(*args, **kwargs)
kwargs.update({"device": custom_backend_name})
# pyrefly: ignore [not-iterable]
return self.to(*args, **kwargs)
_check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name)

View File

@ -48,7 +48,7 @@ MATH_TRANSPILATIONS = collections.OrderedDict(
]
)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
CUDA_TYPE_NAME_MAP = collections.OrderedDict(
[
("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)),
@ -587,6 +587,7 @@ CUDA_TYPE_NAME_MAP = collections.OrderedDict(
]
)
# pyrefly: ignore [no-matching-overload]
CUDA_INCLUDE_MAP = collections.OrderedDict(
[
# since pytorch uses "\b{pattern}\b" as the actual re pattern,
@ -676,7 +677,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
]
)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
CUDA_IDENTIFIER_MAP = collections.OrderedDict(
[
("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)),
@ -8370,6 +8371,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
]
)
# pyrefly: ignore [no-matching-overload]
CUDA_SPECIAL_MAP = collections.OrderedDict(
[
# SPARSE
@ -8852,6 +8854,7 @@ CUDA_SPECIAL_MAP = collections.OrderedDict(
]
)
# pyrefly: ignore [no-matching-overload]
PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
[
("USE_CUDA", ("USE_ROCM", API_PYTORCH)),
@ -9316,6 +9319,7 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
]
)
# pyrefly: ignore [no-matching-overload]
CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
[
("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)),
@ -9401,6 +9405,7 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
#
# NB: if you want a transformation to ONLY apply to the c10/ directory,
# put it as API_CAFFE2
# pyrefly: ignore [no-matching-overload]
C10_MAPPINGS = collections.OrderedDict(
[
("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)),

View File

@ -120,6 +120,7 @@ class GeneratedFileCleaner:
def open(self, fn, *args, **kwargs):
if not os.path.exists(fn):
self.files_to_clean.add(os.path.abspath(fn))
# pyrefly: ignore [not-iterable]
return open(fn, *args, **kwargs)
def makedirs(self, dn, exist_ok=False):
@ -669,7 +670,7 @@ def is_caffe2_gpu_file(rel_filepath):
return True
filename = os.path.basename(rel_filepath)
_, ext = os.path.splitext(filename)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
class TrieNode:
@ -1145,7 +1146,7 @@ def hipify(
out_of_place_only=out_of_place_only,
is_pytorch_extension=is_pytorch_extension))
all_files_set = set(all_files)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for f in extra_files:
if not os.path.isabs(f):
f = os.path.join(output_directory, f)

View File

@ -292,9 +292,10 @@ class WeakIdKeyDictionary(MutableMapping):
if o is not None:
return o, value
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def pop(self, key, *args):
self._dirty_len = True
# pyrefly: ignore [not-iterable]
return self.data.pop(self.ref_type(key), *args) # CHANGED
def setdefault(self, key, default=None):

View File

@ -328,7 +328,7 @@ class StreamContext:
self.stream = stream
self.idx = _get_device_index(None, True)
if self.idx is None:
self.idx = -1 # pyrefly: ignore # bad-assignment
self.idx = -1 # pyrefly: ignore [bad-assignment]
def __enter__(self):
cur_stream = self.stream

View File

@ -126,7 +126,7 @@ class Event(torch._C._XpuEventBase):
"""
if stream is None:
stream = torch.xpu.current_stream()
super().record(stream) # pyrefly: ignore # bad-argument-type
super().record(stream) # pyrefly: ignore [bad-argument-type]
def wait(self, stream=None) -> None:
r"""Make all future work submitted to the given stream wait for this event.