Fix error suppression syntax in utils and nn (#166242)

Fixes syntax for pyrefly : ignores so they only ignore a specific category. No functional changes

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166242
Approved by: https://github.com/oulgen, https://github.com/cyyever
This commit is contained in:
Maggie Moss 2025-10-26 05:21:04 +00:00 committed by PyTorch MergeBot
parent 5121499f6b
commit 84b14f3a10
81 changed files with 358 additions and 358 deletions

View File

@ -153,7 +153,7 @@ class CausalBias(torch.Tensor):
diagonal=diagonal_offset,
)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
"""
Materializes the causal bias into a tensor form.

View File

@ -84,7 +84,7 @@ _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
class FlexKernelOptions(TypedDict, total=False):
"""Options for controlling the behavior of FlexAttention kernels.
@ -128,97 +128,97 @@ class FlexKernelOptions(TypedDict, total=False):
"""
# Performance tuning options
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
num_warps: NotRequired[int]
"""Number of warps to use in the CUDA kernel. Higher values may improve performance
but increase register pressure. Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
num_stages: NotRequired[int]
"""Number of pipeline stages in the CUDA kernel. Higher values may improve performance
but increase shared memory usage. Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_M: NotRequired[int]
"""Thread block size for the sequence length dimension of Q in forward pass.
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_N: NotRequired[int]
"""Thread block size for the sequence length dimension of K/V in forward pass.
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
# Backward-specific block sizes (when prefixed with 'bwd_')
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_M1: NotRequired[int]
"""Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'.
Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_N1: NotRequired[int]
"""Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'.
Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_M2: NotRequired[int]
"""Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'.
Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCK_N2: NotRequired[int]
"""Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'.
Default is determined by autotuning."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
PRESCALE_QK: NotRequired[bool]
"""Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but
may have more numerical error. Default: False."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
ROWS_GUARANTEED_SAFE: NotRequired[bool]
"""If True, guarantees that at least one value in each row is not masked out.
Allows skipping safety checks for better performance. Only set this if you are certain
your mask guarantees this property. For example, causal attention is guaranteed safe
because each query has at least 1 key-value to attend to. Default: False."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]
"""If True, guarantees that all blocks in the mask are contiguous.
Allows optimizing block traversal. For example, causal masks would satisfy this,
but prefix_lm + sliding window would not. Default: False."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
WRITE_DQ: NotRequired[bool]
"""Controls whether gradient scatters are done in the DQ iteration loop of the backward pass.
Setting this to False will force this to happen in the DK loop which depending on your
specific score_mod and mask_mod might be faster. Default: True."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
FORCE_USE_FLEX_ATTENTION: NotRequired[bool]
"""If True, forces the use of the flex attention kernel instead of potentially using
the more optimized flex-decoding kernel for short sequences. This can be a helpful
option for debugging. Default: False."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
USE_TMA: NotRequired[bool]
"""Whether to use Tensor Memory Accelerator (TMA) on supported hardware.
This is experimental and may not work on all hardware, currently specific
to NVIDIA GPUs Hopper+. Default: False."""
# ROCm-specific options
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
kpack: NotRequired[int]
"""ROCm-specific kernel packing parameter."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
matrix_instr_nonkdim: NotRequired[int]
"""ROCm-specific matrix instruction non-K dimension."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
waves_per_eu: NotRequired[int]
"""ROCm-specific waves per execution unit."""
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
force_flash: NotRequired[bool]
""" If True, forces use of the cute-dsl flash attention kernel.
@ -644,7 +644,7 @@ class BlockMask:
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
seq_lengths = (self.seq_lengths,) # type: ignore[assignment]
# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
return (
*seq_lengths,
self.kv_num_blocks,
@ -817,7 +817,7 @@ class BlockMask:
partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
if self.full_kv_num_blocks is not None:
assert self.full_kv_indices is not None
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return partial_dense | _ordered_to_dense(
self.full_kv_num_blocks, self.full_kv_indices
)

View File

@ -78,7 +78,7 @@ class ModuleWrapper(nn.Module):
# nn.Module defines training as a boolean
@property # type: ignore[override]
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def training(self):
return self.cpp_module.training

View File

@ -1275,7 +1275,7 @@ def adaptive_max_pool2d_with_indices(
output_size,
return_indices=return_indices,
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_max_pool2d(input, output_size)
@ -1333,7 +1333,7 @@ def adaptive_max_pool3d_with_indices(
output_size,
return_indices=return_indices,
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_max_pool3d(input, output_size)
@ -1392,7 +1392,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> T
"""
if has_torch_function_unary(input):
return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
@ -1408,7 +1408,7 @@ def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> T
"""
if has_torch_function_unary(input):
return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_output_size = _list_with_default(output_size, input.size())
return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
@ -2444,7 +2444,7 @@ def _no_grad_embedding_renorm_(
input: Tensor,
max_norm: float,
norm_type: float,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> tuple[Tensor, Tensor]:
torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
@ -2698,7 +2698,7 @@ def embedding_bag(
if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested:
include_last_offset = True
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
offsets = input.offsets()
input = input.values().reshape(-1)
if per_sample_weights is not None:
@ -2833,7 +2833,7 @@ def batch_norm(
eps=eps,
)
if training:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_verify_batch_size(input.size())
return torch.batch_norm(
@ -2889,7 +2889,7 @@ def instance_norm(
eps=eps,
)
if use_input_stats:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_verify_spatial_size(input.size())
return torch.instance_norm(
input,
@ -3015,13 +3015,13 @@ def local_response_norm(
div = input.mul(input)
if dim == 3:
div = div.unsqueeze(1)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
div = pad(div, (0, 0, size // 2, (size - 1) // 2))
div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
else:
sizes = input.size()
div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
div = div.view(sizes)
@ -3173,7 +3173,7 @@ def nll_loss(
input,
target,
weight,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_Reduction.get_enum(reduction),
ignore_index,
)
@ -3320,7 +3320,7 @@ def gaussian_nll_loss(
var.clamp_(min=eps)
# Calculate the loss
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var)
if full:
loss += 0.5 * math.log(2 * math.pi)
@ -3496,7 +3496,7 @@ def cross_entropy(
input,
target,
weight,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_Reduction.get_enum(reduction),
ignore_index,
label_smoothing,
@ -3561,7 +3561,7 @@ def binary_cross_entropy(
new_size = _infer_size(target.size(), weight.size())
weight = weight.expand(new_size)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
@ -3692,14 +3692,14 @@ def smooth_l1_loss(
return torch._C._nn.l1_loss(
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_Reduction.get_enum(reduction),
)
else:
return torch._C._nn.smooth_l1_loss(
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_Reduction.get_enum(reduction),
beta,
)
@ -3761,7 +3761,7 @@ def huber_loss(
return torch._C._nn.huber_loss(
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_Reduction.get_enum(reduction),
delta,
)
@ -3773,7 +3773,7 @@ def huber_loss(
unweighted_loss = torch._C._nn.huber_loss(
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_Reduction.get_enum("none"),
delta,
)
@ -3864,7 +3864,7 @@ def l1_loss(
return torch._C._nn.l1_loss(
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_Reduction.get_enum(reduction),
)
@ -3942,7 +3942,7 @@ def mse_loss(
return torch._C._nn.mse_loss(
expanded_input,
expanded_target,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_Reduction.get_enum(reduction),
)
@ -4080,7 +4080,7 @@ def multilabel_margin_loss(
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
else:
reduction_enum = _Reduction.get_enum(reduction)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
@ -4122,7 +4122,7 @@ def soft_margin_loss(
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
else:
reduction_enum = _Reduction.get_enum(reduction)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
@ -4292,7 +4292,7 @@ def multi_margin_loss(
p,
margin,
weight,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
reduction_enum,
)
@ -4439,7 +4439,7 @@ def upsample( # noqa: F811
scale_factor: Optional[float] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor: # noqa: B950
pass
@ -4451,7 +4451,7 @@ def upsample( # noqa: F811
scale_factor: Optional[float] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor: # noqa: B950
pass
@ -4554,7 +4554,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor: # noqa: B950
pass
@ -4568,7 +4568,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor: # noqa: B950
pass
@ -4582,7 +4582,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor: # noqa: B950
pass
@ -4596,7 +4596,7 @@ def interpolate( # noqa: F811
align_corners: Optional[bool] = None,
recompute_scale_factor: Optional[bool] = None,
antialias: bool = False,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor:
pass
@ -4771,7 +4771,7 @@ def interpolate( # noqa: F811
(
torch.floor(
(
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
input.size(i + 2).float()
* torch.tensor(scale_factors[i], dtype=torch.float32)
).float()
@ -4796,28 +4796,28 @@ def interpolate( # noqa: F811
)
if input.dim() == 3 and mode == "nearest":
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
if input.dim() == 4 and mode == "nearest":
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
if input.dim() == 5 and mode == "nearest":
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
if input.dim() == 3 and mode == "nearest-exact":
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
if input.dim() == 4 and mode == "nearest-exact":
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
if input.dim() == 5 and mode == "nearest-exact":
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
if input.dim() == 3 and mode == "area":
assert output_size is not None
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return adaptive_avg_pool1d(input, output_size)
if input.dim() == 4 and mode == "area":
assert output_size is not None
@ -4830,7 +4830,7 @@ def interpolate( # noqa: F811
assert align_corners is not None
return torch._C._nn.upsample_linear1d(
input,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_size,
align_corners,
scale_factors,
@ -4840,7 +4840,7 @@ def interpolate( # noqa: F811
if antialias:
return torch._C._nn._upsample_bilinear2d_aa(
input,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_size,
align_corners,
scale_factors,
@ -4857,7 +4857,7 @@ def interpolate( # noqa: F811
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
return torch._C._nn.upsample_bilinear2d(
input,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_size,
align_corners,
scale_factors,
@ -4876,7 +4876,7 @@ def interpolate( # noqa: F811
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
return torch._C._nn.upsample_trilinear3d(
input,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_size,
align_corners,
scale_factors,
@ -4886,14 +4886,14 @@ def interpolate( # noqa: F811
if antialias:
return torch._C._nn._upsample_bicubic2d_aa(
input,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_size,
align_corners,
scale_factors,
)
return torch._C._nn.upsample_bicubic2d(
input,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
output_size,
align_corners,
scale_factors,
@ -4928,7 +4928,7 @@ def upsample_nearest( # noqa: F811
input: Tensor,
size: Optional[int] = None,
scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor:
pass
@ -4938,7 +4938,7 @@ def upsample_nearest( # noqa: F811
input: Tensor,
size: Optional[list[int]] = None,
scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor:
pass
@ -4980,7 +4980,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[int] = None,
scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor:
pass
@ -4990,7 +4990,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[list[int]] = None,
scale_factor: Optional[float] = None,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor:
pass
@ -5000,7 +5000,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[int] = None,
scale_factor: Optional[list[float]] = None,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor:
pass
@ -5010,7 +5010,7 @@ def upsample_bilinear( # noqa: F811
input: Tensor,
size: Optional[list[int]] = None,
scale_factor: Optional[list[float]] = None,
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
) -> Tensor:
pass
@ -5817,7 +5817,7 @@ def _in_projection_packed(
.squeeze(-2)
.contiguous()
)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return proj[0], proj[1], proj[2]
else:
# encoder-decoder attention
@ -5836,7 +5836,7 @@ def _in_projection_packed(
.squeeze(-2)
.contiguous()
)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return (q_proj, kv_proj[0], kv_proj[1])
else:
w_q, w_k, w_v = w.chunk(3)
@ -5844,7 +5844,7 @@ def _in_projection_packed(
b_q = b_k = b_v = None
else:
b_q, b_k, b_v = b.chunk(3)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
@ -6475,10 +6475,10 @@ def multi_head_attention_forward(
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
if attn_mask is not None:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
key_padding_mask = pad(key_padding_mask, (0, 1))
else:
assert bias_k is None
@ -6487,10 +6487,10 @@ def multi_head_attention_forward(
#
# reshape q, k, v for multihead attention and make them batch first
#
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
if static_k is None:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
@ -6502,7 +6502,7 @@ def multi_head_attention_forward(
)
k = static_k
if static_v is None:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
@ -6518,20 +6518,20 @@ def multi_head_attention_forward(
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat(
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
dim=1,
)
v = torch.cat(
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
dim=1,
)
if attn_mask is not None:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
key_padding_mask = pad(key_padding_mask, (0, 1))
# update source sequence length after adjustments
@ -6581,7 +6581,7 @@ def multi_head_attention_forward(
attn_output = torch.bmm(attn_output_weights, v)
attn_output = (
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
@ -6608,16 +6608,16 @@ def multi_head_attention_forward(
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
q = q.view(bsz, num_heads, tgt_len, head_dim)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
k = k.view(bsz, num_heads, src_len, head_dim)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
v = v.view(bsz, num_heads, src_len, head_dim)
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)

View File

@ -500,7 +500,7 @@ def xavier_normal_(
def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
mode = mode.lower()
valid_modes = ["fan_in", "fan_out"]
if mode not in valid_modes:

View File

@ -6,7 +6,7 @@ from torch.autograd.function import Function
class SyncBatchNorm(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(
self,
input,
@ -211,7 +211,7 @@ class SyncBatchNorm(Function):
class CrossMapLRN2d(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
ctx.size = size
ctx.alpha = alpha
@ -267,7 +267,7 @@ class CrossMapLRN2d(Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
input, output = ctx.saved_tensors
grad_input = grad_output.new()
@ -309,7 +309,7 @@ class CrossMapLRN2d(Function):
class BackwardHookFunction(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, *args):
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
return args

View File

@ -109,7 +109,7 @@ class Sequential(Module):
def __init__(self, *args: Module) -> None: ...
@overload
# pyrefly: ignore # inconsistent-overload
# pyrefly: ignore [inconsistent-overload]
def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
def __init__(self, *args):
@ -624,11 +624,11 @@ class ModuleDict(Module):
"ModuleDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(m).__name__
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if not len(m) == 2:
raise ValueError(
"ModuleDict update sequence element "
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
)
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
@ -687,7 +687,7 @@ class ParameterList(Module):
def __getitem__(self, idx: int) -> Any: ...
@overload
# pyrefly: ignore # inconsistent-overload
# pyrefly: ignore [inconsistent-overload]
def __getitem__(self: T, idx: slice) -> T: ...
def __getitem__(self, idx):
@ -773,11 +773,11 @@ class ParameterList(Module):
size_str,
device_str,
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
child_lines.append(" (" + str(k) + "): " + parastr)
else:
child_lines.append(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
" (" + str(k) + "): Object of type: " + type(p).__name__
)
@ -985,11 +985,11 @@ class ParameterDict(Module):
"ParameterDict update sequence element "
"#" + str(j) + " should be Iterable; is" + type(p).__name__
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if not len(p) == 2:
raise ValueError(
"ParameterDict update sequence element "
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
)
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
@ -1010,11 +1010,11 @@ class ParameterDict(Module):
size_str,
device_str,
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
child_lines.append(" (" + str(k) + "): " + parastr)
else:
child_lines.append(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
" (" + str(k) + "): Object of type: " + type(p).__name__
)
tmpstr = "\n".join(child_lines)

View File

@ -1514,7 +1514,7 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(
0,
0,
@ -1529,11 +1529,11 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
@ -1586,7 +1586,7 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(
0,
0,
@ -1601,11 +1601,11 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
@ -1659,7 +1659,7 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(
0,
0,
@ -1674,11 +1674,11 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
@ -1730,7 +1730,7 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(
0,
0,
@ -1746,11 +1746,11 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
@ -1802,7 +1802,7 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(
0,
0,
@ -1818,11 +1818,11 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:
@ -1874,7 +1874,7 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
dtype=None,
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(
0,
0,
@ -1890,11 +1890,11 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
padding_mode,
**factory_kwargs,
)
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.weight = UninitializedParameter(**factory_kwargs)
self.out_channels = out_channels
if bias:
# pyrefly: ignore # bad-override, bad-argument-type
# pyrefly: ignore [bad-override, bad-argument-type]
self.bias = UninitializedParameter(**factory_kwargs)
def _get_num_spatial_dims(self) -> int:

View File

@ -172,9 +172,9 @@ class LazyModuleMixin:
def __init__(self: _LazyProtocol, *args, **kwargs):
# Mypy doesn't like this super call in a mixin
super().__init__(*args, **kwargs) # type: ignore[misc]
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self._initialize_hook = self.register_forward_pre_hook(
self._infer_parameters, with_kwargs=True
)

View File

@ -286,7 +286,7 @@ class LazyLinear(LazyModuleMixin, Linear):
"""
cls_to_become = Linear # type: ignore[assignment]
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
weight: UninitializedParameter
bias: UninitializedParameter # type: ignore[assignment]
@ -296,20 +296,20 @@ class LazyLinear(LazyModuleMixin, Linear):
factory_kwargs = {"device": device, "dtype": dtype}
# bias is hardcoded to False to avoid creating tensor
# that will soon be overwritten.
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(0, 0, False)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.weight = UninitializedParameter(**factory_kwargs)
self.out_features = out_features
if bias:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.bias = UninitializedParameter(**factory_kwargs)
def reset_parameters(self) -> None:
"""
Resets parameters based on their initialization used in ``__init__``.
"""
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if not self.has_uninitialized_params() and self.in_features != 0:
super().reset_parameters()
@ -317,7 +317,7 @@ class LazyLinear(LazyModuleMixin, Linear):
"""
Infers ``in_features`` based on ``input`` and initializes parameters.
"""
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if self.has_uninitialized_params():
with torch.no_grad():
self.in_features = input.shape[-1]

