MKL-DNN open source integration. (#13135)

* MKL-DNN conv and build integration

* Adding new files that were mistakenly missing from the PR

* Minor change in the pip package build file

* Added missing #include

* Fixed a linking failure when running the bazel test

* Fixing BUILD file format

* Using -fopenmp for building mkl_dnn only when running on linux

* Fixing build rule attribute value

* Removing unnecessary deps from mkl test rule

* Removed deps on mkl-dnn when not building with --config=mkl
This commit is contained in:
Mahmoud Abuzaina 2017-10-03 20:59:45 -07:00 committed by gunan
parent f1a1099232
commit 6af7ab97ac
16 changed files with 1423 additions and 137 deletions

View File

@ -1772,6 +1772,7 @@ tf_cuda_library(
) + if_mkl( ) + if_mkl(
[ [
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
"@mkl_dnn//:mkl_dnn",
], ],
), ),
alwayslink = 1, alwayslink = 1,
@ -1932,7 +1933,7 @@ CORE_CPU_LIB_HEADERS = CORE_CPU_BASE_HDRS + [
"common_runtime/visitable_allocator.h", "common_runtime/visitable_allocator.h",
"graph/gradients.h", "graph/gradients.h",
"graph/quantize_training.h", "graph/quantize_training.h",
] ] + if_mkl(["graph/mkl_graph_util.h"])
tf_cuda_library( tf_cuda_library(
name = "core_cpu_impl", name = "core_cpu_impl",
@ -2033,7 +2034,10 @@ tf_cuda_library(
"//third_party/eigen3", "//third_party/eigen3",
"//tensorflow/core/kernels:required", "//tensorflow/core/kernels:required",
] + if_mkl( ] + if_mkl(
["//third_party/mkl:intel_binary_blob"], [
"//third_party/mkl:intel_binary_blob",
"@mkl_dnn//:mkl_dnn",
],
) + tf_additional_core_deps() + if_static([":core_cpu_impl"]), ) + tf_additional_core_deps() + if_static([":core_cpu_impl"]),
alwayslink = 1, alwayslink = 1,
) )
@ -2669,7 +2673,7 @@ tf_cc_test_mkl(
"graph/mkl_layout_pass_test.cc", "graph/mkl_layout_pass_test.cc",
"graph/mkl_tfconversion_pass_test.cc", "graph/mkl_tfconversion_pass_test.cc",
], ],
linkstatic = tf_kernel_tests_linkstatic(), linkstatic = 1,
deps = [ deps = [
":core", ":core",
":core_cpu", ":core_cpu",
@ -2687,18 +2691,6 @@ tf_cc_test_mkl(
"//tensorflow/cc:cc_ops", "//tensorflow/cc:cc_ops",
"//tensorflow/cc:scope", "//tensorflow/cc:scope",
"//tensorflow/cc:sendrecv_ops", "//tensorflow/cc:sendrecv_ops",
"//tensorflow/core/kernels:mkl_aggregate_ops",
"//tensorflow/core/kernels:mkl_concat_op",
"//tensorflow/core/kernels:mkl_conv_op",
"//tensorflow/core/kernels:mkl_cwise_ops_common",
"//tensorflow/core/kernels:mkl_fused_batch_norm_op",
"//tensorflow/core/kernels:mkl_identity_op",
"//tensorflow/core/kernels:mkl_input_conversion_op",
"//tensorflow/core/kernels:mkl_lrn_op",
"//tensorflow/core/kernels:mkl_pooling_ops",
"//tensorflow/core/kernels:mkl_relu_op",
"//tensorflow/core/kernels:mkl_reshape_op",
"//tensorflow/core/kernels:mkl_tfconv_op",
"//tensorflow/core/kernels:ops_util", "//tensorflow/core/kernels:ops_util",
"//third_party/eigen3", "//third_party/eigen3",
], ],

View File

@ -0,0 +1,129 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
#define TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_
#ifdef INTEL_MKL
#include <string>
#include "tensorflow/core/framework/op_kernel.h"
namespace tensorflow {
// Since our ops are going to produce and also consume N addition tensors
// (Mkl) for N Tensorflow tensors, we can have following different
// orderings among these 2N tensors.
//
// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
// consume A_m, B_m, and C_m additionally.
//
// INTERLEAVED: in this case 2N tensors are interleaved. So for above
// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
//
// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
// by N Mkl tensors. So for above example, the ordering looks
// like: A, B, C, A_m, B_m, C_m
//
// Following APIs map index of original Tensorflow tensors to their
// appropriate position based on selected ordering. For contiguous ordering,
// we need to know the total number of tensors (parameter total).
//
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
// NOTE: Currently, we use contiguous ordering. If you change this, then you
// would need to change Mkl op definitions in nn_ops.cc.
static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
// Get index of MetaData tensor from index 'n' of Data tensor.
inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
// For interleaved ordering, Mkl tensor follows immediately after
// Tensorflow tensor.
return n + 1;
} else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
// For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
return n + total_tensors / 2;
}
}
int inline GetTensorDataIndex(int n, int total_tensors) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
return 2 * n; // index corresponding to nth input/output tensor
} else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
return n;
}
}
int inline GetTensorMetaDataIndex(int n, int total_tensors) {
// Get index for TensorData first and then use mapping function
// to get TensorMetaData index from TensorData index.
int tidx = GetTensorDataIndex(n, total_tensors);
return DataIndexToMetaDataIndex(tidx, total_tensors);
}
namespace mkl_op_registry {
static const char* kMklOpLabel = "MklOp";
static const char* kMklOpLabelPattern = "label='MklOp'";
// Get the name of Mkl op from original TensorFlow op
// We prefix 'Mkl' to the original op to get Mkl op.
inline string GetMklOpName(const string& name) {
// Prefix that we add to Tensorflow op name to construct Mkl op name.
const char* const kMklOpPrefix = "_Mkl";
return string(kMklOpPrefix) + name;
}
// Check whether opname with type T is registered as MKL-compliant.
//
// @input: name of the op
// @input: T datatype to be used for checking op
// @return: true if opname is registered as Mkl op; false otherwise
static inline bool IsMklOp(const std::string& op_name, DataType T) {
string kernel = KernelsRegisteredForOp(op_name);
bool result =
kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT);
if (result) {
VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel;
}
return result;
}
// Check whether opname with type T is registered as MKL-compliant and
// is element-wise.
//
// @input: name of the op
// @input: T datatype to be used for checking op
// @return: true if opname is registered as element-wise Mkl op;
// false otherwise
static inline bool IsMklElementWiseOp(const std::string& op_name,
DataType T) {
if (!IsMklOp(op_name, T)) {
return false;
}
bool result = (0 == op_name.compare(GetMklOpName("Add")) ||
0 == op_name.compare(GetMklOpName("Sub")) ||
0 == op_name.compare(GetMklOpName("Mul")) ||
0 == op_name.compare(GetMklOpName("Maximum")) ||
0 == op_name.compare(GetMklOpName("SquaredDifference")));
VLOG(1) << "mkl_op_registry::" << op_name
<< " is elementwise MKL op: " << result;
return result;
}
} // namespace mkl_op_registry
} // namespace tensorflow
#endif // INTEL_MKL
#endif // TENSORFLOW_CORE_GRAPH_MKL_GRAPH_UTIL_H_

View File

