mirror of
https://github.com/zebrajr/opencv.git
synced 2025-12-06 12:19:50 +01:00
Merge pull request #27698 from abhishek-gola:add_cast_layer
Added cast and castlike layers support in new DNN engine #27698 ### Pull Request Readiness Checklist See details at https://github.com/opencv/opencv/wiki/How_to_contribute#making-a-good-pull-request - [x] I agree to contribute to the project under Apache 2 License. - [x] To the best of my knowledge, the proposed patch is not based on a code under GPL or another license that is incompatible with OpenCV - [x] The PR is proposed to the proper branch - [x] There is a reference to the original bug report and related work - [x] There is accuracy test, performance test and test data in opencv_extra repository, if applicable Patch to opencv_extra has the same branch name. - [x] The feature is well documented and sample code can be built with the project CMake
This commit is contained in:
parent
9c75140b1b
commit
2470c07f1b
|
|
@ -1414,6 +1414,11 @@ CV__DNN_INLINE_NS_BEGIN
|
|||
static Ptr<CastLayer> create(const LayerParams ¶ms);
|
||||
};
|
||||
|
||||
class CV_EXPORTS Cast2Layer : public Layer {
|
||||
public:
|
||||
static Ptr<Cast2Layer> create(const LayerParams ¶ms);
|
||||
};
|
||||
|
||||
class CV_EXPORTS NonMaxSuppressionLayer : public Layer
|
||||
{
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -63,6 +63,12 @@ ov::element::Type cvTypeToOvType(MatType cvType)
|
|||
return ov::element::i64;
|
||||
case CV_Bool:
|
||||
return ov::element::boolean;
|
||||
case CV_16F:
|
||||
return ov::element::f16;
|
||||
case CV_16BF:
|
||||
return ov::element::bf16;
|
||||
case CV_64F:
|
||||
return ov::element::f64;
|
||||
default:
|
||||
CV_Error(Error::StsNotImplemented, format("Unsupported data type %s", typeToString(cvType).c_str()));
|
||||
}
|
||||
|
|
@ -88,6 +94,12 @@ MatType ovTypeToCvType(ov::element::Type ovType)
|
|||
return CV_64S;
|
||||
case ov::element::boolean:
|
||||
return CV_Bool;
|
||||
case ov::element::f16:
|
||||
return CV_16F;
|
||||
case ov::element::bf16:
|
||||
return CV_16BF;
|
||||
case ov::element::f64:
|
||||
return CV_64F;
|
||||
default:
|
||||
CV_Error(Error::StsNotImplemented, format("Unsupported data type %s", ovType.get_type_name().c_str()));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -193,6 +193,7 @@ void initializeLayerFactory()
|
|||
CV_DNN_REGISTER_LAYER_CLASS(Attention, AttentionLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(GroupNormalization, GroupNormLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Cast, CastLayer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(Cast2, Cast2Layer);
|
||||
CV_DNN_REGISTER_LAYER_CLASS(DepthToSpace, DepthToSpaceLayer)
|
||||
CV_DNN_REGISTER_LAYER_CLASS(SpaceToDepth, SpaceToDepthLayer)
|
||||
CV_DNN_REGISTER_LAYER_CLASS(DepthToSpaceInt8, DepthToSpaceLayer)
|
||||
|
|
|
|||
373
modules/dnn/src/layers/cast2_layer.cpp
Normal file
373
modules/dnn/src/layers/cast2_layer.cpp
Normal file
|
|
@ -0,0 +1,373 @@
|
|||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp"
|
||||
#include "../op_inf_engine.hpp"
|
||||
#include "../ie_ngraph.hpp"
|
||||
#include "layers_common.hpp"
|
||||
#include "../net_impl.hpp"
|
||||
|
||||
#include "opencv-onnx.pb.h"
|
||||
|
||||
namespace cv { namespace dnn {
|
||||
|
||||
// ONNX Cast operator
|
||||
// Spec: https://onnx.ai/onnx/operators/onnx__Cast.html
|
||||
// Supported opsets: 1-24
|
||||
// ONNX CastLike operator
|
||||
// Spec: https://onnx.ai/onnx/operators/onnx__CastLike.html
|
||||
// Supported opsets: 15-24
|
||||
|
||||
namespace
|
||||
{
|
||||
inline void castQuantized(const Mat& src, Mat& dst, int targetDepth)
|
||||
{
|
||||
if (targetDepth == CV_16F)
|
||||
{
|
||||
CV_Assert(dst.depth() == CV_32F);
|
||||
if (src.depth() == CV_32F)
|
||||
{
|
||||
MatConstIterator_<float> sIt = src.begin<float>(), sEnd = src.end<float>();
|
||||
MatIterator_<float> dIt = dst.begin<float>();
|
||||
for (; sIt != sEnd; ++sIt, ++dIt)
|
||||
{
|
||||
*dIt = (float)hfloat(*sIt);
|
||||
}
|
||||
}
|
||||
else if (src.depth() == CV_64F)
|
||||
{
|
||||
MatConstIterator_<double> sIt = src.begin<double>(), sEnd = src.end<double>();
|
||||
MatIterator_<float> dIt = dst.begin<float>();
|
||||
for (; sIt != sEnd; ++sIt, ++dIt)
|
||||
{
|
||||
float v = (float)*sIt;
|
||||
*dIt = (float)hfloat(v);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
Mat src32; src.convertTo(src32, CV_32F);
|
||||
MatConstIterator_<float> sIt = src32.begin<float>(), sEnd = src32.end<float>();
|
||||
MatIterator_<float> dIt = dst.begin<float>();
|
||||
for (; sIt != sEnd; ++sIt, ++dIt)
|
||||
{
|
||||
*dIt = (float)hfloat(*sIt);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (targetDepth == CV_16BF)
|
||||
{
|
||||
const int ddepth = dst.depth();
|
||||
if (!(ddepth == CV_16BF || ddepth == CV_16U))
|
||||
{
|
||||
CV_Error(Error::StsNotImplemented, "Unsupported destination depth for BF16 cast");
|
||||
}
|
||||
|
||||
Mat dst_bits(dst.size(), CV_MAKETYPE(CV_16U, dst.channels()), dst.data, dst.step);
|
||||
|
||||
const Mat* src32p;
|
||||
Mat src32;
|
||||
if (src.depth() == CV_32F)
|
||||
src32p = &src;
|
||||
else
|
||||
{
|
||||
src.convertTo(src32, CV_32F);
|
||||
src32p = &src32;
|
||||
}
|
||||
|
||||
const int rows = src32p->rows;
|
||||
const int cols_x_cn = src32p->cols * src32p->channels();
|
||||
for (int r = 0; r < rows; ++r)
|
||||
{
|
||||
const float* in = src32p->ptr<float>(r);
|
||||
ushort* out = dst_bits.ptr<ushort>(r);
|
||||
for (int i = 0; i < cols_x_cn; ++i)
|
||||
{
|
||||
Cv32suf u; u.f = in[i];
|
||||
out[i] = (ushort)(u.u >> 16);
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
src.convertTo(dst, dst.depth());
|
||||
}
|
||||
}
|
||||
|
||||
class Cast2LayerImpl CV_FINAL : public Cast2Layer
|
||||
{
|
||||
public:
|
||||
Cast2LayerImpl(const LayerParams& params)
|
||||
{
|
||||
setParamsFrom(params);
|
||||
hasToParam = false;
|
||||
toCvDepth_ = -1;
|
||||
if (params.has("to"))
|
||||
{
|
||||
hasToParam = true;
|
||||
toCvDepth_ = mapToCvDepth(params.get<int>("to"));
|
||||
}
|
||||
else if (params.has("outputType"))
|
||||
{
|
||||
const int v = params.get<int>("outputType");
|
||||
if (v == CV_Bool || v == CV_8U || v == CV_8S || v == CV_16U || v == CV_16S ||
|
||||
v == CV_32S || v == CV_64S || v == CV_32F || v == CV_64F || v == CV_16F || v == CV_16BF)
|
||||
{
|
||||
hasToParam = true;
|
||||
toCvDepth_ = v;
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Error(Error::StsNotImplemented, "Cast: unsupported 'outputType' value");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
virtual bool supportBackend(int backendId) CV_OVERRIDE
|
||||
{
|
||||
return backendId == DNN_BACKEND_OPENCV ||
|
||||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
|
||||
}
|
||||
|
||||
virtual bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
std::vector<MatShape> &internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_Check(inputs.size(), inputs.size() == 1 || inputs.size() == 2, "Cast expects 1 (Cast) or 2 (CastLike) inputs");
|
||||
outputs.assign(1, inputs[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual void getTypes(const std::vector<MatType>& inputs,
|
||||
const int requiredOutputs,
|
||||
const int requiredInternals,
|
||||
std::vector<MatType>& outputs,
|
||||
std::vector<MatType>& internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_Check(inputs.size(), !inputs.empty(), "Cast expects at least 1 input");
|
||||
|
||||
int targetDepth = -1;
|
||||
if (hasToParam)
|
||||
{
|
||||
targetDepth = toCvDepth_;
|
||||
}
|
||||
else
|
||||
{
|
||||
Net::Impl* netimpl_ = getNetImpl(const_cast<Cast2LayerImpl*>(this));
|
||||
if (netimpl_ && this->inputs.size() >= 2)
|
||||
{
|
||||
const Arg& in1_arg = this->inputs[1];
|
||||
if (in1_arg.idx >= 0)
|
||||
{
|
||||
const ArgData& ad = netimpl_->argData(in1_arg);
|
||||
if (ad.type >= 0)
|
||||
targetDepth = CV_MAT_DEPTH(ad.type);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (targetDepth < 0)
|
||||
{
|
||||
targetDepth = CV_32F;
|
||||
}
|
||||
|
||||
const int in0Type = inputs[0];
|
||||
const int in0CN = in0Type >= 0 ? CV_MAT_CN(in0Type) : 1;
|
||||
int planDepth = targetDepth;
|
||||
if (planDepth == CV_16F) planDepth = CV_32F;
|
||||
if (planDepth == CV_16BF) planDepth = CV_16U;
|
||||
const int outType = CV_MAKETYPE(planDepth, in0CN);
|
||||
outputs.assign(1, outType);
|
||||
}
|
||||
|
||||
#ifdef HAVE_OPENCL
|
||||
bool forward_ocl(InputArrayOfArrays inputs_, OutputArrayOfArrays outputs_, OutputArrayOfArrays internals_)
|
||||
{
|
||||
std::vector<UMat> inputs, outputs;
|
||||
|
||||
inputs_.getUMatVector(inputs);
|
||||
outputs_.getUMatVector(outputs);
|
||||
CV_CheckEQ(inputs.size(), (size_t)1, "");
|
||||
CV_CheckEQ(outputs.size(), (size_t)1, "");
|
||||
|
||||
int runtimeTargetDepth = -1;
|
||||
if (hasToParam)
|
||||
{
|
||||
runtimeTargetDepth = toCvDepth_;
|
||||
}
|
||||
else
|
||||
{
|
||||
if (inputs.size() >= 2 && !inputs[1].empty())
|
||||
runtimeTargetDepth = inputs[1].depth();
|
||||
else
|
||||
runtimeTargetDepth = inputs[0].depth();
|
||||
}
|
||||
|
||||
if (runtimeTargetDepth == CV_16F && outputs[0].depth() == CV_32F)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
if (inputs[0].depth() == outputs[0].depth())
|
||||
inputs[0].copyTo(outputs[0]);
|
||||
else
|
||||
inputs[0].convertTo(outputs[0], outputs[0].depth());
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
void forward(InputArrayOfArrays inputs_arr, OutputArrayOfArrays outputs_arr, OutputArrayOfArrays internals_arr) CV_OVERRIDE
|
||||
{
|
||||
CV_TRACE_FUNCTION();
|
||||
CV_TRACE_ARG_VALUE(name, "name", name.c_str());
|
||||
|
||||
CV_OCL_RUN(IS_DNN_OPENCL_TARGET(preferableTarget),
|
||||
forward_ocl(inputs_arr, outputs_arr, internals_arr));
|
||||
|
||||
std::vector<Mat> inputs, outputs;
|
||||
inputs_arr.getMatVector(inputs);
|
||||
outputs_arr.getMatVector(outputs);
|
||||
|
||||
CV_Check(inputs.size(), inputs.size() == 1 || inputs.size() == 2, "Cast expects 1 (Cast) or 2 (CastLike) inputs");
|
||||
CV_CheckEQ(outputs.size(), (size_t)1, "");
|
||||
|
||||
const Mat& src0 = inputs[0];
|
||||
Mat& dst0 = outputs[0];
|
||||
|
||||
int runtimeTargetDepth = -1;
|
||||
if (hasToParam)
|
||||
{
|
||||
runtimeTargetDepth = toCvDepth_;
|
||||
}
|
||||
else
|
||||
{
|
||||
Net::Impl* netimpl_ = getNetImpl(this);
|
||||
if (netimpl_ && this->inputs.size() >= 2)
|
||||
{
|
||||
const Arg& in1_arg = this->inputs[1];
|
||||
const ArgData& ad = netimpl_->argData(in1_arg);
|
||||
if (ad.type >= 0)
|
||||
runtimeTargetDepth = CV_MAT_DEPTH(ad.type);
|
||||
}
|
||||
if (runtimeTargetDepth < 0 && inputs.size() >= 2 && !inputs[1].empty())
|
||||
runtimeTargetDepth = inputs[1].depth();
|
||||
if (runtimeTargetDepth < 0)
|
||||
runtimeTargetDepth = src0.depth();
|
||||
}
|
||||
CV_CheckGE(runtimeTargetDepth, 0, "Cast: failed to resolve target data type at runtime");
|
||||
|
||||
int plannedDDepth = (runtimeTargetDepth == CV_16F) ? CV_32F :
|
||||
(runtimeTargetDepth == CV_16BF ? CV_16U : runtimeTargetDepth);
|
||||
if (dst0.depth() != plannedDDepth)
|
||||
dst0.create(dst0.size(), CV_MAKETYPE(plannedDDepth, src0.channels()));
|
||||
|
||||
Mat src = src0;
|
||||
Mat dst = dst0;
|
||||
|
||||
const int sdepth = src.depth();
|
||||
const int ddepth = dst.depth();
|
||||
|
||||
if (sdepth == runtimeTargetDepth && !(runtimeTargetDepth == CV_16F && ddepth == CV_32F))
|
||||
{
|
||||
src0.copyTo(dst0);
|
||||
return;
|
||||
}
|
||||
|
||||
if (runtimeTargetDepth == CV_16BF && (ddepth == CV_16BF || ddepth == CV_16U))
|
||||
{
|
||||
castQuantized(src, dst, CV_16BF);
|
||||
}
|
||||
else if (sdepth == CV_16BF)
|
||||
{
|
||||
src.convertTo(dst, ddepth);
|
||||
}
|
||||
else if (runtimeTargetDepth == CV_16F && ddepth == CV_32F)
|
||||
{
|
||||
castQuantized(src, dst, CV_16F);
|
||||
}
|
||||
else if (runtimeTargetDepth == CV_64F && ddepth != CV_64F)
|
||||
{
|
||||
if (ddepth == CV_16U || ddepth == CV_16BF)
|
||||
{
|
||||
castQuantized(src, dst, CV_16BF);
|
||||
}
|
||||
else
|
||||
src.convertTo(dst, ddepth);
|
||||
}
|
||||
else
|
||||
{
|
||||
src.convertTo(dst, ddepth);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef HAVE_DNN_NGRAPH
|
||||
virtual Ptr<BackendNode> initNgraph(const std::vector<Ptr<BackendWrapper> >& inputs,
|
||||
const std::vector<Ptr<BackendNode> >& nodes) CV_OVERRIDE
|
||||
{
|
||||
ov::element::Type dstType;
|
||||
if (hasToParam)
|
||||
{
|
||||
dstType = cvTypeToOvType(CV_MAKETYPE(toCvDepth_, 1));
|
||||
}
|
||||
else if (nodes.size() >= 2)
|
||||
{
|
||||
dstType = nodes[1].dynamicCast<InfEngineNgraphNode>()->node.get_element_type();
|
||||
}
|
||||
else
|
||||
{
|
||||
dstType = nodes[0].dynamicCast<InfEngineNgraphNode>()->node.get_element_type();
|
||||
}
|
||||
auto cast = std::make_shared<ov::op::v0::Convert>(nodes[0].dynamicCast<InfEngineNgraphNode>()->node, dstType);
|
||||
return Ptr<BackendNode>(new InfEngineNgraphNode(cast));
|
||||
}
|
||||
#endif // HAVE_DNN_NGRAPH
|
||||
|
||||
private:
|
||||
bool hasToParam = false;
|
||||
int toCvDepth_ = -1;
|
||||
|
||||
static int mapToCvDepth(int v)
|
||||
{
|
||||
switch (v)
|
||||
{
|
||||
case opencv_onnx::TensorProto_DataType_FLOAT: return CV_32F;
|
||||
case opencv_onnx::TensorProto_DataType_UINT8: return CV_8U;
|
||||
case opencv_onnx::TensorProto_DataType_INT8: return CV_8S;
|
||||
case opencv_onnx::TensorProto_DataType_UINT16: return CV_16U;
|
||||
case opencv_onnx::TensorProto_DataType_INT16: return CV_16S;
|
||||
case opencv_onnx::TensorProto_DataType_INT32: return CV_32S;
|
||||
case opencv_onnx::TensorProto_DataType_INT64: return CV_64S;
|
||||
case opencv_onnx::TensorProto_DataType_BOOL: return CV_Bool;
|
||||
case opencv_onnx::TensorProto_DataType_FLOAT16: return CV_16F;
|
||||
case opencv_onnx::TensorProto_DataType_DOUBLE: return CV_64F;
|
||||
case opencv_onnx::TensorProto_DataType_BFLOAT16: return CV_16BF;
|
||||
default: break;
|
||||
}
|
||||
|
||||
CV_Error(Error::StsNotImplemented, "Cast: unsupported 'to' / dtype value");
|
||||
}
|
||||
|
||||
int resolveTargetDepthAtTypeTime(const std::vector<MatType>& inputs) const
|
||||
{
|
||||
if (hasToParam)
|
||||
return toCvDepth_;
|
||||
if (inputs.size() == 2)
|
||||
{
|
||||
int likeType = inputs[1];
|
||||
if (likeType >= 0)
|
||||
return CV_MAT_DEPTH(likeType);
|
||||
return -1;
|
||||
}
|
||||
return CV_MAT_DEPTH(inputs[0]);
|
||||
}
|
||||
};
|
||||
|
||||
Ptr<Cast2Layer> Cast2Layer::create(const LayerParams& params)
|
||||
{
|
||||
return makePtr<Cast2LayerImpl>(params);
|
||||
}
|
||||
|
||||
}} // namespace cv::dnn
|
||||
|
|
@ -168,8 +168,32 @@ Mat DiagonalInnermostDims(const Mat& input, bool preserve_innermost_dim_val) {
|
|||
output_dims[rank - 1] = 1;
|
||||
}
|
||||
|
||||
// TODO: hande different types
|
||||
Mat output = DiagonalDataAssignment<float>(input);
|
||||
Mat output;
|
||||
switch (input.depth())
|
||||
{
|
||||
case CV_32F:
|
||||
output = DiagonalDataAssignment<float>(input);
|
||||
break;
|
||||
case CV_64F:
|
||||
output = DiagonalDataAssignment<double>(input);
|
||||
break;
|
||||
case CV_16F:
|
||||
{
|
||||
Mat tmp32;
|
||||
input.convertTo(tmp32, CV_32F);
|
||||
Mat out32 = DiagonalDataAssignment<float>(tmp32);
|
||||
out32.convertTo(output, input.type());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
{
|
||||
Mat tmp32;
|
||||
input.convertTo(tmp32, CV_32F);
|
||||
Mat out32 = DiagonalDataAssignment<float>(tmp32);
|
||||
out32.convertTo(output, input.type());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (output_dims != shape(output)){
|
||||
CV_Error(Error::StsError, "Output shape does not match with calculated shape");
|
||||
|
|
@ -420,7 +444,26 @@ public:
|
|||
backendId == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH;
|
||||
}
|
||||
|
||||
// getMemoryShapes
|
||||
virtual void getTypes(const std::vector<MatType>& inputs,
|
||||
const int requiredOutputs,
|
||||
const int requiredInternals,
|
||||
std::vector<MatType>& outputs,
|
||||
std::vector<MatType>& internals) const CV_OVERRIDE
|
||||
{
|
||||
CV_Assert(!inputs.empty());
|
||||
for (const int t : inputs)
|
||||
CV_CheckTypeEQ(t, inputs[0], "All Einsum inputs must have the same type");
|
||||
|
||||
if (preferableTarget == DNN_TARGET_OPENCL_FP16)
|
||||
CV_CheckType(inputs[0], inputs[0] == CV_16F || inputs[0] == CV_32F || inputs[0] == CV_64F, "");
|
||||
else
|
||||
CV_CheckType(inputs[0], inputs[0] == CV_32F || inputs[0] == CV_64F || inputs[0] == CV_16F, "");
|
||||
|
||||
outputs.assign(1, inputs[0]);
|
||||
internals.assign(requiredInternals, inputs[0]);
|
||||
}
|
||||
|
||||
// getMeoryShapes
|
||||
bool getMemoryShapes(const std::vector<MatShape> &inputs,
|
||||
const int requiredOutputs,
|
||||
std::vector<MatShape> &outputs,
|
||||
|
|
@ -586,16 +629,27 @@ Mat LayerEinsumImpl::reduceSum(Mat& src, MatShape& reduceAxis)
|
|||
std::vector<MatShape> outputShapes, internalShapes;
|
||||
reduce->getMemoryShapes(inputShapes, 1, outputShapes, internalShapes);
|
||||
|
||||
int origType = src.type();
|
||||
Mat src32 = src;
|
||||
if (src.depth() != CV_32F)
|
||||
src.convertTo(src32, CV_32F);
|
||||
Mat output(outputShapes[0], CV_32F);
|
||||
|
||||
std::vector<Mat> inputs;
|
||||
std::vector<Mat> outputs;
|
||||
std::vector<Mat> internals;
|
||||
inputs.emplace_back(src);
|
||||
inputs.emplace_back(src32);
|
||||
outputs.emplace_back(output);
|
||||
|
||||
reduce->forward(inputs, outputs, internals);
|
||||
return outputs[0];
|
||||
Mat out = outputs[0];
|
||||
if (CV_MAT_TYPE(origType) != CV_32F)
|
||||
{
|
||||
Mat converted;
|
||||
out.convertTo(converted, origType);
|
||||
return converted;
|
||||
}
|
||||
return out;
|
||||
}
|
||||
|
||||
void LayerEinsumImpl::preProcessInputs(InputArrayOfArrays& inputs_arr)
|
||||
|
|
@ -1374,48 +1428,61 @@ Mat LayerEinsumImpl::batchwiseMatMul(
|
|||
Mat reshapedInput1 = input1;
|
||||
Mat reshapedInput2 = input2;
|
||||
|
||||
int origType = input1.type();
|
||||
Mat a = reshapedInput1, b = reshapedInput2;
|
||||
if (input1.depth() != CV_32F)
|
||||
{
|
||||
reshapedInput1.convertTo(a, CV_32F);
|
||||
reshapedInput2.convertTo(b, CV_32F);
|
||||
}
|
||||
|
||||
Mat output;
|
||||
if (batches > 1)
|
||||
{
|
||||
// create tmpout with type like input1
|
||||
output = Mat({batches, M, N}, input1.type());
|
||||
output = Mat({batches, M, N}, CV_32F);
|
||||
|
||||
reshapedInput2 = reshapedInput2.reshape(1, input2ShapeOverride);
|
||||
reshapedInput1 = reshapedInput1.reshape(1, input1ShapeOverride);
|
||||
b = b.reshape(1, input2ShapeOverride);
|
||||
a = a.reshape(1, input1ShapeOverride);
|
||||
|
||||
fastGemmBatch(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, output, opt);
|
||||
fastGemmBatch(false, false, 1.0, a, b, 0.0, output, opt);
|
||||
} else {
|
||||
|
||||
// input1 should of size MxK
|
||||
// check if input1 needs reshape, if need reshape
|
||||
if (input1.dims > 2 || input1.size[0] != M || (input1.dims > 1 && input1.size[1] != K) || input1.dims == 1)
|
||||
if (reshapedInput1.dims > 2 || reshapedInput1.size[0] != M || (reshapedInput1.dims > 1 && reshapedInput1.size[1] != K) || reshapedInput1.dims == 1)
|
||||
{
|
||||
int shape[] = {M, K};
|
||||
reshapedInput1 = input1.reshape(1, 2, shape);
|
||||
a = a.reshape(1, 2, shape);
|
||||
}
|
||||
|
||||
// input2 should be of size KxN
|
||||
// check if input2 needs reshape, if needs reshape
|
||||
if (input2.dims > 2 || input2.size[0] != K || (input2.dims > 1 && input2.size[1] != N) || input2.dims == 1)
|
||||
if (reshapedInput2.dims > 2 || reshapedInput2.size[0] != K || (reshapedInput2.dims > 1 && reshapedInput2.size[1] != N) || reshapedInput2.dims == 1)
|
||||
{
|
||||
int shape2[] = {K, N};
|
||||
reshapedInput2 = input2.reshape(1, 2, shape2);
|
||||
b = b.reshape(1, 2, shape2);
|
||||
}
|
||||
|
||||
|
||||
output = Mat(M, N, reshapedInput1.type());
|
||||
if ((reshapedInput1.dims == 0 && reshapedInput2.dims == 0) ||
|
||||
(reshapedInput1.dims == 0 && reshapedInput2.dims != 0) ||
|
||||
(reshapedInput1.dims != 0 && reshapedInput2.dims == 0))
|
||||
output = Mat(M, N, CV_32F);
|
||||
if ((a.dims == 0 && b.dims == 0) ||
|
||||
(a.dims == 0 && b.dims != 0) ||
|
||||
(a.dims != 0 && b.dims == 0))
|
||||
{
|
||||
output = reshapedInput1.mul(reshapedInput2); // fastGemm does not support 0D * 0D multiplication
|
||||
output = a.mul(b); // fastGemm does not support 0D * 0D multiplication
|
||||
} else {
|
||||
fastGemm(false, false, 1.0, reshapedInput1, reshapedInput2, 0.0, output, opt);
|
||||
fastGemm(false, false, 1.0, a, b, 0.0, output, opt);
|
||||
}
|
||||
|
||||
output = output.reshape(1, {1, M, N});
|
||||
}
|
||||
if (CV_MAT_TYPE(origType) != CV_32F)
|
||||
{
|
||||
Mat converted;
|
||||
output.convertTo(converted, origType);
|
||||
return converted;
|
||||
}
|
||||
return output;
|
||||
};
|
||||
Ptr<EinsumLayer> EinsumLayer::create(const LayerParams& params)
|
||||
|
|
|
|||
|
|
@ -397,9 +397,9 @@ public:
|
|||
{
|
||||
CV_CheckTypeEQ(inputs[0], input, "All inputs should have equal types");
|
||||
if (preferableTarget == DNN_TARGET_OPENCL_FP16)
|
||||
CV_CheckType(input, input == CV_16F || input == CV_8S || input == CV_8U || input == CV_16S || input == CV_16U || input == CV_32S || input == CV_32U || input == CV_64S || input == CV_64U, "");
|
||||
CV_CheckType(input, input == CV_16F || input == CV_32F || input == CV_64F || input == CV_8S || input == CV_8U || input == CV_16S || input == CV_16U || input == CV_32S || input == CV_32U || input == CV_64S || input == CV_64U, "");
|
||||
else
|
||||
CV_CheckType(input, input == CV_32F || input == CV_8S || input == CV_8U || input == CV_16S || input == CV_16U || input == CV_32S || input == CV_32U || input == CV_64S || input == CV_64U, "");
|
||||
CV_CheckType(input, input == CV_32F || input == CV_64F || input == CV_8S || input == CV_8U || input == CV_16S || input == CV_16U || input == CV_32S || input == CV_32U || input == CV_64S || input == CV_64U, "");
|
||||
}
|
||||
|
||||
if (op == OPERATION::EQUAL || op == OPERATION::GREATER || op == OPERATION::GREATER_EQUAL || op == OPERATION::LESS || op == OPERATION::LESS_EQUAL)
|
||||
|
|
|
|||
|
|
@ -559,12 +559,11 @@ void Net::Impl::setGraphInput(Ptr<Graph>& graph, size_t idx, const Mat& m)
|
|||
int adata_type = adata.type;
|
||||
if ((adata_type == CV_16F || adata_type == CV_16BF) && !enableFP16)
|
||||
adata_type = CV_32F;
|
||||
// [TODO] need to analyze this situation more carefully
|
||||
if (adata_type == CV_64F)
|
||||
adata_type = CV_32F;
|
||||
|
||||
if (adata_type != mtype &&
|
||||
!((adata_type == CV_64F || adata_type == CV_32F || adata_type == CV_16F || adata_type == CV_16BF) &&
|
||||
(mtype == CV_64F || mtype == CV_32F || mtype == CV_16F || mtype == CV_16BF)))
|
||||
(mtype == CV_64F || mtype == CV_32F || mtype == CV_16F || mtype == CV_16BF)) &&
|
||||
!(adata.type == CV_16BF && mtype == CV_16U) && !(adata.type == CV_16F && mtype == CV_16U))
|
||||
{
|
||||
CV_Error_(Error::StsBadArg, ("incompatible type of input tensor #%zu '%s': %s given, %s expected",
|
||||
idx, adata.name.c_str(), typeToString(mtype).c_str(),
|
||||
|
|
@ -574,7 +573,21 @@ void Net::Impl::setGraphInput(Ptr<Graph>& graph, size_t idx, const Mat& m)
|
|||
if (inp_t.shape() != mshape || inp_t.type() != adata_type)
|
||||
finalizeLayers = true;
|
||||
inp_t.fit(mshape, adata_type);
|
||||
m.convertTo(inp_t, adata_type);
|
||||
|
||||
if (adata.type == CV_16BF && mtype == CV_16U)
|
||||
{
|
||||
Mat tmp(mshape, CV_16BF, (void*)m.data);
|
||||
tmp.convertTo(inp_t, adata_type);
|
||||
}
|
||||
else if (adata.type == CV_16F && mtype == CV_16U)
|
||||
{
|
||||
Mat tmp(mshape, CV_16F, (void*)m.data);
|
||||
tmp.convertTo(inp_t, adata_type);
|
||||
}
|
||||
else
|
||||
{
|
||||
m.convertTo(inp_t, adata_type);
|
||||
}
|
||||
} else if (adata.kind == DNN_ARG_TEMP) {
|
||||
int bufidx = bufidxs.at(inp.idx);
|
||||
Mat& buf = buffers.at(bufidx);
|
||||
|
|
|
|||
|
|
@ -1813,12 +1813,35 @@ Mat getMatFromTensor(const opencv_onnx::TensorProto& tensor_proto, bool uint8ToI
|
|||
Mat(sizes, CV_16FC1, rawdata).convertTo(blob, CV_32FC1);
|
||||
}
|
||||
}
|
||||
else if (datatype == opencv_onnx::TensorProto_DataType_BFLOAT16)
|
||||
{
|
||||
if (!tensor_proto.raw_data().empty())
|
||||
{
|
||||
blob.create((int)sizes.size(), sizes.data(), CV_16BFC1);
|
||||
size_t bytes = (size_t)blob.total() * blob.elemSize();
|
||||
memcpy(blob.data, rawdata, bytes);
|
||||
}
|
||||
else if (!tensor_proto.int32_data().empty())
|
||||
{
|
||||
const auto& v = tensor_proto.int32_data();
|
||||
blob.create((int)sizes.size(), sizes.data(), CV_16BFC1);
|
||||
uint16_t* dst = reinterpret_cast<uint16_t*>(blob.data);
|
||||
for (size_t i = 0; i < v.size(); ++i)
|
||||
{
|
||||
dst[i] = static_cast<uint16_t>(v[i] & 0xFFFF);
|
||||
}
|
||||
}
|
||||
else
|
||||
{
|
||||
CV_Error(Error::StsNotImplemented, "BFLOAT16 tensor has no raw_data");
|
||||
}
|
||||
}
|
||||
else if (datatype == opencv_onnx::TensorProto_DataType_DOUBLE)
|
||||
{
|
||||
if (!tensor_proto.double_data().empty())
|
||||
Mat(sizes, CV_64FC1, (void*)tensor_proto.double_data().data()).convertTo(blob, CV_32FC1);
|
||||
else
|
||||
Mat(sizes, CV_64FC1, rawdata).convertTo(blob, CV_32FC1);
|
||||
Mat(sizes, CV_64FC1, rawdata).copyTo(blob);
|
||||
}
|
||||
else if (datatype == opencv_onnx::TensorProto_DataType_INT32)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ static int dataType2cv(int dt)
|
|||
dt == opencv_onnx::TensorProto_DataType_FLOAT ? CV_32F :
|
||||
dt == opencv_onnx::TensorProto_DataType_DOUBLE ? CV_64F :
|
||||
dt == opencv_onnx::TensorProto_DataType_FLOAT16 ? CV_16F :
|
||||
dt == opencv_onnx::TensorProto_DataType_BFLOAT16 ? CV_16BF :
|
||||
dt == opencv_onnx::TensorProto_DataType_COMPLEX64 ? CV_32FC2 :
|
||||
dt == opencv_onnx::TensorProto_DataType_COMPLEX128 ? CV_64FC2 :
|
||||
dt == opencv_onnx::TensorProto_DataType_BOOL ? CV_Bool : -1;
|
||||
|
|
@ -95,6 +96,7 @@ static std::string dataType2str(int dt)
|
|||
dt == opencv_onnx::TensorProto_DataType_INT64 ? "INT64" :
|
||||
dt == opencv_onnx::TensorProto_DataType_FLOAT ? "FLOAT" :
|
||||
dt == opencv_onnx::TensorProto_DataType_FLOAT16 ? "FLOAT16" :
|
||||
dt == opencv_onnx::TensorProto_DataType_BFLOAT16 ? "BFLOAT16" :
|
||||
dt == opencv_onnx::TensorProto_DataType_BOOL ? "BOOL" :
|
||||
dt == opencv_onnx::TensorProto_DataType_COMPLEX64 ? "COMPLEX64" :
|
||||
dt == opencv_onnx::TensorProto_DataType_COMPLEX128 ? "COMPLEX128" : nullptr;
|
||||
|
|
@ -174,6 +176,8 @@ protected:
|
|||
void parseAveragePool (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseBatchNormalization (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseCast (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseCast2 (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseCastLike (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseClip (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseConcat (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
void parseIf (LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto);
|
||||
|
|
@ -1478,13 +1482,16 @@ void ONNXImporter2::parseShape(LayerParams& layerParams, const opencv_onnx::Node
|
|||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
void ONNXImporter2::parseCast(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
void ONNXImporter2::parseCast2(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
opencv_onnx::TensorProto_DataType onnx_type = (opencv_onnx::TensorProto_DataType)layerParams.get<int>("to");
|
||||
int type = dataType2cv(onnx_type);
|
||||
layerParams.type = "Cast2";
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
layerParams.type = "Cast";
|
||||
layerParams.set("outputType", type);
|
||||
void ONNXImporter2::parseCastLike(LayerParams& layerParams, const opencv_onnx::NodeProto& node_proto)
|
||||
{
|
||||
CV_CheckEQ(node_proto.input_size(), 2, "CastLike requires two inputs");
|
||||
layerParams.type = "Cast2";
|
||||
addLayer(layerParams, node_proto);
|
||||
}
|
||||
|
||||
|
|
@ -2539,7 +2546,8 @@ void ONNXImporter2::buildDispatchMap_ONNX_AI(int opset_version)
|
|||
dispatch["Reshape"] = &ONNXImporter2::parseReshape;
|
||||
dispatch["Pad"] = &ONNXImporter2::parsePad;
|
||||
dispatch["Shape"] = &ONNXImporter2::parseShape;
|
||||
dispatch["Cast"] = &ONNXImporter2::parseCast;
|
||||
dispatch["Cast"] = &ONNXImporter2::parseCast2;
|
||||
dispatch["CastLike"] = &ONNXImporter2::parseCastLike;
|
||||
dispatch["ConstantFill"] = dispatch["ConstantOfShape"] = &ONNXImporter2::parseConstantOfShape;
|
||||
dispatch["Gather"] = &ONNXImporter2::parseGather;
|
||||
dispatch["GatherElements"] = &ONNXImporter2::parseGatherElements;
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@
|
|||
"test_cast_STRING_to_FLOAT",
|
||||
"test_castlike_FLOAT_to_STRING_expanded",
|
||||
"test_castlike_STRING_to_FLOAT_expanded",
|
||||
"test_cast_DOUBLE_to_FLOAT",
|
||||
"test_concat_1d_axis_negative_1",
|
||||
"test_conv_with_autopad_same",
|
||||
"test_conv_with_strides_and_asymmetric_padding",
|
||||
|
|
@ -56,6 +57,10 @@
|
|||
"test_div_bcast",
|
||||
"test_div_uint8",
|
||||
"test_dropout_default_ratio",
|
||||
"test_einsum_batch_diagonal",
|
||||
"test_einsum_batch_matmul",
|
||||
"test_einsum_sum",
|
||||
"test_einsum_transpose",
|
||||
"test_flatten_axis0",
|
||||
"test_flatten_axis2",
|
||||
"test_flatten_axis3",
|
||||
|
|
@ -71,6 +76,9 @@
|
|||
"test_maxpool_with_argmax_2d_precomputed_pads",
|
||||
"test_maxpool_with_argmax_2d_precomputed_strides",
|
||||
"test_maxunpool_export_with_output_shape",
|
||||
"test_max_float64",
|
||||
"test_min_float64",
|
||||
"test_mod_mixed_sign_float64",
|
||||
"test_mul_bcast",
|
||||
"test_mul_uint8",
|
||||
"test_softmax_default_axis",
|
||||
|
|
|
|||
|
|
@ -1,10 +1,22 @@
|
|||
"test_basic_conv_with_padding", // (assert failed) !blobs.empty() in initCUDA
|
||||
"test_basic_conv_without_padding", // (assert failed) !blobs.empty() in initCUDA
|
||||
"test_cast_DOUBLE_to_FLOAT",
|
||||
"test_conv_with_autopad_same", // (assert failed) !blobs.empty() in initCUDA
|
||||
"test_conv_with_strides_and_asymmetric_padding", // (assert failed) !blobs.empty() in initCUDA
|
||||
"test_conv_with_strides_no_padding", // (assert failed) !blobs.empty() in initCUDA
|
||||
"test_conv_with_strides_padding", // (assert failed) !blobs.empty() in initCUDA
|
||||
"test_cumsum_1d",
|
||||
"test_cumsum_1d_exclusive",
|
||||
"test_cumsum_1d_reverse",
|
||||
"test_cumsum_1d_reverse_exclusive",
|
||||
"test_cumsum_2d_axis_0",
|
||||
"test_cumsum_2d_axis_1",
|
||||
"test_cumsum_2d_negative_axis",
|
||||
"test_dropout_default_ratio",
|
||||
"test_einsum_batch_diagonal",
|
||||
"test_einsum_batch_matmul",
|
||||
"test_einsum_sum",
|
||||
"test_einsum_transpose",
|
||||
"test_logsoftmax_large_number", // fp16 accuracy issue
|
||||
"test_logsoftmax_large_number_expanded", // fp16 accuracy issue
|
||||
"test_maxpool_with_argmax_2d_precomputed_pads", // assertion failed mat.type() == CV_32F
|
||||
|
|
@ -23,3 +35,6 @@
|
|||
"test_quantizelinear", // Issue https://github.com/opencv/opencv/issues/25999
|
||||
"test_quantizelinear_axis", // Issue https://github.com/opencv/opencv/issues/25999
|
||||
"test_quantizelinear_blocked", // Issue https://github.com/opencv/opencv/issues/25999
|
||||
"test_max_float64",
|
||||
"test_min_float64",
|
||||
"test_mod_mixed_sign_float64",
|
||||
|
|
|
|||
|
|
@ -289,21 +289,21 @@ CASE(test_bitshift_right_uint64)
|
|||
CASE(test_bitshift_right_uint8)
|
||||
SKIP;
|
||||
CASE(test_cast_BFLOAT16_to_FLOAT)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_cast_DOUBLE_to_FLOAT)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_cast_DOUBLE_to_FLOAT16)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_cast_FLOAT16_to_DOUBLE)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_cast_FLOAT16_to_FLOAT)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_cast_FLOAT_to_BFLOAT16)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_cast_FLOAT_to_DOUBLE)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_cast_FLOAT_to_FLOAT16)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_cast_FLOAT_to_STRING)
|
||||
#if SKIP_SET_1
|
||||
SKIP;
|
||||
|
|
@ -313,37 +313,37 @@ CASE(test_cast_STRING_to_FLOAT)
|
|||
SKIP;
|
||||
#endif
|
||||
CASE(test_castlike_BFLOAT16_to_FLOAT)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_BFLOAT16_to_FLOAT_expanded)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_DOUBLE_to_FLOAT)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_DOUBLE_to_FLOAT16)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_DOUBLE_to_FLOAT16_expanded)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_DOUBLE_to_FLOAT_expanded)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT16_to_DOUBLE)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT16_to_DOUBLE_expanded)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT16_to_FLOAT)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT16_to_FLOAT_expanded)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT_to_BFLOAT16)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT_to_BFLOAT16_expanded)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT_to_DOUBLE)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT_to_DOUBLE_expanded)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT_to_FLOAT16)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT_to_FLOAT16_expanded)
|
||||
// no filter
|
||||
SKIP;
|
||||
CASE(test_castlike_FLOAT_to_STRING)
|
||||
// no filter
|
||||
CASE(test_castlike_FLOAT_to_STRING_expanded)
|
||||
|
|
|
|||
|
|
@ -280,3 +280,30 @@
|
|||
"test_div_uint64",
|
||||
"test_cumsum_1d_int32_exclusive",
|
||||
"test_cumsum_2d_int32",
|
||||
"test_cast_BFLOAT16_to_FLOAT",
|
||||
"test_cast_DOUBLE_to_FLOAT16",
|
||||
"test_cast_FLOAT16_to_DOUBLE",
|
||||
"test_cast_FLOAT16_to_FLOAT",
|
||||
"test_cast_FLOAT_to_BFLOAT16",
|
||||
"test_cast_FLOAT_to_DOUBLE",
|
||||
"test_cast_FLOAT_to_FLOAT16",
|
||||
"test_castlike_BFLOAT16_to_FLOAT",
|
||||
"test_castlike_BFLOAT16_to_FLOAT_expanded",
|
||||
"test_castlike_DOUBLE_to_FLOAT",
|
||||
"test_castlike_DOUBLE_to_FLOAT16",
|
||||
"test_castlike_DOUBLE_to_FLOAT16_expanded",
|
||||
"test_castlike_DOUBLE_to_FLOAT_expanded",
|
||||
"test_castlike_FLOAT16_to_DOUBLE",
|
||||
"test_castlike_FLOAT16_to_DOUBLE_expanded",
|
||||
"test_castlike_FLOAT16_to_FLOAT",
|
||||
"test_castlike_FLOAT16_to_FLOAT_expanded",
|
||||
"test_castlike_FLOAT_to_BFLOAT16",
|
||||
"test_castlike_FLOAT_to_BFLOAT16_expanded",
|
||||
"test_castlike_FLOAT_to_DOUBLE",
|
||||
"test_castlike_FLOAT_to_DOUBLE_expanded",
|
||||
"test_castlike_FLOAT_to_FLOAT16",
|
||||
"test_castlike_FLOAT_to_FLOAT16_expanded",
|
||||
"test_gelu_default_1_expanded",
|
||||
"test_gelu_default_2_expanded",
|
||||
"test_gelu_tanh_1_expanded",
|
||||
"test_gelu_tanh_2_expanded",
|
||||
|
|
|
|||
|
|
@ -183,10 +183,6 @@
|
|||
"test_blackmanwindow_expanded",
|
||||
"test_blackmanwindow_symmetric",
|
||||
"test_blackmanwindow_symmetric_expanded",
|
||||
"test_cast_BFLOAT16_to_FLOAT", // Issue::Unsuppoted data type
|
||||
"test_cast_DOUBLE_to_FLOAT16", // Issue::Unsuppoted data type
|
||||
"test_cast_FLOAT16_to_DOUBLE", // Issue::Unsuppoted data type
|
||||
"test_cast_FLOAT16_to_FLOAT", // Issue::Unsuppoted data type
|
||||
"test_cast_FLOAT16_to_FLOAT4E2M1",
|
||||
"test_cast_FLOAT16_to_FLOAT8E4M3FN",
|
||||
"test_cast_FLOAT16_to_FLOAT8E4M3FNUZ",
|
||||
|
|
@ -204,9 +200,6 @@
|
|||
"test_cast_FLOAT8E5M2FNUZ_to_FLOAT16",
|
||||
"test_cast_FLOAT8E5M2_to_FLOAT",
|
||||
"test_cast_FLOAT8E5M2_to_FLOAT16",
|
||||
"test_cast_FLOAT_to_BFLOAT16", // Issue::Unsuppoted data type
|
||||
"test_cast_FLOAT_to_DOUBLE", // Issue::Unsuppoted data type
|
||||
"test_cast_FLOAT_to_FLOAT16", // Issue::Unsuppoted data type
|
||||
"test_cast_FLOAT_to_FLOAT4E2M1",
|
||||
"test_cast_FLOAT_to_FLOAT8E4M3FN",
|
||||
"test_cast_FLOAT_to_FLOAT8E4M3FNUZ",
|
||||
|
|
@ -232,15 +225,6 @@
|
|||
"test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ",
|
||||
"test_cast_no_saturate_FLOAT_to_FLOAT8E5M2",
|
||||
"test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ",
|
||||
"test_castlike_BFLOAT16_to_FLOAT", // Issue::Unsuppoted data type
|
||||
"test_castlike_BFLOAT16_to_FLOAT_expanded", // Issue::Unsuppoted data type
|
||||
"test_castlike_DOUBLE_to_FLOAT", // Issues::Layer::Can't create layer "onnx_node_output_0!output" of type "CastLike" in function 'getLayerInstance'
|
||||
"test_castlike_DOUBLE_to_FLOAT16", // Issues::Layer::Can't create layer "onnx_node_output_0!output" of type "CastLike" in function 'getLayerInstance'
|
||||
"test_castlike_DOUBLE_to_FLOAT16_expanded", // Issues::Layer::mismatch in input and output shapes inputs.size() == requiredOutputs in function 'getMemoryShapes'
|
||||
"test_castlike_DOUBLE_to_FLOAT_expanded", // Issues::Layer::mismatch in input and output shapes inputs.size() == requiredOutputs in function 'getMemoryShapes'
|
||||
"test_castlike_FLOAT16_to_DOUBLE", // Issue::Unsuppoted data type
|
||||
"test_castlike_FLOAT16_to_DOUBLE_expanded", // Issue::Unsuppoted data type
|
||||
"test_castlike_FLOAT16_to_FLOAT", // Issues::Layer::Can't create layer "onnx_node_output_0!output" of type "CastLike" in function 'getLayerInstance'
|
||||
"test_castlike_FLOAT16_to_FLOAT4E2M1",
|
||||
"test_castlike_FLOAT16_to_FLOAT4E2M1_expanded",
|
||||
"test_castlike_FLOAT16_to_FLOAT8E4M3FN",
|
||||
|
|
@ -251,7 +235,6 @@
|
|||
"test_castlike_FLOAT16_to_FLOAT8E5M2FNUZ",
|
||||
"test_castlike_FLOAT16_to_FLOAT8E5M2FNUZ_expanded",
|
||||
"test_castlike_FLOAT16_to_FLOAT8E5M2_expanded",
|
||||
"test_castlike_FLOAT16_to_FLOAT_expanded", // Issues::Layer::mismatch in input and output shapes inputs.size() == requiredOutputs in function 'getMemoryShapes'
|
||||
"test_castlike_FLOAT16_to_INT4",
|
||||
"test_castlike_FLOAT16_to_INT4_expanded",
|
||||
"test_castlike_FLOAT16_to_UINT4",
|
||||
|
|
@ -276,12 +259,6 @@
|
|||
"test_castlike_FLOAT8E5M2_to_FLOAT16",
|
||||
"test_castlike_FLOAT8E5M2_to_FLOAT16_expanded",
|
||||
"test_castlike_FLOAT8E5M2_to_FLOAT_expanded",
|
||||
"test_castlike_FLOAT_to_BFLOAT16", // Issue::Unsuppoted data type
|
||||
"test_castlike_FLOAT_to_BFLOAT16_expanded", // Issue::Unsuppoted data type
|
||||
"test_castlike_FLOAT_to_DOUBLE", // Issues::Layer::Can't create layer "onnx_node_output_0!output" of type "CastLike" in function 'getLayerInstance'
|
||||
"test_castlike_FLOAT_to_DOUBLE_expanded", // Issue::Unsuppoted data type
|
||||
"test_castlike_FLOAT_to_FLOAT16", // Issues::Layer::mismatch in input and output shapes inputs.size() == requiredOutputs in function 'getMemoryShapes'
|
||||
"test_castlike_FLOAT_to_FLOAT16_expanded", // Issues::Layer::mismatch in input and output shapes inputs.size() == requiredOutputs in function 'getMemoryShapes'
|
||||
"test_castlike_FLOAT_to_FLOAT4E2M1",
|
||||
"test_castlike_FLOAT_to_FLOAT4E2M1_expanded",
|
||||
"test_castlike_FLOAT_to_FLOAT8E4M3FN",
|
||||
|
|
@ -402,10 +379,6 @@
|
|||
"test_eyelike_populate_off_main_diagonal", // Issues::Layer::Can't create layer::Can't create layer "onnx_node_output_0!y" of type "EyeLike" in function 'getLayerInstance'
|
||||
"test_eyelike_with_dtype", // ---- same as above ---
|
||||
"test_eyelike_without_dtype", // ---- same as above ---
|
||||
"test_gelu_default_1_expanded", // parser: no corresponding layer for CastLike
|
||||
"test_gelu_default_2_expanded", // parser: no corresponding layer for CastLike
|
||||
"test_gelu_tanh_1_expanded", // parser: no corresponding layer for CastLike
|
||||
"test_gelu_tanh_2_expanded", // parser: no corresponding layer for CastLike
|
||||
"test_gridsample_bicubic", // ---- same as above ---
|
||||
"test_gridsample_bicubic_align_corners_0_additional_1",
|
||||
"test_gridsample_bicubic_align_corners_1_additional_1",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user