mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix pyrelfy ignore syntax in distributions and ao (#166248)
Ensures existing pyrefly ignores only ignore the intended error code pyrefly check lintrunner Pull Request resolved: https://github.com/pytorch/pytorch/pull/166248 Approved by: https://github.com/oulgen
This commit is contained in:
parent
a2b6afeac5
commit
154e4d36e9
|
|
@ -620,7 +620,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
qconfig=qconfig,
|
qconfig=qconfig,
|
||||||
)
|
)
|
||||||
|
|
@ -821,7 +821,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
qconfig=qconfig,
|
qconfig=qconfig,
|
||||||
)
|
)
|
||||||
|
|
@ -1023,7 +1023,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
qconfig=qconfig,
|
qconfig=qconfig,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ class LinearReLU(nnqat.Linear, _FusedModule):
|
||||||
torch.Size([128, 30])
|
torch.Size([128, 30])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
_FLOAT_MODULE = nni.LinearReLU
|
_FLOAT_MODULE = nni.LinearReLU
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -30,7 +30,7 @@ class LinearReLU(nnqd.Linear):
|
||||||
torch.Size([128, 30])
|
torch.Size([128, 30])
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
_FLOAT_MODULE = nni.LinearReLU
|
_FLOAT_MODULE = nni.LinearReLU
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ class ConvReLU1d(nnq.Conv1d):
|
||||||
dilation=dilation,
|
dilation=dilation,
|
||||||
groups=groups,
|
groups=groups,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
padding_mode=padding_mode,
|
padding_mode=padding_mode,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
|
||||||
assert hasattr(cls, "_FLOAT_RELU_MODULE")
|
assert hasattr(cls, "_FLOAT_RELU_MODULE")
|
||||||
relu = cls._FLOAT_RELU_MODULE()
|
relu = cls._FLOAT_RELU_MODULE()
|
||||||
modules.append(relu)
|
modules.append(relu)
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
fused = cls._FLOAT_MODULE(*modules)
|
fused = cls._FLOAT_MODULE(*modules)
|
||||||
fused.train(self.training)
|
fused.train(self.training)
|
||||||
return fused
|
return fused
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class Embedding(nn.Embedding):
|
||||||
scale_grad_by_freq,
|
scale_grad_by_freq,
|
||||||
sparse,
|
sparse,
|
||||||
_weight,
|
_weight,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
assert qconfig, "qconfig must be provided for QAT module"
|
assert qconfig, "qconfig must be provided for QAT module"
|
||||||
|
|
|
||||||
|
|
@ -170,11 +170,11 @@ class MultiheadAttention(nn.MultiheadAttention):
|
||||||
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
|
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
|
||||||
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
|
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
|
||||||
if other.in_proj_bias is None:
|
if other.in_proj_bias is None:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
observed.linear_Q.bias = None
|
observed.linear_Q.bias = None
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
observed.linear_K.bias = None
|
observed.linear_K.bias = None
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
observed.linear_V.bias = None
|
observed.linear_V.bias = None
|
||||||
else:
|
else:
|
||||||
observed.linear_Q.bias = nn.Parameter(
|
observed.linear_Q.bias = nn.Parameter(
|
||||||
|
|
@ -237,7 +237,7 @@ class MultiheadAttention(nn.MultiheadAttention):
|
||||||
_end = _start + fp.embed_dim
|
_end = _start + fp.embed_dim
|
||||||
fp.in_proj_weight[_start:_end, :] = wQ
|
fp.in_proj_weight[_start:_end, :] = wQ
|
||||||
if fp.in_proj_bias is not None:
|
if fp.in_proj_bias is not None:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
assert all(bQ == 0)
|
assert all(bQ == 0)
|
||||||
fp.in_proj_bias[_start:_end] = bQ
|
fp.in_proj_bias[_start:_end] = bQ
|
||||||
|
|
||||||
|
|
@ -245,14 +245,14 @@ class MultiheadAttention(nn.MultiheadAttention):
|
||||||
_end = _start + fp.embed_dim
|
_end = _start + fp.embed_dim
|
||||||
fp.in_proj_weight[_start:_end, :] = wK
|
fp.in_proj_weight[_start:_end, :] = wK
|
||||||
if fp.in_proj_bias is not None:
|
if fp.in_proj_bias is not None:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
assert all(bK == 0)
|
assert all(bK == 0)
|
||||||
fp.in_proj_bias[_start:_end] = bK
|
fp.in_proj_bias[_start:_end] = bK
|
||||||
|
|
||||||
_start = _end
|
_start = _end
|
||||||
fp.in_proj_weight[_start:, :] = wV
|
fp.in_proj_weight[_start:, :] = wV
|
||||||
if fp.in_proj_bias is not None:
|
if fp.in_proj_bias is not None:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
assert all(bV == 0)
|
assert all(bV == 0)
|
||||||
fp.in_proj_bias[_start:] = bV
|
fp.in_proj_bias[_start:] = bV
|
||||||
else:
|
else:
|
||||||
|
|
@ -260,11 +260,11 @@ class MultiheadAttention(nn.MultiheadAttention):
|
||||||
fp.k_proj_weight = nn.Parameter(wK)
|
fp.k_proj_weight = nn.Parameter(wK)
|
||||||
fp.v_proj_weight = nn.Parameter(wV)
|
fp.v_proj_weight = nn.Parameter(wV)
|
||||||
if fp.in_proj_bias is None:
|
if fp.in_proj_bias is None:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.linear_Q.bias = None
|
self.linear_Q.bias = None
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.linear_K.bias = None
|
self.linear_K.bias = None
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.linear_V.bias = None
|
self.linear_V.bias = None
|
||||||
else:
|
else:
|
||||||
fp.in_proj_bias[0 : fp.embed_dim] = bQ
|
fp.in_proj_bias[0 : fp.embed_dim] = bQ
|
||||||
|
|
@ -472,7 +472,7 @@ class MultiheadAttention(nn.MultiheadAttention):
|
||||||
assert static_v.size(2) == head_dim
|
assert static_v.size(2) == head_dim
|
||||||
v = static_v
|
v = static_v
|
||||||
|
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
src_len = k.size(1)
|
src_len = k.size(1)
|
||||||
|
|
||||||
if key_padding_mask is not None:
|
if key_padding_mask is not None:
|
||||||
|
|
@ -481,35 +481,35 @@ class MultiheadAttention(nn.MultiheadAttention):
|
||||||
|
|
||||||
if self.add_zero_attn:
|
if self.add_zero_attn:
|
||||||
src_len += 1
|
src_len += 1
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
|
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if k.is_quantized:
|
if k.is_quantized:
|
||||||
k_zeros = torch.quantize_per_tensor(
|
k_zeros = torch.quantize_per_tensor(
|
||||||
k_zeros,
|
k_zeros,
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
k.q_scale(),
|
k.q_scale(),
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
k.q_zero_point(),
|
k.q_zero_point(),
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
k.dtype,
|
k.dtype,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
k = torch.cat([k, k_zeros], dim=1)
|
k = torch.cat([k, k_zeros], dim=1)
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
|
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if v.is_quantized:
|
if v.is_quantized:
|
||||||
v_zeros = torch.quantize_per_tensor(
|
v_zeros = torch.quantize_per_tensor(
|
||||||
v_zeros,
|
v_zeros,
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
v.q_scale(),
|
v.q_scale(),
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
v.q_zero_point(),
|
v.q_zero_point(),
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
v.dtype,
|
v.dtype,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
v = torch.cat([v, v_zeros], dim=1)
|
v = torch.cat([v, v_zeros], dim=1)
|
||||||
|
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
|
|
|
||||||
|
|
@ -376,7 +376,7 @@ class _LSTMLayer(torch.nn.Module):
|
||||||
bidirectional,
|
bidirectional,
|
||||||
split_gates=split_gates,
|
split_gates=split_gates,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
layer.qconfig = getattr(other, "qconfig", qconfig)
|
layer.qconfig = getattr(other, "qconfig", qconfig)
|
||||||
wi = getattr(other, f"weight_ih_l{layer_idx}")
|
wi = getattr(other, f"weight_ih_l{layer_idx}")
|
||||||
wh = getattr(other, f"weight_hh_l{layer_idx}")
|
wh = getattr(other, f"weight_hh_l{layer_idx}")
|
||||||
|
|
@ -455,7 +455,7 @@ class LSTM(torch.nn.Module):
|
||||||
|
|
||||||
if (
|
if (
|
||||||
not isinstance(dropout, numbers.Number)
|
not isinstance(dropout, numbers.Number)
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
or not 0 <= dropout <= 1
|
or not 0 <= dropout <= 1
|
||||||
or isinstance(dropout, bool)
|
or isinstance(dropout, bool)
|
||||||
):
|
):
|
||||||
|
|
@ -464,7 +464,7 @@ class LSTM(torch.nn.Module):
|
||||||
"representing the probability of an element being "
|
"representing the probability of an element being "
|
||||||
"zeroed"
|
"zeroed"
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
if dropout > 0:
|
if dropout > 0:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"dropout option for quantizable LSTM is ignored. "
|
"dropout option for quantizable LSTM is ignored. "
|
||||||
|
|
@ -578,7 +578,7 @@ class LSTM(torch.nn.Module):
|
||||||
other.bidirectional,
|
other.bidirectional,
|
||||||
split_gates=split_gates,
|
split_gates=split_gates,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
observed.qconfig = getattr(other, "qconfig", qconfig)
|
observed.qconfig = getattr(other, "qconfig", qconfig)
|
||||||
for idx in range(other.num_layers):
|
for idx in range(other.num_layers):
|
||||||
observed.layers[idx] = _LSTMLayer.from_float(
|
observed.layers[idx] = _LSTMLayer.from_float(
|
||||||
|
|
|
||||||
|
|
@ -74,7 +74,7 @@ class Conv1d(nnq.Conv1d):
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
kernel_size = _single(kernel_size)
|
kernel_size = _single(kernel_size)
|
||||||
stride = _single(stride)
|
stride = _single(stride)
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
padding = padding if isinstance(padding, str) else _single(padding)
|
padding = padding if isinstance(padding, str) else _single(padding)
|
||||||
dilation = _single(dilation)
|
dilation = _single(dilation)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -119,9 +119,9 @@ class Linear(nnq.Linear):
|
||||||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||||
if type(mod) is nni.LinearReLU:
|
if type(mod) is nni.LinearReLU:
|
||||||
mod = mod[0]
|
mod = mod[0]
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||||
# pyrefly: ignore # not-callable
|
# pyrefly: ignore [not-callable]
|
||||||
weight_observer = mod.qconfig.weight()
|
weight_observer = mod.qconfig.weight()
|
||||||
else:
|
else:
|
||||||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||||
|
|
@ -145,7 +145,7 @@ class Linear(nnq.Linear):
|
||||||
"Unsupported dtype specified for dynamic quantized Linear!"
|
"Unsupported dtype specified for dynamic quantized Linear!"
|
||||||
)
|
)
|
||||||
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
|
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
qlinear.set_weight_bias(qweight, mod.bias)
|
qlinear.set_weight_bias(qweight, mod.bias)
|
||||||
return qlinear
|
return qlinear
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -522,7 +522,7 @@ class LSTM(RNNBase):
|
||||||
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
>>> output, (hn, cn) = rnn(input, (h0, c0))
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
_FLOAT_MODULE = nn.LSTM
|
_FLOAT_MODULE = nn.LSTM
|
||||||
|
|
||||||
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
||||||
|
|
@ -808,7 +808,7 @@ class GRU(RNNBase):
|
||||||
>>> output, hn = rnn(input, h0)
|
>>> output, hn = rnn(input, h0)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
_FLOAT_MODULE = nn.GRU
|
_FLOAT_MODULE = nn.GRU
|
||||||
|
|
||||||
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
|
||||||
|
|
|
||||||
|
|
@ -67,9 +67,9 @@ class Hardswish(torch.nn.Hardswish):
|
||||||
def __init__(self, scale, zero_point, device=None, dtype=None):
|
def __init__(self, scale, zero_point, device=None, dtype=None):
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
@ -140,9 +140,9 @@ class LeakyReLU(torch.nn.LeakyReLU):
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__(negative_slope, inplace)
|
super().__init__(negative_slope, inplace)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
@ -230,7 +230,7 @@ class Softmax(torch.nn.Softmax):
|
||||||
|
|
||||||
|
|
||||||
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
|
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
|
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
|
||||||
|
|
||||||
def _get_name(self):
|
def _get_name(self):
|
||||||
|
|
|
||||||
|
|
@ -12,9 +12,9 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
|
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
||||||
|
|
@ -408,7 +408,7 @@ class Conv1d(_ConvNd):
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
kernel_size = _single(kernel_size)
|
kernel_size = _single(kernel_size)
|
||||||
stride = _single(stride)
|
stride = _single(stride)
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
padding = padding if isinstance(padding, str) else _single(padding)
|
padding = padding if isinstance(padding, str) else _single(padding)
|
||||||
dilation = _single(dilation)
|
dilation = _single(dilation)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -310,7 +310,7 @@ class Linear(WeightedQuantizedModule):
|
||||||
# the type mismatch in assignment. Also, mypy has an issue with
|
# the type mismatch in assignment. Also, mypy has an issue with
|
||||||
# iterables not being implemented, so we are ignoring those too.
|
# iterables not being implemented, so we are ignoring those too.
|
||||||
if not isinstance(cls._FLOAT_MODULE, Iterable):
|
if not isinstance(cls._FLOAT_MODULE, Iterable):
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
cls._FLOAT_MODULE = [cls._FLOAT_MODULE]
|
cls._FLOAT_MODULE = [cls._FLOAT_MODULE]
|
||||||
supported_modules = ", ".join(
|
supported_modules = ", ".join(
|
||||||
[float_mod.__name__ for float_mod in cls._FLOAT_MODULE]
|
[float_mod.__name__ for float_mod in cls._FLOAT_MODULE]
|
||||||
|
|
|
||||||
|
|
@ -37,14 +37,14 @@ class LayerNorm(torch.nn.LayerNorm):
|
||||||
normalized_shape,
|
normalized_shape,
|
||||||
eps=eps,
|
eps=eps,
|
||||||
elementwise_affine=elementwise_affine,
|
elementwise_affine=elementwise_affine,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
@ -116,9 +116,9 @@ class GroupNorm(torch.nn.GroupNorm):
|
||||||
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
|
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
@ -180,9 +180,9 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
|
||||||
)
|
)
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
@ -249,9 +249,9 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
|
||||||
)
|
)
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
@ -318,9 +318,9 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
|
||||||
)
|
)
|
||||||
self.weight = weight
|
self.weight = weight
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
|
||||||
|
|
||||||
def forward(self, input):
|
def forward(self, input):
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
|
||||||
dilation,
|
dilation,
|
||||||
groups,
|
groups,
|
||||||
bias,
|
bias,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
padding_mode,
|
padding_mode,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
|
|
@ -206,7 +206,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
|
||||||
dilation,
|
dilation,
|
||||||
groups,
|
groups,
|
||||||
bias,
|
bias,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
padding_mode,
|
padding_mode,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
|
|
@ -383,7 +383,7 @@ class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
|
||||||
groups,
|
groups,
|
||||||
bias,
|
bias,
|
||||||
dilation,
|
dilation,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
padding_mode,
|
padding_mode,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
|
|
@ -465,7 +465,7 @@ class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
|
||||||
groups,
|
groups,
|
||||||
bias,
|
bias,
|
||||||
dilation,
|
dilation,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
padding_mode,
|
padding_mode,
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
|
|
|
||||||
|
|
@ -664,7 +664,7 @@ class LSTM(RNNBase):
|
||||||
if isinstance(orig_input, PackedSequence):
|
if isinstance(orig_input, PackedSequence):
|
||||||
output_packed = PackedSequence(
|
output_packed = PackedSequence(
|
||||||
output,
|
output,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
batch_sizes,
|
batch_sizes,
|
||||||
sorted_indices,
|
sorted_indices,
|
||||||
unsorted_indices,
|
unsorted_indices,
|
||||||
|
|
@ -828,7 +828,7 @@ class GRU(RNNBase):
|
||||||
if isinstance(orig_input, PackedSequence):
|
if isinstance(orig_input, PackedSequence):
|
||||||
output_packed = PackedSequence(
|
output_packed = PackedSequence(
|
||||||
output,
|
output,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
batch_sizes,
|
batch_sizes,
|
||||||
sorted_indices,
|
sorted_indices,
|
||||||
unsorted_indices,
|
unsorted_indices,
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class Embedding(nn.Embedding, ReferenceQuantizedModule):
|
||||||
scale_grad_by_freq,
|
scale_grad_by_freq,
|
||||||
sparse,
|
sparse,
|
||||||
_weight,
|
_weight,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
device,
|
device,
|
||||||
dtype,
|
dtype,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -18,7 +18,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
||||||
"scale": 1.0,
|
"scale": 1.0,
|
||||||
"zero_point": 0,
|
"zero_point": 0,
|
||||||
}
|
}
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
|
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
|
||||||
self.weight_dtype = weight_qparams["dtype"]
|
self.weight_dtype = weight_qparams["dtype"]
|
||||||
assert self.weight_qscheme in [
|
assert self.weight_qscheme in [
|
||||||
|
|
@ -81,14 +81,14 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"weight_axis", torch.tensor(0, dtype=torch.int, device=device)
|
"weight_axis", torch.tensor(0, dtype=torch.int, device=device)
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
|
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
|
||||||
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
|
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
|
||||||
# for capturing `.item` operations
|
# for capturing `.item` operations
|
||||||
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
|
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min")
|
self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min")
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max")
|
self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max")
|
||||||
|
|
||||||
def get_weight(self):
|
def get_weight(self):
|
||||||
|
|
@ -105,7 +105,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
||||||
return _quantize_and_dequantize_weight_decomposed(
|
return _quantize_and_dequantize_weight_decomposed(
|
||||||
self.weight, # type: ignore[arg-type]
|
self.weight, # type: ignore[arg-type]
|
||||||
self.weight_qscheme,
|
self.weight_qscheme,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.weight_dtype,
|
self.weight_dtype,
|
||||||
self.weight_scale,
|
self.weight_scale,
|
||||||
self.weight_zero_point,
|
self.weight_zero_point,
|
||||||
|
|
@ -117,7 +117,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
||||||
return _quantize_and_dequantize_weight(
|
return _quantize_and_dequantize_weight(
|
||||||
self.weight, # type: ignore[arg-type]
|
self.weight, # type: ignore[arg-type]
|
||||||
self.weight_qscheme,
|
self.weight_qscheme,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.weight_dtype,
|
self.weight_dtype,
|
||||||
self.weight_scale,
|
self.weight_scale,
|
||||||
self.weight_zero_point,
|
self.weight_zero_point,
|
||||||
|
|
@ -133,7 +133,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
||||||
return _quantize_weight_decomposed(
|
return _quantize_weight_decomposed(
|
||||||
self.weight, # type: ignore[arg-type]
|
self.weight, # type: ignore[arg-type]
|
||||||
self.weight_qscheme,
|
self.weight_qscheme,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.weight_dtype,
|
self.weight_dtype,
|
||||||
self.weight_scale,
|
self.weight_scale,
|
||||||
self.weight_zero_point,
|
self.weight_zero_point,
|
||||||
|
|
@ -145,7 +145,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
|
||||||
return _quantize_weight(
|
return _quantize_weight(
|
||||||
self.weight, # type: ignore[arg-type]
|
self.weight, # type: ignore[arg-type]
|
||||||
self.weight_qscheme,
|
self.weight_qscheme,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.weight_dtype,
|
self.weight_dtype,
|
||||||
self.weight_scale,
|
self.weight_scale,
|
||||||
self.weight_zero_point,
|
self.weight_zero_point,
|
||||||
|
|
|
||||||
|
|
@ -151,9 +151,9 @@ class Linear(torch.nn.Module):
|
||||||
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
|
||||||
if type(mod) is nni.LinearReLU:
|
if type(mod) is nni.LinearReLU:
|
||||||
mod = mod[0]
|
mod = mod[0]
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||||
# pyrefly: ignore # not-callable
|
# pyrefly: ignore [not-callable]
|
||||||
weight_observer = mod.qconfig.weight()
|
weight_observer = mod.qconfig.weight()
|
||||||
else:
|
else:
|
||||||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||||
|
|
@ -187,6 +187,6 @@ class Linear(torch.nn.Module):
|
||||||
col_block_size,
|
col_block_size,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
|
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
|
||||||
return qlinear
|
return qlinear
|
||||||
|
|
|
||||||
|
|
@ -959,7 +959,7 @@ def create_a_shadows_b(
|
||||||
if should_log_inputs:
|
if should_log_inputs:
|
||||||
# skip the input logger when inserting a dtype cast
|
# skip the input logger when inserting a dtype cast
|
||||||
if isinstance(prev_node_c, Node):
|
if isinstance(prev_node_c, Node):
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
|
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
|
||||||
elif isinstance(prev_node_c, list):
|
elif isinstance(prev_node_c, list):
|
||||||
prev_node_c = [
|
prev_node_c = [
|
||||||
|
|
@ -968,7 +968,7 @@ def create_a_shadows_b(
|
||||||
]
|
]
|
||||||
dtype_cast_node = _insert_dtype_cast_after_node(
|
dtype_cast_node = _insert_dtype_cast_after_node(
|
||||||
subgraph_a.start_node,
|
subgraph_a.start_node,
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
node_c,
|
node_c,
|
||||||
prev_node_c,
|
prev_node_c,
|
||||||
gm_a,
|
gm_a,
|
||||||
|
|
@ -1052,7 +1052,7 @@ def create_a_shadows_b(
|
||||||
if num_non_param_args_node_a == 2:
|
if num_non_param_args_node_a == 2:
|
||||||
# node_c_second_non_param_arg = node_c.args[1]
|
# node_c_second_non_param_arg = node_c.args[1]
|
||||||
node_c_second_non_param_arg = get_normalized_nth_input(
|
node_c_second_non_param_arg = get_normalized_nth_input(
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
node_c,
|
node_c,
|
||||||
gm_b,
|
gm_b,
|
||||||
1,
|
1,
|
||||||
|
|
@ -1063,7 +1063,7 @@ def create_a_shadows_b(
|
||||||
subgraph_a,
|
subgraph_a,
|
||||||
gm_a,
|
gm_a,
|
||||||
gm_b,
|
gm_b,
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
node_c.name + "_shadow_copy_",
|
node_c.name + "_shadow_copy_",
|
||||||
)
|
)
|
||||||
env_c[node_a_shadows_c.name] = node_a_shadows_c
|
env_c[node_a_shadows_c.name] = node_a_shadows_c
|
||||||
|
|
@ -1086,19 +1086,19 @@ def create_a_shadows_b(
|
||||||
cur_node = node_a_shadows_c
|
cur_node = node_a_shadows_c
|
||||||
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
|
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
|
||||||
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
|
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
if isinstance(input_logger, Node):
|
if isinstance(input_logger, Node):
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
input_logger_mod = getattr(gm_b, input_logger.name)
|
input_logger_mod = getattr(gm_b, input_logger.name)
|
||||||
input_logger_mod.ref_node_name = cur_node.name
|
input_logger_mod.ref_node_name = cur_node.name
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
if not isinstance(input_logger, list):
|
if not isinstance(input_logger, list):
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
f"Expected list, got {type(input_logger)}"
|
f"Expected list, got {type(input_logger)}"
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
for input_logger_inner in input_logger:
|
for input_logger_inner in input_logger:
|
||||||
input_logger_mod = getattr(gm_b, input_logger_inner.name)
|
input_logger_mod = getattr(gm_b, input_logger_inner.name)
|
||||||
input_logger_mod.ref_node_name = cur_node.name
|
input_logger_mod.ref_node_name = cur_node.name
|
||||||
|
|
|
||||||
|
|
@ -419,7 +419,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
|
||||||
target2,
|
target2,
|
||||||
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
|
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
|
||||||
new_connections.append((source, target1))
|
new_connections.append((source, target1))
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
new_connections.append((source, target2))
|
new_connections.append((source, target2))
|
||||||
|
|
||||||
for source_to_target in (
|
for source_to_target in (
|
||||||
|
|
@ -428,7 +428,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
|
||||||
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
|
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
|
||||||
):
|
):
|
||||||
for source, target in source_to_target.items(): # type:ignore[assignment]
|
for source, target in source_to_target.items(): # type:ignore[assignment]
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
new_connections.append((source, target))
|
new_connections.append((source, target))
|
||||||
|
|
||||||
#
|
#
|
||||||
|
|
|
||||||
|
|
@ -94,10 +94,10 @@ class OutputProp:
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
|
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
node.traced_result = result
|
node.traced_result = result
|
||||||
|
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
env[node.name] = result
|
env[node.name] = result
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
@ -403,10 +403,10 @@ def create_submodule_from_subgraph(
|
||||||
cur_name_idx += 1
|
cur_name_idx += 1
|
||||||
setattr(gm, mod_name, new_arg)
|
setattr(gm, mod_name, new_arg)
|
||||||
new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator]
|
new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator]
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
cur_args_copy.append(new_arg_placeholder)
|
cur_args_copy.append(new_arg_placeholder)
|
||||||
elif isinstance(arg, (float, int, torch.dtype)):
|
elif isinstance(arg, (float, int, torch.dtype)):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
cur_args_copy.append(arg)
|
cur_args_copy.append(arg)
|
||||||
else:
|
else:
|
||||||
raise AssertionError(f"arg of type {type(arg)} not handled yet")
|
raise AssertionError(f"arg of type {type(arg)} not handled yet")
|
||||||
|
|
@ -818,7 +818,7 @@ def create_add_loggers_graph(
|
||||||
model,
|
model,
|
||||||
cur_subgraph_idx,
|
cur_subgraph_idx,
|
||||||
match_name,
|
match_name,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
maybe_subgraph,
|
maybe_subgraph,
|
||||||
[qconfig_mapping],
|
[qconfig_mapping],
|
||||||
[node_name_to_qconfig],
|
[node_name_to_qconfig],
|
||||||
|
|
@ -879,7 +879,7 @@ def create_add_loggers_graph(
|
||||||
cur_node_orig = first_node
|
cur_node_orig = first_node
|
||||||
cur_node_copy = None
|
cur_node_copy = None
|
||||||
first_node_copy = None
|
first_node_copy = None
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
while cur_node_orig in subgraph_to_use:
|
while cur_node_orig in subgraph_to_use:
|
||||||
# TODO(future PR): make this support all possible args/kwargs
|
# TODO(future PR): make this support all possible args/kwargs
|
||||||
if cur_node_orig is first_node:
|
if cur_node_orig is first_node:
|
||||||
|
|
|
||||||
|
|
@ -496,7 +496,7 @@ def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso
|
||||||
Return:
|
Return:
|
||||||
float or tuple of floats
|
float or tuple of floats
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())
|
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,17 +23,17 @@ class SparseDLRM(DLRM_Net):
|
||||||
super().__init__(**args)
|
super().__init__(**args)
|
||||||
|
|
||||||
def forward(self, dense_x, lS_o, lS_i):
|
def forward(self, dense_x, lS_o, lS_i):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
x = self.apply_mlp(dense_x, self.bot_l) # dense features
|
x = self.apply_mlp(dense_x, self.bot_l) # dense features
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag
|
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
z = self.interact_features(x, ly)
|
z = self.interact_features(x, ly)
|
||||||
|
|
||||||
z = z.to_sparse_coo()
|
z = z.to_sparse_coo()
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias)
|
z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias)
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
for layer in self.top_l[1:]:
|
for layer in self.top_l[1:]:
|
||||||
z = layer(z)
|
z = layer(z)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ def run_forward(model, **batch):
|
||||||
model(X, lS_o, lS_i)
|
model(X, lS_o, lS_i)
|
||||||
end = time.time()
|
end = time.time()
|
||||||
time_taken = end - start
|
time_taken = end - start
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
time_list.append(time_taken)
|
time_list.append(time_taken)
|
||||||
avg_time = np.mean(time_list[1:])
|
avg_time = np.mean(time_list[1:])
|
||||||
return avg_time
|
return avg_time
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ class FPGMPruner(BaseStructuredSparsifier):
|
||||||
dist_matrix = self.dist_fn(t_flatten)
|
dist_matrix = self.dist_fn(t_flatten)
|
||||||
|
|
||||||
# more similar with other filter indicates large in the sum of row
|
# more similar with other filter indicates large in the sum of row
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
distance = torch.sum(torch.abs(dist_matrix), 1)
|
distance = torch.sum(torch.abs(dist_matrix), 1)
|
||||||
|
|
||||||
return distance
|
return distance
|
||||||
|
|
|
||||||
|
|
@ -260,7 +260,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
|
||||||
module.register_parameter(
|
module.register_parameter(
|
||||||
"_bias", nn.Parameter(module.bias.detach())
|
"_bias", nn.Parameter(module.bias.detach())
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
module.bias = None
|
module.bias = None
|
||||||
module.prune_bias = prune_bias
|
module.prune_bias = prune_bias
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -97,7 +97,7 @@ def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]:
|
||||||
if module.bias is not None:
|
if module.bias is not None:
|
||||||
module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
|
module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
|
||||||
elif getattr(module, "_bias", None) is not None:
|
elif getattr(module, "_bias", None) is not None:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
|
module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
|
||||||
|
|
||||||
# get pruned biases to propagate to subsequent layer
|
# get pruned biases to propagate to subsequent layer
|
||||||
|
|
@ -127,7 +127,7 @@ def _prune_linear_helper(linear: nn.Linear) -> Tensor:
|
||||||
linear.out_features = linear.weight.shape[0]
|
linear.out_features = linear.weight.shape[0]
|
||||||
_remove_bias_handles(linear)
|
_remove_bias_handles(linear)
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -186,7 +186,7 @@ def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
|
||||||
conv2d.out_channels = conv2d.weight.shape[0]
|
conv2d.out_channels = conv2d.weight.shape[0]
|
||||||
|
|
||||||
_remove_bias_handles(conv2d)
|
_remove_bias_handles(conv2d)
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
return mask
|
return mask
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -207,7 +207,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
|
||||||
new_bias = torch.zeros(conv2d_1.bias.shape)
|
new_bias = torch.zeros(conv2d_1.bias.shape)
|
||||||
new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined]
|
new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined]
|
||||||
# adjusted bias that to keep in conv2d_1
|
# adjusted bias that to keep in conv2d_1
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
|
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
|
||||||
# pruned biases that are kept instead of propagated
|
# pruned biases that are kept instead of propagated
|
||||||
conv2d_1.bias = nn.Parameter(new_bias)
|
conv2d_1.bias = nn.Parameter(new_bias)
|
||||||
|
|
|
||||||
|
|
@ -170,7 +170,7 @@ class BaseSparsifier(abc.ABC):
|
||||||
self.make_config_from_model(model)
|
self.make_config_from_model(model)
|
||||||
|
|
||||||
# TODO: Remove the configuration by reference ('module')
|
# TODO: Remove the configuration by reference ('module')
|
||||||
# pyrefly: ignore # not-iterable
|
# pyrefly: ignore [not-iterable]
|
||||||
for module_config in self.config:
|
for module_config in self.config:
|
||||||
assert isinstance(module_config, dict), (
|
assert isinstance(module_config, dict), (
|
||||||
"config elements should be dicts not modules i.e.:"
|
"config elements should be dicts not modules i.e.:"
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ def swap_module(
|
||||||
new_mod.register_forward_hook(hook_fn)
|
new_mod.register_forward_hook(hook_fn)
|
||||||
|
|
||||||
# respect device affinity when swapping modules
|
# respect device affinity when swapping modules
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
|
devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
|
||||||
assert len(devices) <= 1, (
|
assert len(devices) <= 1, (
|
||||||
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
|
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"
|
||||||
|
|
|
||||||
|
|
@ -235,7 +235,7 @@ class WeightNormSparsifier(BaseSparsifier):
|
||||||
ww = self.norm_fn(getattr(module, tensor_name))
|
ww = self.norm_fn(getattr(module, tensor_name))
|
||||||
tensor_mask = self._make_tensor_mask(
|
tensor_mask = self._make_tensor_mask(
|
||||||
data=ww,
|
data=ww,
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
input_shape=ww.shape,
|
input_shape=ww.shape,
|
||||||
sparsity_level=sparsity_level,
|
sparsity_level=sparsity_level,
|
||||||
sparse_block_shape=sparse_block_shape,
|
sparse_block_shape=sparse_block_shape,
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from .pt2e.export_utils import (
|
||||||
_move_exported_model_to_train as move_exported_model_to_train,
|
_move_exported_model_to_train as move_exported_model_to_train,
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyrefly: ignore # deprecated
|
# pyrefly: ignore [deprecated]
|
||||||
from .qconfig import * # noqa: F403
|
from .qconfig import * # noqa: F403
|
||||||
from .qconfig_mapping import * # noqa: F403
|
from .qconfig_mapping import * # noqa: F403
|
||||||
from .quant_type import * # noqa: F403
|
from .quant_type import * # noqa: F403
|
||||||
|
|
|
||||||
|
|
@ -185,9 +185,9 @@ class FakeQuantize(FakeQuantizeBase):
|
||||||
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
|
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
|
||||||
"dtype", dtype
|
"dtype", dtype
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
|
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
|
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
|
||||||
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
|
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
|
||||||
observer_kwargs["is_dynamic"] = is_dynamic
|
observer_kwargs["is_dynamic"] = is_dynamic
|
||||||
|
|
|
||||||
|
|
@ -1149,7 +1149,7 @@ quantized_decomposed_lib.define(
|
||||||
|
|
||||||
class FakeQuantPerChannel(torch.autograd.Function):
|
class FakeQuantPerChannel(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
|
def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
|
||||||
if scales.dtype != torch.float32:
|
if scales.dtype != torch.float32:
|
||||||
scales = scales.to(torch.float32)
|
scales = scales.to(torch.float32)
|
||||||
|
|
@ -1172,7 +1172,7 @@ class FakeQuantPerChannel(torch.autograd.Function):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, gy):
|
def backward(ctx, gy):
|
||||||
(mask,) = ctx.saved_tensors
|
(mask,) = ctx.saved_tensors
|
||||||
return gy * mask, None, None, None, None, None
|
return gy * mask, None, None, None, None, None
|
||||||
|
|
|
||||||
|
|
@ -248,7 +248,7 @@ def calculate_equalization_scale(
|
||||||
|
|
||||||
|
|
||||||
class EqualizationQConfig(
|
class EqualizationQConfig(
|
||||||
# pyrefly: ignore # invalid-inheritance
|
# pyrefly: ignore [invalid-inheritance]
|
||||||
namedtuple("EqualizationQConfig", ["input_activation", "weight"])
|
namedtuple("EqualizationQConfig", ["input_activation", "weight"])
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
|
@ -463,7 +463,7 @@ def maybe_get_next_equalization_scale(
|
||||||
In this case, the node given is linear1 and we want to locate the InputEqObs.
|
In this case, the node given is linear1 and we want to locate the InputEqObs.
|
||||||
"""
|
"""
|
||||||
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
|
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
|
||||||
# pyrefly: ignore # invalid-argument
|
# pyrefly: ignore [invalid-argument]
|
||||||
if next_inp_eq_obs:
|
if next_inp_eq_obs:
|
||||||
if (
|
if (
|
||||||
next_inp_eq_obs.equalization_scale.nelement() == 1
|
next_inp_eq_obs.equalization_scale.nelement() == 1
|
||||||
|
|
@ -827,7 +827,7 @@ def convert_eq_obs(
|
||||||
scale_weight_node(
|
scale_weight_node(
|
||||||
node,
|
node,
|
||||||
modules,
|
modules,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
equalization_scale,
|
equalization_scale,
|
||||||
maybe_next_equalization_scale,
|
maybe_next_equalization_scale,
|
||||||
)
|
)
|
||||||
|
|
@ -836,7 +836,7 @@ def convert_eq_obs(
|
||||||
node,
|
node,
|
||||||
model,
|
model,
|
||||||
modules,
|
modules,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
equalization_scale,
|
equalization_scale,
|
||||||
maybe_next_equalization_scale,
|
maybe_next_equalization_scale,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -223,7 +223,7 @@ class ModelReportVisualizer:
|
||||||
feature_val = feature_val.item()
|
feature_val = feature_val.item()
|
||||||
|
|
||||||
# we add to our list of values
|
# we add to our list of values
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
tensor_table_row.append(feature_val)
|
tensor_table_row.append(feature_val)
|
||||||
|
|
||||||
tensor_table.append(tensor_table_row)
|
tensor_table.append(tensor_table_row)
|
||||||
|
|
@ -284,7 +284,7 @@ class ModelReportVisualizer:
|
||||||
feature_val = feature_val.item()
|
feature_val = feature_val.item()
|
||||||
|
|
||||||
# add value to channel specific row
|
# add value to channel specific row
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
new_channel_row.append(feature_val)
|
new_channel_row.append(feature_val)
|
||||||
|
|
||||||
# add to table and increment row index counter
|
# add to table and increment row index counter
|
||||||
|
|
|
||||||
|
|
@ -166,7 +166,7 @@ def _create_obs_or_fq_from_qspec(
|
||||||
}
|
}
|
||||||
edge_or_nodes = quantization_spec.derived_from
|
edge_or_nodes = quantization_spec.derived_from
|
||||||
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
|
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
kwargs["obs_or_fqs"] = obs_or_fqs
|
kwargs["obs_or_fqs"] = obs_or_fqs
|
||||||
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
|
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
|
||||||
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
|
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
|
||||||
|
|
@ -2088,11 +2088,11 @@ def prepare(
|
||||||
|
|
||||||
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
|
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
_update_qconfig_for_fusion(model, qconfig_mapping)
|
_update_qconfig_for_fusion(model, qconfig_mapping)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
_update_qconfig_for_fusion(model, _equalization_config)
|
_update_qconfig_for_fusion(model, _equalization_config)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
|
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
|
||||||
# TODO: support regex as well
|
# TODO: support regex as well
|
||||||
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
|
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
|
||||||
|
|
@ -2100,7 +2100,7 @@ def prepare(
|
||||||
if is_qat:
|
if is_qat:
|
||||||
module_to_qat_module = get_module_to_qat_module(backend_config)
|
module_to_qat_module = get_module_to_qat_module(backend_config)
|
||||||
_qat_swap_modules(model, module_to_qat_module)
|
_qat_swap_modules(model, module_to_qat_module)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
_update_qconfig_for_qat(qconfig_mapping, backend_config)
|
_update_qconfig_for_qat(qconfig_mapping, backend_config)
|
||||||
|
|
||||||
# mapping from fully qualified module name to module instance
|
# mapping from fully qualified module name to module instance
|
||||||
|
|
@ -2117,7 +2117,7 @@ def prepare(
|
||||||
model,
|
model,
|
||||||
named_modules,
|
named_modules,
|
||||||
model.graph,
|
model.graph,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
_equalization_config,
|
_equalization_config,
|
||||||
node_name_to_scope,
|
node_name_to_scope,
|
||||||
)
|
)
|
||||||
|
|
@ -2125,7 +2125,7 @@ def prepare(
|
||||||
model,
|
model,
|
||||||
named_modules,
|
named_modules,
|
||||||
model.graph,
|
model.graph,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
qconfig_mapping,
|
qconfig_mapping,
|
||||||
node_name_to_scope,
|
node_name_to_scope,
|
||||||
)
|
)
|
||||||
|
|
@ -2187,7 +2187,7 @@ def prepare(
|
||||||
node_name_to_scope,
|
node_name_to_scope,
|
||||||
prepare_custom_config,
|
prepare_custom_config,
|
||||||
equalization_node_name_to_qconfig,
|
equalization_node_name_to_qconfig,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
qconfig_mapping,
|
qconfig_mapping,
|
||||||
is_qat,
|
is_qat,
|
||||||
observed_node_names,
|
observed_node_names,
|
||||||
|
|
|
||||||
|
|
@ -721,7 +721,7 @@ def _maybe_get_custom_module_lstm_from_node_arg(
|
||||||
a = a.args[0][0] # type: ignore[assignment,index]
|
a = a.args[0][0] # type: ignore[assignment,index]
|
||||||
else:
|
else:
|
||||||
a = a.args[0] # type: ignore[assignment]
|
a = a.args[0] # type: ignore[assignment]
|
||||||
# pyrefly: ignore # bad-return
|
# pyrefly: ignore [bad-return]
|
||||||
return a
|
return a
|
||||||
|
|
||||||
all_match_patterns = [
|
all_match_patterns = [
|
||||||
|
|
|
||||||
|
|
@ -281,12 +281,12 @@ class UniformQuantizationObserverBase(ObserverBase):
|
||||||
)
|
)
|
||||||
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
|
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
|
||||||
if self.has_customized_qrange:
|
if self.has_customized_qrange:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
validate_qmin_qmax(quant_min, quant_max)
|
validate_qmin_qmax(quant_min, quant_max)
|
||||||
self.quant_min, self.quant_max = calculate_qmin_qmax(
|
self.quant_min, self.quant_max = calculate_qmin_qmax(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
quant_min,
|
quant_min,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
quant_max,
|
quant_max,
|
||||||
self.has_customized_qrange,
|
self.has_customized_qrange,
|
||||||
self.dtype,
|
self.dtype,
|
||||||
|
|
|
||||||
|
|
@ -83,7 +83,7 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
# pyrefly: ignore # invalid-inheritance
|
# pyrefly: ignore [invalid-inheritance]
|
||||||
class QConfig(namedtuple("QConfig", ["activation", "weight"])):
|
class QConfig(namedtuple("QConfig", ["activation", "weight"])):
|
||||||
"""
|
"""
|
||||||
Describes how to quantize a layer or a part of the network by providing
|
Describes how to quantize a layer or a part of the network by providing
|
||||||
|
|
@ -121,7 +121,7 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
|
||||||
"`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead",
|
"`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead",
|
||||||
category=FutureWarning,
|
category=FutureWarning,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # invalid-inheritance
|
# pyrefly: ignore [invalid-inheritance]
|
||||||
class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
|
class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
|
||||||
"""
|
"""
|
||||||
Describes how to dynamically quantize a layer or a part of the network by providing
|
Describes how to dynamically quantize a layer or a part of the network by providing
|
||||||
|
|
|
||||||
|
|
@ -418,7 +418,7 @@ class X86InductorQuantizer(Quantizer):
|
||||||
# As we use `_need_skip_config` to skip all invalid configurations,
|
# As we use `_need_skip_config` to skip all invalid configurations,
|
||||||
# we can safely assume that the all existing non-None configurations
|
# we can safely assume that the all existing non-None configurations
|
||||||
# have the same quantization mode.
|
# have the same quantization mode.
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
for qconfig in (
|
for qconfig in (
|
||||||
list(self.module_name_qconfig.values())
|
list(self.module_name_qconfig.values())
|
||||||
+ list(self.operator_type_qconfig.values())
|
+ list(self.operator_type_qconfig.values())
|
||||||
|
|
@ -818,7 +818,7 @@ class X86InductorQuantizer(Quantizer):
|
||||||
)
|
)
|
||||||
binary_node.meta[QUANT_ANNOTATION_KEY] = (
|
binary_node.meta[QUANT_ANNOTATION_KEY] = (
|
||||||
_X86InductorQuantizationAnnotation(
|
_X86InductorQuantizationAnnotation(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
input_qspec_map=binary_node_input_qspec_map,
|
input_qspec_map=binary_node_input_qspec_map,
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
)
|
)
|
||||||
|
|
@ -889,7 +889,7 @@ class X86InductorQuantizer(Quantizer):
|
||||||
)
|
)
|
||||||
binary_node.meta[QUANT_ANNOTATION_KEY] = (
|
binary_node.meta[QUANT_ANNOTATION_KEY] = (
|
||||||
_X86InductorQuantizationAnnotation(
|
_X86InductorQuantizationAnnotation(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
input_qspec_map=binary_node_input_qspec_map,
|
input_qspec_map=binary_node_input_qspec_map,
|
||||||
# TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
|
# TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
|
||||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||||
|
|
@ -1097,7 +1097,7 @@ class X86InductorQuantizer(Quantizer):
|
||||||
quantization_config
|
quantization_config
|
||||||
)
|
)
|
||||||
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
input_qspec_map=binary_node_input_qspec_map,
|
input_qspec_map=binary_node_input_qspec_map,
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
)
|
)
|
||||||
|
|
@ -1152,7 +1152,7 @@ class X86InductorQuantizer(Quantizer):
|
||||||
quantization_config
|
quantization_config
|
||||||
)
|
)
|
||||||
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
input_qspec_map=binary_node_input_qspec_map,
|
input_qspec_map=binary_node_input_qspec_map,
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
_is_output_of_quantized_pattern=True,
|
_is_output_of_quantized_pattern=True,
|
||||||
|
|
@ -1508,7 +1508,7 @@ class X86InductorQuantizer(Quantizer):
|
||||||
has_unary = unary_op is not None
|
has_unary = unary_op is not None
|
||||||
seq_partition = [torch.nn.Linear, binary_op]
|
seq_partition = [torch.nn.Linear, binary_op]
|
||||||
if has_unary:
|
if has_unary:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
seq_partition.append(unary_op)
|
seq_partition.append(unary_op)
|
||||||
fused_partitions = find_sequential_partitions(gm, seq_partition)
|
fused_partitions = find_sequential_partitions(gm, seq_partition)
|
||||||
for fused_partition in fused_partitions:
|
for fused_partition in fused_partitions:
|
||||||
|
|
|
||||||
|
|
@ -376,11 +376,11 @@ def _do_annotate_conv_relu(
|
||||||
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
input_qspec_map[bias] = get_bias_qspec(quantization_config)
|
||||||
partition.append(bias)
|
partition.append(bias)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
if _is_annotated(partition):
|
if _is_annotated(partition):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
if filter_fn and any(not filter_fn(n) for n in partition):
|
if filter_fn and any(not filter_fn(n) for n in partition):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -391,7 +391,7 @@ def _do_annotate_conv_relu(
|
||||||
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
|
||||||
_annotated=True,
|
_annotated=True,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
_mark_nodes_as_annotated(partition)
|
_mark_nodes_as_annotated(partition)
|
||||||
annotated_partitions.append(partition)
|
annotated_partitions.append(partition)
|
||||||
return annotated_partitions
|
return annotated_partitions
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ class Bernoulli(ExponentialFamily):
|
||||||
validate_args (bool, optional): whether to validate arguments, None by default
|
validate_args (bool, optional): whether to validate arguments, None by default
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||||
support = constraints.boolean
|
support = constraints.boolean
|
||||||
has_enumerate_support = True
|
has_enumerate_support = True
|
||||||
|
|
@ -57,12 +57,12 @@ class Bernoulli(ExponentialFamily):
|
||||||
)
|
)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
is_scalar = isinstance(probs, _Number)
|
is_scalar = isinstance(probs, _Number)
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
(self.probs,) = broadcast_all(probs)
|
(self.probs,) = broadcast_all(probs)
|
||||||
else:
|
else:
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
is_scalar = isinstance(logits, _Number)
|
is_scalar = isinstance(logits, _Number)
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
(self.logits,) = broadcast_all(logits)
|
(self.logits,) = broadcast_all(logits)
|
||||||
self._param = self.probs if probs is not None else self.logits
|
self._param = self.probs if probs is not None else self.logits
|
||||||
if is_scalar:
|
if is_scalar:
|
||||||
|
|
@ -140,6 +140,6 @@ class Bernoulli(ExponentialFamily):
|
||||||
def _natural_params(self) -> tuple[Tensor]:
|
def _natural_params(self) -> tuple[Tensor]:
|
||||||
return (torch.logit(self.probs),)
|
return (torch.logit(self.probs),)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def _log_normalizer(self, x):
|
def _log_normalizer(self, x):
|
||||||
return torch.log1p(torch.exp(x))
|
return torch.log1p(torch.exp(x))
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ class Beta(ExponentialFamily):
|
||||||
(often referred to as beta)
|
(often referred to as beta)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"concentration1": constraints.positive,
|
"concentration1": constraints.positive,
|
||||||
"concentration0": constraints.positive,
|
"concentration0": constraints.positive,
|
||||||
|
|
@ -114,6 +114,6 @@ class Beta(ExponentialFamily):
|
||||||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||||
return (self.concentration1, self.concentration0)
|
return (self.concentration1, self.concentration0)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def _log_normalizer(self, x, y):
|
def _log_normalizer(self, x, y):
|
||||||
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
|
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class Binomial(Distribution):
|
||||||
logits (Tensor): Event log-odds
|
logits (Tensor): Event log-odds
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"total_count": constraints.nonnegative_integer,
|
"total_count": constraints.nonnegative_integer,
|
||||||
"probs": constraints.unit_interval,
|
"probs": constraints.unit_interval,
|
||||||
|
|
@ -67,7 +67,7 @@ class Binomial(Distribution):
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
(
|
(
|
||||||
self.total_count,
|
self.total_count,
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.probs,
|
self.probs,
|
||||||
) = broadcast_all(total_count, probs)
|
) = broadcast_all(total_count, probs)
|
||||||
self.total_count = self.total_count.type_as(self.probs)
|
self.total_count = self.total_count.type_as(self.probs)
|
||||||
|
|
@ -75,7 +75,7 @@ class Binomial(Distribution):
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
(
|
(
|
||||||
self.total_count,
|
self.total_count,
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.logits,
|
self.logits,
|
||||||
) = broadcast_all(total_count, logits)
|
) = broadcast_all(total_count, logits)
|
||||||
self.total_count = self.total_count.type_as(self.logits)
|
self.total_count = self.total_count.type_as(self.logits)
|
||||||
|
|
@ -102,7 +102,7 @@ class Binomial(Distribution):
|
||||||
return self._param.new(*args, **kwargs)
|
return self._param.new(*args, **kwargs)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def support(self):
|
def support(self):
|
||||||
return constraints.integer_interval(0, self.total_count)
|
return constraints.integer_interval(0, self.total_count)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class Categorical(Distribution):
|
||||||
logits (Tensor): event log probabilities (unnormalized)
|
logits (Tensor): event log probabilities (unnormalized)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||||
has_enumerate_support = True
|
has_enumerate_support = True
|
||||||
|
|
||||||
|
|
@ -67,14 +67,14 @@ class Categorical(Distribution):
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
if probs.dim() < 1:
|
if probs.dim() < 1:
|
||||||
raise ValueError("`probs` parameter must be at least one-dimensional.")
|
raise ValueError("`probs` parameter must be at least one-dimensional.")
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.probs = probs / probs.sum(-1, keepdim=True)
|
self.probs = probs / probs.sum(-1, keepdim=True)
|
||||||
else:
|
else:
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
if logits.dim() < 1:
|
if logits.dim() < 1:
|
||||||
raise ValueError("`logits` parameter must be at least one-dimensional.")
|
raise ValueError("`logits` parameter must be at least one-dimensional.")
|
||||||
# Normalize
|
# Normalize
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
|
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
|
||||||
self._param = self.probs if probs is not None else self.logits
|
self._param = self.probs if probs is not None else self.logits
|
||||||
self._num_events = self._param.size()[-1]
|
self._num_events = self._param.size()[-1]
|
||||||
|
|
@ -102,7 +102,7 @@ class Categorical(Distribution):
|
||||||
return self._param.new(*args, **kwargs)
|
return self._param.new(*args, **kwargs)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
@constraints.dependent_property(is_discrete=True, event_dim=0)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def support(self):
|
def support(self):
|
||||||
return constraints.integer_interval(0, self._num_events - 1)
|
return constraints.integer_interval(0, self._num_events - 1)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ class Cauchy(Distribution):
|
||||||
scale (float or Tensor): half width at half maximum.
|
scale (float or Tensor): half width at half maximum.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
|
|
||||||
|
|
@ -47,7 +47,7 @@ class ContinuousBernoulli(ExponentialFamily):
|
||||||
https://arxiv.org/abs/1907.06845
|
https://arxiv.org/abs/1907.06845
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||||
support = constraints.unit_interval
|
support = constraints.unit_interval
|
||||||
_mean_carrier_measure = 0
|
_mean_carrier_measure = 0
|
||||||
|
|
@ -66,19 +66,19 @@ class ContinuousBernoulli(ExponentialFamily):
|
||||||
)
|
)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
is_scalar = isinstance(probs, _Number)
|
is_scalar = isinstance(probs, _Number)
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
(self.probs,) = broadcast_all(probs)
|
(self.probs,) = broadcast_all(probs)
|
||||||
# validate 'probs' here if necessary as it is later clamped for numerical stability
|
# validate 'probs' here if necessary as it is later clamped for numerical stability
|
||||||
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
|
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
|
||||||
if validate_args is not None:
|
if validate_args is not None:
|
||||||
if not self.arg_constraints["probs"].check(self.probs).all():
|
if not self.arg_constraints["probs"].check(self.probs).all():
|
||||||
raise ValueError("The parameter probs has invalid values")
|
raise ValueError("The parameter probs has invalid values")
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.probs = clamp_probs(self.probs)
|
self.probs = clamp_probs(self.probs)
|
||||||
else:
|
else:
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
is_scalar = isinstance(logits, _Number)
|
is_scalar = isinstance(logits, _Number)
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
(self.logits,) = broadcast_all(logits)
|
(self.logits,) = broadcast_all(logits)
|
||||||
self._param = self.probs if probs is not None else self.logits
|
self._param = self.probs if probs is not None else self.logits
|
||||||
if is_scalar:
|
if is_scalar:
|
||||||
|
|
@ -234,7 +234,7 @@ class ContinuousBernoulli(ExponentialFamily):
|
||||||
def _natural_params(self) -> tuple[Tensor]:
|
def _natural_params(self) -> tuple[Tensor]:
|
||||||
return (self.logits,)
|
return (self.logits,)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def _log_normalizer(self, x):
|
def _log_normalizer(self, x):
|
||||||
"""computes the log normalizing constant as a function of the natural parameter"""
|
"""computes the log normalizing constant as a function of the natural parameter"""
|
||||||
out_unst_reg = torch.max(
|
out_unst_reg = torch.max(
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ def _Dirichlet_backward(x, concentration, grad_output):
|
||||||
|
|
||||||
class _Dirichlet(Function):
|
class _Dirichlet(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def forward(ctx, concentration):
|
def forward(ctx, concentration):
|
||||||
x = torch._sample_dirichlet(concentration)
|
x = torch._sample_dirichlet(concentration)
|
||||||
ctx.save_for_backward(x, concentration)
|
ctx.save_for_backward(x, concentration)
|
||||||
|
|
@ -30,7 +30,7 @@ class _Dirichlet(Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@once_differentiable
|
@once_differentiable
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
x, concentration = ctx.saved_tensors
|
x, concentration = ctx.saved_tensors
|
||||||
return _Dirichlet_backward(x, concentration, grad_output)
|
return _Dirichlet_backward(x, concentration, grad_output)
|
||||||
|
|
@ -52,7 +52,7 @@ class Dirichlet(ExponentialFamily):
|
||||||
(often referred to as alpha)
|
(often referred to as alpha)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"concentration": constraints.independent(constraints.positive, 1)
|
"concentration": constraints.independent(constraints.positive, 1)
|
||||||
}
|
}
|
||||||
|
|
@ -133,6 +133,6 @@ class Dirichlet(ExponentialFamily):
|
||||||
def _natural_params(self) -> tuple[Tensor]:
|
def _natural_params(self) -> tuple[Tensor]:
|
||||||
return (self.concentration,)
|
return (self.concentration,)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def _log_normalizer(self, x):
|
def _log_normalizer(self, x):
|
||||||
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
|
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ class Exponential(ExponentialFamily):
|
||||||
rate (float or Tensor): rate = 1 / scale of the distribution
|
rate (float or Tensor): rate = 1 / scale of the distribution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"rate": constraints.positive}
|
arg_constraints = {"rate": constraints.positive}
|
||||||
support = constraints.nonnegative
|
support = constraints.nonnegative
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
|
@ -90,6 +90,6 @@ class Exponential(ExponentialFamily):
|
||||||
def _natural_params(self) -> tuple[Tensor]:
|
def _natural_params(self) -> tuple[Tensor]:
|
||||||
return (-self.rate,)
|
return (-self.rate,)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def _log_normalizer(self, x):
|
def _log_normalizer(self, x):
|
||||||
return -torch.log(-x)
|
return -torch.log(-x)
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ class FisherSnedecor(Distribution):
|
||||||
df2 (float or Tensor): degrees of freedom parameter 2
|
df2 (float or Tensor): degrees of freedom parameter 2
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
|
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
|
||||||
support = constraints.positive
|
support = constraints.positive
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ class Gamma(ExponentialFamily):
|
||||||
(often referred to as beta), rate = 1 / scale
|
(often referred to as beta), rate = 1 / scale
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"concentration": constraints.positive,
|
"concentration": constraints.positive,
|
||||||
"rate": constraints.positive,
|
"rate": constraints.positive,
|
||||||
|
|
@ -110,7 +110,7 @@ class Gamma(ExponentialFamily):
|
||||||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||||
return (self.concentration - 1, -self.rate)
|
return (self.concentration - 1, -self.rate)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def _log_normalizer(self, x, y):
|
def _log_normalizer(self, x, y):
|
||||||
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
|
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class GeneralizedPareto(Distribution):
|
||||||
concentration (float or Tensor): Concentration parameter of the distribution
|
concentration (float or Tensor): Concentration parameter of the distribution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"loc": constraints.real,
|
"loc": constraints.real,
|
||||||
"scale": constraints.positive,
|
"scale": constraints.positive,
|
||||||
|
|
@ -131,7 +131,7 @@ class GeneralizedPareto(Distribution):
|
||||||
concentration = self.concentration
|
concentration = self.concentration
|
||||||
valid = concentration < 0.5
|
valid = concentration < 0.5
|
||||||
safe_conc = torch.where(valid, concentration, 0.25)
|
safe_conc = torch.where(valid, concentration, 0.25)
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
|
result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
|
||||||
return torch.where(valid, result, nan)
|
return torch.where(valid, result, nan)
|
||||||
|
|
||||||
|
|
@ -144,7 +144,7 @@ class GeneralizedPareto(Distribution):
|
||||||
return self.loc
|
return self.loc
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def support(self):
|
def support(self):
|
||||||
lower = self.loc
|
lower = self.loc
|
||||||
upper = torch.where(
|
upper = torch.where(
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class Geometric(Distribution):
|
||||||
logits (Number, Tensor): the log-odds of sampling `1`.
|
logits (Number, Tensor): the log-odds of sampling `1`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||||
support = constraints.nonnegative_integer
|
support = constraints.nonnegative_integer
|
||||||
|
|
||||||
|
|
@ -59,11 +59,11 @@ class Geometric(Distribution):
|
||||||
"Either `probs` or `logits` must be specified, but not both."
|
"Either `probs` or `logits` must be specified, but not both."
|
||||||
)
|
)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
(self.probs,) = broadcast_all(probs)
|
(self.probs,) = broadcast_all(probs)
|
||||||
else:
|
else:
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
(self.logits,) = broadcast_all(logits)
|
(self.logits,) = broadcast_all(logits)
|
||||||
probs_or_logits = probs if probs is not None else logits
|
probs_or_logits = probs if probs is not None else logits
|
||||||
if isinstance(probs_or_logits, _Number):
|
if isinstance(probs_or_logits, _Number):
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ class Gumbel(TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -32,10 +32,10 @@ class HalfCauchy(TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"scale": constraints.positive}
|
arg_constraints = {"scale": constraints.positive}
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
support = constraints.nonnegative
|
support = constraints.nonnegative
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
base_dist: Cauchy
|
base_dist: Cauchy
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -32,10 +32,10 @@ class HalfNormal(TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"scale": constraints.positive}
|
arg_constraints = {"scale": constraints.positive}
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
support = constraints.nonnegative
|
support = constraints.nonnegative
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
base_dist: Normal
|
base_dist: Normal
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -91,7 +91,7 @@ class Independent(Distribution, Generic[D]):
|
||||||
return self.base_dist.has_enumerate_support
|
return self.base_dist.has_enumerate_support
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def support(self):
|
def support(self):
|
||||||
result = self.base_dist.support
|
result = self.base_dist.support
|
||||||
if self.reinterpreted_batch_ndims:
|
if self.reinterpreted_batch_ndims:
|
||||||
|
|
|
||||||
|
|
@ -38,10 +38,10 @@ class InverseGamma(TransformedDistribution):
|
||||||
"concentration": constraints.positive,
|
"concentration": constraints.positive,
|
||||||
"rate": constraints.positive,
|
"rate": constraints.positive,
|
||||||
}
|
}
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
support = constraints.positive
|
support = constraints.positive
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
base_dist: Gamma
|
base_dist: Gamma
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -44,7 +44,7 @@ class Kumaraswamy(TransformedDistribution):
|
||||||
"concentration1": constraints.positive,
|
"concentration1": constraints.positive,
|
||||||
"concentration0": constraints.positive,
|
"concentration0": constraints.positive,
|
||||||
}
|
}
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
support = constraints.unit_interval
|
support = constraints.unit_interval
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
|
||||||
|
|
@ -67,7 +67,7 @@ class Kumaraswamy(TransformedDistribution):
|
||||||
AffineTransform(loc=1.0, scale=-1.0),
|
AffineTransform(loc=1.0, scale=-1.0),
|
||||||
PowerTransform(exponent=self.concentration1.reciprocal()),
|
PowerTransform(exponent=self.concentration1.reciprocal()),
|
||||||
]
|
]
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||||
|
|
||||||
def expand(self, batch_shape, _instance=None):
|
def expand(self, batch_shape, _instance=None):
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ class Laplace(Distribution):
|
||||||
scale (float or Tensor): scale of the distribution
|
scale (float or Tensor): scale of the distribution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
|
|
||||||
|
|
@ -60,7 +60,7 @@ class LKJCholesky(Distribution):
|
||||||
Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
|
Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"concentration": constraints.positive}
|
arg_constraints = {"concentration": constraints.positive}
|
||||||
support = constraints.corr_cholesky
|
support = constraints.corr_cholesky
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,10 +32,10 @@ class LogNormal(TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
support = constraints.positive
|
support = constraints.positive
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
base_dist: Normal
|
base_dist: Normal
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -36,10 +36,10 @@ class LogisticNormal(TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
support = constraints.simplex
|
support = constraints.simplex
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
base_dist: Independent[Normal]
|
base_dist: Independent[Normal]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,7 @@ class LowRankMultivariateNormal(Distribution):
|
||||||
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
|
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"loc": constraints.real_vector,
|
"loc": constraints.real_vector,
|
||||||
"cov_factor": constraints.independent(constraints.real, 2),
|
"cov_factor": constraints.independent(constraints.real, 2),
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,7 @@ class MixtureSameFamily(Distribution):
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def support(self):
|
def support(self):
|
||||||
return MixtureSameFamilyConstraint(self._component_distribution.support)
|
return MixtureSameFamilyConstraint(self._component_distribution.support)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class Multinomial(Distribution):
|
||||||
logits (Tensor): event log probabilities (unnormalized)
|
logits (Tensor): event log probabilities (unnormalized)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||||
total_count: int
|
total_count: int
|
||||||
|
|
||||||
|
|
@ -93,7 +93,7 @@ class Multinomial(Distribution):
|
||||||
return self._categorical._new(*args, **kwargs)
|
return self._categorical._new(*args, **kwargs)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=True, event_dim=1)
|
@constraints.dependent_property(is_discrete=True, event_dim=1)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def support(self):
|
def support(self):
|
||||||
return constraints.multinomial(self.total_count)
|
return constraints.multinomial(self.total_count)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,7 @@ class MultivariateNormal(Distribution):
|
||||||
the corresponding lower triangular matrices using a Cholesky decomposition.
|
the corresponding lower triangular matrices using a Cholesky decomposition.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"loc": constraints.real_vector,
|
"loc": constraints.real_vector,
|
||||||
"covariance_matrix": constraints.positive_definite,
|
"covariance_matrix": constraints.positive_definite,
|
||||||
|
|
@ -157,7 +157,7 @@ class MultivariateNormal(Distribution):
|
||||||
"with optional leading batch dimensions"
|
"with optional leading batch dimensions"
|
||||||
)
|
)
|
||||||
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
|
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
|
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
|
||||||
elif covariance_matrix is not None:
|
elif covariance_matrix is not None:
|
||||||
if covariance_matrix.dim() < 2:
|
if covariance_matrix.dim() < 2:
|
||||||
|
|
@ -168,7 +168,7 @@ class MultivariateNormal(Distribution):
|
||||||
batch_shape = torch.broadcast_shapes(
|
batch_shape = torch.broadcast_shapes(
|
||||||
covariance_matrix.shape[:-2], loc.shape[:-1]
|
covariance_matrix.shape[:-2], loc.shape[:-1]
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
|
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
|
||||||
else:
|
else:
|
||||||
assert precision_matrix is not None # helps mypy
|
assert precision_matrix is not None # helps mypy
|
||||||
|
|
@ -180,7 +180,7 @@ class MultivariateNormal(Distribution):
|
||||||
batch_shape = torch.broadcast_shapes(
|
batch_shape = torch.broadcast_shapes(
|
||||||
precision_matrix.shape[:-2], loc.shape[:-1]
|
precision_matrix.shape[:-2], loc.shape[:-1]
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
|
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
|
||||||
self.loc = loc.expand(batch_shape + (-1,))
|
self.loc = loc.expand(batch_shape + (-1,))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ class NegativeBinomial(Distribution):
|
||||||
logits (Tensor): Event log-odds for probabilities of success
|
logits (Tensor): Event log-odds for probabilities of success
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"total_count": constraints.greater_than_eq(0),
|
"total_count": constraints.greater_than_eq(0),
|
||||||
"probs": constraints.half_open_interval(0.0, 1.0),
|
"probs": constraints.half_open_interval(0.0, 1.0),
|
||||||
|
|
@ -55,7 +55,7 @@ class NegativeBinomial(Distribution):
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
(
|
(
|
||||||
self.total_count,
|
self.total_count,
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.probs,
|
self.probs,
|
||||||
) = broadcast_all(total_count, probs)
|
) = broadcast_all(total_count, probs)
|
||||||
self.total_count = self.total_count.type_as(self.probs)
|
self.total_count = self.total_count.type_as(self.probs)
|
||||||
|
|
@ -63,7 +63,7 @@ class NegativeBinomial(Distribution):
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
(
|
(
|
||||||
self.total_count,
|
self.total_count,
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.logits,
|
self.logits,
|
||||||
) = broadcast_all(total_count, logits)
|
) = broadcast_all(total_count, logits)
|
||||||
self.total_count = self.total_count.type_as(self.logits)
|
self.total_count = self.total_count.type_as(self.logits)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ class Normal(ExponentialFamily):
|
||||||
(often referred to as sigma)
|
(often referred to as sigma)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
|
|
@ -89,7 +89,7 @@ class Normal(ExponentialFamily):
|
||||||
if self._validate_args:
|
if self._validate_args:
|
||||||
self._validate_sample(value)
|
self._validate_sample(value)
|
||||||
# compute the variance
|
# compute the variance
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
var = self.scale**2
|
var = self.scale**2
|
||||||
log_scale = (
|
log_scale = (
|
||||||
math.log(self.scale)
|
math.log(self.scale)
|
||||||
|
|
@ -119,6 +119,6 @@ class Normal(ExponentialFamily):
|
||||||
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
def _natural_params(self) -> tuple[Tensor, Tensor]:
|
||||||
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
|
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def _log_normalizer(self, x, y):
|
def _log_normalizer(self, x, y):
|
||||||
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
|
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ class OneHotCategorical(Distribution):
|
||||||
logits (Tensor): event log probabilities (unnormalized)
|
logits (Tensor): event log probabilities (unnormalized)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||||
support = constraints.one_hot
|
support = constraints.one_hot
|
||||||
has_enumerate_support = True
|
has_enumerate_support = True
|
||||||
|
|
|
||||||
|
|
@ -39,7 +39,7 @@ class Pareto(TransformedDistribution):
|
||||||
self.scale, self.alpha = broadcast_all(scale, alpha)
|
self.scale, self.alpha = broadcast_all(scale, alpha)
|
||||||
base_dist = Exponential(self.alpha, validate_args=validate_args)
|
base_dist = Exponential(self.alpha, validate_args=validate_args)
|
||||||
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
|
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||||
|
|
||||||
def expand(
|
def expand(
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ class Poisson(ExponentialFamily):
|
||||||
rate (Number, Tensor): the rate parameter
|
rate (Number, Tensor): the rate parameter
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"rate": constraints.nonnegative}
|
arg_constraints = {"rate": constraints.nonnegative}
|
||||||
support = constraints.nonnegative_integer
|
support = constraints.nonnegative_integer
|
||||||
|
|
||||||
|
|
@ -83,6 +83,6 @@ class Poisson(ExponentialFamily):
|
||||||
def _natural_params(self) -> tuple[Tensor]:
|
def _natural_params(self) -> tuple[Tensor]:
|
||||||
return (torch.log(self.rate),)
|
return (torch.log(self.rate),)
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def _log_normalizer(self, x):
|
def _log_normalizer(self, x):
|
||||||
return torch.exp(x)
|
return torch.exp(x)
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,7 @@ class LogitRelaxedBernoulli(Distribution):
|
||||||
(Jang et al., 2017)
|
(Jang et al., 2017)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
|
|
||||||
|
|
@ -58,12 +58,12 @@ class LogitRelaxedBernoulli(Distribution):
|
||||||
)
|
)
|
||||||
if probs is not None:
|
if probs is not None:
|
||||||
is_scalar = isinstance(probs, _Number)
|
is_scalar = isinstance(probs, _Number)
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
(self.probs,) = broadcast_all(probs)
|
(self.probs,) = broadcast_all(probs)
|
||||||
else:
|
else:
|
||||||
assert logits is not None # helps mypy
|
assert logits is not None # helps mypy
|
||||||
is_scalar = isinstance(logits, _Number)
|
is_scalar = isinstance(logits, _Number)
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
(self.logits,) = broadcast_all(logits)
|
(self.logits,) = broadcast_all(logits)
|
||||||
self._param = self.probs if probs is not None else self.logits
|
self._param = self.probs if probs is not None else self.logits
|
||||||
if is_scalar:
|
if is_scalar:
|
||||||
|
|
@ -141,10 +141,10 @@ class RelaxedBernoulli(TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
support = constraints.unit_interval
|
support = constraints.unit_interval
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
base_dist: LogitRelaxedBernoulli
|
base_dist: LogitRelaxedBernoulli
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,7 @@ class ExpRelaxedCategorical(Distribution):
|
||||||
(Jang et al., 2017)
|
(Jang et al., 2017)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||||
support = (
|
support = (
|
||||||
constraints.real_vector
|
constraints.real_vector
|
||||||
|
|
@ -128,10 +128,10 @@ class RelaxedOneHotCategorical(TransformedDistribution):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
support = constraints.simplex
|
support = constraints.simplex
|
||||||
has_rsample = True
|
has_rsample = True
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
base_dist: ExpRelaxedCategorical
|
base_dist: ExpRelaxedCategorical
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,7 @@ class StudentT(Distribution):
|
||||||
scale (float or Tensor): scale of the distribution
|
scale (float or Tensor): scale of the distribution
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {
|
arg_constraints = {
|
||||||
"df": constraints.positive,
|
"df": constraints.positive,
|
||||||
"loc": constraints.real,
|
"loc": constraints.real,
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,7 @@ class TransformedDistribution(Distribution):
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def support(self):
|
def support(self):
|
||||||
if not self.transforms:
|
if not self.transforms:
|
||||||
return self.base_dist.support
|
return self.base_dist.support
|
||||||
|
|
|
||||||
|
|
@ -226,13 +226,13 @@ class _InverseTransform(Transform):
|
||||||
self._inv: Transform = transform # type: ignore[assignment]
|
self._inv: Transform = transform # type: ignore[assignment]
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def domain(self):
|
def domain(self):
|
||||||
assert self._inv is not None
|
assert self._inv is not None
|
||||||
return self._inv.codomain
|
return self._inv.codomain
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
assert self._inv is not None
|
assert self._inv is not None
|
||||||
return self._inv.domain
|
return self._inv.domain
|
||||||
|
|
@ -302,7 +302,7 @@ class ComposeTransform(Transform):
|
||||||
return self.parts == other.parts
|
return self.parts == other.parts
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def domain(self):
|
def domain(self):
|
||||||
if not self.parts:
|
if not self.parts:
|
||||||
return constraints.real
|
return constraints.real
|
||||||
|
|
@ -318,7 +318,7 @@ class ComposeTransform(Transform):
|
||||||
return domain
|
return domain
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
if not self.parts:
|
if not self.parts:
|
||||||
return constraints.real
|
return constraints.real
|
||||||
|
|
@ -438,14 +438,14 @@ class IndependentTransform(Transform):
|
||||||
)
|
)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def domain(self):
|
def domain(self):
|
||||||
return constraints.independent(
|
return constraints.independent(
|
||||||
self.base_transform.domain, self.reinterpreted_batch_ndims
|
self.base_transform.domain, self.reinterpreted_batch_ndims
|
||||||
)
|
)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
return constraints.independent(
|
return constraints.independent(
|
||||||
self.base_transform.codomain, self.reinterpreted_batch_ndims
|
self.base_transform.codomain, self.reinterpreted_batch_ndims
|
||||||
|
|
@ -513,12 +513,12 @@ class ReshapeTransform(Transform):
|
||||||
super().__init__(cache_size=cache_size)
|
super().__init__(cache_size=cache_size)
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def domain(self):
|
def domain(self):
|
||||||
return constraints.independent(constraints.real, len(self.in_shape))
|
return constraints.independent(constraints.real, len(self.in_shape))
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
return constraints.independent(constraints.real, len(self.out_shape))
|
return constraints.independent(constraints.real, len(self.out_shape))
|
||||||
|
|
||||||
|
|
@ -772,14 +772,14 @@ class AffineTransform(Transform):
|
||||||
return self._event_dim
|
return self._event_dim
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def domain(self):
|
def domain(self):
|
||||||
if self.event_dim == 0:
|
if self.event_dim == 0:
|
||||||
return constraints.real
|
return constraints.real
|
||||||
return constraints.independent(constraints.real, self.event_dim)
|
return constraints.independent(constraints.real, self.event_dim)
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False)
|
@constraints.dependent_property(is_discrete=False)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
if self.event_dim == 0:
|
if self.event_dim == 0:
|
||||||
return constraints.real
|
return constraints.real
|
||||||
|
|
@ -877,7 +877,7 @@ class CorrCholeskyTransform(Transform):
|
||||||
# apply stick-breaking on the squared values
|
# apply stick-breaking on the squared values
|
||||||
# Note that y = sign(r) * sqrt(z * z1m_cumprod)
|
# Note that y = sign(r) * sqrt(z * z1m_cumprod)
|
||||||
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
|
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
z = r**2
|
z = r**2
|
||||||
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
|
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
|
||||||
# Diagonal elements must be 1.
|
# Diagonal elements must be 1.
|
||||||
|
|
@ -1166,14 +1166,14 @@ class CatTransform(Transform):
|
||||||
return all(t.bijective for t in self.transforms)
|
return all(t.bijective for t in self.transforms)
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def domain(self):
|
def domain(self):
|
||||||
return constraints.cat(
|
return constraints.cat(
|
||||||
[t.domain for t in self.transforms], self.dim, self.lengths
|
[t.domain for t in self.transforms], self.dim, self.lengths
|
||||||
)
|
)
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
return constraints.cat(
|
return constraints.cat(
|
||||||
[t.codomain for t in self.transforms], self.dim, self.lengths
|
[t.codomain for t in self.transforms], self.dim, self.lengths
|
||||||
|
|
@ -1246,12 +1246,12 @@ class StackTransform(Transform):
|
||||||
return all(t.bijective for t in self.transforms)
|
return all(t.bijective for t in self.transforms)
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def domain(self):
|
def domain(self):
|
||||||
return constraints.stack([t.domain for t in self.transforms], self.dim)
|
return constraints.stack([t.domain for t in self.transforms], self.dim)
|
||||||
|
|
||||||
@constraints.dependent_property
|
@constraints.dependent_property
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def codomain(self):
|
def codomain(self):
|
||||||
return constraints.stack([t.codomain for t in self.transforms], self.dim)
|
return constraints.stack([t.codomain for t in self.transforms], self.dim)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,7 +79,7 @@ class Uniform(Distribution):
|
||||||
return new
|
return new
|
||||||
|
|
||||||
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def support(self):
|
def support(self):
|
||||||
return constraints.interval(self.low, self.high)
|
return constraints.interval(self.low, self.high)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,7 @@ def _log_modified_bessel_fn(x, order=0):
|
||||||
@torch.jit.script_if_tracing
|
@torch.jit.script_if_tracing
|
||||||
def _rejection_sample(loc, concentration, proposal_r, x):
|
def _rejection_sample(loc, concentration, proposal_r, x):
|
||||||
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
|
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
while not done.all():
|
while not done.all():
|
||||||
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
|
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
|
||||||
u1, u2, u3 = u.unbind()
|
u1, u2, u3 = u.unbind()
|
||||||
|
|
@ -101,7 +101,7 @@ def _rejection_sample(loc, concentration, proposal_r, x):
|
||||||
c = concentration * (proposal_r - f)
|
c = concentration * (proposal_r - f)
|
||||||
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
|
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
|
||||||
if accept.any():
|
if accept.any():
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
|
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
|
||||||
done = done | accept
|
done = done | accept
|
||||||
return (x + math.pi + loc) % (2 * math.pi) - math.pi
|
return (x + math.pi + loc) % (2 * math.pi) - math.pi
|
||||||
|
|
@ -125,7 +125,7 @@ class VonMises(Distribution):
|
||||||
:param torch.Tensor concentration: concentration parameter
|
:param torch.Tensor concentration: concentration parameter
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
|
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
|
||||||
support = constraints.real
|
support = constraints.real
|
||||||
has_rsample = False
|
has_rsample = False
|
||||||
|
|
@ -163,10 +163,10 @@ class VonMises(Distribution):
|
||||||
@lazy_property
|
@lazy_property
|
||||||
def _proposal_r(self) -> Tensor:
|
def _proposal_r(self) -> Tensor:
|
||||||
kappa = self._concentration
|
kappa = self._concentration
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
tau = 1 + (1 + 4 * kappa**2).sqrt()
|
tau = 1 + (1 + 4 * kappa**2).sqrt()
|
||||||
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
|
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
_proposal_r = (1 + rho**2) / (2 * rho)
|
_proposal_r = (1 + rho**2) / (2 * rho)
|
||||||
# second order Taylor expansion around 0 for small kappa
|
# second order Taylor expansion around 0 for small kappa
|
||||||
_proposal_r_taylor = 1 / kappa + kappa
|
_proposal_r_taylor = 1 / kappa + kappa
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class Weibull(TransformedDistribution):
|
||||||
"scale": constraints.positive,
|
"scale": constraints.positive,
|
||||||
"concentration": constraints.positive,
|
"concentration": constraints.positive,
|
||||||
}
|
}
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
support = constraints.positive
|
support = constraints.positive
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
@ -53,7 +53,7 @@ class Weibull(TransformedDistribution):
|
||||||
PowerTransform(exponent=self.concentration_reciprocal),
|
PowerTransform(exponent=self.concentration_reciprocal),
|
||||||
AffineTransform(loc=0, scale=self.scale),
|
AffineTransform(loc=0, scale=self.scale),
|
||||||
]
|
]
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
super().__init__(base_dist, transforms, validate_args=validate_args)
|
super().__init__(base_dist, transforms, validate_args=validate_args)
|
||||||
|
|
||||||
def expand(self, batch_shape, _instance=None):
|
def expand(self, batch_shape, _instance=None):
|
||||||
|
|
|
||||||
|
|
@ -116,13 +116,13 @@ class Wishart(ExponentialFamily):
|
||||||
)
|
)
|
||||||
|
|
||||||
if scale_tril is not None:
|
if scale_tril is not None:
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.scale_tril = param.expand(batch_shape + (-1, -1))
|
self.scale_tril = param.expand(batch_shape + (-1, -1))
|
||||||
elif covariance_matrix is not None:
|
elif covariance_matrix is not None:
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
|
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
|
||||||
elif precision_matrix is not None:
|
elif precision_matrix is not None:
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.precision_matrix = param.expand(batch_shape + (-1, -1))
|
self.precision_matrix = param.expand(batch_shape + (-1, -1))
|
||||||
|
|
||||||
if self.df.lt(event_shape[-1]).any():
|
if self.df.lt(event_shape[-1]).any():
|
||||||
|
|
@ -339,7 +339,7 @@ class Wishart(ExponentialFamily):
|
||||||
p = self._event_shape[-1] # has singleton shape
|
p = self._event_shape[-1] # has singleton shape
|
||||||
return -self.precision_matrix / 2, (nu - p - 1) / 2
|
return -self.precision_matrix / 2, (nu - p - 1) / 2
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def _log_normalizer(self, x, y):
|
def _log_normalizer(self, x, y):
|
||||||
p = self._event_shape[-1]
|
p = self._event_shape[-1]
|
||||||
return (y + (p + 1) / 2) * (
|
return (y + (p + 1) / 2) * (
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user