@ -38,7 +38,7 @@ limitations under the License.
#include "tensorflow/core/util/tensor_format.h" #include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/graph/mkl_layout_pass.h" #include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/graph/mkl_graph_util.h"
namespace tensorflow { namespace tensorflow {

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifdef INTEL_MKL #ifdef INTEL_MKL
#include "tensorflow/core/graph/mkl_layout_pass.h" #include "tensorflow/core/graph/mkl_layout_pass.h"
#include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/graph/mkl_graph_util.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>

View File

@ -34,7 +34,7 @@ limitations under the License.
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/graph/mkl_tfconversion_pass.h" #include "tensorflow/core/graph/mkl_tfconversion_pass.h"
#include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/graph/mkl_graph_util.h"
namespace tensorflow { namespace tensorflow {

View File

@ -16,7 +16,7 @@ limitations under the License.
#ifdef INTEL_MKL #ifdef INTEL_MKL
#include "tensorflow/core/graph/mkl_tfconversion_pass.h" #include "tensorflow/core/graph/mkl_tfconversion_pass.h"
#include "tensorflow/core/util/mkl_util.h" #include "tensorflow/core/graph/mkl_graph_util.h"
#include <algorithm> #include <algorithm>
#include <string> #include <string>

View File

@ -820,6 +820,7 @@ tf_kernel_library(
hdrs = ["transpose_op.h"], hdrs = ["transpose_op.h"],
deps = ARRAY_DEPS + if_mkl([ deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
"@mkl_dnn//:mkl_dnn",
]), ]),
) )
@ -2596,6 +2597,7 @@ tf_kernel_library(
"//conditions:default": [], "//conditions:default": [],
}) + if_mkl([ }) + if_mkl([
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
"@mkl_dnn//:mkl_dnn",
]) + if_cuda([ ]) + if_cuda([
"//tensorflow/core/platform/default/build_config:cublas_plugin", "//tensorflow/core/platform/default/build_config:cublas_plugin",
]), ]),
@ -5501,8 +5503,10 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:nn_ops_op_lib",
] + if_mkl([
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
], "@mkl_dnn//:mkl_dnn",
]),
) )
tf_mkl_kernel_library( tf_mkl_kernel_library(
@ -5516,8 +5520,10 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:nn_ops_op_lib",
] + if_mkl([
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
], "@mkl_dnn//:mkl_dnn",
]),
) )
tf_mkl_kernel_library( tf_mkl_kernel_library(
@ -5566,16 +5572,19 @@ tf_mkl_kernel_library(
"//tensorflow/core:lib", "//tensorflow/core:lib",
"//tensorflow/core:lib_internal", "//tensorflow/core:lib_internal",
"//tensorflow/core:nn_ops_op_lib", "//tensorflow/core:nn_ops_op_lib",
] + if_mkl([
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
], "@mkl_dnn//:mkl_dnn",
]),
) )
tf_mkl_kernel_library( tf_mkl_kernel_library(
name = "mkl_fused_batch_norm_op", name = "mkl_fused_batch_norm_op",
srcs = ["mkl_fused_batch_norm_op.cc"], srcs = ["mkl_fused_batch_norm_op.cc"],
deps = NN_DEPS + [ deps = NN_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
], "@mkl_dnn//:mkl_dnn",
]),
) )
tf_mkl_kernel_library( tf_mkl_kernel_library(
@ -5589,9 +5598,10 @@ tf_mkl_kernel_library(
tf_mkl_kernel_library( tf_mkl_kernel_library(
name = "mkl_concat_op", name = "mkl_concat_op",
prefix = "mkl_concat_op", prefix = "mkl_concat_op",
deps = ARRAY_DEPS + [ deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
], "@mkl_dnn//:mkl_dnn",
]),
) )
tf_mkl_kernel_library( tf_mkl_kernel_library(
@ -5605,17 +5615,19 @@ tf_mkl_kernel_library(
tf_mkl_kernel_library( tf_mkl_kernel_library(
name = "mkl_identity_op", name = "mkl_identity_op",
prefix = "mkl_identity_op", prefix = "mkl_identity_op",
deps = ARRAY_DEPS + [ deps = ARRAY_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
], "@mkl_dnn//:mkl_dnn",
]),
) )
tf_mkl_kernel_library( tf_mkl_kernel_library(
name = "mkl_lrn_op", name = "mkl_lrn_op",
prefix = "mkl_lrn_op", prefix = "mkl_lrn_op",
deps = NN_DEPS + [ deps = NN_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob", "//third_party/mkl:intel_binary_blob",
], "@mkl_dnn//:mkl_dnn",
]),
) )
tf_mkl_kernel_library( tf_mkl_kernel_library(

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/mkl_conv_ops.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
@ -41,10 +42,24 @@ limitations under the License.
#include "mkl_dnn.h" #include "mkl_dnn.h"
#include "mkl_dnn_types.h" #include "mkl_dnn_types.h"
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
using mkldnn::stream;
using mkldnn::prop_kind;
using mkldnn::convolution_forward;
using mkldnn::convolution_backward_weights;
using mkldnn::convolution_direct;
#endif
namespace tensorflow { namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
#ifndef INTEL_MKL_DNN
template <typename Device, class T> template <typename Device, class T>
class MklConv2DCustomBackpropFilterOp : public OpKernel { class MklConv2DCustomBackpropFilterOp : public OpKernel {
public: public:
@ -411,6 +426,174 @@ class MklConv2DCustomBackpropFilterOp : public OpKernel {
TensorFormat data_format_; TensorFormat data_format_;
}; };
#else
template <typename Device, class T>
class MklConv2DCustomBackpropFilterOp : public OpKernel {
public:
explicit MklConv2DCustomBackpropFilterOp(OpKernelConstruction* context)
: OpKernel(context) {
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
int stride_n = GetTensorDim(strides_, data_format_, 'N');
int stride_c = GetTensorDim(strides_, data_format_, 'C');
OP_REQUIRES(
context, (stride_n == 1 && stride_c == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::cpu, 0);
MklDnnData<T> input(&cpu_engine);
MklDnnData<T> outbackprop(&cpu_engine);
MklDnnData<T> output(&cpu_engine);
// Input tensors
const Tensor& input_tensor = MklGetInput(context, 0);
const Tensor& filter_tensor = MklGetInput(context, 1);
const Tensor& obp_tensor = MklGetInput(context, 2); // Outbackprop
// Generate input shapes.
TensorShape filter_shape;
OP_REQUIRES(context, TensorShapeUtils::IsVector(filter_tensor.shape()),
errors::InvalidArgument(
"Conv2DBackpropFilter: filter_sizes input must be 1-dim, not ",
filter_tensor.dims()));
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
filter_tensor.vec<int32>(), &filter_shape));
TensorShape input_shape = input_tensor.shape();
TensorShape obp_shape = obp_tensor.shape();
// By default, all dims are in MKL order. Only dims in TF order
// are those with prefix tf_order.
memory::dims obp_dims, fwd_input_dims, fwd_filter_dims;
memory::dims padding_l, padding_r, strides, fwd_output_dims;
memory::dims fwd_output_dims_tf_order;
// Get forward convolution parameters.
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape,
&fwd_input_dims, &fwd_filter_dims,
&strides,
&fwd_output_dims_tf_order,
&fwd_output_dims,
&padding_l, &padding_r);
if (!context->status().ok()) return;
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(),
mkl_data_format);
auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(),
memory::format::hwio);
auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(),
mkl_data_format);
auto fwd_desc = convolution_forward::desc(prop_kind::forward,
convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md,
strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
// Allocate output tensor and shape
// TODO(nhasabni): Update this when support for MKL layout is added.
// Shape of output of Conv2DBackpropInput is same as 'input' of Conv2D.
TensorShape tf_output_shape(filter_shape);
MklShape mkl_output_mkl_shape;
mkl_output_mkl_shape.SetMklTensor(false);
Tensor* output_tensor = nullptr;
AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
mkl_output_mkl_shape);
// Create memory for user data.
// Describe how the inputs and outputs of Convolution look like. Also
// specify buffers containing actual input and output data.
// Although input shape required is in MKL-DNN order, the layout is
// Tensorflow's layout (NHWC or NCHW depending on data format).
input.SetUsrMem(fwd_input_dims, mkl_data_format, &input_tensor);
// Outbackprop shape is NHWC or NCHW depending on data format. Since
// GetInputSizeInMklOrder function returns size in that order we just use
// use that function directly.
conv_utl.GetInputSizeInMklOrder(obp_shape, &obp_dims);
if (!context->status().ok()) return;
outbackprop.SetUsrMem(obp_dims, mkl_data_format, &obp_tensor);
// Although output shape required is in MKL-DNN order,
// layout is Tensorflow's filter layout (HWIO)
// Shape of output of Conv2DBackpropInput is same as shape of filter.
memory::dims bwd_output_dims = fwd_filter_dims;
output.SetUsrMem(bwd_output_dims, memory::format::hwio, output_tensor);
// Create memory descriptors for convolution data w/ no specified format.
input.SetOpMemDesc(fwd_input_dims, memory::format::any);
outbackprop.SetOpMemDesc(obp_dims, memory::format::any);
output.SetOpMemDesc(bwd_output_dims, memory::format::any);
// Create convolution backward weights primitive.
auto bwd_desc = convolution_backward_weights::desc(convolution_direct,
input.GetOpMemDesc(), output.GetOpMemDesc(),
outbackprop.GetOpMemDesc(), strides, padding_l,
padding_r, TFPaddingToMklDnnPadding(padding_));
auto bwd_pd = convolution_backward_weights::primitive_desc(bwd_desc,
cpu_engine,
fwd_pd);
PrepareAndExecutePrimitive(bwd_pd, &input, &outbackprop, &output);
} catch (mkldnn::error &e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) +
", in file " + string(__FILE__) + ":" +
std::to_string(__LINE__);
OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:",
error_msg));
}
}
private:
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
// Prepare and execute net - checks for input and output reorders.
void PrepareAndExecutePrimitive(
const convolution_backward_weights::primitive_desc& conv_pd,
MklDnnData<T>* input, MklDnnData<T>* obp,
MklDnnData<T>* output) {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution.
std::vector<primitive> net;
input->CheckReorderToOpMem(conv_pd.src_primitive_desc(), &net);
obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
// Memory for output of convolution. Since we may need reorder on the
// output side, we will prepare reorder primitive in case output
// reorder to user memory is required.
bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
conv_pd.diff_weights_primitive_desc());
net.push_back(convolution_backward_weights(conv_pd, input->GetOpMem(),
obp->GetOpMem(), output->GetOpMem()));
// Insert reorder primitive in the net for output reorder if reorder is
// required.
if (output_reorder_required) {
output->InsertReorderToUserMem(&net);
}
// Handle output reorder
stream(stream::kind::eager).submit(net).wait();
}
};
#endif
#define REGISTER_MKL_FILTER_KERNELS(T) \ #define REGISTER_MKL_FILTER_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropFilter") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \

