mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
[TF:XLA] Use HloEvaluator for ComputeConstant, remove the need of a dedicated
compute constant backend. PiperOrigin-RevId: 164940970
This commit is contained in:
parent
eeacdcdb14
commit
87605f3d6a
|
|
@ -119,22 +119,15 @@ Status XlaOpKernelContext::ConstantInputReshaped(
|
||||||
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
|
xla::Layout layout = xla::LayoutUtil::MakeLayout(layout_indices);
|
||||||
|
|
||||||
// Ask the XLA compiler to evaluate the data handle to a literal.
|
// Ask the XLA compiler to evaluate the data handle to a literal.
|
||||||
xla::StatusOr<std::unique_ptr<xla::GlobalData>> computed =
|
xla::StatusOr<std::unique_ptr<xla::Literal>> computed =
|
||||||
builder()->ComputeConstant(handle, &layout);
|
builder()->ComputeConstant(handle, &layout);
|
||||||
if (!computed.ok()) {
|
if (!computed.ok()) {
|
||||||
return errors::InvalidArgument(
|
return errors::InvalidArgument(
|
||||||
"Error evaluating ", context_->op_kernel().name(), " input ", index,
|
"Error evaluating ", context_->op_kernel().name(), " input ", index,
|
||||||
": ", computed.status().error_message());
|
": ", computed.status().error_message());
|
||||||
}
|
}
|
||||||
// Fetch the literal from the compiler service.
|
constant_literal->Swap(computed.ValueOrDie().get());
|
||||||
xla::StatusOr<std::unique_ptr<xla::Literal>> constant =
|
|
||||||
builder()->client()->Transfer(*computed.ValueOrDie());
|
|
||||||
if (!constant.ok()) {
|
|
||||||
return errors::InvalidArgument(
|
|
||||||
"Error evaluating ", context_->op_kernel().name(), " input ", index,
|
|
||||||
": ", constant.status().error_message());
|
|
||||||
}
|
|
||||||
constant_literal->Swap(constant.ValueOrDie().get());
|
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -111,13 +111,12 @@ bool ComputationBuilder::MakeWindow(
|
||||||
return true;
|
return true;
|
||||||
} else {
|
} else {
|
||||||
NoteError(InvalidArgument(
|
NoteError(InvalidArgument(
|
||||||
"%s",
|
"%s", tensorflow::strings::StrCat(
|
||||||
tensorflow::strings::StrCat(
|
"Window has different number of window dimensions than of ",
|
||||||
"Window has different number of window dimensions than of ",
|
x_name, "\nNumber of window dimensions: ",
|
||||||
x_name, "\nNumber of window dimensions: ",
|
window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
|
||||||
window_dimensions.size(), "\nNumber of ", x_name, ": ", x,
|
"\n")
|
||||||
"\n")
|
.c_str())); //
|
||||||
.c_str())); //
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -663,24 +662,26 @@ bool ComputationBuilder::VerifyConvolution(
|
||||||
}
|
}
|
||||||
int num_spatial_dims = num_dims - 2;
|
int num_spatial_dims = num_dims - 2;
|
||||||
|
|
||||||
const auto check_spatial_dimensions = [&](
|
const auto check_spatial_dimensions =
|
||||||
const char* const field_name,
|
[&](const char* const field_name,
|
||||||
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
|
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
|
||||||
numbers) {
|
numbers) {
|
||||||
if (numbers.size() != num_spatial_dims) {
|
if (numbers.size() != num_spatial_dims) {
|
||||||
NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
|
NoteError(InvalidArgument("Expected %d elements for %s, but got %d.",
|
||||||
num_spatial_dims, field_name, numbers.size()));
|
num_spatial_dims, field_name,
|
||||||
return false;
|
numbers.size()));
|
||||||
}
|
return false;
|
||||||
for (int i = 0; i < numbers.size(); ++i) {
|
}
|
||||||
if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
|
for (int i = 0; i < numbers.size(); ++i) {
|
||||||
NoteError(InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
|
if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
|
||||||
field_name, i, numbers.Get(i)));
|
NoteError(
|
||||||
return false;
|
InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
|
||||||
}
|
field_name, i, numbers.Get(i)));
|
||||||
}
|
return false;
|
||||||
return true;
|
}
|
||||||
};
|
}
|
||||||
|
return true;
|
||||||
|
};
|
||||||
return check_spatial_dimensions("spatial_dimensions",
|
return check_spatial_dimensions("spatial_dimensions",
|
||||||
dimension_numbers.spatial_dimensions()) &&
|
dimension_numbers.spatial_dimensions()) &&
|
||||||
check_spatial_dimensions(
|
check_spatial_dimensions(
|
||||||
|
|
@ -1268,7 +1269,7 @@ StatusOr<bool> ComputationBuilder::IsConstant(
|
||||||
return response.is_constant();
|
return response.is_constant();
|
||||||
}
|
}
|
||||||
|
|
||||||
StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
|
StatusOr<std::unique_ptr<Literal>> ComputationBuilder::ComputeConstant(
|
||||||
const ComputationDataHandle& operand, const Layout* output_layout) {
|
const ComputationDataHandle& operand, const Layout* output_layout) {
|
||||||
if (!first_error_.ok()) {
|
if (!first_error_.ok()) {
|
||||||
return first_error_;
|
return first_error_;
|
||||||
|
|
@ -1291,8 +1292,14 @@ StatusOr<std::unique_ptr<GlobalData>> ComputationBuilder::ComputeConstant(
|
||||||
return s;
|
return s;
|
||||||
}
|
}
|
||||||
|
|
||||||
TF_RET_CHECK(response.output().handle() != 0);
|
VLOG(3) << "ComputeConstant: {" << response.DebugString() << "}";
|
||||||
return MakeUnique<GlobalData>(client_->stub(), response.output());
|
|
||||||
|
if (!response.has_literal()) {
|
||||||
|
return InternalError(
|
||||||
|
"no computed literal in the provided response in ComputeConstant "
|
||||||
|
"request");
|
||||||
|
}
|
||||||
|
return MakeUnique<Literal>(response.literal());
|
||||||
}
|
}
|
||||||
|
|
||||||
ComputationDataHandle ComputationBuilder::Map(
|
ComputationDataHandle ComputationBuilder::Map(
|
||||||
|
|
|
||||||
|
|
@ -679,12 +679,12 @@ class ComputationBuilder {
|
||||||
// Computes the value of a constant indicated by a
|
// Computes the value of a constant indicated by a
|
||||||
// ComputationDataHandle.
|
// ComputationDataHandle.
|
||||||
//
|
//
|
||||||
// The handle must be from the computation currently being built -
|
// The operand must be from the computation currently being built -
|
||||||
// i.e., returned from this builder with no intervening call to
|
// i.e., returned from this builder with no intervening call to
|
||||||
// Build(). This happens to currently work regardless of that, but
|
// Build(). This happens to currently work regardless of that, but
|
||||||
// that may stop working at any time.
|
// that may stop working at any time.
|
||||||
//
|
//
|
||||||
// The handle must represent a constant value, which in this case
|
// The operand must represent a constant value, which in this case
|
||||||
// means that it must not statically depend on a parameter to the
|
// means that it must not statically depend on a parameter to the
|
||||||
// computation that is being built.
|
// computation that is being built.
|
||||||
//
|
//
|
||||||
|
|
@ -702,8 +702,8 @@ class ComputationBuilder {
|
||||||
//
|
//
|
||||||
// If output_layout is non-null, then the output of the computation
|
// If output_layout is non-null, then the output of the computation
|
||||||
// will be stored using that layout.
|
// will be stored using that layout.
|
||||||
StatusOr<std::unique_ptr<GlobalData>> ComputeConstant(
|
StatusOr<std::unique_ptr<Literal>> ComputeConstant(
|
||||||
const ComputationDataHandle& handle,
|
const ComputationDataHandle& operand,
|
||||||
const Layout* output_layout = nullptr);
|
const Layout* output_layout = nullptr);
|
||||||
|
|
||||||
// Returns a new ComputationBuilder whose resultant Computation is used only
|
// Returns a new ComputationBuilder whose resultant Computation is used only
|
||||||
|
|
|
||||||
|
|
@ -428,6 +428,7 @@ cc_library(
|
||||||
":gpu_transfer_manager",
|
":gpu_transfer_manager",
|
||||||
":hlo",
|
":hlo",
|
||||||
":hlo_cost_analysis",
|
":hlo_cost_analysis",
|
||||||
|
":hlo_evaluator",
|
||||||
":hlo_execution_profile",
|
":hlo_execution_profile",
|
||||||
":hlo_module_config",
|
":hlo_module_config",
|
||||||
":platform_util",
|
":platform_util",
|
||||||
|
|
|
||||||
|
|
@ -50,19 +50,14 @@ CompileOnlyService::NewService(const ServiceOptions& options) {
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
|
TF_ASSIGN_OR_RETURN(auto compiler, Compiler::GetForPlatform(platform));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
std::unique_ptr<CompileOnlyService> service(
|
||||||
CreateComputeConstantBackend());
|
new CompileOnlyService(options, compiler));
|
||||||
std::unique_ptr<CompileOnlyService> service(new CompileOnlyService(
|
|
||||||
options, compiler, std::move(compute_constant_backend)));
|
|
||||||
return std::move(service);
|
return std::move(service);
|
||||||
}
|
}
|
||||||
|
|
||||||
CompileOnlyService::CompileOnlyService(
|
CompileOnlyService::CompileOnlyService(const ServiceOptions& options,
|
||||||
const ServiceOptions& options, Compiler* compiler,
|
Compiler* compiler)
|
||||||
std::unique_ptr<Backend> compute_constant_backend)
|
: Service(options, /*execute_backend=*/nullptr), compiler_(compiler) {}
|
||||||
: Service(options, /*backend=*/nullptr,
|
|
||||||
std::move(compute_constant_backend)),
|
|
||||||
compiler_(compiler) {}
|
|
||||||
|
|
||||||
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
StatusOr<std::vector<std::unique_ptr<AotCompilationResult>>>
|
||||||
CompileOnlyService::CompileAheadOfTime(
|
CompileOnlyService::CompileAheadOfTime(
|
||||||
|
|
|
||||||
|
|
@ -102,9 +102,8 @@ class CompileOnlyService : public Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit CompileOnlyService(
|
explicit CompileOnlyService(const ServiceOptions& options,
|
||||||
const ServiceOptions& options, Compiler* compiler,
|
Compiler* compiler);
|
||||||
std::unique_ptr<Backend> compute_constant_backend);
|
|
||||||
CompileOnlyService(const CompileOnlyService&) = delete;
|
CompileOnlyService(const CompileOnlyService&) = delete;
|
||||||
void operator=(const CompileOnlyService&) = delete;
|
void operator=(const CompileOnlyService&) = delete;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -54,23 +54,19 @@ namespace xla {
|
||||||
}
|
}
|
||||||
|
|
||||||
BackendOptions backend_options;
|
BackendOptions backend_options;
|
||||||
backend_options.set_platform(platform)
|
backend_options.set_platform(platform).set_intra_op_parallelism_threads(
|
||||||
.set_intra_op_parallelism_threads(options.intra_op_parallelism_threads());
|
options.intra_op_parallelism_threads());
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
|
||||||
Backend::CreateBackend(backend_options));
|
Backend::CreateBackend(backend_options));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
std::unique_ptr<LocalService> service(
|
||||||
CreateComputeConstantBackend());
|
new LocalService(options, std::move(backend)));
|
||||||
std::unique_ptr<LocalService> service(new LocalService(
|
|
||||||
options, std::move(backend), std::move(compute_constant_backend)));
|
|
||||||
return std::move(service);
|
return std::move(service);
|
||||||
}
|
}
|
||||||
|
|
||||||
LocalService::LocalService(const ServiceOptions& options,
|
LocalService::LocalService(const ServiceOptions& options,
|
||||||
std::unique_ptr<Backend> execute_backend,
|
std::unique_ptr<Backend> execute_backend)
|
||||||
std::unique_ptr<Backend> compute_constant_backend)
|
: Service(options, std::move(execute_backend)) {}
|
||||||
: Service(options, std::move(execute_backend),
|
|
||||||
std::move(compute_constant_backend)) {}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
// Returns the space required to allocate a shape. If
|
// Returns the space required to allocate a shape. If
|
||||||
|
|
@ -161,7 +157,6 @@ StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
|
||||||
std::vector<perftools::gputools::DeviceMemoryBase> argument_buffers(
|
std::vector<perftools::gputools::DeviceMemoryBase> argument_buffers(
|
||||||
argument_layouts.size());
|
argument_layouts.size());
|
||||||
return BuildExecutable(versioned_handle, std::move(module_config),
|
return BuildExecutable(versioned_handle, std::move(module_config),
|
||||||
/*executable_for_compute_constant=*/false,
|
|
||||||
argument_buffers, execute_backend_.get(), executor);
|
argument_buffers, execute_backend_.get(), executor);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -57,8 +57,7 @@ class LocalService : public Service {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
explicit LocalService(const ServiceOptions& options,
|
explicit LocalService(const ServiceOptions& options,
|
||||||
std::unique_ptr<Backend> backend,
|
std::unique_ptr<Backend> backend);
|
||||||
std::unique_ptr<Backend> compute_constant_backend);
|
|
||||||
LocalService(const LocalService&) = delete;
|
LocalService(const LocalService&) = delete;
|
||||||
void operator=(const LocalService&) = delete;
|
void operator=(const LocalService&) = delete;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/xla/service/executable.h"
|
#include "tensorflow/compiler/xla/service/executable.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_evaluator.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
#include "tensorflow/compiler/xla/service/hlo_module_config.h"
|
||||||
|
|
@ -144,36 +145,15 @@ int ServiceOptions::intra_op_parallelism_threads() const {
|
||||||
backend_options.set_platform(platform);
|
backend_options.set_platform(platform);
|
||||||
TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
|
TF_ASSIGN_OR_RETURN(execute_backend, Backend::CreateBackend(backend_options));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> compute_constant_backend,
|
|
||||||
CreateComputeConstantBackend());
|
|
||||||
std::unique_ptr<Service> service(
|
std::unique_ptr<Service> service(
|
||||||
new Service(options, std::move(execute_backend),
|
new Service(options, std::move(execute_backend)));
|
||||||
std::move(compute_constant_backend)));
|
|
||||||
return std::move(service);
|
return std::move(service);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* static */ StatusOr<std::unique_ptr<Backend>>
|
|
||||||
Service::CreateComputeConstantBackend() {
|
|
||||||
TF_ASSIGN_OR_RETURN(std::vector<se::Platform*> platforms,
|
|
||||||
PlatformUtil::GetSupportedPlatforms());
|
|
||||||
for (auto* platform : platforms) {
|
|
||||||
if (platform->id() == se::host::kHostPlatformId) {
|
|
||||||
BackendOptions backend_options;
|
|
||||||
backend_options.set_platform(platform);
|
|
||||||
return Backend::CreateBackend(backend_options);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return NotFound("CPU platform not found");
|
|
||||||
}
|
|
||||||
|
|
||||||
Service::Service(const ServiceOptions& options,
|
Service::Service(const ServiceOptions& options,
|
||||||
std::unique_ptr<Backend> execute_backend,
|
std::unique_ptr<Backend> execute_backend)
|
||||||
std::unique_ptr<Backend> compute_constant_backend)
|
: options_(options), execute_backend_(std::move(execute_backend)) {
|
||||||
: options_(options),
|
|
||||||
execute_backend_(std::move(execute_backend)),
|
|
||||||
compute_constant_backend_(std::move(compute_constant_backend)) {
|
|
||||||
CHECK(options_.number_of_replicas() > 0);
|
CHECK(options_.number_of_replicas() > 0);
|
||||||
|
|
||||||
if (execute_backend_) {
|
if (execute_backend_) {
|
||||||
if (execute_backend_->device_count() > 0) {
|
if (execute_backend_->device_count() > 0) {
|
||||||
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
|
CHECK_GE(execute_backend_->device_count(), options_.number_of_replicas())
|
||||||
|
|
@ -418,7 +398,6 @@ StatusOr<std::vector<std::unique_ptr<Executable>>> Service::BuildExecutables(
|
||||||
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||||
const VersionedComputationHandle& versioned_handle,
|
const VersionedComputationHandle& versioned_handle,
|
||||||
std::unique_ptr<HloModuleConfig> module_config,
|
std::unique_ptr<HloModuleConfig> module_config,
|
||||||
bool executable_for_compute_constant,
|
|
||||||
const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||||
arguments,
|
arguments,
|
||||||
Backend* backend, se::StreamExecutor* executor) {
|
Backend* backend, se::StreamExecutor* executor) {
|
||||||
|
|
@ -431,8 +410,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||||
module_config->debug_options().xla_dump_computations_to();
|
module_config->debug_options().xla_dump_computations_to();
|
||||||
const string& other_directory_path =
|
const string& other_directory_path =
|
||||||
module_config->debug_options().xla_dump_executions_to();
|
module_config->debug_options().xla_dump_executions_to();
|
||||||
if (!executable_for_compute_constant &&
|
if (!directory_path.empty() || !other_directory_path.empty()) {
|
||||||
(!directory_path.empty() || !other_directory_path.empty())) {
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
session_module,
|
session_module,
|
||||||
computation_tracker_.SnapshotComputation(versioned_handle.handle));
|
computation_tracker_.SnapshotComputation(versioned_handle.handle));
|
||||||
|
|
@ -450,7 +428,7 @@ StatusOr<std::unique_ptr<Executable>> Service::BuildExecutable(
|
||||||
std::unique_ptr<HloModule> module,
|
std::unique_ptr<HloModule> module,
|
||||||
computation_tracker_.BuildHloModule(versioned_handle, *module_config,
|
computation_tracker_.BuildHloModule(versioned_handle, *module_config,
|
||||||
/*include_unreachable_instructions=*/
|
/*include_unreachable_instructions=*/
|
||||||
!executable_for_compute_constant));
|
true));
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<Executable> executable,
|
std::unique_ptr<Executable> executable,
|
||||||
|
|
@ -490,8 +468,7 @@ StatusOr<std::shared_ptr<Executable>> Service::BuildAndCacheExecutable(
|
||||||
HloModuleConfig original_module_config = *module_config;
|
HloModuleConfig original_module_config = *module_config;
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::unique_ptr<Executable> executable_unique_ptr,
|
std::unique_ptr<Executable> executable_unique_ptr,
|
||||||
BuildExecutable(versioned_handle, std::move(module_config),
|
BuildExecutable(versioned_handle, std::move(module_config), arguments,
|
||||||
/*executable_for_compute_constant=*/false, arguments,
|
|
||||||
backend, executor));
|
backend, executor));
|
||||||
|
|
||||||
if (profile != nullptr) {
|
if (profile != nullptr) {
|
||||||
|
|
@ -1098,7 +1075,6 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(bool is_constant,
|
TF_ASSIGN_OR_RETURN(bool is_constant,
|
||||||
user_computation->IsConstant(arg->operand()));
|
user_computation->IsConstant(arg->operand()));
|
||||||
|
|
||||||
if (!is_constant) {
|
if (!is_constant) {
|
||||||
return InvalidArgument("Operand to ComputeConstant depends on parameter.");
|
return InvalidArgument("Operand to ComputeConstant depends on parameter.");
|
||||||
}
|
}
|
||||||
|
|
@ -1114,8 +1090,8 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
|
||||||
|
|
||||||
ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions();
|
ExecutionOptions execution_options = xla::CreateDefaultExecutionOptions();
|
||||||
execution_options.mutable_debug_options()->set_xla_enable_fast_math(false);
|
execution_options.mutable_debug_options()->set_xla_enable_fast_math(false);
|
||||||
execution_options.mutable_debug_options()->set_xla_backend_optimization_level(
|
execution_options.mutable_debug_options()
|
||||||
0);
|
->set_xla_eliminate_hlo_implicit_broadcast(true);
|
||||||
*execution_options.mutable_shape_with_output_layout() =
|
*execution_options.mutable_shape_with_output_layout() =
|
||||||
program_shape.result();
|
program_shape.result();
|
||||||
|
|
||||||
|
|
@ -1130,20 +1106,22 @@ tensorflow::Status Service::ComputeConstant(const ComputeConstantRequest* arg,
|
||||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
|
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloModuleConfig> module_config,
|
||||||
CreateModuleConfig(program_shape, {}, execution_options));
|
CreateModuleConfig(program_shape, {}, execution_options));
|
||||||
|
|
||||||
|
// Exclude dead parameter instructions for the purpose of computing constants.
|
||||||
TF_ASSIGN_OR_RETURN(
|
TF_ASSIGN_OR_RETURN(
|
||||||
std::shared_ptr<Executable> executable,
|
std::unique_ptr<HloModule> module,
|
||||||
BuildExecutable(versioned_handle, std::move(module_config),
|
computation_tracker_.BuildHloModule(versioned_handle, *module_config,
|
||||||
/*executable_for_compute_constant=*/true,
|
/*include_unreachable_instructions=*/
|
||||||
/*arguments=*/{}, compute_constant_backend_.get(),
|
false));
|
||||||
compute_constant_backend_->default_stream_executor()));
|
|
||||||
|
HloEvaluator evaluator;
|
||||||
|
TF_ASSIGN_OR_RETURN(auto result_literal, evaluator.Evaluate(*module, {}));
|
||||||
|
// Since the shape_with_output_layout option in ExecutionOption is
|
||||||
|
// non-effective to the Evaluator results, explicit relayout here.
|
||||||
|
if (arg->has_output_layout()) {
|
||||||
|
result_literal = result_literal->Relayout(arg->output_layout());
|
||||||
|
}
|
||||||
|
*result->mutable_literal() = result_literal->ToProto();
|
||||||
|
|
||||||
TF_ASSIGN_OR_RETURN(
|
|
||||||
*result->mutable_output(),
|
|
||||||
ExecuteAndRegisterResult(
|
|
||||||
executable.get(), /*arguments=*/{}, compute_constant_backend_.get(),
|
|
||||||
compute_constant_backend_->default_stream_executor(),
|
|
||||||
"constant computed from " + user_computation->name(),
|
|
||||||
/*profile=*/nullptr));
|
|
||||||
return tensorflow::Status::OK();
|
return tensorflow::Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -71,9 +71,9 @@ class ServiceOptions {
|
||||||
int intra_op_parallelism_threads_ = -1;
|
int intra_op_parallelism_threads_ = -1;
|
||||||
};
|
};
|
||||||
|
|
||||||
// The XLA service object, which is the same across all
|
// The XLA service object, which is the same across all platforms. It maintains
|
||||||
// platforms. It maintains the service state of computations and allocations,
|
// the service state of computations and allocations, and delegates
|
||||||
// and delegates target-specific requests to the target-specific infrastructure
|
// target-specific requests to the target-specific infrastructure
|
||||||
// (target-specific compiler, StreamExecutor).
|
// (target-specific compiler, StreamExecutor).
|
||||||
class Service : public ServiceInterface {
|
class Service : public ServiceInterface {
|
||||||
public:
|
public:
|
||||||
|
|
@ -258,8 +258,8 @@ class Service : public ServiceInterface {
|
||||||
|
|
||||||
// The constructor is private. Use the NewService factory to create new
|
// The constructor is private. Use the NewService factory to create new
|
||||||
// service objects.
|
// service objects.
|
||||||
Service(const ServiceOptions& options, std::unique_ptr<Backend> backend,
|
Service(const ServiceOptions& options,
|
||||||
std::unique_ptr<Backend> compute_constant_backend);
|
std::unique_ptr<Backend> execute_backend);
|
||||||
|
|
||||||
static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
|
static StatusOr<std::unique_ptr<Backend>> CreateComputeConstantBackend();
|
||||||
|
|
||||||
|
|
@ -280,16 +280,10 @@ class Service : public ServiceInterface {
|
||||||
const ExecutionOptions* execution_options,
|
const ExecutionOptions* execution_options,
|
||||||
bool has_hybrid_result = false);
|
bool has_hybrid_result = false);
|
||||||
|
|
||||||
// Builds an Executable for the given parameters. If
|
// Builds an Executable for the given parameters.
|
||||||
// executable_for_compute_constant is true, then the executable is intended to
|
|
||||||
// be used for ComputeConstant which means dead parameter instructions are not
|
|
||||||
// included in the executable.The parameter "profile" can optionally point to
|
|
||||||
// an ExecutionProfile object which will be filled in with profile data
|
|
||||||
// relevant to compilation.
|
|
||||||
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
|
StatusOr<std::unique_ptr<Executable>> BuildExecutable(
|
||||||
const VersionedComputationHandle& versioned_handle,
|
const VersionedComputationHandle& versioned_handle,
|
||||||
std::unique_ptr<HloModuleConfig> module_config,
|
std::unique_ptr<HloModuleConfig> module_config,
|
||||||
bool executable_for_compute_constant,
|
|
||||||
const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
const tensorflow::gtl::ArraySlice<perftools::gputools::DeviceMemoryBase>
|
||||||
arguments,
|
arguments,
|
||||||
Backend* backend, perftools::gputools::StreamExecutor* executor);
|
Backend* backend, perftools::gputools::StreamExecutor* executor);
|
||||||
|
|
@ -381,9 +375,6 @@ class Service : public ServiceInterface {
|
||||||
// TODO(b/28616830): Support multiple backends for execution.
|
// TODO(b/28616830): Support multiple backends for execution.
|
||||||
std::unique_ptr<Backend> execute_backend_;
|
std::unique_ptr<Backend> execute_backend_;
|
||||||
|
|
||||||
// Backend to use when executing ComputeConstant.
|
|
||||||
std::unique_ptr<Backend> compute_constant_backend_;
|
|
||||||
|
|
||||||
TF_DISALLOW_COPY_AND_ASSIGN(Service);
|
TF_DISALLOW_COPY_AND_ASSIGN(Service);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -72,9 +72,8 @@ class ComputeConstantTest : public ::testing::Test {
|
||||||
StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
|
StatusOr<std::unique_ptr<Literal>> ComputeConstantLiteral(
|
||||||
Client* client, const ComputationDataHandle& operand,
|
Client* client, const ComputationDataHandle& operand,
|
||||||
ComputationBuilder* builder, Layout* output_layout = nullptr) {
|
ComputationBuilder* builder, Layout* output_layout = nullptr) {
|
||||||
TF_ASSIGN_OR_RETURN(auto remote_computed,
|
TF_ASSIGN_OR_RETURN(auto computed,
|
||||||
builder->ComputeConstant(operand, output_layout));
|
builder->ComputeConstant(operand, output_layout));
|
||||||
TF_ASSIGN_OR_RETURN(auto computed, client->Transfer(*remote_computed));
|
|
||||||
return std::move(computed);
|
return std::move(computed);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -253,35 +252,5 @@ XLA_TEST_F(ComputeConstantTest, Layout) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// This test is permanently disabled on CPU because it requires that the
|
|
||||||
// backend used for execution is different than the backend used for
|
|
||||||
// ComputeConstant which is always cpu.
|
|
||||||
TEST_F(ComputeConstantTest, DISABLED_ON_CPU(ReuseComputedConstant)) {
|
|
||||||
// Compute a trivial constant, then try to use the value in an Execute
|
|
||||||
// call. This should fail because the constant resides on the CPU and the
|
|
||||||
// Execute call is executed on a different backend. This test only makes
|
|
||||||
// sense with LocalClient, since CompileOnlyClient does not support
|
|
||||||
// execution.
|
|
||||||
Client* client = ClientOrDie(platform_, ClientType::kLocal);
|
|
||||||
ComputationBuilder constant_b(client, TestName());
|
|
||||||
auto constant = constant_b.ConstantR0<int32>(42);
|
|
||||||
auto handle = constant_b.ComputeConstant(constant).ConsumeValueOrDie();
|
|
||||||
auto literal = client->Transfer(*handle).ConsumeValueOrDie();
|
|
||||||
LiteralTestUtil::ExpectR0Equal(42, *literal);
|
|
||||||
|
|
||||||
// Build trivial computation which takes one parameter.
|
|
||||||
ComputationBuilder b(client, TestName());
|
|
||||||
b.Neg(b.Parameter(0, ShapeUtil::MakeShape(S32, {}), "param0"));
|
|
||||||
auto computation = b.Build().ConsumeValueOrDie();
|
|
||||||
|
|
||||||
// Try to use value from ComputeConstant in Execute.
|
|
||||||
auto execute_status = client->Execute(computation, {handle.get()});
|
|
||||||
EXPECT_FALSE(execute_status.ok());
|
|
||||||
EXPECT_THAT(
|
|
||||||
execute_status.status().error_message(),
|
|
||||||
::testing::ContainsRegex("argument 0 is on device Host:0 but computation "
|
|
||||||
"will be executed on device"));
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
|
|
@ -350,7 +350,9 @@ message ComputeConstantRequest {
|
||||||
}
|
}
|
||||||
|
|
||||||
message ComputeConstantResponse {
|
message ComputeConstantResponse {
|
||||||
GlobalDataHandle output = 1;
|
// A LiteralProto is returned directly for this request, instead of a
|
||||||
|
// ComputationDataHandle.
|
||||||
|
LiteralProto literal = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
message DeconstructTupleRequest {
|
message DeconstructTupleRequest {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user