View File

@ -84,7 +84,7 @@ class CircularPad1d(_CircularPadNd):
[5., 6., 7., 4., 5., 6., 7., 4.]]])
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None:
@ -145,7 +145,7 @@ class CircularPad2d(_CircularPadNd):
[8., 6., 7., 8., 6.]]]])
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None:
@ -196,7 +196,7 @@ class CircularPad3d(_CircularPadNd):
>>> output = m(input)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None:
@ -268,7 +268,7 @@ class ConstantPad1d(_ConstantPadNd):
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int]
def __init__(self, padding: _size_2_t, value: float) -> None:
@ -320,7 +320,7 @@ class ConstantPad2d(_ConstantPadNd):
"""
__constants__ = ["padding", "value"]
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t, value: float) -> None:
@ -361,7 +361,7 @@ class ConstantPad3d(_ConstantPadNd):
>>> output = m(input)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t, value: float) -> None:
@ -415,7 +415,7 @@ class ReflectionPad1d(_ReflectionPadNd):
[7., 6., 5., 4., 5., 6., 7., 6.]]])
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None:
@ -469,7 +469,7 @@ class ReflectionPad2d(_ReflectionPadNd):
[7., 6., 7., 8., 7.]]]])
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None:
@ -525,7 +525,7 @@ class ReflectionPad3d(_ReflectionPadNd):
[1., 0., 1., 0.]]]]])
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None:
@ -579,7 +579,7 @@ class ReplicationPad1d(_ReplicationPadNd):
[4., 4., 4., 4., 5., 6., 7., 7.]]])
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int]
def __init__(self, padding: _size_2_t) -> None:
@ -633,7 +633,7 @@ class ReplicationPad2d(_ReplicationPadNd):
[6., 6., 7., 8., 8.]]]])
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int, int, int]
def __init__(self, padding: _size_4_t) -> None:
@ -676,7 +676,7 @@ class ReplicationPad3d(_ReplicationPadNd):
>>> output = m(input)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
padding: tuple[int, int, int, int, int, int]
def __init__(self, padding: _size_6_t) -> None:

View File

@ -138,7 +138,7 @@ class Transformer(Module):
d_model,
eps=layer_norm_eps,
bias=bias,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
**factory_kwargs,
)
self.encoder = TransformerEncoder(
@ -164,7 +164,7 @@ class Transformer(Module):
d_model,
eps=layer_norm_eps,
bias=bias,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
**factory_kwargs,
)
self.decoder = TransformerDecoder(
@ -768,9 +768,9 @@ class TransformerEncoderLayer(Module):
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
self.norm_first = norm_first
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
@ -1062,11 +1062,11 @@ class TransformerDecoderLayer(Module):
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
self.norm_first = norm_first
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)

View File

@ -36,7 +36,7 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
import torch
if isinstance(out_size, (int, torch.SymInt)):
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return out_size
if len(defaults) <= len(out_size):
raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")

View File

@ -43,7 +43,7 @@ def broadcast(tensor, devices=None, *, out=None):
devices = [_get_device_index(d) for d in devices]
return torch._C._broadcast(tensor, devices)
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return torch._C._broadcast_out(tensor, out)
@ -201,7 +201,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=
"""
tensor = _handle_complex(tensor)
if out is None:
# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
devices = [_get_device_index(d) for d in devices]
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
else:

