mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[AO] fix per token block size calculation (#150890)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150890 Approved by: https://github.com/jerryzh168
This commit is contained in:
parent
c59aaa03ff
commit
6fb089f2a2
|
|
@ -493,6 +493,7 @@ class MinMaxObserver(UniformQuantizationObserverBase):
|
||||||
.. note:: If the running minimum equals to the running maximum, the scale
|
.. note:: If the running minimum equals to the running maximum, the scale
|
||||||
and zero_point are set to 1.0 and 0.
|
and zero_point are set to 1.0 and 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
min_val: torch.Tensor
|
min_val: torch.Tensor
|
||||||
max_val: torch.Tensor
|
max_val: torch.Tensor
|
||||||
|
|
||||||
|
|
@ -702,6 +703,7 @@ class PerChannelMinMaxObserver(UniformQuantizationObserverBase):
|
||||||
.. note:: If the running minimum equals to the running maximum, the scales
|
.. note:: If the running minimum equals to the running maximum, the scales
|
||||||
and zero_points are set to 1.0 and 0.
|
and zero_points are set to 1.0 and 0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
min_val: torch.Tensor
|
min_val: torch.Tensor
|
||||||
max_val: torch.Tensor
|
max_val: torch.Tensor
|
||||||
|
|
||||||
|
|
@ -997,6 +999,7 @@ class HistogramObserver(UniformQuantizationObserverBase):
|
||||||
3. Compute the scale and zero point the same way as in the
|
3. Compute the scale and zero point the same way as in the
|
||||||
:class:`~torch.ao.quantization.MinMaxObserver`
|
:class:`~torch.ao.quantization.MinMaxObserver`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
histogram: torch.Tensor
|
histogram: torch.Tensor
|
||||||
min_val: torch.Tensor
|
min_val: torch.Tensor
|
||||||
max_val: torch.Tensor
|
max_val: torch.Tensor
|
||||||
|
|
@ -1524,6 +1527,7 @@ class RecordingObserver(ObserverBase):
|
||||||
qscheme: Quantization scheme to be used
|
qscheme: Quantization scheme to be used
|
||||||
reduce_range: Reduces the range of the quantized data type by 1 bit
|
reduce_range: Reduces the range of the quantized data type by 1 bit
|
||||||
"""
|
"""
|
||||||
|
|
||||||
__annotations__ = {"tensor_val": list[Optional[torch.Tensor]]}
|
__annotations__ = {"tensor_val": list[Optional[torch.Tensor]]}
|
||||||
|
|
||||||
def __init__(self, dtype=torch.quint8):
|
def __init__(self, dtype=torch.quint8):
|
||||||
|
|
@ -1790,7 +1794,7 @@ def get_block_size(
|
||||||
), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
|
), f"Expecting input shape dim to be 2 for per group quantization, gotinput shape: {input_shape}"
|
||||||
return (1, granularity.group_size)
|
return (1, granularity.group_size)
|
||||||
elif isinstance(granularity, PerToken):
|
elif isinstance(granularity, PerToken):
|
||||||
block_size = list(input_shape)
|
block_size = [1] * len(input_shape)
|
||||||
block_size[-1] = input_shape[-1]
|
block_size[-1] = input_shape[-1]
|
||||||
return tuple(block_size)
|
return tuple(block_size)
|
||||||
raise ValueError(f"Unsupported Granularity: {granularity}")
|
raise ValueError(f"Unsupported Granularity: {granularity}")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user