mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Avoid parsing a rendezvous key for Send/Recv ops outside a loop.
For such ops, the rendezvous key will be constant, because
`ctx->frame_iter()` will always evaluate to `{0, 0}`. Benchmarking
reveals that this can save between 1 and 2 microseconds per Send or
Recv op execution. The optimization applies to all cross-process,
inter-device, and intra-device (host-to/from-device memory) Send/Recv
ops.
PiperOrigin-RevId: 158032522
This commit is contained in:
parent
cc2dd4ac85
commit
9f932e6ce6
|
|
@ -39,7 +39,8 @@ namespace tensorflow {
|
|||
namespace test {
|
||||
|
||||
Benchmark::Benchmark(const string& device, Graph* g,
|
||||
const SessionOptions* options, Graph* init) {
|
||||
const SessionOptions* options, Graph* init,
|
||||
Rendezvous* rendez) {
|
||||
SessionOptions default_options;
|
||||
if (!options) {
|
||||
options = &default_options;
|
||||
|
|
@ -61,7 +62,11 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
|||
pool_->Schedule(closure);
|
||||
};
|
||||
|
||||
rendez_ = NewLocalRendezvous();
|
||||
if (rendez == nullptr) {
|
||||
rendez_ = NewLocalRendezvous();
|
||||
} else {
|
||||
rendez_ = rendez;
|
||||
}
|
||||
|
||||
const int graph_def_version = g->versions().producer();
|
||||
|
||||
|
|
|
|||
|
|
@ -35,10 +35,11 @@ namespace test {
|
|||
|
||||
class Benchmark {
|
||||
public:
|
||||
// "device" must be either "cpu" or "gpu". Takes ownership of "g"
|
||||
// and "init".
|
||||
// "device" must be either "cpu" or "gpu". Takes ownership of "g",
|
||||
// "init", and one reference on "rendez" (if not null).
|
||||
Benchmark(const string& device, Graph* g,
|
||||
const SessionOptions* options = nullptr, Graph* init = nullptr);
|
||||
const SessionOptions* options = nullptr, Graph* init = nullptr,
|
||||
Rendezvous* rendez = nullptr);
|
||||
~Benchmark();
|
||||
|
||||
// Executes the graph for "iters" times.
|
||||
|
|
|
|||
|
|
@ -3159,6 +3159,21 @@ tf_kernel_library(
|
|||
deps = REQUIRED_DEPS,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "sendrecv_ops_test",
|
||||
srcs = ["sendrecv_ops_test.cc"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
|
||||
deps = [
|
||||
":ops_testutil",
|
||||
":ops_util",
|
||||
":sendrecv_ops",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sparse",
|
||||
deps = [
|
||||
|
|
|
|||
|
|
@ -52,17 +52,16 @@ SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
|||
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
|
||||
key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
|
||||
send_device_incarnation, tensor_name);
|
||||
// The vast majority of Send nodes are outside any loop context, so
|
||||
// proactively cache the rendezvous key for the top-level.
|
||||
GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
|
||||
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
|
||||
}
|
||||
|
||||
void SendOp::Compute(OpKernelContext* ctx) {
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->rendezvous() != nullptr,
|
||||
errors::Internal("Op kernel context needs to provide a rendezvous."));
|
||||
Rendezvous::ParsedKey parsed;
|
||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &parsed.buf_);
|
||||
VLOG(2) << "Send " << parsed.buf_;
|
||||
|
||||
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed.buf_, &parsed));
|
||||
|
||||
// The device context may be passed between the Send/Recv
|
||||
// boundary, so that the device context used to produce the Tensor
|
||||
|
|
@ -71,8 +70,24 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
|||
Rendezvous::Args args;
|
||||
args.device_context = ctx->op_device_context();
|
||||
args.alloc_attrs = ctx->input_alloc_attr(0);
|
||||
OP_REQUIRES_OK(ctx, ctx->rendezvous()->Send(parsed, args, ctx->input(0),
|
||||
ctx->is_input_dead()));
|
||||
|
||||
if (ctx->frame_iter() == FrameAndIter(0, 0)) {
|
||||
// Use the cached rendezvous key.
|
||||
VLOG(2) << "Send " << parsed_key_.buf_;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0),
|
||||
ctx->is_input_dead()));
|
||||
} else {
|
||||
Rendezvous::ParsedKey in_loop_parsed;
|
||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &in_loop_parsed.buf_);
|
||||
VLOG(2) << "Send " << in_loop_parsed.buf_;
|
||||
OP_REQUIRES_OK(ctx,
|
||||
Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed));
|
||||
|
||||
OP_REQUIRES_OK(ctx,
|
||||
ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0),
|
||||
ctx->is_input_dead()));
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
|
||||
|
|
@ -101,17 +116,16 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
|||
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
|
||||
key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
|
||||
send_device_incarnation, tensor_name);
|
||||
// The vast majority of Recv nodes are outside any loop context, so
|
||||
// proactively cache the rendezvous key for the top-level.
|
||||
GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
|
||||
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
|
||||
}
|
||||
|
||||
void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||
OP_REQUIRES(
|
||||
ctx, ctx->rendezvous() != nullptr,
|
||||
errors::Internal("Op kernel context needs to provide a rendezvous."));
|
||||
Rendezvous::ParsedKey parsed;
|
||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &parsed.buf_);
|
||||
VLOG(2) << "Recv " << parsed.buf_;
|
||||
|
||||
OP_REQUIRES_OK_ASYNC(ctx, Rendezvous::ParseKey(parsed.buf_, &parsed), done);
|
||||
|
||||
Rendezvous::Args args;
|
||||
args.device_context = ctx->op_device_context();
|
||||
|
|
@ -136,7 +150,19 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
|||
done();
|
||||
},
|
||||
std::move(done), _1, _2, _3, _4, _5);
|
||||
ctx->rendezvous()->RecvAsync(parsed, args, std::move(done_cb));
|
||||
|
||||
if (ctx->frame_iter() == FrameAndIter(0, 0)) {
|
||||
VLOG(2) << "Recv " << parsed_key_.buf_;
|
||||
ctx->rendezvous()->RecvAsync(parsed_key_, args, std::move(done_cb));
|
||||
} else {
|
||||
Rendezvous::ParsedKey in_loop_parsed;
|
||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &in_loop_parsed.buf_);
|
||||
VLOG(2) << "Recv " << in_loop_parsed.buf_;
|
||||
OP_REQUIRES_OK_ASYNC(
|
||||
ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed), done);
|
||||
|
||||
ctx->rendezvous()->RecvAsync(in_loop_parsed, args, std::move(done_cb));
|
||||
}
|
||||
}
|
||||
|
||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ class SendOp : public OpKernel {
|
|||
|
||||
private:
|
||||
string key_prefix_;
|
||||
Rendezvous::ParsedKey parsed_key_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(SendOp);
|
||||
};
|
||||
|
|
@ -39,6 +40,7 @@ class RecvOp : public AsyncOpKernel {
|
|||
|
||||
private:
|
||||
string key_prefix_;
|
||||
Rendezvous::ParsedKey parsed_key_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(RecvOp);
|
||||
};
|
||||
|
|
|
|||
74
tensorflow/core/kernels/sendrecv_ops_test.cc
Normal file
74
tensorflow/core/kernels/sendrecv_ops_test.cc
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/test_benchmark.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
|
||||
// Implement a trivial version of the Rendezvous interface, to avoid
|
||||
// clouding the benchmark results with the time spent in the various
|
||||
// implementations, and to avoid the duplicate-send or duplicate-recv
|
||||
// errors that would arise from running either benchmark in a loop.
|
||||
class DummyRendezvous : public Rendezvous {
|
||||
Status Send(const ParsedKey& key, const Args& args, const Tensor& val,
|
||||
const bool is_dead) override {
|
||||
return Status::OK();
|
||||
}
|
||||
void RecvAsync(const ParsedKey& key, const Args& args,
|
||||
DoneCallback done) override {
|
||||
static Tensor* t = new Tensor(DT_FLOAT, TensorShape({0}));
|
||||
done(Status::OK(), args, args, *t, false);
|
||||
}
|
||||
void StartAbort(const Status& status) override {}
|
||||
};
|
||||
|
||||
static Graph* Send() {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
Tensor in0(DT_FLOAT, TensorShape({0}));
|
||||
test::graph::Send(g, test::graph::Constant(g, in0), "T", "/cpu:0", 1,
|
||||
"/cpu:0");
|
||||
test::graph::Recv(g, "T", "float", "/cpu:0", 1, "/cpu:0");
|
||||
return g;
|
||||
}
|
||||
|
||||
static Graph* Recv() {
|
||||
Graph* g = new Graph(OpRegistry::Global());
|
||||
test::graph::Recv(g, "T", "float", "/cpu:0", 1, "/cpu:0");
|
||||
return g;
|
||||
}
|
||||
|
||||
static void BM_Send(int iters) {
|
||||
testing::UseRealTime();
|
||||
testing::ItemsProcessed(static_cast<int64>(iters));
|
||||
test::Benchmark("cpu", Send(), nullptr, nullptr, new DummyRendezvous)
|
||||
.Run(iters);
|
||||
}
|
||||
BENCHMARK(BM_Send);
|
||||
|
||||
static void BM_Recv(int iters) {
|
||||
testing::UseRealTime();
|
||||
testing::ItemsProcessed(static_cast<int64>(iters));
|
||||
test::Benchmark("cpu", Recv(), nullptr, nullptr, new DummyRendezvous)
|
||||
.Run(iters);
|
||||
}
|
||||
BENCHMARK(BM_Recv);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
Loading…
Reference in New Issue
Block a user