Fix pyrelfy ignore syntax in distributions and ao (#166248)

Ensures existing pyrefly ignores only ignore the intended error code

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166248
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss 2025-10-26 22:13:48 +00:00 committed by PyTorch MergeBot
parent a2b6afeac5
commit 154e4d36e9
84 changed files with 251 additions and 251 deletions

View File

@ -620,7 +620,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
padding_mode=padding_mode, padding_mode=padding_mode,
qconfig=qconfig, qconfig=qconfig,
) )
@ -821,7 +821,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
padding_mode=padding_mode, padding_mode=padding_mode,
qconfig=qconfig, qconfig=qconfig,
) )
@ -1023,7 +1023,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
padding_mode=padding_mode, padding_mode=padding_mode,
qconfig=qconfig, qconfig=qconfig,
) )

View File

@ -36,7 +36,7 @@ class LinearReLU(nnqat.Linear, _FusedModule):
torch.Size([128, 30]) torch.Size([128, 30])
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
_FLOAT_MODULE = nni.LinearReLU _FLOAT_MODULE = nni.LinearReLU
def __init__( def __init__(

View File

@ -30,7 +30,7 @@ class LinearReLU(nnqd.Linear):
torch.Size([128, 30]) torch.Size([128, 30])
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
_FLOAT_MODULE = nni.LinearReLU _FLOAT_MODULE = nni.LinearReLU
def __init__( def __init__(

View File

@ -54,7 +54,7 @@ class ConvReLU1d(nnq.Conv1d):
dilation=dilation, dilation=dilation,
groups=groups, groups=groups,
bias=bias, bias=bias,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
padding_mode=padding_mode, padding_mode=padding_mode,
device=device, device=device,
dtype=dtype, dtype=dtype,

View File

@ -114,7 +114,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
assert hasattr(cls, "_FLOAT_RELU_MODULE") assert hasattr(cls, "_FLOAT_RELU_MODULE")
relu = cls._FLOAT_RELU_MODULE() relu = cls._FLOAT_RELU_MODULE()
modules.append(relu) modules.append(relu)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
fused = cls._FLOAT_MODULE(*modules) fused = cls._FLOAT_MODULE(*modules)
fused.train(self.training) fused.train(self.training)
return fused return fused

View File

@ -50,7 +50,7 @@ class Embedding(nn.Embedding):
scale_grad_by_freq, scale_grad_by_freq,
sparse, sparse,
_weight, _weight,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
**factory_kwargs, **factory_kwargs,
) )
assert qconfig, "qconfig must be provided for QAT module" assert qconfig, "qconfig must be provided for QAT module"

View File

@ -170,11 +170,11 @@ class MultiheadAttention(nn.MultiheadAttention):
observed.linear_K.weight = nn.Parameter(other.k_proj_weight) observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
observed.linear_V.weight = nn.Parameter(other.v_proj_weight) observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
if other.in_proj_bias is None: if other.in_proj_bias is None:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
observed.linear_Q.bias = None observed.linear_Q.bias = None
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
observed.linear_K.bias = None observed.linear_K.bias = None
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
observed.linear_V.bias = None observed.linear_V.bias = None
else: else:
observed.linear_Q.bias = nn.Parameter( observed.linear_Q.bias = nn.Parameter(
@ -237,7 +237,7 @@ class MultiheadAttention(nn.MultiheadAttention):
_end = _start + fp.embed_dim _end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wQ fp.in_proj_weight[_start:_end, :] = wQ
if fp.in_proj_bias is not None: if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
assert all(bQ == 0) assert all(bQ == 0)
fp.in_proj_bias[_start:_end] = bQ fp.in_proj_bias[_start:_end] = bQ
@ -245,14 +245,14 @@ class MultiheadAttention(nn.MultiheadAttention):
_end = _start + fp.embed_dim _end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wK fp.in_proj_weight[_start:_end, :] = wK
if fp.in_proj_bias is not None: if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
assert all(bK == 0) assert all(bK == 0)
fp.in_proj_bias[_start:_end] = bK fp.in_proj_bias[_start:_end] = bK
_start = _end _start = _end
fp.in_proj_weight[_start:, :] = wV fp.in_proj_weight[_start:, :] = wV
if fp.in_proj_bias is not None: if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
assert all(bV == 0) assert all(bV == 0)
fp.in_proj_bias[_start:] = bV fp.in_proj_bias[_start:] = bV
else: else:
@ -260,11 +260,11 @@ class MultiheadAttention(nn.MultiheadAttention):
fp.k_proj_weight = nn.Parameter(wK) fp.k_proj_weight = nn.Parameter(wK)
fp.v_proj_weight = nn.Parameter(wV) fp.v_proj_weight = nn.Parameter(wV)
if fp.in_proj_bias is None: if fp.in_proj_bias is None:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.linear_Q.bias = None self.linear_Q.bias = None
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.linear_K.bias = None self.linear_K.bias = None
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.linear_V.bias = None self.linear_V.bias = None
else: else:
fp.in_proj_bias[0 : fp.embed_dim] = bQ fp.in_proj_bias[0 : fp.embed_dim] = bQ
@ -472,7 +472,7 @@ class MultiheadAttention(nn.MultiheadAttention):
assert static_v.size(2) == head_dim assert static_v.size(2) == head_dim
v = static_v v = static_v
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
src_len = k.size(1) src_len = k.size(1)
if key_padding_mask is not None: if key_padding_mask is not None:
@ -481,35 +481,35 @@ class MultiheadAttention(nn.MultiheadAttention):
if self.add_zero_attn: if self.add_zero_attn:
src_len += 1 src_len += 1
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:]) k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if k.is_quantized: if k.is_quantized:
k_zeros = torch.quantize_per_tensor( k_zeros = torch.quantize_per_tensor(
k_zeros, k_zeros,
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
k.q_scale(), k.q_scale(),
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
k.q_zero_point(), k.q_zero_point(),
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
k.dtype, k.dtype,
) )
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
k = torch.cat([k, k_zeros], dim=1) k = torch.cat([k, k_zeros], dim=1)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:]) v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if v.is_quantized: if v.is_quantized:
v_zeros = torch.quantize_per_tensor( v_zeros = torch.quantize_per_tensor(
v_zeros, v_zeros,
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
v.q_scale(), v.q_scale(),
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
v.q_zero_point(), v.q_zero_point(),
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
v.dtype, v.dtype,
) )
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
v = torch.cat([v, v_zeros], dim=1) v = torch.cat([v, v_zeros], dim=1)
if attn_mask is not None: if attn_mask is not None:

View File

@ -376,7 +376,7 @@ class _LSTMLayer(torch.nn.Module):
bidirectional, bidirectional,
split_gates=split_gates, split_gates=split_gates,
) )
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
layer.qconfig = getattr(other, "qconfig", qconfig) layer.qconfig = getattr(other, "qconfig", qconfig)
wi = getattr(other, f"weight_ih_l{layer_idx}") wi = getattr(other, f"weight_ih_l{layer_idx}")
wh = getattr(other, f"weight_hh_l{layer_idx}") wh = getattr(other, f"weight_hh_l{layer_idx}")
@ -455,7 +455,7 @@ class LSTM(torch.nn.Module):
if ( if (
not isinstance(dropout, numbers.Number) not isinstance(dropout, numbers.Number)
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
or not 0 <= dropout <= 1 or not 0 <= dropout <= 1
or isinstance(dropout, bool) or isinstance(dropout, bool)
): ):
@ -464,7 +464,7 @@ class LSTM(torch.nn.Module):
"representing the probability of an element being " "representing the probability of an element being "
"zeroed" "zeroed"
) )
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
if dropout > 0: if dropout > 0:
warnings.warn( warnings.warn(
"dropout option for quantizable LSTM is ignored. " "dropout option for quantizable LSTM is ignored. "
@ -578,7 +578,7 @@ class LSTM(torch.nn.Module):
other.bidirectional, other.bidirectional,
split_gates=split_gates, split_gates=split_gates,
) )
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
observed.qconfig = getattr(other, "qconfig", qconfig) observed.qconfig = getattr(other, "qconfig", qconfig)
for idx in range(other.num_layers): for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float( observed.layers[idx] = _LSTMLayer.from_float(

View File

@ -74,7 +74,7 @@ class Conv1d(nnq.Conv1d):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size) kernel_size = _single(kernel_size)
stride = _single(stride) stride = _single(stride)
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
padding = padding if isinstance(padding, str) else _single(padding) padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation) dilation = _single(dilation)

View File

