Add pyrefly suppressions 2/n (#164513)

Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

---
step 1: uncomment lines in the `pyrefly.toml` file
before: https://gist.github.com/maggiemoss/911b4d0bc88bf8cf3ab91f67184e9d46

after:
```
 INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml`
 INFO 0 errors (1,152 ignored)
 ```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164513
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss 2025-10-03 02:46:09 +00:00 committed by PyTorch MergeBot
parent d1cbb74fb1
commit 1051c1de5c
26 changed files with 133 additions and 13 deletions

View File

@ -40,16 +40,6 @@ project-excludes = [
"torch/autograd/**",
"torch/cuda/**",
"torch/export/**",
"torch/profiler/**",
"torch/_prims_common/**",
"torch/backends/**",
# "torch/testing/**",
"torch/_C/**",
"torch/sparse/**",
"torch/_library/**",
"torch/_prims/**",
"torch/_decomp/**",
"torch/_meta_registrations.py",
# formatting issues
"torch/linalg/__init__.py",
"torch/package/importer.py",

View File

@ -112,17 +112,28 @@ class DebugLevel(Enum):
DETAIL = ...
class ReduceOp:
# pyrefly: ignore # unknown-name
def __init__(self, op: RedOpType) -> None: ...
# pyrefly: ignore # unknown-name
SUM: RedOpType = ...
# pyrefly: ignore # unknown-name
AVG: RedOpType = ...
# pyrefly: ignore # unknown-name
PRODUCT: RedOpType = ...
# pyrefly: ignore # unknown-name
MIN: RedOpType = ...
# pyrefly: ignore # unknown-name
MAX: RedOpType = ...
# pyrefly: ignore # unknown-name
BAND: RedOpType = ...
# pyrefly: ignore # unknown-name
BOR: RedOpType = ...
# pyrefly: ignore # unknown-name
BXOR: RedOpType = ...
# pyrefly: ignore # unknown-name
PREMUL_SUM: RedOpType = ...
# pyrefly: ignore # unknown-name
UNUSED: RedOpType = ...
# mypy error being ignored:

View File

@ -240,6 +240,7 @@ def get_decompositions(
registry = global_decomposition_table[type]
packets_to_overloads = defaultdict(list)
# pyrefly: ignore # bad-assignment
for opo in registry:
if isinstance(opo, (OpOverload, OpOverloadPacket)):
packets_to_overloads[opo.overloadpacket].append(opo)

View File

@ -382,6 +382,7 @@ def to_real_dtype(dtype: torch.dtype):
def mse_loss(
self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
) -> Tensor:
# pyrefly: ignore # unsupported-operation
loss = (self - target) ** 2
return apply_loss_reduction(loss, reduction)
@ -415,6 +416,7 @@ def smooth_l1_loss(
beta: float = 1.0,
):
loss = (self - target).abs()
# pyrefly: ignore # unsupported-operation
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
return apply_loss_reduction(loss, reduction)
@ -4893,7 +4895,9 @@ def _reflection_pad_backward(grad_output, x, padding):
@register_decomposition(aten.aminmax)
@out_wrapper("min", "max")
def aminmax(self, *, dim=None, keepdim=False):
# pyrefly: ignore # bad-argument-type
amin = torch.amin(self, dim=dim, keepdim=keepdim)
# pyrefly: ignore # bad-argument-type
amax = torch.amax(self, dim=dim, keepdim=keepdim)
return amin, amax
@ -5138,6 +5142,7 @@ def baddbmm(self, batch1, batch2, beta=1, alpha=1):
alpha = int(alpha)
result = torch.bmm(batch1, batch2)
if not isinstance(alpha, numbers.Number) or alpha != 1:
# pyrefly: ignore # unsupported-operation
result = result * alpha
if beta == 0:
return result

View File

@ -245,6 +245,7 @@ def save_op_profiles(op_profiles: dict[str, set[OpProfile]], f: FileLike) -> Non
yaml_str = generate_yaml_from_profiles(op_profiles)
if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f)
with open(f, "w") as file:
@ -309,6 +310,7 @@ def load_op_profiles(f: FileLike) -> dict[str, set[OpProfile]]:
Loads the saved operator profiles from `save_op_profiles`.
"""
if isinstance(f, (str, os.PathLike)):
# pyrefly: ignore # no-matching-overload
f = os.fspath(f)
with open(f) as file:

View File

@ -159,6 +159,7 @@ def infer_schema(
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
seen_args.add(name)
if param.default is inspect.Parameter.empty:
# pyrefly: ignore # bad-argument-type
params.append(f"{schema_type} {name}")
else:
default_repr = None
@ -176,6 +177,7 @@ def infer_schema(
f"Parameter {name} has an unsupported default value type {type(param.default)}. "
f"Please file an issue on GitHub so we can prioritize this."
)
# pyrefly: ignore # bad-argument-type
params.append(f"{schema_type} {name}={default_repr}")
if mutates_args != UNKNOWN_MUTATES:
mutates_args_not_seen = set(mutates_args) - seen_args
@ -202,6 +204,7 @@ def derived_types(
):
result: list[tuple[Union[type, typing._SpecialForm, GenericAlias], str]] = [
(base_type, cpp_type),
# pyrefly: ignore # not-a-type
(typing.Optional[base_type], f"{cpp_type}?"),
]
@ -220,6 +223,7 @@ def derived_types(
if optional_base_list:
result.extend(
(seq_typ, f"{cpp_type}?[]")
# pyrefly: ignore # not-a-type
for seq_typ in derived_seq_types(typing.Optional[base_type])
)
if optional_list_base:
@ -273,6 +277,7 @@ def parse_return(annotation, error_fn):
f"Return has unsupported type {annotation}. "
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
)
# pyrefly: ignore # index-error
return SUPPORTED_RETURN_TYPES[annotation]
args = typing.get_args(annotation)

View File

@ -2341,16 +2341,19 @@ def calc_conv_nd_return_shape(
ret_shape = [input_tensor.shape[0], out_channels]
if isinstance(stride, IntLike):
# pyrefly: ignore # bad-assignment
stride = [stride] * len(dims)
elif len(stride) == 1:
stride = [stride[0]] * len(dims)
if isinstance(padding, IntLike):
# pyrefly: ignore # bad-assignment
padding = [padding] * len(dims)
elif len(padding) == 1:
padding = [padding[0]] * len(dims)
if isinstance(dilation, IntLike):
# pyrefly: ignore # bad-assignment
dilation = [dilation] * len(dims)
elif len(dilation) == 1:
dilation = [dilation[0]] * len(dims)
@ -2358,6 +2361,7 @@ def calc_conv_nd_return_shape(
output_padding_list: Optional[list[int]] = None
if output_padding:
if isinstance(output_padding, IntLike):
# pyrefly: ignore # bad-assignment
output_padding_list = [output_padding] * len(dims)
elif len(output_padding) == 1:
output_padding_list = [output_padding[0]] * len(dims)
@ -2370,15 +2374,19 @@ def calc_conv_nd_return_shape(
ret_shape.append(
_formula_transposed(
dims[i],
# pyrefly: ignore # index-error
padding[i],
# pyrefly: ignore # index-error
dilation[i],
kernel_size[i],
# pyrefly: ignore # index-error
stride[i],
output_padding_list[i],
)
)
else:
ret_shape.append(
# pyrefly: ignore # index-error
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
)
from torch.fx.experimental.symbolic_shapes import sym_or
@ -3444,6 +3452,7 @@ def meta_index_Tensor(self, indices):
"""
shape = before_shape + replacement_shape + after_shape
strides = list(self.stride())
# pyrefly: ignore # unsupported-operation
strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
replacement_shape
)
@ -6655,6 +6664,7 @@ def rnn_cell_checkSizes(
)
torch._check(
all(
# pyrefly: ignore # missing-attribute
x.device == input_gates.device
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
),

View File

@ -427,6 +427,7 @@ def _prim_elementwise_meta(
# Acquires the device (if it exists) or number
device = None
number = None
# pyrefly: ignore # bad-assignment
for arg in args_:
if isinstance(arg, TensorLike):
if utils.is_cpu_scalar_tensor(arg):
@ -1015,8 +1016,10 @@ def _div_aten(a, b):
)
if is_integral:
# pyrefly: ignore # bad-argument-type
return torch.div(a, b, rounding_mode="trunc")
else:
# pyrefly: ignore # bad-argument-type
return torch.true_divide(a, b)

View File

@ -125,6 +125,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
# Unless we are in prims_mode, in which case we want to use nvprims
if orig_func in torch_function_passthrough or orig_func in all_prims():
with self.prims_mode_cls():
# pyrefly: ignore # invalid-param-spec
return orig_func(*args, **kwargs)
mapping = torch_to_refs_map()
func = mapping.get(orig_func, None)
@ -147,6 +148,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
if func is not None:
# If the ref exists query whether we should use it or not
if self.should_fallback_fn(self, orig_func, func, args, kwargs):
# pyrefly: ignore # invalid-param-spec
return orig_func(*args, **kwargs)
# torch calls inside func should be interpreted as refs calls
with self:
@ -155,4 +157,5 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
raise RuntimeError(
f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
)
# pyrefly: ignore # invalid-param-spec
return orig_func(*args, **kwargs)

View File

@ -29,6 +29,7 @@ def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None):
rngprim_def = torch.library.custom_op(
"rngprims::" + name, impl_aten, mutates_args=(), schema=schema
)
# pyrefly: ignore # missing-attribute
rngprim_def.register_fake(impl_meta)
prim_packet = getattr(torch._ops.ops.rngprims, name)
@ -329,9 +330,11 @@ def register_graphsafe_run_with_rng_state_op():
@graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA)
def impl_cuda(op, *args, rng_state=None, **kwargs):
# pyrefly: ignore # missing-attribute
device_idx = rng_state.device.index
generator = torch.cuda.default_generators[device_idx]
current_state = generator.graphsafe_get_state()
# pyrefly: ignore # bad-argument-type
generator.graphsafe_set_state(rng_state)
out = op(*args, **kwargs)
generator.graphsafe_set_state(current_state)

View File

@ -113,6 +113,7 @@ def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:
if len(a) != len(b):
return False
# pyrefly: ignore # bad-assignment
for x, y in zip(a, b):
if allow_rhs_unbacked:
if isinstance(y, torch.SymInt):
@ -389,7 +390,11 @@ def validate_memory_format(memory_format: torch.memory_format):
def is_contiguous_for_memory_format( # type: ignore[return]
a: Tensor, *, memory_format: torch.memory_format, false_if_dde=False
a: Tensor,
*,
memory_format: torch.memory_format,
false_if_dde=False,
# pyrefly: ignore # bad-return
) -> bool:
validate_memory_format(memory_format)
@ -810,12 +815,16 @@ def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
# mapping negative offsets to positive ones
@overload
def canonicalize_dims(
rank: int, indices: Sequence[int], wrap_scalar: bool = True
rank: int,
indices: Sequence[int],
wrap_scalar: bool = True,
# pyrefly: ignore # bad-return
) -> tuple[int, ...]:
pass
@overload
# pyrefly: ignore # bad-return
def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:
pass
@ -862,6 +871,7 @@ def check_same_device(*args, allow_cpu_scalar_tensors):
# Note: cannot initialize device to the first arg's device (it may not have one)
device = None
# pyrefly: ignore # bad-assignment
for arg in args:
if isinstance(arg, Number):
continue
@ -909,6 +919,7 @@ def check_same_shape(*args, allow_cpu_scalar_tensors: bool):
"""
shape = None
# pyrefly: ignore # bad-assignment
for arg in args:
if isinstance(arg, Number):
continue
@ -935,6 +946,7 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
shape = None
scalar_shape = None
# pyrefly: ignore # bad-assignment
for arg in args:
if isinstance(arg, Number):
continue
@ -991,6 +1003,7 @@ def extract_shape_from_varargs(
# Handles tuple unwrapping
if len(shape) == 1 and isinstance(shape[0], Sequence):
# pyrefly: ignore # bad-assignment
shape = shape[0]
if validate:
@ -1292,6 +1305,7 @@ def get_higher_dtype(
raise RuntimeError("Unexpected type given to _extract_dtype!")
# pyrefly: ignore # bad-argument-type
a, b = _extract_dtype(a), _extract_dtype(b)
if a is b:
@ -1387,6 +1401,7 @@ def check_same_dtype(*args):
full_dtype = None
scalar_type = None
# pyrefly: ignore # bad-assignment
for arg in args:
if isinstance(arg, Number):
# Scalar type checking is disabled (and may be removed in the future)
@ -1657,8 +1672,10 @@ def elementwise_dtypes(
# Prefers dtype of tensors with one or more dimensions
if one_plus_dim_tensor_dtype is not None:
# pyrefly: ignore # bad-return
return one_plus_dim_tensor_dtype
# pyrefly: ignore # bad-return
return zero_dim_tensor_dtype
if highest_type is float:

View File

@ -28,16 +28,19 @@ _P = ParamSpec("_P")
@overload
# pyrefly: ignore # bad-return
def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
pass
@overload
# pyrefly: ignore # bad-return
def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
pass
@overload
# pyrefly: ignore # bad-return
def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
pass
@ -276,7 +279,10 @@ def out_wrapper(
TensorLikeType
if is_tensor
else NamedTuple(
f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names]
# pyrefly: ignore # bad-argument-count
f"return_types_{fn.__name__}",
# pyrefly: ignore # bad-argument-count
[(o, TensorLikeType) for o in out_names],
)
)
@ -294,6 +300,7 @@ def out_wrapper(
kwargs[k] = out_attr
def maybe_check_copy_devices(out):
# pyrefly: ignore # unsupported-operation
if isinstance(out, TensorLike) and isinstance(args[0], TensorLike):
check_copy_devices(copy_from=args[0], copy_to=out)
@ -429,6 +436,7 @@ def backwards_not_supported(prim):
class BackwardsNotSupported(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
def forward(ctx, args_spec, *flat_args):
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
return redispatch_prim(args, kwargs)
@ -477,11 +485,14 @@ def elementwise_unary_scalar_wrapper(
dtype = utils.type_to_dtype(type(args[0]))
args_ = list(args)
args_[0] = torch.tensor(args[0], dtype=dtype)
# pyrefly: ignore # invalid-param-spec
result = fn(*args_, **kwargs)
assert isinstance(result, torch.Tensor)
return result.item()
# pyrefly: ignore # invalid-param-spec
return fn(*args, **kwargs)
_fn.__signature__ = sig # type: ignore[attr-defined]
# pyrefly: ignore # bad-return
return _fn

View File

@ -414,6 +414,7 @@ class _NnapiSerializer:
) # noqa: TRY002
return Operand(
shape=tuple(tensor.shape),
# pyrefly: ignore # bad-argument-type
op_type=op_type,
dim_order=dim_order,
scale=scale,
@ -1734,11 +1735,13 @@ class _NnapiSerializer:
for dim in (2, 3): # h, w indices
if image_oper.shape[dim] == 0:
if size_ctype.kind() != "NoneType":
# pyrefly: ignore # unsupported-operation
self.compute_operand_shape(out_id, dim, size_arg[dim - 2])
elif scale_ctype.kind() != "NoneType":
self.compute_operand_shape(
out_id,
dim,
# pyrefly: ignore # unsupported-operation
f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})",
)
else:

View File

@ -34,8 +34,11 @@ if _cudnn is not None:
def _init():
global __cudnn_version
if __cudnn_version is None:
# pyrefly: ignore # missing-attribute
__cudnn_version = _cudnn.getVersionInt()
# pyrefly: ignore # missing-attribute
runtime_version = _cudnn.getRuntimeVersion()
# pyrefly: ignore # missing-attribute
compile_version = _cudnn.getCompileVersion()
runtime_major, runtime_minor, _ = runtime_version
compile_major, compile_minor, _ = compile_version
@ -44,6 +47,7 @@ if _cudnn is not None:
# Not sure about MIOpen (ROCm), so always do a strict check
if runtime_major != compile_major:
cudnn_compatible = False
# pyrefly: ignore # missing-attribute
elif runtime_major < 7 or not _cudnn.is_cuda:
cudnn_compatible = runtime_minor == compile_minor
else:

View File

@ -12,12 +12,16 @@ except ImportError:
def get_cudnn_mode(mode):
if mode == "RNN_RELU":
# pyrefly: ignore # missing-attribute
return int(_cudnn.RNNMode.rnn_relu)
elif mode == "RNN_TANH":
# pyrefly: ignore # missing-attribute
return int(_cudnn.RNNMode.rnn_tanh)
elif mode == "LSTM":
# pyrefly: ignore # missing-attribute
return int(_cudnn.RNNMode.lstm)
elif mode == "GRU":
# pyrefly: ignore # missing-attribute
return int(_cudnn.RNNMode.gru)
else:
raise Exception(f"Unknown mode: {mode}") # noqa: TRY002
@ -56,6 +60,7 @@ def init_dropout_state(dropout, train, dropout_seed, dropout_state):
dropout_p,
train,
dropout_seed,
# pyrefly: ignore # unexpected-keyword
self_ty=torch.uint8,
device=torch.device("cuda"),
)

View File

@ -23,6 +23,7 @@ if _cusparselt is not None:
global __cusparselt_version
global __MAX_ALG_ID
if __cusparselt_version is None:
# pyrefly: ignore # missing-attribute
__cusparselt_version = _cusparselt.getVersionInt()
if __cusparselt_version == 400:
__MAX_ALG_ID = 4

View File

@ -70,6 +70,7 @@ def _set_strategy(_strategy: str) -> None:
def _get_strategy() -> str:
# pyrefly: ignore # bad-return
return strategy

View File

@ -835,6 +835,7 @@ def create_args(parser=None):
@retval ArgumentParser
"""
# pyrefly: ignore # missing-attribute
parser.add_argument(
"--multi-instance",
"--multi_instance",
@ -843,6 +844,7 @@ def create_args(parser=None):
help="Enable multi-instance, by default one instance per node",
)
# pyrefly: ignore # missing-attribute
parser.add_argument(
"-m",
"--module",
@ -853,6 +855,7 @@ def create_args(parser=None):
'"python -m".',
)
# pyrefly: ignore # missing-attribute
parser.add_argument(
"--no-python",
"--no_python",
@ -867,6 +870,7 @@ def create_args(parser=None):
_add_multi_instance_params(parser)
# positional
# pyrefly: ignore # missing-attribute
parser.add_argument(
"program",
type=str,
@ -875,6 +879,7 @@ def create_args(parser=None):
)
# rest from the training program
# pyrefly: ignore # missing-attribute
parser.add_argument("program_args", nargs=REMAINDER)

View File

@ -230,6 +230,7 @@ class SchemaMatcher:
for schema in cls.match_schemas(t):
mutable = mutable or [False for _ in schema.arguments]
for i, arg in enumerate(schema.arguments):
# pyrefly: ignore # unsupported-operation
mutable[i] |= getattr(arg.alias_info, "is_write", False)
return tuple(mutable or (None for _ in t.inputs))
@ -671,6 +672,7 @@ class MemoryProfile:
output: list[tuple[int, Action, KeyAndID, int]] = []
allocation_times: dict[tuple[TensorKey, bool], int] = {}
live_unknown: dict[tuple[int, torch.device], Literal[True]] = {}
# pyrefly: ignore # bad-assignment
for event in self._op_tree.dfs():
if event.typed[0] == _EventType.Allocation:
alloc_fields = event.typed[1]
@ -772,11 +774,14 @@ class MemoryProfile:
for key, (_, version) in node.inputs.items()
if self._categories.get(key, version)
in (Category.GRADIENT, Category.PARAMETER)
# pyrefly: ignore # unsupported-operation
or key.id in depends_on_gradient
)
if ids:
# pyrefly: ignore # missing-attribute
depends_on_gradient.update(ids)
# pyrefly: ignore # missing-attribute
depends_on_gradient.update(key.id for key in node.outputs)
# We are guaranteed to exit because there is a finite set of
@ -785,6 +790,7 @@ class MemoryProfile:
# once to fold the first step into that loop, and a third time
# where no new elements are added.
if len(depends_on_gradient) == start_size:
# pyrefly: ignore # bad-return
return depends_on_gradient
def _set_gradients_and_temporaries(self) -> None:
@ -1081,6 +1087,7 @@ class MemoryProfileTimeline:
if action in (Action.PREEXISTING, Action.CREATE):
raw_events.append(
# pyrefly: ignore # bad-argument-type
(
t,
_ACTION_TO_INDEX[action],
@ -1091,6 +1098,7 @@ class MemoryProfileTimeline:
elif action == Action.INCREMENT_VERSION:
raw_events.append(
# pyrefly: ignore # bad-argument-type
(
t,
_ACTION_TO_INDEX[action],
@ -1099,6 +1107,7 @@ class MemoryProfileTimeline:
)
)
raw_events.append(
# pyrefly: ignore # bad-argument-type
(
t,
_ACTION_TO_INDEX[action],
@ -1109,6 +1118,7 @@ class MemoryProfileTimeline:
elif action == Action.DESTROY:
raw_events.append(
# pyrefly: ignore # bad-argument-type
(
t,
_ACTION_TO_INDEX[action],

View File

@ -211,6 +211,7 @@ class BasicEvaluation:
# Find latest cuda kernel event
if hasattr(event, "start_us"):
start_time = event.start_us() * 1000
# pyrefly: ignore # missing-attribute
end_time = (event.start_us() + event.duration_us()) * 1000
# Find current spawned cuda kernel event
if event in kernel_mapping and kernel_mapping[event] is not None:

View File

@ -161,14 +161,19 @@ class _KinetoProfile:
self.mem_tl: Optional[MemoryProfileTimeline] = None
self.use_device = None
if ProfilerActivity.CUDA in self.activities:
# pyrefly: ignore # bad-assignment
self.use_device = "cuda"
elif ProfilerActivity.XPU in self.activities:
# pyrefly: ignore # bad-assignment
self.use_device = "xpu"
elif ProfilerActivity.MTIA in self.activities:
# pyrefly: ignore # bad-assignment
self.use_device = "mtia"
elif ProfilerActivity.HPU in self.activities:
# pyrefly: ignore # bad-assignment
self.use_device = "hpu"
elif ProfilerActivity.PrivateUse1 in self.activities:
# pyrefly: ignore # bad-assignment
self.use_device = _get_privateuse1_backend_name()
# user-defined metadata to be amended to the trace
@ -380,6 +385,7 @@ class _KinetoProfile:
}
if backend == "nccl":
nccl_version = torch.cuda.nccl.version()
# pyrefly: ignore # unsupported-operation
dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
return dist_info

View File

@ -623,17 +623,20 @@ def as_sparse_gradcheck(gradcheck):
)
obj = obj.to_dense().sparse_mask(full_mask)
if obj.layout is torch.sparse_coo:
# pyrefly: ignore # no-matching-overload
d.update(
indices=obj._indices(), is_coalesced=obj.is_coalesced()
)
values = obj._values()
elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
# pyrefly: ignore # no-matching-overload
d.update(
compressed_indices=obj.crow_indices(),
plain_indices=obj.col_indices(),
)
values = obj.values()
else:
# pyrefly: ignore # no-matching-overload
d.update(
compressed_indices=obj.ccol_indices(),
plain_indices=obj.row_indices(),

View File

@ -385,6 +385,7 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
g1 = c_offsets[r + 1]
for g in range(g0, g1):
p, q = pq[g]
# pyrefly: ignore # unsupported-operation
accumulators[r] += blocks[p] @ others[q]
else:
_scatter_mm2(blocks, others, c_offsets, pq, accumulators)
@ -1296,6 +1297,7 @@ def bsr_dense_addmm(
assert alpha != 0
def kernel(grid, *sliced_tensors):
# pyrefly: ignore # unsupported-operation
_bsr_strided_addmm_kernel[grid](
*ptr_stride_extractor(*sliced_tensors),
beta,
@ -1425,6 +1427,7 @@ if has_triton():
mat1_block = tl.load(
mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :],
# pyrefly: ignore # index-error
mask=mask_k[None, :],
other=0.0,
)
@ -1433,6 +1436,7 @@ if has_triton():
mat2_block_ptrs
+ mat2_tiled_col_stride * col_block
+ mat2_row_block_stride * k_offsets[:, None],
# pyrefly: ignore # index-error
mask=mask_k[:, None],
other=0.0,
)
@ -1970,6 +1974,7 @@ if has_triton():
if attn_mask.dtype is not torch.bool:
check_dtype(f_name, attn_mask, query.dtype)
# pyrefly: ignore # not-callable
sdpa = sampled_addmm(
attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False
)
@ -1981,8 +1986,10 @@ if has_triton():
)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
sdpa.values().mul_(scale_factor)
# pyrefly: ignore # not-callable
sdpa = bsr_softmax(sdpa)
torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
# pyrefly: ignore # not-callable
sdpa = bsr_dense_mm(sdpa, value)
return sdpa

View File

@ -232,8 +232,10 @@ def dump():
part2 = current_content[end_data_index:]
data_part = []
for op_key in sorted(_operation_device_version_data, key=sort_key):
# pyrefly: ignore # bad-argument-type
data_part.append(" " + repr(op_key).replace("'", '"') + ": {")
op_data = _operation_device_version_data[op_key]
# pyrefly: ignore # bad-argument-type
data_part.extend(f" {key}: {op_data[key]}," for key in sorted(op_data))
data_part.append(" },")
new_content = part1 + "\n".join(data_part) + "\n" + part2
@ -367,6 +369,7 @@ def minimize(
if next_target < minimal_target:
minimal_target = next_target
parameters = next_parameters
# pyrefly: ignore # unsupported-operation
pbar.total += i + 1
break
else:

View File

@ -1,5 +1,7 @@
from torch._C import FileCheck as FileCheck
from . import _utils
# pyrefly: ignore # deprecated
from ._comparison import assert_allclose, assert_close as assert_close
from ._creation import make_tensor as make_tensor

View File

@ -241,6 +241,7 @@ def make_scalar_mismatch_msg(
Defaults to "Scalars".
"""
abs_diff = abs(actual - expected)
# pyrefly: ignore # bad-argument-type
rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected)
return _make_mismatch_msg(
default_identifier="Scalars",
@ -484,6 +485,7 @@ class BooleanPair(Pair):
def _supported_types(self) -> tuple[type, ...]:
cls: list[type] = [bool]
if HAS_NUMPY:
# pyrefly: ignore # missing-attribute
cls.append(np.bool_)
return tuple(cls)
@ -499,6 +501,7 @@ class BooleanPair(Pair):
def _to_bool(self, bool_like: Any, *, id: tuple[Any, ...]) -> bool:
if isinstance(bool_like, bool):
return bool_like
# pyrefly: ignore # missing-attribute
elif isinstance(bool_like, np.bool_):
return bool_like.item()
else:
@ -578,6 +581,7 @@ class NumberPair(Pair):
def _supported_types(self) -> tuple[type, ...]:
cls = list(self._NUMBER_TYPES)
if HAS_NUMPY:
# pyrefly: ignore # missing-attribute
cls.append(np.number)
return tuple(cls)
@ -593,6 +597,7 @@ class NumberPair(Pair):
def _to_number(
self, number_like: Any, *, id: tuple[Any, ...]
) -> Union[int, float, complex]:
# pyrefly: ignore # missing-attribute
if HAS_NUMPY and isinstance(number_like, np.number):
return number_like.item()
elif isinstance(number_like, self._NUMBER_TYPES):
@ -1115,6 +1120,7 @@ def originate_pairs(
mapping_types: tuple[type, ...] = (collections.abc.Mapping,),
id: tuple[Any, ...] = (),
**options: Any,
# pyrefly: ignore # bad-return
) -> list[Pair]:
"""Originates pairs from the individual inputs.
@ -1310,7 +1316,9 @@ def not_close_error_metas(
# would not get freed until cycle collection, leaking cuda memory in tests.
# We break the cycle by removing the reference to the error_meta objects
# from this frame as it returns.
# pyrefly: ignore # bad-assignment
error_metas = [error_metas]
# pyrefly: ignore # bad-return
return error_metas.pop()