mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Add WriteEvent method to SummaryWriterInterface
Another change will follow that adds an op for this method. It will be useful for loading event logs into other types of summary writer implementations, like a database. This change might also make the new summary file writer go faster, due to less memory copying. PiperOrigin-RevId: 173640116
This commit is contained in:
parent
a494558127
commit
9c8a520b07
|
|
@ -12,6 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/core/kernels/summary_interface.h"
|
||||
|
||||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
|
|
@ -19,12 +22,10 @@ limitations under the License.
|
|||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/kernels/summary_interface.h"
|
||||
#include "tensorflow/core/lib/histogram/histogram.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/png/png_io.h"
|
||||
#include "tensorflow/core/lib/wav/wav_io.h"
|
||||
#include "tensorflow/core/util/event.pb.h"
|
||||
#include "tensorflow/core/util/events_writer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
|
@ -250,28 +251,34 @@ class SummaryWriterImpl : public SummaryWriterInterface {
|
|||
|
||||
Status WriteTensor(int64 global_step, Tensor t, const string& tag,
|
||||
const string& serialized_metadata) override {
|
||||
Summary s;
|
||||
Summary::Value* v = s.add_value();
|
||||
std::unique_ptr<Event> e{new Event};
|
||||
e->set_step(global_step);
|
||||
e->set_wall_time(GetWallTime());
|
||||
Summary::Value* v = e->mutable_summary()->add_value();
|
||||
t.AsProtoTensorContent(v->mutable_tensor());
|
||||
v->set_tag(tag);
|
||||
v->mutable_metadata()->ParseFromString(serialized_metadata);
|
||||
return Enqueue(global_step, s);
|
||||
return WriteEvent(std::move(e));
|
||||
}
|
||||
|
||||
Status WriteScalar(int64 global_step, Tensor t, const string& tag) override {
|
||||
Summary s;
|
||||
Summary::Value* v = s.add_value();
|
||||
std::unique_ptr<Event> e{new Event};
|
||||
e->set_step(global_step);
|
||||
e->set_wall_time(GetWallTime());
|
||||
Summary::Value* v = e->mutable_summary()->add_value();
|
||||
v->set_tag(tag);
|
||||
float value;
|
||||
TF_RETURN_IF_ERROR(TensorValueAt<float>(t, 0, &value));
|
||||
v->set_simple_value(value);
|
||||
return Enqueue(global_step, s);
|
||||
return WriteEvent(std::move(e));
|
||||
}
|
||||
|
||||
Status WriteHistogram(int64 global_step, Tensor t,
|
||||
const string& tag) override {
|
||||
Summary s;
|
||||
Summary::Value* v = s.add_value();
|
||||
std::unique_ptr<Event> e{new Event};
|
||||
e->set_step(global_step);
|
||||
e->set_wall_time(GetWallTime());
|
||||
Summary::Value* v = e->mutable_summary()->add_value();
|
||||
v->set_tag(tag);
|
||||
histogram::Histogram histo;
|
||||
for (int64 i = 0; i < t.NumElements(); i++) {
|
||||
|
|
@ -287,7 +294,7 @@ class SummaryWriterImpl : public SummaryWriterInterface {
|
|||
}
|
||||
|
||||
histo.EncodeToProto(v->mutable_histo(), false /* Drop zero buckets */);
|
||||
return Enqueue(global_step, s);
|
||||
return WriteEvent(std::move(e));
|
||||
}
|
||||
|
||||
Status WriteImage(int64 global_step, Tensor tensor, const string& tag,
|
||||
|
|
@ -306,7 +313,10 @@ class SummaryWriterImpl : public SummaryWriterInterface {
|
|||
return errors::InvalidArgument("Tensor too large for summary ",
|
||||
tensor.shape().DebugString());
|
||||
}
|
||||
Summary s;
|
||||
std::unique_ptr<Event> e{new Event};
|
||||
e->set_step(global_step);
|
||||
e->set_wall_time(GetWallTime());
|
||||
Summary* s = e->mutable_summary();
|
||||
// The casts and h * w cannot overflow because of the limits above.
|
||||
const int batch_size = static_cast<int>(tensor.dim_size(0));
|
||||
const int h = static_cast<int>(tensor.dim_size(1));
|
||||
|
|
@ -321,20 +331,20 @@ class SummaryWriterImpl : public SummaryWriterInterface {
|
|||
&values(i, 0, 0), Eigen::DSizes<Eigen::DenseIndex, 2>(hw, depth));
|
||||
};
|
||||
TF_RETURN_IF_ERROR(
|
||||
AddImages(tag, max_images, batch_size, w, h, depth, ith_image, &s));
|
||||
AddImages(tag, max_images, batch_size, w, h, depth, ith_image, s));
|
||||
} else if (tensor.dtype() == DT_HALF) {
|
||||
TF_RETURN_IF_ERROR(NormalizeAndAddImages<Eigen::half>(
|
||||
tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, &s));
|
||||
tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, s));
|
||||
} else if (tensor.dtype() == DT_FLOAT) {
|
||||
TF_RETURN_IF_ERROR(NormalizeAndAddImages<float>(
|
||||
tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, &s));
|
||||
tensor, max_images, h, w, hw, depth, batch_size, tag, bad_color, s));
|
||||
} else {
|
||||
return errors::InvalidArgument(
|
||||
"Only DT_INT8, DT_HALF, and DT_FLOAT images are supported. Got ",
|
||||
DataTypeString(tensor.dtype()));
|
||||
}
|
||||
|
||||
return Enqueue(global_step, s);
|
||||
return WriteEvent(std::move(e));
|
||||
}
|
||||
|
||||
Status WriteAudio(int64 global_step, Tensor tensor, const string& tag,
|
||||
|
|
@ -346,10 +356,13 @@ class SummaryWriterImpl : public SummaryWriterInterface {
|
|||
const int64 length_frames = tensor.dim_size(1);
|
||||
const int64 num_channels =
|
||||
tensor.dims() == 2 ? 1 : tensor.dim_size(tensor.dims() - 1);
|
||||
Summary s;
|
||||
std::unique_ptr<Event> e{new Event};
|
||||
e->set_step(global_step);
|
||||
e->set_wall_time(GetWallTime());
|
||||
Summary* s = e->mutable_summary();
|
||||
const int N = std::min<int>(max_outputs, batch_size);
|
||||
for (int i = 0; i < N; ++i) {
|
||||
Summary::Value* v = s.add_value();
|
||||
Summary::Value* v = s->add_value();
|
||||
if (max_outputs > 1) {
|
||||
v->set_tag(strings::StrCat(tag, "/audio/", i));
|
||||
} else {
|
||||
|
|
@ -375,16 +388,12 @@ class SummaryWriterImpl : public SummaryWriterInterface {
|
|||
channels_by_frames.data(), sample_rate_truncated, num_channels,
|
||||
length_frames, sa->mutable_encoded_audio_string()));
|
||||
}
|
||||
|
||||
return Enqueue(global_step, s);
|
||||
return WriteEvent(std::move(e));
|
||||
}
|
||||
|
||||
string DebugString() override { return "SummaryWriterImpl"; }
|
||||
|
||||
private:
|
||||
Status Enqueue(int64 global_step, const Summary& summary) {
|
||||
Status WriteEvent(std::unique_ptr<Event> event) override {
|
||||
mutex_lock ml(mu_);
|
||||
queue_.emplace_back(global_step, summary, env_->NowMicros());
|
||||
queue_.emplace_back(std::move(event));
|
||||
if (queue_.size() >= max_queue_ ||
|
||||
env_->NowMicros() - last_flush_ > 1000 * flush_millis_) {
|
||||
return InternalFlush();
|
||||
|
|
@ -392,13 +401,16 @@ class SummaryWriterImpl : public SummaryWriterInterface {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
string DebugString() override { return "SummaryWriterImpl"; }
|
||||
|
||||
private:
|
||||
double GetWallTime() {
|
||||
return static_cast<double>(env_->NowMicros()) / 1.0e6;
|
||||
}
|
||||
|
||||
Status InternalFlush() EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
for (const EventInfo& e : queue_) {
|
||||
Event event;
|
||||
event.set_step(std::get<0>(e));
|
||||
*event.mutable_summary() = std::get<1>(e);
|
||||
event.set_wall_time(static_cast<double>(std::get<2>(e)) / 1.0e6);
|
||||
events_writer_->WriteEvent(event);
|
||||
for (const std::unique_ptr<Event>& e : queue_) {
|
||||
events_writer_->WriteEvent(*e);
|
||||
}
|
||||
queue_.clear();
|
||||
if (!events_writer_->Flush()) {
|
||||
|
|
@ -413,9 +425,8 @@ class SummaryWriterImpl : public SummaryWriterInterface {
|
|||
const int flush_millis_;
|
||||
uint64 last_flush_;
|
||||
Env* env_;
|
||||
using EventInfo = std::tuple<int64, Summary, int64>;
|
||||
mutex mu_;
|
||||
std::vector<EventInfo> queue_ GUARDED_BY(mu_);
|
||||
std::vector<std::unique_ptr<Event>> queue_ GUARDED_BY(mu_);
|
||||
// A pointer to allow deferred construction.
|
||||
std::unique_ptr<EventsWriter> events_writer_ GUARDED_BY(mu_);
|
||||
std::vector<std::pair<string, SummaryMetadata>> registered_summaries_
|
||||
|
|
|
|||
|
|
@ -15,8 +15,10 @@ limitations under the License.
|
|||
#ifndef TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_SUMMARY_INTERFACE_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/core/framework/resource_mgr.h"
|
||||
#include "tensorflow/core/util/event.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
|
@ -43,6 +45,8 @@ class SummaryWriterInterface : public ResourceBase {
|
|||
|
||||
virtual Status WriteAudio(int64 global_step, Tensor t, const string& tag,
|
||||
int max_outputs_, float sample_rate) = 0;
|
||||
|
||||
virtual Status WriteEvent(std::unique_ptr<Event> e) = 0;
|
||||
};
|
||||
|
||||
// Creates a SummaryWriterInterface instance which writes to a file. It will
|
||||
|
|
|
|||
|
|
@ -12,11 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <vector>
|
||||
#include "tensorflow/core/kernels/summary_interface.h"
|
||||
|
||||
#include "tensorflow/core/framework/summary.pb.h"
|
||||
#include "tensorflow/core/kernels/summary_interface.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/core/refcount.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
|
|
@ -43,8 +41,8 @@ class SummaryInterfaceTest : public ::testing::Test {
|
|||
protected:
|
||||
Status SummaryTestHelper(
|
||||
const string& test_name,
|
||||
std::function<Status(SummaryWriterInterface*)> writer_fn,
|
||||
std::function<void(const Event&)> test_fn) {
|
||||
const std::function<Status(SummaryWriterInterface*)>& writer_fn,
|
||||
const std::function<void(const Event&)>& test_fn) {
|
||||
static std::set<string>* tests = new std::set<string>();
|
||||
CHECK(tests->insert(test_name).second) << ": " << test_name;
|
||||
|
||||
|
|
@ -182,6 +180,24 @@ TEST_F(SummaryInterfaceTest, WriteAudio) {
|
|||
}));
|
||||
}
|
||||
|
||||
TEST_F(SummaryInterfaceTest, WriteEvent) {
|
||||
TF_CHECK_OK(
|
||||
SummaryTestHelper("event_test",
|
||||
[](SummaryWriterInterface* writer) {
|
||||
std::unique_ptr<Event> e{new Event};
|
||||
e->set_step(7);
|
||||
e->mutable_summary()->add_value()->set_tag("hi");
|
||||
TF_RETURN_IF_ERROR(writer->WriteEvent(std::move(e)));
|
||||
TF_RETURN_IF_ERROR(writer->Flush());
|
||||
return Status::OK();
|
||||
},
|
||||
[](const Event& e) {
|
||||
EXPECT_EQ(e.step(), 7);
|
||||
CHECK_EQ(e.summary().value_size(), 1);
|
||||
EXPECT_EQ(e.summary().value(0).tag(), "hi");
|
||||
}));
|
||||
}
|
||||
|
||||
TEST_F(SummaryInterfaceTest, WallTime) {
|
||||
env_.AdvanceByMillis(7023);
|
||||
TF_CHECK_OK(SummaryTestHelper(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user