@ -119,9 +119,9 @@ class Linear(nnq.Linear):
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) is nni.LinearReLU: if type(mod) is nni.LinearReLU:
mod = mod[0] mod = mod[0]
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if mod.qconfig is not None and mod.qconfig.weight is not None: if mod.qconfig is not None and mod.qconfig.weight is not None:
# pyrefly: ignore # not-callable # pyrefly: ignore [not-callable]
weight_observer = mod.qconfig.weight() weight_observer = mod.qconfig.weight()
else: else:
# We have the circular import issues if we import the qconfig in the beginning of this file: # We have the circular import issues if we import the qconfig in the beginning of this file:
@ -145,7 +145,7 @@ class Linear(nnq.Linear):
"Unsupported dtype specified for dynamic quantized Linear!" "Unsupported dtype specified for dynamic quantized Linear!"
) )
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype) qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
qlinear.set_weight_bias(qweight, mod.bias) qlinear.set_weight_bias(qweight, mod.bias)
return qlinear return qlinear

View File

@ -522,7 +522,7 @@ class LSTM(RNNBase):
>>> output, (hn, cn) = rnn(input, (h0, c0)) >>> output, (hn, cn) = rnn(input, (h0, c0))
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
_FLOAT_MODULE = nn.LSTM _FLOAT_MODULE = nn.LSTM
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]} __overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
@ -808,7 +808,7 @@ class GRU(RNNBase):
>>> output, hn = rnn(input, h0) >>> output, hn = rnn(input, h0)
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
_FLOAT_MODULE = nn.GRU _FLOAT_MODULE = nn.GRU
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]} __overloads__ = {"forward": ["forward_packed", "forward_tensor"]}

View File

@ -67,9 +67,9 @@ class Hardswish(torch.nn.Hardswish):
def __init__(self, scale, zero_point, device=None, dtype=None): def __init__(self, scale, zero_point, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__() super().__init__()
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -140,9 +140,9 @@ class LeakyReLU(torch.nn.LeakyReLU):
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(negative_slope, inplace) super().__init__(negative_slope, inplace)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -230,7 +230,7 @@ class Softmax(torch.nn.Softmax):
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention): class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention _FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
def _get_name(self): def _get_name(self):

View File

@ -12,9 +12,9 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
) -> None: ) -> None:
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs) super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs)) self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
@staticmethod @staticmethod

View File

@ -408,7 +408,7 @@ class Conv1d(_ConvNd):
factory_kwargs = {"device": device, "dtype": dtype} factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size) kernel_size = _single(kernel_size)
stride = _single(stride) stride = _single(stride)
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
padding = padding if isinstance(padding, str) else _single(padding) padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation) dilation = _single(dilation)

View File

@ -310,7 +310,7 @@ class Linear(WeightedQuantizedModule):
# the type mismatch in assignment. Also, mypy has an issue with # the type mismatch in assignment. Also, mypy has an issue with
# iterables not being implemented, so we are ignoring those too. # iterables not being implemented, so we are ignoring those too.
if not isinstance(cls._FLOAT_MODULE, Iterable): if not isinstance(cls._FLOAT_MODULE, Iterable):
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
cls._FLOAT_MODULE = [cls._FLOAT_MODULE] cls._FLOAT_MODULE = [cls._FLOAT_MODULE]
supported_modules = ", ".join( supported_modules = ", ".join(
[float_mod.__name__ for float_mod in cls._FLOAT_MODULE] [float_mod.__name__ for float_mod in cls._FLOAT_MODULE]

View File

@ -37,14 +37,14 @@ class LayerNorm(torch.nn.LayerNorm):
normalized_shape, normalized_shape,
eps=eps, eps=eps,
elementwise_affine=elementwise_affine, elementwise_affine=elementwise_affine,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
**factory_kwargs, **factory_kwargs,
) )
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -116,9 +116,9 @@ class GroupNorm(torch.nn.GroupNorm):
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs) super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -180,9 +180,9 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
) )
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -249,9 +249,9 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
) )
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):
@ -318,9 +318,9 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
) )
self.weight = weight self.weight = weight
self.bias = bias self.bias = bias
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs)) self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs)) self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input): def forward(self, input):

View File

@ -141,7 +141,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
dilation, dilation,
groups, groups,
bias, bias,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
padding_mode, padding_mode,
device, device,
dtype, dtype,
@ -206,7 +206,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
dilation, dilation,
groups, groups,
bias, bias,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
padding_mode, padding_mode,
device, device,
dtype, dtype,
@ -383,7 +383,7 @@ class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
groups, groups,
bias, bias,
dilation, dilation,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
padding_mode, padding_mode,
device, device,
dtype, dtype,
@ -465,7 +465,7 @@ class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
groups, groups,
bias, bias,
dilation, dilation,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
padding_mode, padding_mode,
device, device,
dtype, dtype,

View File

