mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add pyrefly suppressions 2/n (#164513)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check --- step 1: uncomment lines in the `pyrefly.toml` file before: https://gist.github.com/maggiemoss/911b4d0bc88bf8cf3ab91f67184e9d46 after: ``` INFO Checking project configured at `/Users/maggiemoss/python_projects/pytorch/pyrefly.toml` INFO 0 errors (1,152 ignored) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/164513 Approved by: https://github.com/oulgen
This commit is contained in:
parent
d1cbb74fb1
commit
1051c1de5c
10
pyrefly.toml
10
pyrefly.toml
|
|
@ -40,16 +40,6 @@ project-excludes = [
|
|||
"torch/autograd/**",
|
||||
"torch/cuda/**",
|
||||
"torch/export/**",
|
||||
"torch/profiler/**",
|
||||
"torch/_prims_common/**",
|
||||
"torch/backends/**",
|
||||
# "torch/testing/**",
|
||||
"torch/_C/**",
|
||||
"torch/sparse/**",
|
||||
"torch/_library/**",
|
||||
"torch/_prims/**",
|
||||
"torch/_decomp/**",
|
||||
"torch/_meta_registrations.py",
|
||||
# formatting issues
|
||||
"torch/linalg/__init__.py",
|
||||
"torch/package/importer.py",
|
||||
|
|
|
|||
|
|
@ -112,17 +112,28 @@ class DebugLevel(Enum):
|
|||
DETAIL = ...
|
||||
|
||||
class ReduceOp:
|
||||
# pyrefly: ignore # unknown-name
|
||||
def __init__(self, op: RedOpType) -> None: ...
|
||||
|
||||
# pyrefly: ignore # unknown-name
|
||||
SUM: RedOpType = ...
|
||||
# pyrefly: ignore # unknown-name
|
||||
AVG: RedOpType = ...
|
||||
# pyrefly: ignore # unknown-name
|
||||
PRODUCT: RedOpType = ...
|
||||
# pyrefly: ignore # unknown-name
|
||||
MIN: RedOpType = ...
|
||||
# pyrefly: ignore # unknown-name
|
||||
MAX: RedOpType = ...
|
||||
# pyrefly: ignore # unknown-name
|
||||
BAND: RedOpType = ...
|
||||
# pyrefly: ignore # unknown-name
|
||||
BOR: RedOpType = ...
|
||||
# pyrefly: ignore # unknown-name
|
||||
BXOR: RedOpType = ...
|
||||
# pyrefly: ignore # unknown-name
|
||||
PREMUL_SUM: RedOpType = ...
|
||||
# pyrefly: ignore # unknown-name
|
||||
UNUSED: RedOpType = ...
|
||||
|
||||
# mypy error being ignored:
|
||||
|
|
|
|||
|
|
@ -240,6 +240,7 @@ def get_decompositions(
|
|||
|
||||
registry = global_decomposition_table[type]
|
||||
packets_to_overloads = defaultdict(list)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for opo in registry:
|
||||
if isinstance(opo, (OpOverload, OpOverloadPacket)):
|
||||
packets_to_overloads[opo.overloadpacket].append(opo)
|
||||
|
|
|
|||
|
|
@ -382,6 +382,7 @@ def to_real_dtype(dtype: torch.dtype):
|
|||
def mse_loss(
|
||||
self: Tensor, target: Tensor, reduction: int = Reduction.MEAN.value
|
||||
) -> Tensor:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
loss = (self - target) ** 2
|
||||
return apply_loss_reduction(loss, reduction)
|
||||
|
||||
|
|
@ -415,6 +416,7 @@ def smooth_l1_loss(
|
|||
beta: float = 1.0,
|
||||
):
|
||||
loss = (self - target).abs()
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
|
||||
return apply_loss_reduction(loss, reduction)
|
||||
|
||||
|
|
@ -4893,7 +4895,9 @@ def _reflection_pad_backward(grad_output, x, padding):
|
|||
@register_decomposition(aten.aminmax)
|
||||
@out_wrapper("min", "max")
|
||||
def aminmax(self, *, dim=None, keepdim=False):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
amin = torch.amin(self, dim=dim, keepdim=keepdim)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
amax = torch.amax(self, dim=dim, keepdim=keepdim)
|
||||
return amin, amax
|
||||
|
||||
|
|
@ -5138,6 +5142,7 @@ def baddbmm(self, batch1, batch2, beta=1, alpha=1):
|
|||
alpha = int(alpha)
|
||||
result = torch.bmm(batch1, batch2)
|
||||
if not isinstance(alpha, numbers.Number) or alpha != 1:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
result = result * alpha
|
||||
if beta == 0:
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -245,6 +245,7 @@ def save_op_profiles(op_profiles: dict[str, set[OpProfile]], f: FileLike) -> Non
|
|||
yaml_str = generate_yaml_from_profiles(op_profiles)
|
||||
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
f = os.fspath(f)
|
||||
|
||||
with open(f, "w") as file:
|
||||
|
|
@ -309,6 +310,7 @@ def load_op_profiles(f: FileLike) -> dict[str, set[OpProfile]]:
|
|||
Loads the saved operator profiles from `save_op_profiles`.
|
||||
"""
|
||||
if isinstance(f, (str, os.PathLike)):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
f = os.fspath(f)
|
||||
|
||||
with open(f) as file:
|
||||
|
|
|
|||
|
|
@ -159,6 +159,7 @@ def infer_schema(
|
|||
schema_type = f"Tensor(a{idx}!){schema_type[len('Tensor') :]}"
|
||||
seen_args.add(name)
|
||||
if param.default is inspect.Parameter.empty:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
params.append(f"{schema_type} {name}")
|
||||
else:
|
||||
default_repr = None
|
||||
|
|
@ -176,6 +177,7 @@ def infer_schema(
|
|||
f"Parameter {name} has an unsupported default value type {type(param.default)}. "
|
||||
f"Please file an issue on GitHub so we can prioritize this."
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
params.append(f"{schema_type} {name}={default_repr}")
|
||||
if mutates_args != UNKNOWN_MUTATES:
|
||||
mutates_args_not_seen = set(mutates_args) - seen_args
|
||||
|
|
@ -202,6 +204,7 @@ def derived_types(
|
|||
):
|
||||
result: list[tuple[Union[type, typing._SpecialForm, GenericAlias], str]] = [
|
||||
(base_type, cpp_type),
|
||||
# pyrefly: ignore # not-a-type
|
||||
(typing.Optional[base_type], f"{cpp_type}?"),
|
||||
]
|
||||
|
||||
|
|
@ -220,6 +223,7 @@ def derived_types(
|
|||
if optional_base_list:
|
||||
result.extend(
|
||||
(seq_typ, f"{cpp_type}?[]")
|
||||
# pyrefly: ignore # not-a-type
|
||||
for seq_typ in derived_seq_types(typing.Optional[base_type])
|
||||
)
|
||||
if optional_list_base:
|
||||
|
|
@ -273,6 +277,7 @@ def parse_return(annotation, error_fn):
|
|||
f"Return has unsupported type {annotation}. "
|
||||
f"The valid types are: {SUPPORTED_RETURN_TYPES}."
|
||||
)
|
||||
# pyrefly: ignore # index-error
|
||||
return SUPPORTED_RETURN_TYPES[annotation]
|
||||
|
||||
args = typing.get_args(annotation)
|
||||
|
|
|
|||
|
|
@ -2341,16 +2341,19 @@ def calc_conv_nd_return_shape(
|
|||
|
||||
ret_shape = [input_tensor.shape[0], out_channels]
|
||||
if isinstance(stride, IntLike):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
stride = [stride] * len(dims)
|
||||
elif len(stride) == 1:
|
||||
stride = [stride[0]] * len(dims)
|
||||
|
||||
if isinstance(padding, IntLike):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
padding = [padding] * len(dims)
|
||||
elif len(padding) == 1:
|
||||
padding = [padding[0]] * len(dims)
|
||||
|
||||
if isinstance(dilation, IntLike):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
dilation = [dilation] * len(dims)
|
||||
elif len(dilation) == 1:
|
||||
dilation = [dilation[0]] * len(dims)
|
||||
|
|
@ -2358,6 +2361,7 @@ def calc_conv_nd_return_shape(
|
|||
output_padding_list: Optional[list[int]] = None
|
||||
if output_padding:
|
||||
if isinstance(output_padding, IntLike):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
output_padding_list = [output_padding] * len(dims)
|
||||
elif len(output_padding) == 1:
|
||||
output_padding_list = [output_padding[0]] * len(dims)
|
||||
|
|
@ -2370,15 +2374,19 @@ def calc_conv_nd_return_shape(
|
|||
ret_shape.append(
|
||||
_formula_transposed(
|
||||
dims[i],
|
||||
# pyrefly: ignore # index-error
|
||||
padding[i],
|
||||
# pyrefly: ignore # index-error
|
||||
dilation[i],
|
||||
kernel_size[i],
|
||||
# pyrefly: ignore # index-error
|
||||
stride[i],
|
||||
output_padding_list[i],
|
||||
)
|
||||
)
|
||||
else:
|
||||
ret_shape.append(
|
||||
# pyrefly: ignore # index-error
|
||||
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import sym_or
|
||||
|
|
@ -3444,6 +3452,7 @@ def meta_index_Tensor(self, indices):
|
|||
"""
|
||||
shape = before_shape + replacement_shape + after_shape
|
||||
strides = list(self.stride())
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
|
||||
replacement_shape
|
||||
)
|
||||
|
|
@ -6655,6 +6664,7 @@ def rnn_cell_checkSizes(
|
|||
)
|
||||
torch._check(
|
||||
all(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
x.device == input_gates.device
|
||||
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
|
||||
),
|
||||
|
|
|
|||
|
|
@ -427,6 +427,7 @@ def _prim_elementwise_meta(
|
|||
# Acquires the device (if it exists) or number
|
||||
device = None
|
||||
number = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for arg in args_:
|
||||
if isinstance(arg, TensorLike):
|
||||
if utils.is_cpu_scalar_tensor(arg):
|
||||
|
|
@ -1015,8 +1016,10 @@ def _div_aten(a, b):
|
|||
)
|
||||
|
||||
if is_integral:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch.div(a, b, rounding_mode="trunc")
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
return torch.true_divide(a, b)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -125,6 +125,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
|||
# Unless we are in prims_mode, in which case we want to use nvprims
|
||||
if orig_func in torch_function_passthrough or orig_func in all_prims():
|
||||
with self.prims_mode_cls():
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
return orig_func(*args, **kwargs)
|
||||
mapping = torch_to_refs_map()
|
||||
func = mapping.get(orig_func, None)
|
||||
|
|
@ -147,6 +148,7 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
|||
if func is not None:
|
||||
# If the ref exists query whether we should use it or not
|
||||
if self.should_fallback_fn(self, orig_func, func, args, kwargs):
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
return orig_func(*args, **kwargs)
|
||||
# torch calls inside func should be interpreted as refs calls
|
||||
with self:
|
||||
|
|
@ -155,4 +157,5 @@ class TorchRefsMode(torch.overrides.TorchFunctionMode):
|
|||
raise RuntimeError(
|
||||
f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
|
||||
)
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
return orig_func(*args, **kwargs)
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None):
|
|||
rngprim_def = torch.library.custom_op(
|
||||
"rngprims::" + name, impl_aten, mutates_args=(), schema=schema
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
rngprim_def.register_fake(impl_meta)
|
||||
|
||||
prim_packet = getattr(torch._ops.ops.rngprims, name)
|
||||
|
|
@ -329,9 +330,11 @@ def register_graphsafe_run_with_rng_state_op():
|
|||
|
||||
@graphsafe_run_with_rng_state.py_impl(DispatchKey.CUDA)
|
||||
def impl_cuda(op, *args, rng_state=None, **kwargs):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
device_idx = rng_state.device.index
|
||||
generator = torch.cuda.default_generators[device_idx]
|
||||
current_state = generator.graphsafe_get_state()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
generator.graphsafe_set_state(rng_state)
|
||||
out = op(*args, **kwargs)
|
||||
generator.graphsafe_set_state(current_state)
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ def same_shape(a: ShapeType, b: ShapeType, *, allow_rhs_unbacked=False) -> bool:
|
|||
if len(a) != len(b):
|
||||
return False
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for x, y in zip(a, b):
|
||||
if allow_rhs_unbacked:
|
||||
if isinstance(y, torch.SymInt):
|
||||
|
|
@ -389,7 +390,11 @@ def validate_memory_format(memory_format: torch.memory_format):
|
|||
|
||||
|
||||
def is_contiguous_for_memory_format( # type: ignore[return]
|
||||
a: Tensor, *, memory_format: torch.memory_format, false_if_dde=False
|
||||
a: Tensor,
|
||||
*,
|
||||
memory_format: torch.memory_format,
|
||||
false_if_dde=False,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> bool:
|
||||
validate_memory_format(memory_format)
|
||||
|
||||
|
|
@ -810,12 +815,16 @@ def canonicalize_dim(rank: int, idx: int, wrap_scalar: bool = True) -> int:
|
|||
# mapping negative offsets to positive ones
|
||||
@overload
|
||||
def canonicalize_dims(
|
||||
rank: int, indices: Sequence[int], wrap_scalar: bool = True
|
||||
rank: int,
|
||||
indices: Sequence[int],
|
||||
wrap_scalar: bool = True,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> tuple[int, ...]:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # bad-return
|
||||
def canonicalize_dims(rank: int, indices: int, wrap_scalar: bool = True) -> int:
|
||||
pass
|
||||
|
||||
|
|
@ -862,6 +871,7 @@ def check_same_device(*args, allow_cpu_scalar_tensors):
|
|||
|
||||
# Note: cannot initialize device to the first arg's device (it may not have one)
|
||||
device = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for arg in args:
|
||||
if isinstance(arg, Number):
|
||||
continue
|
||||
|
|
@ -909,6 +919,7 @@ def check_same_shape(*args, allow_cpu_scalar_tensors: bool):
|
|||
"""
|
||||
shape = None
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for arg in args:
|
||||
if isinstance(arg, Number):
|
||||
continue
|
||||
|
|
@ -935,6 +946,7 @@ def extract_shape(*args, allow_cpu_scalar_tensors: bool) -> Optional[ShapeType]:
|
|||
shape = None
|
||||
scalar_shape = None
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for arg in args:
|
||||
if isinstance(arg, Number):
|
||||
continue
|
||||
|
|
@ -991,6 +1003,7 @@ def extract_shape_from_varargs(
|
|||
|
||||
# Handles tuple unwrapping
|
||||
if len(shape) == 1 and isinstance(shape[0], Sequence):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
shape = shape[0]
|
||||
|
||||
if validate:
|
||||
|
|
@ -1292,6 +1305,7 @@ def get_higher_dtype(
|
|||
|
||||
raise RuntimeError("Unexpected type given to _extract_dtype!")
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
a, b = _extract_dtype(a), _extract_dtype(b)
|
||||
|
||||
if a is b:
|
||||
|
|
@ -1387,6 +1401,7 @@ def check_same_dtype(*args):
|
|||
full_dtype = None
|
||||
scalar_type = None
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for arg in args:
|
||||
if isinstance(arg, Number):
|
||||
# Scalar type checking is disabled (and may be removed in the future)
|
||||
|
|
@ -1657,8 +1672,10 @@ def elementwise_dtypes(
|
|||
|
||||
# Prefers dtype of tensors with one or more dimensions
|
||||
if one_plus_dim_tensor_dtype is not None:
|
||||
# pyrefly: ignore # bad-return
|
||||
return one_plus_dim_tensor_dtype
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
return zero_dim_tensor_dtype
|
||||
|
||||
if highest_type is float:
|
||||
|
|
|
|||
|
|
@ -28,16 +28,19 @@ _P = ParamSpec("_P")
|
|||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # bad-return
|
||||
def _maybe_convert_to_dtype(a: TensorLikeType, dtype: torch.dtype) -> TensorLikeType:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # bad-return
|
||||
def _maybe_convert_to_dtype(a: NumberType, dtype: torch.dtype) -> NumberType:
|
||||
pass
|
||||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # bad-return
|
||||
def _maybe_convert_to_dtype(a: Sequence, dtype: torch.dtype) -> Sequence:
|
||||
pass
|
||||
|
||||
|
|
@ -276,7 +279,10 @@ def out_wrapper(
|
|||
TensorLikeType
|
||||
if is_tensor
|
||||
else NamedTuple(
|
||||
f"return_types_{fn.__name__}", [(o, TensorLikeType) for o in out_names]
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
f"return_types_{fn.__name__}",
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
[(o, TensorLikeType) for o in out_names],
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -294,6 +300,7 @@ def out_wrapper(
|
|||
kwargs[k] = out_attr
|
||||
|
||||
def maybe_check_copy_devices(out):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
if isinstance(out, TensorLike) and isinstance(args[0], TensorLike):
|
||||
check_copy_devices(copy_from=args[0], copy_to=out)
|
||||
|
||||
|
|
@ -429,6 +436,7 @@ def backwards_not_supported(prim):
|
|||
|
||||
class BackwardsNotSupported(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
def forward(ctx, args_spec, *flat_args):
|
||||
args, kwargs = tree_unflatten(flat_args, args_spec) # type: ignore[arg-type]
|
||||
return redispatch_prim(args, kwargs)
|
||||
|
|
@ -477,11 +485,14 @@ def elementwise_unary_scalar_wrapper(
|
|||
dtype = utils.type_to_dtype(type(args[0]))
|
||||
args_ = list(args)
|
||||
args_[0] = torch.tensor(args[0], dtype=dtype)
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
result = fn(*args_, **kwargs)
|
||||
assert isinstance(result, torch.Tensor)
|
||||
return result.item()
|
||||
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
_fn.__signature__ = sig # type: ignore[attr-defined]
|
||||
# pyrefly: ignore # bad-return
|
||||
return _fn
|
||||
|
|
|
|||
|
|
@ -414,6 +414,7 @@ class _NnapiSerializer:
|
|||
) # noqa: TRY002
|
||||
return Operand(
|
||||
shape=tuple(tensor.shape),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
op_type=op_type,
|
||||
dim_order=dim_order,
|
||||
scale=scale,
|
||||
|
|
@ -1734,11 +1735,13 @@ class _NnapiSerializer:
|
|||
for dim in (2, 3): # h, w indices
|
||||
if image_oper.shape[dim] == 0:
|
||||
if size_ctype.kind() != "NoneType":
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
self.compute_operand_shape(out_id, dim, size_arg[dim - 2])
|
||||
elif scale_ctype.kind() != "NoneType":
|
||||
self.compute_operand_shape(
|
||||
out_id,
|
||||
dim,
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})",
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -34,8 +34,11 @@ if _cudnn is not None:
|
|||
def _init():
|
||||
global __cudnn_version
|
||||
if __cudnn_version is None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
__cudnn_version = _cudnn.getVersionInt()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
runtime_version = _cudnn.getRuntimeVersion()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
compile_version = _cudnn.getCompileVersion()
|
||||
runtime_major, runtime_minor, _ = runtime_version
|
||||
compile_major, compile_minor, _ = compile_version
|
||||
|
|
@ -44,6 +47,7 @@ if _cudnn is not None:
|
|||
# Not sure about MIOpen (ROCm), so always do a strict check
|
||||
if runtime_major != compile_major:
|
||||
cudnn_compatible = False
|
||||
# pyrefly: ignore # missing-attribute
|
||||
elif runtime_major < 7 or not _cudnn.is_cuda:
|
||||
cudnn_compatible = runtime_minor == compile_minor
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -12,12 +12,16 @@ except ImportError:
|
|||
|
||||
def get_cudnn_mode(mode):
|
||||
if mode == "RNN_RELU":
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return int(_cudnn.RNNMode.rnn_relu)
|
||||
elif mode == "RNN_TANH":
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return int(_cudnn.RNNMode.rnn_tanh)
|
||||
elif mode == "LSTM":
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return int(_cudnn.RNNMode.lstm)
|
||||
elif mode == "GRU":
|
||||
# pyrefly: ignore # missing-attribute
|
||||
return int(_cudnn.RNNMode.gru)
|
||||
else:
|
||||
raise Exception(f"Unknown mode: {mode}") # noqa: TRY002
|
||||
|
|
@ -56,6 +60,7 @@ def init_dropout_state(dropout, train, dropout_seed, dropout_state):
|
|||
dropout_p,
|
||||
train,
|
||||
dropout_seed,
|
||||
# pyrefly: ignore # unexpected-keyword
|
||||
self_ty=torch.uint8,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ if _cusparselt is not None:
|
|||
global __cusparselt_version
|
||||
global __MAX_ALG_ID
|
||||
if __cusparselt_version is None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
__cusparselt_version = _cusparselt.getVersionInt()
|
||||
if __cusparselt_version == 400:
|
||||
__MAX_ALG_ID = 4
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ def _set_strategy(_strategy: str) -> None:
|
|||
|
||||
|
||||
def _get_strategy() -> str:
|
||||
# pyrefly: ignore # bad-return
|
||||
return strategy
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -835,6 +835,7 @@ def create_args(parser=None):
|
|||
|
||||
@retval ArgumentParser
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
parser.add_argument(
|
||||
"--multi-instance",
|
||||
"--multi_instance",
|
||||
|
|
@ -843,6 +844,7 @@ def create_args(parser=None):
|
|||
help="Enable multi-instance, by default one instance per node",
|
||||
)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--module",
|
||||
|
|
@ -853,6 +855,7 @@ def create_args(parser=None):
|
|||
'"python -m".',
|
||||
)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
parser.add_argument(
|
||||
"--no-python",
|
||||
"--no_python",
|
||||
|
|
@ -867,6 +870,7 @@ def create_args(parser=None):
|
|||
|
||||
_add_multi_instance_params(parser)
|
||||
# positional
|
||||
# pyrefly: ignore # missing-attribute
|
||||
parser.add_argument(
|
||||
"program",
|
||||
type=str,
|
||||
|
|
@ -875,6 +879,7 @@ def create_args(parser=None):
|
|||
)
|
||||
|
||||
# rest from the training program
|
||||
# pyrefly: ignore # missing-attribute
|
||||
parser.add_argument("program_args", nargs=REMAINDER)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -230,6 +230,7 @@ class SchemaMatcher:
|
|||
for schema in cls.match_schemas(t):
|
||||
mutable = mutable or [False for _ in schema.arguments]
|
||||
for i, arg in enumerate(schema.arguments):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
mutable[i] |= getattr(arg.alias_info, "is_write", False)
|
||||
|
||||
return tuple(mutable or (None for _ in t.inputs))
|
||||
|
|
@ -671,6 +672,7 @@ class MemoryProfile:
|
|||
output: list[tuple[int, Action, KeyAndID, int]] = []
|
||||
allocation_times: dict[tuple[TensorKey, bool], int] = {}
|
||||
live_unknown: dict[tuple[int, torch.device], Literal[True]] = {}
|
||||
# pyrefly: ignore # bad-assignment
|
||||
for event in self._op_tree.dfs():
|
||||
if event.typed[0] == _EventType.Allocation:
|
||||
alloc_fields = event.typed[1]
|
||||
|
|
@ -772,11 +774,14 @@ class MemoryProfile:
|
|||
for key, (_, version) in node.inputs.items()
|
||||
if self._categories.get(key, version)
|
||||
in (Category.GRADIENT, Category.PARAMETER)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
or key.id in depends_on_gradient
|
||||
)
|
||||
|
||||
if ids:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
depends_on_gradient.update(ids)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
depends_on_gradient.update(key.id for key in node.outputs)
|
||||
|
||||
# We are guaranteed to exit because there is a finite set of
|
||||
|
|
@ -785,6 +790,7 @@ class MemoryProfile:
|
|||
# once to fold the first step into that loop, and a third time
|
||||
# where no new elements are added.
|
||||
if len(depends_on_gradient) == start_size:
|
||||
# pyrefly: ignore # bad-return
|
||||
return depends_on_gradient
|
||||
|
||||
def _set_gradients_and_temporaries(self) -> None:
|
||||
|
|
@ -1081,6 +1087,7 @@ class MemoryProfileTimeline:
|
|||
|
||||
if action in (Action.PREEXISTING, Action.CREATE):
|
||||
raw_events.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
(
|
||||
t,
|
||||
_ACTION_TO_INDEX[action],
|
||||
|
|
@ -1091,6 +1098,7 @@ class MemoryProfileTimeline:
|
|||
|
||||
elif action == Action.INCREMENT_VERSION:
|
||||
raw_events.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
(
|
||||
t,
|
||||
_ACTION_TO_INDEX[action],
|
||||
|
|
@ -1099,6 +1107,7 @@ class MemoryProfileTimeline:
|
|||
)
|
||||
)
|
||||
raw_events.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
(
|
||||
t,
|
||||
_ACTION_TO_INDEX[action],
|
||||
|
|
@ -1109,6 +1118,7 @@ class MemoryProfileTimeline:
|
|||
|
||||
elif action == Action.DESTROY:
|
||||
raw_events.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
(
|
||||
t,
|
||||
_ACTION_TO_INDEX[action],
|
||||
|
|
|
|||
|
|
@ -211,6 +211,7 @@ class BasicEvaluation:
|
|||
# Find latest cuda kernel event
|
||||
if hasattr(event, "start_us"):
|
||||
start_time = event.start_us() * 1000
|
||||
# pyrefly: ignore # missing-attribute
|
||||
end_time = (event.start_us() + event.duration_us()) * 1000
|
||||
# Find current spawned cuda kernel event
|
||||
if event in kernel_mapping and kernel_mapping[event] is not None:
|
||||
|
|
|
|||
|
|
@ -161,14 +161,19 @@ class _KinetoProfile:
|
|||
self.mem_tl: Optional[MemoryProfileTimeline] = None
|
||||
self.use_device = None
|
||||
if ProfilerActivity.CUDA in self.activities:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.use_device = "cuda"
|
||||
elif ProfilerActivity.XPU in self.activities:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.use_device = "xpu"
|
||||
elif ProfilerActivity.MTIA in self.activities:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.use_device = "mtia"
|
||||
elif ProfilerActivity.HPU in self.activities:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.use_device = "hpu"
|
||||
elif ProfilerActivity.PrivateUse1 in self.activities:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
self.use_device = _get_privateuse1_backend_name()
|
||||
|
||||
# user-defined metadata to be amended to the trace
|
||||
|
|
@ -380,6 +385,7 @@ class _KinetoProfile:
|
|||
}
|
||||
if backend == "nccl":
|
||||
nccl_version = torch.cuda.nccl.version()
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
dist_info["nccl_version"] = ".".join(str(v) for v in nccl_version)
|
||||
return dist_info
|
||||
|
||||
|
|
|
|||
|
|
@ -623,17 +623,20 @@ def as_sparse_gradcheck(gradcheck):
|
|||
)
|
||||
obj = obj.to_dense().sparse_mask(full_mask)
|
||||
if obj.layout is torch.sparse_coo:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
d.update(
|
||||
indices=obj._indices(), is_coalesced=obj.is_coalesced()
|
||||
)
|
||||
values = obj._values()
|
||||
elif obj.layout in {torch.sparse_csr, torch.sparse_bsr}:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
d.update(
|
||||
compressed_indices=obj.crow_indices(),
|
||||
plain_indices=obj.col_indices(),
|
||||
)
|
||||
values = obj.values()
|
||||
else:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
d.update(
|
||||
compressed_indices=obj.ccol_indices(),
|
||||
plain_indices=obj.row_indices(),
|
||||
|
|
|
|||
|
|
@ -385,6 +385,7 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
|
|||
g1 = c_offsets[r + 1]
|
||||
for g in range(g0, g1):
|
||||
p, q = pq[g]
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
accumulators[r] += blocks[p] @ others[q]
|
||||
else:
|
||||
_scatter_mm2(blocks, others, c_offsets, pq, accumulators)
|
||||
|
|
@ -1296,6 +1297,7 @@ def bsr_dense_addmm(
|
|||
assert alpha != 0
|
||||
|
||||
def kernel(grid, *sliced_tensors):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
_bsr_strided_addmm_kernel[grid](
|
||||
*ptr_stride_extractor(*sliced_tensors),
|
||||
beta,
|
||||
|
|
@ -1425,6 +1427,7 @@ if has_triton():
|
|||
|
||||
mat1_block = tl.load(
|
||||
mat1_block_ptrs + mat1_col_block_stride * k_offsets[None, :],
|
||||
# pyrefly: ignore # index-error
|
||||
mask=mask_k[None, :],
|
||||
other=0.0,
|
||||
)
|
||||
|
|
@ -1433,6 +1436,7 @@ if has_triton():
|
|||
mat2_block_ptrs
|
||||
+ mat2_tiled_col_stride * col_block
|
||||
+ mat2_row_block_stride * k_offsets[:, None],
|
||||
# pyrefly: ignore # index-error
|
||||
mask=mask_k[:, None],
|
||||
other=0.0,
|
||||
)
|
||||
|
|
@ -1970,6 +1974,7 @@ if has_triton():
|
|||
if attn_mask.dtype is not torch.bool:
|
||||
check_dtype(f_name, attn_mask, query.dtype)
|
||||
|
||||
# pyrefly: ignore # not-callable
|
||||
sdpa = sampled_addmm(
|
||||
attn_mask, query, key.transpose(-2, -1), beta=0.0, skip_checks=False
|
||||
)
|
||||
|
|
@ -1981,8 +1986,10 @@ if has_triton():
|
|||
)
|
||||
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
|
||||
sdpa.values().mul_(scale_factor)
|
||||
# pyrefly: ignore # not-callable
|
||||
sdpa = bsr_softmax(sdpa)
|
||||
torch.nn.functional.dropout(sdpa.values(), p=dropout_p, inplace=True)
|
||||
# pyrefly: ignore # not-callable
|
||||
sdpa = bsr_dense_mm(sdpa, value)
|
||||
return sdpa
|
||||
|
||||
|
|
|
|||
|
|
@ -232,8 +232,10 @@ def dump():
|
|||
part2 = current_content[end_data_index:]
|
||||
data_part = []
|
||||
for op_key in sorted(_operation_device_version_data, key=sort_key):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
data_part.append(" " + repr(op_key).replace("'", '"') + ": {")
|
||||
op_data = _operation_device_version_data[op_key]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
data_part.extend(f" {key}: {op_data[key]}," for key in sorted(op_data))
|
||||
data_part.append(" },")
|
||||
new_content = part1 + "\n".join(data_part) + "\n" + part2
|
||||
|
|
@ -367,6 +369,7 @@ def minimize(
|
|||
if next_target < minimal_target:
|
||||
minimal_target = next_target
|
||||
parameters = next_parameters
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
pbar.total += i + 1
|
||||
break
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from torch._C import FileCheck as FileCheck
|
||||
|
||||
from . import _utils
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
from ._comparison import assert_allclose, assert_close as assert_close
|
||||
from ._creation import make_tensor as make_tensor
|
||||
|
|
|
|||
|
|
@ -241,6 +241,7 @@ def make_scalar_mismatch_msg(
|
|||
Defaults to "Scalars".
|
||||
"""
|
||||
abs_diff = abs(actual - expected)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
rel_diff = float("inf") if expected == 0 else abs_diff / abs(expected)
|
||||
return _make_mismatch_msg(
|
||||
default_identifier="Scalars",
|
||||
|
|
@ -484,6 +485,7 @@ class BooleanPair(Pair):
|
|||
def _supported_types(self) -> tuple[type, ...]:
|
||||
cls: list[type] = [bool]
|
||||
if HAS_NUMPY:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
cls.append(np.bool_)
|
||||
return tuple(cls)
|
||||
|
||||
|
|
@ -499,6 +501,7 @@ class BooleanPair(Pair):
|
|||
def _to_bool(self, bool_like: Any, *, id: tuple[Any, ...]) -> bool:
|
||||
if isinstance(bool_like, bool):
|
||||
return bool_like
|
||||
# pyrefly: ignore # missing-attribute
|
||||
elif isinstance(bool_like, np.bool_):
|
||||
return bool_like.item()
|
||||
else:
|
||||
|
|
@ -578,6 +581,7 @@ class NumberPair(Pair):
|
|||
def _supported_types(self) -> tuple[type, ...]:
|
||||
cls = list(self._NUMBER_TYPES)
|
||||
if HAS_NUMPY:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
cls.append(np.number)
|
||||
return tuple(cls)
|
||||
|
||||
|
|
@ -593,6 +597,7 @@ class NumberPair(Pair):
|
|||
def _to_number(
|
||||
self, number_like: Any, *, id: tuple[Any, ...]
|
||||
) -> Union[int, float, complex]:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
if HAS_NUMPY and isinstance(number_like, np.number):
|
||||
return number_like.item()
|
||||
elif isinstance(number_like, self._NUMBER_TYPES):
|
||||
|
|
@ -1115,6 +1120,7 @@ def originate_pairs(
|
|||
mapping_types: tuple[type, ...] = (collections.abc.Mapping,),
|
||||
id: tuple[Any, ...] = (),
|
||||
**options: Any,
|
||||
# pyrefly: ignore # bad-return
|
||||
) -> list[Pair]:
|
||||
"""Originates pairs from the individual inputs.
|
||||
|
||||
|
|
@ -1310,7 +1316,9 @@ def not_close_error_metas(
|
|||
# would not get freed until cycle collection, leaking cuda memory in tests.
|
||||
# We break the cycle by removing the reference to the error_meta objects
|
||||
# from this frame as it returns.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
error_metas = [error_metas]
|
||||
# pyrefly: ignore # bad-return
|
||||
return error_metas.pop()
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user