Merge pull request #27698 from abhishek-gola:add_cast_layer

Added cast and castlike layers support in new DNN engine #27698

### Pull Request Readiness Checklist

See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request

- [x] I agree to contribute to the project under Apache 2 License.
- [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV
- [x] The PR is proposed to the proper branch
- [x] There is a reference to the original bug report and related work
- [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable
      Patch to opencv_extra has the same branch name.
- [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
Abhishek Gola 2025-10-04 15:52:30 +05:30 committed by GitHub
parent 9c75140b1b
commit 2470c07f1b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 609 additions and 84 deletions

View File

@ -1414,6 +1414,11 @@ CV__DNN_INLINE_NS_BEGIN
static Ptr<CastLayer> create(const LayerParams &params);
};
class CV_EXPORTS Cast2Layer : public Layer {
public:
static Ptr<Cast2Layer> create(const LayerParams &params);
};
class CV_EXPORTS NonMaxSuppressionLayer : public Layer
{
public:

View File

@ -63,6 +63,12 @@ ov::element::Type cvTypeToOvType(MatType cvType)
return ov::element::i64;
case CV_Bool:
return ov::element::boolean;
case CV_16F:
return ov::element::f16;
case CV_16BF:
return ov::element::bf16;
case CV_64F:
return ov::element::f64;
default:
CV_Error(Error::StsNotImplemented, format("Unsupported data type %s", typeToString(cvType).c_str()));
}
@ -88,6 +94,12 @@ MatType ovTypeToCvType(ov::element::Type ovType)
return CV_64S;
case ov::element::boolean:
return CV_Bool;
case ov::element::f16:
return CV_16F;
case ov::element::bf16:
return CV_16BF;
case ov::element::f64:
return CV_64F;
default:
CV_Error(Error::StsNotImplemented, format("Unsupported data type %s", ovType.get_type_name().c_str()));
}

View File

@ -193,6 +193,7 @@ void initializeLayerFactory()
CV_DNN_REGISTER_LAYER_CLASS(Attention, AttentionLayer);
CV_DNN_REGISTER_LAYER_CLASS(GroupNormalization, GroupNormLayer);
CV_DNN_REGISTER_LAYER_CLASS(Cast, CastLayer);
CV_DNN_REGISTER_LAYER_CLASS(Cast2, Cast2Layer);
CV_DNN_REGISTER_LAYER_CLASS(DepthToSpace, DepthToSpaceLayer)
CV_DNN_REGISTER_LAYER_CLASS(SpaceToDepth, SpaceToDepthLayer)
CV_DNN_REGISTER_LAYER_CLASS(DepthToSpaceInt8, DepthToSpaceLayer)

View File

@ -0,0 +1,373 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level directory
// of this distribution and at http://opencv.org/license.html.
#include "../precomp.hpp"
#include "../op_inf_engine.hpp"
#include "../ie_ngraph.hpp"
#include "layers_common.hpp"
#include "../net_impl.hpp"
#include "opencv-onnx.pb.h"
namespace cv { namespace dnn {
// ONNX Cast operator
// Spec: https://onnx.ai/onnx/operators/onnx__Cast.html
// Supported opsets: 1-24
// ONNX CastLike operator
// Spec: https://onnx.ai/onnx/operators/onnx__CastLike.html
// Supported opsets: 15-24
namespace
{
inline void castQuantized(const Mat& src, Mat& dst, int targetDepth)
{
if (targetDepth == CV_16F)
{
CV_Assert(dst.depth() == CV_32F);
if (src.depth() == CV_32F)
{
MatConstIterator_<float> sIt = src.begin<float>(), sEnd = src.end<float>();
MatIterator_<float> dIt = dst.begin<float>();
for (; sIt != sEnd; ++sIt, ++dIt)
{
*dIt = (float)hfloat(*sIt);
}
}
else if (src.depth() == CV_64F)
{
MatConstIterator_<double> sIt = src.begin<double>(), sEnd = src.end<double>();
MatIterator_<float> dIt = dst.begin<float>();
for (; sIt != sEnd; ++sIt, ++dIt)
{
float v = (float)*sIt;
*dIt = (float)hfloat(v);
}
}
else
{
Mat src32; src.convertTo(src32, CV_32F);
MatConstIterator_<float> sIt = src32.begin<float>(), sEnd = src32.end<float>();
MatIterator_<float> dIt = dst.begin<float>();
for (; sIt != sEnd; ++sIt, ++dIt)
{
*dIt = (float)hfloat(*sIt);
}
}
return;
}
if (targetDepth == CV_16BF)
{
const int ddepth = dst.depth();
if (!(ddepth == CV_16BF || ddepth == CV_16U))
{
CV_Error(Error::StsNotImplemented, "Unsupported destination depth for BF16 cast");
}
Mat dst_bits(dst.size(), CV_MAKETYPE(CV_16U, dst.channels()), dst.data, dst.step);
const Mat* src32p;
Mat src32;
if (src.depth() == CV_32F)
src32p = &src;
else
{
src.convertTo(src32, CV_32F);
src32p = &src32;
}
const int rows = src32p->rows;
const int cols_x_cn = src32p->cols * src32p->channels();
for (int r = 0; r < rows; ++r)
{
const float* in = src32p->ptr<float>(r);
ushort* out = dst_bits.ptr<ushort>(r);
for (int i = 0; i < cols_x_cn; ++i)
{
Cv32suf u; u.f = in[i];
out[i] = (ushort)(u.u >> 16);
}
}
return;
}
src.convertTo(dst, dst.depth());
}
}
class Cast2LayerImpl CV_FINAL : public Cast2Layer
{
public:
Cast2LayerImpl(const LayerParams& params)
{
setParamsFrom(params);
hasToParam = false;
toCvDepth_ = -1;
if (params.has("to"))
{
hasToParam = true;
toCvDepth_ = mapToCvDepth(params.get<int>("to"));
}
else if (params.has("outputType"))
{
const int v = params.get<int>("outputType");
if (v == CV_Bool || v == CV_8U || v == CV_8S || v == CV_16U || v == CV_16S ||
v == CV_32S || v == CV_64S || v == CV_32F || v == CV_64F || v == CV_16F || v == CV_16BF)
{
hasToParam = true;
toCvDepth_ = v;
}
else
{
CV_Error(Error::StsNotImplemented, "Cast: unsupported 'outputType' value");
}
}
}
virtual bool supportBackend(int backendId) CV_OVERRIDE
{
return backendId == DNN_BACKEND_OPENCV ||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
}
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
std::vector<MatShape> &internals) const CV_OVERRIDE
{
CV_Check(inputs.size(), inputs.size() == 1 || inputs.size() == 2, "Cast expects 1 (Cast) or 2 (CastLike) inputs");
outputs.assign(1, inputs[0]);
return false;
}
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_Check(inputs.size(), !inputs.empty(), "Cast expects at least 1 input");
int targetDepth = -1;
if (hasToParam)
{
targetDepth = toCvDepth_;
}
else
{
Net::Impl* netimpl_ = getNetImpl(const_cast<Cast2LayerImpl*>(this));
if (netimpl_ && this->inputs.size() >= 2)
{
const Arg& in1_arg = this->inputs[1];
if (in1_arg.idx >= 0)
{
const ArgData& ad = netimpl_->argData(in1_arg);
if (ad.type >= 0)
targetDepth = CV_MAT_DEPTH(ad.type);
}
}
}
if (targetDepth < 0)
{
targetDepth = CV_32F;
}
const int in0Type = inputs[0];
const int in0CN = in0Type >= 0 ? CV_MAT_CN(in0Type) : 1;
int planDepth = targetDepth;
if (planDepth == CV_16F) planDepth = CV_32F;
if (planDepth == CV_16BF) planDepth = CV_16U;
const int outType = CV_MAKETYPE(planDepth, in0CN);
outputs.assign(1, outType);
}
#ifdef HAVE_OPENCL
bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_)
{
std::vector<UMat> inputs, outputs;
inputs_.getUMatVector(inputs);
outputs_.getUMatVector(outputs);
CV_CheckEQ(inputs.size(), (size_t)1, "");
CV_CheckEQ(outputs.size(), (size_t)1, "");
int runtimeTargetDepth = -1;
if (hasToParam)
{
runtimeTargetDepth = toCvDepth_;
}
else
{
if (inputs.size() >= 2 && !inputs[1].empty())
runtimeTargetDepth = inputs[1].depth();
else
runtimeTargetDepth = inputs[0].depth();
}
if (runtimeTargetDepth == CV_16F && outputs[0].depth() == CV_32F)
{
return false;
}
if (inputs[0].depth() == outputs[0].depth())
inputs[0].copyTo(outputs[0]);
else
inputs[0].convertTo(outputs[0], outputs[0].depth());
return true;
}
#endif
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
{
CV_TRACE_FUNCTION();
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
forward_ocl(inputs_arr, outputs_arr, internals_arr));
std::vector<Mat> inputs, outputs;
inputs_arr.getMatVector(inputs);
outputs_arr.getMatVector(outputs);
CV_Check(inputs.size(), inputs.size() == 1 || inputs.size() == 2, "Cast expects 1 (Cast) or 2 (CastLike) inputs");
CV_CheckEQ(outputs.size(), (size_t)1, "");
const Mat& src0 = inputs[0];
Mat& dst0 = outputs[0];
int runtimeTargetDepth = -1;
if (hasToParam)
{
runtimeTargetDepth = toCvDepth_;
}
else
{
Net::Impl* netimpl_ = getNetImpl(this);
if (netimpl_ && this->inputs.size() >= 2)
{
const Arg& in1_arg = this->inputs[1];
const ArgData& ad = netimpl_->argData(in1_arg);
if (ad.type >= 0)
runtimeTargetDepth = CV_MAT_DEPTH(ad.type);
}
if (runtimeTargetDepth < 0 && inputs.size() >= 2 && !inputs[1].empty())
runtimeTargetDepth = inputs[1].depth();
if (runtimeTargetDepth < 0)
runtimeTargetDepth = src0.depth();
}
CV_CheckGE(runtimeTargetDepth, 0, "Cast: failed to resolve target data type at runtime");
int plannedDDepth = (runtimeTargetDepth == CV_16F) ? CV_32F :
(runtimeTargetDepth == CV_16BF ? CV_16U : runtimeTargetDepth);
if (dst0.depth() != plannedDDepth)
dst0.create(dst0.size(), CV_MAKETYPE(plannedDDepth, src0.channels()));
Mat src = src0;
Mat dst = dst0;
const int sdepth = src.depth();
const int ddepth = dst.depth();
if (sdepth == runtimeTargetDepth && !(runtimeTargetDepth == CV_16F && ddepth == CV_32F))
{
src0.copyTo(dst0);
return;
}
if (runtimeTargetDepth == CV_16BF && (ddepth == CV_16BF || ddepth == CV_16U))
{
castQuantized(src, dst, CV_16BF);
}
else if (sdepth == CV_16BF)
{
src.convertTo(dst, ddepth);
}
else if (runtimeTargetDepth == CV_16F && ddepth == CV_32F)
{
castQuantized(src, dst, CV_16F);
}
else if (runtimeTargetDepth == CV_64F && ddepth != CV_64F)
{
if (ddepth == CV_16U || ddepth == CV_16BF)
{
castQuantized(src, dst, CV_16BF);
}
else
src.convertTo(dst, ddepth);
}
else
{
src.convertTo(dst, ddepth);
}
}
#ifdef HAVE_DNN_NGRAPH
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE
{
ov::element::Type dstType;
if (hasToParam)
{
dstType = cvTypeToOvType(CV_MAKETYPE(toCvDepth_, 1));
}
else if (nodes.size() >= 2)
{
dstType = nodes[1].dynamicCast<InfEngineNgraphNode>()->node.get_element_type();
}
else
{
dstType = nodes[0].dynamicCast<InfEngineNgraphNode>()->node.get_element_type();
}
auto cast = std::make_shared<ov::op::v0::Convert>(nodes[0].dynamicCast<InfEngineNgraphNode>()->node, dstType);
return Ptr<BackendNode>(new InfEngineNgraphNode(cast));
}
#endif // HAVE_DNN_NGRAPH
private:
bool hasToParam = false;
int toCvDepth_ = -1;
static int mapToCvDepth(int v)
{
switch (v)
{
case opencv_onnx::TensorProto_DataType_FLOAT: return CV_32F;
case opencv_onnx::TensorProto_DataType_UINT8: return CV_8U;
case opencv_onnx::TensorProto_DataType_INT8: return CV_8S;
case opencv_onnx::TensorProto_DataType_UINT16: return CV_16U;
case opencv_onnx::TensorProto_DataType_INT16: return CV_16S;
case opencv_onnx::TensorProto_DataType_INT32: return CV_32S;
case opencv_onnx::TensorProto_DataType_INT64: return CV_64S;
case opencv_onnx::TensorProto_DataType_BOOL: return CV_Bool;
case opencv_onnx::TensorProto_DataType_FLOAT16: return CV_16F;
case opencv_onnx::TensorProto_DataType_DOUBLE: return CV_64F;
case opencv_onnx::TensorProto_DataType_BFLOAT16: return CV_16BF;
default: break;
}
CV_Error(Error::StsNotImplemented, "Cast: unsupported 'to' / dtype value");
}
int resolveTargetDepthAtTypeTime(const std::vector<MatType>& inputs) const
{
if (hasToParam)
return toCvDepth_;
if (inputs.size() == 2)
{
int likeType = inputs[1];
if (likeType >= 0)
return CV_MAT_DEPTH(likeType);
return -1;
}
return CV_MAT_DEPTH(inputs[0]);
}
};
Ptr<Cast2Layer> Cast2Layer::create(const LayerParams& params)
{
return makePtr<Cast2LayerImpl>(params);
}
}} // namespace cv::dnn