View File

@ -160,7 +160,7 @@ class DataParallel(Module, Generic[T]):
self.module = module
self.device_ids = [_get_device_index(x, True) for x in device_ids]
self.output_device = _get_device_index(output_device, True)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.src_device_obj = torch.device(device_type, self.device_ids[0])
if device_type == "cuda":
@ -174,7 +174,7 @@ class DataParallel(Module, Generic[T]):
if not self.device_ids:
return self.module(*inputs, **kwargs)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError(
@ -261,10 +261,10 @@ def data_parallel(
device_ids = [_get_device_index(x, True) for x in device_ids]
output_device = _get_device_index(output_device, True)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
src_device_obj = torch.device(device_type, device_ids[0])
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
for t in chain(module.parameters(), module.buffers()):
if t.device != src_device_obj:
raise RuntimeError(

View File

@ -241,7 +241,7 @@ class _BufferCommHook:
# is completed.
class _DDPSink(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, ddp_weakref, *inputs):
# set_materialize_grads(False) will ensure that None gradients stay as
# None and are not filled with zeros.
@ -692,7 +692,7 @@ class DistributedDataParallel(Module, Joinable):
elif process_group is None and device_mesh is None:
self.process_group = _get_default_group()
elif device_mesh is None:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.process_group = process_group
else:
if device_mesh.ndim != 1:
@ -780,13 +780,13 @@ class DistributedDataParallel(Module, Joinable):
self.device_ids = None
self.output_device = None
else:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.device_ids = [_get_device_index(x, True) for x in device_ids]
if output_device is None:
output_device = device_ids[0]
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.output_device = _get_device_index(output_device, True)
self.static_graph = False
@ -936,7 +936,7 @@ class DistributedDataParallel(Module, Joinable):
# enabled.
self._accum_grad_hooks: list[RemovableHandle] = []
if self._use_python_reducer:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
torch._inductor.config._fuse_ddp_communication = True
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
# Directly adding this to the trace rule will disturb the users

View File

@ -56,16 +56,16 @@ def scatter(inputs, target_gpus, dim=0):
if isinstance(obj, torch.Tensor):
return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj):
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
if isinstance(obj, tuple) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return list(zip(*map(scatter_map, obj)))
if isinstance(obj, list) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return [list(i) for i in zip(*map(scatter_map, obj))]
if isinstance(obj, dict) and len(obj) > 0:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
return [obj for _ in target_gpus]
@ -127,12 +127,12 @@ def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0)
if isinstance(out, dict):
if not all(len(out) == len(d) for d in outputs):
raise ValueError("All dicts must have the same number of keys")
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
if _is_namedtuple(out):
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return type(out)._make(map(gather_map, zip(*outputs)))
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return type(out)(map(gather_map, zip(*outputs)))
# Recursive function calls like this create reference cycles.

View File

@ -81,7 +81,7 @@ class Parameter(torch.Tensor, metaclass=_ParameterMeta):
memo[id(self)] = result
return result
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def __repr__(self):
return "Parameter containing:\n" + super().__repr__()
@ -144,7 +144,7 @@ class UninitializedTensorMixin:
if dtype is None:
dtype = self.data.dtype
self.data = torch.empty(shape, device=device, dtype=dtype)
# pyrefly: ignore # bad-override, missing-attribute
# pyrefly: ignore [bad-override, missing-attribute]
self.__class__ = self.cls_to_become
@property
@ -168,7 +168,7 @@ class UninitializedTensorMixin:
def __reduce_ex__(self, proto):
# See Note [Don't serialize hooks]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return (self.__class__, (self.requires_grad,))
@classmethod
@ -178,7 +178,7 @@ class UninitializedTensorMixin:
if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper":
if kwargs is None:
kwargs = {}
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return super().__torch_function__(func, types, args, kwargs)
raise ValueError(
f"Attempted to use an uninitialized parameter in {func}. "
@ -220,7 +220,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
def __new__(cls, requires_grad=True, device=None, dtype=None) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
data = torch.empty(0, **factory_kwargs)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return torch.Tensor._make_subclass(cls, data, requires_grad)
def __deepcopy__(self, memo):
@ -266,9 +266,9 @@ class Buffer(torch.Tensor, metaclass=_BufferMeta):
data = torch.empty(0)
t = data.detach().requires_grad_(data.requires_grad)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
t.persistent = persistent
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
t._is_buffer = True
return t
@ -299,9 +299,9 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
factory_kwargs = {"device": device, "dtype": dtype}
data = torch.empty(0, **factory_kwargs)
ret = torch.Tensor._make_subclass(cls, data, requires_grad)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
ret.persistent = persistent
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
ret._is_buffer = True
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return ret

View File

@ -24,7 +24,7 @@ from .expanded_weights_utils import forward_helper
@implements_per_sample_grads(F.conv3d)
class ConvPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(
ctx: Any,
kwarg_names: list[str],
@ -57,7 +57,7 @@ class ConvPerSampleGrad(torch.autograd.Function):
f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}"
)
# pyrefly: ignore # invalid-type-var
# pyrefly: ignore [invalid-type-var]
ctx.conv_fn = conv_fn
ctx.batch_size = orig_input.shape[0]

View File

@ -237,7 +237,7 @@ def conv_unfold_weight_grad_sample(
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
# rearrange the above tensor and extract diagonals.
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
weight_grad_sample = weight_grad_sample.view(
n,
groups,

View File

@ -14,7 +14,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.embedding)
class EmbeddingPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
) -> torch.Tensor:
@ -35,7 +35,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(
ctx: Any, grad_output: torch.Tensor
) -> tuple[Optional[torch.Tensor], ...]:

View File

@ -131,7 +131,7 @@ class ExpandedWeight(torch.Tensor):
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
decomp_opts = expanded_weights_rnn_decomps[func]
use_input_variant = isinstance(
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
args[2],
list,
) # data variant uses a list here

View File

@ -8,7 +8,7 @@ from .expanded_weights_impl import ExpandedWeight
def is_batch_first(expanded_args_and_kwargs):
batch_first = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for arg in expanded_args_and_kwargs:
if not isinstance(arg, ExpandedWeight):
continue

View File

@ -18,7 +18,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.group_norm)
class GroupNormPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs
@ -47,7 +47,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
input, num_groups = ctx.input, ctx.num_groups
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
@ -97,7 +97,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
weight,
lambda _: torch.einsum(
"ni...->ni",
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
F.group_norm(input, num_groups, eps=eps) * grad_output,
),
)

