When creating an HloModule from an HloProto construct the HloModuleConfig

with a correct ProgramShape which matches the shapes of the entry computation.
Previously the module config had a bogus or default constructed ProgramShape.

PiperOrigin-RevId: 173741104
This commit is contained in:
Mark Heffernan 2017-10-27 18:11:01 -07:00 committed by TensorFlower Gardener
parent 09a89ae57d
commit 45c5118f0e
3 changed files with 121 additions and 12 deletions

View File

@ -204,13 +204,93 @@ HloModuleProto HloModule::ToProto() const {
return proto;
}
namespace {
// Construct a ProgramShape matching the shape of the parameters and root of the
// given module's entry computation.
StatusOr<ProgramShape> ProgramShapeFromProto(const HloModuleProto& module) {
const HloComputationProto* entry_computation = nullptr;
for (const HloComputationProto& computation : module.computations()) {
if (computation.name() == module.entry_computation_name()) {
entry_computation = &computation;
break;
}
}
TF_RET_CHECK(entry_computation != nullptr)
<< "No computation with entry computation name"
<< module.entry_computation_name();
tensorflow::gtl::FlatMap<int64, std::pair<string, const Shape*>> parameters;
const HloInstructionProto* root = nullptr;
for (const HloInstructionProto& instruction :
entry_computation->instructions()) {
if (instruction.name() == entry_computation->root_name()) {
TF_RET_CHECK(root == nullptr) << "Entry computation has more than "
"one instruction with (root) name "
<< instruction.name();
root = &instruction;
}
if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number()))
<< "Entry computation has more than one parameter instruction "
"with parameter number "
<< instruction.parameter_number();
parameters[instruction.parameter_number()] = {
instruction.parameter_name(), &instruction.shape()};
}
}
TF_RET_CHECK(root != nullptr)
<< "Entry computation is missing root instruction named "
<< entry_computation->root_name();
ProgramShape program_shape;
*program_shape.mutable_result() = root->shape();
for (int64 i = 0; i < parameters.size(); ++i) {
TF_RET_CHECK(ContainsKey(parameters, i))
<< "Entry computation missing parameter number " << i;
const string& name = parameters.at(i).first;
const Shape& shape = *parameters.at(i).second;
*program_shape.add_parameters() = shape;
program_shape.add_parameter_names(name);
}
return std::move(program_shape);
}
} // namespace
/* static */
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
const HloModuleProto& proto,
const VersionedComputationHandle& entry_computation_handle,
const HloModuleConfig& config) {
auto module =
MakeUnique<HloModule>(proto.name(), entry_computation_handle, config);
const HloModuleProto& proto, const HloModuleConfig& module_config,
const VersionedComputationHandle& entry_computation_handle) {
// The ProgramShape in the passed in module config must match the shapes of
// the entry parameters and root.
TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape,
ProgramShapeFromProto(proto));
TF_RET_CHECK(expected_program_shape.parameters_size() ==
module_config.entry_computation_layout().parameter_count());
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
const Shape& parameter_shape =
module_config.entry_computation_layout().parameter_layout(i).shape();
TF_RET_CHECK(
ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape))
<< "HloModuleConfig has different shape for parameter " << i
<< " than the HLO module. Expected: "
<< ShapeUtil::HumanStringWithLayout(
expected_program_shape.parameters(i))
<< ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
}
const Shape& result_shape =
module_config.entry_computation_layout().result_layout().shape();
TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape))
<< "HloModuleConfig has different result shape than the HLO module. "
"Expected: "
<< ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
<< ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
module_config);
tensorflow::gtl::FlatMap<string, HloComputation*> computation_map;
for (const HloComputationProto& computation_proto : proto.computations()) {
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
@ -250,6 +330,29 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
return std::move(module);
}
/* static */
StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
const HloModuleProto& module) {
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
ProgramShapeFromProto(module));
HloModuleConfig module_config(program_shape);
// The module config is constructed with default layouts regardless of what is
// passed in via the ProgramShape. Set the layouts to the appropriate values.
ComputationLayout* entry_layout =
module_config.mutable_entry_computation_layout();
for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
TF_RETURN_IF_ERROR(
entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
program_shape.parameters(i)));
}
TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
program_shape.result()));
return module_config;
}
namespace {
// Returns whether `hlo` is used outside the given subcomputation.
// `instructions_in_subcomputation` is the instruction set of the given

View File

@ -144,9 +144,14 @@ class HloModule {
// Convert an HloModule to or from a proto.
HloModuleProto ToProto() const;
static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
const HloModuleProto& proto,
const VersionedComputationHandle& entry_computation_handle,
const HloModuleConfig& config);
const HloModuleProto& proto, const HloModuleConfig& module_config,
const VersionedComputationHandle& entry_computation_handle =
VersionedComputationHandle());
// Creates and returns an HloModuleConfig with an appropriate program shape
// for the HLO module in the given proto.
static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
const HloModuleProto& module);
// Outlines the given expression from the given computation.
// instructions_to_outline contains the instructions that form the expression.

View File

@ -45,11 +45,12 @@ HloRunner::ReadModuleFromHloProtoFile(const char* filename,
HloProto proto;
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
filename, &proto));
HloModuleConfig config;
TF_ASSIGN_OR_RETURN(
HloModuleConfig config,
HloModule::CreateModuleConfigFromProto(proto.hlo_module()));
config.set_debug_options(debug_options);
TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto(
proto.hlo_module(),
VersionedComputationHandle(), config));
TF_ASSIGN_OR_RETURN(auto module,
HloModule::CreateFromProto(proto.hlo_module(), config));
return std::move(module);
}