mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
Minor cleanup
PiperOrigin-RevId: 163685423
This commit is contained in:
parent
f9c758719e
commit
15e928d51e
|
|
@ -18,10 +18,7 @@
|
|||
|
||||
namespace tensorflow {
|
||||
|
||||
using shape_inference::DimensionHandle;
|
||||
using shape_inference::InferenceContext;
|
||||
using shape_inference::ShapeHandle;
|
||||
|
||||
|
||||
REGISTER_OP("ScatterAddNdim")
|
||||
.Input("input: Ref(float)")
|
||||
|
|
|
|||
|
|
@ -28,7 +28,6 @@ limitations under the License.
|
|||
#include "tensorflow/core/distributed_runtime/worker_interface.h"
|
||||
#include "tensorflow/core/framework/allocation_description.pb.h"
|
||||
#include "tensorflow/core/framework/cost_graph.pb.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ limitations under the License.
|
|||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
TEST(RendezvousTest, Key) {
|
||||
const string key = Rendezvous::CreateKey(
|
||||
|
|
@ -64,22 +65,22 @@ TEST(RendezvousTest, Key) {
|
|||
|
||||
class LocalRendezvousTest : public ::testing::Test {
|
||||
public:
|
||||
LocalRendezvousTest()
|
||||
: threads_(new thread::ThreadPool(Env::Default(), "test", 16)) {
|
||||
LocalRendezvousTest() : threads_(Env::Default(), "test", 16) {
|
||||
rendez_ = NewLocalRendezvous();
|
||||
}
|
||||
|
||||
~LocalRendezvousTest() override {
|
||||
rendez_->Unref();
|
||||
delete threads_;
|
||||
}
|
||||
|
||||
void SchedClosure(std::function<void()> fn) { threads_->Schedule(fn); }
|
||||
void SchedClosure(std::function<void()> fn) {
|
||||
threads_.Schedule(std::move(fn));
|
||||
}
|
||||
|
||||
Rendezvous* rendez_;
|
||||
|
||||
private:
|
||||
thread::ThreadPool* threads_;
|
||||
thread::ThreadPool threads_;
|
||||
};
|
||||
|
||||
// string -> Tensor<string>
|
||||
|
|
@ -96,9 +97,6 @@ string V(const Tensor& tensor) {
|
|||
return tensor.scalar<string>()();
|
||||
}
|
||||
|
||||
const char* kFoo = "/cpu:0;1;/cpu:1;foo;1;2";
|
||||
const char* kBar = "/gpu:0;2;/gpu:1;bar;1;2";
|
||||
|
||||
Rendezvous::ParsedKey MakeKey(const string& name) {
|
||||
string s = Rendezvous::CreateKey("/job:mnist/replica:1/task:2/CPU:0", 7890,
|
||||
"/job:mnist/replica:1/task:2/GPU:0", name,
|
||||
|
|
@ -213,7 +211,7 @@ TEST_F(LocalRendezvousTest, RandomSendRecv) {
|
|||
state.done.WaitForNotification();
|
||||
}
|
||||
|
||||
static void RandomSleep() {
|
||||
void RandomSleep() {
|
||||
if (std::rand() % 10 == 0) {
|
||||
Env::Default()->SleepForMicroseconds(1000);
|
||||
}
|
||||
|
|
@ -310,7 +308,7 @@ TEST_F(LocalRendezvousTest, TransferDummyDeviceContext) {
|
|||
args1.device_context->Unref();
|
||||
}
|
||||
|
||||
static void BM_SendRecv(int iters) {
|
||||
void BM_SendRecv(int iters) {
|
||||
Rendezvous* rendez = NewLocalRendezvous();
|
||||
Tensor orig = V("val");
|
||||
Tensor val(DT_STRING, TensorShape({}));
|
||||
|
|
@ -328,7 +326,7 @@ static void BM_SendRecv(int iters) {
|
|||
}
|
||||
BENCHMARK(BM_SendRecv);
|
||||
|
||||
static void BM_PingPong(int iters) {
|
||||
void BM_PingPong(int iters) {
|
||||
CHECK_GT(iters, 0);
|
||||
thread::ThreadPool* pool = new thread::ThreadPool(Env::Default(), "test", 1);
|
||||
|
||||
|
|
@ -362,4 +360,5 @@ static void BM_PingPong(int iters) {
|
|||
}
|
||||
BENCHMARK(BM_PingPong);
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user