View File

@ -17,7 +17,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.instance_norm)
class InstanceNormPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
instance_norm = partial(torch.instance_norm, cudnn_enabled=True)
expanded_args, expanded_kwargs = standard_kwargs(
@ -37,7 +37,7 @@ class InstanceNormPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps

View File

@ -17,7 +17,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.layer_norm)
class LayerNormPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
expanded_args, expanded_kwargs = standard_kwargs(
kwarg_names, expanded_args_and_kwargs
@ -43,7 +43,7 @@ class LayerNormPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
def weight_per_sample_grad(weight):
return sum_over_all_but_batch_and_last_n(

View File

@ -16,7 +16,7 @@ from .expanded_weights_utils import (
@implements_per_sample_grads(F.linear)
class LinearPerSampleGrad(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, _, __, *expanded_args_and_kwargs):
if len(expanded_args_and_kwargs[0].shape) <= 1:
raise RuntimeError(
@ -36,7 +36,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
return output
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
input, weight = ctx.args
bias = ctx.kwargs["bias"]

View File

@ -77,7 +77,7 @@ def swap_tensor(
setattr(module, name, tensor)
elif hasattr(module, name):
delattr(module, name)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return orig_tensor

View File

@ -41,11 +41,11 @@ def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
def _no_grad_wrapper(*args, **kwargs):
with torch.no_grad():
# pyrefly: ignore # invalid-param-spec
# pyrefly: ignore [invalid-param-spec]
return func(*args, **kwargs)
functools.update_wrapper(_no_grad_wrapper, func)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return _no_grad_wrapper
@ -283,7 +283,7 @@ def clip_grad_value_(
clip_value = float(clip_value)
grads = [p.grad for p in parameters if p.grad is not None]
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
grouped_grads = _group_tensors_by_device_and_dtype([grads])
for (device, _), ([grads], _) in grouped_grads.items():

View File

@ -84,7 +84,7 @@ def convert_conv2d_weight_memory_format(
)
for child in module.children():
convert_conv2d_weight_memory_format(child, memory_format)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return module
@ -164,7 +164,7 @@ def convert_conv3d_weight_memory_format(
)
for child in module.children():
convert_conv3d_weight_memory_format(child, memory_format)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return module

View File

@ -98,7 +98,7 @@ class _Orthogonal(Module):
)
# Q is now orthogonal (or unitary) of size (..., n, n)
if n != k:
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
Q = Q[..., :k]
# Q is now the size of the X (albeit perhaps transposed)
else:
@ -111,10 +111,10 @@ class _Orthogonal(Module):
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
if hasattr(self, "base"):
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
Q = self.base @ Q
if transposed:
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
Q = Q.mT
return Q # type: ignore[possibly-undefined]

View File

@ -179,28 +179,28 @@ class ParametrizationList(ModuleList):
# Register the tensor(s)
if self.is_tensor:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if original.dtype != new.dtype:
raise ValueError(
"When `right_inverse` outputs one tensor, it may not change the dtype.\n"
f"original.dtype: {original.dtype}\n"
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
f"right_inverse(original).dtype: {new.dtype}"
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if original.device != new.device:
raise ValueError(
"When `right_inverse` outputs one tensor, it may not change the device.\n"
f"original.device: {original.device}\n"
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
f"right_inverse(original).device: {new.device}"
)
# Set the original to original so that the user does not need to re-register the parameter
# manually in the optimiser
with torch.no_grad():
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_maybe_set(original, new)
_register_parameter_or_buffer(self, "original", original)
else:
@ -401,7 +401,7 @@ def _inject_property(module: Module, tensor_name: str) -> None:
if torch.jit.is_scripting():
raise RuntimeError("Parametrization is not working with scripting.")
parametrization = self.parametrizations[tensor_name]
# pyrefly: ignore # redundant-condition
# pyrefly: ignore [redundant-condition]
if _cache_enabled:
if torch.jit.is_scripting():
# Scripting
@ -701,7 +701,7 @@ def remove_parametrizations(
# Fetch the original tensor
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
parametrizations = module.parametrizations[tensor_name]
# pyrefly: ignore # invalid-argument
# pyrefly: ignore [invalid-argument]
if parametrizations.is_tensor:
original = parametrizations.original
assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor"

View File

@ -274,11 +274,11 @@ class PruningContainer(BasePruningMethod):
if not isinstance(args, Iterable): # only 1 item
self._tensor_name = args._tensor_name
self.add_pruning_method(args)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
elif len(args) == 1: # only 1 item in a tuple
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
self._tensor_name = args[0]._tensor_name
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
self.add_pruning_method(args[0])
else: # manual construction from list or other iterable (or no args)
for method in args:
@ -1100,7 +1100,7 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw
# flatten importance scores to consider them all at once in global pruning
relevant_importance_scores = torch.nn.utils.parameters_to_vector(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
[
importance_scores.get((module, name), getattr(module, name))
for (module, name) in parameters

View File

@ -332,7 +332,7 @@ def spectral_norm(
else:
dim = 0
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return module

View File

@ -117,7 +117,7 @@ def context_decorator(ctx, func):
@functools.wraps(func)
def decorate_context(*args, **kwargs):
# pyrefly: ignore # bad-context-manager
# pyrefly: ignore [bad-context-manager]
with ctx_factory():
return func(*args, **kwargs)

View File

@ -1074,7 +1074,7 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
with python_pytree._NODE_REGISTRY_LOCK:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
python_pytree._cxx_pytree_imported = True
args, kwargs = (), {} # type: ignore[var-annotated]
for args, kwargs in python_pytree._cxx_pytree_pending_imports:

View File

@ -262,7 +262,7 @@ class DebugMode(TorchDispatchMode):
self.module_tracker.__enter__() # type: ignore[attribute, union-attr]
return self
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def __exit__(self, *args):
super().__exit__(*args)
if self.record_nn_module:

View File

@ -60,7 +60,7 @@ def _device_constructors():
# NB: This is directly called from C++ in torch/csrc/Device.cpp
class DeviceContext(TorchFunctionMode):
def __init__(self, device):
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.device = torch.device(device)
def __enter__(self):

View File

@ -35,12 +35,12 @@ def cache_method(
if not (cache := getattr(self, cache_name, None)):
cache = {}
setattr(self, cache_name, cache)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
cached_value = cache.get(args, _cache_sentinel)
if cached_value is not _cache_sentinel:
return cached_value
value = f(self, *args, **kwargs)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
cache[args] = value
return value

View File

@ -708,7 +708,7 @@ class structseq(tuple[_T_co, ...]):
def __new__(
cls: type[Self],
sequence: Iterable[_T_co],
# pyrefly: ignore # bad-function-definition
# pyrefly: ignore [bad-function-definition]
dict: dict[str, Any] = ...,
) -> Self:
raise NotImplementedError
@ -755,7 +755,7 @@ def _tuple_flatten_with_keys(
d: tuple[T, ...],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _tuple_flatten(d)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -769,7 +769,7 @@ def _list_flatten(d: list[T]) -> tuple[list[T], Context]:
def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _list_flatten(d)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -785,7 +785,7 @@ def _dict_flatten_with_keys(
d: dict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _dict_flatten(d)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return [(MappingKey(k), v) for k, v in zip(context, values)], context
@ -801,7 +801,7 @@ def _namedtuple_flatten_with_keys(
d: NamedTuple,
) -> tuple[list[tuple[KeyEntry, Any]], Context]:
values, context = _namedtuple_flatten(d)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return (
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
context,
@ -851,7 +851,7 @@ def _ordereddict_flatten_with_keys(
d: OrderedDict[Any, T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _ordereddict_flatten(d)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return [(MappingKey(k), v) for k, v in zip(context, values)], context
@ -876,7 +876,7 @@ def _defaultdict_flatten_with_keys(
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _defaultdict_flatten(d)
_, dict_context = context
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
@ -925,7 +925,7 @@ def _deque_flatten_with_keys(
d: deque[T],
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _deque_flatten(d)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
@ -1827,7 +1827,7 @@ def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
for attr in classname.split("."):
enum_cls = getattr(enum_cls, attr)
enum_cls = cast(type[Enum], enum_cls)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
return enum_cls[obj["name"]]
return obj

View File

@ -305,7 +305,7 @@ def strobelight(
) -> Callable[_P, Optional[_R]]:
@functools.wraps(work_function)
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return profiler.profile(work_function, *args, **kwargs)
return wrapper_function

View File

@ -105,7 +105,7 @@ def _keep_float(
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
@functools.wraps(f)
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
r: Union[_T, sympy.Float] = f(*args)
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
r, sympy.Float
@ -113,7 +113,7 @@ def _keep_float(
r = sympy.Float(float(r))
return r
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return inner
@ -200,12 +200,12 @@ class FloorDiv(sympy.Function):
@property
def base(self) -> sympy.Basic:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return self.args[0]
@property
def divisor(self) -> sympy.Basic:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return self.args[1]
def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
@ -374,7 +374,7 @@ class ModularIndexing(sympy.Function):
return None
def _eval_is_nonnegative(self) -> Optional[bool]:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
p, q = self.args[:2]
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
@ -455,7 +455,7 @@ class PythonMod(sympy.Function):
# - floor(p / q) = 0
# - p % q = p - floor(p / q) * q = p
less = p < q
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if less.is_Boolean and bool(less) and r.is_positive:
return p
@ -472,11 +472,11 @@ class PythonMod(sympy.Function):
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
def _ccode(self, printer):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
abs_q = str(q) if self.args[1].is_positive else f"abs({q})"
return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}"
@ -559,7 +559,7 @@ class CeilToInt(sympy.Function):
return sympy.Integer(math.ceil(float(number)))
def _ccode(self, printer):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5)
return f"ceil({number})"
@ -830,7 +830,7 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
if not cond:
return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
if isinstance(ai, cls):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
return a
@ -1008,7 +1008,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
def _eval_is_negative(self): # type:ignore[override]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return fuzzy_and(a.is_negative for a in self.args)
@ -1027,7 +1027,7 @@ class Min(MinMaxBase, Application): # type: ignore[misc]
return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
def _eval_is_negative(self): # type:ignore[override]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return fuzzy_or(a.is_negative for a in self.args)
@ -1165,9 +1165,9 @@ class IntTrueDiv(sympy.Function):
return sympy.Float(int(base) / int(divisor))
def _ccode(self, printer):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
return f"((int){base}/(int){divisor})"
@ -1331,16 +1331,16 @@ class Identity(sympy.Function):
precedence = 10
def __repr__(self): # type: ignore[override]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return f"Identity({self.args[0]})"
def _sympystr(self, printer):
"""Controls how sympy's StrPrinter prints this"""
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return f"({printer.doprint(self.args[0])})"
def _eval_is_real(self):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return self.args[0].is_real
def _eval_is_integer(self):
@ -1348,15 +1348,15 @@ class Identity(sympy.Function):
def _eval_expand_identity(self, **hints):
# Removes the identity op.
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return self.args[0]
def __int__(self) -> int:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return int(self.args[0])
def __float__(self) -> float:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return float(self.args[0])

View File

@ -9,7 +9,7 @@ from sympy.core.parameters import global_parameters
from sympy.core.singleton import S, Singleton
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
class IntInfinity(Number, metaclass=Singleton):
r"""Positive integer infinite quantity.
@ -204,7 +204,7 @@ class IntInfinity(Number, metaclass=Singleton):
int_oo = S.IntInfinity
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
class NegativeIntInfinity(Number, metaclass=Singleton):
"""Negative integer infinite quantity.

View File

@ -66,7 +66,7 @@ class ExprPrinter(StrPrinter):
# NB: this pow by natural, you should never have used builtin sympy.pow
# for FloatPow, and a symbolic exponent should be PowByNatural. These
# means exp is guaranteed to be integer.
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _print_Pow(self, expr: sympy.Expr) -> str:
base, exp = expr.args
if exp != int(exp):

View File

@ -176,7 +176,7 @@ class ReferenceAnalysis:
@staticmethod
def pow(a, b):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return _keep_float(FloatPow)(a, b)
@staticmethod

View File

@ -126,9 +126,9 @@ AllFn2 = Union[ExprFn2, BoolFn2]
class ValueRanges(Generic[_T]):
if TYPE_CHECKING:
# ruff doesn't understand circular references but mypy does
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
ExprVR = ValueRanges[sympy.Expr] # noqa: F821
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
BoolVR = ValueRanges[SympyBoolean] # noqa: F821
AllVR = Union[ExprVR, BoolVR]
@ -484,7 +484,7 @@ class SymPyValueRangeAnalysis:
@staticmethod
def to_dtype(a, dtype, src_dtype=None):
if dtype == torch.float64:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return ValueRanges.increasing_map(a, ToFloat)
elif dtype == torch.bool:
return ValueRanges.unknown_bool()
@ -494,7 +494,7 @@ class SymPyValueRangeAnalysis:
@staticmethod
def trunc_to_int(a, dtype):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return ValueRanges.increasing_map(a, TruncToInt)
@staticmethod
@ -652,7 +652,7 @@ class SymPyValueRangeAnalysis:
return ValueRanges.coordinatewise_monotone_map(
a,
b,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_keep_float(IntTrueDiv),
)
@ -668,7 +668,7 @@ class SymPyValueRangeAnalysis:
return ValueRanges.coordinatewise_monotone_map(
a,
b,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_keep_float(FloatTrueDiv),
)
@ -748,7 +748,7 @@ class SymPyValueRangeAnalysis:
# We should know that b >= 0 but we may have forgotten this fact due
# to replacements, so don't assert it, but DO clamp it to prevent
# degenerate problems
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
return ValueRanges.coordinatewise_increasing_map(
a, b & ValueRanges(0, int_oo), PowByNatural
)
@ -915,7 +915,7 @@ class SymPyValueRangeAnalysis:
@classmethod
def round_to_int(cls, number, dtype):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return ValueRanges.increasing_map(number, RoundToInt)
# It's used in some models on symints
@ -1032,7 +1032,7 @@ class SymPyValueRangeAnalysis:
@staticmethod
def trunc(x):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return ValueRanges.increasing_map(x, TruncToFloat)

View File

@ -63,7 +63,7 @@ def generate_coo_data(size, sparse_dim, nnz, dtype, device):
indices = torch.rand(sparse_dim, nnz, device=device)
indices.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(indices))
indices = indices.to(torch.long)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
values = torch.rand([nnz, ], dtype=dtype, device=device)
return indices, values

View File

@ -170,7 +170,7 @@ if HAS_TABULATE:
_disable_tensor_cores()
table.append([
("Training" if optimizer else "Inference"),
# pyrefly: ignore # redundant-condition
# pyrefly: ignore [redundant-condition]
backend if backend else "-",
mode if mode is not None else "-",
f"{compilation_time} ms " if compilation_time else "-",
@ -191,5 +191,5 @@ if HAS_TABULATE:
])
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
return tabulate(table, headers=field_names, tablefmt="github")

View File

@ -35,7 +35,7 @@ def _get_build_root() -> str:
global _BUILD_ROOT
if _BUILD_ROOT is None:
_BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build")
# pyrefly: ignore # missing-argument
# pyrefly: ignore [missing-argument]
atexit.register(shutil.rmtree, _BUILD_ROOT)
return _BUILD_ROOT

View File

@ -92,7 +92,7 @@ class FuzzedSparseTensor(FuzzedTensor):
return x
def _make_tensor(self, params, state):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
size, _, _ = self._get_size_and_steps(params)
density = params['density']
nnz = math.ceil(sum(size) * density)
@ -102,10 +102,10 @@ class FuzzedSparseTensor(FuzzedTensor):
is_coalesced = params['coalesced']
sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size)
sparse_dim = min(sparse_dim, len(size))
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if self._cuda:
tensor = tensor.cuda()
sparse_dim = tensor.sparse_dim()
@ -121,7 +121,7 @@ class FuzzedSparseTensor(FuzzedTensor):
"sparse_dim": sparse_dim,
"dense_dim": dense_dim,
"is_hybrid": is_hybrid,
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
"dtype": str(self._dtype),
}
return tensor, properties

View File

@ -234,7 +234,7 @@ class Timer:
setup = textwrap.dedent(setup)
setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip()
# pyrefly: ignore # bad-instantiation
# pyrefly: ignore [bad-instantiation]
self._timer = self._timer_cls(
stmt=stmt,
setup=setup,

View File

@ -449,13 +449,13 @@ class GlobalsBridge:
load_lines = []
for name, wrapped_value in self._globals.items():
if wrapped_value.setup is not None:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
load_lines.append(textwrap.dedent(wrapped_value.setup))
if wrapped_value.serialization == Serialization.PICKLE:
path = os.path.join(self._data_dir, f"{name}.pkl")
load_lines.append(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)")
with open(path, "wb") as f:
pickle.dump(wrapped_value.value, f)
@ -465,13 +465,13 @@ class GlobalsBridge:
# TODO: Figure out if we can use torch.serialization.add_safe_globals here
# Using weights_only=False after the change in
# https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)")
torch.save(wrapped_value.value, path)
elif wrapped_value.serialization == Serialization.TORCH_JIT:
path = os.path.join(self._data_dir, f"{name}.pt")
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
load_lines.append(f"{name} = torch.jit.load({repr(path)})")
with open(path, "wb") as f:
torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call]

View File

@ -222,7 +222,7 @@ def _get_autocast_kwargs(device_type="cuda"):
class CheckpointFunction(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, run_function, preserve_rng_state, *args):
check_backward_validity(args)
ctx.run_function = run_function
@ -785,7 +785,7 @@ class _Holder:
class _NoopSaveInputs(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(*args):
return torch.empty((0,))
@ -1008,7 +1008,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
def logging_mode():
with LoggingTensorMode(), \
capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.logs, self.tbs = logs_and_tb
yield logs_and_tb
return logging_mode()

View File

@ -788,7 +788,7 @@ class BuildExtension(build_ext):
# Use absolute path for output_dir so that the object file paths
# (`objects`) get generated with absolute paths.
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
output_dir = os.path.abspath(output_dir)
# See Note [Absolute include_dirs]
@ -979,7 +979,7 @@ class BuildExtension(build_ext):
is_standalone=False):
if not self.compiler.initialized:
self.compiler.initialize()
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
output_dir = os.path.abspath(output_dir)
# Note [Absolute include_dirs]
@ -2573,7 +2573,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]:
def _get_vc_env(vc_arch: str) -> dict[str, str]:
try:
from setuptools import distutils # type: ignore[attr-defined]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
return distutils._msvccompiler._get_vc_env(vc_arch)
except AttributeError:
try:

View File

@ -204,7 +204,7 @@ def collate(
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError("each element in list of batch should be of equal size")
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.

View File

@ -70,7 +70,7 @@ def pin_memory(data, device=None):
return clone
else:
return type(data)(
# pyrefly: ignore # bad-argument-count
# pyrefly: ignore [bad-argument-count]
{k: pin_memory(sample, device) for k, sample in data.items()}
) # type: ignore[call-arg]
except TypeError:

View File

@ -265,7 +265,7 @@ class _DataPipeType:
# Default type for DataPipe without annotation
_T_co = TypeVar("_T_co", covariant=True)
# pyrefly: ignore # invalid-annotation
# pyrefly: ignore [invalid-annotation]
_DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
@ -284,7 +284,7 @@ class _DataPipeMeta(GenericMeta):
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
# TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
# pyrefly: ignore # no-access
# pyrefly: ignore [no-access]
cls.__origin__ = None
if "type" in namespace:
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]

View File

@ -80,7 +80,7 @@ class Capture:
def _ops_str(self):
res = ""
# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
for op in self.ctx["operations"]:
if len(res) > 0:
res += "\n"
@ -90,7 +90,7 @@ class Capture:
def __getstate__(self):
# TODO(VitalyFedyunin): Currently can't pickle (why?)
self.ctx["schema_df"] = None
# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
for var in self.ctx["variables"]:
var.calculated_value = None
state = {}
@ -114,13 +114,13 @@ class Capture:
return CaptureGetItem(self, key, ctx=self.ctx)
def __setitem__(self, key, value):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
def __add__(self, add_val):
res = CaptureAdd(self, add_val, ctx=self.ctx)
var = CaptureVariable(res, ctx=self.ctx)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.ctx["operations"].append(
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
)
@ -129,7 +129,7 @@ class Capture:
def __sub__(self, add_val):
res = CaptureSub(self, add_val, ctx=self.ctx)
var = CaptureVariable(res, ctx=self.ctx)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.ctx["operations"].append(
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
)
@ -139,19 +139,19 @@ class Capture:
res = CaptureMul(self, add_val, ctx=self.ctx)
var = CaptureVariable(res, ctx=self.ctx)
t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.ctx["operations"].append(t)
return var
def _is_context_empty(self):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
def apply_ops_2(self, dataframe):
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
self.ctx["variables"][0].calculated_value = dataframe
# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
for op in self.ctx["operations"]:
op.execute()
@ -184,7 +184,7 @@ class Capture:
res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
var = CaptureVariable(None, ctx=self.ctx)
t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.ctx["operations"].append(t)
return var
@ -283,9 +283,9 @@ class CaptureVariable(Capture):
def apply_ops(self, dataframe):
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
self.ctx["variables"][0].calculated_value = dataframe
# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
for op in self.ctx["operations"]:
op.execute()
return self.calculated_value
@ -385,7 +385,7 @@ def get_val(capture):
class CaptureInitial(CaptureVariable):
def __init__(self, schema_df=None):
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
new_ctx: dict[str, list[Any]] = {
"operations": [],
"variables": [],
@ -401,7 +401,7 @@ class CaptureDataFrame(CaptureInitial):
class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
def as_datapipe(self):
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)
def raw_iterator(self):

View File

@ -92,7 +92,7 @@ class FilterDataFramesPipe(DFIterDataPipe):
size = None
all_buffer = []
filter_res = []
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for df in self.source_datapipe:
if size is None:
size = len(df.index)

View File

@ -135,7 +135,7 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
_fast_forward_iterator: Optional[Iterator] = None
def __iter__(self) -> Iterator[_T_co]:
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return self
def __getattr__(self, attribute_name):
@ -380,7 +380,7 @@ class _DataPipeSerializationWrapper:
value = pickle.dumps(self._datapipe)
except Exception:
if HAS_DILL:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
value = dill.dumps(self._datapipe)
use_dill = True
else:
@ -390,7 +390,7 @@ class _DataPipeSerializationWrapper:
def __setstate__(self, state):
value, use_dill = state
if use_dill:
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self._datapipe = dill.loads(value)
else:
self._datapipe = pickle.loads(value)
@ -407,7 +407,7 @@ class _DataPipeSerializationWrapper:
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
def __init__(self, datapipe: IterDataPipe[_T_co]):
super().__init__(datapipe)
# pyrefly: ignore # invalid-type-var
# pyrefly: ignore [invalid-type-var]
self._datapipe_iter: Optional[Iterator[_T_co]] = None
def __iter__(self) -> "_IterDataPipeSerializationWrapper":

View File

@ -118,7 +118,7 @@ class MapperIterDataPipe(IterDataPipe[_T_co]):
for idx in sorted(self.input_col[1:], reverse=True):
del data[idx]
else:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
data[self.input_col] = res
else:
if self.output_col == -1:

View File

@ -43,7 +43,7 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]):
"Sampler class requires input datapipe implemented `__len__`"
)
super().__init__()
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.datapipe = datapipe
self.sampler_args = () if sampler_args is None else sampler_args
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs

View File

@ -59,7 +59,7 @@ class ConcaterIterDataPipe(IterDataPipe):
def __len__(self) -> int:
if all(isinstance(dp, Sized) for dp in self.datapipes):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return sum(len(dp) for dp in self.datapipes)
else:
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
@ -180,7 +180,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
self._child_stop: list[bool] = [True for _ in range(num_instances)]
def __len__(self):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return len(self.main_datapipe)
def get_next_element_by_instance(self, instance_id: int):
@ -240,7 +240,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
return self.end_ptr is not None and all(self._child_stop)
def get_length_by_instance(self, instance_id: int) -> int:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return len(self.main_datapipe)
def reset(self) -> None:
@ -327,7 +327,7 @@ class _ChildDataPipe(IterDataPipe):
if not isinstance(main_datapipe, _ContainerTemplate):
raise AssertionError("main_datapipe must implement _ContainerTemplate")
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.main_datapipe: IterDataPipe = main_datapipe
self.instance_id = instance_id
@ -454,7 +454,7 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
drop_none: bool,
buffer_size: int,
):
# pyrefly: ignore # invalid-type-var
# pyrefly: ignore [invalid-type-var]
self.main_datapipe = datapipe
self._datapipe_iterator: Optional[Iterator[Any]] = None
self.num_instances = num_instances
@ -466,9 +466,9 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
UserWarning,
)
self.current_buffer_usage = 0
# pyrefly: ignore # invalid-type-var
# pyrefly: ignore [invalid-type-var]
self.child_buffers: list[deque[_T_co]] = [deque() for _ in range(num_instances)]
# pyrefly: ignore # invalid-type-var
# pyrefly: ignore [invalid-type-var]
self.classifier_fn = classifier_fn
self.drop_none = drop_none
self.main_datapipe_exhausted = False
@ -706,7 +706,7 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
def __len__(self) -> int:
if all(isinstance(dp, Sized) for dp in self.datapipes):
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return min(len(dp) for dp in self.datapipes)
else:
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")

View File

@ -204,9 +204,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
drop_remaining: bool = False,
):
_check_unpickable_fn(group_key_fn)
# pyrefly: ignore # invalid-type-var
# pyrefly: ignore [invalid-type-var]
self.datapipe = datapipe
# pyrefly: ignore # invalid-type-var
# pyrefly: ignore [invalid-type-var]
self.group_key_fn = group_key_fn
self.keep_key = keep_key
@ -218,14 +218,14 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
if group_size is not None and buffer_size is not None:
if not (0 < group_size <= buffer_size):
raise AssertionError("group_size must be > 0 and <= buffer_size")
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.guaranteed_group_size = group_size
if guaranteed_group_size is not None:
if group_size is None or not (0 < guaranteed_group_size <= group_size):
raise AssertionError(
"guaranteed_group_size must be > 0 and <= group_size and group_size must be set"
)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.guaranteed_group_size = guaranteed_group_size
self.drop_remaining = drop_remaining
self.wrapper_class = DataChunk

View File

@ -60,7 +60,7 @@ class MapperMapDataPipe(MapDataPipe[_T_co]):
self.fn = fn # type: ignore[assignment]
def __len__(self) -> int:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return len(self.datapipe)
def __getitem__(self, index) -> _T_co:

View File

@ -64,7 +64,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]):
) -> None:
super().__init__()
self.datapipe = datapipe
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.indices = list(range(len(datapipe))) if indices is None else indices
self._enabled = True
self._seed = None
@ -96,7 +96,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]):
self._shuffled_indices = self._rng.sample(self.indices, len(self.indices))
def __len__(self) -> int:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return len(self.datapipe)
def __getstate__(self):

View File

@ -49,16 +49,16 @@ class ConcaterMapDataPipe(MapDataPipe):
def __getitem__(self, index) -> _T_co: # type: ignore[type-var]
offset = 0
for dp in self.datapipes:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if index - offset < len(dp):
return dp[index - offset]
else:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
offset += len(dp)
raise IndexError(f"Index {index} is out of range.")
def __len__(self) -> int:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return sum(len(dp) for dp in self.datapipes)
@ -105,5 +105,5 @@ class ZipperMapDataPipe(MapDataPipe[tuple[_T_co, ...]]):
return tuple(res)
def __len__(self) -> int:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return min(len(dp) for dp in self.datapipes)

View File

@ -196,7 +196,7 @@ def get_file_pathnames_from_root(
if match_masks(fname, masks):
yield path
else:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for path, dirs, files in os.walk(root, onerror=onerror):
if abspath:
path = os.path.abspath(path)

View File

@ -43,7 +43,7 @@ def _simple_graph_snapshot_restoration(
# simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`,
# the first reset will not actually reset.
datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`.
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
apply_random_seed(datapipe, rng)
remainder = n_iterations

View File

@ -137,7 +137,7 @@ class DistributedSampler(Sampler[_T_co]):
f"Number of subsampled indices ({len(indices)}) does not match num_samples ({self.num_samples})"
)
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return iter(indices)
def __len__(self) -> int:

View File

@ -72,7 +72,7 @@ def _list_connected_datapipes(
p.dump(scan_obj)
except (pickle.PickleError, AttributeError, TypeError):
if dill_available():
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
d.dump(scan_obj)
else:
raise

View File

@ -31,7 +31,7 @@ class FileBaton:
True if the file could be created, else False.
"""
try:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL)
return True
except FileExistsError:

View File

@ -149,7 +149,7 @@ def conv_flop_count(
@register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward])
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
"""Count flops for convolution."""
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)

View File

@ -145,7 +145,7 @@ class BackwardHook:
res = out
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.grad_outputs = None
return self._unpack_none(self.input_tensors_index, res)

View File

@ -237,7 +237,7 @@ def get_model_info(
with zipfile.ZipFile(path_or_file) as zf:
path_prefix = None
zip_files = []
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for zi in zf.infolist():
prefix = re.sub("/.*", "", zi.filename)
if path_prefix is None:
@ -392,12 +392,12 @@ def get_inline_skeleton():
import importlib.resources
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
skeleton = importlib.resources.read_text(__package__, "skeleton.html")
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
js_code = importlib.resources.read_text(__package__, "code.js")
for js_module in ["preact", "htm"]:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs")
js_url = "data:application/javascript," + urllib.parse.quote(js_lib)
js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url)

View File

@ -31,7 +31,7 @@ def make_np(x: torch.Tensor) -> np.ndarray:
def _prepare_pytorch(x: torch.Tensor) -> np.ndarray:
if x.dtype == torch.bfloat16:
x = x.to(torch.float16)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
x = x.detach().cpu().numpy()
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return x

View File

@ -188,7 +188,7 @@ class GraphPy:
for key, node in self.nodes_io.items():
if type(node) is NodeBase:
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
if hasattr(node, "input_or_output"):
self.unique_name_to_scoped_name[key] = (
@ -199,7 +199,7 @@ class GraphPy:
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
if node.scope == "" and self.shallowest_scope_name:
self.unique_name_to_scoped_name[node.debugName] = (
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
self.shallowest_scope_name + "/" + node.debugName
)

View File

@ -57,14 +57,14 @@ def _prepare_video(V):
return num != 0 and ((num & (num - 1)) == 0)
# pad to nearest power of 2, all at once
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
if not is_power2(V.shape[0]):
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0])
V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0)
n_rows = 2 ** ((b.bit_length() - 1) // 2)
# pyrefly: ignore # index-error
# pyrefly: ignore [index-error]
n_cols = V.shape[0] // n_rows
V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w))

View File

@ -498,7 +498,7 @@ def make_histogram(values, bins, max_bins=None):
subsampling = num_bins // max_bins
subsampling_remainder = num_bins % subsampling
if subsampling_remainder != 0:
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
counts = np.pad(
counts,
pad_width=[[0, subsampling - subsampling_remainder]],
@ -838,21 +838,21 @@ def compute_curve(labels, predictions, num_thresholds=None, weights=None):
weights = 1.0
# Compute bins of true positives and false positives.
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
float_labels = labels.astype(np.float64)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
histogram_range = (0, num_thresholds - 1)
tp_buckets, _ = np.histogram(
bucket_indices,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
bins=num_thresholds,
range=histogram_range,
weights=float_labels * weights,
)
fp_buckets, _ = np.histogram(
bucket_indices,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
bins=num_thresholds,
range=histogram_range,
weights=(1.0 - float_labels) * weights,

View File

@ -254,9 +254,9 @@ class SummaryWriter:
buckets = []
neg_buckets = []
while v < 1e20:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
buckets.append(v)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
neg_buckets.append(-v)
v *= 1.1
self.default_bins = neg_buckets[::-1] + [0] + buckets
@ -264,19 +264,19 @@ class SummaryWriter:
def _get_file_writer(self):
"""Return the default FileWriter instance. Recreates it if closed."""
if self.all_writers is None or self.file_writer is None:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.file_writer = FileWriter(
self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix
)
# pyrefly: ignore # bad-assignment, missing-attribute
# pyrefly: ignore [bad-assignment, missing-attribute]
self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
if self.purge_step is not None:
most_recent_step = self.purge_step
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.file_writer.add_event(
Event(step=most_recent_step, file_version="brain.Event:2")
)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
self.file_writer.add_event(
Event(
step=most_recent_step,
@ -1207,7 +1207,7 @@ class SummaryWriter:
for writer in self.all_writers.values():
writer.flush()
writer.close()
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.file_writer = self.all_writers = None
def __enter__(self):

View File

@ -461,7 +461,7 @@ def to_html(nodes):
if n.context is None:
continue
s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}'))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
listeners.append(s)
dot = to_dot(nodes)
return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners))