Fix pyrelfy ignore syntax in distributions and ao (#166248)

Ensures existing pyrefly ignores only ignore the intended error code

pyrefly check
lintrunner

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166248
Approved by: https://github.com/oulgen
This commit is contained in:
Maggie Moss 2025-10-26 22:13:48 +00:00 committed by PyTorch MergeBot
parent a2b6afeac5
commit 154e4d36e9
84 changed files with 251 additions and 251 deletions

View File

@ -620,7 +620,7 @@ class ConvReLU1d(nnqat.Conv1d, nni._FusedModule):
dilation=dilation,
groups=groups,
bias=bias,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
padding_mode=padding_mode,
qconfig=qconfig,
)
@ -821,7 +821,7 @@ class ConvReLU2d(nnqat.Conv2d, nni._FusedModule):
dilation=dilation,
groups=groups,
bias=bias,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
padding_mode=padding_mode,
qconfig=qconfig,
)
@ -1023,7 +1023,7 @@ class ConvReLU3d(nnqat.Conv3d, nni._FusedModule):
dilation=dilation,
groups=groups,
bias=bias,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
padding_mode=padding_mode,
qconfig=qconfig,
)

View File

@ -36,7 +36,7 @@ class LinearReLU(nnqat.Linear, _FusedModule):
torch.Size([128, 30])
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
_FLOAT_MODULE = nni.LinearReLU
def __init__(

View File

@ -30,7 +30,7 @@ class LinearReLU(nnqd.Linear):
torch.Size([128, 30])
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
_FLOAT_MODULE = nni.LinearReLU
def __init__(

View File

@ -54,7 +54,7 @@ class ConvReLU1d(nnq.Conv1d):
dilation=dilation,
groups=groups,
bias=bias,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
padding_mode=padding_mode,
device=device,
dtype=dtype,

View File

@ -114,7 +114,7 @@ class _ConvNd(nn.modules.conv._ConvNd):
assert hasattr(cls, "_FLOAT_RELU_MODULE")
relu = cls._FLOAT_RELU_MODULE()
modules.append(relu)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
fused = cls._FLOAT_MODULE(*modules)
fused.train(self.training)
return fused

View File

@ -50,7 +50,7 @@ class Embedding(nn.Embedding):
scale_grad_by_freq,
sparse,
_weight,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
**factory_kwargs,
)
assert qconfig, "qconfig must be provided for QAT module"

View File

@ -170,11 +170,11 @@ class MultiheadAttention(nn.MultiheadAttention):
observed.linear_K.weight = nn.Parameter(other.k_proj_weight)
observed.linear_V.weight = nn.Parameter(other.v_proj_weight)
if other.in_proj_bias is None:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
observed.linear_Q.bias = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
observed.linear_K.bias = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
observed.linear_V.bias = None
else:
observed.linear_Q.bias = nn.Parameter(
@ -237,7 +237,7 @@ class MultiheadAttention(nn.MultiheadAttention):
_end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wQ
if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
assert all(bQ == 0)
fp.in_proj_bias[_start:_end] = bQ
@ -245,14 +245,14 @@ class MultiheadAttention(nn.MultiheadAttention):
_end = _start + fp.embed_dim
fp.in_proj_weight[_start:_end, :] = wK
if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
assert all(bK == 0)
fp.in_proj_bias[_start:_end] = bK
_start = _end
fp.in_proj_weight[_start:, :] = wV
if fp.in_proj_bias is not None:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
assert all(bV == 0)
fp.in_proj_bias[_start:] = bV
else:
@ -260,11 +260,11 @@ class MultiheadAttention(nn.MultiheadAttention):
fp.k_proj_weight = nn.Parameter(wK)
fp.v_proj_weight = nn.Parameter(wV)
if fp.in_proj_bias is None:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.linear_Q.bias = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.linear_K.bias = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.linear_V.bias = None
else:
fp.in_proj_bias[0 : fp.embed_dim] = bQ
@ -472,7 +472,7 @@ class MultiheadAttention(nn.MultiheadAttention):
assert static_v.size(2) == head_dim
v = static_v
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
src_len = k.size(1)
if key_padding_mask is not None:
@ -481,35 +481,35 @@ class MultiheadAttention(nn.MultiheadAttention):
if self.add_zero_attn:
src_len += 1
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
k_zeros = torch.zeros((k.size(0), 1) + k.size()[2:])
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if k.is_quantized:
k_zeros = torch.quantize_per_tensor(
k_zeros,
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
k.q_scale(),
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
k.q_zero_point(),
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
k.dtype,
)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
k = torch.cat([k, k_zeros], dim=1)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
v_zeros = torch.zeros((v.size(0), 1) + k.size()[2:])
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if v.is_quantized:
v_zeros = torch.quantize_per_tensor(
v_zeros,
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
v.q_scale(),
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
v.q_zero_point(),
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
v.dtype,
)
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
v = torch.cat([v, v_zeros], dim=1)
if attn_mask is not None:

View File

@ -376,7 +376,7 @@ class _LSTMLayer(torch.nn.Module):
bidirectional,
split_gates=split_gates,
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
layer.qconfig = getattr(other, "qconfig", qconfig)
wi = getattr(other, f"weight_ih_l{layer_idx}")
wh = getattr(other, f"weight_hh_l{layer_idx}")
@ -455,7 +455,7 @@ class LSTM(torch.nn.Module):
if (
not isinstance(dropout, numbers.Number)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
or not 0 <= dropout <= 1
or isinstance(dropout, bool)
):
@ -464,7 +464,7 @@ class LSTM(torch.nn.Module):
"representing the probability of an element being "
"zeroed"
)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
if dropout > 0:
warnings.warn(
"dropout option for quantizable LSTM is ignored. "
@ -578,7 +578,7 @@ class LSTM(torch.nn.Module):
other.bidirectional,
split_gates=split_gates,
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
observed.qconfig = getattr(other, "qconfig", qconfig)
for idx in range(other.num_layers):
observed.layers[idx] = _LSTMLayer.from_float(

View File

@ -74,7 +74,7 @@ class Conv1d(nnq.Conv1d):
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size)
stride = _single(stride)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation)

View File

@ -119,9 +119,9 @@ class Linear(nnq.Linear):
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) is nni.LinearReLU:
mod = mod[0]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if mod.qconfig is not None and mod.qconfig.weight is not None:
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:
@ -145,7 +145,7 @@ class Linear(nnq.Linear):
"Unsupported dtype specified for dynamic quantized Linear!"
)
qlinear = cls(mod.in_features, mod.out_features, dtype=dtype)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
qlinear.set_weight_bias(qweight, mod.bias)
return qlinear