@ -664,7 +664,7 @@ class LSTM(RNNBase):
if isinstance(orig_input, PackedSequence): if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence( output_packed = PackedSequence(
output, output,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
batch_sizes, batch_sizes,
sorted_indices, sorted_indices,
unsorted_indices, unsorted_indices,
@ -828,7 +828,7 @@ class GRU(RNNBase):
if isinstance(orig_input, PackedSequence): if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence( output_packed = PackedSequence(
output, output,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
batch_sizes, batch_sizes,
sorted_indices, sorted_indices,
unsorted_indices, unsorted_indices,

View File

@ -42,7 +42,7 @@ class Embedding(nn.Embedding, ReferenceQuantizedModule):
scale_grad_by_freq, scale_grad_by_freq,
sparse, sparse,
_weight, _weight,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
device, device,
dtype, dtype,
) )

View File

@ -18,7 +18,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
"scale": 1.0, "scale": 1.0,
"zero_point": 0, "zero_point": 0,
} }
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"] self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
self.weight_dtype = weight_qparams["dtype"] self.weight_dtype = weight_qparams["dtype"]
assert self.weight_qscheme in [ assert self.weight_qscheme in [
@ -81,14 +81,14 @@ class ReferenceQuantizedModule(torch.nn.Module):
self.register_buffer( self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device) "weight_axis", torch.tensor(0, dtype=torch.int, device=device)
) )
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False) self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export # store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
# for capturing `.item` operations # for capturing `.item` operations
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment] self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min") self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min")
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max") self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max")
def get_weight(self): def get_weight(self):
@ -105,7 +105,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_and_dequantize_weight_decomposed( return _quantize_and_dequantize_weight_decomposed(
self.weight, # type: ignore[arg-type] self.weight, # type: ignore[arg-type]
self.weight_qscheme, self.weight_qscheme,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.weight_dtype, self.weight_dtype,
self.weight_scale, self.weight_scale,
self.weight_zero_point, self.weight_zero_point,
@ -117,7 +117,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_and_dequantize_weight( return _quantize_and_dequantize_weight(
self.weight, # type: ignore[arg-type] self.weight, # type: ignore[arg-type]
self.weight_qscheme, self.weight_qscheme,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.weight_dtype, self.weight_dtype,
self.weight_scale, self.weight_scale,
self.weight_zero_point, self.weight_zero_point,
@ -133,7 +133,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_weight_decomposed( return _quantize_weight_decomposed(
self.weight, # type: ignore[arg-type] self.weight, # type: ignore[arg-type]
self.weight_qscheme, self.weight_qscheme,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.weight_dtype, self.weight_dtype,
self.weight_scale, self.weight_scale,
self.weight_zero_point, self.weight_zero_point,
@ -145,7 +145,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_weight( return _quantize_weight(
self.weight, # type: ignore[arg-type] self.weight, # type: ignore[arg-type]
self.weight_qscheme, self.weight_qscheme,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
self.weight_dtype, self.weight_dtype,
self.weight_scale, self.weight_scale,
self.weight_zero_point, self.weight_zero_point,

View File

@ -151,9 +151,9 @@ class Linear(torch.nn.Module):
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined" assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) is nni.LinearReLU: if type(mod) is nni.LinearReLU:
mod = mod[0] mod = mod[0]
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
if mod.qconfig is not None and mod.qconfig.weight is not None: if mod.qconfig is not None and mod.qconfig.weight is not None:
# pyrefly: ignore # not-callable # pyrefly: ignore [not-callable]
weight_observer = mod.qconfig.weight() weight_observer = mod.qconfig.weight()
else: else:
# We have the circular import issues if we import the qconfig in the beginning of this file: # We have the circular import issues if we import the qconfig in the beginning of this file:
@ -187,6 +187,6 @@ class Linear(torch.nn.Module):
col_block_size, col_block_size,
dtype=dtype, dtype=dtype,
) )
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size) qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
return qlinear return qlinear

View File

@ -959,7 +959,7 @@ def create_a_shadows_b(
if should_log_inputs: if should_log_inputs:
# skip the input logger when inserting a dtype cast # skip the input logger when inserting a dtype cast
if isinstance(prev_node_c, Node): if isinstance(prev_node_c, Node):
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0) prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
elif isinstance(prev_node_c, list): elif isinstance(prev_node_c, list):
prev_node_c = [ prev_node_c = [
@ -968,7 +968,7 @@ def create_a_shadows_b(
] ]
dtype_cast_node = _insert_dtype_cast_after_node( dtype_cast_node = _insert_dtype_cast_after_node(
subgraph_a.start_node, subgraph_a.start_node,
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
node_c, node_c,
prev_node_c, prev_node_c,
gm_a, gm_a,
@ -1052,7 +1052,7 @@ def create_a_shadows_b(
if num_non_param_args_node_a == 2: if num_non_param_args_node_a == 2:
# node_c_second_non_param_arg = node_c.args[1] # node_c_second_non_param_arg = node_c.args[1]
node_c_second_non_param_arg = get_normalized_nth_input( node_c_second_non_param_arg = get_normalized_nth_input(
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
node_c, node_c,
gm_b, gm_b,
1, 1,
@ -1063,7 +1063,7 @@ def create_a_shadows_b(
subgraph_a, subgraph_a,
gm_a, gm_a,
gm_b, gm_b,
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
node_c.name + "_shadow_copy_", node_c.name + "_shadow_copy_",
) )
env_c[node_a_shadows_c.name] = node_a_shadows_c env_c[node_a_shadows_c.name] = node_a_shadows_c
@ -1086,19 +1086,19 @@ def create_a_shadows_b(
cur_node = node_a_shadows_c cur_node = node_a_shadows_c
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined] while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment] cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
if isinstance(input_logger, Node): if isinstance(input_logger, Node):
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
input_logger_mod = getattr(gm_b, input_logger.name) input_logger_mod = getattr(gm_b, input_logger.name)
input_logger_mod.ref_node_name = cur_node.name input_logger_mod.ref_node_name = cur_node.name
else: else:
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
if not isinstance(input_logger, list): if not isinstance(input_logger, list):
raise AssertionError( raise AssertionError(
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
f"Expected list, got {type(input_logger)}" f"Expected list, got {type(input_logger)}"
) )
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
for input_logger_inner in input_logger: for input_logger_inner in input_logger:
input_logger_mod = getattr(gm_b, input_logger_inner.name) input_logger_mod = getattr(gm_b, input_logger_inner.name)
input_logger_mod.ref_node_name = cur_node.name input_logger_mod.ref_node_name = cur_node.name

View File

@ -419,7 +419,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
target2, target2,
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items(): ) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
new_connections.append((source, target1)) new_connections.append((source, target1))
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
new_connections.append((source, target2)) new_connections.append((source, target2))
for source_to_target in ( for source_to_target in (
@ -428,7 +428,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS, quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
): ):
for source, target in source_to_target.items(): # type:ignore[assignment] for source, target in source_to_target.items(): # type:ignore[assignment]
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
new_connections.append((source, target)) new_connections.append((source, target))
# #

View File

@ -94,10 +94,10 @@ class OutputProp:
) )
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined] if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
node.traced_result = result node.traced_result = result
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
env[node.name] = result env[node.name] = result
return None return None
@ -403,10 +403,10 @@ def create_submodule_from_subgraph(
cur_name_idx += 1 cur_name_idx += 1
setattr(gm, mod_name, new_arg) setattr(gm, mod_name, new_arg)
new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator] new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator]
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
cur_args_copy.append(new_arg_placeholder) cur_args_copy.append(new_arg_placeholder)
elif isinstance(arg, (float, int, torch.dtype)): elif isinstance(arg, (float, int, torch.dtype)):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
cur_args_copy.append(arg) cur_args_copy.append(arg)
else: else:
raise AssertionError(f"arg of type {type(arg)} not handled yet") raise AssertionError(f"arg of type {type(arg)} not handled yet")
@ -818,7 +818,7 @@ def create_add_loggers_graph(
model, model,
cur_subgraph_idx, cur_subgraph_idx,
match_name, match_name,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
maybe_subgraph, maybe_subgraph,
[qconfig_mapping], [qconfig_mapping],
[node_name_to_qconfig], [node_name_to_qconfig],
@ -879,7 +879,7 @@ def create_add_loggers_graph(
cur_node_orig = first_node cur_node_orig = first_node
cur_node_copy = None cur_node_copy = None
first_node_copy = None first_node_copy = None
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
while cur_node_orig in subgraph_to_use: while cur_node_orig in subgraph_to_use:
# TODO(future PR): make this support all possible args/kwargs # TODO(future PR): make this support all possible args/kwargs
if cur_node_orig is first_node: if cur_node_orig is first_node:

View File

@ -496,7 +496,7 @@ def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso
Return: Return:
float or tuple of floats float or tuple of floats
""" """
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum()) return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())

View File

@ -23,17 +23,17 @@ class SparseDLRM(DLRM_Net):
super().__init__(**args) super().__init__(**args)
def forward(self, dense_x, lS_o, lS_i): def forward(self, dense_x, lS_o, lS_i):
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
x = self.apply_mlp(dense_x, self.bot_l) # dense features x = self.apply_mlp(dense_x, self.bot_l) # dense features
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
z = self.interact_features(x, ly) z = self.interact_features(x, ly)
z = z.to_sparse_coo() z = z.to_sparse_coo()
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias) z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias)
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
for layer in self.top_l[1:]: for layer in self.top_l[1:]:
z = layer(z) z = layer(z)

View File

@ -27,7 +27,7 @@ def run_forward(model, **batch):
model(X, lS_o, lS_i) model(X, lS_o, lS_i)
end = time.time() end = time.time()
time_taken = end - start time_taken = end - start
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
time_list.append(time_taken) time_list.append(time_taken)
avg_time = np.mean(time_list[1:]) avg_time = np.mean(time_list[1:])
return avg_time return avg_time

View File

@ -72,7 +72,7 @@ class FPGMPruner(BaseStructuredSparsifier):
dist_matrix = self.dist_fn(t_flatten) dist_matrix = self.dist_fn(t_flatten)
# more similar with other filter indicates large in the sum of row # more similar with other filter indicates large in the sum of row
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
distance = torch.sum(torch.abs(dist_matrix), 1) distance = torch.sum(torch.abs(dist_matrix), 1)
return distance return distance

View File

@ -260,7 +260,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
module.register_parameter( module.register_parameter(
"_bias", nn.Parameter(module.bias.detach()) "_bias", nn.Parameter(module.bias.detach())
) )
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
module.bias = None module.bias = None
module.prune_bias = prune_bias module.prune_bias = prune_bias

View File

@ -97,7 +97,7 @@ def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]:
if module.bias is not None: if module.bias is not None:
module.bias = nn.Parameter(cast(Tensor, module.bias)[mask]) module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
elif getattr(module, "_bias", None) is not None: elif getattr(module, "_bias", None) is not None:
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
module.bias = nn.Parameter(cast(Tensor, module._bias)[mask]) module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
# get pruned biases to propagate to subsequent layer # get pruned biases to propagate to subsequent layer
@ -127,7 +127,7 @@ def _prune_linear_helper(linear: nn.Linear) -> Tensor:
linear.out_features = linear.weight.shape[0] linear.out_features = linear.weight.shape[0]
_remove_bias_handles(linear) _remove_bias_handles(linear)
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
return mask return mask
@ -186,7 +186,7 @@ def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
conv2d.out_channels = conv2d.weight.shape[0] conv2d.out_channels = conv2d.weight.shape[0]
_remove_bias_handles(conv2d) _remove_bias_handles(conv2d)
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
return mask return mask
@ -207,7 +207,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
new_bias = torch.zeros(conv2d_1.bias.shape) new_bias = torch.zeros(conv2d_1.bias.shape)
new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined] new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined]
# adjusted bias that to keep in conv2d_1 # adjusted bias that to keep in conv2d_1
# pyrefly: ignore # unbound-name # pyrefly: ignore [unbound-name]
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask] new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
# pruned biases that are kept instead of propagated # pruned biases that are kept instead of propagated
conv2d_1.bias = nn.Parameter(new_bias) conv2d_1.bias = nn.Parameter(new_bias)

View File