View File

@ -30,6 +30,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/conv_grad_ops.h" #include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/kernels/mkl_conv_ops.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
@ -43,10 +44,23 @@ limitations under the License.
#include "mkl_dnn.h" #include "mkl_dnn.h"
#include "mkl_dnn_types.h" #include "mkl_dnn_types.h"
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
using mkldnn::stream;
using mkldnn::prop_kind;
using mkldnn::convolution_forward;
using mkldnn::convolution_direct;
using mkldnn::convolution_backward_data;
#endif
namespace tensorflow { namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
#ifndef INTEL_MKL_DNN
template <typename Device, class T> template <typename Device, class T>
class MklConv2DCustomBackpropInputOp : public OpKernel { class MklConv2DCustomBackpropInputOp : public OpKernel {
public: public:
@ -345,6 +359,180 @@ class MklConv2DCustomBackpropInputOp : public OpKernel {
TensorFormat data_format; TensorFormat data_format;
}; };
#else
template <typename Device, class T>
class MklConv2DCustomBackpropInputOp : public OpKernel {
public:
~MklConv2DCustomBackpropInputOp() {}
explicit MklConv2DCustomBackpropInputOp(OpKernelConstruction* context)
: OpKernel(context) {
string data_format_str;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format_str));
OP_REQUIRES(context, FormatFromString(data_format_str, &data_format_),
errors::InvalidArgument("Invalid data format"));
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
int stride_n = GetTensorDim(strides_, data_format_, 'N');
int stride_c = GetTensorDim(strides_, data_format_, 'C');
OP_REQUIRES(
context, (stride_n == 1 && stride_c == 1),
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::cpu, 0);
MklDnnData<T> filter(&cpu_engine);
MklDnnData<T> outbackprop(&cpu_engine);
MklDnnData<T> output(&cpu_engine);
// Input tensors
const Tensor& input_tensor = MklGetInput(context, 0);
const Tensor& filter_tensor = MklGetInput(context, 1);
const Tensor& obp_tensor = MklGetInput(context, 2); // Outbackprop
// Generate input shape.
TensorShape input_shape;
OP_REQUIRES(context, TensorShapeUtils::IsVector(input_tensor.shape()),
errors::InvalidArgument(
"Conv2DBackpropInput: input_sizes input must be 1-dim, not ",
input_tensor.dims()));
OP_REQUIRES_OK(context, TensorShapeUtils::MakeShape(
input_tensor.vec<int32>(), &input_shape));
TensorShape filter_shape = filter_tensor.shape();
TensorShape obp_shape = obp_tensor.shape();
// By default, all dims are in MKL order. Only dims in TF order
// are those with prefix tf_order.
memory::dims obp_dims, fwd_input_dims, fwd_filter_dims;
memory::dims padding_l, padding_r, strides, fwd_output_dims;
memory::dims fwd_output_dims_tf_order;
// Get forward convolution parameters.
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
conv_utl.GetConvFwdSizesInMklOrder(input_shape, filter_shape,
&fwd_input_dims, &fwd_filter_dims,
&strides,
&fwd_output_dims_tf_order,
&fwd_output_dims,
&padding_l, &padding_r);
if (!context->status().ok()) return;
// Create Convolution forward descriptor since Convolution backward
// API needs it. For that, we first need to create input, filter
// and output memory descriptors.
auto mkl_data_format = TFDataFormatToMklDnnDataFormat(data_format_);
auto fwd_src_md = memory::desc(fwd_input_dims, MklDnnType<T>(),
mkl_data_format);
auto fwd_filter_md = memory::desc(fwd_filter_dims, MklDnnType<T>(),
memory::format::hwio);
auto fwd_out_md = memory::desc(fwd_output_dims, MklDnnType<T>(),
mkl_data_format);
auto fwd_desc = convolution_forward::desc(prop_kind::forward,
convolution_direct, fwd_src_md, fwd_filter_md, fwd_out_md,
strides, padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
auto fwd_pd = convolution_forward::primitive_desc(fwd_desc, cpu_engine);
// Allocate output tensor and shape
// TODO(nhasabni): Update this when support for MKL layout is added.
// Shape of output of Conv2DBackpropInput is same as 'input' of Conv2D.
TensorShape tf_output_shape(input_shape);
MklShape mkl_output_mkl_shape;
mkl_output_mkl_shape.SetMklTensor(false);
Tensor* output_tensor = nullptr;
AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
mkl_output_mkl_shape);
// Create memory for user data.
// Describe how the inputs and outputs of Convolution look like. Also
// specify buffers containing actual input and output data.
// Although input shape required is in MKL-DNN order, the layout is
// Tensorflow's layout (NHWC or NCHW depending on data format).
// Although filter shape (filter_dims) required is in MKL-DNN order,
// the layout is Tensorflow's layout (HWIO).
// Shape of Conv2DBackpropInput's filter is same as that of Conv2D filter.
filter.SetUsrMem(fwd_filter_dims, memory::format::hwio, &filter_tensor);
// Outbackprop shape is NHWC or NCHW depending on data format. Since
// GetInputSizeInMklOrder function returns size in that order we just use
// use that function directly.
conv_utl.GetInputSizeInMklOrder(obp_shape, &obp_dims);
if (!context->status().ok()) return;
outbackprop.SetUsrMem(obp_dims, mkl_data_format, &obp_tensor);
// Although output shape required is in MKL-DNN order,
// layout is Tensorflow's layout (NHWC or NCHW depending on data format).
// Shape of output of Conv2DBackpropInput is same as shape of 'input'
// of Conv2D.
memory::dims bwd_output_dims = fwd_input_dims;
output.SetUsrMem(bwd_output_dims, mkl_data_format, output_tensor);
// Create memory descriptors for convolution data w/ no specified format.
filter.SetOpMemDesc(fwd_filter_dims, memory::format::any);
outbackprop.SetOpMemDesc(obp_dims, memory::format::any);
output.SetOpMemDesc(bwd_output_dims, memory::format::any);
// Create convolution backward data primitive.
auto bwd_desc = convolution_backward_data::desc(convolution_direct,
output.GetOpMemDesc(), filter.GetOpMemDesc(),
outbackprop.GetOpMemDesc(), strides, padding_l,
padding_r, TFPaddingToMklDnnPadding(padding_));
auto bwd_pd = convolution_backward_data::primitive_desc(bwd_desc,
cpu_engine,
fwd_pd);
PrepareAndExecutePrimitive(bwd_pd, &filter, &outbackprop, &output);
} catch (mkldnn::error &e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + string(e.message) +
", in file " + string(__FILE__) + ":" +
std::to_string(__LINE__);
OP_REQUIRES_OK(context, errors::Aborted("Operation received an exception:",
error_msg));
}
}
private:
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
// Prepare and execute net - checks for input and output reorders.
void PrepareAndExecutePrimitive(
const convolution_backward_data::primitive_desc& conv_pd,
MklDnnData<T>* filter, MklDnnData<T>* obp,
MklDnnData<T>* output) {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution.
std::vector<primitive> net;
filter->CheckReorderToOpMem(conv_pd.weights_primitive_desc(), &net);
obp->CheckReorderToOpMem(conv_pd.diff_dst_primitive_desc(), &net);
// Memory for output of convolution. Since we may need reorder on the
// output side, we will prepare reorder primitive in case output
// reorder to user memory is required.
bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
conv_pd.diff_src_primitive_desc());
net.push_back(convolution_backward_data(conv_pd, obp->GetOpMem(),
filter->GetOpMem(), output->GetOpMem()));
// Insert reorder primitive in the net for output reorder if reorder is
// required.
if (output_reorder_required) {
output->InsertReorderToUserMem(&net);
}
// Handle output reorder
stream(stream::kind::eager).submit(net).wait();
}
};
#endif // INTEL_MKL_DNN
#define REGISTER_MKL_CPU_KERNELS(T) \ #define REGISTER_MKL_CPU_KERNELS(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \ REGISTER_KERNEL_BUILDER(Name("_MklConv2DBackpropInput") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \

View File

