mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix error suppression syntax in utils and nn (#166242)
Fixes syntax for pyrefly : ignores so they only ignore a specific category. No functional changes pyrefly check lintrunner Pull Request resolved: https://github.com/pytorch/pytorch/pull/166242 Approved by: https://github.com/oulgen, https://github.com/cyyever
This commit is contained in:
parent
5121499f6b
commit
84b14f3a10
|
|
@ -153,7 +153,7 @@ class CausalBias(torch.Tensor):
|
|||
diagonal=diagonal_offset,
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
def _materialize(self, device: Optional[torch.device] = None) -> torch.Tensor:
|
||||
"""
|
||||
Materializes the causal bias into a tensor form.
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ _score_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor, Tensor], Tensor
|
|||
_mask_mod_signature = Callable[[Tensor, Tensor, Tensor, Tensor], Tensor]
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class FlexKernelOptions(TypedDict, total=False):
|
||||
"""Options for controlling the behavior of FlexAttention kernels.
|
||||
|
||||
|
|
@ -128,97 +128,97 @@ class FlexKernelOptions(TypedDict, total=False):
|
|||
"""
|
||||
|
||||
# Performance tuning options
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
num_warps: NotRequired[int]
|
||||
"""Number of warps to use in the CUDA kernel. Higher values may improve performance
|
||||
but increase register pressure. Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
num_stages: NotRequired[int]
|
||||
"""Number of pipeline stages in the CUDA kernel. Higher values may improve performance
|
||||
but increase shared memory usage. Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
BLOCK_M: NotRequired[int]
|
||||
"""Thread block size for the sequence length dimension of Q in forward pass.
|
||||
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
BLOCK_N: NotRequired[int]
|
||||
"""Thread block size for the sequence length dimension of K/V in forward pass.
|
||||
Must be a power of 2. Common values: 16, 32, 64, 128. Default is determined by autotuning."""
|
||||
|
||||
# Backward-specific block sizes (when prefixed with 'bwd_')
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
BLOCK_M1: NotRequired[int]
|
||||
"""Thread block size for Q dimension in backward pass. Use as 'bwd_BLOCK_M1'.
|
||||
Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
BLOCK_N1: NotRequired[int]
|
||||
"""Thread block size for K/V dimension in backward pass. Use as 'bwd_BLOCK_N1'.
|
||||
Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
BLOCK_M2: NotRequired[int]
|
||||
"""Thread block size for second Q dimension in backward pass. Use as 'bwd_BLOCK_M2'.
|
||||
Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
BLOCK_N2: NotRequired[int]
|
||||
"""Thread block size for second K/V dimension in backward pass. Use as 'bwd_BLOCK_N2'.
|
||||
Default is determined by autotuning."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
PRESCALE_QK: NotRequired[bool]
|
||||
"""Whether to pre-scale QK by 1/sqrt(d) and change of base. This is slightly faster but
|
||||
may have more numerical error. Default: False."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
ROWS_GUARANTEED_SAFE: NotRequired[bool]
|
||||
"""If True, guarantees that at least one value in each row is not masked out.
|
||||
Allows skipping safety checks for better performance. Only set this if you are certain
|
||||
your mask guarantees this property. For example, causal attention is guaranteed safe
|
||||
because each query has at least 1 key-value to attend to. Default: False."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
BLOCKS_ARE_CONTIGUOUS: NotRequired[bool]
|
||||
"""If True, guarantees that all blocks in the mask are contiguous.
|
||||
Allows optimizing block traversal. For example, causal masks would satisfy this,
|
||||
but prefix_lm + sliding window would not. Default: False."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
WRITE_DQ: NotRequired[bool]
|
||||
"""Controls whether gradient scatters are done in the DQ iteration loop of the backward pass.
|
||||
Setting this to False will force this to happen in the DK loop which depending on your
|
||||
specific score_mod and mask_mod might be faster. Default: True."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
FORCE_USE_FLEX_ATTENTION: NotRequired[bool]
|
||||
"""If True, forces the use of the flex attention kernel instead of potentially using
|
||||
the more optimized flex-decoding kernel for short sequences. This can be a helpful
|
||||
option for debugging. Default: False."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
USE_TMA: NotRequired[bool]
|
||||
"""Whether to use Tensor Memory Accelerator (TMA) on supported hardware.
|
||||
This is experimental and may not work on all hardware, currently specific
|
||||
to NVIDIA GPUs Hopper+. Default: False."""
|
||||
|
||||
# ROCm-specific options
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
kpack: NotRequired[int]
|
||||
"""ROCm-specific kernel packing parameter."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
matrix_instr_nonkdim: NotRequired[int]
|
||||
"""ROCm-specific matrix instruction non-K dimension."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
waves_per_eu: NotRequired[int]
|
||||
"""ROCm-specific waves per execution unit."""
|
||||
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
force_flash: NotRequired[bool]
|
||||
""" If True, forces use of the cute-dsl flash attention kernel.
|
||||
|
||||
|
|
@ -644,7 +644,7 @@ class BlockMask:
|
|||
block_size = (self.BLOCK_SIZE,) # type: ignore[assignment]
|
||||
seq_lengths = (self.seq_lengths,) # type: ignore[assignment]
|
||||
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return (
|
||||
*seq_lengths,
|
||||
self.kv_num_blocks,
|
||||
|
|
@ -817,7 +817,7 @@ class BlockMask:
|
|||
partial_dense = _ordered_to_dense(self.kv_num_blocks, self.kv_indices)
|
||||
if self.full_kv_num_blocks is not None:
|
||||
assert self.full_kv_indices is not None
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return partial_dense | _ordered_to_dense(
|
||||
self.full_kv_num_blocks, self.full_kv_indices
|
||||
)
|
||||
|
|
|
|||
|
|
@ -78,7 +78,7 @@ class ModuleWrapper(nn.Module):
|
|||
|
||||
# nn.Module defines training as a boolean
|
||||
@property # type: ignore[override]
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def training(self):
|
||||
return self.cpp_module.training
|
||||
|
||||
|
|
|
|||
|
|
@ -1275,7 +1275,7 @@ def adaptive_max_pool2d_with_indices(
|
|||
output_size,
|
||||
return_indices=return_indices,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_size = _list_with_default(output_size, input.size())
|
||||
return torch._C._nn.adaptive_max_pool2d(input, output_size)
|
||||
|
||||
|
|
@ -1333,7 +1333,7 @@ def adaptive_max_pool3d_with_indices(
|
|||
output_size,
|
||||
return_indices=return_indices,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_size = _list_with_default(output_size, input.size())
|
||||
return torch._C._nn.adaptive_max_pool3d(input, output_size)
|
||||
|
||||
|
|
@ -1392,7 +1392,7 @@ def adaptive_avg_pool2d(input: Tensor, output_size: BroadcastingList2[int]) -> T
|
|||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(adaptive_avg_pool2d, (input,), input, output_size)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_output_size = _list_with_default(output_size, input.size())
|
||||
return torch._C._nn.adaptive_avg_pool2d(input, _output_size)
|
||||
|
||||
|
|
@ -1408,7 +1408,7 @@ def adaptive_avg_pool3d(input: Tensor, output_size: BroadcastingList3[int]) -> T
|
|||
"""
|
||||
if has_torch_function_unary(input):
|
||||
return handle_torch_function(adaptive_avg_pool3d, (input,), input, output_size)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_output_size = _list_with_default(output_size, input.size())
|
||||
return torch._C._nn.adaptive_avg_pool3d(input, _output_size)
|
||||
|
||||
|
|
@ -2444,7 +2444,7 @@ def _no_grad_embedding_renorm_(
|
|||
input: Tensor,
|
||||
max_norm: float,
|
||||
norm_type: float,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> tuple[Tensor, Tensor]:
|
||||
torch.embedding_renorm_(weight.detach(), input, max_norm, norm_type)
|
||||
|
||||
|
|
@ -2698,7 +2698,7 @@ def embedding_bag(
|
|||
|
||||
if not torch.jit.is_scripting() and input.dim() == 2 and input.is_nested:
|
||||
include_last_offset = True
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
offsets = input.offsets()
|
||||
input = input.values().reshape(-1)
|
||||
if per_sample_weights is not None:
|
||||
|
|
@ -2833,7 +2833,7 @@ def batch_norm(
|
|||
eps=eps,
|
||||
)
|
||||
if training:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_verify_batch_size(input.size())
|
||||
|
||||
return torch.batch_norm(
|
||||
|
|
@ -2889,7 +2889,7 @@ def instance_norm(
|
|||
eps=eps,
|
||||
)
|
||||
if use_input_stats:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_verify_spatial_size(input.size())
|
||||
return torch.instance_norm(
|
||||
input,
|
||||
|
|
@ -3015,13 +3015,13 @@ def local_response_norm(
|
|||
div = input.mul(input)
|
||||
if dim == 3:
|
||||
div = div.unsqueeze(1)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
div = pad(div, (0, 0, size // 2, (size - 1) // 2))
|
||||
div = avg_pool2d(div, (size, 1), stride=1).squeeze(1)
|
||||
else:
|
||||
sizes = input.size()
|
||||
div = div.view(sizes[0], 1, sizes[1], sizes[2], -1)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
div = pad(div, (0, 0, 0, 0, size // 2, (size - 1) // 2))
|
||||
div = avg_pool3d(div, (size, 1, 1), stride=1).squeeze(1)
|
||||
div = div.view(sizes)
|
||||
|
|
@ -3173,7 +3173,7 @@ def nll_loss(
|
|||
input,
|
||||
target,
|
||||
weight,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_Reduction.get_enum(reduction),
|
||||
ignore_index,
|
||||
)
|
||||
|
|
@ -3320,7 +3320,7 @@ def gaussian_nll_loss(
|
|||
var.clamp_(min=eps)
|
||||
|
||||
# Calculate the loss
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
loss = 0.5 * (torch.log(var) + (input - target) ** 2 / var)
|
||||
if full:
|
||||
loss += 0.5 * math.log(2 * math.pi)
|
||||
|
|
@ -3496,7 +3496,7 @@ def cross_entropy(
|
|||
input,
|
||||
target,
|
||||
weight,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_Reduction.get_enum(reduction),
|
||||
ignore_index,
|
||||
label_smoothing,
|
||||
|
|
@ -3561,7 +3561,7 @@ def binary_cross_entropy(
|
|||
new_size = _infer_size(target.size(), weight.size())
|
||||
weight = weight.expand(new_size)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch._C._nn.binary_cross_entropy(input, target, weight, reduction_enum)
|
||||
|
||||
|
||||
|
|
@ -3692,14 +3692,14 @@ def smooth_l1_loss(
|
|||
return torch._C._nn.l1_loss(
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_Reduction.get_enum(reduction),
|
||||
)
|
||||
else:
|
||||
return torch._C._nn.smooth_l1_loss(
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_Reduction.get_enum(reduction),
|
||||
beta,
|
||||
)
|
||||
|
|
@ -3761,7 +3761,7 @@ def huber_loss(
|
|||
return torch._C._nn.huber_loss(
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_Reduction.get_enum(reduction),
|
||||
delta,
|
||||
)
|
||||
|
|
@ -3773,7 +3773,7 @@ def huber_loss(
|
|||
unweighted_loss = torch._C._nn.huber_loss(
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_Reduction.get_enum("none"),
|
||||
delta,
|
||||
)
|
||||
|
|
@ -3864,7 +3864,7 @@ def l1_loss(
|
|||
return torch._C._nn.l1_loss(
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_Reduction.get_enum(reduction),
|
||||
)
|
||||
|
||||
|
|
@ -3942,7 +3942,7 @@ def mse_loss(
|
|||
return torch._C._nn.mse_loss(
|
||||
expanded_input,
|
||||
expanded_target,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_Reduction.get_enum(reduction),
|
||||
)
|
||||
|
||||
|
|
@ -4080,7 +4080,7 @@ def multilabel_margin_loss(
|
|||
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
|
||||
else:
|
||||
reduction_enum = _Reduction.get_enum(reduction)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch._C._nn.multilabel_margin_loss(input, target, reduction_enum)
|
||||
|
||||
|
||||
|
|
@ -4122,7 +4122,7 @@ def soft_margin_loss(
|
|||
reduction_enum = _Reduction.legacy_get_enum(size_average, reduce)
|
||||
else:
|
||||
reduction_enum = _Reduction.get_enum(reduction)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch._C._nn.soft_margin_loss(input, target, reduction_enum)
|
||||
|
||||
|
||||
|
|
@ -4292,7 +4292,7 @@ def multi_margin_loss(
|
|||
p,
|
||||
margin,
|
||||
weight,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
reduction_enum,
|
||||
)
|
||||
|
||||
|
|
@ -4439,7 +4439,7 @@ def upsample( # noqa: F811
|
|||
scale_factor: Optional[float] = None,
|
||||
mode: str = "nearest",
|
||||
align_corners: Optional[bool] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor: # noqa: B950
|
||||
pass
|
||||
|
||||
|
|
@ -4451,7 +4451,7 @@ def upsample( # noqa: F811
|
|||
scale_factor: Optional[float] = None,
|
||||
mode: str = "nearest",
|
||||
align_corners: Optional[bool] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor: # noqa: B950
|
||||
pass
|
||||
|
||||
|
|
@ -4554,7 +4554,7 @@ def interpolate( # noqa: F811
|
|||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
antialias: bool = False,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor: # noqa: B950
|
||||
pass
|
||||
|
||||
|
|
@ -4568,7 +4568,7 @@ def interpolate( # noqa: F811
|
|||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
antialias: bool = False,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor: # noqa: B950
|
||||
pass
|
||||
|
||||
|
|
@ -4582,7 +4582,7 @@ def interpolate( # noqa: F811
|
|||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
antialias: bool = False,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor: # noqa: B950
|
||||
pass
|
||||
|
||||
|
|
@ -4596,7 +4596,7 @@ def interpolate( # noqa: F811
|
|||
align_corners: Optional[bool] = None,
|
||||
recompute_scale_factor: Optional[bool] = None,
|
||||
antialias: bool = False,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
|
|
@ -4771,7 +4771,7 @@ def interpolate( # noqa: F811
|
|||
(
|
||||
torch.floor(
|
||||
(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
input.size(i + 2).float()
|
||||
* torch.tensor(scale_factors[i], dtype=torch.float32)
|
||||
).float()
|
||||
|
|
@ -4796,28 +4796,28 @@ def interpolate( # noqa: F811
|
|||
)
|
||||
|
||||
if input.dim() == 3 and mode == "nearest":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch._C._nn.upsample_nearest1d(input, output_size, scale_factors)
|
||||
if input.dim() == 4 and mode == "nearest":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch._C._nn.upsample_nearest2d(input, output_size, scale_factors)
|
||||
if input.dim() == 5 and mode == "nearest":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch._C._nn.upsample_nearest3d(input, output_size, scale_factors)
|
||||
|
||||
if input.dim() == 3 and mode == "nearest-exact":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch._C._nn._upsample_nearest_exact1d(input, output_size, scale_factors)
|
||||
if input.dim() == 4 and mode == "nearest-exact":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch._C._nn._upsample_nearest_exact2d(input, output_size, scale_factors)
|
||||
if input.dim() == 5 and mode == "nearest-exact":
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch._C._nn._upsample_nearest_exact3d(input, output_size, scale_factors)
|
||||
|
||||
if input.dim() == 3 and mode == "area":
|
||||
assert output_size is not None
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return adaptive_avg_pool1d(input, output_size)
|
||||
if input.dim() == 4 and mode == "area":
|
||||
assert output_size is not None
|
||||
|
|
@ -4830,7 +4830,7 @@ def interpolate( # noqa: F811
|
|||
assert align_corners is not None
|
||||
return torch._C._nn.upsample_linear1d(
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
|
|
@ -4840,7 +4840,7 @@ def interpolate( # noqa: F811
|
|||
if antialias:
|
||||
return torch._C._nn._upsample_bilinear2d_aa(
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
|
|
@ -4857,7 +4857,7 @@ def interpolate( # noqa: F811
|
|||
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
|
||||
return torch._C._nn.upsample_bilinear2d(
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
|
|
@ -4876,7 +4876,7 @@ def interpolate( # noqa: F811
|
|||
)._upsample_linear_vec(input, output_size, align_corners, scale_factors)
|
||||
return torch._C._nn.upsample_trilinear3d(
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
|
|
@ -4886,14 +4886,14 @@ def interpolate( # noqa: F811
|
|||
if antialias:
|
||||
return torch._C._nn._upsample_bicubic2d_aa(
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
)
|
||||
return torch._C._nn.upsample_bicubic2d(
|
||||
input,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_size,
|
||||
align_corners,
|
||||
scale_factors,
|
||||
|
|
@ -4928,7 +4928,7 @@ def upsample_nearest( # noqa: F811
|
|||
input: Tensor,
|
||||
size: Optional[int] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
|
|
@ -4938,7 +4938,7 @@ def upsample_nearest( # noqa: F811
|
|||
input: Tensor,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
|
|
@ -4980,7 +4980,7 @@ def upsample_bilinear( # noqa: F811
|
|||
input: Tensor,
|
||||
size: Optional[int] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
|
|
@ -4990,7 +4990,7 @@ def upsample_bilinear( # noqa: F811
|
|||
input: Tensor,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
|
|
@ -5000,7 +5000,7 @@ def upsample_bilinear( # noqa: F811
|
|||
input: Tensor,
|
||||
size: Optional[int] = None,
|
||||
scale_factor: Optional[list[float]] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
|
|
@ -5010,7 +5010,7 @@ def upsample_bilinear( # noqa: F811
|
|||
input: Tensor,
|
||||
size: Optional[list[int]] = None,
|
||||
scale_factor: Optional[list[float]] = None,
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
) -> Tensor:
|
||||
pass
|
||||
|
||||
|
|
@ -5817,7 +5817,7 @@ def _in_projection_packed(
|
|||
.squeeze(-2)
|
||||
.contiguous()
|
||||
)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return proj[0], proj[1], proj[2]
|
||||
else:
|
||||
# encoder-decoder attention
|
||||
|
|
@ -5836,7 +5836,7 @@ def _in_projection_packed(
|
|||
.squeeze(-2)
|
||||
.contiguous()
|
||||
)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return (q_proj, kv_proj[0], kv_proj[1])
|
||||
else:
|
||||
w_q, w_k, w_v = w.chunk(3)
|
||||
|
|
@ -5844,7 +5844,7 @@ def _in_projection_packed(
|
|||
b_q = b_k = b_v = None
|
||||
else:
|
||||
b_q, b_k, b_v = b.chunk(3)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
||||
|
||||
|
||||
|
|
@ -6475,10 +6475,10 @@ def multi_head_attention_forward(
|
|||
k = torch.cat([k, bias_k.repeat(1, bsz, 1)])
|
||||
v = torch.cat([v, bias_v.repeat(1, bsz, 1)])
|
||||
if attn_mask is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
attn_mask = pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
key_padding_mask = pad(key_padding_mask, (0, 1))
|
||||
else:
|
||||
assert bias_k is None
|
||||
|
|
@ -6487,10 +6487,10 @@ def multi_head_attention_forward(
|
|||
#
|
||||
# reshape q, k, v for multihead attention and make them batch first
|
||||
#
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
q = q.view(tgt_len, bsz * num_heads, head_dim).transpose(0, 1)
|
||||
if static_k is None:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
|
|
@ -6502,7 +6502,7 @@ def multi_head_attention_forward(
|
|||
)
|
||||
k = static_k
|
||||
if static_v is None:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
|
|
@ -6518,20 +6518,20 @@ def multi_head_attention_forward(
|
|||
if add_zero_attn:
|
||||
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||
k = torch.cat(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)],
|
||||
dim=1,
|
||||
)
|
||||
v = torch.cat(
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)],
|
||||
dim=1,
|
||||
)
|
||||
if attn_mask is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
attn_mask = pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
key_padding_mask = pad(key_padding_mask, (0, 1))
|
||||
|
||||
# update source sequence length after adjustments
|
||||
|
|
@ -6581,7 +6581,7 @@ def multi_head_attention_forward(
|
|||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
|
||||
attn_output = (
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
|
|
@ -6608,16 +6608,16 @@ def multi_head_attention_forward(
|
|||
attn_mask = attn_mask.view(bsz, num_heads, -1, src_len)
|
||||
|
||||
q = q.view(bsz, num_heads, tgt_len, head_dim)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
k = k.view(bsz, num_heads, src_len, head_dim)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
attn_output = (
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -500,7 +500,7 @@ def xavier_normal_(
|
|||
|
||||
|
||||
def _calculate_correct_fan(tensor: Tensor, mode: _FanMode) -> int:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
mode = mode.lower()
|
||||
valid_modes = ["fan_in", "fan_out"]
|
||||
if mode not in valid_modes:
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from torch.autograd.function import Function
|
|||
|
||||
class SyncBatchNorm(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(
|
||||
self,
|
||||
input,
|
||||
|
|
@ -211,7 +211,7 @@ class SyncBatchNorm(Function):
|
|||
|
||||
class CrossMapLRN2d(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, input, size, alpha=1e-4, beta=0.75, k=1):
|
||||
ctx.size = size
|
||||
ctx.alpha = alpha
|
||||
|
|
@ -267,7 +267,7 @@ class CrossMapLRN2d(Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
input, output = ctx.saved_tensors
|
||||
grad_input = grad_output.new()
|
||||
|
|
@ -309,7 +309,7 @@ class CrossMapLRN2d(Function):
|
|||
|
||||
class BackwardHookFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, *args):
|
||||
ctx.mark_non_differentiable(*[arg for arg in args if not arg.requires_grad])
|
||||
return args
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class Sequential(Module):
|
|||
def __init__(self, *args: Module) -> None: ...
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
# pyrefly: ignore [inconsistent-overload]
|
||||
def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
|
||||
|
||||
def __init__(self, *args):
|
||||
|
|
@ -624,11 +624,11 @@ class ModuleDict(Module):
|
|||
"ModuleDict update sequence element "
|
||||
"#" + str(j) + " should be Iterable; is" + type(m).__name__
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if not len(m) == 2:
|
||||
raise ValueError(
|
||||
"ModuleDict update sequence element "
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
"#" + str(j) + " has length " + str(len(m)) + "; 2 is required"
|
||||
)
|
||||
# modules can be Mapping (what it's typed at), or a list: [(name1, module1), (name2, module2)]
|
||||
|
|
@ -687,7 +687,7 @@ class ParameterList(Module):
|
|||
def __getitem__(self, idx: int) -> Any: ...
|
||||
|
||||
@overload
|
||||
# pyrefly: ignore # inconsistent-overload
|
||||
# pyrefly: ignore [inconsistent-overload]
|
||||
def __getitem__(self: T, idx: slice) -> T: ...
|
||||
|
||||
def __getitem__(self, idx):
|
||||
|
|
@ -773,11 +773,11 @@ class ParameterList(Module):
|
|||
size_str,
|
||||
device_str,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
child_lines.append(" (" + str(k) + "): " + parastr)
|
||||
else:
|
||||
child_lines.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
" (" + str(k) + "): Object of type: " + type(p).__name__
|
||||
)
|
||||
|
||||
|
|
@ -985,11 +985,11 @@ class ParameterDict(Module):
|
|||
"ParameterDict update sequence element "
|
||||
"#" + str(j) + " should be Iterable; is" + type(p).__name__
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if not len(p) == 2:
|
||||
raise ValueError(
|
||||
"ParameterDict update sequence element "
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required"
|
||||
)
|
||||
# parameters as length-2 list too cumbersome to type, see ModuleDict.update comment
|
||||
|
|
@ -1010,11 +1010,11 @@ class ParameterDict(Module):
|
|||
size_str,
|
||||
device_str,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
child_lines.append(" (" + str(k) + "): " + parastr)
|
||||
else:
|
||||
child_lines.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
" (" + str(k) + "): Object of type: " + type(p).__name__
|
||||
)
|
||||
tmpstr = "\n".join(child_lines)
|
||||
|
|
|
|||
|
|
@ -1514,7 +1514,7 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
|||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1529,11 +1529,11 @@ class LazyConv1d(_LazyConvXdMixin, Conv1d): # type: ignore[misc]
|
|||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
|
|
@ -1586,7 +1586,7 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
|||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1601,11 +1601,11 @@ class LazyConv2d(_LazyConvXdMixin, Conv2d): # type: ignore[misc]
|
|||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
|
|
@ -1659,7 +1659,7 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
|||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1674,11 +1674,11 @@ class LazyConv3d(_LazyConvXdMixin, Conv3d): # type: ignore[misc]
|
|||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
|
|
@ -1730,7 +1730,7 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
|
|||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1746,11 +1746,11 @@ class LazyConvTranspose1d(_LazyConvXdMixin, ConvTranspose1d): # type: ignore[mi
|
|||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
|
|
@ -1802,7 +1802,7 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
|
|||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1818,11 +1818,11 @@ class LazyConvTranspose2d(_LazyConvXdMixin, ConvTranspose2d): # type: ignore[mi
|
|||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
|
|
@ -1874,7 +1874,7 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
|
|||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(
|
||||
0,
|
||||
0,
|
||||
|
|
@ -1890,11 +1890,11 @@ class LazyConvTranspose3d(_LazyConvXdMixin, ConvTranspose3d): # type: ignore[mi
|
|||
padding_mode,
|
||||
**factory_kwargs,
|
||||
)
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_channels = out_channels
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-override, bad-argument-type
|
||||
# pyrefly: ignore [bad-override, bad-argument-type]
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def _get_num_spatial_dims(self) -> int:
|
||||
|
|
|
|||
|
|
@ -172,9 +172,9 @@ class LazyModuleMixin:
|
|||
def __init__(self: _LazyProtocol, *args, **kwargs):
|
||||
# Mypy doesn't like this super call in a mixin
|
||||
super().__init__(*args, **kwargs) # type: ignore[misc]
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self._load_hook = self._register_load_state_dict_pre_hook(self._lazy_load_hook)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self._initialize_hook = self.register_forward_pre_hook(
|
||||
self._infer_parameters, with_kwargs=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -286,7 +286,7 @@ class LazyLinear(LazyModuleMixin, Linear):
|
|||
"""
|
||||
|
||||
cls_to_become = Linear # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
weight: UninitializedParameter
|
||||
bias: UninitializedParameter # type: ignore[assignment]
|
||||
|
||||
|
|
@ -296,20 +296,20 @@ class LazyLinear(LazyModuleMixin, Linear):
|
|||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# bias is hardcoded to False to avoid creating tensor
|
||||
# that will soon be overwritten.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(0, 0, False)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
self.out_features = out_features
|
||||
if bias:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
"""
|
||||
Resets parameters based on their initialization used in ``__init__``.
|
||||
"""
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if not self.has_uninitialized_params() and self.in_features != 0:
|
||||
super().reset_parameters()
|
||||
|
||||
|
|
@ -317,7 +317,7 @@ class LazyLinear(LazyModuleMixin, Linear):
|
|||
"""
|
||||
Infers ``in_features`` based on ``input`` and initializes parameters.
|
||||
"""
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if self.has_uninitialized_params():
|
||||
with torch.no_grad():
|
||||
self.in_features = input.shape[-1]
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class CircularPad1d(_CircularPadNd):
|
|||
[5., 6., 7., 4., 5., 6., 7., 4.]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t) -> None:
|
||||
|
|
@ -145,7 +145,7 @@ class CircularPad2d(_CircularPadNd):
|
|||
[8., 6., 7., 8., 6.]]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t) -> None:
|
||||
|
|
@ -196,7 +196,7 @@ class CircularPad3d(_CircularPadNd):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t) -> None:
|
||||
|
|
@ -268,7 +268,7 @@ class ConstantPad1d(_ConstantPadNd):
|
|||
[ 3.5000, 3.5000, 3.5000, -3.6372, 0.1182, -1.8652, 3.5000]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t, value: float) -> None:
|
||||
|
|
@ -320,7 +320,7 @@ class ConstantPad2d(_ConstantPadNd):
|
|||
"""
|
||||
|
||||
__constants__ = ["padding", "value"]
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t, value: float) -> None:
|
||||
|
|
@ -361,7 +361,7 @@ class ConstantPad3d(_ConstantPadNd):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t, value: float) -> None:
|
||||
|
|
@ -415,7 +415,7 @@ class ReflectionPad1d(_ReflectionPadNd):
|
|||
[7., 6., 5., 4., 5., 6., 7., 6.]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t) -> None:
|
||||
|
|
@ -469,7 +469,7 @@ class ReflectionPad2d(_ReflectionPadNd):
|
|||
[7., 6., 7., 8., 7.]]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t) -> None:
|
||||
|
|
@ -525,7 +525,7 @@ class ReflectionPad3d(_ReflectionPadNd):
|
|||
[1., 0., 1., 0.]]]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t) -> None:
|
||||
|
|
@ -579,7 +579,7 @@ class ReplicationPad1d(_ReplicationPadNd):
|
|||
[4., 4., 4., 4., 5., 6., 7., 7.]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int]
|
||||
|
||||
def __init__(self, padding: _size_2_t) -> None:
|
||||
|
|
@ -633,7 +633,7 @@ class ReplicationPad2d(_ReplicationPadNd):
|
|||
[6., 6., 7., 8., 8.]]]])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_4_t) -> None:
|
||||
|
|
@ -676,7 +676,7 @@ class ReplicationPad3d(_ReplicationPadNd):
|
|||
>>> output = m(input)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
padding: tuple[int, int, int, int, int, int]
|
||||
|
||||
def __init__(self, padding: _size_6_t) -> None:
|
||||
|
|
|
|||
|
|
@ -138,7 +138,7 @@ class Transformer(Module):
|
|||
d_model,
|
||||
eps=layer_norm_eps,
|
||||
bias=bias,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.encoder = TransformerEncoder(
|
||||
|
|
@ -164,7 +164,7 @@ class Transformer(Module):
|
|||
d_model,
|
||||
eps=layer_norm_eps,
|
||||
bias=bias,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.decoder = TransformerDecoder(
|
||||
|
|
@ -768,9 +768,9 @@ class TransformerEncoderLayer(Module):
|
|||
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
||||
|
||||
self.norm_first = norm_first
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
self.dropout1 = Dropout(dropout)
|
||||
self.dropout2 = Dropout(dropout)
|
||||
|
|
@ -1062,11 +1062,11 @@ class TransformerDecoderLayer(Module):
|
|||
self.linear2 = Linear(dim_feedforward, d_model, bias=bias, **factory_kwargs)
|
||||
|
||||
self.norm_first = norm_first
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.norm1 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.norm2 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.norm3 = LayerNorm(d_model, eps=layer_norm_eps, bias=bias, **factory_kwargs)
|
||||
self.dropout1 = Dropout(dropout)
|
||||
self.dropout2 = Dropout(dropout)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
|
|||
import torch
|
||||
|
||||
if isinstance(out_size, (int, torch.SymInt)):
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return out_size
|
||||
if len(defaults) <= len(out_size):
|
||||
raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def broadcast(tensor, devices=None, *, out=None):
|
|||
devices = [_get_device_index(d) for d in devices]
|
||||
return torch._C._broadcast(tensor, devices)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return torch._C._broadcast_out(tensor, out)
|
||||
|
||||
|
||||
|
|
@ -201,7 +201,7 @@ def scatter(tensor, devices=None, chunk_sizes=None, dim=0, streams=None, *, out=
|
|||
"""
|
||||
tensor = _handle_complex(tensor)
|
||||
if out is None:
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
devices = [_get_device_index(d) for d in devices]
|
||||
return tuple(torch._C._scatter(tensor, devices, chunk_sizes, dim, streams))
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -160,7 +160,7 @@ class DataParallel(Module, Generic[T]):
|
|||
self.module = module
|
||||
self.device_ids = [_get_device_index(x, True) for x in device_ids]
|
||||
self.output_device = _get_device_index(output_device, True)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.src_device_obj = torch.device(device_type, self.device_ids[0])
|
||||
|
||||
if device_type == "cuda":
|
||||
|
|
@ -174,7 +174,7 @@ class DataParallel(Module, Generic[T]):
|
|||
if not self.device_ids:
|
||||
return self.module(*inputs, **kwargs)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
for t in chain(self.module.parameters(), self.module.buffers()):
|
||||
if t.device != self.src_device_obj:
|
||||
raise RuntimeError(
|
||||
|
|
@ -261,10 +261,10 @@ def data_parallel(
|
|||
|
||||
device_ids = [_get_device_index(x, True) for x in device_ids]
|
||||
output_device = _get_device_index(output_device, True)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
src_device_obj = torch.device(device_type, device_ids[0])
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
for t in chain(module.parameters(), module.buffers()):
|
||||
if t.device != src_device_obj:
|
||||
raise RuntimeError(
|
||||
|
|
|
|||
|
|
@ -241,7 +241,7 @@ class _BufferCommHook:
|
|||
# is completed.
|
||||
class _DDPSink(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, ddp_weakref, *inputs):
|
||||
# set_materialize_grads(False) will ensure that None gradients stay as
|
||||
# None and are not filled with zeros.
|
||||
|
|
@ -692,7 +692,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
elif process_group is None and device_mesh is None:
|
||||
self.process_group = _get_default_group()
|
||||
elif device_mesh is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.process_group = process_group
|
||||
else:
|
||||
if device_mesh.ndim != 1:
|
||||
|
|
@ -780,13 +780,13 @@ class DistributedDataParallel(Module, Joinable):
|
|||
self.device_ids = None
|
||||
self.output_device = None
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.device_ids = [_get_device_index(x, True) for x in device_ids]
|
||||
|
||||
if output_device is None:
|
||||
output_device = device_ids[0]
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.output_device = _get_device_index(output_device, True)
|
||||
|
||||
self.static_graph = False
|
||||
|
|
@ -936,7 +936,7 @@ class DistributedDataParallel(Module, Joinable):
|
|||
# enabled.
|
||||
self._accum_grad_hooks: list[RemovableHandle] = []
|
||||
if self._use_python_reducer:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
torch._inductor.config._fuse_ddp_communication = True
|
||||
torch._inductor.config._fuse_ddp_bucket_size = bucket_cap_mb
|
||||
# Directly adding this to the trace rule will disturb the users
|
||||
|
|
|
|||
|
|
@ -56,16 +56,16 @@ def scatter(inputs, target_gpus, dim=0):
|
|||
if isinstance(obj, torch.Tensor):
|
||||
return Scatter.apply(target_gpus, None, dim, obj)
|
||||
if _is_namedtuple(obj):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return list(zip(*map(scatter_map, obj)))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return [list(i) for i in zip(*map(scatter_map, obj))]
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
|
||||
return [obj for _ in target_gpus]
|
||||
|
||||
|
|
@ -127,12 +127,12 @@ def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0)
|
|||
if isinstance(out, dict):
|
||||
if not all(len(out) == len(d) for d in outputs):
|
||||
raise ValueError("All dicts must have the same number of keys")
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
|
||||
if _is_namedtuple(out):
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return type(out)._make(map(gather_map, zip(*outputs)))
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return type(out)(map(gather_map, zip(*outputs)))
|
||||
|
||||
# Recursive function calls like this create reference cycles.
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ class Parameter(torch.Tensor, metaclass=_ParameterMeta):
|
|||
memo[id(self)] = result
|
||||
return result
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def __repr__(self):
|
||||
return "Parameter containing:\n" + super().__repr__()
|
||||
|
||||
|
|
@ -144,7 +144,7 @@ class UninitializedTensorMixin:
|
|||
if dtype is None:
|
||||
dtype = self.data.dtype
|
||||
self.data = torch.empty(shape, device=device, dtype=dtype)
|
||||
# pyrefly: ignore # bad-override, missing-attribute
|
||||
# pyrefly: ignore [bad-override, missing-attribute]
|
||||
self.__class__ = self.cls_to_become
|
||||
|
||||
@property
|
||||
|
|
@ -168,7 +168,7 @@ class UninitializedTensorMixin:
|
|||
|
||||
def __reduce_ex__(self, proto):
|
||||
# See Note [Don't serialize hooks]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return (self.__class__, (self.requires_grad,))
|
||||
|
||||
@classmethod
|
||||
|
|
@ -178,7 +178,7 @@ class UninitializedTensorMixin:
|
|||
if func in cls._allowed_methods or func.__class__.__name__ == "method-wrapper":
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return super().__torch_function__(func, types, args, kwargs)
|
||||
raise ValueError(
|
||||
f"Attempted to use an uninitialized parameter in {func}. "
|
||||
|
|
@ -220,7 +220,7 @@ class UninitializedParameter(UninitializedTensorMixin, Parameter):
|
|||
def __new__(cls, requires_grad=True, device=None, dtype=None) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
data = torch.empty(0, **factory_kwargs)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
|
||||
def __deepcopy__(self, memo):
|
||||
|
|
@ -266,9 +266,9 @@ class Buffer(torch.Tensor, metaclass=_BufferMeta):
|
|||
data = torch.empty(0)
|
||||
|
||||
t = data.detach().requires_grad_(data.requires_grad)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
t.persistent = persistent
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
t._is_buffer = True
|
||||
return t
|
||||
|
||||
|
|
@ -299,9 +299,9 @@ class UninitializedBuffer(UninitializedTensorMixin, torch.Tensor):
|
|||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
data = torch.empty(0, **factory_kwargs)
|
||||
ret = torch.Tensor._make_subclass(cls, data, requires_grad)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
ret.persistent = persistent
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
ret._is_buffer = True
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return ret
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from .expanded_weights_utils import forward_helper
|
|||
@implements_per_sample_grads(F.conv3d)
|
||||
class ConvPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(
|
||||
ctx: Any,
|
||||
kwarg_names: list[str],
|
||||
|
|
@ -57,7 +57,7 @@ class ConvPerSampleGrad(torch.autograd.Function):
|
|||
f"unbatched input of dim {input.dim()}, expected input of dim {batched_dim_size}"
|
||||
)
|
||||
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
# pyrefly: ignore [invalid-type-var]
|
||||
ctx.conv_fn = conv_fn
|
||||
|
||||
ctx.batch_size = orig_input.shape[0]
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ def conv_unfold_weight_grad_sample(
|
|||
# n=batch_sz; o=num_out_channels; p=(num_in_channels/groups)*kernel_sz
|
||||
weight_grad_sample = torch.einsum("noq,npq->nop", grad_output, input)
|
||||
# rearrange the above tensor and extract diagonals.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
weight_grad_sample = weight_grad_sample.view(
|
||||
n,
|
||||
groups,
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from .expanded_weights_utils import (
|
|||
@implements_per_sample_grads(F.embedding)
|
||||
class EmbeddingPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(
|
||||
ctx: Any, kwarg_names: list[str], _: Any, *expanded_args_and_kwargs: Any
|
||||
) -> torch.Tensor:
|
||||
|
|
@ -35,7 +35,7 @@ class EmbeddingPerSampleGrad(torch.autograd.Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(
|
||||
ctx: Any, grad_output: torch.Tensor
|
||||
) -> tuple[Optional[torch.Tensor], ...]:
|
||||
|
|
|
|||
|
|
@ -131,7 +131,7 @@ class ExpandedWeight(torch.Tensor):
|
|||
# in aten, choosing the input or data variants is done by parsing logic. This mimics some of that
|
||||
decomp_opts = expanded_weights_rnn_decomps[func]
|
||||
use_input_variant = isinstance(
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
args[2],
|
||||
list,
|
||||
) # data variant uses a list here
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from .expanded_weights_impl import ExpandedWeight
|
|||
|
||||
def is_batch_first(expanded_args_and_kwargs):
|
||||
batch_first = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for arg in expanded_args_and_kwargs:
|
||||
if not isinstance(arg, ExpandedWeight):
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from .expanded_weights_utils import (
|
|||
@implements_per_sample_grads(F.group_norm)
|
||||
class GroupNormPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(
|
||||
kwarg_names, expanded_args_and_kwargs
|
||||
|
|
@ -47,7 +47,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
input, num_groups = ctx.input, ctx.num_groups
|
||||
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
||||
|
|
@ -97,7 +97,7 @@ class GroupNormPerSampleGrad(torch.autograd.Function):
|
|||
weight,
|
||||
lambda _: torch.einsum(
|
||||
"ni...->ni",
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
F.group_norm(input, num_groups, eps=eps) * grad_output,
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from .expanded_weights_utils import (
|
|||
@implements_per_sample_grads(F.instance_norm)
|
||||
class InstanceNormPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||
instance_norm = partial(torch.instance_norm, cudnn_enabled=True)
|
||||
expanded_args, expanded_kwargs = standard_kwargs(
|
||||
|
|
@ -37,7 +37,7 @@ class InstanceNormPerSampleGrad(torch.autograd.Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
input, running_mean, running_var = ctx.input, ctx.running_mean, ctx.running_var
|
||||
weight, bias, eps = ctx.weight, ctx.bias, ctx.eps
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from .expanded_weights_utils import (
|
|||
@implements_per_sample_grads(F.layer_norm)
|
||||
class LayerNormPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, kwarg_names, _, *expanded_args_and_kwargs):
|
||||
expanded_args, expanded_kwargs = standard_kwargs(
|
||||
kwarg_names, expanded_args_and_kwargs
|
||||
|
|
@ -43,7 +43,7 @@ class LayerNormPerSampleGrad(torch.autograd.Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
def weight_per_sample_grad(weight):
|
||||
return sum_over_all_but_batch_and_last_n(
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from .expanded_weights_utils import (
|
|||
@implements_per_sample_grads(F.linear)
|
||||
class LinearPerSampleGrad(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, _, __, *expanded_args_and_kwargs):
|
||||
if len(expanded_args_and_kwargs[0].shape) <= 1:
|
||||
raise RuntimeError(
|
||||
|
|
@ -36,7 +36,7 @@ class LinearPerSampleGrad(torch.autograd.Function):
|
|||
return output
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
input, weight = ctx.args
|
||||
bias = ctx.kwargs["bias"]
|
||||
|
|
|
|||
|
|
@ -77,7 +77,7 @@ def swap_tensor(
|
|||
setattr(module, name, tensor)
|
||||
elif hasattr(module, name):
|
||||
delattr(module, name)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return orig_tensor
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,11 +41,11 @@ def _no_grad(func: Callable[_P, _R]) -> Callable[_P, _R]:
|
|||
|
||||
def _no_grad_wrapper(*args, **kwargs):
|
||||
with torch.no_grad():
|
||||
# pyrefly: ignore # invalid-param-spec
|
||||
# pyrefly: ignore [invalid-param-spec]
|
||||
return func(*args, **kwargs)
|
||||
|
||||
functools.update_wrapper(_no_grad_wrapper, func)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return _no_grad_wrapper
|
||||
|
||||
|
||||
|
|
@ -283,7 +283,7 @@ def clip_grad_value_(
|
|||
clip_value = float(clip_value)
|
||||
|
||||
grads = [p.grad for p in parameters if p.grad is not None]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
grouped_grads = _group_tensors_by_device_and_dtype([grads])
|
||||
|
||||
for (device, _), ([grads], _) in grouped_grads.items():
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ def convert_conv2d_weight_memory_format(
|
|||
)
|
||||
for child in module.children():
|
||||
convert_conv2d_weight_memory_format(child, memory_format)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return module
|
||||
|
||||
|
||||
|
|
@ -164,7 +164,7 @@ def convert_conv3d_weight_memory_format(
|
|||
)
|
||||
for child in module.children():
|
||||
convert_conv3d_weight_memory_format(child, memory_format)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return module
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -98,7 +98,7 @@ class _Orthogonal(Module):
|
|||
)
|
||||
# Q is now orthogonal (or unitary) of size (..., n, n)
|
||||
if n != k:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
Q = Q[..., :k]
|
||||
# Q is now the size of the X (albeit perhaps transposed)
|
||||
else:
|
||||
|
|
@ -111,10 +111,10 @@ class _Orthogonal(Module):
|
|||
Q = Q * X.diagonal(dim1=-2, dim2=-1).int().unsqueeze(-2)
|
||||
|
||||
if hasattr(self, "base"):
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
Q = self.base @ Q
|
||||
if transposed:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
Q = Q.mT
|
||||
return Q # type: ignore[possibly-undefined]
|
||||
|
||||
|
|
|
|||
|
|
@ -179,28 +179,28 @@ class ParametrizationList(ModuleList):
|
|||
|
||||
# Register the tensor(s)
|
||||
if self.is_tensor:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if original.dtype != new.dtype:
|
||||
raise ValueError(
|
||||
"When `right_inverse` outputs one tensor, it may not change the dtype.\n"
|
||||
f"original.dtype: {original.dtype}\n"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
f"right_inverse(original).dtype: {new.dtype}"
|
||||
)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if original.device != new.device:
|
||||
raise ValueError(
|
||||
"When `right_inverse` outputs one tensor, it may not change the device.\n"
|
||||
f"original.device: {original.device}\n"
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
f"right_inverse(original).device: {new.device}"
|
||||
)
|
||||
|
||||
# Set the original to original so that the user does not need to re-register the parameter
|
||||
# manually in the optimiser
|
||||
with torch.no_grad():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_maybe_set(original, new)
|
||||
_register_parameter_or_buffer(self, "original", original)
|
||||
else:
|
||||
|
|
@ -401,7 +401,7 @@ def _inject_property(module: Module, tensor_name: str) -> None:
|
|||
if torch.jit.is_scripting():
|
||||
raise RuntimeError("Parametrization is not working with scripting.")
|
||||
parametrization = self.parametrizations[tensor_name]
|
||||
# pyrefly: ignore # redundant-condition
|
||||
# pyrefly: ignore [redundant-condition]
|
||||
if _cache_enabled:
|
||||
if torch.jit.is_scripting():
|
||||
# Scripting
|
||||
|
|
@ -701,7 +701,7 @@ def remove_parametrizations(
|
|||
# Fetch the original tensor
|
||||
assert isinstance(module.parametrizations, ModuleDict) # Make mypy happy
|
||||
parametrizations = module.parametrizations[tensor_name]
|
||||
# pyrefly: ignore # invalid-argument
|
||||
# pyrefly: ignore [invalid-argument]
|
||||
if parametrizations.is_tensor:
|
||||
original = parametrizations.original
|
||||
assert isinstance(original, torch.Tensor), "is_tensor promised us a Tensor"
|
||||
|
|
|
|||
|
|
@ -274,11 +274,11 @@ class PruningContainer(BasePruningMethod):
|
|||
if not isinstance(args, Iterable): # only 1 item
|
||||
self._tensor_name = args._tensor_name
|
||||
self.add_pruning_method(args)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
elif len(args) == 1: # only 1 item in a tuple
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
self._tensor_name = args[0]._tensor_name
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
self.add_pruning_method(args[0])
|
||||
else: # manual construction from list or other iterable (or no args)
|
||||
for method in args:
|
||||
|
|
@ -1100,7 +1100,7 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw
|
|||
|
||||
# flatten importance scores to consider them all at once in global pruning
|
||||
relevant_importance_scores = torch.nn.utils.parameters_to_vector(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
[
|
||||
importance_scores.get((module, name), getattr(module, name))
|
||||
for (module, name) in parameters
|
||||
|
|
|
|||
|
|
@ -332,7 +332,7 @@ def spectral_norm(
|
|||
else:
|
||||
dim = 0
|
||||
SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return module
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ def context_decorator(ctx, func):
|
|||
|
||||
@functools.wraps(func)
|
||||
def decorate_context(*args, **kwargs):
|
||||
# pyrefly: ignore # bad-context-manager
|
||||
# pyrefly: ignore [bad-context-manager]
|
||||
with ctx_factory():
|
||||
return func(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -1074,7 +1074,7 @@ def key_get(obj: Any, kp: KeyPath) -> Any:
|
|||
|
||||
|
||||
with python_pytree._NODE_REGISTRY_LOCK:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
python_pytree._cxx_pytree_imported = True
|
||||
args, kwargs = (), {} # type: ignore[var-annotated]
|
||||
for args, kwargs in python_pytree._cxx_pytree_pending_imports:
|
||||
|
|
|
|||
|
|
@ -262,7 +262,7 @@ class DebugMode(TorchDispatchMode):
|
|||
self.module_tracker.__enter__() # type: ignore[attribute, union-attr]
|
||||
return self
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def __exit__(self, *args):
|
||||
super().__exit__(*args)
|
||||
if self.record_nn_module:
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ def _device_constructors():
|
|||
# NB: This is directly called from C++ in torch/csrc/Device.cpp
|
||||
class DeviceContext(TorchFunctionMode):
|
||||
def __init__(self, device):
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.device = torch.device(device)
|
||||
|
||||
def __enter__(self):
|
||||
|
|
|
|||
|
|
@ -35,12 +35,12 @@ def cache_method(
|
|||
if not (cache := getattr(self, cache_name, None)):
|
||||
cache = {}
|
||||
setattr(self, cache_name, cache)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
cached_value = cache.get(args, _cache_sentinel)
|
||||
if cached_value is not _cache_sentinel:
|
||||
return cached_value
|
||||
value = f(self, *args, **kwargs)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
cache[args] = value
|
||||
return value
|
||||
|
||||
|
|
|
|||
|
|
@ -708,7 +708,7 @@ class structseq(tuple[_T_co, ...]):
|
|||
def __new__(
|
||||
cls: type[Self],
|
||||
sequence: Iterable[_T_co],
|
||||
# pyrefly: ignore # bad-function-definition
|
||||
# pyrefly: ignore [bad-function-definition]
|
||||
dict: dict[str, Any] = ...,
|
||||
) -> Self:
|
||||
raise NotImplementedError
|
||||
|
|
@ -755,7 +755,7 @@ def _tuple_flatten_with_keys(
|
|||
d: tuple[T, ...],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _tuple_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
|
|
@ -769,7 +769,7 @@ def _list_flatten(d: list[T]) -> tuple[list[T], Context]:
|
|||
|
||||
def _list_flatten_with_keys(d: list[T]) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _list_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
|
|
@ -785,7 +785,7 @@ def _dict_flatten_with_keys(
|
|||
d: dict[Any, T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _dict_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
|
||||
|
||||
|
|
@ -801,7 +801,7 @@ def _namedtuple_flatten_with_keys(
|
|||
d: NamedTuple,
|
||||
) -> tuple[list[tuple[KeyEntry, Any]], Context]:
|
||||
values, context = _namedtuple_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return (
|
||||
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
|
||||
context,
|
||||
|
|
@ -851,7 +851,7 @@ def _ordereddict_flatten_with_keys(
|
|||
d: OrderedDict[Any, T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _ordereddict_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
|
||||
|
||||
|
|
@ -876,7 +876,7 @@ def _defaultdict_flatten_with_keys(
|
|||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _defaultdict_flatten(d)
|
||||
_, dict_context = context
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
|
||||
|
||||
|
||||
|
|
@ -925,7 +925,7 @@ def _deque_flatten_with_keys(
|
|||
d: deque[T],
|
||||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _deque_flatten(d)
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return [(SequenceKey(i), v) for i, v in enumerate(values)], context
|
||||
|
||||
|
||||
|
|
@ -1827,7 +1827,7 @@ def enum_object_hook(obj: dict[str, Any]) -> Union[Enum, dict[str, Any]]:
|
|||
for attr in classname.split("."):
|
||||
enum_cls = getattr(enum_cls, attr)
|
||||
enum_cls = cast(type[Enum], enum_cls)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
return enum_cls[obj["name"]]
|
||||
return obj
|
||||
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ def strobelight(
|
|||
) -> Callable[_P, Optional[_R]]:
|
||||
@functools.wraps(work_function)
|
||||
def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> Optional[_R]:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return profiler.profile(work_function, *args, **kwargs)
|
||||
|
||||
return wrapper_function
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ def _keep_float(
|
|||
) -> Callable[[Unpack[_Ts]], Union[_T, sympy.Float]]:
|
||||
@functools.wraps(f)
|
||||
def inner(*args: Unpack[_Ts]) -> Union[_T, sympy.Float]:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
r: Union[_T, sympy.Float] = f(*args)
|
||||
if any(isinstance(a, sympy.Float) for a in args) and not isinstance(
|
||||
r, sympy.Float
|
||||
|
|
@ -113,7 +113,7 @@ def _keep_float(
|
|||
r = sympy.Float(float(r))
|
||||
return r
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return inner
|
||||
|
||||
|
||||
|
|
@ -200,12 +200,12 @@ class FloorDiv(sympy.Function):
|
|||
|
||||
@property
|
||||
def base(self) -> sympy.Basic:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return self.args[0]
|
||||
|
||||
@property
|
||||
def divisor(self) -> sympy.Basic:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return self.args[1]
|
||||
|
||||
def _sympystr(self, printer: sympy.printing.StrPrinter) -> str:
|
||||
|
|
@ -374,7 +374,7 @@ class ModularIndexing(sympy.Function):
|
|||
return None
|
||||
|
||||
def _eval_is_nonnegative(self) -> Optional[bool]:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
p, q = self.args[:2]
|
||||
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) # type: ignore[attr-defined]
|
||||
|
||||
|
|
@ -455,7 +455,7 @@ class PythonMod(sympy.Function):
|
|||
# - floor(p / q) = 0
|
||||
# - p % q = p - floor(p / q) * q = p
|
||||
less = p < q
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if less.is_Boolean and bool(less) and r.is_positive:
|
||||
return p
|
||||
|
||||
|
|
@ -472,11 +472,11 @@ class PythonMod(sympy.Function):
|
|||
return True if self.args[1].is_negative else None # type: ignore[attr-defined]
|
||||
|
||||
def _ccode(self, printer):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
p = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
q = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
abs_q = str(q) if self.args[1].is_positive else f"abs({q})"
|
||||
return f"({p} % {q}) < 0 ? {p} % {q} + {abs_q} : {p} % {q}"
|
||||
|
||||
|
|
@ -559,7 +559,7 @@ class CeilToInt(sympy.Function):
|
|||
return sympy.Integer(math.ceil(float(number)))
|
||||
|
||||
def _ccode(self, printer):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
number = printer.parenthesize(self.args[0], self.args[0].precedence - 0.5)
|
||||
return f"ceil({number})"
|
||||
|
||||
|
|
@ -830,7 +830,7 @@ class MinMaxBase(Expr, LatticeOp): # type: ignore[misc]
|
|||
if not cond:
|
||||
return ai.func(*[do(i, a) for i in ai.args], evaluate=False)
|
||||
if isinstance(ai, cls):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return ai.func(*[do(i, a) for i in ai.args if i != a], evaluate=False)
|
||||
return a
|
||||
|
||||
|
|
@ -1008,7 +1008,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc]
|
|||
return fuzzy_or(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_negative(self): # type:ignore[override]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return fuzzy_and(a.is_negative for a in self.args)
|
||||
|
||||
|
||||
|
|
@ -1027,7 +1027,7 @@ class Min(MinMaxBase, Application): # type: ignore[misc]
|
|||
return fuzzy_and(a.is_nonnegative for a in self.args) # type: ignore[attr-defined]
|
||||
|
||||
def _eval_is_negative(self): # type:ignore[override]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return fuzzy_or(a.is_negative for a in self.args)
|
||||
|
||||
|
||||
|
|
@ -1165,9 +1165,9 @@ class IntTrueDiv(sympy.Function):
|
|||
return sympy.Float(int(base) / int(divisor))
|
||||
|
||||
def _ccode(self, printer):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
base = printer.parenthesize(self.args[0], PRECEDENCE["Atom"] - 0.5)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
divisor = printer.parenthesize(self.args[1], PRECEDENCE["Atom"] - 0.5)
|
||||
return f"((int){base}/(int){divisor})"
|
||||
|
||||
|
|
@ -1331,16 +1331,16 @@ class Identity(sympy.Function):
|
|||
precedence = 10
|
||||
|
||||
def __repr__(self): # type: ignore[override]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return f"Identity({self.args[0]})"
|
||||
|
||||
def _sympystr(self, printer):
|
||||
"""Controls how sympy's StrPrinter prints this"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return f"({printer.doprint(self.args[0])})"
|
||||
|
||||
def _eval_is_real(self):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return self.args[0].is_real
|
||||
|
||||
def _eval_is_integer(self):
|
||||
|
|
@ -1348,15 +1348,15 @@ class Identity(sympy.Function):
|
|||
|
||||
def _eval_expand_identity(self, **hints):
|
||||
# Removes the identity op.
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return self.args[0]
|
||||
|
||||
def __int__(self) -> int:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return int(self.args[0])
|
||||
|
||||
def __float__(self) -> float:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return float(self.args[0])
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from sympy.core.parameters import global_parameters
|
|||
from sympy.core.singleton import S, Singleton
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class IntInfinity(Number, metaclass=Singleton):
|
||||
r"""Positive integer infinite quantity.
|
||||
|
||||
|
|
@ -204,7 +204,7 @@ class IntInfinity(Number, metaclass=Singleton):
|
|||
int_oo = S.IntInfinity
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class NegativeIntInfinity(Number, metaclass=Singleton):
|
||||
"""Negative integer infinite quantity.
|
||||
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ class ExprPrinter(StrPrinter):
|
|||
# NB: this pow by natural, you should never have used builtin sympy.pow
|
||||
# for FloatPow, and a symbolic exponent should be PowByNatural. These
|
||||
# means exp is guaranteed to be integer.
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _print_Pow(self, expr: sympy.Expr) -> str:
|
||||
base, exp = expr.args
|
||||
if exp != int(exp):
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ class ReferenceAnalysis:
|
|||
|
||||
@staticmethod
|
||||
def pow(a, b):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return _keep_float(FloatPow)(a, b)
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -126,9 +126,9 @@ AllFn2 = Union[ExprFn2, BoolFn2]
|
|||
class ValueRanges(Generic[_T]):
|
||||
if TYPE_CHECKING:
|
||||
# ruff doesn't understand circular references but mypy does
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
ExprVR = ValueRanges[sympy.Expr] # noqa: F821
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
BoolVR = ValueRanges[SympyBoolean] # noqa: F821
|
||||
AllVR = Union[ExprVR, BoolVR]
|
||||
|
||||
|
|
@ -484,7 +484,7 @@ class SymPyValueRangeAnalysis:
|
|||
@staticmethod
|
||||
def to_dtype(a, dtype, src_dtype=None):
|
||||
if dtype == torch.float64:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return ValueRanges.increasing_map(a, ToFloat)
|
||||
elif dtype == torch.bool:
|
||||
return ValueRanges.unknown_bool()
|
||||
|
|
@ -494,7 +494,7 @@ class SymPyValueRangeAnalysis:
|
|||
|
||||
@staticmethod
|
||||
def trunc_to_int(a, dtype):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return ValueRanges.increasing_map(a, TruncToInt)
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -652,7 +652,7 @@ class SymPyValueRangeAnalysis:
|
|||
return ValueRanges.coordinatewise_monotone_map(
|
||||
a,
|
||||
b,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_keep_float(IntTrueDiv),
|
||||
)
|
||||
|
||||
|
|
@ -668,7 +668,7 @@ class SymPyValueRangeAnalysis:
|
|||
return ValueRanges.coordinatewise_monotone_map(
|
||||
a,
|
||||
b,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_keep_float(FloatTrueDiv),
|
||||
)
|
||||
|
||||
|
|
@ -748,7 +748,7 @@ class SymPyValueRangeAnalysis:
|
|||
# We should know that b >= 0 but we may have forgotten this fact due
|
||||
# to replacements, so don't assert it, but DO clamp it to prevent
|
||||
# degenerate problems
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return ValueRanges.coordinatewise_increasing_map(
|
||||
a, b & ValueRanges(0, int_oo), PowByNatural
|
||||
)
|
||||
|
|
@ -915,7 +915,7 @@ class SymPyValueRangeAnalysis:
|
|||
|
||||
@classmethod
|
||||
def round_to_int(cls, number, dtype):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return ValueRanges.increasing_map(number, RoundToInt)
|
||||
|
||||
# It's used in some models on symints
|
||||
|
|
@ -1032,7 +1032,7 @@ class SymPyValueRangeAnalysis:
|
|||
|
||||
@staticmethod
|
||||
def trunc(x):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return ValueRanges.increasing_map(x, TruncToFloat)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -63,7 +63,7 @@ def generate_coo_data(size, sparse_dim, nnz, dtype, device):
|
|||
indices = torch.rand(sparse_dim, nnz, device=device)
|
||||
indices.mul_(torch.tensor(size[:sparse_dim]).unsqueeze(1).to(indices))
|
||||
indices = indices.to(torch.long)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
values = torch.rand([nnz, ], dtype=dtype, device=device)
|
||||
return indices, values
|
||||
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ if HAS_TABULATE:
|
|||
_disable_tensor_cores()
|
||||
table.append([
|
||||
("Training" if optimizer else "Inference"),
|
||||
# pyrefly: ignore # redundant-condition
|
||||
# pyrefly: ignore [redundant-condition]
|
||||
backend if backend else "-",
|
||||
mode if mode is not None else "-",
|
||||
f"{compilation_time} ms " if compilation_time else "-",
|
||||
|
|
@ -191,5 +191,5 @@ if HAS_TABULATE:
|
|||
])
|
||||
|
||||
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
return tabulate(table, headers=field_names, tablefmt="github")
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ def _get_build_root() -> str:
|
|||
global _BUILD_ROOT
|
||||
if _BUILD_ROOT is None:
|
||||
_BUILD_ROOT = _make_temp_dir(prefix="benchmark_utils_jit_build")
|
||||
# pyrefly: ignore # missing-argument
|
||||
# pyrefly: ignore [missing-argument]
|
||||
atexit.register(shutil.rmtree, _BUILD_ROOT)
|
||||
return _BUILD_ROOT
|
||||
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class FuzzedSparseTensor(FuzzedTensor):
|
|||
return x
|
||||
|
||||
def _make_tensor(self, params, state):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
size, _, _ = self._get_size_and_steps(params)
|
||||
density = params['density']
|
||||
nnz = math.ceil(sum(size) * density)
|
||||
|
|
@ -102,10 +102,10 @@ class FuzzedSparseTensor(FuzzedTensor):
|
|||
is_coalesced = params['coalesced']
|
||||
sparse_dim = params['sparse_dim'] if self._sparse_dim else len(size)
|
||||
sparse_dim = min(sparse_dim, len(size))
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
tensor = self.sparse_tensor_constructor(size, self._dtype, sparse_dim, nnz, is_coalesced)
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if self._cuda:
|
||||
tensor = tensor.cuda()
|
||||
sparse_dim = tensor.sparse_dim()
|
||||
|
|
@ -121,7 +121,7 @@ class FuzzedSparseTensor(FuzzedTensor):
|
|||
"sparse_dim": sparse_dim,
|
||||
"dense_dim": dense_dim,
|
||||
"is_hybrid": is_hybrid,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
"dtype": str(self._dtype),
|
||||
}
|
||||
return tensor, properties
|
||||
|
|
|
|||
|
|
@ -234,7 +234,7 @@ class Timer:
|
|||
setup = textwrap.dedent(setup)
|
||||
setup = (setup[1:] if setup and setup[0] == "\n" else setup).rstrip()
|
||||
|
||||
# pyrefly: ignore # bad-instantiation
|
||||
# pyrefly: ignore [bad-instantiation]
|
||||
self._timer = self._timer_cls(
|
||||
stmt=stmt,
|
||||
setup=setup,
|
||||
|
|
|
|||
|
|
@ -449,13 +449,13 @@ class GlobalsBridge:
|
|||
load_lines = []
|
||||
for name, wrapped_value in self._globals.items():
|
||||
if wrapped_value.setup is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
load_lines.append(textwrap.dedent(wrapped_value.setup))
|
||||
|
||||
if wrapped_value.serialization == Serialization.PICKLE:
|
||||
path = os.path.join(self._data_dir, f"{name}.pkl")
|
||||
load_lines.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
f"with open({repr(path)}, 'rb') as f:\n {name} = pickle.load(f)")
|
||||
with open(path, "wb") as f:
|
||||
pickle.dump(wrapped_value.value, f)
|
||||
|
|
@ -465,13 +465,13 @@ class GlobalsBridge:
|
|||
# TODO: Figure out if we can use torch.serialization.add_safe_globals here
|
||||
# Using weights_only=False after the change in
|
||||
# https://dev-discuss.pytorch.org/t/bc-breaking-change-torch-load-is-being-flipped-to-use-weights-only-true-by-default-in-the-nightlies-after-137602/2573
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
load_lines.append(f"{name} = torch.load({repr(path)}, weights_only=False)")
|
||||
torch.save(wrapped_value.value, path)
|
||||
|
||||
elif wrapped_value.serialization == Serialization.TORCH_JIT:
|
||||
path = os.path.join(self._data_dir, f"{name}.pt")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
load_lines.append(f"{name} = torch.jit.load({repr(path)})")
|
||||
with open(path, "wb") as f:
|
||||
torch.jit.save(wrapped_value.value, f) # type: ignore[no-untyped-call]
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ def _get_autocast_kwargs(device_type="cuda"):
|
|||
|
||||
class CheckpointFunction(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||
check_backward_validity(args)
|
||||
ctx.run_function = run_function
|
||||
|
|
@ -785,7 +785,7 @@ class _Holder:
|
|||
|
||||
class _NoopSaveInputs(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(*args):
|
||||
return torch.empty((0,))
|
||||
|
||||
|
|
@ -1008,7 +1008,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
|
|||
def logging_mode():
|
||||
with LoggingTensorMode(), \
|
||||
capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.logs, self.tbs = logs_and_tb
|
||||
yield logs_and_tb
|
||||
return logging_mode()
|
||||
|
|
|
|||
|
|
@ -788,7 +788,7 @@ class BuildExtension(build_ext):
|
|||
|
||||
# Use absolute path for output_dir so that the object file paths
|
||||
# (`objects`) get generated with absolute paths.
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
|
||||
# See Note [Absolute include_dirs]
|
||||
|
|
@ -979,7 +979,7 @@ class BuildExtension(build_ext):
|
|||
is_standalone=False):
|
||||
if not self.compiler.initialized:
|
||||
self.compiler.initialize()
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
output_dir = os.path.abspath(output_dir)
|
||||
|
||||
# Note [Absolute include_dirs]
|
||||
|
|
@ -2573,7 +2573,7 @@ def _get_num_workers(verbose: bool) -> Optional[int]:
|
|||
def _get_vc_env(vc_arch: str) -> dict[str, str]:
|
||||
try:
|
||||
from setuptools import distutils # type: ignore[attr-defined]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return distutils._msvccompiler._get_vc_env(vc_arch)
|
||||
except AttributeError:
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -204,7 +204,7 @@ def collate(
|
|||
# check to make sure that the elements in batch have consistent size
|
||||
it = iter(batch)
|
||||
elem_size = len(next(it))
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
if not all(len(elem) == elem_size for elem in it):
|
||||
raise RuntimeError("each element in list of batch should be of equal size")
|
||||
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ def pin_memory(data, device=None):
|
|||
return clone
|
||||
else:
|
||||
return type(data)(
|
||||
# pyrefly: ignore # bad-argument-count
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
{k: pin_memory(sample, device) for k, sample in data.items()}
|
||||
) # type: ignore[call-arg]
|
||||
except TypeError:
|
||||
|
|
|
|||
|
|
@ -265,7 +265,7 @@ class _DataPipeType:
|
|||
|
||||
# Default type for DataPipe without annotation
|
||||
_T_co = TypeVar("_T_co", covariant=True)
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
_DEFAULT_TYPE = _DataPipeType(Generic[_T_co])
|
||||
|
||||
|
||||
|
|
@ -284,7 +284,7 @@ class _DataPipeMeta(GenericMeta):
|
|||
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
||||
|
||||
# TODO: the statements below are not reachable by design as there is a bug and typing is low priority for now.
|
||||
# pyrefly: ignore # no-access
|
||||
# pyrefly: ignore [no-access]
|
||||
cls.__origin__ = None
|
||||
if "type" in namespace:
|
||||
return super().__new__(cls, name, bases, namespace, **kwargs) # type: ignore[call-overload]
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class Capture:
|
|||
|
||||
def _ops_str(self):
|
||||
res = ""
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
for op in self.ctx["operations"]:
|
||||
if len(res) > 0:
|
||||
res += "\n"
|
||||
|
|
@ -90,7 +90,7 @@ class Capture:
|
|||
def __getstate__(self):
|
||||
# TODO(VitalyFedyunin): Currently can't pickle (why?)
|
||||
self.ctx["schema_df"] = None
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
for var in self.ctx["variables"]:
|
||||
var.calculated_value = None
|
||||
state = {}
|
||||
|
|
@ -114,13 +114,13 @@ class Capture:
|
|||
return CaptureGetItem(self, key, ctx=self.ctx)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.ctx["operations"].append(CaptureSetItem(self, key, value, ctx=self.ctx))
|
||||
|
||||
def __add__(self, add_val):
|
||||
res = CaptureAdd(self, add_val, ctx=self.ctx)
|
||||
var = CaptureVariable(res, ctx=self.ctx)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.ctx["operations"].append(
|
||||
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
||||
)
|
||||
|
|
@ -129,7 +129,7 @@ class Capture:
|
|||
def __sub__(self, add_val):
|
||||
res = CaptureSub(self, add_val, ctx=self.ctx)
|
||||
var = CaptureVariable(res, ctx=self.ctx)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.ctx["operations"].append(
|
||||
CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
||||
)
|
||||
|
|
@ -139,19 +139,19 @@ class Capture:
|
|||
res = CaptureMul(self, add_val, ctx=self.ctx)
|
||||
var = CaptureVariable(res, ctx=self.ctx)
|
||||
t = CaptureVariableAssign(variable=var, value=res, ctx=self.ctx)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.ctx["operations"].append(t)
|
||||
return var
|
||||
|
||||
def _is_context_empty(self):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return len(self.ctx["operations"]) == 0 and len(self.ctx["variables"]) == 0
|
||||
|
||||
def apply_ops_2(self, dataframe):
|
||||
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
self.ctx["variables"][0].calculated_value = dataframe
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
for op in self.ctx["operations"]:
|
||||
op.execute()
|
||||
|
||||
|
|
@ -184,7 +184,7 @@ class Capture:
|
|||
res = CaptureCall(self, ctx=self.ctx, args=args, kwargs=kwargs)
|
||||
var = CaptureVariable(None, ctx=self.ctx)
|
||||
t = CaptureVariableAssign(ctx=self.ctx, variable=var, value=res)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.ctx["operations"].append(t)
|
||||
return var
|
||||
|
||||
|
|
@ -283,9 +283,9 @@ class CaptureVariable(Capture):
|
|||
|
||||
def apply_ops(self, dataframe):
|
||||
# TODO(VitalyFedyunin): Make this calculation thread safe (as currently it updates pointer)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
self.ctx["variables"][0].calculated_value = dataframe
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
for op in self.ctx["operations"]:
|
||||
op.execute()
|
||||
return self.calculated_value
|
||||
|
|
@ -385,7 +385,7 @@ def get_val(capture):
|
|||
|
||||
class CaptureInitial(CaptureVariable):
|
||||
def __init__(self, schema_df=None):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
new_ctx: dict[str, list[Any]] = {
|
||||
"operations": [],
|
||||
"variables": [],
|
||||
|
|
@ -401,7 +401,7 @@ class CaptureDataFrame(CaptureInitial):
|
|||
|
||||
class CaptureDataFrameWithDataPipeOps(CaptureDataFrame):
|
||||
def as_datapipe(self):
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
return DataFrameTracedOps(self.ctx["variables"][0].source_datapipe, self)
|
||||
|
||||
def raw_iterator(self):
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ class FilterDataFramesPipe(DFIterDataPipe):
|
|||
size = None
|
||||
all_buffer = []
|
||||
filter_res = []
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for df in self.source_datapipe:
|
||||
if size is None:
|
||||
size = len(df.index)
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ class IterDataPipe(IterableDataset[_T_co], metaclass=_IterDataPipeMeta):
|
|||
_fast_forward_iterator: Optional[Iterator] = None
|
||||
|
||||
def __iter__(self) -> Iterator[_T_co]:
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return self
|
||||
|
||||
def __getattr__(self, attribute_name):
|
||||
|
|
@ -380,7 +380,7 @@ class _DataPipeSerializationWrapper:
|
|||
value = pickle.dumps(self._datapipe)
|
||||
except Exception:
|
||||
if HAS_DILL:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
value = dill.dumps(self._datapipe)
|
||||
use_dill = True
|
||||
else:
|
||||
|
|
@ -390,7 +390,7 @@ class _DataPipeSerializationWrapper:
|
|||
def __setstate__(self, state):
|
||||
value, use_dill = state
|
||||
if use_dill:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self._datapipe = dill.loads(value)
|
||||
else:
|
||||
self._datapipe = pickle.loads(value)
|
||||
|
|
@ -407,7 +407,7 @@ class _DataPipeSerializationWrapper:
|
|||
class _IterDataPipeSerializationWrapper(_DataPipeSerializationWrapper, IterDataPipe):
|
||||
def __init__(self, datapipe: IterDataPipe[_T_co]):
|
||||
super().__init__(datapipe)
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
# pyrefly: ignore [invalid-type-var]
|
||||
self._datapipe_iter: Optional[Iterator[_T_co]] = None
|
||||
|
||||
def __iter__(self) -> "_IterDataPipeSerializationWrapper":
|
||||
|
|
|
|||
|
|
@ -118,7 +118,7 @@ class MapperIterDataPipe(IterDataPipe[_T_co]):
|
|||
for idx in sorted(self.input_col[1:], reverse=True):
|
||||
del data[idx]
|
||||
else:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
data[self.input_col] = res
|
||||
else:
|
||||
if self.output_col == -1:
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ class SamplerIterDataPipe(IterDataPipe[_T_co]):
|
|||
"Sampler class requires input datapipe implemented `__len__`"
|
||||
)
|
||||
super().__init__()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.datapipe = datapipe
|
||||
self.sampler_args = () if sampler_args is None else sampler_args
|
||||
self.sampler_kwargs = {} if sampler_kwargs is None else sampler_kwargs
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ class ConcaterIterDataPipe(IterDataPipe):
|
|||
|
||||
def __len__(self) -> int:
|
||||
if all(isinstance(dp, Sized) for dp in self.datapipes):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return sum(len(dp) for dp in self.datapipes)
|
||||
else:
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
|
@ -180,7 +180,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
|||
self._child_stop: list[bool] = [True for _ in range(num_instances)]
|
||||
|
||||
def __len__(self):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return len(self.main_datapipe)
|
||||
|
||||
def get_next_element_by_instance(self, instance_id: int):
|
||||
|
|
@ -240,7 +240,7 @@ class _ForkerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
|||
return self.end_ptr is not None and all(self._child_stop)
|
||||
|
||||
def get_length_by_instance(self, instance_id: int) -> int:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return len(self.main_datapipe)
|
||||
|
||||
def reset(self) -> None:
|
||||
|
|
@ -327,7 +327,7 @@ class _ChildDataPipe(IterDataPipe):
|
|||
if not isinstance(main_datapipe, _ContainerTemplate):
|
||||
raise AssertionError("main_datapipe must implement _ContainerTemplate")
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.main_datapipe: IterDataPipe = main_datapipe
|
||||
self.instance_id = instance_id
|
||||
|
||||
|
|
@ -454,7 +454,7 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
|||
drop_none: bool,
|
||||
buffer_size: int,
|
||||
):
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
# pyrefly: ignore [invalid-type-var]
|
||||
self.main_datapipe = datapipe
|
||||
self._datapipe_iterator: Optional[Iterator[Any]] = None
|
||||
self.num_instances = num_instances
|
||||
|
|
@ -466,9 +466,9 @@ class _DemultiplexerIterDataPipe(IterDataPipe, _ContainerTemplate):
|
|||
UserWarning,
|
||||
)
|
||||
self.current_buffer_usage = 0
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
# pyrefly: ignore [invalid-type-var]
|
||||
self.child_buffers: list[deque[_T_co]] = [deque() for _ in range(num_instances)]
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
# pyrefly: ignore [invalid-type-var]
|
||||
self.classifier_fn = classifier_fn
|
||||
self.drop_none = drop_none
|
||||
self.main_datapipe_exhausted = False
|
||||
|
|
@ -706,7 +706,7 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
|
|||
|
||||
def __len__(self) -> int:
|
||||
if all(isinstance(dp, Sized) for dp in self.datapipes):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return min(len(dp) for dp in self.datapipes)
|
||||
else:
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
|
|
|||
|
|
@ -204,9 +204,9 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
|||
drop_remaining: bool = False,
|
||||
):
|
||||
_check_unpickable_fn(group_key_fn)
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
# pyrefly: ignore [invalid-type-var]
|
||||
self.datapipe = datapipe
|
||||
# pyrefly: ignore # invalid-type-var
|
||||
# pyrefly: ignore [invalid-type-var]
|
||||
self.group_key_fn = group_key_fn
|
||||
|
||||
self.keep_key = keep_key
|
||||
|
|
@ -218,14 +218,14 @@ class GrouperIterDataPipe(IterDataPipe[DataChunk]):
|
|||
if group_size is not None and buffer_size is not None:
|
||||
if not (0 < group_size <= buffer_size):
|
||||
raise AssertionError("group_size must be > 0 and <= buffer_size")
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.guaranteed_group_size = group_size
|
||||
if guaranteed_group_size is not None:
|
||||
if group_size is None or not (0 < guaranteed_group_size <= group_size):
|
||||
raise AssertionError(
|
||||
"guaranteed_group_size must be > 0 and <= group_size and group_size must be set"
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.guaranteed_group_size = guaranteed_group_size
|
||||
self.drop_remaining = drop_remaining
|
||||
self.wrapper_class = DataChunk
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class MapperMapDataPipe(MapDataPipe[_T_co]):
|
|||
self.fn = fn # type: ignore[assignment]
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return len(self.datapipe)
|
||||
|
||||
def __getitem__(self, index) -> _T_co:
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]):
|
|||
) -> None:
|
||||
super().__init__()
|
||||
self.datapipe = datapipe
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.indices = list(range(len(datapipe))) if indices is None else indices
|
||||
self._enabled = True
|
||||
self._seed = None
|
||||
|
|
@ -96,7 +96,7 @@ class ShufflerIterDataPipe(IterDataPipe[_T_co]):
|
|||
self._shuffled_indices = self._rng.sample(self.indices, len(self.indices))
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return len(self.datapipe)
|
||||
|
||||
def __getstate__(self):
|
||||
|
|
|
|||
|
|
@ -49,16 +49,16 @@ class ConcaterMapDataPipe(MapDataPipe):
|
|||
def __getitem__(self, index) -> _T_co: # type: ignore[type-var]
|
||||
offset = 0
|
||||
for dp in self.datapipes:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if index - offset < len(dp):
|
||||
return dp[index - offset]
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
offset += len(dp)
|
||||
raise IndexError(f"Index {index} is out of range.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return sum(len(dp) for dp in self.datapipes)
|
||||
|
||||
|
||||
|
|
@ -105,5 +105,5 @@ class ZipperMapDataPipe(MapDataPipe[tuple[_T_co, ...]]):
|
|||
return tuple(res)
|
||||
|
||||
def __len__(self) -> int:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return min(len(dp) for dp in self.datapipes)
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ def get_file_pathnames_from_root(
|
|||
if match_masks(fname, masks):
|
||||
yield path
|
||||
else:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for path, dirs, files in os.walk(root, onerror=onerror):
|
||||
if abspath:
|
||||
path = os.path.abspath(path)
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ def _simple_graph_snapshot_restoration(
|
|||
# simple fast-forwarding. Therefore, we need to call `reset` twice, because if `SnapshotState` is `Restored`,
|
||||
# the first reset will not actually reset.
|
||||
datapipe.reset() # This ensures `SnapshotState` is `Iterating` by this point, even if it was `Restored`.
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
apply_random_seed(datapipe, rng)
|
||||
|
||||
remainder = n_iterations
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ class DistributedSampler(Sampler[_T_co]):
|
|||
f"Number of subsampled indices ({len(indices)}) does not match num_samples ({self.num_samples})"
|
||||
)
|
||||
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return iter(indices)
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ def _list_connected_datapipes(
|
|||
p.dump(scan_obj)
|
||||
except (pickle.PickleError, AttributeError, TypeError):
|
||||
if dill_available():
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
d.dump(scan_obj)
|
||||
else:
|
||||
raise
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class FileBaton:
|
|||
True if the file could be created, else False.
|
||||
"""
|
||||
try:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.fd = os.open(self.lock_file_path, os.O_CREAT | os.O_EXCL)
|
||||
return True
|
||||
except FileExistsError:
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ def conv_flop_count(
|
|||
@register_flop_formula([aten.convolution, aten._convolution, aten.cudnn_convolution, aten._slow_conv2d_forward])
|
||||
def conv_flop(x_shape, w_shape, _bias, _stride, _padding, _dilation, transposed, *args, out_shape=None, **kwargs) -> int:
|
||||
"""Count flops for convolution."""
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return conv_flop_count(x_shape, w_shape, out_shape, transposed=transposed)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -145,7 +145,7 @@ class BackwardHook:
|
|||
|
||||
res = out
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.grad_outputs = None
|
||||
|
||||
return self._unpack_none(self.input_tensors_index, res)
|
||||
|
|
|
|||
|
|
@ -237,7 +237,7 @@ def get_model_info(
|
|||
with zipfile.ZipFile(path_or_file) as zf:
|
||||
path_prefix = None
|
||||
zip_files = []
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for zi in zf.infolist():
|
||||
prefix = re.sub("/.*", "", zi.filename)
|
||||
if path_prefix is None:
|
||||
|
|
@ -392,12 +392,12 @@ def get_inline_skeleton():
|
|||
|
||||
import importlib.resources
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
skeleton = importlib.resources.read_text(__package__, "skeleton.html")
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
js_code = importlib.resources.read_text(__package__, "code.js")
|
||||
for js_module in ["preact", "htm"]:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
js_lib = importlib.resources.read_binary(__package__, f"{js_module}.mjs")
|
||||
js_url = "data:application/javascript," + urllib.parse.quote(js_lib)
|
||||
js_code = js_code.replace(f"https://unpkg.com/{js_module}?module", js_url)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def make_np(x: torch.Tensor) -> np.ndarray:
|
|||
def _prepare_pytorch(x: torch.Tensor) -> np.ndarray:
|
||||
if x.dtype == torch.bfloat16:
|
||||
x = x.to(torch.float16)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
x = x.detach().cpu().numpy()
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return x
|
||||
|
|
|
|||
|
|
@ -188,7 +188,7 @@ class GraphPy:
|
|||
|
||||
for key, node in self.nodes_io.items():
|
||||
if type(node) is NodeBase:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
|
||||
if hasattr(node, "input_or_output"):
|
||||
self.unique_name_to_scoped_name[key] = (
|
||||
|
|
@ -199,7 +199,7 @@ class GraphPy:
|
|||
self.unique_name_to_scoped_name[key] = node.scope + "/" + node.debugName
|
||||
if node.scope == "" and self.shallowest_scope_name:
|
||||
self.unique_name_to_scoped_name[node.debugName] = (
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
self.shallowest_scope_name + "/" + node.debugName
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -57,14 +57,14 @@ def _prepare_video(V):
|
|||
return num != 0 and ((num & (num - 1)) == 0)
|
||||
|
||||
# pad to nearest power of 2, all at once
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
if not is_power2(V.shape[0]):
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
len_addition = int(2 ** V.shape[0].bit_length() - V.shape[0])
|
||||
V = np.concatenate((V, np.zeros(shape=(len_addition, t, c, h, w))), axis=0)
|
||||
|
||||
n_rows = 2 ** ((b.bit_length() - 1) // 2)
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
n_cols = V.shape[0] // n_rows
|
||||
|
||||
V = np.reshape(V, newshape=(n_rows, n_cols, t, c, h, w))
|
||||
|
|
|
|||
|
|
@ -498,7 +498,7 @@ def make_histogram(values, bins, max_bins=None):
|
|||
subsampling = num_bins // max_bins
|
||||
subsampling_remainder = num_bins % subsampling
|
||||
if subsampling_remainder != 0:
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
counts = np.pad(
|
||||
counts,
|
||||
pad_width=[[0, subsampling - subsampling_remainder]],
|
||||
|
|
@ -838,21 +838,21 @@ def compute_curve(labels, predictions, num_thresholds=None, weights=None):
|
|||
weights = 1.0
|
||||
|
||||
# Compute bins of true positives and false positives.
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
bucket_indices = np.int32(np.floor(predictions * (num_thresholds - 1)))
|
||||
float_labels = labels.astype(np.float64)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
histogram_range = (0, num_thresholds - 1)
|
||||
tp_buckets, _ = np.histogram(
|
||||
bucket_indices,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
bins=num_thresholds,
|
||||
range=histogram_range,
|
||||
weights=float_labels * weights,
|
||||
)
|
||||
fp_buckets, _ = np.histogram(
|
||||
bucket_indices,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
bins=num_thresholds,
|
||||
range=histogram_range,
|
||||
weights=(1.0 - float_labels) * weights,
|
||||
|
|
|
|||
|
|
@ -254,9 +254,9 @@ class SummaryWriter:
|
|||
buckets = []
|
||||
neg_buckets = []
|
||||
while v < 1e20:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
buckets.append(v)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
neg_buckets.append(-v)
|
||||
v *= 1.1
|
||||
self.default_bins = neg_buckets[::-1] + [0] + buckets
|
||||
|
|
@ -264,19 +264,19 @@ class SummaryWriter:
|
|||
def _get_file_writer(self):
|
||||
"""Return the default FileWriter instance. Recreates it if closed."""
|
||||
if self.all_writers is None or self.file_writer is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.file_writer = FileWriter(
|
||||
self.log_dir, self.max_queue, self.flush_secs, self.filename_suffix
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment, missing-attribute
|
||||
# pyrefly: ignore [bad-assignment, missing-attribute]
|
||||
self.all_writers = {self.file_writer.get_logdir(): self.file_writer}
|
||||
if self.purge_step is not None:
|
||||
most_recent_step = self.purge_step
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.file_writer.add_event(
|
||||
Event(step=most_recent_step, file_version="brain.Event:2")
|
||||
)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.file_writer.add_event(
|
||||
Event(
|
||||
step=most_recent_step,
|
||||
|
|
@ -1207,7 +1207,7 @@ class SummaryWriter:
|
|||
for writer in self.all_writers.values():
|
||||
writer.flush()
|
||||
writer.close()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.file_writer = self.all_writers = None
|
||||
|
||||
def __enter__(self):
|
||||
|
|
|
|||
|
|
@ -461,7 +461,7 @@ def to_html(nodes):
|
|||
if n.context is None:
|
||||
continue
|
||||
s = _listener_template.format(id=str(i + 1), stack=escape(f'{n.label}:\n{n.context}'))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
listeners.append(s)
|
||||
dot = to_dot(nodes)
|
||||
return _template.replace('$DOT', repr(dot)).replace('$LISTENERS', '\n'.join(listeners))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user