@ -170,7 +170,7 @@ class BaseSparsifier(abc.ABC):
self.make_config_from_model(model) self.make_config_from_model(model)
# TODO: Remove the configuration by reference ('module') # TODO: Remove the configuration by reference ('module')
# pyrefly: ignore # not-iterable # pyrefly: ignore [not-iterable]
for module_config in self.config: for module_config in self.config:
assert isinstance(module_config, dict), ( assert isinstance(module_config, dict), (
"config elements should be dicts not modules i.e.:" "config elements should be dicts not modules i.e.:"

View File

@ -51,7 +51,7 @@ def swap_module(
new_mod.register_forward_hook(hook_fn) new_mod.register_forward_hook(hook_fn)
# respect device affinity when swapping modules # respect device affinity when swapping modules
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
devices = {p.device for p in chain(mod.parameters(), mod.buffers())} devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
assert len(devices) <= 1, ( assert len(devices) <= 1, (
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}" f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"

View File

@ -235,7 +235,7 @@ class WeightNormSparsifier(BaseSparsifier):
ww = self.norm_fn(getattr(module, tensor_name)) ww = self.norm_fn(getattr(module, tensor_name))
tensor_mask = self._make_tensor_mask( tensor_mask = self._make_tensor_mask(
data=ww, data=ww,
# pyrefly: ignore # missing-attribute # pyrefly: ignore [missing-attribute]
input_shape=ww.shape, input_shape=ww.shape,
sparsity_level=sparsity_level, sparsity_level=sparsity_level,
sparse_block_shape=sparse_block_shape, sparse_block_shape=sparse_block_shape,

View File

@ -25,7 +25,7 @@ from .pt2e.export_utils import (
_move_exported_model_to_train as move_exported_model_to_train, _move_exported_model_to_train as move_exported_model_to_train,
) )
# pyrefly: ignore # deprecated # pyrefly: ignore [deprecated]
from .qconfig import * # noqa: F403 from .qconfig import * # noqa: F403
from .qconfig_mapping import * # noqa: F403 from .qconfig_mapping import * # noqa: F403
from .quant_type import * # noqa: F403 from .quant_type import * # noqa: F403

View File

@ -185,9 +185,9 @@ class FakeQuantize(FakeQuantizeBase):
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get( dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
"dtype", dtype "dtype", dtype
) )
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound" assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound" assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max}) observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
observer_kwargs["is_dynamic"] = is_dynamic observer_kwargs["is_dynamic"] = is_dynamic

View File

@ -1149,7 +1149,7 @@ quantized_decomposed_lib.define(
class FakeQuantPerChannel(torch.autograd.Function): class FakeQuantPerChannel(torch.autograd.Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max): def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
if scales.dtype != torch.float32: if scales.dtype != torch.float32:
scales = scales.to(torch.float32) scales = scales.to(torch.float32)
@ -1172,7 +1172,7 @@ class FakeQuantPerChannel(torch.autograd.Function):
return out return out
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def backward(ctx, gy): def backward(ctx, gy):
(mask,) = ctx.saved_tensors (mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None return gy * mask, None, None, None, None, None

View File

@ -248,7 +248,7 @@ def calculate_equalization_scale(
class EqualizationQConfig( class EqualizationQConfig(
# pyrefly: ignore # invalid-inheritance # pyrefly: ignore [invalid-inheritance]
namedtuple("EqualizationQConfig", ["input_activation", "weight"]) namedtuple("EqualizationQConfig", ["input_activation", "weight"])
): ):
""" """
@ -463,7 +463,7 @@ def maybe_get_next_equalization_scale(
In this case, the node given is linear1 and we want to locate the InputEqObs. In this case, the node given is linear1 and we want to locate the InputEqObs.
""" """
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules) next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
# pyrefly: ignore # invalid-argument # pyrefly: ignore [invalid-argument]
if next_inp_eq_obs: if next_inp_eq_obs:
if ( if (
next_inp_eq_obs.equalization_scale.nelement() == 1 next_inp_eq_obs.equalization_scale.nelement() == 1
@ -827,7 +827,7 @@ def convert_eq_obs(
scale_weight_node( scale_weight_node(
node, node,
modules, modules,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
equalization_scale, equalization_scale,
maybe_next_equalization_scale, maybe_next_equalization_scale,
) )
@ -836,7 +836,7 @@ def convert_eq_obs(
node, node,
model, model,
modules, modules,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
equalization_scale, equalization_scale,
maybe_next_equalization_scale, maybe_next_equalization_scale,
) )

View File

@ -223,7 +223,7 @@ class ModelReportVisualizer:
feature_val = feature_val.item() feature_val = feature_val.item()
# we add to our list of values # we add to our list of values
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
tensor_table_row.append(feature_val) tensor_table_row.append(feature_val)
tensor_table.append(tensor_table_row) tensor_table.append(tensor_table_row)
@ -284,7 +284,7 @@ class ModelReportVisualizer:
feature_val = feature_val.item() feature_val = feature_val.item()
# add value to channel specific row # add value to channel specific row
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
new_channel_row.append(feature_val) new_channel_row.append(feature_val)
# add to table and increment row index counter # add to table and increment row index counter

View File

@ -166,7 +166,7 @@ def _create_obs_or_fq_from_qspec(
} }
edge_or_nodes = quantization_spec.derived_from edge_or_nodes = quantization_spec.derived_from
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes] obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
kwargs["obs_or_fqs"] = obs_or_fqs kwargs["obs_or_fqs"] = obs_or_fqs
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)() return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec): elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
@ -2088,11 +2088,11 @@ def prepare(
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config) root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
_update_qconfig_for_fusion(model, qconfig_mapping) _update_qconfig_for_fusion(model, qconfig_mapping)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
_update_qconfig_for_fusion(model, _equalization_config) _update_qconfig_for_fusion(model, _equalization_config)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping) flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
# TODO: support regex as well # TODO: support regex as well
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict()) propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
@ -2100,7 +2100,7 @@ def prepare(
if is_qat: if is_qat:
module_to_qat_module = get_module_to_qat_module(backend_config) module_to_qat_module = get_module_to_qat_module(backend_config)
_qat_swap_modules(model, module_to_qat_module) _qat_swap_modules(model, module_to_qat_module)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
_update_qconfig_for_qat(qconfig_mapping, backend_config) _update_qconfig_for_qat(qconfig_mapping, backend_config)
# mapping from fully qualified module name to module instance # mapping from fully qualified module name to module instance
@ -2117,7 +2117,7 @@ def prepare(
model, model,
named_modules, named_modules,
model.graph, model.graph,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
_equalization_config, _equalization_config,
node_name_to_scope, node_name_to_scope,
) )
@ -2125,7 +2125,7 @@ def prepare(
model, model,
named_modules, named_modules,
model.graph, model.graph,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
qconfig_mapping, qconfig_mapping,
node_name_to_scope, node_name_to_scope,
) )
@ -2187,7 +2187,7 @@ def prepare(
node_name_to_scope, node_name_to_scope,
prepare_custom_config, prepare_custom_config,
equalization_node_name_to_qconfig, equalization_node_name_to_qconfig,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
qconfig_mapping, qconfig_mapping,
is_qat, is_qat,
observed_node_names, observed_node_names,

View File

@ -721,7 +721,7 @@ def _maybe_get_custom_module_lstm_from_node_arg(
a = a.args[0][0] # type: ignore[assignment,index] a = a.args[0][0] # type: ignore[assignment,index]
else: else:
a = a.args[0] # type: ignore[assignment] a = a.args[0] # type: ignore[assignment]
# pyrefly: ignore # bad-return # pyrefly: ignore [bad-return]
return a return a
all_match_patterns = [ all_match_patterns = [

View File

@ -281,12 +281,12 @@ class UniformQuantizationObserverBase(ObserverBase):
) )
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None) self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
if self.has_customized_qrange: if self.has_customized_qrange:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
validate_qmin_qmax(quant_min, quant_max) validate_qmin_qmax(quant_min, quant_max)
self.quant_min, self.quant_max = calculate_qmin_qmax( self.quant_min, self.quant_max = calculate_qmin_qmax(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
quant_min, quant_min,
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
quant_max, quant_max,
self.has_customized_qrange, self.has_customized_qrange,
self.dtype, self.dtype,

View File

@ -83,7 +83,7 @@ __all__ = [
] ]
# pyrefly: ignore # invalid-inheritance # pyrefly: ignore [invalid-inheritance]
class QConfig(namedtuple("QConfig", ["activation", "weight"])): class QConfig(namedtuple("QConfig", ["activation", "weight"])):
""" """
Describes how to quantize a layer or a part of the network by providing Describes how to quantize a layer or a part of the network by providing
@ -121,7 +121,7 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
"`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead", "`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead",
category=FutureWarning, category=FutureWarning,
) )
# pyrefly: ignore # invalid-inheritance # pyrefly: ignore [invalid-inheritance]
class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])): class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
""" """
Describes how to dynamically quantize a layer or a part of the network by providing Describes how to dynamically quantize a layer or a part of the network by providing

