[AO] fix per token block size calculation (#150890)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150890
Approved by: https://github.com/jerryzh168
This commit is contained in:
Max Ren 2025-04-08 17:06:57 -07:00 committed by PyTorch MergeBot
parent c59aaa03ff
commit 6fb089f2a2

View File

@ -493,6 +493,7 @@ class MinMaxObserver(UniformQuantizationObserverBase):
.. note:: If the running minimum equals to the running maximum, the scale .. note:: If the running minimum equals to the running maximum, the scale
and zero_point are set to 1.0 and 0. and zero_point are set to 1.0 and 0.
""" """
min_val: torch.Tensor min_val: torch.Tensor
max_val: torch.Tensor max_val: torch.Tensor
@ -702,6 +703,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
.. note:: If the running minimum equals to the running maximum, the scales .. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0. and zero_points are set to 1.0 and 0.
""" """
min_val: torch.Tensor min_val: torch.Tensor
max_val: torch.Tensor max_val: torch.Tensor
@ -997,6 +999,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
3. Compute the scale and zero point the same way as in the 3. Compute the scale and zero point the same way as in the
:class:`~torch.ao.quantization.MinMaxObserver` :class:`~torch.ao.quantization.MinMaxObserver`
""" """
histogram: torch.Tensor histogram: torch.Tensor
min_val: torch.Tensor min_val: torch.Tensor
max_val: torch.Tensor max_val: torch.Tensor
@ -1524,6 +1527,7 @@ class RecordingObserver(ObserverBase):
qscheme: Quantization scheme to be used qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit reduce_range: Reduces the range of the quantized data type by 1 bit
""" """
__annotations__ = {"tensor_val": list[Optional[torch.Tensor]]} __annotations__ = {"tensor_val": list[Optional[torch.Tensor]]}
def __init__(self, dtype=torch.quint8): def __init__(self, dtype=torch.quint8):
@ -1790,7 +1794,7 @@ def get_block_size(
), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}" ), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
return (1, granularity.group_size) return (1, granularity.group_size)
elif isinstance(granularity, PerToken): elif isinstance(granularity, PerToken):
block_size = list(input_shape) block_size = [1] * len(input_shape)
block_size[-1] = input_shape[-1] block_size[-1] = input_shape[-1]
return tuple(block_size) return tuple(block_size)
raise ValueError(f"Unsupported Granularity: {granularity}") raise ValueError(f"Unsupported Granularity: {granularity}")