@ -19,6 +19,8 @@ limitations under the License.
#include <string.h> #include <string.h>
#include <map> #include <map>
#include <vector> #include <vector>
#include <string>
#include "tensorflow/core/framework/numeric_op.h" #include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/register_types.h"
@ -26,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h" #include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/bounds_check.h" #include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/mkl_conv_ops.h"
#include "tensorflow/core/kernels/ops_util.h" #include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
@ -40,10 +43,23 @@ limitations under the License.
#include "mkl_dnn.h" #include "mkl_dnn.h"
#include "mkl_dnn_types.h" #include "mkl_dnn_types.h"
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
using mkldnn::stream;
using mkldnn::prop_kind;
using mkldnn::convolution_forward;
using mkldnn::convolution_direct;
#endif
namespace tensorflow { namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice; typedef Eigen::ThreadPoolDevice CPUDevice;
// For now, MKL-ML is default. So making MKL-DNN not a default choice.
#ifndef INTEL_MKL_DNN
template <typename Device, typename T, bool biasEnabled> template <typename Device, typename T, bool biasEnabled>
class MklConv2DOp : public OpKernel { class MklConv2DOp : public OpKernel {
public: public:
@ -461,6 +477,205 @@ class MklConv2DOp : public OpKernel {
TensorFormat data_format_; TensorFormat data_format_;
}; };
#else
template <typename Device, typename T, bool biasEnabled>
class MklConv2DOp : public OpKernel {
public:
~MklConv2DOp() {}
explicit MklConv2DOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("strides", &strides_));
string data_format;
OP_REQUIRES_OK(context, context->GetAttr("data_format", &data_format));
OP_REQUIRES(context, FormatFromString(data_format, &data_format_),
errors::InvalidArgument("Invalid data format"));
OP_REQUIRES(context, strides_.size() == 4,
errors::InvalidArgument("Sliding window strides field must "
"specify 4 dimensions"));
const int64 stride_n = GetTensorDim(strides_, data_format_, 'N');
const int64 stride_c = GetTensorDim(strides_, data_format_, 'C');
OP_REQUIRES(
context, stride_n == 1 && stride_c == 1,
errors::InvalidArgument("Current implementation does not yet support "
"strides in the batch and depth dimensions."));
OP_REQUIRES_OK(context, context->GetAttr("padding", &padding_));
}
void Compute(OpKernelContext* context) override {
try {
auto cpu_engine = engine(engine::cpu, 0);
// Input tensors
size_t src_idx = 0, filter_idx = 1;
const Tensor& src_tensor = MklGetInput(context, src_idx);
const Tensor& filter_tensor = MklGetInput(context, filter_idx);
MklDnnData<T> src(&cpu_engine);
MklDnnData<T> filter(&cpu_engine);
MklDnnData<T> output(&cpu_engine);
memory::dims src_dims, filter_dims, padding_l, padding_r, strides;
memory::dims output_dims_tf_order, output_dims_mkl_order;
// Get shapes of input tensors in MKL-DNN order
MklDnnConvUtil conv_utl(context, strides_, padding_, data_format_);
conv_utl.GetConvFwdSizesInMklOrder(src_tensor.shape(),
filter_tensor.shape(),
&src_dims, &filter_dims, &strides,
&output_dims_tf_order,
&output_dims_mkl_order, &padding_l,
&padding_r);
if (!context->status().ok()) return;
// Check for corner case - if there is nothing to compute, return.
TensorShape tf_output_shape({output_dims_tf_order[0],
output_dims_tf_order[1],
output_dims_tf_order[2],
output_dims_tf_order[3]});
Tensor* output_tensor = nullptr;
MklShape mkl_output_mkl_shape;
mkl_output_mkl_shape.SetMklTensor(false);
AllocateOutputSetMklShape(context, 0, &output_tensor, tf_output_shape,
mkl_output_mkl_shape);
// Forward filter in TF format from input at index 1 to output at index 1.
ForwardTfTensorInToOut(context, 1, 1);
if (tf_output_shape.num_elements() == 0) {
// TODO(jbobba): Verify correctness here
// Need semantics for Null MKL tensor
return;
}
// Corner case to handle 0 batch size.
if (output_dims_tf_order[0] == 0) {
// Nothing to do, allocate output tensor and return
// TODO(nhasabni): remove this code later once serialization
// in MKL-DNN is supported.
AllocateOutputSetMklShape(context, 0, &output_tensor,
src_tensor.shape(), mkl_output_mkl_shape);
return;
} else {
// Otherwise regular output tensor allocation
// Allocate output tensor.
}
CHECK_NOTNULL(output_tensor);
// Create memory for user data.
// Describe how the inputs and outputs of Convolution look like. Also
// specify buffers containing actual input and output data.
// Although input shape (src_dims) required is in MKL-DNN order,
// the layout is Tensorflow's layout (NHWC or NCHW depending on data
// format).
src.SetUsrMem(src_dims, TFDataFormatToMklDnnDataFormat(data_format_),
const_cast<void*>(static_cast<const void*>(
src_tensor.flat<T>().data())));
// Although filter shape (filter_dims) required is in MKL-DNN order,
// the layout is Tensorflow's layout (HWIO).
filter.SetUsrMem(filter_dims, memory::format::hwio,
const_cast<void*>(static_cast<const void*>(
filter_tensor.flat<T>().data())));
// Although output shape (output_dims) required is in MKL-DNN order,
// layout is Tensorflow's layout (NHWC or NCHW depending on data format).
output.SetUsrMem(output_dims_mkl_order,
TFDataFormatToMklDnnDataFormat(data_format_),
output_tensor->flat<T>().data());
// Create memory descriptors for convolution data w/ no specified format.
src.SetOpMemDesc(src_dims, memory::format::any);
filter.SetOpMemDesc(filter_dims, memory::format::any);
output.SetOpMemDesc(output_dims_mkl_order, memory::format::any);
// If bias is enabled, then do the same steps as above for bias.
if (biasEnabled) {
MklDnnData<T> bias(&cpu_engine);
memory::dims bias_size;
conv_utl.GetBiasSizeInMklOrder(2 /* bias idx */, &bias_size);
const Tensor& bias_tensor = MklGetInput(context, 2);
bias.SetUsrMem(bias_size, memory::format::x,
const_cast<void*>(static_cast<const void*>(
bias_tensor.flat<T>().data())));
bias.SetOpMemDesc(bias_size, memory::format::any);
// Create convolution primitive with Bias.
auto conv_desc = convolution_forward::desc(prop_kind::forward,
convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(),
bias.GetOpMemDesc(), output.GetOpMemDesc(), strides,
padding_l, padding_r, TFPaddingToMklDnnPadding(padding_));
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc,
cpu_engine);
PrepareAndExecuteNet(conv_prim_desc, &src, &filter, &bias, &output);
} else {
// Create convolution primitive without Bias.
auto conv_desc = convolution_forward::desc(prop_kind::forward,
convolution_direct, src.GetOpMemDesc(), filter.GetOpMemDesc(),
output.GetOpMemDesc(), strides, padding_l, padding_r,
TFPaddingToMklDnnPadding(padding_));
auto conv_prim_desc = convolution_forward::primitive_desc(conv_desc,
cpu_engine);
PrepareAndExecuteNet(conv_prim_desc, &src, &filter, nullptr, &output);
}
} catch (mkldnn::error &e) {
string error_msg = "Status: " + std::to_string(e.status) +
", message: " + std::string(e.message) +
", in file " + std::string(__FILE__) + ":" +
std::to_string(__LINE__);
OP_REQUIRES_OK(context,
errors::Aborted("Operation received an exception:", error_msg));
}
}
private:
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
// Prepare and execute net - checks for input and output reorders.
void PrepareAndExecuteNet(
const convolution_forward::primitive_desc& conv_prim_desc,
MklDnnData<T>* src, MklDnnData<T>* filter,
MklDnnData<T>* bias, MklDnnData<T>* output) {
// Create reorders between user layout and MKL layout if it is needed and
// add it to the net before convolution.
std::vector<primitive> net;
src->CheckReorderToOpMem(conv_prim_desc.src_primitive_desc(), &net);
filter->CheckReorderToOpMem(conv_prim_desc.weights_primitive_desc(), &net);
// Memory for output of convolution. Since we may need reorder on the
// output side, we will prepare reorder primitive in case output
// reorder to user memory is required.
bool output_reorder_required = output->PrepareReorderToUserMemIfReq(
conv_prim_desc.dst_primitive_desc());
// Create convolution primitive and add it to net.
if (bias) {
CHECK_EQ(biasEnabled, true);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
filter->GetOpMem(), bias->GetOpMem(),
output->GetOpMem()));
} else {
CHECK_EQ(biasEnabled, false);
net.push_back(convolution_forward(conv_prim_desc, src->GetOpMem(),
filter->GetOpMem(), output->GetOpMem()));
}
// Insert reorder primitive in the net for output reorder if reorder is
// required.
if (output_reorder_required) {
output->InsertReorderToUserMem(&net);
}
// Handle output reorder
stream(stream::kind::eager).submit(net).wait();
}
};
#endif
#define REGISTER_MKL_CPU(T) \ #define REGISTER_MKL_CPU(T) \
REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \ REGISTER_KERNEL_BUILDER(Name("_MklConv2D") \
.Device(DEVICE_CPU) \ .Device(DEVICE_CPU) \

View File

@ -0,0 +1,316 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
#define TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_
#include <vector>
#include <limits>
#include "tensorflow/core/framework/numeric_op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/tensor_slice.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/kernels/conv_grad_ops.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/util/mkl_util.h"
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
#endif
namespace tensorflow {
#ifdef INTEL_MKL_DNN
class MklDnnConvUtil {
protected:
OpKernelContext* context_; // We don't own this.
std::vector<int32> strides_;
Padding padding_;
TensorFormat data_format_;
public:
MklDnnConvUtil(OpKernelContext* context, const std::vector<int32>& strides,
Padding pad, TensorFormat fm) : context_(context),
strides_(strides), padding_(pad), data_format_(fm) {}
virtual ~MklDnnConvUtil() { context_ = nullptr; }
// Calculate Convolution strides
virtual inline void GetStridesInMklOrder(memory::dims *strides) {
// For now we take the stride from the second and third dimensions only
// (we do not support striding on the batch or depth dimension).
CHECK_NOTNULL(strides);
int stride_rows = GetTensorDim(strides_, data_format_, 'H');
int stride_cols = GetTensorDim(strides_, data_format_, 'W');
*strides = {stride_rows, stride_cols};
}
// Calculate Convolution input size in MKL-DNN order. MKL-DNN
// requires input in NCHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status.
virtual inline void
GetInputSizeInMklOrder(const TensorShape& input_shape,
memory::dims *input_dims) {
#define CHECK_BOUNDS(val, err_msg) do { \
OP_REQUIRES(context_, FastBoundsCheck(val, \
std::numeric_limits<int>::max()), \
errors::InvalidArgument(err_msg)); \
}while(0)
CHECK_NOTNULL(input_dims);
// Input channel
int64 input_depth_raw = GetTensorDim(input_shape, data_format_, 'C');
int input_depth = static_cast<int>(input_depth_raw);
// Input rows/height
int64 input_rows_raw = GetTensorDim(input_shape, data_format_, 'H');
CHECK_BOUNDS(input_rows_raw, "Input rows too large");
int input_rows = static_cast<int>(input_rows_raw);
// Input columns/width
int64 input_cols_raw = GetTensorDim(input_shape, data_format_, 'W');
CHECK_BOUNDS(input_cols_raw, "Input cols too large");
int input_cols = static_cast<int>(input_cols_raw);
// Input batch
int64 input_batch_raw = GetTensorDim(input_shape, data_format_, 'N');
CHECK_BOUNDS(input_batch_raw, "Input batch too large");
int input_batch = static_cast<int>(input_batch_raw);
#undef CHECK_BOUNDS
// MKL-DNN always requires input in NCHW format.
*input_dims = {input_batch, input_depth, input_rows, input_cols};
}
// Calculate Convolution filter size in MKL-DNN order. MKL-DNN
// requires filter in OIHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status.
//
// Calculate Convolution filter size in MKL-DNN order. MKL-DNN
// requires filter in OIHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status. This function differs from GetConvFilterSizeInMklOrder in
// parameter for input - it accepts src_shape since Convolution Backward
// Input gets shape of input tensor rather than actual tensor (Convolution
// forward gets actual tensor as input).
//
// TODO(nhasabni): Add similar function for input and filter in MklShape.
virtual inline void
GetFilterSizeInMklOrder(const TensorShape& input_shape,
const TensorShape& filter_shape,
memory::dims *filter_dims) {
CHECK_NOTNULL(filter_dims);
OP_REQUIRES(context_, filter_shape.dims() == 4,
errors::InvalidArgument("filter must be 4-dimensional: ",
filter_shape.DebugString()));
for (int i = 0; i < 3; i++) {
OP_REQUIRES(context_, FastBoundsCheck(filter_shape.dim_size(i),
std::numeric_limits<int>::max()),
errors::InvalidArgument("filter too large"));
}
int input_depth = GetTensorDim(input_shape, data_format_, 'C');
OP_REQUIRES(
context_, input_depth == filter_shape.dim_size(2),
errors::InvalidArgument("input and filter must have the same depth: ",
input_depth, " vs ", filter_shape.dim_size(2)));
// TF filter is always in (rows, cols, in_depth, out_depth) order.
int filter_rows = static_cast<int>(filter_shape.dim_size(0));
int filter_cols = static_cast<int>(filter_shape.dim_size(1));
int in_depth = static_cast<int>(filter_shape.dim_size(2));
int out_depth = static_cast<int>(filter_shape.dim_size(3));
// MKL-DNN always needs filter in OIHW format.
// OIHW = (out_depth, in_depth, rows, cols)
*filter_dims = {out_depth, in_depth, filter_rows, filter_cols};
}
// Calculate Convolution filter size in MKL-DNN order. MKL-DNN
// requires filter in OIHW format. Function does not return anything.
// But errors arising from sanity checks are returned in context's
// status.
virtual inline void
GetFilterSizeInMklOrder(size_t src_index, size_t filter_index,
memory::dims *filter_dims) {
CHECK_NOTNULL(filter_dims);
const Tensor& input = MklGetInput(context_, src_index);
const Tensor& filter = MklGetInput(context_, filter_index);
GetFilterSizeInMklOrder(input.shape(), filter.shape(), filter_dims);
}
// Calculate Bias size for 2D Convolution. Function does not return
// anything, but sets error in context status.
virtual inline void
GetBiasSizeInMklOrder(size_t bias_index, memory::dims *bias_dims) {
const Tensor& bias = MklGetInput(context_, bias_index);
OP_REQUIRES(context_, bias.dims() == 1,
errors::InvalidArgument("bias must be 1-dimensional: ",
bias.shape().DebugString()));
*bias_dims = { static_cast<int>(bias.dim_size(0)) };
}
// Function to calculate output and padding size for 2D convolution.
//
// Calculate output shape of Convolution in MKL-DNN and TensorFlow order.
// MKL-DNN uses NCHW for output order. But TensorFlow output will be in
// NHWC or NCHW format depending on data format. Function also calculates
// left, right, top and bottom pads. Function does not return any status -
// status is returned via context status.
//
// TODO(nhasabni): Add similar function for input and filter in MklShape.
virtual inline void
GetOutputAndPadSizeInMklOrder(const TensorShape& input_shape,
const TensorShape& filter_shape,
const memory::dims& strides,
memory::dims *output_dims_tf_order,
memory::dims *output_dims_mkl_order,
memory::dims *pad_l, memory::dims *pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
CHECK_NOTNULL(pad_r);
int input_rows = GetTensorDim(input_shape, data_format_, 'H');
int input_cols = GetTensorDim(input_shape, data_format_, 'W');
// The first dimension for filter is rows/height.
int filter_rows = filter_shape.dim_size(0);
// The second dimension for filter is cols/width.
int filter_cols = filter_shape.dim_size(1);
// Stride is vector of 2 elements: {s_r, s_c}
int stride_rows = strides[0];
int stride_cols = strides[1];
// Output batch is same as input batch.
int out_batch = GetTensorDim(input_shape, data_format_, 'N');
// Output depth is same as last dimension for filter.
int out_depth = filter_shape.dim_size(3);
int64 out_rows = 0, out_cols = 0;
int64 pad_top = 0, pad_bottom = 0, pad_left, pad_right;
OP_REQUIRES_OK(context_,
GetWindowedOutputSizeVerbose(input_rows, filter_rows, stride_rows,
padding_, &out_rows, &pad_top, &pad_bottom));
OP_REQUIRES_OK(context_,
GetWindowedOutputSizeVerbose(input_cols, filter_cols, stride_cols,
padding_, &out_cols, &pad_left, &pad_right));
// Tensorflow output is in data_format order. (NHWC or NCHW)
TensorShape out_shape = ShapeFromFormat(data_format_, out_batch,
out_rows, out_cols, out_depth);
*output_dims_tf_order = TFShapeToMklDnnDims(out_shape);
// MKL-DNN always needs output in NCHW format.
*output_dims_mkl_order = {out_batch, out_depth, static_cast<int>(out_rows),
static_cast<int>(out_cols)};
// Now handle padding. MKL-DNN uses asymetric padding.
*pad_l = {static_cast<int>(pad_top), static_cast<int>(pad_left)};
*pad_r = {static_cast<int>(pad_bottom), static_cast<int>(pad_right)};
}
// Calculate output and pad size of forward Convolution operator.
// See comment on GetConvOutputAndPadSizeInMklOrder for parameters.
//
// Function does not return anything, but sets error in context status.
inline void
GetOutputAndPadSizeInMklOrder(size_t src_index, size_t filter_index,
const memory::dims& strides,
memory::dims *output_dims_tf_order,
memory::dims *output_dims_mkl_order,
memory::dims *pad_l, memory::dims *pad_r) {
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
CHECK_NOTNULL(pad_r);
const Tensor& input = MklGetInput(context_, src_index);
const Tensor& filter = MklGetInput(context_, filter_index);
OP_REQUIRES(context_, input.dims() == 4,
errors::InvalidArgument("input must be 4-dimensional",
input.shape().DebugString()));
GetOutputAndPadSizeInMklOrder(input.shape(), filter.shape(),
strides, output_dims_tf_order,
output_dims_mkl_order, pad_l, pad_r);
}
// Wrapper function to calculate input, filter, and output sizes of
// 2D Convolution in MKL order (NCHW for input and output; OIHW for filter.)
// Function also calculates output shape in Tensorflow order. Additionally, it
// also calculates strides and paddings for 2D Convolution.
//
// Function does not return anything, but sets error in context status.
inline void GetConvFwdSizesInMklOrder(const TensorShape& input_shape,
const TensorShape& filter_shape,
memory::dims *input_dims,
memory::dims *filter_dims,
memory::dims *strides,
memory::dims *output_dims_tf_order,
memory::dims *output_dims_mkl_order,
memory::dims *pad_l,
memory::dims *pad_r) {
CHECK_NOTNULL(input_dims);
CHECK_NOTNULL(filter_dims);
CHECK_NOTNULL(strides);
CHECK_NOTNULL(output_dims_tf_order);
CHECK_NOTNULL(output_dims_mkl_order);
CHECK_NOTNULL(pad_l);
CHECK_NOTNULL(pad_r);
GetInputSizeInMklOrder(input_shape, input_dims);
if (!context_->status().ok()) return;
GetFilterSizeInMklOrder(input_shape, filter_shape, filter_dims);
if (!context_->status().ok()) return;
GetStridesInMklOrder(strides);
GetOutputAndPadSizeInMklOrder(input_shape, filter_shape, *strides,
output_dims_tf_order,
output_dims_mkl_order,
pad_l, pad_r);
if (!context_->status().ok()) return;
}
};
#endif // INTEL_MKL_DNN
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_MKL_CONV_OPS_H_

View File

@ -26,13 +26,19 @@ limitations under the License.
#include "mkl_trans.h" #include "mkl_trans.h"
#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h" #include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/util/padding.h"
#include "tensorflow/core/util/tensor_format.h"
#include "tensorflow/core/graph/mkl_graph_util.h"
#ifdef INTEL_MKL_DNN
#include "mkldnn.hpp"
#endif
// The file contains a number of utility classes and functions used by MKL // The file contains a number of utility classes and functions used by MKL
// enabled kernels // enabled kernels
@ -219,19 +225,19 @@ class MklShape {
// Location from start of buffer where isMklTensor_ is serialized // Location from start of buffer where isMklTensor_ is serialized
#define DIMS_OFFSET \ #define DIMS_OFFSET \
(IS_MKL_TENSOR_OFFSET + sizeof(size_t)) // Location of dimension_ (IS_MKL_TENSOR_OFFSET + sizeof(size_t)) // Location of dimension_
// Location of sizes. Note dim is not used here, left here
// to make macros consistent.
#define SIZES_OFFSET(dims) \ #define SIZES_OFFSET(dims) \
(DIMS_OFFSET + \ (DIMS_OFFSET + sizeof(size_t))
sizeof(size_t)) // Location of sizes. Note dim is not used here, left here
// to make macros consistent.
#define STRIDES_OFFSET(dims) \ #define STRIDES_OFFSET(dims) \
(SIZES_OFFSET(dims) + dims * sizeof(size_t)) // Location of strides (SIZES_OFFSET(dims) + dims * sizeof(size_t)) // Location of strides
#define MKL_LAYOUT_OFFSET(dims) \ #define MKL_LAYOUT_OFFSET(dims) \
(STRIDES_OFFSET(dims) + dims * sizeof(size_t)) // Location of mklLayout_ (STRIDES_OFFSET(dims) + dims * sizeof(size_t)) // Location of mklLayout_
#define TF_LAYOUT_OFFSET(dims) \ #define TF_LAYOUT_OFFSET(dims) \
(MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF) // Location of tfLayout_ (MKL_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF) // Location of tfLayout_
// Location of tf_to_mkl_dim_map_
#define TF_TO_MKL_DIM_MAP_OFFSET(dims) \ #define TF_TO_MKL_DIM_MAP_OFFSET(dims) \
(TF_LAYOUT_OFFSET(dims) + \ (TF_LAYOUT_OFFSET(dims) + SIZE_OF_MKL_DNN_BUF)
SIZE_OF_MKL_DNN_BUF) // Location of tf_to_mkl_dim_map_
// TODO(agramesh1) make sure to create a const to share with rewrite pass // TODO(agramesh1) make sure to create a const to share with rewrite pass
// for min size of MKL metadata tensor. // for min size of MKL metadata tensor.
@ -342,58 +348,6 @@ inline Tensor ConvertMklToTF(OpKernelContext* context, const Tensor& mkl_tensor,
return output_tensor; return output_tensor;
} }
// Since our ops are going to produce and also consume N addition tensors
// (Mkl) for N Tensorflow tensors, we can have following different
// orderings among these 2N tensors.
//
// E.g., for Tensorflow tensors A, B, and C, our ops will produce and
// consume A_m, B_m, and C_m additionally.
//
// INTERLEAVED: in this case 2N tensors are interleaved. So for above
// example, the ordering looks like: A, A_m, B, B_m, C, C_m.
//
// CONTIGUOUS: in thi case N Tensorflow tensors are contiguous followed
// by N Mkl tensors. So for above example, the ordering looks
// like: A, B, C, A_m, B_m, C_m
//
// Following APIs map index of original Tensorflow tensors to their appropriate
// position based on selected ordering. For contiguous ordering, we need to know
// the total number of tensors (parameter total).
//
typedef enum { TENSORS_INTERLEAVED, TENSORS_CONTIGUOUS } MklTfTensorOrdering;
// NOTE: Currently, we use contiguous ordering. If you change this, then you
// would need to change Mkl op definitions in nn_ops.cc.
static MklTfTensorOrdering kTensorOrdering = TENSORS_CONTIGUOUS;
// Get index of MetaData tensor from index 'n' of Data tensor.
inline int DataIndexToMetaDataIndex(int n, int total_tensors) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
// For interleaved ordering, Mkl tensor follows immediately after
// Tensorflow tensor.
return n + 1;
} else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
// For contiguous ordering, Mkl tensor is n+total_tensors / 2 away.
return n + total_tensors / 2;
}
}
int inline GetTensorDataIndex(int n, int total_tensors) {
if (kTensorOrdering == MklTfTensorOrdering::TENSORS_INTERLEAVED) {
return 2 * n; // index corresponding to nth input/output tensor
} else {
CHECK_EQ(kTensorOrdering, MklTfTensorOrdering::TENSORS_CONTIGUOUS);
return n;
}
}
int inline GetTensorMetaDataIndex(int n, int total_tensors) {
// Get index for TensorData first and then use mapping function
// to get TensorMetaData index from TensorData index.
int tidx = GetTensorDataIndex(n, total_tensors);
return DataIndexToMetaDataIndex(tidx, total_tensors);
}
// Get the MKL shape from the second string tensor // Get the MKL shape from the second string tensor
inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) { inline void GetMklShape(OpKernelContext* ctext, int n, MklShape* mklshape) {
mklshape->DeSerializeMklShape( mklshape->DeSerializeMklShape(
@ -480,6 +434,13 @@ inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
*buf_out = static_cast<void*>(tensor_out->flat<float>().data()); *buf_out = static_cast<void*>(tensor_out->flat<float>().data());
} }
template <typename T>
inline void AllocTmpBuffer(OpKernelContext* context, Tensor* tensor_out,
TensorShape tf_shape) {
OP_REQUIRES_OK(context, context->allocate_temp(DataTypeToEnum<T>::v(),
tf_shape, tensor_out));
}
inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides, inline void GetStridesFromSizes(TensorFormat data_format, size_t* strides,
const size_t* sizes) { const size_t* sizes) {
// MKL requires strides in NCHW // MKL requires strides in NCHW
@ -743,56 +704,294 @@ inline void MklNCHWToNHWC(const Tensor& input, Tensor** output) {
} }
} }
namespace mkl_op_registry { // -------------------------------------------------------------------
static const char* kMklOpLabel = "MklOp";
static const char* kMklOpLabelPattern = "label='MklOp'";
// Get the name of Mkl op from original TensorFlow op #ifdef INTEL_MKL_DNN
// We prefix 'Mkl' to the original op to get Mkl op.
inline string GetMklOpName(const string& name) { using mkldnn::memory;
// Prefix that we add to Tensorflow op name to construct Mkl op name. using mkldnn::reorder;
const char* const kMklOpPrefix = "_Mkl"; using mkldnn::primitive;
return string(kMklOpPrefix) + name; using mkldnn::padding_kind;
using mkldnn::engine;
/// Return MKL-DNN data type (memory::data_type) for input type T
///
/// @input None
/// @return memory::data_type corresponding to type T
template<typename T> static memory::data_type MklDnnType();
/// Instantiation for float type. Add similar instantiations for other
/// type if needed.
template <>
memory::data_type MklDnnType<float>() {
return memory::data_type::f32;
} }
// Check whether opname with type T is registered as MKL-compliant. /// Map TensorFlow's data format into MKL-DNN data format
// ///
// @input: name of the op /// @input: TensorFlow data format
// @input: T datatype to be used for checking op /// @return: memory::format corresponding to TensorFlow data format;
// @return: true if opname is registered as Mkl op; false otherwise /// Fails with an error if invalid data format.
static inline bool IsMklOp(const std::string& op_name, DataType T) { inline memory::format TFDataFormatToMklDnnDataFormat(TensorFormat format) {
string kernel = KernelsRegisteredForOp(op_name); if (format == FORMAT_NHWC) return memory::format::nhwc;
bool result = else if (format == FORMAT_NCHW) return memory::format::nchw;
kernel.find(kMklOpLabelPattern) != string::npos && (T == DT_FLOAT); TF_CHECK_OK(Status(error::Code::INVALID_ARGUMENT,
if (result) { "Unsupported data format"));
VLOG(1) << "mkl_op_registry::" << op_name << " is " << kMklOpLabel; // Return to get rid of compiler warning
return memory::format::format_undef;
}
/// Map TensorShape object into memory::dims required by MKL-DNN
///
/// This function will simply map input TensorShape into MKL-DNN dims
/// naively. So it will preserve the order of dimensions. E.g., if
/// input tensor is in NHWC format, then dims will be in NHWC format
/// also.
///
/// @input TensorShape object in shape
/// @return memory::dims corresponding to TensorShape
inline memory::dims TFShapeToMklDnnDims(const TensorShape& shape) {
memory::dims dims(shape.dims());
for (unsigned int d = 0; d < shape.dims(); ++d) {
dims[d] = shape.dim_size(d);
} }
return result; return dims;
} }
// Check whether opname with type T is registered as MKL-compliant and /// Map TensorShape object into memory::dims in NCHW format required by MKL-DNN
// is element-wise. ///
// /// This function is a specific one than above function. It will map input
// @input: name of the op /// TensorShape into MKL-DNN dims in NCHW format. So it may not preserve the
// @input: T datatype to be used for checking op /// order of dimensions. E.g., if input tensor is in NHWC format, then dims
// @return: true if opname is registered as element-wise Mkl op; false otherwise /// will be in NCHW format, and not in NHWC format.
static inline bool IsMklElementWiseOp(const std::string& op_name, DataType T) { ///
if (!IsMklOp(op_name, T)) { /// @input TensorShape object in shape
/// @return memory::dims in MKL-DNN required NCHW format
inline memory::dims TFShapeToMklDnnDimsInNCHW(const TensorShape& shape,
TensorFormat format) {
// Check validity of format.
CHECK_NE(TFDataFormatToMklDnnDataFormat(format),
memory::format::format_undef);
int n = shape.dim_size(GetTensorDimIndex(format, 'N'));
int c = shape.dim_size(GetTensorDimIndex(format, 'C'));
int h = shape.dim_size(GetTensorDimIndex(format, 'H'));
int w = shape.dim_size(GetTensorDimIndex(format, 'W'));
// MKL-DNN requires dimensions in NCHW format.
return memory::dims({n, c, h, w});
}
inline padding_kind TFPaddingToMklDnnPadding(Padding pad) {
// MKL-DNN only supports zero padding.
return padding_kind::zero;
}
/*
* Class to represent all the resources corresponding to a tensor in TensorFlow
* that are required to execute an operation (such as Convolution).
*/
template <typename T>
class MklDnnData {
private:
/// MKL-DNN memory primitive for input user memory
memory* user_memory_;
/// MKL-DNN memory primitive in case input or output reorder is needed.
memory* reorder_memory_;
/// Operations memory descriptor
memory::desc* op_md_;
/// CPU engine on which operation will be executed
const engine* cpu_engine_;
public:
explicit MklDnnData(const engine* e) : user_memory_(nullptr),
reorder_memory_(nullptr),
op_md_(nullptr), cpu_engine_(e) {}
~MklDnnData() {
cpu_engine_ = nullptr; // We don't own this.
delete(user_memory_);
delete(reorder_memory_);
delete(op_md_);
}
void* GetTensorBuffer(const Tensor* tensor) {
CHECK_NOTNULL(tensor);
return const_cast<void*>(static_cast<const void*>(
tensor->flat<T>().data()));
}
/// Set user memory primitive using specified dimensions, memory format and
/// data_buffer. Function automatically uses element data type by using
/// input type T used for creating call object.
///
/// In a nutshell, function allows user to describe the input tensor to
/// an operation. E.g., filter of Conv2D is of shape {1, 2, 3, 4}, and
/// memory format HWIO, and the buffer that contains actual values is
/// pointed by data_buffer.
void SetUsrMem(memory::dims dim, memory::format fm, void* data_buffer) {
CHECK_NOTNULL(data_buffer);
CHECK_NOTNULL(cpu_engine_);
// TODO(nhasabni): can we remove dynamic memory allocation?
user_memory_ = new memory(memory::primitive_desc(
memory::desc(dim, MklDnnType<T>(), fm),
*cpu_engine_), data_buffer);
}
void SetUsrMem(memory::dims dim, memory::format fm, const Tensor* tensor) {
CHECK_NOTNULL(tensor);
SetUsrMem(dim, fm, GetTensorBuffer(tensor));
}
/// A version of function to set user memory primitive that accepts memory
/// descriptor directly, instead of accepting dimensions and format. This
/// function is more generic that the one above, but the function above is
/// sufficient in most cases.
void SetUsrMem(memory::desc md, void* data_buffer) {
CHECK_NOTNULL(data_buffer);
CHECK_NOTNULL(cpu_engine_);
// TODO(nhasabni): can we remove dynamic memory allocation?
user_memory_ = new memory(memory::primitive_desc(md, *cpu_engine_),
data_buffer);
}
/// A version of SetUsrMem with memory descriptor and tensor
void SetUsrMem(memory::desc md, const Tensor* tensor) {
CHECK_NOTNULL(tensor);
SetUsrMem(md, GetTensorBuffer(tensor));
}
/// A version of function to set user memory primitive that accepts primitive
/// descriptor directly, instead of accepting dimensions and format. This
/// function is more generic that the one above, but the function above is
/// sufficient in most cases.
void SetUsrMem(memory::primitive_desc pd, void* data_buffer) {
CHECK_NOTNULL(data_buffer);
CHECK_NOTNULL(cpu_engine_);
// TODO(nhasabni): can we remove dynamic memory allocation?
user_memory_ = new memory(pd, data_buffer);
}
/// A version of SetUsrMem with primitive descriptor and tensor
void SetUsrMem(memory::primitive_desc pd, const Tensor* tensor) {
CHECK_NOTNULL(tensor);
SetUsrMem(pd, GetTensorBuffer(tensor));
}
/// Get function for user memory primitive.
const memory* GetUsrMem() const { return user_memory_; }
/// Get function for primitive descriptor of user memory primitive.
const memory::primitive_desc GetUsrMemPrimDesc() const {
CHECK_NOTNULL(user_memory_);
return user_memory_->get_primitive_desc();
}
/// Get function for descriptor of user memory.
memory::desc GetUsrMemDesc() {
// This is ugly. Why MKL-DNN does not provide desc() method of const type??
const memory::primitive_desc pd = GetUsrMemPrimDesc();
return const_cast<memory::primitive_desc*>(&pd)->desc();
}
/// Get function for data buffer of user memory primitive.
void* GetUsrMemDataHandle() const {
CHECK_NOTNULL(user_memory_);
return user_memory_->get_data_handle();
}
/// Get the memory primitive for input and output of an op. If inputs
/// to an op require reorders, then this function returns memory primitive
/// for reorder. Otherwise, it will return memory primitive for user memory.
///
/// E.g., Conv2D(I, F) is a primitive with I and F being inputs. Then to
/// execute Conv2D, we need memory primitive for I and F. Buf if reorder is
/// required for I and F (say I_r is reorder primitive for I; F_r is reorder
/// primitive for F), then we need I_r and F_r to perform Conv2D.
const memory& GetOpMem() const {
return reorder_memory_ ? *reorder_memory_ : *user_memory_;
}
/// Set memory descriptor of an operation in terms of dimensions and memory
/// format. E.g., For Conv2D, the dimensions would be same as user dimensions
/// but memory::format would be mkldnn::any because we want MKL-DNN to choose
/// best layout/format for given input dimensions.
void SetOpMemDesc(const memory::dims& dim, memory::format fm) {
// TODO(nhasabni): can we remove dynamic memory allocation?
op_md_ = new memory::desc(dim, MklDnnType<T>(), fm);
}
/// Get function for memory descriptor for an operation
const memory::desc& GetOpMemDesc() const { return *op_md_; }
/// Function to handle input reordering
///
/// Check if we need to reorder this input of an operation.
/// Return true and allocate reorder memory primitive if reorder is needed.
/// Otherwise, return false and do not allocate reorder memory primitive.
///
/// To check if reorder is needed, this function compares memory primitive
/// descriptor of an operation (op_pd) for the given input with the
/// user-specified memory primitive descriptor.
///
/// @input: op_pd - memory primitive descriptor of the given input of an
/// operation
/// @input: net - net to which to add reorder primitive in case it is needed.
/// @return: true in case reorder of input is needed; false, otherwise.
bool CheckReorderToOpMem(const memory::primitive_desc& op_pd,
std::vector<primitive>* net) {
CHECK_NOTNULL(net);
CHECK_NOTNULL(user_memory_);
if (op_pd != user_memory_->get_primitive_desc()) {
// TODO(nhasabni): can we remove dynamic memory allocation?
reorder_memory_ = new memory(op_pd);
net->push_back(reorder(*user_memory_, *reorder_memory_));
return true;
}
return false; return false;
} }
bool result = (0 == op_name.compare(GetMklOpName("Add")) || /// Function to handle output reorder
0 == op_name.compare(GetMklOpName("Sub")) || ///
0 == op_name.compare(GetMklOpName("Mul")) || /// This function performs very similar functionality as input reordering
0 == op_name.compare(GetMklOpName("Maximum")) || /// function above. The only difference is that this function does not add
0 == op_name.compare(GetMklOpName("SquaredDifference"))); /// reorder primitive to the net. The reason for this is: the reorder
/// primitive for output needs to be added to the list only after operation
/// has executed. But we need to prepare a temporary buffer in case output
/// reorder is needed. And this temporary buffer will hold the output of
/// an operation before it is fed to reorder primitive.
///
/// @input memory primitive descriptor for the given output of an operation
/// @return: true in case reorder of output is needed; false, otherwise.
bool PrepareReorderToUserMemIfReq(const memory::primitive_desc& op_pd) {
CHECK_NOTNULL(user_memory_);
if (op_pd != user_memory_->get_primitive_desc()) {
// TODO(nhasabni): can we remove dynamic memory allocation?
reorder_memory_ = new memory(op_pd);
return true;
}
return false;
}
VLOG(1) << "mkl_op_registry::" << op_name /// Function to actually insert reorder primitive in the net
<< " is elementwise MKL op: " << result; ///
return result; /// This function completes remaining part of output reordering. It inserts
} /// a reordering primitive from the temporary buffer that holds the output
/// to the user-specified output buffer.
///
/// @input: net - net to which to add reorder primitive
void InsertReorderToUserMem(std::vector<primitive>* net) {
CHECK_NOTNULL(net);
CHECK_NOTNULL(user_memory_);
CHECK_NOTNULL(reorder_memory_);
net->push_back(reorder(*reorder_memory_, *user_memory_));
}
};
} // namespace mkl_op_registry #endif // INTEL_MKL_DNN
} // namespace tensorflow } // namespace tensorflow
#endif // INTEL_MKL #endif // INTEL_MKL

