Fix pyrefly ignore syntax (#166438)

Reformats pyrefly ignore suppressions so they only ignore one error code.

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166438
Approved by: https://github.com/Skylion007
This commit is contained in:
Maggie Moss 2025-10-29 00:02:16 +00:00 committed by PyTorch MergeBot
parent a9b29caeae
commit 31e42eb732
50 changed files with 178 additions and 178 deletions

View File

@ -1810,7 +1810,7 @@ def _check_not_implemented(cond, message=None): # noqa: F811
_check_with(
NotImplementedError,
cond,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
message,
)
@ -2669,7 +2669,7 @@ def compile(
dynamic=dynamic,
disable=disable,
guard_filter_fn=guard_filter_fn,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
)(model)(*args, **kwargs)
return export_wrapped_fn

View File

@ -382,7 +382,7 @@ def to_real_dtype(dtype: torch.dtype):
def mse_loss(
self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
) -> Tensor:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
loss = (self - target) ** 2
return apply_loss_reduction(loss, reduction)
@ -416,7 +416,7 @@ def smooth_l1_loss(
beta: float = 1.0,
):
loss = (self - target).abs()
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
return apply_loss_reduction(loss, reduction)
@ -4079,7 +4079,7 @@ def _nll_loss_forward(
return result, total_weight
if weight is not None:
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
w = w.expand(self.shape)
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
wsum = torch.where(target != ignore_index, wsum, 0)
@ -4896,9 +4896,9 @@ def _reflection_pad_backward(grad_output, x, padding):
@register_decomposition(aten.aminmax)
@out_wrapper("min", "max")
def aminmax(self, *, dim=None, keepdim=False):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
amin = torch.amin(self, dim=dim, keepdim=keepdim)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
amax = torch.amax(self, dim=dim, keepdim=keepdim)
return amin, amax
@ -5143,7 +5143,7 @@ def baddbmm(self, batch1, batch2, beta=1, alpha=1):
alpha = int(alpha)
result = torch.bmm(batch1, batch2)
if not isinstance(alpha, numbers.Number) or alpha != 1:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
result = result * alpha
if beta == 0:
return result

View File

@ -132,7 +132,7 @@ def _make_export_case(m, name, configs):
m.__doc__ is not None
), f"Could not find description or docstring for export case: {m}"
configs = {**configs, "description": m.__doc__}
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return ExportCase(**{**configs, "model": m, "name": name})

View File

@ -3,12 +3,12 @@ import torch
class MyAutogradFunction(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, x):
return x.clone()
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
return grad_output + 1

View File

@ -39,7 +39,7 @@ def get_class_if_classified_error(e: Exception) -> Optional[str]:
TorchRuntimeError: None,
}
if type(e) in _ALLOW_LIST:
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
attr_name = _ALLOW_LIST[type(e)]
if attr_name is None:
return ALWAYS_CLASSIFIED

View File

@ -101,7 +101,7 @@ class _KeyPathTrie:
assert len(kp) > 0
k, *kp = kp # type: ignore[assignment]
node = node[k]
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return node, kp
@ -356,12 +356,12 @@ def _override_builtin_ops():
original_min = builtins.min
original_pow = math.pow
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
builtins.max = functools.partial(
_tensor_min_max, real_callable=original_max, tensor_callable=torch.maximum
)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
builtins.min = functools.partial(
_tensor_min_max, real_callable=original_min, tensor_callable=torch.minimum
)
@ -1087,7 +1087,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
def run():
# Run sequence.
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
t = args[0]
for _method, _args in sequence:
t = _method(t, *_args)

View File

