mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63058 Dedicating separate tests for different quantization methods. Currently supporting FP16 method. ghstack-source-id: 136499767 Test Plan: uck test mode/dev //caffe2/test/distributed/algorithms/quantization:quantization_gloo_fork -- name_of_the_test Reviewed By: wanchaol Differential Revision: D30142580 fbshipit-source-id: 3aacec1a231a662067d2b48c001f0c69fefcdd60
129 lines
4.6 KiB
Python
129 lines
4.6 KiB
Python
import functools
|
|
import torch
|
|
import torch.distributed as dist
|
|
|
|
|
|
from enum import Enum
|
|
|
|
|
|
TORCH_HALF_MIN = torch.finfo(torch.float16).min
|
|
TORCH_HALF_MAX = torch.finfo(torch.float16).max
|
|
|
|
class DQuantType(Enum):
|
|
FP16 = "fp16"
|
|
|
|
def __str__(self) -> str:
|
|
return self.value
|
|
|
|
|
|
def _fp32_to_fp16_with_clamp(tensor: torch.Tensor) -> torch.Tensor:
|
|
return torch.clamp(tensor, TORCH_HALF_MIN, TORCH_HALF_MAX).half()
|
|
|
|
def _quantize_tensor(tensor, qtype):
|
|
if not isinstance(tensor, torch.Tensor):
|
|
raise RuntimeError(
|
|
f"_quantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
|
|
)
|
|
if (qtype == DQuantType.FP16):
|
|
return _fp32_to_fp16_with_clamp(tensor)
|
|
else:
|
|
raise RuntimeError(
|
|
f'Quantization type {qtype} is not supported'
|
|
)
|
|
|
|
def _quantize_tensor_list(tensor_list, qtype):
|
|
if not isinstance(tensor_list, list) or not all(
|
|
isinstance(p, torch.Tensor) for p in tensor_list
|
|
):
|
|
raise RuntimeError(
|
|
f"_quantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
|
|
)
|
|
if (qtype == DQuantType.FP16):
|
|
quantized_tensor_list = [_quantize_tensor(t, qtype) for t in tensor_list]
|
|
return quantized_tensor_list
|
|
else:
|
|
raise RuntimeError(
|
|
f'Quantization type {qtype} is not supported'
|
|
)
|
|
|
|
def _dequantize_tensor(tensor, qtype, quant_loss=None):
|
|
if not isinstance(tensor, torch.Tensor):
|
|
raise RuntimeError(
|
|
f"_dequantize_tensor expecting torch.Tensor as input but found {type(tensor)}"
|
|
)
|
|
if (qtype == DQuantType.FP16):
|
|
if tensor.dtype != torch.float16:
|
|
raise RuntimeError(
|
|
f"tensor dtype is {tensor.dtype} while expected to be FP16."
|
|
)
|
|
elif tensor.dtype == torch.float16 and quant_loss is None:
|
|
return tensor.float()
|
|
else:
|
|
return tensor.float() / quant_loss
|
|
else:
|
|
raise RuntimeError(
|
|
f'Quantization type {qtype} is not supported'
|
|
)
|
|
|
|
|
|
def _dequantize_tensor_list(tensor_list, qtype, quant_loss=None):
|
|
if not isinstance(tensor_list, list) or not all(
|
|
isinstance(p, torch.Tensor) for p in tensor_list
|
|
):
|
|
raise RuntimeError(
|
|
f"_dequantize_tensor_list expecting list of torch.Tensor as input but found {type(tensor_list)}"
|
|
)
|
|
elif (qtype == DQuantType.FP16):
|
|
dequantized_tensor_list = [_dequantize_tensor(t, qtype) for t in tensor_list]
|
|
return dequantized_tensor_list
|
|
else:
|
|
raise RuntimeError(
|
|
f'Quantization type {qtype} is not supported'
|
|
)
|
|
|
|
|
|
def auto_quantize(func, qtype, quant_loss=None):
|
|
"""
|
|
This is a prototype API that automatically quantize the input tensors, choose the precision types, and
|
|
pass other necessary arguments and then dequantizes the output.
|
|
Currently it only supports:
|
|
. FP16 quantization method
|
|
. all_gather, all_to_all collective ops
|
|
Args:
|
|
func (callable): A function representing collective operations.
|
|
qtype (QuantType): Quantization method
|
|
quant_loss (float, optional): This can be used to improve accuracy in the dequantization.
|
|
Returns:
|
|
(callable): the same collective as func but enables automatic quantization/dequantization.
|
|
"""
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
group = kwargs.get('group', None)
|
|
async_op = kwargs.get('async_op', False)
|
|
if (async_op is True):
|
|
raise RuntimeError(
|
|
'The async_op=True mode is not supported yet.'
|
|
)
|
|
if (func == dist.all_gather):
|
|
tensors = args[0]
|
|
input_tensors = _quantize_tensor(args[1], qtype)
|
|
out_tensors = _quantize_tensor_list(tensors, qtype)
|
|
dist.all_gather(out_tensors, input_tensors, group=group, async_op=async_op)
|
|
for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
|
|
tensors[i] = t
|
|
|
|
elif (func == dist.all_to_all):
|
|
tensors = args[0]
|
|
input_tensors = _quantize_tensor_list(args[1], qtype)
|
|
out_tensors = _quantize_tensor_list(tensors, qtype)
|
|
dist.all_to_all(out_tensors, input_tensors, group=group, async_op=async_op)
|
|
for i, t in enumerate(_dequantize_tensor_list(out_tensors, qtype, quant_loss=quant_loss)):
|
|
tensors[i] = t
|
|
|
|
else:
|
|
raise RuntimeError(
|
|
f"The collective op {func} is not supported yet"
|
|
)
|
|
|
|
return wrapper
|