mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
bf16 support for fused_moving_avg_obs_fake_quant() op (#162620)
enabling bf16 support for `torch.fused_moving_avg_obs_fake_quant()` op on cuda **testing** `python test/quantization/pt2e/test_quantize_pt2e.py` Pull Request resolved: https://github.com/pytorch/pytorch/pull/162620 Approved by: https://github.com/andrewor14, https://github.com/jerryzh168
This commit is contained in:
parent
c230ac7300
commit
9494b09549
|
|
@ -1,5 +1,6 @@
|
|||
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/ceil_div.h>
|
||||
#include <ATen/native/cuda/Loops.cuh>
|
||||
#include <c10/cuda/CUDAGuard.h>
|
||||
|
|
@ -21,10 +22,11 @@
|
|||
namespace at::native {
|
||||
|
||||
namespace {
|
||||
template <typename T>
|
||||
__global__ void ChooseQuantizationParamsKernelImpl(
|
||||
const int64_t* fake_quant_on,
|
||||
const float* x_min,
|
||||
const float* x_max,
|
||||
const T* x_min,
|
||||
const T* x_max,
|
||||
int32_t qmin,
|
||||
int32_t qmax,
|
||||
int size,
|
||||
|
|
@ -93,34 +95,44 @@ __global__ void ChooseQuantizationParamsKernelImpl(
|
|||
}
|
||||
}
|
||||
|
||||
__device__ inline bool isinf_device(float v) {
|
||||
return ::isinf(v);
|
||||
}
|
||||
__device__ inline bool isinf_device(c10::BFloat16 v) {
|
||||
return ::isinf(static_cast<float>(v));
|
||||
}
|
||||
|
||||
// CUDA kernel to compute Moving Average Min/Max of the tensor.
|
||||
// It uses the running_min and running_max along with averaging const, c.
|
||||
// The formula used to compute the new min/max is as follows
|
||||
//
|
||||
// running_min = (1 - c) * running_min + c * x_min, if running_min != inf
|
||||
// running_min = x_min, if running_min == inf
|
||||
template <typename T>
|
||||
__global__ void MovingAverageMinMax(
|
||||
const int64_t* observer_on,
|
||||
const float* x_min,
|
||||
const float* x_max,
|
||||
float* running_min,
|
||||
float* running_max,
|
||||
const T* x_min,
|
||||
const T* x_max,
|
||||
T* running_min,
|
||||
T* running_max,
|
||||
const float averaging_const,
|
||||
const int size) {
|
||||
int i = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (*observer_on == 1) {
|
||||
if (i < size) {
|
||||
float curr_min = x_min[i];
|
||||
float curr_max = x_max[i];
|
||||
T curr_min = x_min[i];
|
||||
T curr_max = x_max[i];
|
||||
|
||||
float adjusted_min = ::isinf(running_min[i])
|
||||
? curr_min
|
||||
: (running_min[i]) + averaging_const * (curr_min - (running_min[i]));
|
||||
T averaging_const_t = static_cast<T>(averaging_const);
|
||||
|
||||
float adjusted_max = ::isinf(running_max[i])
|
||||
? curr_max
|
||||
: (running_max[i]) + averaging_const * (curr_max - (running_max[i]));
|
||||
T adjusted_min = isinf_device(running_min[i]) ? curr_min
|
||||
: (running_min[i]) +
|
||||
averaging_const_t * (curr_min - (running_min[i]));
|
||||
|
||||
T adjusted_max = isinf_device(running_max[i]) ? curr_max
|
||||
: (running_max[i]) +
|
||||
averaging_const_t * (curr_max - (running_max[i]));
|
||||
|
||||
running_min[i] = adjusted_min;
|
||||
running_max[i] = adjusted_max;
|
||||
|
|
@ -142,40 +154,51 @@ void _calculate_moving_average(
|
|||
at::Tensor x_min, x_max;
|
||||
|
||||
int64_t* observer_on_data = observer_on.data_ptr<int64_t>();
|
||||
float* running_min_data = running_min.data_ptr<float>();
|
||||
float* running_max_data = running_max.data_ptr<float>();
|
||||
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
||||
|
||||
if (per_row_fq) {
|
||||
std::tie(x_min, x_max) = at::aminmax(x, 1);
|
||||
float* x_min_data = x_min.data_ptr<float>();
|
||||
float* x_max_data = x_max.data_ptr<float>();
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
|
||||
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
|
||||
|
||||
// Moving Average Min/Max observer for activations
|
||||
MovingAverageMinMax<<<num_blocks, num_threads, 0, cuda_stream>>>(
|
||||
observer_on_data,
|
||||
x_min_data,
|
||||
x_max_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
averaging_const,
|
||||
size);
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
|
||||
// Moving Average Min/Max observer for activations
|
||||
MovingAverageMinMax<<<num_blocks, num_threads, 0, cuda_stream>>>(
|
||||
observer_on_data,
|
||||
x_min_data,
|
||||
x_max_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
averaging_const,
|
||||
size);
|
||||
});
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
std::tie(x_min, x_max) = at::aminmax(x);
|
||||
float* x_min_data = x_min.data_ptr<float>();
|
||||
float* x_max_data = x_max.data_ptr<float>();
|
||||
// Moving Average Min/Max observer for activations
|
||||
MovingAverageMinMax<<<1, 1, 0, cuda_stream>>>(
|
||||
observer_on_data,
|
||||
x_min_data,
|
||||
x_max_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
averaging_const,
|
||||
1 /*size*/);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
|
||||
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
|
||||
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
|
||||
// Moving Average Min/Max observer for activations
|
||||
MovingAverageMinMax<<<1, 1, 0, cuda_stream>>>(
|
||||
observer_on_data,
|
||||
x_min_data,
|
||||
x_max_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
averaging_const,
|
||||
1 /*size*/);
|
||||
});
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
}
|
||||
|
|
@ -198,34 +221,44 @@ void _calc_moving_avg_qparams_helper(
|
|||
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
|
||||
int64_t* fake_quant_on_data = fake_quant_on.data_ptr<int64_t>();
|
||||
if (per_row_fq) {
|
||||
float* running_min_data = running_min.data_ptr<float>();
|
||||
float* running_max_data = running_max.data_ptr<float>();
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
|
||||
ChooseQuantizationParamsKernelImpl<<<num_blocks, num_threads, 0, cuda_stream>>>(
|
||||
fake_quant_on_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
qmin,
|
||||
qmax,
|
||||
size,
|
||||
symmetric_quant,
|
||||
scale_ptr,
|
||||
zp_ptr);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
int num_threads = std::min(size, (int64_t)512);
|
||||
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
|
||||
ChooseQuantizationParamsKernelImpl<<<
|
||||
num_blocks,
|
||||
num_threads,
|
||||
0,
|
||||
cuda_stream>>>(
|
||||
fake_quant_on_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
qmin,
|
||||
qmax,
|
||||
size,
|
||||
symmetric_quant,
|
||||
scale_ptr,
|
||||
zp_ptr);
|
||||
});
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
} else {
|
||||
float* running_min_data = running_min.data_ptr<float>();
|
||||
float* running_max_data = running_max.data_ptr<float>();
|
||||
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(
|
||||
fake_quant_on_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
qmin,
|
||||
qmax,
|
||||
1, // size
|
||||
symmetric_quant, // preserve_sparsity
|
||||
scale_ptr,
|
||||
zp_ptr);
|
||||
AT_DISPATCH_FLOATING_TYPES_AND(
|
||||
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
|
||||
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
|
||||
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
|
||||
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(
|
||||
fake_quant_on_data,
|
||||
running_min_data,
|
||||
running_max_data,
|
||||
qmin,
|
||||
qmax,
|
||||
1, // size
|
||||
symmetric_quant, // preserve_sparsity
|
||||
scale_ptr,
|
||||
zp_ptr);
|
||||
});
|
||||
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -908,6 +908,19 @@ class TestFakeQuantize(TestCase):
|
|||
self.assertEqual(fq_module.activation_post_process.quant_min, 0)
|
||||
self.assertEqual(fq_module.activation_post_process.quant_max, 127)
|
||||
|
||||
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']))
|
||||
def test_fused_moving_avg_obs_fake_quant(self, device):
|
||||
try:
|
||||
sampled_dtype = st.sampled_from(["bf16", "fp32"]) if device == "cuda" else "fp32"
|
||||
dtype = torch.bfloat16 if sampled_dtype == "bf16" else torch.float32
|
||||
torch.set_default_dtype(dtype)
|
||||
|
||||
with torch.device(device):
|
||||
fake_quantize = FusedMovingAvgObsFakeQuantize()
|
||||
fake_quantize.forward(torch.rand((256, 512)))
|
||||
finally:
|
||||
torch.set_default_dtype(torch.float32)
|
||||
|
||||
def _get_buffer_ids(module):
|
||||
"""
|
||||
Object addresses stay constant if and only if all modifications are in-place
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@
|
|||
|
||||
import torch
|
||||
import math
|
||||
from typing import Union
|
||||
from torch.ao.quantization import (
|
||||
FakeQuantize,
|
||||
MovingAverageMinMaxObserver,
|
||||
|
|
@ -155,18 +156,25 @@ def _fake_quantize_learnable_per_channel_affine_grad_reference(
|
|||
|
||||
def _get_tensor_min_max(
|
||||
X: torch.Tensor,
|
||||
running_min: float = float("inf"),
|
||||
running_max: float = float("-inf"),
|
||||
averaging_const: float = 0.01) -> tuple[float, float]:
|
||||
min_val = X.min().to(dtype=torch.float32).item()
|
||||
max_val = X.max().to(dtype=torch.float32).item()
|
||||
running_min: Union[float, torch.Tensor] = float("inf"),
|
||||
running_max: Union[float, torch.Tensor] = float("-inf"),
|
||||
averaging_const: float = 0.01,
|
||||
dtype: torch.dtype = torch.float32) -> tuple[float, float]:
|
||||
min_val_tensor = X.min().to(dtype=dtype)
|
||||
max_val_tensor = X.max().to(dtype=dtype)
|
||||
averaging_const_tensor = torch.tensor(averaging_const, dtype=dtype).item()
|
||||
|
||||
if not math.isinf(running_min):
|
||||
min_val = running_min + averaging_const * (min_val - running_min)
|
||||
if not math.isinf(running_max):
|
||||
max_val = running_max + averaging_const * (max_val - running_max)
|
||||
if not isinstance(running_min, torch.Tensor):
|
||||
running_min = torch.tensor(running_min, dtype=dtype)
|
||||
if not isinstance(running_max, torch.Tensor):
|
||||
running_max = torch.tensor(running_max, dtype=dtype)
|
||||
|
||||
return min_val, max_val
|
||||
if not torch.isinf(running_min):
|
||||
min_val_tensor = running_min + averaging_const_tensor * (min_val_tensor - running_min)
|
||||
if not torch.isinf(running_max):
|
||||
max_val_tensor = running_max + averaging_const_tensor * (max_val_tensor - running_max)
|
||||
|
||||
return min_val_tensor.item(), max_val_tensor.item()
|
||||
|
||||
def _get_per_row_min_max(
|
||||
x: torch.Tensor,
|
||||
|
|
@ -1064,10 +1072,13 @@ class TestFusedObsFakeQuant(TestCase):
|
|||
Tests the case where we call the fused_obs_fake_quant op multiple times
|
||||
and update the running_min and max of the activation tensors.
|
||||
"""
|
||||
in_running_min_ref = out_running_min_ref = float("inf")
|
||||
in_running_min_op = torch.tensor(float("inf"), device=device)
|
||||
in_running_max_ref = out_running_max_ref = float("-inf")
|
||||
in_running_max_op = torch.tensor(float("-inf"), device=device)
|
||||
sampled_dtype = st.sampled_from(["bf16", "fp32"]) if device == "cuda" else "fp32"
|
||||
dtype = torch.bfloat16 if sampled_dtype == "bf16" else torch.float32
|
||||
|
||||
in_running_min_ref = out_running_min_ref = torch.tensor(float("inf"), dtype=dtype)
|
||||
in_running_min_op = torch.tensor(float("inf"), dtype=dtype, device=device)
|
||||
in_running_max_ref = out_running_max_ref = torch.tensor(float("-inf"), dtype=dtype)
|
||||
in_running_max_op = torch.tensor(float("-inf"), dtype=dtype, device=device)
|
||||
avg_const = 0.01
|
||||
scale = torch.tensor([1.0], device=device)
|
||||
zero_point = torch.tensor([0], dtype=torch.int, device=device)
|
||||
|
|
@ -1080,8 +1091,7 @@ class TestFusedObsFakeQuant(TestCase):
|
|||
observer_on = True if use_bool else 1
|
||||
if i > 4:
|
||||
fake_quant_on = True if use_bool else 1
|
||||
|
||||
x = torch.randn(5, 5, device=device)
|
||||
x = torch.randn(5, 5, dtype=dtype, device=device)
|
||||
out = pt_op(
|
||||
x,
|
||||
torch.tensor(observer_on, device=device),
|
||||
|
|
@ -1106,6 +1116,7 @@ class TestFusedObsFakeQuant(TestCase):
|
|||
running_min=in_running_min_ref,
|
||||
running_max=in_running_max_ref,
|
||||
averaging_const=0.01,
|
||||
dtype=dtype,
|
||||
)
|
||||
|
||||
if fake_quant_on:
|
||||
|
|
@ -1128,7 +1139,7 @@ class TestFusedObsFakeQuant(TestCase):
|
|||
torch.testing.assert_close(out, x_in)
|
||||
|
||||
# Test empty input works
|
||||
x = torch.empty(0, 5, device=device)
|
||||
x = torch.empty(0, 5, dtype=dtype, device=device)
|
||||
out = pt_op(
|
||||
x,
|
||||
torch.tensor(1, device=device),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user