View File

@ -522,7 +522,7 @@ class LSTM(RNNBase):
>>> output, (hn, cn) = rnn(input, (h0, c0))
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
_FLOAT_MODULE = nn.LSTM
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}
@ -808,7 +808,7 @@ class GRU(RNNBase):
>>> output, hn = rnn(input, h0)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
_FLOAT_MODULE = nn.GRU
__overloads__ = {"forward": ["forward_packed", "forward_tensor"]}

View File

@ -67,9 +67,9 @@ class Hardswish(torch.nn.Hardswish):
def __init__(self, scale, zero_point, device=None, dtype=None):
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__()
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -140,9 +140,9 @@ class LeakyReLU(torch.nn.LeakyReLU):
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(negative_slope, inplace)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -230,7 +230,7 @@ class Softmax(torch.nn.Softmax):
class MultiheadAttention(torch.ao.nn.quantizable.MultiheadAttention):
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
_FLOAT_MODULE = torch.ao.nn.quantizable.MultiheadAttention
def _get_name(self):

View File

@ -12,9 +12,9 @@ class _BatchNorm(torch.nn.modules.batchnorm._BatchNorm):
) -> None:
factory_kwargs = {"device": device, "dtype": dtype}
super().__init__(num_features, eps, momentum, True, True, **factory_kwargs)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(1.0, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(0, **factory_kwargs))
@staticmethod

View File

@ -408,7 +408,7 @@ class Conv1d(_ConvNd):
factory_kwargs = {"device": device, "dtype": dtype}
kernel_size = _single(kernel_size)
stride = _single(stride)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
padding = padding if isinstance(padding, str) else _single(padding)
dilation = _single(dilation)

View File

@ -310,7 +310,7 @@ class Linear(WeightedQuantizedModule):
# the type mismatch in assignment. Also, mypy has an issue with
# iterables not being implemented, so we are ignoring those too.
if not isinstance(cls._FLOAT_MODULE, Iterable):
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
cls._FLOAT_MODULE = [cls._FLOAT_MODULE]
supported_modules = ", ".join(
[float_mod.__name__ for float_mod in cls._FLOAT_MODULE]

View File

@ -37,14 +37,14 @@ class LayerNorm(torch.nn.LayerNorm):
normalized_shape,
eps=eps,
elementwise_affine=elementwise_affine,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
**factory_kwargs,
)
self.weight = weight
self.bias = bias
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -116,9 +116,9 @@ class GroupNorm(torch.nn.GroupNorm):
super().__init__(num_groups, num_channels, eps, affine, **factory_kwargs)
self.weight = weight
self.bias = bias
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -180,9 +180,9 @@ class InstanceNorm1d(torch.nn.InstanceNorm1d):
)
self.weight = weight
self.bias = bias
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -249,9 +249,9 @@ class InstanceNorm2d(torch.nn.InstanceNorm2d):
)
self.weight = weight
self.bias = bias
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):
@ -318,9 +318,9 @@ class InstanceNorm3d(torch.nn.InstanceNorm3d):
)
self.weight = weight
self.bias = bias
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("scale", torch.tensor(scale, **factory_kwargs))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.register_buffer("zero_point", torch.tensor(zero_point, **factory_kwargs))
def forward(self, input):

View File

@ -141,7 +141,7 @@ class Conv2d(_ConvNd, nn.Conv2d):
dilation,
groups,
bias,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
padding_mode,
device,
dtype,
@ -206,7 +206,7 @@ class Conv3d(_ConvNd, nn.Conv3d):
dilation,
groups,
bias,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
padding_mode,
device,
dtype,
@ -383,7 +383,7 @@ class ConvTranspose2d(_ConvTransposeNd, nn.ConvTranspose2d):
groups,
bias,
dilation,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
padding_mode,
device,
dtype,
@ -465,7 +465,7 @@ class ConvTranspose3d(_ConvTransposeNd, nn.ConvTranspose3d):
groups,
bias,
dilation,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
padding_mode,
device,
dtype,

View File

@ -664,7 +664,7 @@ class LSTM(RNNBase):
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
batch_sizes,
sorted_indices,
unsorted_indices,
@ -828,7 +828,7 @@ class GRU(RNNBase):
if isinstance(orig_input, PackedSequence):
output_packed = PackedSequence(
output,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
batch_sizes,
sorted_indices,
unsorted_indices,

View File

@ -42,7 +42,7 @@ class Embedding(nn.Embedding, ReferenceQuantizedModule):
scale_grad_by_freq,
sparse,
_weight,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
device,
dtype,
)

View File

