Avoid parsing a rendezvous key for Send/Recv ops outside a loop.

For such ops, the rendezvous key will be constant, because
`ctx->frame_iter()` will always evaluate to `{0, 0}`. Benchmarking
reveals that this can save between 1 and 2 microseconds per Send or
Recv op execution. The optimization applies to all cross-process,
inter-device, and intra-device (host-to/from-device memory) Send/Recv
ops.

PiperOrigin-RevId: 158032522
This commit is contained in:
Derek Murray 2017-06-05 10:47:38 -07:00 committed by TensorFlower Gardener
parent cc2dd4ac85
commit 9f932e6ce6
6 changed files with 141 additions and 18 deletions

View File

@ -39,7 +39,8 @@ namespace tensorflow {
namespace test {
Benchmark::Benchmark(const string& device, Graph* g,
const SessionOptions* options, Graph* init) {
const SessionOptions* options, Graph* init,
Rendezvous* rendez) {
SessionOptions default_options;
if (!options) {
options = &default_options;
@ -61,7 +62,11 @@ Benchmark::Benchmark(const string& device, Graph* g,
pool_->Schedule(closure);
};
if (rendez == nullptr) {
rendez_ = NewLocalRendezvous();
} else {
rendez_ = rendez;
}
const int graph_def_version = g->versions().producer();

View File

@ -35,10 +35,11 @@ namespace test {
class Benchmark {
public:
// "device" must be either "cpu" or "gpu". Takes ownership of "g"
// and "init".
// "device" must be either "cpu" or "gpu". Takes ownership of "g",
// "init", and one reference on "rendez" (if not null).
Benchmark(const string& device, Graph* g,
const SessionOptions* options = nullptr, Graph* init = nullptr);
const SessionOptions* options = nullptr, Graph* init = nullptr,
Rendezvous* rendez = nullptr);
~Benchmark();
// Executes the graph for "iters" times.

View File

@ -3159,6 +3159,21 @@ tf_kernel_library(
deps = REQUIRED_DEPS,
)
tf_cc_test(
name = "sendrecv_ops_test",
srcs = ["sendrecv_ops_test.cc"],
linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
deps = [
":ops_testutil",
":ops_util",
":sendrecv_ops",
"//tensorflow/core:framework",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
cc_library(
name = "sparse",
deps = [

View File

@ -52,17 +52,16 @@ SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
send_device_incarnation, tensor_name);
// The vast majority of Send nodes are outside any loop context, so
// proactively cache the rendezvous key for the top-level.
GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
}
void SendOp::Compute(OpKernelContext* ctx) {
OP_REQUIRES(
ctx, ctx->rendezvous() != nullptr,
errors::Internal("Op kernel context needs to provide a rendezvous."));
Rendezvous::ParsedKey parsed;
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &parsed.buf_);
VLOG(2) << "Send " << parsed.buf_;
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed.buf_, &parsed));
// The device context may be passed between the Send/Recv
// boundary, so that the device context used to produce the Tensor
@ -71,8 +70,24 @@ void SendOp::Compute(OpKernelContext* ctx) {
Rendezvous::Args args;
args.device_context = ctx->op_device_context();
args.alloc_attrs = ctx->input_alloc_attr(0);
OP_REQUIRES_OK(ctx, ctx->rendezvous()->Send(parsed, args, ctx->input(0),
if (ctx->frame_iter() == FrameAndIter(0, 0)) {
// Use the cached rendezvous key.
VLOG(2) << "Send " << parsed_key_.buf_;
OP_REQUIRES_OK(ctx,
ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0),
ctx->is_input_dead()));
} else {
Rendezvous::ParsedKey in_loop_parsed;
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &in_loop_parsed.buf_);
VLOG(2) << "Send " << in_loop_parsed.buf_;
OP_REQUIRES_OK(ctx,
Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed));
OP_REQUIRES_OK(ctx,
ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0),
ctx->is_input_dead()));
}
}
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
@ -101,17 +116,16 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
send_device_incarnation, tensor_name);
// The vast majority of Recv nodes are outside any loop context, so
// proactively cache the rendezvous key for the top-level.
GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
}
void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
OP_REQUIRES(
ctx, ctx->rendezvous() != nullptr,
errors::Internal("Op kernel context needs to provide a rendezvous."));
Rendezvous::ParsedKey parsed;
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &parsed.buf_);
VLOG(2) << "Recv " << parsed.buf_;
OP_REQUIRES_OK_ASYNC(ctx, Rendezvous::ParseKey(parsed.buf_, &parsed), done);
Rendezvous::Args args;
args.device_context = ctx->op_device_context();
@ -136,7 +150,19 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
done();
},
std::move(done), _1, _2, _3, _4, _5);
ctx->rendezvous()->RecvAsync(parsed, args, std::move(done_cb));
if (ctx->frame_iter() == FrameAndIter(0, 0)) {
VLOG(2) << "Recv " << parsed_key_.buf_;
ctx->rendezvous()->RecvAsync(parsed_key_, args, std::move(done_cb));
} else {
Rendezvous::ParsedKey in_loop_parsed;
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &in_loop_parsed.buf_);
VLOG(2) << "Recv " << in_loop_parsed.buf_;
OP_REQUIRES_OK_ASYNC(
ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed), done);
ctx->rendezvous()->RecvAsync(in_loop_parsed, args, std::move(done_cb));
}
}
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);

