mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix pyrefly error syntax (2/n) (#166448)
Ensrues pyrefly ignores only silence one error code. After this, only ~40 files left to clean up . pyrefly check lintrunner Pull Request resolved: https://github.com/pytorch/pytorch/pull/166448 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
56afad4eb3
commit
84fe848503
|
|
@ -880,14 +880,14 @@ def logsumexp(
|
|||
if not isinstance(dim, Iterable):
|
||||
dim = (dim,)
|
||||
if self.numel() == 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return torch.sum(torch.exp(self), dim, keepdim).log()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
maxes = torch.amax(torch.real(self), dim, keepdim=True)
|
||||
maxes = torch.masked_fill(maxes, maxes.abs() == float("inf"), 0)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
maxes_squeezed = maxes if keepdim else torch.squeeze(maxes, dim)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
result = torch.sum(torch.exp(self - maxes), dim, keepdim)
|
||||
return result.log().add(maxes_squeezed)
|
||||
|
||||
|
|
@ -1245,12 +1245,12 @@ def copysign(
|
|||
a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberType]
|
||||
):
|
||||
if isinstance(b, Number) and isinstance(a, Tensor):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
|
||||
elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
|
||||
msg = f"Expected divisor (b) to be on the same device ({a.device}) as dividend (a), but it is found on {b.device}!"
|
||||
raise RuntimeError(msg)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return where(signbit(b), neg(abs(a)), abs(a))
|
||||
|
||||
|
||||
|
|
@ -1342,7 +1342,7 @@ def float_power(
|
|||
b = _maybe_convert_to_dtype(b, dtype)
|
||||
|
||||
a, b = _maybe_broadcast(a, b)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return pow(a, b)
|
||||
|
||||
|
||||
|
|
@ -1384,15 +1384,15 @@ def floor_divide(
|
|||
):
|
||||
# Wrap scalars because some references only accept tensor arguments.
|
||||
if isinstance(a, Number) and isinstance(b, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
a = scalar_tensor(a)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
b = scalar_tensor(b)
|
||||
elif isinstance(b, Number) and isinstance(a, Tensor):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
|
||||
elif isinstance(a, Number) and isinstance(b, Tensor):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
|
||||
elif isinstance(a, Tensor) and isinstance(b, Tensor) and a.device != b.device:
|
||||
if a.device == torch.device("cpu"):
|
||||
|
|
@ -1869,10 +1869,10 @@ def xlogy(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, NumberT
|
|||
|
||||
# Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
|
||||
if isinstance(b, TensorLike) and isinstance(a, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
a = scalar_tensor(a, dtype=b.dtype, device=b.device)
|
||||
elif isinstance(a, TensorLike) and isinstance(b, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
b = scalar_tensor(b, dtype=a.dtype, device=a.device)
|
||||
|
||||
# mypy: expected "Tensor"
|
||||
|
|
@ -2865,7 +2865,7 @@ def cat(tensors: TensorSequenceType, dim: int = 0) -> TensorLikeType:
|
|||
# SymInts
|
||||
|
||||
example = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for i, t in enumerate(tensors):
|
||||
if example is None:
|
||||
if t.ndim != 1:
|
||||
|
|
@ -3358,7 +3358,7 @@ def native_layer_norm(
|
|||
# while torch.Size([1, 2, 3]) == (1, 2, 3) is True
|
||||
# therefore we use tuple(normalized_shape)
|
||||
torch._check(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
weight is None or sym_eq(weight.shape, tuple(normalized_shape)),
|
||||
lambda: "Expected weight to be of same shape as normalized_shape, but got "
|
||||
+ "weight of shape "
|
||||
|
|
@ -3367,7 +3367,7 @@ def native_layer_norm(
|
|||
+ str(normalized_shape),
|
||||
)
|
||||
torch._check(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
bias is None or sym_eq(bias.shape, tuple(normalized_shape)),
|
||||
lambda: "Expected bias to be of same shape as normalized_shape, but got "
|
||||
+ "bias of shape "
|
||||
|
|
@ -3379,7 +3379,7 @@ def native_layer_norm(
|
|||
input.ndim >= normalized_ndim
|
||||
and sym_eq(
|
||||
input.shape[(input.ndim - normalized_ndim) :],
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
tuple(normalized_shape),
|
||||
),
|
||||
lambda: "Given normalized_shape="
|
||||
|
|
@ -3988,16 +3988,16 @@ def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLike
|
|||
# Keeping this as ref for now as FakeTensor runs into some issues with complex tensors
|
||||
return a.clone()
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if a.dim() == 0 and len(dims) > 0:
|
||||
raise IndexError(
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
f"Dimension specified as {dims[0]} but tensor has no dimensions"
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
len_shifts = len(shifts)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
len_dims = len(dims)
|
||||
if len_shifts != 1 or len_dims != 1:
|
||||
if len_shifts == 0:
|
||||
|
|
@ -4005,27 +4005,27 @@ def roll(a: TensorLikeType, shifts: DimsType, dims: DimsType = ()) -> TensorLike
|
|||
# Takes care of the case when dims is not specified (default)
|
||||
# By default, the tensor is flattened before shifting, after which the original shape is restored
|
||||
if len_dims == 0 and len_shifts == 1:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch.roll(torch.flatten(a), shifts, 0).view(a.shape)
|
||||
if len_shifts != len_dims:
|
||||
raise RuntimeError(
|
||||
f"shifts and dimensions must align. shifts: {len_shifts}, dims: {len_dims}"
|
||||
)
|
||||
assert len_dims > 1
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
tail_shifts = shifts[1:]
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
tail_dims = dims[1:]
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
first_dim_rolled = torch.roll(a, (shifts[0],), dims[0])
|
||||
return torch.roll(first_dim_rolled, tail_shifts, tail_dims)
|
||||
|
||||
# This path is taken when only one dimension is rolled
|
||||
# For example to get `first_dim_rolled` above
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
dim = dims[0]
|
||||
size = a.shape[dim]
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
start = (size - shifts[0]) % size
|
||||
idx = torch.arange(size, device=a.device)
|
||||
return a.index_select(dim, torch.fmod(start + idx, size))
|
||||
|
|
@ -4107,7 +4107,7 @@ def softmax(
|
|||
a_max = amax(a_, dim, keepdim=True)
|
||||
a_exp = exp(a_ - a_max)
|
||||
return _maybe_convert_to_dtype(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
true_divide(a_exp, sum(a_exp, dim, keepdim=True)),
|
||||
result_dtype,
|
||||
) # type: ignore[return-value]
|
||||
|
|
@ -4427,7 +4427,7 @@ def hsplit(
|
|||
if isinstance(indices_or_sections, IntLike):
|
||||
split_size = indices_or_sections
|
||||
torch._check(
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
(split_size != 0 and a.shape[dim] % split_size == 0),
|
||||
lambda: (
|
||||
"torch.hsplit attempted to split along dimension "
|
||||
|
|
@ -4439,7 +4439,7 @@ def hsplit(
|
|||
+ "!"
|
||||
),
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return tensor_split(a, split_size, dim)
|
||||
|
||||
torch._check_type(
|
||||
|
|
@ -4470,7 +4470,7 @@ def vsplit(
|
|||
if isinstance(indices_or_sections, IntLike):
|
||||
split_size = indices_or_sections
|
||||
torch._check(
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
(split_size != 0 and a.shape[0] % split_size == 0),
|
||||
lambda: (
|
||||
f"torch.vsplit attempted to split along dimension 0"
|
||||
|
|
@ -4481,7 +4481,7 @@ def vsplit(
|
|||
f"!"
|
||||
),
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return tensor_split(a, split_size, 0)
|
||||
|
||||
torch._check_type(
|
||||
|
|
@ -4686,7 +4686,7 @@ def dsplit(a: TensorLikeType, sections: DimsType) -> TensorSequenceType:
|
|||
raise RuntimeError(
|
||||
f"torch.dsplit requires a tensor with at least 3 dimension, but got a tensor with {a.ndim} dimensions!"
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
if isinstance(sections, IntLike) and (sections == 0 or a.shape[2] % sections != 0):
|
||||
raise RuntimeError(
|
||||
"torch.dsplit attempted to split along dimension 2, "
|
||||
|
|
@ -5460,7 +5460,7 @@ def logspace(
|
|||
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
# pyrefly: ignore [inconsistent-overload]
|
||||
def meshgrid(tensors: Sequence[TensorLikeType], indexing: str):
|
||||
pass
|
||||
|
||||
|
|
@ -5887,7 +5887,7 @@ def masked_fill(a: TensorLikeType, mask: TensorLikeType, value: TensorOrNumberLi
|
|||
|
||||
# Since `where` allows type-promotion,
|
||||
# cast value to correct type before passing to `where`
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
value = _maybe_convert_to_dtype(value, a.dtype)
|
||||
r = torch.where(mask, value, a) # type: ignore[arg-type]
|
||||
|
||||
|
|
@ -6720,7 +6720,7 @@ def _recursive_build(
|
|||
# torch.Size([1, 2])
|
||||
return obj.detach().to(dtype=scalarType, device="cpu", copy=True)
|
||||
elif isinstance(obj, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch.scalar_tensor(obj, dtype=scalarType)
|
||||
|
||||
# seq can be a list of tensors
|
||||
|
|
|
|||
|
|
@ -224,9 +224,9 @@ def matrix_norm(
|
|||
len(dim) == 2, lambda: f"linalg.matrix_norm: dim must be a 2-tuple. Got {dim}"
|
||||
)
|
||||
torch._check(
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
dim[0] != dim[1],
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
lambda: f"linalg.matrix_norm: dims must be different. Got ({dim[0]}, {dim[1]})",
|
||||
)
|
||||
# dtype arg
|
||||
|
|
@ -248,7 +248,7 @@ def matrix_norm(
|
|||
else: # ord == "nuc"
|
||||
if dtype is not None:
|
||||
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
|
||||
result = torch.sum(svdvals(prims.transpose(A, perm)), -1, keepdim)
|
||||
if keepdim:
|
||||
|
|
@ -272,7 +272,7 @@ def matrix_norm(
|
|||
if abs_ord == 2.0:
|
||||
if dtype is not None:
|
||||
A = _maybe_convert_to_dtype(A, dtype) # type: ignore[assignment]
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
perm = _backshift_permutation(dim[0], dim[1], A.ndim)
|
||||
result = max_min(svdvals(prims.transpose(A, perm)), dim=-1)
|
||||
if keepdim:
|
||||
|
|
@ -280,7 +280,7 @@ def matrix_norm(
|
|||
result = prims.transpose(torch.unsqueeze(result, -1), inv_perm)
|
||||
return result
|
||||
else: # 1, -1, inf, -inf
|
||||
# pyrefly: ignore # bad-unpacking
|
||||
# pyrefly: ignore [bad-unpacking]
|
||||
dim0, dim1 = dim
|
||||
if abs_ord == float("inf"):
|
||||
dim0, dim1 = dim1, dim0
|
||||
|
|
|
|||
|
|
@ -142,11 +142,11 @@ def _inplace_wrapper(fn: Callable[_P, _T]) -> Callable[_P, _T]:
|
|||
# nb. We use the name of the first argument used in the unary references
|
||||
@wraps(fn)
|
||||
def _fn(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
a = args[0]
|
||||
if "inplace" not in kwargs:
|
||||
kwargs["inplace"] = False
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
if kwargs["inplace"]:
|
||||
torch._check(
|
||||
"out" not in kwargs,
|
||||
|
|
@ -627,7 +627,7 @@ def smooth_l1_loss(
|
|||
)
|
||||
else:
|
||||
loss = torch.abs(input - target)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
loss = torch.where(loss < beta, 0.5 * loss**2 / beta, loss - 0.5 * beta)
|
||||
return _apply_loss_reduction(loss, reduction)
|
||||
|
||||
|
|
|
|||
|
|
@ -155,10 +155,10 @@ def xlog1py(a: Union[TensorLikeType, NumberType], b: Union[TensorLikeType, Numbe
|
|||
|
||||
# Operations like eq and log do not handle scalar values, so we convert them to scalar_tensors.
|
||||
if isinstance(a, TensorLike) and isinstance(b, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
b = refs.scalar_tensor(b, dtype=a.dtype, device=a.device)
|
||||
elif isinstance(b, TensorLike) and isinstance(a, Number):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
a = refs.scalar_tensor(a, dtype=b.dtype, device=b.device)
|
||||
|
||||
# mypy: expected "Tensor"
|
||||
|
|
|
|||
|
|
@ -314,7 +314,7 @@ def strobelight(
|
|||
) -> Callable[_P, Optional[_R]]:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return profiler.profile(work_function, *args, **kwargs)
|
||||
|
||||
return wrapper_function
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class _DeconstructedSymNode:
|
|||
node.pytype,
|
||||
node._hint,
|
||||
node.constant,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
node.fx_node,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -404,9 +404,9 @@ class FakeTensorConverter:
|
|||
with no_dispatch():
|
||||
return FakeTensor(
|
||||
fake_mode,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
make_meta_t(),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
device,
|
||||
# TODO: callback might be used in recursive contexts, in
|
||||
# which case using t is wrong! BUG!
|
||||
|
|
@ -681,7 +681,7 @@ class FakeTensor(Tensor):
|
|||
_mode_key = torch._C._TorchDispatchModeKey.FAKE
|
||||
|
||||
@property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def device(self) -> torch.device:
|
||||
if self.fake_mode.in_kernel_invocation:
|
||||
return torch.device("meta")
|
||||
|
|
@ -709,7 +709,7 @@ class FakeTensor(Tensor):
|
|||
|
||||
# We don't support named tensors; graph break
|
||||
@property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def names(self) -> list[str]:
|
||||
raise UnsupportedFakeTensorException(
|
||||
"torch.compile doesn't support named tensors"
|
||||
|
|
@ -768,7 +768,7 @@ class FakeTensor(Tensor):
|
|||
)
|
||||
else:
|
||||
device = torch.device(f"{device.type}:0")
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.fake_device = device
|
||||
self.fake_mode = fake_mode
|
||||
self.constant = constant
|
||||
|
|
@ -1374,7 +1374,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
return self._stack
|
||||
|
||||
@count
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def __torch_dispatch__(
|
||||
self,
|
||||
func: OpOverload,
|
||||
|
|
@ -1499,7 +1499,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
# Do this dispatch outside the above except handler so if it
|
||||
# generates its own exception there won't be a __context__ caused by
|
||||
# the caching mechanism.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
assert state is not None
|
||||
|
|
@ -1517,27 +1517,27 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
# This represents a negative cache entry - we already saw that the
|
||||
# output is uncachable. Compute it from first principals.
|
||||
FakeTensorMode.cache_bypasses[entry.reason] += 1
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
# We have a cache entry.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output = self._output_from_cache_entry(state, entry, key, func, args)
|
||||
FakeTensorMode.cache_hits += 1
|
||||
if self.cache_crosscheck_enabled:
|
||||
# For debugging / testing: Validate that the output synthesized
|
||||
# from the cache matches the output created by normal dispatch.
|
||||
with disable_fake_tensor_cache(self):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self._crosscheck_cache_output(output, func, types, args, kwargs)
|
||||
return output
|
||||
|
||||
# We don't have a cache entry.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output = self._dispatch_impl(func, types, args, kwargs)
|
||||
|
||||
try:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
|
||||
except _BypassDispatchCache as e:
|
||||
# We ran "extra" checks on the cache key and determined that it's no
|
||||
|
|
@ -1595,16 +1595,16 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
if state.known_symbols:
|
||||
# If there are symbols then include the epoch - this is really more
|
||||
# of a Shape env var which lives on the FakeTensorMode.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
key_values.append(self.epoch)
|
||||
# Collect the id_hashed objects to attach a weakref finalize later
|
||||
id_hashed_objects: list[object] = []
|
||||
# Translate any FakeTensor args to metadata.
|
||||
if args:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self._prep_args_for_hash(key_values, args, state, id_hashed_objects)
|
||||
if kwargs:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self._prep_args_for_hash(key_values, kwargs, state, id_hashed_objects)
|
||||
key = _DispatchCacheKey(tuple(key_values))
|
||||
|
||||
|
|
@ -1922,7 +1922,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
self._validate_output_for_cache_entry(
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
|
|
@ -1932,7 +1932,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
self._validate_output_for_cache_entry(
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
|
|
@ -1944,7 +1944,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
self._get_output_info_for_cache_entry(
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
|
|
@ -1953,7 +1953,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
for out_elem in output
|
||||
]
|
||||
return _DispatchCacheValidEntry(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_infos=tuple(output_infos),
|
||||
is_output_tuple=True,
|
||||
)
|
||||
|
|
@ -1962,7 +1962,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
output_info = self._get_output_info_for_cache_entry(
|
||||
state,
|
||||
key,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
func,
|
||||
args,
|
||||
kwargs,
|
||||
|
|
@ -2509,7 +2509,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
)
|
||||
|
||||
with self, maybe_ignore_fresh_unbacked_symbols():
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
return registered_hop_fake_fns[func](*args, **kwargs)
|
||||
|
||||
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
|
||||
|
|
@ -2664,7 +2664,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
# TODO: Is this really needed?
|
||||
compute_unbacked_bindings(self.shape_env, fake_out, peek=True)
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return fake_out
|
||||
|
||||
# Try for fastpath
|
||||
|
|
@ -2946,7 +2946,7 @@ class FakeTensorMode(TorchDispatchMode):
|
|||
self, e, device or common_device
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return e
|
||||
|
||||
return tree_map(wrap, r)
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool:
|
|||
|
||||
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
|
||||
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return t.grad
|
||||
|
||||
|
||||
|
|
@ -416,7 +416,7 @@ class MetaTensorDescriber:
|
|||
device=t.device,
|
||||
size=t.size(),
|
||||
stride=stride,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
storage_offset=storage_offset,
|
||||
dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
|
||||
dynamo_hint_overrides=getattr(t, "_dynamo_hint_overrides", {}),
|
||||
|
|
@ -541,7 +541,7 @@ class _FakeTensorViewFunc(ViewFunc["FakeTensor"]):
|
|||
tensor_visitor_fn: Optional[Callable[[torch.Tensor], FakeTensor]] = None,
|
||||
) -> FakeTensor:
|
||||
return torch._subclasses.fake_tensor.FakeTensor._view_func_unsafe(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
t,
|
||||
new_base,
|
||||
symint_visitor_fn,
|
||||
|
|
@ -1019,7 +1019,7 @@ class MetaConverter(Generic[_TensorT]):
|
|||
# Morally, the code here is same as transform_subclass, but we've
|
||||
# written it from scratch to read EmptyCreateSubclass
|
||||
outer_size = outer_size if outer_size is not None else t.size
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
outer_stride = outer_stride if outer_stride is not None else t.stride
|
||||
|
||||
assert symbolic_context is None or isinstance(
|
||||
|
|
@ -1276,7 +1276,7 @@ class MetaConverter(Generic[_TensorT]):
|
|||
) -> torch.Tensor:
|
||||
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
|
||||
if visited_t is None:
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return None
|
||||
|
||||
# NB: visited_t being a Tensor here is very naughty! Should
|
||||
|
|
@ -1407,7 +1407,7 @@ class MetaConverter(Generic[_TensorT]):
|
|||
if t.requires_grad:
|
||||
r.requires_grad = True
|
||||
if t.requires_grad and not is_leaf:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
r = self._backward_error(r)
|
||||
elif t.is_nested and not t.is_traceable_wrapper_subclass:
|
||||
# TODO: Handle this better in Dynamo?
|
||||
|
|
@ -1446,7 +1446,7 @@ class MetaConverter(Generic[_TensorT]):
|
|||
if t.requires_grad:
|
||||
r.requires_grad = True
|
||||
if t.requires_grad and not is_leaf:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
r = self._backward_error(r)
|
||||
elif t.is_functorch_wrapped:
|
||||
if t.is_view:
|
||||
|
|
@ -1543,7 +1543,7 @@ class MetaConverter(Generic[_TensorT]):
|
|||
)
|
||||
assert t.data is not None
|
||||
_safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return r
|
||||
|
||||
r = _to_fake_tensor(t)
|
||||
|
|
@ -1693,7 +1693,7 @@ class MetaConverter(Generic[_TensorT]):
|
|||
not (t.is_batchedtensor or t.is_gradtrackingtensor)
|
||||
and t.is_functorch_wrapped
|
||||
) or t.is_legacy_batchedtensor:
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return NotImplemented
|
||||
|
||||
(
|
||||
|
|
@ -1740,7 +1740,7 @@ class MetaConverter(Generic[_TensorT]):
|
|||
# the metadata of the inner tensor.
|
||||
# So instead, we now have a dedicated fn to set autograd history,
|
||||
# without inadvertently changing other metadata.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
r = self._backward_error(r)
|
||||
|
||||
s = t.storage
|
||||
|
|
@ -1820,7 +1820,7 @@ class MetaConverter(Generic[_TensorT]):
|
|||
|
||||
# TODO: Use a valid grad-specific symbolic context instead of recycling
|
||||
# the one from t. This isn't correct if e.g. t._is_view() != t.grad._is_view().
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
r.grad = self.meta_tensor(
|
||||
t.grad,
|
||||
shape_env,
|
||||
|
|
@ -1828,15 +1828,15 @@ class MetaConverter(Generic[_TensorT]):
|
|||
AttrSource(source, "grad"),
|
||||
symbolic_context,
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
torch._C._set_conj(r, t.is_conj)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
torch._C._set_neg(r, t.is_neg)
|
||||
# This can be skipped if necessary for performance reasons
|
||||
skip_leaf = (
|
||||
t.is_gradtrackingtensor and t.level == GRAD_TENSOR_SENTINEL_VALUE
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
assert_metadata_eq(assert_eq, t, r, skip_symbolic=True, skip_leaf=skip_leaf)
|
||||
# Thanks to storage resizing, it's possible to end up with a tensor
|
||||
# that advertises a real size, but has a storage that actually has zero bytes.
|
||||
|
|
@ -1844,23 +1844,23 @@ class MetaConverter(Generic[_TensorT]):
|
|||
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
||||
|
||||
if t.storage is not None and guard_or_false(t.storage.size == 0):
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
r.untyped_storage().resize_(0)
|
||||
|
||||
if t.is_parameter:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
r._is_param = True
|
||||
|
||||
# See Note: [Creating symbolic nested int]
|
||||
if t.nested_int is not None:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
assert _is_fake_tensor(r)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
r.nested_int_memo = r.fake_mode.create_symbolic_nested_int(
|
||||
nt_tensor_id=t.nested_int
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.set_tensor_memo(t, r)
|
||||
|
||||
return self._checked_get_tensor_memo(t)
|
||||
|
|
@ -1904,13 +1904,13 @@ class MetaConverter(Generic[_TensorT]):
|
|||
(t._is_view() and t._base is not None and t._base.is_sparse)
|
||||
):
|
||||
self.miss += 1
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return NotImplemented
|
||||
else:
|
||||
self.hit += 1
|
||||
elif torch.overrides.is_tensor_like(t):
|
||||
self.miss += 1
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return NotImplemented
|
||||
else:
|
||||
# non-Tensor types don't count as hit or miss
|
||||
|
|
|
|||
|
|
@ -657,10 +657,10 @@ def _str_intern(inp, *, tensor_contents=None):
|
|||
grad_fn_name = "Invalid"
|
||||
|
||||
if grad_fn_name is None and grad_fn is not None: # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
grad_fn_name = type(grad_fn).__name__
|
||||
if grad_fn_name == "CppFunction":
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]
|
||||
|
||||
if grad_fn_name is not None:
|
||||
|
|
|
|||
|
|
@ -85,7 +85,7 @@ def compile_time_strobelight_meta(
|
|||
@functools.wraps(function)
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T:
|
||||
if "skip" in kwargs and isinstance(
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
skip := kwargs["skip"],
|
||||
int,
|
||||
):
|
||||
|
|
@ -331,7 +331,7 @@ def deprecated():
|
|||
|
||||
# public deprecated alias
|
||||
alias = typing_extensions.deprecated(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
warning_msg,
|
||||
category=UserWarning,
|
||||
stacklevel=1,
|
||||
|
|
|
|||
|
|
@ -415,7 +415,7 @@ def _cast(value, device_type: str, dtype: _dtype):
|
|||
return value
|
||||
elif HAS_NUMPY and isinstance(
|
||||
value,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
np.ndarray,
|
||||
):
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -350,7 +350,7 @@ def backward(
|
|||
Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,)
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
tensors = tuple(tensors)
|
||||
|
||||
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
|
||||
|
|
@ -450,12 +450,12 @@ def grad(
|
|||
Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
|
||||
)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
outputs = tuple(outputs)
|
||||
if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
|
||||
inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
inputs = tuple(inputs)
|
||||
t_outputs = tuple(i for i in outputs if is_tensor_like(i))
|
||||
t_inputs = tuple(i for i in inputs if is_tensor_like(i))
|
||||
|
|
|
|||
|
|
@ -15,14 +15,14 @@ class Type(Function):
|
|||
"please use `torch.tensor.to(dtype=dtype)` instead.",
|
||||
category=FutureWarning,
|
||||
)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, i, dest_type):
|
||||
ctx.input_type = type(i)
|
||||
ctx.input_device = -1 if not i.is_cuda else i.get_device()
|
||||
return i.type(dest_type)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
if ctx.input_device == -1:
|
||||
return grad_output.type(ctx.input_type), None
|
||||
|
|
@ -34,7 +34,7 @@ class Type(Function):
|
|||
# TODO: deprecate this
|
||||
class Resize(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, tensor, sizes):
|
||||
ctx.sizes = sizes
|
||||
ctx.numel = reduce(operator.mul, sizes, 1)
|
||||
|
|
@ -63,7 +63,7 @@ class Resize(Function):
|
|||
return tensor.contiguous().view(*sizes)
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
if grad_output.numel() != ctx.numel:
|
||||
raise AssertionError(
|
||||
|
|
|
|||
|
|
@ -415,6 +415,6 @@ class _unsafe_preserve_version_counter(_DecoratorContextManager):
|
|||
def __enter__(self) -> None:
|
||||
pass
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def __exit__(self, *args) -> None:
|
||||
torch._C._autograd._unsafe_set_version_counter(self.tensors, self.prev_versions)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from typing_extensions import deprecated
|
|||
import torch
|
||||
import torch.testing
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
# pyrefly: ignore [deprecated]
|
||||
from torch._vmap_internals import _vmap, vmap
|
||||
from torch.overrides import is_tensor_like
|
||||
from torch.types import _TensorOrTensors
|
||||
|
|
|
|||
|
|
@ -230,7 +230,7 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
|
|||
|
||||
# Note that output_nr default to 0 which is the right value
|
||||
# for the AccumulateGrad node.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return GradientEdge(grad_fn, tensor.output_nr, ownership_token=token)
|
||||
|
||||
|
||||
|
|
@ -534,7 +534,7 @@ def register_multi_grad_hook(
|
|||
"expected this hook to be called inside a backward call"
|
||||
)
|
||||
count[id] = count.get(id, 0)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
buffer[id] = buffer.get(id, [None] * len_tensors)
|
||||
|
||||
with lock:
|
||||
|
|
|
|||
|
|
@ -744,7 +744,7 @@ class profile:
|
|||
return all_function_events
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class record_function(_ContextDecorator):
|
||||
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
|
||||
Label will only appear if CPU activity tracing is enabled.
|
||||
|
|
@ -792,7 +792,7 @@ class record_function(_ContextDecorator):
|
|||
# TODO: TorchScript ignores standard type annotation here
|
||||
# self.record: Optional["torch.classes.profiler._RecordFunction"] = None
|
||||
self.record = torch.jit.annotate(
|
||||
# pyrefly: ignore # not-a-type
|
||||
# pyrefly: ignore [not-a-type]
|
||||
Optional["torch.classes.profiler._RecordFunction"],
|
||||
None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -101,14 +101,14 @@ class profile:
|
|||
|
||||
records = _disable_profiler_legacy()
|
||||
parsed_results = _parse_legacy_records(records)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.function_events = EventList(
|
||||
parsed_results,
|
||||
use_device="cuda" if self.use_cuda else None,
|
||||
profile_memory=self.profile_memory,
|
||||
with_flops=self.with_flops,
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.function_events._build_tree()
|
||||
return False
|
||||
|
||||
|
|
|
|||
|
|
@ -414,7 +414,7 @@ class _NnapiSerializer:
|
|||
) # noqa: TRY002
|
||||
return Operand(
|
||||
shape=tuple(tensor.shape),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
op_type=op_type,
|
||||
dim_order=dim_order,
|
||||
scale=scale,
|
||||
|
|
@ -1735,13 +1735,13 @@ class _NnapiSerializer:
|
|||
for dim in (2, 3): # h, w indices
|
||||
if image_oper.shape[dim] == 0:
|
||||
if size_ctype.kind() != "NoneType":
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
self.compute_operand_shape(out_id, dim, size_arg[dim - 2])
|
||||
elif scale_ctype.kind() != "NoneType":
|
||||
self.compute_operand_shape(
|
||||
out_id,
|
||||
dim,
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
f"int({scale_arg[dim - 2]} * {flex_name(image_id, dim)})",
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -34,11 +34,11 @@ if _cudnn is not None:
|
|||
def _init():
|
||||
global __cudnn_version
|
||||
if __cudnn_version is None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
__cudnn_version = _cudnn.getVersionInt()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
runtime_version = _cudnn.getRuntimeVersion()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
compile_version = _cudnn.getCompileVersion()
|
||||
runtime_major, runtime_minor, _ = runtime_version
|
||||
compile_major, compile_minor, _ = compile_version
|
||||
|
|
@ -47,7 +47,7 @@ if _cudnn is not None:
|
|||
# Not sure about MIOpen (ROCm), so always do a strict check
|
||||
if runtime_major != compile_major:
|
||||
cudnn_compatible = False
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
elif runtime_major < 7 or not _cudnn.is_cuda:
|
||||
cudnn_compatible = runtime_minor == compile_minor
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -12,16 +12,16 @@ except ImportError:
|
|||
|
||||
def get_cudnn_mode(mode):
|
||||
if mode == "RNN_RELU":
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return int(_cudnn.RNNMode.rnn_relu)
|
||||
elif mode == "RNN_TANH":
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return int(_cudnn.RNNMode.rnn_tanh)
|
||||
elif mode == "LSTM":
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return int(_cudnn.RNNMode.lstm)
|
||||
elif mode == "GRU":
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return int(_cudnn.RNNMode.gru)
|
||||
else:
|
||||
raise Exception(f"Unknown mode: {mode}") # noqa: TRY002
|
||||
|
|
@ -60,7 +60,7 @@ def init_dropout_state(dropout, train, dropout_seed, dropout_state):
|
|||
dropout_p,
|
||||
train,
|
||||
dropout_seed,
|
||||
# pyrefly: ignore # unexpected-keyword
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
self_ty=torch.uint8,
|
||||
device=torch.device("cuda"),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ if _cusparselt is not None:
|
|||
global __cusparselt_version
|
||||
global __MAX_ALG_ID
|
||||
if __cusparselt_version is None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
__cusparselt_version = _cusparselt.getVersionInt()
|
||||
if __cusparselt_version == 400:
|
||||
__MAX_ALG_ID = 4
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def is_available():
|
|||
|
||||
def is_acl_available():
|
||||
r"""Return whether PyTorch is built with MKL-DNN + ACL support."""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return torch._C._has_mkldnn_acl
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ def _set_strategy(_strategy: str) -> None:
|
|||
|
||||
|
||||
def _get_strategy() -> str:
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return strategy
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -835,7 +835,7 @@ def create_args(parser=None):
|
|||
|
||||
@retval ArgumentParser
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
parser.add_argument(
|
||||
"--multi-instance",
|
||||
"--multi_instance",
|
||||
|
|
@ -844,7 +844,7 @@ def create_args(parser=None):
|
|||
help="Enable multi-instance, by default one instance per node",
|
||||
)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
parser.add_argument(
|
||||
"-m",
|
||||
"--module",
|
||||
|
|
@ -855,7 +855,7 @@ def create_args(parser=None):
|
|||
'"python -m".',
|
||||
)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
parser.add_argument(
|
||||
"--no-python",
|
||||
"--no_python",
|
||||
|
|
@ -870,7 +870,7 @@ def create_args(parser=None):
|
|||
|
||||
_add_multi_instance_params(parser)
|
||||
# positional
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
parser.add_argument(
|
||||
"program",
|
||||
type=str,
|
||||
|
|
@ -879,7 +879,7 @@ def create_args(parser=None):
|
|||
)
|
||||
|
||||
# rest from the training program
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
parser.add_argument("program_args", nargs=REMAINDER)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class CacheArtifactFactory:
|
|||
@classmethod
|
||||
def create(cls, artifact_type_key: str, key: str, content: bytes) -> CacheArtifact:
|
||||
artifact_cls = cls._get_artifact_type(artifact_type_key)
|
||||
# pyrefly: ignore # bad-instantiation
|
||||
# pyrefly: ignore [bad-instantiation]
|
||||
return artifact_cls(key, content)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -98,7 +98,7 @@ class CacheArtifactFactory:
|
|||
cls, artifact_type_key: str, key: str, content: Any
|
||||
) -> CacheArtifact:
|
||||
artifact_cls = cls._get_artifact_type(artifact_type_key)
|
||||
# pyrefly: ignore # bad-instantiation
|
||||
# pyrefly: ignore [bad-instantiation]
|
||||
return artifact_cls(key, artifact_cls.encode(content))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,3 @@
|
|||
# pyrefly: ignore # deprecated
|
||||
# pyrefly: ignore [deprecated]
|
||||
from .autocast_mode import autocast
|
||||
from .grad_scaler import GradScaler
|
||||
|
|
|
|||
|
|
@ -501,14 +501,14 @@ class cudaStatus:
|
|||
|
||||
class CudaError(RuntimeError):
|
||||
def __init__(self, code: int) -> None:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
|
||||
super().__init__(f"{msg} ({code})")
|
||||
|
||||
|
||||
def check_error(res: int) -> None:
|
||||
r"""Raise an error if the result of a CUDA runtime API call is not success."""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if res != _cudart.cudaError.success:
|
||||
raise CudaError(res)
|
||||
|
||||
|
|
@ -608,7 +608,7 @@ def get_device_capability(device: "Device" = None) -> tuple[int, int]:
|
|||
return prop.major, prop.minor
|
||||
|
||||
|
||||
# pyrefly: ignore # not-a-type
|
||||
# pyrefly: ignore [not-a-type]
|
||||
def get_device_properties(device: "Device" = None) -> _CudaDeviceProperties:
|
||||
r"""Get the properties of a device.
|
||||
|
||||
|
|
@ -659,7 +659,7 @@ class StreamContext:
|
|||
self.idx = _get_device_index(None, True)
|
||||
if not torch.jit.is_scripting():
|
||||
if self.idx is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.idx = -1
|
||||
|
||||
self.src_prev_stream = (
|
||||
|
|
@ -964,9 +964,9 @@ def _device_count_amdsmi() -> int:
|
|||
if raw_cnt <= 0:
|
||||
return raw_cnt
|
||||
# Trim the list up to a maximum available device
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
for idx, val in enumerate(visible_devices):
|
||||
# pyrefly: ignore # redundant-cast
|
||||
# pyrefly: ignore [redundant-cast]
|
||||
if cast(int, val) >= raw_cnt:
|
||||
return idx
|
||||
except OSError:
|
||||
|
|
@ -1000,9 +1000,9 @@ def _device_count_nvml() -> int:
|
|||
if raw_cnt <= 0:
|
||||
return raw_cnt
|
||||
# Trim the list up to a maximum available device
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
for idx, val in enumerate(visible_devices):
|
||||
# pyrefly: ignore # redundant-cast
|
||||
# pyrefly: ignore [redundant-cast]
|
||||
if cast(int, val) >= raw_cnt:
|
||||
return idx
|
||||
except OSError:
|
||||
|
|
@ -1218,9 +1218,9 @@ def _get_pynvml_handler(device: "Device" = None):
|
|||
if not _HAS_PYNVML:
|
||||
raise ModuleNotFoundError(
|
||||
"nvidia-ml-py does not seem to be installed or it can't be imported."
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
) from _PYNVML_ERR
|
||||
# pyrefly: ignore # import-error
|
||||
# pyrefly: ignore [import-error]
|
||||
from pynvml import NVMLError_DriverNotLoaded
|
||||
|
||||
try:
|
||||
|
|
@ -1237,7 +1237,7 @@ def _get_amdsmi_handler(device: "Device" = None):
|
|||
if not _HAS_PYNVML:
|
||||
raise ModuleNotFoundError(
|
||||
"amdsmi does not seem to be installed or it can't be imported."
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
) from _PYNVML_ERR
|
||||
try:
|
||||
amdsmi.amdsmi_init()
|
||||
|
|
@ -1501,7 +1501,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int
|
|||
return default_generator.get_offset()
|
||||
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
# pyrefly: ignore [deprecated]
|
||||
from .memory import * # noqa: F403
|
||||
from .random import * # noqa: F403
|
||||
|
||||
|
|
@ -1718,7 +1718,7 @@ def _register_triton_kernels():
|
|||
def kernel_impl(*args, **kwargs):
|
||||
from torch.sparse._triton_ops import bsr_dense_mm
|
||||
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
return bsr_dense_mm(*args, skip_checks=True, **kwargs)
|
||||
|
||||
@_WrappedTritonKernel
|
||||
|
|
|
|||
|
|
@ -279,7 +279,7 @@ class _CudaModule:
|
|||
return self._kernels[name]
|
||||
|
||||
# Import the CUDA library inside the method
|
||||
# pyrefly: ignore # missing-module-attribute
|
||||
# pyrefly: ignore [missing-module-attribute]
|
||||
from torch.cuda._utils import _get_gpu_runtime_library
|
||||
|
||||
libcuda = _get_gpu_runtime_library()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# pyrefly: ignore # deprecated
|
||||
# pyrefly: ignore [deprecated]
|
||||
from .autocast_mode import autocast, custom_bwd, custom_fwd
|
||||
from .common import amp_definitely_not_available
|
||||
from .grad_scaler import GradScaler
|
||||
|
|
|
|||
|
|
@ -259,7 +259,7 @@ class graph:
|
|||
self.cuda_graph.capture_begin(
|
||||
# type: ignore[misc]
|
||||
*self.pool,
|
||||
# pyrefly: ignore # bad-keyword-argument
|
||||
# pyrefly: ignore [bad-keyword-argument]
|
||||
capture_error_mode=self.capture_error_mode,
|
||||
)
|
||||
|
||||
|
|
@ -525,7 +525,7 @@ def make_graphed_callables(
|
|||
) -> Callable[..., object]:
|
||||
class Graphed(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
|
||||
# At this stage, only the user args may (potentially) be new tensors.
|
||||
for i in range(len_user_args):
|
||||
|
|
@ -537,7 +537,7 @@ def make_graphed_callables(
|
|||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
|
||||
assert len(grads) == len(static_grad_outputs)
|
||||
for g, grad in zip(static_grad_outputs, grads):
|
||||
|
|
@ -551,7 +551,7 @@ def make_graphed_callables(
|
|||
# Input args that didn't require grad expect a None gradient.
|
||||
assert isinstance(static_grad_inputs, tuple)
|
||||
return tuple(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
b.detach() if b is not None else b
|
||||
for b in static_grad_inputs
|
||||
)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ if hasattr(torch._C, "_CUDAGreenContext"):
|
|||
|
||||
|
||||
# Python shim helps Sphinx process docstrings more reliably.
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class GreenContext(_GreenContext):
|
||||
r"""Wrapper around a CUDA green context.
|
||||
|
||||
|
|
|
|||
|
|
@ -772,7 +772,7 @@ def list_gpu_processes(device: "Device" = None) -> str:
|
|||
import pynvml # type: ignore[import]
|
||||
except ModuleNotFoundError:
|
||||
return "pynvml module not found, please install nvidia-ml-py"
|
||||
# pyrefly: ignore # import-error
|
||||
# pyrefly: ignore [import-error]
|
||||
from pynvml import NVMLError_DriverNotLoaded
|
||||
|
||||
try:
|
||||
|
|
@ -855,7 +855,7 @@ def _record_memory_history_legacy(
|
|||
_C._cuda_record_memory_history_legacy( # type: ignore[call-arg]
|
||||
enabled,
|
||||
record_context,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
trace_alloc_max_entries,
|
||||
trace_alloc_record_context,
|
||||
record_context_cpp,
|
||||
|
|
@ -1076,7 +1076,7 @@ def _set_memory_metadata(metadata: str):
|
|||
metadata (str): Custom metadata string to attach to allocations.
|
||||
Pass an empty string to clear the metadata.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
torch._C._cuda_setMemoryMetadata(metadata)
|
||||
|
||||
|
||||
|
|
@ -1087,7 +1087,7 @@ def _get_memory_metadata() -> str:
|
|||
Returns:
|
||||
str: The current metadata string, or empty string if no metadata is set.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return torch._C._cuda_getMemoryMetadata()
|
||||
|
||||
|
||||
|
|
@ -1110,7 +1110,7 @@ def _save_memory_usage(filename="output.svg", snapshot=None):
|
|||
category=FutureWarning,
|
||||
)
|
||||
def _set_allocator_settings(env: str):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return torch._C._accelerator_setAllocatorSettings(env)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -53,7 +53,7 @@ def range_start(msg) -> int:
|
|||
Args:
|
||||
msg (str): ASCII message to associate with the range.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return _nvtx.rangeStartA(msg)
|
||||
|
||||
|
||||
|
|
@ -64,7 +64,7 @@ def range_end(range_id) -> None:
|
|||
Args:
|
||||
range_id (int): an unique handle for the start range.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
_nvtx.rangeEnd(range_id)
|
||||
|
||||
|
||||
|
|
@ -85,7 +85,7 @@ def _device_range_start(msg: str, stream: int = 0) -> object:
|
|||
msg (str): ASCII message to associate with the range.
|
||||
stream (int): CUDA stream id.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return _nvtx.deviceRangeStart(msg, stream)
|
||||
|
||||
|
||||
|
|
@ -98,7 +98,7 @@ def _device_range_end(range_handle: object, stream: int = 0) -> None:
|
|||
range_handle: an unique handle for the start range.
|
||||
stream (int): CUDA stream id.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
_nvtx.deviceRangeEnd(range_handle, stream)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -295,7 +295,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
|||
|
||||
self.logger.addHandler(self)
|
||||
self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
torch._logging._internal.GET_DTRACE_STRUCTURED = True
|
||||
return self
|
||||
|
||||
|
|
@ -303,7 +303,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
|||
self.log_record = LogRecord()
|
||||
self.expression_created_logs = {}
|
||||
self.logger.removeHandler(self)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace
|
||||
self.prev_get_dtrace = False
|
||||
|
||||
|
|
|
|||
|
|
@ -107,11 +107,11 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
|
|||
return
|
||||
|
||||
if not (
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
arg.op == "call_function"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
and arg.target == operator.getitem
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
and arg.args[1] == i
|
||||
):
|
||||
log.debug(
|
||||
|
|
|
|||
|
|
@ -186,7 +186,7 @@ def _ignore_backend_decomps():
|
|||
def _disable_custom_triton_op_functional_decomposition():
|
||||
old = torch._functorch.config.decompose_custom_triton_ops
|
||||
try:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
torch._functorch.config.decompose_custom_triton_ops = False
|
||||
yield torch._functorch.config.decompose_custom_triton_ops
|
||||
finally:
|
||||
|
|
@ -365,7 +365,7 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
|
|||
|
||||
nn_module_stack = {
|
||||
root_key: (root, root_cls.__module__ + "." + root_cls.__qualname__),
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
**nn_module_stack,
|
||||
}
|
||||
node.meta["nn_module_stack"] = {
|
||||
|
|
@ -687,7 +687,7 @@ def _restore_state_dict(
|
|||
for name, _ in list(
|
||||
chain(
|
||||
original_module.named_parameters(remove_duplicate=False),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
original_module.named_buffers(remove_duplicate=False),
|
||||
)
|
||||
):
|
||||
|
|
|
|||
|
|
@ -293,7 +293,7 @@ class _ExportPackage:
|
|||
if isinstance(fn, torch.nn.Module):
|
||||
dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type]
|
||||
else:
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
# pyrefly: ignore [invalid-param-spec]
|
||||
dynamic_shapes = v(*args, **kwargs)
|
||||
except AssertionError:
|
||||
continue
|
||||
|
|
@ -341,7 +341,7 @@ class _ExportPackage:
|
|||
assert not hasattr(fn, "_define_overload")
|
||||
_exporter_context._define_overload = _define_overload # type: ignore[attr-defined]
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return _exporter_context
|
||||
|
||||
@property
|
||||
|
|
@ -378,7 +378,7 @@ class _ExportPackage:
|
|||
kwargs=ep.example_inputs[1],
|
||||
options=options,
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
aoti_files_map[name] = aoti_files
|
||||
|
||||
from torch._inductor.package import package
|
||||
|
|
|
|||
|
|
@ -1500,7 +1500,7 @@ class ExportedProgram:
|
|||
transformed_gm = res.graph_module if res is not None else self.graph_module
|
||||
assert transformed_gm is not None
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if transformed_gm is self.graph_module and not res.modified:
|
||||
return self
|
||||
|
||||
|
|
@ -1579,7 +1579,7 @@ class ExportedProgram:
|
|||
verifiers=self.verifiers,
|
||||
)
|
||||
transformed_ep.graph_module.meta.update(self.graph_module.meta)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
transformed_ep.graph_module.meta.update(res.graph_module.meta)
|
||||
return transformed_ep
|
||||
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ def move_to_device_pass(
|
|||
and node.target == torch.ops.aten.to.device
|
||||
):
|
||||
args = list(node.args)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
args[1] = _get_new_device(args[1], location)
|
||||
node.args = tuple(args)
|
||||
|
||||
|
|
|
|||
|
|
@ -173,10 +173,10 @@ class PT2ArchiveWriter:
|
|||
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
|
||||
)
|
||||
for file_path in file_paths:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
filename = os.path.relpath(file_path, folder_dir)
|
||||
archive_path = os.path.join(archive_dir, filename)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.write_file(archive_path, file_path)
|
||||
|
||||
def close(self) -> None:
|
||||
|
|
@ -696,7 +696,7 @@ def package_pt2(
|
|||
if isinstance(f, (str, os.PathLike)):
|
||||
f = os.fspath(f)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
with PT2ArchiveWriter(f) as archive_writer:
|
||||
_package_exported_programs(
|
||||
archive_writer, exported_programs, pickle_protocol=pickle_protocol
|
||||
|
|
@ -711,7 +711,7 @@ def package_pt2(
|
|||
|
||||
if isinstance(f, (io.IOBase, IO)):
|
||||
f.seek(0)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return f
|
||||
|
||||
|
||||
|
|
@ -1098,7 +1098,7 @@ def load_pt2(
|
|||
|
||||
weights = {}
|
||||
weight_maps = {}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
with PT2ArchiveReader(f) as archive_reader:
|
||||
version = archive_reader.read_string(ARCHIVE_VERSION_PATH)
|
||||
if version != ARCHIVE_VERSION_VALUE:
|
||||
|
|
|
|||
|
|
@ -25,15 +25,15 @@ class TensorProperties:
|
|||
|
||||
if not self.is_fake:
|
||||
# only get the storage pointer for real tensors
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.storage_ptr = tensor.untyped_storage().data_ptr()
|
||||
if self.is_contiguous:
|
||||
# only get storage size and start/end pointers for contiguous tensors
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.storage_size = tensor.untyped_storage().nbytes()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.start = tensor.data_ptr()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.end = _end_ptr(tensor)
|
||||
|
||||
# info to recover tensor
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ class GraphPickler(pickle.Pickler):
|
|||
self._meta_tensor_describer = MetaTensorDescriber(copy_data=False)
|
||||
|
||||
@override
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def reducer_override(
|
||||
self, obj: object
|
||||
) -> tuple[Callable[..., Any], tuple[Any, ...]]:
|
||||
|
|
@ -204,7 +204,7 @@ class _SymNodePickleData:
|
|||
]:
|
||||
args = (cls(obj.node), pickler._unpickle_state)
|
||||
if isinstance(obj, torch.SymInt):
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return _SymNodePickleData.unpickle_sym_int, args
|
||||
else:
|
||||
raise NotImplementedError(f"Unhandled SymNode type {type(obj)}")
|
||||
|
|
@ -281,7 +281,7 @@ class _TensorPickleData:
|
|||
return FakeTensor(
|
||||
unpickle_state.fake_mode,
|
||||
make_meta_t(),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
device,
|
||||
)
|
||||
|
||||
|
|
@ -334,9 +334,9 @@ class _TorchNumpyPickleData:
|
|||
if not (name := getattr(np, "__name__", None)):
|
||||
return None
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
assert np == getattr(importlib.import_module(mod), name)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return cls(mod, name)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -603,7 +603,7 @@ class Tracer(TracerBase):
|
|||
in inspect.signature(self.create_proxy).parameters
|
||||
):
|
||||
kwargs["proxy_factory_fn"] = (
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
None
|
||||
if not self.param_shapes_constant
|
||||
else lambda node: ParameterProxy(
|
||||
|
|
|
|||
|
|
@ -659,7 +659,7 @@ class Partitioner:
|
|||
find_combination, partitions = find_partition_to_combine_based_on_size(
|
||||
sorted_partitions,
|
||||
available_mem_bytes,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
partitions,
|
||||
)
|
||||
return
|
||||
|
|
@ -705,7 +705,7 @@ class Partitioner:
|
|||
non_embedding_partitions.append(partition)
|
||||
if new_partition:
|
||||
partition = self.create_partition()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
partition.left_mem_bytes = available_mem_bytes
|
||||
return partition
|
||||
return None
|
||||
|
|
@ -1001,7 +1001,7 @@ class Partitioner:
|
|||
node, n1, p0, p1, node_to_latency_mapping, transfer_rate_per_sec
|
||||
)
|
||||
if cost < min_cost:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
node_pair = [node, n1]
|
||||
min_cost = cost
|
||||
return cost, node_pair # type: ignore[possibly-undefined]
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ def split_result_tensors(
|
|||
else:
|
||||
splits = [x.shape[0] for x in inputs]
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch.split(result, splits)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -178,7 +178,7 @@ class MetaTracer(torch.fx.Tracer):
|
|||
kwargs,
|
||||
name,
|
||||
type_expr,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
proxy_factory_fn,
|
||||
)
|
||||
|
||||
|
|
@ -201,7 +201,7 @@ class MetaTracer(torch.fx.Tracer):
|
|||
|
||||
if kind == "call_function":
|
||||
meta_target = manual_meta_overrides.get(target, target)
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
meta_out = meta_target(*args_metas, **kwargs_metas)
|
||||
elif kind == "call_method":
|
||||
meta_target = getattr(args_metas[0], target) # type: ignore[index]
|
||||
|
|
|
|||
|
|
@ -528,11 +528,11 @@ def view_inference_rule(n: Node, symbols, constraints, counter):
|
|||
if t == -1:
|
||||
var, counter = gen_dvar(counter)
|
||||
t2_type.append(var)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
num_constraints.append(BinConstraintD(var, Dyn, op_neq))
|
||||
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
num_constraints.append(BinConstraintD(t, Dyn, op_neq))
|
||||
t2_type.append(t) # type: ignore[arg-type]
|
||||
|
||||
|
|
@ -1477,7 +1477,7 @@ class ConstraintGenerator:
|
|||
|
||||
all_constraints = []
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for n in graph.nodes:
|
||||
(constraints, counter) = self.generate_constraints_node(n, counter)
|
||||
all_constraints += constraints
|
||||
|
|
|
|||
|
|
@ -193,7 +193,7 @@ def modules_to_mkldnn(nodes: list[fx.Node], modules: dict[str, nn.Module]):
|
|||
assert isinstance(node.target, str)
|
||||
cur_module = modules[node.target]
|
||||
if type(cur_module) in mkldnn_map:
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
new_module = mkldnn_map[type(cur_module)](cur_module, torch.float)
|
||||
assert isinstance(new_module, nn.Module)
|
||||
old_modules[new_module] = copy.deepcopy(cur_module)
|
||||
|
|
@ -266,7 +266,7 @@ def gen_mkl_autotuner(example_inputs, iters=10, warmup=1):
|
|||
reset_modules(
|
||||
submodule.graph.nodes,
|
||||
dict(submodule.named_modules()),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
old_modules,
|
||||
)
|
||||
no_mkl_time = benchmark(lambda: submodule(*sample_inputs))
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ pytree.register_pytree_node(
|
|||
torch.Size,
|
||||
lambda xs: (list(xs), None),
|
||||
lambda xs, _: tuple(xs),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
flatten_with_keys_fn=lambda xs: (
|
||||
[(pytree.SequenceKey(i), x) for i, x in enumerate(xs)],
|
||||
None,
|
||||
|
|
@ -310,7 +310,7 @@ def set_proxy_slot( # type: ignore[no-redef]
|
|||
|
||||
def has_proxy_slot(obj: Tensor, tracer: _ProxyTracer) -> bool:
|
||||
assert isinstance(obj, (Tensor, SymNode)), type(obj)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return bool(get_proxy_slot(obj, tracer, False, lambda _: True))
|
||||
|
||||
|
||||
|
|
@ -407,7 +407,7 @@ def get_proxy_slot(
|
|||
assert isinstance(obj, py_sym_types), type(obj)
|
||||
tracker = tracer.symnode_tracker
|
||||
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
value = tracker.get(obj)
|
||||
|
||||
if value is None and isinstance(obj, py_sym_types):
|
||||
|
|
@ -420,7 +420,7 @@ def get_proxy_slot(
|
|||
else:
|
||||
# Attempt to build it from first principles.
|
||||
_build_proxy_for_sym_expr(tracer, obj.node.expr, obj)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
value = tracker.get(obj)
|
||||
|
||||
if value is None:
|
||||
|
|
@ -1552,7 +1552,7 @@ class TorchFunctionMetadataMode(TorchFunctionMode):
|
|||
kwargs: Optional[dict[str, object]] = None,
|
||||
) -> object:
|
||||
kwargs = kwargs or {}
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.tracer.torch_fn_metadata = func
|
||||
self.tracer.torch_fn_counts[func] = self.tracer.torch_fn_counts.get(func, 0) + 1
|
||||
return func(*args, **kwargs)
|
||||
|
|
@ -1602,7 +1602,7 @@ class PreDispatchTorchFunctionMode(TorchFunctionMode):
|
|||
# For autocast, the python APIs run so we don't have to run them again
|
||||
# here.
|
||||
if func is torch._C._set_grad_enabled:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
func(*args, **kwargs)
|
||||
return node
|
||||
|
||||
|
|
@ -1821,7 +1821,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||
self.decomposition_table = decomposition_table or {}
|
||||
self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real")
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def placeholder(
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
|
|
@ -1834,7 +1834,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||
# TODO handle case where the first character of target is '*'
|
||||
return out
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def get_attr(
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
|
|
@ -1848,7 +1848,7 @@ class DecompositionInterpreter(fx.Interpreter):
|
|||
|
||||
# call_function, call_method, call_module get traced automatically by the outer mode.
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def output(
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
|
|
@ -1967,7 +1967,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||
# Class is modified to be a subclass of torch.nn.Module
|
||||
# Warning: We blow away our own attributes here to mimic the base class
|
||||
# - so don't expect `self.x` to do anything useful.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
self.__class__ = type(
|
||||
base.__class__.__name__,
|
||||
(self.__class__, base.__class__),
|
||||
|
|
@ -1990,7 +1990,7 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||
if not isinstance(attr_val, Module):
|
||||
return attr_val
|
||||
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
return AttrProxy(attr_val, tracer.proxy_paths[self] + "." + name)
|
||||
|
||||
def get_base(self) -> Module:
|
||||
|
|
@ -2003,12 +2003,12 @@ class _ModuleStackTracer(PythonKeyTracer):
|
|||
res = torch.nn.Sequential(
|
||||
OrderedDict(list(self._modules.items())[idx])
|
||||
)
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
|
||||
elif isinstance(self, torch.nn.ModuleList):
|
||||
# Copied from nn/modules/container.py
|
||||
res = torch.nn.ModuleList(list(self._modules.values())[idx])
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
return AttrProxy(res, f"{tracer.proxy_paths[self]}.{idx}")
|
||||
|
||||
return super().__getitem__(idx) # type: ignore[misc]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user