diff --git a/torch/ao/quantization/observer.py b/torch/ao/quantization/observer.py index 6a39bdc0fc3..673d52e8924 100644 --- a/torch/ao/quantization/observer.py +++ b/torch/ao/quantization/observer.py @@ -493,6 +493,7 @@ class MinMaxObserver(UniformQuantizationObserverBase): .. note:: If the running minimum equals to the running maximum, the scale and zero_point are set to 1.0 and 0. """ + min_val: torch.Tensor max_val: torch.Tensor @@ -702,6 +703,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase): .. note:: If the running minimum equals to the running maximum, the scales and zero_points are set to 1.0 and 0. """ + min_val: torch.Tensor max_val: torch.Tensor @@ -997,6 +999,7 @@ class HistogramObserver(UniformQuantizationObserverBase): 3. Compute the scale and zero point the same way as in the :class:`~torch.ao.quantization.MinMaxObserver` """ + histogram: torch.Tensor min_val: torch.Tensor max_val: torch.Tensor @@ -1524,6 +1527,7 @@ class RecordingObserver(ObserverBase): qscheme: Quantization scheme to be used reduce_range: Reduces the range of the quantized data type by 1 bit """ + __annotations__ = {"tensor_val": list[Optional[torch.Tensor]]} def __init__(self, dtype=torch.quint8): @@ -1790,7 +1794,7 @@ def get_block_size( ), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" return (1, granularity.group_size) elif isinstance(granularity, PerToken): - block_size = list(input_shape) + block_size = [1] * len(input_shape) block_size[-1] = input_shape[-1] return tuple(block_size) raise ValueError(f"Unsupported Granularity: {granularity}")