[AO] fix per token block size calculation (#150890)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150890
Approved by: https://github.com/jerryzh168
This commit is contained in:
Max Ren 2025-04-08 17:06:57 -07:00 committed by PyTorch MergeBot
parent c59aaa03ff
commit 6fb089f2a2

View File

@ -493,6 +493,7 @@ class MinMaxObserver(UniformQuantizationObserverBase):
.. note:: If the running minimum equals to the running maximum, the scale
and zero_point are set to 1.0 and 0.
"""
min_val: torch.Tensor
max_val: torch.Tensor
@ -702,6 +703,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
.. note:: If the running minimum equals to the running maximum, the scales
and zero_points are set to 1.0 and 0.
"""
min_val: torch.Tensor
max_val: torch.Tensor
@ -997,6 +999,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
3. Compute the scale and zero point the same way as in the
:class:`~torch.ao.quantization.MinMaxObserver`
"""
histogram: torch.Tensor
min_val: torch.Tensor
max_val: torch.Tensor
@ -1524,6 +1527,7 @@ class RecordingObserver(ObserverBase):
qscheme: Quantization scheme to be used
reduce_range: Reduces the range of the quantized data type by 1 bit
"""
__annotations__ = {"tensor_val": list[Optional[torch.Tensor]]}
def __init__(self, dtype=torch.quint8):
@ -1790,7 +1794,7 @@ def get_block_size(
), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
return (1, granularity.group_size)
elif isinstance(granularity, PerToken):
block_size = list(input_shape)
block_size = [1] * len(input_shape)
block_size[-1] = input_shape[-1]
return tuple(block_size)
raise ValueError(f"Unsupported Granularity: {granularity}")