[Code Clean] Clean asserts in torch/ao/quantization/fx/* (#165420)

Replace assert statements with explicit if/raise patterns in:

- torch/ao/quantization/fx/* (177 errors)

fix partialy #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165420
Approved by: https://github.com/RohitRathore1, https://github.com/fffrog, https://github.com/albanD
This commit is contained in:
zhudada 2025-10-30 20:53:31 +00:00 committed by PyTorch MergeBot
parent df71b70727
commit 7692fa09cd
14 changed files with 567 additions and 307 deletions

View File

@ -29,15 +29,17 @@ def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
raise ValueError(f"Unsupported dtype: {dtype}")
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
assert quant_min >= quant_min_lower_bound, (
"quant_min out of bound for dtype, "
f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
)
if quant_min < quant_min_lower_bound:
raise AssertionError(
"quant_min out of bound for dtype, "
f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
)
assert quant_max <= quant_max_upper_bound, (
"quant_max out of bound for dtype, "
f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
)
if quant_max > quant_max_upper_bound:
raise AssertionError(
"quant_max out of bound for dtype, "
f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
)
quantized_decomposed_lib.define(
@ -72,9 +74,10 @@ def quantize_per_tensor(
"""
if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32)
assert input.dtype == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
if input.dtype != torch.float32:
raise AssertionError(
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
inv_scale = 1.0 / scale
@ -94,9 +97,10 @@ def quantize_per_tensor_meta(
) -> torch.Tensor:
if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32)
assert input.dtype == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
if input.dtype != torch.float32:
raise AssertionError(
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
return torch.empty_like(input, dtype=dtype)
@ -122,12 +126,14 @@ def quantize_per_tensor_tensor(
Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values
"""
assert zero_point.numel() == 1, (
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
assert scale.numel() == 1, (
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
if zero_point.numel() != 1:
raise AssertionError(
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
if scale.numel() != 1:
raise AssertionError(
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
return quantize_per_tensor(
input,
scale.item(),
@ -149,15 +155,18 @@ def quantize_per_tensor_tensor_meta(
) -> torch.Tensor:
if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32)
assert zero_point.numel() == 1, (
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
assert scale.numel() == 1, (
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
assert input.dtype == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
if zero_point.numel() != 1:
raise AssertionError(
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
if scale.numel() != 1:
raise AssertionError(
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
if input.dtype != torch.float32:
raise AssertionError(
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
return torch.empty_like(input, dtype=dtype)
@ -184,12 +193,14 @@ def quantize_per_tensor_tensor2(
Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values
"""
assert zero_point.numel() == 1, (
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
assert scale.numel() == 1, (
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
if zero_point.numel() != 1:
raise AssertionError(
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
if scale.numel() != 1:
raise AssertionError(
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
return quantize_per_tensor(
input,
scale.item(),
@ -266,9 +277,10 @@ def dequantize_per_tensor(
Returns:
dequantized float32 Tensor
"""
assert input.dtype == dtype, (
f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
)
if input.dtype != dtype:
raise AssertionError(
f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
)
if out_dtype is None:
out_dtype = torch.float32
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
@ -322,12 +334,14 @@ def dequantize_per_tensor_tensor(
Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values
"""
assert zero_point.numel() == 1, (
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
assert scale.numel() == 1, (
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
if zero_point.numel() != 1:
raise AssertionError(
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
if scale.numel() != 1:
raise AssertionError(
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
return dequantize_per_tensor(
input,
scale.item(),
@ -352,13 +366,18 @@ def dequantize_per_tensor_tensor_meta(
) -> torch.Tensor:
if out_dtype is None:
out_dtype = torch.float32
assert zero_point.numel() == 1, (
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
assert scale.numel() == 1, (
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
if zero_point.numel() != 1:
raise AssertionError(
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
if scale.numel() != 1:
raise AssertionError(
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
if input.dtype != dtype:
raise AssertionError(
f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
)
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
return torch.empty_like(input, dtype=out_dtype)
else:
@ -392,12 +411,14 @@ def dequantize_per_tensor_tensor2(
Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values
"""
assert zero_point.numel() == 1, (
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
assert scale.numel() == 1, (
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
if zero_point.numel() != 1:
raise AssertionError(
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
)
if scale.numel() != 1:
raise AssertionError(
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
return dequantize_per_tensor(
input,
scale.item(),
@ -448,16 +469,18 @@ def choose_qparams_tensor(
scale (float): quantization parameter for the target quantized Tensor
zero_point (int): quantization parameter for the target quantized Tensor
"""
assert input.dtype in [
if input.dtype not in [
torch.float32,
torch.float16,
torch.bfloat16,
], (
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
)
assert dtype in _DTYPE_TO_QVALUE_BOUNDS, (
f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
)
]:
raise AssertionError(
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
)
if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
raise AssertionError(
f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
)
validate_qmin_qmax(qmin, qmax)
min_val, max_val = torch.aminmax(input)
@ -500,16 +523,18 @@ def choose_qparams_symmetric_tensor(
scale (float): quantization parameter for the target quantized Tensor
zero_point (int): quantization parameter for the target quantized Tensor
"""
assert input.dtype in [
if input.dtype not in [
torch.float32,
torch.float16,
torch.bfloat16,
], (
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
)
assert dtype in _DTYPE_TO_QVALUE_BOUNDS, (
f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
)
]:
raise AssertionError(
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
)
if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
raise AssertionError(
f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
)
validate_qmin_qmax(qmin, qmax)
min_val, max_val = torch.aminmax(input)
@ -529,17 +554,18 @@ def choose_qparams_symmetric_tensor(
def choose_qparams_tensor_meta(
input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]:
assert input.dtype in [
if input.dtype not in [
torch.float32,
torch.float16,
torch.bfloat16,
], (
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
)
assert quant_min < quant_max, (
f"Expecting quant_min to be smaller than quant_max but received min: \
{quant_min} max: {quant_max}"
)
]:
raise AssertionError(
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
)
if quant_min >= quant_max:
raise AssertionError(
f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}"
)
return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(
1, dtype=torch.int64, device=input.device
)
@ -598,10 +624,12 @@ def quantize_per_channel(
"""
if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32)
assert input.dtype == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
if input.dtype != torch.float32:
raise AssertionError(
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
if axis >= input.dim():
raise AssertionError(f"Expecting axis to be < {input.dim()}")
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
input, permute_axis_list = _permute_to_axis_zero(input, axis)
@ -629,10 +657,12 @@ def quantize_per_channel_meta(
) -> torch.Tensor:
if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32)
assert input.dtype == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
if input.dtype != torch.float32:
raise AssertionError(
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
if axis >= input.dim():
raise AssertionError(f"Expecting axis to be < {input.dim()}")
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=dtype)
@ -687,12 +717,14 @@ def dequantize_per_channel(
Returns:
dequantized float32 Tensor
"""
assert input.dtype == dtype, (
f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
)
if input.dtype != dtype:
raise AssertionError(
f"Expecting input to have dtype: {dtype}, but got dtype: {input.dtype}"
)
if out_dtype is None:
out_dtype = torch.float32
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
if axis >= input.dim():
raise AssertionError(f"Expecting axis to be < {input.dim()}")
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
input, permute_axis_list = _permute_to_axis_zero(input, axis)
@ -722,12 +754,14 @@ def dequantize_per_channel_meta(
*,
out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
assert input.dtype == dtype, (
f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
)
if input.dtype != dtype:
raise AssertionError(
f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
)
if out_dtype is None:
out_dtype = torch.float32
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
if axis >= input.dim():
raise AssertionError(f"Expecting axis to be < {input.dim()}")
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=out_dtype)
@ -879,12 +913,12 @@ def choose_qparams_per_token_asymmetric_meta(
def _per_token_quant_qparam_dim_check(input, scales, zero_points):
num_tokens = math.prod(list(input.size())[:-1])
assert num_tokens == scales.numel(), (
f"num_tokens: {num_tokens} scales: {scales.size()}"
)
assert num_tokens == zero_points.numel(), (
f"num_tokens: {num_tokens} zero_points: {zero_points.size()}"
)
if num_tokens != scales.numel():
raise AssertionError(f"num_tokens: {num_tokens} scales: {scales.size()}")
if num_tokens != zero_points.numel():
raise AssertionError(
f"num_tokens: {num_tokens} zero_points: {zero_points.size()}"
)
quantized_decomposed_lib.define(
@ -1019,17 +1053,21 @@ def quantize_per_channel_group(
dtype: torch.dtype,
group_size=128,
):
assert group_size > 1
if group_size <= 1:
raise AssertionError("group_size must be > 1")
# needed for GPTQ single column quantize
if group_size > input.shape[-1] and scales.shape[-1] == 1:
group_size = input.shape[-1]
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
if input.shape[-1] % group_size != 0:
raise AssertionError("input.shape[-1] must be divisible by group_size")
if input.dim() != 2:
raise AssertionError("input must be 2-dimensional")
# TODO: check for dtype, currently we can't express torch.int4 so it's omitted
to_quant = input.reshape(-1, group_size)
assert torch.isnan(to_quant).sum() == 0
if torch.isnan(to_quant).sum() != 0:
raise AssertionError("to_quant must not contain NaNs")
scales = scales.reshape(-1, 1)
zero_points = zero_points.reshape(-1, 1)
@ -1074,13 +1112,16 @@ def quantize_per_channel_group_meta(
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
are not stored in the Tensor, we are storing them in function arguments instead
"""
assert group_size > 1
if group_size <= 1:
raise AssertionError("group_size must be > 1")
# needed for GPTQ single column quantize
if group_size > input.shape[-1] and scales.shape[-1] == 1:
group_size = input.shape[-1]
assert input.shape[-1] % group_size == 0
assert input.dim() == 2
if input.shape[-1] % group_size != 0:
raise AssertionError("input.shape[-1] must be divisible by group_size")
if input.dim() != 2:
raise AssertionError("input must be 2-dimensional")
return torch.empty_like(input, dtype=dtype)
@ -1124,12 +1165,15 @@ def dequantize_per_channel_group(
dequantized Tensor with dtype `output_dtype`
"""
assert group_size > 1
if group_size <= 1:
raise AssertionError("group_size must be > 1")
# needed for GPTQ single column dequantize
if group_size > w_int8.shape[-1] and scales.shape[-1] == 1:
group_size = w_int8.shape[-1]
assert w_int8.shape[-1] % group_size == 0
assert w_int8.dim() == 2
if w_int8.shape[-1] % group_size != 0:
raise AssertionError("w_int8.shape[-1] must be divisible by group_size")
if w_int8.dim() != 2:
raise AssertionError("w_int8 must be 2-dimensional")
w_int8_grouped = w_int8.reshape(-1, group_size)
scales = scales.reshape(-1, 1)
@ -1155,10 +1199,12 @@ class FakeQuantPerChannel(torch.autograd.Function):
scales = scales.to(torch.float32)
if zero_points.dtype != torch.int32:
zero_points = zero_points.to(torch.int32)
assert input.dtype == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
if input.dtype != torch.float32:
raise AssertionError(
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
if axis >= input.dim():
raise AssertionError(f"Expecting axis to be < {input.dim()}")
broadcast_dims = list(range(axis)) + list(range(axis + 1, input.ndim))
unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims)
unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims)

View File

@ -90,7 +90,7 @@ class _InputEqualizationObserver(nn.Module):
self.equalization_shape: list[int] = []
def forward(self, x_orig):
if not (x_orig.ndim >= 2 and x_orig.ndim <= 5):
if x_orig.ndim < 2 or x_orig.ndim > 5:
raise ValueError(
"InputEqualizationObserver only supports Linear and Conv layers"
)
@ -191,7 +191,7 @@ class _WeightEqualizationObserver(nn.Module):
self.equalization_scale = torch.tensor(1)
def forward(self, w_orig):
if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
if w_orig.ndim < 2 or w_orig.ndim > 5:
raise ValueError(
"InputEqualizationObserver only supports Linear and Conv layers"
)
@ -232,7 +232,7 @@ def calculate_equalization_scale(
)
return torch.tensor(1)
if not (min_inputs.shape == min_weights.shape):
if min_inputs.shape != min_weights.shape:
raise ValueError(
"Input and Weight must have the same column dimension. "
+ f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
@ -355,30 +355,45 @@ def get_op_node_and_weight_eq_obs(
op_node = user
break
assert op_node is not None
if op_node is None:
raise AssertionError(
"Expected an operation node after the input equalization observer"
)
if op_node.op == "call_module":
# If the op_node is a nn.Linear layer, then it must have a
# WeightEqualizationObserver configuration
maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(
model, "equalization_node_name_to_qconfig"
)
assert maybe_equalization_node_name_to_config is not None
if maybe_equalization_node_name_to_config is None:
raise AssertionError(
"Expected 'equalization_node_name_to_qconfig' attribute in observed graph module"
)
equalization_node_name_to_qconfig: dict[str, Any] = (
maybe_equalization_node_name_to_config # type: ignore[assignment]
)
assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
if equalization_node_name_to_qconfig.get(op_node.name, None) is None:
raise AssertionError(
f"No equalization qconfig found for op node {op_node.name}"
)
weight_eq_obs = equalization_node_name_to_qconfig.get( # type: ignore[union-attr]
op_node.name, None
).weight()
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
if not isinstance(weight_eq_obs, _WeightEqualizationObserver):
raise AssertionError(
"Expected weight equalization observer to be a _WeightEqualizationObserver"
)
return op_node, weight_eq_obs
elif op_node.op == "call_function":
weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
if weight_node is not None:
weight_eq_obs = modules[str(weight_node.target)]
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
if not isinstance(weight_eq_obs, _WeightEqualizationObserver):
raise AssertionError(
"Expected weight equalization observer to be a _WeightEqualizationObserver"
)
return op_node, weight_eq_obs
return None, None
@ -388,17 +403,20 @@ def maybe_get_weight_eq_obs_node(
op_node: Node, modules: dict[str, nn.Module]
) -> Optional[Node]:
"""Gets the weight equalization observer node if it exists."""
assert op_node.op == "call_function"
if op_node.op != "call_function":
raise AssertionError(
"maybe_get_weight_eq_obs_node expects a call_function op_node"
)
for node_arg in op_node.args:
if node_arg_is_weight(op_node, node_arg):
assert (
if (
isinstance(node_arg, Node)
and node_arg.op == "call_module"
and isinstance(
modules[str(node_arg.target)], _WeightEqualizationObserver
)
)
return node_arg
):
return node_arg
return None
@ -422,7 +440,8 @@ def maybe_get_next_input_eq_obs(
the following equalization observer for linear2.
"""
assert node_supports_equalization(node, modules)
if not node_supports_equalization(node, modules):
raise AssertionError("Node does not support equalization")
# Locate the following nn.ReLU or F.relu node if it exists
maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
@ -448,7 +467,10 @@ def maybe_get_next_input_eq_obs(
return None
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
if not isinstance(maybe_eq_obs, _InputEqualizationObserver):
raise AssertionError(
"Expected the following equalization observer to be an _InputEqualizationObserver"
)
return maybe_eq_obs
@ -480,10 +502,16 @@ def scale_input_observer(node: Node, modules: dict[str, nn.Module]) -> None:
equalization observer
"""
input_eq_obs = modules[str(node.target)]
assert isinstance(input_eq_obs, _InputEqualizationObserver)
if not isinstance(input_eq_obs, _InputEqualizationObserver):
raise AssertionError(
"Expected the module at node.target to be an _InputEqualizationObserver"
)
input_quant_obs_node = node.args[0]
assert isinstance(input_quant_obs_node, Node)
if not isinstance(input_quant_obs_node, Node):
raise AssertionError(
"Expected the input quantization observer node to be a Node"
)
input_quant_obs = modules[str(input_quant_obs_node.target)]
if not isinstance(input_quant_obs, ObserverBase):
@ -518,14 +546,19 @@ def scale_weight_node(
op_module = modules[str(node.target)][0] # type: ignore[index]
else:
op_module = modules[str(node.target)]
assert nn_module_supports_equalization(
op_module
) or custom_module_supports_equalization(op_module)
if not (
nn_module_supports_equalization(op_module)
or custom_module_supports_equalization(op_module)
):
raise AssertionError(
"Expected operation module to support equalization (nn or custom)"
)
# Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale
weight = op_module.weight
assert isinstance(weight, torch.Tensor)
if not isinstance(weight, torch.Tensor):
raise AssertionError("Expected op_module.weight to be a torch.Tensor")
# Scale the weights by the reciprocal of the equalization scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
@ -547,7 +580,8 @@ def scale_weight_node(
bias = op_module.bias
if bias is None:
return
assert isinstance(bias, torch.Tensor)
if not isinstance(bias, torch.Tensor):
raise AssertionError("Expected op_module.bias to be a torch.Tensor")
# Reshape the equalization scale so that we can multiply it element-wise to the bias
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
@ -581,15 +615,20 @@ def scale_weight_functional(
weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None:
return
assert isinstance(weight_quant_obs_node, Node) and isinstance(
modules[str(weight_quant_obs_node.target)], ObserverBase
)
if not (
isinstance(weight_quant_obs_node, Node)
and isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
):
raise AssertionError(
"Expected weight_quant_obs_node to be a Node whose module is an ObserverBase"
)
# Get the get_attr(weight) node
weight_node = weight_quant_obs_node.args[0]
if weight_node is None:
return
assert isinstance(weight_node, Node) and weight_node.op == "get_attr"
if not (isinstance(weight_node, Node) and weight_node.op == "get_attr"):
raise AssertionError("Expected weight node to be a 'get_attr' Node")
weight_parent_name, weight_name = _parent_name(weight_node.target)
weight = getattr(modules[weight_parent_name], weight_name)
@ -612,7 +651,8 @@ def scale_weight_functional(
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
setattr(modules[weight_parent_name], weight_name, scaled_weight)
assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
if not torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight):
raise AssertionError("Model buffer for weight does not match the scaled weight")
# Multiply the bias element wise by the next equalization scale
bias_node = None
@ -644,10 +684,14 @@ def clear_weight_quant_obs_node(op_node: Node, modules: dict[str, nn.Module]) ->
weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None:
return
assert isinstance(weight_quant_obs_node, Node)
if not isinstance(weight_quant_obs_node, Node):
raise AssertionError("Expected weight_quant_obs_node to be a Node")
weight_quant_obs = modules[str(weight_quant_obs_node.target)]
assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
if not isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase):
raise AssertionError(
"Expected the module at weight_quant_obs_node to be an ObserverBase"
)
weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
@ -682,7 +726,10 @@ def update_obs_for_equalization(
modules[node.target], _InputEqualizationObserver
):
input_eq_obs = modules[node.target]
assert isinstance(input_eq_obs, _InputEqualizationObserver)
if not isinstance(input_eq_obs, _InputEqualizationObserver):
raise AssertionError(
"Expected module at node.target to be an _InputEqualizationObserver"
)
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
if op_node is None or weight_eq_obs is None:
@ -693,7 +740,10 @@ def update_obs_for_equalization(
# been created
if fused_module_supports_equalization(modules[str(op_node.target)]):
module = modules[str(op_node.target)][0] # type: ignore[index]
assert nn_module_supports_equalization(module)
if not nn_module_supports_equalization(module):
raise AssertionError(
"Expected fused module to support equalization"
)
weight_eq_obs(module.weight)
else:
weight_eq_obs(modules[str(op_node.target)].weight)
@ -810,7 +860,10 @@ def convert_eq_obs(
elif weight_eq_obs_dict.get(node.name, None) is not None:
weight_eq_obs = weight_eq_obs_dict.get(node.name)
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
if not isinstance(weight_eq_obs, _WeightEqualizationObserver):
raise AssertionError(
"Expected weight equalization observer to be a _WeightEqualizationObserver"
)
equalization_scale = weight_eq_obs.equalization_scale
if (
@ -844,9 +897,12 @@ def convert_eq_obs(
weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
if weight_eq_obs_node is None:
return
assert isinstance(
if not isinstance(
modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver
)
):
raise AssertionError(
"Expected weight equalization observer to be a _WeightEqualizationObserver"
)
# Clear the quantization observer's min/max values so that they
# can get updated later based on the new scale values

View File

@ -585,7 +585,8 @@ def _match_static_pattern(
return SKIP_LOWERING_VALUE
q_node = node
ref_node = q_node.args[0]
assert isinstance(ref_node, Node)
if not isinstance(ref_node, Node):
raise AssertionError("Expected the reference node to be a torch.fx Node")
# Handle cases where the node is wrapped in a ReLU
if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or (
@ -593,7 +594,10 @@ def _match_static_pattern(
):
relu_node = ref_node
ref_node = relu_node.args[0]
assert isinstance(ref_node, Node)
if not isinstance(ref_node, Node):
raise AssertionError(
"Expected the reference node after ReLU to be a torch.fx Node"
)
else:
relu_node = None
if should_skip_lowering(ref_node, qconfig_map):
@ -616,9 +620,10 @@ def _match_static_pattern(
# (2) There must be at least one dequantize node
matched_dequantize = False
for i in dequantize_node_arg_indices:
assert i < len(ref_node.args), (
f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}"
)
if i >= len(ref_node.args):
raise AssertionError(
f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}"
)
arg = ref_node.args[i]
if is_dequantize_node(arg):
matched_dequantize = True
@ -660,7 +665,8 @@ def _match_static_pattern_with_two_inputs(
return SKIP_LOWERING_VALUE
q_node = node
ref_node = q_node.args[0]
assert isinstance(ref_node, Node)
if not isinstance(ref_node, Node):
raise AssertionError("Expected the reference node to be a torch.fx Node")
if should_skip_lowering(ref_node, qconfig_map):
return SKIP_LOWERING_VALUE
@ -711,13 +717,21 @@ def _lower_static_weighted_ref_module(
)
if q_node is None:
continue
assert ref_node is not None
if ref_node is None:
raise AssertionError(
"Expected a reference node when matching static pattern"
)
(_, scale_node, zero_point_node, _) = q_node.args
ref_module = _get_module(ref_node, modules)
ref_class = type(ref_module)
assert isinstance(scale_node, Node)
assert isinstance(zero_point_node, Node)
assert issubclass(ref_class, nn.Module)
if not isinstance(scale_node, Node):
raise AssertionError("Expected scale_node to be a Node")
if not isinstance(zero_point_node, Node):
raise AssertionError("Expected zero_point_node to be a Node")
if not issubclass(ref_class, nn.Module):
raise AssertionError(
"Expected reference module class to be a subclass of nn.Module"
)
# Step 1: Change this pattern to use the corresponding quantized module
# For fused modules, we also check whether the inner module is a reference module
@ -736,9 +750,11 @@ def _lower_static_weighted_ref_module(
setattr(modules[parent_name], module_name, q_module)
# Step 2: Reroute around dq_node, and remove q_node and its args
assert len(ref_node.args) == 1
if len(ref_node.args) != 1:
raise AssertionError("Expected reference node to have exactly 1 arg")
dq_node = ref_node.args[0]
assert isinstance(dq_node, Node)
if not isinstance(dq_node, Node):
raise AssertionError("Expected dq_node to be a Node")
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
q_node.replace_all_uses_with(ref_node)
model.graph.erase_node(q_node)
@ -771,13 +787,21 @@ def _lower_static_weighted_ref_module_with_two_inputs(
)
if q_node is None:
continue
assert ref_node is not None
if ref_node is None:
raise AssertionError(
"Expected a reference node when matching static pattern with two inputs"
)
(_, scale_node, zero_point_node, _) = q_node.args
ref_module = _get_module(ref_node, modules)
ref_class = type(ref_module)
assert isinstance(scale_node, Node)
assert isinstance(zero_point_node, Node)
assert issubclass(ref_class, nn.Module)
if not isinstance(scale_node, Node):
raise AssertionError("Expected scale_node to be a Node")
if not isinstance(zero_point_node, Node):
raise AssertionError("Expected zero_point_node to be a Node")
if not issubclass(ref_class, nn.Module):
raise AssertionError(
"Expected reference module class to be a subclass of nn.Module"
)
# Step 1: Change this pattern to use the corresponding quantized module
# For fused modules, we also check whether the inner module is a reference module
@ -798,12 +822,14 @@ def _lower_static_weighted_ref_module_with_two_inputs(
setattr(modules[parent_name], module_name, q_module)
# Step 2: Reroute around dq_node, and remove q_node and its args
assert len(ref_node.args) == 2
if len(ref_node.args) != 2:
raise AssertionError("Expected reference node to have exactly 2 args")
for arg in ref_node.args:
if not is_dequantize_node(arg):
continue
dq_node = arg
assert isinstance(dq_node, Node)
if not isinstance(dq_node, Node):
raise AssertionError("Expected dq_node to be a Node")
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
q_node.replace_all_uses_with(ref_node)
@ -900,14 +926,21 @@ def _lower_static_weighted_ref_functional(
)
if q_node is None:
continue
assert func_node is not None
if func_node is None:
raise AssertionError(
"Expected a function node when matching static functional pattern"
)
(_, output_scale_node, output_zp_node, _) = q_node.args
(input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
assert isinstance(output_zp_node, Node)
assert isinstance(input_dq_node, Node)
assert isinstance(weight_dq_node, Node)
if not isinstance(output_zp_node, Node):
raise AssertionError("Expected output_zp_node to be a Node")
if not isinstance(input_dq_node, Node):
raise AssertionError("Expected input_dq_node to be a Node")
if not isinstance(weight_dq_node, Node):
raise AssertionError("Expected weight_dq_node to be a Node")
quantized_weight = weight_dq_node.args[0]
assert isinstance(quantized_weight, Node)
if not isinstance(quantized_weight, Node):
raise AssertionError("Expected quantized_weight to be a Node")
if quantized_weight.op != "call_function" or quantized_weight.target not in (
torch.quantize_per_tensor,
torch.quantize_per_channel,
@ -1135,7 +1168,10 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfi
)
if q_node is None:
continue
assert bop_node is not None
if bop_node is None:
raise AssertionError(
"Expected a binary op node when matching quantized binary op pattern"
)
(_, scale_node, zero_point_node, _) = q_node.args
# Step 1: Remove dequant nodes
@ -1144,14 +1180,21 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfi
if not is_dequantize_node(arg):
continue
dq_node = arg
assert isinstance(dq_node, Node)
if not isinstance(dq_node, Node):
raise AssertionError("Expected dq_node to be a Node")
dn_input = dq_node.args[0]
bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type]
num_dq_nodes += 1
assert num_dq_nodes > 0
if num_dq_nodes <= 0:
raise AssertionError(
"Expected at least one dequantize node in binary op args"
)
# Step 2: Swap binary op to quantized binary op
assert bop_node.target in QBIN_OP_MAPPING
if bop_node.target not in QBIN_OP_MAPPING:
raise AssertionError(
f"Unsupported binary op {bop_node.target} for lowering"
)
binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING
qbin_op = binop_to_qbinop[bop_node.target]
# prepare the args for quantized binary op
@ -1188,7 +1231,8 @@ def special_pattern_replacement(model: GraphModule):
and len(q_node.args) == 2
and q_node.args[1] == torch.float16
)
if not (is_quantize or is_to_fp16):
# Only continue when neither quantize nor to_fp16
if not is_quantize and not is_to_fp16:
continue
ref_node = q_node.args[0]
# get output scale/zero_point/dtype from the quantize node
@ -1217,13 +1261,17 @@ def special_pattern_replacement(model: GraphModule):
)
if not (is_call_module or is_call_function or is_call_method):
continue
assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0
if len(ref_node.args) <= 0 and len(ref_node.kwargs) <= 0:
raise AssertionError("Expected ref_node to have args or kwargs")
dq_node_or_nodes = (
ref_node.args[0]
if len(ref_node.args) > 0
else next(iter(ref_node.kwargs.values()))
)
assert isinstance(dq_node_or_nodes, (Node, tuple, list))
if not isinstance(dq_node_or_nodes, (Node, tuple, list)):
raise AssertionError(
"Expected dq_node_or_nodes to be a Node, tuple, or list"
)
is_dequantize = False
if isinstance(dq_node_or_nodes, Node):
is_dequantize = (

View File

@ -362,11 +362,15 @@ class PerChannelDetector(DetectorBase):
# assert statement for MyPy
q_config_file = module.qconfig
assert isinstance(q_config_file, QConfig)
if not isinstance(q_config_file, QConfig):
raise AssertionError("module.qconfig must be a QConfig")
# this object should either be fake quant or observer
q_or_s_obj = module.qconfig.weight.p.func()
assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase))
if not isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)):
raise AssertionError(
"module.qconfig.weight must be a FakeQuantize or ObserverBase"
)
per_channel_used = False # will be true if found in qconfig
@ -1160,9 +1164,10 @@ class InputWeightEqualizationDetector(DetectorBase):
input_channels = len(input_ratio)
if weight_channels != input_channels:
# we try to replicate
assert input_channels % weight_channels == 0, (
"input channels should be divisible by weight channels."
)
if input_channels % weight_channels != 0:
raise AssertionError(
"input channels should be divisible by weight channels."
)
# get replication factor
rep_factor: int = input_channels // weight_channels
@ -1418,11 +1423,15 @@ class OutlierDetector(DetectorBase):
self.ratio_threshold = ratio_threshold
# make sure passed in percentile is valid
assert reference_percentile >= 0 and reference_percentile <= 1
assert (
if reference_percentile < 0 or reference_percentile > 1:
raise AssertionError("reference_percentile must be between 0 and 1")
if not (
fraction_batches_used_threshold >= 0
and fraction_batches_used_threshold <= 1
)
):
raise AssertionError(
"fraction_batches_used_threshold must be between 0 and 1"
)
self.reference_percentile = reference_percentile
self.fraction_batches_used_threshold = fraction_batches_used_threshold
self.ch_axis = ch_axis

View File

@ -261,7 +261,8 @@ class ModelReport:
raise ValueError("The node_fqn is was not found within the module.")
# assert for MyPy
assert isinstance(node_to_return, torch.fx.node.Node)
if not isinstance(node_to_return, torch.fx.node.Node):
raise AssertionError("node_to_return must be a torch.fx.node.Node")
return node_to_return

View File

@ -112,8 +112,12 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
or quantize_per_channel and dequantize_per_channel
"""
graph = model.graph
assert modules is not None
assert isinstance(node.target, str)
if modules is None:
raise AssertionError("modules must not be None")
if not isinstance(node.target, str):
raise AssertionError(
f"Expected node.target to be a str, but got {type(node.target)}"
)
module_path, prefix = _get_module_path_and_prefix(
node, node_name_to_scope, node_name_to_qconfig
)
@ -260,10 +264,10 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
# and that can be done after we remove reduce_range flag
# 1. extract qparams from activation_post_process module
dtype_ = to_underlying_dtype(dtype)
assert dtype_ in [torch.uint8, torch.int8], (
"only uint8 and int8 are supported in reference flow for "
"dynamic quantization right now"
)
if dtype_ not in [torch.uint8, torch.int8]:
raise AssertionError(
"only uint8 and int8 are supported in reference flow for dynamic quantization right now"
)
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined]
@ -379,8 +383,12 @@ def _replace_observer_with_quantize_dequantize_node(
After:
... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
"""
assert modules is not None
assert isinstance(node.target, str)
if modules is None:
raise AssertionError("modules must not be None")
if not isinstance(node.target, str):
raise AssertionError(
f"Expected node.target to be a str, but got {type(node.target)}"
)
graph = model.graph
module_path, prefix = _get_module_path_and_prefix(
node, node_name_to_scope, node_name_to_qconfig
@ -521,9 +529,10 @@ def _replace_observer_or_dequant_stub_with_dequantize_node(
node: Node, graph: Graph
) -> None:
call_custom_module_node = node.args[0]
assert isinstance(call_custom_module_node, Node), (
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
)
if not isinstance(call_custom_module_node, Node):
raise AssertionError(
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
)
node.replace_all_uses_with(call_custom_module_node)
graph.erase_node(node)
_insert_dequantize_node(call_custom_module_node, graph)
@ -617,9 +626,10 @@ def _get_module_path_and_prefix(
# operator (they can be the same)
# this flag identifies if the observer is inserted only because the observed node is
# the input of the next operator
assert isinstance(observed_node, Node), (
f"Expecting observed node to be a Node, but got {observed_node}"
)
if not isinstance(observed_node, Node):
raise AssertionError(
f"Expecting observed node to be a Node, but got {observed_node}"
)
is_input_observer_only = (
node_name_to_qconfig[observed_node.name] is None
if observed_node.name in node_name_to_qconfig
@ -727,8 +737,10 @@ def convert_standalone_module(
"_observed_graph_module_attrs"
].standalone_module_output_quantized_idxs
if len(sm_output_quantized_idxs) > 0:
assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
"output idxs = [0] is supported"
if sm_output_quantized_idxs[0] != 0:
raise AssertionError(
"Currently only quantized output idxs = [0] is supported"
)
# if it's non-empty, then it means the output is kept in quantized form
# we'll just add a dequantize node after this node
@ -882,9 +894,10 @@ def convert_weighted_module(
ref_qmodule_cls = root_module_to_quantized_reference_module.get(
type_before_parametrizations(float_module), None
)
assert ref_qmodule_cls is not None, (
f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
)
if ref_qmodule_cls is None:
raise AssertionError(
f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
)
ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined]
if fused_module is not None:
fused_module[0] = ref_qmodule # type: ignore[operator]
@ -904,9 +917,10 @@ def _remove_previous_dequantize_in_custom_module(
\\ - dequantize
"""
# expecting the input node for a custom module node to be a Node
assert isinstance(prev_node, Node), (
f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
)
if not isinstance(prev_node, Node):
raise AssertionError(
f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
)
if prev_node.op == "call_method" and prev_node.target == "dequantize":
node.replace_input_with(prev_node, prev_node.args[0])
# Remove the dequantize node if it doesn't have other users
@ -952,15 +966,21 @@ def convert_custom_module(
if _is_custom_module_lstm(node, modules):
# The inputs are tuples in the form (input, (hidden0, hidden1))
# Ensure all three input nodes are quantized
assert (
if not (
len(node.args) == 2
and isinstance(node.args[1], tuple)
and len(node.args[1]) == 2
)
):
raise AssertionError(
"Expected LSTM custom module inputs to be (input, (hidden0, hidden1))"
)
(inputs, (hidden0, hidden1)) = node.args # type: ignore[misc]
assert isinstance(inputs, Node)
assert isinstance(hidden0, Node)
assert isinstance(hidden1, Node)
if not isinstance(inputs, Node):
raise AssertionError("Expected inputs to be a Node")
if not isinstance(hidden0, Node):
raise AssertionError("Expected hidden0 to be a Node")
if not isinstance(hidden1, Node):
raise AssertionError("Expected hidden1 to be a Node")
_remove_previous_dequantize_in_custom_module(node, inputs, graph)
_remove_previous_dequantize_in_custom_module(node, hidden0, graph)
_remove_previous_dequantize_in_custom_module(node, hidden1, graph)
@ -971,22 +991,32 @@ def convert_custom_module(
# to the module.
# Additional handling is yet to be implemented for the outputs, similar
# to LSTM custom module
assert len(node.args) == 3
if len(node.args) != 3:
raise AssertionError(
"Expected MHA custom module inputs to be (query, key, value)"
)
query, key, value = node.args
assert isinstance(query, Node)
assert isinstance(key, Node)
assert isinstance(value, Node)
if not isinstance(query, Node):
raise AssertionError("Expected query to be a Node")
if not isinstance(key, Node):
raise AssertionError("Expected key to be a Node")
if not isinstance(value, Node):
raise AssertionError("Expected value to be a Node")
_remove_previous_dequantize_in_custom_module(node, query, graph)
_remove_previous_dequantize_in_custom_module(node, key, graph)
_remove_previous_dequantize_in_custom_module(node, value, graph)
else:
# remove the previous dequant node to ensure the inputs are quantized
arg = node.args[0]
assert isinstance(arg, Node)
if not isinstance(arg, Node):
raise AssertionError("Expected arg to be a Node")
_remove_previous_dequantize_in_custom_module(node, arg, graph)
# absorb the following observer into the module conversion
activation_post_process = _maybe_get_observer_for_node(node, modules)
assert activation_post_process is not None
if activation_post_process is None:
raise AssertionError(
"Expected activation_post_process to be present for observed custom module"
)
observed_custom_module.activation_post_process = activation_post_process
# swap the observed custom module to quantized custom module
@ -1061,7 +1091,8 @@ def convert(
QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
)
qconfig_mapping = copy.deepcopy(qconfig_mapping)
assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
if not (qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)):
raise AssertionError("qconfig_mapping must be None or a QConfigMapping")
if isinstance(backend_config, dict):
warnings.warn(
@ -1075,7 +1106,8 @@ def convert(
if backend_config is None:
backend_config = get_native_backend_config()
assert _is_observed_module(model), "incoming model must be produced by prepare_fx"
if not _is_observed_module(model):
raise AssertionError("incoming model must be produced by prepare_fx")
observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
node_name_to_scope: dict[str, tuple[str, type]] = (
observed_graph_module_attrs.node_name_to_scope
@ -1121,14 +1153,16 @@ def convert(
# all the values either match what was set in prepare node_name_to_qconfig
# or are set to None in the convert_node_name_to_qconfig.
for k, v in node_name_to_qconfig.items():
assert k in convert_node_name_to_qconfig, (
f"Expected key {k} in convert node_name_to_qconfig"
)
if convert_node_name_to_qconfig[k] is not None:
assert qconfig_equals(v, convert_node_name_to_qconfig[k]), (
f"Expected k {k} to have the same value in prepare and convert QConfigMappings, "
f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
if k not in convert_node_name_to_qconfig:
raise AssertionError(
f"Expected key {k} in convert node_name_to_qconfig"
)
if convert_node_name_to_qconfig[k] is not None:
if not qconfig_equals(v, convert_node_name_to_qconfig[k]):
raise AssertionError(
f"Expected k {k} to have the same value in prepare and convert QConfigMappings, "
f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
)
node_name_to_qconfig = convert_node_name_to_qconfig
custom_module_classes = get_custom_module_class_keys(
@ -1201,7 +1235,10 @@ def convert(
)
elif node.op == "call_module":
mod = _get_module(node, modules)
assert mod is not None
if mod is None:
raise AssertionError(
"Expected module for call_module node to be present in modules mapping"
)
if _is_activation_post_process(mod):
observed_node = node.args[0]
if observed_node in statically_quantized_custom_module_nodes:

View File

@ -102,7 +102,10 @@ def fuse(
else:
node_subpattern = None
if maybe_last_node is node:
assert obj is not None
if obj is None:
raise AssertionError(
"fuse handler object must not be None for matched root node"
)
root_node_getter = fusion_pattern_to_root_node_getter.get(
pattern, default_root_node_getter
)

View File

@ -65,9 +65,8 @@ class DefaultFuseHandler(FuseHandler):
fuser_method_mapping: dict[Pattern, Union[torch.nn.Sequential, Callable]],
is_qat: bool,
) -> Node:
assert root_node.op == "call_module", (
"Expecting module node to be a call_module Node"
)
if root_node.op != "call_module":
raise AssertionError("Expecting module node to be a call_module Node")
root_module = named_modules[str(root_node.target)]
def get_modules(pattern):

View File

@ -109,7 +109,8 @@ def _get_lstm_with_individually_observed_parts(
# TODO: maybe make this work for layer_bw as well
for layer in quantizable_lstm.layers:
cell = layer.layer_fw.cell # type: ignore[union-attr]
assert isinstance(cell, torch.nn.Module), "cell should be a nn.Module"
if not isinstance(cell, torch.nn.Module):
raise AssertionError("cell should be a nn.Module")
cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config)
# HACK: Manually replace the activation_post_process following these ops.
# This is needed for FloatFunctional ops because there is currently no way
@ -150,7 +151,8 @@ def _get_lstm_with_individually_observed_parts(
continue
if op_index not in op_index_to_activation_post_process_ctr:
continue
assert len(node.users) == 1
if len(node.users) != 1:
raise AssertionError("expected exactly one user for the node")
activation_post_process_name = next(iter(node.users.keys())).name
activation_post_process_ctr = op_index_to_activation_post_process_ctr[
op_index
@ -195,7 +197,8 @@ def _get_reference_quantized_lstm_module(
for i, layer in enumerate(quantized_lstm.layers):
cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr]
cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type]
assert isinstance(cell, torch.fx.GraphModule)
if not isinstance(cell, torch.fx.GraphModule):
raise AssertionError("cell must be converted to a torch.fx.GraphModule")
# HACK: Manually remove input quantize nodes and output dequantize nodes,
# since custom modules expect quint8 inputs and outputs for now. Note that
# this functionality is supposedly handled through PrepareCustomConfig's

View File

@ -33,7 +33,8 @@ def _is_match(modules, node, pattern, max_uses=sys.maxsize):
if isinstance(pattern, tuple):
self_match, *arg_matches = pattern
if self_match is getattr:
assert len(pattern) == 2, "Expecting getattr pattern to have two elements"
if len(pattern) != 2:
raise AssertionError("Expecting getattr pattern to have two elements")
arg_matches = []
else:
self_match = pattern
@ -190,7 +191,8 @@ def _find_matches(
break
# add custom module instances to the match result
assert modules is not None
if modules is None:
raise AssertionError("modules must not be None")
for node in graph.nodes:
if (
node.op == "call_module"
@ -204,7 +206,8 @@ def _find_matches(
)
def is_standalone_module(node_target: str, modules: dict[str, torch.nn.Module]):
assert modules is not None
if modules is None:
raise AssertionError("modules must not be None")
return (
node_target in standalone_module_names
or type(modules[node_target]) # type: ignore[operator]

View File

@ -149,10 +149,11 @@ def _create_obs_or_fq_from_qspec(
return None
if isinstance(quantization_spec, SharedQuantizationSpec):
edge_or_node = quantization_spec.edge_or_node
assert edge_or_node in obs_or_fq_map, (
"please make sure only refer to edge or node that has "
f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}"
)
if edge_or_node not in obs_or_fq_map:
raise AssertionError(
"please make sure only refer to edge or node that has "
f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}"
)
return obs_or_fq_map[edge_or_node]
elif isinstance(quantization_spec, DerivedQuantizationSpec):
# can't use asdict, so not calling get_observer_kwargs here
@ -177,7 +178,8 @@ def _create_obs_or_fq_from_qspec(
else:
return observer_ctr()
assert isinstance(quantization_spec, QuantizationSpec)
if not isinstance(quantization_spec, QuantizationSpec):
raise AssertionError("quantization_spec must be a QuantizationSpec")
observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
kwargs = _get_observer_kwargs(quantization_spec)
kwargs.pop("observer_or_fake_quant_ctr")
@ -214,10 +216,14 @@ def _needs_obs_or_fq(
# need to insert placeholder observer for dynamic quantization so that it can
# be converted to choose_qparams -> q -> dq in convert step
if cur_target_is_dynamic:
assert cur_target_dtype in _OBS_DTYPE_LIST, (
f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
)
assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST
if cur_target_dtype not in _OBS_DTYPE_LIST:
raise AssertionError(
f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
)
if prev_output_dtype in _DO_NOT_OBS_DTYPE_LIST:
raise AssertionError(
"prev_output_dtype must not be in _DO_NOT_OBS_DTYPE_LIST"
)
return is_zeroth_arg
if reuse_input_obs_or_fq:
return False
@ -398,7 +404,8 @@ def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
"""
if backend_config is None or pattern is None:
return True
assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
if matched_node_pattern is None or len(matched_node_pattern) < 1:
raise AssertionError("matched_node_pattern must be non-empty")
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
@ -535,7 +542,8 @@ def _set_target_dtype_info_for_matched_node_pattern(
# other types of matched object, e.g. int, float literals, are ignored
elif isinstance(matched_node_pattern, Node):
# for pyre
assert isinstance(matched_node_pattern, Node)
if not isinstance(matched_node_pattern, Node):
raise AssertionError("matched_node_pattern must be a Node")
node = matched_node_pattern
if node in processed_nodes:
return
@ -674,7 +682,8 @@ def _get_output_act_obs_or_fq(
We are assuming that the observers are inserted correctly, and the dtype for
argument in quantized graph will match what is specified by the qconfig
"""
assert isinstance(arg, Node)
if not isinstance(arg, Node):
raise AssertionError("arg must be a Node")
if "quantization_annotation" in arg.meta:
return _create_obs_or_fq_from_qspec(
arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
@ -698,9 +707,8 @@ def _get_output_act_obs_or_fq(
)
elif _is_activation_post_process_node(arg, named_modules):
observed_arg = arg.args[0]
assert isinstance(observed_arg, Node), (
"Currently we only support observing Node"
)
if not isinstance(observed_arg, Node):
raise AssertionError("Currently we only support observing Node")
if "quantization_annotation" in observed_arg.meta:
output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
observed_arg.meta["quantization_annotation"].output_qspec,
@ -708,7 +716,10 @@ def _get_output_act_obs_or_fq(
is_qat,
)
else:
assert "target_dtype_info" in observed_arg.meta
if "target_dtype_info" not in observed_arg.meta:
raise AssertionError(
"expected 'target_dtype_info' in observed_arg.meta"
)
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][
"output_act_obs_or_fq_ctr"
]
@ -754,7 +765,8 @@ def _get_arg_as_input_act_obs_or_fq(
"""Get the observer or fake quant constructor for the Argument `arg`, as input
to Node `node`
"""
assert isinstance(arg, Node)
if not isinstance(arg, Node):
raise AssertionError("arg must be a Node")
# "input_qspec_map" is the more general design we'll use for pt2e path
# it is a map from input argument node to observer or fake quant constructor, for example
# for the following graph:
@ -838,7 +850,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
if not isinstance(arg, Node):
return arg
assert isinstance(arg, Node)
if not isinstance(arg, Node):
raise AssertionError("arg must be a Node")
# default (no observer)
new_arg = arg
@ -854,7 +867,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
"quantization_annotation"
]._reuse_input_obs_or_fq
else:
assert "target_dtype_info" in node.meta
if "target_dtype_info" not in node.meta:
raise AssertionError("expected 'target_dtype_info' in node.meta")
# TODO: we are assuming "target_dtype_info" exists here, maybe
# a default value also need to be provided here
target_dtype_info = node.meta["target_dtype_info"]
@ -889,7 +903,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
)
else:
assert qconfig is not None
if qconfig is None:
raise AssertionError("qconfig must not be None")
# custom flow for standalone modules
_, _, sm_prepare_custom_config, _ = _get_standalone_module_configs(
node, named_modules, prepare_custom_config, qconfig, backend_config
@ -946,7 +961,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
existing_obs_node = maybe_obs_node
break
assert arg_as_input_act_obs_or_fq is not None
if arg_as_input_act_obs_or_fq is None:
raise AssertionError("arg_as_input_act_obs_or_fq must not be None")
obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
if existing_obs_node is None:
new_obs_node = _insert_obs_or_fq(
@ -1102,7 +1118,8 @@ def _maybe_insert_output_observer_for_node(
Note: inserting dynamic quantization ops for output is not supported in fx graph mode
quantization code path right now
"""
assert node.op != "output", "observer insertion for outputs is handled elsewhere"
if node.op == "output":
raise AssertionError("observer insertion for outputs is handled elsewhere")
is_standalone_module = False
if "quantization_annotation" in node.meta:
@ -1110,7 +1127,8 @@ def _maybe_insert_output_observer_for_node(
node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
)
else:
assert "target_dtype_info" in node.meta
if "target_dtype_info" not in node.meta:
raise AssertionError("expected 'target_dtype_info' in node.meta")
is_standalone_module = node.meta["target_dtype_info"].get(
"_is_standalone_module", False
)
@ -1222,7 +1240,10 @@ def _maybe_insert_observers_before_graph_output(
and arg_as_input_target_dtype != torch.float
)
if need_obs:
assert observer_mod is not None
if observer_mod is None:
raise AssertionError(
"observer_mod must not be None when need_obs is True"
)
# insert observer
observer_node = _insert_obs_or_fq(
maybe_node, observer_mod, model, named_modules, graph
@ -1393,9 +1414,11 @@ def _maybe_make_input_output_share_observers(
if iteration_guard > 10000:
raise AssertionError("Unable to find observer of previous node")
assert isinstance(first_arg_arg, Node)
if not isinstance(first_arg_arg, Node):
raise AssertionError("first_arg_arg must be a Node")
target_to_use = first_arg_arg.target
assert isinstance(target_to_use, str)
if not isinstance(target_to_use, str):
raise AssertionError("target_to_use must be a string")
obs_mod_to_use = named_modules[target_to_use]
if isinstance(first_arg, (list, tuple)):
@ -1418,7 +1441,10 @@ def _maybe_make_input_output_share_observers(
# set the output observer node to use that module
for output_obs_node in node.users.keys():
assert _is_activation_post_process_node(output_obs_node, named_modules)
if not _is_activation_post_process_node(output_obs_node, named_modules):
raise AssertionError(
"output_obs_node must be an activation post process node"
)
parent_name, name = _parent_name(output_obs_node.target)
setattr(named_modules[parent_name], name, obs_mod_to_use)
@ -1431,7 +1457,10 @@ def _remove_output_observer(
):
items = list(node.users.items())
for output_obs_node, _ in items:
assert _is_activation_post_process_node(output_obs_node, named_modules)
if not _is_activation_post_process_node(output_obs_node, named_modules):
raise AssertionError(
"output_obs_node must be an activation post process node"
)
output_obs_node.replace_all_uses_with(node)
model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator]
@ -1554,7 +1583,8 @@ def insert_observers_for_model(
qhandler,
qconfig,
) = match_res_with_qconfig
assert qhandler is not None
if qhandler is None:
raise AssertionError("qhandler must not be None")
_set_target_dtype_info_for_matched_node_pattern(
matched_node_pattern,
last_node,
@ -1632,7 +1662,8 @@ def insert_observers_for_model(
pattern, matched_node_pattern, qconfig, backend_config
)
)
assert qhandler is not None
if qhandler is None:
raise AssertionError("qhandler must not be None")
# get output_act_dtype so that we don't also reset the special typed nodes
# TODO: we might want to handle these more uniformly with the default path
@ -1726,7 +1757,8 @@ def insert_observers_for_model(
if not skip_inserting_observers and is_supported_by_backend:
named_modules = dict(model.named_modules(remove_duplicate=False))
if node.op != "output":
assert matched_node_pattern is not None
if matched_node_pattern is None:
raise AssertionError("matched_node_pattern must not be None")
# add matched nodes to the observed node name set
_add_matched_node_name_to_set(
matched_node_pattern, observed_node_names
@ -2064,8 +2096,10 @@ def prepare(
)
backend_config = BackendConfig.from_dict(backend_config)
assert isinstance(qconfig_mapping, QConfigMapping)
assert isinstance(_equalization_config, QConfigMapping)
if not isinstance(qconfig_mapping, QConfigMapping):
raise AssertionError("qconfig_mapping must be a QConfigMapping")
if not isinstance(_equalization_config, QConfigMapping):
raise AssertionError("_equalization_config must be a QConfigMapping")
qconfig_mapping = copy.deepcopy(qconfig_mapping)
_equalization_config = copy.deepcopy(_equalization_config)
@ -2194,11 +2228,12 @@ def prepare(
)
if is_standalone_module:
assert result_node is not None
assert isinstance(result_node.args[0], Node), (
"standalone module only supports returning simple value currently"
"(not tuple, dict etc.)"
)
if result_node is None:
raise AssertionError("result_node must not be None for standalone modules")
if not isinstance(result_node.args[0], Node):
raise AssertionError(
"standalone module only supports returning simple value currently (not tuple, dict etc.)"
)
# these inputs are observed in parent
# converting List[int] to Tensor since module attribute is
# Union[Tensor, Module]

View File

@ -228,11 +228,12 @@ def _compare_prepare_convert_qconfig_mappings(
`prepare_qconfig_mapping`: configuration for prepare quantization step
`convert_qconfig_mapping`: configuration for convert quantization step
"""
assert qconfig_equals(
if not qconfig_equals(
prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig
), (
"Expected global qconfigs to be the same in the prepare and convert quantization configs"
)
):
raise AssertionError(
"Expected global qconfigs to be the same in the prepare and convert quantization configs"
)
prepare_dicts: list[OrderedDict] = [
prepare_qconfig_mapping.object_type_qconfigs,
prepare_qconfig_mapping.module_name_qconfigs,
@ -250,16 +251,17 @@ def _compare_prepare_convert_qconfig_mappings(
]
for i in range(len(prepare_dicts)):
for name in prepare_dicts[i].keys():
assert name in convert_dicts[i], (
f"Missing key {dict_names[i]} {name} in convert QConfigMapping \
when it was present in prepare"
)
assert convert_dicts[i][name] is None or qconfig_equals(
if name not in convert_dicts[i]:
raise AssertionError(
f"Missing key {dict_names[i]} {name} in convert QConfigMapping when it was present in prepare"
)
if convert_dicts[i][name] is not None and not qconfig_equals(
prepare_dicts[i][name], convert_dicts[i][name]
), (
f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \
prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}"
)
):
raise AssertionError(
"Expected convert QConfigMapping to have the same qconfig as prepare for key "
f"{dict_names[i]} {name}; prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}"
)
def _is_qconfig_supported_by_dtype_configs(

View File

@ -119,10 +119,11 @@ def _get_quantize_handler_cls(
):
super().__init__(node_pattern, modules, root_node_getter)
if num_tensor_args_to_observation_type:
assert self.num_tensor_args in num_tensor_args_to_observation_type, (
f"Must provide observation_type config for tensor number {self.num_tensor_args}"
f" in num_tensor_args_to_observation_type for {node_pattern}"
)
if self.num_tensor_args not in num_tensor_args_to_observation_type:
raise AssertionError(
f"Must provide observation_type config for tensor number {self.num_tensor_args}"
f" in num_tensor_args_to_observation_type for {node_pattern}"
)
self.observation_type = num_tensor_args_to_observation_type[
self.num_tensor_args
]

View File

@ -165,7 +165,8 @@ def get_qconv_prepack_op(conv_op: Callable) -> Callable:
torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
}
prepack_op = prepack_ops.get(conv_op)
assert prepack_op, f"Didn't find prepack op for {conv_op}"
if prepack_op is None:
raise AssertionError(f"Didn't find prepack op for {conv_op}")
return prepack_op
@ -230,7 +231,8 @@ def graph_module_from_producer_nodes(
Return:
A graph module constructed from the producer nodes
"""
assert len(producer_nodes) > 0, "list of producer nodes can not be empty"
if len(producer_nodes) == 0:
raise AssertionError("list of producer nodes can not be empty")
# since we traced back from node to getattr
producer_nodes.reverse()
graph = Graph()
@ -300,7 +302,8 @@ def all_node_args_have_no_tensors(
elif node.op == "placeholder":
result = False
elif node.op == "call_module":
assert isinstance(node.target, str)
if not isinstance(node.target, str):
raise AssertionError("node.target must be a string for call_module nodes")
if _is_activation_post_process(modules[node.target]):
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
elif node.op == "call_module":
@ -503,9 +506,10 @@ def _is_custom_module_lstm(
"""
mod = _get_module(node, named_modules)
if qconfig is not None and qhandler is not None:
assert isinstance(
if not isinstance(
qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler
) # type: ignore[attr-defined]
): # type: ignore[attr-defined]
raise AssertionError("qhandler must be a QuantizeHandler when provided")
return (
isinstance(mod, torch.nn.LSTM)
and activation_is_statically_quantized(qconfig)
@ -527,9 +531,10 @@ def _is_custom_module_mha(
"""
mod = _get_module(node, named_modules)
if qconfig is not None and qhandler is not None:
assert isinstance(
if not isinstance(
qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler
) # type: ignore[attr-defined]
): # type: ignore[attr-defined]
raise AssertionError("qhandler must be a QuantizeHandler when provided")
return (
isinstance(mod, torch.nn.MultiheadAttention)
and activation_is_statically_quantized(qconfig)
@ -826,11 +831,17 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
for pattern in matched_patterns:
first_tuple = pattern[0]
last_getitem = pattern[-1]
assert first_tuple.op == "call_function" and first_tuple.target is tuple
assert (
if not (first_tuple.op == "call_function" and first_tuple.target is tuple):
raise AssertionError(
"first tuple node must be a call_function with target tuple"
)
if not (
last_getitem.op == "call_function"
and last_getitem.target == operator.getitem
)
):
raise AssertionError(
"last getitem node must be a call_function with target operator.getitem"
)
last_getitem_index = last_getitem.args[1]
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
for user in list(last_getitem.users.keys()):
@ -847,7 +858,10 @@ def _get_observer_from_activation_post_process(
if isinstance(activation_post_process, ObserverBase):
return activation_post_process
else:
assert isinstance(activation_post_process, FakeQuantizeBase)
if not isinstance(activation_post_process, FakeQuantizeBase):
raise AssertionError(
"activation_post_process must be an ObserverBase or FakeQuantizeBase"
)
return activation_post_process.activation_post_process # type: ignore[return-value]
@ -966,7 +980,10 @@ def _qconfig_satisfies_dtype_config_constraints(
satisfies_constraints = True
if activation_post_process_ctr is not None:
activation_post_process = activation_post_process_ctr()
assert _is_activation_post_process(activation_post_process)
if not _is_activation_post_process(activation_post_process):
raise AssertionError(
"activation_post_process must be an activation post process"
)
# If dtypes don't match, don't check the activation_post_process and return True early
if activation_post_process.dtype != dtype_with_constraints.dtype:
return True