View File

@ -168,8 +168,32 @@ Mat DiagonalInnermostDims(const Mat& input, bool preserve_innermost_dim_val) {
output_dims[rank - 1] = 1;
}
// TODO: hande different types
Mat output = DiagonalDataAssignment<float>(input);
Mat output;
switch (input.depth())
{
case CV_32F:
output = DiagonalDataAssignment<float>(input);
break;
case CV_64F:
output = DiagonalDataAssignment<double>(input);
break;
case CV_16F:
{
Mat tmp32;
input.convertTo(tmp32, CV_32F);
Mat out32 = DiagonalDataAssignment<float>(tmp32);
out32.convertTo(output, input.type());
break;
}
default:
{
Mat tmp32;
input.convertTo(tmp32, CV_32F);
Mat out32 = DiagonalDataAssignment<float>(tmp32);
out32.convertTo(output, input.type());
break;
}
}
if (output_dims != shape(output)){
CV_Error(Error::StsError, "Output shape does not match with calculated shape");
@ -420,7 +444,26 @@ public:
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
}
// getMemoryShapes
virtual void getTypes(const std::vector<MatType>& inputs,
const int requiredOutputs,
const int requiredInternals,
std::vector<MatType>& outputs,
std::vector<MatType>& internals) const CV_OVERRIDE
{
CV_Assert(!inputs.empty());
for (const int t : inputs)
CV_CheckTypeEQ(t, inputs[0], "All Einsum inputs must have the same type");
if (preferableTarget == DNN_TARGET_OPENCL_FP16)
CV_CheckType(inputs[0], inputs[0] == CV_16F || inputs[0] == CV_32F || inputs[0] == CV_64F, "");
else
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_64F || inputs[0] == CV_16F, "");
outputs.assign(1, inputs[0]);
internals.assign(requiredInternals, inputs[0]);
}
// getMeoryShapes
bool getMemoryShapes(const std::vector<MatShape> &inputs,
const int requiredOutputs,
std::vector<MatShape> &outputs,
@ -586,16 +629,27 @@ Mat LayerEinsumImpl::reduceSum(Mat& src, MatShape& reduceAxis)
std::vector<MatShape> outputShapes, internalShapes;
reduce->getMemoryShapes(inputShapes, 1, outputShapes, internalShapes);
int origType = src.type();
Mat src32 = src;
if (src.depth() != CV_32F)
src.convertTo(src32, CV_32F);
Mat output(outputShapes[0], CV_32F);
std::vector<Mat> inputs;
std::vector<Mat> outputs;
std::vector<Mat> internals;
inputs.emplace_back(src);
inputs.emplace_back(src32);
outputs.emplace_back(output);
reduce->forward(inputs, outputs, internals);
return outputs[0];
Mat out = outputs[0];
if (CV_MAT_TYPE(origType) != CV_32F)
{
Mat converted;
out.convertTo(converted, origType);
return converted;
}
return out;
}
void LayerEinsumImpl::preProcessInputs(InputArrayOfArrays& inputs_arr)
@ -1374,48 +1428,61 @@ Mat LayerEinsumImpl::batchwiseMatMul(
Mat reshapedInput1 = input1;
Mat reshapedInput2 = input2;
int origType = input1.type();
Mat a = reshapedInput1, b = reshapedInput2;
if (input1.depth() != CV_32F)
{
reshapedInput1.convertTo(a, CV_32F);
reshapedInput2.convertTo(b, CV_32F);
}
Mat output;
if (batches > 1)
{
// create tmpout with type like input1
output = Mat({batches, M, N}, input1.type());
output = Mat({batches, M, N}, CV_32F);
reshapedInput2 = reshapedInput2.reshape(1, input2ShapeOverride);
reshapedInput1 = reshapedInput1.reshape(1, input1ShapeOverride);
b = b.reshape(1, input2ShapeOverride);
a = a.reshape(1, input1ShapeOverride);
fastGemmBatch(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, output, opt);
fastGemmBatch(false, false, 1.0, a, b, 0.0, output, opt);
} else {
// input1 should of size MxK
// check if input1 needs reshape, if need reshape
if (input1.dims > 2 || input1.size[0] != M || (input1.dims > 1 && input1.size[1] != K) || input1.dims == 1)
if (reshapedInput1.dims > 2 || reshapedInput1.size[0] != M || (reshapedInput1.dims > 1 && reshapedInput1.size[1] != K) || reshapedInput1.dims == 1)
{
int shape[] = {M, K};
reshapedInput1 = input1.reshape(1, 2, shape);
a = a.reshape(1, 2, shape);
}
// input2 should be of size KxN
// check if input2 needs reshape, if needs reshape
if (input2.dims > 2 || input2.size[0] != K || (input2.dims > 1 && input2.size[1] != N) || input2.dims == 1)
if (reshapedInput2.dims > 2 || reshapedInput2.size[0] != K || (reshapedInput2.dims > 1 && reshapedInput2.size[1] != N) || reshapedInput2.dims == 1)
{
int shape2[] = {K, N};
reshapedInput2 = input2.reshape(1, 2, shape2);
b = b.reshape(1, 2, shape2);
}
output = Mat(M, N, reshapedInput1.type());
if ((reshapedInput1.dims == 0 && reshapedInput2.dims == 0) ||
(reshapedInput1.dims == 0 && reshapedInput2.dims != 0) ||
(reshapedInput1.dims != 0 && reshapedInput2.dims == 0))
output = Mat(M, N, CV_32F);
if ((a.dims == 0 && b.dims == 0) ||
(a.dims == 0 && b.dims != 0) ||
(a.dims != 0 && b.dims == 0))
{
output = reshapedInput1.mul(reshapedInput2); // fastGemm does not support 0D * 0D multiplication
output = a.mul(b); // fastGemm does not support 0D * 0D multiplication
} else {
fastGemm(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, output, opt);
fastGemm(false, false, 1.0, a, b, 0.0, output, opt);
}
output = output.reshape(1, {1, M, N});
}
if (CV_MAT_TYPE(origType) != CV_32F)
{
Mat converted;
output.convertTo(converted, origType);
return converted;
}
return output;
};
Ptr<EinsumLayer> EinsumLayer::create(const LayerParams& params)

