diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 10f34a685f3..d2049a93672 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include #include @@ -27,6 +28,14 @@ #include // for quantize_per_te... #include #include +#include +#include +#include +#include +#include +#include +#include +#include #endif #include @@ -918,6 +927,118 @@ at::Tensor PackedLinearWeightsOnednn:: apply_tanh( std::move(input), output_scale, output_zero_point); } +static at::Tensor fp8_qlinear_onednn_ref( + at::Tensor input, + double input_scale, + at::Tensor weight, // expect plain weight + at::Tensor weight_scales, + std::optional bias, // plain tensor + double output_scale, + std::optional output_dtype, + std::optional other, // extra input for binary post-op + double other_scale, + const std::string_view& binary_post_op, // e.g. "none", "sum", "add" + double binary_alpha, + const std::string_view& unary_post_op, // e.g. "none", "relu" + torch::List>& unary_post_op_args, + std::string_view& unary_post_op_algorithm) { + TORCH_CHECK( + input.scalar_type() == at::ScalarType::Float8_e4m3fn && weight.scalar_type() == at::ScalarType::Float8_e4m3fn, + "FP8 qlinear: Unexpected dtype of input and weight:", input.scalar_type(), ", ", weight.scalar_type()); + const int64_t dim = input.dim(); + auto input_contig = + dim == 2 ? input.contiguous() : input.reshape({-1, input.size(dim - 1)}).contiguous(); + auto N = weight.size(0); + auto output_size = input.sizes().vec(); + output_size[dim - 1] = N; + auto dqx = input_contig.to(at::kFloat) * input_scale; + std::vector w_scales_new_shape(weight.dim(), 1); + w_scales_new_shape[0] = -1; + auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape); + auto y_f32 = at::linear(dqx, dqw, bias); + if (binary_post_op == "none") { + if (unary_post_op == "relu") { + at::relu_(y_f32); + } else if (unary_post_op == "leaky_relu") { + TORCH_CHECK( + unary_post_op_args.size() == 1, + "onednn qlinear: expect one argument for post op leaky_relu but got ", unary_post_op_args.size(), " args"); + auto element = unary_post_op_args.get(0); + auto alpha = element.value().to(); + at::leaky_relu_(y_f32, alpha); + } else if (unary_post_op == "tanh") { + at::tanh_(y_f32); + } else if (unary_post_op == "gelu") { + TORCH_CHECK( + unary_post_op_algorithm == "none" || unary_post_op_algorithm == "tanh", + "onednn qlinear: algorithm for post op gelu must be none or tanh but got ", unary_post_op_algorithm); + at::gelu_(y_f32, unary_post_op_algorithm); + } else if (unary_post_op == "hardtanh") { + TORCH_CHECK( + unary_post_op_args.size() == 2 && + unary_post_op_args.get(0).has_value() && + unary_post_op_args.get(1).has_value(), + "hardtanh is expected to have two scalar input: min_val and max_val"); + auto lower_bound_value = + unary_post_op_args.get(0).value().to(); + auto upper_bound_value = + unary_post_op_args.get(1).value().to(); + at::hardtanh_(y_f32, lower_bound_value, upper_bound_value); + } else if (unary_post_op == "hardswish") { + at::hardswish_(y_f32); + } else if (unary_post_op == "swish") { + // return ideep::attr_t::fuse_swish(); + y_f32 = y_f32 * at::sigmoid(y_f32); + } else { + TORCH_CHECK( + unary_post_op == "none", + "onednn qlinear: unsupported unary post op ", unary_post_op); + } + } else if (binary_post_op == "sum") { + TORCH_CHECK(other.has_value(), "onednn qlinear: the extra input is missing for post op sum"); + auto x1 = other.value(); + TORCH_CHECK(x1.sizes().vec() == output_size); + auto x1_f32 = x1.to(at::kFloat) * other_scale; + x1_f32 = x1_f32.view(y_f32.sizes()); + if (unary_post_op == "none") { + y_f32.add_(x1_f32); + } else if (unary_post_op == "relu") { + y_f32.add_(x1_f32).relu_(); + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op sum"); + } + y_f32.div_(output_scale); + x1.copy_(y_f32.to(x1.scalar_type()).view(x1.sizes())); + return x1; + } else if (binary_post_op == "add") { + TORCH_CHECK(other.has_value(), "onednn qlinear: the extra input is missing for post op sum"); + auto x1 = other.value(); + TORCH_CHECK(x1.sizes().vec() == output_size); + auto x1_f32 = x1.to(at::kFloat) * other_scale; + x1_f32 = x1_f32.view(y_f32.sizes()); + if (unary_post_op == "none") { + y_f32.add_(x1_f32); + } else if (unary_post_op == "relu") { + y_f32.add_(x1_f32).relu_(); + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported unary post op ", unary_post_op, " with binary post op add"); + } + } else { + TORCH_CHECK( + false, + "onednn qlinear: unsupported binary post op ", binary_post_op); + } + + y_f32.div_(output_scale); + y_f32 = y_f32.view(output_size); + auto out_dtype = output_dtype.has_value() ? output_dtype.value() : at::kFloat8_e4m3fn; + return y_f32.to(out_dtype); +} + static at::Tensor linear_int8_with_onednn_weight( at::Tensor input, // int8 CPU Tensor, not QTensor double input_scale, @@ -939,10 +1060,18 @@ static at::Tensor linear_int8_with_onednn_weight( std::string_view& unary_post_op_algorithm) { using ideep::tensor; const int64_t dim = input.dim(); - TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char, - "qlinear with mkldnn tensor: data type of input should be uint8 or int8 (unsigned char or char)."); - TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char, - "qlinear with mkldnn tensor: data type of weight should be int8 (char)."); + TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char || input.scalar_type() == c10::ScalarType::Float8_e4m3fn, + "qlinear with mkldnn tensor: data type of input should be uint8, int8 or float8_e4m3fn."); + TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char || onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn, + "qlinear with mkldnn tensor: data type of weight should be int8 or float8_e4m3fn."); + bool is_fp8 = false; + if (input.scalar_type() == c10::ScalarType::Float8_e4m3fn || onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn) { + TORCH_CHECK( + input.scalar_type() == c10::ScalarType::Float8_e4m3fn && onednn_weight.scalar_type() == c10::ScalarType::Float8_e4m3fn, + "qlinear with mkldnn tensor: data type of input and weight should be the same for fp8, but got ", + input.scalar_type(), " and ", onednn_weight.scalar_type()); + is_fp8 = true; + } TORCH_CHECK( weight_scales.scalar_type() == c10::ScalarType::Float, "weight scales should be dtype c10::ScalarType::Float."); TORCH_CHECK( @@ -976,7 +1105,7 @@ static at::Tensor linear_int8_with_onednn_weight( ); } if (binary_post_op == "sum") { - auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : c10::kByte; + auto expected_dtype = output_dtype.has_value() ? output_dtype.value() : input.scalar_type(); TORCH_CHECK( other.value().scalar_type() == expected_dtype, "onednn qlinear: the dtype of extra input for binary post op should be ", expected_dtype, @@ -984,6 +1113,14 @@ static at::Tensor linear_int8_with_onednn_weight( ); } } + if (is_fp8 && !cpuinfo_has_x86_amx_int8()) { + // Fall back to ref impl on old platforms because not supported + return fp8_qlinear_onednn_ref( + input, input_scale, onednn_weight, weight_scales, bias, + output_scale, output_dtype, other, other_scale, + binary_post_op, binary_alpha, unary_post_op, + unary_post_op_args, unary_post_op_algorithm); + } // If the input has more than two dimensions, we will reshape it to a 2-dimensional form // for calculation and subsequently reshape the output back. @@ -1016,7 +1153,7 @@ static at::Tensor linear_int8_with_onednn_weight( at::empty( dst_dims, at::device(c10::kCPU) - .dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : c10::kByte)) + .dtype(fp32_output ? c10::kFloat : (bf16_output ? c10::kBFloat16 : input.scalar_type())) ); if (output.numel() == 0) { return output; @@ -1029,7 +1166,7 @@ static at::Tensor linear_int8_with_onednn_weight( empty_tensor; // Create onednn primitive - auto src_dtype = input.scalar_type() == c10::kByte ? ideep::data_type::u8 : ideep::data_type::s8; + auto src_dtype = at::native::get_mkldnn_dtype(input.scalar_type()); auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any); auto weights_desc = packed_weight.get_desc(); auto dst_dtype = dst.get_data_type(); @@ -1463,5 +1600,16 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) { TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor)); } +TORCH_LIBRARY_IMPL(onednn, CPU, m) { + m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise"), + TORCH_FN(QLinearOnednn::run_pointwise)); + m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.tensor"), + TORCH_FN(at::native::QLinearOnednn::run_pointwise_tensor)); + m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary"), + TORCH_FN(QLinearOnednn::run_pointwise_binary)); + m.impl(TORCH_SELECTIVE_NAME("onednn::qlinear_pointwise.binary_tensor"), + TORCH_FN(at::native::QLinearOnednn::run_pointwise_binary_tensor)); +} + } // namespace } // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index 50af0862aef..d99a336bf37 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -297,14 +297,32 @@ c10::intrusive_ptr PackedLinearWeightsOnednn::prepack( static inline at::Tensor pack_weight_to_onednn_tensor( const at::Tensor& weight, std::optional>& input_shape) { + at::ScalarType weigh_dtype = weight.scalar_type(); + TORCH_CHECK( + weigh_dtype == at::kChar || weigh_dtype == at::kFloat8_e4m3fn, + "Weight should be of type int8 or float8_e4m3fn"); + bool is_fp8 = weigh_dtype == at::kFloat8_e4m3fn; + if (is_fp8 && !cpuinfo_has_x86_amx_int8()) { + // oneDNN's fp8 requires AMX support + // If AMX is not available, fall back to reference implementation + return weight; + } std::vector w_dims = weight.sizes().vec(); - ideep::tensor wei = ideep::tensor({w_dims, dnnl::memory::data_type::s8}, weight.data_ptr()); + auto w_data_type = is_fp8 + ? dnnl::memory::data_type::f8_e4m3 + : dnnl::memory::data_type::s8; + ideep::tensor wei = ideep::tensor({w_dims, w_data_type}, weight.data_ptr()); wei.transpose_(0, 1); // oneDNN requires transposed weight ideep::dims input_dims = input_shape.has_value() ? input_shape.value().vec() : ideep::dims(); ideep::attr_t op_attr; - op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0); + if (!is_fp8) { + op_attr.set_zero_points_mask(DNNL_ARG_SRC, 0); + } + auto x_data_type = is_fp8 + ? dnnl::memory::data_type::f8_e4m3 + : dnnl::memory::data_type::u8; auto w_desc = ideep::matmul_forward::expected_weights_desc( - wei.get_dims(), input_dims, dnnl::memory::data_type::s8, dnnl::memory::data_type::u8, op_attr); + wei.get_dims(), input_dims, w_data_type, x_data_type, op_attr); ideep::tensor expected_weight(w_desc); expected_weight.feed_from(wei); auto packed_weight = at::native::new_with_itensor_mkldnn( diff --git a/test/quantization/core/test_quantized_op.py b/test/quantization/core/test_quantized_op.py index 3f244b31e54..c01e3c31833 100644 --- a/test/quantization/core/test_quantized_op.py +++ b/test/quantization/core/test_quantized_op.py @@ -4511,7 +4511,7 @@ class TestQuantizedLinear(TestCase): qlinear_op, post_op="none", unary_post_op_args=(), - post_op_algorithms=("none"), + post_op_algorithms=("none",), ): qlinear_prepack = torch.ops.onednn.qlinear_prepack linear_op = F.linear @@ -4678,6 +4678,184 @@ class TestQuantizedLinear(TestCase): qlinear = torch.ops.onednn.qlinear_pointwise.binary self._test_qlinear_pt2e_helper(qlinear, "add_relu") + def _quantize_fp8e4m3(self, t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None): + quant_max = torch.finfo(torch.float8_e4m3fn).max + eps = torch.Tensor([torch.finfo(torch.float32).eps]) + if channelwise: + scale = scale or t.reshape(t.shape[0], -1).abs().max(-1)[0] / quant_max + scale = torch.max(scale, eps) + scale_reshape = scale.reshape((-1,) + (1,) * (t.dim() - 1)) + qt = t / scale_reshape + else: + scale = scale or t.abs().max().reshape([1]) / quant_max + scale = torch.max(scale, eps) if isinstance(scale, torch.Tensor) else max(scale, eps.item()) + qt = t / scale + qt = qt.to(torch.float8_e4m3fn) + return qt, scale + + def _dequantize_fp8e4m3(self, qt: torch.Tensor, scale: torch.Tensor): + dqt = qt.float() + if scale.numel() == 1: + # per tensor + dqt = dqt * scale + else: + # per channel + scale_reshape = scale.reshape((-1,) + (1,) * (qt.dim() - 1)) + dqt = dqt * scale_reshape + return dqt + + def _test_qlinear_fp8_helper( + self, + qlinear_op, + post_op="none", + unary_post_op_args=(), + post_op_algorithms=("none",), + ): + qlinear_prepack = torch.ops.onednn.qlinear_prepack + linear_op = F.linear + in_channels_list = [4, 8] + out_channels_list = [16, 32] + batch_size = 1 + use_bias_list = [True, False] + weight_quant_per_channel_list = [True, False] + output_dtype_list = [None, torch.float32, torch.bfloat16] + y_scale, y_zp = 0.07, 0 + input_dim_list = [2, 3] + cases = itertools.product( + in_channels_list, out_channels_list, use_bias_list, + weight_quant_per_channel_list, output_dtype_list, post_op_algorithms, input_dim_list) + with override_quantized_engine('onednn'): + for ic, oc, use_bias, weight_quant_per_channel, output_dtype, post_op_algo, input_dim in cases: + used_y_scale = y_scale + used_y_zp = y_zp + fp32_out = output_dtype == torch.float32 + bfloat16_out = output_dtype == torch.bfloat16 + if fp32_out or bfloat16_out: + used_y_scale = 1.0 + x2_scale, x2_zp = 1.0, 0 + else: + x2_scale, x2_zp = 0.3, 0 + x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10 + w = torch.rand(oc, ic) * 10 + qx, x_scale = self._quantize_fp8e4m3(x, channelwise=False) + qw, w_scales = self._quantize_fp8e4m3(w, channelwise=weight_quant_per_channel) + if use_bias: + b = torch.rand(oc) * 10 + else: + b = None + + # compute reference result + x_ref = self._dequantize_fp8e4m3(qx, x_scale) + w_ref = self._dequantize_fp8e4m3(qw, w_scales) + y_ref = linear_op(x_ref, w_ref, b) + + # compute fp8 linear + qw_packed = qlinear_prepack(qw, x.shape) + x_zp = 0 + w_zps = torch.zeros_like(w_scales, dtype=torch.int) + + if post_op in ("none", "relu", "gelu"): + qy = qlinear_op( + qx, x_scale, x_zp, qw_packed, w_scales, w_zps, + b, used_y_scale, used_y_zp, output_dtype, + post_op, unary_post_op_args, post_op_algo + ) + if post_op == "relu": + y_ref = F.relu(y_ref) + elif post_op == "gelu": + y_ref = F.gelu(y_ref, approximate=post_op_algo) + elif post_op in ("sum", "sum_relu"): + x2 = torch.rand_like(y_ref) + x2_q, x2_scale = self._quantize_fp8e4m3(x2, channelwise=False) + x2_dq = self._dequantize_fp8e4m3(x2_q, x2_scale) + unary_post_op = "relu" if post_op == "sum_relu" else "none" + binary_alpha = 1.0 # we only support alpha=1.0 now + # if output_dtype is fp32 or bf16, accumulate on x2 + # if output_dtype is None (fp8), accumulate on x2_dq + accum = x2_q if output_dtype is None else x2 + accum_ref = x2_dq if output_dtype is None else x2.clone() + x2_scale = x2_scale if output_dtype is None else 1.0 + if bfloat16_out: + accum = accum.bfloat16() + accum_ref = accum_ref.bfloat16() + qy = qlinear_op( + qx, x_scale, x_zp, qw_packed, w_scales, w_zps, + accum, b, used_y_scale, used_y_zp, output_dtype, + x2_scale, x2_zp, "sum", binary_alpha, + unary_post_op, unary_post_op_args, post_op_algo + ) + y_ref = y_ref + accum_ref * binary_alpha + if unary_post_op == "relu": + y_ref = F.relu(y_ref) + elif post_op in ("add", "add_relu"): + if output_dtype is not None: + # Only support fp8 output + continue + x2 = torch.rand_like(y_ref) + unary_post_op = "relu" if post_op == "add_relu" else "none" + binary_alpha = 1.0 # we only support alpha=1.0 now + qy = qlinear_op( + qx, x_scale, x_zp, qw_packed, w_scales, w_zps, + x2, b, used_y_scale, used_y_zp, output_dtype, + 1.0, 0, "add", binary_alpha, + unary_post_op, unary_post_op_args, post_op_algo + ) + y_ref = y_ref + x2 * binary_alpha + if unary_post_op == "relu": + y_ref = F.relu(y_ref) + + # Compare results + if output_dtype is None: + y_ref = self._quantize_fp8e4m3(y_ref, False, used_y_scale)[0] + else: + y_ref = y_ref.to(output_dtype) + + self.assertEqual(x.dim(), qy.dim()) + self.assertEqual(y_ref.float(), qy.float()) + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise + self._test_qlinear_fp8_helper(qlinear, "none") + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_relu_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise + self._test_qlinear_fp8_helper(qlinear, "relu") + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_gelu_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise + post_op_algorithms = ['none', 'tanh'] + self._test_qlinear_fp8_helper(qlinear, "gelu", post_op_algorithms=post_op_algorithms) + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_sum_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise.binary + self._test_qlinear_fp8_helper(qlinear, "sum") + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_sum_relu_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise.binary + self._test_qlinear_fp8_helper(qlinear, "sum_relu") + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_add_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise.binary + self._test_qlinear_fp8_helper(qlinear, "add") + + @unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode") + @skipIfNoONEDNN + def test_qlinear_add_relu_fp8(self): + qlinear = torch.ops.onednn.qlinear_pointwise.binary + self._test_qlinear_fp8_helper(qlinear, "add_relu") + @unittest.skipIf(IS_MACOS, "Known test failure on Mac.") class TestQuantizedEmbeddingOps(TestCase):