mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
MKL-DNN open source integration. (#13135)
* MKL-DNN conv and build integration * Adding new files that were mistakenly missing from the PR * Minor change in the pip package build file * Added missing #include * Fixed a linking failure when running the bazel test * Fixing BUILD file format * Using -fopenmp for building mkl_dnn only when running on linux * Fixing build rule attribute value * Removing unnecessary deps from mkl test rule * Removed deps on mkl-dnn when not building with --config=mkl
This commit is contained in:
parent
f1a1099232
commit
6af7ab97ac
|
|
@ -1772,6 +1772,7 @@ tf_cuda_library(
|
|||
) + if_mkl(
|
||||
[
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
],
|
||||
),
|
||||
alwayslink = 1,
|
||||
|
|
@ -1932,7 +1933,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
|
|||
"common_runtime/visitable_allocator.h",
|
||||
"graph/gradients.h",
|
||||
"graph/quantize_training.h",
|
||||
]
|
||||
] + if_mkl(["graph/mkl_graph_util.h"])
|
||||
|
||||
tf_cuda_library(
|
||||
name = "core_cpu_impl",
|
||||
|
|
@ -2033,7 +2034,10 @@ tf_cuda_library(
|
|||
"//third_party/eigen3",
|
||||
"//tensorflow/core/kernels:required",
|
||||
] + if_mkl(
|
||||
["//third_party/mkl:intel_binary_blob"],
|
||||
[
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
],
|
||||
) + tf_additional_core_deps() + if_static([":core_cpu_impl"]),
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
|
@ -2669,7 +2673,7 @@ tf_cc_test_mkl(
|
|||
"graph/mkl_layout_pass_test.cc",
|
||||
"graph/mkl_tfconversion_pass_test.cc",
|
||||
],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
linkstatic = 1,
|
||||
deps = [
|
||||
":core",
|
||||
":core_cpu",
|
||||
|
|
@ -2687,18 +2691,6 @@ tf_cc_test_mkl(
|
|||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/cc:sendrecv_ops",
|
||||
"//tensorflow/core/kernels:mkl_aggregate_ops",
|
||||
"//tensorflow/core/kernels:mkl_concat_op",
|
||||
"//tensorflow/core/kernels:mkl_conv_op",
|
||||
"//tensorflow/core/kernels:mkl_cwise_ops_common",
|
||||
"//tensorflow/core/kernels:mkl_fused_batch_norm_op",
|
||||
"//tensorflow/core/kernels:mkl_identity_op",
|
||||
"//tensorflow/core/kernels:mkl_input_conversion_op",
|
||||
"//tensorflow/core/kernels:mkl_lrn_op",
|
||||
"//tensorflow/core/kernels:mkl_pooling_ops",
|
||||
"//tensorflow/core/kernels:mkl_relu_op",
|
||||
"//tensorflow/core/kernels:mkl_reshape_op",
|
||||
"//tensorflow/core/kernels:mkl_tfconv_op",
|
||||
"//tensorflow/core/kernels:ops_util",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
|
|
|
|||
129
tensorflow/core/graph/mkl_graph_util.h
Normal file
129
tensorflow/core/graph/mkl_graph_util.h
Normal file
|
|
@ -0,0 +1,129 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
|
||||
#define TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include <string>
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
||||
namespace tensorflow {
|
||||
// Since our ops are going to produce and also consume N addition tensors
|
||||
// (Mkl) for N Tensorflow tensors, we can have following different
|
||||
// orderings among these 2N tensors.
|
||||
//
|
||||
// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
|
||||
// consume A_m, B_m, and C_m additionally.
|
||||
//
|
||||
// INTERLEAVED: in this case 2N tensors are interleaved. So for above
|
||||
// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
|
||||
//
|
||||
// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
|
||||
// by N Mkl tensors. So for above example, the ordering looks
|
||||
// like: A, B, C, A_m, B_m, C_m
|
||||
//
|
||||
// Following APIs map index of original Tensorflow tensors to their
|
||||
// appropriate position based on selected ordering. For contiguous ordering,
|
||||
// we need to know the total number of tensors (parameter total).
|
||||
//
|
||||
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
|
||||
// NOTE: Currently, we use contiguous ordering. If you change this, then you
|
||||
// would need to change Mkl op definitions in nn_ops.cc.
|
||||
static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
|
||||
|
||||
// Get index of MetaData tensor from index 'n' of Data tensor.
|
||||
inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
// For interleaved ordering, Mkl tensor follows immediately after
|
||||
// Tensorflow tensor.
|
||||
return n + 1;
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
// For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
|
||||
return n + total_tensors / 2;
|
||||
}
|
||||
}
|
||||
|
||||
int inline GetTensorDataIndex(int n, int total_tensors) {
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
return 2 * n; // index corresponding to nth input/output tensor
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
return n;
|
||||
}
|
||||
}
|
||||
|
||||
int inline GetTensorMetaDataIndex(int n, int total_tensors) {
|
||||
// Get index for TensorData first and then use mapping function
|
||||
// to get TensorMetaData index from TensorData index.
|
||||
int tidx = GetTensorDataIndex(n, total_tensors);
|
||||
return DataIndexToMetaDataIndex(tidx, total_tensors);
|
||||
}
|
||||
|
||||
namespace mkl_op_registry {
|
||||
static const char* kMklOpLabel = "MklOp";
|
||||
static const char* kMklOpLabelPattern = "label='MklOp'";
|
||||
|
||||
// Get the name of Mkl op from original TensorFlow op
|
||||
// We prefix 'Mkl' to the original op to get Mkl op.
|
||||
inline string GetMklOpName(const string& name) {
|
||||
// Prefix that we add to Tensorflow op name to construct Mkl op name.
|
||||
const char* const kMklOpPrefix = "_Mkl";
|
||||
return string(kMklOpPrefix) + name;
|
||||
}
|
||||
|
||||
// Check whether opname with type T is registered as MKL-compliant.
|
||||
//
|
||||
// @input: name of the op
|
||||
// @input: T datatype to be used for checking op
|
||||
// @return: true if opname is registered as Mkl op; false otherwise
|
||||
static inline bool IsMklOp(const std::string& op_name, DataType T) {
|
||||
string kernel = KernelsRegisteredForOp(op_name);
|
||||
bool result =
|
||||
kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
|
||||
if (result) {
|
||||
VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Check whether opname with type T is registered as MKL-compliant and
|
||||
// is element-wise.
|
||||
//
|
||||
// @input: name of the op
|
||||
// @input: T datatype to be used for checking op
|
||||
// @return: true if opname is registered as element-wise Mkl op;
|
||||
// false otherwise
|
||||
static inline bool IsMklElementWiseOp(const std::string& op_name,
|
||||
DataType T) {
|
||||
if (!IsMklOp(op_name, T)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
|
||||
0 == op_name.compare(GetMklOpName("Sub")) ||
|
||||
0 == op_name.compare(GetMklOpName("Mul")) ||
|
||||
0 == op_name.compare(GetMklOpName("Maximum")) ||
|
||||
0 == op_name.compare(GetMklOpName("SquaredDifference")));
|
||||
|
||||
VLOG(1) << "mkl_op_registry::" << op_name
|
||||
<< " is elementwise MKL op: " << result;
|
||||
return result;
|
||||
}
|
||||
} // namespace mkl_op_registry
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
||||
#endif // TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
|
||||
|
|
@ -38,7 +38,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/logging.h"
|
||||
|
||||
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ limitations under the License.
|
|||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
|
|
|
|||
|
|
@ -820,6 +820,7 @@ tf_kernel_library(
|
|||
hdrs = ["transpose_op.h"],
|
||||
deps = ARRAY_DEPS + if_mkl([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
@ -2596,6 +2597,7 @@ tf_kernel_library(
|
|||
"//conditions:default": [],
|
||||
}) + if_mkl([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
]) + if_cuda([
|
||||
"//tensorflow/core/platform/default/build_config:cublas_plugin",
|
||||
]),
|
||||
|
|
@ -5501,8 +5503,10 @@ tf_mkl_kernel_library(
|
|||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
] + if_mkl([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
|
|
@ -5516,8 +5520,10 @@ tf_mkl_kernel_library(
|
|||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
] + if_mkl([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
|
|
@ -5566,16 +5572,19 @@ tf_mkl_kernel_library(
|
|||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:lib_internal",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
] + if_mkl([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
name = "mkl_fused_batch_norm_op",
|
||||
srcs = ["mkl_fused_batch_norm_op.cc"],
|
||||
deps = NN_DEPS + [
|
||||
deps = NN_DEPS + if_mkl([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
|
|
@ -5589,9 +5598,10 @@ tf_mkl_kernel_library(
|
|||
tf_mkl_kernel_library(
|
||||
name = "mkl_concat_op",
|
||||
prefix = "mkl_concat_op",
|
||||
deps = ARRAY_DEPS + [
|
||||
deps = ARRAY_DEPS + if_mkl([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
|
|
@ -5605,17 +5615,19 @@ tf_mkl_kernel_library(
|
|||
tf_mkl_kernel_library(
|
||||
name = "mkl_identity_op",
|
||||
prefix = "mkl_identity_op",
|
||||
deps = ARRAY_DEPS + [
|
||||
deps = ARRAY_DEPS + if_mkl([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
name = "mkl_lrn_op",
|
||||
prefix = "mkl_lrn_op",
|
||||
deps = NN_DEPS + [
|
||||
deps = NN_DEPS + if_mkl([
|
||||
"//third_party/mkl:intel_binary_blob",
|
||||
],
|
||||
"@mkl_dnn//:mkl_dnn",
|
||||
]),
|
||||
)
|
||||
|
||||
tf_mkl_kernel_library(
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
||||
#include "tensorflow/core/kernels/mkl_conv_ops.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
|
@ -41,10 +42,24 @@ limitations under the License.
|
|||
#include "mkl_dnn.h"
|
||||
#include "mkl_dnn_types.h"
|
||||
|
||||
#ifdef INTEL_MKL_DNN
|
||||
#include "mkldnn.hpp"
|
||||
|
||||
using mkldnn::stream;
|
||||
using mkldnn::prop_kind;
|
||||
|
||||
using mkldnn::convolution_forward;
|
||||
using mkldnn::convolution_backward_weights;
|
||||
using mkldnn::convolution_direct;
|
||||
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
#ifndef INTEL_MKL_DNN
|
||||
|
||||
template <typename Device, class T>
|
||||
class MklConv2DCustomBackpropFilterOp : public OpKernel {
|
||||
public:
|
||||
|
|
@ -411,6 +426,174 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
|
|||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
template <typename Device, class T>
|
||||
class MklConv2DCustomBackpropFilterOp : public OpKernel {
|
||||
public:
|
||||
explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
string data_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
|
||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
|
||||
int stride_n = GetTensorDim(strides_, data_format_, 'N');
|
||||
int stride_c = GetTensorDim(strides_, data_format_, 'C');
|
||||
OP_REQUIRES(
|
||||
context, (stride_n == 1 && stride_c == 1),
|
||||
errors::InvalidArgument("Current implementation does not yet support "
|
||||
"strides in the batch and depth dimensions."));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
try {
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
|
||||
MklDnnData<T> input(&cpu_engine);
|
||||
MklDnnData<T> outbackprop(&cpu_engine);
|
||||
MklDnnData<T> output(&cpu_engine);
|
||||
|
||||
// Input tensors
|
||||
const Tensor& input_tensor = MklGetInput(context, 0);
|
||||
const Tensor& filter_tensor = MklGetInput(context, 1);
|
||||
const Tensor& obp_tensor = MklGetInput(context, 2); // Outbackprop
|
||||
|
||||
// Generate input shapes.
|
||||
TensorShape filter_shape;
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_tensor.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
|
||||
filter_tensor.dims()));
|
||||
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
|
||||
filter_tensor.vec<int32>(), &filter_shape));
|
||||
TensorShape input_shape = input_tensor.shape();
|
||||
TensorShape obp_shape = obp_tensor.shape();
|
||||
|
||||
// By default, all dims are in MKL order. Only dims in TF order
|
||||
// are those with prefix tf_order.
|
||||
memory::dims obp_dims, fwd_input_dims, fwd_filter_dims;
|
||||
memory::dims padding_l, padding_r, strides, fwd_output_dims;
|
||||
memory::dims fwd_output_dims_tf_order;
|
||||
|
||||
// Get forward convolution parameters.
|
||||
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
|
||||
conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape,
|
||||
&fwd_input_dims, &fwd_filter_dims,
|
||||
&strides,
|
||||
&fwd_output_dims_tf_order,
|
||||
&fwd_output_dims,
|
||||
&padding_l, &padding_r);
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
// Create Convolution forward descriptor since Convolution backward
|
||||
// API needs it. For that, we first need to create input, filter
|
||||
// and output memory descriptors.
|
||||
auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
|
||||
auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(),
|
||||
mkl_data_format);
|
||||
auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(),
|
||||
memory::format::hwio);
|
||||
auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(),
|
||||
mkl_data_format);
|
||||
auto fwd_desc = convolution_forward::desc(prop_kind::forward,
|
||||
convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md,
|
||||
strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
|
||||
auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
|
||||
|
||||
// Allocate output tensor and shape
|
||||
// TODO(nhasabni): Update this when support for MKL layout is added.
|
||||
// Shape of output of Conv2DBackpropInput is same as 'input' of Conv2D.
|
||||
TensorShape tf_output_shape(filter_shape);
|
||||
MklShape mkl_output_mkl_shape;
|
||||
mkl_output_mkl_shape.SetMklTensor(false);
|
||||
Tensor* output_tensor = nullptr;
|
||||
AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
|
||||
mkl_output_mkl_shape);
|
||||
|
||||
// Create memory for user data.
|
||||
// Describe how the inputs and outputs of Convolution look like. Also
|
||||
// specify buffers containing actual input and output data.
|
||||
// Although input shape required is in MKL-DNN order, the layout is
|
||||
// Tensorflow's layout (NHWC or NCHW depending on data format).
|
||||
input.SetUsrMem(fwd_input_dims, mkl_data_format, &input_tensor);
|
||||
// Outbackprop shape is NHWC or NCHW depending on data format. Since
|
||||
// GetInputSizeInMklOrder function returns size in that order we just use
|
||||
// use that function directly.
|
||||
conv_utl.GetInputSizeInMklOrder(obp_shape, &obp_dims);
|
||||
if (!context->status().ok()) return;
|
||||
outbackprop.SetUsrMem(obp_dims, mkl_data_format, &obp_tensor);
|
||||
// Although output shape required is in MKL-DNN order,
|
||||
// layout is Tensorflow's filter layout (HWIO)
|
||||
// Shape of output of Conv2DBackpropInput is same as shape of filter.
|
||||
memory::dims bwd_output_dims = fwd_filter_dims;
|
||||
output.SetUsrMem(bwd_output_dims, memory::format::hwio, output_tensor);
|
||||
|
||||
// Create memory descriptors for convolution data w/ no specified format.
|
||||
input.SetOpMemDesc(fwd_input_dims, memory::format::any);
|
||||
outbackprop.SetOpMemDesc(obp_dims, memory::format::any);
|
||||
output.SetOpMemDesc(bwd_output_dims, memory::format::any);
|
||||
|
||||
// Create convolution backward weights primitive.
|
||||
auto bwd_desc = convolution_backward_weights::desc(convolution_direct,
|
||||
input.GetOpMemDesc(), output.GetOpMemDesc(),
|
||||
outbackprop.GetOpMemDesc(), strides, padding_l,
|
||||
padding_r, TFPaddingToMklDnnPadding(padding_));
|
||||
|
||||
auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
|
||||
cpu_engine,
|
||||
fwd_pd);
|
||||
|
||||
PrepareAndExecutePrimitive(bwd_pd, &input, &outbackprop, &output);
|
||||
} catch (mkldnn::error &e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) +
|
||||
", in file " + string(__FILE__) + ":" +
|
||||
std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:",
|
||||
error_msg));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int32> strides_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
|
||||
// Prepare and execute net - checks for input and output reorders.
|
||||
void PrepareAndExecutePrimitive(
|
||||
const convolution_backward_weights::primitive_desc& conv_pd,
|
||||
MklDnnData<T>* input, MklDnnData<T>* obp,
|
||||
MklDnnData<T>* output) {
|
||||
// Create reorders between user layout and MKL layout if it is needed and
|
||||
// add it to the net before convolution.
|
||||
std::vector<primitive> net;
|
||||
input->CheckReorderToOpMem(conv_pd.src_primitive_desc(), &net);
|
||||
obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
|
||||
|
||||
// Memory for output of convolution. Since we may need reorder on the
|
||||
// output side, we will prepare reorder primitive in case output
|
||||
// reorder to user memory is required.
|
||||
bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
|
||||
conv_pd.diff_weights_primitive_desc());
|
||||
|
||||
net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(),
|
||||
obp->GetOpMem(), output->GetOpMem()));
|
||||
|
||||
// Insert reorder primitive in the net for output reorder if reorder is
|
||||
// required.
|
||||
if (output_reorder_required) {
|
||||
output->InsertReorderToUserMem(&net);
|
||||
}
|
||||
|
||||
// Handle output reorder
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
}
|
||||
};
|
||||
#endif
|
||||
|
||||
#define REGISTER_MKL_FILTER_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
|
||||
.Device(DEVICE_CPU) \
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
||||
#include "tensorflow/core/kernels/mkl_conv_ops.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
|
@ -43,10 +44,23 @@ limitations under the License.
|
|||
#include "mkl_dnn.h"
|
||||
#include "mkl_dnn_types.h"
|
||||
|
||||
#ifdef INTEL_MKL_DNN
|
||||
#include "mkldnn.hpp"
|
||||
|
||||
using mkldnn::stream;
|
||||
using mkldnn::prop_kind;
|
||||
|
||||
using mkldnn::convolution_forward;
|
||||
using mkldnn::convolution_direct;
|
||||
using mkldnn::convolution_backward_data;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
#ifndef INTEL_MKL_DNN
|
||||
|
||||
template <typename Device, class T>
|
||||
class MklConv2DCustomBackpropInputOp : public OpKernel {
|
||||
public:
|
||||
|
|
@ -345,6 +359,180 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
|
|||
TensorFormat data_format;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
template <typename Device, class T>
|
||||
class MklConv2DCustomBackpropInputOp : public OpKernel {
|
||||
public:
|
||||
~MklConv2DCustomBackpropInputOp() {}
|
||||
explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
|
||||
: OpKernel(context) {
|
||||
string data_format_str;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
|
||||
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
|
||||
int stride_n = GetTensorDim(strides_, data_format_, 'N');
|
||||
int stride_c = GetTensorDim(strides_, data_format_, 'C');
|
||||
OP_REQUIRES(
|
||||
context, (stride_n == 1 && stride_c == 1),
|
||||
errors::InvalidArgument("Current implementation does not yet support "
|
||||
"strides in the batch and depth dimensions."));
|
||||
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
try {
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
|
||||
MklDnnData<T> filter(&cpu_engine);
|
||||
MklDnnData<T> outbackprop(&cpu_engine);
|
||||
MklDnnData<T> output(&cpu_engine);
|
||||
|
||||
// Input tensors
|
||||
const Tensor& input_tensor = MklGetInput(context, 0);
|
||||
const Tensor& filter_tensor = MklGetInput(context, 1);
|
||||
const Tensor& obp_tensor = MklGetInput(context, 2); // Outbackprop
|
||||
|
||||
// Generate input shape.
|
||||
TensorShape input_shape;
|
||||
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
|
||||
errors::InvalidArgument(
|
||||
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
|
||||
input_tensor.dims()));
|
||||
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
|
||||
input_tensor.vec<int32>(), &input_shape));
|
||||
TensorShape filter_shape = filter_tensor.shape();
|
||||
TensorShape obp_shape = obp_tensor.shape();
|
||||
|
||||
// By default, all dims are in MKL order. Only dims in TF order
|
||||
// are those with prefix tf_order.
|
||||
memory::dims obp_dims, fwd_input_dims, fwd_filter_dims;
|
||||
memory::dims padding_l, padding_r, strides, fwd_output_dims;
|
||||
memory::dims fwd_output_dims_tf_order;
|
||||
|
||||
// Get forward convolution parameters.
|
||||
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
|
||||
conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape,
|
||||
&fwd_input_dims, &fwd_filter_dims,
|
||||
&strides,
|
||||
&fwd_output_dims_tf_order,
|
||||
&fwd_output_dims,
|
||||
&padding_l, &padding_r);
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
// Create Convolution forward descriptor since Convolution backward
|
||||
// API needs it. For that, we first need to create input, filter
|
||||
// and output memory descriptors.
|
||||
auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
|
||||
auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(),
|
||||
mkl_data_format);
|
||||
auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(),
|
||||
memory::format::hwio);
|
||||
auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(),
|
||||
mkl_data_format);
|
||||
auto fwd_desc = convolution_forward::desc(prop_kind::forward,
|
||||
convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md,
|
||||
strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
|
||||
auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
|
||||
|
||||
// Allocate output tensor and shape
|
||||
// TODO(nhasabni): Update this when support for MKL layout is added.
|
||||
// Shape of output of Conv2DBackpropInput is same as 'input' of Conv2D.
|
||||
TensorShape tf_output_shape(input_shape);
|
||||
MklShape mkl_output_mkl_shape;
|
||||
mkl_output_mkl_shape.SetMklTensor(false);
|
||||
Tensor* output_tensor = nullptr;
|
||||
AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
|
||||
mkl_output_mkl_shape);
|
||||
|
||||
// Create memory for user data.
|
||||
// Describe how the inputs and outputs of Convolution look like. Also
|
||||
// specify buffers containing actual input and output data.
|
||||
// Although input shape required is in MKL-DNN order, the layout is
|
||||
// Tensorflow's layout (NHWC or NCHW depending on data format).
|
||||
// Although filter shape (filter_dims) required is in MKL-DNN order,
|
||||
// the layout is Tensorflow's layout (HWIO).
|
||||
// Shape of Conv2DBackpropInput's filter is same as that of Conv2D filter.
|
||||
filter.SetUsrMem(fwd_filter_dims, memory::format::hwio, &filter_tensor);
|
||||
// Outbackprop shape is NHWC or NCHW depending on data format. Since
|
||||
// GetInputSizeInMklOrder function returns size in that order we just use
|
||||
// use that function directly.
|
||||
conv_utl.GetInputSizeInMklOrder(obp_shape, &obp_dims);
|
||||
if (!context->status().ok()) return;
|
||||
outbackprop.SetUsrMem(obp_dims, mkl_data_format, &obp_tensor);
|
||||
// Although output shape required is in MKL-DNN order,
|
||||
// layout is Tensorflow's layout (NHWC or NCHW depending on data format).
|
||||
// Shape of output of Conv2DBackpropInput is same as shape of 'input'
|
||||
// of Conv2D.
|
||||
memory::dims bwd_output_dims = fwd_input_dims;
|
||||
output.SetUsrMem(bwd_output_dims, mkl_data_format, output_tensor);
|
||||
|
||||
// Create memory descriptors for convolution data w/ no specified format.
|
||||
filter.SetOpMemDesc(fwd_filter_dims, memory::format::any);
|
||||
outbackprop.SetOpMemDesc(obp_dims, memory::format::any);
|
||||
output.SetOpMemDesc(bwd_output_dims, memory::format::any);
|
||||
|
||||
// Create convolution backward data primitive.
|
||||
auto bwd_desc = convolution_backward_data::desc(convolution_direct,
|
||||
output.GetOpMemDesc(), filter.GetOpMemDesc(),
|
||||
outbackprop.GetOpMemDesc(), strides, padding_l,
|
||||
padding_r, TFPaddingToMklDnnPadding(padding_));
|
||||
|
||||
auto bwd_pd = convolution_backward_data::primitive_desc(bwd_desc,
|
||||
cpu_engine,
|
||||
fwd_pd);
|
||||
|
||||
PrepareAndExecutePrimitive(bwd_pd, &filter, &outbackprop, &output);
|
||||
} catch (mkldnn::error &e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + string(e.message) +
|
||||
", in file " + string(__FILE__) + ":" +
|
||||
std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:",
|
||||
error_msg));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int32> strides_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
|
||||
// Prepare and execute net - checks for input and output reorders.
|
||||
void PrepareAndExecutePrimitive(
|
||||
const convolution_backward_data::primitive_desc& conv_pd,
|
||||
MklDnnData<T>* filter, MklDnnData<T>* obp,
|
||||
MklDnnData<T>* output) {
|
||||
// Create reorders between user layout and MKL layout if it is needed and
|
||||
// add it to the net before convolution.
|
||||
std::vector<primitive> net;
|
||||
filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net);
|
||||
obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
|
||||
|
||||
// Memory for output of convolution. Since we may need reorder on the
|
||||
// output side, we will prepare reorder primitive in case output
|
||||
// reorder to user memory is required.
|
||||
bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
|
||||
conv_pd.diff_src_primitive_desc());
|
||||
|
||||
net.push_back(convolution_backward_data(conv_pd, obp->GetOpMem(),
|
||||
filter->GetOpMem(), output->GetOpMem()));
|
||||
|
||||
// Insert reorder primitive in the net for output reorder if reorder is
|
||||
// required.
|
||||
if (output_reorder_required) {
|
||||
output->InsertReorderToUserMem(&net);
|
||||
}
|
||||
|
||||
// Handle output reorder
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
}
|
||||
};
|
||||
|
||||
#endif // INTEL_MKL_DNN
|
||||
|
||||
#define REGISTER_MKL_CPU_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
|
||||
.Device(DEVICE_CPU) \
|
||||
|
|
|
|||
|
|
@ -19,6 +19,8 @@ limitations under the License.
|
|||
#include <string.h>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
|
|
@ -26,6 +28,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/mkl_conv_ops.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
|
|
@ -40,10 +43,23 @@ limitations under the License.
|
|||
#include "mkl_dnn.h"
|
||||
#include "mkl_dnn_types.h"
|
||||
|
||||
#ifdef INTEL_MKL_DNN
|
||||
#include "mkldnn.hpp"
|
||||
|
||||
using mkldnn::stream;
|
||||
using mkldnn::prop_kind;
|
||||
|
||||
using mkldnn::convolution_forward;
|
||||
using mkldnn::convolution_direct;
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
// For now, MKL-ML is default. So making MKL-DNN not a default choice.
|
||||
#ifndef INTEL_MKL_DNN
|
||||
|
||||
template <typename Device, typename T, bool biasEnabled>
|
||||
class MklConv2DOp : public OpKernel {
|
||||
public:
|
||||
|
|
@ -461,6 +477,205 @@ class MklConv2DOp : public OpKernel {
|
|||
TensorFormat data_format_;
|
||||
};
|
||||
|
||||
#else
|
||||
|
||||
template <typename Device, typename T, bool biasEnabled>
|
||||
class MklConv2DOp : public OpKernel {
|
||||
public:
|
||||
~MklConv2DOp() {}
|
||||
|
||||
explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
|
||||
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
|
||||
string data_format;
|
||||
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
|
||||
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
|
||||
errors::InvalidArgument("Invalid data format"));
|
||||
OP_REQUIRES(context, strides_.size() == 4,
|
||||
errors::InvalidArgument("Sliding window strides field must "
|
||||
"specify 4 dimensions"));
|
||||
|
||||
const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
|
||||
const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
|
||||
OP_REQUIRES(
|
||||
context, stride_n == 1 && stride_c == 1,
|
||||
errors::InvalidArgument("Current implementation does not yet support "
|
||||
"strides in the batch and depth dimensions."));
|
||||
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
try {
|
||||
auto cpu_engine = engine(engine::cpu, 0);
|
||||
|
||||
// Input tensors
|
||||
size_t src_idx = 0, filter_idx = 1;
|
||||
const Tensor& src_tensor = MklGetInput(context, src_idx);
|
||||
const Tensor& filter_tensor = MklGetInput(context, filter_idx);
|
||||
|
||||
MklDnnData<T> src(&cpu_engine);
|
||||
MklDnnData<T> filter(&cpu_engine);
|
||||
MklDnnData<T> output(&cpu_engine);
|
||||
|
||||
memory::dims src_dims, filter_dims, padding_l, padding_r, strides;
|
||||
memory::dims output_dims_tf_order, output_dims_mkl_order;
|
||||
|
||||
// Get shapes of input tensors in MKL-DNN order
|
||||
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
|
||||
conv_utl.GetConvFwdSizesInMklOrder(src_tensor.shape(),
|
||||
filter_tensor.shape(),
|
||||
&src_dims, &filter_dims, &strides,
|
||||
&output_dims_tf_order,
|
||||
&output_dims_mkl_order, &padding_l,
|
||||
&padding_r);
|
||||
if (!context->status().ok()) return;
|
||||
|
||||
// Check for corner case - if there is nothing to compute, return.
|
||||
TensorShape tf_output_shape({output_dims_tf_order[0],
|
||||
output_dims_tf_order[1],
|
||||
output_dims_tf_order[2],
|
||||
output_dims_tf_order[3]});
|
||||
Tensor* output_tensor = nullptr;
|
||||
MklShape mkl_output_mkl_shape;
|
||||
mkl_output_mkl_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
|
||||
mkl_output_mkl_shape);
|
||||
|
||||
// Forward filter in TF format from input at index 1 to output at index 1.
|
||||
ForwardTfTensorInToOut(context, 1, 1);
|
||||
|
||||
if (tf_output_shape.num_elements() == 0) {
|
||||
// TODO(jbobba): Verify correctness here
|
||||
// Need semantics for Null MKL tensor
|
||||
return;
|
||||
}
|
||||
|
||||
// Corner case to handle 0 batch size.
|
||||
if (output_dims_tf_order[0] == 0) {
|
||||
// Nothing to do, allocate output tensor and return
|
||||
// TODO(nhasabni): remove this code later once serialization
|
||||
// in MKL-DNN is supported.
|
||||
AllocateOutputSetMklShape(context, 0, &output_tensor,
|
||||
src_tensor.shape(), mkl_output_mkl_shape);
|
||||
return;
|
||||
} else {
|
||||
// Otherwise regular output tensor allocation
|
||||
// Allocate output tensor.
|
||||
}
|
||||
CHECK_NOTNULL(output_tensor);
|
||||
|
||||
// Create memory for user data.
|
||||
// Describe how the inputs and outputs of Convolution look like. Also
|
||||
// specify buffers containing actual input and output data.
|
||||
// Although input shape (src_dims) required is in MKL-DNN order,
|
||||
// the layout is Tensorflow's layout (NHWC or NCHW depending on data
|
||||
// format).
|
||||
src.SetUsrMem(src_dims, TFDataFormatToMklDnnDataFormat(data_format_),
|
||||
const_cast<void*>(static_cast<const void*>(
|
||||
src_tensor.flat<T>().data())));
|
||||
// Although filter shape (filter_dims) required is in MKL-DNN order,
|
||||
// the layout is Tensorflow's layout (HWIO).
|
||||
filter.SetUsrMem(filter_dims, memory::format::hwio,
|
||||
const_cast<void*>(static_cast<const void*>(
|
||||
filter_tensor.flat<T>().data())));
|
||||
// Although output shape (output_dims) required is in MKL-DNN order,
|
||||
// layout is Tensorflow's layout (NHWC or NCHW depending on data format).
|
||||
output.SetUsrMem(output_dims_mkl_order,
|
||||
TFDataFormatToMklDnnDataFormat(data_format_),
|
||||
output_tensor->flat<T>().data());
|
||||
|
||||
// Create memory descriptors for convolution data w/ no specified format.
|
||||
src.SetOpMemDesc(src_dims, memory::format::any);
|
||||
filter.SetOpMemDesc(filter_dims, memory::format::any);
|
||||
output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
|
||||
|
||||
// If bias is enabled, then do the same steps as above for bias.
|
||||
if (biasEnabled) {
|
||||
MklDnnData<T> bias(&cpu_engine);
|
||||
memory::dims bias_size;
|
||||
conv_utl.GetBiasSizeInMklOrder(2 /* bias idx */, &bias_size);
|
||||
const Tensor& bias_tensor = MklGetInput(context, 2);
|
||||
bias.SetUsrMem(bias_size, memory::format::x,
|
||||
const_cast<void*>(static_cast<const void*>(
|
||||
bias_tensor.flat<T>().data())));
|
||||
bias.SetOpMemDesc(bias_size, memory::format::any);
|
||||
|
||||
// Create convolution primitive with Bias.
|
||||
auto conv_desc = convolution_forward::desc(prop_kind::forward,
|
||||
convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(),
|
||||
bias.GetOpMemDesc(), output.GetOpMemDesc(), strides,
|
||||
padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
|
||||
|
||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc,
|
||||
cpu_engine);
|
||||
PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output);
|
||||
} else {
|
||||
// Create convolution primitive without Bias.
|
||||
auto conv_desc = convolution_forward::desc(prop_kind::forward,
|
||||
convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(),
|
||||
output.GetOpMemDesc(), strides, padding_l, padding_r,
|
||||
TFPaddingToMklDnnPadding(padding_));
|
||||
|
||||
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc,
|
||||
cpu_engine);
|
||||
PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output);
|
||||
}
|
||||
} catch (mkldnn::error &e) {
|
||||
string error_msg = "Status: " + std::to_string(e.status) +
|
||||
", message: " + std::string(e.message) +
|
||||
", in file " + std::string(__FILE__) + ":" +
|
||||
std::to_string(__LINE__);
|
||||
OP_REQUIRES_OK(context,
|
||||
errors::Aborted("Operation received an exception:", error_msg));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<int32> strides_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
|
||||
// Prepare and execute net - checks for input and output reorders.
|
||||
void PrepareAndExecuteNet(
|
||||
const convolution_forward::primitive_desc& conv_prim_desc,
|
||||
MklDnnData<T>* src, MklDnnData<T>* filter,
|
||||
MklDnnData<T>* bias, MklDnnData<T>* output) {
|
||||
// Create reorders between user layout and MKL layout if it is needed and
|
||||
// add it to the net before convolution.
|
||||
std::vector<primitive> net;
|
||||
src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc(), &net);
|
||||
filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(), &net);
|
||||
|
||||
// Memory for output of convolution. Since we may need reorder on the
|
||||
// output side, we will prepare reorder primitive in case output
|
||||
// reorder to user memory is required.
|
||||
bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
|
||||
conv_prim_desc.dst_primitive_desc());
|
||||
|
||||
// Create convolution primitive and add it to net.
|
||||
if (bias) {
|
||||
CHECK_EQ(biasEnabled, true);
|
||||
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
|
||||
filter->GetOpMem(), bias->GetOpMem(),
|
||||
output->GetOpMem()));
|
||||
} else {
|
||||
CHECK_EQ(biasEnabled, false);
|
||||
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
|
||||
filter->GetOpMem(), output->GetOpMem()));
|
||||
}
|
||||
|
||||
// Insert reorder primitive in the net for output reorder if reorder is
|
||||
// required.
|
||||
if (output_reorder_required) {
|
||||
output->InsertReorderToUserMem(&net);
|
||||
}
|
||||
|
||||
// Handle output reorder
|
||||
stream(stream::kind::eager).submit(net).wait();
|
||||
}
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
#define REGISTER_MKL_CPU(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
|
||||
.Device(DEVICE_CPU) \
|
||||
|
|
|
|||
316
tensorflow/core/kernels/mkl_conv_ops.h
Normal file
316
tensorflow/core/kernels/mkl_conv_ops.h
Normal file
|
|
@ -0,0 +1,316 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
|
||||
|
||||
#include <vector>
|
||||
#include <limits>
|
||||
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/tensor_slice.h"
|
||||
#include "tensorflow/core/kernels/bounds_check.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
#include "tensorflow/core/kernels/conv_grad_ops.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/lib/strings/numbers.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
#ifdef INTEL_MKL_DNN
|
||||
#include "mkldnn.hpp"
|
||||
#endif
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
#ifdef INTEL_MKL_DNN
|
||||
|
||||
class MklDnnConvUtil {
|
||||
protected:
|
||||
OpKernelContext* context_; // We don't own this.
|
||||
std::vector<int32> strides_;
|
||||
Padding padding_;
|
||||
TensorFormat data_format_;
|
||||
|
||||
public:
|
||||
MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides,
|
||||
Padding pad, TensorFormat fm) : context_(context),
|
||||
strides_(strides), padding_(pad), data_format_(fm) {}
|
||||
|
||||
virtual ~MklDnnConvUtil() { context_ = nullptr; }
|
||||
|
||||
// Calculate Convolution strides
|
||||
virtual inline void GetStridesInMklOrder(memory::dims *strides) {
|
||||
// For now we take the stride from the second and third dimensions only
|
||||
// (we do not support striding on the batch or depth dimension).
|
||||
CHECK_NOTNULL(strides);
|
||||
int stride_rows = GetTensorDim(strides_, data_format_, 'H');
|
||||
int stride_cols = GetTensorDim(strides_, data_format_, 'W');
|
||||
*strides = {stride_rows, stride_cols};
|
||||
}
|
||||
|
||||
// Calculate Convolution input size in MKL-DNN order. MKL-DNN
|
||||
// requires input in NCHW format. Function does not return anything.
|
||||
// But errors arising from sanity checks are returned in context's
|
||||
// status.
|
||||
virtual inline void
|
||||
GetInputSizeInMklOrder(const TensorShape& input_shape,
|
||||
memory::dims *input_dims) {
|
||||
#define CHECK_BOUNDS(val, err_msg) do { \
|
||||
OP_REQUIRES(context_, FastBoundsCheck(val, \
|
||||
std::numeric_limits<int>::max()), \
|
||||
errors::InvalidArgument(err_msg)); \
|
||||
}while(0)
|
||||
|
||||
CHECK_NOTNULL(input_dims);
|
||||
|
||||
// Input channel
|
||||
int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
|
||||
int input_depth = static_cast<int>(input_depth_raw);
|
||||
|
||||
// Input rows/height
|
||||
int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
|
||||
CHECK_BOUNDS(input_rows_raw, "Input rows too large");
|
||||
int input_rows = static_cast<int>(input_rows_raw);
|
||||
|
||||
// Input columns/width
|
||||
int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
|
||||
CHECK_BOUNDS(input_cols_raw, "Input cols too large");
|
||||
int input_cols = static_cast<int>(input_cols_raw);
|
||||
|
||||
// Input batch
|
||||
int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
|
||||
CHECK_BOUNDS(input_batch_raw, "Input batch too large");
|
||||
int input_batch = static_cast<int>(input_batch_raw);
|
||||
|
||||
#undef CHECK_BOUNDS
|
||||
|
||||
// MKL-DNN always requires input in NCHW format.
|
||||
*input_dims = {input_batch, input_depth, input_rows, input_cols};
|
||||
}
|
||||
|
||||
// Calculate Convolution filter size in MKL-DNN order. MKL-DNN
|
||||
// requires filter in OIHW format. Function does not return anything.
|
||||
// But errors arising from sanity checks are returned in context's
|
||||
// status.
|
||||
//
|
||||
// Calculate Convolution filter size in MKL-DNN order. MKL-DNN
|
||||
// requires filter in OIHW format. Function does not return anything.
|
||||
// But errors arising from sanity checks are returned in context's
|
||||
// status. This function differs from GetConvFilterSizeInMklOrder in
|
||||
// parameter for input - it accepts src_shape since Convolution Backward
|
||||
// Input gets shape of input tensor rather than actual tensor (Convolution
|
||||
// forward gets actual tensor as input).
|
||||
//
|
||||
// TODO(nhasabni): Add similar function for input and filter in MklShape.
|
||||
virtual inline void
|
||||
GetFilterSizeInMklOrder(const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape,
|
||||
memory::dims *filter_dims) {
|
||||
CHECK_NOTNULL(filter_dims);
|
||||
|
||||
OP_REQUIRES(context_, filter_shape.dims() == 4,
|
||||
errors::InvalidArgument("filter must be 4-dimensional: ",
|
||||
filter_shape.DebugString()));
|
||||
|
||||
for (int i = 0; i < 3; i++) {
|
||||
OP_REQUIRES(context_, FastBoundsCheck(filter_shape.dim_size(i),
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("filter too large"));
|
||||
}
|
||||
|
||||
int input_depth = GetTensorDim(input_shape, data_format_, 'C');
|
||||
|
||||
OP_REQUIRES(
|
||||
context_, input_depth == filter_shape.dim_size(2),
|
||||
errors::InvalidArgument("input and filter must have the same depth: ",
|
||||
input_depth, " vs ", filter_shape.dim_size(2)));
|
||||
|
||||
// TF filter is always in (rows, cols, in_depth, out_depth) order.
|
||||
int filter_rows = static_cast<int>(filter_shape.dim_size(0));
|
||||
int filter_cols = static_cast<int>(filter_shape.dim_size(1));
|
||||
int in_depth = static_cast<int>(filter_shape.dim_size(2));
|
||||
int out_depth = static_cast<int>(filter_shape.dim_size(3));
|
||||
|
||||
// MKL-DNN always needs filter in OIHW format.
|
||||
// OIHW = (out_depth, in_depth, rows, cols)
|
||||
*filter_dims = {out_depth, in_depth, filter_rows, filter_cols};
|
||||
}
|
||||
|
||||
// Calculate Convolution filter size in MKL-DNN order. MKL-DNN
|
||||
// requires filter in OIHW format. Function does not return anything.
|
||||
// But errors arising from sanity checks are returned in context's
|
||||
// status.
|
||||
virtual inline void
|
||||
GetFilterSizeInMklOrder(size_t src_index, size_t filter_index,
|
||||
memory::dims *filter_dims) {
|
||||
CHECK_NOTNULL(filter_dims);
|
||||
const Tensor& input = MklGetInput(context_, src_index);
|
||||
const Tensor& filter = MklGetInput(context_, filter_index);
|
||||
GetFilterSizeInMklOrder(input.shape(), filter.shape(), filter_dims);
|
||||
}
|
||||
|
||||
// Calculate Bias size for 2D Convolution. Function does not return
|
||||
// anything, but sets error in context status.
|
||||
virtual inline void
|
||||
GetBiasSizeInMklOrder(size_t bias_index, memory::dims *bias_dims) {
|
||||
const Tensor& bias = MklGetInput(context_, bias_index);
|
||||
OP_REQUIRES(context_, bias.dims() == 1,
|
||||
errors::InvalidArgument("bias must be 1-dimensional: ",
|
||||
bias.shape().DebugString()));
|
||||
|
||||
*bias_dims = { static_cast<int>(bias.dim_size(0)) };
|
||||
}
|
||||
|
||||
// Function to calculate output and padding size for 2D convolution.
|
||||
//
|
||||
// Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
|
||||
// MKL-DNN uses NCHW for output order. But TensorFlow output will be in
|
||||
// NHWC or NCHW format depending on data format. Function also calculates
|
||||
// left, right, top and bottom pads. Function does not return any status -
|
||||
// status is returned via context status.
|
||||
//
|
||||
// TODO(nhasabni): Add similar function for input and filter in MklShape.
|
||||
virtual inline void
|
||||
GetOutputAndPadSizeInMklOrder(const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape,
|
||||
const memory::dims& strides,
|
||||
memory::dims *output_dims_tf_order,
|
||||
memory::dims *output_dims_mkl_order,
|
||||
memory::dims *pad_l, memory::dims *pad_r) {
|
||||
CHECK_NOTNULL(output_dims_tf_order);
|
||||
CHECK_NOTNULL(output_dims_mkl_order);
|
||||
CHECK_NOTNULL(pad_l);
|
||||
CHECK_NOTNULL(pad_r);
|
||||
|
||||
int input_rows = GetTensorDim(input_shape, data_format_, 'H');
|
||||
int input_cols = GetTensorDim(input_shape, data_format_, 'W');
|
||||
|
||||
// The first dimension for filter is rows/height.
|
||||
int filter_rows = filter_shape.dim_size(0);
|
||||
// The second dimension for filter is cols/width.
|
||||
int filter_cols = filter_shape.dim_size(1);
|
||||
|
||||
// Stride is vector of 2 elements: {s_r, s_c}
|
||||
int stride_rows = strides[0];
|
||||
int stride_cols = strides[1];
|
||||
|
||||
// Output batch is same as input batch.
|
||||
int out_batch = GetTensorDim(input_shape, data_format_, 'N');
|
||||
// Output depth is same as last dimension for filter.
|
||||
int out_depth = filter_shape.dim_size(3);
|
||||
|
||||
int64 out_rows = 0, out_cols = 0;
|
||||
int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
|
||||
|
||||
OP_REQUIRES_OK(context_,
|
||||
GetWindowedOutputSizeVerbose(input_rows, filter_rows, stride_rows,
|
||||
padding_, &out_rows, &pad_top, &pad_bottom));
|
||||
OP_REQUIRES_OK(context_,
|
||||
GetWindowedOutputSizeVerbose(input_cols, filter_cols, stride_cols,
|
||||
padding_, &out_cols, &pad_left, &pad_right));
|
||||
|
||||
// Tensorflow output is in data_format order. (NHWC or NCHW)
|
||||
TensorShape out_shape = ShapeFromFormat(data_format_, out_batch,
|
||||
out_rows, out_cols, out_depth);
|
||||
*output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
|
||||
|
||||
// MKL-DNN always needs output in NCHW format.
|
||||
*output_dims_mkl_order = {out_batch, out_depth, static_cast<int>(out_rows),
|
||||
static_cast<int>(out_cols)};
|
||||
|
||||
// Now handle padding. MKL-DNN uses asymetric padding.
|
||||
*pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
|
||||
*pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
|
||||
}
|
||||
|
||||
// Calculate output and pad size of forward Convolution operator.
|
||||
// See comment on GetConvOutputAndPadSizeInMklOrder for parameters.
|
||||
//
|
||||
// Function does not return anything, but sets error in context status.
|
||||
inline void
|
||||
GetOutputAndPadSizeInMklOrder(size_t src_index, size_t filter_index,
|
||||
const memory::dims& strides,
|
||||
memory::dims *output_dims_tf_order,
|
||||
memory::dims *output_dims_mkl_order,
|
||||
memory::dims *pad_l, memory::dims *pad_r) {
|
||||
CHECK_NOTNULL(output_dims_tf_order);
|
||||
CHECK_NOTNULL(output_dims_mkl_order);
|
||||
CHECK_NOTNULL(pad_l);
|
||||
CHECK_NOTNULL(pad_r);
|
||||
|
||||
const Tensor& input = MklGetInput(context_, src_index);
|
||||
const Tensor& filter = MklGetInput(context_, filter_index);
|
||||
|
||||
OP_REQUIRES(context_, input.dims() == 4,
|
||||
errors::InvalidArgument("input must be 4-dimensional",
|
||||
input.shape().DebugString()));
|
||||
|
||||
GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(),
|
||||
strides, output_dims_tf_order,
|
||||
output_dims_mkl_order, pad_l, pad_r);
|
||||
}
|
||||
|
||||
// Wrapper function to calculate input, filter, and output sizes of
|
||||
// 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.)
|
||||
// Function also calculates output shape in Tensorflow order. Additionally, it
|
||||
// also calculates strides and paddings for 2D Convolution.
|
||||
//
|
||||
// Function does not return anything, but sets error in context status.
|
||||
inline void GetConvFwdSizesInMklOrder(const TensorShape& input_shape,
|
||||
const TensorShape& filter_shape,
|
||||
memory::dims *input_dims,
|
||||
memory::dims *filter_dims,
|
||||
memory::dims *strides,
|
||||
memory::dims *output_dims_tf_order,
|
||||
memory::dims *output_dims_mkl_order,
|
||||
memory::dims *pad_l,
|
||||
memory::dims *pad_r) {
|
||||
CHECK_NOTNULL(input_dims);
|
||||
CHECK_NOTNULL(filter_dims);
|
||||
CHECK_NOTNULL(strides);
|
||||
CHECK_NOTNULL(output_dims_tf_order);
|
||||
CHECK_NOTNULL(output_dims_mkl_order);
|
||||
CHECK_NOTNULL(pad_l);
|
||||
CHECK_NOTNULL(pad_r);
|
||||
|
||||
GetInputSizeInMklOrder(input_shape, input_dims);
|
||||
if (!context_->status().ok()) return;
|
||||
GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims);
|
||||
if (!context_->status().ok()) return;
|
||||
GetStridesInMklOrder(strides);
|
||||
GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides,
|
||||
output_dims_tf_order,
|
||||
output_dims_mkl_order,
|
||||
pad_l, pad_r);
|
||||
if (!context_->status().ok()) return;
|
||||
}
|
||||
};
|
||||
|
||||
#endif // INTEL_MKL_DNN
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
|
||||
|
|
@ -26,13 +26,19 @@ limitations under the License.
|
|||
#include "mkl_trans.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/array_slice.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/graph/mkl_graph_util.h"
|
||||
|
||||
#ifdef INTEL_MKL_DNN
|
||||
#include "mkldnn.hpp"
|
||||
#endif
|
||||
|
||||
// The file contains a number of utility classes and functions used by MKL
|
||||
// enabled kernels
|
||||
|
|
@ -219,19 +225,19 @@ class MklShape {
|
|||
// Location from start of buffer where isMklTensor_ is serialized
|
||||
#define DIMS_OFFSET \
|
||||
(IS_MKL_TENSOR_OFFSET + sizeof(size_t)) // Location of dimension_
|
||||
// Location of sizes. Note dim is not used here, left here
|
||||
// to make macros consistent.
|
||||
#define SIZES_OFFSET(dims) \
|
||||
(DIMS_OFFSET + \
|
||||
sizeof(size_t)) // Location of sizes. Note dim is not used here, left here
|
||||
// to make macros consistent.
|
||||
(DIMS_OFFSET + sizeof(size_t))
|
||||
#define STRIDES_OFFSET(dims) \
|
||||
(SIZES_OFFSET(dims) + dims * sizeof(size_t)) // Location of strides
|
||||
#define MKL_LAYOUT_OFFSET(dims) \
|
||||
(STRIDES_OFFSET(dims) + dims * sizeof(size_t)) // Location of mklLayout_
|
||||
#define TF_LAYOUT_OFFSET(dims) \
|
||||
(MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF) // Location of tfLayout_
|
||||
// Location of tf_to_mkl_dim_map_
|
||||
#define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
|
||||
(TF_LAYOUT_OFFSET(dims) + \
|
||||
SIZE_OF_MKL_DNN_BUF) // Location of tf_to_mkl_dim_map_
|
||||
(TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
|
||||
|
||||
// TODO(agramesh1) make sure to create a const to share with rewrite pass
|
||||
// for min size of MKL metadata tensor.
|
||||
|
|
@ -342,58 +348,6 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
|
|||
return output_tensor;
|
||||
}
|
||||
|
||||
// Since our ops are going to produce and also consume N addition tensors
|
||||
// (Mkl) for N Tensorflow tensors, we can have following different
|
||||
// orderings among these 2N tensors.
|
||||
//
|
||||
// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
|
||||
// consume A_m, B_m, and C_m additionally.
|
||||
//
|
||||
// INTERLEAVED: in this case 2N tensors are interleaved. So for above
|
||||
// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
|
||||
//
|
||||
// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
|
||||
// by N Mkl tensors. So for above example, the ordering looks
|
||||
// like: A, B, C, A_m, B_m, C_m
|
||||
//
|
||||
// Following APIs map index of original Tensorflow tensors to their appropriate
|
||||
// position based on selected ordering. For contiguous ordering, we need to know
|
||||
// the total number of tensors (parameter total).
|
||||
//
|
||||
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
|
||||
// NOTE: Currently, we use contiguous ordering. If you change this, then you
|
||||
// would need to change Mkl op definitions in nn_ops.cc.
|
||||
static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
|
||||
|
||||
// Get index of MetaData tensor from index 'n' of Data tensor.
|
||||
inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
// For interleaved ordering, Mkl tensor follows immediately after
|
||||
// Tensorflow tensor.
|
||||
return n + 1;
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
// For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
|
||||
return n + total_tensors / 2;
|
||||
}
|
||||
}
|
||||
|
||||
int inline GetTensorDataIndex(int n, int total_tensors) {
|
||||
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
|
||||
return 2 * n; // index corresponding to nth input/output tensor
|
||||
} else {
|
||||
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
|
||||
return n;
|
||||
}
|
||||
}
|
||||
|
||||
int inline GetTensorMetaDataIndex(int n, int total_tensors) {
|
||||
// Get index for TensorData first and then use mapping function
|
||||
// to get TensorMetaData index from TensorData index.
|
||||
int tidx = GetTensorDataIndex(n, total_tensors);
|
||||
return DataIndexToMetaDataIndex(tidx, total_tensors);
|
||||
}
|
||||
|
||||
// Get the MKL shape from the second string tensor
|
||||
inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
|
||||
mklshape->DeSerializeMklShape(
|
||||
|
|
@ -480,6 +434,13 @@ inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
|
|||
*buf_out = static_cast<void*>(tensor_out->flat<float>().data());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
|
||||
TensorShape tf_shape) {
|
||||
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
|
||||
tf_shape, tensor_out));
|
||||
}
|
||||
|
||||
inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
|
||||
const size_t* sizes) {
|
||||
// MKL requires strides in NCHW
|
||||
|
|
@ -743,56 +704,294 @@ inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
|
|||
}
|
||||
}
|
||||
|
||||
namespace mkl_op_registry {
|
||||
static const char* kMklOpLabel = "MklOp";
|
||||
static const char* kMklOpLabelPattern = "label='MklOp'";
|
||||
// -------------------------------------------------------------------
|
||||
|
||||
// Get the name of Mkl op from original TensorFlow op
|
||||
// We prefix 'Mkl' to the original op to get Mkl op.
|
||||
inline string GetMklOpName(const string& name) {
|
||||
// Prefix that we add to Tensorflow op name to construct Mkl op name.
|
||||
const char* const kMklOpPrefix = "_Mkl";
|
||||
return string(kMklOpPrefix) + name;
|
||||
#ifdef INTEL_MKL_DNN
|
||||
|
||||
using mkldnn::memory;
|
||||
using mkldnn::reorder;
|
||||
using mkldnn::primitive;
|
||||
using mkldnn::padding_kind;
|
||||
using mkldnn::engine;
|
||||
|
||||
/// Return MKL-DNN data type (memory::data_type) for input type T
|
||||
///
|
||||
/// @input None
|
||||
/// @return memory::data_type corresponding to type T
|
||||
template<typename T> static memory::data_type MklDnnType();
|
||||
|
||||
/// Instantiation for float type. Add similar instantiations for other
|
||||
/// type if needed.
|
||||
template <>
|
||||
memory::data_type MklDnnType<float>() {
|
||||
return memory::data_type::f32;
|
||||
}
|
||||
|
||||
// Check whether opname with type T is registered as MKL-compliant.
|
||||
//
|
||||
// @input: name of the op
|
||||
// @input: T datatype to be used for checking op
|
||||
// @return: true if opname is registered as Mkl op; false otherwise
|
||||
static inline bool IsMklOp(const std::string& op_name, DataType T) {
|
||||
string kernel = KernelsRegisteredForOp(op_name);
|
||||
bool result =
|
||||
kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
|
||||
if (result) {
|
||||
VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
|
||||
/// Map TensorFlow's data format into MKL-DNN data format
|
||||
///
|
||||
/// @input: TensorFlow data format
|
||||
/// @return: memory::format corresponding to TensorFlow data format;
|
||||
/// Fails with an error if invalid data format.
|
||||
inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
|
||||
if (format == FORMAT_NHWC) return memory::format::nhwc;
|
||||
else if (format == FORMAT_NCHW) return memory::format::nchw;
|
||||
TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT,
|
||||
"Unsupported data format"));
|
||||
// Return to get rid of compiler warning
|
||||
return memory::format::format_undef;
|
||||
}
|
||||
|
||||
/// Map TensorShape object into memory::dims required by MKL-DNN
|
||||
///
|
||||
/// This function will simply map input TensorShape into MKL-DNN dims
|
||||
/// naively. So it will preserve the order of dimensions. E.g., if
|
||||
/// input tensor is in NHWC format, then dims will be in NHWC format
|
||||
/// also.
|
||||
///
|
||||
/// @input TensorShape object in shape
|
||||
/// @return memory::dims corresponding to TensorShape
|
||||
inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
|
||||
memory::dims dims(shape.dims());
|
||||
for (unsigned int d = 0; d < shape.dims(); ++d) {
|
||||
dims[d] = shape.dim_size(d);
|
||||
}
|
||||
return result;
|
||||
return dims;
|
||||
}
|
||||
|
||||
// Check whether opname with type T is registered as MKL-compliant and
|
||||
// is element-wise.
|
||||
//
|
||||
// @input: name of the op
|
||||
// @input: T datatype to be used for checking op
|
||||
// @return: true if opname is registered as element-wise Mkl op; false otherwise
|
||||
static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) {
|
||||
if (!IsMklOp(op_name, T)) {
|
||||
/// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
|
||||
///
|
||||
/// This function is a specific one than above function. It will map input
|
||||
/// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
|
||||
/// order of dimensions. E.g., if input tensor is in NHWC format, then dims
|
||||
/// will be in NCHW format, and not in NHWC format.
|
||||
///
|
||||
/// @input TensorShape object in shape
|
||||
/// @return memory::dims in MKL-DNN required NCHW format
|
||||
inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
|
||||
TensorFormat format) {
|
||||
// Check validity of format.
|
||||
CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
|
||||
memory::format::format_undef);
|
||||
|
||||
int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
|
||||
int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
|
||||
int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
|
||||
int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
|
||||
|
||||
// MKL-DNN requires dimensions in NCHW format.
|
||||
return memory::dims({n, c, h, w});
|
||||
}
|
||||
|
||||
inline padding_kind TFPaddingToMklDnnPadding(Padding pad) {
|
||||
// MKL-DNN only supports zero padding.
|
||||
return padding_kind::zero;
|
||||
}
|
||||
|
||||
/*
|
||||
* Class to represent all the resources corresponding to a tensor in TensorFlow
|
||||
* that are required to execute an operation (such as Convolution).
|
||||
*/
|
||||
template <typename T>
|
||||
class MklDnnData {
|
||||
private:
|
||||
/// MKL-DNN memory primitive for input user memory
|
||||
memory* user_memory_;
|
||||
|
||||
/// MKL-DNN memory primitive in case input or output reorder is needed.
|
||||
memory* reorder_memory_;
|
||||
|
||||
/// Operations memory descriptor
|
||||
memory::desc* op_md_;
|
||||
|
||||
/// CPU engine on which operation will be executed
|
||||
const engine* cpu_engine_;
|
||||
|
||||
public:
|
||||
explicit MklDnnData(const engine* e) : user_memory_(nullptr),
|
||||
reorder_memory_(nullptr),
|
||||
op_md_(nullptr), cpu_engine_(e) {}
|
||||
|
||||
~MklDnnData() {
|
||||
cpu_engine_ = nullptr; // We don't own this.
|
||||
delete(user_memory_);
|
||||
delete(reorder_memory_);
|
||||
delete(op_md_);
|
||||
}
|
||||
|
||||
void* GetTensorBuffer(const Tensor* tensor) {
|
||||
CHECK_NOTNULL(tensor);
|
||||
return const_cast<void*>(static_cast<const void*>(
|
||||
tensor->flat<T>().data()));
|
||||
}
|
||||
|
||||
/// Set user memory primitive using specified dimensions, memory format and
|
||||
/// data_buffer. Function automatically uses element data type by using
|
||||
/// input type T used for creating call object.
|
||||
///
|
||||
/// In a nutshell, function allows user to describe the input tensor to
|
||||
/// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
|
||||
/// memory format HWIO, and the buffer that contains actual values is
|
||||
/// pointed by data_buffer.
|
||||
void SetUsrMem(memory::dims dim, memory::format fm, void* data_buffer) {
|
||||
CHECK_NOTNULL(data_buffer);
|
||||
CHECK_NOTNULL(cpu_engine_);
|
||||
// TODO(nhasabni): can we remove dynamic memory allocation?
|
||||
user_memory_ = new memory(memory::primitive_desc(
|
||||
memory::desc(dim, MklDnnType<T>(), fm),
|
||||
*cpu_engine_), data_buffer);
|
||||
}
|
||||
|
||||
void SetUsrMem(memory::dims dim, memory::format fm, const Tensor* tensor) {
|
||||
CHECK_NOTNULL(tensor);
|
||||
SetUsrMem(dim, fm, GetTensorBuffer(tensor));
|
||||
}
|
||||
|
||||
/// A version of function to set user memory primitive that accepts memory
|
||||
/// descriptor directly, instead of accepting dimensions and format. This
|
||||
/// function is more generic that the one above, but the function above is
|
||||
/// sufficient in most cases.
|
||||
void SetUsrMem(memory::desc md, void* data_buffer) {
|
||||
CHECK_NOTNULL(data_buffer);
|
||||
CHECK_NOTNULL(cpu_engine_);
|
||||
// TODO(nhasabni): can we remove dynamic memory allocation?
|
||||
user_memory_ = new memory(memory::primitive_desc(md, *cpu_engine_),
|
||||
data_buffer);
|
||||
}
|
||||
|
||||
/// A version of SetUsrMem with memory descriptor and tensor
|
||||
void SetUsrMem(memory::desc md, const Tensor* tensor) {
|
||||
CHECK_NOTNULL(tensor);
|
||||
SetUsrMem(md, GetTensorBuffer(tensor));
|
||||
}
|
||||
|
||||
/// A version of function to set user memory primitive that accepts primitive
|
||||
/// descriptor directly, instead of accepting dimensions and format. This
|
||||
/// function is more generic that the one above, but the function above is
|
||||
/// sufficient in most cases.
|
||||
void SetUsrMem(memory::primitive_desc pd, void* data_buffer) {
|
||||
CHECK_NOTNULL(data_buffer);
|
||||
CHECK_NOTNULL(cpu_engine_);
|
||||
// TODO(nhasabni): can we remove dynamic memory allocation?
|
||||
user_memory_ = new memory(pd, data_buffer);
|
||||
}
|
||||
|
||||
/// A version of SetUsrMem with primitive descriptor and tensor
|
||||
void SetUsrMem(memory::primitive_desc pd, const Tensor* tensor) {
|
||||
CHECK_NOTNULL(tensor);
|
||||
SetUsrMem(pd, GetTensorBuffer(tensor));
|
||||
}
|
||||
|
||||
/// Get function for user memory primitive.
|
||||
const memory* GetUsrMem() const { return user_memory_; }
|
||||
|
||||
/// Get function for primitive descriptor of user memory primitive.
|
||||
const memory::primitive_desc GetUsrMemPrimDesc() const {
|
||||
CHECK_NOTNULL(user_memory_);
|
||||
return user_memory_->get_primitive_desc();
|
||||
}
|
||||
|
||||
/// Get function for descriptor of user memory.
|
||||
memory::desc GetUsrMemDesc() {
|
||||
// This is ugly. Why MKL-DNN does not provide desc() method of const type??
|
||||
const memory::primitive_desc pd = GetUsrMemPrimDesc();
|
||||
return const_cast<memory::primitive_desc*>(&pd)->desc();
|
||||
}
|
||||
|
||||
/// Get function for data buffer of user memory primitive.
|
||||
void* GetUsrMemDataHandle() const {
|
||||
CHECK_NOTNULL(user_memory_);
|
||||
return user_memory_->get_data_handle();
|
||||
}
|
||||
|
||||
/// Get the memory primitive for input and output of an op. If inputs
|
||||
/// to an op require reorders, then this function returns memory primitive
|
||||
/// for reorder. Otherwise, it will return memory primitive for user memory.
|
||||
///
|
||||
/// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
|
||||
/// execute Conv2D, we need memory primitive for I and F. Buf if reorder is
|
||||
/// required for I and F (say I_r is reorder primitive for I; F_r is reorder
|
||||
/// primitive for F), then we need I_r and F_r to perform Conv2D.
|
||||
const memory& GetOpMem() const {
|
||||
return reorder_memory_ ? *reorder_memory_ : *user_memory_;
|
||||
}
|
||||
|
||||
/// Set memory descriptor of an operation in terms of dimensions and memory
|
||||
/// format. E.g., For Conv2D, the dimensions would be same as user dimensions
|
||||
/// but memory::format would be mkldnn::any because we want MKL-DNN to choose
|
||||
/// best layout/format for given input dimensions.
|
||||
void SetOpMemDesc(const memory::dims& dim, memory::format fm) {
|
||||
// TODO(nhasabni): can we remove dynamic memory allocation?
|
||||
op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
|
||||
}
|
||||
|
||||
/// Get function for memory descriptor for an operation
|
||||
const memory::desc& GetOpMemDesc() const { return *op_md_; }
|
||||
|
||||
/// Function to handle input reordering
|
||||
///
|
||||
/// Check if we need to reorder this input of an operation.
|
||||
/// Return true and allocate reorder memory primitive if reorder is needed.
|
||||
/// Otherwise, return false and do not allocate reorder memory primitive.
|
||||
///
|
||||
/// To check if reorder is needed, this function compares memory primitive
|
||||
/// descriptor of an operation (op_pd) for the given input with the
|
||||
/// user-specified memory primitive descriptor.
|
||||
///
|
||||
/// @input: op_pd - memory primitive descriptor of the given input of an
|
||||
/// operation
|
||||
/// @input: net - net to which to add reorder primitive in case it is needed.
|
||||
/// @return: true in case reorder of input is needed; false, otherwise.
|
||||
bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
|
||||
std::vector<primitive>* net) {
|
||||
CHECK_NOTNULL(net);
|
||||
CHECK_NOTNULL(user_memory_);
|
||||
if (op_pd != user_memory_->get_primitive_desc()) {
|
||||
// TODO(nhasabni): can we remove dynamic memory allocation?
|
||||
reorder_memory_ = new memory(op_pd);
|
||||
net->push_back(reorder(*user_memory_, *reorder_memory_));
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
|
||||
0 == op_name.compare(GetMklOpName("Sub")) ||
|
||||
0 == op_name.compare(GetMklOpName("Mul")) ||
|
||||
0 == op_name.compare(GetMklOpName("Maximum")) ||
|
||||
0 == op_name.compare(GetMklOpName("SquaredDifference")));
|
||||
/// Function to handle output reorder
|
||||
///
|
||||
/// This function performs very similar functionality as input reordering
|
||||
/// function above. The only difference is that this function does not add
|
||||
/// reorder primitive to the net. The reason for this is: the reorder
|
||||
/// primitive for output needs to be added to the list only after operation
|
||||
/// has executed. But we need to prepare a temporary buffer in case output
|
||||
/// reorder is needed. And this temporary buffer will hold the output of
|
||||
/// an operation before it is fed to reorder primitive.
|
||||
///
|
||||
/// @input memory primitive descriptor for the given output of an operation
|
||||
/// @return: true in case reorder of output is needed; false, otherwise.
|
||||
bool PrepareReorderToUserMemIfReq(const memory::primitive_desc& op_pd) {
|
||||
CHECK_NOTNULL(user_memory_);
|
||||
if (op_pd != user_memory_->get_primitive_desc()) {
|
||||
// TODO(nhasabni): can we remove dynamic memory allocation?
|
||||
reorder_memory_ = new memory(op_pd);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
VLOG(1) << "mkl_op_registry::" << op_name
|
||||
<< " is elementwise MKL op: " << result;
|
||||
return result;
|
||||
}
|
||||
/// Function to actually insert reorder primitive in the net
|
||||
///
|
||||
/// This function completes remaining part of output reordering. It inserts
|
||||
/// a reordering primitive from the temporary buffer that holds the output
|
||||
/// to the user-specified output buffer.
|
||||
///
|
||||
/// @input: net - net to which to add reorder primitive
|
||||
void InsertReorderToUserMem(std::vector<primitive>* net) {
|
||||
CHECK_NOTNULL(net);
|
||||
CHECK_NOTNULL(user_memory_);
|
||||
CHECK_NOTNULL(reorder_memory_);
|
||||
net->push_back(reorder(*reorder_memory_, *user_memory_));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mkl_op_registry
|
||||
#endif // INTEL_MKL_DNN
|
||||
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
||||
|
|
|
|||
|
|
@ -165,8 +165,8 @@ def tf_copts():
|
|||
"-DEIGEN_AVOID_STL_ARRAY",
|
||||
"-Iexternal/gemmlowp",
|
||||
"-Wno-sign-compare",
|
||||
"-fno-exceptions",
|
||||
"-ftemplate-depth=900",
|
||||
"-fno-exceptions",
|
||||
]) + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm(
|
||||
["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + select({
|
||||
clean_dep("//tensorflow:android"): [
|
||||
|
|
@ -526,6 +526,7 @@ def tf_cc_test(name,
|
|||
extra_copts=[],
|
||||
suffix="",
|
||||
linkopts=[],
|
||||
nocopts=None,
|
||||
**kwargs):
|
||||
native.cc_test(
|
||||
name="%s%s" % (name, suffix),
|
||||
|
|
@ -547,6 +548,7 @@ def tf_cc_test(name,
|
|||
clean_dep("//tensorflow:darwin"): 1,
|
||||
"//conditions:default": 0,
|
||||
}),
|
||||
nocopts=nocopts,
|
||||
**kwargs)
|
||||
|
||||
|
||||
|
|
@ -649,7 +651,8 @@ def tf_cc_tests(srcs,
|
|||
tags=[],
|
||||
size="medium",
|
||||
args=None,
|
||||
linkopts=[]):
|
||||
linkopts=[],
|
||||
nocopts=None):
|
||||
for src in srcs:
|
||||
tf_cc_test(
|
||||
name=src_to_test_name(src),
|
||||
|
|
@ -659,7 +662,8 @@ def tf_cc_tests(srcs,
|
|||
tags=tags,
|
||||
size=size,
|
||||
args=args,
|
||||
linkopts=linkopts)
|
||||
linkopts=linkopts,
|
||||
nocopts=nocopts)
|
||||
|
||||
|
||||
def tf_cc_test_mkl(srcs,
|
||||
|
|
@ -669,7 +673,7 @@ def tf_cc_test_mkl(srcs,
|
|||
tags=[],
|
||||
size="medium",
|
||||
args=None):
|
||||
if_mkl(tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args))
|
||||
if_mkl(tf_cc_tests(srcs, deps, name, linkstatic=linkstatic, tags=tags, size=size, args=args, nocopts="-fno-exceptions"))
|
||||
|
||||
|
||||
def tf_cc_tests_gpu(srcs,
|
||||
|
|
@ -867,18 +871,29 @@ def tf_mkl_kernel_library(name,
|
|||
deps=None,
|
||||
alwayslink=1,
|
||||
copts=tf_copts(),
|
||||
nocopts="-fno-exceptions",
|
||||
**kwargs):
|
||||
if_mkl(
|
||||
tf_kernel_library(
|
||||
name,
|
||||
prefix=prefix,
|
||||
if not bool(srcs):
|
||||
srcs = []
|
||||
if not bool(hdrs):
|
||||
hdrs = []
|
||||
|
||||
if prefix:
|
||||
srcs = srcs + native.glob(
|
||||
[prefix + "*.cc"])
|
||||
hdrs = hdrs + native.glob(
|
||||
[prefix + "*.h"])
|
||||
|
||||
if_mkl(
|
||||
native.cc_library(
|
||||
name=name,
|
||||
srcs=srcs,
|
||||
gpu_srcs=gpu_srcs,
|
||||
hdrs=hdrs,
|
||||
deps=deps,
|
||||
alwayslink=alwayslink,
|
||||
copts=copts,
|
||||
**kwargs))
|
||||
nocopts=nocopts
|
||||
))
|
||||
|
||||
|
||||
# Bazel rules for building swig files.
|
||||
|
|
|
|||
|
|
@ -170,6 +170,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
|
|||
print("path_prefix was specified to tf_workspace but is no longer used " +
|
||||
"and will be removed in the future.")
|
||||
|
||||
native.new_http_archive(
|
||||
name = "mkl_dnn",
|
||||
urls = [
|
||||
"https://github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz",
|
||||
"http://mirror.bazel.build/github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz",
|
||||
],
|
||||
sha256 = "0d529ad4c49dc799e6df07c2b88b115d0668735da15fb3b3862d28d33fa68165",
|
||||
strip_prefix = "mkl-dnn-b01e3a55a07be62172e713bcd2644c5176360212",
|
||||
build_file = str(Label("//third_party/mkl_dnn:mkldnn.BUILD")),
|
||||
)
|
||||
|
||||
native.new_http_archive(
|
||||
name = "eigen_archive",
|
||||
urls = [
|
||||
|
|
|
|||
1
third_party/mkl_dnn/BUILD
vendored
Normal file
1
third_party/mkl_dnn/BUILD
vendored
Normal file
|
|
@ -0,0 +1 @@
|
|||
licenses(["notice"])
|
||||
25
third_party/mkl_dnn/mkldnn.BUILD
vendored
Normal file
25
third_party/mkl_dnn/mkldnn.BUILD
vendored
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
exports_files(["LICENSE"])
|
||||
|
||||
cc_library(
|
||||
name = "mkl_dnn",
|
||||
srcs = glob([
|
||||
"src/common/*.cpp",
|
||||
"src/cpu/*.cpp",
|
||||
]),
|
||||
hdrs = glob(["include/*"]),
|
||||
copts = ["-fexceptions"] + select({
|
||||
"@org_tensorflow//tensorflow:linux_x86_64": [
|
||||
"-fopenmp",
|
||||
],
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
includes = [
|
||||
"include",
|
||||
"src",
|
||||
"src/common",
|
||||
"src/cpu",
|
||||
"src/cpu/xbyak",
|
||||
],
|
||||
nocopts = "-fno-exceptions",
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
Loading…
Reference in New Issue
Block a user