View File

@ -418,7 +418,7 @@ class X86InductorQuantizer(Quantizer):
# As we use `_need_skip_config` to skip all invalid configurations, # As we use `_need_skip_config` to skip all invalid configurations,
# we can safely assume that the all existing non-None configurations # we can safely assume that the all existing non-None configurations
# have the same quantization mode. # have the same quantization mode.
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
for qconfig in ( for qconfig in (
list(self.module_name_qconfig.values()) list(self.module_name_qconfig.values())
+ list(self.operator_type_qconfig.values()) + list(self.operator_type_qconfig.values())
@ -818,7 +818,7 @@ class X86InductorQuantizer(Quantizer):
) )
binary_node.meta[QUANT_ANNOTATION_KEY] = ( binary_node.meta[QUANT_ANNOTATION_KEY] = (
_X86InductorQuantizationAnnotation( _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
input_qspec_map=binary_node_input_qspec_map, input_qspec_map=binary_node_input_qspec_map,
_annotated=True, _annotated=True,
) )
@ -889,7 +889,7 @@ class X86InductorQuantizer(Quantizer):
) )
binary_node.meta[QUANT_ANNOTATION_KEY] = ( binary_node.meta[QUANT_ANNOTATION_KEY] = (
_X86InductorQuantizationAnnotation( _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
input_qspec_map=binary_node_input_qspec_map, input_qspec_map=binary_node_input_qspec_map,
# TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher. # TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
@ -1097,7 +1097,7 @@ class X86InductorQuantizer(Quantizer):
quantization_config quantization_config
) )
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
input_qspec_map=binary_node_input_qspec_map, input_qspec_map=binary_node_input_qspec_map,
_annotated=True, _annotated=True,
) )
@ -1152,7 +1152,7 @@ class X86InductorQuantizer(Quantizer):
quantization_config quantization_config
) )
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation( binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
input_qspec_map=binary_node_input_qspec_map, input_qspec_map=binary_node_input_qspec_map,
_annotated=True, _annotated=True,
_is_output_of_quantized_pattern=True, _is_output_of_quantized_pattern=True,
@ -1508,7 +1508,7 @@ class X86InductorQuantizer(Quantizer):
has_unary = unary_op is not None has_unary = unary_op is not None
seq_partition = [torch.nn.Linear, binary_op] seq_partition = [torch.nn.Linear, binary_op]
if has_unary: if has_unary:
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
seq_partition.append(unary_op) seq_partition.append(unary_op)
fused_partitions = find_sequential_partitions(gm, seq_partition) fused_partitions = find_sequential_partitions(gm, seq_partition)
for fused_partition in fused_partitions: for fused_partition in fused_partitions:

View File

@ -376,11 +376,11 @@ def _do_annotate_conv_relu(
input_qspec_map[bias] = get_bias_qspec(quantization_config) input_qspec_map[bias] = get_bias_qspec(quantization_config)
partition.append(bias) partition.append(bias)
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
if _is_annotated(partition): if _is_annotated(partition):
continue continue
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
if filter_fn and any(not filter_fn(n) for n in partition): if filter_fn and any(not filter_fn(n) for n in partition):
continue continue
@ -391,7 +391,7 @@ def _do_annotate_conv_relu(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type] output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True, _annotated=True,
) )
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
_mark_nodes_as_annotated(partition) _mark_nodes_as_annotated(partition)
annotated_partitions.append(partition) annotated_partitions.append(partition)
return annotated_partitions return annotated_partitions

View File

@ -39,7 +39,7 @@ class Bernoulli(ExponentialFamily):
validate_args (bool, optional): whether to validate arguments, None by default validate_args (bool, optional): whether to validate arguments, None by default
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.boolean support = constraints.boolean
has_enumerate_support = True has_enumerate_support = True
@ -57,12 +57,12 @@ class Bernoulli(ExponentialFamily):
) )
if probs is not None: if probs is not None:
is_scalar = isinstance(probs, _Number) is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs) (self.probs,) = broadcast_all(probs)
else: else:
assert logits is not None # helps mypy assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number) is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits) (self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits self._param = self.probs if probs is not None else self.logits
if is_scalar: if is_scalar:
@ -140,6 +140,6 @@ class Bernoulli(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]: def _natural_params(self) -> tuple[Tensor]:
return (torch.logit(self.probs),) return (torch.logit(self.probs),)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def _log_normalizer(self, x): def _log_normalizer(self, x):
return torch.log1p(torch.exp(x)) return torch.log1p(torch.exp(x))

View File

@ -31,7 +31,7 @@ class Beta(ExponentialFamily):
(often referred to as beta) (often referred to as beta)
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = { arg_constraints = {
"concentration1": constraints.positive, "concentration1": constraints.positive,
"concentration0": constraints.positive, "concentration0": constraints.positive,
@ -114,6 +114,6 @@ class Beta(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]: def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.concentration1, self.concentration0) return (self.concentration1, self.concentration0)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def _log_normalizer(self, x, y): def _log_normalizer(self, x, y):
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y) return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)

View File

@ -45,7 +45,7 @@ class Binomial(Distribution):
logits (Tensor): Event log-odds logits (Tensor): Event log-odds
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = { arg_constraints = {
"total_count": constraints.nonnegative_integer, "total_count": constraints.nonnegative_integer,
"probs": constraints.unit_interval, "probs": constraints.unit_interval,
@ -67,7 +67,7 @@ class Binomial(Distribution):
if probs is not None: if probs is not None:
( (
self.total_count, self.total_count,
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.probs, self.probs,
) = broadcast_all(total_count, probs) ) = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs) self.total_count = self.total_count.type_as(self.probs)
@ -75,7 +75,7 @@ class Binomial(Distribution):
assert logits is not None # helps mypy assert logits is not None # helps mypy
( (
self.total_count, self.total_count,
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.logits, self.logits,
) = broadcast_all(total_count, logits) ) = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits) self.total_count = self.total_count.type_as(self.logits)
@ -102,7 +102,7 @@ class Binomial(Distribution):
return self._param.new(*args, **kwargs) return self._param.new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=0) @constraints.dependent_property(is_discrete=True, event_dim=0)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def support(self): def support(self):
return constraints.integer_interval(0, self.total_count) return constraints.integer_interval(0, self.total_count)

View File

@ -50,7 +50,7 @@ class Categorical(Distribution):
logits (Tensor): event log probabilities (unnormalized) logits (Tensor): event log probabilities (unnormalized)
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
has_enumerate_support = True has_enumerate_support = True
@ -67,14 +67,14 @@ class Categorical(Distribution):
if probs is not None: if probs is not None:
if probs.dim() < 1: if probs.dim() < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.") raise ValueError("`probs` parameter must be at least one-dimensional.")
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.probs = probs / probs.sum(-1, keepdim=True) self.probs = probs / probs.sum(-1, keepdim=True)
else: else:
assert logits is not None # helps mypy assert logits is not None # helps mypy
if logits.dim() < 1: if logits.dim() < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.") raise ValueError("`logits` parameter must be at least one-dimensional.")
# Normalize # Normalize
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True) self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
self._param = self.probs if probs is not None else self.logits self._param = self.probs if probs is not None else self.logits
self._num_events = self._param.size()[-1] self._num_events = self._param.size()[-1]
@ -102,7 +102,7 @@ class Categorical(Distribution):
return self._param.new(*args, **kwargs) return self._param.new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=0) @constraints.dependent_property(is_discrete=True, event_dim=0)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def support(self): def support(self):
return constraints.integer_interval(0, self._num_events - 1) return constraints.integer_interval(0, self._num_events - 1)

View File

@ -31,7 +31,7 @@ class Cauchy(Distribution):
scale (float or Tensor): half width at half maximum. scale (float or Tensor): half width at half maximum.
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real support = constraints.real
has_rsample = True has_rsample = True

View File

