#tf-data-service register_dataset accepts a dataset_id.

If provided, tf.data service will use the provided ID for the dataset.
If the dataset ID already exists, no new dataset will be registered.
This is useful if multiple training jobs need to use the same dataset
for training. In this case, users should call `register_dataset` with
the same `dataset_id`.

PiperOrigin-RevId: 462690057
This commit is contained in:
Yang Chen 2022-07-22 13:21:18 -07:00 committed by TensorFlower Gardener
parent 2347ee9015
commit 76aeaa7b99
26 changed files with 474 additions and 78 deletions

View File

@ -95,6 +95,12 @@
same dataset. See
https://www.tensorflow.org/api_docs/python/tf/data/experimental/service#sharing_tfdata_service_with_concurrent_trainers
for more details.
* Added `dataset_id` to `tf.data.experimental.service.register_dataset`.
If provided, tf.data service will use the provided ID for the dataset.
If the dataset ID already exists, no new dataset will be registered.
This is useful if multiple training jobs need to use the same dataset
for training. In this case, users should call `register_dataset` with
the same `dataset_id`.
* Added a new field, `inject_prefetch`, to
`tf.data.experimental.OptimizationOptions`. If it is set to `True`,
tf.data will now automatically add a `prefetch` transformation to

View File

@ -356,7 +356,6 @@ tf_cc_test(
":dispatcher_client",
":test_cluster",
":test_util",
"@com_google_absl//absl/types:optional",
"//tensorflow/core/platform:status_matchers",
"//tensorflow/core:framework",
"//tensorflow/core:protos_all_cc",
@ -407,6 +406,7 @@ cc_library(
"//tensorflow/core/platform:mutex",
"//tensorflow/core/platform:path",
"//tensorflow/core/platform:status",
"//tensorflow/core/platform:statusor",
"//tensorflow/core/platform:strcat",
"//tensorflow/core/platform:thread_annotations",
] + tf_grpc_cc_dependencies(),

View File

@ -70,12 +70,17 @@ message GetVersionResponse {
int64 version = 1;
}
// Next tag: 4
// Next tag: 5
message GetOrRegisterDatasetRequest {
// The dataset to register.
DatasetDef dataset = 1;
// Metadata related to tf.data service.
DataServiceMetadata metadata = 3;
oneof optional_dataset_id {
// If provided, tf.data service will register the dataset with the specified
// ID. Otherwise, it will generate a unique dataset ID.
string dataset_id = 4;
}
reserved 2;
}

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <limits>
#include <memory>
#include <optional>
#include <string>
#include <vector>
@ -116,11 +117,15 @@ Status DataServiceDispatcherClient::GetSplit(int64_t iteration_id,
Status DataServiceDispatcherClient::RegisterDataset(
const DatasetDef& dataset, const DataServiceMetadata& metadata,
const std::optional<std::string>& requested_dataset_id,
std::string& dataset_id) {
TF_RETURN_IF_ERROR(EnsureInitialized());
GetOrRegisterDatasetRequest req;
*req.mutable_dataset() = dataset;
*req.mutable_metadata() = metadata;
if (requested_dataset_id.has_value()) {
req.set_dataset_id(*requested_dataset_id);
}
GetOrRegisterDatasetResponse resp;
grpc::ClientContext client_ctx;

View File

@ -70,6 +70,7 @@ class DataServiceDispatcherClient : public DataServiceClientBase {
// dataset id in `dataset_id`.
Status RegisterDataset(const DatasetDef& dataset,
const DataServiceMetadata& metadata,
const std::optional<std::string>& requested_dataset_id,
std::string& dataset_id);
// If `job_name` is set, looks up a job matching `job_name`.

View File

@ -19,7 +19,6 @@ limitations under the License.
#include <optional>
#include <string>
#include "absl/types/optional.h"
#include "tensorflow/core/data/service/common.pb.h"
#include "tensorflow/core/data/service/data_transfer.h"
#include "tensorflow/core/data/service/test_cluster.h"
@ -38,6 +37,8 @@ namespace data {
namespace {
using ::tensorflow::data::testing::EqualsProto;
using ::tensorflow::data::testing::InfiniteDataset;
using ::tensorflow::data::testing::RangeDataset;
using ::tensorflow::testing::StatusIs;
using ::testing::AllOf;
using ::testing::HasSubstr;
@ -54,11 +55,12 @@ class DispatcherClientTest : public ::testing::Test {
}
// Creates a dataset and returns the dataset ID.
StatusOr<std::string> RegisterDataset(const DataServiceMetadata& metadata) {
const auto dataset_def = testing::RangeDataset(10);
StatusOr<std::string> RegisterDataset(
const DatasetDef& dataset, const DataServiceMetadata& metadata,
const std::optional<std::string>& requested_dataset_id = std::nullopt) {
std::string dataset_id;
TF_RETURN_IF_ERROR(
dispatcher_client_->RegisterDataset(dataset_def, metadata, dataset_id));
TF_RETURN_IF_ERROR(dispatcher_client_->RegisterDataset(
dataset, metadata, requested_dataset_id, dataset_id));
return dataset_id;
}
@ -70,9 +72,9 @@ TEST_F(DispatcherClientTest, GetDataServiceMetadata) {
DataServiceMetadata metadata;
metadata.set_element_spec("encoded_element_spec");
metadata.set_compression(DataServiceMetadata::COMPRESSION_SNAPPY);
metadata.set_cardinality(kInfiniteCardinality);
metadata.set_cardinality(10);
TF_ASSERT_OK_AND_ASSIGN(const std::string dataset_id,
RegisterDataset(metadata));
RegisterDataset(RangeDataset(10), metadata));
DataServiceMetadata result;
TF_ASSERT_OK(dispatcher_client_->GetDataServiceMetadata(dataset_id, result));
@ -93,13 +95,54 @@ TEST_F(DispatcherClientTest, GetDataServiceConfig) {
EXPECT_EQ(config.deployment_mode(), DEPLOYMENT_MODE_COLOCATED);
}
TEST_F(DispatcherClientTest, RegisterDatasetWithExplicitId) {
DataServiceMetadata metadata;
metadata.set_element_spec("encoded_element_spec");
metadata.set_compression(DataServiceMetadata::COMPRESSION_SNAPPY);
metadata.set_cardinality(10);
TF_ASSERT_OK_AND_ASSIGN(
const std::string dataset_id1,
RegisterDataset(RangeDataset(10), metadata,
/*requested_dataset_id=*/"dataset_id"));
EXPECT_EQ(dataset_id1, "dataset_id");
// Registers a dataset with the same dataset ID.
TF_ASSERT_OK_AND_ASSIGN(
const std::string dataset_id2,
RegisterDataset(RangeDataset(10), metadata,
/*requested_dataset_id=*/"dataset_id"));
EXPECT_EQ(dataset_id1, dataset_id2);
}
TEST_F(DispatcherClientTest, DatasetsDoNotMatch) {
DataServiceMetadata metadata;
metadata.set_element_spec("encoded_element_spec");
metadata.set_compression(DataServiceMetadata::COMPRESSION_SNAPPY);
metadata.set_cardinality(10);
TF_ASSERT_OK_AND_ASSIGN(
const std::string dataset_id1,
RegisterDataset(RangeDataset(10), metadata,
/*requested_dataset_id=*/"dataset_id"));
EXPECT_EQ(dataset_id1, "dataset_id");
// Registers a dataset with the same dataset ID but different metadata.
metadata.set_cardinality(kInfiniteCardinality);
EXPECT_THAT(
RegisterDataset(InfiniteDataset(), metadata,
/*requested_dataset_id=*/"dataset_id"),
StatusIs(
error::INVALID_ARGUMENT,
HasSubstr(
"Datasets with the same ID should have the same structure")));
}
TEST_F(DispatcherClientTest, EnableCrossTrainerCache) {
DataServiceMetadata metadata;
metadata.set_element_spec("encoded_element_spec");
metadata.set_compression(DataServiceMetadata::COMPRESSION_SNAPPY);
metadata.set_cardinality(kInfiniteCardinality);
TF_ASSERT_OK_AND_ASSIGN(const std::string dataset_id,
RegisterDataset(metadata));
RegisterDataset(InfiniteDataset(), metadata));
ProcessingModeDef processing_mode;
processing_mode.set_sharding_policy(ProcessingModeDef::OFF);
@ -126,9 +169,9 @@ TEST_F(DispatcherClientTest, CreateNamedJob) {
DataServiceMetadata metadata;
metadata.set_element_spec("encoded_element_spec");
metadata.set_compression(DataServiceMetadata::COMPRESSION_SNAPPY);
metadata.set_cardinality(kInfiniteCardinality);
metadata.set_cardinality(10);
TF_ASSERT_OK_AND_ASSIGN(const std::string dataset_id,
RegisterDataset(metadata));
RegisterDataset(RangeDataset(10), metadata));
ProcessingModeDef processing_mode;
processing_mode.set_sharding_policy(ProcessingModeDef::OFF);
@ -136,14 +179,14 @@ TEST_F(DispatcherClientTest, CreateNamedJob) {
int64_t job_id_1 = -1;
TF_ASSERT_OK(dispatcher_client_->GetOrCreateJob(
dataset_id, processing_mode, job_name,
/*num_consumers=*/absl::nullopt,
/*num_consumers=*/std::nullopt,
/*use_cross_trainer_cache=*/true, TARGET_WORKERS_AUTO, job_id_1));
int64_t job_id_2 = -2;
// Creating the same job should succeed and receive the same job id.
TF_ASSERT_OK(dispatcher_client_->GetOrCreateJob(
dataset_id, processing_mode, job_name,
/*num_consumers=*/absl::nullopt,
/*num_consumers=*/std::nullopt,
/*use_cross_trainer_cache=*/true, TARGET_WORKERS_AUTO, job_id_2));
ASSERT_EQ(job_id_1, job_id_2);
}
@ -152,9 +195,9 @@ TEST_F(DispatcherClientTest, NamedJobsDoNotMatch) {
DataServiceMetadata metadata;
metadata.set_element_spec("encoded_element_spec");
metadata.set_compression(DataServiceMetadata::COMPRESSION_SNAPPY);
metadata.set_cardinality(kInfiniteCardinality);
metadata.set_cardinality(10);
TF_ASSERT_OK_AND_ASSIGN(const std::string dataset_id,
RegisterDataset(metadata));
RegisterDataset(RangeDataset(10), metadata));
int64_t job_id = 0;
ProcessingModeDef processing_mode;
@ -162,7 +205,7 @@ TEST_F(DispatcherClientTest, NamedJobsDoNotMatch) {
std::string job_name = "job";
TF_ASSERT_OK(dispatcher_client_->GetOrCreateJob(
dataset_id, processing_mode, job_name,
/*num_consumers=*/absl::nullopt,
/*num_consumers=*/std::nullopt,
/*use_cross_trainer_cache=*/false, TARGET_WORKERS_AUTO, job_id));
// Creating the same iteration with a different argument should fail.

View File

@ -17,6 +17,7 @@ limitations under the License.
#include <algorithm>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>
@ -57,6 +58,7 @@ limitations under the License.
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/random.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/strcat.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/protobuf/data_service.pb.h"
@ -483,36 +485,83 @@ Status DataServiceDispatcherImpl::GetOrRegisterDataset(
absl::StrCat("Registering dataset graph: ", graph->DebugString()));
mutex_lock l(mu_);
std::shared_ptr<const Dataset> dataset;
Status s = state_.DatasetFromFingerprint(fingerprint, dataset);
if (s.ok()) {
std::string dataset_id = dataset->dataset_id;
VLOG(3) << "Received duplicate RegisterDataset request with fingerprint "
<< fingerprint << ". Returning id " << dataset_id;
response->set_dataset_id(dataset_id);
return OkStatus();
} else if (!errors::IsNotFound(s)) {
return s;
TF_ASSIGN_OR_RETURN(std::optional<std::string> dataset_id,
FindDataset(*request, fingerprint));
if (dataset_id.has_value()) {
VLOG(3) << "RegisterDataset returns an existing dataset with ID = "
<< *dataset_id << ", fingerprint = " << fingerprint << ".";
response->set_dataset_id(*dataset_id);
return Status::OK();
}
std::string dataset_id;
std::string new_dataset_id;
TF_RETURN_IF_ERROR(RegisterDataset(fingerprint, dataset_def,
request->metadata(), dataset_id));
response->set_dataset_id(dataset_id);
VLOG(3) << "Registered new dataset with id " << dataset_id;
request->metadata(), request->dataset_id(),
new_dataset_id));
response->set_dataset_id(new_dataset_id);
VLOG(3) << "Registered new dataset with id " << new_dataset_id;
return OkStatus();
}
StatusOr<std::optional<std::string>> DataServiceDispatcherImpl::FindDataset(
const GetOrRegisterDatasetRequest& request, uint64 fingerprint)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
std::shared_ptr<const Dataset> existing_dataset;
Status status;
// TODO(b/236725000): Stop supporting fingerprint-based deduping. This becomes
// unreliable due to nondeterminism in the dataset graphdef generation. The
// users should provide a `dataset_id` to dedupe the dataset instead.
if (request.dataset_id().empty()) {
status = state_.DatasetFromFingerprint(fingerprint, existing_dataset);
} else {
status = state_.DatasetFromId(request.dataset_id(), existing_dataset);
}
if (errors::IsNotFound(status)) {
return std::optional<std::string>();
}
TF_RETURN_IF_ERROR(status);
if (!request.dataset_id().empty()) {
TF_RETURN_IF_ERROR(ValidateMatchingDataset(
request.dataset_id(), request.metadata(), existing_dataset->metadata));
}
return std::optional<std::string>(existing_dataset->dataset_id);
}
Status DataServiceDispatcherImpl::ValidateMatchingDataset(
const std::string& dataset_id, const DataServiceMetadata& new_metadata,
const DataServiceMetadata& old_metadata) TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
MessageDifferencer differ;
differ.set_message_field_comparison(MessageDifferencer::EQUIVALENT);
differ.set_repeated_field_comparison(MessageDifferencer::AS_SET);
std::string diff;
differ.ReportDifferencesToString(&diff);
bool equivalent = differ.Compare(new_metadata, old_metadata);
if (!equivalent) {
return errors::InvalidArgument(
"Datasets with the same ID should have the same structure, got ",
"diff for dataset ID ", dataset_id, ": ", diff, ". To fix this error, ",
"make sure you're registering the same dataset with the same ID.");
}
return Status::OK();
}
Status DataServiceDispatcherImpl::RegisterDataset(
uint64 fingerprint, const DatasetDef& dataset,
const DataServiceMetadata& metadata, std::string& dataset_id)
const DataServiceMetadata& metadata,
const std::string& requested_dataset_id, std::string& dataset_id)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
dataset_id = state_.NextAvailableDatasetId();
dataset_id = requested_dataset_id;
if (dataset_id.empty()) {
dataset_id = state_.NextAvailableDatasetId();
}
Update update;
RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
register_dataset->set_dataset_id(dataset_id);
register_dataset->set_fingerprint(fingerprint);
*register_dataset->mutable_metadata() = metadata;
register_dataset->set_dedupe_by_dataset_id(!requested_dataset_id.empty());
TF_RETURN_IF_ERROR(
dataset_store_->Put(DatasetKey(dataset_id, fingerprint), dataset));
return Apply(update);

View File

@ -17,6 +17,7 @@ limitations under the License.
#define TENSORFLOW_CORE_DATA_SERVICE_DISPATCHER_IMPL_H_
#include <memory>
#include <optional>
#include <string>
#include <vector>
@ -37,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/statusor.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/protobuf/data_service.pb.h"
#include "tensorflow/core/protobuf/service_config.pb.h"
@ -195,8 +197,19 @@ class DataServiceDispatcherImpl {
// id in `dataset_id`.
Status RegisterDataset(uint64 fingerprint, const DatasetDef& dataset,
const DataServiceMetadata& metadata,
const std::string& requested_dataset_id,
std::string& dataset_id)
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_);
// Finds the dataset ID with the requested dataset ID, or with the matching
// fingerprint if the ID does not exist. Returns nullptr if no such dataset
// exists.
StatusOr<std::optional<std::string>> FindDataset(
const GetOrRegisterDatasetRequest& request, uint64 fingerprint);
// Verifies the datasets with the same ID have the same metadata. If the
// metadata differs, returns an invalid argument error.
Status ValidateMatchingDataset(const std::string& dataset_id,
const DataServiceMetadata& new_metadata,
const DataServiceMetadata& old_metadata);
// Gets a worker's stub from `worker_stubs_`, or if none exists, creates a
// stub and stores it in `worker_stubs_`. A borrowed pointer to the stub is
// stored in `out_stub`.

View File

@ -98,8 +98,14 @@ void DispatcherState::RegisterDataset(
register_dataset.metadata());
DCHECK(!datasets_by_id_.contains(dataset_id));
datasets_by_id_[dataset_id] = dataset;
DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
datasets_by_fingerprint_[fingerprint] = dataset;
if (!register_dataset.dedupe_by_dataset_id()) {
// Only stores the fingerprint if the user has not requested a dataset ID.
// If the user has requested a dataset ID, we will look up datasets by their
// IDs, not by fingerprints. Otherwise, an anonymous dataset can refer to
// a dataset with an explicit dataset ID.
DCHECK(!datasets_by_fingerprint_.contains(fingerprint));
datasets_by_fingerprint_[fingerprint] = dataset;
}
UpdateNextAvailableDatasetId();
}

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/data/service/dispatcher_state.h"
#include <cstdint>
#include <memory>
#include <string>
@ -54,19 +55,28 @@ Status RegisterDataset(const std::string& dataset_id, uint64 fingerprint,
RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
register_dataset->set_dataset_id(dataset_id);
register_dataset->set_fingerprint(fingerprint);
TF_RETURN_IF_ERROR(state.Apply(update));
return OkStatus();
register_dataset->set_dedupe_by_dataset_id(false);
return state.Apply(update);
}
Status RegisterDataset(const std::string& dataset_id, DispatcherState& state) {
return RegisterDataset(dataset_id, /*fingerprint=*/1, state);
}
Status RegisterDataset(const std::string& dataset_id, uint64_t fingerprint,
bool dedupe_by_dataset_id, DispatcherState& state) {
Update update;
RegisterDatasetUpdate* register_dataset = update.mutable_register_dataset();
register_dataset->set_dataset_id(dataset_id);
register_dataset->set_fingerprint(fingerprint);
register_dataset->set_dedupe_by_dataset_id(dedupe_by_dataset_id);
return state.Apply(update);
}
Status RegisterWorker(std::string worker_address, DispatcherState& state) {
Update update;
update.mutable_register_worker()->set_worker_address(worker_address);
TF_RETURN_IF_ERROR(state.Apply(update));
return OkStatus();
return state.Apply(update);
}
Status CreateJob(int64_t job_id, const std::string& dataset_id,
@ -76,8 +86,7 @@ Status CreateJob(int64_t job_id, const std::string& dataset_id,
create_job->set_job_id(job_id);
create_job->set_dataset_id(dataset_id);
create_job->set_job_name(job_name);
TF_RETURN_IF_ERROR(state.Apply(update));
return Status::OK();
return state.Apply(update);
}
Status CreateIteration(int64_t iteration_id, const std::string& dataset_id,
@ -91,8 +100,7 @@ Status CreateIteration(int64_t iteration_id, const std::string& dataset_id,
create_iteration->set_job_id(job_id);
create_iteration->set_iteration_id(iteration_id);
create_iteration->set_repetition(named_iteration_key.repetition);
TF_RETURN_IF_ERROR(state.Apply(update));
return OkStatus();
return state.Apply(update);
}
Status CreateIteration(int64_t iteration_id, const std::string& dataset_id,
@ -109,8 +117,7 @@ Status AcquireIterationClientId(int64_t iteration_id,
update.mutable_acquire_iteration_client();
acquire_iteration_client->set_iteration_id(iteration_id);
acquire_iteration_client->set_iteration_client_id(iteration_client_id);
TF_RETURN_IF_ERROR(state.Apply(update));
return OkStatus();
return state.Apply(update);
}
Status ReleaseIterationClientId(int64_t iteration_client_id,
@ -120,8 +127,7 @@ Status ReleaseIterationClientId(int64_t iteration_client_id,
update.mutable_release_iteration_client();
release_iteration_client->set_iteration_client_id(iteration_client_id);
release_iteration_client->set_time_micros(release_time);
TF_RETURN_IF_ERROR(state.Apply(update));
return OkStatus();
return state.Apply(update);
}
Status CreateTask(int64_t task_id, int64_t iteration_id,
@ -131,17 +137,16 @@ Status CreateTask(int64_t task_id, int64_t iteration_id,
create_task->set_task_id(task_id);
create_task->set_iteration_id(iteration_id);
create_task->set_worker_address(worker_address);
TF_RETURN_IF_ERROR(state.Apply(update));
return OkStatus();
return state.Apply(update);
}
Status FinishTask(int64_t task_id, DispatcherState& state) {
Update update;
FinishTaskUpdate* finish_task = update.mutable_finish_task();
finish_task->set_task_id(task_id);
TF_RETURN_IF_ERROR(state.Apply(update));
return OkStatus();
return state.Apply(update);
}
} // namespace
TEST(DispatcherState, RegisterDataset) {
@ -171,6 +176,51 @@ TEST(DispatcherState, RegisterDataset) {
}
}
TEST(DispatcherState, RegisterDatasetWithExplicitID) {
const uint64_t fingerprint = 20;
DispatcherState state;
TF_EXPECT_OK(RegisterDataset("dataset_id", fingerprint,
/*dedupe_by_dataset_id=*/true, state));
std::shared_ptr<const Dataset> dataset;
TF_EXPECT_OK(state.DatasetFromId("dataset_id", dataset));
EXPECT_EQ(dataset->dataset_id, "dataset_id");
// The fingerprint is not registered if the user requests an explicit ID.
EXPECT_THAT(state.DatasetFromFingerprint(fingerprint, dataset),
StatusIs(error::NOT_FOUND));
}
TEST(DispatcherState, RegisterDatasetsWithDifferentIDs) {
const uint64_t fingerprint = 20;
DispatcherState state;
TF_EXPECT_OK(RegisterDataset("dataset_id1", fingerprint,
/*dedupe_by_dataset_id=*/true, state));
TF_EXPECT_OK(RegisterDataset("dataset_id2", fingerprint,
/*dedupe_by_dataset_id=*/true, state));
std::shared_ptr<const Dataset> dataset;
TF_EXPECT_OK(state.DatasetFromId("dataset_id1", dataset));
EXPECT_EQ(dataset->dataset_id, "dataset_id1");
TF_EXPECT_OK(state.DatasetFromId("dataset_id2", dataset));
EXPECT_EQ(dataset->dataset_id, "dataset_id2");
}
TEST(DispatcherState, RegisterDatasetsWithExplicitAndAnonymousIDs) {
const uint64_t fingerprint = 20;
DispatcherState state;
TF_EXPECT_OK(RegisterDataset("dataset_id1", fingerprint,
/*dedupe_by_dataset_id=*/true, state));
TF_EXPECT_OK(RegisterDataset("dataset_id2", fingerprint,
/*dedupe_by_dataset_id=*/false, state));
std::shared_ptr<const Dataset> dataset;
TF_EXPECT_OK(state.DatasetFromId("dataset_id1", dataset));
EXPECT_EQ(dataset->dataset_id, "dataset_id1");
TF_EXPECT_OK(state.DatasetFromId("dataset_id2", dataset));
EXPECT_EQ(dataset->dataset_id, "dataset_id2");
// The fingerprint is not registered if the user requests an explicit ID. So
// the following query returns "dataset_id2".
TF_EXPECT_OK(state.DatasetFromFingerprint(fingerprint, dataset));
EXPECT_EQ(dataset->dataset_id, "dataset_id2");
}
TEST(DispatcherState, RegisterDatasetCompression) {
DispatcherState state;
const std::string dataset_id = state.NextAvailableDatasetId();

View File

@ -28,11 +28,12 @@ message Update {
reserved 13;
}
// Next tag: 4
// Next tag: 5
message RegisterDatasetUpdate {
string dataset_id = 1;
uint64 fingerprint = 2;
DataServiceMetadata metadata = 3;
bool dedupe_by_dataset_id = 4;
}
// Next tag: 5

View File

@ -181,7 +181,8 @@ StatusOr<std::string> DatasetClient<T>::RegisterDataset(
const DatasetDef& dataset) {
std::string dataset_id;
TF_RETURN_IF_ERROR(dispatcher_client_->RegisterDataset(
dataset, DataServiceMetadata(), dataset_id));
dataset, DataServiceMetadata(), /*requested_dataset_id=*/std::nullopt,
dataset_id));
return dataset_id;
}

View File

@ -151,6 +151,29 @@ DatasetDef RangeDatasetWithShardHint(const int64_t range) {
return dataset_def;
}
DatasetDef InfiniteDataset() {
DatasetDef dataset_def;
*dataset_def.mutable_graph() = GDef(
{NDef("start", "Const", /*inputs=*/{},
{{"value", AsScalar<int64_t>(0)}, {"dtype", DT_INT64}}),
NDef("stop", "Const", /*inputs=*/{},
{{"value", AsScalar<int64_t>(100000000)}, {"dtype", DT_INT64}}),
NDef("step", "Const", /*inputs=*/{},
{{"value", AsScalar<int64_t>(1)}, {"dtype", DT_INT64}}),
NDef("range", "RangeDataset", /*inputs=*/{"start", "stop", "step"},
{{"output_shapes", gtl::ArraySlice<TensorShape>{TensorShape()}},
{"output_types", gtl::ArraySlice<DataType>{DT_INT64}}}),
NDef("count", "Const", /*inputs=*/{},
{{"value", AsScalar<int64_t>(-1)}, {"dtype", DT_INT64}}),
NDef("repeat", "RepeatDataset", /*inputs=*/{"range", "count"},
{{"output_shapes", gtl::ArraySlice<TensorShape>{TensorShape()}},
{"output_types", gtl::ArraySlice<DataType>{DT_INT64}}}),
NDef("dataset", "_Retval", /*inputs=*/{"repeat"},
{{"T", DT_VARIANT}, {"index", 0}})},
{});
return dataset_def;
}
StatusOr<DatasetDef> InterleaveTextlineDataset(
const std::vector<tstring>& filenames,
const std::vector<tstring>& contents) {

View File

@ -42,6 +42,10 @@ DatasetDef RangeSquareDataset(int64_t range);
// tf.data.Dataset.range(range).shard(SHARD_HINT, SHARD_HINT).
DatasetDef RangeDatasetWithShardHint(int64_t range);
// Returns a test dataset representing
// tf.data.Dataset.range(100000000).repeat().
DatasetDef InfiniteDataset();
// Returns a test dataset representing
// tf.data.Dataset.from_tensor_slices(["filenames"]).interleave(
// lambda filepath: tf.data.TextLineDataset(filepath),

View File

@ -63,6 +63,23 @@ StatusOr<std::vector<std::vector<Tensor>>> GetIteratorOutput(
return result;
}
TEST(TestUtilTest, RangeDataset) {
const auto dataset_def = RangeDataset(/*range=*/10);
standalone::Dataset::Params params;
std::unique_ptr<standalone::Dataset> dataset;
TF_ASSERT_OK(
standalone::Dataset::FromGraph(params, dataset_def.graph(), &dataset));
std::unique_ptr<standalone::Iterator> iterator;
TF_ASSERT_OK(dataset->MakeIterator(&iterator));
TF_ASSERT_OK_AND_ASSIGN(std::vector<std::vector<Tensor>> result,
GetIteratorOutput(*iterator));
ASSERT_EQ(result.size(), 10);
for (int i = 0; i < result.size(); ++i) {
test::ExpectEqual(result[i][0], Tensor(int64_t{i}));
}
}
TEST(TestUtilTest, RangeSquareDataset) {
const auto dataset_def = RangeSquareDataset(/*range=*/10);
standalone::Dataset::Params params;
@ -80,6 +97,24 @@ TEST(TestUtilTest, RangeSquareDataset) {
}
}
TEST(TestUtilTest, InfiniteDataset) {
const auto dataset_def = InfiniteDataset();
standalone::Dataset::Params params;
std::unique_ptr<standalone::Dataset> dataset;
TF_ASSERT_OK(
standalone::Dataset::FromGraph(params, dataset_def.graph(), &dataset));
std::unique_ptr<standalone::Iterator> iterator;
TF_ASSERT_OK(dataset->MakeIterator(&iterator));
// Verifies the first 10 elements.
for (int64_t i = 0; i < 10; ++i) {
std::vector<tensorflow::Tensor> outputs;
bool end_of_sequence;
TF_ASSERT_OK(iterator->GetNext(&outputs, &end_of_sequence));
test::ExpectEqual(outputs[0], Tensor(i));
}
}
TEST(TestUtilTest, EmptyDataset) {
const auto dataset_def = RangeSquareDataset(/*range=*/0);
standalone::Dataset::Params params;

View File

@ -68,7 +68,8 @@ class WorkerClientTest : public ::testing::Test {
const auto dataset_def = RangeSquareDataset(range);
std::string dataset_id;
TF_RETURN_IF_ERROR(dispatcher_client_->RegisterDataset(
dataset_def, DataServiceMetadata(), dataset_id));
dataset_def, DataServiceMetadata(),
/*requested_dataset_id=*/std::nullopt, dataset_id));
return dataset_id;
}

View File

@ -275,7 +275,6 @@ tf_kernel_library(
"//tensorflow/core/data/service:dispatcher_client",
"//tensorflow/core/data/service:grpc_util",
"//tensorflow/core/framework:graph_proto_cc",
"//tensorflow/core/kernels/data:iterator_ops",
"//tensorflow/core/platform:env_time",
"//tensorflow/core/platform:errors",
"//tensorflow/core/platform:status",

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/experimental/data_service_ops.h"
#include <optional>
#include <string>
#include <utility>
@ -75,6 +76,13 @@ RegisterDatasetOp::RegisterDatasetOp(OpKernelConstruction* ctx)
OP_REQUIRES_OK(ctx, ctx->GetAttr(kMetadata, &serialized_metadata));
serialized_metadata_ = serialized_metadata;
}
if (ctx->HasAttr(kRequestedDatasetId)) {
tstring requested_dataset_id;
OP_REQUIRES_OK(ctx,
ctx->GetAttr(kRequestedDatasetId, &requested_dataset_id));
requested_dataset_id_ = requested_dataset_id;
}
}
void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
@ -124,14 +132,19 @@ void RegisterDatasetOp::Compute(OpKernelContext* ctx) {
}
metadata.set_cardinality(dataset->Cardinality());
std::optional<std::string> requested_dataset_id;
if (!requested_dataset_id_.empty()) {
requested_dataset_id = requested_dataset_id_;
}
DataServiceDispatcherClient client(address, protocol);
std::string dataset_id;
int64_t deadline_micros = EnvTime::NowMicros() + kRetryTimeoutMicros;
OP_REQUIRES_OK(
ctx, grpc_util::Retry(
[&]() {
return client.RegisterDataset(dataset_def, metadata,
dataset_id);
return client.RegisterDataset(
dataset_def, metadata, requested_dataset_id, dataset_id);
},
/*description=*/
strings::StrCat("register dataset with dispatcher at ", address),

View File

@ -20,7 +20,6 @@ limitations under the License.
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/kernels/data/iterator_ops.h"
namespace tensorflow {
namespace data {
@ -39,6 +38,8 @@ class RegisterDatasetOp : public OpKernel {
"external_state_policy";
static constexpr const char* const kElementSpec = "element_spec";
static constexpr const char* const kMetadata = "metadata";
static constexpr const char* const kRequestedDatasetId =
"requested_dataset_id";
static constexpr const char* const kTimeoutMs = "timeout_ms";
explicit RegisterDatasetOp(OpKernelConstruction* ctx);
@ -50,6 +51,7 @@ class RegisterDatasetOp : public OpKernel {
SerializationContext::ExternalStatePolicy external_state_policy_;
std::string element_spec_;
std::string serialized_metadata_;
std::string requested_dataset_id_;
};
} // namespace data

View File

@ -1509,6 +1509,7 @@ REGISTER_OP("RegisterDatasetV2")
.Output("dataset_id: string")
.Attr("external_state_policy: int")
.Attr("element_spec: string = ''")
.Attr("requested_dataset_id: string = ''")
.Attr("metadata: string = ''")
.SetShapeFn(shape_inference::ScalarShape);

View File

@ -830,6 +830,129 @@ class DataServiceOpsTest(data_service_test_base.TestBase,
id_2 = self.register_dataset(cluster.dispatcher_address(), ds_2)
self.assertNotEqual(self.evaluate(id_1), self.evaluate(id_2))
@combinations.generate(test_base.default_test_combinations())
def testRegisterWithExplicitDatasetId(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
dataset = dataset_ops.Dataset.range(10)
dataset_id = data_service_ops.register_dataset(
cluster.dispatcher.target, dataset, dataset_id="dataset_id")
dataset = data_service_ops.from_dataset_id(
dataset_id=dataset_id,
element_spec=dataset.element_spec,
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=cluster.dispatcher.target)
self.assertDatasetProduces(dataset, list(range(10)))
# Verifies the dataset ID is indeed "dataset_id".
dataset = data_service_ops.from_dataset_id(
dataset_id="dataset_id",
element_spec=dataset.element_spec,
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=cluster.dispatcher.target)
self.assertDatasetProduces(dataset, list(range(10)))
# Eager mode only: In the graph mode, `register_dataset` may not run before
# `from_dataset_id` if `from_dataset_id` does not use its return value.
@combinations.generate(test_base.eager_only_combinations())
def testFromRegisteredStringDatasetId(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
dataset = dataset_ops.Dataset.range(10)
_ = data_service_ops.register_dataset(
cluster.dispatcher.target, dataset, dataset_id="dataset_id")
dataset = data_service_ops.from_dataset_id(
dataset_id="dataset_id",
element_spec=dataset.element_spec,
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=cluster.dispatcher.target)
self.assertDatasetProduces(dataset, list(range(10)))
@combinations.generate(test_base.default_test_combinations())
def testRegisterSameDatasetIds(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10)
dataset_id1 = data_service_ops.register_dataset(
cluster.dispatcher.target, dataset1, dataset_id="dataset_id")
dataset_id2 = data_service_ops.register_dataset(
cluster.dispatcher.target, dataset2, dataset_id="dataset_id")
dataset1 = data_service_ops.from_dataset_id(
dataset_id=dataset_id1,
element_spec=dataset1.element_spec,
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=cluster.dispatcher.target,
job_name="job_name")
dataset2 = data_service_ops.from_dataset_id(
dataset_id=dataset_id2,
element_spec=dataset2.element_spec,
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=cluster.dispatcher.target,
job_name="job_name")
# `dataset2` is empty because the datasets share the same job and `dataset1`
# has exhausted the dataset.
self.assertDatasetProduces(dataset1, list(range(10)))
self.assertDatasetProduces(dataset2, list())
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
different_dataset_id=[None, "another_dataset_id"])))
def testRegisterDifferentDatasetIds(self, different_dataset_id):
cluster = data_service_test_base.TestCluster(num_workers=1)
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10)
dataset_id1 = data_service_ops.register_dataset(
cluster.dispatcher.target, dataset1, dataset_id="dataset_id")
dataset_id2 = data_service_ops.register_dataset(
cluster.dispatcher.target, dataset2, dataset_id=different_dataset_id)
dataset1 = data_service_ops.from_dataset_id(
dataset_id=dataset_id1,
element_spec=dataset1.element_spec,
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=cluster.dispatcher.target,
job_name="job_name")
dataset2 = data_service_ops.from_dataset_id(
dataset_id=dataset_id2,
element_spec=dataset2.element_spec,
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=cluster.dispatcher.target,
job_name="job_name")
# `dataset1` and `dataset2` are different datasets.
self.assertDatasetProduces(dataset1, list(range(10)))
self.assertDatasetProduces(dataset2, list(range(10)))
@combinations.generate(test_base.default_test_combinations())
def testDatasetsDoNotMatch(self):
cluster = data_service_test_base.TestCluster(num_workers=1)
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.from_tensor_slices(list("Test dataset."))
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"Datasets with the same ID should have the same structure"):
dataset_id1 = data_service_ops.register_dataset(
cluster.dispatcher.target,
dataset1,
compression=None,
dataset_id="dataset_id")
dataset_id2 = data_service_ops.register_dataset(
cluster.dispatcher.target,
dataset2,
compression=None,
dataset_id="dataset_id")
dataset1 = data_service_ops.from_dataset_id(
dataset_id=dataset_id1,
element_spec=dataset1.element_spec,
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=cluster.dispatcher.target)
dataset2 = data_service_ops.from_dataset_id(
dataset_id=dataset_id2,
element_spec=dataset2.element_spec,
processing_mode=data_service_ops.ShardingPolicy.OFF,
service=cluster.dispatcher.target)
self.getDatasetOutput(dataset1)
self.getDatasetOutput(dataset2)
@combinations.generate(test_base.default_test_combinations())
def testDoubleDistribute(self):
cluster = data_service_test_base.TestCluster(num_workers=1)

View File

@ -800,7 +800,7 @@ def distribute(processing_mode,
target_workers=target_workers)
def _register_dataset(service, dataset, compression):
def _register_dataset(service, dataset, compression, dataset_id=None):
"""Registers a dataset with the tf.data service.
This transformation is similar to `register_dataset`, but supports additional
@ -816,9 +816,15 @@ def _register_dataset(service, dataset, compression):
compression: How to compress the dataset's elements before transferring them
over the network. "AUTO" leaves the decision of how to compress up to the
tf.data service runtime. `None` indicates not to compress.
dataset_id: (Optional.) By default, tf.data service generates a unique
(string) ID for each registered dataset. If a `dataset_id` is provided, it
will use the specified ID. If a dataset with a matching ID already exists,
no new dataset is registered. This is useful if multiple training jobs
want to (re)use the same dataset for training. In this case, they can
register the dataset with the same dataset ID.
Returns:
A scalar tensor of the registered dataset's id.
A scalar string tensor representing the dataset ID.
"""
_validate_compression(compression)
if isinstance(service, tuple):
@ -845,22 +851,25 @@ def _register_dataset(service, dataset, compression):
element_spec=encoded_spec,
compression=_get_compression_proto(compression))
if compat.forward_compatible(2022, 8, 31):
register_dataset_op = gen_experimental_dataset_ops.register_dataset_v2
if compat.forward_compatible(2022, 8, 31) or dataset_id:
return gen_experimental_dataset_ops.register_dataset_v2(
dataset._variant_tensor, # pylint: disable=protected-access
address=address,
protocol=protocol,
external_state_policy=external_state_policy.value,
requested_dataset_id=dataset_id,
metadata=metadata.SerializeToString())
else:
register_dataset_op = gen_experimental_dataset_ops.register_dataset
dataset_id = register_dataset_op(
dataset._variant_tensor, # pylint: disable=protected-access
address=address,
protocol=protocol,
external_state_policy=external_state_policy.value,
metadata=metadata.SerializeToString())
return dataset_id
return gen_experimental_dataset_ops.register_dataset(
dataset._variant_tensor, # pylint: disable=protected-access
address=address,
protocol=protocol,
external_state_policy=external_state_policy.value,
metadata=metadata.SerializeToString())
@tf_export("data.experimental.service.register_dataset")
def register_dataset(service, dataset, compression="AUTO"):
def register_dataset(service, dataset, compression="AUTO", dataset_id=None):
"""Registers a dataset with the tf.data service.
`register_dataset` registers a dataset with the tf.data service so that
@ -901,11 +910,17 @@ def register_dataset(service, dataset, compression="AUTO"):
transferring them over the network. "AUTO" leaves the decision of how to
compress up to the tf.data service runtime. `None` indicates not to
compress.
dataset_id: (Optional.) By default, tf.data service generates a unique
(string) ID for each registered dataset. If a `dataset_id` is provided, it
will use the specified ID. If a dataset with a matching ID already exists,
no new dataset is registered. This is useful if multiple training jobs
want to (re)use the same dataset for training. In this case, they can
register the dataset with the same dataset ID.
Returns:
A scalar tensor of the registered dataset's id.
A scalar string tensor representing the dataset ID.
"""
return _register_dataset(service, dataset, compression)
return _register_dataset(service, dataset, compression, dataset_id)
def _from_dataset_id(processing_mode,

View File

@ -26,6 +26,6 @@ tf_module {
}
member_method {
name: "register_dataset"
argspec: "args=[\'service\', \'dataset\', \'compression\'], varargs=None, keywords=None, defaults=[\'AUTO\'], "
argspec: "args=[\'service\', \'dataset\', \'compression\', \'dataset_id\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\'], "
}
}

View File

@ -3530,7 +3530,7 @@ tf_module {
}
member_method {
name: "RegisterDatasetV2"
argspec: "args=[\'dataset\', \'address\', \'protocol\', \'external_state_policy\', \'element_spec\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
argspec: "args=[\'dataset\', \'address\', \'protocol\', \'external_state_policy\', \'element_spec\', \'requested_dataset_id\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
}
member_method {
name: "Relu"

View File

@ -34,6 +34,6 @@ tf_module {
}
member_method {
name: "register_dataset"
argspec: "args=[\'service\', \'dataset\', \'compression\'], varargs=None, keywords=None, defaults=[\'AUTO\'], "
argspec: "args=[\'service\', \'dataset\', \'compression\', \'dataset_id\'], varargs=None, keywords=None, defaults=[\'AUTO\', \'None\'], "
}
}

View File

@ -3530,7 +3530,7 @@ tf_module {
}
member_method {
name: "RegisterDatasetV2"
argspec: "args=[\'dataset\', \'address\', \'protocol\', \'external_state_policy\', \'element_spec\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
argspec: "args=[\'dataset\', \'address\', \'protocol\', \'external_state_policy\', \'element_spec\', \'requested_dataset_id\', \'metadata\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'\', \'None\'], "
}
member_method {
name: "Relu"