View File

@ -397,9 +397,9 @@ public:
{
CV_CheckTypeEQ(inputs[0], input, "All inputs should have equal types");
if (preferableTarget == DNN_TARGET_OPENCL_FP16)
CV_CheckType(input, input == CV_16F || input == CV_8S || input == CV_8U || input == CV_16S || input == CV_16U || input == CV_32S || input == CV_32U || input == CV_64S || input == CV_64U, "");
CV_CheckType(input, input == CV_16F || input == CV_32F || input == CV_64F || input == CV_8S || input == CV_8U || input == CV_16S || input == CV_16U || input == CV_32S || input == CV_32U || input == CV_64S || input == CV_64U, "");
else
CV_CheckType(input, input == CV_32F || input == CV_8S || input == CV_8U || input == CV_16S || input == CV_16U || input == CV_32S || input == CV_32U || input == CV_64S || input == CV_64U, "");
CV_CheckType(input, input == CV_32F || input == CV_64F || input == CV_8S || input == CV_8U || input == CV_16S || input == CV_16U || input == CV_32S || input == CV_32U || input == CV_64S || input == CV_64U, "");
}
if (op == OPERATION::EQUAL || op == OPERATION::GREATER || op == OPERATION::GREATER_EQUAL || op == OPERATION::LESS || op == OPERATION::LESS_EQUAL)