@ -188,7 +188,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
self.callback = callback
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def placeholder(
self,
target: str, # type: ignore[override]
@ -440,7 +440,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
)
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
interpreter = self.ExportInterpreter(self, graph_module)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
prev_interpreter, self.interpreter = (
self.interpreter,
torch.fx.Interpreter( # type: ignore[assignment]

View File

@ -32,7 +32,7 @@ def _node_metadata_hook(
that nodes being added are only call_function nodes, and copies over the
first argument node's nn_module_stack.
"""
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
fake_mode = fake_mode or contextlib.nullcontext()
assert node.op == "call_function" and callable(node.target), (
@ -48,7 +48,7 @@ def _node_metadata_hook(
fake_args, fake_kwargs = pytree.tree_map_only(
torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs)
)
# pyrefly: ignore # bad-context-manager
# pyrefly: ignore [bad-context-manager]
with fake_mode, enable_python_dispatcher():
fake_res = node.target(*fake_args, **fake_kwargs)
node.meta["val"] = fake_res
@ -84,7 +84,7 @@ def _node_metadata_hook(
"torch_fn",
(
f"{node.target.__name__}_0",
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
f"{node.target.__class__.__name__}.{node.target.__name__}",
),
)

View File

@ -567,7 +567,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
quantized = False
last_quantized_node = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for node in gm.graph.nodes:
if isinstance(node.target, OpOverload):
with gm.graph.inserting_before(node):
@ -630,7 +630,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
attr_names_to_clean.add(k)
if k == "_buffers":
buffer_name_to_clean = set()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for b_name, b_value in v.items():
if isinstance(b_value, torch.Tensor) and b_value.dtype in [
torch.qint8,
@ -638,7 +638,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
]:
buffer_name_to_clean.add(b_name)
for b_name in buffer_name_to_clean:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
v.pop(b_name, None)
for attr_name in attr_names_to_clean:
delattr(submod, attr_name)

View File

@ -35,7 +35,7 @@ def _replace_with_hop_helper(
)
call_func_node.meta["torch_fn"] = (
f"{wrap_hoo.__name__}",
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}",
)
if isinstance(output_args, (tuple, list)):

View File

@ -54,7 +54,7 @@ def _postprocess_serialized_shapes(
)
for k, v in sorted(dims.items())
}
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims)
if to_dict:
return _dataclass_to_dict(spec)
@ -184,7 +184,7 @@ def _dump_dynamic_shapes(
kwargs = kwargs or {}
if isinstance(dynamic_shapes, dict):
dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment]
# pyrefly: ignore # bad-assignment, bad-argument-type
# pyrefly: ignore [bad-assignment, bad-argument-type]
dynamic_shapes = tuple(dynamic_shapes)
combined_args = tuple(args) + tuple(kwargs.values())

View File

@ -623,9 +623,9 @@ class _Commit:
def update_schema():
import importlib.resources
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if importlib.resources.is_resource(__package__, "schema.yaml"):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
content = importlib.resources.read_text(__package__, "schema.yaml")
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
_check(match is not None, "checksum not found in schema.yaml")
@ -633,7 +633,7 @@ def update_schema():
checksum_head = match.group(1)
thrift_content = importlib.resources.read_text(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
__package__,
"export_schema.thrift",
)
@ -658,9 +658,9 @@ def update_schema():
src, cpp_header, thrift_schema = _staged_schema()
additions, subtractions = _diff_schema(dst, src)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
yaml_path = __package__.replace(".", "/") + "/schema.yaml"
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
thrift_schema_path = __package__.replace(".", "/") + "/export_schema.thrift"
torch_prefix = "torch/"
assert yaml_path.startswith(torch_prefix) # sanity check

View File

@ -383,7 +383,7 @@ def _reconstruct_fake_tensor(
fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta)
if is_parameter:
fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment]
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return fake_tensor
@ -2741,7 +2741,7 @@ class GraphModuleDeserializer(metaclass=Final):
serialized_node.metadata
)
assert arg is not None
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
fx_node.meta["val"] = tuple(meta_val)
self.serialized_name_to_node[fx_node.name] = fx_node
@ -3167,7 +3167,7 @@ def _dict_to_dataclass(cls, data):
_value = next(iter(data.values()))
assert isinstance(_type, str)
field_type = cls.__annotations__[_type]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return cls.create(**{_type: _dict_to_dataclass(field_type, _value)})
elif dataclasses.is_dataclass(cls):
fields = {}
@ -3474,23 +3474,23 @@ def _canonicalize_graph(
n.metadata.clear()
# Stage 4: Aggregate values.
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
sorted_tensor_values = dict(
sorted(graph.tensor_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
sorted_sym_int_values = dict(
sorted(graph.sym_int_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
sorted_sym_float_values = dict(
sorted(graph.sym_float_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
sorted_sym_bool_values = dict(
sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0))
)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
sorted_custom_obj_values = dict(
sorted(graph.custom_obj_values.items(), key=operator.itemgetter(0))
)
@ -3547,7 +3547,7 @@ def canonicalize(
ExportedProgram: The canonicalized exported program.
"""
ep = copy.deepcopy(ep)
# pyrefly: ignore # annotation-mismatch
# pyrefly: ignore [annotation-mismatch]
constants: set[str] = constants or set()
opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0)))

View File