@ -18,7 +18,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
"scale": 1.0,
"zero_point": 0,
}
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.weight_qscheme: torch.qscheme = weight_qparams["qscheme"]
self.weight_dtype = weight_qparams["dtype"]
assert self.weight_qscheme in [
@ -81,14 +81,14 @@ class ReferenceQuantizedModule(torch.nn.Module):
self.register_buffer(
"weight_axis", torch.tensor(0, dtype=torch.int, device=device)
)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.is_decomposed: bool = weight_qparams.get("is_decomposed", False)
# store weight_axis as weight_axis_int due to some constraints of torchdynamo.export
# for capturing `.item` operations
self.weight_axis_int: int = self.weight_axis.item() # type: ignore[operator, assignment]
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.weight_quant_min: typing.Optional[int] = weight_qparams.get("quant_min")
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
self.weight_quant_max: typing.Optional[int] = weight_qparams.get("quant_max")
def get_weight(self):
@ -105,7 +105,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_and_dequantize_weight_decomposed(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
@ -117,7 +117,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_and_dequantize_weight(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
@ -133,7 +133,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_weight_decomposed(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,
@ -145,7 +145,7 @@ class ReferenceQuantizedModule(torch.nn.Module):
return _quantize_weight(
self.weight, # type: ignore[arg-type]
self.weight_qscheme,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
self.weight_dtype,
self.weight_scale,
self.weight_zero_point,

View File

@ -151,9 +151,9 @@ class Linear(torch.nn.Module):
assert hasattr(mod, "qconfig"), "Input float module must have qconfig defined"
if type(mod) is nni.LinearReLU:
mod = mod[0]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
if mod.qconfig is not None and mod.qconfig.weight is not None:
# pyrefly: ignore # not-callable
# pyrefly: ignore [not-callable]
weight_observer = mod.qconfig.weight()
else:
# We have the circular import issues if we import the qconfig in the beginning of this file:
@ -187,6 +187,6 @@ class Linear(torch.nn.Module):
col_block_size,
dtype=dtype,
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
qlinear.set_weight_bias(qweight, mod.bias, row_block_size, col_block_size)
return qlinear

View File

@ -959,7 +959,7 @@ def create_a_shadows_b(
if should_log_inputs:
# skip the input logger when inserting a dtype cast
if isinstance(prev_node_c, Node):
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
prev_node_c = get_normalized_nth_input(node_c, gm_b, 0)
elif isinstance(prev_node_c, list):
prev_node_c = [
@ -968,7 +968,7 @@ def create_a_shadows_b(
]
dtype_cast_node = _insert_dtype_cast_after_node(
subgraph_a.start_node,
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
node_c,
prev_node_c,
gm_a,
@ -1052,7 +1052,7 @@ def create_a_shadows_b(
if num_non_param_args_node_a == 2:
# node_c_second_non_param_arg = node_c.args[1]
node_c_second_non_param_arg = get_normalized_nth_input(
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
node_c,
gm_b,
1,
@ -1063,7 +1063,7 @@ def create_a_shadows_b(
subgraph_a,
gm_a,
gm_b,
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
node_c.name + "_shadow_copy_",
)
env_c[node_a_shadows_c.name] = node_a_shadows_c
@ -1086,19 +1086,19 @@ def create_a_shadows_b(
cur_node = node_a_shadows_c
while get_normalized_nth_input(cur_node, gm_b, 0) != input_logger: # type: ignore[possibly-undefined]
cur_node = get_normalized_nth_input(cur_node, gm_b, 0) # type: ignore[assignment]
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
if isinstance(input_logger, Node):
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
input_logger_mod = getattr(gm_b, input_logger.name)
input_logger_mod.ref_node_name = cur_node.name
else:
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
if not isinstance(input_logger, list):
raise AssertionError(
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
f"Expected list, got {type(input_logger)}"
)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
for input_logger_inner in input_logger:
input_logger_mod = getattr(gm_b, input_logger_inner.name)
input_logger_mod.ref_node_name = cur_node.name

View File

@ -419,7 +419,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
target2,
) in _lower_to_native_backend.STATIC_LOWER_FUNCTIONAL_MAP.items():
new_connections.append((source, target1))
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
new_connections.append((source, target2))
for source_to_target in (
@ -428,7 +428,7 @@ def get_base_name_to_sets_of_related_ops() -> dict[str, set[NSNodeTargetType]]:
quantization_mappings.DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS,
):
for source, target in source_to_target.items(): # type:ignore[assignment]
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
new_connections.append((source, target))
#

View File

@ -94,10 +94,10 @@ class OutputProp:
)
if isinstance(result, torch.Tensor): # type: ignore[possibly-undefined]
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
node.traced_result = result
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
env[node.name] = result
return None
@ -403,10 +403,10 @@ def create_submodule_from_subgraph(
cur_name_idx += 1
setattr(gm, mod_name, new_arg)
new_arg_placeholder = gm.placeholder(mod_name) # type: ignore[operator]
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
cur_args_copy.append(new_arg_placeholder)
elif isinstance(arg, (float, int, torch.dtype)):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
cur_args_copy.append(arg)
else:
raise AssertionError(f"arg of type {type(arg)} not handled yet")
@ -818,7 +818,7 @@ def create_add_loggers_graph(
model,
cur_subgraph_idx,
match_name,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
maybe_subgraph,
[qconfig_mapping],
[node_name_to_qconfig],
@ -879,7 +879,7 @@ def create_add_loggers_graph(
cur_node_orig = first_node
cur_node_copy = None
first_node_copy = None
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
while cur_node_orig in subgraph_to_use:
# TODO(future PR): make this support all possible args/kwargs
if cur_node_orig is first_node:

View File

@ -496,7 +496,7 @@ def compute_normalized_l2_error(x: torch.Tensor, y: torch.Tensor) -> torch.Tenso
Return:
float or tuple of floats
"""
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
return torch.sqrt(((x - y) ** 2).sum() / (x**2).sum())

View File

@ -23,17 +23,17 @@ class SparseDLRM(DLRM_Net):
super().__init__(**args)
def forward(self, dense_x, lS_o, lS_i):
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
x = self.apply_mlp(dense_x, self.bot_l) # dense features
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
ly = self.apply_emb(lS_o, lS_i, self.emb_l, self.v_W_l) # apply embedding bag
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
z = self.interact_features(x, ly)
z = z.to_sparse_coo()
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
z = torch.mm(z, self.top_l[0].weight.T).add(self.top_l[0].bias)
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
for layer in self.top_l[1:]:
z = layer(z)

View File

@ -27,7 +27,7 @@ def run_forward(model, **batch):
model(X, lS_o, lS_i)
end = time.time()
time_taken = end - start
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
time_list.append(time_taken)
avg_time = np.mean(time_list[1:])
return avg_time

View File

@ -72,7 +72,7 @@ class FPGMPruner(BaseStructuredSparsifier):
dist_matrix = self.dist_fn(t_flatten)
# more similar with other filter indicates large in the sum of row
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
distance = torch.sum(torch.abs(dist_matrix), 1)
return distance

View File

@ -260,7 +260,7 @@ class BaseStructuredSparsifier(BaseSparsifier):
module.register_parameter(
"_bias", nn.Parameter(module.bias.detach())
)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
module.bias = None
module.prune_bias = prune_bias

View File

@ -97,7 +97,7 @@ def _propagate_module_bias(module: nn.Module, mask: Tensor) -> Optional[Tensor]:
if module.bias is not None:
module.bias = nn.Parameter(cast(Tensor, module.bias)[mask])
elif getattr(module, "_bias", None) is not None:
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
module.bias = nn.Parameter(cast(Tensor, module._bias)[mask])
# get pruned biases to propagate to subsequent layer
@ -127,7 +127,7 @@ def _prune_linear_helper(linear: nn.Linear) -> Tensor:
linear.out_features = linear.weight.shape[0]
_remove_bias_handles(linear)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return mask
@ -186,7 +186,7 @@ def _prune_conv2d_helper(conv2d: nn.Conv2d) -> Tensor:
conv2d.out_channels = conv2d.weight.shape[0]
_remove_bias_handles(conv2d)
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
return mask
@ -207,7 +207,7 @@ def prune_conv2d_padded(conv2d_1: nn.Conv2d) -> None:
new_bias = torch.zeros(conv2d_1.bias.shape)
new_bias[mask] = conv2d_1.bias[mask] # type: ignore[possibly-undefined]
# adjusted bias that to keep in conv2d_1
# pyrefly: ignore # unbound-name
# pyrefly: ignore [unbound-name]
new_bias[~mask] = cast(Tensor, conv2d_1._bias)[~mask]
# pruned biases that are kept instead of propagated
conv2d_1.bias = nn.Parameter(new_bias)

View File

@ -170,7 +170,7 @@ class BaseSparsifier(abc.ABC):
self.make_config_from_model(model)
# TODO: Remove the configuration by reference ('module')
# pyrefly: ignore # not-iterable
# pyrefly: ignore [not-iterable]
for module_config in self.config:
assert isinstance(module_config, dict), (
"config elements should be dicts not modules i.e.:"

View File

@ -51,7 +51,7 @@ def swap_module(
new_mod.register_forward_hook(hook_fn)
# respect device affinity when swapping modules
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
devices = {p.device for p in chain(mod.parameters(), mod.buffers())}
assert len(devices) <= 1, (
f"swap_module only works with cpu or single-device CUDA modules, but got devices {devices}"

View File

@ -235,7 +235,7 @@ class WeightNormSparsifier(BaseSparsifier):
ww = self.norm_fn(getattr(module, tensor_name))
tensor_mask = self._make_tensor_mask(
data=ww,
# pyrefly: ignore # missing-attribute
# pyrefly: ignore [missing-attribute]
input_shape=ww.shape,
sparsity_level=sparsity_level,
sparse_block_shape=sparse_block_shape,

View File

@ -25,7 +25,7 @@ from .pt2e.export_utils import (
_move_exported_model_to_train as move_exported_model_to_train,
)
# pyrefly: ignore # deprecated
# pyrefly: ignore [deprecated]
from .qconfig import * # noqa: F403
from .qconfig_mapping import * # noqa: F403
from .quant_type import * # noqa: F403

View File

@ -185,9 +185,9 @@ class FakeQuantize(FakeQuantizeBase):
dtype = getattr(getattr(observer, "p", {}), "keywords", {}).get(
"dtype", dtype
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
assert torch.iinfo(dtype).min <= quant_min, "quant_min out of bound"
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
assert quant_max <= torch.iinfo(dtype).max, "quant_max out of bound"
observer_kwargs.update({"quant_min": quant_min, "quant_max": quant_max})
observer_kwargs["is_dynamic"] = is_dynamic

View File

@ -1149,7 +1149,7 @@ quantized_decomposed_lib.define(
class FakeQuantPerChannel(torch.autograd.Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, input, scales, zero_points, axis, quant_min, quant_max):
if scales.dtype != torch.float32:
scales = scales.to(torch.float32)
@ -1172,7 +1172,7 @@ class FakeQuantPerChannel(torch.autograd.Function):
return out
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, gy):
(mask,) = ctx.saved_tensors
return gy * mask, None, None, None, None, None

View File

@ -248,7 +248,7 @@ def calculate_equalization_scale(
class EqualizationQConfig(
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
namedtuple("EqualizationQConfig", ["input_activation", "weight"])
):
"""
@ -463,7 +463,7 @@ def maybe_get_next_equalization_scale(
In this case, the node given is linear1 and we want to locate the InputEqObs.
"""
next_inp_eq_obs = maybe_get_next_input_eq_obs(node, modules)
# pyrefly: ignore # invalid-argument
# pyrefly: ignore [invalid-argument]
if next_inp_eq_obs:
if (
next_inp_eq_obs.equalization_scale.nelement() == 1
@ -827,7 +827,7 @@ def convert_eq_obs(
scale_weight_node(
node,
modules,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
equalization_scale,
maybe_next_equalization_scale,
)
@ -836,7 +836,7 @@ def convert_eq_obs(
node,
model,
modules,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
equalization_scale,
maybe_next_equalization_scale,
)

View File

@ -223,7 +223,7 @@ class ModelReportVisualizer:
feature_val = feature_val.item()
# we add to our list of values
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
tensor_table_row.append(feature_val)
tensor_table.append(tensor_table_row)
@ -284,7 +284,7 @@ class ModelReportVisualizer:
feature_val = feature_val.item()
# add value to channel specific row
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
new_channel_row.append(feature_val)
# add to table and increment row index counter

View File

@ -166,7 +166,7 @@ def _create_obs_or_fq_from_qspec(
}
edge_or_nodes = quantization_spec.derived_from
obs_or_fqs = [obs_or_fq_map[k] for k in edge_or_nodes]
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
kwargs["obs_or_fqs"] = obs_or_fqs
return _DerivedObserverOrFakeQuantize.with_args(**kwargs)()
elif isinstance(quantization_spec, FixedQParamsQuantizationSpec):
@ -2088,11 +2088,11 @@ def prepare(
root_node_getter_mapping = get_fusion_pattern_to_root_node_getter(backend_config)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_update_qconfig_for_fusion(model, qconfig_mapping)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_update_qconfig_for_fusion(model, _equalization_config)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
flattened_qconfig_dict = _get_flattened_qconfig_dict(qconfig_mapping)
# TODO: support regex as well
propagate_qconfig_(model, flattened_qconfig_dict, prepare_custom_config.to_dict())
@ -2100,7 +2100,7 @@ def prepare(
if is_qat:
module_to_qat_module = get_module_to_qat_module(backend_config)
_qat_swap_modules(model, module_to_qat_module)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_update_qconfig_for_qat(qconfig_mapping, backend_config)
# mapping from fully qualified module name to module instance
@ -2117,7 +2117,7 @@ def prepare(
model,
named_modules,
model.graph,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_equalization_config,
node_name_to_scope,
)
@ -2125,7 +2125,7 @@ def prepare(
model,
named_modules,
model.graph,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
qconfig_mapping,
node_name_to_scope,
)
@ -2187,7 +2187,7 @@ def prepare(
node_name_to_scope,
prepare_custom_config,
equalization_node_name_to_qconfig,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
qconfig_mapping,
is_qat,
observed_node_names,

View File

@ -721,7 +721,7 @@ def _maybe_get_custom_module_lstm_from_node_arg(
a = a.args[0][0] # type: ignore[assignment,index]
else:
a = a.args[0] # type: ignore[assignment]
# pyrefly: ignore # bad-return
# pyrefly: ignore [bad-return]
return a
all_match_patterns = [

View File

@ -281,12 +281,12 @@ class UniformQuantizationObserverBase(ObserverBase):
)
self.has_customized_qrange = (quant_min is not None) and (quant_max is not None)
if self.has_customized_qrange:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
validate_qmin_qmax(quant_min, quant_max)
self.quant_min, self.quant_max = calculate_qmin_qmax(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
quant_min,
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
quant_max,
self.has_customized_qrange,
self.dtype,

View File

@ -83,7 +83,7 @@ __all__ = [
]
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
class QConfig(namedtuple("QConfig", ["activation", "weight"])):
"""
Describes how to quantize a layer or a part of the network by providing
@ -121,7 +121,7 @@ class QConfig(namedtuple("QConfig", ["activation", "weight"])):
"`QConfigDynamic` is going to be deprecated in PyTorch 1.12, please use `QConfig` instead",
category=FutureWarning,
)
# pyrefly: ignore # invalid-inheritance
# pyrefly: ignore [invalid-inheritance]
class QConfigDynamic(namedtuple("QConfigDynamic", ["activation", "weight"])):
"""
Describes how to dynamically quantize a layer or a part of the network by providing

View File

@ -418,7 +418,7 @@ class X86InductorQuantizer(Quantizer):
# As we use `_need_skip_config` to skip all invalid configurations,
# we can safely assume that the all existing non-None configurations
# have the same quantization mode.
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
for qconfig in (
list(self.module_name_qconfig.values())
+ list(self.operator_type_qconfig.values())
@ -818,7 +818,7 @@ class X86InductorQuantizer(Quantizer):
)
binary_node.meta[QUANT_ANNOTATION_KEY] = (
_X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
input_qspec_map=binary_node_input_qspec_map,
_annotated=True,
)
@ -889,7 +889,7 @@ class X86InductorQuantizer(Quantizer):
)
binary_node.meta[QUANT_ANNOTATION_KEY] = (
_X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
input_qspec_map=binary_node_input_qspec_map,
# TODO<leslie> Remove the annotate of output in QAT when qat util support pattern matcher.
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
@ -1097,7 +1097,7 @@ class X86InductorQuantizer(Quantizer):
quantization_config
)
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
input_qspec_map=binary_node_input_qspec_map,
_annotated=True,
)
@ -1152,7 +1152,7 @@ class X86InductorQuantizer(Quantizer):
quantization_config
)
binary_node.meta[QUANT_ANNOTATION_KEY] = _X86InductorQuantizationAnnotation(
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
input_qspec_map=binary_node_input_qspec_map,
_annotated=True,
_is_output_of_quantized_pattern=True,
@ -1508,7 +1508,7 @@ class X86InductorQuantizer(Quantizer):
has_unary = unary_op is not None
seq_partition = [torch.nn.Linear, binary_op]
if has_unary:
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
seq_partition.append(unary_op)
fused_partitions = find_sequential_partitions(gm, seq_partition)
for fused_partition in fused_partitions:

View File

@ -376,11 +376,11 @@ def _do_annotate_conv_relu(
input_qspec_map[bias] = get_bias_qspec(quantization_config)
partition.append(bias)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if _is_annotated(partition):
continue
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
if filter_fn and any(not filter_fn(n) for n in partition):
continue
@ -391,7 +391,7 @@ def _do_annotate_conv_relu(
output_qspec=get_output_act_qspec(quantization_config), # type: ignore[arg-type]
_annotated=True,
)
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
_mark_nodes_as_annotated(partition)
annotated_partitions.append(partition)
return annotated_partitions

View File

@ -39,7 +39,7 @@ class Bernoulli(ExponentialFamily):
validate_args (bool, optional): whether to validate arguments, None by default
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.boolean
has_enumerate_support = True
@ -57,12 +57,12 @@ class Bernoulli(ExponentialFamily):
)
if probs is not None:
is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
@ -140,6 +140,6 @@ class Bernoulli(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]:
return (torch.logit(self.probs),)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _log_normalizer(self, x):
return torch.log1p(torch.exp(x))

View File

@ -31,7 +31,7 @@ class Beta(ExponentialFamily):
(often referred to as beta)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {
"concentration1": constraints.positive,
"concentration0": constraints.positive,
@ -114,6 +114,6 @@ class Beta(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.concentration1, self.concentration0)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _log_normalizer(self, x, y):
return torch.lgamma(x) + torch.lgamma(y) - torch.lgamma(x + y)

View File

@ -45,7 +45,7 @@ class Binomial(Distribution):
logits (Tensor): Event log-odds
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {
"total_count": constraints.nonnegative_integer,
"probs": constraints.unit_interval,
@ -67,7 +67,7 @@ class Binomial(Distribution):
if probs is not None:
(
self.total_count,
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.probs,
) = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs)
@ -75,7 +75,7 @@ class Binomial(Distribution):
assert logits is not None # helps mypy
(
self.total_count,
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.logits,
) = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits)
@ -102,7 +102,7 @@ class Binomial(Distribution):
return self._param.new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=0)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def support(self):
return constraints.integer_interval(0, self.total_count)

View File

@ -50,7 +50,7 @@ class Categorical(Distribution):
logits (Tensor): event log probabilities (unnormalized)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
has_enumerate_support = True
@ -67,14 +67,14 @@ class Categorical(Distribution):
if probs is not None:
if probs.dim() < 1:
raise ValueError("`probs` parameter must be at least one-dimensional.")
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.probs = probs / probs.sum(-1, keepdim=True)
else:
assert logits is not None # helps mypy
if logits.dim() < 1:
raise ValueError("`logits` parameter must be at least one-dimensional.")
# Normalize
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.logits = logits - logits.logsumexp(dim=-1, keepdim=True)
self._param = self.probs if probs is not None else self.logits
self._num_events = self._param.size()[-1]
@ -102,7 +102,7 @@ class Categorical(Distribution):
return self._param.new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=0)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def support(self):
return constraints.integer_interval(0, self._num_events - 1)

View File

@ -31,7 +31,7 @@ class Cauchy(Distribution):
scale (float or Tensor): half width at half maximum.
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
has_rsample = True

View File

@ -47,7 +47,7 @@ class ContinuousBernoulli(ExponentialFamily):
https://arxiv.org/abs/1907.06845
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.unit_interval
_mean_carrier_measure = 0
@ -66,19 +66,19 @@ class ContinuousBernoulli(ExponentialFamily):
)
if probs is not None:
is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs)
# validate 'probs' here if necessary as it is later clamped for numerical stability
# close to 0 and 1, later on; otherwise the clamped 'probs' would always pass
if validate_args is not None:
if not self.arg_constraints["probs"].check(self.probs).all():
raise ValueError("The parameter probs has invalid values")
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.probs = clamp_probs(self.probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
@ -234,7 +234,7 @@ class ContinuousBernoulli(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]:
return (self.logits,)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _log_normalizer(self, x):
"""computes the log normalizing constant as a function of the natural parameter"""
out_unst_reg = torch.max(

View File

@ -22,7 +22,7 @@ def _Dirichlet_backward(x, concentration, grad_output):
class _Dirichlet(Function):
@staticmethod
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def forward(ctx, concentration):
x = torch._sample_dirichlet(concentration)
ctx.save_for_backward(x, concentration)
@ -30,7 +30,7 @@ class _Dirichlet(Function):
@staticmethod
@once_differentiable
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def backward(ctx, grad_output):
x, concentration = ctx.saved_tensors
return _Dirichlet_backward(x, concentration, grad_output)
@ -52,7 +52,7 @@ class Dirichlet(ExponentialFamily):
(often referred to as alpha)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {
"concentration": constraints.independent(constraints.positive, 1)
}
@ -133,6 +133,6 @@ class Dirichlet(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]:
return (self.concentration,)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _log_normalizer(self, x):
return x.lgamma().sum(-1) - torch.lgamma(x.sum(-1))

View File

@ -27,7 +27,7 @@ class Exponential(ExponentialFamily):
rate (float or Tensor): rate = 1 / scale of the distribution
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"rate": constraints.positive}
support = constraints.nonnegative
has_rsample = True
@ -90,6 +90,6 @@ class Exponential(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]:
return (-self.rate,)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _log_normalizer(self, x):
return -torch.log(-x)

View File

@ -29,7 +29,7 @@ class FisherSnedecor(Distribution):
df2 (float or Tensor): degrees of freedom parameter 2
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"df1": constraints.positive, "df2": constraints.positive}
support = constraints.positive
has_rsample = True

View File

@ -34,7 +34,7 @@ class Gamma(ExponentialFamily):
(often referred to as beta), rate = 1 / scale
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {
"concentration": constraints.positive,
"rate": constraints.positive,
@ -110,7 +110,7 @@ class Gamma(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.concentration - 1, -self.rate)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _log_normalizer(self, x, y):
return torch.lgamma(x + 1) + (x + 1) * torch.log(-y.reciprocal())

View File

@ -35,7 +35,7 @@ class GeneralizedPareto(Distribution):
concentration (float or Tensor): Concentration parameter of the distribution
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {
"loc": constraints.real,
"scale": constraints.positive,
@ -131,7 +131,7 @@ class GeneralizedPareto(Distribution):
concentration = self.concentration
valid = concentration < 0.5
safe_conc = torch.where(valid, concentration, 0.25)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
result = self.scale**2 / ((1 - safe_conc) ** 2 * (1 - 2 * safe_conc))
return torch.where(valid, result, nan)
@ -144,7 +144,7 @@ class GeneralizedPareto(Distribution):
return self.loc
@constraints.dependent_property(is_discrete=False, event_dim=0)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def support(self):
lower = self.loc
upper = torch.where(

View File

@ -44,7 +44,7 @@ class Geometric(Distribution):
logits (Number, Tensor): the log-odds of sampling `1`.
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.nonnegative_integer
@ -59,11 +59,11 @@ class Geometric(Distribution):
"Either `probs` or `logits` must be specified, but not both."
)
if probs is not None:
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, _Number):

View File

@ -32,7 +32,7 @@ class Gumbel(TransformedDistribution):
"""
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
support = constraints.real
def __init__(

View File

@ -32,10 +32,10 @@ class HalfCauchy(TransformedDistribution):
"""
arg_constraints = {"scale": constraints.positive}
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
support = constraints.nonnegative
has_rsample = True
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
base_dist: Cauchy
def __init__(

View File

@ -32,10 +32,10 @@ class HalfNormal(TransformedDistribution):
"""
arg_constraints = {"scale": constraints.positive}
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
support = constraints.nonnegative
has_rsample = True
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
base_dist: Normal
def __init__(

View File

@ -91,7 +91,7 @@ class Independent(Distribution, Generic[D]):
return self.base_dist.has_enumerate_support
@constraints.dependent_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def support(self):
result = self.base_dist.support
if self.reinterpreted_batch_ndims:

View File

@ -38,10 +38,10 @@ class InverseGamma(TransformedDistribution):
"concentration": constraints.positive,
"rate": constraints.positive,
}
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
support = constraints.positive
has_rsample = True
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
base_dist: Gamma
def __init__(

View File

@ -44,7 +44,7 @@ class Kumaraswamy(TransformedDistribution):
"concentration1": constraints.positive,
"concentration0": constraints.positive,
}
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
support = constraints.unit_interval
has_rsample = True
@ -67,7 +67,7 @@ class Kumaraswamy(TransformedDistribution):
AffineTransform(loc=1.0, scale=-1.0),
PowerTransform(exponent=self.concentration1.reciprocal()),
]
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):

View File

@ -28,7 +28,7 @@ class Laplace(Distribution):
scale (float or Tensor): scale of the distribution
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
has_rsample = True

View File

@ -60,7 +60,7 @@ class LKJCholesky(Distribution):
Journal of Multivariate Analysis. 100. 10.1016/j.jmva.2009.04.008
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"concentration": constraints.positive}
support = constraints.corr_cholesky

View File

@ -32,10 +32,10 @@ class LogNormal(TransformedDistribution):
"""
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
support = constraints.positive
has_rsample = True
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
base_dist: Normal
def __init__(

View File

@ -36,10 +36,10 @@ class LogisticNormal(TransformedDistribution):
"""
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
support = constraints.simplex
has_rsample = True
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
base_dist: Independent[Normal]
def __init__(

View File

@ -86,7 +86,7 @@ class LowRankMultivariateNormal(Distribution):
capacitance = I + cov_factor.T @ inv(cov_diag) @ cov_factor
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {
"loc": constraints.real_vector,
"cov_factor": constraints.independent(constraints.real, 2),

View File

@ -124,7 +124,7 @@ class MixtureSameFamily(Distribution):
return new
@constraints.dependent_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def support(self):
return MixtureSameFamilyConstraint(self._component_distribution.support)

View File

@ -50,7 +50,7 @@ class Multinomial(Distribution):
logits (Tensor): event log probabilities (unnormalized)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
total_count: int
@ -93,7 +93,7 @@ class Multinomial(Distribution):
return self._categorical._new(*args, **kwargs)
@constraints.dependent_property(is_discrete=True, event_dim=1)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def support(self):
return constraints.multinomial(self.total_count)

View File

@ -123,7 +123,7 @@ class MultivariateNormal(Distribution):
the corresponding lower triangular matrices using a Cholesky decomposition.
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {
"loc": constraints.real_vector,
"covariance_matrix": constraints.positive_definite,
@ -157,7 +157,7 @@ class MultivariateNormal(Distribution):
"with optional leading batch dimensions"
)
batch_shape = torch.broadcast_shapes(scale_tril.shape[:-2], loc.shape[:-1])
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.scale_tril = scale_tril.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None:
if covariance_matrix.dim() < 2:
@ -168,7 +168,7 @@ class MultivariateNormal(Distribution):
batch_shape = torch.broadcast_shapes(
covariance_matrix.shape[:-2], loc.shape[:-1]
)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.covariance_matrix = covariance_matrix.expand(batch_shape + (-1, -1))
else:
assert precision_matrix is not None # helps mypy
@ -180,7 +180,7 @@ class MultivariateNormal(Distribution):
batch_shape = torch.broadcast_shapes(
precision_matrix.shape[:-2], loc.shape[:-1]
)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.precision_matrix = precision_matrix.expand(batch_shape + (-1, -1))
self.loc = loc.expand(batch_shape + (-1,))

View File

@ -33,7 +33,7 @@ class NegativeBinomial(Distribution):
logits (Tensor): Event log-odds for probabilities of success
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {
"total_count": constraints.greater_than_eq(0),
"probs": constraints.half_open_interval(0.0, 1.0),
@ -55,7 +55,7 @@ class NegativeBinomial(Distribution):
if probs is not None:
(
self.total_count,
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.probs,
) = broadcast_all(total_count, probs)
self.total_count = self.total_count.type_as(self.probs)
@ -63,7 +63,7 @@ class NegativeBinomial(Distribution):
assert logits is not None # helps mypy
(
self.total_count,
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.logits,
) = broadcast_all(total_count, logits)
self.total_count = self.total_count.type_as(self.logits)

View File

@ -31,7 +31,7 @@ class Normal(ExponentialFamily):
(often referred to as sigma)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
has_rsample = True
@ -89,7 +89,7 @@ class Normal(ExponentialFamily):
if self._validate_args:
self._validate_sample(value)
# compute the variance
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
var = self.scale**2
log_scale = (
math.log(self.scale)
@ -119,6 +119,6 @@ class Normal(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor, Tensor]:
return (self.loc / self.scale.pow(2), -0.5 * self.scale.pow(2).reciprocal())
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _log_normalizer(self, x, y):
return -0.25 * x.pow(2) / y + 0.5 * torch.log(-math.pi / y)

View File

@ -42,7 +42,7 @@ class OneHotCategorical(Distribution):
logits (Tensor): event log probabilities (unnormalized)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = constraints.one_hot
has_enumerate_support = True

View File

@ -39,7 +39,7 @@ class Pareto(TransformedDistribution):
self.scale, self.alpha = broadcast_all(scale, alpha)
base_dist = Exponential(self.alpha, validate_args=validate_args)
transforms = [ExpTransform(), AffineTransform(loc=0, scale=self.scale)]
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(

View File

@ -32,7 +32,7 @@ class Poisson(ExponentialFamily):
rate (Number, Tensor): the rate parameter
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"rate": constraints.nonnegative}
support = constraints.nonnegative_integer
@ -83,6 +83,6 @@ class Poisson(ExponentialFamily):
def _natural_params(self) -> tuple[Tensor]:
return (torch.log(self.rate),)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _log_normalizer(self, x):
return torch.exp(x)

View File

@ -40,7 +40,7 @@ class LogitRelaxedBernoulli(Distribution):
(Jang et al., 2017)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
support = constraints.real
@ -58,12 +58,12 @@ class LogitRelaxedBernoulli(Distribution):
)
if probs is not None:
is_scalar = isinstance(probs, _Number)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
(self.probs,) = broadcast_all(probs)
else:
assert logits is not None # helps mypy
is_scalar = isinstance(logits, _Number)
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
(self.logits,) = broadcast_all(logits)
self._param = self.probs if probs is not None else self.logits
if is_scalar:
@ -141,10 +141,10 @@ class RelaxedBernoulli(TransformedDistribution):
"""
arg_constraints = {"probs": constraints.unit_interval, "logits": constraints.real}
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
support = constraints.unit_interval
has_rsample = True
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
base_dist: LogitRelaxedBernoulli
def __init__(

View File

@ -38,7 +38,7 @@ class ExpRelaxedCategorical(Distribution):
(Jang et al., 2017)
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
support = (
constraints.real_vector
@ -128,10 +128,10 @@ class RelaxedOneHotCategorical(TransformedDistribution):
"""
arg_constraints = {"probs": constraints.simplex, "logits": constraints.real_vector}
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
support = constraints.simplex
has_rsample = True
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
base_dist: ExpRelaxedCategorical
def __init__(

View File

@ -31,7 +31,7 @@ class StudentT(Distribution):
scale (float or Tensor): scale of the distribution
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {
"df": constraints.positive,
"loc": constraints.real,

View File

@ -123,7 +123,7 @@ class TransformedDistribution(Distribution):
return new
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def support(self):
if not self.transforms:
return self.base_dist.support

View File

@ -226,13 +226,13 @@ class _InverseTransform(Transform):
self._inv: Transform = transform # type: ignore[assignment]
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def domain(self):
assert self._inv is not None
return self._inv.codomain
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def codomain(self):
assert self._inv is not None
return self._inv.domain
@ -302,7 +302,7 @@ class ComposeTransform(Transform):
return self.parts == other.parts
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def domain(self):
if not self.parts:
return constraints.real
@ -318,7 +318,7 @@ class ComposeTransform(Transform):
return domain
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def codomain(self):
if not self.parts:
return constraints.real
@ -438,14 +438,14 @@ class IndependentTransform(Transform):
)
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def domain(self):
return constraints.independent(
self.base_transform.domain, self.reinterpreted_batch_ndims
)
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def codomain(self):
return constraints.independent(
self.base_transform.codomain, self.reinterpreted_batch_ndims
@ -513,12 +513,12 @@ class ReshapeTransform(Transform):
super().__init__(cache_size=cache_size)
@constraints.dependent_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def domain(self):
return constraints.independent(constraints.real, len(self.in_shape))
@constraints.dependent_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def codomain(self):
return constraints.independent(constraints.real, len(self.out_shape))
@ -772,14 +772,14 @@ class AffineTransform(Transform):
return self._event_dim
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def domain(self):
if self.event_dim == 0:
return constraints.real
return constraints.independent(constraints.real, self.event_dim)
@constraints.dependent_property(is_discrete=False)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def codomain(self):
if self.event_dim == 0:
return constraints.real
@ -877,7 +877,7 @@ class CorrCholeskyTransform(Transform):
# apply stick-breaking on the squared values
# Note that y = sign(r) * sqrt(z * z1m_cumprod)
# = (sign(r) * sqrt(z)) * sqrt(z1m_cumprod) = r * sqrt(z1m_cumprod)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
z = r**2
z1m_cumprod_sqrt = (1 - z).sqrt().cumprod(-1)
# Diagonal elements must be 1.
@ -1166,14 +1166,14 @@ class CatTransform(Transform):
return all(t.bijective for t in self.transforms)
@constraints.dependent_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def domain(self):
return constraints.cat(
[t.domain for t in self.transforms], self.dim, self.lengths
)
@constraints.dependent_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def codomain(self):
return constraints.cat(
[t.codomain for t in self.transforms], self.dim, self.lengths
@ -1246,12 +1246,12 @@ class StackTransform(Transform):
return all(t.bijective for t in self.transforms)
@constraints.dependent_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def domain(self):
return constraints.stack([t.domain for t in self.transforms], self.dim)
@constraints.dependent_property
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def codomain(self):
return constraints.stack([t.codomain for t in self.transforms], self.dim)

View File

@ -79,7 +79,7 @@ class Uniform(Distribution):
return new
@constraints.dependent_property(is_discrete=False, event_dim=0)
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def support(self):
return constraints.interval(self.low, self.high)

View File

@ -92,7 +92,7 @@ def _log_modified_bessel_fn(x, order=0):
@torch.jit.script_if_tracing
def _rejection_sample(loc, concentration, proposal_r, x):
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
# pyrefly: ignore # bad-assignment
# pyrefly: ignore [bad-assignment]
while not done.all():
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
u1, u2, u3 = u.unbind()
@ -101,7 +101,7 @@ def _rejection_sample(loc, concentration, proposal_r, x):
c = concentration * (proposal_r - f)
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
if accept.any():
# pyrefly: ignore # no-matching-overload
# pyrefly: ignore [no-matching-overload]
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
done = done | accept
return (x + math.pi + loc) % (2 * math.pi) - math.pi
@ -125,7 +125,7 @@ class VonMises(Distribution):
:param torch.Tensor concentration: concentration parameter
"""
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
support = constraints.real
has_rsample = False
@ -163,10 +163,10 @@ class VonMises(Distribution):
@lazy_property
def _proposal_r(self) -> Tensor:
kappa = self._concentration
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
tau = 1 + (1 + 4 * kappa**2).sqrt()
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
# pyrefly: ignore # unsupported-operation
# pyrefly: ignore [unsupported-operation]
_proposal_r = (1 + rho**2) / (2 * rho)
# second order Taylor expansion around 0 for small kappa
_proposal_r_taylor = 1 / kappa + kappa

View File

@ -35,7 +35,7 @@ class Weibull(TransformedDistribution):
"scale": constraints.positive,
"concentration": constraints.positive,
}
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
support = constraints.positive
def __init__(
@ -53,7 +53,7 @@ class Weibull(TransformedDistribution):
PowerTransform(exponent=self.concentration_reciprocal),
AffineTransform(loc=0, scale=self.scale),
]
# pyrefly: ignore # bad-argument-type
# pyrefly: ignore [bad-argument-type]
super().__init__(base_dist, transforms, validate_args=validate_args)
def expand(self, batch_shape, _instance=None):

View File

@ -116,13 +116,13 @@ class Wishart(ExponentialFamily):
)
if scale_tril is not None:
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.scale_tril = param.expand(batch_shape + (-1, -1))
elif covariance_matrix is not None:
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.covariance_matrix = param.expand(batch_shape + (-1, -1))
elif precision_matrix is not None:
# pyrefly: ignore # read-only
# pyrefly: ignore [read-only]
self.precision_matrix = param.expand(batch_shape + (-1, -1))
if self.df.lt(event_shape[-1]).any():
@ -339,7 +339,7 @@ class Wishart(ExponentialFamily):
p = self._event_shape[-1] # has singleton shape
return -self.precision_matrix / 2, (nu - p - 1) / 2
# pyrefly: ignore # bad-override
# pyrefly: ignore [bad-override]
def _log_normalizer(self, x, y):
p = self._event_shape[-1]
return (y + (p + 1) / 2) * (