mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Quant][CPU] Enable fp8 qconv (#157076)
**Summary** Enable fp8 qconv on CPU. It's part of the plan to enable fp8 static quantization on CPU. This PR only adds FP8 support of the existing int8 qconv op. It does not add a new op nor does it affect frontend or quantization flow. The schema of the qconv op is not changed either. So, the FP8 qconv shares the same op as INT8 qconv and the difference is that src/wei dtype is fp8 instead of int8. The output dtype can be fp8/float32/bfloat16. The implementation uses the oneDNN library. Note: OneDNN does not support quantized fp8 convolution until v3.9 but the version used in PyTorch is v3.7.2. So, the op goes to the reference kernel for now. And we have also update the oneDNN path so that it's compatible with the fp8 dtype. Once oneDNN is upgraded to v3.9 or newer, minimum changes are needed to enable the oneDNN path. And we have ensured that the behavior of the reference kernel is the same as the new oneDNN's implementation. - oneDNN version < 3.9 (now) - Always go to the reference kernel - oneDNN version >= 3.9 (future) - Go to reference kernel on old platforms (without AMX) - Use oneDNN on new platforms (with AMX) **Test plan** ``` pytest test/quantization/core/test_quantized_op.py -k "qconv and fp8" ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/157076 Approved by: https://github.com/leslie-fang-intel, https://github.com/jerryzh168
This commit is contained in:
parent
ed508cc018
commit
e1a20988f3
|
|
@ -34,6 +34,15 @@
|
|||
#include <ATen/ops/quantize_per_channel_native.h>
|
||||
#include <ATen/ops/quantize_per_tensor_native.h>
|
||||
#include <ATen/ops/zeros.h>
|
||||
#include <ATen/ops/convolution.h>
|
||||
#include <ATen/ops/linear.h>
|
||||
#include <ATen/ops/relu.h>
|
||||
#include <ATen/ops/leaky_relu.h>
|
||||
#include <ATen/ops/tanh.h>
|
||||
#include <ATen/ops/gelu.h>
|
||||
#include <ATen/ops/hardtanh.h>
|
||||
#include <ATen/ops/hardswish.h>
|
||||
#include <ATen/ops/sigmoid.h>
|
||||
#endif
|
||||
|
||||
#include <c10/util/irange.h>
|
||||
|
|
@ -1384,6 +1393,116 @@ template at::Tensor PackedConvWeightsOnednn<3>::apply_relu(
|
|||
double output_scale,
|
||||
int64_t output_zero_point);
|
||||
|
||||
static at::Tensor _fp8_convolution_onednn_ref(
|
||||
at::Tensor act, // contains quantized values but not QTensor
|
||||
double act_scale,
|
||||
at::Tensor weight, // MKLDNN tensor with quantized values
|
||||
at::Tensor weight_scales,
|
||||
std::optional<at::Tensor> bias, // Bias is not packed into MKLDNN tensor
|
||||
torch::List<int64_t> stride,
|
||||
torch::List<int64_t> padding,
|
||||
torch::List<int64_t> dilation,
|
||||
int64_t groups,
|
||||
double output_scale,
|
||||
std::optional<at::Tensor> accum, // accum to fused with conv add
|
||||
double accum_scale,
|
||||
std::optional<c10::ScalarType> output_dtype,
|
||||
std::optional<std::string_view> binary_attr,
|
||||
std::optional<at::Scalar> binary_alpha,
|
||||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
TORCH_CHECK(
|
||||
act.scalar_type() == at::ScalarType::Float8_e4m3fn && weight.scalar_type() == at::ScalarType::Float8_e4m3fn,
|
||||
"FP8 qconv: Unexpected dtype of input and weight:", act.scalar_type(), ", ", weight.scalar_type());
|
||||
int kSpatialDim = act.dim() - 2;
|
||||
// conv1d is converted to conv2d before calling this function
|
||||
TORCH_CHECK(kSpatialDim != 1, "Expect 2D or 3D convolution, but got 1D convolution.");
|
||||
auto act_contig = act.contiguous(kSpatialDim == 2 ?
|
||||
c10::MemoryFormat::ChannelsLast :
|
||||
c10::MemoryFormat::ChannelsLast3d);
|
||||
auto dqx = act_contig.to(at::kFloat) * act_scale;
|
||||
std::vector<int64_t> w_scales_new_shape(weight.dim(), 1);
|
||||
w_scales_new_shape[0] = -1;
|
||||
auto dqw = weight.to(at::kFloat) * weight_scales.reshape(w_scales_new_shape);
|
||||
auto output_padding = std::vector<int64_t>(kSpatialDim, 0);
|
||||
auto y_f32 = at::convolution(
|
||||
dqx, dqw, bias, stride.vec(), padding.vec(), dilation.vec(), /* transposed */false, output_padding, groups
|
||||
);
|
||||
if (!binary_attr.has_value() || binary_attr == "none") {
|
||||
if (unary_attr == "relu") {
|
||||
at::relu_(y_f32);
|
||||
} else if (unary_attr == "leaky_relu") {
|
||||
TORCH_CHECK(
|
||||
unary_scalars.size() == 1,
|
||||
"onednn qconv: expect one argument for post op leaky_relu but got ", unary_scalars.size(), " args");
|
||||
auto element = unary_scalars.get(0);
|
||||
auto alpha = element.value().to<float>();
|
||||
at::leaky_relu_(y_f32, alpha);
|
||||
} else if (unary_attr == "tanh") {
|
||||
at::tanh_(y_f32);
|
||||
} else if (unary_attr == "gelu") {
|
||||
TORCH_CHECK(
|
||||
unary_algorithm == "none" || unary_algorithm == "tanh",
|
||||
"onednn qconv: algorithm for post op gelu must be none or tanh but got ", unary_algorithm);
|
||||
at::gelu_(y_f32, unary_algorithm.value());
|
||||
} else if (unary_attr == "hardtanh") {
|
||||
TORCH_CHECK(
|
||||
unary_scalars.size() == 2 &&
|
||||
unary_scalars.get(0).has_value() &&
|
||||
unary_scalars.get(1).has_value(),
|
||||
"hardtanh is expected to have two scalar input: min_val and max_val");
|
||||
auto lower_bound_value =
|
||||
unary_scalars.get(0).value().to<float>();
|
||||
auto upper_bound_value =
|
||||
unary_scalars.get(1).value().to<float>();
|
||||
at::hardtanh_(y_f32, lower_bound_value, upper_bound_value);
|
||||
} else if (unary_attr == "hardswish") {
|
||||
at::hardswish_(y_f32);
|
||||
} else if (unary_attr == "swish") {
|
||||
y_f32 = y_f32 * at::sigmoid(y_f32);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
!unary_attr.has_value() || unary_attr == "none",
|
||||
"onednn qconv: unsupported unary post op ", unary_attr);
|
||||
}
|
||||
} else if (binary_attr == "sum") {
|
||||
TORCH_CHECK(accum.has_value(), "onednn qconv: the extra input is missing for post op sum");
|
||||
auto x1 = accum.value();
|
||||
TORCH_CHECK(x1.sizes() == y_f32.sizes());
|
||||
auto x1_f32 = x1.to(at::kFloat) * accum_scale;
|
||||
x1_f32 = x1_f32.view(y_f32.sizes());
|
||||
if (!unary_attr.has_value() || unary_attr == "none") {
|
||||
y_f32.add_(x1_f32);
|
||||
} else if (unary_attr == "relu") {
|
||||
y_f32.add_(x1_f32).relu_();
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"onednn qconv: unsupported unary post op ", unary_attr, " with binary post op sum");
|
||||
}
|
||||
y_f32.div_(output_scale);
|
||||
if (x1.scalar_type() == at::kFloat8_e4m3fn) {
|
||||
// Align with oneDNN: convert fp32 to fp8 by fp32 -> fp16 -> fp8
|
||||
y_f32 = y_f32.to(at::kHalf);
|
||||
}
|
||||
x1.copy_(y_f32.to(x1.scalar_type()).view(x1.sizes()));
|
||||
return x1;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
"onednn qconv: unsupported binary post op ", binary_attr);
|
||||
}
|
||||
|
||||
y_f32.div_(output_scale);
|
||||
auto out_dtype = output_dtype.has_value() ? output_dtype.value() : at::kFloat8_e4m3fn;
|
||||
if (out_dtype == at::kFloat8_e4m3fn) {
|
||||
// Align with oneDNN: convert fp32 to fp8 by fp32 -> fp16 -> fp8
|
||||
return y_f32.to(at::kHalf).to(out_dtype);
|
||||
}
|
||||
return y_f32.to(out_dtype);
|
||||
}
|
||||
|
||||
static at::Tensor _quantized_convolution_onednn(
|
||||
at::Tensor act, // contains quantized values but not QTensor
|
||||
double act_scale,
|
||||
|
|
@ -1408,6 +1527,7 @@ static at::Tensor _quantized_convolution_onednn(
|
|||
std::optional<std::string_view> unary_attr,
|
||||
torch::List<std::optional<at::Scalar>> unary_scalars,
|
||||
std::optional<std::string_view> unary_algorithm) {
|
||||
using ideep::tensor;
|
||||
/*********************************/
|
||||
/* Checks */
|
||||
/*********************************/
|
||||
|
|
@ -1464,10 +1584,6 @@ static at::Tensor _quantized_convolution_onednn(
|
|||
if (kSpatialDim == 1) {
|
||||
kSpatialDim += 1;
|
||||
}
|
||||
TORCH_CHECK(
|
||||
weight.is_mkldnn(),
|
||||
func_name, ": Weight should be prepacked as an MKLDNN tensor"
|
||||
);
|
||||
if (transposed) {
|
||||
TORCH_CHECK(
|
||||
false,
|
||||
|
|
@ -1481,12 +1597,13 @@ static at::Tensor _quantized_convolution_onednn(
|
|||
padding = quant_utils::MakeArgForConv1d(padding, 0);
|
||||
dilation = quant_utils::MakeArgForConv1d(dilation, 1);
|
||||
}
|
||||
auto act_dtype = act.scalar_type();
|
||||
TORCH_CHECK(
|
||||
act.scalar_type() == c10::ScalarType::Byte,
|
||||
func_name, ": Input tensor should have uint8 (unsigned char) data type");
|
||||
act_dtype == c10::ScalarType::Byte || act_dtype == c10::ScalarType::Float8_e4m3fn,
|
||||
func_name, ": Input tensor should have uint8 (unsigned char) or fp8 data type");
|
||||
TORCH_CHECK(
|
||||
weight.scalar_type() == c10::ScalarType::Char,
|
||||
func_name, ": Weight tensor should have int8 (char) data type");
|
||||
weight.scalar_type() == c10::ScalarType::Char || weight.scalar_type() == c10::ScalarType::Float8_e4m3fn,
|
||||
func_name, ": Weight tensor should have int8 (char) or fp8 data type");
|
||||
TORCH_CHECK(
|
||||
weight.ndimension() == kSpatialDim + 2,
|
||||
func_name, ": Weights are expected to have ", kSpatialDim + 2, " dimensions");
|
||||
|
|
@ -1502,6 +1619,30 @@ static at::Tensor _quantized_convolution_onednn(
|
|||
dilation.size() == (decltype(dilation.size()))kSpatialDim,
|
||||
func_name, ": dilation should contain ", kSpatialDim, " elements for ",
|
||||
kSpatialDim, "D convolution.");
|
||||
bool is_fp8 = weight.scalar_type() == c10::ScalarType::Float8_e4m3fn;
|
||||
if (is_fp8) {
|
||||
TORCH_CHECK(act_dtype == c10::ScalarType::Float8_e4m3fn,
|
||||
func_name, ": expect input tensor to have fp8 data type, but got ", act_dtype);
|
||||
TORCH_CHECK(act_zero_point == 0,
|
||||
func_name, ": fp8 input should not have zero point.");
|
||||
// the current version of oneDNN does not fp8 conv yet
|
||||
// TODO(weiwen) Refine this part when oneDNN supports fp8 conv
|
||||
auto out = _fp8_convolution_onednn_ref(
|
||||
act, act_scale, weight, weight_scales,
|
||||
bias, stride, padding, dilation, groups,
|
||||
output_scale, accum, accum_scale,
|
||||
output_dtype, binary_attr, binary_alpha, unary_attr,
|
||||
unary_scalars, unary_algorithm);
|
||||
if (is_1d) {
|
||||
out.squeeze_(quant_utils::kConv1dSqueezeDim + 2);
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
weight.is_mkldnn(),
|
||||
func_name, ": Weight should be prepacked as an MKLDNN tensor"
|
||||
);
|
||||
|
||||
// Parameters
|
||||
#if IDEEP_PREREQ(3, 1, 0, 1)
|
||||
|
|
@ -1577,7 +1718,7 @@ static at::Tensor _quantized_convolution_onednn(
|
|||
c10::MemoryFormat::ChannelsLast :
|
||||
c10::MemoryFormat::ChannelsLast3d);
|
||||
auto src_dims = act_contig.sizes().vec();
|
||||
auto src_data_type = dnnl::memory::data_type::u8;
|
||||
auto src_data_type = at::native::get_mkldnn_dtype(act.scalar_type());
|
||||
auto src_desc = ideep::tensor::desc(src_dims, src_data_type,
|
||||
kSpatialDim == 2 ? ideep::format_tag::nhwc : ideep::format_tag::ndhwc);
|
||||
ideep::tensor src;
|
||||
|
|
@ -1594,7 +1735,7 @@ static at::Tensor _quantized_convolution_onednn(
|
|||
at::empty(
|
||||
dst_dims,
|
||||
at::device(c10::kCPU)
|
||||
.dtype(fp32_output ? c10::kFloat : (bfloat16_output ? c10::kBFloat16 : c10::kByte))
|
||||
.dtype(fp32_output ? c10::kFloat : (bfloat16_output ? c10::kBFloat16 : act_dtype))
|
||||
.memory_format(kSpatialDim == 2 ?
|
||||
c10::MemoryFormat::ChannelsLast :
|
||||
c10::MemoryFormat::ChannelsLast3d)
|
||||
|
|
@ -1619,9 +1760,8 @@ static at::Tensor _quantized_convolution_onednn(
|
|||
// Use oneDNN's APIs instead of prepare/compute from ideep to reduce integration overhead.
|
||||
// The functions from ideep are heavy because they have complex data structures for unified API
|
||||
// oneDNN version >= 3.1.0 is required.
|
||||
using ideep::tensor;
|
||||
auto weight_grouped = packed_weight.make_grouped_weights(groups, /* is_deconv */false);
|
||||
auto weights_desc = tensor::desc(weight_grouped.get_dims(), ideep::data_type::s8, ideep::format_tag::any);
|
||||
auto weights_desc = tensor::desc(weight_grouped.get_dims(), packed_weight.get_data_type(), ideep::format_tag::any);
|
||||
if (groups > 1) {
|
||||
weights_desc = weights_desc.to_grouped(groups);
|
||||
}
|
||||
|
|
@ -2090,5 +2230,12 @@ TORCH_LIBRARY_IMPL(onednn, MkldnnCPU, m) {
|
|||
m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary_tensor"), at::native::QConvoneDNN::run_pointwise_binary_tensor);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(onednn, CPU, m) {
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_pointwise"), at::native::QConvoneDNN::run_pointwise);
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qconv_pointwise.tensor"), at::native::QConvoneDNN::run_pointwise_tensor);
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary"), at::native::QConvoneDNN::run_pointwise_binary);
|
||||
m.impl(TORCH_SELECTIVE_NAME("onednn::qconv2d_pointwise.binary_tensor"), at::native::QConvoneDNN::run_pointwise_binary_tensor);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace at::native
|
||||
|
|
|
|||
|
|
@ -519,6 +519,10 @@ at::Tensor _qconv_prepack_onednn(
|
|||
dilation.size() == (decltype(dilation.size()))kSpatialDim,
|
||||
"dilation should contain ", kSpatialDim, " elements for ",
|
||||
kSpatialDim, "D convolution.");
|
||||
TORCH_CHECK(
|
||||
weight.scalar_type() == at::kChar || weight.scalar_type() == at::kFloat8_e4m3fn,
|
||||
"Weight should have dtype int8 or fp8_e4m3fn but got ", weight.scalar_type());
|
||||
bool is_fp8 = weight.scalar_type() == at::kFloat8_e4m3fn;
|
||||
|
||||
bool is_1d = (1 == kSpatialDim);
|
||||
auto x_dims = input_shape.has_value()?input_shape.value().vec():ideep::dims();
|
||||
|
|
@ -535,6 +539,12 @@ at::Tensor _qconv_prepack_onednn(
|
|||
dilation = quant_utils::MakeArgForConv1d(dilation, 1);
|
||||
kSpatialDim += 1;
|
||||
}
|
||||
if (is_fp8) {
|
||||
// The current version of oneDNN does not support fp8 conv
|
||||
// TODO(weiwen) Remove this when oneDNN supports fp8 conv
|
||||
// FP8 convolution is not supported by oneDNN until v3.9
|
||||
return weight;
|
||||
}
|
||||
auto w_dims = weight.sizes().vec();
|
||||
auto strides = stride.vec();
|
||||
auto padding_l = padding.vec();
|
||||
|
|
@ -581,11 +591,13 @@ at::Tensor _qconv_prepack_onednn(
|
|||
ideep::dims dims_iohw, dims_giohw;
|
||||
ideep::tag w_tag = ideep::tag::any;
|
||||
const bool with_groups = groups > 1;
|
||||
auto w_dnnl_dtype = at::native::get_mkldnn_dtype(weight.scalar_type());
|
||||
auto x_dnnl_dtype = is_fp8 ? dnnl::memory::data_type::f8_e4m3 : dnnl::memory::data_type::u8;
|
||||
w_desc = ideep::convolution_forward::expected_weights_desc(
|
||||
w_dims, dnnl::memory::data_type::s8,
|
||||
w_dims, w_dnnl_dtype,
|
||||
strides, padding_l, padding_r, dilates, groups,
|
||||
dnnl::algorithm::convolution_direct, dnnl::prop_kind::forward_inference,
|
||||
dnnl::memory::data_type::u8, x_dims, op_attr, /*is_channels_last=*/true);
|
||||
x_dnnl_dtype, x_dims, op_attr, /*is_channels_last=*/true);
|
||||
|
||||
// Note: Weight in Conv1D will unsqueeze into Conv2D in previous step
|
||||
weight_copy = weight.clone(c10::MemoryFormat::Contiguous);
|
||||
|
|
@ -598,7 +610,7 @@ at::Tensor _qconv_prepack_onednn(
|
|||
ideep::dims wei_dims = with_groups ? ideep::utils::group_dims(w_desc.get_dims(), groups)
|
||||
: w_desc.get_dims();
|
||||
ideep::tensor wgt = ideep::tensor(
|
||||
ideep::tensor::desc({wei_dims, dnnl::memory::data_type::s8, w_tag}, groups),
|
||||
ideep::tensor::desc({wei_dims, w_dnnl_dtype, w_tag}, groups),
|
||||
weight_copy.data_ptr());
|
||||
|
||||
wgt.set_scale(weights_scales); // Scales are needed for feed_from().
|
||||
|
|
|
|||
|
|
@ -987,7 +987,6 @@ static at::Tensor fp8_qlinear_onednn_ref(
|
|||
} else if (unary_post_op == "hardswish") {
|
||||
at::hardswish_(y_f32);
|
||||
} else if (unary_post_op == "swish") {
|
||||
// return ideep::attr_t::fuse_swish();
|
||||
y_f32 = y_f32 * at::sigmoid(y_f32);
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
|
|
|
|||
|
|
@ -154,6 +154,32 @@ def _get_random_tensor_and_q_params(shapes, rand_scale, torch_type):
|
|||
X_scale = 1e-10
|
||||
return X, X_scale, X_zero_point
|
||||
|
||||
def _quantize_fp8e4m3(t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None):
|
||||
quant_max = torch.finfo(torch.float8_e4m3fn).max
|
||||
eps = torch.Tensor([torch.finfo(torch.float32).eps])
|
||||
if channelwise:
|
||||
scale = scale or t.reshape(t.shape[0], -1).abs().max(-1)[0] / quant_max
|
||||
scale = torch.max(scale, eps)
|
||||
scale_reshape = scale.reshape((-1,) + (1,) * (t.dim() - 1))
|
||||
qt = t / scale_reshape
|
||||
else:
|
||||
scale = scale or t.abs().max().reshape([1]) / quant_max
|
||||
scale = torch.max(scale, eps) if isinstance(scale, torch.Tensor) else max(scale, eps.item())
|
||||
qt = t / scale
|
||||
qt = qt.to(torch.float8_e4m3fn)
|
||||
return qt, scale
|
||||
|
||||
def _dequantize_fp8e4m3(qt: torch.Tensor, scale: torch.Tensor):
|
||||
dqt = qt.float()
|
||||
if scale.numel() == 1:
|
||||
# per tensor
|
||||
dqt = dqt * scale
|
||||
else:
|
||||
# per channel
|
||||
scale_reshape = scale.reshape((-1,) + (1,) * (qt.dim() - 1))
|
||||
dqt = dqt * scale_reshape
|
||||
return dqt
|
||||
|
||||
class TestQuantizedOps(TestCase):
|
||||
|
||||
"""Helper function to test quantized activation functions."""
|
||||
|
|
@ -4678,32 +4704,6 @@ class TestQuantizedLinear(TestCase):
|
|||
qlinear = torch.ops.onednn.qlinear_pointwise.binary
|
||||
self._test_qlinear_pt2e_helper(qlinear, "add_relu")
|
||||
|
||||
def _quantize_fp8e4m3(self, t: torch.Tensor, channelwise: bool, scale: Optional[torch.Tensor] = None):
|
||||
quant_max = torch.finfo(torch.float8_e4m3fn).max
|
||||
eps = torch.Tensor([torch.finfo(torch.float32).eps])
|
||||
if channelwise:
|
||||
scale = scale or t.reshape(t.shape[0], -1).abs().max(-1)[0] / quant_max
|
||||
scale = torch.max(scale, eps)
|
||||
scale_reshape = scale.reshape((-1,) + (1,) * (t.dim() - 1))
|
||||
qt = t / scale_reshape
|
||||
else:
|
||||
scale = scale or t.abs().max().reshape([1]) / quant_max
|
||||
scale = torch.max(scale, eps) if isinstance(scale, torch.Tensor) else max(scale, eps.item())
|
||||
qt = t / scale
|
||||
qt = qt.to(torch.float8_e4m3fn)
|
||||
return qt, scale
|
||||
|
||||
def _dequantize_fp8e4m3(self, qt: torch.Tensor, scale: torch.Tensor):
|
||||
dqt = qt.float()
|
||||
if scale.numel() == 1:
|
||||
# per tensor
|
||||
dqt = dqt * scale
|
||||
else:
|
||||
# per channel
|
||||
scale_reshape = scale.reshape((-1,) + (1,) * (qt.dim() - 1))
|
||||
dqt = dqt * scale_reshape
|
||||
return dqt
|
||||
|
||||
def _test_qlinear_fp8_helper(
|
||||
self,
|
||||
qlinear_op,
|
||||
|
|
@ -4737,16 +4737,16 @@ class TestQuantizedLinear(TestCase):
|
|||
x2_scale, x2_zp = 0.3, 0
|
||||
x = torch.rand(batch_size, (ic + 1), ic) * 10 if input_dim == 3 else torch.rand(batch_size, ic) * 10
|
||||
w = torch.rand(oc, ic) * 10
|
||||
qx, x_scale = self._quantize_fp8e4m3(x, channelwise=False)
|
||||
qw, w_scales = self._quantize_fp8e4m3(w, channelwise=weight_quant_per_channel)
|
||||
qx, x_scale = _quantize_fp8e4m3(x, channelwise=False)
|
||||
qw, w_scales = _quantize_fp8e4m3(w, channelwise=weight_quant_per_channel)
|
||||
if use_bias:
|
||||
b = torch.rand(oc) * 10
|
||||
else:
|
||||
b = None
|
||||
|
||||
# compute reference result
|
||||
x_ref = self._dequantize_fp8e4m3(qx, x_scale)
|
||||
w_ref = self._dequantize_fp8e4m3(qw, w_scales)
|
||||
x_ref = _dequantize_fp8e4m3(qx, x_scale)
|
||||
w_ref = _dequantize_fp8e4m3(qw, w_scales)
|
||||
y_ref = linear_op(x_ref, w_ref, b)
|
||||
|
||||
# compute fp8 linear
|
||||
|
|
@ -4766,8 +4766,8 @@ class TestQuantizedLinear(TestCase):
|
|||
y_ref = F.gelu(y_ref, approximate=post_op_algo)
|
||||
elif post_op in ("sum", "sum_relu"):
|
||||
x2 = torch.rand_like(y_ref)
|
||||
x2_q, x2_scale = self._quantize_fp8e4m3(x2, channelwise=False)
|
||||
x2_dq = self._dequantize_fp8e4m3(x2_q, x2_scale)
|
||||
x2_q, x2_scale = _quantize_fp8e4m3(x2, channelwise=False)
|
||||
x2_dq = _dequantize_fp8e4m3(x2_q, x2_scale)
|
||||
unary_post_op = "relu" if post_op == "sum_relu" else "none"
|
||||
binary_alpha = 1.0 # we only support alpha=1.0 now
|
||||
# if output_dtype is fp32 or bf16, accumulate on x2
|
||||
|
|
@ -4806,7 +4806,7 @@ class TestQuantizedLinear(TestCase):
|
|||
|
||||
# Compare results
|
||||
if output_dtype is None:
|
||||
y_ref = self._quantize_fp8e4m3(y_ref, False, used_y_scale)[0]
|
||||
y_ref = _quantize_fp8e4m3(y_ref, False, used_y_scale)[0]
|
||||
else:
|
||||
y_ref = y_ref.to(output_dtype)
|
||||
|
||||
|
|
@ -7490,10 +7490,10 @@ class TestQuantizedConv(TestCase):
|
|||
qconv_output_dtype=output_dtype,
|
||||
)
|
||||
|
||||
# Test qconv with post op silu
|
||||
# Test qconv with post op swish
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_silu_pt2e(self):
|
||||
def test_qconv2d_swish_pt2e(self):
|
||||
input_channels_per_group = 2
|
||||
output_channels_per_group = 2
|
||||
groups_list = [1, 10]
|
||||
|
|
@ -7815,6 +7815,330 @@ class TestQuantizedConv(TestCase):
|
|||
qconv_output_dtype=output_dtype,
|
||||
)
|
||||
|
||||
def _make_qconv_tensors_fp8(
|
||||
self, batch_size, input_channels_per_group, input_feature_map_shape,
|
||||
output_channels_per_group, groups, kernels, strides, pads, dilations,
|
||||
use_bias, use_channelwise, use_transpose,
|
||||
device=torch.device("cpu"),
|
||||
):
|
||||
assert not (use_channelwise and use_transpose), \
|
||||
"Cannot generate channelwise qconv_transpose_tensors "
|
||||
input_channels = input_channels_per_group * groups
|
||||
output_channels = output_channels_per_group * groups
|
||||
# Padded input size should be at least as big as dilated kernel
|
||||
kernels = _single(kernels)
|
||||
strides = _single(strides)
|
||||
pads = _single(pads)
|
||||
dilations = _single(dilations)
|
||||
for i in range(len(kernels)):
|
||||
assume(input_feature_map_shape[i] + 2 * pads[i]
|
||||
>= dilations[i] * (kernels[i] - 1) + 1)
|
||||
# the operator expects them in the format
|
||||
# (output_channels, input_channels/groups, kernel_d, kernel_h, kernel_w)
|
||||
# (input_channels, output_channels/groups, kernel_d, kernel_h, kernel_w)
|
||||
if use_transpose:
|
||||
output_shape = (input_channels, output_channels_per_group,)
|
||||
else:
|
||||
output_shape = (output_channels, input_channels_per_group,)
|
||||
|
||||
X = torch.rand(
|
||||
(batch_size, input_channels,) + input_feature_map_shape,
|
||||
device=device,
|
||||
)
|
||||
X_q, X_scale = _quantize_fp8e4m3(X, channelwise=False)
|
||||
W = torch.randn(output_shape + kernels, device=device) * 0.1
|
||||
W_q, W_scale = _quantize_fp8e4m3(W, channelwise=use_channelwise)
|
||||
bias_float = torch.randn((output_channels,), device=device) if use_bias else None
|
||||
|
||||
return X, W, X_q, W_q, X_scale, W_scale, bias_float
|
||||
|
||||
def _test_qconv_impl_cpu_tensor_fp8(
|
||||
self,
|
||||
qconv,
|
||||
qconv_prepack,
|
||||
conv_op,
|
||||
input_channels_per_group=2,
|
||||
input_feature_map_shape=(),
|
||||
output_channels_per_group=2,
|
||||
groups=1,
|
||||
kernels=3,
|
||||
strides=(),
|
||||
pads=(),
|
||||
dilations=(),
|
||||
Y_scale=0.02,
|
||||
use_bias=True,
|
||||
post_op=PointwisePostOp(),
|
||||
use_channelwise=True,
|
||||
X2_scale=0.02,
|
||||
qconv_output_dtype=None, # None, torch.float32, torch.bfloat16
|
||||
weight_in_channel_last_format=False,
|
||||
):
|
||||
# We assume FP8 quantization is always symmetric
|
||||
fp32_output = True if qconv_output_dtype is torch.float32 else False
|
||||
bfloat16_output = True if qconv_output_dtype is torch.bfloat16 else False
|
||||
if fp32_output or bfloat16_output:
|
||||
Y_scale = 1.0
|
||||
X2_scale = 1.0
|
||||
batch_size = 3
|
||||
device = torch.device("cpu")
|
||||
use_transpose = False
|
||||
X, W, X_q, W_q, X_scale, W_scale, bias_float = self._make_qconv_tensors_fp8(
|
||||
batch_size,
|
||||
input_channels_per_group,
|
||||
input_feature_map_shape,
|
||||
output_channels_per_group,
|
||||
groups,
|
||||
kernels,
|
||||
strides,
|
||||
pads,
|
||||
dilations,
|
||||
use_bias,
|
||||
use_channelwise,
|
||||
use_transpose,
|
||||
device=device,
|
||||
)
|
||||
# Assign weights
|
||||
dqW = _dequantize_fp8e4m3(W_q, W_scale)
|
||||
dqX = _dequantize_fp8e4m3(X_q, X_scale)
|
||||
conv_op.weight = torch.nn.Parameter(dqW, requires_grad=False)
|
||||
conv_op.bias = (
|
||||
torch.nn.Parameter(bias_float, requires_grad=False) if use_bias else None
|
||||
)
|
||||
result_ref = conv_op(dqX)
|
||||
X2 = None
|
||||
X2_q = None
|
||||
X2_scale = 1.0
|
||||
|
||||
if post_op.binary_attr == "sum":
|
||||
X2_dtype = qconv_output_dtype if qconv_output_dtype else torch.float32
|
||||
X2 = torch.rand_like(result_ref, device=device, dtype=X2_dtype)
|
||||
if qconv_output_dtype is None:
|
||||
X2_q, X2_scale = _quantize_fp8e4m3(X2, channelwise=False)
|
||||
X2_dq = _dequantize_fp8e4m3(X2_q, X2_scale)
|
||||
X2_scale = X2_scale.item()
|
||||
else:
|
||||
X2_dq = X2
|
||||
result_ref = result_ref + X2_dq
|
||||
if post_op.unary_attr == "relu":
|
||||
relu = torch.nn.ReLU()
|
||||
result_ref = relu(result_ref)
|
||||
elif post_op.unary_attr == "relu":
|
||||
assert not use_transpose, "Cannot fuse ReLU with ConvTranspose"
|
||||
relu = torch.nn.ReLU()
|
||||
result_ref = relu(result_ref)
|
||||
elif post_op.unary_attr == "hardtanh":
|
||||
assert not use_transpose, "Cannot fuse hardtanh with ConvTranspose"
|
||||
assert len(post_op.scalars) == 2, "For post op hardtanh, expect 2 parameters passed in"
|
||||
hardtanh = torch.nn.Hardtanh(min_val=post_op.scalars[0], max_val=post_op.scalars[1])
|
||||
result_ref = hardtanh(result_ref)
|
||||
elif post_op.unary_attr == "hardswish":
|
||||
assert not use_transpose, "Cannot fuse hardswish with ConvTranspose"
|
||||
hardswish = torch.nn.Hardswish()
|
||||
result_ref = hardswish(result_ref)
|
||||
elif post_op.unary_attr == "swish":
|
||||
assert not use_transpose, "Cannot fuse silu with ConvTranspose"
|
||||
silu = torch.nn.SiLU()
|
||||
result_ref = silu(result_ref)
|
||||
|
||||
# Quantize reference results for comparison
|
||||
if qconv_output_dtype is None:
|
||||
Y_scale_t = torch.Tensor([Y_scale]).to(device)
|
||||
# Align with oneDNN: convert fp32 to fp8 by fp32 -> fp16 -> fp8
|
||||
result_ref = result_ref.div(Y_scale_t).half().to(torch.float8_e4m3fn)
|
||||
else:
|
||||
result_ref = result_ref.to(qconv_output_dtype)
|
||||
|
||||
# Calculate the result for PT2E path
|
||||
if weight_in_channel_last_format:
|
||||
if W_q.dim() == 5:
|
||||
W_q = W_q.to(memory_format=torch.channels_last_3d)
|
||||
elif W_q.dim() == 4:
|
||||
W_q = W_q.to(memory_format=torch.channels_last)
|
||||
|
||||
X_scale_scalar = X_scale.item()
|
||||
packed_weight = qconv_prepack(
|
||||
W_q,
|
||||
W_scale,
|
||||
X_scale_scalar,
|
||||
0, # X_zero_point
|
||||
strides,
|
||||
pads,
|
||||
dilations,
|
||||
groups,
|
||||
X_q.size(),
|
||||
)
|
||||
|
||||
if post_op.binary_attr == "sum":
|
||||
accum = (
|
||||
X2_q.contiguous(memory_format=torch.channels_last)
|
||||
if X2_q is not None
|
||||
else X2.contiguous(memory_format=torch.channels_last)
|
||||
)
|
||||
result = qconv(
|
||||
X_q,
|
||||
X_scale_scalar,
|
||||
0, # X_zero_point
|
||||
packed_weight,
|
||||
W_scale,
|
||||
torch.zeros([], dtype=torch.int8), # W_zero_point
|
||||
accum,
|
||||
bias_float,
|
||||
strides,
|
||||
pads,
|
||||
dilations,
|
||||
groups,
|
||||
Y_scale,
|
||||
0, # Y_zero_point
|
||||
qconv_output_dtype,
|
||||
X2_scale,
|
||||
0, # X2_zero_point
|
||||
post_op.binary_attr,
|
||||
post_op.alpha,
|
||||
post_op.unary_attr,
|
||||
post_op.scalars,
|
||||
post_op.algorithm,
|
||||
)
|
||||
else:
|
||||
result = qconv(
|
||||
X_q,
|
||||
X_scale_scalar,
|
||||
0, # X_zero_point
|
||||
packed_weight,
|
||||
W_scale,
|
||||
torch.zeros([], dtype=torch.int8), # W_zero_point
|
||||
bias_float,
|
||||
strides,
|
||||
pads,
|
||||
dilations,
|
||||
groups,
|
||||
Y_scale,
|
||||
0, # Y_zero_point
|
||||
qconv_output_dtype,
|
||||
post_op.unary_attr,
|
||||
post_op.scalars,
|
||||
post_op.algorithm,
|
||||
)
|
||||
if fp32_output or bfloat16_output:
|
||||
self.assertTrue(result.dtype == qconv_output_dtype)
|
||||
|
||||
assert torch.allclose(result.float(), result_ref.float(), atol=1e-6)
|
||||
|
||||
def _test_qconv_fp8_helper(self, nd, pointwise_post_op):
|
||||
# nd = 1,2,3 -> conv1d/2d/3d
|
||||
if pointwise_post_op.binary_attr != "none":
|
||||
# Only conv2d supports binary post op
|
||||
assert nd == 2
|
||||
groups_list = [1, 3]
|
||||
input_channels_per_group = 2
|
||||
output_channels_per_group = 2
|
||||
length = 4
|
||||
kernel = 3
|
||||
stride = 1
|
||||
pad = 1
|
||||
dilation = 1
|
||||
use_bias_list = [False, True]
|
||||
use_channelwise_list = [False, True]
|
||||
output_dtype_list = [None, torch.float32, torch.bfloat16]
|
||||
options = itertools.product(groups_list, use_bias_list, use_channelwise_list, output_dtype_list)
|
||||
for groups, use_bias, use_channelwise, output_dtype in options:
|
||||
if output_dtype is not None and not (use_bias and use_channelwise):
|
||||
# Remove some test combination to reduce UT test time
|
||||
continue
|
||||
conv_mod = getattr(torch.nn, f"Conv{nd}d")(
|
||||
input_channels_per_group * groups,
|
||||
output_channels_per_group * groups,
|
||||
kernel,
|
||||
stride,
|
||||
pad,
|
||||
dilation,
|
||||
groups,
|
||||
)
|
||||
qconv = (
|
||||
torch.ops.onednn.qconv_pointwise
|
||||
if pointwise_post_op.binary_attr == "none"
|
||||
else torch.ops.onednn.qconv2d_pointwise.binary
|
||||
)
|
||||
qconv_prepack = torch.ops.onednn.qconv_prepack
|
||||
self._test_qconv_impl_cpu_tensor_fp8(
|
||||
qconv,
|
||||
qconv_prepack,
|
||||
conv_mod,
|
||||
input_channels_per_group=input_channels_per_group,
|
||||
input_feature_map_shape=(length,) * nd,
|
||||
output_channels_per_group=output_channels_per_group,
|
||||
groups=groups,
|
||||
kernels=[kernel] * nd,
|
||||
strides=[stride] * nd,
|
||||
pads=[pad] * nd,
|
||||
dilations=[dilation] * nd,
|
||||
use_bias=use_bias,
|
||||
post_op=pointwise_post_op,
|
||||
use_channelwise=use_channelwise,
|
||||
qconv_output_dtype=output_dtype,
|
||||
)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv1d_fp8(self):
|
||||
pointwise_post_op = PointwisePostOp()
|
||||
self._test_qconv_fp8_helper(1, pointwise_post_op)
|
||||
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv1d_relu_fp8(self):
|
||||
pointwise_post_op = PointwisePostOp(unary_attr="relu")
|
||||
self._test_qconv_fp8_helper(1, pointwise_post_op)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_fp8(self):
|
||||
pointwise_post_op = PointwisePostOp()
|
||||
self._test_qconv_fp8_helper(2, pointwise_post_op)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_relu_fp8(self):
|
||||
pointwise_post_op = PointwisePostOp(unary_attr="relu")
|
||||
self._test_qconv_fp8_helper(2, pointwise_post_op)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_hardtanh_fp8(self):
|
||||
pointwise_post_op = PointwisePostOp(unary_attr="hardtanh", scalars=[0.0, 6.0])
|
||||
self._test_qconv_fp8_helper(2, pointwise_post_op)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_swish_fp8(self):
|
||||
pointwise_post_op = PointwisePostOp(unary_attr="swish")
|
||||
self._test_qconv_fp8_helper(2, pointwise_post_op)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_hardswish_fp8(self):
|
||||
pointwise_post_op = PointwisePostOp(unary_attr="hardswish")
|
||||
self._test_qconv_fp8_helper(2, pointwise_post_op)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_sum_fp8(self):
|
||||
pointwise_post_op = PointwisePostOp(binary_attr="sum")
|
||||
self._test_qconv_fp8_helper(2, pointwise_post_op)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv2d_sum_relu_fp8(self):
|
||||
pointwise_post_op = PointwisePostOp(binary_attr="sum", unary_attr="relu")
|
||||
self._test_qconv_fp8_helper(2, pointwise_post_op)
|
||||
|
||||
@unittest.skipIf(IS_FBCODE, "Skip pt2e ops in fbcode")
|
||||
@skipIfNoONEDNN
|
||||
def test_qconv3d_fp8(self):
|
||||
pointwise_post_op = PointwisePostOp()
|
||||
self._test_qconv_fp8_helper(3, pointwise_post_op)
|
||||
|
||||
|
||||
|
||||
class TestPadding(TestCase):
|
||||
@given(batch_size=st.integers(1, 64),
|
||||
|
|
|
|||
|
|
@ -2720,10 +2720,24 @@ if torch._C._has_mkldnn:
|
|||
groups,
|
||||
None,
|
||||
)
|
||||
assert output_dtype in [torch.float32, torch.bfloat16, torch.uint8, torch.int8]
|
||||
if output_dtype is None:
|
||||
output_dtype = x.dtype
|
||||
assert output_dtype in [
|
||||
torch.float32,
|
||||
torch.bfloat16,
|
||||
torch.uint8,
|
||||
torch.int8,
|
||||
torch.float8_e4m3fn,
|
||||
]
|
||||
out = x.new_empty(shape_out, dtype=output_dtype)
|
||||
assert len(shape_out) in [3, 4], "only conv1d/2d are supported"
|
||||
format = torch.channels_last if len(shape_out) == 4 else torch.contiguous_format
|
||||
assert len(shape_out) in [3, 4, 5], (
|
||||
"Expect output to be 3d/4d/5d for conv1d/2d/3d"
|
||||
)
|
||||
format = {
|
||||
3: torch.contiguous_format,
|
||||
4: torch.channels_last,
|
||||
5: torch.channels_last_3d,
|
||||
}[len(shape_out)]
|
||||
out = out.to(memory_format=format)
|
||||
return out
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user