mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix pyrefly ignore syntax (#166438)
Reformats pyrefly ignore suppressions so they only ignore one error code. pyrefly check lintrunner Pull Request resolved: https://github.com/pytorch/pytorch/pull/166438 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
a9b29caeae
commit
31e42eb732
|
|
@ -1810,7 +1810,7 @@ def _check_not_implemented(cond, message=None): # noqa: F811
|
|||
_check_with(
|
||||
NotImplementedError,
|
||||
cond,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
message,
|
||||
)
|
||||
|
||||
|
|
@ -2669,7 +2669,7 @@ def compile(
|
|||
dynamic=dynamic,
|
||||
disable=disable,
|
||||
guard_filter_fn=guard_filter_fn,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
)(model)(*args, **kwargs)
|
||||
|
||||
return export_wrapped_fn
|
||||
|
|
|
|||
|
|
@ -382,7 +382,7 @@ def to_real_dtype(dtype: torch.dtype):
|
|||
def mse_loss(
|
||||
self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
|
||||
) -> Tensor:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
loss = (self - target) ** 2
|
||||
return apply_loss_reduction(loss, reduction)
|
||||
|
||||
|
|
@ -416,7 +416,7 @@ def smooth_l1_loss(
|
|||
beta: float = 1.0,
|
||||
):
|
||||
loss = (self - target).abs()
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
|
||||
return apply_loss_reduction(loss, reduction)
|
||||
|
||||
|
|
@ -4079,7 +4079,7 @@ def _nll_loss_forward(
|
|||
return result, total_weight
|
||||
|
||||
if weight is not None:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
w = w.expand(self.shape)
|
||||
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
|
||||
wsum = torch.where(target != ignore_index, wsum, 0)
|
||||
|
|
@ -4896,9 +4896,9 @@ def _reflection_pad_backward(grad_output, x, padding):
|
|||
@register_decomposition(aten.aminmax)
|
||||
@out_wrapper("min", "max")
|
||||
def aminmax(self, *, dim=None, keepdim=False):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
amin = torch.amin(self, dim=dim, keepdim=keepdim)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
amax = torch.amax(self, dim=dim, keepdim=keepdim)
|
||||
return amin, amax
|
||||
|
||||
|
|
@ -5143,7 +5143,7 @@ def baddbmm(self, batch1, batch2, beta=1, alpha=1):
|
|||
alpha = int(alpha)
|
||||
result = torch.bmm(batch1, batch2)
|
||||
if not isinstance(alpha, numbers.Number) or alpha != 1:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
result = result * alpha
|
||||
if beta == 0:
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -132,7 +132,7 @@ def _make_export_case(m, name, configs):
|
|||
m.__doc__ is not None
|
||||
), f"Could not find description or docstring for export case: {m}"
|
||||
configs = {**configs, "description": m.__doc__}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return ExportCase(**{**configs, "model": m, "name": name})
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -3,12 +3,12 @@ import torch
|
|||
|
||||
class MyAutogradFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, x):
|
||||
return x.clone()
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
return grad_output + 1
|
||||
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ def get_class_if_classified_error(e: Exception) -> Optional[str]:
|
|||
TorchRuntimeError: None,
|
||||
}
|
||||
if type(e) in _ALLOW_LIST:
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
attr_name = _ALLOW_LIST[type(e)]
|
||||
if attr_name is None:
|
||||
return ALWAYS_CLASSIFIED
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class _KeyPathTrie:
|
|||
assert len(kp) > 0
|
||||
k, *kp = kp # type: ignore[assignment]
|
||||
node = node[k]
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return node, kp
|
||||
|
||||
|
||||
|
|
@ -356,12 +356,12 @@ def _override_builtin_ops():
|
|||
original_min = builtins.min
|
||||
original_pow = math.pow
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
builtins.max = functools.partial(
|
||||
_tensor_min_max, real_callable=original_max, tensor_callable=torch.maximum
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
builtins.min = functools.partial(
|
||||
_tensor_min_max, real_callable=original_min, tensor_callable=torch.minimum
|
||||
)
|
||||
|
|
@ -1087,7 +1087,7 @@ class _NonStrictTorchFunctionHandler(torch.overrides.TorchFunctionMode):
|
|||
|
||||
def run():
|
||||
# Run sequence.
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
t = args[0]
|
||||
for _method, _args in sequence:
|
||||
t = _method(t, *_args)
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||
self.callback = callback
|
||||
self.node: torch.fx.Node = next(iter(gm.graph.nodes))
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def placeholder(
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
|
|
@ -440,7 +440,7 @@ class _ExportPassBaseDeprecatedDoNotUse(PassBase):
|
|||
)
|
||||
self.tracer.fake_tensor_mode = prev_tracer.fake_tensor_mode
|
||||
interpreter = self.ExportInterpreter(self, graph_module)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
prev_interpreter, self.interpreter = (
|
||||
self.interpreter,
|
||||
torch.fx.Interpreter( # type: ignore[assignment]
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def _node_metadata_hook(
|
|||
that nodes being added are only call_function nodes, and copies over the
|
||||
first argument node's nn_module_stack.
|
||||
"""
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
fake_mode = fake_mode or contextlib.nullcontext()
|
||||
|
||||
assert node.op == "call_function" and callable(node.target), (
|
||||
|
|
@ -48,7 +48,7 @@ def _node_metadata_hook(
|
|||
fake_args, fake_kwargs = pytree.tree_map_only(
|
||||
torch.fx.Node, lambda arg: arg.meta["val"], (node.args, node.kwargs)
|
||||
)
|
||||
# pyrefly: ignore # bad-context-manager
|
||||
# pyrefly: ignore [bad-context-manager]
|
||||
with fake_mode, enable_python_dispatcher():
|
||||
fake_res = node.target(*fake_args, **fake_kwargs)
|
||||
node.meta["val"] = fake_res
|
||||
|
|
@ -84,7 +84,7 @@ def _node_metadata_hook(
|
|||
"torch_fn",
|
||||
(
|
||||
f"{node.target.__name__}_0",
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
f"{node.target.__class__.__name__}.{node.target.__name__}",
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -567,7 +567,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
|
|||
quantized = False
|
||||
|
||||
last_quantized_node = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for node in gm.graph.nodes:
|
||||
if isinstance(node.target, OpOverload):
|
||||
with gm.graph.inserting_before(node):
|
||||
|
|
@ -630,7 +630,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
|
|||
attr_names_to_clean.add(k)
|
||||
if k == "_buffers":
|
||||
buffer_name_to_clean = set()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for b_name, b_value in v.items():
|
||||
if isinstance(b_value, torch.Tensor) and b_value.dtype in [
|
||||
torch.qint8,
|
||||
|
|
@ -638,7 +638,7 @@ def replace_quantized_ops_with_standard_ops(gm: torch.fx.GraphModule):
|
|||
]:
|
||||
buffer_name_to_clean.add(b_name)
|
||||
for b_name in buffer_name_to_clean:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
v.pop(b_name, None)
|
||||
for attr_name in attr_names_to_clean:
|
||||
delattr(submod, attr_name)
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def _replace_with_hop_helper(
|
|||
)
|
||||
call_func_node.meta["torch_fn"] = (
|
||||
f"{wrap_hoo.__name__}",
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
f"{wrap_hoo.__class__.__name__}.{wrap_hoo.__name__}",
|
||||
)
|
||||
if isinstance(output_args, (tuple, list)):
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ def _postprocess_serialized_shapes(
|
|||
)
|
||||
for k, v in sorted(dims.items())
|
||||
}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
spec = DynamicShapesSpec(dynamic_shapes=dynamic_shapes, dims=dims)
|
||||
if to_dict:
|
||||
return _dataclass_to_dict(spec)
|
||||
|
|
@ -184,7 +184,7 @@ def _dump_dynamic_shapes(
|
|||
kwargs = kwargs or {}
|
||||
if isinstance(dynamic_shapes, dict):
|
||||
dynamic_shapes = dynamic_shapes.values() # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-assignment, bad-argument-type
|
||||
# pyrefly: ignore [bad-assignment, bad-argument-type]
|
||||
dynamic_shapes = tuple(dynamic_shapes)
|
||||
combined_args = tuple(args) + tuple(kwargs.values())
|
||||
|
||||
|
|
|
|||
|
|
@ -623,9 +623,9 @@ class _Commit:
|
|||
def update_schema():
|
||||
import importlib.resources
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if importlib.resources.is_resource(__package__, "schema.yaml"):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
content = importlib.resources.read_text(__package__, "schema.yaml")
|
||||
match = re.search("checksum<<([A-Fa-f0-9]{64})>>", content)
|
||||
_check(match is not None, "checksum not found in schema.yaml")
|
||||
|
|
@ -633,7 +633,7 @@ def update_schema():
|
|||
checksum_head = match.group(1)
|
||||
|
||||
thrift_content = importlib.resources.read_text(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
__package__,
|
||||
"export_schema.thrift",
|
||||
)
|
||||
|
|
@ -658,9 +658,9 @@ def update_schema():
|
|||
|
||||
src, cpp_header, thrift_schema = _staged_schema()
|
||||
additions, subtractions = _diff_schema(dst, src)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
yaml_path = __package__.replace(".", "/") + "/schema.yaml"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
thrift_schema_path = __package__.replace(".", "/") + "/export_schema.thrift"
|
||||
torch_prefix = "torch/"
|
||||
assert yaml_path.startswith(torch_prefix) # sanity check
|
||||
|
|
|
|||
|
|
@ -383,7 +383,7 @@ def _reconstruct_fake_tensor(
|
|||
fake_tensor = _CURRENT_DESERIALIZER.deserialize_tensor_meta(tensor_meta)
|
||||
if is_parameter:
|
||||
fake_tensor = torch.nn.Parameter(fake_tensor) # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return fake_tensor
|
||||
|
||||
|
||||
|
|
@ -2741,7 +2741,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
|||
serialized_node.metadata
|
||||
)
|
||||
assert arg is not None
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.generate_getitem(meta_val, fx_node, arg, 0, deserialized_metadata)
|
||||
fx_node.meta["val"] = tuple(meta_val)
|
||||
self.serialized_name_to_node[fx_node.name] = fx_node
|
||||
|
|
@ -3167,7 +3167,7 @@ def _dict_to_dataclass(cls, data):
|
|||
_value = next(iter(data.values()))
|
||||
assert isinstance(_type, str)
|
||||
field_type = cls.__annotations__[_type]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return cls.create(**{_type: _dict_to_dataclass(field_type, _value)})
|
||||
elif dataclasses.is_dataclass(cls):
|
||||
fields = {}
|
||||
|
|
@ -3474,23 +3474,23 @@ def _canonicalize_graph(
|
|||
n.metadata.clear()
|
||||
|
||||
# Stage 4: Aggregate values.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
sorted_tensor_values = dict(
|
||||
sorted(graph.tensor_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
sorted_sym_int_values = dict(
|
||||
sorted(graph.sym_int_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
sorted_sym_float_values = dict(
|
||||
sorted(graph.sym_float_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
sorted_sym_bool_values = dict(
|
||||
sorted(graph.sym_bool_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
sorted_custom_obj_values = dict(
|
||||
sorted(graph.custom_obj_values.items(), key=operator.itemgetter(0))
|
||||
)
|
||||
|
|
@ -3547,7 +3547,7 @@ def canonicalize(
|
|||
ExportedProgram: The canonicalized exported program.
|
||||
"""
|
||||
ep = copy.deepcopy(ep)
|
||||
# pyrefly: ignore # annotation-mismatch
|
||||
# pyrefly: ignore [annotation-mismatch]
|
||||
constants: set[str] = constants or set()
|
||||
|
||||
opset_version = dict(sorted(ep.opset_version.items(), key=operator.itemgetter(0)))
|
||||
|
|
|
|||
|
|
@ -1120,14 +1120,14 @@ def placeholder_naming_pass(
|
|||
if ( # handle targets for custom objects
|
||||
spec.kind == InputKind.CUSTOM_OBJ and spec.target in name_map
|
||||
):
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
spec.target = name_map[spec.target][4:] # strip obj_ prefix
|
||||
|
||||
for spec in export_graph_signature.output_specs:
|
||||
if spec.arg.name in name_map:
|
||||
spec.arg.name = name_map[spec.arg.name]
|
||||
if spec.kind == OutputKind.USER_INPUT_MUTATION and spec.target in name_map:
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
spec.target = name_map[spec.target]
|
||||
|
||||
# rename keys in constants dict for custom objects
|
||||
|
|
|
|||
|
|
@ -384,7 +384,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
|
|||
class AOTAutogradCachePickler(FxGraphCachePickler):
|
||||
def __init__(self, gm: torch.fx.GraphModule):
|
||||
super().__init__(gm)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
self.dispatch_table: dict
|
||||
self.dispatch_table.update(
|
||||
{
|
||||
|
|
|
|||
|
|
@ -86,10 +86,10 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
|||
|
||||
memory_format = MemoryFormatMeta.from_tensor(out)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if memory_format.memory_format is not None:
|
||||
was = out
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
out = out.contiguous(memory_format=memory_format.memory_format)
|
||||
updated = was is not out
|
||||
|
||||
|
|
@ -119,7 +119,7 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
|||
out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined]
|
||||
|
||||
if is_subclass:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
attrs = out.__tensor_flatten__()[0]
|
||||
|
||||
for attr in attrs:
|
||||
|
|
@ -129,7 +129,7 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
|||
new_elem_memory_format,
|
||||
elem_updated,
|
||||
) = coerce_tangent_and_suggest_memory_format(elem)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
out_memory_format.append(new_elem_memory_format)
|
||||
if elem_updated:
|
||||
setattr(out, attr, new_elem)
|
||||
|
|
@ -493,7 +493,7 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
curr_storage in inp_storage_refs
|
||||
and not functional_tensor_storage_changed
|
||||
):
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
base_idx = inp_storage_refs[curr_storage]
|
||||
is_input_tensor = id(o) in inp_tensor_ids
|
||||
num_aliased_outs = out_tensor_alias_counts[curr_storage]
|
||||
|
|
@ -701,7 +701,7 @@ from a multi-output view call"
|
|||
# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
|
||||
# are *regenerated* later, and not used directly in the autograd graph
|
||||
def _plain_fake_tensor_like_subclass(x):
|
||||
# pyrefly: ignore # bad-context-manager
|
||||
# pyrefly: ignore [bad-context-manager]
|
||||
with detect_fake_mode():
|
||||
return torch.empty(
|
||||
x.shape, dtype=x.dtype, device=x.device, layout=x.layout
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ def get_all_input_and_grad_nodes(
|
|||
continue
|
||||
if isinstance(desc, SubclassGetAttrAOTInput):
|
||||
_raise_autograd_subclass_not_implemented(n, desc)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
input_index[desc] = (n, None)
|
||||
elif n.op == "output":
|
||||
assert "desc" in n.meta, (n, n.meta)
|
||||
|
|
@ -130,7 +130,7 @@ def get_all_output_and_tangent_nodes(
|
|||
continue
|
||||
if isinstance(sub_d, SubclassGetAttrAOTOutput):
|
||||
_raise_autograd_subclass_not_implemented(sub_n, sub_d)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
output_index[sub_d] = (sub_n, None)
|
||||
for n in g.nodes:
|
||||
if n.op == "placeholder":
|
||||
|
|
|
|||
|
|
@ -1310,12 +1310,12 @@ def aot_dispatch_subclass(
|
|||
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
|
||||
meta_updated = run_functionalized_fw_and_collect_metadata(
|
||||
without_output_descs(metadata_fn),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
flat_args_descs=primals_unwrapped_descs,
|
||||
static_input_indices=remapped_static_indices,
|
||||
keep_input_mutations=meta.keep_input_mutations,
|
||||
is_train=meta.is_train,
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
)(*primals_unwrapped)
|
||||
|
||||
subclass_meta.fw_metadata = meta_updated
|
||||
|
|
|
|||
|
|
@ -538,7 +538,7 @@ def collect_bw_donated_buffer_idxs(
|
|||
fw_ins,
|
||||
user_fw_outs,
|
||||
bw_outs,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
saved_tensors,
|
||||
)
|
||||
|
||||
|
|
@ -1523,7 +1523,7 @@ def _aot_stage2a_partition(
|
|||
|
||||
# apply joint_gm callback here
|
||||
if callable(torch._functorch.config.joint_custom_pass):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
fx_g = torch._functorch.config.joint_custom_pass(fx_g, joint_inputs)
|
||||
|
||||
static_lifetime_input_indices = fw_metadata.static_input_indices
|
||||
|
|
@ -1794,7 +1794,7 @@ def _aot_stage2b_bw_compile(
|
|||
|
||||
ph_size = ph_arg.size()
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
placeholder_list[i] = ph_arg.as_strided(ph_size, real_stride)
|
||||
compiled_bw_func = None
|
||||
if (
|
||||
|
|
|
|||
|
|
@ -225,7 +225,7 @@ def make_output_handler(info, runtime_metadata, trace_joint):
|
|||
# not sure why AOTDispatcher needs to manually set this
|
||||
def maybe_mark_dynamic_helper(t: torch.Tensor, dims: set[int]):
|
||||
if hasattr(t, "_dynamo_weak_dynamic_indices"):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
t._dynamo_weak_dynamic_indices |= dims
|
||||
else:
|
||||
t._dynamo_weak_dynamic_indices = dims.copy() # type: ignore[attr-defined]
|
||||
|
|
@ -1143,7 +1143,7 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
|
|||
|
||||
def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]:
|
||||
f_args_inner = []
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
for inner_idx_or_tuple in synthetic_base_info:
|
||||
if isinstance(inner_idx_or_tuple, int):
|
||||
f_args_inner.append(primals[inner_idx_or_tuple])
|
||||
|
|
@ -2114,7 +2114,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
return (ctx._autograd_function_id, *ctx.symints)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, *deduped_flat_tensor_args):
|
||||
args = deduped_flat_tensor_args
|
||||
if backward_state_indices:
|
||||
|
|
@ -2151,7 +2151,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
# in the fw output order.
|
||||
fw_outs = call_func_at_runtime_with_args(
|
||||
CompiledFunction.compiled_fw,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
args,
|
||||
disable_amp=disable_amp,
|
||||
)
|
||||
|
|
@ -2347,7 +2347,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
|||
_aot_id = aot_config.aot_id
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(double_ctx, *unused_args):
|
||||
return impl_fn(double_ctx)
|
||||
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ class MemoryFormatMeta:
|
|||
|
||||
if use_memory_format:
|
||||
return MemoryFormatMeta(
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
memory_format=torch._prims_common.suggest_memory_format(t),
|
||||
)
|
||||
|
||||
|
|
@ -893,15 +893,15 @@ class GraphSignature:
|
|||
parameters_to_mutate = {}
|
||||
for output_name, mutation_name in outputs_to_mutations.items():
|
||||
if mutation_name in user_inputs:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
user_inputs_to_mutate[output_name] = mutation_name
|
||||
else:
|
||||
assert mutation_name in buffers or mutation_name in parameters
|
||||
if mutation_name in buffers:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
buffers_to_mutate[output_name] = mutation_name
|
||||
else:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
parameters_to_mutate[output_name] = mutation_name
|
||||
|
||||
start, stop = stop, stop + num_user_outputs
|
||||
|
|
@ -1236,9 +1236,9 @@ class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
|
|||
output_code_ty: type[TOutputCode],
|
||||
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
|
||||
):
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
# pyrefly: ignore [invalid-type-var]
|
||||
self.output_code_ty = output_code_ty
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
# pyrefly: ignore [invalid-type-var]
|
||||
self.compiler_fn = compiler_fn
|
||||
|
||||
def __call__(
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
|
|||
"""
|
||||
for name, tensor in itertools.chain(
|
||||
list(module.named_parameters(recurse=False)),
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
list(module.named_buffers(recurse=False)),
|
||||
):
|
||||
if is_traceable_wrapper_subclass(tensor):
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ def unwrap_tensor_subclasses(
|
|||
n_desc: Any = (
|
||||
SubclassGetAttrAOTInput(desc, attr)
|
||||
if isinstance(desc, AOTInput)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
else SubclassGetAttrAOTOutput(desc, attr)
|
||||
)
|
||||
flatten_subclass(inner_tensor, n_desc, out=out)
|
||||
|
|
@ -258,7 +258,7 @@ def unwrap_tensor_subclasses(
|
|||
descs_inner: list[AOTDescriptor] = []
|
||||
|
||||
for x, desc in zip(wrapped_args, wrapped_args_descs):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
flatten_subclass(typing.cast(Tensor, x), desc, out=(xs_inner, descs_inner))
|
||||
|
||||
return xs_inner, descs_inner
|
||||
|
|
@ -283,7 +283,7 @@ def runtime_unwrap_tensor_subclasses(
|
|||
|
||||
for attr in attrs:
|
||||
inner_tensor = getattr(x, attr)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
inner_meta = meta.attrs.get(attr)
|
||||
flatten_subclass(inner_tensor, inner_meta, out=out)
|
||||
|
||||
|
|
|
|||
|
|
@ -331,7 +331,7 @@ def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
|
|||
and out.args[1] == 0
|
||||
and out.args[0] in with_effect_nodes
|
||||
):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
output_token_nodes.append(out)
|
||||
else:
|
||||
other_output_nodes.append(out)
|
||||
|
|
@ -573,10 +573,10 @@ def without_output_descs(f: Callable[_P, tuple[_T, _S]]) -> Callable[_P, _T]:
|
|||
@wraps(f)
|
||||
@simple_wraps(f)
|
||||
def inner(*args, **kwargs):
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
# pyrefly: ignore [invalid-param-spec]
|
||||
return f(*args, **kwargs)[0]
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return inner
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -753,7 +753,7 @@ class AutogradFunctionApply(HigherOrderOperator):
|
|||
|
||||
class ApplyTemplate(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, *args):
|
||||
nonlocal saved_values
|
||||
output, saved_values = fwd(None, *fwd_args)
|
||||
|
|
|
|||
|
|
@ -42,9 +42,9 @@ def create_names_map(
|
|||
This function creates a mapping from the names in named_params to the
|
||||
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
|
||||
"""
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
named_params = dict(named_params)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
tied_named_params = dict(tied_named_params)
|
||||
|
||||
tensors_dict_keys = set(named_params.keys())
|
||||
|
|
@ -53,11 +53,11 @@ def create_names_map(
|
|||
|
||||
tensor_to_mapping: dict[Tensor, tuple[str, list[str]]] = {}
|
||||
for key, tensor in named_params.items():
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
tensor_to_mapping[tensor] = (key, [])
|
||||
for key, tensor in tied_named_params.items():
|
||||
assert tensor in tensor_to_mapping
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
tensor_to_mapping[tensor][1].append(key)
|
||||
return dict(tensor_to_mapping.values())
|
||||
|
||||
|
|
|
|||
|
|
@ -199,7 +199,7 @@ def _extract_graph_with_inputs_outputs(
|
|||
new_node = new_graph.placeholder(node.name)
|
||||
# Can't use node_copy here as we may be turning previous call_function into placeholders
|
||||
new_node.meta = node.meta
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
env[node] = new_node
|
||||
|
||||
for node in joint_graph.nodes:
|
||||
|
|
@ -228,10 +228,10 @@ def _extract_graph_with_inputs_outputs(
|
|||
if any(all_args):
|
||||
env[node] = InvalidNode # type: ignore[assignment]
|
||||
continue
|
||||
# pyrefly: ignore # unsupported-operation, bad-argument-type
|
||||
# pyrefly: ignore [unsupported-operation, bad-argument-type]
|
||||
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
||||
elif node.op == "get_attr":
|
||||
# pyrefly: ignore # unsupported-operation, bad-argument-type
|
||||
# pyrefly: ignore [unsupported-operation, bad-argument-type]
|
||||
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
||||
elif node.op == "output":
|
||||
pass
|
||||
|
|
@ -628,7 +628,7 @@ def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
|||
position_to_quant.get(i, node) for i, node in enumerate(fwd_outputs)
|
||||
]
|
||||
# add the scale nodes to the output find the first sym_node in the output
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
idx = find_first_sym_node(output_updated_args)
|
||||
scale_nodes = tensor_scale_nodes + sym_scale_nodes
|
||||
if scale_nodes:
|
||||
|
|
@ -1234,7 +1234,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
|
|||
# critical path first.
|
||||
cur_nodes += node.all_input_nodes
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n])
|
||||
for node in insertable_nodes:
|
||||
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
||||
|
|
@ -1463,14 +1463,14 @@ def functionalize_rng_ops(
|
|||
devices = OrderedSet(
|
||||
get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values()
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
devices.discard(torch.device("cpu"))
|
||||
# multiple cuda devices won't work with cudagraphs anyway,
|
||||
# fallback to non graphsafe rng checkpointing
|
||||
multi_cuda_devices = len(devices) > 1
|
||||
|
||||
# this changes numerics, so if fallback_random is set we will not use it
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
ind_config = torch._inductor.config
|
||||
use_rng_graphsafe_rng_functionalization = (
|
||||
config.graphsafe_rng_functionalization
|
||||
|
|
@ -2902,7 +2902,7 @@ def min_cut_rematerialization_partition(
|
|||
node_info,
|
||||
memory_budget=memory_budget,
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if config._sync_decision_cross_ranks:
|
||||
saved_values = _sync_decision_cross_ranks(joint_graph, saved_values)
|
||||
# save_for_backward on tensors and stashes symints in autograd .ctx
|
||||
|
|
@ -2913,7 +2913,7 @@ def min_cut_rematerialization_partition(
|
|||
fw_module, bw_module = _extract_fwd_bwd_modules(
|
||||
joint_module,
|
||||
saved_values,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class VmapInterpreter(FuncTorchInterpreter):
|
|||
self._cdata = cdata
|
||||
|
||||
@cached_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _cptr(self):
|
||||
return CVmapInterpreterPtr(self._cdata)
|
||||
|
||||
|
|
@ -171,7 +171,7 @@ class GradInterpreter(FuncTorchInterpreter):
|
|||
self._cdata = cdata
|
||||
|
||||
@cached_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _cptr(self):
|
||||
return CGradInterpreterPtr(self._cdata)
|
||||
|
||||
|
|
@ -209,7 +209,7 @@ class JvpInterpreter(FuncTorchInterpreter):
|
|||
self._cdata = cdata
|
||||
|
||||
@cached_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _cptr(self):
|
||||
return CJvpInterpreterPtr(self._cdata)
|
||||
|
||||
|
|
@ -246,7 +246,7 @@ class FunctionalizeInterpreter(FuncTorchInterpreter):
|
|||
self._cdata = cdata
|
||||
|
||||
@cached_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _cptr(self):
|
||||
return CFunctionalizeInterpreterPtr(self._cdata)
|
||||
|
||||
|
|
|
|||
|
|
@ -57,7 +57,7 @@ def _interleave(a, b, dim=0):
|
|||
|
||||
stacked = torch.stack([a, b], dim=dim + 1)
|
||||
interleaved = torch.flatten(stacked, start_dim=dim, end_dim=dim + 1)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if b_trunc:
|
||||
# TODO: find torch alternative for slice_along dim for torch.jit.script to work
|
||||
interleaved = aten.slice(interleaved, dim, 0, b.shape[dim] + a.shape[dim] - 1)
|
||||
|
|
@ -97,7 +97,7 @@ class AssociativeScanOp(HigherOrderOperator):
|
|||
validate_subgraph_args_types(additional_inputs)
|
||||
return super().__call__(combine_fn, xs, additional_inputs)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def gen_schema(self, combine_fn, xs, additional_inputs):
|
||||
from torch._higher_order_ops.schema import HopSchemaGenerator
|
||||
from torch._higher_order_ops.utils import materialize_as_graph
|
||||
|
|
@ -650,7 +650,7 @@ class AssociativeScanAutogradOp(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(
|
||||
ctx,
|
||||
combine_fn,
|
||||
|
|
|
|||
|
|
@ -610,7 +610,7 @@ def do_auto_functionalize_v2(
|
|||
normalized_kwargs = {}
|
||||
|
||||
schema = op._schema
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
op = op._op if isinstance(op, HopInstance) else op
|
||||
assert isinstance(op, get_args(_MutableOpType))
|
||||
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
|||
out = self(functionalized_subgraph, *unwrapped_operands, **kwargs)
|
||||
return ctx.wrap_tensors(out)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def gen_schema(self, subgraph, *operands, **kwargs):
|
||||
from .schema import HopSchemaGenerator
|
||||
|
||||
|
|
@ -216,7 +216,7 @@ class BaseHOP(HigherOrderOperator, abc.ABC):
|
|||
|
||||
class BaseHOPFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, hop, subgraph, kwargs, *operands):
|
||||
ctx.hop = hop
|
||||
ctx.operands = operands
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ class CondOp(HigherOrderOperator):
|
|||
validate_subgraph_args_types(operands)
|
||||
return super().__call__(pred, true_fn, false_fn, operands)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def gen_schema(self, pred, true_fn, false_fn, operands):
|
||||
from torch._higher_order_ops.schema import HopSchemaGenerator
|
||||
from torch._higher_order_ops.utils import materialize_as_graph
|
||||
|
|
@ -286,7 +286,7 @@ def cond_op_dense(pred, true_fn, false_fn, operands):
|
|||
|
||||
class CondAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(
|
||||
ctx,
|
||||
pred,
|
||||
|
|
|
|||
|
|
@ -298,5 +298,5 @@ def handle_effects(
|
|||
assert isinstance(wrapped_token, torch.Tensor)
|
||||
tokens[key] = wrapped_token
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return ctx.wrap_tensors(unwrapped_outs)
|
||||
|
|
|
|||
|
|
@ -354,7 +354,7 @@ def trace_flex_attention(
|
|||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", flex_attention, proxy_args, {}
|
||||
|
|
@ -363,7 +363,7 @@ def trace_flex_attention(
|
|||
example_out,
|
||||
out_proxy,
|
||||
constant=None,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
tracer=proxy_mode.tracer,
|
||||
)
|
||||
|
||||
|
|
@ -626,7 +626,7 @@ def create_fw_bw_graph(
|
|||
|
||||
class FlexAttentionAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(
|
||||
ctx: Any,
|
||||
query: Tensor,
|
||||
|
|
@ -1075,7 +1075,7 @@ def trace_flex_attention_backward(
|
|||
score_mod_other_buffers,
|
||||
mask_mod_other_buffers,
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function",
|
||||
|
|
@ -1088,7 +1088,7 @@ def trace_flex_attention_backward(
|
|||
example_out,
|
||||
out_proxy,
|
||||
constant=None,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
tracer=proxy_mode.tracer,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class InvokeSubgraphHOP(HigherOrderOperator):
|
|||
|
||||
return super().__call__(subgraph, identifier, *operands)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def gen_schema(self, subgraph, identifier, *operands):
|
||||
from torch._higher_order_ops.schema import HopSchemaGenerator
|
||||
from torch._higher_order_ops.utils import (
|
||||
|
|
@ -402,7 +402,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(
|
||||
ctx,
|
||||
subgraph,
|
||||
|
|
@ -479,7 +479,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
|||
for tangent in filtered_grad_outs:
|
||||
metadata = extract_tensor_metadata(tangent)
|
||||
metadata._flatten_into(tangent_metadata, fake_mode, state)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
tangent_metadata = tuple(tangent_metadata)
|
||||
|
||||
# bw_graph is a joint graph with signature (*primals_and_tangents) and
|
||||
|
|
|
|||
|
|
@ -387,7 +387,7 @@ def create_hop_fw_bw(
|
|||
|
||||
class LocalMapAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(
|
||||
ctx: Any,
|
||||
fw_gm: GraphModule,
|
||||
|
|
@ -440,7 +440,7 @@ class LocalMapAutogradOp(torch.autograd.Function):
|
|||
)
|
||||
|
||||
for i, meta in ctx.expected_tangent_metadata.items():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
grads[i] = coerce_to_expected_memory_format(grads[i], meta)
|
||||
|
||||
grad_ins = local_map_hop(ctx.bw_gm, *saved_activations, *grads)
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ def map(
|
|||
|
||||
class MapAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, f, num_mapped_args, *flat_args):
|
||||
ctx._f = f
|
||||
ctx._num_mapped_args = num_mapped_args
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ class ScanOp(HigherOrderOperator):
|
|||
validate_subgraph_args_types(additional_inputs)
|
||||
return super().__call__(combine_fn, init, xs, additional_inputs)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def gen_schema(self, combine_fn, init, xs, additional_inputs):
|
||||
from torch._higher_order_ops.schema import HopSchemaGenerator
|
||||
from torch._higher_order_ops.utils import materialize_as_graph
|
||||
|
|
@ -449,7 +449,7 @@ class ScanAutogradOp(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(
|
||||
ctx,
|
||||
hop_partitioned_graph,
|
||||
|
|
|
|||
|
|
@ -292,7 +292,7 @@ def generate_ttir(
|
|||
ordered_args[name] = 2
|
||||
elif (
|
||||
stable_meta := maybe_unpack_tma_stable_metadata(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
tma_descriptor_metadata.get(name, None)
|
||||
)
|
||||
) is not None:
|
||||
|
|
@ -426,7 +426,7 @@ def generate_ttir(
|
|||
specialize_value=not kp.do_not_specialize,
|
||||
align=not kp.do_not_specialize_on_alignment,
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
attrvals.append(spec[1])
|
||||
|
||||
attrs = find_paths_if(attrvals, lambda _, x: isinstance(x, str))
|
||||
|
|
@ -445,7 +445,7 @@ def generate_ttir(
|
|||
def get_signature_value(idx: int, arg: Any) -> str:
|
||||
if kernel.params[idx].is_constexpr:
|
||||
return "constexpr"
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
return mangle_type(arg)
|
||||
|
||||
else:
|
||||
|
|
@ -819,7 +819,7 @@ def get_tma_stores(
|
|||
for op in op_list:
|
||||
if op.name == "tt.call":
|
||||
assert op.fn_call_name in functions
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
tma_stores = get_tma_stores(functions, op.fn_call_name)
|
||||
for i, inp in enumerate(op.args):
|
||||
if Param(idx=i) in tma_stores:
|
||||
|
|
@ -901,7 +901,7 @@ def analyze_kernel_mutations(
|
|||
assert op.fn_call_name in functions
|
||||
mutations = analyze_kernel_mutations(
|
||||
functions,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
op.fn_call_name,
|
||||
len(op.args),
|
||||
)
|
||||
|
|
@ -956,7 +956,7 @@ def identify_mutated_tensors(
|
|||
assert functions is not None
|
||||
kernel_name = next(iter(functions.keys()))
|
||||
# Triton codegen modifies the name
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
assert kernel.fn.__name__ in kernel_name
|
||||
# Reset the cache between top level invocations
|
||||
# The cache for analyze kernel mutations is mainly used for cycle
|
||||
|
|
@ -1060,9 +1060,9 @@ def triton_kernel_wrapper_mutation_dense(
|
|||
grid_fn = grid[0]
|
||||
else:
|
||||
fn_name, code = user_defined_kernel_grid_fn_code(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
kernel.fn.__name__,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
kernel.configs,
|
||||
grid,
|
||||
)
|
||||
|
|
@ -1113,7 +1113,7 @@ def triton_kernel_wrapper_mutation_dense(
|
|||
# avoid mutating the original inputs
|
||||
kwargs = kwargs.copy()
|
||||
constant_args = constant_args.copy()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for name in kernel.arg_names:
|
||||
if name in kwargs:
|
||||
args.append(kwargs.pop(name))
|
||||
|
|
@ -1122,7 +1122,7 @@ def triton_kernel_wrapper_mutation_dense(
|
|||
else:
|
||||
break
|
||||
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
kernel[grid_fn](*args, **kwargs, **constant_args)
|
||||
|
||||
|
||||
|
|
@ -1528,7 +1528,7 @@ class TritonHOPifier:
|
|||
|
||||
assert kernel_idx is None or variable.kernel_idx == kernel_idx
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
variable.grid = grid
|
||||
|
||||
if isinstance(kernel, Autotuner):
|
||||
|
|
@ -2076,7 +2076,7 @@ class TraceableTritonKernelWrapper:
|
|||
return tracing_triton_hopifier_singleton.call_run(self, args, kwargs, None)
|
||||
else:
|
||||
assert self.kernel is not None
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return self.kernel.run(*args, **kwargs)
|
||||
|
||||
def __call__(self, *args: Sequence[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
|
|
@ -2088,7 +2088,7 @@ class TraceableTritonKernelWrapper:
|
|||
)
|
||||
else:
|
||||
assert self.kernel is not None
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
return self.kernel[self.grid](*args, **kwargs)
|
||||
|
||||
def specialize_symbolic(self, arg: Sequence[Any]) -> Any:
|
||||
|
|
|
|||
|
|
@ -270,7 +270,7 @@ def _set_compilation_env():
|
|||
# We need to turn off the is_fx_tracing_flag. Remove this flag check from dyanmo
|
||||
# once we are confident fx tracing works with dynamo.
|
||||
torch.fx._symbolic_trace._is_fx_tracing_flag = False
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
torch._dynamo.config.allow_empty_graphs = True
|
||||
torch._dynamo.config.capture_scalar_outputs = True
|
||||
yield
|
||||
|
|
@ -441,7 +441,7 @@ def unique_graph_name_with_root(
|
|||
) -> tuple[int, str]:
|
||||
next_name = None
|
||||
i = 0
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
while not next_name:
|
||||
candidate = f"{prefix}_{i}"
|
||||
if hasattr(root, candidate):
|
||||
|
|
@ -798,7 +798,7 @@ def create_bw_fn(
|
|||
|
||||
from torch._functorch.aot_autograd import AOTConfig, create_joint
|
||||
|
||||
# pyrefly: ignore # missing-module-attribute
|
||||
# pyrefly: ignore [missing-module-attribute]
|
||||
from torch._higher_order_ops.utils import prepare_fw_with_masks_all_requires_grad
|
||||
|
||||
dummy_aot_config = AOTConfig(
|
||||
|
|
@ -943,7 +943,7 @@ def check_input_alias_and_mutation(
|
|||
out_out_alias_map,
|
||||
mutated_inputs,
|
||||
) = check_input_alias_and_mutation_return_outputs(gm)[:-1]
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return inp_inp_alias_map, inp_out_alias_map, out_out_alias_map, mutated_inputs
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class WhileLoopOp(HigherOrderOperator):
|
|||
validate_subgraph_args_types(additional_inputs)
|
||||
return super().__call__(cond_fn, body_fn, carried_inputs, additional_inputs)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def gen_schema(self, cond_fn, body_fn, carried_inputs, additional_inputs):
|
||||
from torch._higher_order_ops.schema import HopSchemaGenerator
|
||||
from torch._higher_order_ops.utils import materialize_as_graph
|
||||
|
|
@ -431,7 +431,7 @@ def while_loop_tracing(
|
|||
elif isinstance(x, torch.Tensor):
|
||||
x = x.clone()
|
||||
if hasattr(x, "constant") and x.constant is not None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
x.constant = None
|
||||
return x
|
||||
|
||||
|
|
@ -454,7 +454,7 @@ def while_loop_tracing(
|
|||
|
||||
next_name = None
|
||||
i = 0
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
while not next_name:
|
||||
candidate = f"while_loop_cond_graph_{i}"
|
||||
if hasattr(proxy_mode.tracer.root, candidate):
|
||||
|
|
@ -699,7 +699,7 @@ class WhileLoopStackOutputOp(HigherOrderOperator):
|
|||
|
||||
class WhileLoopAutogradOp(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(
|
||||
ctx,
|
||||
cond_fn,
|
||||
|
|
@ -729,7 +729,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
|
|||
ctx.additional_inputs = additional_inputs
|
||||
ctx.fw_outputs = fw_outputs
|
||||
loop_count = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for out in fw_outputs:
|
||||
if isinstance(out, torch.Tensor):
|
||||
if loop_count is not None:
|
||||
|
|
@ -883,7 +883,7 @@ class WhileLoopAutogradOp(torch.autograd.Function):
|
|||
while_loop_op(
|
||||
cond_gm,
|
||||
body_gm,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
(
|
||||
init_idx,
|
||||
*init_grad_carries,
|
||||
|
|
|
|||
|
|
@ -844,7 +844,7 @@ def ignore(drop=False, **kwargs):
|
|||
# @torch.jit.ignore
|
||||
# def fn(...):
|
||||
fn = drop
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
fn._torchscript_modifier = FunctionModifiers.IGNORE
|
||||
return fn
|
||||
|
||||
|
|
@ -1255,7 +1255,7 @@ def _get_named_tuple_properties(
|
|||
obj_annotations = inspect.get_annotations(obj)
|
||||
if len(obj_annotations) == 0 and hasattr(obj, "__base__"):
|
||||
obj_annotations = inspect.get_annotations(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
obj.__base__
|
||||
)
|
||||
|
||||
|
|
@ -1447,7 +1447,7 @@ def container_checker(obj, target_type) -> bool:
|
|||
return False
|
||||
return True
|
||||
elif origin_type is Union or issubclass(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
origin_type,
|
||||
BuiltinUnionType,
|
||||
): # also handles Optional
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ def infer_schema(
|
|||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
|
||||
seen_args.add(name)
|
||||
if param.default is inspect.Parameter.empty:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
params.append(f"{schema_type} {name}")
|
||||
else:
|
||||
default_repr = None
|
||||
|
|
@ -194,7 +194,7 @@ def infer_schema(
|
|||
f"Parameter {name} has an unsupported default value type {type(param.default)}. "
|
||||
f"Please file an issue on GitHub so we can prioritize this."
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
params.append(f"{schema_type} {name}={default_repr}")
|
||||
if mutates_args != UNKNOWN_MUTATES:
|
||||
mutates_args_not_seen = set(mutates_args) - seen_args
|
||||
|
|
@ -221,7 +221,7 @@ def derived_types(
|
|||
):
|
||||
result: list[tuple[Union[type, typing._SpecialForm, GenericAlias], str]] = [
|
||||
(base_type, cpp_type),
|
||||
# pyrefly: ignore # not-a-type
|
||||
# pyrefly: ignore [not-a-type]
|
||||
(typing.Optional[base_type], f"{cpp_type}?"),
|
||||
]
|
||||
|
||||
|
|
@ -240,7 +240,7 @@ def derived_types(
|
|||
if optional_base_list:
|
||||
result.extend(
|
||||
(seq_typ, f"{cpp_type}?[]")
|
||||
# pyrefly: ignore # not-a-type
|
||||
# pyrefly: ignore [not-a-type]
|
||||
for seq_typ in derived_seq_types(typing.Optional[base_type])
|
||||
)
|
||||
if optional_list_base:
|
||||
|
|
@ -252,7 +252,7 @@ def derived_types(
|
|||
|
||||
|
||||
def get_supported_param_types():
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
data: list[tuple[Union[type, typing._SpecialForm], str, bool, bool, bool]] = [
|
||||
# (python type, schema type, type[] variant, type?[] variant, type[]? variant
|
||||
(Tensor, "Tensor", True, True, False),
|
||||
|
|
@ -296,7 +296,7 @@ def parse_return(annotation, error_fn):
|
|||
f"Return has unsupported type {annotation}. "
|
||||
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
|
||||
)
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
return SUPPORTED_RETURN_TYPES[annotation]
|
||||
|
||||
args = typing.get_args(annotation)
|
||||
|
|
|
|||
|
|
@ -1050,7 +1050,7 @@ class LOBPCG:
|
|||
|
||||
return torch.matmul(
|
||||
U * d_col.mT,
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
Z * E**-0.5,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1246,7 +1246,7 @@ class OpOverloadPacket(Generic[_P, _T]):
|
|||
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
|
||||
# intercept it here and call TorchBindOpverload instead.
|
||||
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return _call_overload_packet_from_python(self, *args, **kwargs)
|
||||
return self._op(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -427,7 +427,7 @@ def _prim_elementwise_meta(
|
|||
# Acquires the device (if it exists) or number
|
||||
device = None
|
||||
number = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for arg in args_:
|
||||
if isinstance(arg, TensorLike):
|
||||
if utils.is_cpu_scalar_tensor(arg):
|
||||
|
|
@ -1012,10 +1012,10 @@ def _div_aten(a, b):
|
|||
)
|
||||
|
||||
if is_integral:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch.div(a, b, rounding_mode="trunc")
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch.true_divide(a, b)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -125,7 +125,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
|||
# Unless we are in prims_mode, in which case we want to use nvprims
|
||||
if orig_func in torch_function_passthrough or orig_func in all_prims():
|
||||
with self.prims_mode_cls():
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
# pyrefly: ignore [invalid-param-spec]
|
||||
return orig_func(*args, **kwargs)
|
||||
mapping = torch_to_refs_map()
|
||||
func = mapping.get(orig_func, None)
|
||||
|
|
@ -148,7 +148,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
|||
if func is not None:
|
||||
# If the ref exists query whether we should use it or not
|
||||
if self.should_fallback_fn(self, orig_func, func, args, kwargs):
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
# pyrefly: ignore [invalid-param-spec]
|
||||
return orig_func(*args, **kwargs)
|
||||
# torch calls inside func should be interpreted as refs calls
|
||||
with self:
|
||||
|
|
@ -157,5 +157,5 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
|||
raise RuntimeError(
|
||||
f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
|
||||
)
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
# pyrefly: ignore [invalid-param-spec]
|
||||
return orig_func(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None):
|
|||
rngprim_def = torch.library.custom_op(
|
||||
"rngprims::" + name, impl_aten, mutates_args=(), schema=schema
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
rngprim_def.register_fake(impl_meta)
|
||||
|
||||
prim_packet = getattr(torch._ops.ops.rngprims, name)
|
||||
|
|
@ -330,11 +330,11 @@ def register_graphsafe_run_with_rng_state_op():
|
|||
|
||||
@graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA)
|
||||
def impl_cuda(op, *args, rng_state=None, **kwargs):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
device_idx = rng_state.device.index
|
||||
generator = torch.cuda.default_generators[device_idx]
|
||||
current_state = generator.graphsafe_get_state()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
generator.graphsafe_set_state(rng_state)
|
||||
out = op(*args, **kwargs)
|
||||
generator.graphsafe_set_state(current_state)
|
||||
|
|
|
|||
|
|
@ -396,7 +396,7 @@ def is_contiguous_for_memory_format( # type: ignore[return]
|
|||
*,
|
||||
memory_format: torch.memory_format,
|
||||
false_if_dde=False,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> bool:
|
||||
validate_memory_format(memory_format)
|
||||
|
||||
|
|
@ -820,13 +820,13 @@ def canonicalize_dims(
|
|||
rank: int,
|
||||
indices: Sequence[int],
|
||||
wrap_scalar: bool = True,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> tuple[int, ...]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:
|
||||
pass
|
||||
|
||||
|
|
@ -873,7 +873,7 @@ def check_same_device(*args, allow_cpu_scalar_tensors):
|
|||
|
||||
# Note: cannot initialize device to the first arg's device (it may not have one)
|
||||
device = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for arg in args:
|
||||
if isinstance(arg, Number):
|
||||
continue
|
||||
|
|
@ -921,7 +921,7 @@ def check_same_shape(*args, allow_cpu_scalar_tensors: bool):
|
|||
"""
|
||||
shape = None
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for arg in args:
|
||||
if isinstance(arg, Number):
|
||||
continue
|
||||
|
|
@ -948,7 +948,7 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
|
|||
shape = None
|
||||
scalar_shape = None
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for arg in args:
|
||||
if isinstance(arg, Number):
|
||||
continue
|
||||
|
|
@ -1005,7 +1005,7 @@ def extract_shape_from_varargs(
|
|||
|
||||
# Handles tuple unwrapping
|
||||
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
shape = shape[0]
|
||||
|
||||
if validate:
|
||||
|
|
@ -1301,7 +1301,7 @@ def get_higher_dtype(
|
|||
|
||||
raise RuntimeError("Unexpected type given to _extract_dtype!")
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
a, b = _extract_dtype(a), _extract_dtype(b)
|
||||
|
||||
if a is b:
|
||||
|
|
@ -1397,7 +1397,7 @@ def check_same_dtype(*args):
|
|||
full_dtype = None
|
||||
scalar_type = None
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for arg in args:
|
||||
if isinstance(arg, Number):
|
||||
# Scalar type checking is disabled (and may be removed in the future)
|
||||
|
|
@ -1668,10 +1668,10 @@ def elementwise_dtypes(
|
|||
|
||||
# Prefers dtype of tensors with one or more dimensions
|
||||
if one_plus_dim_tensor_dtype is not None:
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return one_plus_dim_tensor_dtype
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return zero_dim_tensor_dtype
|
||||
|
||||
if highest_type is float:
|
||||
|
|
|
|||
|
|
@ -28,19 +28,19 @@ _P = ParamSpec("_P")
|
|||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
|
||||
pass
|
||||
|
||||
|
|
@ -280,7 +280,7 @@ def out_wrapper(
|
|||
if is_tensor
|
||||
else NamedTuple(
|
||||
f"return_types_{fn.__name__}",
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
[(o, TensorLikeType) for o in out_names],
|
||||
)
|
||||
)
|
||||
|
|
@ -299,7 +299,7 @@ def out_wrapper(
|
|||
kwargs[k] = out_attr
|
||||
|
||||
def maybe_check_copy_devices(out):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
if isinstance(out, TensorLike) and isinstance(args[0], TensorLike):
|
||||
check_copy_devices(copy_from=args[0], copy_to=out)
|
||||
|
||||
|
|
@ -435,7 +435,7 @@ def backwards_not_supported(prim):
|
|||
|
||||
class BackwardsNotSupported(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, args_spec, *flat_args):
|
||||
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
|
||||
return redispatch_prim(args, kwargs)
|
||||
|
|
@ -484,14 +484,14 @@ def elementwise_unary_scalar_wrapper(
|
|||
dtype = utils.type_to_dtype(type(args[0]))
|
||||
args_ = list(args)
|
||||
args_[0] = torch.tensor(args[0], dtype=dtype)
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
# pyrefly: ignore [invalid-param-spec]
|
||||
result = fn(*args_, **kwargs)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
return result.item()
|
||||
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
# pyrefly: ignore [invalid-param-spec]
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
_fn.__signature__ = sig # type: ignore[attr-defined]
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return _fn
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user