@ -1120,14 +1120,14 @@ def placeholder_naming_pass(
if ( # handle targets for custom objects
spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map
):
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
spec.target = name_map[spec.target][4:] # strip obj_ prefix
for spec in export_graph_signature.output_specs:
if spec.arg.name in name_map:
spec.arg.name = name_map[spec.arg.name]
if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map:
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
spec.target = name_map[spec.target]
# rename keys in constants dict for custom objects

View File

@ -384,7 +384,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
class AOTAutogradCachePickler(FxGraphCachePickler):
def __init__(self, gm: torch.fx.GraphModule):
super().__init__(gm)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
self.dispatch_table: dict
self.dispatch_table.update(
{

View File

@ -86,10 +86,10 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
memory_format = MemoryFormatMeta.from_tensor(out)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if memory_format.memory_format is not None:
was = out
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
out = out.contiguous(memory_format=memory_format.memory_format)
updated = was is not out
@ -119,7 +119,7 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined]
if is_subclass:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
attrs = out.__tensor_flatten__()[0]
for attr in attrs:
@ -129,7 +129,7 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
new_elem_memory_format,
elem_updated,
) = coerce_tangent_and_suggest_memory_format(elem)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
out_memory_format.append(new_elem_memory_format)
if elem_updated:
setattr(out, attr, new_elem)
@ -493,7 +493,7 @@ def run_functionalized_fw_and_collect_metadata(
curr_storage in inp_storage_refs
and not functional_tensor_storage_changed
):
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
base_idx = inp_storage_refs[curr_storage]
is_input_tensor = id(o) in inp_tensor_ids
num_aliased_outs = out_tensor_alias_counts[curr_storage]
@ -701,7 +701,7 @@ from a multi-output view call"
# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
# are *regenerated* later, and not used directly in the autograd graph
def _plain_fake_tensor_like_subclass(x):
# pyrefly: ignore # bad-context-manager
# pyrefly: ignore [bad-context-manager]
with detect_fake_mode():
return torch.empty(
x.shape, dtype=x.dtype, device=x.device, layout=x.layout

View File

@ -78,7 +78,7 @@ def get_all_input_and_grad_nodes(
continue
if isinstance(desc, SubclassGetAttrAOTInput):
_raise_autograd_subclass_not_implemented(n, desc)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
input_index[desc] = (n, None)
elif n.op == "output":
assert "desc" in n.meta, (n, n.meta)
@ -130,7 +130,7 @@ def get_all_output_and_tangent_nodes(
continue
if isinstance(sub_d, SubclassGetAttrAOTOutput):
_raise_autograd_subclass_not_implemented(sub_n, sub_d)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
output_index[sub_d] = (sub_n, None)
for n in g.nodes:
if n.op == "placeholder":

View File

@ -1310,12 +1310,12 @@ def aot_dispatch_subclass(
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
meta_updated = run_functionalized_fw_and_collect_metadata(
without_output_descs(metadata_fn),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
flat_args_descs=primals_unwrapped_descs,
static_input_indices=remapped_static_indices,
keep_input_mutations=meta.keep_input_mutations,
is_train=meta.is_train,
# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
)(*primals_unwrapped)
subclass_meta.fw_metadata = meta_updated

View File

@ -538,7 +538,7 @@ def collect_bw_donated_buffer_idxs(
fw_ins,
user_fw_outs,
bw_outs,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
saved_tensors,
)
@ -1523,7 +1523,7 @@ def _aot_stage2a_partition(
# apply joint_gm callback here
if callable(torch._functorch.config.joint_custom_pass):
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
static_lifetime_input_indices = fw_metadata.static_input_indices
@ -1794,7 +1794,7 @@ def _aot_stage2b_bw_compile(
ph_size = ph_arg.size()
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
placeholder_list[i] = ph_arg.as_strided(ph_size, real_stride)
compiled_bw_func = None
if (

View File

@ -225,7 +225,7 @@ def make_output_handler(info, runtime_metadata, trace_joint):
# not sure why AOTDispatcher needs to manually set this
def maybe_mark_dynamic_helper(t: torch.Tensor, dims: set[int]):
if hasattr(t, "_dynamo_weak_dynamic_indices"):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
t._dynamo_weak_dynamic_indices |= dims
else:
t._dynamo_weak_dynamic_indices = dims.copy() # type: ignore[attr-defined]
@ -1143,7 +1143,7 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]:
f_args_inner = []
# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
for inner_idx_or_tuple in synthetic_base_info:
if isinstance(inner_idx_or_tuple, int):
f_args_inner.append(primals[inner_idx_or_tuple])
@ -2114,7 +2114,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
return (ctx._autograd_function_id, *ctx.symints)
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, *deduped_flat_tensor_args):
args = deduped_flat_tensor_args
if backward_state_indices:
@ -2151,7 +2151,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
# in the fw output order.
fw_outs = call_func_at_runtime_with_args(
CompiledFunction.compiled_fw,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
args,
disable_amp=disable_amp,
)
@ -2347,7 +2347,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
_aot_id = aot_config.aot_id
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(double_ctx, *unused_args):
return impl_fn(double_ctx)

View File

@ -196,7 +196,7 @@ class MemoryFormatMeta:
if use_memory_format:
return MemoryFormatMeta(
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
memory_format=torch._prims_common.suggest_memory_format(t),
)
@ -893,15 +893,15 @@ class GraphSignature:
parameters_to_mutate = {}
for output_name, mutation_name in outputs_to_mutations.items():
if mutation_name in user_inputs:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
user_inputs_to_mutate[output_name] = mutation_name
else:
assert mutation_name in buffers or mutation_name in parameters
if mutation_name in buffers:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
buffers_to_mutate[output_name] = mutation_name
else:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
parameters_to_mutate[output_name] = mutation_name
start, stop = stop, stop + num_user_outputs
@ -1236,9 +1236,9 @@ class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
output_code_ty: type[TOutputCode],
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
):
# pyrefly: ignore # invalid-type-var
# pyrefly: ignore [invalid-type-var]
self.output_code_ty = output_code_ty
# pyrefly: ignore # invalid-type-var
# pyrefly: ignore [invalid-type-var]
self.compiler_fn = compiler_fn
def __call__(

View File

@ -90,7 +90,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
"""
for name, tensor in itertools.chain(
list(module.named_parameters(recurse=False)),
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
list(module.named_buffers(recurse=False)),
):
if is_traceable_wrapper_subclass(tensor):

View File

@ -237,7 +237,7 @@ def unwrap_tensor_subclasses(
n_desc: Any = (
SubclassGetAttrAOTInput(desc, attr)
if isinstance(desc, AOTInput)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
else SubclassGetAttrAOTOutput(desc, attr)
)
flatten_subclass(inner_tensor, n_desc, out=out)
@ -258,7 +258,7 @@ def unwrap_tensor_subclasses(
descs_inner: list[AOTDescriptor] = []
for x, desc in zip(wrapped_args, wrapped_args_descs):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
flatten_subclass(typing.cast(Tensor, x), desc, out=(xs_inner, descs_inner))
return xs_inner, descs_inner
@ -283,7 +283,7 @@ def runtime_unwrap_tensor_subclasses(
for attr in attrs:
inner_tensor = getattr(x, attr)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
inner_meta = meta.attrs.get(attr)
flatten_subclass(inner_tensor, inner_meta, out=out)

View File

@ -331,7 +331,7 @@ def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
and out.args[1] == 0
and out.args[0] in with_effect_nodes
):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
output_token_nodes.append(out)
else:
other_output_nodes.append(out)
@ -573,10 +573,10 @@ def without_output_descs(f: Callable[_P, tuple[_T, _S]]) -> Callable[_P, _T]:
@wraps(f)
@simple_wraps(f)
def inner(*args, **kwargs):
# pyrefly: ignore # invalid-param-spec
# pyrefly: ignore [invalid-param-spec]
return f(*args, **kwargs)[0]
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return inner

View File

@ -753,7 +753,7 @@ class AutogradFunctionApply(HigherOrderOperator):
class ApplyTemplate(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, *args):
nonlocal saved_values
output, saved_values = fwd(None, *fwd_args)

View File

@ -42,9 +42,9 @@ def create_names_map(
This function creates a mapping from the names in named_params to the
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
"""
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
named_params = dict(named_params)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
tied_named_params = dict(tied_named_params)
tensors_dict_keys = set(named_params.keys())
@ -53,11 +53,11 @@ def create_names_map(
tensor_to_mapping: dict[Tensor, tuple[str, list[str]]] = {}
for key, tensor in named_params.items():
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
tensor_to_mapping[tensor] = (key, [])
for key, tensor in tied_named_params.items():
assert tensor in tensor_to_mapping
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
tensor_to_mapping[tensor][1].append(key)
return dict(tensor_to_mapping.values())

View File

@ -199,7 +199,7 @@ def _extract_graph_with_inputs_outputs(
new_node = new_graph.placeholder(node.name)
# Can't use node_copy here as we may be turning previous call_function into placeholders
new_node.meta = node.meta
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
env[node] = new_node
for node in joint_graph.nodes:
@ -228,10 +228,10 @@ def _extract_graph_with_inputs_outputs(
if any(all_args):
env[node] = InvalidNode # type: ignore[assignment]
continue
# pyrefly: ignore # unsupported-operation, bad-argument-type
# pyrefly: ignore [unsupported-operation, bad-argument-type]
env[node] = new_graph.node_copy(node, lambda x: env[x])
elif node.op == "get_attr":
# pyrefly: ignore # unsupported-operation, bad-argument-type
# pyrefly: ignore [unsupported-operation, bad-argument-type]
env[node] = new_graph.node_copy(node, lambda x: env[x])
elif node.op == "output":
pass
@ -628,7 +628,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
]
# add the scale nodes to the output find the first sym_node in the output
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
idx = find_first_sym_node(output_updated_args)
scale_nodes = tensor_scale_nodes + sym_scale_nodes
if scale_nodes:
@ -1234,7 +1234,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
# critical path first.
cur_nodes += node.all_input_nodes
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n])
for node in insertable_nodes:
env[node] = new_graph.node_copy(node, lambda x: env[x])
@ -1463,14 +1463,14 @@ def functionalize_rng_ops(
devices = OrderedSet(
get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values()
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
devices.discard(torch.device("cpu"))
# multiple cuda devices won't work with cudagraphs anyway,
# fallback to non graphsafe rng checkpointing
multi_cuda_devices = len(devices) > 1
# this changes numerics, so if fallback_random is set we will not use it
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
ind_config = torch._inductor.config
use_rng_graphsafe_rng_functionalization = (
config.graphsafe_rng_functionalization
@ -2902,7 +2902,7 @@ def min_cut_rematerialization_partition(
node_info,
memory_budget=memory_budget,
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
if config._sync_decision_cross_ranks:
saved_values = _sync_decision_cross_ranks(joint_graph, saved_values)
# save_for_backward on tensors and stashes symints in autograd .ctx
@ -2913,7 +2913,7 @@ def min_cut_rematerialization_partition(
fw_module, bw_module = _extract_fwd_bwd_modules(
joint_module,
saved_values,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
saved_sym_nodes=saved_sym_nodes,
num_fwd_outputs=num_fwd_outputs,
static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,

View File

@ -131,7 +131,7 @@ class VmapInterpreter(FuncTorchInterpreter):
self._cdata = cdata
@cached_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _cptr(self):
return CVmapInterpreterPtr(self._cdata)
@ -171,7 +171,7 @@ class GradInterpreter(FuncTorchInterpreter):
self._cdata = cdata
@cached_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _cptr(self):
return CGradInterpreterPtr(self._cdata)
@ -209,7 +209,7 @@ class JvpInterpreter(FuncTorchInterpreter):
self._cdata = cdata
@cached_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _cptr(self):
return CJvpInterpreterPtr(self._cdata)
@ -246,7 +246,7 @@ class FunctionalizeInterpreter(FuncTorchInterpreter):
self._cdata = cdata
@cached_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _cptr(self):
return CFunctionalizeInterpreterPtr(self._cdata)

View File

@ -57,7 +57,7 @@ def _interleave(a, b, dim=0):
stacked = torch.stack([a, b], dim=dim + 1)
interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
if b_trunc:
# TODO: find torch alternative for slice_along dim for torch.jit.script to work
interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1)
@ -97,7 +97,7 @@ class AssociativeScanOp(HigherOrderOperator):
validate_subgraph_args_types(additional_inputs)
return super().__call__(combine_fn, xs, additional_inputs)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def gen_schema(self, combine_fn, xs, additional_inputs):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import materialize_as_graph
@ -650,7 +650,7 @@ class AssociativeScanAutogradOp(torch.autograd.Function):
"""
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(
ctx,
combine_fn,

View File

@ -610,7 +610,7 @@ def do_auto_functionalize_v2(
normalized_kwargs = {}
schema = op._schema
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
op = op._op if isinstance(op, HopInstance) else op
assert isinstance(op, get_args(_MutableOpType))

View File

@ -170,7 +170,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
out = self(functionalized_subgraph, *unwrapped_operands, **kwargs)
return ctx.wrap_tensors(out)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def gen_schema(self, subgraph, *operands, **kwargs):
from .schema import HopSchemaGenerator
@ -216,7 +216,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
class BaseHOPFunction(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, hop, subgraph, kwargs, *operands):
ctx.hop = hop
ctx.operands = operands

View File

@ -52,7 +52,7 @@ class CondOp(HigherOrderOperator):
validate_subgraph_args_types(operands)
return super().__call__(pred, true_fn, false_fn, operands)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def gen_schema(self, pred, true_fn, false_fn, operands):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import materialize_as_graph
@ -286,7 +286,7 @@ def cond_op_dense(pred, true_fn, false_fn, operands):
class CondAutogradOp(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(
ctx,
pred,

View File

@ -298,5 +298,5 @@ def handle_effects(
assert isinstance(wrapped_token, torch.Tensor)
tokens[key] = wrapped_token
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return ctx.wrap_tensors(unwrapped_outs)

View File

@ -354,7 +354,7 @@ def trace_flex_attention(
score_mod_other_buffers,
mask_mod_other_buffers,
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", flex_attention, proxy_args, {}
@ -363,7 +363,7 @@ def trace_flex_attention(
example_out,
out_proxy,
constant=None,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
tracer=proxy_mode.tracer,
)
@ -626,7 +626,7 @@ def create_fw_bw_graph(
class FlexAttentionAutogradOp(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(
ctx: Any,
query: Tensor,
@ -1075,7 +1075,7 @@ def trace_flex_attention_backward(
score_mod_other_buffers,
mask_mod_other_buffers,
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function",
@ -1088,7 +1088,7 @@ def trace_flex_attention_backward(
example_out,
out_proxy,
constant=None,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
tracer=proxy_mode.tracer,
)

View File

@ -86,7 +86,7 @@ class InvokeSubgraphHOP(HigherOrderOperator):
return super().__call__(subgraph, identifier, *operands)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def gen_schema(self, subgraph, identifier, *operands):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import (
@ -402,7 +402,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
"""
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(
ctx,
subgraph,
@ -479,7 +479,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
for tangent in filtered_grad_outs:
metadata = extract_tensor_metadata(tangent)
metadata._flatten_into(tangent_metadata, fake_mode, state)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
tangent_metadata = tuple(tangent_metadata)
# bw_graph is a joint graph with signature (*primals_and_tangents) and

View File

@ -387,7 +387,7 @@ def create_hop_fw_bw(
class LocalMapAutogradOp(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(
ctx: Any,
fw_gm: GraphModule,
@ -440,7 +440,7 @@ class LocalMapAutogradOp(torch.autograd.Function):
)
for i, meta in ctx.expected_tangent_metadata.items():
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
grads[i] = coerce_to_expected_memory_format(grads[i], meta)
grad_ins = local_map_hop(ctx.bw_gm, *saved_activations, *grads)

View File

@ -125,7 +125,7 @@ def map(
class MapAutogradOp(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, f, num_mapped_args, *flat_args):
ctx._f = f
ctx._num_mapped_args = num_mapped_args

View File

@ -241,7 +241,7 @@ class ScanOp(HigherOrderOperator):
validate_subgraph_args_types(additional_inputs)
return super().__call__(combine_fn, init, xs, additional_inputs)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def gen_schema(self, combine_fn, init, xs, additional_inputs):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import materialize_as_graph
@ -449,7 +449,7 @@ class ScanAutogradOp(torch.autograd.Function):
"""
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(
ctx,
hop_partitioned_graph,

View File

@ -292,7 +292,7 @@ def generate_ttir(
ordered_args[name] = 2
elif (
stable_meta := maybe_unpack_tma_stable_metadata(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
tma_descriptor_metadata.get(name, None)
)
) is not None:
@ -426,7 +426,7 @@ def generate_ttir(
specialize_value=not kp.do_not_specialize,
align=not kp.do_not_specialize_on_alignment,
)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
attrvals.append(spec[1])
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
@ -445,7 +445,7 @@ def generate_ttir(
def get_signature_value(idx: int, arg: Any) -> str:
if kernel.params[idx].is_constexpr:
return "constexpr"
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
return mangle_type(arg)
else:
@ -819,7 +819,7 @@ def get_tma_stores(
for op in op_list:
if op.name == "tt.call":
assert op.fn_call_name in functions
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
tma_stores = get_tma_stores(functions, op.fn_call_name)
for i, inp in enumerate(op.args):
if Param(idx=i) in tma_stores:
@ -901,7 +901,7 @@ def analyze_kernel_mutations(
assert op.fn_call_name in functions
mutations = analyze_kernel_mutations(
functions,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
op.fn_call_name,
len(op.args),
)
@ -956,7 +956,7 @@ def identify_mutated_tensors(
assert functions is not None
kernel_name = next(iter(functions.keys()))
# Triton codegen modifies the name
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
assert kernel.fn.__name__ in kernel_name
# Reset the cache between top level invocations
# The cache for analyze kernel mutations is mainly used for cycle
@ -1060,9 +1060,9 @@ def triton_kernel_wrapper_mutation_dense(
grid_fn = grid[0]
else:
fn_name, code = user_defined_kernel_grid_fn_code(
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
kernel.fn.__name__,
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
kernel.configs,
grid,
)
@ -1113,7 +1113,7 @@ def triton_kernel_wrapper_mutation_dense(
# avoid mutating the original inputs
kwargs = kwargs.copy()
constant_args = constant_args.copy()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for name in kernel.arg_names:
if name in kwargs:
args.append(kwargs.pop(name))
@ -1122,7 +1122,7 @@ def triton_kernel_wrapper_mutation_dense(
else:
break
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
kernel[grid_fn](*args, **kwargs, **constant_args)
@ -1528,7 +1528,7 @@ class TritonHOPifier:
assert kernel_idx is None or variable.kernel_idx == kernel_idx
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
variable.grid = grid
if isinstance(kernel, Autotuner):
@ -2076,7 +2076,7 @@ class TraceableTritonKernelWrapper:
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
else:
assert self.kernel is not None
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return self.kernel.run(*args, **kwargs)
def __call__(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
@ -2088,7 +2088,7 @@ class TraceableTritonKernelWrapper:
)
else:
assert self.kernel is not None
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
return self.kernel[self.grid](*args, **kwargs)
def specialize_symbolic(self, arg: Sequence[Any]) -> Any:

View File

@ -270,7 +270,7 @@ def _set_compilation_env():
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
# once we are confident fx tracing works with dynamo.
torch.fx._symbolic_trace._is_fx_tracing_flag = False
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
torch._dynamo.config.allow_empty_graphs = True
torch._dynamo.config.capture_scalar_outputs = True
yield
@ -441,7 +441,7 @@ def unique_graph_name_with_root(
) -> tuple[int, str]:
next_name = None
i = 0
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
while not next_name:
candidate = f"{prefix}_{i}"
if hasattr(root, candidate):
@ -798,7 +798,7 @@ def create_bw_fn(
from torch._functorch.aot_autograd import AOTConfig, create_joint
# pyrefly: ignore # missing-module-attribute
# pyrefly: ignore [missing-module-attribute]
from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad
dummy_aot_config = AOTConfig(
@ -943,7 +943,7 @@ def check_input_alias_and_mutation(
out_out_alias_map,
mutated_inputs,
) = check_input_alias_and_mutation_return_outputs(gm)[:-1]
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs

View File

@ -54,7 +54,7 @@ class WhileLoopOp(HigherOrderOperator):
validate_subgraph_args_types(additional_inputs)
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def gen_schema(self, cond_fn, body_fn, carried_inputs, additional_inputs):
from torch._higher_order_ops.schema import HopSchemaGenerator
from torch._higher_order_ops.utils import materialize_as_graph
@ -431,7 +431,7 @@ def while_loop_tracing(
elif isinstance(x, torch.Tensor):
x = x.clone()
if hasattr(x, "constant") and x.constant is not None:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
x.constant = None
return x
@ -454,7 +454,7 @@ def while_loop_tracing(
next_name = None
i = 0
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
while not next_name:
candidate = f"while_loop_cond_graph_{i}"
if hasattr(proxy_mode.tracer.root, candidate):
@ -699,7 +699,7 @@ class WhileLoopStackOutputOp(HigherOrderOperator):
class WhileLoopAutogradOp(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(
ctx,
cond_fn,
@ -729,7 +729,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
ctx.additional_inputs = additional_inputs
ctx.fw_outputs = fw_outputs
loop_count = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for out in fw_outputs:
if isinstance(out, torch.Tensor):
if loop_count is not None:
@ -883,7 +883,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
while_loop_op(
cond_gm,
body_gm,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
(
init_idx,
*init_grad_carries,

View File

@ -844,7 +844,7 @@ def ignore(drop=False, **kwargs):
# @torch.jit.ignore
# def fn(...):
fn = drop
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
fn._torchscript_modifier = FunctionModifiers.IGNORE
return fn
@ -1255,7 +1255,7 @@ def _get_named_tuple_properties(
obj_annotations = inspect.get_annotations(obj)
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
obj_annotations = inspect.get_annotations(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
obj.__base__
)
@ -1447,7 +1447,7 @@ def container_checker(obj, target_type) -> bool:
return False
return True
elif origin_type is Union or issubclass(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
origin_type,
BuiltinUnionType,
): # also handles Optional

View File

@ -176,7 +176,7 @@ def infer_schema(
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
seen_args.add(name)
if param.default is inspect.Parameter.empty:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
params.append(f"{schema_type} {name}")
else:
default_repr = None
@ -194,7 +194,7 @@ def infer_schema(
f"Parameter {name} has an unsupported default value type {type(param.default)}. "
f"Please file an issue on GitHub so we can prioritize this."
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
params.append(f"{schema_type} {name}={default_repr}")
if mutates_args != UNKNOWN_MUTATES:
mutates_args_not_seen = set(mutates_args) - seen_args
@ -221,7 +221,7 @@ def derived_types(
):
result: list[tuple[Union[type, typing._SpecialForm, GenericAlias], str]] = [
(base_type, cpp_type),
# pyrefly: ignore # not-a-type
# pyrefly: ignore [not-a-type]
(typing.Optional[base_type], f"{cpp_type}?"),
]
@ -240,7 +240,7 @@ def derived_types(
if optional_base_list:
result.extend(
(seq_typ, f"{cpp_type}?[]")
# pyrefly: ignore # not-a-type
# pyrefly: ignore [not-a-type]
for seq_typ in derived_seq_types(typing.Optional[base_type])
)
if optional_list_base:
@ -252,7 +252,7 @@ def derived_types(
def get_supported_param_types():
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
data: list[tuple[Union[type, typing._SpecialForm], str, bool, bool, bool]] = [
# (python type, schema type, type[] variant, type?[] variant, type[]? variant
(Tensor, "Tensor", True, True, False),
@ -296,7 +296,7 @@ def parse_return(annotation, error_fn):
f"Return has unsupported type {annotation}. "
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
)
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
return SUPPORTED_RETURN_TYPES[annotation]
args = typing.get_args(annotation)

View File

@ -1050,7 +1050,7 @@ class LOBPCG:
return torch.matmul(
U * d_col.mT,
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
Z * E**-0.5,
)

View File

@ -1246,7 +1246,7 @@ class OpOverloadPacket(Generic[_P, _T]):
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
# intercept it here and call TorchBindOpverload instead.
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return _call_overload_packet_from_python(self, *args, **kwargs)
return self._op(*args, **kwargs)

View File

@ -427,7 +427,7 @@ def _prim_elementwise_meta(
# Acquires the device (if it exists) or number
device = None
number = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for arg in args_:
if isinstance(arg, TensorLike):
if utils.is_cpu_scalar_tensor(arg):
@ -1012,10 +1012,10 @@ def _div_aten(a, b):
)
if is_integral:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch.div(a, b, rounding_mode="trunc")
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch.true_divide(a, b)

View File

@ -125,7 +125,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
# Unless we are in prims_mode, in which case we want to use nvprims
if orig_func in torch_function_passthrough or orig_func in all_prims():
with self.prims_mode_cls():
# pyrefly: ignore # invalid-param-spec
# pyrefly: ignore [invalid-param-spec]
return orig_func(*args, **kwargs)
mapping = torch_to_refs_map()
func = mapping.get(orig_func, None)
@ -148,7 +148,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
if func is not None:
# If the ref exists query whether we should use it or not
if self.should_fallback_fn(self, orig_func, func, args, kwargs):
# pyrefly: ignore # invalid-param-spec
# pyrefly: ignore [invalid-param-spec]
return orig_func(*args, **kwargs)
# torch calls inside func should be interpreted as refs calls
with self:
@ -157,5 +157,5 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
raise RuntimeError(
f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
)
# pyrefly: ignore # invalid-param-spec
# pyrefly: ignore [invalid-param-spec]
return orig_func(*args, **kwargs)

View File

@ -29,7 +29,7 @@ def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None):
rngprim_def = torch.library.custom_op(
"rngprims::" + name, impl_aten, mutates_args=(), schema=schema
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
rngprim_def.register_fake(impl_meta)
prim_packet = getattr(torch._ops.ops.rngprims, name)
@ -330,11 +330,11 @@ def register_graphsafe_run_with_rng_state_op():
@graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(op, *args, rng_state=None, **kwargs):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
device_idx = rng_state.device.index
generator = torch.cuda.default_generators[device_idx]
current_state = generator.graphsafe_get_state()
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
generator.graphsafe_set_state(rng_state)
out = op(*args, **kwargs)
generator.graphsafe_set_state(current_state)

View File

@ -396,7 +396,7 @@ def is_contiguous_for_memory_format( # type: ignore[return]
*,
memory_format: torch.memory_format,
false_if_dde=False,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> bool:
validate_memory_format(memory_format)
@ -820,13 +820,13 @@ def canonicalize_dims(
rank: int,
indices: Sequence[int],
wrap_scalar: bool = True,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> tuple[int, ...]:
pass
@overload
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:
pass
@ -873,7 +873,7 @@ def check_same_device(*args, allow_cpu_scalar_tensors):
# Note: cannot initialize device to the first arg's device (it may not have one)
device = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for arg in args:
if isinstance(arg, Number):
continue
@ -921,7 +921,7 @@ def check_same_shape(*args, allow_cpu_scalar_tensors: bool):
"""
shape = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for arg in args:
if isinstance(arg, Number):
continue
@ -948,7 +948,7 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
shape = None
scalar_shape = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for arg in args:
if isinstance(arg, Number):
continue
@ -1005,7 +1005,7 @@ def extract_shape_from_varargs(
# Handles tuple unwrapping
if len(shape) == 1 and isinstance(shape[0], Sequence):
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
shape = shape[0]
if validate:
@ -1301,7 +1301,7 @@ def get_higher_dtype(
raise RuntimeError("Unexpected type given to _extract_dtype!")
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
a, b = _extract_dtype(a), _extract_dtype(b)
if a is b:
@ -1397,7 +1397,7 @@ def check_same_dtype(*args):
full_dtype = None
scalar_type = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for arg in args:
if isinstance(arg, Number):
# Scalar type checking is disabled (and may be removed in the future)
@ -1668,10 +1668,10 @@ def elementwise_dtypes(
# Prefers dtype of tensors with one or more dimensions
if one_plus_dim_tensor_dtype is not None:
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return one_plus_dim_tensor_dtype
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return zero_dim_tensor_dtype
if highest_type is float:

View File

@ -28,19 +28,19 @@ _P = ParamSpec("_P")
@overload
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
pass
@overload
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
pass
@overload
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
pass
@ -280,7 +280,7 @@ def out_wrapper(
if is_tensor
else NamedTuple(
f"return_types_{fn.__name__}",
# pyrefly: ignore # bad-argument-count
# pyrefly: ignore [bad-argument-count]
[(o, TensorLikeType) for o in out_names],
)
)
@ -299,7 +299,7 @@ def out_wrapper(
kwargs[k] = out_attr
def maybe_check_copy_devices(out):
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
if isinstance(out, TensorLike) and isinstance(args[0], TensorLike):
check_copy_devices(copy_from=args[0], copy_to=out)
@ -435,7 +435,7 @@ def backwards_not_supported(prim):
class BackwardsNotSupported(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, args_spec, *flat_args):
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
return redispatch_prim(args, kwargs)
@ -484,14 +484,14 @@ def elementwise_unary_scalar_wrapper(
dtype = utils.type_to_dtype(type(args[0]))
args_ = list(args)
args_[0] = torch.tensor(args[0], dtype=dtype)
# pyrefly: ignore # invalid-param-spec
# pyrefly: ignore [invalid-param-spec]
result = fn(*args_, **kwargs)
assert isinstance(result, torch.Tensor)
return result.item()
# pyrefly: ignore # invalid-param-spec
# pyrefly: ignore [invalid-param-spec]
return fn(*args, **kwargs)
_fn.__signature__ = sig # type: ignore[attr-defined]
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return _fn