View File

@ -28,6 +28,7 @@ class SendOp : public OpKernel {
private:
string key_prefix_;
Rendezvous::ParsedKey parsed_key_;
TF_DISALLOW_COPY_AND_ASSIGN(SendOp);
};
@ -39,6 +40,7 @@ class RecvOp : public AsyncOpKernel {
private:
string key_prefix_;
Rendezvous::ParsedKey parsed_key_;
TF_DISALLOW_COPY_AND_ASSIGN(RecvOp);
};

View File

@ -0,0 +1,74 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
// Implement a trivial version of the Rendezvous interface, to avoid
// clouding the benchmark results with the time spent in the various
// implementations, and to avoid the duplicate-send or duplicate-recv
// errors that would arise from running either benchmark in a loop.
class DummyRendezvous : public Rendezvous {
Status Send(const ParsedKey& key, const Args& args, const Tensor& val,
const bool is_dead) override {
return Status::OK();
}
void RecvAsync(const ParsedKey& key, const Args& args,
DoneCallback done) override {
static Tensor* t = new Tensor(DT_FLOAT, TensorShape({0}));
done(Status::OK(), args, args, *t, false);
}
void StartAbort(const Status& status) override {}
};
static Graph* Send() {
Graph* g = new Graph(OpRegistry::Global());
Tensor in0(DT_FLOAT, TensorShape({0}));
test::graph::Send(g, test::graph::Constant(g, in0), "T", "/cpu:0", 1,
"/cpu:0");
test::graph::Recv(g, "T", "float", "/cpu:0", 1, "/cpu:0");
return g;
}
static Graph* Recv() {
Graph* g = new Graph(OpRegistry::Global());
test::graph::Recv(g, "T", "float", "/cpu:0", 1, "/cpu:0");
return g;
}
static void BM_Send(int iters) {
testing::UseRealTime();
testing::ItemsProcessed(static_cast<int64>(iters));
test::Benchmark("cpu", Send(), nullptr, nullptr, new DummyRendezvous)
.Run(iters);
}
BENCHMARK(BM_Send);
static void BM_Recv(int iters) {
testing::UseRealTime();
testing::ItemsProcessed(static_cast<int64>(iters));
test::Benchmark("cpu", Recv(), nullptr, nullptr, new DummyRendezvous)
.Run(iters);
}
BENCHMARK(BM_Recv);
} // namespace
} // namespace tensorflow