# mypy: ignore-errors r"""Importing this file includes common utility methods for checking quantized tensors and modules. """ import numpy as np import torch from torch import Tensor from contextlib import contextmanager from torch.testing._internal.common_utils import TEST_WITH_TSAN, IS_PPC, IS_MACOS, IS_WINDOWS supported_qengines = torch.backends.quantized.supported_engines supported_qengines.remove('none') # Note: We currently do not run QNNPACK tests on WINDOWS and MACOS as it is flaky. Issue #29326 # QNNPACK is not supported on PPC if 'qnnpack' in supported_qengines and any([IS_PPC, TEST_WITH_TSAN, IS_MACOS, IS_WINDOWS]): supported_qengines.remove('qnnpack') def _conv_output_shape(input_size, kernel_size, padding, stride, dilation, output_padding=0): """Computes the output shape given convolution parameters.""" return np.floor((input_size + 2 * padding - kernel_size - (kernel_size - 1) * (dilation - 1)) / stride) + 2 * output_padding + 1 # Quantization references def _quantize(x, scale, zero_point, qmin=None, qmax=None, dtype=np.uint8): """Quantizes a numpy array.""" if qmin is None: qmin = np.iinfo(dtype).min if qmax is None: qmax = np.iinfo(dtype).max qx = np.round(x / scale + zero_point).astype(np.int64) qx = np.clip(qx, qmin, qmax) qx = qx.astype(dtype) return qx def _dequantize(qx, scale, zero_point): """Dequantizes a numpy array.""" x = (qx.astype(float) - zero_point) * scale return x def _requantize(x, multiplier, zero_point, qmin=0, qmax=255, qtype=np.uint8): """Requantizes a numpy array, i.e., intermediate int32 or int16 values are converted back to given type""" qx = (x * multiplier).round() + zero_point qx = np.clip(qx, qmin, qmax).astype(qtype) return qx def _calculate_dynamic_qparams(X, dtype, reduce_range=False, qscheme=torch.per_tensor_affine): """Calculate the dynamic quantization parameters (scale, zero_point) according to the min and max element of the tensor""" assert qscheme in (torch.per_tensor_affine, torch.per_tensor_symmetric) if qscheme == torch.per_tensor_symmetric: assert dtype == torch.qint8 if isinstance(X, torch.Tensor): X = X.numpy() if dtype == torch.qint8: if reduce_range: qmin, qmax = -64, 63 else: qmin, qmax = -128, 127 else: # dtype == torch.quint8 if reduce_range: qmin, qmax = 0, 127 else: qmin, qmax = 0, 255 min_val = X.min() max_val = X.max() is_symmetric = (qscheme == torch.per_tensor_symmetric) if min_val == max_val: scale = 1.0 zero_point = 0 else: if is_symmetric: max_val = max(max_val, -min_val) min_val = -max_val scale = (max_val - min_val) / (qmax - qmin) scale = max(scale, np.finfo(np.float32).eps) zero_point = 0 else: max_val = max(max_val, 0.0) min_val = min(min_val, 0.0) scale = (max_val - min_val) / (qmax - qmin) scale = max(scale, np.finfo(np.float32).eps) zero_point = qmin - round(min_val / scale) zero_point = max(qmin, zero_point) zero_point = min(qmax, zero_point) return [float(scale), int(zero_point)] def _calculate_dynamic_per_channel_qparams(X, dtype): """Calculate the dynamic quantization parameters (scale, zero_point) according to the min and max element of the tensor""" if isinstance(X, torch.Tensor): X = X.numpy() qmin, qmax = torch.iinfo(dtype).min, torch.iinfo(dtype).max n_levels = qmax - qmin scale = np.zeros(X.shape[0], dtype=np.float64) zero_point = np.zeros(X.shape[0], dtype=np.int64) for i in range(zero_point.shape[0]): min_val = X.min() max_val = X.max() if min_val == max_val: scale[i] = 1.0 zero_point[i] = 0 else: max_val = max(max_val, 0.0) min_val = min(min_val, 0.0) scale[i] = (max_val - min_val) / n_levels scale[i] = max(scale[i], np.finfo(np.float32).eps) zero_point[i] = qmin - round(min_val / scale[i]) zero_point[i] = max(qmin, zero_point[i]) zero_point[i] = min(qmax, zero_point[i]) return scale, zero_point def _snr(x, x_hat): """Calculates the signal to noise ratio and returns the signal and noise power, as well as the SNR in dB. If the input is a list/tuple this function is called recursively on each element. The result will have the same nested structure as the inputs. Args: x, x_hat: Either a tensor or a nested list/tuple of tensors. Returns: signal, noise, SNR(in dB): Either floats or a nested list of floats """ if isinstance(x, (list, tuple)): assert len(x) == len(x_hat) res = [_snr(x[idx], x_hat[idx]) for idx in range(len(x))] return res if x_hat.is_quantized: x_hat = x_hat.dequantize() if x.is_quantized: x = x.dequantize() noise = (x - x_hat).norm() if noise == 0: return 0.0, float('inf'), float('inf') signal = x.norm() snr = signal / noise snr_db = 20 * snr.log10() return signal, noise, snr_db @contextmanager def override_quantized_engine(qengine): previous = torch.backends.quantized.engine torch.backends.quantized.engine = qengine try: yield finally: torch.backends.quantized.engine = previous @contextmanager def override_cpu_allocator_for_qnnpack(qengine_is_qnnpack): try: if qengine_is_qnnpack: torch._C._set_default_mobile_cpu_allocator() yield finally: if qengine_is_qnnpack: torch._C._unset_default_mobile_cpu_allocator() # TODO: Update all quantization tests to use this decorator. # Currently for some of the tests it seems to have inconsistent params # for fbgemm vs qnnpack. def override_qengines(qfunction): def test_fn(*args, **kwargs): for qengine in supported_qengines: with override_quantized_engine(qengine): # qfunction should not return anything. qfunction(*args, **kwargs) return test_fn def qengine_is_fbgemm(): return torch.backends.quantized.engine == 'fbgemm' def qengine_is_qnnpack(): return torch.backends.quantized.engine == 'qnnpack' def qengine_is_onednn(): return torch.backends.quantized.engine == 'onednn' def qengine_is_x86(): return torch.backends.quantized.engine == 'x86' # Helper function used to simulate per-channel fake-quant against any axis def _permute_to_axis_zero(X, axis): new_axis_list = list(range(X.dim())) new_axis_list[axis] = 0 new_axis_list[0] = axis y = X.permute(tuple(new_axis_list)) return y, new_axis_list # Reference method for fake quantize # Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64 def _fake_quantize_per_channel_affine_reference(X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max): dtype = X.dtype X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis) res = torch.zeros_like(X) for i in range(X.size()[0]): res[i] = (torch.clamp(torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i]), quant_min, quant_max) - per_channel_zero_point[i]) * per_channel_scale[i] out = res.permute(tuple(permute_axis_list)) return out.to(dtype) # Reference method for the gradient of the fake quantize operator # Note: because scale/zero_point are left as float in the actual kernel, this mimics how fake_quant works for float16/64 def _fake_quantize_per_channel_affine_grad_reference(dY, X, per_channel_scale, per_channel_zero_point, axis, quant_min, quant_max): dtype = X.dtype X, permute_axis_list = _permute_to_axis_zero(X.to(torch.float32), axis) Xq = torch.zeros_like(X) for i in range(X.size()[0]): Xq[i] = torch.round(X[i] * (1.0 / per_channel_scale[i]) + per_channel_zero_point[i]) Xq = Xq.permute(tuple(permute_axis_list)) mask = (Xq >= quant_min) * (Xq <= quant_max) res = torch.zeros_like(dY) res[mask] = dY[mask] return res.to(dtype) def to_tensor(X, device): if not isinstance(X, torch.Tensor): X = torch.tensor(X) else: X = X.detach().clone() return X.to(device=torch.device(device), dtype=torch.float32) # copy-pasted from # https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29 def _n_ones(n: int) -> int: return (1 << n) - 1 EBITS_F32, MBITS_F32 = 8, 23 F32_EXP_BIAS = _n_ones(EBITS_F32 - 1) # copy-pasted from # https://github.com/pytorch/ao/blob/bc4f51da86956275da7db0da6e420c506df97820/torchao/prototype/custom_fp_utils.py#L27C1-L142C29 def _f32_to_floatx_unpacked(x: Tensor, ebits: int, mbits: int) -> Tensor: """Convert FP32 numbers to sub-byte floating point numbers with the given number of exponent and mantissa bits. Input: torch.Tensor of dtype torch.float Output: torch.Tensor of dtype torch.uint8, where the bit encoding is stored in the least significant bits. e.g. fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding Note: there are no special values (NaN, inf) support in this code. Values outside the representable range of Floatx after rounding are clamped to the maximum Floatx magnitude (sign is preserved). Code below is an adaptation of https://fburl.com/code/ciwofcg4 Background 1: last answer in https://stackoverflow.com/q/8981913 Background 2: Computer Organization and Design, RISC-V edition, Chapter 3.5 """ assert x.dtype == torch.float assert 1 + ebits + mbits <= 8 # calculate constants exp_bias = _n_ones(ebits - 1) max_int = _n_ones(ebits + mbits) sign_mask = 1 << (ebits + mbits) # TODO document this better magic_adder = _n_ones(MBITS_F32 - mbits - 1) # all E bits and M bits are 1s max_normal = 2 ** (_n_ones(ebits) - exp_bias) * (_n_ones(mbits + 1) / (2**mbits)) # E bits = 1, M bits = 0 min_normal = 2 ** (1 - exp_bias) denorm_exp = ( # exp bias conversion between formats (F32_EXP_BIAS - exp_bias) # mantissa length difference between formats + (MBITS_F32 - mbits) # add one to encoded exponent for denormalized numbers + 1 ) denorm_mask_int = denorm_exp << MBITS_F32 # reinterpret int32 as float32 denorm_mask_float = torch.tensor(denorm_mask_int, dtype=torch.int32).view( torch.float32 ) # save the sign # Note that we have torch.uint32, but some ops like cpu bit shifts # do not work on it. So, we stay in int32. x = x.view(torch.int32) sign = x & 0x80000000 # set everything to positive, will add sign back at the end x = x ^ sign # TODO: can the branch floating point comparisons below be done without # converting to float? probably but need to verify x = x.view(torch.float) # rewrite saturate/denorm/norm branches without explicit data dependent # control flow, to be more compiler friendly saturate_mask = x >= max_normal denormal_mask = torch.logical_and(torch.logical_not(saturate_mask), x < min_normal) normal_mask = torch.logical_not(torch.logical_or(saturate_mask, denormal_mask)) # # branch 1: saturate to max val - handled later in the code which combines # the branches # # # branch 2: to conversion to denormal as well as rounding up to normal # denormal_x = x + denorm_mask_float denormal_x = denormal_x.view(torch.int32) denormal_x -= denorm_mask_int denormal_x = denormal_x.to(torch.uint8) # # branch 3: stay in normal range, adjust the exponent and round # normal_x = x.view(torch.int32) # resulting mantissa is odd mant_odd = (normal_x >> (MBITS_F32 - mbits)) & 1 # update exponent, rounding bias part 1 val_to_add = ((exp_bias - F32_EXP_BIAS) << MBITS_F32) + magic_adder normal_x += val_to_add # rounding bias part 2 normal_x += mant_odd # take the bits! normal_x = normal_x >> (MBITS_F32 - mbits) normal_x = normal_x.to(torch.uint8) # # combine the branches # x = torch.full_like(x, max_int, dtype=torch.uint8) x = torch.where(denormal_mask, denormal_x, x) x = torch.where(normal_mask, normal_x, x) # add sign back sign_lp = sign >> (MBITS_F32 + EBITS_F32 - mbits - ebits) sign_lp = sign_lp.to(torch.uint8) # Right shift of a negative signed integer can fill the least significant # bits with either 1s or 0s, depending on the implementation. Since PyTorch # doesn't have an uint32 dtype, we mask out these bits to get just the # f4 sign bit sign_lp = sign_lp & sign_mask x = x | sign_lp return x.to(torch.uint8) # copy-pasted from # https://github.com/pytorch/ao/blob/29488018d99af7f7339f06353c6b5bbeae8a1493/torchao/prototype/custom_fp_utils.py#L147 def _floatx_unpacked_to_f32(x: Tensor, ebits: int, mbits: int) -> Tensor: """Convert sub-byte floating point numbers with the given number of exponent and mantissa bits to FP32. Input: torch.Tensor of dtype uint8, where the bit encoding is stored in the least significant bits. e.g. fp4: bits 0-3 empty and bits 4-7 in fp4_e2m1 encoding fp6: bits 0-1 empty and bits 2-7 in fp6_e2m3 or fp6_e3m2 encoding Output: torch.Tensor of dtype fp32 with the dequantized value """ assert x.dtype == torch.uint8 assert 1 + ebits + mbits <= 8 sign_mask = 1 << (ebits + mbits) exp_bias = _n_ones(ebits - 1) mantissa_mask = _n_ones(mbits) # save the sign sign_lp = x & sign_mask # set everything to positive, will add sign back at the end x_pos = x ^ sign_lp # # 1. Calculate zero mask # zero_mask = x_pos == 0 # # 2. Calculate the denormal path mask # denormal_mask = torch.logical_and((x_pos > 0), ((x_pos >> mbits) == 0)) # # 3. Calculate the normal path # # calculate the new exponent and shift it to bits 2:9 of the result exp_biased_lp = x_pos >> mbits exp_biased_f32 = exp_biased_lp - exp_bias + F32_EXP_BIAS exp_biased_f32 = exp_biased_f32.to(torch.int32) << MBITS_F32 # shift the mantissa to bits 10:32 of the result mantissa_lp_int32 = (x_pos & mantissa_mask).to(torch.int32) mantissa_f32 = mantissa_lp_int32 << (MBITS_F32 - mbits) result = exp_biased_f32 | mantissa_f32 # # 4. Add the zero and denormal casts to the already casted normal path # result[zero_mask] = 0 denormal_exp_biased = 1 - exp_bias + F32_EXP_BIAS # fast path. # without this, performance for FP4_E2M1 is slower by 2x if mbits == 1: result[denormal_mask] = (denormal_exp_biased - mbits) << MBITS_F32 else: # iterate over all possible values of mantissa # i=0, j=1 # i=1, j=10,11 # i=2, j=100,101,110,111 # and so on for i in range(mbits): for mantissa_cmp in range(1 << i, 1 << (i + 1)): # left shift mantissa until it overflows (create an implicit 1) # subtract exponent by the same amount left_shift = mbits - i mantissa_f32 = (mantissa_cmp - (1 << i)) << ( left_shift + MBITS_F32 - mbits ) exp_biased_f32 = (denormal_exp_biased - left_shift) << MBITS_F32 # we can update this in-place since the values won't overlap # torch.compile() may complain unsupported operand type(s) for |: 'SymInt' and 'int' # thus we use + instead of | here mantissa_lp_int32[mantissa_lp_int32 == mantissa_cmp] = ( exp_biased_f32 + mantissa_f32 ) result = torch.where(denormal_mask, mantissa_lp_int32, result) # add sign back sign_f32 = sign_lp.to(torch.int32) << (MBITS_F32 - mbits + EBITS_F32 - ebits) result = result | sign_f32 return result.view(torch.float) # copied from https://github.com/drisspg/transformer_nuggets/blob/main/transformer_nuggets/mx/to_blocked.py def ceil_div(a, b): return (a + b - 1) // b def to_blocked(input_matrix) -> torch.Tensor: """ Rearrange a large matrix by breaking it into blocks and applying the rearrangement pattern. See: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout Args: input_matrix: Input tensor of shape (H, W) Returns: Rearranged tensor of shape (32*ceil_div(H,128), 16*ceil_div(W,4)) """ rows, cols = input_matrix.shape n_row_blocks = ceil_div(rows, 128) n_col_blocks = ceil_div(cols, 4) # Calculate the padded shape padded_rows = n_row_blocks * 128 padded_cols = n_col_blocks * 4 padded = input_matrix # Ideally we would use torch.nn.pad but it doesn't support float8_e8m0fnu for now if (rows, cols) != (padded_rows, padded_cols): padded = torch.zeros((padded_rows, padded_cols), device=input_matrix.device, dtype=input_matrix.dtype) padded[:rows, :cols] = input_matrix # Rearrange the blocks blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3) rearranged = blocks.reshape(-1, 4, 32, 4).transpose(1, 2).reshape(-1, 32, 16) return rearranged.flatten()