@ -47,7 +47,7 @@ class ContinuousBernoulli(ExponentialFamily):
https://arxiv.org/abs/1907.06845 https://arxiv.org/abs/1907.06845
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.unit_interval support = constraints.unit_interval
_mean_carrier_measure = 0 _mean_carrier_measure = 0
@ -66,19 +66,19 @@ class ContinuousBernoulli(ExponentialFamily):
) )
if probs is not None: if probs is not None:
is_scalar = isinstance(probs, _Number) is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs) (self.probs,) = broadcast_all(probs)
# validate 'probs' here if necessary as it is later clamped for numerical stability # validate 'probs' here if necessary as it is later clamped for numerical stability
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass # close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
if validate_args is not None: if validate_args is not None:
if not self.arg_constraints["probs"].check(self.probs).all(): if not self.arg_constraints["probs"].check(self.probs).all():
raise ValueError("The parameter probs has invalid values") raise ValueError("The parameter probs has invalid values")
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.probs = clamp_probs(self.probs) self.probs = clamp_probs(self.probs)
else: else:
assert logits is not None # helps mypy assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number) is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits) (self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits self._param = self.probs if probs is not None else self.logits
if is_scalar: if is_scalar:
@ -234,7 +234,7 @@ class ContinuousBernoulli(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]: def _natural_params(self) -> tuple[Tensor]:
return (self.logits,) return (self.logits,)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def _log_normalizer(self, x): def _log_normalizer(self, x):
"""computes the log normalizing constant as a function of the natural parameter""" """computes the log normalizing constant as a function of the natural parameter"""
out_unst_reg = torch.max( out_unst_reg = torch.max(

View File

@ -22,7 +22,7 @@ def _Dirichlet_backward(x, concentration, grad_output):
class _Dirichlet(Function): class _Dirichlet(Function):
@staticmethod @staticmethod
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def forward(ctx, concentration): def forward(ctx, concentration):
x = torch._sample_dirichlet(concentration) x = torch._sample_dirichlet(concentration)
ctx.save_for_backward(x, concentration) ctx.save_for_backward(x, concentration)
@ -30,7 +30,7 @@ class _Dirichlet(Function):
@staticmethod @staticmethod
@once_differentiable @once_differentiable
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def backward(ctx, grad_output): def backward(ctx, grad_output):
x, concentration = ctx.saved_tensors x, concentration = ctx.saved_tensors
return _Dirichlet_backward(x, concentration, grad_output) return _Dirichlet_backward(x, concentration, grad_output)
@ -52,7 +52,7 @@ class Dirichlet(ExponentialFamily):
(often referred to as alpha) (often referred to as alpha)
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = { arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1) "concentration": constraints.independent(constraints.positive, 1)
} }
@ -133,6 +133,6 @@ class Dirichlet(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]: def _natural_params(self) -> tuple[Tensor]:
return (self.concentration,) return (self.concentration,)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def _log_normalizer(self, x): def _log_normalizer(self, x):
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1)) return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))

View File

@ -27,7 +27,7 @@ class Exponential(ExponentialFamily):
rate (float or Tensor): rate = 1 / scale of the distribution rate (float or Tensor): rate = 1 / scale of the distribution
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"rate": constraints.positive} arg_constraints = {"rate": constraints.positive}
support = constraints.nonnegative support = constraints.nonnegative
has_rsample = True has_rsample = True
@ -90,6 +90,6 @@ class Exponential(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]: def _natural_params(self) -> tuple[Tensor]:
return (-self.rate,) return (-self.rate,)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def _log_normalizer(self, x): def _log_normalizer(self, x):
return -torch.log(-x) return -torch.log(-x)

View File

@ -29,7 +29,7 @@ class FisherSnedecor(Distribution):
df2 (float or Tensor): degrees of freedom parameter 2 df2 (float or Tensor): degrees of freedom parameter 2
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive} arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
support = constraints.positive support = constraints.positive
has_rsample = True has_rsample = True

View File

@ -34,7 +34,7 @@ class Gamma(ExponentialFamily):
(often referred to as beta), rate = 1 / scale (often referred to as beta), rate = 1 / scale
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = { arg_constraints = {
"concentration": constraints.positive, "concentration": constraints.positive,
"rate": constraints.positive, "rate": constraints.positive,
@ -110,7 +110,7 @@ class Gamma(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]: def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.concentration - 1, -self.rate) return (self.concentration - 1, -self.rate)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def _log_normalizer(self, x, y): def _log_normalizer(self, x, y):
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal()) return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())

View File

@ -35,7 +35,7 @@ class GeneralizedPareto(Distribution):
concentration (float or Tensor): Concentration parameter of the distribution concentration (float or Tensor): Concentration parameter of the distribution
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = { arg_constraints = {
"loc": constraints.real, "loc": constraints.real,
"scale": constraints.positive, "scale": constraints.positive,
@ -131,7 +131,7 @@ class GeneralizedPareto(Distribution):
concentration = self.concentration concentration = self.concentration
valid = concentration < 0.5 valid = concentration < 0.5
safe_conc = torch.where(valid, concentration, 0.25) safe_conc = torch.where(valid, concentration, 0.25)
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc)) result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
return torch.where(valid, result, nan) return torch.where(valid, result, nan)
@ -144,7 +144,7 @@ class GeneralizedPareto(Distribution):
return self.loc return self.loc
@constraints.dependent_property(is_discrete=False, event_dim=0) @constraints.dependent_property(is_discrete=False, event_dim=0)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def support(self): def support(self):
lower = self.loc lower = self.loc
upper = torch.where( upper = torch.where(

View File

@ -44,7 +44,7 @@ class Geometric(Distribution):
logits (Number, Tensor): the log-odds of sampling `1`. logits (Number, Tensor): the log-odds of sampling `1`.
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.nonnegative_integer support = constraints.nonnegative_integer
@ -59,11 +59,11 @@ class Geometric(Distribution):
"Either `probs` or `logits` must be specified, but not both." "Either `probs` or `logits` must be specified, but not both."
) )
if probs is not None: if probs is not None:
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs) (self.probs,) = broadcast_all(probs)
else: else:
assert logits is not None # helps mypy assert logits is not None # helps mypy
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits) (self.logits,) = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, _Number): if isinstance(probs_or_logits, _Number):

View File