View File

@ -559,12 +559,11 @@ void Net::Impl::setGraphInput(Ptr<Graph>& graph, size_t idx, const Mat& m)
int adata_type = adata.type;
if ((adata_type == CV_16F || adata_type == CV_16BF) && !enableFP16)
adata_type = CV_32F;
// [TODO] need to analyze this situation more carefully
if (adata_type == CV_64F)
adata_type = CV_32F;
if (adata_type != mtype &&
!((adata_type == CV_64F || adata_type == CV_32F || adata_type == CV_16F || adata_type == CV_16BF) &&
(mtype == CV_64F || mtype == CV_32F || mtype == CV_16F || mtype == CV_16BF)))
(mtype == CV_64F || mtype == CV_32F || mtype == CV_16F || mtype == CV_16BF)) &&
!(adata.type == CV_16BF && mtype == CV_16U) && !(adata.type == CV_16F && mtype == CV_16U))
{
CV_Error_(Error::StsBadArg, ("incompatible type of input tensor #%zu '%s': %s given, %s expected",
idx, adata.name.c_str(), typeToString(mtype).c_str(),
@ -574,7 +573,21 @@ void Net::Impl::setGraphInput(Ptr<Graph>& graph, size_t idx, const Mat& m)
if (inp_t.shape() != mshape || inp_t.type() != adata_type)
finalizeLayers = true;
inp_t.fit(mshape, adata_type);
m.convertTo(inp_t, adata_type);
if (adata.type == CV_16BF && mtype == CV_16U)
{
Mat tmp(mshape, CV_16BF, (void*)m.data);
tmp.convertTo(inp_t, adata_type);
}
else if (adata.type == CV_16F && mtype == CV_16U)
{
Mat tmp(mshape, CV_16F, (void*)m.data);
tmp.convertTo(inp_t, adata_type);
}
else
{
m.convertTo(inp_t, adata_type);
}
} else if (adata.kind == DNN_ARG_TEMP) {
int bufidx = bufidxs.at(inp.idx);
Mat& buf = buffers.at(bufidx);

View File

@ -1813,12 +1813,35 @@ Mat getMatFromTensor(const opencv_onnx::TensorProto& tensor_proto, bool uint8ToI
Mat(sizes, CV_16FC1, rawdata).convertTo(blob, CV_32FC1);
}
}
else if (datatype == opencv_onnx::TensorProto_DataType_BFLOAT16)
{
if (!tensor_proto.raw_data().empty())
{
blob.create((int)sizes.size(), sizes.data(), CV_16BFC1);
size_t bytes = (size_t)blob.total() * blob.elemSize();
memcpy(blob.data, rawdata, bytes);
}
else if (!tensor_proto.int32_data().empty())
{
const auto& v = tensor_proto.int32_data();
blob.create((int)sizes.size(), sizes.data(), CV_16BFC1);
uint16_t* dst = reinterpret_cast<uint16_t*>(blob.data);
for (size_t i = 0; i < v.size(); ++i)
{
dst[i] = static_cast<uint16_t>(v[i] & 0xFFFF);
}
}
else
{
CV_Error(Error::StsNotImplemented, "BFLOAT16 tensor has no raw_data");
}
}
else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
{
if (!tensor_proto.double_data().empty())
Mat(sizes, CV_64FC1, (void*)tensor_proto.double_data().data()).convertTo(blob, CV_32FC1);
else
Mat(sizes, CV_64FC1, rawdata).convertTo(blob, CV_32FC1);
Mat(sizes, CV_64FC1, rawdata).copyTo(blob);
}
else if (datatype == opencv_onnx::TensorProto_DataType_INT32)
{

View File

@ -74,6 +74,7 @@ static int dataType2cv(int dt)
dt == opencv_onnx::TensorProto_DataType_FLOAT ? CV_32F :
dt == opencv_onnx::TensorProto_DataType_DOUBLE ? CV_64F :
dt == opencv_onnx::TensorProto_DataType_FLOAT16 ? CV_16F :
dt == opencv_onnx::TensorProto_DataType_BFLOAT16 ? CV_16BF :
dt == opencv_onnx::TensorProto_DataType_COMPLEX64 ? CV_32FC2 :
dt == opencv_onnx::TensorProto_DataType_COMPLEX128 ? CV_64FC2 :
dt == opencv_onnx::TensorProto_DataType_BOOL ? CV_Bool : -1;
@ -95,6 +96,7 @@ static std::string dataType2str(int dt)
dt == opencv_onnx::TensorProto_DataType_INT64 ? "INT64" :
dt == opencv_onnx::TensorProto_DataType_FLOAT ? "FLOAT" :
dt == opencv_onnx::TensorProto_DataType_FLOAT16 ? "FLOAT16" :
dt == opencv_onnx::TensorProto_DataType_BFLOAT16 ? "BFLOAT16" :
dt == opencv_onnx::TensorProto_DataType_BOOL ? "BOOL" :
dt == opencv_onnx::TensorProto_DataType_COMPLEX64 ? "COMPLEX64" :
dt == opencv_onnx::TensorProto_DataType_COMPLEX128 ? "COMPLEX128" : nullptr;
@ -174,6 +176,8 @@ protected:
void parseAveragePool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseBatchNormalization (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseCast (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseCast2 (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseCastLike (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseClip (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseConcat (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
void parseIf (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
@ -1478,13 +1482,16 @@ void ONNXImporter2::parseShape(LayerParams& layerParams, const opencv_onnx::Node
addLayer(layerParams, node_proto);
}
void ONNXImporter2::parseCast(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
void ONNXImporter2::parseCast2(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
opencv_onnx::TensorProto_DataType onnx_type = (opencv_onnx::TensorProto_DataType)layerParams.get<int>("to");
int type = dataType2cv(onnx_type);
layerParams.type = "Cast2";
addLayer(layerParams, node_proto);
}
layerParams.type = "Cast";
layerParams.set("outputType", type);
void ONNXImporter2::parseCastLike(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
{
CV_CheckEQ(node_proto.input_size(), 2, "CastLike requires two inputs");
layerParams.type = "Cast2";
addLayer(layerParams, node_proto);
}
@ -2539,7 +2546,8 @@ void ONNXImporter2::buildDispatchMap_ONNX_AI(int opset_version)
dispatch["Reshape"] = &ONNXImporter2::parseReshape;
dispatch["Pad"] = &ONNXImporter2::parsePad;
dispatch["Shape"] = &ONNXImporter2::parseShape;
dispatch["Cast"] = &ONNXImporter2::parseCast;
dispatch["Cast"] = &ONNXImporter2::parseCast2;
dispatch["CastLike"] = &ONNXImporter2::parseCastLike;
dispatch["ConstantFill"] = dispatch["ConstantOfShape"] = &ONNXImporter2::parseConstantOfShape;
dispatch["Gather"] = &ONNXImporter2::parseGather;
dispatch["GatherElements"] = &ONNXImporter2::parseGatherElements;

View File

@ -41,6 +41,7 @@
"test_cast_STRING_to_FLOAT",
"test_castlike_FLOAT_to_STRING_expanded",
"test_castlike_STRING_to_FLOAT_expanded",
"test_cast_DOUBLE_to_FLOAT",
"test_concat_1d_axis_negative_1",
"test_conv_with_autopad_same",
"test_conv_with_strides_and_asymmetric_padding",
@ -56,6 +57,10 @@
"test_div_bcast",
"test_div_uint8",
"test_dropout_default_ratio",
"test_einsum_batch_diagonal",
"test_einsum_batch_matmul",
"test_einsum_sum",
"test_einsum_transpose",
"test_flatten_axis0",
"test_flatten_axis2",
"test_flatten_axis3",
@ -71,6 +76,9 @@
"test_maxpool_with_argmax_2d_precomputed_pads",
"test_maxpool_with_argmax_2d_precomputed_strides",
"test_maxunpool_export_with_output_shape",
"test_max_float64",
"test_min_float64",
"test_mod_mixed_sign_float64",
"test_mul_bcast",
"test_mul_uint8",
"test_softmax_default_axis",

View File

@ -1,10 +1,22 @@
"test_basic_conv_with_padding", // (assert failed) !blobs.empty() in initCUDA
"test_basic_conv_without_padding", // (assert failed) !blobs.empty() in initCUDA
"test_cast_DOUBLE_to_FLOAT",
"test_conv_with_autopad_same", // (assert failed) !blobs.empty() in initCUDA
"test_conv_with_strides_and_asymmetric_padding", // (assert failed) !blobs.empty() in initCUDA
"test_conv_with_strides_no_padding", // (assert failed) !blobs.empty() in initCUDA
"test_conv_with_strides_padding", // (assert failed) !blobs.empty() in initCUDA
"test_cumsum_1d",
"test_cumsum_1d_exclusive",
"test_cumsum_1d_reverse",
"test_cumsum_1d_reverse_exclusive",
"test_cumsum_2d_axis_0",
"test_cumsum_2d_axis_1",
"test_cumsum_2d_negative_axis",
"test_dropout_default_ratio",
"test_einsum_batch_diagonal",
"test_einsum_batch_matmul",
"test_einsum_sum",
"test_einsum_transpose",
"test_logsoftmax_large_number", // fp16 accuracy issue
"test_logsoftmax_large_number_expanded", // fp16 accuracy issue
"test_maxpool_with_argmax_2d_precomputed_pads", // assertion failed mat.type() == CV_32F
@ -23,3 +35,6 @@
"test_quantizelinear", // Issue https://github.com/opencv/opencv/issues/25999
"test_quantizelinear_axis", // Issue https://github.com/opencv/opencv/issues/25999
"test_quantizelinear_blocked", // Issue https://github.com/opencv/opencv/issues/25999
"test_max_float64",
"test_min_float64",
"test_mod_mixed_sign_float64",

View File

@ -289,21 +289,21 @@ CASE(test_bitshift_right_uint64)
CASE(test_bitshift_right_uint8)
SKIP;
CASE(test_cast_BFLOAT16_to_FLOAT)
// no filter
SKIP;
CASE(test_cast_DOUBLE_to_FLOAT)
// no filter
SKIP;
CASE(test_cast_DOUBLE_to_FLOAT16)
// no filter
SKIP;
CASE(test_cast_FLOAT16_to_DOUBLE)
// no filter
SKIP;
CASE(test_cast_FLOAT16_to_FLOAT)
// no filter
SKIP;
CASE(test_cast_FLOAT_to_BFLOAT16)
// no filter
SKIP;
CASE(test_cast_FLOAT_to_DOUBLE)
// no filter
SKIP;
CASE(test_cast_FLOAT_to_FLOAT16)
// no filter
SKIP;
CASE(test_cast_FLOAT_to_STRING)
#if SKIP_SET_1
SKIP;
@ -313,37 +313,37 @@ CASE(test_cast_STRING_to_FLOAT)
SKIP;
#endif
CASE(test_castlike_BFLOAT16_to_FLOAT)
// no filter
SKIP;
CASE(test_castlike_BFLOAT16_to_FLOAT_expanded)
// no filter
SKIP;
CASE(test_castlike_DOUBLE_to_FLOAT)
// no filter
SKIP;
CASE(test_castlike_DOUBLE_to_FLOAT16)
// no filter
SKIP;
CASE(test_castlike_DOUBLE_to_FLOAT16_expanded)
// no filter
SKIP;
CASE(test_castlike_DOUBLE_to_FLOAT_expanded)
// no filter
SKIP;
CASE(test_castlike_FLOAT16_to_DOUBLE)
// no filter
SKIP;
CASE(test_castlike_FLOAT16_to_DOUBLE_expanded)
// no filter
SKIP;
CASE(test_castlike_FLOAT16_to_FLOAT)
// no filter
SKIP;
CASE(test_castlike_FLOAT16_to_FLOAT_expanded)
// no filter
SKIP;
CASE(test_castlike_FLOAT_to_BFLOAT16)
// no filter
SKIP;
CASE(test_castlike_FLOAT_to_BFLOAT16_expanded)
// no filter
SKIP;
CASE(test_castlike_FLOAT_to_DOUBLE)
// no filter
SKIP;
CASE(test_castlike_FLOAT_to_DOUBLE_expanded)
// no filter
SKIP;
CASE(test_castlike_FLOAT_to_FLOAT16)
// no filter
SKIP;
CASE(test_castlike_FLOAT_to_FLOAT16_expanded)
// no filter
SKIP;
CASE(test_castlike_FLOAT_to_STRING)
// no filter
CASE(test_castlike_FLOAT_to_STRING_expanded)

View File

@ -280,3 +280,30 @@
"test_div_uint64",
"test_cumsum_1d_int32_exclusive",
"test_cumsum_2d_int32",
"test_cast_BFLOAT16_to_FLOAT",
"test_cast_DOUBLE_to_FLOAT16",
"test_cast_FLOAT16_to_DOUBLE",
"test_cast_FLOAT16_to_FLOAT",
"test_cast_FLOAT_to_BFLOAT16",
"test_cast_FLOAT_to_DOUBLE",
"test_cast_FLOAT_to_FLOAT16",
"test_castlike_BFLOAT16_to_FLOAT",
"test_castlike_BFLOAT16_to_FLOAT_expanded",
"test_castlike_DOUBLE_to_FLOAT",
"test_castlike_DOUBLE_to_FLOAT16",
"test_castlike_DOUBLE_to_FLOAT16_expanded",
"test_castlike_DOUBLE_to_FLOAT_expanded",
"test_castlike_FLOAT16_to_DOUBLE",
"test_castlike_FLOAT16_to_DOUBLE_expanded",
"test_castlike_FLOAT16_to_FLOAT",
"test_castlike_FLOAT16_to_FLOAT_expanded",
"test_castlike_FLOAT_to_BFLOAT16",
"test_castlike_FLOAT_to_BFLOAT16_expanded",
"test_castlike_FLOAT_to_DOUBLE",
"test_castlike_FLOAT_to_DOUBLE_expanded",
"test_castlike_FLOAT_to_FLOAT16",
"test_castlike_FLOAT_to_FLOAT16_expanded",
"test_gelu_default_1_expanded",
"test_gelu_default_2_expanded",
"test_gelu_tanh_1_expanded",
"test_gelu_tanh_2_expanded",

View File

@ -183,10 +183,6 @@
"test_blackmanwindow_expanded",
"test_blackmanwindow_symmetric",
"test_blackmanwindow_symmetric_expanded",
"test_cast_BFLOAT16_to_FLOAT", // Issue::Unsuppoted data type
"test_cast_DOUBLE_to_FLOAT16", // Issue::Unsuppoted data type
"test_cast_FLOAT16_to_DOUBLE", // Issue::Unsuppoted data type
"test_cast_FLOAT16_to_FLOAT", // Issue::Unsuppoted data type
"test_cast_FLOAT16_to_FLOAT4E2M1",
"test_cast_FLOAT16_to_FLOAT8E4M3FN",
"test_cast_FLOAT16_to_FLOAT8E4M3FNUZ",
@ -204,9 +200,6 @@
"test_cast_FLOAT8E5M2FNUZ_to_FLOAT16",
"test_cast_FLOAT8E5M2_to_FLOAT",
"test_cast_FLOAT8E5M2_to_FLOAT16",
"test_cast_FLOAT_to_BFLOAT16", // Issue::Unsuppoted data type
"test_cast_FLOAT_to_DOUBLE", // Issue::Unsuppoted data type
"test_cast_FLOAT_to_FLOAT16", // Issue::Unsuppoted data type
"test_cast_FLOAT_to_FLOAT4E2M1",
"test_cast_FLOAT_to_FLOAT8E4M3FN",
"test_cast_FLOAT_to_FLOAT8E4M3FNUZ",
@ -232,15 +225,6 @@
"test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ",
"test_cast_no_saturate_FLOAT_to_FLOAT8E5M2",
"test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ",
"test_castlike_BFLOAT16_to_FLOAT", // Issue::Unsuppoted data type
"test_castlike_BFLOAT16_to_FLOAT_expanded", // Issue::Unsuppoted data type
"test_castlike_DOUBLE_to_FLOAT", // Issues::Layer::Can't create layer "onnx_node_output_0!output" of type "CastLike" in function 'getLayerInstance'
"test_castlike_DOUBLE_to_FLOAT16", // Issues::Layer::Can't create layer "onnx_node_output_0!output" of type "CastLike" in function 'getLayerInstance'
"test_castlike_DOUBLE_to_FLOAT16_expanded", // Issues::Layer::mismatch in input and output shapes inputs.size() == requiredOutputs in function 'getMemoryShapes'
"test_castlike_DOUBLE_to_FLOAT_expanded", // Issues::Layer::mismatch in input and output shapes inputs.size() == requiredOutputs in function 'getMemoryShapes'
"test_castlike_FLOAT16_to_DOUBLE", // Issue::Unsuppoted data type
"test_castlike_FLOAT16_to_DOUBLE_expanded", // Issue::Unsuppoted data type
"test_castlike_FLOAT16_to_FLOAT", // Issues::Layer::Can't create layer "onnx_node_output_0!output" of type "CastLike" in function 'getLayerInstance'
"test_castlike_FLOAT16_to_FLOAT4E2M1",
"test_castlike_FLOAT16_to_FLOAT4E2M1_expanded",
"test_castlike_FLOAT16_to_FLOAT8E4M3FN",
@ -251,7 +235,6 @@
"test_castlike_FLOAT16_to_FLOAT8E5M2FNUZ",
"test_castlike_FLOAT16_to_FLOAT8E5M2FNUZ_expanded",
"test_castlike_FLOAT16_to_FLOAT8E5M2_expanded",
"test_castlike_FLOAT16_to_FLOAT_expanded", // Issues::Layer::mismatch in input and output shapes inputs.size() == requiredOutputs in function 'getMemoryShapes'
"test_castlike_FLOAT16_to_INT4",
"test_castlike_FLOAT16_to_INT4_expanded",
"test_castlike_FLOAT16_to_UINT4",
@ -276,12 +259,6 @@
"test_castlike_FLOAT8E5M2_to_FLOAT16",
"test_castlike_FLOAT8E5M2_to_FLOAT16_expanded",
"test_castlike_FLOAT8E5M2_to_FLOAT_expanded",
"test_castlike_FLOAT_to_BFLOAT16", // Issue::Unsuppoted data type
"test_castlike_FLOAT_to_BFLOAT16_expanded", // Issue::Unsuppoted data type
"test_castlike_FLOAT_to_DOUBLE", // Issues::Layer::Can't create layer "onnx_node_output_0!output" of type "CastLike" in function 'getLayerInstance'
"test_castlike_FLOAT_to_DOUBLE_expanded", // Issue::Unsuppoted data type
"test_castlike_FLOAT_to_FLOAT16", // Issues::Layer::mismatch in input and output shapes inputs.size() == requiredOutputs in function 'getMemoryShapes'
"test_castlike_FLOAT_to_FLOAT16_expanded", // Issues::Layer::mismatch in input and output shapes inputs.size() == requiredOutputs in function 'getMemoryShapes'
"test_castlike_FLOAT_to_FLOAT4E2M1",
"test_castlike_FLOAT_to_FLOAT4E2M1_expanded",
"test_castlike_FLOAT_to_FLOAT8E4M3FN",
@ -402,10 +379,6 @@
"test_eyelike_populate_off_main_diagonal", // Issues::Layer::Can't create layer::Can't create layer "onnx_node_output_0!y" of type "EyeLike" in function 'getLayerInstance'
"test_eyelike_with_dtype", // ---- same as above ---
"test_eyelike_without_dtype", // ---- same as above ---
"test_gelu_default_1_expanded", // parser: no corresponding layer for CastLike
"test_gelu_default_2_expanded", // parser: no corresponding layer for CastLike
"test_gelu_tanh_1_expanded", // parser: no corresponding layer for CastLike
"test_gelu_tanh_2_expanded", // parser: no corresponding layer for CastLike
"test_gridsample_bicubic", // ---- same as above ---
"test_gridsample_bicubic_align_corners_0_additional_1",
"test_gridsample_bicubic_align_corners_1_additional_1",