mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix pyrelfy ignore syntax in distributions and ao (#166248)
Ensures existing pyrefly ignores only ignore the intended error code pyrefly check lintrunner Pull Request resolved: https://github.com/pytorch/pytorch/pull/166248 Approved by: https://github.com/oulgen
This commit is contained in:
parent
a2b6afeac5
commit
154e4d36e9
|
|
@ -620,7 +620,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
|||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
padding_mode=padding_mode,
|
||||
qconfig=qconfig,
|
||||
)
|
||||
|
|
@ -821,7 +821,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
|||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
padding_mode=padding_mode,
|
||||
qconfig=qconfig,
|
||||
)
|
||||
|
|
@ -1023,7 +1023,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
|||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
padding_mode=padding_mode,
|
||||
qconfig=qconfig,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class LinearReLU(nnqat.Linear, _FusedModule):
|
|||
torch.Size([128, 30])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
_FLOAT_MODULE = nni.LinearReLU
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -30,7 +30,7 @@ class LinearReLU(nnqd.Linear):
|
|||
torch.Size([128, 30])
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
_FLOAT_MODULE = nni.LinearReLU
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class ConvReLU1d(nnq.Conv1d):
|
|||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=bias,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
padding_mode=padding_mode,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
|
|||
assert hasattr(cls, "_FLOAT_RELU_MODULE")
|
||||
relu = cls._FLOAT_RELU_MODULE()
|
||||
modules.append(relu)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
fused = cls._FLOAT_MODULE(*modules)
|
||||
fused.train(self.training)
|
||||
return fused
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class Embedding(nn.Embedding):
|
|||
scale_grad_by_freq,
|
||||
sparse,
|
||||
_weight,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
**factory_kwargs,
|
||||
)
|
||||
assert qconfig, "qconfig must be provided for QAT module"
|
||||
|
|
|
|||
|
|
@ -170,11 +170,11 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
|
||||
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
|
||||
if other.in_proj_bias is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
observed.linear_Q.bias = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
observed.linear_K.bias = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
observed.linear_V.bias = None
|
||||
else:
|
||||
observed.linear_Q.bias = nn.Parameter(
|
||||
|
|
@ -237,7 +237,7 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||
_end = _start + fp.embed_dim
|
||||
fp.in_proj_weight[_start:_end, :] = wQ
|
||||
if fp.in_proj_bias is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
assert all(bQ == 0)
|
||||
fp.in_proj_bias[_start:_end] = bQ
|
||||
|
||||
|
|
@ -245,14 +245,14 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||
_end = _start + fp.embed_dim
|
||||
fp.in_proj_weight[_start:_end, :] = wK
|
||||
if fp.in_proj_bias is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
assert all(bK == 0)
|
||||
fp.in_proj_bias[_start:_end] = bK
|
||||
|
||||
_start = _end
|
||||
fp.in_proj_weight[_start:, :] = wV
|
||||
if fp.in_proj_bias is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
assert all(bV == 0)
|
||||
fp.in_proj_bias[_start:] = bV
|
||||
else:
|
||||
|
|
@ -260,11 +260,11 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||
fp.k_proj_weight = nn.Parameter(wK)
|
||||
fp.v_proj_weight = nn.Parameter(wV)
|
||||
if fp.in_proj_bias is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.linear_Q.bias = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.linear_K.bias = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.linear_V.bias = None
|
||||
else:
|
||||
fp.in_proj_bias[0 : fp.embed_dim] = bQ
|
||||
|
|
@ -472,7 +472,7 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||
assert static_v.size(2) == head_dim
|
||||
v = static_v
|
||||
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
src_len = k.size(1)
|
||||
|
||||
if key_padding_mask is not None:
|
||||
|
|
@ -481,35 +481,35 @@ class MultiheadAttention(nn.MultiheadAttention):
|
|||
|
||||
if self.add_zero_attn:
|
||||
src_len += 1
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if k.is_quantized:
|
||||
k_zeros = torch.quantize_per_tensor(
|
||||
k_zeros,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
k.q_scale(),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
k.q_zero_point(),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
k.dtype,
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
k = torch.cat([k, k_zeros], dim=1)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if v.is_quantized:
|
||||
v_zeros = torch.quantize_per_tensor(
|
||||
v_zeros,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
v.q_scale(),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
v.q_zero_point(),
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
v.dtype,
|
||||
)
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
v = torch.cat([v, v_zeros], dim=1)
|
||||
|
||||
if attn_mask is not None:
|
||||
|
|
|
|||
|
|
@ -376,7 +376,7 @@ class _LSTMLayer(torch.nn.Module):
|
|||
bidirectional,
|
||||
split_gates=split_gates,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
layer.qconfig = getattr(other, "qconfig", qconfig)
|
||||
wi = getattr(other, f"weight_ih_l{layer_idx}")
|
||||
wh = getattr(other, f"weight_hh_l{layer_idx}")
|
||||
|
|
@ -455,7 +455,7 @@ class LSTM(torch.nn.Module):
|
|||
|
||||
if (
|
||||
not isinstance(dropout, numbers.Number)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
or not 0 <= dropout <= 1
|
||||
or isinstance(dropout, bool)
|
||||
):
|
||||
|
|
@ -464,7 +464,7 @@ class LSTM(torch.nn.Module):
|
|||
"representing the probability of an element being "
|
||||
"zeroed"
|
||||
)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
if dropout > 0:
|
||||
warnings.warn(
|
||||
"dropout option for quantizable LSTM is ignored. "
|
||||
|
|
@ -578,7 +578,7 @@ class LSTM(torch.nn.Module):
|
|||
other.bidirectional,
|
||||
split_gates=split_gates,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
observed.qconfig = getattr(other, "qconfig", qconfig)
|
||||
for idx in range(other.num_layers):
|
||||
observed.layers[idx] = _LSTMLayer.from_float(
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class Conv1d(nnq.Conv1d):
|
|||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
kernel_size = _single(kernel_size)
|
||||
stride = _single(stride)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
padding = padding if isinstance(padding, str) else _single(padding)
|
||||
dilation = _single(dilation)
|
||||
|
||||
|
|
|
|||
|
|
@ -119,9 +119,9 @@ class Linear(nnq.Linear):
|
|||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||
if type(mod) is nni.LinearReLU:
|
||||
mod = mod[0]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
weight_observer = mod.qconfig.weight()
|
||||
else:
|
||||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||
|
|
@ -145,7 +145,7 @@ class Linear(nnq.Linear):
|
|||
"Unsupported dtype specified for dynamic quantized Linear!"
|
||||
)
|
||||
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
qlinear.set_weight_bias(qweight, mod.bias)
|
||||
return qlinear
|
||||
|
||||
|
|
|
|||
|
|
@ -522,7 +522,7 @@ class LSTM(RNNBase):
|
|||
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
_FLOAT_MODULE = nn.LSTM
|
||||
|
||||
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
||||
|
|
@ -808,7 +808,7 @@ class GRU(RNNBase):
|
|||
>>> output, hn = rnn(input, h0)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
_FLOAT_MODULE = nn.GRU
|
||||
|
||||
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
||||
|
|
|
|||
|
|
@ -67,9 +67,9 @@ class Hardswish(torch.nn.Hardswish):
|
|||
def __init__(self, scale, zero_point, device=None, dtype=None):
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__()
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||
|
||||
def forward(self, input):
|
||||
|
|
@ -140,9 +140,9 @@ class LeakyReLU(torch.nn.LeakyReLU):
|
|||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__(negative_slope, inplace)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||
|
||||
def forward(self, input):
|
||||
|
|
@ -230,7 +230,7 @@ class Softmax(torch.nn.Softmax):
|
|||
|
||||
|
||||
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
|
||||
|
||||
def _get_name(self):
|
||||
|
|
|
|||
|
|
@ -12,9 +12,9 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
|
|||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -408,7 +408,7 @@ class Conv1d(_ConvNd):
|
|||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
kernel_size = _single(kernel_size)
|
||||
stride = _single(stride)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
padding = padding if isinstance(padding, str) else _single(padding)
|
||||
dilation = _single(dilation)
|
||||
|
||||
|
|
|
|||
|
|
@ -310,7 +310,7 @@ class Linear(WeightedQuantizedModule):
|
|||
# the type mismatch in assignment. Also, mypy has an issue with
|
||||
# iterables not being implemented, so we are ignoring those too.
|
||||
if not isinstance(cls._FLOAT_MODULE, Iterable):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
cls._FLOAT_MODULE = [cls._FLOAT_MODULE]
|
||||
supported_modules = ", ".join(
|
||||
[float_mod.__name__ for float_mod in cls._FLOAT_MODULE]
|
||||
|
|
|
|||
|
|
@ -37,14 +37,14 @@ class LayerNorm(torch.nn.LayerNorm):
|
|||
normalized_shape,
|
||||
eps=eps,
|
||||
elementwise_affine=elementwise_affine,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||
|
||||
def forward(self, input):
|
||||
|
|
@ -116,9 +116,9 @@ class GroupNorm(torch.nn.GroupNorm):
|
|||
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||
|
||||
def forward(self, input):
|
||||
|
|
@ -180,9 +180,9 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
|
|||
)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||
|
||||
def forward(self, input):
|
||||
|
|
@ -249,9 +249,9 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
|
|||
)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||
|
||||
def forward(self, input):
|
||||
|
|
@ -318,9 +318,9 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
|
|||
)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||
|
||||
def forward(self, input):
|
||||
|
|
|
|||
|
|
@ -141,7 +141,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
|
|||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
padding_mode,
|
||||
device,
|
||||
dtype,
|
||||
|
|
@ -206,7 +206,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
|
|||
dilation,
|
||||
groups,
|
||||
bias,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
padding_mode,
|
||||
device,
|
||||
dtype,
|
||||
|
|
@ -383,7 +383,7 @@ class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
|
|||
groups,
|
||||
bias,
|
||||
dilation,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
padding_mode,
|
||||
device,
|
||||
dtype,
|
||||
|
|
@ -465,7 +465,7 @@ class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
|
|||
groups,
|
||||
bias,
|
||||
dilation,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
padding_mode,
|
||||
device,
|
||||
dtype,
|
||||
|
|
|
|||
|
|
@ -664,7 +664,7 @@ class LSTM(RNNBase):
|
|||
if isinstance(orig_input, PackedSequence):
|
||||
output_packed = PackedSequence(
|
||||
output,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
batch_sizes,
|
||||
sorted_indices,
|
||||
unsorted_indices,
|
||||
|
|
@ -828,7 +828,7 @@ class GRU(RNNBase):
|
|||
if isinstance(orig_input, PackedSequence):
|
||||
output_packed = PackedSequence(
|
||||
output,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
batch_sizes,
|
||||
sorted_indices,
|
||||
unsorted_indices,
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class Embedding(nn.Embedding, ReferenceQuantizedModule):
|
|||
scale_grad_by_freq,
|
||||
sparse,
|
||||
_weight,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
device,
|
||||
dtype,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||
"scale": 1.0,
|
||||
"zero_point": 0,
|
||||
}
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
|
||||
self.weight_dtype = weight_qparams["dtype"]
|
||||
assert self.weight_qscheme in [
|
||||
|
|
@ -81,14 +81,14 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||
self.register_buffer(
|
||||
"weight_axis", torch.tensor(0, dtype=torch.int, device=device)
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
|
||||
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
|
||||
# for capturing `.item` operations
|
||||
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min")
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max")
|
||||
|
||||
def get_weight(self):
|
||||
|
|
@ -105,7 +105,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||
return _quantize_and_dequantize_weight_decomposed(
|
||||
self.weight, # type: ignore[arg-type]
|
||||
self.weight_qscheme,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.weight_dtype,
|
||||
self.weight_scale,
|
||||
self.weight_zero_point,
|
||||
|
|
@ -117,7 +117,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||
return _quantize_and_dequantize_weight(
|
||||
self.weight, # type: ignore[arg-type]
|
||||
self.weight_qscheme,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.weight_dtype,
|
||||
self.weight_scale,
|
||||
self.weight_zero_point,
|
||||
|
|
@ -133,7 +133,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||
return _quantize_weight_decomposed(
|
||||
self.weight, # type: ignore[arg-type]
|
||||
self.weight_qscheme,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.weight_dtype,
|
||||
self.weight_scale,
|
||||
self.weight_zero_point,
|
||||
|
|
@ -145,7 +145,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
|||
return _quantize_weight(
|
||||
self.weight, # type: ignore[arg-type]
|
||||
self.weight_qscheme,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.weight_dtype,
|
||||
self.weight_scale,
|
||||
self.weight_zero_point,
|
||||
|
|
|
|||
|
|
@ -151,9 +151,9 @@ class Linear(torch.nn.Module):
|
|||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||
if type(mod) is nni.LinearReLU:
|
||||
mod = mod[0]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
weight_observer = mod.qconfig.weight()
|
||||
else:
|
||||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||
|
|
@ -187,6 +187,6 @@ class Linear(torch.nn.Module):
|
|||
col_block_size,
|
||||
dtype=dtype,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
|
||||
return qlinear
|
||||
|
|
|
|||
|
|
@ -959,7 +959,7 @@ def create_a_shadows_b(
|
|||
if should_log_inputs:
|
||||
# skip the input logger when inserting a dtype cast
|
||||
if isinstance(prev_node_c, Node):
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
|
||||
elif isinstance(prev_node_c, list):
|
||||
prev_node_c = [
|
||||
|
|
@ -968,7 +968,7 @@ def create_a_shadows_b(
|
|||
]
|
||||
dtype_cast_node = _insert_dtype_cast_after_node(
|
||||
subgraph_a.start_node,
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
node_c,
|
||||
prev_node_c,
|
||||
gm_a,
|
||||
|
|
@ -1052,7 +1052,7 @@ def create_a_shadows_b(
|
|||
if num_non_param_args_node_a == 2:
|
||||
# node_c_second_non_param_arg = node_c.args[1]
|
||||
node_c_second_non_param_arg = get_normalized_nth_input(
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
node_c,
|
||||
gm_b,
|
||||
1,
|
||||
|
|
@ -1063,7 +1063,7 @@ def create_a_shadows_b(
|
|||
subgraph_a,
|
||||
gm_a,
|
||||
gm_b,
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
node_c.name + "_shadow_copy_",
|
||||
)
|
||||
env_c[node_a_shadows_c.name] = node_a_shadows_c
|
||||
|
|
@ -1086,19 +1086,19 @@ def create_a_shadows_b(
|
|||
cur_node = node_a_shadows_c
|
||||
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
|
||||
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if isinstance(input_logger, Node):
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
input_logger_mod = getattr(gm_b, input_logger.name)
|
||||
input_logger_mod.ref_node_name = cur_node.name
|
||||
else:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if not isinstance(input_logger, list):
|
||||
raise AssertionError(
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
f"Expected list, got {type(input_logger)}"
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
for input_logger_inner in input_logger:
|
||||
input_logger_mod = getattr(gm_b, input_logger_inner.name)
|
||||
input_logger_mod.ref_node_name = cur_node.name
|
||||
|
|
|
|||
|
|
@ -419,7 +419,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
|
|||
target2,
|
||||
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
|
||||
new_connections.append((source, target1))
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
new_connections.append((source, target2))
|
||||
|
||||
for source_to_target in (
|
||||
|
|
@ -428,7 +428,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
|
|||
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
|
||||
):
|
||||
for source, target in source_to_target.items(): # type:ignore[assignment]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
new_connections.append((source, target))
|
||||
|
||||
#
|
||||
|
|
|
|||
|
|
@ -94,10 +94,10 @@ class OutputProp:
|
|||
)
|
||||
|
||||
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
node.traced_result = result
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
env[node.name] = result
|
||||
|
||||
return None
|
||||
|
|
@ -403,10 +403,10 @@ def create_submodule_from_subgraph(
|
|||
cur_name_idx += 1
|
||||
setattr(gm, mod_name, new_arg)
|
||||
new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
cur_args_copy.append(new_arg_placeholder)
|
||||
elif isinstance(arg, (float, int, torch.dtype)):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
cur_args_copy.append(arg)
|
||||
else:
|
||||
raise AssertionError(f"arg of type {type(arg)} not handled yet")
|
||||
|
|
@ -818,7 +818,7 @@ def create_add_loggers_graph(
|
|||
model,
|
||||
cur_subgraph_idx,
|
||||
match_name,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
maybe_subgraph,
|
||||
[qconfig_mapping],
|
||||
[node_name_to_qconfig],
|
||||
|
|
@ -879,7 +879,7 @@ def create_add_loggers_graph(
|
|||
cur_node_orig = first_node
|
||||
cur_node_copy = None
|
||||
first_node_copy = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
while cur_node_orig in subgraph_to_use:
|
||||
# TODO(future PR): make this support all possible args/kwargs
|
||||
if cur_node_orig is first_node:
|
||||
|
|
|
|||
|
|
@ -496,7 +496,7 @@ def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso
|
|||
Return:
|
||||
float or tuple of floats
|
||||
"""
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,17 +23,17 @@ class SparseDLRM(DLRM_Net):
|
|||
super().__init__(**args)
|
||||
|
||||
def forward(self, dense_x, lS_o, lS_i):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
x = self.apply_mlp(dense_x, self.bot_l) # dense features
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
z = self.interact_features(x, ly)
|
||||
|
||||
z = z.to_sparse_coo()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for layer in self.top_l[1:]:
|
||||
z = layer(z)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def run_forward(model, **batch):
|
|||
model(X, lS_o, lS_i)
|
||||
end = time.time()
|
||||
time_taken = end - start
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
time_list.append(time_taken)
|
||||
avg_time = np.mean(time_list[1:])
|
||||
return avg_time
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class FPGMPruner(BaseStructuredSparsifier):
|
|||
dist_matrix = self.dist_fn(t_flatten)
|
||||
|
||||
# more similar with other filter indicates large in the sum of row
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
distance = torch.sum(torch.abs(dist_matrix), 1)
|
||||
|
||||
return distance
|
||||
|
|
|
|||
|
|
@ -260,7 +260,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
|
|||
module.register_parameter(
|
||||
"_bias", nn.Parameter(module.bias.detach())
|
||||
)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
module.bias = None
|
||||
module.prune_bias = prune_bias
|
||||
|
||||
|
|
|
|||
|
|
@ -97,7 +97,7 @@ def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]:
|
|||
if module.bias is not None:
|
||||
module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
|
||||
elif getattr(module, "_bias", None) is not None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
|
||||
|
||||
# get pruned biases to propagate to subsequent layer
|
||||
|
|
@ -127,7 +127,7 @@ def _prune_linear_helper(linear: nn.Linear) -> Tensor:
|
|||
linear.out_features = linear.weight.shape[0]
|
||||
_remove_bias_handles(linear)
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return mask
|
||||
|
||||
|
||||
|
|
@ -186,7 +186,7 @@ def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
|
|||
conv2d.out_channels = conv2d.weight.shape[0]
|
||||
|
||||
_remove_bias_handles(conv2d)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return mask
|
||||
|
||||
|
||||
|
|
@ -207,7 +207,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
|
|||
new_bias = torch.zeros(conv2d_1.bias.shape)
|
||||
new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined]
|
||||
# adjusted bias that to keep in conv2d_1
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
|
||||
# pruned biases that are kept instead of propagated
|
||||
conv2d_1.bias = nn.Parameter(new_bias)
|
||||
|
|
|
|||
|
|
@ -170,7 +170,7 @@ class BaseSparsifier(abc.ABC):
|
|||
self.make_config_from_model(model)
|
||||
|
||||
# TODO: Remove the configuration by reference ('module')
|
||||
# pyrefly: ignore # not-iterable
|
||||
# pyrefly: ignore [not-iterable]
|
||||
for module_config in self.config:
|
||||
assert isinstance(module_config, dict), (
|
||||
"config elements should be dicts not modules i.e.:"
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ def swap_module(
|
|||
new_mod.register_forward_hook(hook_fn)
|
||||
|
||||
# respect device affinity when swapping modules
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
|
||||
assert len(devices) <= 1, (
|
||||
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
|
||||
|
|
|
|||
|
|
@ -235,7 +235,7 @@ class WeightNormSparsifier(BaseSparsifier):
|
|||
ww = self.norm_fn(getattr(module, tensor_name))
|
||||
tensor_mask = self._make_tensor_mask(
|
||||
data=ww,
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
input_shape=ww.shape,
|
||||
sparsity_level=sparsity_level,
|
||||
sparse_block_shape=sparse_block_shape,
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from .pt2e.export_utils import (
|
|||
_move_exported_model_to_train as move_exported_model_to_train,
|
||||
)
|
||||
|
||||
# pyrefly: ignore # deprecated
|
||||
# pyrefly: ignore [deprecated]
|
||||
from .qconfig import * # noqa: F403
|
||||
from .qconfig_mapping import * # noqa: F403
|
||||
from .quant_type import * # noqa: F403
|
||||
|
|
|
|||
|
|
@ -185,9 +185,9 @@ class FakeQuantize(FakeQuantizeBase):
|
|||
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
|
||||
"dtype", dtype
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
|
||||
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
|
||||
observer_kwargs["is_dynamic"] = is_dynamic
|
||||
|
|
|
|||
|
|
@ -1149,7 +1149,7 @@ quantized_decomposed_lib.define(
|
|||
|
||||
class FakeQuantPerChannel(torch.autograd.Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
|
||||
if scales.dtype != torch.float32:
|
||||
scales = scales.to(torch.float32)
|
||||
|
|
@ -1172,7 +1172,7 @@ class FakeQuantPerChannel(torch.autograd.Function):
|
|||
return out
|
||||
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, gy):
|
||||
(mask,) = ctx.saved_tensors
|
||||
return gy * mask, None, None, None, None, None
|
||||
|
|
|
|||
|
|
@ -248,7 +248,7 @@ def calculate_equalization_scale(
|
|||
|
||||
|
||||
class EqualizationQConfig(
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
namedtuple("EqualizationQConfig", ["input_activation", "weight"])
|
||||
):
|
||||
"""
|
||||
|
|
@ -463,7 +463,7 @@ def maybe_get_next_equalization_scale(
|
|||
In this case, the node given is linear1 and we want to locate the InputEqObs.
|
||||
"""
|
||||
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
|
||||
# pyrefly: ignore # invalid-argument
|
||||
# pyrefly: ignore [invalid-argument]
|
||||
if next_inp_eq_obs:
|
||||
if (
|
||||
next_inp_eq_obs.equalization_scale.nelement() == 1
|
||||
|
|
@ -827,7 +827,7 @@ def convert_eq_obs(
|
|||
scale_weight_node(
|
||||
node,
|
||||
modules,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
equalization_scale,
|
||||
maybe_next_equalization_scale,
|
||||
)
|
||||
|
|
@ -836,7 +836,7 @@ def convert_eq_obs(
|
|||
node,
|
||||
model,
|
||||
modules,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
equalization_scale,
|
||||
maybe_next_equalization_scale,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -223,7 +223,7 @@ class ModelReportVisualizer:
|
|||
feature_val = feature_val.item()
|
||||
|
||||
# we add to our list of values
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
tensor_table_row.append(feature_val)
|
||||
|
||||
tensor_table.append(tensor_table_row)
|
||||
|
|
@ -284,7 +284,7 @@ class ModelReportVisualizer:
|
|||
feature_val = feature_val.item()
|
||||
|
||||
# add value to channel specific row
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
new_channel_row.append(feature_val)
|
||||
|
||||
# add to table and increment row index counter
|
||||
|
|
|
|||
|
|
@ -166,7 +166,7 @@ def _create_obs_or_fq_from_qspec(
|
|||
}
|
||||
edge_or_nodes = quantization_spec.derived_from
|
||||
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
kwargs["obs_or_fqs"] = obs_or_fqs
|
||||
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
|
||||
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
|
||||
|
|
@ -2088,11 +2088,11 @@ def prepare(
|
|||
|
||||
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_update_qconfig_for_fusion(model, qconfig_mapping)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_update_qconfig_for_fusion(model, _equalization_config)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
|
||||
# TODO: support regex as well
|
||||
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
|
||||
|
|
@ -2100,7 +2100,7 @@ def prepare(
|
|||
if is_qat:
|
||||
module_to_qat_module = get_module_to_qat_module(backend_config)
|
||||
_qat_swap_modules(model, module_to_qat_module)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_update_qconfig_for_qat(qconfig_mapping, backend_config)
|
||||
|
||||
# mapping from fully qualified module name to module instance
|
||||
|
|
@ -2117,7 +2117,7 @@ def prepare(
|
|||
model,
|
||||
named_modules,
|
||||
model.graph,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_equalization_config,
|
||||
node_name_to_scope,
|
||||
)
|
||||
|
|
@ -2125,7 +2125,7 @@ def prepare(
|
|||
model,
|
||||
named_modules,
|
||||
model.graph,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
qconfig_mapping,
|
||||
node_name_to_scope,
|
||||
)
|
||||
|
|
@ -2187,7 +2187,7 @@ def prepare(
|
|||
node_name_to_scope,
|
||||
prepare_custom_config,
|
||||
equalization_node_name_to_qconfig,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
qconfig_mapping,
|
||||
is_qat,
|
||||
observed_node_names,
|
||||
|
|
|
|||
|
|
@ -721,7 +721,7 @@ def _maybe_get_custom_module_lstm_from_node_arg(
|
|||
a = a.args[0][0] # type: ignore[assignment,index]
|
||||
else:
|
||||
a = a.args[0] # type: ignore[assignment]
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return a
|
||||
|
||||
all_match_patterns = [
|
||||
|
|
|
|||
|
|
@ -281,12 +281,12 @@ class UniformQuantizationObserverBase(ObserverBase):
|
|||
)
|
||||
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
|
||||
if self.has_customized_qrange:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
validate_qmin_qmax(quant_min, quant_max)
|
||||
self.quant_min, self.quant_max = calculate_qmin_qmax(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
quant_min,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
quant_max,
|
||||
self.has_customized_qrange,
|
||||
self.dtype,
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class QConfig(namedtuple("QConfig", ["activation", "weight"])):
|
||||
"""
|
||||
Describes how to quantize a layer or a part of the network by providing
|
||||
|
|
@ -121,7 +121,7 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
|
|||
"`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead",
|
||||
category=FutureWarning,
|
||||
)
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
|
||||
"""
|
||||
Describes how to dynamically quantize a layer or a part of the network by providing
|
||||
|
|
|
|||
|
|
@ -418,7 +418,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
# As we use `_need_skip_config` to skip all invalid configurations,
|
||||
# we can safely assume that the all existing non-None configurations
|
||||
# have the same quantization mode.
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for qconfig in (
|
||||
list(self.module_name_qconfig.values())
|
||||
+ list(self.operator_type_qconfig.values())
|
||||
|
|
@ -818,7 +818,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
)
|
||||
binary_node.meta[QUANT_ANNOTATION_KEY] = (
|
||||
_X86InductorQuantizationAnnotation(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
input_qspec_map=binary_node_input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
|
|
@ -889,7 +889,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
)
|
||||
binary_node.meta[QUANT_ANNOTATION_KEY] = (
|
||||
_X86InductorQuantizationAnnotation(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
input_qspec_map=binary_node_input_qspec_map,
|
||||
# TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
|
||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
|
|
@ -1097,7 +1097,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
quantization_config
|
||||
)
|
||||
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
input_qspec_map=binary_node_input_qspec_map,
|
||||
_annotated=True,
|
||||
)
|
||||
|
|
@ -1152,7 +1152,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
quantization_config
|
||||
)
|
||||
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
input_qspec_map=binary_node_input_qspec_map,
|
||||
_annotated=True,
|
||||
_is_output_of_quantized_pattern=True,
|
||||
|
|
@ -1508,7 +1508,7 @@ class X86InductorQuantizer(Quantizer):
|
|||
has_unary = unary_op is not None
|
||||
seq_partition = [torch.nn.Linear, binary_op]
|
||||
if has_unary:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
seq_partition.append(unary_op)
|
||||
fused_partitions = find_sequential_partitions(gm, seq_partition)
|
||||
for fused_partition in fused_partitions:
|
||||
|
|
|
|||
|
|
@ -376,11 +376,11 @@ def _do_annotate_conv_relu(
|
|||
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
||||
partition.append(bias)
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if _is_annotated(partition):
|
||||
continue
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if filter_fn and any(not filter_fn(n) for n in partition):
|
||||
continue
|
||||
|
||||
|
|
@ -391,7 +391,7 @@ def _do_annotate_conv_relu(
|
|||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||
_annotated=True,
|
||||
)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
_mark_nodes_as_annotated(partition)
|
||||
annotated_partitions.append(partition)
|
||||
return annotated_partitions
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class Bernoulli(ExponentialFamily):
|
|||
validate_args (bool, optional): whether to validate arguments, None by default
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||
support = constraints.boolean
|
||||
has_enumerate_support = True
|
||||
|
|
@ -57,12 +57,12 @@ class Bernoulli(ExponentialFamily):
|
|||
)
|
||||
if probs is not None:
|
||||
is_scalar = isinstance(probs, _Number)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
(self.probs,) = broadcast_all(probs)
|
||||
else:
|
||||
assert logits is not None # helps mypy
|
||||
is_scalar = isinstance(logits, _Number)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
(self.logits,) = broadcast_all(logits)
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
if is_scalar:
|
||||
|
|
@ -140,6 +140,6 @@ class Bernoulli(ExponentialFamily):
|
|||
def _natural_params(self) -> tuple[Tensor]:
|
||||
return (torch.logit(self.probs),)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _log_normalizer(self, x):
|
||||
return torch.log1p(torch.exp(x))
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class Beta(ExponentialFamily):
|
|||
(often referred to as beta)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {
|
||||
"concentration1": constraints.positive,
|
||||
"concentration0": constraints.positive,
|
||||
|
|
@ -114,6 +114,6 @@ class Beta(ExponentialFamily):
|
|||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||
return (self.concentration1, self.concentration0)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _log_normalizer(self, x, y):
|
||||
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
|
||||
|
|
|
|||
|
|
@ -45,7 +45,7 @@ class Binomial(Distribution):
|
|||
logits (Tensor): Event log-odds
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {
|
||||
"total_count": constraints.nonnegative_integer,
|
||||
"probs": constraints.unit_interval,
|
||||
|
|
@ -67,7 +67,7 @@ class Binomial(Distribution):
|
|||
if probs is not None:
|
||||
(
|
||||
self.total_count,
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.probs,
|
||||
) = broadcast_all(total_count, probs)
|
||||
self.total_count = self.total_count.type_as(self.probs)
|
||||
|
|
@ -75,7 +75,7 @@ class Binomial(Distribution):
|
|||
assert logits is not None # helps mypy
|
||||
(
|
||||
self.total_count,
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.logits,
|
||||
) = broadcast_all(total_count, logits)
|
||||
self.total_count = self.total_count.type_as(self.logits)
|
||||
|
|
@ -102,7 +102,7 @@ class Binomial(Distribution):
|
|||
return self._param.new(*args, **kwargs)
|
||||
|
||||
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def support(self):
|
||||
return constraints.integer_interval(0, self.total_count)
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class Categorical(Distribution):
|
|||
logits (Tensor): event log probabilities (unnormalized)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||
has_enumerate_support = True
|
||||
|
||||
|
|
@ -67,14 +67,14 @@ class Categorical(Distribution):
|
|||
if probs is not None:
|
||||
if probs.dim() < 1:
|
||||
raise ValueError("`probs` parameter must be at least one-dimensional.")
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.probs = probs / probs.sum(-1, keepdim=True)
|
||||
else:
|
||||
assert logits is not None # helps mypy
|
||||
if logits.dim() < 1:
|
||||
raise ValueError("`logits` parameter must be at least one-dimensional.")
|
||||
# Normalize
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
self._num_events = self._param.size()[-1]
|
||||
|
|
@ -102,7 +102,7 @@ class Categorical(Distribution):
|
|||
return self._param.new(*args, **kwargs)
|
||||
|
||||
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def support(self):
|
||||
return constraints.integer_interval(0, self._num_events - 1)
|
||||
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class Cauchy(Distribution):
|
|||
scale (float or Tensor): half width at half maximum.
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
support = constraints.real
|
||||
has_rsample = True
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class ContinuousBernoulli(ExponentialFamily):
|
|||
https://arxiv.org/abs/1907.06845
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||
support = constraints.unit_interval
|
||||
_mean_carrier_measure = 0
|
||||
|
|
@ -66,19 +66,19 @@ class ContinuousBernoulli(ExponentialFamily):
|
|||
)
|
||||
if probs is not None:
|
||||
is_scalar = isinstance(probs, _Number)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
(self.probs,) = broadcast_all(probs)
|
||||
# validate 'probs' here if necessary as it is later clamped for numerical stability
|
||||
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
|
||||
if validate_args is not None:
|
||||
if not self.arg_constraints["probs"].check(self.probs).all():
|
||||
raise ValueError("The parameter probs has invalid values")
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.probs = clamp_probs(self.probs)
|
||||
else:
|
||||
assert logits is not None # helps mypy
|
||||
is_scalar = isinstance(logits, _Number)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
(self.logits,) = broadcast_all(logits)
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
if is_scalar:
|
||||
|
|
@ -234,7 +234,7 @@ class ContinuousBernoulli(ExponentialFamily):
|
|||
def _natural_params(self) -> tuple[Tensor]:
|
||||
return (self.logits,)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _log_normalizer(self, x):
|
||||
"""computes the log normalizing constant as a function of the natural parameter"""
|
||||
out_unst_reg = torch.max(
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def _Dirichlet_backward(x, concentration, grad_output):
|
|||
|
||||
class _Dirichlet(Function):
|
||||
@staticmethod
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def forward(ctx, concentration):
|
||||
x = torch._sample_dirichlet(concentration)
|
||||
ctx.save_for_backward(x, concentration)
|
||||
|
|
@ -30,7 +30,7 @@ class _Dirichlet(Function):
|
|||
|
||||
@staticmethod
|
||||
@once_differentiable
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def backward(ctx, grad_output):
|
||||
x, concentration = ctx.saved_tensors
|
||||
return _Dirichlet_backward(x, concentration, grad_output)
|
||||
|
|
@ -52,7 +52,7 @@ class Dirichlet(ExponentialFamily):
|
|||
(often referred to as alpha)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {
|
||||
"concentration": constraints.independent(constraints.positive, 1)
|
||||
}
|
||||
|
|
@ -133,6 +133,6 @@ class Dirichlet(ExponentialFamily):
|
|||
def _natural_params(self) -> tuple[Tensor]:
|
||||
return (self.concentration,)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _log_normalizer(self, x):
|
||||
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ class Exponential(ExponentialFamily):
|
|||
rate (float or Tensor): rate = 1 / scale of the distribution
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"rate": constraints.positive}
|
||||
support = constraints.nonnegative
|
||||
has_rsample = True
|
||||
|
|
@ -90,6 +90,6 @@ class Exponential(ExponentialFamily):
|
|||
def _natural_params(self) -> tuple[Tensor]:
|
||||
return (-self.rate,)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _log_normalizer(self, x):
|
||||
return -torch.log(-x)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ class FisherSnedecor(Distribution):
|
|||
df2 (float or Tensor): degrees of freedom parameter 2
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
|
||||
support = constraints.positive
|
||||
has_rsample = True
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ class Gamma(ExponentialFamily):
|
|||
(often referred to as beta), rate = 1 / scale
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {
|
||||
"concentration": constraints.positive,
|
||||
"rate": constraints.positive,
|
||||
|
|
@ -110,7 +110,7 @@ class Gamma(ExponentialFamily):
|
|||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||
return (self.concentration - 1, -self.rate)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _log_normalizer(self, x, y):
|
||||
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class GeneralizedPareto(Distribution):
|
|||
concentration (float or Tensor): Concentration parameter of the distribution
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {
|
||||
"loc": constraints.real,
|
||||
"scale": constraints.positive,
|
||||
|
|
@ -131,7 +131,7 @@ class GeneralizedPareto(Distribution):
|
|||
concentration = self.concentration
|
||||
valid = concentration < 0.5
|
||||
safe_conc = torch.where(valid, concentration, 0.25)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
|
||||
return torch.where(valid, result, nan)
|
||||
|
||||
|
|
@ -144,7 +144,7 @@ class GeneralizedPareto(Distribution):
|
|||
return self.loc
|
||||
|
||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def support(self):
|
||||
lower = self.loc
|
||||
upper = torch.where(
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class Geometric(Distribution):
|
|||
logits (Number, Tensor): the log-odds of sampling `1`.
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||
support = constraints.nonnegative_integer
|
||||
|
||||
|
|
@ -59,11 +59,11 @@ class Geometric(Distribution):
|
|||
"Either `probs` or `logits` must be specified, but not both."
|
||||
)
|
||||
if probs is not None:
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
(self.probs,) = broadcast_all(probs)
|
||||
else:
|
||||
assert logits is not None # helps mypy
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
(self.logits,) = broadcast_all(logits)
|
||||
probs_or_logits = probs if probs is not None else logits
|
||||
if isinstance(probs_or_logits, _Number):
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class Gumbel(TransformedDistribution):
|
|||
"""
|
||||
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
support = constraints.real
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -32,10 +32,10 @@ class HalfCauchy(TransformedDistribution):
|
|||
"""
|
||||
|
||||
arg_constraints = {"scale": constraints.positive}
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
support = constraints.nonnegative
|
||||
has_rsample = True
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
base_dist: Cauchy
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -32,10 +32,10 @@ class HalfNormal(TransformedDistribution):
|
|||
"""
|
||||
|
||||
arg_constraints = {"scale": constraints.positive}
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
support = constraints.nonnegative
|
||||
has_rsample = True
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
base_dist: Normal
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -91,7 +91,7 @@ class Independent(Distribution, Generic[D]):
|
|||
return self.base_dist.has_enumerate_support
|
||||
|
||||
@constraints.dependent_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def support(self):
|
||||
result = self.base_dist.support
|
||||
if self.reinterpreted_batch_ndims:
|
||||
|
|
|
|||
|
|
@ -38,10 +38,10 @@ class InverseGamma(TransformedDistribution):
|
|||
"concentration": constraints.positive,
|
||||
"rate": constraints.positive,
|
||||
}
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
support = constraints.positive
|
||||
has_rsample = True
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
base_dist: Gamma
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ class Kumaraswamy(TransformedDistribution):
|
|||
"concentration1": constraints.positive,
|
||||
"concentration0": constraints.positive,
|
||||
}
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
support = constraints.unit_interval
|
||||
has_rsample = True
|
||||
|
||||
|
|
@ -67,7 +67,7 @@ class Kumaraswamy(TransformedDistribution):
|
|||
AffineTransform(loc=1.0, scale=-1.0),
|
||||
PowerTransform(exponent=self.concentration1.reciprocal()),
|
||||
]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class Laplace(Distribution):
|
|||
scale (float or Tensor): scale of the distribution
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
support = constraints.real
|
||||
has_rsample = True
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ class LKJCholesky(Distribution):
|
|||
Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"concentration": constraints.positive}
|
||||
support = constraints.corr_cholesky
|
||||
|
||||
|
|
|
|||
|
|
@ -32,10 +32,10 @@ class LogNormal(TransformedDistribution):
|
|||
"""
|
||||
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
support = constraints.positive
|
||||
has_rsample = True
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
base_dist: Normal
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -36,10 +36,10 @@ class LogisticNormal(TransformedDistribution):
|
|||
"""
|
||||
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
support = constraints.simplex
|
||||
has_rsample = True
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
base_dist: Independent[Normal]
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ class LowRankMultivariateNormal(Distribution):
|
|||
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {
|
||||
"loc": constraints.real_vector,
|
||||
"cov_factor": constraints.independent(constraints.real, 2),
|
||||
|
|
|
|||
|
|
@ -124,7 +124,7 @@ class MixtureSameFamily(Distribution):
|
|||
return new
|
||||
|
||||
@constraints.dependent_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def support(self):
|
||||
return MixtureSameFamilyConstraint(self._component_distribution.support)
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class Multinomial(Distribution):
|
|||
logits (Tensor): event log probabilities (unnormalized)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||
total_count: int
|
||||
|
||||
|
|
@ -93,7 +93,7 @@ class Multinomial(Distribution):
|
|||
return self._categorical._new(*args, **kwargs)
|
||||
|
||||
@constraints.dependent_property(is_discrete=True, event_dim=1)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def support(self):
|
||||
return constraints.multinomial(self.total_count)
|
||||
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ class MultivariateNormal(Distribution):
|
|||
the corresponding lower triangular matrices using a Cholesky decomposition.
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {
|
||||
"loc": constraints.real_vector,
|
||||
"covariance_matrix": constraints.positive_definite,
|
||||
|
|
@ -157,7 +157,7 @@ class MultivariateNormal(Distribution):
|
|||
"with optional leading batch dimensions"
|
||||
)
|
||||
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
|
||||
elif covariance_matrix is not None:
|
||||
if covariance_matrix.dim() < 2:
|
||||
|
|
@ -168,7 +168,7 @@ class MultivariateNormal(Distribution):
|
|||
batch_shape = torch.broadcast_shapes(
|
||||
covariance_matrix.shape[:-2], loc.shape[:-1]
|
||||
)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
|
||||
else:
|
||||
assert precision_matrix is not None # helps mypy
|
||||
|
|
@ -180,7 +180,7 @@ class MultivariateNormal(Distribution):
|
|||
batch_shape = torch.broadcast_shapes(
|
||||
precision_matrix.shape[:-2], loc.shape[:-1]
|
||||
)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
|
||||
self.loc = loc.expand(batch_shape + (-1,))
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class NegativeBinomial(Distribution):
|
|||
logits (Tensor): Event log-odds for probabilities of success
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {
|
||||
"total_count": constraints.greater_than_eq(0),
|
||||
"probs": constraints.half_open_interval(0.0, 1.0),
|
||||
|
|
@ -55,7 +55,7 @@ class NegativeBinomial(Distribution):
|
|||
if probs is not None:
|
||||
(
|
||||
self.total_count,
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.probs,
|
||||
) = broadcast_all(total_count, probs)
|
||||
self.total_count = self.total_count.type_as(self.probs)
|
||||
|
|
@ -63,7 +63,7 @@ class NegativeBinomial(Distribution):
|
|||
assert logits is not None # helps mypy
|
||||
(
|
||||
self.total_count,
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.logits,
|
||||
) = broadcast_all(total_count, logits)
|
||||
self.total_count = self.total_count.type_as(self.logits)
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class Normal(ExponentialFamily):
|
|||
(often referred to as sigma)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||
support = constraints.real
|
||||
has_rsample = True
|
||||
|
|
@ -89,7 +89,7 @@ class Normal(ExponentialFamily):
|
|||
if self._validate_args:
|
||||
self._validate_sample(value)
|
||||
# compute the variance
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
var = self.scale**2
|
||||
log_scale = (
|
||||
math.log(self.scale)
|
||||
|
|
@ -119,6 +119,6 @@ class Normal(ExponentialFamily):
|
|||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _log_normalizer(self, x, y):
|
||||
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class OneHotCategorical(Distribution):
|
|||
logits (Tensor): event log probabilities (unnormalized)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||
support = constraints.one_hot
|
||||
has_enumerate_support = True
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class Pareto(TransformedDistribution):
|
|||
self.scale, self.alpha = broadcast_all(scale, alpha)
|
||||
base_dist = Exponential(self.alpha, validate_args=validate_args)
|
||||
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||
|
||||
def expand(
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ class Poisson(ExponentialFamily):
|
|||
rate (Number, Tensor): the rate parameter
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"rate": constraints.nonnegative}
|
||||
support = constraints.nonnegative_integer
|
||||
|
||||
|
|
@ -83,6 +83,6 @@ class Poisson(ExponentialFamily):
|
|||
def _natural_params(self) -> tuple[Tensor]:
|
||||
return (torch.log(self.rate),)
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _log_normalizer(self, x):
|
||||
return torch.exp(x)
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class LogitRelaxedBernoulli(Distribution):
|
|||
(Jang et al., 2017)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||
support = constraints.real
|
||||
|
||||
|
|
@ -58,12 +58,12 @@ class LogitRelaxedBernoulli(Distribution):
|
|||
)
|
||||
if probs is not None:
|
||||
is_scalar = isinstance(probs, _Number)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
(self.probs,) = broadcast_all(probs)
|
||||
else:
|
||||
assert logits is not None # helps mypy
|
||||
is_scalar = isinstance(logits, _Number)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
(self.logits,) = broadcast_all(logits)
|
||||
self._param = self.probs if probs is not None else self.logits
|
||||
if is_scalar:
|
||||
|
|
@ -141,10 +141,10 @@ class RelaxedBernoulli(TransformedDistribution):
|
|||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
support = constraints.unit_interval
|
||||
has_rsample = True
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
base_dist: LogitRelaxedBernoulli
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ class ExpRelaxedCategorical(Distribution):
|
|||
(Jang et al., 2017)
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||
support = (
|
||||
constraints.real_vector
|
||||
|
|
@ -128,10 +128,10 @@ class RelaxedOneHotCategorical(TransformedDistribution):
|
|||
"""
|
||||
|
||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
support = constraints.simplex
|
||||
has_rsample = True
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
base_dist: ExpRelaxedCategorical
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ class StudentT(Distribution):
|
|||
scale (float or Tensor): scale of the distribution
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {
|
||||
"df": constraints.positive,
|
||||
"loc": constraints.real,
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ class TransformedDistribution(Distribution):
|
|||
return new
|
||||
|
||||
@constraints.dependent_property(is_discrete=False)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def support(self):
|
||||
if not self.transforms:
|
||||
return self.base_dist.support
|
||||
|
|
|
|||
|
|
@ -226,13 +226,13 @@ class _InverseTransform(Transform):
|
|||
self._inv: Transform = transform # type: ignore[assignment]
|
||||
|
||||
@constraints.dependent_property(is_discrete=False)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def domain(self):
|
||||
assert self._inv is not None
|
||||
return self._inv.codomain
|
||||
|
||||
@constraints.dependent_property(is_discrete=False)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def codomain(self):
|
||||
assert self._inv is not None
|
||||
return self._inv.domain
|
||||
|
|
@ -302,7 +302,7 @@ class ComposeTransform(Transform):
|
|||
return self.parts == other.parts
|
||||
|
||||
@constraints.dependent_property(is_discrete=False)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def domain(self):
|
||||
if not self.parts:
|
||||
return constraints.real
|
||||
|
|
@ -318,7 +318,7 @@ class ComposeTransform(Transform):
|
|||
return domain
|
||||
|
||||
@constraints.dependent_property(is_discrete=False)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def codomain(self):
|
||||
if not self.parts:
|
||||
return constraints.real
|
||||
|
|
@ -438,14 +438,14 @@ class IndependentTransform(Transform):
|
|||
)
|
||||
|
||||
@constraints.dependent_property(is_discrete=False)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def domain(self):
|
||||
return constraints.independent(
|
||||
self.base_transform.domain, self.reinterpreted_batch_ndims
|
||||
)
|
||||
|
||||
@constraints.dependent_property(is_discrete=False)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def codomain(self):
|
||||
return constraints.independent(
|
||||
self.base_transform.codomain, self.reinterpreted_batch_ndims
|
||||
|
|
@ -513,12 +513,12 @@ class ReshapeTransform(Transform):
|
|||
super().__init__(cache_size=cache_size)
|
||||
|
||||
@constraints.dependent_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def domain(self):
|
||||
return constraints.independent(constraints.real, len(self.in_shape))
|
||||
|
||||
@constraints.dependent_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def codomain(self):
|
||||
return constraints.independent(constraints.real, len(self.out_shape))
|
||||
|
||||
|
|
@ -772,14 +772,14 @@ class AffineTransform(Transform):
|
|||
return self._event_dim
|
||||
|
||||
@constraints.dependent_property(is_discrete=False)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def domain(self):
|
||||
if self.event_dim == 0:
|
||||
return constraints.real
|
||||
return constraints.independent(constraints.real, self.event_dim)
|
||||
|
||||
@constraints.dependent_property(is_discrete=False)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def codomain(self):
|
||||
if self.event_dim == 0:
|
||||
return constraints.real
|
||||
|
|
@ -877,7 +877,7 @@ class CorrCholeskyTransform(Transform):
|
|||
# apply stick-breaking on the squared values
|
||||
# Note that y = sign(r) * sqrt(z * z1m_cumprod)
|
||||
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
z = r**2
|
||||
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
|
||||
# Diagonal elements must be 1.
|
||||
|
|
@ -1166,14 +1166,14 @@ class CatTransform(Transform):
|
|||
return all(t.bijective for t in self.transforms)
|
||||
|
||||
@constraints.dependent_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def domain(self):
|
||||
return constraints.cat(
|
||||
[t.domain for t in self.transforms], self.dim, self.lengths
|
||||
)
|
||||
|
||||
@constraints.dependent_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def codomain(self):
|
||||
return constraints.cat(
|
||||
[t.codomain for t in self.transforms], self.dim, self.lengths
|
||||
|
|
@ -1246,12 +1246,12 @@ class StackTransform(Transform):
|
|||
return all(t.bijective for t in self.transforms)
|
||||
|
||||
@constraints.dependent_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def domain(self):
|
||||
return constraints.stack([t.domain for t in self.transforms], self.dim)
|
||||
|
||||
@constraints.dependent_property
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def codomain(self):
|
||||
return constraints.stack([t.codomain for t in self.transforms], self.dim)
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class Uniform(Distribution):
|
|||
return new
|
||||
|
||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def support(self):
|
||||
return constraints.interval(self.low, self.high)
|
||||
|
||||
|
|
|
|||
|
|
@ -92,7 +92,7 @@ def _log_modified_bessel_fn(x, order=0):
|
|||
@torch.jit.script_if_tracing
|
||||
def _rejection_sample(loc, concentration, proposal_r, x):
|
||||
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
while not done.all():
|
||||
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
|
||||
u1, u2, u3 = u.unbind()
|
||||
|
|
@ -101,7 +101,7 @@ def _rejection_sample(loc, concentration, proposal_r, x):
|
|||
c = concentration * (proposal_r - f)
|
||||
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
|
||||
if accept.any():
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
|
||||
done = done | accept
|
||||
return (x + math.pi + loc) % (2 * math.pi) - math.pi
|
||||
|
|
@ -125,7 +125,7 @@ class VonMises(Distribution):
|
|||
:param torch.Tensor concentration: concentration parameter
|
||||
"""
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
|
||||
support = constraints.real
|
||||
has_rsample = False
|
||||
|
|
@ -163,10 +163,10 @@ class VonMises(Distribution):
|
|||
@lazy_property
|
||||
def _proposal_r(self) -> Tensor:
|
||||
kappa = self._concentration
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
tau = 1 + (1 + 4 * kappa**2).sqrt()
|
||||
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
_proposal_r = (1 + rho**2) / (2 * rho)
|
||||
# second order Taylor expansion around 0 for small kappa
|
||||
_proposal_r_taylor = 1 / kappa + kappa
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ class Weibull(TransformedDistribution):
|
|||
"scale": constraints.positive,
|
||||
"concentration": constraints.positive,
|
||||
}
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
support = constraints.positive
|
||||
|
||||
def __init__(
|
||||
|
|
@ -53,7 +53,7 @@ class Weibull(TransformedDistribution):
|
|||
PowerTransform(exponent=self.concentration_reciprocal),
|
||||
AffineTransform(loc=0, scale=self.scale),
|
||||
]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||
|
||||
def expand(self, batch_shape, _instance=None):
|
||||
|
|
|
|||
|
|
@ -116,13 +116,13 @@ class Wishart(ExponentialFamily):
|
|||
)
|
||||
|
||||
if scale_tril is not None:
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.scale_tril = param.expand(batch_shape + (-1, -1))
|
||||
elif covariance_matrix is not None:
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
|
||||
elif precision_matrix is not None:
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.precision_matrix = param.expand(batch_shape + (-1, -1))
|
||||
|
||||
if self.df.lt(event_shape[-1]).any():
|
||||
|
|
@ -339,7 +339,7 @@ class Wishart(ExponentialFamily):
|
|||
p = self._event_shape[-1] # has singleton shape
|
||||
return -self.precision_matrix / 2, (nu - p - 1) / 2
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def _log_normalizer(self, x, y):
|
||||
p = self._event_shape[-1]
|
||||
return (y + (p + 1) / 2) * (
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user