mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Avoid parsing a rendezvous key for Send/Recv ops outside a loop.
For such ops, the rendezvous key will be constant, because
`ctx->frame_iter()` will always evaluate to `{0, 0}`. Benchmarking
reveals that this can save between 1 and 2 microseconds per Send or
Recv op execution. The optimization applies to all cross-process,
inter-device, and intra-device (host-to/from-device memory) Send/Recv
ops.
PiperOrigin-RevId: 158032522
This commit is contained in:
parent
cc2dd4ac85
commit
9f932e6ce6
|
|
@ -39,7 +39,8 @@ namespace tensorflow {
|
||||||
namespace test {
|
namespace test {
|
||||||
|
|
||||||
Benchmark::Benchmark(const string& device, Graph* g,
|
Benchmark::Benchmark(const string& device, Graph* g,
|
||||||
const SessionOptions* options, Graph* init) {
|
const SessionOptions* options, Graph* init,
|
||||||
|
Rendezvous* rendez) {
|
||||||
SessionOptions default_options;
|
SessionOptions default_options;
|
||||||
if (!options) {
|
if (!options) {
|
||||||
options = &default_options;
|
options = &default_options;
|
||||||
|
|
@ -61,7 +62,11 @@ Benchmark::Benchmark(const string& device, Graph* g,
|
||||||
pool_->Schedule(closure);
|
pool_->Schedule(closure);
|
||||||
};
|
};
|
||||||
|
|
||||||
rendez_ = NewLocalRendezvous();
|
if (rendez == nullptr) {
|
||||||
|
rendez_ = NewLocalRendezvous();
|
||||||
|
} else {
|
||||||
|
rendez_ = rendez;
|
||||||
|
}
|
||||||
|
|
||||||
const int graph_def_version = g->versions().producer();
|
const int graph_def_version = g->versions().producer();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,10 +35,11 @@ namespace test {
|
||||||
|
|
||||||
class Benchmark {
|
class Benchmark {
|
||||||
public:
|
public:
|
||||||
// "device" must be either "cpu" or "gpu". Takes ownership of "g"
|
// "device" must be either "cpu" or "gpu". Takes ownership of "g",
|
||||||
// and "init".
|
// "init", and one reference on "rendez" (if not null).
|
||||||
Benchmark(const string& device, Graph* g,
|
Benchmark(const string& device, Graph* g,
|
||||||
const SessionOptions* options = nullptr, Graph* init = nullptr);
|
const SessionOptions* options = nullptr, Graph* init = nullptr,
|
||||||
|
Rendezvous* rendez = nullptr);
|
||||||
~Benchmark();
|
~Benchmark();
|
||||||
|
|
||||||
// Executes the graph for "iters" times.
|
// Executes the graph for "iters" times.
|
||||||
|
|
|
||||||
|
|
@ -3159,6 +3159,21 @@ tf_kernel_library(
|
||||||
deps = REQUIRED_DEPS,
|
deps = REQUIRED_DEPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
tf_cc_test(
|
||||||
|
name = "sendrecv_ops_test",
|
||||||
|
srcs = ["sendrecv_ops_test.cc"],
|
||||||
|
linkstatic = tf_kernel_tests_linkstatic(), # Required for benchmarking
|
||||||
|
deps = [
|
||||||
|
":ops_testutil",
|
||||||
|
":ops_util",
|
||||||
|
":sendrecv_ops",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:test",
|
||||||
|
"//tensorflow/core:test_main",
|
||||||
|
"//tensorflow/core:testlib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "sparse",
|
name = "sparse",
|
||||||
deps = [
|
deps = [
|
||||||
|
|
|
||||||
|
|
@ -52,17 +52,16 @@ SendOp::SendOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
|
||||||
key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
|
key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
|
||||||
send_device_incarnation, tensor_name);
|
send_device_incarnation, tensor_name);
|
||||||
|
// The vast majority of Send nodes are outside any loop context, so
|
||||||
|
// proactively cache the rendezvous key for the top-level.
|
||||||
|
GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
|
||||||
|
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void SendOp::Compute(OpKernelContext* ctx) {
|
void SendOp::Compute(OpKernelContext* ctx) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, ctx->rendezvous() != nullptr,
|
ctx, ctx->rendezvous() != nullptr,
|
||||||
errors::Internal("Op kernel context needs to provide a rendezvous."));
|
errors::Internal("Op kernel context needs to provide a rendezvous."));
|
||||||
Rendezvous::ParsedKey parsed;
|
|
||||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &parsed.buf_);
|
|
||||||
VLOG(2) << "Send " << parsed.buf_;
|
|
||||||
|
|
||||||
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed.buf_, &parsed));
|
|
||||||
|
|
||||||
// The device context may be passed between the Send/Recv
|
// The device context may be passed between the Send/Recv
|
||||||
// boundary, so that the device context used to produce the Tensor
|
// boundary, so that the device context used to produce the Tensor
|
||||||
|
|
@ -71,8 +70,24 @@ void SendOp::Compute(OpKernelContext* ctx) {
|
||||||
Rendezvous::Args args;
|
Rendezvous::Args args;
|
||||||
args.device_context = ctx->op_device_context();
|
args.device_context = ctx->op_device_context();
|
||||||
args.alloc_attrs = ctx->input_alloc_attr(0);
|
args.alloc_attrs = ctx->input_alloc_attr(0);
|
||||||
OP_REQUIRES_OK(ctx, ctx->rendezvous()->Send(parsed, args, ctx->input(0),
|
|
||||||
ctx->is_input_dead()));
|
if (ctx->frame_iter() == FrameAndIter(0, 0)) {
|
||||||
|
// Use the cached rendezvous key.
|
||||||
|
VLOG(2) << "Send " << parsed_key_.buf_;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->rendezvous()->Send(parsed_key_, args, ctx->input(0),
|
||||||
|
ctx->is_input_dead()));
|
||||||
|
} else {
|
||||||
|
Rendezvous::ParsedKey in_loop_parsed;
|
||||||
|
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &in_loop_parsed.buf_);
|
||||||
|
VLOG(2) << "Send " << in_loop_parsed.buf_;
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed));
|
||||||
|
|
||||||
|
OP_REQUIRES_OK(ctx,
|
||||||
|
ctx->rendezvous()->Send(in_loop_parsed, args, ctx->input(0),
|
||||||
|
ctx->is_input_dead()));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
|
REGISTER_KERNEL_BUILDER(Name("_Send").Device(DEVICE_CPU), SendOp);
|
||||||
|
|
@ -101,17 +116,16 @@ RecvOp::RecvOp(OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
|
||||||
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("tensor_name", &tensor_name));
|
||||||
key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
|
key_prefix_ = GetRendezvousKeyPrefix(send_device, recv_device,
|
||||||
send_device_incarnation, tensor_name);
|
send_device_incarnation, tensor_name);
|
||||||
|
// The vast majority of Recv nodes are outside any loop context, so
|
||||||
|
// proactively cache the rendezvous key for the top-level.
|
||||||
|
GetRendezvousKey(key_prefix_, {0, 0}, &parsed_key_.buf_);
|
||||||
|
OP_REQUIRES_OK(ctx, Rendezvous::ParseKey(parsed_key_.buf_, &parsed_key_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||||
OP_REQUIRES(
|
OP_REQUIRES(
|
||||||
ctx, ctx->rendezvous() != nullptr,
|
ctx, ctx->rendezvous() != nullptr,
|
||||||
errors::Internal("Op kernel context needs to provide a rendezvous."));
|
errors::Internal("Op kernel context needs to provide a rendezvous."));
|
||||||
Rendezvous::ParsedKey parsed;
|
|
||||||
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &parsed.buf_);
|
|
||||||
VLOG(2) << "Recv " << parsed.buf_;
|
|
||||||
|
|
||||||
OP_REQUIRES_OK_ASYNC(ctx, Rendezvous::ParseKey(parsed.buf_, &parsed), done);
|
|
||||||
|
|
||||||
Rendezvous::Args args;
|
Rendezvous::Args args;
|
||||||
args.device_context = ctx->op_device_context();
|
args.device_context = ctx->op_device_context();
|
||||||
|
|
@ -136,7 +150,19 @@ void RecvOp::ComputeAsync(OpKernelContext* ctx, DoneCallback done) {
|
||||||
done();
|
done();
|
||||||
},
|
},
|
||||||
std::move(done), _1, _2, _3, _4, _5);
|
std::move(done), _1, _2, _3, _4, _5);
|
||||||
ctx->rendezvous()->RecvAsync(parsed, args, std::move(done_cb));
|
|
||||||
|
if (ctx->frame_iter() == FrameAndIter(0, 0)) {
|
||||||
|
VLOG(2) << "Recv " << parsed_key_.buf_;
|
||||||
|
ctx->rendezvous()->RecvAsync(parsed_key_, args, std::move(done_cb));
|
||||||
|
} else {
|
||||||
|
Rendezvous::ParsedKey in_loop_parsed;
|
||||||
|
GetRendezvousKey(key_prefix_, ctx->frame_iter(), &in_loop_parsed.buf_);
|
||||||
|
VLOG(2) << "Recv " << in_loop_parsed.buf_;
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
ctx, Rendezvous::ParseKey(in_loop_parsed.buf_, &in_loop_parsed), done);
|
||||||
|
|
||||||
|
ctx->rendezvous()->RecvAsync(in_loop_parsed, args, std::move(done_cb));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
|
REGISTER_KERNEL_BUILDER(Name("_Recv").Device(DEVICE_CPU), RecvOp);
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ class SendOp : public OpKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string key_prefix_;
|
string key_prefix_;
|
||||||
|
Rendezvous::ParsedKey parsed_key_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(SendOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(SendOp);
|
||||||
};
|
};
|
||||||
|
|
@ -39,6 +40,7 @@ class RecvOp : public AsyncOpKernel {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
string key_prefix_;
|
string key_prefix_;
|
||||||
|
Rendezvous::ParsedKey parsed_key_;
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(RecvOp);
|
TF_DISALLOW_COPY_AND_ASSIGN(RecvOp);
|
||||||
};
|
};
|
||||||
|
|
|
||||||
74
tensorflow/core/kernels/sendrecv_ops_test.cc
Normal file
74
tensorflow/core/kernels/sendrecv_ops_test.cc
Normal file
|
|
@ -0,0 +1,74 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
|
||||||
|
#include "tensorflow/core/framework/tensor.h"
|
||||||
|
#include "tensorflow/core/platform/test.h"
|
||||||
|
#include "tensorflow/core/platform/test_benchmark.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
// Implement a trivial version of the Rendezvous interface, to avoid
|
||||||
|
// clouding the benchmark results with the time spent in the various
|
||||||
|
// implementations, and to avoid the duplicate-send or duplicate-recv
|
||||||
|
// errors that would arise from running either benchmark in a loop.
|
||||||
|
class DummyRendezvous : public Rendezvous {
|
||||||
|
Status Send(const ParsedKey& key, const Args& args, const Tensor& val,
|
||||||
|
const bool is_dead) override {
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
void RecvAsync(const ParsedKey& key, const Args& args,
|
||||||
|
DoneCallback done) override {
|
||||||
|
static Tensor* t = new Tensor(DT_FLOAT, TensorShape({0}));
|
||||||
|
done(Status::OK(), args, args, *t, false);
|
||||||
|
}
|
||||||
|
void StartAbort(const Status& status) override {}
|
||||||
|
};
|
||||||
|
|
||||||
|
static Graph* Send() {
|
||||||
|
Graph* g = new Graph(OpRegistry::Global());
|
||||||
|
Tensor in0(DT_FLOAT, TensorShape({0}));
|
||||||
|
test::graph::Send(g, test::graph::Constant(g, in0), "T", "/cpu:0", 1,
|
||||||
|
"/cpu:0");
|
||||||
|
test::graph::Recv(g, "T", "float", "/cpu:0", 1, "/cpu:0");
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
static Graph* Recv() {
|
||||||
|
Graph* g = new Graph(OpRegistry::Global());
|
||||||
|
test::graph::Recv(g, "T", "float", "/cpu:0", 1, "/cpu:0");
|
||||||
|
return g;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void BM_Send(int iters) {
|
||||||
|
testing::UseRealTime();
|
||||||
|
testing::ItemsProcessed(static_cast<int64>(iters));
|
||||||
|
test::Benchmark("cpu", Send(), nullptr, nullptr, new DummyRendezvous)
|
||||||
|
.Run(iters);
|
||||||
|
}
|
||||||
|
BENCHMARK(BM_Send);
|
||||||
|
|
||||||
|
static void BM_Recv(int iters) {
|
||||||
|
testing::UseRealTime();
|
||||||
|
testing::ItemsProcessed(static_cast<int64>(iters));
|
||||||
|
test::Benchmark("cpu", Recv(), nullptr, nullptr, new DummyRendezvous)
|
||||||
|
.Run(iters);
|
||||||
|
}
|
||||||
|
BENCHMARK(BM_Recv);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace tensorflow
|
||||||
Loading…
Reference in New Issue
Block a user