View File

@ -165,8 +165,8 @@ def tf_copts():
"-DEIGEN_AVOID_STL_ARRAY", "-DEIGEN_AVOID_STL_ARRAY",
"-Iexternal/gemmlowp", "-Iexternal/gemmlowp",
"-Wno-sign-compare", "-Wno-sign-compare",
"-fno-exceptions",
"-ftemplate-depth=900", "-ftemplate-depth=900",
"-fno-exceptions",
]) + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm( ]) + if_cuda(["-DGOOGLE_CUDA=1"]) + if_mkl(["-DINTEL_MKL=1", "-fopenmp",]) + if_android_arm(
["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + select({ ["-mfpu=neon"]) + if_linux_x86_64(["-msse3"]) + select({
clean_dep("//tensorflow:android"): [ clean_dep("//tensorflow:android"): [
@ -526,6 +526,7 @@ def tf_cc_test(name,
extra_copts=[], extra_copts=[],
suffix="", suffix="",
linkopts=[], linkopts=[],
nocopts=None,
**kwargs): **kwargs):
native.cc_test( native.cc_test(
name="%s%s" % (name, suffix), name="%s%s" % (name, suffix),
@ -547,6 +548,7 @@ def tf_cc_test(name,
clean_dep("//tensorflow:darwin"): 1, clean_dep("//tensorflow:darwin"): 1,
"//conditions:default": 0, "//conditions:default": 0,
}), }),
nocopts=nocopts,
**kwargs) **kwargs)
@ -649,7 +651,8 @@ def tf_cc_tests(srcs,
tags=[], tags=[],
size="medium", size="medium",
args=None, args=None,
linkopts=[]): linkopts=[],
nocopts=None):
for src in srcs: for src in srcs:
tf_cc_test( tf_cc_test(
name=src_to_test_name(src), name=src_to_test_name(src),
@ -659,7 +662,8 @@ def tf_cc_tests(srcs,
tags=tags, tags=tags,
size=size, size=size,
args=args, args=args,
linkopts=linkopts) linkopts=linkopts,
nocopts=nocopts)
def tf_cc_test_mkl(srcs, def tf_cc_test_mkl(srcs,
@ -669,7 +673,7 @@ def tf_cc_test_mkl(srcs,
tags=[], tags=[],
size="medium", size="medium",
args=None): args=None):
if_mkl(tf_cc_tests(srcs, deps, linkstatic, tags=tags, size=size, args=args)) if_mkl(tf_cc_tests(srcs, deps, name, linkstatic=linkstatic, tags=tags, size=size, args=args, nocopts="-fno-exceptions"))
def tf_cc_tests_gpu(srcs, def tf_cc_tests_gpu(srcs,
@ -867,18 +871,29 @@ def tf_mkl_kernel_library(name,
deps=None, deps=None,
alwayslink=1, alwayslink=1,
copts=tf_copts(), copts=tf_copts(),
nocopts="-fno-exceptions",
**kwargs): **kwargs):
if_mkl( if not bool(srcs):
tf_kernel_library( srcs = []
name, if not bool(hdrs):
prefix=prefix, hdrs = []
if prefix:
srcs = srcs + native.glob(
[prefix + "*.cc"])
hdrs = hdrs + native.glob(
[prefix + "*.h"])
if_mkl(
native.cc_library(
name=name,
srcs=srcs, srcs=srcs,
gpu_srcs=gpu_srcs,
hdrs=hdrs, hdrs=hdrs,
deps=deps, deps=deps,
alwayslink=alwayslink, alwayslink=alwayslink,
copts=copts, copts=copts,
**kwargs)) nocopts=nocopts
))
# Bazel rules for building swig files. # Bazel rules for building swig files.

View File

@ -170,6 +170,17 @@ def tf_workspace(path_prefix="", tf_repo_name=""):
print("path_prefix was specified to tf_workspace but is no longer used " + print("path_prefix was specified to tf_workspace but is no longer used " +
"and will be removed in the future.") "and will be removed in the future.")
native.new_http_archive(
name = "mkl_dnn",
urls = [
"https://github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz",
"http://mirror.bazel.build/github.com/01org/mkl-dnn/archive/b01e3a55a07be62172e713bcd2644c5176360212.tar.gz",
],
sha256 = "0d529ad4c49dc799e6df07c2b88b115d0668735da15fb3b3862d28d33fa68165",
strip_prefix = "mkl-dnn-b01e3a55a07be62172e713bcd2644c5176360212",
build_file = str(Label("//third_party/mkl_dnn:mkldnn.BUILD")),
)
native.new_http_archive( native.new_http_archive(
name = "eigen_archive", name = "eigen_archive",
urls = [ urls = [

1
third_party/mkl_dnn/BUILD vendored Normal file
View File

@ -0,0 +1 @@
licenses(["notice"])

25
third_party/mkl_dnn/mkldnn.BUILD vendored Normal file
View File

@ -0,0 +1,25 @@
exports_files(["LICENSE"])
cc_library(
name = "mkl_dnn",
srcs = glob([
"src/common/*.cpp",
"src/cpu/*.cpp",
]),
hdrs = glob(["include/*"]),
copts = ["-fexceptions"] + select({
"@org_tensorflow//tensorflow:linux_x86_64": [
"-fopenmp",
],
"//conditions:default": [],
}),
includes = [
"include",
"src",
"src/common",
"src/cpu",
"src/cpu/xbyak",
],
nocopts = "-fno-exceptions",
visibility = ["//visibility:public"],
)