[Quant][CPU] Enable fp8 qconv (#157076)

**Summary**
Enable fp8 qconv on CPU. It's part of the plan to enable fp8 static quantization on CPU. This PR only adds FP8 support of the existing int8 qconv op. It does not add a new op nor does it affect frontend or quantization flow. The schema of the qconv op is not changed either.

So, the FP8 qconv shares the same op as INT8 qconv and the difference is that src/wei dtype is fp8 instead of int8. The output dtype can be fp8/float32/bfloat16. The implementation uses the oneDNN library.

Note:
OneDNN does not support quantized fp8 convolution until v3.9 but the version used in PyTorch is v3.7.2. So, the op goes to the reference kernel for now. And we have also update the oneDNN path so that it's compatible with the fp8 dtype. Once oneDNN is upgraded to v3.9 or newer, minimum changes are needed to enable the oneDNN path. And we have ensured that the behavior of the reference kernel is the same as the new oneDNN's implementation.
- oneDNN version < 3.9 (now)
  - Always go to the reference kernel
- oneDNN version >= 3.9 (future)
  - Go to reference kernel on old platforms (without AMX)
  - Use oneDNN on new platforms (with AMX)

**Test plan**
```
pytest test/quantization/core/test_quantized_op.py -k "qconv and fp8"
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157076
Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
Xia, Weiwen 2025-07-11 10:00:53 +00:00 committed by PyTorch MergeBot
parent ed508cc018
commit e1a20988f3
5 changed files with 550 additions and 54 deletions

View File

@ -34,6 +34,15 @@
#include <ATen/ops/quantize_per_channel_native.h>
#include <ATen/ops/quantize_per_tensor_native.h>
#include <ATen/ops/zeros.h>
#include <ATen/ops/convolution.h>
#include <ATen/ops/linear.h>
#include <ATen/ops/relu.h>
#include <ATen/ops/leaky_relu.h>
#include <ATen/ops/tanh.h>
#include <ATen/ops/gelu.h>
#include <ATen/ops/hardtanh.h>
#include <ATen/ops/hardswish.h>
#include <ATen/ops/sigmoid.h>
#endif
#include <c10/util/irange.h>
@ -1384,6 +1393,116 @@ template at::Tensor PackedConvWeightsOnednn<3>::apply_relu(
double output_scale,
int64_t output_zero_point);
static at::Tensor _fp8_convolution_onednn_ref(
at::Tensor act, // contains quantized values but not QTensor
double act_scale,
at::Tensor weight, // MKLDNN tensor with quantized values
at::Tensor weight_scales,
std::optional<at::Tensor> bias, // Bias is not packed into MKLDNN tensor
torch::List<int64_t> stride,
torch::List<int64_t> padding,
torch::List<int64_t> dilation,
int64_t groups,
double output_scale,
std::optional<at::Tensor> accum, // accum to fused with conv add
double accum_scale,
std::optional<c10::ScalarType> output_dtype,
std::optional<std::string_view> binary_attr,
std::optional<at::Scalar> binary_alpha,
std::optional<std::string_view> unary_attr,
torch::List<std::optional<at::Scalar>> unary_scalars,
std::optional<std::string_view> unary_algorithm) {
TORCH_CHECK(
act.scalar_type() == at::ScalarType::Float8_e4m3fn && weight.scalar_type() == at::ScalarType::Float8_e4m3fn,
"FP8 qconv: Unexpected dtype of input and weight:", act.scalar_type(), ", ", weight.scalar_type());
int kSpatialDim = act.dim() - 2;
// conv1d is converted to conv2d before calling this function
TORCH_CHECK(kSpatialDim != 1, "Expect 2D or 3D convolution, but got 1D convolution.");
auto act_contig = act.contiguous(kSpatialDim == 2 ?
c10::MemoryFormat::ChannelsLast :
c10::MemoryFormat::ChannelsLast3d);
auto dqx = act_contig.to(at::kFloat) * act_scale;
std::vector<int64_t> w_scales_new_shape(weight.dim(), 1);
w_scales_new_shape[0] = -1;
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
auto output_padding = std::vector<int64_t>(kSpatialDim, 0);
auto y_f32 = at::convolution(
dqx, dqw, bias, stride.vec(), padding.vec(), dilation.vec(), /* transposed */false, output_padding, groups
);
if (!binary_attr.has_value() || binary_attr == "none") {
if (unary_attr == "relu") {
at::relu_(y_f32);
} else if (unary_attr == "leaky_relu") {
TORCH_CHECK(
unary_scalars.size() == 1,
"onednn qconv: expect one argument for post op leaky_relu but got ", unary_scalars.size(), " args");
auto element = unary_scalars.get(0);
auto alpha = element.value().to<float>();
at::leaky_relu_(y_f32, alpha);
} else if (unary_attr == "tanh") {
at::tanh_(y_f32);
} else if (unary_attr == "gelu") {
TORCH_CHECK(
unary_algorithm == "none" || unary_algorithm == "tanh",
"onednn qconv: algorithm for post op gelu must be none or tanh but got ", unary_algorithm);
at::gelu_(y_f32, unary_algorithm.value());
} else if (unary_attr == "hardtanh") {
TORCH_CHECK(
unary_scalars.size() == 2 &&
unary_scalars.get(0).has_value() &&
unary_scalars.get(1).has_value(),
"hardtanh is expected to have two scalar input: min_val and max_val");
auto lower_bound_value =
unary_scalars.get(0).value().to<float>();
auto upper_bound_value =
unary_scalars.get(1).value().to<float>();
at::hardtanh_(y_f32, lower_bound_value, upper_bound_value);
} else if (unary_attr == "hardswish") {
at::hardswish_(y_f32);
} else if (unary_attr == "swish") {
y_f32 = y_f32 * at::sigmoid(y_f32);
} else {
TORCH_CHECK(
!unary_attr.has_value() || unary_attr == "none",
"onednn qconv: unsupported unary post op ", unary_attr);
}
} else if (binary_attr == "sum") {
TORCH_CHECK(accum.has_value(), "onednn qconv: the extra input is missing for post op sum");
auto x1 = accum.value();
TORCH_CHECK(x1.sizes() == y_f32.sizes());
auto x1_f32 = x1.to(at::kFloat) * accum_scale;
x1_f32 = x1_f32.view(y_f32.sizes());
if (!unary_attr.has_value() || unary_attr == "none") {
y_f32.add_(x1_f32);
} else if (unary_attr == "relu") {
y_f32.add_(x1_f32).relu_();
} else {
TORCH_CHECK(
false,
"onednn qconv: unsupported unary post op ", unary_attr, " with binary post op sum");
}
y_f32.div_(output_scale);
if (x1.scalar_type() == at::kFloat8_e4m3fn) {
// Align with oneDNN: convert fp32 to fp8 by fp32 -> fp16 -> fp8
y_f32 = y_f32.to(at::kHalf);
}
x1.copy_(y_f32.to(x1.scalar_type()).view(x1.sizes()));
return x1;
} else {
TORCH_CHECK(
false,
"onednn qconv: unsupported binary post op ", binary_attr);
}
y_f32.div_(output_scale);
auto out_dtype = output_dtype.has_value() ? output_dtype.value() : at::kFloat8_e4m3fn;
if (out_dtype == at::kFloat8_e4m3fn) {
// Align with oneDNN: convert fp32 to fp8 by fp32 -> fp16 -> fp8
return y_f32.to(at::kHalf).to(out_dtype);
}
return y_f32.to(out_dtype);
}
static at::Tensor _quantized_convolution_onednn(
at::Tensor act, // contains quantized values but not QTensor
double act_scale,
@ -1408,6 +1527,7 @@ static at::Tensor _quantized_convolution_onednn(
std::optional<std::string_view> unary_attr,
torch::List<std::optional<at::Scalar>> unary_scalars,
std::optional<std::string_view> unary_algorithm) {
using ideep::tensor;
/*********************************/
/* Checks */
/*********************************/
@ -1464,10 +1584,6 @@ static at::Tensor _quantized_convolution_onednn(
if (kSpatialDim == 1) {
kSpatialDim += 1;
}
TORCH_CHECK(
weight.is_mkldnn(),
func_name, ": Weight should be prepacked as an MKLDNN tensor"
);
if (transposed) {
TORCH_CHECK(
false,
@ -1481,12 +1597,13 @@ static at::Tensor _quantized_convolution_onednn(
padding = quant_utils::MakeArgForConv1d(padding, 0);
dilation = quant_utils::MakeArgForConv1d(dilation, 1);
}
auto act_dtype = act.scalar_type();
TORCH_CHECK(
act.scalar_type() == c10::ScalarType::Byte,
func_name, ": Input tensor should have uint8 (unsigned char) data type");
act_dtype == c10::ScalarType::Byte || act_dtype == c10::ScalarType::Float8_e4m3fn,
func_name, ": Input tensor should have uint8 (unsigned char) or fp8 data type");
TORCH_CHECK(
weight.scalar_type() == c10::ScalarType::Char,
func_name, ": Weight tensor should have int8 (char) data type");
weight.scalar_type() == c10::ScalarType::Char || weight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
func_name, ": Weight tensor should have int8 (char) or fp8 data type");
TORCH_CHECK(
weight.ndimension() == kSpatialDim + 2,
func_name, ": Weights are expected to have ", kSpatialDim + 2, " dimensions");
@ -1502,6 +1619,30 @@ static at::Tensor _quantized_convolution_onednn(
dilation.size() == (decltype(dilation.size()))kSpatialDim,
func_name, ": dilation should contain ", kSpatialDim, " elements for ",
kSpatialDim, "D convolution.");
bool is_fp8 = weight.scalar_type() == c10::ScalarType::Float8_e4m3fn;
if (is_fp8) {
TORCH_CHECK(act_dtype == c10::ScalarType::Float8_e4m3fn,
func_name, ": expect input tensor to have fp8 data type, but got ", act_dtype);
TORCH_CHECK(act_zero_point == 0,
func_name, ": fp8 input should not have zero point.");
// the current version of oneDNN does not fp8 conv yet
// TODO(weiwen) Refine this part when oneDNN supports fp8 conv
auto out = _fp8_convolution_onednn_ref(
act, act_scale, weight, weight_scales,
bias, stride, padding, dilation, groups,
output_scale, accum, accum_scale,
output_dtype, binary_attr, binary_alpha, unary_attr,
unary_scalars, unary_algorithm);
if (is_1d) {
out.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
}
return out;
}
TORCH_CHECK(
weight.is_mkldnn(),
func_name, ": Weight should be prepacked as an MKLDNN tensor"
);
// Parameters
#if IDEEP_PREREQ(3, 1, 0, 1)
@ -1577,7 +1718,7 @@ static at::Tensor _quantized_convolution_onednn(
c10::MemoryFormat::ChannelsLast :
c10::MemoryFormat::ChannelsLast3d);
auto src_dims = act_contig.sizes().vec();
auto src_data_type = dnnl::memory::data_type::u8;
auto src_data_type = at::native::get_mkldnn_dtype(act.scalar_type());
auto src_desc = ideep::tensor::desc(src_dims, src_data_type,
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
ideep::tensor src;
@ -1594,7 +1735,7 @@ static at::Tensor _quantized_convolution_onednn(
at::empty(
dst_dims,
at::device(c10::kCPU)
.dtype(fp32_output ? c10::kFloat : (bfloat16_output ? c10::kBFloat16 : c10::kByte))
.dtype(fp32_output ? c10::kFloat : (bfloat16_output ? c10::kBFloat16 : act_dtype))
.memory_format(kSpatialDim == 2 ?
c10::MemoryFormat::ChannelsLast :
c10::MemoryFormat::ChannelsLast3d)
@ -1619,9 +1760,8 @@ static at::Tensor _quantized_convolution_onednn(
// Use oneDNN's APIs instead of prepare/compute from ideep to reduce integration overhead.
// The functions from ideep are heavy because they have complex data structures for unified API
// oneDNN version >= 3.1.0 is required.
using ideep::tensor;
auto weight_grouped = packed_weight.make_grouped_weights(groups, /* is_deconv */false);
auto weights_desc = tensor::desc(weight_grouped.get_dims(), ideep::data_type::s8, ideep::format_tag::any);
auto weights_desc = tensor::desc(weight_grouped.get_dims(), packed_weight.get_data_type(), ideep::format_tag::any);
if (groups > 1) {
weights_desc = weights_desc.to_grouped(groups);
}
@ -2090,5 +2230,12 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {
m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary_tensor"), at::native::QConvoneDNN::run_pointwise_binary_tensor);
}
TORCH_LIBRARY_IMPL(onednn, CPU, m) {
m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_pointwise"), at::native::QConvoneDNN::run_pointwise);
m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_pointwise.tensor"), at::native::QConvoneDNN::run_pointwise_tensor);
m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary"), at::native::QConvoneDNN::run_pointwise_binary);
m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary_tensor"), at::native::QConvoneDNN::run_pointwise_binary_tensor);
}
} // namespace
} // namespace at::native

View File

@ -519,6 +519,10 @@ at::Tensor _qconv_prepack_onednn(
dilation.size() == (decltype(dilation.size()))kSpatialDim,
"dilation should contain ", kSpatialDim, " elements for ",
kSpatialDim, "D convolution.");
TORCH_CHECK(
weight.scalar_type() == at::kChar || weight.scalar_type() == at::kFloat8_e4m3fn,
"Weight should have dtype int8 or fp8_e4m3fn but got ", weight.scalar_type());
bool is_fp8 = weight.scalar_type() == at::kFloat8_e4m3fn;
bool is_1d = (1 == kSpatialDim);
auto x_dims = input_shape.has_value()?input_shape.value().vec():ideep::dims();
@ -535,6 +539,12 @@ at::Tensor _qconv_prepack_onednn(
dilation = quant_utils::MakeArgForConv1d(dilation, 1);
kSpatialDim += 1;
}
if (is_fp8) {
// The current version of oneDNN does not support fp8 conv
// TODO(weiwen) Remove this when oneDNN supports fp8 conv
// FP8 convolution is not supported by oneDNN until v3.9
return weight;
}
auto w_dims = weight.sizes().vec();
auto strides = stride.vec();
auto padding_l = padding.vec();
@ -581,11 +591,13 @@ at::Tensor _qconv_prepack_onednn(
ideep::dims dims_iohw, dims_giohw;
ideep::tag w_tag = ideep::tag::any;
const bool with_groups = groups > 1;
auto w_dnnl_dtype = at::native::get_mkldnn_dtype(weight.scalar_type());
auto x_dnnl_dtype = is_fp8 ? dnnl::memory::data_type::f8_e4m3 : dnnl::memory::data_type::u8;
w_desc = ideep::convolution_forward::expected_weights_desc(
w_dims, dnnl::memory::data_type::s8,
w_dims, w_dnnl_dtype,
strides, padding_l, padding_r, dilates, groups,
dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
dnnl::memory::data_type::u8, x_dims, op_attr, /*is_channels_last=*/true);
x_dnnl_dtype, x_dims, op_attr, /*is_channels_last=*/true);
// Note: Weight in Conv1D will unsqueeze into Conv2D in previous step
weight_copy = weight.clone(c10::MemoryFormat::Contiguous);
@ -598,7 +610,7 @@ at::Tensor _qconv_prepack_onednn(
ideep::dims wei_dims = with_groups ? ideep::utils::group_dims(w_desc.get_dims(), groups)
: w_desc.get_dims();
ideep::tensor wgt = ideep::tensor(
ideep::tensor::desc({wei_dims, dnnl::memory::data_type::s8, w_tag}, groups),
ideep::tensor::desc({wei_dims, w_dnnl_dtype, w_tag}, groups),
weight_copy.data_ptr());
wgt.set_scale(weights_scales); // Scales are needed for feed_from().

View File

@ -987,7 +987,6 @@ static at::Tensor fp8_qlinear_onednn_ref(
} else if (unary_post_op == "hardswish") {
at::hardswish_(y_f32);
} else if (unary_post_op == "swish") {
// return ideep::attr_t::fuse_swish();
y_f32 = y_f32 * at::sigmoid(y_f32);
} else {
TORCH_CHECK(

View File

@ -154,6 +154,32 @@ def _get_random_tensor_and_q_params(shapes, rand_scale, torch_type):
X_scale = 1e-10
return X, X_scale, X_zero_point
def _quantize_fp8e4m3(t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None):
quant_max = torch.finfo(torch.float8_e4m3fn).max
eps = torch.Tensor([torch.finfo(torch.float32).eps])
if channelwise:
scale = scale or t.reshape(t.shape[0], -1).abs().max(-1)[0] / quant_max
scale = torch.max(scale, eps)
scale_reshape = scale.reshape((-1,) + (1,) * (t.dim() - 1))
qt = t / scale_reshape
else:
scale = scale or t.abs().max().reshape([1]) / quant_max
scale = torch.max(scale, eps) if isinstance(scale, torch.Tensor) else max(scale, eps.item())
qt = t / scale
qt = qt.to(torch.float8_e4m3fn)
return qt, scale
def _dequantize_fp8e4m3(qt: torch.Tensor, scale: torch.Tensor):
dqt = qt.float()
if scale.numel() == 1:
# per tensor
dqt = dqt * scale
else:
# per channel
scale_reshape = scale.reshape((-1,) + (1,) * (qt.dim() - 1))
dqt = dqt * scale_reshape
return dqt
class TestQuantizedOps(TestCase):
"""Helper function to test quantized activation functions."""
@ -4678,32 +4704,6 @@ class TestQuantizedLinear(TestCase):
qlinear = torch.ops.onednn.qlinear_pointwise.binary
self._test_qlinear_pt2e_helper(qlinear, "add_relu")
def _quantize_fp8e4m3(self, t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None):
quant_max = torch.finfo(torch.float8_e4m3fn).max
eps = torch.Tensor([torch.finfo(torch.float32).eps])
if channelwise:
scale = scale or t.reshape(t.shape[0], -1).abs().max(-1)[0] / quant_max
scale = torch.max(scale, eps)
scale_reshape = scale.reshape((-1,) + (1,) * (t.dim() - 1))
qt = t / scale_reshape
else:
scale = scale or t.abs().max().reshape([1]) / quant_max
scale = torch.max(scale, eps) if isinstance(scale, torch.Tensor) else max(scale, eps.item())
qt = t / scale
qt = qt.to(torch.float8_e4m3fn)
return qt, scale
def _dequantize_fp8e4m3(self, qt: torch.Tensor, scale: torch.Tensor):
dqt = qt.float()
if scale.numel() == 1:
# per tensor
dqt = dqt * scale
else:
# per channel
scale_reshape = scale.reshape((-1,) + (1,) * (qt.dim() - 1))
dqt = dqt * scale_reshape
return dqt
def _test_qlinear_fp8_helper(
self,
qlinear_op,
@ -4737,16 +4737,16 @@ class TestQuantizedLinear(TestCase):
x2_scale, x2_zp = 0.3, 0
x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10
w = torch.rand(oc, ic) * 10
qx, x_scale = self._quantize_fp8e4m3(x, channelwise=False)
qw, w_scales = self._quantize_fp8e4m3(w, channelwise=weight_quant_per_channel)
qx, x_scale = _quantize_fp8e4m3(x, channelwise=False)
qw, w_scales = _quantize_fp8e4m3(w, channelwise=weight_quant_per_channel)
if use_bias:
b = torch.rand(oc) * 10
else:
b = None
# compute reference result
x_ref = self._dequantize_fp8e4m3(qx, x_scale)
w_ref = self._dequantize_fp8e4m3(qw, w_scales)
x_ref = _dequantize_fp8e4m3(qx, x_scale)
w_ref = _dequantize_fp8e4m3(qw, w_scales)
y_ref = linear_op(x_ref, w_ref, b)
# compute fp8 linear
@ -4766,8 +4766,8 @@ class TestQuantizedLinear(TestCase):
y_ref = F.gelu(y_ref, approximate=post_op_algo)
elif post_op in ("sum", "sum_relu"):
x2 = torch.rand_like(y_ref)
x2_q, x2_scale = self._quantize_fp8e4m3(x2, channelwise=False)
x2_dq = self._dequantize_fp8e4m3(x2_q, x2_scale)
x2_q, x2_scale = _quantize_fp8e4m3(x2, channelwise=False)
x2_dq = _dequantize_fp8e4m3(x2_q, x2_scale)
unary_post_op = "relu" if post_op == "sum_relu" else "none"
binary_alpha = 1.0 # we only support alpha=1.0 now
# if output_dtype is fp32 or bf16, accumulate on x2
@ -4806,7 +4806,7 @@ class TestQuantizedLinear(TestCase):
# Compare results
if output_dtype is None:
y_ref = self._quantize_fp8e4m3(y_ref, False, used_y_scale)[0]
y_ref = _quantize_fp8e4m3(y_ref, False, used_y_scale)[0]
else:
y_ref = y_ref.to(output_dtype)
@ -7490,10 +7490,10 @@ class TestQuantizedConv(TestCase):
qconv_output_dtype=output_dtype,
)
# Test qconv with post op silu
# Test qconv with post op swish
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv2d_silu_pt2e(self):
def test_qconv2d_swish_pt2e(self):
input_channels_per_group = 2
output_channels_per_group = 2
groups_list = [1, 10]
@ -7815,6 +7815,330 @@ class TestQuantizedConv(TestCase):
qconv_output_dtype=output_dtype,
)
def _make_qconv_tensors_fp8(
self, batch_size, input_channels_per_group, input_feature_map_shape,
output_channels_per_group, groups, kernels, strides, pads, dilations,
use_bias, use_channelwise, use_transpose,
device=torch.device("cpu"),
):
assert not (use_channelwise and use_transpose), \
"Cannot generate channelwise qconv_transpose_tensors "
input_channels = input_channels_per_group * groups
output_channels = output_channels_per_group * groups
# Padded input size should be at least as big as dilated kernel
kernels = _single(kernels)
strides = _single(strides)
pads = _single(pads)
dilations = _single(dilations)
for i in range(len(kernels)):
assume(input_feature_map_shape[i] + 2 * pads[i]
>= dilations[i] * (kernels[i] - 1) + 1)
# the operator expects them in the format
# (output_channels, input_channels/groups, kernel_d, kernel_h, kernel_w)
# (input_channels, output_channels/groups, kernel_d, kernel_h, kernel_w)
if use_transpose:
output_shape = (input_channels, output_channels_per_group,)
else:
output_shape = (output_channels, input_channels_per_group,)
X = torch.rand(
(batch_size, input_channels,) + input_feature_map_shape,
device=device,
)
X_q, X_scale = _quantize_fp8e4m3(X, channelwise=False)
W = torch.randn(output_shape + kernels, device=device) * 0.1
W_q, W_scale = _quantize_fp8e4m3(W, channelwise=use_channelwise)
bias_float = torch.randn((output_channels,), device=device) if use_bias else None
return X, W, X_q, W_q, X_scale, W_scale, bias_float
def _test_qconv_impl_cpu_tensor_fp8(
self,
qconv,
qconv_prepack,
conv_op,
input_channels_per_group=2,
input_feature_map_shape=(),
output_channels_per_group=2,
groups=1,
kernels=3,
strides=(),
pads=(),
dilations=(),
Y_scale=0.02,
use_bias=True,
post_op=PointwisePostOp(),
use_channelwise=True,
X2_scale=0.02,
qconv_output_dtype=None, # None, torch.float32, torch.bfloat16
weight_in_channel_last_format=False,
):
# We assume FP8 quantization is always symmetric
fp32_output = True if qconv_output_dtype is torch.float32 else False
bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False
if fp32_output or bfloat16_output:
Y_scale = 1.0
X2_scale = 1.0
batch_size = 3
device = torch.device("cpu")
use_transpose = False
X, W, X_q, W_q, X_scale, W_scale, bias_float = self._make_qconv_tensors_fp8(
batch_size,
input_channels_per_group,
input_feature_map_shape,
output_channels_per_group,
groups,
kernels,
strides,
pads,
dilations,
use_bias,
use_channelwise,
use_transpose,
device=device,
)
# Assign weights
dqW = _dequantize_fp8e4m3(W_q, W_scale)
dqX = _dequantize_fp8e4m3(X_q, X_scale)
conv_op.weight = torch.nn.Parameter(dqW, requires_grad=False)
conv_op.bias = (
torch.nn.Parameter(bias_float, requires_grad=False) if use_bias else None
)
result_ref = conv_op(dqX)
X2 = None
X2_q = None
X2_scale = 1.0
if post_op.binary_attr == "sum":
X2_dtype = qconv_output_dtype if qconv_output_dtype else torch.float32
X2 = torch.rand_like(result_ref, device=device, dtype=X2_dtype)
if qconv_output_dtype is None:
X2_q, X2_scale = _quantize_fp8e4m3(X2, channelwise=False)
X2_dq = _dequantize_fp8e4m3(X2_q, X2_scale)
X2_scale = X2_scale.item()
else:
X2_dq = X2
result_ref = result_ref + X2_dq
if post_op.unary_attr == "relu":
relu = torch.nn.ReLU()
result_ref = relu(result_ref)
elif post_op.unary_attr == "relu":
assert not use_transpose, "Cannot fuse ReLU with ConvTranspose"
relu = torch.nn.ReLU()
result_ref = relu(result_ref)
elif post_op.unary_attr == "hardtanh":
assert not use_transpose, "Cannot fuse hardtanh with ConvTranspose"
assert len(post_op.scalars) == 2, "For post op hardtanh, expect 2 parameters passed in"
hardtanh = torch.nn.Hardtanh(min_val=post_op.scalars[0], max_val=post_op.scalars[1])
result_ref = hardtanh(result_ref)
elif post_op.unary_attr == "hardswish":
assert not use_transpose, "Cannot fuse hardswish with ConvTranspose"
hardswish = torch.nn.Hardswish()
result_ref = hardswish(result_ref)
elif post_op.unary_attr == "swish":
assert not use_transpose, "Cannot fuse silu with ConvTranspose"
silu = torch.nn.SiLU()
result_ref = silu(result_ref)
# Quantize reference results for comparison
if qconv_output_dtype is None:
Y_scale_t = torch.Tensor([Y_scale]).to(device)
# Align with oneDNN: convert fp32 to fp8 by fp32 -> fp16 -> fp8
result_ref = result_ref.div(Y_scale_t).half().to(torch.float8_e4m3fn)
else:
result_ref = result_ref.to(qconv_output_dtype)
# Calculate the result for PT2E path
if weight_in_channel_last_format:
if W_q.dim() == 5:
W_q = W_q.to(memory_format=torch.channels_last_3d)
elif W_q.dim() == 4:
W_q = W_q.to(memory_format=torch.channels_last)
X_scale_scalar = X_scale.item()
packed_weight = qconv_prepack(
W_q,
W_scale,
X_scale_scalar,
0, # X_zero_point
strides,
pads,
dilations,
groups,
X_q.size(),
)
if post_op.binary_attr == "sum":
accum = (
X2_q.contiguous(memory_format=torch.channels_last)
if X2_q is not None
else X2.contiguous(memory_format=torch.channels_last)
)
result = qconv(
X_q,
X_scale_scalar,
0, # X_zero_point
packed_weight,
W_scale,
torch.zeros([], dtype=torch.int8), # W_zero_point
accum,
bias_float,
strides,
pads,
dilations,
groups,
Y_scale,
0, # Y_zero_point
qconv_output_dtype,
X2_scale,
0, # X2_zero_point
post_op.binary_attr,
post_op.alpha,
post_op.unary_attr,
post_op.scalars,
post_op.algorithm,
)
else:
result = qconv(
X_q,
X_scale_scalar,
0, # X_zero_point
packed_weight,
W_scale,
torch.zeros([], dtype=torch.int8), # W_zero_point
bias_float,
strides,
pads,
dilations,
groups,
Y_scale,
0, # Y_zero_point
qconv_output_dtype,
post_op.unary_attr,
post_op.scalars,
post_op.algorithm,
)
if fp32_output or bfloat16_output:
self.assertTrue(result.dtype == qconv_output_dtype)
assert torch.allclose(result.float(), result_ref.float(), atol=1e-6)
def _test_qconv_fp8_helper(self, nd, pointwise_post_op):
# nd = 1,2,3 -> conv1d/2d/3d
if pointwise_post_op.binary_attr != "none":
# Only conv2d supports binary post op
assert nd == 2
groups_list = [1, 3]
input_channels_per_group = 2
output_channels_per_group = 2
length = 4
kernel = 3
stride = 1
pad = 1
dilation = 1
use_bias_list = [False, True]
use_channelwise_list = [False, True]
output_dtype_list = [None, torch.float32, torch.bfloat16]
options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list)
for groups, use_bias, use_channelwise, output_dtype in options:
if output_dtype is not None and not (use_bias and use_channelwise):
# Remove some test combination to reduce UT test time
continue
conv_mod = getattr(torch.nn, f"Conv{nd}d")(
input_channels_per_group * groups,
output_channels_per_group * groups,
kernel,
stride,
pad,
dilation,
groups,
)
qconv = (
torch.ops.onednn.qconv_pointwise
if pointwise_post_op.binary_attr == "none"
else torch.ops.onednn.qconv2d_pointwise.binary
)
qconv_prepack = torch.ops.onednn.qconv_prepack
self._test_qconv_impl_cpu_tensor_fp8(
qconv,
qconv_prepack,
conv_mod,
input_channels_per_group=input_channels_per_group,
input_feature_map_shape=(length,) * nd,
output_channels_per_group=output_channels_per_group,
groups=groups,
kernels=[kernel] * nd,
strides=[stride] * nd,
pads=[pad] * nd,
dilations=[dilation] * nd,
use_bias=use_bias,
post_op=pointwise_post_op,
use_channelwise=use_channelwise,
qconv_output_dtype=output_dtype,
)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv1d_fp8(self):
pointwise_post_op = PointwisePostOp()
self._test_qconv_fp8_helper(1, pointwise_post_op)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv1d_relu_fp8(self):
pointwise_post_op = PointwisePostOp(unary_attr="relu")
self._test_qconv_fp8_helper(1, pointwise_post_op)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv2d_fp8(self):
pointwise_post_op = PointwisePostOp()
self._test_qconv_fp8_helper(2, pointwise_post_op)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv2d_relu_fp8(self):
pointwise_post_op = PointwisePostOp(unary_attr="relu")
self._test_qconv_fp8_helper(2, pointwise_post_op)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv2d_hardtanh_fp8(self):
pointwise_post_op = PointwisePostOp(unary_attr="hardtanh", scalars=[0.0, 6.0])
self._test_qconv_fp8_helper(2, pointwise_post_op)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv2d_swish_fp8(self):
pointwise_post_op = PointwisePostOp(unary_attr="swish")
self._test_qconv_fp8_helper(2, pointwise_post_op)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv2d_hardswish_fp8(self):
pointwise_post_op = PointwisePostOp(unary_attr="hardswish")
self._test_qconv_fp8_helper(2, pointwise_post_op)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv2d_sum_fp8(self):
pointwise_post_op = PointwisePostOp(binary_attr="sum")
self._test_qconv_fp8_helper(2, pointwise_post_op)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv2d_sum_relu_fp8(self):
pointwise_post_op = PointwisePostOp(binary_attr="sum", unary_attr="relu")
self._test_qconv_fp8_helper(2, pointwise_post_op)
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
@skipIfNoONEDNN
def test_qconv3d_fp8(self):
pointwise_post_op = PointwisePostOp()
self._test_qconv_fp8_helper(3, pointwise_post_op)
class TestPadding(TestCase):
@given(batch_size=st.integers(1, 64),

View File

@ -2720,10 +2720,24 @@ if torch._C._has_mkldnn:
groups,
None,
)
assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8]
if output_dtype is None:
output_dtype = x.dtype
assert output_dtype in [
torch.float32,
torch.bfloat16,
torch.uint8,
torch.int8,
torch.float8_e4m3fn,
]
out = x.new_empty(shape_out, dtype=output_dtype)
assert len(shape_out) in [3, 4], "only conv1d/2d are supported"
format = torch.channels_last if len(shape_out) == 4 else torch.contiguous_format
assert len(shape_out) in [3, 4, 5], (
"Expect output to be 3d/4d/5d for conv1d/2d/3d"
)
format = {
3: torch.contiguous_format,
4: torch.channels_last,
5: torch.channels_last_3d,
}[len(shape_out)]
out = out.to(memory_format=format)
return out