@ -32,7 +32,7 @@ class Gumbel(TransformedDistribution):
""" """
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
support = constraints.real support = constraints.real
def __init__( def __init__(

View File

@ -32,10 +32,10 @@ class HalfCauchy(TransformedDistribution):
""" """
arg_constraints = {"scale": constraints.positive} arg_constraints = {"scale": constraints.positive}
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
support = constraints.nonnegative support = constraints.nonnegative
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
base_dist: Cauchy base_dist: Cauchy
def __init__( def __init__(

View File

@ -32,10 +32,10 @@ class HalfNormal(TransformedDistribution):
""" """
arg_constraints = {"scale": constraints.positive} arg_constraints = {"scale": constraints.positive}
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
support = constraints.nonnegative support = constraints.nonnegative
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
base_dist: Normal base_dist: Normal
def __init__( def __init__(

View File

@ -91,7 +91,7 @@ class Independent(Distribution, Generic[D]):
return self.base_dist.has_enumerate_support return self.base_dist.has_enumerate_support
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def support(self): def support(self):
result = self.base_dist.support result = self.base_dist.support
if self.reinterpreted_batch_ndims: if self.reinterpreted_batch_ndims:

View File

@ -38,10 +38,10 @@ class InverseGamma(TransformedDistribution):
"concentration": constraints.positive, "concentration": constraints.positive,
"rate": constraints.positive, "rate": constraints.positive,
} }
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
support = constraints.positive support = constraints.positive
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
base_dist: Gamma base_dist: Gamma
def __init__( def __init__(

View File

@ -44,7 +44,7 @@ class Kumaraswamy(TransformedDistribution):
"concentration1": constraints.positive, "concentration1": constraints.positive,
"concentration0": constraints.positive, "concentration0": constraints.positive,
} }
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
support = constraints.unit_interval support = constraints.unit_interval
has_rsample = True has_rsample = True
@ -67,7 +67,7 @@ class Kumaraswamy(TransformedDistribution):
AffineTransform(loc=1.0, scale=-1.0), AffineTransform(loc=1.0, scale=-1.0),
PowerTransform(exponent=self.concentration1.reciprocal()), PowerTransform(exponent=self.concentration1.reciprocal()),
] ]
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
super().__init__(base_dist, transforms, validate_args=validate_args) super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(self, batch_shape, _instance=None): def expand(self, batch_shape, _instance=None):

View File

@ -28,7 +28,7 @@ class Laplace(Distribution):
scale (float or Tensor): scale of the distribution scale (float or Tensor): scale of the distribution
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real support = constraints.real
has_rsample = True has_rsample = True

View File

@ -60,7 +60,7 @@ class LKJCholesky(Distribution):
Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008 Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"concentration": constraints.positive} arg_constraints = {"concentration": constraints.positive}
support = constraints.corr_cholesky support = constraints.corr_cholesky

View File

@ -32,10 +32,10 @@ class LogNormal(TransformedDistribution):
""" """
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
support = constraints.positive support = constraints.positive
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
base_dist: Normal base_dist: Normal
def __init__( def __init__(

View File

@ -36,10 +36,10 @@ class LogisticNormal(TransformedDistribution):
""" """
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
support = constraints.simplex support = constraints.simplex
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
base_dist: Independent[Normal] base_dist: Independent[Normal]
def __init__( def __init__(

View File

@ -86,7 +86,7 @@ class LowRankMultivariateNormal(Distribution):
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = { arg_constraints = {
"loc": constraints.real_vector, "loc": constraints.real_vector,
"cov_factor": constraints.independent(constraints.real, 2), "cov_factor": constraints.independent(constraints.real, 2),

View File

@ -124,7 +124,7 @@ class MixtureSameFamily(Distribution):
return new return new
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def support(self): def support(self):
return MixtureSameFamilyConstraint(self._component_distribution.support) return MixtureSameFamilyConstraint(self._component_distribution.support)

View File

@ -50,7 +50,7 @@ class Multinomial(Distribution):
logits (Tensor): event log probabilities (unnormalized) logits (Tensor): event log probabilities (unnormalized)
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
total_count: int total_count: int
@ -93,7 +93,7 @@ class Multinomial(Distribution):
return self._categorical._new(*args, **kwargs) return self._categorical._new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=1) @constraints.dependent_property(is_discrete=True, event_dim=1)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def support(self): def support(self):
return constraints.multinomial(self.total_count) return constraints.multinomial(self.total_count)

View File

@ -123,7 +123,7 @@ class MultivariateNormal(Distribution):
the corresponding lower triangular matrices using a Cholesky decomposition. the corresponding lower triangular matrices using a Cholesky decomposition.
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = { arg_constraints = {
"loc": constraints.real_vector, "loc": constraints.real_vector,
"covariance_matrix": constraints.positive_definite, "covariance_matrix": constraints.positive_definite,
@ -157,7 +157,7 @@ class MultivariateNormal(Distribution):
"with optional leading batch dimensions" "with optional leading batch dimensions"
) )
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1]) batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1)) self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None: elif covariance_matrix is not None:
if covariance_matrix.dim() < 2: if covariance_matrix.dim() < 2:
@ -168,7 +168,7 @@ class MultivariateNormal(Distribution):
batch_shape = torch.broadcast_shapes( batch_shape = torch.broadcast_shapes(
covariance_matrix.shape[:-2], loc.shape[:-1] covariance_matrix.shape[:-2], loc.shape[:-1]
) )
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1)) self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
else: else:
assert precision_matrix is not None # helps mypy assert precision_matrix is not None # helps mypy
@ -180,7 +180,7 @@ class MultivariateNormal(Distribution):
batch_shape = torch.broadcast_shapes( batch_shape = torch.broadcast_shapes(
precision_matrix.shape[:-2], loc.shape[:-1] precision_matrix.shape[:-2], loc.shape[:-1]
) )
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1)) self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
self.loc = loc.expand(batch_shape + (-1,)) self.loc = loc.expand(batch_shape + (-1,))

View File

@ -33,7 +33,7 @@ class NegativeBinomial(Distribution):
logits (Tensor): Event log-odds for probabilities of success logits (Tensor): Event log-odds for probabilities of success
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = { arg_constraints = {
"total_count": constraints.greater_than_eq(0), "total_count": constraints.greater_than_eq(0),
"probs": constraints.half_open_interval(0.0, 1.0), "probs": constraints.half_open_interval(0.0, 1.0),
@ -55,7 +55,7 @@ class NegativeBinomial(Distribution):
if probs is not None: if probs is not None:
( (
self.total_count, self.total_count,
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.probs, self.probs,
) = broadcast_all(total_count, probs) ) = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs) self.total_count = self.total_count.type_as(self.probs)
@ -63,7 +63,7 @@ class NegativeBinomial(Distribution):
assert logits is not None # helps mypy assert logits is not None # helps mypy
( (
self.total_count, self.total_count,
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.logits, self.logits,
) = broadcast_all(total_count, logits) ) = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits) self.total_count = self.total_count.type_as(self.logits)

View File

@ -31,7 +31,7 @@ class Normal(ExponentialFamily):
(often referred to as sigma) (often referred to as sigma)
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"loc": constraints.real, "scale": constraints.positive} arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real support = constraints.real
has_rsample = True has_rsample = True
@ -89,7 +89,7 @@ class Normal(ExponentialFamily):
if self._validate_args: if self._validate_args:
self._validate_sample(value) self._validate_sample(value)
# compute the variance # compute the variance
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
var = self.scale**2 var = self.scale**2
log_scale = ( log_scale = (
math.log(self.scale) math.log(self.scale)
@ -119,6 +119,6 @@ class Normal(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]: def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal()) return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def _log_normalizer(self, x, y): def _log_normalizer(self, x, y):
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y) return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)

View File

@ -42,7 +42,7 @@ class OneHotCategorical(Distribution):
logits (Tensor): event log probabilities (unnormalized) logits (Tensor): event log probabilities (unnormalized)
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.one_hot support = constraints.one_hot
has_enumerate_support = True has_enumerate_support = True

View File

