bf16 support for fused_moving_avg_obs_fake_quant() op (#162620)

enabling bf16 support for `torch.fused_moving_avg_obs_fake_quant()` op on cuda

**testing**
`python test/quantization/pt2e/test_quantize_pt2e.py`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/162620
Approved by: https://github.com/andrewor14, https://github.com/jerryzh168
This commit is contained in:
Angel Li 2025-09-16 21:22:44 +00:00 committed by PyTorch MergeBot
parent c230ac7300
commit 9494b09549
3 changed files with 138 additions and 81 deletions

View File

@ -1,5 +1,6 @@
#define TORCH_ASSERT_ONLY_METHOD_OPERATORS
#include <ATen/core/Tensor.h>
#include <ATen/Dispatch.h>
#include <ATen/ceil_div.h>
#include <ATen/native/cuda/Loops.cuh>
#include <c10/cuda/CUDAGuard.h>
@ -21,10 +22,11 @@
namespace at::native {
namespace {
template <typename T>
__global__ void ChooseQuantizationParamsKernelImpl(
const int64_t* fake_quant_on,
const float* x_min,
const float* x_max,
const T* x_min,
const T* x_max,
int32_t qmin,
int32_t qmax,
int size,
@ -93,34 +95,44 @@ __global__ void ChooseQuantizationParamsKernelImpl(
}
}
__device__ inline bool isinf_device(float v) {
return ::isinf(v);
}
__device__ inline bool isinf_device(c10::BFloat16 v) {
return ::isinf(static_cast<float>(v));
}
// CUDA kernel to compute Moving Average Min/Max of the tensor.
// It uses the running_min and running_max along with averaging const, c.
// The formula used to compute the new min/max is as follows
//
// running_min = (1 - c) * running_min + c * x_min, if running_min != inf
// running_min = x_min, if running_min == inf
template <typename T>
__global__ void MovingAverageMinMax(
const int64_t* observer_on,
const float* x_min,
const float* x_max,
float* running_min,
float* running_max,
const T* x_min,
const T* x_max,
T* running_min,
T* running_max,
const float averaging_const,
const int size) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (*observer_on == 1) {
if (i < size) {
float curr_min = x_min[i];
float curr_max = x_max[i];
T curr_min = x_min[i];
T curr_max = x_max[i];
float adjusted_min = ::isinf(running_min[i])
? curr_min
: (running_min[i]) + averaging_const * (curr_min - (running_min[i]));
T averaging_const_t = static_cast<T>(averaging_const);
float adjusted_max = ::isinf(running_max[i])
? curr_max
: (running_max[i]) + averaging_const * (curr_max - (running_max[i]));
T adjusted_min = isinf_device(running_min[i]) ? curr_min
: (running_min[i]) +
averaging_const_t * (curr_min - (running_min[i]));
T adjusted_max = isinf_device(running_max[i]) ? curr_max
: (running_max[i]) +
averaging_const_t * (curr_max - (running_max[i]));
running_min[i] = adjusted_min;
running_max[i] = adjusted_max;
@ -142,40 +154,51 @@ void _calculate_moving_average(
at::Tensor x_min, x_max;
int64_t* observer_on_data = observer_on.data_ptr<int64_t>();
float* running_min_data = running_min.data_ptr<float>();
float* running_max_data = running_max.data_ptr<float>();
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
if (per_row_fq) {
std::tie(x_min, x_max) = at::aminmax(x, 1);
float* x_min_data = x_min.data_ptr<float>();
float* x_max_data = x_max.data_ptr<float>();
int num_threads = std::min(size, (int64_t)512);
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
// Moving Average Min/Max observer for activations
MovingAverageMinMax<<<num_blocks, num_threads, 0, cuda_stream>>>(
observer_on_data,
x_min_data,
x_max_data,
running_min_data,
running_max_data,
averaging_const,
size);
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
// Moving Average Min/Max observer for activations
MovingAverageMinMax<<<num_blocks, num_threads, 0, cuda_stream>>>(
observer_on_data,
x_min_data,
x_max_data,
running_min_data,
running_max_data,
averaging_const,
size);
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
std::tie(x_min, x_max) = at::aminmax(x);
float* x_min_data = x_min.data_ptr<float>();
float* x_max_data = x_max.data_ptr<float>();
// Moving Average Min/Max observer for activations
MovingAverageMinMax<<<1, 1, 0, cuda_stream>>>(
observer_on_data,
x_min_data,
x_max_data,
running_min_data,
running_max_data,
averaging_const,
1 /*size*/);
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* x_min_data = x_min.data_ptr<scalar_t>();
scalar_t* x_max_data = x_max.data_ptr<scalar_t>();
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
// Moving Average Min/Max observer for activations
MovingAverageMinMax<<<1, 1, 0, cuda_stream>>>(
observer_on_data,
x_min_data,
x_max_data,
running_min_data,
running_max_data,
averaging_const,
1 /*size*/);
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
@ -198,34 +221,44 @@ void _calc_moving_avg_qparams_helper(
cudaStream_t cuda_stream = at::cuda::getCurrentCUDAStream();
int64_t* fake_quant_on_data = fake_quant_on.data_ptr<int64_t>();
if (per_row_fq) {
float* running_min_data = running_min.data_ptr<float>();
float* running_max_data = running_max.data_ptr<float>();
int num_threads = std::min(size, (int64_t)512);
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
ChooseQuantizationParamsKernelImpl<<<num_blocks, num_threads, 0, cuda_stream>>>(
fake_quant_on_data,
running_min_data,
running_max_data,
qmin,
qmax,
size,
symmetric_quant,
scale_ptr,
zp_ptr);
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
int num_threads = std::min(size, (int64_t)512);
const uint64_t num_blocks = ceil_div<uint64_t>(size, num_threads);
ChooseQuantizationParamsKernelImpl<<<
num_blocks,
num_threads,
0,
cuda_stream>>>(
fake_quant_on_data,
running_min_data,
running_max_data,
qmin,
qmax,
size,
symmetric_quant,
scale_ptr,
zp_ptr);
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
} else {
float* running_min_data = running_min.data_ptr<float>();
float* running_max_data = running_max.data_ptr<float>();
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(
fake_quant_on_data,
running_min_data,
running_max_data,
qmin,
qmax,
1, // size
symmetric_quant, // preserve_sparsity
scale_ptr,
zp_ptr);
AT_DISPATCH_FLOATING_TYPES_AND(
at::kBFloat16, x.scalar_type(), "aminmax_kernel", [&] {
scalar_t* running_min_data = running_min.data_ptr<scalar_t>();
scalar_t* running_max_data = running_max.data_ptr<scalar_t>();
ChooseQuantizationParamsKernelImpl<<<1, 1, 0, cuda_stream>>>(
fake_quant_on_data,
running_min_data,
running_max_data,
qmin,
qmax,
1, // size
symmetric_quant, // preserve_sparsity
scale_ptr,
zp_ptr);
});
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}

View File

@ -908,6 +908,19 @@ class TestFakeQuantize(TestCase):
self.assertEqual(fq_module.activation_post_process.quant_min, 0)
self.assertEqual(fq_module.activation_post_process.quant_max, 127)
@given(device=st.sampled_from(['cpu', 'cuda'] if torch.cuda.is_available() else ['cpu']))
def test_fused_moving_avg_obs_fake_quant(self, device):
try:
sampled_dtype = st.sampled_from(["bf16", "fp32"]) if device == "cuda" else "fp32"
dtype = torch.bfloat16 if sampled_dtype == "bf16" else torch.float32
torch.set_default_dtype(dtype)
with torch.device(device):
fake_quantize = FusedMovingAvgObsFakeQuantize()
fake_quantize.forward(torch.rand((256, 512)))
finally:
torch.set_default_dtype(torch.float32)
def _get_buffer_ids(module):
"""
Object addresses stay constant if and only if all modifications are in-place

View File

@ -3,6 +3,7 @@
import torch
import math
from typing import Union
from torch.ao.quantization import (
FakeQuantize,
MovingAverageMinMaxObserver,
@ -155,18 +156,25 @@ def _fake_quantize_learnable_per_channel_affine_grad_reference(
def _get_tensor_min_max(
X: torch.Tensor,
running_min: float = float("inf"),
running_max: float = float("-inf"),
averaging_const: float = 0.01) -> tuple[float, float]:
min_val = X.min().to(dtype=torch.float32).item()
max_val = X.max().to(dtype=torch.float32).item()
running_min: Union[float, torch.Tensor] = float("inf"),
running_max: Union[float, torch.Tensor] = float("-inf"),
averaging_const: float = 0.01,
dtype: torch.dtype = torch.float32) -> tuple[float, float]:
min_val_tensor = X.min().to(dtype=dtype)
max_val_tensor = X.max().to(dtype=dtype)
averaging_const_tensor = torch.tensor(averaging_const, dtype=dtype).item()
if not math.isinf(running_min):
min_val = running_min + averaging_const * (min_val - running_min)
if not math.isinf(running_max):
max_val = running_max + averaging_const * (max_val - running_max)
if not isinstance(running_min, torch.Tensor):
running_min = torch.tensor(running_min, dtype=dtype)
if not isinstance(running_max, torch.Tensor):
running_max = torch.tensor(running_max, dtype=dtype)
return min_val, max_val
if not torch.isinf(running_min):
min_val_tensor = running_min + averaging_const_tensor * (min_val_tensor - running_min)
if not torch.isinf(running_max):
max_val_tensor = running_max + averaging_const_tensor * (max_val_tensor - running_max)
return min_val_tensor.item(), max_val_tensor.item()
def _get_per_row_min_max(
x: torch.Tensor,
@ -1064,10 +1072,13 @@ class TestFusedObsFakeQuant(TestCase):
Tests the case where we call the fused_obs_fake_quant op multiple times
and update the running_min and max of the activation tensors.
"""
in_running_min_ref = out_running_min_ref = float("inf")
in_running_min_op = torch.tensor(float("inf"), device=device)
in_running_max_ref = out_running_max_ref = float("-inf")
in_running_max_op = torch.tensor(float("-inf"), device=device)
sampled_dtype = st.sampled_from(["bf16", "fp32"]) if device == "cuda" else "fp32"
dtype = torch.bfloat16 if sampled_dtype == "bf16" else torch.float32
in_running_min_ref = out_running_min_ref = torch.tensor(float("inf"), dtype=dtype)
in_running_min_op = torch.tensor(float("inf"), dtype=dtype, device=device)
in_running_max_ref = out_running_max_ref = torch.tensor(float("-inf"), dtype=dtype)
in_running_max_op = torch.tensor(float("-inf"), dtype=dtype, device=device)
avg_const = 0.01
scale = torch.tensor([1.0], device=device)
zero_point = torch.tensor([0], dtype=torch.int, device=device)
@ -1080,8 +1091,7 @@ class TestFusedObsFakeQuant(TestCase):
observer_on = True if use_bool else 1
if i > 4:
fake_quant_on = True if use_bool else 1
x = torch.randn(5, 5, device=device)
x = torch.randn(5, 5, dtype=dtype, device=device)
out = pt_op(
x,
torch.tensor(observer_on, device=device),
@ -1106,6 +1116,7 @@ class TestFusedObsFakeQuant(TestCase):
running_min=in_running_min_ref,
running_max=in_running_max_ref,
averaging_const=0.01,
dtype=dtype,
)
if fake_quant_on:
@ -1128,7 +1139,7 @@ class TestFusedObsFakeQuant(TestCase):
torch.testing.assert_close(out, x_in)
# Test empty input works
x = torch.empty(0, 5, device=device)
x = torch.empty(0, 5, dtype=dtype, device=device)
out = pt_op(
x,
torch.tensor(1, device=device),