Add WriteEvent method to SummaryWriterInterface

Another change will follow that adds an op for this method. It will be useful
for loading event logs into other types of summary writer implementations, like
a database.

This change might also make the new summary file writer go faster, due to less
memory copying.

PiperOrigin-RevId: 173640116
This commit is contained in:
Justine Tunney 2017-10-27 00:15:55 -07:00 committed by TensorFlower Gardener
parent a494558127
commit 9c8a520b07
3 changed files with 69 additions and 38 deletions

View File

@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/summary_interface.h"
#include <utility>
#include "tensorflow/compiler/xla/ptr_util.h" #include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/op_kernel.h"
@ -19,12 +22,10 @@ limitations under the License.
#include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/kernels/summary_interface.h"
#include "tensorflow/core/lib/histogram/histogram.h" #include "tensorflow/core/lib/histogram/histogram.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/png/png_io.h" #include "tensorflow/core/lib/png/png_io.h"
#include "tensorflow/core/lib/wav/wav_io.h" #include "tensorflow/core/lib/wav/wav_io.h"
#include "tensorflow/core/util/event.pb.h"
#include "tensorflow/core/util/events_writer.h" #include "tensorflow/core/util/events_writer.h"
namespace tensorflow { namespace tensorflow {
@ -250,28 +251,34 @@ class SummaryWriterImpl : public SummaryWriterInterface {
Status WriteTensor(int64 global_step, Tensor t, const string& tag, Status WriteTensor(int64 global_step, Tensor t, const string& tag,
const string& serialized_metadata) override { const string& serialized_metadata) override {
Summary s; std::unique_ptr<Event> e{new Event};
Summary::Value* v = s.add_value(); e->set_step(global_step);
e->set_wall_time(GetWallTime());
Summary::Value* v = e->mutable_summary()->add_value();
t.AsProtoTensorContent(v->mutable_tensor()); t.AsProtoTensorContent(v->mutable_tensor());
v->set_tag(tag); v->set_tag(tag);
v->mutable_metadata()->ParseFromString(serialized_metadata); v->mutable_metadata()->ParseFromString(serialized_metadata);
return Enqueue(global_step, s); return WriteEvent(std::move(e));
} }
Status WriteScalar(int64 global_step, Tensor t, const string& tag) override { Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
Summary s; std::unique_ptr<Event> e{new Event};
Summary::Value* v = s.add_value(); e->set_step(global_step);
e->set_wall_time(GetWallTime());
Summary::Value* v = e->mutable_summary()->add_value();
v->set_tag(tag); v->set_tag(tag);
float value; float value;
TF_RETURN_IF_ERROR(TensorValueAt<float>(t, 0, &value)); TF_RETURN_IF_ERROR(TensorValueAt<float>(t, 0, &value));
v->set_simple_value(value); v->set_simple_value(value);
return Enqueue(global_step, s); return WriteEvent(std::move(e));
} }
Status WriteHistogram(int64 global_step, Tensor t, Status WriteHistogram(int64 global_step, Tensor t,
const string& tag) override { const string& tag) override {
Summary s; std::unique_ptr<Event> e{new Event};
Summary::Value* v = s.add_value(); e->set_step(global_step);
e->set_wall_time(GetWallTime());
Summary::Value* v = e->mutable_summary()->add_value();
v->set_tag(tag); v->set_tag(tag);
histogram::Histogram histo; histogram::Histogram histo;
for (int64 i = 0; i < t.NumElements(); i++) { for (int64 i = 0; i < t.NumElements(); i++) {
@ -287,7 +294,7 @@ class SummaryWriterImpl : public SummaryWriterInterface {
} }
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */); histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
return Enqueue(global_step, s); return WriteEvent(std::move(e));
} }
Status WriteImage(int64 global_step, Tensor tensor, const string& tag, Status WriteImage(int64 global_step, Tensor tensor, const string& tag,
@ -306,7 +313,10 @@ class SummaryWriterImpl : public SummaryWriterInterface {
return errors::InvalidArgument("Tensor too large for summary ", return errors::InvalidArgument("Tensor too large for summary ",
tensor.shape().DebugString()); tensor.shape().DebugString());
} }
Summary s; std::unique_ptr<Event> e{new Event};
e->set_step(global_step);
e->set_wall_time(GetWallTime());
Summary* s = e->mutable_summary();
// The casts and h * w cannot overflow because of the limits above. // The casts and h * w cannot overflow because of the limits above.
const int batch_size = static_cast<int>(tensor.dim_size(0)); const int batch_size = static_cast<int>(tensor.dim_size(0));
const int h = static_cast<int>(tensor.dim_size(1)); const int h = static_cast<int>(tensor.dim_size(1));
@ -321,20 +331,20 @@ class SummaryWriterImpl : public SummaryWriterInterface {
&values(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth)); &values(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
}; };
TF_RETURN_IF_ERROR( TF_RETURN_IF_ERROR(
AddImages(tag, max_images, batch_size, w, h, depth, ith_image, &s)); AddImages(tag, max_images, batch_size, w, h, depth, ith_image, s));
} else if (tensor.dtype() == DT_HALF) { } else if (tensor.dtype() == DT_HALF) {
TF_RETURN_IF_ERROR(NormalizeAndAddImages<Eigen::half>( TF_RETURN_IF_ERROR(NormalizeAndAddImages<Eigen::half>(
tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, &s)); tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, s));
} else if (tensor.dtype() == DT_FLOAT) { } else if (tensor.dtype() == DT_FLOAT) {
TF_RETURN_IF_ERROR(NormalizeAndAddImages<float>( TF_RETURN_IF_ERROR(NormalizeAndAddImages<float>(
tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, &s)); tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, s));
} else { } else {
return errors::InvalidArgument( return errors::InvalidArgument(
"Only DT_INT8, DT_HALF, and DT_FLOAT images are supported. Got ", "Only DT_INT8, DT_HALF, and DT_FLOAT images are supported. Got ",
DataTypeString(tensor.dtype())); DataTypeString(tensor.dtype()));
} }
return Enqueue(global_step, s); return WriteEvent(std::move(e));
} }
Status WriteAudio(int64 global_step, Tensor tensor, const string& tag, Status WriteAudio(int64 global_step, Tensor tensor, const string& tag,
@ -346,10 +356,13 @@ class SummaryWriterImpl : public SummaryWriterInterface {
const int64 length_frames = tensor.dim_size(1); const int64 length_frames = tensor.dim_size(1);
const int64 num_channels = const int64 num_channels =
tensor.dims() == 2 ? 1 : tensor.dim_size(tensor.dims() - 1); tensor.dims() == 2 ? 1 : tensor.dim_size(tensor.dims() - 1);
Summary s; std::unique_ptr<Event> e{new Event};
e->set_step(global_step);
e->set_wall_time(GetWallTime());
Summary* s = e->mutable_summary();
const int N = std::min<int>(max_outputs, batch_size); const int N = std::min<int>(max_outputs, batch_size);
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
Summary::Value* v = s.add_value(); Summary::Value* v = s->add_value();
if (max_outputs > 1) { if (max_outputs > 1) {
v->set_tag(strings::StrCat(tag, "/audio/", i)); v->set_tag(strings::StrCat(tag, "/audio/", i));
} else { } else {
@ -375,16 +388,12 @@ class SummaryWriterImpl : public SummaryWriterInterface {
channels_by_frames.data(), sample_rate_truncated, num_channels, channels_by_frames.data(), sample_rate_truncated, num_channels,
length_frames, sa->mutable_encoded_audio_string())); length_frames, sa->mutable_encoded_audio_string()));
} }
return WriteEvent(std::move(e));
return Enqueue(global_step, s);
} }
string DebugString() override { return "SummaryWriterImpl"; } Status WriteEvent(std::unique_ptr<Event> event) override {
private:
Status Enqueue(int64 global_step, const Summary& summary) {
mutex_lock ml(mu_); mutex_lock ml(mu_);
queue_.emplace_back(global_step, summary, env_->NowMicros()); queue_.emplace_back(std::move(event));
if (queue_.size() >= max_queue_ || if (queue_.size() >= max_queue_ ||
env_->NowMicros() - last_flush_ > 1000 * flush_millis_) { env_->NowMicros() - last_flush_ > 1000 * flush_millis_) {
return InternalFlush(); return InternalFlush();
@ -392,13 +401,16 @@ class SummaryWriterImpl : public SummaryWriterInterface {
return Status::OK(); return Status::OK();
} }
string DebugString() override { return "SummaryWriterImpl"; }
private:
double GetWallTime() {
return static_cast<double>(env_->NowMicros()) / 1.0e6;
}
Status InternalFlush() EXCLUSIVE_LOCKS_REQUIRED(mu_) { Status InternalFlush() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
for (const EventInfo& e : queue_) { for (const std::unique_ptr<Event>& e : queue_) {
Event event; events_writer_->WriteEvent(*e);
event.set_step(std::get<0>(e));
*event.mutable_summary() = std::get<1>(e);
event.set_wall_time(static_cast<double>(std::get<2>(e)) / 1.0e6);
events_writer_->WriteEvent(event);
} }
queue_.clear(); queue_.clear();
if (!events_writer_->Flush()) { if (!events_writer_->Flush()) {
@ -413,9 +425,8 @@ class SummaryWriterImpl : public SummaryWriterInterface {
const int flush_millis_; const int flush_millis_;
uint64 last_flush_; uint64 last_flush_;
Env* env_; Env* env_;
using EventInfo = std::tuple<int64, Summary, int64>;
mutex mu_; mutex mu_;
std::vector<EventInfo> queue_ GUARDED_BY(mu_); std::vector<std::unique_ptr<Event>> queue_ GUARDED_BY(mu_);
// A pointer to allow deferred construction. // A pointer to allow deferred construction.
std::unique_ptr<EventsWriter> events_writer_ GUARDED_BY(mu_); std::unique_ptr<EventsWriter> events_writer_ GUARDED_BY(mu_);
std::vector<std::pair<string, SummaryMetadata>> registered_summaries_ std::vector<std::pair<string, SummaryMetadata>> registered_summaries_

View File

@ -15,8 +15,10 @@ limitations under the License.
#ifndef TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_ #ifndef TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_
#define TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_ #define TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_
#include <memory>
#include "tensorflow/core/framework/resource_mgr.h" #include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/util/event.pb.h"
namespace tensorflow { namespace tensorflow {
@ -43,6 +45,8 @@ class SummaryWriterInterface : public ResourceBase {
virtual Status WriteAudio(int64 global_step, Tensor t, const string& tag, virtual Status WriteAudio(int64 global_step, Tensor t, const string& tag,
int max_outputs_, float sample_rate) = 0; int max_outputs_, float sample_rate) = 0;
virtual Status WriteEvent(std::unique_ptr<Event> e) = 0;
}; };
// Creates a SummaryWriterInterface instance which writes to a file. It will // Creates a SummaryWriterInterface instance which writes to a file. It will

View File

@ -12,11 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/core/kernels/summary_interface.h"
#include <vector>
#include "tensorflow/core/framework/summary.pb.h" #include "tensorflow/core/framework/summary.pb.h"
#include "tensorflow/core/kernels/summary_interface.h"
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/io/path.h"
@ -43,8 +41,8 @@ class SummaryInterfaceTest : public ::testing::Test {
protected: protected:
Status SummaryTestHelper( Status SummaryTestHelper(
const string& test_name, const string& test_name,
std::function<Status(SummaryWriterInterface*)> writer_fn, const std::function<Status(SummaryWriterInterface*)>& writer_fn,
std::function<void(const Event&)> test_fn) { const std::function<void(const Event&)>& test_fn) {
static std::set<string>* tests = new std::set<string>(); static std::set<string>* tests = new std::set<string>();
CHECK(tests->insert(test_name).second) << ": " << test_name; CHECK(tests->insert(test_name).second) << ": " << test_name;
@ -182,6 +180,24 @@ TEST_F(SummaryInterfaceTest, WriteAudio) {
})); }));
} }
TEST_F(SummaryInterfaceTest, WriteEvent) {
TF_CHECK_OK(
SummaryTestHelper("event_test",
[](SummaryWriterInterface* writer) {
std::unique_ptr<Event> e{new Event};
e->set_step(7);
e->mutable_summary()->add_value()->set_tag("hi");
TF_RETURN_IF_ERROR(writer->WriteEvent(std::move(e)));
TF_RETURN_IF_ERROR(writer->Flush());
return Status::OK();
},
[](const Event& e) {
EXPECT_EQ(e.step(), 7);
CHECK_EQ(e.summary().value_size(), 1);
EXPECT_EQ(e.summary().value(0).tag(), "hi");
}));
}
TEST_F(SummaryInterfaceTest, WallTime) { TEST_F(SummaryInterfaceTest, WallTime) {
env_.AdvanceByMillis(7023); env_.AdvanceByMillis(7023);
TF_CHECK_OK(SummaryTestHelper( TF_CHECK_OK(SummaryTestHelper(