@ -39,7 +39,7 @@ class Pareto(TransformedDistribution):
self.scale, self.alpha = broadcast_all(scale, alpha) self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args) base_dist = Exponential(self.alpha, validate_args=validate_args)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)] transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
super().__init__(base_dist, transforms, validate_args=validate_args) super().__init__(base_dist, transforms, validate_args=validate_args)
def expand( def expand(

View File

@ -32,7 +32,7 @@ class Poisson(ExponentialFamily):
rate (Number, Tensor): the rate parameter rate (Number, Tensor): the rate parameter
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"rate": constraints.nonnegative} arg_constraints = {"rate": constraints.nonnegative}
support = constraints.nonnegative_integer support = constraints.nonnegative_integer
@ -83,6 +83,6 @@ class Poisson(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]: def _natural_params(self) -> tuple[Tensor]:
return (torch.log(self.rate),) return (torch.log(self.rate),)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def _log_normalizer(self, x): def _log_normalizer(self, x):
return torch.exp(x) return torch.exp(x)

View File

@ -40,7 +40,7 @@ class LogitRelaxedBernoulli(Distribution):
(Jang et al., 2017) (Jang et al., 2017)
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.real support = constraints.real
@ -58,12 +58,12 @@ class LogitRelaxedBernoulli(Distribution):
) )
if probs is not None: if probs is not None:
is_scalar = isinstance(probs, _Number) is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs) (self.probs,) = broadcast_all(probs)
else: else:
assert logits is not None # helps mypy assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number) is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits) (self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits self._param = self.probs if probs is not None else self.logits
if is_scalar: if is_scalar:
@ -141,10 +141,10 @@ class RelaxedBernoulli(TransformedDistribution):
""" """
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real} arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
support = constraints.unit_interval support = constraints.unit_interval
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
base_dist: LogitRelaxedBernoulli base_dist: LogitRelaxedBernoulli
def __init__( def __init__(

View File

@ -38,7 +38,7 @@ class ExpRelaxedCategorical(Distribution):
(Jang et al., 2017) (Jang et al., 2017)
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = ( support = (
constraints.real_vector constraints.real_vector
@ -128,10 +128,10 @@ class RelaxedOneHotCategorical(TransformedDistribution):
""" """
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector} arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
support = constraints.simplex support = constraints.simplex
has_rsample = True has_rsample = True
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
base_dist: ExpRelaxedCategorical base_dist: ExpRelaxedCategorical
def __init__( def __init__(

View File

@ -31,7 +31,7 @@ class StudentT(Distribution):
scale (float or Tensor): scale of the distribution scale (float or Tensor): scale of the distribution
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = { arg_constraints = {
"df": constraints.positive, "df": constraints.positive,
"loc": constraints.real, "loc": constraints.real,

View File

@ -123,7 +123,7 @@ class TransformedDistribution(Distribution):
return new return new
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def support(self): def support(self):
if not self.transforms: if not self.transforms:
return self.base_dist.support return self.base_dist.support

View File

@ -226,13 +226,13 @@ class _InverseTransform(Transform):
self._inv: Transform = transform # type: ignore[assignment] self._inv: Transform = transform # type: ignore[assignment]
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def domain(self): def domain(self):
assert self._inv is not None assert self._inv is not None
return self._inv.codomain return self._inv.codomain
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def codomain(self): def codomain(self):
assert self._inv is not None assert self._inv is not None
return self._inv.domain return self._inv.domain
@ -302,7 +302,7 @@ class ComposeTransform(Transform):
return self.parts == other.parts return self.parts == other.parts
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def domain(self): def domain(self):
if not self.parts: if not self.parts:
return constraints.real return constraints.real
@ -318,7 +318,7 @@ class ComposeTransform(Transform):
return domain return domain
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def codomain(self): def codomain(self):
if not self.parts: if not self.parts:
return constraints.real return constraints.real
@ -438,14 +438,14 @@ class IndependentTransform(Transform):
) )
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def domain(self): def domain(self):
return constraints.independent( return constraints.independent(
self.base_transform.domain, self.reinterpreted_batch_ndims self.base_transform.domain, self.reinterpreted_batch_ndims
) )
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def codomain(self): def codomain(self):
return constraints.independent( return constraints.independent(
self.base_transform.codomain, self.reinterpreted_batch_ndims self.base_transform.codomain, self.reinterpreted_batch_ndims
@ -513,12 +513,12 @@ class ReshapeTransform(Transform):
super().__init__(cache_size=cache_size) super().__init__(cache_size=cache_size)
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def domain(self): def domain(self):
return constraints.independent(constraints.real, len(self.in_shape)) return constraints.independent(constraints.real, len(self.in_shape))
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def codomain(self): def codomain(self):
return constraints.independent(constraints.real, len(self.out_shape)) return constraints.independent(constraints.real, len(self.out_shape))
@ -772,14 +772,14 @@ class AffineTransform(Transform):
return self._event_dim return self._event_dim
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def domain(self): def domain(self):
if self.event_dim == 0: if self.event_dim == 0:
return constraints.real return constraints.real
return constraints.independent(constraints.real, self.event_dim) return constraints.independent(constraints.real, self.event_dim)
@constraints.dependent_property(is_discrete=False) @constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def codomain(self): def codomain(self):
if self.event_dim == 0: if self.event_dim == 0:
return constraints.real return constraints.real
@ -877,7 +877,7 @@ class CorrCholeskyTransform(Transform):
# apply stick-breaking on the squared values # apply stick-breaking on the squared values
# Note that y = sign(r) * sqrt(z * z1m_cumprod) # Note that y = sign(r) * sqrt(z * z1m_cumprod)
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod) # = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
z = r**2 z = r**2
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1) z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
# Diagonal elements must be 1. # Diagonal elements must be 1.
@ -1166,14 +1166,14 @@ class CatTransform(Transform):
return all(t.bijective for t in self.transforms) return all(t.bijective for t in self.transforms)
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def domain(self): def domain(self):
return constraints.cat( return constraints.cat(
[t.domain for t in self.transforms], self.dim, self.lengths [t.domain for t in self.transforms], self.dim, self.lengths
) )
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def codomain(self): def codomain(self):
return constraints.cat( return constraints.cat(
[t.codomain for t in self.transforms], self.dim, self.lengths [t.codomain for t in self.transforms], self.dim, self.lengths
@ -1246,12 +1246,12 @@ class StackTransform(Transform):
return all(t.bijective for t in self.transforms) return all(t.bijective for t in self.transforms)
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def domain(self): def domain(self):
return constraints.stack([t.domain for t in self.transforms], self.dim) return constraints.stack([t.domain for t in self.transforms], self.dim)
@constraints.dependent_property @constraints.dependent_property
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def codomain(self): def codomain(self):
return constraints.stack([t.codomain for t in self.transforms], self.dim) return constraints.stack([t.codomain for t in self.transforms], self.dim)

View File

@ -79,7 +79,7 @@ class Uniform(Distribution):
return new return new
@constraints.dependent_property(is_discrete=False, event_dim=0) @constraints.dependent_property(is_discrete=False, event_dim=0)
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def support(self): def support(self):
return constraints.interval(self.low, self.high) return constraints.interval(self.low, self.high)

View File

@ -92,7 +92,7 @@ def _log_modified_bessel_fn(x, order=0):
@torch.jit.script_if_tracing @torch.jit.script_if_tracing
def _rejection_sample(loc, concentration, proposal_r, x): def _rejection_sample(loc, concentration, proposal_r, x):
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device) done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
# pyrefly: ignore # bad-assignment # pyrefly: ignore [bad-assignment]
while not done.all(): while not done.all():
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device) u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
u1, u2, u3 = u.unbind() u1, u2, u3 = u.unbind()
@ -101,7 +101,7 @@ def _rejection_sample(loc, concentration, proposal_r, x):
c = concentration * (proposal_r - f) c = concentration * (proposal_r - f)
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0) accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
if accept.any(): if accept.any():
# pyrefly: ignore # no-matching-overload # pyrefly: ignore [no-matching-overload]
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x) x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
done = done | accept done = done | accept
return (x + math.pi + loc) % (2 * math.pi) - math.pi return (x + math.pi + loc) % (2 * math.pi) - math.pi
@ -125,7 +125,7 @@ class VonMises(Distribution):
:param torch.Tensor concentration: concentration parameter :param torch.Tensor concentration: concentration parameter
""" """
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive} arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
support = constraints.real support = constraints.real
has_rsample = False has_rsample = False
@ -163,10 +163,10 @@ class VonMises(Distribution):
@lazy_property @lazy_property
def _proposal_r(self) -> Tensor: def _proposal_r(self) -> Tensor:
kappa = self._concentration kappa = self._concentration
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
tau = 1 + (1 + 4 * kappa**2).sqrt() tau = 1 + (1 + 4 * kappa**2).sqrt()
rho = (tau - (2 * tau).sqrt()) / (2 * kappa) rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
# pyrefly: ignore # unsupported-operation # pyrefly: ignore [unsupported-operation]
_proposal_r = (1 + rho**2) / (2 * rho) _proposal_r = (1 + rho**2) / (2 * rho)
# second order Taylor expansion around 0 for small kappa # second order Taylor expansion around 0 for small kappa
_proposal_r_taylor = 1 / kappa + kappa _proposal_r_taylor = 1 / kappa + kappa

View File

@ -35,7 +35,7 @@ class Weibull(TransformedDistribution):
"scale": constraints.positive, "scale": constraints.positive,
"concentration": constraints.positive, "concentration": constraints.positive,
} }
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
support = constraints.positive support = constraints.positive
def __init__( def __init__(
@ -53,7 +53,7 @@ class Weibull(TransformedDistribution):
PowerTransform(exponent=self.concentration_reciprocal), PowerTransform(exponent=self.concentration_reciprocal),
AffineTransform(loc=0, scale=self.scale), AffineTransform(loc=0, scale=self.scale),
] ]
# pyrefly: ignore # bad-argument-type # pyrefly: ignore [bad-argument-type]
super().__init__(base_dist, transforms, validate_args=validate_args) super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(self, batch_shape, _instance=None): def expand(self, batch_shape, _instance=None):

View File

@ -116,13 +116,13 @@ class Wishart(ExponentialFamily):
) )
if scale_tril is not None: if scale_tril is not None:
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.scale_tril = param.expand(batch_shape + (-1, -1)) self.scale_tril = param.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None: elif covariance_matrix is not None:
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.covariance_matrix = param.expand(batch_shape + (-1, -1)) self.covariance_matrix = param.expand(batch_shape + (-1, -1))
elif precision_matrix is not None: elif precision_matrix is not None:
# pyrefly: ignore # read-only # pyrefly: ignore [read-only]
self.precision_matrix = param.expand(batch_shape + (-1, -1)) self.precision_matrix = param.expand(batch_shape + (-1, -1))
if self.df.lt(event_shape[-1]).any(): if self.df.lt(event_shape[-1]).any():
@ -339,7 +339,7 @@ class Wishart(ExponentialFamily):
p = self._event_shape[-1] # has singleton shape p = self._event_shape[-1] # has singleton shape
return -self.precision_matrix / 2, (nu - p - 1) / 2 return -self.precision_matrix / 2, (nu - p - 1) / 2
# pyrefly: ignore # bad-override # pyrefly: ignore [bad-override]
def _log_normalizer(self, x, y): def _log_normalizer(self, x, y):
p = self._event_shape[-1] p = self._event_shape[-1]
return (y + (p + 1) / 2) * ( return (y + (p + 1) / 2) * (