[Code Clean] Clean asserts in torch/ao/quantization/fx/* (#165420)

Replace assert statements with explicit if/raise patterns in:

- torch/ao/quantization/fx/* (177 errors)

fix partialy #164878

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165420
Approved by: https://github.com/RohitRathore1, https://github.com/fffrog, https://github.com/albanD
This commit is contained in:
zhudada 2025-10-30 20:53:31 +00:00 committed by PyTorch MergeBot
parent df71b70727
commit 7692fa09cd
14 changed files with 567 additions and 307 deletions

View File

@ -29,15 +29,17 @@ def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
raise ValueError(f"Unsupported dtype: {dtype}") raise ValueError(f"Unsupported dtype: {dtype}")
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype] quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
assert quant_min >= quant_min_lower_bound, ( if quant_min < quant_min_lower_bound:
"quant_min out of bound for dtype, " raise AssertionError(
f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}" "quant_min out of bound for dtype, "
) f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
)
assert quant_max <= quant_max_upper_bound, ( if quant_max > quant_max_upper_bound:
"quant_max out of bound for dtype, " raise AssertionError(
f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}" "quant_max out of bound for dtype, "
) f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
)
quantized_decomposed_lib.define( quantized_decomposed_lib.define(
@ -72,9 +74,10 @@ def quantize_per_tensor(
""" """
if input.dtype in [torch.float16, torch.bfloat16]: if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32) input = input.to(torch.float32)
assert input.dtype == torch.float32, ( if input.dtype != torch.float32:
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
_quant_min_max_bounds_check(quant_min, quant_max, dtype) _quant_min_max_bounds_check(quant_min, quant_max, dtype)
inv_scale = 1.0 / scale inv_scale = 1.0 / scale
@ -94,9 +97,10 @@ def quantize_per_tensor_meta(
) -> torch.Tensor: ) -> torch.Tensor:
if input.dtype in [torch.float16, torch.bfloat16]: if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32) input = input.to(torch.float32)
assert input.dtype == torch.float32, ( if input.dtype != torch.float32:
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
return torch.empty_like(input, dtype=dtype) return torch.empty_like(input, dtype=dtype)
@ -122,12 +126,14 @@ def quantize_per_tensor_tensor(
Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values scalar values
""" """
assert zero_point.numel() == 1, ( if zero_point.numel() != 1:
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" raise AssertionError(
) f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, ( )
f"Expecting scale tensor to be one element, but received : {scale.numel()}" if scale.numel() != 1:
) raise AssertionError(
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
return quantize_per_tensor( return quantize_per_tensor(
input, input,
scale.item(), scale.item(),
@ -149,15 +155,18 @@ def quantize_per_tensor_tensor_meta(
) -> torch.Tensor: ) -> torch.Tensor:
if input.dtype in [torch.float16, torch.bfloat16]: if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32) input = input.to(torch.float32)
assert zero_point.numel() == 1, ( if zero_point.numel() != 1:
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" raise AssertionError(
) f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, ( )
f"Expecting scale tensor to be one element, but received : {scale.numel()}" if scale.numel() != 1:
) raise AssertionError(
assert input.dtype == torch.float32, ( f"Expecting scale tensor to be one element, but received : {scale.numel()}"
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" )
) if input.dtype != torch.float32:
raise AssertionError(
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
)
return torch.empty_like(input, dtype=dtype) return torch.empty_like(input, dtype=dtype)
@ -184,12 +193,14 @@ def quantize_per_tensor_tensor2(
Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values scalar values
""" """
assert zero_point.numel() == 1, ( if zero_point.numel() != 1:
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" raise AssertionError(
) f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, ( )
f"Expecting scale tensor to be one element, but received : {scale.numel()}" if scale.numel() != 1:
) raise AssertionError(
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
return quantize_per_tensor( return quantize_per_tensor(
input, input,
scale.item(), scale.item(),
@ -266,9 +277,10 @@ def dequantize_per_tensor(
Returns: Returns:
dequantized float32 Tensor dequantized float32 Tensor
""" """
assert input.dtype == dtype, ( if input.dtype != dtype:
f"Expecting input to have dtype: {dtype}, but got {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
)
if out_dtype is None: if out_dtype is None:
out_dtype = torch.float32 out_dtype = torch.float32
if dtype in _DTYPE_TO_QVALUE_BOUNDS: if dtype in _DTYPE_TO_QVALUE_BOUNDS:
@ -322,12 +334,14 @@ def dequantize_per_tensor_tensor(
Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values scalar values
""" """
assert zero_point.numel() == 1, ( if zero_point.numel() != 1:
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" raise AssertionError(
) f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, ( )
f"Expecting scale tensor to be one element, but received : {scale.numel()}" if scale.numel() != 1:
) raise AssertionError(
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
return dequantize_per_tensor( return dequantize_per_tensor(
input, input,
scale.item(), scale.item(),
@ -352,13 +366,18 @@ def dequantize_per_tensor_tensor_meta(
) -> torch.Tensor: ) -> torch.Tensor:
if out_dtype is None: if out_dtype is None:
out_dtype = torch.float32 out_dtype = torch.float32
assert zero_point.numel() == 1, ( if zero_point.numel() != 1:
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" raise AssertionError(
) f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, ( )
f"Expecting scale tensor to be one element, but received : {scale.numel()}" if scale.numel() != 1:
) raise AssertionError(
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}" f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
if input.dtype != dtype:
raise AssertionError(
f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
)
if dtype in _DTYPE_TO_QVALUE_BOUNDS: if dtype in _DTYPE_TO_QVALUE_BOUNDS:
return torch.empty_like(input, dtype=out_dtype) return torch.empty_like(input, dtype=out_dtype)
else: else:
@ -392,12 +411,14 @@ def dequantize_per_tensor_tensor2(
Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
scalar values scalar values
""" """
assert zero_point.numel() == 1, ( if zero_point.numel() != 1:
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}" raise AssertionError(
) f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
assert scale.numel() == 1, ( )
f"Expecting scale tensor to be one element, but received : {scale.numel()}" if scale.numel() != 1:
) raise AssertionError(
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
)
return dequantize_per_tensor( return dequantize_per_tensor(
input, input,
scale.item(), scale.item(),
@ -448,16 +469,18 @@ def choose_qparams_tensor(
scale (float): quantization parameter for the target quantized Tensor scale (float): quantization parameter for the target quantized Tensor
zero_point (int): quantization parameter for the target quantized Tensor zero_point (int): quantization parameter for the target quantized Tensor
""" """
assert input.dtype in [ if input.dtype not in [
torch.float32, torch.float32,
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
], ( ]:
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
assert dtype in _DTYPE_TO_QVALUE_BOUNDS, ( )
f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
) raise AssertionError(
f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
)
validate_qmin_qmax(qmin, qmax) validate_qmin_qmax(qmin, qmax)
min_val, max_val = torch.aminmax(input) min_val, max_val = torch.aminmax(input)
@ -500,16 +523,18 @@ def choose_qparams_symmetric_tensor(
scale (float): quantization parameter for the target quantized Tensor scale (float): quantization parameter for the target quantized Tensor
zero_point (int): quantization parameter for the target quantized Tensor zero_point (int): quantization parameter for the target quantized Tensor
""" """
assert input.dtype in [ if input.dtype not in [
torch.float32, torch.float32,
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
], ( ]:
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
assert dtype in _DTYPE_TO_QVALUE_BOUNDS, ( )
f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}" if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
) raise AssertionError(
f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
)
validate_qmin_qmax(qmin, qmax) validate_qmin_qmax(qmin, qmax)
min_val, max_val = torch.aminmax(input) min_val, max_val = torch.aminmax(input)
@ -529,17 +554,18 @@ def choose_qparams_symmetric_tensor(
def choose_qparams_tensor_meta( def choose_qparams_tensor_meta(
input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
assert input.dtype in [ if input.dtype not in [
torch.float32, torch.float32,
torch.float16, torch.float16,
torch.bfloat16, torch.bfloat16,
], ( ]:
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
assert quant_min < quant_max, ( )
f"Expecting quant_min to be smaller than quant_max but received min: \ if quant_min >= quant_max:
{quant_min} max: {quant_max}" raise AssertionError(
) f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}"
)
return torch.empty(1, dtype=torch.double, device=input.device), torch.empty( return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(
1, dtype=torch.int64, device=input.device 1, dtype=torch.int64, device=input.device
) )
@ -598,10 +624,12 @@ def quantize_per_channel(
""" """
if input.dtype in [torch.float16, torch.bfloat16]: if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32) input = input.to(torch.float32)
assert input.dtype == torch.float32, ( if input.dtype != torch.float32:
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" )
if axis >= input.dim():
raise AssertionError(f"Expecting axis to be < {input.dim()}")
_quant_min_max_bounds_check(quant_min, quant_max, dtype) _quant_min_max_bounds_check(quant_min, quant_max, dtype)
input, permute_axis_list = _permute_to_axis_zero(input, axis) input, permute_axis_list = _permute_to_axis_zero(input, axis)
@ -629,10 +657,12 @@ def quantize_per_channel_meta(
) -> torch.Tensor: ) -> torch.Tensor:
if input.dtype in [torch.float16, torch.bfloat16]: if input.dtype in [torch.float16, torch.bfloat16]:
input = input.to(torch.float32) input = input.to(torch.float32)
assert input.dtype == torch.float32, ( if input.dtype != torch.float32:
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" )
if axis >= input.dim():
raise AssertionError(f"Expecting axis to be < {input.dim()}")
_quant_min_max_bounds_check(quant_min, quant_max, dtype) _quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=dtype) return torch.empty_like(input, dtype=dtype)
@ -687,12 +717,14 @@ def dequantize_per_channel(
Returns: Returns:
dequantized float32 Tensor dequantized float32 Tensor
""" """
assert input.dtype == dtype, ( if input.dtype != dtype:
f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype: {dtype}, but got dtype: {input.dtype}"
)
if out_dtype is None: if out_dtype is None:
out_dtype = torch.float32 out_dtype = torch.float32
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" if axis >= input.dim():
raise AssertionError(f"Expecting axis to be < {input.dim()}")
_quant_min_max_bounds_check(quant_min, quant_max, dtype) _quant_min_max_bounds_check(quant_min, quant_max, dtype)
input, permute_axis_list = _permute_to_axis_zero(input, axis) input, permute_axis_list = _permute_to_axis_zero(input, axis)
@ -722,12 +754,14 @@ def dequantize_per_channel_meta(
*, *,
out_dtype: Optional[torch.dtype] = None, out_dtype: Optional[torch.dtype] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert input.dtype == dtype, ( if input.dtype != dtype:
f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
)
if out_dtype is None: if out_dtype is None:
out_dtype = torch.float32 out_dtype = torch.float32
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" if axis >= input.dim():
raise AssertionError(f"Expecting axis to be < {input.dim()}")
_quant_min_max_bounds_check(quant_min, quant_max, dtype) _quant_min_max_bounds_check(quant_min, quant_max, dtype)
return torch.empty_like(input, dtype=out_dtype) return torch.empty_like(input, dtype=out_dtype)
@ -879,12 +913,12 @@ def choose_qparams_per_token_asymmetric_meta(
def _per_token_quant_qparam_dim_check(input, scales, zero_points): def _per_token_quant_qparam_dim_check(input, scales, zero_points):
num_tokens = math.prod(list(input.size())[:-1]) num_tokens = math.prod(list(input.size())[:-1])
assert num_tokens == scales.numel(), ( if num_tokens != scales.numel():
f"num_tokens: {num_tokens} scales: {scales.size()}" raise AssertionError(f"num_tokens: {num_tokens} scales: {scales.size()}")
) if num_tokens != zero_points.numel():
assert num_tokens == zero_points.numel(), ( raise AssertionError(
f"num_tokens: {num_tokens} zero_points: {zero_points.size()}" f"num_tokens: {num_tokens} zero_points: {zero_points.size()}"
) )
quantized_decomposed_lib.define( quantized_decomposed_lib.define(
@ -1019,17 +1053,21 @@ def quantize_per_channel_group(
dtype: torch.dtype, dtype: torch.dtype,
group_size=128, group_size=128,
): ):
assert group_size > 1 if group_size <= 1:
raise AssertionError("group_size must be > 1")
# needed for GPTQ single column quantize # needed for GPTQ single column quantize
if group_size > input.shape[-1] and scales.shape[-1] == 1: if group_size > input.shape[-1] and scales.shape[-1] == 1:
group_size = input.shape[-1] group_size = input.shape[-1]
assert input.shape[-1] % group_size == 0 if input.shape[-1] % group_size != 0:
assert input.dim() == 2 raise AssertionError("input.shape[-1] must be divisible by group_size")
if input.dim() != 2:
raise AssertionError("input must be 2-dimensional")
# TODO: check for dtype, currently we can't express torch.int4 so it's omitted # TODO: check for dtype, currently we can't express torch.int4 so it's omitted
to_quant = input.reshape(-1, group_size) to_quant = input.reshape(-1, group_size)
assert torch.isnan(to_quant).sum() == 0 if torch.isnan(to_quant).sum() != 0:
raise AssertionError("to_quant must not contain NaNs")
scales = scales.reshape(-1, 1) scales = scales.reshape(-1, 1)
zero_points = zero_points.reshape(-1, 1) zero_points = zero_points.reshape(-1, 1)
@ -1074,13 +1112,16 @@ def quantize_per_channel_group_meta(
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
are not stored in the Tensor, we are storing them in function arguments instead are not stored in the Tensor, we are storing them in function arguments instead
""" """
assert group_size > 1 if group_size <= 1:
raise AssertionError("group_size must be > 1")
# needed for GPTQ single column quantize # needed for GPTQ single column quantize
if group_size > input.shape[-1] and scales.shape[-1] == 1: if group_size > input.shape[-1] and scales.shape[-1] == 1:
group_size = input.shape[-1] group_size = input.shape[-1]
assert input.shape[-1] % group_size == 0 if input.shape[-1] % group_size != 0:
assert input.dim() == 2 raise AssertionError("input.shape[-1] must be divisible by group_size")
if input.dim() != 2:
raise AssertionError("input must be 2-dimensional")
return torch.empty_like(input, dtype=dtype) return torch.empty_like(input, dtype=dtype)
@ -1124,12 +1165,15 @@ def dequantize_per_channel_group(
dequantized Tensor with dtype `output_dtype` dequantized Tensor with dtype `output_dtype`
""" """
assert group_size > 1 if group_size <= 1:
raise AssertionError("group_size must be > 1")
# needed for GPTQ single column dequantize # needed for GPTQ single column dequantize
if group_size > w_int8.shape[-1] and scales.shape[-1] == 1: if group_size > w_int8.shape[-1] and scales.shape[-1] == 1:
group_size = w_int8.shape[-1] group_size = w_int8.shape[-1]
assert w_int8.shape[-1] % group_size == 0 if w_int8.shape[-1] % group_size != 0:
assert w_int8.dim() == 2 raise AssertionError("w_int8.shape[-1] must be divisible by group_size")
if w_int8.dim() != 2:
raise AssertionError("w_int8 must be 2-dimensional")
w_int8_grouped = w_int8.reshape(-1, group_size) w_int8_grouped = w_int8.reshape(-1, group_size)
scales = scales.reshape(-1, 1) scales = scales.reshape(-1, 1)
@ -1155,10 +1199,12 @@ class FakeQuantPerChannel(torch.autograd.Function):
scales = scales.to(torch.float32) scales = scales.to(torch.float32)
if zero_points.dtype != torch.int32: if zero_points.dtype != torch.int32:
zero_points = zero_points.to(torch.int32) zero_points = zero_points.to(torch.int32)
assert input.dtype == torch.float32, ( if input.dtype != torch.float32:
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}" raise AssertionError(
) f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}" )
if axis >= input.dim():
raise AssertionError(f"Expecting axis to be < {input.dim()}")
broadcast_dims = list(range(axis)) + list(range(axis + 1, input.ndim)) broadcast_dims = list(range(axis)) + list(range(axis + 1, input.ndim))
unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims) unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims)
unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims) unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims)

View File

@ -90,7 +90,7 @@ class _InputEqualizationObserver(nn.Module):
self.equalization_shape: list[int] = [] self.equalization_shape: list[int] = []
def forward(self, x_orig): def forward(self, x_orig):
if not (x_orig.ndim >= 2 and x_orig.ndim <= 5): if x_orig.ndim < 2 or x_orig.ndim > 5:
raise ValueError( raise ValueError(
"InputEqualizationObserver only supports Linear and Conv layers" "InputEqualizationObserver only supports Linear and Conv layers"
) )
@ -191,7 +191,7 @@ class _WeightEqualizationObserver(nn.Module):
self.equalization_scale = torch.tensor(1) self.equalization_scale = torch.tensor(1)
def forward(self, w_orig): def forward(self, w_orig):
if not (w_orig.ndim >= 2 and w_orig.ndim <= 5): if w_orig.ndim < 2 or w_orig.ndim > 5:
raise ValueError( raise ValueError(
"InputEqualizationObserver only supports Linear and Conv layers" "InputEqualizationObserver only supports Linear and Conv layers"
) )
@ -232,7 +232,7 @@ def calculate_equalization_scale(
) )
return torch.tensor(1) return torch.tensor(1)
if not (min_inputs.shape == min_weights.shape): if min_inputs.shape != min_weights.shape:
raise ValueError( raise ValueError(
"Input and Weight must have the same column dimension. " "Input and Weight must have the same column dimension. "
+ f"Found {min_inputs.shape} and {min_weights.shape} shapes instead." + f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
@ -355,30 +355,45 @@ def get_op_node_and_weight_eq_obs(
op_node = user op_node = user
break break
assert op_node is not None if op_node is None:
raise AssertionError(
"Expected an operation node after the input equalization observer"
)
if op_node.op == "call_module": if op_node.op == "call_module":
# If the op_node is a nn.Linear layer, then it must have a # If the op_node is a nn.Linear layer, then it must have a
# WeightEqualizationObserver configuration # WeightEqualizationObserver configuration
maybe_equalization_node_name_to_config = _get_observed_graph_module_attr( maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(
model, "equalization_node_name_to_qconfig" model, "equalization_node_name_to_qconfig"
) )
assert maybe_equalization_node_name_to_config is not None if maybe_equalization_node_name_to_config is None:
raise AssertionError(
"Expected 'equalization_node_name_to_qconfig' attribute in observed graph module"
)
equalization_node_name_to_qconfig: dict[str, Any] = ( equalization_node_name_to_qconfig: dict[str, Any] = (
maybe_equalization_node_name_to_config # type: ignore[assignment] maybe_equalization_node_name_to_config # type: ignore[assignment]
) )
assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None if equalization_node_name_to_qconfig.get(op_node.name, None) is None:
raise AssertionError(
f"No equalization qconfig found for op node {op_node.name}"
)
weight_eq_obs = equalization_node_name_to_qconfig.get( # type: ignore[union-attr] weight_eq_obs = equalization_node_name_to_qconfig.get( # type: ignore[union-attr]
op_node.name, None op_node.name, None
).weight() ).weight()
assert isinstance(weight_eq_obs, _WeightEqualizationObserver) if not isinstance(weight_eq_obs, _WeightEqualizationObserver):
raise AssertionError(
"Expected weight equalization observer to be a _WeightEqualizationObserver"
)
return op_node, weight_eq_obs return op_node, weight_eq_obs
elif op_node.op == "call_function": elif op_node.op == "call_function":
weight_node = maybe_get_weight_eq_obs_node(op_node, modules) weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
if weight_node is not None: if weight_node is not None:
weight_eq_obs = modules[str(weight_node.target)] weight_eq_obs = modules[str(weight_node.target)]
assert isinstance(weight_eq_obs, _WeightEqualizationObserver) if not isinstance(weight_eq_obs, _WeightEqualizationObserver):
raise AssertionError(
"Expected weight equalization observer to be a _WeightEqualizationObserver"
)
return op_node, weight_eq_obs return op_node, weight_eq_obs
return None, None return None, None
@ -388,17 +403,20 @@ def maybe_get_weight_eq_obs_node(
op_node: Node, modules: dict[str, nn.Module] op_node: Node, modules: dict[str, nn.Module]
) -> Optional[Node]: ) -> Optional[Node]:
"""Gets the weight equalization observer node if it exists.""" """Gets the weight equalization observer node if it exists."""
assert op_node.op == "call_function" if op_node.op != "call_function":
raise AssertionError(
"maybe_get_weight_eq_obs_node expects a call_function op_node"
)
for node_arg in op_node.args: for node_arg in op_node.args:
if node_arg_is_weight(op_node, node_arg): if node_arg_is_weight(op_node, node_arg):
assert ( if (
isinstance(node_arg, Node) isinstance(node_arg, Node)
and node_arg.op == "call_module" and node_arg.op == "call_module"
and isinstance( and isinstance(
modules[str(node_arg.target)], _WeightEqualizationObserver modules[str(node_arg.target)], _WeightEqualizationObserver
) )
) ):
return node_arg return node_arg
return None return None
@ -422,7 +440,8 @@ def maybe_get_next_input_eq_obs(
the following equalization observer for linear2. the following equalization observer for linear2.
""" """
assert node_supports_equalization(node, modules) if not node_supports_equalization(node, modules):
raise AssertionError("Node does not support equalization")
# Locate the following nn.ReLU or F.relu node if it exists # Locate the following nn.ReLU or F.relu node if it exists
maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU) maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
@ -448,7 +467,10 @@ def maybe_get_next_input_eq_obs(
return None return None
maybe_eq_obs = modules[str(maybe_eq_obs_node)] maybe_eq_obs = modules[str(maybe_eq_obs_node)]
assert isinstance(maybe_eq_obs, _InputEqualizationObserver) if not isinstance(maybe_eq_obs, _InputEqualizationObserver):
raise AssertionError(
"Expected the following equalization observer to be an _InputEqualizationObserver"
)
return maybe_eq_obs return maybe_eq_obs
@ -480,10 +502,16 @@ def scale_input_observer(node: Node, modules: dict[str, nn.Module]) -> None:
equalization observer equalization observer
""" """
input_eq_obs = modules[str(node.target)] input_eq_obs = modules[str(node.target)]
assert isinstance(input_eq_obs, _InputEqualizationObserver) if not isinstance(input_eq_obs, _InputEqualizationObserver):
raise AssertionError(
"Expected the module at node.target to be an _InputEqualizationObserver"
)
input_quant_obs_node = node.args[0] input_quant_obs_node = node.args[0]
assert isinstance(input_quant_obs_node, Node) if not isinstance(input_quant_obs_node, Node):
raise AssertionError(
"Expected the input quantization observer node to be a Node"
)
input_quant_obs = modules[str(input_quant_obs_node.target)] input_quant_obs = modules[str(input_quant_obs_node.target)]
if not isinstance(input_quant_obs, ObserverBase): if not isinstance(input_quant_obs, ObserverBase):
@ -518,14 +546,19 @@ def scale_weight_node(
op_module = modules[str(node.target)][0] # type: ignore[index] op_module = modules[str(node.target)][0] # type: ignore[index]
else: else:
op_module = modules[str(node.target)] op_module = modules[str(node.target)]
assert nn_module_supports_equalization( if not (
op_module nn_module_supports_equalization(op_module)
) or custom_module_supports_equalization(op_module) or custom_module_supports_equalization(op_module)
):
raise AssertionError(
"Expected operation module to support equalization (nn or custom)"
)
# Scale the weights for input-weight equalization # Scale the weights for input-weight equalization
# If the following layer needs to be equalized then we will multiply its scale # If the following layer needs to be equalized then we will multiply its scale
weight = op_module.weight weight = op_module.weight
assert isinstance(weight, torch.Tensor) if not isinstance(weight, torch.Tensor):
raise AssertionError("Expected op_module.weight to be a torch.Tensor")
# Scale the weights by the reciprocal of the equalization scale # Scale the weights by the reciprocal of the equalization scale
# Reshape the equalization scale so that we can multiply it to the weight along axis=1 # Reshape the equalization scale so that we can multiply it to the weight along axis=1
@ -547,7 +580,8 @@ def scale_weight_node(
bias = op_module.bias bias = op_module.bias
if bias is None: if bias is None:
return return
assert isinstance(bias, torch.Tensor) if not isinstance(bias, torch.Tensor):
raise AssertionError("Expected op_module.bias to be a torch.Tensor")
# Reshape the equalization scale so that we can multiply it element-wise to the bias # Reshape the equalization scale so that we can multiply it element-wise to the bias
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias) next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
@ -581,15 +615,20 @@ def scale_weight_functional(
weight_quant_obs_node = weight_eq_obs_node.args[0] weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None: if weight_quant_obs_node is None:
return return
assert isinstance(weight_quant_obs_node, Node) and isinstance( if not (
modules[str(weight_quant_obs_node.target)], ObserverBase isinstance(weight_quant_obs_node, Node)
) and isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
):
raise AssertionError(
"Expected weight_quant_obs_node to be a Node whose module is an ObserverBase"
)
# Get the get_attr(weight) node # Get the get_attr(weight) node
weight_node = weight_quant_obs_node.args[0] weight_node = weight_quant_obs_node.args[0]
if weight_node is None: if weight_node is None:
return return
assert isinstance(weight_node, Node) and weight_node.op == "get_attr" if not (isinstance(weight_node, Node) and weight_node.op == "get_attr"):
raise AssertionError("Expected weight node to be a 'get_attr' Node")
weight_parent_name, weight_name = _parent_name(weight_node.target) weight_parent_name, weight_name = _parent_name(weight_node.target)
weight = getattr(modules[weight_parent_name], weight_name) weight = getattr(modules[weight_parent_name], weight_name)
@ -612,7 +651,8 @@ def scale_weight_functional(
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped) scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
setattr(modules[weight_parent_name], weight_name, scaled_weight) setattr(modules[weight_parent_name], weight_name, scaled_weight)
assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight) if not torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight):
raise AssertionError("Model buffer for weight does not match the scaled weight")
# Multiply the bias element wise by the next equalization scale # Multiply the bias element wise by the next equalization scale
bias_node = None bias_node = None
@ -644,10 +684,14 @@ def clear_weight_quant_obs_node(op_node: Node, modules: dict[str, nn.Module]) ->
weight_quant_obs_node = weight_eq_obs_node.args[0] weight_quant_obs_node = weight_eq_obs_node.args[0]
if weight_quant_obs_node is None: if weight_quant_obs_node is None:
return return
assert isinstance(weight_quant_obs_node, Node) if not isinstance(weight_quant_obs_node, Node):
raise AssertionError("Expected weight_quant_obs_node to be a Node")
weight_quant_obs = modules[str(weight_quant_obs_node.target)] weight_quant_obs = modules[str(weight_quant_obs_node.target)]
assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase) if not isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase):
raise AssertionError(
"Expected the module at weight_quant_obs_node to be an ObserverBase"
)
weight_quant_obs.reset_min_max_vals() # type: ignore[operator] weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
@ -682,7 +726,10 @@ def update_obs_for_equalization(
modules[node.target], _InputEqualizationObserver modules[node.target], _InputEqualizationObserver
): ):
input_eq_obs = modules[node.target] input_eq_obs = modules[node.target]
assert isinstance(input_eq_obs, _InputEqualizationObserver) if not isinstance(input_eq_obs, _InputEqualizationObserver):
raise AssertionError(
"Expected module at node.target to be an _InputEqualizationObserver"
)
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules) op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
if op_node is None or weight_eq_obs is None: if op_node is None or weight_eq_obs is None:
@ -693,7 +740,10 @@ def update_obs_for_equalization(
# been created # been created
if fused_module_supports_equalization(modules[str(op_node.target)]): if fused_module_supports_equalization(modules[str(op_node.target)]):
module = modules[str(op_node.target)][0] # type: ignore[index] module = modules[str(op_node.target)][0] # type: ignore[index]
assert nn_module_supports_equalization(module) if not nn_module_supports_equalization(module):
raise AssertionError(
"Expected fused module to support equalization"
)
weight_eq_obs(module.weight) weight_eq_obs(module.weight)
else: else:
weight_eq_obs(modules[str(op_node.target)].weight) weight_eq_obs(modules[str(op_node.target)].weight)
@ -810,7 +860,10 @@ def convert_eq_obs(
elif weight_eq_obs_dict.get(node.name, None) is not None: elif weight_eq_obs_dict.get(node.name, None) is not None:
weight_eq_obs = weight_eq_obs_dict.get(node.name) weight_eq_obs = weight_eq_obs_dict.get(node.name)
assert isinstance(weight_eq_obs, _WeightEqualizationObserver) if not isinstance(weight_eq_obs, _WeightEqualizationObserver):
raise AssertionError(
"Expected weight equalization observer to be a _WeightEqualizationObserver"
)
equalization_scale = weight_eq_obs.equalization_scale equalization_scale = weight_eq_obs.equalization_scale
if ( if (
@ -844,9 +897,12 @@ def convert_eq_obs(
weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules) weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
if weight_eq_obs_node is None: if weight_eq_obs_node is None:
return return
assert isinstance( if not isinstance(
modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver
) ):
raise AssertionError(
"Expected weight equalization observer to be a _WeightEqualizationObserver"
)
# Clear the quantization observer's min/max values so that they # Clear the quantization observer's min/max values so that they
# can get updated later based on the new scale values # can get updated later based on the new scale values

View File

@ -585,7 +585,8 @@ def _match_static_pattern(
return SKIP_LOWERING_VALUE return SKIP_LOWERING_VALUE
q_node = node q_node = node
ref_node = q_node.args[0] ref_node = q_node.args[0]
assert isinstance(ref_node, Node) if not isinstance(ref_node, Node):
raise AssertionError("Expected the reference node to be a torch.fx Node")
# Handle cases where the node is wrapped in a ReLU # Handle cases where the node is wrapped in a ReLU
if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or ( if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or (
@ -593,7 +594,10 @@ def _match_static_pattern(
): ):
relu_node = ref_node relu_node = ref_node
ref_node = relu_node.args[0] ref_node = relu_node.args[0]
assert isinstance(ref_node, Node) if not isinstance(ref_node, Node):
raise AssertionError(
"Expected the reference node after ReLU to be a torch.fx Node"
)
else: else:
relu_node = None relu_node = None
if should_skip_lowering(ref_node, qconfig_map): if should_skip_lowering(ref_node, qconfig_map):
@ -616,9 +620,10 @@ def _match_static_pattern(
# (2) There must be at least one dequantize node # (2) There must be at least one dequantize node
matched_dequantize = False matched_dequantize = False
for i in dequantize_node_arg_indices: for i in dequantize_node_arg_indices:
assert i < len(ref_node.args), ( if i >= len(ref_node.args):
f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}" raise AssertionError(
) f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}"
)
arg = ref_node.args[i] arg = ref_node.args[i]
if is_dequantize_node(arg): if is_dequantize_node(arg):
matched_dequantize = True matched_dequantize = True
@ -660,7 +665,8 @@ def _match_static_pattern_with_two_inputs(
return SKIP_LOWERING_VALUE return SKIP_LOWERING_VALUE
q_node = node q_node = node
ref_node = q_node.args[0] ref_node = q_node.args[0]
assert isinstance(ref_node, Node) if not isinstance(ref_node, Node):
raise AssertionError("Expected the reference node to be a torch.fx Node")
if should_skip_lowering(ref_node, qconfig_map): if should_skip_lowering(ref_node, qconfig_map):
return SKIP_LOWERING_VALUE return SKIP_LOWERING_VALUE
@ -711,13 +717,21 @@ def _lower_static_weighted_ref_module(
) )
if q_node is None: if q_node is None:
continue continue
assert ref_node is not None if ref_node is None:
raise AssertionError(
"Expected a reference node when matching static pattern"
)
(_, scale_node, zero_point_node, _) = q_node.args (_, scale_node, zero_point_node, _) = q_node.args
ref_module = _get_module(ref_node, modules) ref_module = _get_module(ref_node, modules)
ref_class = type(ref_module) ref_class = type(ref_module)
assert isinstance(scale_node, Node) if not isinstance(scale_node, Node):
assert isinstance(zero_point_node, Node) raise AssertionError("Expected scale_node to be a Node")
assert issubclass(ref_class, nn.Module) if not isinstance(zero_point_node, Node):
raise AssertionError("Expected zero_point_node to be a Node")
if not issubclass(ref_class, nn.Module):
raise AssertionError(
"Expected reference module class to be a subclass of nn.Module"
)
# Step 1: Change this pattern to use the corresponding quantized module # Step 1: Change this pattern to use the corresponding quantized module
# For fused modules, we also check whether the inner module is a reference module # For fused modules, we also check whether the inner module is a reference module
@ -736,9 +750,11 @@ def _lower_static_weighted_ref_module(
setattr(modules[parent_name], module_name, q_module) setattr(modules[parent_name], module_name, q_module)
# Step 2: Reroute around dq_node, and remove q_node and its args # Step 2: Reroute around dq_node, and remove q_node and its args
assert len(ref_node.args) == 1 if len(ref_node.args) != 1:
raise AssertionError("Expected reference node to have exactly 1 arg")
dq_node = ref_node.args[0] dq_node = ref_node.args[0]
assert isinstance(dq_node, Node) if not isinstance(dq_node, Node):
raise AssertionError("Expected dq_node to be a Node")
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
q_node.replace_all_uses_with(ref_node) q_node.replace_all_uses_with(ref_node)
model.graph.erase_node(q_node) model.graph.erase_node(q_node)
@ -771,13 +787,21 @@ def _lower_static_weighted_ref_module_with_two_inputs(
) )
if q_node is None: if q_node is None:
continue continue
assert ref_node is not None if ref_node is None:
raise AssertionError(
"Expected a reference node when matching static pattern with two inputs"
)
(_, scale_node, zero_point_node, _) = q_node.args (_, scale_node, zero_point_node, _) = q_node.args
ref_module = _get_module(ref_node, modules) ref_module = _get_module(ref_node, modules)
ref_class = type(ref_module) ref_class = type(ref_module)
assert isinstance(scale_node, Node) if not isinstance(scale_node, Node):
assert isinstance(zero_point_node, Node) raise AssertionError("Expected scale_node to be a Node")
assert issubclass(ref_class, nn.Module) if not isinstance(zero_point_node, Node):
raise AssertionError("Expected zero_point_node to be a Node")
if not issubclass(ref_class, nn.Module):
raise AssertionError(
"Expected reference module class to be a subclass of nn.Module"
)
# Step 1: Change this pattern to use the corresponding quantized module # Step 1: Change this pattern to use the corresponding quantized module
# For fused modules, we also check whether the inner module is a reference module # For fused modules, we also check whether the inner module is a reference module
@ -798,12 +822,14 @@ def _lower_static_weighted_ref_module_with_two_inputs(
setattr(modules[parent_name], module_name, q_module) setattr(modules[parent_name], module_name, q_module)
# Step 2: Reroute around dq_node, and remove q_node and its args # Step 2: Reroute around dq_node, and remove q_node and its args
assert len(ref_node.args) == 2 if len(ref_node.args) != 2:
raise AssertionError("Expected reference node to have exactly 2 args")
for arg in ref_node.args: for arg in ref_node.args:
if not is_dequantize_node(arg): if not is_dequantize_node(arg):
continue continue
dq_node = arg dq_node = arg
assert isinstance(dq_node, Node) if not isinstance(dq_node, Node):
raise AssertionError("Expected dq_node to be a Node")
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type] ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
q_node.replace_all_uses_with(ref_node) q_node.replace_all_uses_with(ref_node)
@ -900,14 +926,21 @@ def _lower_static_weighted_ref_functional(
) )
if q_node is None: if q_node is None:
continue continue
assert func_node is not None if func_node is None:
raise AssertionError(
"Expected a function node when matching static functional pattern"
)
(_, output_scale_node, output_zp_node, _) = q_node.args (_, output_scale_node, output_zp_node, _) = q_node.args
(input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args (input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
assert isinstance(output_zp_node, Node) if not isinstance(output_zp_node, Node):
assert isinstance(input_dq_node, Node) raise AssertionError("Expected output_zp_node to be a Node")
assert isinstance(weight_dq_node, Node) if not isinstance(input_dq_node, Node):
raise AssertionError("Expected input_dq_node to be a Node")
if not isinstance(weight_dq_node, Node):
raise AssertionError("Expected weight_dq_node to be a Node")
quantized_weight = weight_dq_node.args[0] quantized_weight = weight_dq_node.args[0]
assert isinstance(quantized_weight, Node) if not isinstance(quantized_weight, Node):
raise AssertionError("Expected quantized_weight to be a Node")
if quantized_weight.op != "call_function" or quantized_weight.target not in ( if quantized_weight.op != "call_function" or quantized_weight.target not in (
torch.quantize_per_tensor, torch.quantize_per_tensor,
torch.quantize_per_channel, torch.quantize_per_channel,
@ -1135,7 +1168,10 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfi
) )
if q_node is None: if q_node is None:
continue continue
assert bop_node is not None if bop_node is None:
raise AssertionError(
"Expected a binary op node when matching quantized binary op pattern"
)
(_, scale_node, zero_point_node, _) = q_node.args (_, scale_node, zero_point_node, _) = q_node.args
# Step 1: Remove dequant nodes # Step 1: Remove dequant nodes
@ -1144,14 +1180,21 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfi
if not is_dequantize_node(arg): if not is_dequantize_node(arg):
continue continue
dq_node = arg dq_node = arg
assert isinstance(dq_node, Node) if not isinstance(dq_node, Node):
raise AssertionError("Expected dq_node to be a Node")
dn_input = dq_node.args[0] dn_input = dq_node.args[0]
bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type] bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type]
num_dq_nodes += 1 num_dq_nodes += 1
assert num_dq_nodes > 0 if num_dq_nodes <= 0:
raise AssertionError(
"Expected at least one dequantize node in binary op args"
)
# Step 2: Swap binary op to quantized binary op # Step 2: Swap binary op to quantized binary op
assert bop_node.target in QBIN_OP_MAPPING if bop_node.target not in QBIN_OP_MAPPING:
raise AssertionError(
f"Unsupported binary op {bop_node.target} for lowering"
)
binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING
qbin_op = binop_to_qbinop[bop_node.target] qbin_op = binop_to_qbinop[bop_node.target]
# prepare the args for quantized binary op # prepare the args for quantized binary op
@ -1188,7 +1231,8 @@ def special_pattern_replacement(model: GraphModule):
and len(q_node.args) == 2 and len(q_node.args) == 2
and q_node.args[1] == torch.float16 and q_node.args[1] == torch.float16
) )
if not (is_quantize or is_to_fp16): # Only continue when neither quantize nor to_fp16
if not is_quantize and not is_to_fp16:
continue continue
ref_node = q_node.args[0] ref_node = q_node.args[0]
# get output scale/zero_point/dtype from the quantize node # get output scale/zero_point/dtype from the quantize node
@ -1217,13 +1261,17 @@ def special_pattern_replacement(model: GraphModule):
) )
if not (is_call_module or is_call_function or is_call_method): if not (is_call_module or is_call_function or is_call_method):
continue continue
assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0 if len(ref_node.args) <= 0 and len(ref_node.kwargs) <= 0:
raise AssertionError("Expected ref_node to have args or kwargs")
dq_node_or_nodes = ( dq_node_or_nodes = (
ref_node.args[0] ref_node.args[0]
if len(ref_node.args) > 0 if len(ref_node.args) > 0
else next(iter(ref_node.kwargs.values())) else next(iter(ref_node.kwargs.values()))
) )
assert isinstance(dq_node_or_nodes, (Node, tuple, list)) if not isinstance(dq_node_or_nodes, (Node, tuple, list)):
raise AssertionError(
"Expected dq_node_or_nodes to be a Node, tuple, or list"
)
is_dequantize = False is_dequantize = False
if isinstance(dq_node_or_nodes, Node): if isinstance(dq_node_or_nodes, Node):
is_dequantize = ( is_dequantize = (

View File

@ -362,11 +362,15 @@ class PerChannelDetector(DetectorBase):
# assert statement for MyPy # assert statement for MyPy
q_config_file = module.qconfig q_config_file = module.qconfig
assert isinstance(q_config_file, QConfig) if not isinstance(q_config_file, QConfig):
raise AssertionError("module.qconfig must be a QConfig")
# this object should either be fake quant or observer # this object should either be fake quant or observer
q_or_s_obj = module.qconfig.weight.p.func() q_or_s_obj = module.qconfig.weight.p.func()
assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)) if not isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)):
raise AssertionError(
"module.qconfig.weight must be a FakeQuantize or ObserverBase"
)
per_channel_used = False # will be true if found in qconfig per_channel_used = False # will be true if found in qconfig
@ -1160,9 +1164,10 @@ class InputWeightEqualizationDetector(DetectorBase):
input_channels = len(input_ratio) input_channels = len(input_ratio)
if weight_channels != input_channels: if weight_channels != input_channels:
# we try to replicate # we try to replicate
assert input_channels % weight_channels == 0, ( if input_channels % weight_channels != 0:
"input channels should be divisible by weight channels." raise AssertionError(
) "input channels should be divisible by weight channels."
)
# get replication factor # get replication factor
rep_factor: int = input_channels // weight_channels rep_factor: int = input_channels // weight_channels
@ -1418,11 +1423,15 @@ class OutlierDetector(DetectorBase):
self.ratio_threshold = ratio_threshold self.ratio_threshold = ratio_threshold
# make sure passed in percentile is valid # make sure passed in percentile is valid
assert reference_percentile >= 0 and reference_percentile <= 1 if reference_percentile < 0 or reference_percentile > 1:
assert ( raise AssertionError("reference_percentile must be between 0 and 1")
if not (
fraction_batches_used_threshold >= 0 fraction_batches_used_threshold >= 0
and fraction_batches_used_threshold <= 1 and fraction_batches_used_threshold <= 1
) ):
raise AssertionError(
"fraction_batches_used_threshold must be between 0 and 1"
)
self.reference_percentile = reference_percentile self.reference_percentile = reference_percentile
self.fraction_batches_used_threshold = fraction_batches_used_threshold self.fraction_batches_used_threshold = fraction_batches_used_threshold
self.ch_axis = ch_axis self.ch_axis = ch_axis

View File

@ -261,7 +261,8 @@ class ModelReport:
raise ValueError("The node_fqn is was not found within the module.") raise ValueError("The node_fqn is was not found within the module.")
# assert for MyPy # assert for MyPy
assert isinstance(node_to_return, torch.fx.node.Node) if not isinstance(node_to_return, torch.fx.node.Node):
raise AssertionError("node_to_return must be a torch.fx.node.Node")
return node_to_return return node_to_return

View File

@ -112,8 +112,12 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
or quantize_per_channel and dequantize_per_channel or quantize_per_channel and dequantize_per_channel
""" """
graph = model.graph graph = model.graph
assert modules is not None if modules is None:
assert isinstance(node.target, str) raise AssertionError("modules must not be None")
if not isinstance(node.target, str):
raise AssertionError(
f"Expected node.target to be a str, but got {type(node.target)}"
)
module_path, prefix = _get_module_path_and_prefix( module_path, prefix = _get_module_path_and_prefix(
node, node_name_to_scope, node_name_to_qconfig node, node_name_to_scope, node_name_to_qconfig
) )
@ -260,10 +264,10 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
# and that can be done after we remove reduce_range flag # and that can be done after we remove reduce_range flag
# 1. extract qparams from activation_post_process module # 1. extract qparams from activation_post_process module
dtype_ = to_underlying_dtype(dtype) dtype_ = to_underlying_dtype(dtype)
assert dtype_ in [torch.uint8, torch.int8], ( if dtype_ not in [torch.uint8, torch.int8]:
"only uint8 and int8 are supported in reference flow for " raise AssertionError(
"dynamic quantization right now" "only uint8 and int8 are supported in reference flow for dynamic quantization right now"
) )
quant_min = activation_post_process.quant_min # type: ignore[attr-defined] quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
quant_max = activation_post_process.quant_max # type: ignore[attr-defined] quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined] qscheme = getattr(activation_post_process, "qscheme", torch.per_tensor_affine) # type: ignore[attr-defined]
@ -379,8 +383,12 @@ def _replace_observer_with_quantize_dequantize_node(
After: After:
... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ... ... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
""" """
assert modules is not None if modules is None:
assert isinstance(node.target, str) raise AssertionError("modules must not be None")
if not isinstance(node.target, str):
raise AssertionError(
f"Expected node.target to be a str, but got {type(node.target)}"
)
graph = model.graph graph = model.graph
module_path, prefix = _get_module_path_and_prefix( module_path, prefix = _get_module_path_and_prefix(
node, node_name_to_scope, node_name_to_qconfig node, node_name_to_scope, node_name_to_qconfig
@ -521,9 +529,10 @@ def _replace_observer_or_dequant_stub_with_dequantize_node(
node: Node, graph: Graph node: Node, graph: Graph
) -> None: ) -> None:
call_custom_module_node = node.args[0] call_custom_module_node = node.args[0]
assert isinstance(call_custom_module_node, Node), ( if not isinstance(call_custom_module_node, Node):
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}" raise AssertionError(
) f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
)
node.replace_all_uses_with(call_custom_module_node) node.replace_all_uses_with(call_custom_module_node)
graph.erase_node(node) graph.erase_node(node)
_insert_dequantize_node(call_custom_module_node, graph) _insert_dequantize_node(call_custom_module_node, graph)
@ -617,9 +626,10 @@ def _get_module_path_and_prefix(
# operator (they can be the same) # operator (they can be the same)
# this flag identifies if the observer is inserted only because the observed node is # this flag identifies if the observer is inserted only because the observed node is
# the input of the next operator # the input of the next operator
assert isinstance(observed_node, Node), ( if not isinstance(observed_node, Node):
f"Expecting observed node to be a Node, but got {observed_node}" raise AssertionError(
) f"Expecting observed node to be a Node, but got {observed_node}"
)
is_input_observer_only = ( is_input_observer_only = (
node_name_to_qconfig[observed_node.name] is None node_name_to_qconfig[observed_node.name] is None
if observed_node.name in node_name_to_qconfig if observed_node.name in node_name_to_qconfig
@ -727,8 +737,10 @@ def convert_standalone_module(
"_observed_graph_module_attrs" "_observed_graph_module_attrs"
].standalone_module_output_quantized_idxs ].standalone_module_output_quantized_idxs
if len(sm_output_quantized_idxs) > 0: if len(sm_output_quantized_idxs) > 0:
assert sm_output_quantized_idxs[0] == 0, "Currently only quantized" if sm_output_quantized_idxs[0] != 0:
"output idxs = [0] is supported" raise AssertionError(
"Currently only quantized output idxs = [0] is supported"
)
# if it's non-empty, then it means the output is kept in quantized form # if it's non-empty, then it means the output is kept in quantized form
# we'll just add a dequantize node after this node # we'll just add a dequantize node after this node
@ -882,9 +894,10 @@ def convert_weighted_module(
ref_qmodule_cls = root_module_to_quantized_reference_module.get( ref_qmodule_cls = root_module_to_quantized_reference_module.get(
type_before_parametrizations(float_module), None type_before_parametrizations(float_module), None
) )
assert ref_qmodule_cls is not None, ( if ref_qmodule_cls is None:
f"No reference quantized module class configured for {type_before_parametrizations(float_module)}" raise AssertionError(
) f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
)
ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined] ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined]
if fused_module is not None: if fused_module is not None:
fused_module[0] = ref_qmodule # type: ignore[operator] fused_module[0] = ref_qmodule # type: ignore[operator]
@ -904,9 +917,10 @@ def _remove_previous_dequantize_in_custom_module(
\\ - dequantize \\ - dequantize
""" """
# expecting the input node for a custom module node to be a Node # expecting the input node for a custom module node to be a Node
assert isinstance(prev_node, Node), ( if not isinstance(prev_node, Node):
f"Expecting the argument for custom module node to be a Node, but got {prev_node}" raise AssertionError(
) f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
)
if prev_node.op == "call_method" and prev_node.target == "dequantize": if prev_node.op == "call_method" and prev_node.target == "dequantize":
node.replace_input_with(prev_node, prev_node.args[0]) node.replace_input_with(prev_node, prev_node.args[0])
# Remove the dequantize node if it doesn't have other users # Remove the dequantize node if it doesn't have other users
@ -952,15 +966,21 @@ def convert_custom_module(
if _is_custom_module_lstm(node, modules): if _is_custom_module_lstm(node, modules):
# The inputs are tuples in the form (input, (hidden0, hidden1)) # The inputs are tuples in the form (input, (hidden0, hidden1))
# Ensure all three input nodes are quantized # Ensure all three input nodes are quantized
assert ( if not (
len(node.args) == 2 len(node.args) == 2
and isinstance(node.args[1], tuple) and isinstance(node.args[1], tuple)
and len(node.args[1]) == 2 and len(node.args[1]) == 2
) ):
raise AssertionError(
"Expected LSTM custom module inputs to be (input, (hidden0, hidden1))"
)
(inputs, (hidden0, hidden1)) = node.args # type: ignore[misc] (inputs, (hidden0, hidden1)) = node.args # type: ignore[misc]
assert isinstance(inputs, Node) if not isinstance(inputs, Node):
assert isinstance(hidden0, Node) raise AssertionError("Expected inputs to be a Node")
assert isinstance(hidden1, Node) if not isinstance(hidden0, Node):
raise AssertionError("Expected hidden0 to be a Node")
if not isinstance(hidden1, Node):
raise AssertionError("Expected hidden1 to be a Node")
_remove_previous_dequantize_in_custom_module(node, inputs, graph) _remove_previous_dequantize_in_custom_module(node, inputs, graph)
_remove_previous_dequantize_in_custom_module(node, hidden0, graph) _remove_previous_dequantize_in_custom_module(node, hidden0, graph)
_remove_previous_dequantize_in_custom_module(node, hidden1, graph) _remove_previous_dequantize_in_custom_module(node, hidden1, graph)
@ -971,22 +991,32 @@ def convert_custom_module(
# to the module. # to the module.
# Additional handling is yet to be implemented for the outputs, similar # Additional handling is yet to be implemented for the outputs, similar
# to LSTM custom module # to LSTM custom module
assert len(node.args) == 3 if len(node.args) != 3:
raise AssertionError(
"Expected MHA custom module inputs to be (query, key, value)"
)
query, key, value = node.args query, key, value = node.args
assert isinstance(query, Node) if not isinstance(query, Node):
assert isinstance(key, Node) raise AssertionError("Expected query to be a Node")
assert isinstance(value, Node) if not isinstance(key, Node):
raise AssertionError("Expected key to be a Node")
if not isinstance(value, Node):
raise AssertionError("Expected value to be a Node")
_remove_previous_dequantize_in_custom_module(node, query, graph) _remove_previous_dequantize_in_custom_module(node, query, graph)
_remove_previous_dequantize_in_custom_module(node, key, graph) _remove_previous_dequantize_in_custom_module(node, key, graph)
_remove_previous_dequantize_in_custom_module(node, value, graph) _remove_previous_dequantize_in_custom_module(node, value, graph)
else: else:
# remove the previous dequant node to ensure the inputs are quantized # remove the previous dequant node to ensure the inputs are quantized
arg = node.args[0] arg = node.args[0]
assert isinstance(arg, Node) if not isinstance(arg, Node):
raise AssertionError("Expected arg to be a Node")
_remove_previous_dequantize_in_custom_module(node, arg, graph) _remove_previous_dequantize_in_custom_module(node, arg, graph)
# absorb the following observer into the module conversion # absorb the following observer into the module conversion
activation_post_process = _maybe_get_observer_for_node(node, modules) activation_post_process = _maybe_get_observer_for_node(node, modules)
assert activation_post_process is not None if activation_post_process is None:
raise AssertionError(
"Expected activation_post_process to be present for observed custom module"
)
observed_custom_module.activation_post_process = activation_post_process observed_custom_module.activation_post_process = activation_post_process
# swap the observed custom module to quantized custom module # swap the observed custom module to quantized custom module
@ -1061,7 +1091,8 @@ def convert(
QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
) )
qconfig_mapping = copy.deepcopy(qconfig_mapping) qconfig_mapping = copy.deepcopy(qconfig_mapping)
assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping) if not (qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)):
raise AssertionError("qconfig_mapping must be None or a QConfigMapping")
if isinstance(backend_config, dict): if isinstance(backend_config, dict):
warnings.warn( warnings.warn(
@ -1075,7 +1106,8 @@ def convert(
if backend_config is None: if backend_config is None:
backend_config = get_native_backend_config() backend_config = get_native_backend_config()
assert _is_observed_module(model), "incoming model must be produced by prepare_fx" if not _is_observed_module(model):
raise AssertionError("incoming model must be produced by prepare_fx")
observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"] observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
node_name_to_scope: dict[str, tuple[str, type]] = ( node_name_to_scope: dict[str, tuple[str, type]] = (
observed_graph_module_attrs.node_name_to_scope observed_graph_module_attrs.node_name_to_scope
@ -1121,14 +1153,16 @@ def convert(
# all the values either match what was set in prepare node_name_to_qconfig # all the values either match what was set in prepare node_name_to_qconfig
# or are set to None in the convert_node_name_to_qconfig. # or are set to None in the convert_node_name_to_qconfig.
for k, v in node_name_to_qconfig.items(): for k, v in node_name_to_qconfig.items():
assert k in convert_node_name_to_qconfig, ( if k not in convert_node_name_to_qconfig:
f"Expected key {k} in convert node_name_to_qconfig" raise AssertionError(
) f"Expected key {k} in convert node_name_to_qconfig"
if convert_node_name_to_qconfig[k] is not None:
assert qconfig_equals(v, convert_node_name_to_qconfig[k]), (
f"Expected k {k} to have the same value in prepare and convert QConfigMappings, "
f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
) )
if convert_node_name_to_qconfig[k] is not None:
if not qconfig_equals(v, convert_node_name_to_qconfig[k]):
raise AssertionError(
f"Expected k {k} to have the same value in prepare and convert QConfigMappings, "
f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
)
node_name_to_qconfig = convert_node_name_to_qconfig node_name_to_qconfig = convert_node_name_to_qconfig
custom_module_classes = get_custom_module_class_keys( custom_module_classes = get_custom_module_class_keys(
@ -1201,7 +1235,10 @@ def convert(
) )
elif node.op == "call_module": elif node.op == "call_module":
mod = _get_module(node, modules) mod = _get_module(node, modules)
assert mod is not None if mod is None:
raise AssertionError(
"Expected module for call_module node to be present in modules mapping"
)
if _is_activation_post_process(mod): if _is_activation_post_process(mod):
observed_node = node.args[0] observed_node = node.args[0]
if observed_node in statically_quantized_custom_module_nodes: if observed_node in statically_quantized_custom_module_nodes:

View File

@ -102,7 +102,10 @@ def fuse(
else: else:
node_subpattern = None node_subpattern = None
if maybe_last_node is node: if maybe_last_node is node:
assert obj is not None if obj is None:
raise AssertionError(
"fuse handler object must not be None for matched root node"
)
root_node_getter = fusion_pattern_to_root_node_getter.get( root_node_getter = fusion_pattern_to_root_node_getter.get(
pattern, default_root_node_getter pattern, default_root_node_getter
) )

View File

@ -65,9 +65,8 @@ class DefaultFuseHandler(FuseHandler):
fuser_method_mapping: dict[Pattern, Union[torch.nn.Sequential, Callable]], fuser_method_mapping: dict[Pattern, Union[torch.nn.Sequential, Callable]],
is_qat: bool, is_qat: bool,
) -> Node: ) -> Node:
assert root_node.op == "call_module", ( if root_node.op != "call_module":
"Expecting module node to be a call_module Node" raise AssertionError("Expecting module node to be a call_module Node")
)
root_module = named_modules[str(root_node.target)] root_module = named_modules[str(root_node.target)]
def get_modules(pattern): def get_modules(pattern):

View File

@ -109,7 +109,8 @@ def _get_lstm_with_individually_observed_parts(
# TODO: maybe make this work for layer_bw as well # TODO: maybe make this work for layer_bw as well
for layer in quantizable_lstm.layers: for layer in quantizable_lstm.layers:
cell = layer.layer_fw.cell # type: ignore[union-attr] cell = layer.layer_fw.cell # type: ignore[union-attr]
assert isinstance(cell, torch.nn.Module), "cell should be a nn.Module" if not isinstance(cell, torch.nn.Module):
raise AssertionError("cell should be a nn.Module")
cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config) cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config)
# HACK: Manually replace the activation_post_process following these ops. # HACK: Manually replace the activation_post_process following these ops.
# This is needed for FloatFunctional ops because there is currently no way # This is needed for FloatFunctional ops because there is currently no way
@ -150,7 +151,8 @@ def _get_lstm_with_individually_observed_parts(
continue continue
if op_index not in op_index_to_activation_post_process_ctr: if op_index not in op_index_to_activation_post_process_ctr:
continue continue
assert len(node.users) == 1 if len(node.users) != 1:
raise AssertionError("expected exactly one user for the node")
activation_post_process_name = next(iter(node.users.keys())).name activation_post_process_name = next(iter(node.users.keys())).name
activation_post_process_ctr = op_index_to_activation_post_process_ctr[ activation_post_process_ctr = op_index_to_activation_post_process_ctr[
op_index op_index
@ -195,7 +197,8 @@ def _get_reference_quantized_lstm_module(
for i, layer in enumerate(quantized_lstm.layers): for i, layer in enumerate(quantized_lstm.layers):
cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr] cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr]
cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type] cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type]
assert isinstance(cell, torch.fx.GraphModule) if not isinstance(cell, torch.fx.GraphModule):
raise AssertionError("cell must be converted to a torch.fx.GraphModule")
# HACK: Manually remove input quantize nodes and output dequantize nodes, # HACK: Manually remove input quantize nodes and output dequantize nodes,
# since custom modules expect quint8 inputs and outputs for now. Note that # since custom modules expect quint8 inputs and outputs for now. Note that
# this functionality is supposedly handled through PrepareCustomConfig's # this functionality is supposedly handled through PrepareCustomConfig's

View File

@ -33,7 +33,8 @@ def _is_match(modules, node, pattern, max_uses=sys.maxsize):
if isinstance(pattern, tuple): if isinstance(pattern, tuple):
self_match, *arg_matches = pattern self_match, *arg_matches = pattern
if self_match is getattr: if self_match is getattr:
assert len(pattern) == 2, "Expecting getattr pattern to have two elements" if len(pattern) != 2:
raise AssertionError("Expecting getattr pattern to have two elements")
arg_matches = [] arg_matches = []
else: else:
self_match = pattern self_match = pattern
@ -190,7 +191,8 @@ def _find_matches(
break break
# add custom module instances to the match result # add custom module instances to the match result
assert modules is not None if modules is None:
raise AssertionError("modules must not be None")
for node in graph.nodes: for node in graph.nodes:
if ( if (
node.op == "call_module" node.op == "call_module"
@ -204,7 +206,8 @@ def _find_matches(
) )
def is_standalone_module(node_target: str, modules: dict[str, torch.nn.Module]): def is_standalone_module(node_target: str, modules: dict[str, torch.nn.Module]):
assert modules is not None if modules is None:
raise AssertionError("modules must not be None")
return ( return (
node_target in standalone_module_names node_target in standalone_module_names
or type(modules[node_target]) # type: ignore[operator] or type(modules[node_target]) # type: ignore[operator]

View File

@ -149,10 +149,11 @@ def _create_obs_or_fq_from_qspec(
return None return None
if isinstance(quantization_spec, SharedQuantizationSpec): if isinstance(quantization_spec, SharedQuantizationSpec):
edge_or_node = quantization_spec.edge_or_node edge_or_node = quantization_spec.edge_or_node
assert edge_or_node in obs_or_fq_map, ( if edge_or_node not in obs_or_fq_map:
"please make sure only refer to edge or node that has " raise AssertionError(
f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}" "please make sure only refer to edge or node that has "
) f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}"
)
return obs_or_fq_map[edge_or_node] return obs_or_fq_map[edge_or_node]
elif isinstance(quantization_spec, DerivedQuantizationSpec): elif isinstance(quantization_spec, DerivedQuantizationSpec):
# can't use asdict, so not calling get_observer_kwargs here # can't use asdict, so not calling get_observer_kwargs here
@ -177,7 +178,8 @@ def _create_obs_or_fq_from_qspec(
else: else:
return observer_ctr() return observer_ctr()
assert isinstance(quantization_spec, QuantizationSpec) if not isinstance(quantization_spec, QuantizationSpec):
raise AssertionError("quantization_spec must be a QuantizationSpec")
observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
kwargs = _get_observer_kwargs(quantization_spec) kwargs = _get_observer_kwargs(quantization_spec)
kwargs.pop("observer_or_fake_quant_ctr") kwargs.pop("observer_or_fake_quant_ctr")
@ -214,10 +216,14 @@ def _needs_obs_or_fq(
# need to insert placeholder observer for dynamic quantization so that it can # need to insert placeholder observer for dynamic quantization so that it can
# be converted to choose_qparams -> q -> dq in convert step # be converted to choose_qparams -> q -> dq in convert step
if cur_target_is_dynamic: if cur_target_is_dynamic:
assert cur_target_dtype in _OBS_DTYPE_LIST, ( if cur_target_dtype not in _OBS_DTYPE_LIST:
f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}" raise AssertionError(
) f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST )
if prev_output_dtype in _DO_NOT_OBS_DTYPE_LIST:
raise AssertionError(
"prev_output_dtype must not be in _DO_NOT_OBS_DTYPE_LIST"
)
return is_zeroth_arg return is_zeroth_arg
if reuse_input_obs_or_fq: if reuse_input_obs_or_fq:
return False return False
@ -398,7 +404,8 @@ def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
""" """
if backend_config is None or pattern is None: if backend_config is None or pattern is None:
return True return True
assert matched_node_pattern is not None and len(matched_node_pattern) >= 1 if matched_node_pattern is None or len(matched_node_pattern) < 1:
raise AssertionError("matched_node_pattern must be non-empty")
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config) pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, []) dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config) pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
@ -535,7 +542,8 @@ def _set_target_dtype_info_for_matched_node_pattern(
# other types of matched object, e.g. int, float literals, are ignored # other types of matched object, e.g. int, float literals, are ignored
elif isinstance(matched_node_pattern, Node): elif isinstance(matched_node_pattern, Node):
# for pyre # for pyre
assert isinstance(matched_node_pattern, Node) if not isinstance(matched_node_pattern, Node):
raise AssertionError("matched_node_pattern must be a Node")
node = matched_node_pattern node = matched_node_pattern
if node in processed_nodes: if node in processed_nodes:
return return
@ -674,7 +682,8 @@ def _get_output_act_obs_or_fq(
We are assuming that the observers are inserted correctly, and the dtype for We are assuming that the observers are inserted correctly, and the dtype for
argument in quantized graph will match what is specified by the qconfig argument in quantized graph will match what is specified by the qconfig
""" """
assert isinstance(arg, Node) if not isinstance(arg, Node):
raise AssertionError("arg must be a Node")
if "quantization_annotation" in arg.meta: if "quantization_annotation" in arg.meta:
return _create_obs_or_fq_from_qspec( return _create_obs_or_fq_from_qspec(
arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
@ -698,9 +707,8 @@ def _get_output_act_obs_or_fq(
) )
elif _is_activation_post_process_node(arg, named_modules): elif _is_activation_post_process_node(arg, named_modules):
observed_arg = arg.args[0] observed_arg = arg.args[0]
assert isinstance(observed_arg, Node), ( if not isinstance(observed_arg, Node):
"Currently we only support observing Node" raise AssertionError("Currently we only support observing Node")
)
if "quantization_annotation" in observed_arg.meta: if "quantization_annotation" in observed_arg.meta:
output_act_obs_or_fq = _create_obs_or_fq_from_qspec( output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
observed_arg.meta["quantization_annotation"].output_qspec, observed_arg.meta["quantization_annotation"].output_qspec,
@ -708,7 +716,10 @@ def _get_output_act_obs_or_fq(
is_qat, is_qat,
) )
else: else:
assert "target_dtype_info" in observed_arg.meta if "target_dtype_info" not in observed_arg.meta:
raise AssertionError(
"expected 'target_dtype_info' in observed_arg.meta"
)
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][ output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][
"output_act_obs_or_fq_ctr" "output_act_obs_or_fq_ctr"
] ]
@ -754,7 +765,8 @@ def _get_arg_as_input_act_obs_or_fq(
"""Get the observer or fake quant constructor for the Argument `arg`, as input """Get the observer or fake quant constructor for the Argument `arg`, as input
to Node `node` to Node `node`
""" """
assert isinstance(arg, Node) if not isinstance(arg, Node):
raise AssertionError("arg must be a Node")
# "input_qspec_map" is the more general design we'll use for pt2e path # "input_qspec_map" is the more general design we'll use for pt2e path
# it is a map from input argument node to observer or fake quant constructor, for example # it is a map from input argument node to observer or fake quant constructor, for example
# for the following graph: # for the following graph:
@ -838,7 +850,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
if not isinstance(arg, Node): if not isinstance(arg, Node):
return arg return arg
assert isinstance(arg, Node) if not isinstance(arg, Node):
raise AssertionError("arg must be a Node")
# default (no observer) # default (no observer)
new_arg = arg new_arg = arg
@ -854,7 +867,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
"quantization_annotation" "quantization_annotation"
]._reuse_input_obs_or_fq ]._reuse_input_obs_or_fq
else: else:
assert "target_dtype_info" in node.meta if "target_dtype_info" not in node.meta:
raise AssertionError("expected 'target_dtype_info' in node.meta")
# TODO: we are assuming "target_dtype_info" exists here, maybe # TODO: we are assuming "target_dtype_info" exists here, maybe
# a default value also need to be provided here # a default value also need to be provided here
target_dtype_info = node.meta["target_dtype_info"] target_dtype_info = node.meta["target_dtype_info"]
@ -889,7 +903,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
) )
else: else:
assert qconfig is not None if qconfig is None:
raise AssertionError("qconfig must not be None")
# custom flow for standalone modules # custom flow for standalone modules
_, _, sm_prepare_custom_config, _ = _get_standalone_module_configs( _, _, sm_prepare_custom_config, _ = _get_standalone_module_configs(
node, named_modules, prepare_custom_config, qconfig, backend_config node, named_modules, prepare_custom_config, qconfig, backend_config
@ -946,7 +961,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
existing_obs_node = maybe_obs_node existing_obs_node = maybe_obs_node
break break
assert arg_as_input_act_obs_or_fq is not None if arg_as_input_act_obs_or_fq is None:
raise AssertionError("arg_as_input_act_obs_or_fq must not be None")
obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
if existing_obs_node is None: if existing_obs_node is None:
new_obs_node = _insert_obs_or_fq( new_obs_node = _insert_obs_or_fq(
@ -1102,7 +1118,8 @@ def _maybe_insert_output_observer_for_node(
Note: inserting dynamic quantization ops for output is not supported in fx graph mode Note: inserting dynamic quantization ops for output is not supported in fx graph mode
quantization code path right now quantization code path right now
""" """
assert node.op != "output", "observer insertion for outputs is handled elsewhere" if node.op == "output":
raise AssertionError("observer insertion for outputs is handled elsewhere")
is_standalone_module = False is_standalone_module = False
if "quantization_annotation" in node.meta: if "quantization_annotation" in node.meta:
@ -1110,7 +1127,8 @@ def _maybe_insert_output_observer_for_node(
node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
) )
else: else:
assert "target_dtype_info" in node.meta if "target_dtype_info" not in node.meta:
raise AssertionError("expected 'target_dtype_info' in node.meta")
is_standalone_module = node.meta["target_dtype_info"].get( is_standalone_module = node.meta["target_dtype_info"].get(
"_is_standalone_module", False "_is_standalone_module", False
) )
@ -1222,7 +1240,10 @@ def _maybe_insert_observers_before_graph_output(
and arg_as_input_target_dtype != torch.float and arg_as_input_target_dtype != torch.float
) )
if need_obs: if need_obs:
assert observer_mod is not None if observer_mod is None:
raise AssertionError(
"observer_mod must not be None when need_obs is True"
)
# insert observer # insert observer
observer_node = _insert_obs_or_fq( observer_node = _insert_obs_or_fq(
maybe_node, observer_mod, model, named_modules, graph maybe_node, observer_mod, model, named_modules, graph
@ -1393,9 +1414,11 @@ def _maybe_make_input_output_share_observers(
if iteration_guard > 10000: if iteration_guard > 10000:
raise AssertionError("Unable to find observer of previous node") raise AssertionError("Unable to find observer of previous node")
assert isinstance(first_arg_arg, Node) if not isinstance(first_arg_arg, Node):
raise AssertionError("first_arg_arg must be a Node")
target_to_use = first_arg_arg.target target_to_use = first_arg_arg.target
assert isinstance(target_to_use, str) if not isinstance(target_to_use, str):
raise AssertionError("target_to_use must be a string")
obs_mod_to_use = named_modules[target_to_use] obs_mod_to_use = named_modules[target_to_use]
if isinstance(first_arg, (list, tuple)): if isinstance(first_arg, (list, tuple)):
@ -1418,7 +1441,10 @@ def _maybe_make_input_output_share_observers(
# set the output observer node to use that module # set the output observer node to use that module
for output_obs_node in node.users.keys(): for output_obs_node in node.users.keys():
assert _is_activation_post_process_node(output_obs_node, named_modules) if not _is_activation_post_process_node(output_obs_node, named_modules):
raise AssertionError(
"output_obs_node must be an activation post process node"
)
parent_name, name = _parent_name(output_obs_node.target) parent_name, name = _parent_name(output_obs_node.target)
setattr(named_modules[parent_name], name, obs_mod_to_use) setattr(named_modules[parent_name], name, obs_mod_to_use)
@ -1431,7 +1457,10 @@ def _remove_output_observer(
): ):
items = list(node.users.items()) items = list(node.users.items())
for output_obs_node, _ in items: for output_obs_node, _ in items:
assert _is_activation_post_process_node(output_obs_node, named_modules) if not _is_activation_post_process_node(output_obs_node, named_modules):
raise AssertionError(
"output_obs_node must be an activation post process node"
)
output_obs_node.replace_all_uses_with(node) output_obs_node.replace_all_uses_with(node)
model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator] model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator]
@ -1554,7 +1583,8 @@ def insert_observers_for_model(
qhandler, qhandler,
qconfig, qconfig,
) = match_res_with_qconfig ) = match_res_with_qconfig
assert qhandler is not None if qhandler is None:
raise AssertionError("qhandler must not be None")
_set_target_dtype_info_for_matched_node_pattern( _set_target_dtype_info_for_matched_node_pattern(
matched_node_pattern, matched_node_pattern,
last_node, last_node,
@ -1632,7 +1662,8 @@ def insert_observers_for_model(
pattern, matched_node_pattern, qconfig, backend_config pattern, matched_node_pattern, qconfig, backend_config
) )
) )
assert qhandler is not None if qhandler is None:
raise AssertionError("qhandler must not be None")
# get output_act_dtype so that we don't also reset the special typed nodes # get output_act_dtype so that we don't also reset the special typed nodes
# TODO: we might want to handle these more uniformly with the default path # TODO: we might want to handle these more uniformly with the default path
@ -1726,7 +1757,8 @@ def insert_observers_for_model(
if not skip_inserting_observers and is_supported_by_backend: if not skip_inserting_observers and is_supported_by_backend:
named_modules = dict(model.named_modules(remove_duplicate=False)) named_modules = dict(model.named_modules(remove_duplicate=False))
if node.op != "output": if node.op != "output":
assert matched_node_pattern is not None if matched_node_pattern is None:
raise AssertionError("matched_node_pattern must not be None")
# add matched nodes to the observed node name set # add matched nodes to the observed node name set
_add_matched_node_name_to_set( _add_matched_node_name_to_set(
matched_node_pattern, observed_node_names matched_node_pattern, observed_node_names
@ -2064,8 +2096,10 @@ def prepare(
) )
backend_config = BackendConfig.from_dict(backend_config) backend_config = BackendConfig.from_dict(backend_config)
assert isinstance(qconfig_mapping, QConfigMapping) if not isinstance(qconfig_mapping, QConfigMapping):
assert isinstance(_equalization_config, QConfigMapping) raise AssertionError("qconfig_mapping must be a QConfigMapping")
if not isinstance(_equalization_config, QConfigMapping):
raise AssertionError("_equalization_config must be a QConfigMapping")
qconfig_mapping = copy.deepcopy(qconfig_mapping) qconfig_mapping = copy.deepcopy(qconfig_mapping)
_equalization_config = copy.deepcopy(_equalization_config) _equalization_config = copy.deepcopy(_equalization_config)
@ -2194,11 +2228,12 @@ def prepare(
) )
if is_standalone_module: if is_standalone_module:
assert result_node is not None if result_node is None:
assert isinstance(result_node.args[0], Node), ( raise AssertionError("result_node must not be None for standalone modules")
"standalone module only supports returning simple value currently" if not isinstance(result_node.args[0], Node):
"(not tuple, dict etc.)" raise AssertionError(
) "standalone module only supports returning simple value currently (not tuple, dict etc.)"
)
# these inputs are observed in parent # these inputs are observed in parent
# converting List[int] to Tensor since module attribute is # converting List[int] to Tensor since module attribute is
# Union[Tensor, Module] # Union[Tensor, Module]

View File

@ -228,11 +228,12 @@ def _compare_prepare_convert_qconfig_mappings(
`prepare_qconfig_mapping`: configuration for prepare quantization step `prepare_qconfig_mapping`: configuration for prepare quantization step
`convert_qconfig_mapping`: configuration for convert quantization step `convert_qconfig_mapping`: configuration for convert quantization step
""" """
assert qconfig_equals( if not qconfig_equals(
prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig
), ( ):
"Expected global qconfigs to be the same in the prepare and convert quantization configs" raise AssertionError(
) "Expected global qconfigs to be the same in the prepare and convert quantization configs"
)
prepare_dicts: list[OrderedDict] = [ prepare_dicts: list[OrderedDict] = [
prepare_qconfig_mapping.object_type_qconfigs, prepare_qconfig_mapping.object_type_qconfigs,
prepare_qconfig_mapping.module_name_qconfigs, prepare_qconfig_mapping.module_name_qconfigs,
@ -250,16 +251,17 @@ def _compare_prepare_convert_qconfig_mappings(
] ]
for i in range(len(prepare_dicts)): for i in range(len(prepare_dicts)):
for name in prepare_dicts[i].keys(): for name in prepare_dicts[i].keys():
assert name in convert_dicts[i], ( if name not in convert_dicts[i]:
f"Missing key {dict_names[i]} {name} in convert QConfigMapping \ raise AssertionError(
when it was present in prepare" f"Missing key {dict_names[i]} {name} in convert QConfigMapping when it was present in prepare"
) )
assert convert_dicts[i][name] is None or qconfig_equals( if convert_dicts[i][name] is not None and not qconfig_equals(
prepare_dicts[i][name], convert_dicts[i][name] prepare_dicts[i][name], convert_dicts[i][name]
), ( ):
f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \ raise AssertionError(
prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}" "Expected convert QConfigMapping to have the same qconfig as prepare for key "
) f"{dict_names[i]} {name}; prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}"
)
def _is_qconfig_supported_by_dtype_configs( def _is_qconfig_supported_by_dtype_configs(

View File

@ -119,10 +119,11 @@ def _get_quantize_handler_cls(
): ):
super().__init__(node_pattern, modules, root_node_getter) super().__init__(node_pattern, modules, root_node_getter)
if num_tensor_args_to_observation_type: if num_tensor_args_to_observation_type:
assert self.num_tensor_args in num_tensor_args_to_observation_type, ( if self.num_tensor_args not in num_tensor_args_to_observation_type:
f"Must provide observation_type config for tensor number {self.num_tensor_args}" raise AssertionError(
f" in num_tensor_args_to_observation_type for {node_pattern}" f"Must provide observation_type config for tensor number {self.num_tensor_args}"
) f" in num_tensor_args_to_observation_type for {node_pattern}"
)
self.observation_type = num_tensor_args_to_observation_type[ self.observation_type = num_tensor_args_to_observation_type[
self.num_tensor_args self.num_tensor_args
] ]

View File

@ -165,7 +165,8 @@ def get_qconv_prepack_op(conv_op: Callable) -> Callable:
torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack, torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
} }
prepack_op = prepack_ops.get(conv_op) prepack_op = prepack_ops.get(conv_op)
assert prepack_op, f"Didn't find prepack op for {conv_op}" if prepack_op is None:
raise AssertionError(f"Didn't find prepack op for {conv_op}")
return prepack_op return prepack_op
@ -230,7 +231,8 @@ def graph_module_from_producer_nodes(
Return: Return:
A graph module constructed from the producer nodes A graph module constructed from the producer nodes
""" """
assert len(producer_nodes) > 0, "list of producer nodes can not be empty" if len(producer_nodes) == 0:
raise AssertionError("list of producer nodes can not be empty")
# since we traced back from node to getattr # since we traced back from node to getattr
producer_nodes.reverse() producer_nodes.reverse()
graph = Graph() graph = Graph()
@ -300,7 +302,8 @@ def all_node_args_have_no_tensors(
elif node.op == "placeholder": elif node.op == "placeholder":
result = False result = False
elif node.op == "call_module": elif node.op == "call_module":
assert isinstance(node.target, str) if not isinstance(node.target, str):
raise AssertionError("node.target must be a string for call_module nodes")
if _is_activation_post_process(modules[node.target]): if _is_activation_post_process(modules[node.target]):
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type] result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
elif node.op == "call_module": elif node.op == "call_module":
@ -503,9 +506,10 @@ def _is_custom_module_lstm(
""" """
mod = _get_module(node, named_modules) mod = _get_module(node, named_modules)
if qconfig is not None and qhandler is not None: if qconfig is not None and qhandler is not None:
assert isinstance( if not isinstance(
qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler
) # type: ignore[attr-defined] ): # type: ignore[attr-defined]
raise AssertionError("qhandler must be a QuantizeHandler when provided")
return ( return (
isinstance(mod, torch.nn.LSTM) isinstance(mod, torch.nn.LSTM)
and activation_is_statically_quantized(qconfig) and activation_is_statically_quantized(qconfig)
@ -527,9 +531,10 @@ def _is_custom_module_mha(
""" """
mod = _get_module(node, named_modules) mod = _get_module(node, named_modules)
if qconfig is not None and qhandler is not None: if qconfig is not None and qhandler is not None:
assert isinstance( if not isinstance(
qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler
) # type: ignore[attr-defined] ): # type: ignore[attr-defined]
raise AssertionError("qhandler must be a QuantizeHandler when provided")
return ( return (
isinstance(mod, torch.nn.MultiheadAttention) isinstance(mod, torch.nn.MultiheadAttention)
and activation_is_statically_quantized(qconfig) and activation_is_statically_quantized(qconfig)
@ -826,11 +831,17 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
for pattern in matched_patterns: for pattern in matched_patterns:
first_tuple = pattern[0] first_tuple = pattern[0]
last_getitem = pattern[-1] last_getitem = pattern[-1]
assert first_tuple.op == "call_function" and first_tuple.target is tuple if not (first_tuple.op == "call_function" and first_tuple.target is tuple):
assert ( raise AssertionError(
"first tuple node must be a call_function with target tuple"
)
if not (
last_getitem.op == "call_function" last_getitem.op == "call_function"
and last_getitem.target == operator.getitem and last_getitem.target == operator.getitem
) ):
raise AssertionError(
"last getitem node must be a call_function with target operator.getitem"
)
last_getitem_index = last_getitem.args[1] last_getitem_index = last_getitem.args[1]
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index] new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
for user in list(last_getitem.users.keys()): for user in list(last_getitem.users.keys()):
@ -847,7 +858,10 @@ def _get_observer_from_activation_post_process(
if isinstance(activation_post_process, ObserverBase): if isinstance(activation_post_process, ObserverBase):
return activation_post_process return activation_post_process
else: else:
assert isinstance(activation_post_process, FakeQuantizeBase) if not isinstance(activation_post_process, FakeQuantizeBase):
raise AssertionError(
"activation_post_process must be an ObserverBase or FakeQuantizeBase"
)
return activation_post_process.activation_post_process # type: ignore[return-value] return activation_post_process.activation_post_process # type: ignore[return-value]
@ -966,7 +980,10 @@ def _qconfig_satisfies_dtype_config_constraints(
satisfies_constraints = True satisfies_constraints = True
if activation_post_process_ctr is not None: if activation_post_process_ctr is not None:
activation_post_process = activation_post_process_ctr() activation_post_process = activation_post_process_ctr()
assert _is_activation_post_process(activation_post_process) if not _is_activation_post_process(activation_post_process):
raise AssertionError(
"activation_post_process must be an activation post process"
)
# If dtypes don't match, don't check the activation_post_process and return True early # If dtypes don't match, don't check the activation_post_process and return True early
if activation_post_process.dtype != dtype_with_constraints.dtype: if activation_post_process.dtype != dtype_with_constraints.dtype:
return True return True