Fix pyrefly error syntax (2/n) (#166448)

Ensrues pyrefly ignores only silence one error code.

After this, only ~40 files left to clean up .

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166448
Approved by: https://github.com/Skylion007
This commit is contained in:
Maggie Moss 2025-10-29 00:36:37 +00:00 committed by PyTorch MergeBot
parent 56afad4eb3
commit 84fe848503
50 changed files with 217 additions and 217 deletions

View File

@ -880,14 +880,14 @@ def logsumexp(
if not isinstance(dim, Iterable):
dim = (dim,)
if self.numel() == 0:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return torch.sum(torch.exp(self), dim, keepdim).log()
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
maxes = torch.amax(torch.real(self), dim, keepdim=True)
maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
result = torch.sum(torch.exp(self - maxes), dim, keepdim)
return result.log().add(maxes_squeezed)
@ -1245,12 +1245,12 @@ def copysign(
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
):
if isinstance(b, Number) and isinstance(a, Tensor):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
raise RuntimeError(msg)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return where(signbit(b), neg(abs(a)), abs(a))
@ -1342,7 +1342,7 @@ def float_power(
b = _maybe_convert_to_dtype(b, dtype)
a, b = _maybe_broadcast(a, b)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return pow(a, b)
@ -1384,15 +1384,15 @@ def floor_divide(
):
# Wrap scalars because some references only accept tensor arguments.
if isinstance(a, Number) and isinstance(b, Number):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
a = scalar_tensor(a)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
b = scalar_tensor(b)
elif isinstance(b, Number) and isinstance(a, Tensor):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
elif isinstance(a, Number) and isinstance(b, Tensor):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
if a.device == torch.device("cpu"):
@ -1869,10 +1869,10 @@ def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberT
# Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
if isinstance(b, TensorLike) and isinstance(a, Number):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
elif isinstance(a, TensorLike) and isinstance(b, Number):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
# mypy: expected "Tensor"
@ -2865,7 +2865,7 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
# SymInts
example = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for i, t in enumerate(tensors):
if example is None:
if t.ndim != 1:
@ -3358,7 +3358,7 @@ def native_layer_norm(
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
# therefore we use tuple(normalized_shape)
torch._check(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
weight is None or sym_eq(weight.shape, tuple(normalized_shape)),
lambda: "Expected weight to be of same shape as normalized_shape, but got "
+ "weight of shape "
@ -3367,7 +3367,7 @@ def native_layer_norm(
+ str(normalized_shape),
)
torch._check(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
bias is None or sym_eq(bias.shape, tuple(normalized_shape)),
lambda: "Expected bias to be of same shape as normalized_shape, but got "
+ "bias of shape "
@ -3379,7 +3379,7 @@ def native_layer_norm(
input.ndim >= normalized_ndim
and sym_eq(
input.shape[(input.ndim - normalized_ndim) :],
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
tuple(normalized_shape),
),
lambda: "Given normalized_shape="
@ -3988,16 +3988,16 @@ def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLike
# Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
return a.clone()
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if a.dim() == 0 and len(dims) > 0:
raise IndexError(
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
f"Dimension specified as {dims[0]} but tensor has no dimensions"
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
len_shifts = len(shifts)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
len_dims = len(dims)
if len_shifts != 1 or len_dims != 1:
if len_shifts == 0:
@ -4005,27 +4005,27 @@ def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLike
# Takes care of the case when dims is not specified (default)
# By default, the tensor is flattened before shifting, after which the original shape is restored
if len_dims == 0 and len_shifts == 1:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
if len_shifts != len_dims:
raise RuntimeError(
f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
)
assert len_dims > 1
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
tail_shifts = shifts[1:]
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
tail_dims = dims[1:]
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
first_dim_rolled = torch.roll(a, (shifts[0],), dims[0])
return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
# This path is taken when only one dimension is rolled
# For example to get `first_dim_rolled` above
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
dim = dims[0]
size = a.shape[dim]
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
start = (size - shifts[0]) % size
idx = torch.arange(size, device=a.device)
return a.index_select(dim, torch.fmod(start + idx, size))
@ -4107,7 +4107,7 @@ def softmax(
a_max = amax(a_, dim, keepdim=True)
a_exp = exp(a_ - a_max)
return _maybe_convert_to_dtype(
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
true_divide(a_exp, sum(a_exp, dim, keepdim=True)),
result_dtype,
) # type: ignore[return-value]
@ -4427,7 +4427,7 @@ def hsplit(
if isinstance(indices_or_sections, IntLike):
split_size = indices_or_sections
torch._check(
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
(split_size != 0 and a.shape[dim] % split_size == 0),
lambda: (
"torch.hsplit attempted to split along dimension "
@ -4439,7 +4439,7 @@ def hsplit(
+ "!"
),
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return tensor_split(a, split_size, dim)
torch._check_type(
@ -4470,7 +4470,7 @@ def vsplit(
if isinstance(indices_or_sections, IntLike):
split_size = indices_or_sections
torch._check(
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
(split_size != 0 and a.shape[0] % split_size == 0),
lambda: (
f"torch.vsplit attempted to split along dimension 0"
@ -4481,7 +4481,7 @@ def vsplit(
f"!"
),
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return tensor_split(a, split_size, 0)
torch._check_type(
@ -4686,7 +4686,7 @@ def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
raise RuntimeError(
f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
raise RuntimeError(
"torch.dsplit attempted to split along dimension 2, "
@ -5460,7 +5460,7 @@ def logspace(
@overload
# pyrefly: ignore # inconsistent-overload
# pyrefly: ignore [inconsistent-overload]
def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
pass
@ -5887,7 +5887,7 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi
# Since `where` allows type-promotion,
# cast value to correct type before passing to `where`
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
value = _maybe_convert_to_dtype(value, a.dtype)
r = torch.where(mask, value, a) # type: ignore[arg-type]
@ -6720,7 +6720,7 @@ def _recursive_build(
# torch.Size([1, 2])
return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
elif isinstance(obj, Number):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch.scalar_tensor(obj, dtype=scalarType)
# seq can be a list of tensors

View File

@ -224,9 +224,9 @@ def matrix_norm(
len(dim) == 2, lambda: f"linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
)
torch._check(
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
dim[0] != dim[1],
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
lambda: f"linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
)
# dtype arg
@ -248,7 +248,7 @@ def matrix_norm(
else: # ord == "nuc"
if dtype is not None:
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
if keepdim:
@ -272,7 +272,7 @@ def matrix_norm(
if abs_ord == 2.0:
if dtype is not None:
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
if keepdim:
@ -280,7 +280,7 @@ def matrix_norm(
result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
return result
else: # 1, -1, inf, -inf
# pyrefly: ignore # bad-unpacking
# pyrefly: ignore [bad-unpacking]
dim0, dim1 = dim
if abs_ord == float("inf"):
dim0, dim1 = dim1, dim0

View File

@ -142,11 +142,11 @@ def _inplace_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]:
# nb. We use the name of the first argument used in the unary references
@wraps(fn)
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
a = args[0]
if "inplace" not in kwargs:
kwargs["inplace"] = False
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
if kwargs["inplace"]:
torch._check(
"out" not in kwargs,
@ -627,7 +627,7 @@ def smooth_l1_loss(
)
else:
loss = torch.abs(input - target)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
return _apply_loss_reduction(loss, reduction)

View File

@ -155,10 +155,10 @@ def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, Numbe
# Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
if isinstance(a, TensorLike) and isinstance(b, Number):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device)
elif isinstance(b, TensorLike) and isinstance(a, Number):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device)
# mypy: expected "Tensor"

View File

@ -314,7 +314,7 @@ def strobelight(
) -> Callable[_P, Optional[_R]]:
@functools.wraps(work_function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return profiler.profile(work_function, *args, **kwargs)
return wrapper_function

View File

@ -37,7 +37,7 @@ class _DeconstructedSymNode:
node.pytype,
node._hint,
node.constant,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
node.fx_node,
)

View File

@ -404,9 +404,9 @@ class FakeTensorConverter:
with no_dispatch():
return FakeTensor(
fake_mode,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
make_meta_t(),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
device,
# TODO: callback might be used in recursive contexts, in
# which case using t is wrong! BUG!
@ -681,7 +681,7 @@ class FakeTensor(Tensor):
_mode_key = torch._C._TorchDispatchModeKey.FAKE
@property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def device(self) -> torch.device:
if self.fake_mode.in_kernel_invocation:
return torch.device("meta")
@ -709,7 +709,7 @@ class FakeTensor(Tensor):
# We don't support named tensors; graph break
@property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def names(self) -> list[str]:
raise UnsupportedFakeTensorException(
"torch.compile doesn't support named tensors"
@ -768,7 +768,7 @@ class FakeTensor(Tensor):
)
else:
device = torch.device(f"{device.type}:0")
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.fake_device = device
self.fake_mode = fake_mode
self.constant = constant
@ -1374,7 +1374,7 @@ class FakeTensorMode(TorchDispatchMode):
return self._stack
@count
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def __torch_dispatch__(
self,
func: OpOverload,
@ -1499,7 +1499,7 @@ class FakeTensorMode(TorchDispatchMode):
# Do this dispatch outside the above except handler so if it
# generates its own exception there won't be a __context__ caused by
# the caching mechanism.
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return self._dispatch_impl(func, types, args, kwargs)
assert state is not None
@ -1517,27 +1517,27 @@ class FakeTensorMode(TorchDispatchMode):
# This represents a negative cache entry - we already saw that the
# output is uncachable. Compute it from first principals.
FakeTensorMode.cache_bypasses[entry.reason] += 1
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return self._dispatch_impl(func, types, args, kwargs)
# We have a cache entry.
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output = self._output_from_cache_entry(state, entry, key, func, args)
FakeTensorMode.cache_hits += 1
if self.cache_crosscheck_enabled:
# For debugging / testing: Validate that the output synthesized
# from the cache matches the output created by normal dispatch.
with disable_fake_tensor_cache(self):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self._crosscheck_cache_output(output, func, types, args, kwargs)
return output
# We don't have a cache entry.
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output = self._dispatch_impl(func, types, args, kwargs)
try:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
except _BypassDispatchCache as e:
# We ran "extra" checks on the cache key and determined that it's no
@ -1595,16 +1595,16 @@ class FakeTensorMode(TorchDispatchMode):
if state.known_symbols:
# If there are symbols then include the epoch - this is really more
# of a Shape env var which lives on the FakeTensorMode.
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
key_values.append(self.epoch)
# Collect the id_hashed objects to attach a weakref finalize later
id_hashed_objects: list[object] = []
# Translate any FakeTensor args to metadata.
if args:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self._prep_args_for_hash(key_values, args, state, id_hashed_objects)
if kwargs:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self._prep_args_for_hash(key_values, kwargs, state, id_hashed_objects)
key = _DispatchCacheKey(tuple(key_values))
@ -1922,7 +1922,7 @@ class FakeTensorMode(TorchDispatchMode):
self._validate_output_for_cache_entry(
state,
key,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
func,
args,
kwargs,
@ -1932,7 +1932,7 @@ class FakeTensorMode(TorchDispatchMode):
self._validate_output_for_cache_entry(
state,
key,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
func,
args,
kwargs,
@ -1944,7 +1944,7 @@ class FakeTensorMode(TorchDispatchMode):
self._get_output_info_for_cache_entry(
state,
key,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
func,
args,
kwargs,
@ -1953,7 +1953,7 @@ class FakeTensorMode(TorchDispatchMode):
for out_elem in output
]
return _DispatchCacheValidEntry(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_infos=tuple(output_infos),
is_output_tuple=True,
)
@ -1962,7 +1962,7 @@ class FakeTensorMode(TorchDispatchMode):
output_info = self._get_output_info_for_cache_entry(
state,
key,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
func,
args,
kwargs,
@ -2509,7 +2509,7 @@ class FakeTensorMode(TorchDispatchMode):
)
with self, maybe_ignore_fresh_unbacked_symbols():
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
return registered_hop_fake_fns[func](*args, **kwargs)
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
@ -2664,7 +2664,7 @@ class FakeTensorMode(TorchDispatchMode):
# TODO: Is this really needed?
compute_unbacked_bindings(self.shape_env, fake_out, peek=True)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return fake_out
# Try for fastpath
@ -2946,7 +2946,7 @@ class FakeTensorMode(TorchDispatchMode):
self, e, device or common_device
)
else:
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return e
return tree_map(wrap, r)

View File

@ -81,7 +81,7 @@ def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool:
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return t.grad
@ -416,7 +416,7 @@ class MetaTensorDescriber:
device=t.device,
size=t.size(),
stride=stride,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
storage_offset=storage_offset,
dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
dynamo_hint_overrides=getattr(t, "_dynamo_hint_overrides", {}),
@ -541,7 +541,7 @@ class _FakeTensorViewFunc(ViewFunc["FakeTensor"]):
tensor_visitor_fn: Optional[Callable[[torch.Tensor], FakeTensor]] = None,
) -> FakeTensor:
return torch._subclasses.fake_tensor.FakeTensor._view_func_unsafe(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
t,
new_base,
symint_visitor_fn,
@ -1019,7 +1019,7 @@ class MetaConverter(Generic[_TensorT]):
# Morally, the code here is same as transform_subclass, but we've
# written it from scratch to read EmptyCreateSubclass
outer_size = outer_size if outer_size is not None else t.size
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
outer_stride = outer_stride if outer_stride is not None else t.stride
assert symbolic_context is None or isinstance(
@ -1276,7 +1276,7 @@ class MetaConverter(Generic[_TensorT]):
) -> torch.Tensor:
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
if visited_t is None:
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return None
# NB: visited_t being a Tensor here is very naughty! Should
@ -1407,7 +1407,7 @@ class MetaConverter(Generic[_TensorT]):
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
r = self._backward_error(r)
elif t.is_nested and not t.is_traceable_wrapper_subclass:
# TODO: Handle this better in Dynamo?
@ -1446,7 +1446,7 @@ class MetaConverter(Generic[_TensorT]):
if t.requires_grad:
r.requires_grad = True
if t.requires_grad and not is_leaf:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
r = self._backward_error(r)
elif t.is_functorch_wrapped:
if t.is_view:
@ -1543,7 +1543,7 @@ class MetaConverter(Generic[_TensorT]):
)
assert t.data is not None
_safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return r
r = _to_fake_tensor(t)
@ -1693,7 +1693,7 @@ class MetaConverter(Generic[_TensorT]):
not (t.is_batchedtensor or t.is_gradtrackingtensor)
and t.is_functorch_wrapped
) or t.is_legacy_batchedtensor:
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return NotImplemented
(
@ -1740,7 +1740,7 @@ class MetaConverter(Generic[_TensorT]):
# the metadata of the inner tensor.
# So instead, we now have a dedicated fn to set autograd history,
# without inadvertently changing other metadata.
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
r = self._backward_error(r)
s = t.storage
@ -1820,7 +1820,7 @@ class MetaConverter(Generic[_TensorT]):
# TODO: Use a valid grad-specific symbolic context instead of recycling
# the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view().
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
r.grad = self.meta_tensor(
t.grad,
shape_env,
@ -1828,15 +1828,15 @@ class MetaConverter(Generic[_TensorT]):
AttrSource(source, "grad"),
symbolic_context,
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
torch._C._set_conj(r, t.is_conj)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
torch._C._set_neg(r, t.is_neg)
# This can be skipped if necessary for performance reasons
skip_leaf = (
t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
# Thanks to storage resizing, it's possible to end up with a tensor
# that advertises a real size, but has a storage that actually has zero bytes.
@ -1844,23 +1844,23 @@ class MetaConverter(Generic[_TensorT]):
from torch.fx.experimental.symbolic_shapes import guard_or_false
if t.storage is not None and guard_or_false(t.storage.size == 0):
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
r.untyped_storage().resize_(0)
if t.is_parameter:
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
r._is_param = True
# See Note: [Creating symbolic nested int]
if t.nested_int is not None:
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
assert _is_fake_tensor(r)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
nt_tensor_id=t.nested_int
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.set_tensor_memo(t, r)
return self._checked_get_tensor_memo(t)
@ -1904,13 +1904,13 @@ class MetaConverter(Generic[_TensorT]):
(t._is_view() and t._base is not None and t._base.is_sparse)
):
self.miss += 1
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return NotImplemented
else:
self.hit += 1
elif torch.overrides.is_tensor_like(t):
self.miss += 1
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return NotImplemented
else:
# non-Tensor types don't count as hit or miss

View File

@ -657,10 +657,10 @@ def _str_intern(inp, *, tensor_contents=None):
grad_fn_name = "Invalid"
if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
grad_fn_name = type(grad_fn).__name__
if grad_fn_name == "CppFunction":
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
if grad_fn_name is not None:

View File

@ -85,7 +85,7 @@ def compile_time_strobelight_meta(
@functools.wraps(function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T:
if "skip" in kwargs and isinstance(
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
skip := kwargs["skip"],
int,
):
@ -331,7 +331,7 @@ def deprecated():
# public deprecated alias
alias = typing_extensions.deprecated(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
warning_msg,
category=UserWarning,
stacklevel=1,

View File

@ -415,7 +415,7 @@ def _cast(value, device_type: str, dtype: _dtype):
return value
elif HAS_NUMPY and isinstance(
value,
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
np.ndarray,
):
return value

View File

@ -350,7 +350,7 @@ def backward(
Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,)
)
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
tensors = tuple(tensors)
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
@ -450,12 +450,12 @@ def grad(
Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
)
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
outputs = tuple(outputs)
if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
inputs = tuple(inputs)
t_outputs = tuple(i for i in outputs if is_tensor_like(i))
t_inputs = tuple(i for i in inputs if is_tensor_like(i))

View File

@ -15,14 +15,14 @@ class Type(Function):
"please use `torch.tensor.to(dtype=dtype)` instead.",
category=FutureWarning,
)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, i, dest_type):
ctx.input_type = type(i)
ctx.input_device = -1 if not i.is_cuda else i.get_device()
return i.type(dest_type)
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
if ctx.input_device == -1:
return grad_output.type(ctx.input_type), None
@ -34,7 +34,7 @@ class Type(Function):
# TODO: deprecate this
class Resize(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, tensor, sizes):
ctx.sizes = sizes
ctx.numel = reduce(operator.mul, sizes, 1)
@ -63,7 +63,7 @@ class Resize(Function):
return tensor.contiguous().view(*sizes)
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
if grad_output.numel() != ctx.numel:
raise AssertionError(

View File

@ -415,6 +415,6 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
def __enter__(self) -> None:
pass
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def __exit__(self, *args) -> None:
torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions)

View File

@ -10,7 +10,7 @@ from typing_extensions import deprecated
import torch
import torch.testing
# pyrefly: ignore # deprecated
# pyrefly: ignore [deprecated]
from torch._vmap_internals import _vmap, vmap
from torch.overrides import is_tensor_like
from torch.types import _TensorOrTensors

View File

@ -230,7 +230,7 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
# Note that output_nr default to 0 which is the right value
# for the AccumulateGrad node.
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return GradientEdge(grad_fn, tensor.output_nr, ownership_token=token)
@ -534,7 +534,7 @@ def register_multi_grad_hook(
"expected this hook to be called inside a backward call"
)
count[id] = count.get(id, 0)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
buffer[id] = buffer.get(id, [None] * len_tensors)
with lock:

View File

@ -744,7 +744,7 @@ class profile:
return all_function_events
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
class record_function(_ContextDecorator):
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
Label will only appear if CPU activity tracing is enabled.
@ -792,7 +792,7 @@ class record_function(_ContextDecorator):
# TODO: TorchScript ignores standard type annotation here
# self.record: Optional["torch.classes.profiler._RecordFunction"] = None
self.record = torch.jit.annotate(
# pyrefly: ignore # not-a-type
# pyrefly: ignore [not-a-type]
Optional["torch.classes.profiler._RecordFunction"],
None,
)

View File

@ -101,14 +101,14 @@ class profile:
records = _disable_profiler_legacy()
parsed_results = _parse_legacy_records(records)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.function_events = EventList(
parsed_results,
use_device="cuda" if self.use_cuda else None,
profile_memory=self.profile_memory,
with_flops=self.with_flops,
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.function_events._build_tree()
return False

View File

@ -414,7 +414,7 @@ class _NnapiSerializer:
) # noqa: TRY002
return Operand(
shape=tuple(tensor.shape),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
op_type=op_type,
dim_order=dim_order,
scale=scale,
@ -1735,13 +1735,13 @@ class _NnapiSerializer:
for dim in (2, 3): # h, w indices
if image_oper.shape[dim] == 0:
if size_ctype.kind() != "NoneType":
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
self.compute_operand_shape(out_id, dim, size_arg[dim - 2])
elif scale_ctype.kind() != "NoneType":
self.compute_operand_shape(
out_id,
dim,
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})",
)
else:

View File

@ -34,11 +34,11 @@ if _cudnn is not None:
def _init():
global __cudnn_version
if __cudnn_version is None:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
__cudnn_version = _cudnn.getVersionInt()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
runtime_version = _cudnn.getRuntimeVersion()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
compile_version = _cudnn.getCompileVersion()
runtime_major, runtime_minor, _ = runtime_version
compile_major, compile_minor, _ = compile_version
@ -47,7 +47,7 @@ if _cudnn is not None:
# Not sure about MIOpen (ROCm), so always do a strict check
if runtime_major != compile_major:
cudnn_compatible = False
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
elif runtime_major < 7 or not _cudnn.is_cuda:
cudnn_compatible = runtime_minor == compile_minor
else:

View File

@ -12,16 +12,16 @@ except ImportError:
def get_cudnn_mode(mode):
if mode == "RNN_RELU":
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return int(_cudnn.RNNMode.rnn_relu)
elif mode == "RNN_TANH":
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return int(_cudnn.RNNMode.rnn_tanh)
elif mode == "LSTM":
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return int(_cudnn.RNNMode.lstm)
elif mode == "GRU":
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return int(_cudnn.RNNMode.gru)
else:
raise Exception(f"Unknown mode: {mode}") # noqa: TRY002
@ -60,7 +60,7 @@ def init_dropout_state(dropout, train, dropout_seed, dropout_state):
dropout_p,
train,
dropout_seed,
# pyrefly: ignore # unexpected-keyword
# pyrefly: ignore [unexpected-keyword]
self_ty=torch.uint8,
device=torch.device("cuda"),
)

View File

@ -23,7 +23,7 @@ if _cusparselt is not None:
global __cusparselt_version
global __MAX_ALG_ID
if __cusparselt_version is None:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
__cusparselt_version = _cusparselt.getVersionInt()
if __cusparselt_version == 400:
__MAX_ALG_ID = 4

View File

@ -21,7 +21,7 @@ def is_available():
def is_acl_available():
r"""Return whether PyTorch is built with MKL-DNN + ACL support."""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return torch._C._has_mkldnn_acl

View File

@ -70,7 +70,7 @@ def _set_strategy(_strategy: str) -> None:
def _get_strategy() -> str:
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return strategy

View File

@ -835,7 +835,7 @@ def create_args(parser=None):
@retval ArgumentParser
"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
parser.add_argument(
"--multi-instance",
"--multi_instance",
@ -844,7 +844,7 @@ def create_args(parser=None):
help="Enable multi-instance, by default one instance per node",
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
parser.add_argument(
"-m",
"--module",
@ -855,7 +855,7 @@ def create_args(parser=None):
'"python -m".',
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
parser.add_argument(
"--no-python",
"--no_python",
@ -870,7 +870,7 @@ def create_args(parser=None):
_add_multi_instance_params(parser)
# positional
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
parser.add_argument(
"program",
type=str,
@ -879,7 +879,7 @@ def create_args(parser=None):
)
# rest from the training program
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
parser.add_argument("program_args", nargs=REMAINDER)

View File

@ -90,7 +90,7 @@ class CacheArtifactFactory:
@classmethod
def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact:
artifact_cls = cls._get_artifact_type(artifact_type_key)
# pyrefly: ignore # bad-instantiation
# pyrefly: ignore [bad-instantiation]
return artifact_cls(key, content)
@classmethod
@ -98,7 +98,7 @@ class CacheArtifactFactory:
cls, artifact_type_key: str, key: str, content: Any
) -> CacheArtifact:
artifact_cls = cls._get_artifact_type(artifact_type_key)
# pyrefly: ignore # bad-instantiation
# pyrefly: ignore [bad-instantiation]
return artifact_cls(key, artifact_cls.encode(content))

View File

@ -1,3 +1,3 @@
# pyrefly: ignore # deprecated
# pyrefly: ignore [deprecated]
from .autocast_mode import autocast
from .grad_scaler import GradScaler

View File

@ -501,14 +501,14 @@ class cudaStatus:
class CudaError(RuntimeError):
def __init__(self, code: int) -> None:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
super().__init__(f"{msg} ({code})")
def check_error(res: int) -> None:
r"""Raise an error if the result of a CUDA runtime API call is not success."""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if res != _cudart.cudaError.success:
raise CudaError(res)
@ -608,7 +608,7 @@ def get_device_capability(device: "Device" = None) -> tuple[int, int]:
return prop.major, prop.minor
# pyrefly: ignore # not-a-type
# pyrefly: ignore [not-a-type]
def get_device_properties(device: "Device" = None) -> _CudaDeviceProperties:
r"""Get the properties of a device.
@ -659,7 +659,7 @@ class StreamContext:
self.idx = _get_device_index(None, True)
if not torch.jit.is_scripting():
if self.idx is None:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.idx = -1
self.src_prev_stream = (
@ -964,9 +964,9 @@ def _device_count_amdsmi() -> int:
if raw_cnt <= 0:
return raw_cnt
# Trim the list up to a maximum available device
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
for idx, val in enumerate(visible_devices):
# pyrefly: ignore # redundant-cast
# pyrefly: ignore [redundant-cast]
if cast(int, val) >= raw_cnt:
return idx
except OSError:
@ -1000,9 +1000,9 @@ def _device_count_nvml() -> int:
if raw_cnt <= 0:
return raw_cnt
# Trim the list up to a maximum available device
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
for idx, val in enumerate(visible_devices):
# pyrefly: ignore # redundant-cast
# pyrefly: ignore [redundant-cast]
if cast(int, val) >= raw_cnt:
return idx
except OSError:
@ -1218,9 +1218,9 @@ def _get_pynvml_handler(device: "Device" = None):
if not _HAS_PYNVML:
raise ModuleNotFoundError(
"nvidia-ml-py does not seem to be installed or it can't be imported."
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
) from _PYNVML_ERR
# pyrefly: ignore # import-error
# pyrefly: ignore [import-error]
from pynvml import NVMLError_DriverNotLoaded
try:
@ -1237,7 +1237,7 @@ def _get_amdsmi_handler(device: "Device" = None):
if not _HAS_PYNVML:
raise ModuleNotFoundError(
"amdsmi does not seem to be installed or it can't be imported."
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
) from _PYNVML_ERR
try:
amdsmi.amdsmi_init()
@ -1501,7 +1501,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int
return default_generator.get_offset()
# pyrefly: ignore # deprecated
# pyrefly: ignore [deprecated]
from .memory import * # noqa: F403
from .random import * # noqa: F403
@ -1718,7 +1718,7 @@ def _register_triton_kernels():
def kernel_impl(*args, **kwargs):
from torch.sparse._triton_ops import bsr_dense_mm
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
return bsr_dense_mm(*args, skip_checks=True, **kwargs)
@_WrappedTritonKernel

View File

@ -279,7 +279,7 @@ class _CudaModule:
return self._kernels[name]
# Import the CUDA library inside the method
# pyrefly: ignore # missing-module-attribute
# pyrefly: ignore [missing-module-attribute]
from torch.cuda._utils import _get_gpu_runtime_library
libcuda = _get_gpu_runtime_library()

View File

@ -1,4 +1,4 @@
# pyrefly: ignore # deprecated
# pyrefly: ignore [deprecated]
from .autocast_mode import autocast, custom_bwd, custom_fwd
from .common import amp_definitely_not_available
from .grad_scaler import GradScaler

View File

@ -259,7 +259,7 @@ class graph:
self.cuda_graph.capture_begin(
# type: ignore[misc]
*self.pool,
# pyrefly: ignore # bad-keyword-argument
# pyrefly: ignore [bad-keyword-argument]
capture_error_mode=self.capture_error_mode,
)
@ -525,7 +525,7 @@ def make_graphed_callables(
) -> Callable[..., object]:
class Graphed(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
# At this stage, only the user args may (potentially) be new tensors.
for i in range(len_user_args):
@ -537,7 +537,7 @@ def make_graphed_callables(
@staticmethod
@torch.autograd.function.once_differentiable
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
assert len(grads) == len(static_grad_outputs)
for g, grad in zip(static_grad_outputs, grads):
@ -551,7 +551,7 @@ def make_graphed_callables(
# Input args that didn't require grad expect a None gradient.
assert isinstance(static_grad_inputs, tuple)
return tuple(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
b.detach() if b is not None else b
for b in static_grad_inputs
)

View File

@ -10,7 +10,7 @@ if hasattr(torch._C, "_CUDAGreenContext"):
# Python shim helps Sphinx process docstrings more reliably.
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
class GreenContext(_GreenContext):
r"""Wrapper around a CUDA green context.

View File

@ -772,7 +772,7 @@ def list_gpu_processes(device: "Device" = None) -> str:
import pynvml # type: ignore[import]
except ModuleNotFoundError:
return "pynvml module not found, please install nvidia-ml-py"
# pyrefly: ignore # import-error
# pyrefly: ignore [import-error]
from pynvml import NVMLError_DriverNotLoaded
try:
@ -855,7 +855,7 @@ def _record_memory_history_legacy(
_C._cuda_record_memory_history_legacy( # type: ignore[call-arg]
enabled,
record_context,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
trace_alloc_max_entries,
trace_alloc_record_context,
record_context_cpp,
@ -1076,7 +1076,7 @@ def _set_memory_metadata(metadata: str):
metadata (str): Custom metadata string to attach to allocations.
Pass an empty string to clear the metadata.
"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
torch._C._cuda_setMemoryMetadata(metadata)
@ -1087,7 +1087,7 @@ def _get_memory_metadata() -> str:
Returns:
str: The current metadata string, or empty string if no metadata is set.
"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return torch._C._cuda_getMemoryMetadata()
@ -1110,7 +1110,7 @@ def _save_memory_usage(filename="output.svg", snapshot=None):
category=FutureWarning,
)
def _set_allocator_settings(env: str):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return torch._C._accelerator_setAllocatorSettings(env)

View File

@ -53,7 +53,7 @@ def range_start(msg) -> int:
Args:
msg (str): ASCII message to associate with the range.
"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return _nvtx.rangeStartA(msg)
@ -64,7 +64,7 @@ def range_end(range_id) -> None:
Args:
range_id (int): an unique handle for the start range.
"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
_nvtx.rangeEnd(range_id)
@ -85,7 +85,7 @@ def _device_range_start(msg: str, stream: int = 0) -> object:
msg (str): ASCII message to associate with the range.
stream (int): CUDA stream id.
"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return _nvtx.deviceRangeStart(msg, stream)
@ -98,7 +98,7 @@ def _device_range_end(range_handle: object, stream: int = 0) -> None:
range_handle: an unique handle for the start range.
stream (int): CUDA stream id.
"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
_nvtx.deviceRangeEnd(range_handle, stream)

View File

@ -295,7 +295,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
self.logger.addHandler(self)
self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
torch._logging._internal.GET_DTRACE_STRUCTURED = True
return self
@ -303,7 +303,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
self.log_record = LogRecord()
self.expression_created_logs = {}
self.logger.removeHandler(self)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace
self.prev_get_dtrace = False

View File

@ -107,11 +107,11 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
return
if not (
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
arg.op == "call_function"
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
and arg.target == operator.getitem
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
and arg.args[1] == i
):
log.debug(

View File

@ -186,7 +186,7 @@ def _ignore_backend_decomps():
def _disable_custom_triton_op_functional_decomposition():
old = torch._functorch.config.decompose_custom_triton_ops
try:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
torch._functorch.config.decompose_custom_triton_ops = False
yield torch._functorch.config.decompose_custom_triton_ops
finally:
@ -365,7 +365,7 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
nn_module_stack = {
root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
**nn_module_stack,
}
node.meta["nn_module_stack"] = {
@ -687,7 +687,7 @@ def _restore_state_dict(
for name, _ in list(
chain(
original_module.named_parameters(remove_duplicate=False),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
original_module.named_buffers(remove_duplicate=False),
)
):

View File

@ -293,7 +293,7 @@ class _ExportPackage:
if isinstance(fn, torch.nn.Module):
dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type]
else:
# pyrefly: ignore # invalid-param-spec
# pyrefly: ignore [invalid-param-spec]
dynamic_shapes = v(*args, **kwargs)
except AssertionError:
continue
@ -341,7 +341,7 @@ class _ExportPackage:
assert not hasattr(fn, "_define_overload")
_exporter_context._define_overload = _define_overload # type: ignore[attr-defined]
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return _exporter_context
@property
@ -378,7 +378,7 @@ class _ExportPackage:
kwargs=ep.example_inputs[1],
options=options,
)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
aoti_files_map[name] = aoti_files
from torch._inductor.package import package

View File

@ -1500,7 +1500,7 @@ class ExportedProgram:
transformed_gm = res.graph_module if res is not None else self.graph_module
assert transformed_gm is not None
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if transformed_gm is self.graph_module and not res.modified:
return self
@ -1579,7 +1579,7 @@ class ExportedProgram:
verifiers=self.verifiers,
)
transformed_ep.graph_module.meta.update(self.graph_module.meta)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
transformed_ep.graph_module.meta.update(res.graph_module.meta)
return transformed_ep

View File

@ -81,7 +81,7 @@ def move_to_device_pass(
and node.target == torch.ops.aten.to.device
):
args = list(node.args)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
args[1] = _get_new_device(args[1], location)
node.args = tuple(args)

View File

@ -173,10 +173,10 @@ class PT2ArchiveWriter:
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
)
for file_path in file_paths:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
filename = os.path.relpath(file_path, folder_dir)
archive_path = os.path.join(archive_dir, filename)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.write_file(archive_path, file_path)
def close(self) -> None:
@ -696,7 +696,7 @@ def package_pt2(
if isinstance(f, (str, os.PathLike)):
f = os.fspath(f)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
with PT2ArchiveWriter(f) as archive_writer:
_package_exported_programs(
archive_writer, exported_programs, pickle_protocol=pickle_protocol
@ -711,7 +711,7 @@ def package_pt2(
if isinstance(f, (io.IOBase, IO)):
f.seek(0)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return f
@ -1098,7 +1098,7 @@ def load_pt2(
weights = {}
weight_maps = {}
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
with PT2ArchiveReader(f) as archive_reader:
version = archive_reader.read_string(ARCHIVE_VERSION_PATH)
if version != ARCHIVE_VERSION_VALUE:

View File

@ -25,15 +25,15 @@ class TensorProperties:
if not self.is_fake:
# only get the storage pointer for real tensors
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.storage_ptr = tensor.untyped_storage().data_ptr()
if self.is_contiguous:
# only get storage size and start/end pointers for contiguous tensors
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.storage_size = tensor.untyped_storage().nbytes()
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.start = tensor.data_ptr()
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.end = _end_ptr(tensor)
# info to recover tensor

View File

@ -67,7 +67,7 @@ class GraphPickler(pickle.Pickler):
self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
@override
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def reducer_override(
self, obj: object
) -> tuple[Callable[..., Any], tuple[Any, ...]]:
@ -204,7 +204,7 @@ class _SymNodePickleData:
]:
args = (cls(obj.node), pickler._unpickle_state)
if isinstance(obj, torch.SymInt):
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return _SymNodePickleData.unpickle_sym_int, args
else:
raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
@ -281,7 +281,7 @@ class _TensorPickleData:
return FakeTensor(
unpickle_state.fake_mode,
make_meta_t(),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
device,
)
@ -334,9 +334,9 @@ class _TorchNumpyPickleData:
if not (name := getattr(np, "__name__", None)):
return None
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
assert np == getattr(importlib.import_module(mod), name)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return cls(mod, name)

View File

@ -603,7 +603,7 @@ class Tracer(TracerBase):
in inspect.signature(self.create_proxy).parameters
):
kwargs["proxy_factory_fn"] = (
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
None
if not self.param_shapes_constant
else lambda node: ParameterProxy(

View File

@ -659,7 +659,7 @@ class Partitioner:
find_combination, partitions = find_partition_to_combine_based_on_size(
sorted_partitions,
available_mem_bytes,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
partitions,
)
return
@ -705,7 +705,7 @@ class Partitioner:
non_embedding_partitions.append(partition)
if new_partition:
partition = self.create_partition()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
partition.left_mem_bytes = available_mem_bytes
return partition
return None
@ -1001,7 +1001,7 @@ class Partitioner:
node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
)
if cost < min_cost:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
node_pair = [node, n1]
min_cost = cost
return cost, node_pair # type: ignore[possibly-undefined]

View File

@ -30,7 +30,7 @@ def split_result_tensors(
else:
splits = [x.shape[0] for x in inputs]
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch.split(result, splits)

View File

@ -178,7 +178,7 @@ class MetaTracer(torch.fx.Tracer):
kwargs,
name,
type_expr,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
proxy_factory_fn,
)
@ -201,7 +201,7 @@ class MetaTracer(torch.fx.Tracer):
if kind == "call_function":
meta_target = manual_meta_overrides.get(target, target)
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
meta_out = meta_target(*args_metas, **kwargs_metas)
elif kind == "call_method":
meta_target = getattr(args_metas[0], target) # type: ignore[index]

View File

@ -528,11 +528,11 @@ def view_inference_rule(n: Node, symbols, constraints, counter):
if t == -1:
var, counter = gen_dvar(counter)
t2_type.append(var)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
num_constraints.append(BinConstraintD(var, Dyn, op_neq))
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
num_constraints.append(BinConstraintD(t, Dyn, op_neq))
t2_type.append(t) # type: ignore[arg-type]
@ -1477,7 +1477,7 @@ class ConstraintGenerator:
all_constraints = []
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for n in graph.nodes:
(constraints, counter) = self.generate_constraints_node(n, counter)
all_constraints += constraints

View File

@ -193,7 +193,7 @@ def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
assert isinstance(node.target, str)
cur_module = modules[node.target]
if type(cur_module) in mkldnn_map:
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
assert isinstance(new_module, nn.Module)
old_modules[new_module] = copy.deepcopy(cur_module)
@ -266,7 +266,7 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
reset_modules(
submodule.graph.nodes,
dict(submodule.named_modules()),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
old_modules,
)
no_mkl_time = benchmark(lambda: submodule(*sample_inputs))

View File

@ -124,7 +124,7 @@ pytree.register_pytree_node(
torch.Size,
lambda xs: (list(xs), None),
lambda xs, _: tuple(xs),
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
flatten_with_keys_fn=lambda xs: (
[(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
None,
@ -310,7 +310,7 @@ def set_proxy_slot( # type: ignore[no-redef]
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
assert isinstance(obj, (Tensor, SymNode)), type(obj)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
@ -407,7 +407,7 @@ def get_proxy_slot(
assert isinstance(obj, py_sym_types), type(obj)
tracker = tracer.symnode_tracker
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
value = tracker.get(obj)
if value is None and isinstance(obj, py_sym_types):
@ -420,7 +420,7 @@ def get_proxy_slot(
else:
# Attempt to build it from first principles.
_build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
value = tracker.get(obj)
if value is None:
@ -1552,7 +1552,7 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
kwargs: Optional[dict[str, object]] = None,
) -> object:
kwargs = kwargs or {}
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.tracer.torch_fn_metadata = func
self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1
return func(*args, **kwargs)
@ -1602,7 +1602,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
# For autocast, the python APIs run so we don't have to run them again
# here.
if func is torch._C._set_grad_enabled:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
func(*args, **kwargs)
return node
@ -1821,7 +1821,7 @@ class DecompositionInterpreter(fx.Interpreter):
self.decomposition_table = decomposition_table or {}
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def placeholder(
self,
target: str, # type: ignore[override]
@ -1834,7 +1834,7 @@ class DecompositionInterpreter(fx.Interpreter):
# TODO handle case where the first character of target is '*'
return out
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def get_attr(
self,
target: str, # type: ignore[override]
@ -1848,7 +1848,7 @@ class DecompositionInterpreter(fx.Interpreter):
# call_function, call_method, call_module get traced automatically by the outer mode.
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def output(
self,
target: str, # type: ignore[override]
@ -1967,7 +1967,7 @@ class _ModuleStackTracer(PythonKeyTracer):
# Class is modified to be a subclass of torch.nn.Module
# Warning: We blow away our own attributes here to mimic the base class
# - so don't expect `self.x` to do anything useful.
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
self.__class__ = type(
base.__class__.__name__,
(self.__class__, base.__class__),
@ -1990,7 +1990,7 @@ class _ModuleStackTracer(PythonKeyTracer):
if not isinstance(attr_val, Module):
return attr_val
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name)
def get_base(self) -> Module:
@ -2003,12 +2003,12 @@ class _ModuleStackTracer(PythonKeyTracer):
res = torch.nn.Sequential(
OrderedDict(list(self._modules.items())[idx])
)
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
elif isinstance(self, torch.nn.ModuleList):
# Copied from nn/modules/container.py
res = torch.nn.ModuleList(list(self._modules.values())[idx])
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
return super().__getitem__(idx) # type: ignore[misc]