mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Code Clean] Clean asserts in torch/ao/quantization/fx/* (#165420)
Replace assert statements with explicit if/raise patterns in: - torch/ao/quantization/fx/* (177 errors) fix partialy #164878 Pull Request resolved: https://github.com/pytorch/pytorch/pull/165420 Approved by: https://github.com/RohitRathore1, https://github.com/fffrog, https://github.com/albanD
This commit is contained in:
parent
df71b70727
commit
7692fa09cd
|
|
@ -29,12 +29,14 @@ def _quant_min_max_bounds_check(quant_min, quant_max, dtype):
|
|||
raise ValueError(f"Unsupported dtype: {dtype}")
|
||||
quant_min_lower_bound, quant_max_upper_bound = _DTYPE_TO_QVALUE_BOUNDS[dtype]
|
||||
|
||||
assert quant_min >= quant_min_lower_bound, (
|
||||
if quant_min < quant_min_lower_bound:
|
||||
raise AssertionError(
|
||||
"quant_min out of bound for dtype, "
|
||||
f"quant_min_lower_bound: {quant_min_lower_bound} quant_min: {quant_min}"
|
||||
)
|
||||
|
||||
assert quant_max <= quant_max_upper_bound, (
|
||||
if quant_max > quant_max_upper_bound:
|
||||
raise AssertionError(
|
||||
"quant_max out of bound for dtype, "
|
||||
f"quant_max_upper_bound: {quant_max_upper_bound} quant_max: {quant_max}"
|
||||
)
|
||||
|
|
@ -72,7 +74,8 @@ def quantize_per_tensor(
|
|||
"""
|
||||
if input.dtype in [torch.float16, torch.bfloat16]:
|
||||
input = input.to(torch.float32)
|
||||
assert input.dtype == torch.float32, (
|
||||
if input.dtype != torch.float32:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
||||
)
|
||||
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
||||
|
|
@ -94,7 +97,8 @@ def quantize_per_tensor_meta(
|
|||
) -> torch.Tensor:
|
||||
if input.dtype in [torch.float16, torch.bfloat16]:
|
||||
input = input.to(torch.float32)
|
||||
assert input.dtype == torch.float32, (
|
||||
if input.dtype != torch.float32:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
||||
)
|
||||
return torch.empty_like(input, dtype=dtype)
|
||||
|
|
@ -122,10 +126,12 @@ def quantize_per_tensor_tensor(
|
|||
Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
|
||||
scalar values
|
||||
"""
|
||||
assert zero_point.numel() == 1, (
|
||||
if zero_point.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
)
|
||||
assert scale.numel() == 1, (
|
||||
if scale.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
)
|
||||
return quantize_per_tensor(
|
||||
|
|
@ -149,13 +155,16 @@ def quantize_per_tensor_tensor_meta(
|
|||
) -> torch.Tensor:
|
||||
if input.dtype in [torch.float16, torch.bfloat16]:
|
||||
input = input.to(torch.float32)
|
||||
assert zero_point.numel() == 1, (
|
||||
if zero_point.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
)
|
||||
assert scale.numel() == 1, (
|
||||
if scale.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
)
|
||||
assert input.dtype == torch.float32, (
|
||||
if input.dtype != torch.float32:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
||||
)
|
||||
return torch.empty_like(input, dtype=dtype)
|
||||
|
|
@ -184,10 +193,12 @@ def quantize_per_tensor_tensor2(
|
|||
Same as `quantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
|
||||
scalar values
|
||||
"""
|
||||
assert zero_point.numel() == 1, (
|
||||
if zero_point.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
)
|
||||
assert scale.numel() == 1, (
|
||||
if scale.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
)
|
||||
return quantize_per_tensor(
|
||||
|
|
@ -266,7 +277,8 @@ def dequantize_per_tensor(
|
|||
Returns:
|
||||
dequantized float32 Tensor
|
||||
"""
|
||||
assert input.dtype == dtype, (
|
||||
if input.dtype != dtype:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
|
||||
)
|
||||
if out_dtype is None:
|
||||
|
|
@ -322,10 +334,12 @@ def dequantize_per_tensor_tensor(
|
|||
Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
|
||||
scalar values
|
||||
"""
|
||||
assert zero_point.numel() == 1, (
|
||||
if zero_point.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
)
|
||||
assert scale.numel() == 1, (
|
||||
if scale.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
)
|
||||
return dequantize_per_tensor(
|
||||
|
|
@ -352,13 +366,18 @@ def dequantize_per_tensor_tensor_meta(
|
|||
) -> torch.Tensor:
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
assert zero_point.numel() == 1, (
|
||||
if zero_point.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
)
|
||||
assert scale.numel() == 1, (
|
||||
if scale.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
)
|
||||
assert input.dtype == dtype, f"Expecting input to have dtype: {dtype}"
|
||||
if input.dtype != dtype:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype: {dtype}, but got {input.dtype}"
|
||||
)
|
||||
if dtype in _DTYPE_TO_QVALUE_BOUNDS:
|
||||
return torch.empty_like(input, dtype=out_dtype)
|
||||
else:
|
||||
|
|
@ -392,10 +411,12 @@ def dequantize_per_tensor_tensor2(
|
|||
Same as `dequantize_per_tensor` but scale and zero_point are Scalar Tensor instead of
|
||||
scalar values
|
||||
"""
|
||||
assert zero_point.numel() == 1, (
|
||||
if zero_point.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting zero_point tensor to be one element, but received : {zero_point.numel()}"
|
||||
)
|
||||
assert scale.numel() == 1, (
|
||||
if scale.numel() != 1:
|
||||
raise AssertionError(
|
||||
f"Expecting scale tensor to be one element, but received : {scale.numel()}"
|
||||
)
|
||||
return dequantize_per_tensor(
|
||||
|
|
@ -448,14 +469,16 @@ def choose_qparams_tensor(
|
|||
scale (float): quantization parameter for the target quantized Tensor
|
||||
zero_point (int): quantization parameter for the target quantized Tensor
|
||||
"""
|
||||
assert input.dtype in [
|
||||
if input.dtype not in [
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], (
|
||||
]:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
|
||||
)
|
||||
assert dtype in _DTYPE_TO_QVALUE_BOUNDS, (
|
||||
if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
|
||||
raise AssertionError(
|
||||
f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
|
||||
)
|
||||
validate_qmin_qmax(qmin, qmax)
|
||||
|
|
@ -500,14 +523,16 @@ def choose_qparams_symmetric_tensor(
|
|||
scale (float): quantization parameter for the target quantized Tensor
|
||||
zero_point (int): quantization parameter for the target quantized Tensor
|
||||
"""
|
||||
assert input.dtype in [
|
||||
if input.dtype not in [
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], (
|
||||
]:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
|
||||
)
|
||||
assert dtype in _DTYPE_TO_QVALUE_BOUNDS, (
|
||||
if dtype not in _DTYPE_TO_QVALUE_BOUNDS:
|
||||
raise AssertionError(
|
||||
f"Expecting target dtype to be one of {_DTYPE_TO_QVALUE_BOUNDS.keys()}, but got: {dtype}"
|
||||
)
|
||||
validate_qmin_qmax(qmin, qmax)
|
||||
|
|
@ -529,16 +554,17 @@ def choose_qparams_symmetric_tensor(
|
|||
def choose_qparams_tensor_meta(
|
||||
input: torch.Tensor, quant_min: int, quant_max: int, eps: float, dtype: torch.dtype
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
assert input.dtype in [
|
||||
if input.dtype not in [
|
||||
torch.float32,
|
||||
torch.float16,
|
||||
torch.bfloat16,
|
||||
], (
|
||||
]:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype torch.float32/16/b16, but got dtype: {input.dtype}"
|
||||
)
|
||||
assert quant_min < quant_max, (
|
||||
f"Expecting quant_min to be smaller than quant_max but received min: \
|
||||
{quant_min} max: {quant_max}"
|
||||
if quant_min >= quant_max:
|
||||
raise AssertionError(
|
||||
f"Expecting quant_min to be smaller than quant_max but received min: {quant_min} max: {quant_max}"
|
||||
)
|
||||
return torch.empty(1, dtype=torch.double, device=input.device), torch.empty(
|
||||
1, dtype=torch.int64, device=input.device
|
||||
|
|
@ -598,10 +624,12 @@ def quantize_per_channel(
|
|||
"""
|
||||
if input.dtype in [torch.float16, torch.bfloat16]:
|
||||
input = input.to(torch.float32)
|
||||
assert input.dtype == torch.float32, (
|
||||
if input.dtype != torch.float32:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
||||
)
|
||||
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
||||
if axis >= input.dim():
|
||||
raise AssertionError(f"Expecting axis to be < {input.dim()}")
|
||||
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
||||
input, permute_axis_list = _permute_to_axis_zero(input, axis)
|
||||
|
||||
|
|
@ -629,10 +657,12 @@ def quantize_per_channel_meta(
|
|||
) -> torch.Tensor:
|
||||
if input.dtype in [torch.float16, torch.bfloat16]:
|
||||
input = input.to(torch.float32)
|
||||
assert input.dtype == torch.float32, (
|
||||
if input.dtype != torch.float32:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
||||
)
|
||||
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
||||
if axis >= input.dim():
|
||||
raise AssertionError(f"Expecting axis to be < {input.dim()}")
|
||||
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
||||
return torch.empty_like(input, dtype=dtype)
|
||||
|
||||
|
|
@ -687,12 +717,14 @@ def dequantize_per_channel(
|
|||
Returns:
|
||||
dequantized float32 Tensor
|
||||
"""
|
||||
assert input.dtype == dtype, (
|
||||
f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
|
||||
if input.dtype != dtype:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype: {dtype}, but got dtype: {input.dtype}"
|
||||
)
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
||||
if axis >= input.dim():
|
||||
raise AssertionError(f"Expecting axis to be < {input.dim()}")
|
||||
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
||||
input, permute_axis_list = _permute_to_axis_zero(input, axis)
|
||||
|
||||
|
|
@ -722,12 +754,14 @@ def dequantize_per_channel_meta(
|
|||
*,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
) -> torch.Tensor:
|
||||
assert input.dtype == dtype, (
|
||||
if input.dtype != dtype:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype {dtype}, but got dtype: {input.dtype}"
|
||||
)
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
||||
if axis >= input.dim():
|
||||
raise AssertionError(f"Expecting axis to be < {input.dim()}")
|
||||
_quant_min_max_bounds_check(quant_min, quant_max, dtype)
|
||||
return torch.empty_like(input, dtype=out_dtype)
|
||||
|
||||
|
|
@ -879,10 +913,10 @@ def choose_qparams_per_token_asymmetric_meta(
|
|||
|
||||
def _per_token_quant_qparam_dim_check(input, scales, zero_points):
|
||||
num_tokens = math.prod(list(input.size())[:-1])
|
||||
assert num_tokens == scales.numel(), (
|
||||
f"num_tokens: {num_tokens} scales: {scales.size()}"
|
||||
)
|
||||
assert num_tokens == zero_points.numel(), (
|
||||
if num_tokens != scales.numel():
|
||||
raise AssertionError(f"num_tokens: {num_tokens} scales: {scales.size()}")
|
||||
if num_tokens != zero_points.numel():
|
||||
raise AssertionError(
|
||||
f"num_tokens: {num_tokens} zero_points: {zero_points.size()}"
|
||||
)
|
||||
|
||||
|
|
@ -1019,17 +1053,21 @@ def quantize_per_channel_group(
|
|||
dtype: torch.dtype,
|
||||
group_size=128,
|
||||
):
|
||||
assert group_size > 1
|
||||
if group_size <= 1:
|
||||
raise AssertionError("group_size must be > 1")
|
||||
# needed for GPTQ single column quantize
|
||||
if group_size > input.shape[-1] and scales.shape[-1] == 1:
|
||||
group_size = input.shape[-1]
|
||||
|
||||
assert input.shape[-1] % group_size == 0
|
||||
assert input.dim() == 2
|
||||
if input.shape[-1] % group_size != 0:
|
||||
raise AssertionError("input.shape[-1] must be divisible by group_size")
|
||||
if input.dim() != 2:
|
||||
raise AssertionError("input must be 2-dimensional")
|
||||
|
||||
# TODO: check for dtype, currently we can't express torch.int4 so it's omitted
|
||||
to_quant = input.reshape(-1, group_size)
|
||||
assert torch.isnan(to_quant).sum() == 0
|
||||
if torch.isnan(to_quant).sum() != 0:
|
||||
raise AssertionError("to_quant must not contain NaNs")
|
||||
|
||||
scales = scales.reshape(-1, 1)
|
||||
zero_points = zero_points.reshape(-1, 1)
|
||||
|
|
@ -1074,13 +1112,16 @@ def quantize_per_channel_group_meta(
|
|||
Tensor with requested dtype (e.g. torch.uint8), note the quantization parameters
|
||||
are not stored in the Tensor, we are storing them in function arguments instead
|
||||
"""
|
||||
assert group_size > 1
|
||||
if group_size <= 1:
|
||||
raise AssertionError("group_size must be > 1")
|
||||
# needed for GPTQ single column quantize
|
||||
if group_size > input.shape[-1] and scales.shape[-1] == 1:
|
||||
group_size = input.shape[-1]
|
||||
|
||||
assert input.shape[-1] % group_size == 0
|
||||
assert input.dim() == 2
|
||||
if input.shape[-1] % group_size != 0:
|
||||
raise AssertionError("input.shape[-1] must be divisible by group_size")
|
||||
if input.dim() != 2:
|
||||
raise AssertionError("input must be 2-dimensional")
|
||||
return torch.empty_like(input, dtype=dtype)
|
||||
|
||||
|
||||
|
|
@ -1124,12 +1165,15 @@ def dequantize_per_channel_group(
|
|||
dequantized Tensor with dtype `output_dtype`
|
||||
"""
|
||||
|
||||
assert group_size > 1
|
||||
if group_size <= 1:
|
||||
raise AssertionError("group_size must be > 1")
|
||||
# needed for GPTQ single column dequantize
|
||||
if group_size > w_int8.shape[-1] and scales.shape[-1] == 1:
|
||||
group_size = w_int8.shape[-1]
|
||||
assert w_int8.shape[-1] % group_size == 0
|
||||
assert w_int8.dim() == 2
|
||||
if w_int8.shape[-1] % group_size != 0:
|
||||
raise AssertionError("w_int8.shape[-1] must be divisible by group_size")
|
||||
if w_int8.dim() != 2:
|
||||
raise AssertionError("w_int8 must be 2-dimensional")
|
||||
|
||||
w_int8_grouped = w_int8.reshape(-1, group_size)
|
||||
scales = scales.reshape(-1, 1)
|
||||
|
|
@ -1155,10 +1199,12 @@ class FakeQuantPerChannel(torch.autograd.Function):
|
|||
scales = scales.to(torch.float32)
|
||||
if zero_points.dtype != torch.int32:
|
||||
zero_points = zero_points.to(torch.int32)
|
||||
assert input.dtype == torch.float32, (
|
||||
if input.dtype != torch.float32:
|
||||
raise AssertionError(
|
||||
f"Expecting input to have dtype torch.float32, but got dtype: {input.dtype}"
|
||||
)
|
||||
assert axis < input.dim(), f"Expecting axis to be < {input.dim()}"
|
||||
if axis >= input.dim():
|
||||
raise AssertionError(f"Expecting axis to be < {input.dim()}")
|
||||
broadcast_dims = list(range(axis)) + list(range(axis + 1, input.ndim))
|
||||
unsqueeze_scales = _unsqueeze_multiple(scales, broadcast_dims)
|
||||
unsqueeze_zero_points = _unsqueeze_multiple(zero_points, broadcast_dims)
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class _InputEqualizationObserver(nn.Module):
|
|||
self.equalization_shape: list[int] = []
|
||||
|
||||
def forward(self, x_orig):
|
||||
if not (x_orig.ndim >= 2 and x_orig.ndim <= 5):
|
||||
if x_orig.ndim < 2 or x_orig.ndim > 5:
|
||||
raise ValueError(
|
||||
"InputEqualizationObserver only supports Linear and Conv layers"
|
||||
)
|
||||
|
|
@ -191,7 +191,7 @@ class _WeightEqualizationObserver(nn.Module):
|
|||
self.equalization_scale = torch.tensor(1)
|
||||
|
||||
def forward(self, w_orig):
|
||||
if not (w_orig.ndim >= 2 and w_orig.ndim <= 5):
|
||||
if w_orig.ndim < 2 or w_orig.ndim > 5:
|
||||
raise ValueError(
|
||||
"InputEqualizationObserver only supports Linear and Conv layers"
|
||||
)
|
||||
|
|
@ -232,7 +232,7 @@ def calculate_equalization_scale(
|
|||
)
|
||||
return torch.tensor(1)
|
||||
|
||||
if not (min_inputs.shape == min_weights.shape):
|
||||
if min_inputs.shape != min_weights.shape:
|
||||
raise ValueError(
|
||||
"Input and Weight must have the same column dimension. "
|
||||
+ f"Found {min_inputs.shape} and {min_weights.shape} shapes instead."
|
||||
|
|
@ -355,30 +355,45 @@ def get_op_node_and_weight_eq_obs(
|
|||
op_node = user
|
||||
break
|
||||
|
||||
assert op_node is not None
|
||||
if op_node is None:
|
||||
raise AssertionError(
|
||||
"Expected an operation node after the input equalization observer"
|
||||
)
|
||||
if op_node.op == "call_module":
|
||||
# If the op_node is a nn.Linear layer, then it must have a
|
||||
# WeightEqualizationObserver configuration
|
||||
maybe_equalization_node_name_to_config = _get_observed_graph_module_attr(
|
||||
model, "equalization_node_name_to_qconfig"
|
||||
)
|
||||
assert maybe_equalization_node_name_to_config is not None
|
||||
if maybe_equalization_node_name_to_config is None:
|
||||
raise AssertionError(
|
||||
"Expected 'equalization_node_name_to_qconfig' attribute in observed graph module"
|
||||
)
|
||||
equalization_node_name_to_qconfig: dict[str, Any] = (
|
||||
maybe_equalization_node_name_to_config # type: ignore[assignment]
|
||||
)
|
||||
assert equalization_node_name_to_qconfig.get(op_node.name, None) is not None
|
||||
if equalization_node_name_to_qconfig.get(op_node.name, None) is None:
|
||||
raise AssertionError(
|
||||
f"No equalization qconfig found for op node {op_node.name}"
|
||||
)
|
||||
weight_eq_obs = equalization_node_name_to_qconfig.get( # type: ignore[union-attr]
|
||||
op_node.name, None
|
||||
).weight()
|
||||
|
||||
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
|
||||
if not isinstance(weight_eq_obs, _WeightEqualizationObserver):
|
||||
raise AssertionError(
|
||||
"Expected weight equalization observer to be a _WeightEqualizationObserver"
|
||||
)
|
||||
return op_node, weight_eq_obs
|
||||
|
||||
elif op_node.op == "call_function":
|
||||
weight_node = maybe_get_weight_eq_obs_node(op_node, modules)
|
||||
if weight_node is not None:
|
||||
weight_eq_obs = modules[str(weight_node.target)]
|
||||
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
|
||||
if not isinstance(weight_eq_obs, _WeightEqualizationObserver):
|
||||
raise AssertionError(
|
||||
"Expected weight equalization observer to be a _WeightEqualizationObserver"
|
||||
)
|
||||
return op_node, weight_eq_obs
|
||||
|
||||
return None, None
|
||||
|
|
@ -388,16 +403,19 @@ def maybe_get_weight_eq_obs_node(
|
|||
op_node: Node, modules: dict[str, nn.Module]
|
||||
) -> Optional[Node]:
|
||||
"""Gets the weight equalization observer node if it exists."""
|
||||
assert op_node.op == "call_function"
|
||||
if op_node.op != "call_function":
|
||||
raise AssertionError(
|
||||
"maybe_get_weight_eq_obs_node expects a call_function op_node"
|
||||
)
|
||||
for node_arg in op_node.args:
|
||||
if node_arg_is_weight(op_node, node_arg):
|
||||
assert (
|
||||
if (
|
||||
isinstance(node_arg, Node)
|
||||
and node_arg.op == "call_module"
|
||||
and isinstance(
|
||||
modules[str(node_arg.target)], _WeightEqualizationObserver
|
||||
)
|
||||
)
|
||||
):
|
||||
return node_arg
|
||||
return None
|
||||
|
||||
|
|
@ -422,7 +440,8 @@ def maybe_get_next_input_eq_obs(
|
|||
the following equalization observer for linear2.
|
||||
"""
|
||||
|
||||
assert node_supports_equalization(node, modules)
|
||||
if not node_supports_equalization(node, modules):
|
||||
raise AssertionError("Node does not support equalization")
|
||||
|
||||
# Locate the following nn.ReLU or F.relu node if it exists
|
||||
maybe_relu_node = maybe_get_next_module(node, modules, nn.ReLU)
|
||||
|
|
@ -448,7 +467,10 @@ def maybe_get_next_input_eq_obs(
|
|||
return None
|
||||
|
||||
maybe_eq_obs = modules[str(maybe_eq_obs_node)]
|
||||
assert isinstance(maybe_eq_obs, _InputEqualizationObserver)
|
||||
if not isinstance(maybe_eq_obs, _InputEqualizationObserver):
|
||||
raise AssertionError(
|
||||
"Expected the following equalization observer to be an _InputEqualizationObserver"
|
||||
)
|
||||
return maybe_eq_obs
|
||||
|
||||
|
||||
|
|
@ -480,10 +502,16 @@ def scale_input_observer(node: Node, modules: dict[str, nn.Module]) -> None:
|
|||
equalization observer
|
||||
"""
|
||||
input_eq_obs = modules[str(node.target)]
|
||||
assert isinstance(input_eq_obs, _InputEqualizationObserver)
|
||||
if not isinstance(input_eq_obs, _InputEqualizationObserver):
|
||||
raise AssertionError(
|
||||
"Expected the module at node.target to be an _InputEqualizationObserver"
|
||||
)
|
||||
|
||||
input_quant_obs_node = node.args[0]
|
||||
assert isinstance(input_quant_obs_node, Node)
|
||||
if not isinstance(input_quant_obs_node, Node):
|
||||
raise AssertionError(
|
||||
"Expected the input quantization observer node to be a Node"
|
||||
)
|
||||
|
||||
input_quant_obs = modules[str(input_quant_obs_node.target)]
|
||||
if not isinstance(input_quant_obs, ObserverBase):
|
||||
|
|
@ -518,14 +546,19 @@ def scale_weight_node(
|
|||
op_module = modules[str(node.target)][0] # type: ignore[index]
|
||||
else:
|
||||
op_module = modules[str(node.target)]
|
||||
assert nn_module_supports_equalization(
|
||||
op_module
|
||||
) or custom_module_supports_equalization(op_module)
|
||||
if not (
|
||||
nn_module_supports_equalization(op_module)
|
||||
or custom_module_supports_equalization(op_module)
|
||||
):
|
||||
raise AssertionError(
|
||||
"Expected operation module to support equalization (nn or custom)"
|
||||
)
|
||||
|
||||
# Scale the weights for input-weight equalization
|
||||
# If the following layer needs to be equalized then we will multiply its scale
|
||||
weight = op_module.weight
|
||||
assert isinstance(weight, torch.Tensor)
|
||||
if not isinstance(weight, torch.Tensor):
|
||||
raise AssertionError("Expected op_module.weight to be a torch.Tensor")
|
||||
|
||||
# Scale the weights by the reciprocal of the equalization scale
|
||||
# Reshape the equalization scale so that we can multiply it to the weight along axis=1
|
||||
|
|
@ -547,7 +580,8 @@ def scale_weight_node(
|
|||
bias = op_module.bias
|
||||
if bias is None:
|
||||
return
|
||||
assert isinstance(bias, torch.Tensor)
|
||||
if not isinstance(bias, torch.Tensor):
|
||||
raise AssertionError("Expected op_module.bias to be a torch.Tensor")
|
||||
|
||||
# Reshape the equalization scale so that we can multiply it element-wise to the bias
|
||||
next_equalization_scale_reshaped = reshape_scale(next_equalization_scale, 0, bias)
|
||||
|
|
@ -581,15 +615,20 @@ def scale_weight_functional(
|
|||
weight_quant_obs_node = weight_eq_obs_node.args[0]
|
||||
if weight_quant_obs_node is None:
|
||||
return
|
||||
assert isinstance(weight_quant_obs_node, Node) and isinstance(
|
||||
modules[str(weight_quant_obs_node.target)], ObserverBase
|
||||
if not (
|
||||
isinstance(weight_quant_obs_node, Node)
|
||||
and isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
|
||||
):
|
||||
raise AssertionError(
|
||||
"Expected weight_quant_obs_node to be a Node whose module is an ObserverBase"
|
||||
)
|
||||
|
||||
# Get the get_attr(weight) node
|
||||
weight_node = weight_quant_obs_node.args[0]
|
||||
if weight_node is None:
|
||||
return
|
||||
assert isinstance(weight_node, Node) and weight_node.op == "get_attr"
|
||||
if not (isinstance(weight_node, Node) and weight_node.op == "get_attr"):
|
||||
raise AssertionError("Expected weight node to be a 'get_attr' Node")
|
||||
|
||||
weight_parent_name, weight_name = _parent_name(weight_node.target)
|
||||
weight = getattr(modules[weight_parent_name], weight_name)
|
||||
|
|
@ -612,7 +651,8 @@ def scale_weight_functional(
|
|||
scaled_weight = torch.mul(scaled_weight, next_equalization_scale_reshaped)
|
||||
|
||||
setattr(modules[weight_parent_name], weight_name, scaled_weight)
|
||||
assert torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight)
|
||||
if not torch.allclose(model.get_buffer(str(weight_node.target)), scaled_weight):
|
||||
raise AssertionError("Model buffer for weight does not match the scaled weight")
|
||||
|
||||
# Multiply the bias element wise by the next equalization scale
|
||||
bias_node = None
|
||||
|
|
@ -644,10 +684,14 @@ def clear_weight_quant_obs_node(op_node: Node, modules: dict[str, nn.Module]) ->
|
|||
weight_quant_obs_node = weight_eq_obs_node.args[0]
|
||||
if weight_quant_obs_node is None:
|
||||
return
|
||||
assert isinstance(weight_quant_obs_node, Node)
|
||||
if not isinstance(weight_quant_obs_node, Node):
|
||||
raise AssertionError("Expected weight_quant_obs_node to be a Node")
|
||||
|
||||
weight_quant_obs = modules[str(weight_quant_obs_node.target)]
|
||||
assert isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase)
|
||||
if not isinstance(modules[str(weight_quant_obs_node.target)], ObserverBase):
|
||||
raise AssertionError(
|
||||
"Expected the module at weight_quant_obs_node to be an ObserverBase"
|
||||
)
|
||||
weight_quant_obs.reset_min_max_vals() # type: ignore[operator]
|
||||
|
||||
|
||||
|
|
@ -682,7 +726,10 @@ def update_obs_for_equalization(
|
|||
modules[node.target], _InputEqualizationObserver
|
||||
):
|
||||
input_eq_obs = modules[node.target]
|
||||
assert isinstance(input_eq_obs, _InputEqualizationObserver)
|
||||
if not isinstance(input_eq_obs, _InputEqualizationObserver):
|
||||
raise AssertionError(
|
||||
"Expected module at node.target to be an _InputEqualizationObserver"
|
||||
)
|
||||
op_node, weight_eq_obs = get_op_node_and_weight_eq_obs(node, model, modules)
|
||||
|
||||
if op_node is None or weight_eq_obs is None:
|
||||
|
|
@ -693,7 +740,10 @@ def update_obs_for_equalization(
|
|||
# been created
|
||||
if fused_module_supports_equalization(modules[str(op_node.target)]):
|
||||
module = modules[str(op_node.target)][0] # type: ignore[index]
|
||||
assert nn_module_supports_equalization(module)
|
||||
if not nn_module_supports_equalization(module):
|
||||
raise AssertionError(
|
||||
"Expected fused module to support equalization"
|
||||
)
|
||||
weight_eq_obs(module.weight)
|
||||
else:
|
||||
weight_eq_obs(modules[str(op_node.target)].weight)
|
||||
|
|
@ -810,7 +860,10 @@ def convert_eq_obs(
|
|||
|
||||
elif weight_eq_obs_dict.get(node.name, None) is not None:
|
||||
weight_eq_obs = weight_eq_obs_dict.get(node.name)
|
||||
assert isinstance(weight_eq_obs, _WeightEqualizationObserver)
|
||||
if not isinstance(weight_eq_obs, _WeightEqualizationObserver):
|
||||
raise AssertionError(
|
||||
"Expected weight equalization observer to be a _WeightEqualizationObserver"
|
||||
)
|
||||
equalization_scale = weight_eq_obs.equalization_scale
|
||||
|
||||
if (
|
||||
|
|
@ -844,8 +897,11 @@ def convert_eq_obs(
|
|||
weight_eq_obs_node = maybe_get_weight_eq_obs_node(node, modules)
|
||||
if weight_eq_obs_node is None:
|
||||
return
|
||||
assert isinstance(
|
||||
if not isinstance(
|
||||
modules[str(weight_eq_obs_node.target)], _WeightEqualizationObserver
|
||||
):
|
||||
raise AssertionError(
|
||||
"Expected weight equalization observer to be a _WeightEqualizationObserver"
|
||||
)
|
||||
|
||||
# Clear the quantization observer's min/max values so that they
|
||||
|
|
|
|||
|
|
@ -585,7 +585,8 @@ def _match_static_pattern(
|
|||
return SKIP_LOWERING_VALUE
|
||||
q_node = node
|
||||
ref_node = q_node.args[0]
|
||||
assert isinstance(ref_node, Node)
|
||||
if not isinstance(ref_node, Node):
|
||||
raise AssertionError("Expected the reference node to be a torch.fx Node")
|
||||
|
||||
# Handle cases where the node is wrapped in a ReLU
|
||||
if (ref_node.op == "call_function" and ref_node.target in (F.relu, torch.relu)) or (
|
||||
|
|
@ -593,7 +594,10 @@ def _match_static_pattern(
|
|||
):
|
||||
relu_node = ref_node
|
||||
ref_node = relu_node.args[0]
|
||||
assert isinstance(ref_node, Node)
|
||||
if not isinstance(ref_node, Node):
|
||||
raise AssertionError(
|
||||
"Expected the reference node after ReLU to be a torch.fx Node"
|
||||
)
|
||||
else:
|
||||
relu_node = None
|
||||
if should_skip_lowering(ref_node, qconfig_map):
|
||||
|
|
@ -616,7 +620,8 @@ def _match_static_pattern(
|
|||
# (2) There must be at least one dequantize node
|
||||
matched_dequantize = False
|
||||
for i in dequantize_node_arg_indices:
|
||||
assert i < len(ref_node.args), (
|
||||
if i >= len(ref_node.args):
|
||||
raise AssertionError(
|
||||
f"Dequantize index {i} exceeded reference node's arg length {len(ref_node.args)}"
|
||||
)
|
||||
arg = ref_node.args[i]
|
||||
|
|
@ -660,7 +665,8 @@ def _match_static_pattern_with_two_inputs(
|
|||
return SKIP_LOWERING_VALUE
|
||||
q_node = node
|
||||
ref_node = q_node.args[0]
|
||||
assert isinstance(ref_node, Node)
|
||||
if not isinstance(ref_node, Node):
|
||||
raise AssertionError("Expected the reference node to be a torch.fx Node")
|
||||
|
||||
if should_skip_lowering(ref_node, qconfig_map):
|
||||
return SKIP_LOWERING_VALUE
|
||||
|
|
@ -711,13 +717,21 @@ def _lower_static_weighted_ref_module(
|
|||
)
|
||||
if q_node is None:
|
||||
continue
|
||||
assert ref_node is not None
|
||||
if ref_node is None:
|
||||
raise AssertionError(
|
||||
"Expected a reference node when matching static pattern"
|
||||
)
|
||||
(_, scale_node, zero_point_node, _) = q_node.args
|
||||
ref_module = _get_module(ref_node, modules)
|
||||
ref_class = type(ref_module)
|
||||
assert isinstance(scale_node, Node)
|
||||
assert isinstance(zero_point_node, Node)
|
||||
assert issubclass(ref_class, nn.Module)
|
||||
if not isinstance(scale_node, Node):
|
||||
raise AssertionError("Expected scale_node to be a Node")
|
||||
if not isinstance(zero_point_node, Node):
|
||||
raise AssertionError("Expected zero_point_node to be a Node")
|
||||
if not issubclass(ref_class, nn.Module):
|
||||
raise AssertionError(
|
||||
"Expected reference module class to be a subclass of nn.Module"
|
||||
)
|
||||
|
||||
# Step 1: Change this pattern to use the corresponding quantized module
|
||||
# For fused modules, we also check whether the inner module is a reference module
|
||||
|
|
@ -736,9 +750,11 @@ def _lower_static_weighted_ref_module(
|
|||
setattr(modules[parent_name], module_name, q_module)
|
||||
|
||||
# Step 2: Reroute around dq_node, and remove q_node and its args
|
||||
assert len(ref_node.args) == 1
|
||||
if len(ref_node.args) != 1:
|
||||
raise AssertionError("Expected reference node to have exactly 1 arg")
|
||||
dq_node = ref_node.args[0]
|
||||
assert isinstance(dq_node, Node)
|
||||
if not isinstance(dq_node, Node):
|
||||
raise AssertionError("Expected dq_node to be a Node")
|
||||
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
|
||||
q_node.replace_all_uses_with(ref_node)
|
||||
model.graph.erase_node(q_node)
|
||||
|
|
@ -771,13 +787,21 @@ def _lower_static_weighted_ref_module_with_two_inputs(
|
|||
)
|
||||
if q_node is None:
|
||||
continue
|
||||
assert ref_node is not None
|
||||
if ref_node is None:
|
||||
raise AssertionError(
|
||||
"Expected a reference node when matching static pattern with two inputs"
|
||||
)
|
||||
(_, scale_node, zero_point_node, _) = q_node.args
|
||||
ref_module = _get_module(ref_node, modules)
|
||||
ref_class = type(ref_module)
|
||||
assert isinstance(scale_node, Node)
|
||||
assert isinstance(zero_point_node, Node)
|
||||
assert issubclass(ref_class, nn.Module)
|
||||
if not isinstance(scale_node, Node):
|
||||
raise AssertionError("Expected scale_node to be a Node")
|
||||
if not isinstance(zero_point_node, Node):
|
||||
raise AssertionError("Expected zero_point_node to be a Node")
|
||||
if not issubclass(ref_class, nn.Module):
|
||||
raise AssertionError(
|
||||
"Expected reference module class to be a subclass of nn.Module"
|
||||
)
|
||||
|
||||
# Step 1: Change this pattern to use the corresponding quantized module
|
||||
# For fused modules, we also check whether the inner module is a reference module
|
||||
|
|
@ -798,12 +822,14 @@ def _lower_static_weighted_ref_module_with_two_inputs(
|
|||
setattr(modules[parent_name], module_name, q_module)
|
||||
|
||||
# Step 2: Reroute around dq_node, and remove q_node and its args
|
||||
assert len(ref_node.args) == 2
|
||||
if len(ref_node.args) != 2:
|
||||
raise AssertionError("Expected reference node to have exactly 2 args")
|
||||
for arg in ref_node.args:
|
||||
if not is_dequantize_node(arg):
|
||||
continue
|
||||
dq_node = arg
|
||||
assert isinstance(dq_node, Node)
|
||||
if not isinstance(dq_node, Node):
|
||||
raise AssertionError("Expected dq_node to be a Node")
|
||||
ref_node.replace_input_with(dq_node, dq_node.args[0]) # type: ignore[arg-type]
|
||||
|
||||
q_node.replace_all_uses_with(ref_node)
|
||||
|
|
@ -900,14 +926,21 @@ def _lower_static_weighted_ref_functional(
|
|||
)
|
||||
if q_node is None:
|
||||
continue
|
||||
assert func_node is not None
|
||||
if func_node is None:
|
||||
raise AssertionError(
|
||||
"Expected a function node when matching static functional pattern"
|
||||
)
|
||||
(_, output_scale_node, output_zp_node, _) = q_node.args
|
||||
(input_dq_node, weight_dq_node, *remaining_func_args) = func_node.args
|
||||
assert isinstance(output_zp_node, Node)
|
||||
assert isinstance(input_dq_node, Node)
|
||||
assert isinstance(weight_dq_node, Node)
|
||||
if not isinstance(output_zp_node, Node):
|
||||
raise AssertionError("Expected output_zp_node to be a Node")
|
||||
if not isinstance(input_dq_node, Node):
|
||||
raise AssertionError("Expected input_dq_node to be a Node")
|
||||
if not isinstance(weight_dq_node, Node):
|
||||
raise AssertionError("Expected weight_dq_node to be a Node")
|
||||
quantized_weight = weight_dq_node.args[0]
|
||||
assert isinstance(quantized_weight, Node)
|
||||
if not isinstance(quantized_weight, Node):
|
||||
raise AssertionError("Expected quantized_weight to be a Node")
|
||||
if quantized_weight.op != "call_function" or quantized_weight.target not in (
|
||||
torch.quantize_per_tensor,
|
||||
torch.quantize_per_channel,
|
||||
|
|
@ -1135,7 +1168,10 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfi
|
|||
)
|
||||
if q_node is None:
|
||||
continue
|
||||
assert bop_node is not None
|
||||
if bop_node is None:
|
||||
raise AssertionError(
|
||||
"Expected a binary op node when matching quantized binary op pattern"
|
||||
)
|
||||
(_, scale_node, zero_point_node, _) = q_node.args
|
||||
|
||||
# Step 1: Remove dequant nodes
|
||||
|
|
@ -1144,14 +1180,21 @@ def _lower_quantized_binary_op(model: GraphModule, qconfig_map: dict[str, QConfi
|
|||
if not is_dequantize_node(arg):
|
||||
continue
|
||||
dq_node = arg
|
||||
assert isinstance(dq_node, Node)
|
||||
if not isinstance(dq_node, Node):
|
||||
raise AssertionError("Expected dq_node to be a Node")
|
||||
dn_input = dq_node.args[0]
|
||||
bop_node.replace_input_with(dq_node, dn_input) # type: ignore[arg-type]
|
||||
num_dq_nodes += 1
|
||||
assert num_dq_nodes > 0
|
||||
if num_dq_nodes <= 0:
|
||||
raise AssertionError(
|
||||
"Expected at least one dequantize node in binary op args"
|
||||
)
|
||||
|
||||
# Step 2: Swap binary op to quantized binary op
|
||||
assert bop_node.target in QBIN_OP_MAPPING
|
||||
if bop_node.target not in QBIN_OP_MAPPING:
|
||||
raise AssertionError(
|
||||
f"Unsupported binary op {bop_node.target} for lowering"
|
||||
)
|
||||
binop_to_qbinop = QBIN_OP_MAPPING if relu_node is None else QBIN_RELU_OP_MAPPING
|
||||
qbin_op = binop_to_qbinop[bop_node.target]
|
||||
# prepare the args for quantized binary op
|
||||
|
|
@ -1188,7 +1231,8 @@ def special_pattern_replacement(model: GraphModule):
|
|||
and len(q_node.args) == 2
|
||||
and q_node.args[1] == torch.float16
|
||||
)
|
||||
if not (is_quantize or is_to_fp16):
|
||||
# Only continue when neither quantize nor to_fp16
|
||||
if not is_quantize and not is_to_fp16:
|
||||
continue
|
||||
ref_node = q_node.args[0]
|
||||
# get output scale/zero_point/dtype from the quantize node
|
||||
|
|
@ -1217,13 +1261,17 @@ def special_pattern_replacement(model: GraphModule):
|
|||
)
|
||||
if not (is_call_module or is_call_function or is_call_method):
|
||||
continue
|
||||
assert len(ref_node.args) > 0 or len(ref_node.kwargs) > 0
|
||||
if len(ref_node.args) <= 0 and len(ref_node.kwargs) <= 0:
|
||||
raise AssertionError("Expected ref_node to have args or kwargs")
|
||||
dq_node_or_nodes = (
|
||||
ref_node.args[0]
|
||||
if len(ref_node.args) > 0
|
||||
else next(iter(ref_node.kwargs.values()))
|
||||
)
|
||||
assert isinstance(dq_node_or_nodes, (Node, tuple, list))
|
||||
if not isinstance(dq_node_or_nodes, (Node, tuple, list)):
|
||||
raise AssertionError(
|
||||
"Expected dq_node_or_nodes to be a Node, tuple, or list"
|
||||
)
|
||||
is_dequantize = False
|
||||
if isinstance(dq_node_or_nodes, Node):
|
||||
is_dequantize = (
|
||||
|
|
|
|||
|
|
@ -362,11 +362,15 @@ class PerChannelDetector(DetectorBase):
|
|||
|
||||
# assert statement for MyPy
|
||||
q_config_file = module.qconfig
|
||||
assert isinstance(q_config_file, QConfig)
|
||||
if not isinstance(q_config_file, QConfig):
|
||||
raise AssertionError("module.qconfig must be a QConfig")
|
||||
|
||||
# this object should either be fake quant or observer
|
||||
q_or_s_obj = module.qconfig.weight.p.func()
|
||||
assert isinstance(q_or_s_obj, (FakeQuantize, ObserverBase))
|
||||
if not isinstance(q_or_s_obj, (FakeQuantize, ObserverBase)):
|
||||
raise AssertionError(
|
||||
"module.qconfig.weight must be a FakeQuantize or ObserverBase"
|
||||
)
|
||||
|
||||
per_channel_used = False # will be true if found in qconfig
|
||||
|
||||
|
|
@ -1160,7 +1164,8 @@ class InputWeightEqualizationDetector(DetectorBase):
|
|||
input_channels = len(input_ratio)
|
||||
if weight_channels != input_channels:
|
||||
# we try to replicate
|
||||
assert input_channels % weight_channels == 0, (
|
||||
if input_channels % weight_channels != 0:
|
||||
raise AssertionError(
|
||||
"input channels should be divisible by weight channels."
|
||||
)
|
||||
# get replication factor
|
||||
|
|
@ -1418,10 +1423,14 @@ class OutlierDetector(DetectorBase):
|
|||
self.ratio_threshold = ratio_threshold
|
||||
|
||||
# make sure passed in percentile is valid
|
||||
assert reference_percentile >= 0 and reference_percentile <= 1
|
||||
assert (
|
||||
if reference_percentile < 0 or reference_percentile > 1:
|
||||
raise AssertionError("reference_percentile must be between 0 and 1")
|
||||
if not (
|
||||
fraction_batches_used_threshold >= 0
|
||||
and fraction_batches_used_threshold <= 1
|
||||
):
|
||||
raise AssertionError(
|
||||
"fraction_batches_used_threshold must be between 0 and 1"
|
||||
)
|
||||
self.reference_percentile = reference_percentile
|
||||
self.fraction_batches_used_threshold = fraction_batches_used_threshold
|
||||
|
|
|
|||
|
|
@ -261,7 +261,8 @@ class ModelReport:
|
|||
raise ValueError("The node_fqn is was not found within the module.")
|
||||
|
||||
# assert for MyPy
|
||||
assert isinstance(node_to_return, torch.fx.node.Node)
|
||||
if not isinstance(node_to_return, torch.fx.node.Node):
|
||||
raise AssertionError("node_to_return must be a torch.fx.node.Node")
|
||||
|
||||
return node_to_return
|
||||
|
||||
|
|
|
|||
|
|
@ -112,8 +112,12 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|||
or quantize_per_channel and dequantize_per_channel
|
||||
"""
|
||||
graph = model.graph
|
||||
assert modules is not None
|
||||
assert isinstance(node.target, str)
|
||||
if modules is None:
|
||||
raise AssertionError("modules must not be None")
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(
|
||||
f"Expected node.target to be a str, but got {type(node.target)}"
|
||||
)
|
||||
module_path, prefix = _get_module_path_and_prefix(
|
||||
node, node_name_to_scope, node_name_to_qconfig
|
||||
)
|
||||
|
|
@ -260,9 +264,9 @@ def _replace_observer_with_quantize_dequantize_node_decomposed(
|
|||
# and that can be done after we remove reduce_range flag
|
||||
# 1. extract qparams from activation_post_process module
|
||||
dtype_ = to_underlying_dtype(dtype)
|
||||
assert dtype_ in [torch.uint8, torch.int8], (
|
||||
"only uint8 and int8 are supported in reference flow for "
|
||||
"dynamic quantization right now"
|
||||
if dtype_ not in [torch.uint8, torch.int8]:
|
||||
raise AssertionError(
|
||||
"only uint8 and int8 are supported in reference flow for dynamic quantization right now"
|
||||
)
|
||||
quant_min = activation_post_process.quant_min # type: ignore[attr-defined]
|
||||
quant_max = activation_post_process.quant_max # type: ignore[attr-defined]
|
||||
|
|
@ -379,8 +383,12 @@ def _replace_observer_with_quantize_dequantize_node(
|
|||
After:
|
||||
... -> torch.quantize_per_tensor(x, ...) -> x.dequantize() -> ...
|
||||
"""
|
||||
assert modules is not None
|
||||
assert isinstance(node.target, str)
|
||||
if modules is None:
|
||||
raise AssertionError("modules must not be None")
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError(
|
||||
f"Expected node.target to be a str, but got {type(node.target)}"
|
||||
)
|
||||
graph = model.graph
|
||||
module_path, prefix = _get_module_path_and_prefix(
|
||||
node, node_name_to_scope, node_name_to_qconfig
|
||||
|
|
@ -521,7 +529,8 @@ def _replace_observer_or_dequant_stub_with_dequantize_node(
|
|||
node: Node, graph: Graph
|
||||
) -> None:
|
||||
call_custom_module_node = node.args[0]
|
||||
assert isinstance(call_custom_module_node, Node), (
|
||||
if not isinstance(call_custom_module_node, Node):
|
||||
raise AssertionError(
|
||||
f"Expecting the for call custom module node to be a Node, but got {call_custom_module_node}"
|
||||
)
|
||||
node.replace_all_uses_with(call_custom_module_node)
|
||||
|
|
@ -617,7 +626,8 @@ def _get_module_path_and_prefix(
|
|||
# operator (they can be the same)
|
||||
# this flag identifies if the observer is inserted only because the observed node is
|
||||
# the input of the next operator
|
||||
assert isinstance(observed_node, Node), (
|
||||
if not isinstance(observed_node, Node):
|
||||
raise AssertionError(
|
||||
f"Expecting observed node to be a Node, but got {observed_node}"
|
||||
)
|
||||
is_input_observer_only = (
|
||||
|
|
@ -727,8 +737,10 @@ def convert_standalone_module(
|
|||
"_observed_graph_module_attrs"
|
||||
].standalone_module_output_quantized_idxs
|
||||
if len(sm_output_quantized_idxs) > 0:
|
||||
assert sm_output_quantized_idxs[0] == 0, "Currently only quantized"
|
||||
"output idxs = [0] is supported"
|
||||
if sm_output_quantized_idxs[0] != 0:
|
||||
raise AssertionError(
|
||||
"Currently only quantized output idxs = [0] is supported"
|
||||
)
|
||||
|
||||
# if it's non-empty, then it means the output is kept in quantized form
|
||||
# we'll just add a dequantize node after this node
|
||||
|
|
@ -882,7 +894,8 @@ def convert_weighted_module(
|
|||
ref_qmodule_cls = root_module_to_quantized_reference_module.get(
|
||||
type_before_parametrizations(float_module), None
|
||||
)
|
||||
assert ref_qmodule_cls is not None, (
|
||||
if ref_qmodule_cls is None:
|
||||
raise AssertionError(
|
||||
f"No reference quantized module class configured for {type_before_parametrizations(float_module)}"
|
||||
)
|
||||
ref_qmodule = ref_qmodule_cls.from_float(float_module, wq_or_wq_dict) # type: ignore[attr-defined]
|
||||
|
|
@ -904,7 +917,8 @@ def _remove_previous_dequantize_in_custom_module(
|
|||
\\ - dequantize
|
||||
"""
|
||||
# expecting the input node for a custom module node to be a Node
|
||||
assert isinstance(prev_node, Node), (
|
||||
if not isinstance(prev_node, Node):
|
||||
raise AssertionError(
|
||||
f"Expecting the argument for custom module node to be a Node, but got {prev_node}"
|
||||
)
|
||||
if prev_node.op == "call_method" and prev_node.target == "dequantize":
|
||||
|
|
@ -952,15 +966,21 @@ def convert_custom_module(
|
|||
if _is_custom_module_lstm(node, modules):
|
||||
# The inputs are tuples in the form (input, (hidden0, hidden1))
|
||||
# Ensure all three input nodes are quantized
|
||||
assert (
|
||||
if not (
|
||||
len(node.args) == 2
|
||||
and isinstance(node.args[1], tuple)
|
||||
and len(node.args[1]) == 2
|
||||
):
|
||||
raise AssertionError(
|
||||
"Expected LSTM custom module inputs to be (input, (hidden0, hidden1))"
|
||||
)
|
||||
(inputs, (hidden0, hidden1)) = node.args # type: ignore[misc]
|
||||
assert isinstance(inputs, Node)
|
||||
assert isinstance(hidden0, Node)
|
||||
assert isinstance(hidden1, Node)
|
||||
if not isinstance(inputs, Node):
|
||||
raise AssertionError("Expected inputs to be a Node")
|
||||
if not isinstance(hidden0, Node):
|
||||
raise AssertionError("Expected hidden0 to be a Node")
|
||||
if not isinstance(hidden1, Node):
|
||||
raise AssertionError("Expected hidden1 to be a Node")
|
||||
_remove_previous_dequantize_in_custom_module(node, inputs, graph)
|
||||
_remove_previous_dequantize_in_custom_module(node, hidden0, graph)
|
||||
_remove_previous_dequantize_in_custom_module(node, hidden1, graph)
|
||||
|
|
@ -971,22 +991,32 @@ def convert_custom_module(
|
|||
# to the module.
|
||||
# Additional handling is yet to be implemented for the outputs, similar
|
||||
# to LSTM custom module
|
||||
assert len(node.args) == 3
|
||||
if len(node.args) != 3:
|
||||
raise AssertionError(
|
||||
"Expected MHA custom module inputs to be (query, key, value)"
|
||||
)
|
||||
query, key, value = node.args
|
||||
assert isinstance(query, Node)
|
||||
assert isinstance(key, Node)
|
||||
assert isinstance(value, Node)
|
||||
if not isinstance(query, Node):
|
||||
raise AssertionError("Expected query to be a Node")
|
||||
if not isinstance(key, Node):
|
||||
raise AssertionError("Expected key to be a Node")
|
||||
if not isinstance(value, Node):
|
||||
raise AssertionError("Expected value to be a Node")
|
||||
_remove_previous_dequantize_in_custom_module(node, query, graph)
|
||||
_remove_previous_dequantize_in_custom_module(node, key, graph)
|
||||
_remove_previous_dequantize_in_custom_module(node, value, graph)
|
||||
else:
|
||||
# remove the previous dequant node to ensure the inputs are quantized
|
||||
arg = node.args[0]
|
||||
assert isinstance(arg, Node)
|
||||
if not isinstance(arg, Node):
|
||||
raise AssertionError("Expected arg to be a Node")
|
||||
_remove_previous_dequantize_in_custom_module(node, arg, graph)
|
||||
# absorb the following observer into the module conversion
|
||||
activation_post_process = _maybe_get_observer_for_node(node, modules)
|
||||
assert activation_post_process is not None
|
||||
if activation_post_process is None:
|
||||
raise AssertionError(
|
||||
"Expected activation_post_process to be present for observed custom module"
|
||||
)
|
||||
observed_custom_module.activation_post_process = activation_post_process
|
||||
|
||||
# swap the observed custom module to quantized custom module
|
||||
|
|
@ -1061,7 +1091,8 @@ def convert(
|
|||
QConfigMapping.from_dict(qconfig_mapping) if qconfig_mapping else None
|
||||
)
|
||||
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
||||
assert qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)
|
||||
if not (qconfig_mapping is None or isinstance(qconfig_mapping, QConfigMapping)):
|
||||
raise AssertionError("qconfig_mapping must be None or a QConfigMapping")
|
||||
|
||||
if isinstance(backend_config, dict):
|
||||
warnings.warn(
|
||||
|
|
@ -1075,7 +1106,8 @@ def convert(
|
|||
if backend_config is None:
|
||||
backend_config = get_native_backend_config()
|
||||
|
||||
assert _is_observed_module(model), "incoming model must be produced by prepare_fx"
|
||||
if not _is_observed_module(model):
|
||||
raise AssertionError("incoming model must be produced by prepare_fx")
|
||||
observed_graph_module_attrs = model.meta["_observed_graph_module_attrs"]
|
||||
node_name_to_scope: dict[str, tuple[str, type]] = (
|
||||
observed_graph_module_attrs.node_name_to_scope
|
||||
|
|
@ -1121,11 +1153,13 @@ def convert(
|
|||
# all the values either match what was set in prepare node_name_to_qconfig
|
||||
# or are set to None in the convert_node_name_to_qconfig.
|
||||
for k, v in node_name_to_qconfig.items():
|
||||
assert k in convert_node_name_to_qconfig, (
|
||||
if k not in convert_node_name_to_qconfig:
|
||||
raise AssertionError(
|
||||
f"Expected key {k} in convert node_name_to_qconfig"
|
||||
)
|
||||
if convert_node_name_to_qconfig[k] is not None:
|
||||
assert qconfig_equals(v, convert_node_name_to_qconfig[k]), (
|
||||
if not qconfig_equals(v, convert_node_name_to_qconfig[k]):
|
||||
raise AssertionError(
|
||||
f"Expected k {k} to have the same value in prepare and convert QConfigMappings, "
|
||||
f"but {v} was updated to {convert_node_name_to_qconfig[k]}"
|
||||
)
|
||||
|
|
@ -1201,7 +1235,10 @@ def convert(
|
|||
)
|
||||
elif node.op == "call_module":
|
||||
mod = _get_module(node, modules)
|
||||
assert mod is not None
|
||||
if mod is None:
|
||||
raise AssertionError(
|
||||
"Expected module for call_module node to be present in modules mapping"
|
||||
)
|
||||
if _is_activation_post_process(mod):
|
||||
observed_node = node.args[0]
|
||||
if observed_node in statically_quantized_custom_module_nodes:
|
||||
|
|
|
|||
|
|
@ -102,7 +102,10 @@ def fuse(
|
|||
else:
|
||||
node_subpattern = None
|
||||
if maybe_last_node is node:
|
||||
assert obj is not None
|
||||
if obj is None:
|
||||
raise AssertionError(
|
||||
"fuse handler object must not be None for matched root node"
|
||||
)
|
||||
root_node_getter = fusion_pattern_to_root_node_getter.get(
|
||||
pattern, default_root_node_getter
|
||||
)
|
||||
|
|
|
|||
|
|
@ -65,9 +65,8 @@ class DefaultFuseHandler(FuseHandler):
|
|||
fuser_method_mapping: dict[Pattern, Union[torch.nn.Sequential, Callable]],
|
||||
is_qat: bool,
|
||||
) -> Node:
|
||||
assert root_node.op == "call_module", (
|
||||
"Expecting module node to be a call_module Node"
|
||||
)
|
||||
if root_node.op != "call_module":
|
||||
raise AssertionError("Expecting module node to be a call_module Node")
|
||||
root_module = named_modules[str(root_node.target)]
|
||||
|
||||
def get_modules(pattern):
|
||||
|
|
|
|||
|
|
@ -109,7 +109,8 @@ def _get_lstm_with_individually_observed_parts(
|
|||
# TODO: maybe make this work for layer_bw as well
|
||||
for layer in quantizable_lstm.layers:
|
||||
cell = layer.layer_fw.cell # type: ignore[union-attr]
|
||||
assert isinstance(cell, torch.nn.Module), "cell should be a nn.Module"
|
||||
if not isinstance(cell, torch.nn.Module):
|
||||
raise AssertionError("cell should be a nn.Module")
|
||||
cell = prepare_fx(cell, cell_qm, example_inputs, backend_config=backend_config)
|
||||
# HACK: Manually replace the activation_post_process following these ops.
|
||||
# This is needed for FloatFunctional ops because there is currently no way
|
||||
|
|
@ -150,7 +151,8 @@ def _get_lstm_with_individually_observed_parts(
|
|||
continue
|
||||
if op_index not in op_index_to_activation_post_process_ctr:
|
||||
continue
|
||||
assert len(node.users) == 1
|
||||
if len(node.users) != 1:
|
||||
raise AssertionError("expected exactly one user for the node")
|
||||
activation_post_process_name = next(iter(node.users.keys())).name
|
||||
activation_post_process_ctr = op_index_to_activation_post_process_ctr[
|
||||
op_index
|
||||
|
|
@ -195,7 +197,8 @@ def _get_reference_quantized_lstm_module(
|
|||
for i, layer in enumerate(quantized_lstm.layers):
|
||||
cell = copy.deepcopy(observed_lstm.layers.get_submodule(str(i)).layer_fw.cell) # type: ignore[union-attr]
|
||||
cell = convert_to_reference_fx(cell, backend_config=backend_config) # type: ignore[arg-type]
|
||||
assert isinstance(cell, torch.fx.GraphModule)
|
||||
if not isinstance(cell, torch.fx.GraphModule):
|
||||
raise AssertionError("cell must be converted to a torch.fx.GraphModule")
|
||||
# HACK: Manually remove input quantize nodes and output dequantize nodes,
|
||||
# since custom modules expect quint8 inputs and outputs for now. Note that
|
||||
# this functionality is supposedly handled through PrepareCustomConfig's
|
||||
|
|
|
|||
|
|
@ -33,7 +33,8 @@ def _is_match(modules, node, pattern, max_uses=sys.maxsize):
|
|||
if isinstance(pattern, tuple):
|
||||
self_match, *arg_matches = pattern
|
||||
if self_match is getattr:
|
||||
assert len(pattern) == 2, "Expecting getattr pattern to have two elements"
|
||||
if len(pattern) != 2:
|
||||
raise AssertionError("Expecting getattr pattern to have two elements")
|
||||
arg_matches = []
|
||||
else:
|
||||
self_match = pattern
|
||||
|
|
@ -190,7 +191,8 @@ def _find_matches(
|
|||
break
|
||||
|
||||
# add custom module instances to the match result
|
||||
assert modules is not None
|
||||
if modules is None:
|
||||
raise AssertionError("modules must not be None")
|
||||
for node in graph.nodes:
|
||||
if (
|
||||
node.op == "call_module"
|
||||
|
|
@ -204,7 +206,8 @@ def _find_matches(
|
|||
)
|
||||
|
||||
def is_standalone_module(node_target: str, modules: dict[str, torch.nn.Module]):
|
||||
assert modules is not None
|
||||
if modules is None:
|
||||
raise AssertionError("modules must not be None")
|
||||
return (
|
||||
node_target in standalone_module_names
|
||||
or type(modules[node_target]) # type: ignore[operator]
|
||||
|
|
|
|||
|
|
@ -149,7 +149,8 @@ def _create_obs_or_fq_from_qspec(
|
|||
return None
|
||||
if isinstance(quantization_spec, SharedQuantizationSpec):
|
||||
edge_or_node = quantization_spec.edge_or_node
|
||||
assert edge_or_node in obs_or_fq_map, (
|
||||
if edge_or_node not in obs_or_fq_map:
|
||||
raise AssertionError(
|
||||
"please make sure only refer to edge or node that has "
|
||||
f"observer/fake_quant inserted: '{edge_or_node}' not in\n{obs_or_fq_map.keys()}"
|
||||
)
|
||||
|
|
@ -177,7 +178,8 @@ def _create_obs_or_fq_from_qspec(
|
|||
else:
|
||||
return observer_ctr()
|
||||
|
||||
assert isinstance(quantization_spec, QuantizationSpec)
|
||||
if not isinstance(quantization_spec, QuantizationSpec):
|
||||
raise AssertionError("quantization_spec must be a QuantizationSpec")
|
||||
observer_or_fake_quant_ctr = quantization_spec.observer_or_fake_quant_ctr
|
||||
kwargs = _get_observer_kwargs(quantization_spec)
|
||||
kwargs.pop("observer_or_fake_quant_ctr")
|
||||
|
|
@ -214,10 +216,14 @@ def _needs_obs_or_fq(
|
|||
# need to insert placeholder observer for dynamic quantization so that it can
|
||||
# be converted to choose_qparams -> q -> dq in convert step
|
||||
if cur_target_is_dynamic:
|
||||
assert cur_target_dtype in _OBS_DTYPE_LIST, (
|
||||
if cur_target_dtype not in _OBS_DTYPE_LIST:
|
||||
raise AssertionError(
|
||||
f"Expected cur_target_dtype to be torch.float, but got: {cur_target_dtype}"
|
||||
)
|
||||
assert prev_output_dtype not in _DO_NOT_OBS_DTYPE_LIST
|
||||
if prev_output_dtype in _DO_NOT_OBS_DTYPE_LIST:
|
||||
raise AssertionError(
|
||||
"prev_output_dtype must not be in _DO_NOT_OBS_DTYPE_LIST"
|
||||
)
|
||||
return is_zeroth_arg
|
||||
if reuse_input_obs_or_fq:
|
||||
return False
|
||||
|
|
@ -398,7 +404,8 @@ def _is_pattern_dtype_config_and_qconfig_supported_by_backend(
|
|||
"""
|
||||
if backend_config is None or pattern is None:
|
||||
return True
|
||||
assert matched_node_pattern is not None and len(matched_node_pattern) >= 1
|
||||
if matched_node_pattern is None or len(matched_node_pattern) < 1:
|
||||
raise AssertionError("matched_node_pattern must be non-empty")
|
||||
pattern_to_dtype_configs = get_pattern_to_dtype_configs(backend_config)
|
||||
dtype_configs: list[DTypeConfig] = pattern_to_dtype_configs.get(pattern, [])
|
||||
pattern_to_root_node_getter = get_fusion_pattern_to_root_node_getter(backend_config)
|
||||
|
|
@ -535,7 +542,8 @@ def _set_target_dtype_info_for_matched_node_pattern(
|
|||
# other types of matched object, e.g. int, float literals, are ignored
|
||||
elif isinstance(matched_node_pattern, Node):
|
||||
# for pyre
|
||||
assert isinstance(matched_node_pattern, Node)
|
||||
if not isinstance(matched_node_pattern, Node):
|
||||
raise AssertionError("matched_node_pattern must be a Node")
|
||||
node = matched_node_pattern
|
||||
if node in processed_nodes:
|
||||
return
|
||||
|
|
@ -674,7 +682,8 @@ def _get_output_act_obs_or_fq(
|
|||
We are assuming that the observers are inserted correctly, and the dtype for
|
||||
argument in quantized graph will match what is specified by the qconfig
|
||||
"""
|
||||
assert isinstance(arg, Node)
|
||||
if not isinstance(arg, Node):
|
||||
raise AssertionError("arg must be a Node")
|
||||
if "quantization_annotation" in arg.meta:
|
||||
return _create_obs_or_fq_from_qspec(
|
||||
arg.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
|
||||
|
|
@ -698,9 +707,8 @@ def _get_output_act_obs_or_fq(
|
|||
)
|
||||
elif _is_activation_post_process_node(arg, named_modules):
|
||||
observed_arg = arg.args[0]
|
||||
assert isinstance(observed_arg, Node), (
|
||||
"Currently we only support observing Node"
|
||||
)
|
||||
if not isinstance(observed_arg, Node):
|
||||
raise AssertionError("Currently we only support observing Node")
|
||||
if "quantization_annotation" in observed_arg.meta:
|
||||
output_act_obs_or_fq = _create_obs_or_fq_from_qspec(
|
||||
observed_arg.meta["quantization_annotation"].output_qspec,
|
||||
|
|
@ -708,7 +716,10 @@ def _get_output_act_obs_or_fq(
|
|||
is_qat,
|
||||
)
|
||||
else:
|
||||
assert "target_dtype_info" in observed_arg.meta
|
||||
if "target_dtype_info" not in observed_arg.meta:
|
||||
raise AssertionError(
|
||||
"expected 'target_dtype_info' in observed_arg.meta"
|
||||
)
|
||||
output_act_obs_or_fq_ctr = observed_arg.meta["target_dtype_info"][
|
||||
"output_act_obs_or_fq_ctr"
|
||||
]
|
||||
|
|
@ -754,7 +765,8 @@ def _get_arg_as_input_act_obs_or_fq(
|
|||
"""Get the observer or fake quant constructor for the Argument `arg`, as input
|
||||
to Node `node`
|
||||
"""
|
||||
assert isinstance(arg, Node)
|
||||
if not isinstance(arg, Node):
|
||||
raise AssertionError("arg must be a Node")
|
||||
# "input_qspec_map" is the more general design we'll use for pt2e path
|
||||
# it is a map from input argument node to observer or fake quant constructor, for example
|
||||
# for the following graph:
|
||||
|
|
@ -838,7 +850,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
|||
|
||||
if not isinstance(arg, Node):
|
||||
return arg
|
||||
assert isinstance(arg, Node)
|
||||
if not isinstance(arg, Node):
|
||||
raise AssertionError("arg must be a Node")
|
||||
# default (no observer)
|
||||
new_arg = arg
|
||||
|
||||
|
|
@ -854,7 +867,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
|||
"quantization_annotation"
|
||||
]._reuse_input_obs_or_fq
|
||||
else:
|
||||
assert "target_dtype_info" in node.meta
|
||||
if "target_dtype_info" not in node.meta:
|
||||
raise AssertionError("expected 'target_dtype_info' in node.meta")
|
||||
# TODO: we are assuming "target_dtype_info" exists here, maybe
|
||||
# a default value also need to be provided here
|
||||
target_dtype_info = node.meta["target_dtype_info"]
|
||||
|
|
@ -889,7 +903,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
|||
)
|
||||
|
||||
else:
|
||||
assert qconfig is not None
|
||||
if qconfig is None:
|
||||
raise AssertionError("qconfig must not be None")
|
||||
# custom flow for standalone modules
|
||||
_, _, sm_prepare_custom_config, _ = _get_standalone_module_configs(
|
||||
node, named_modules, prepare_custom_config, qconfig, backend_config
|
||||
|
|
@ -946,7 +961,8 @@ def _maybe_insert_input_observer_for_arg_or_kwarg(
|
|||
existing_obs_node = maybe_obs_node
|
||||
break
|
||||
|
||||
assert arg_as_input_act_obs_or_fq is not None
|
||||
if arg_as_input_act_obs_or_fq is None:
|
||||
raise AssertionError("arg_as_input_act_obs_or_fq must not be None")
|
||||
obs_or_fq_map[(arg, node)] = arg_as_input_act_obs_or_fq
|
||||
if existing_obs_node is None:
|
||||
new_obs_node = _insert_obs_or_fq(
|
||||
|
|
@ -1102,7 +1118,8 @@ def _maybe_insert_output_observer_for_node(
|
|||
Note: inserting dynamic quantization ops for output is not supported in fx graph mode
|
||||
quantization code path right now
|
||||
"""
|
||||
assert node.op != "output", "observer insertion for outputs is handled elsewhere"
|
||||
if node.op == "output":
|
||||
raise AssertionError("observer insertion for outputs is handled elsewhere")
|
||||
|
||||
is_standalone_module = False
|
||||
if "quantization_annotation" in node.meta:
|
||||
|
|
@ -1110,7 +1127,8 @@ def _maybe_insert_output_observer_for_node(
|
|||
node.meta["quantization_annotation"].output_qspec, obs_or_fq_map, is_qat
|
||||
)
|
||||
else:
|
||||
assert "target_dtype_info" in node.meta
|
||||
if "target_dtype_info" not in node.meta:
|
||||
raise AssertionError("expected 'target_dtype_info' in node.meta")
|
||||
is_standalone_module = node.meta["target_dtype_info"].get(
|
||||
"_is_standalone_module", False
|
||||
)
|
||||
|
|
@ -1222,7 +1240,10 @@ def _maybe_insert_observers_before_graph_output(
|
|||
and arg_as_input_target_dtype != torch.float
|
||||
)
|
||||
if need_obs:
|
||||
assert observer_mod is not None
|
||||
if observer_mod is None:
|
||||
raise AssertionError(
|
||||
"observer_mod must not be None when need_obs is True"
|
||||
)
|
||||
# insert observer
|
||||
observer_node = _insert_obs_or_fq(
|
||||
maybe_node, observer_mod, model, named_modules, graph
|
||||
|
|
@ -1393,9 +1414,11 @@ def _maybe_make_input_output_share_observers(
|
|||
if iteration_guard > 10000:
|
||||
raise AssertionError("Unable to find observer of previous node")
|
||||
|
||||
assert isinstance(first_arg_arg, Node)
|
||||
if not isinstance(first_arg_arg, Node):
|
||||
raise AssertionError("first_arg_arg must be a Node")
|
||||
target_to_use = first_arg_arg.target
|
||||
assert isinstance(target_to_use, str)
|
||||
if not isinstance(target_to_use, str):
|
||||
raise AssertionError("target_to_use must be a string")
|
||||
obs_mod_to_use = named_modules[target_to_use]
|
||||
|
||||
if isinstance(first_arg, (list, tuple)):
|
||||
|
|
@ -1418,7 +1441,10 @@ def _maybe_make_input_output_share_observers(
|
|||
|
||||
# set the output observer node to use that module
|
||||
for output_obs_node in node.users.keys():
|
||||
assert _is_activation_post_process_node(output_obs_node, named_modules)
|
||||
if not _is_activation_post_process_node(output_obs_node, named_modules):
|
||||
raise AssertionError(
|
||||
"output_obs_node must be an activation post process node"
|
||||
)
|
||||
parent_name, name = _parent_name(output_obs_node.target)
|
||||
setattr(named_modules[parent_name], name, obs_mod_to_use)
|
||||
|
||||
|
|
@ -1431,7 +1457,10 @@ def _remove_output_observer(
|
|||
):
|
||||
items = list(node.users.items())
|
||||
for output_obs_node, _ in items:
|
||||
assert _is_activation_post_process_node(output_obs_node, named_modules)
|
||||
if not _is_activation_post_process_node(output_obs_node, named_modules):
|
||||
raise AssertionError(
|
||||
"output_obs_node must be an activation post process node"
|
||||
)
|
||||
output_obs_node.replace_all_uses_with(node)
|
||||
model.graph.erase_node(output_obs_node) # type: ignore[union-attr, operator]
|
||||
|
||||
|
|
@ -1554,7 +1583,8 @@ def insert_observers_for_model(
|
|||
qhandler,
|
||||
qconfig,
|
||||
) = match_res_with_qconfig
|
||||
assert qhandler is not None
|
||||
if qhandler is None:
|
||||
raise AssertionError("qhandler must not be None")
|
||||
_set_target_dtype_info_for_matched_node_pattern(
|
||||
matched_node_pattern,
|
||||
last_node,
|
||||
|
|
@ -1632,7 +1662,8 @@ def insert_observers_for_model(
|
|||
pattern, matched_node_pattern, qconfig, backend_config
|
||||
)
|
||||
)
|
||||
assert qhandler is not None
|
||||
if qhandler is None:
|
||||
raise AssertionError("qhandler must not be None")
|
||||
|
||||
# get output_act_dtype so that we don't also reset the special typed nodes
|
||||
# TODO: we might want to handle these more uniformly with the default path
|
||||
|
|
@ -1726,7 +1757,8 @@ def insert_observers_for_model(
|
|||
if not skip_inserting_observers and is_supported_by_backend:
|
||||
named_modules = dict(model.named_modules(remove_duplicate=False))
|
||||
if node.op != "output":
|
||||
assert matched_node_pattern is not None
|
||||
if matched_node_pattern is None:
|
||||
raise AssertionError("matched_node_pattern must not be None")
|
||||
# add matched nodes to the observed node name set
|
||||
_add_matched_node_name_to_set(
|
||||
matched_node_pattern, observed_node_names
|
||||
|
|
@ -2064,8 +2096,10 @@ def prepare(
|
|||
)
|
||||
backend_config = BackendConfig.from_dict(backend_config)
|
||||
|
||||
assert isinstance(qconfig_mapping, QConfigMapping)
|
||||
assert isinstance(_equalization_config, QConfigMapping)
|
||||
if not isinstance(qconfig_mapping, QConfigMapping):
|
||||
raise AssertionError("qconfig_mapping must be a QConfigMapping")
|
||||
if not isinstance(_equalization_config, QConfigMapping):
|
||||
raise AssertionError("_equalization_config must be a QConfigMapping")
|
||||
qconfig_mapping = copy.deepcopy(qconfig_mapping)
|
||||
_equalization_config = copy.deepcopy(_equalization_config)
|
||||
|
||||
|
|
@ -2194,10 +2228,11 @@ def prepare(
|
|||
)
|
||||
|
||||
if is_standalone_module:
|
||||
assert result_node is not None
|
||||
assert isinstance(result_node.args[0], Node), (
|
||||
"standalone module only supports returning simple value currently"
|
||||
"(not tuple, dict etc.)"
|
||||
if result_node is None:
|
||||
raise AssertionError("result_node must not be None for standalone modules")
|
||||
if not isinstance(result_node.args[0], Node):
|
||||
raise AssertionError(
|
||||
"standalone module only supports returning simple value currently (not tuple, dict etc.)"
|
||||
)
|
||||
# these inputs are observed in parent
|
||||
# converting List[int] to Tensor since module attribute is
|
||||
|
|
|
|||
|
|
@ -228,9 +228,10 @@ def _compare_prepare_convert_qconfig_mappings(
|
|||
`prepare_qconfig_mapping`: configuration for prepare quantization step
|
||||
`convert_qconfig_mapping`: configuration for convert quantization step
|
||||
"""
|
||||
assert qconfig_equals(
|
||||
if not qconfig_equals(
|
||||
prepare_qconfig_mapping.global_qconfig, convert_qconfig_mapping.global_qconfig
|
||||
), (
|
||||
):
|
||||
raise AssertionError(
|
||||
"Expected global qconfigs to be the same in the prepare and convert quantization configs"
|
||||
)
|
||||
prepare_dicts: list[OrderedDict] = [
|
||||
|
|
@ -250,15 +251,16 @@ def _compare_prepare_convert_qconfig_mappings(
|
|||
]
|
||||
for i in range(len(prepare_dicts)):
|
||||
for name in prepare_dicts[i].keys():
|
||||
assert name in convert_dicts[i], (
|
||||
f"Missing key {dict_names[i]} {name} in convert QConfigMapping \
|
||||
when it was present in prepare"
|
||||
if name not in convert_dicts[i]:
|
||||
raise AssertionError(
|
||||
f"Missing key {dict_names[i]} {name} in convert QConfigMapping when it was present in prepare"
|
||||
)
|
||||
assert convert_dicts[i][name] is None or qconfig_equals(
|
||||
if convert_dicts[i][name] is not None and not qconfig_equals(
|
||||
prepare_dicts[i][name], convert_dicts[i][name]
|
||||
), (
|
||||
f"Expected convert QConfigMapping to have the same qconfig as prepare for key {dict_names[i]} {name}; \
|
||||
prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}"
|
||||
):
|
||||
raise AssertionError(
|
||||
"Expected convert QConfigMapping to have the same qconfig as prepare for key "
|
||||
f"{dict_names[i]} {name}; prepare: {prepare_dicts[i][name]}; convert: {convert_dicts[i][name]}"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -119,7 +119,8 @@ def _get_quantize_handler_cls(
|
|||
):
|
||||
super().__init__(node_pattern, modules, root_node_getter)
|
||||
if num_tensor_args_to_observation_type:
|
||||
assert self.num_tensor_args in num_tensor_args_to_observation_type, (
|
||||
if self.num_tensor_args not in num_tensor_args_to_observation_type:
|
||||
raise AssertionError(
|
||||
f"Must provide observation_type config for tensor number {self.num_tensor_args}"
|
||||
f" in num_tensor_args_to_observation_type for {node_pattern}"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -165,7 +165,8 @@ def get_qconv_prepack_op(conv_op: Callable) -> Callable:
|
|||
torch.nn.functional.conv_transpose3d: torch.ops.quantized.conv_transpose3d_prepack,
|
||||
}
|
||||
prepack_op = prepack_ops.get(conv_op)
|
||||
assert prepack_op, f"Didn't find prepack op for {conv_op}"
|
||||
if prepack_op is None:
|
||||
raise AssertionError(f"Didn't find prepack op for {conv_op}")
|
||||
return prepack_op
|
||||
|
||||
|
||||
|
|
@ -230,7 +231,8 @@ def graph_module_from_producer_nodes(
|
|||
Return:
|
||||
A graph module constructed from the producer nodes
|
||||
"""
|
||||
assert len(producer_nodes) > 0, "list of producer nodes can not be empty"
|
||||
if len(producer_nodes) == 0:
|
||||
raise AssertionError("list of producer nodes can not be empty")
|
||||
# since we traced back from node to getattr
|
||||
producer_nodes.reverse()
|
||||
graph = Graph()
|
||||
|
|
@ -300,7 +302,8 @@ def all_node_args_have_no_tensors(
|
|||
elif node.op == "placeholder":
|
||||
result = False
|
||||
elif node.op == "call_module":
|
||||
assert isinstance(node.target, str)
|
||||
if not isinstance(node.target, str):
|
||||
raise AssertionError("node.target must be a string for call_module nodes")
|
||||
if _is_activation_post_process(modules[node.target]):
|
||||
result = all_node_args_have_no_tensors(node.args[0], modules, cache) # type: ignore[arg-type]
|
||||
elif node.op == "call_module":
|
||||
|
|
@ -503,9 +506,10 @@ def _is_custom_module_lstm(
|
|||
"""
|
||||
mod = _get_module(node, named_modules)
|
||||
if qconfig is not None and qhandler is not None:
|
||||
assert isinstance(
|
||||
if not isinstance(
|
||||
qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler
|
||||
) # type: ignore[attr-defined]
|
||||
): # type: ignore[attr-defined]
|
||||
raise AssertionError("qhandler must be a QuantizeHandler when provided")
|
||||
return (
|
||||
isinstance(mod, torch.nn.LSTM)
|
||||
and activation_is_statically_quantized(qconfig)
|
||||
|
|
@ -527,9 +531,10 @@ def _is_custom_module_mha(
|
|||
"""
|
||||
mod = _get_module(node, named_modules)
|
||||
if qconfig is not None and qhandler is not None:
|
||||
assert isinstance(
|
||||
if not isinstance(
|
||||
qhandler, torch.ao.quantization.fx.quantize_handler.QuantizeHandler
|
||||
) # type: ignore[attr-defined]
|
||||
): # type: ignore[attr-defined]
|
||||
raise AssertionError("qhandler must be a QuantizeHandler when provided")
|
||||
return (
|
||||
isinstance(mod, torch.nn.MultiheadAttention)
|
||||
and activation_is_statically_quantized(qconfig)
|
||||
|
|
@ -826,10 +831,16 @@ def _reroute_tuple_getitem_pattern(graph: Graph):
|
|||
for pattern in matched_patterns:
|
||||
first_tuple = pattern[0]
|
||||
last_getitem = pattern[-1]
|
||||
assert first_tuple.op == "call_function" and first_tuple.target is tuple
|
||||
assert (
|
||||
if not (first_tuple.op == "call_function" and first_tuple.target is tuple):
|
||||
raise AssertionError(
|
||||
"first tuple node must be a call_function with target tuple"
|
||||
)
|
||||
if not (
|
||||
last_getitem.op == "call_function"
|
||||
and last_getitem.target == operator.getitem
|
||||
):
|
||||
raise AssertionError(
|
||||
"last getitem node must be a call_function with target operator.getitem"
|
||||
)
|
||||
last_getitem_index = last_getitem.args[1]
|
||||
new_input = first_tuple.args[0][last_getitem_index] # type: ignore[index]
|
||||
|
|
@ -847,7 +858,10 @@ def _get_observer_from_activation_post_process(
|
|||
if isinstance(activation_post_process, ObserverBase):
|
||||
return activation_post_process
|
||||
else:
|
||||
assert isinstance(activation_post_process, FakeQuantizeBase)
|
||||
if not isinstance(activation_post_process, FakeQuantizeBase):
|
||||
raise AssertionError(
|
||||
"activation_post_process must be an ObserverBase or FakeQuantizeBase"
|
||||
)
|
||||
return activation_post_process.activation_post_process # type: ignore[return-value]
|
||||
|
||||
|
||||
|
|
@ -966,7 +980,10 @@ def _qconfig_satisfies_dtype_config_constraints(
|
|||
satisfies_constraints = True
|
||||
if activation_post_process_ctr is not None:
|
||||
activation_post_process = activation_post_process_ctr()
|
||||
assert _is_activation_post_process(activation_post_process)
|
||||
if not _is_activation_post_process(activation_post_process):
|
||||
raise AssertionError(
|
||||
"activation_post_process must be an activation post process"
|
||||
)
|
||||
# If dtypes don't match, don't check the activation_post_process and return True early
|
||||
if activation_post_process.dtype != dtype_with_constraints.dtype:
|
||||
return True
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user