mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
When creating an HloModule from an HloProto construct the HloModuleConfig
with a correct ProgramShape which matches the shapes of the entry computation. Previously the module config had a bogus or default constructed ProgramShape. PiperOrigin-RevId: 173741104
This commit is contained in:
parent
09a89ae57d
commit
45c5118f0e
|
|
@ -204,13 +204,93 @@ HloModuleProto HloModule::ToProto() const {
|
|||
return proto;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Construct a ProgramShape matching the shape of the parameters and root of the
|
||||
// given module's entry computation.
|
||||
StatusOr<ProgramShape> ProgramShapeFromProto(const HloModuleProto& module) {
|
||||
const HloComputationProto* entry_computation = nullptr;
|
||||
for (const HloComputationProto& computation : module.computations()) {
|
||||
if (computation.name() == module.entry_computation_name()) {
|
||||
entry_computation = &computation;
|
||||
break;
|
||||
}
|
||||
}
|
||||
TF_RET_CHECK(entry_computation != nullptr)
|
||||
<< "No computation with entry computation name"
|
||||
<< module.entry_computation_name();
|
||||
|
||||
tensorflow::gtl::FlatMap<int64, std::pair<string, const Shape*>> parameters;
|
||||
const HloInstructionProto* root = nullptr;
|
||||
for (const HloInstructionProto& instruction :
|
||||
entry_computation->instructions()) {
|
||||
if (instruction.name() == entry_computation->root_name()) {
|
||||
TF_RET_CHECK(root == nullptr) << "Entry computation has more than "
|
||||
"one instruction with (root) name "
|
||||
<< instruction.name();
|
||||
root = &instruction;
|
||||
}
|
||||
if (instruction.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
|
||||
TF_RET_CHECK(!ContainsKey(parameters, instruction.parameter_number()))
|
||||
<< "Entry computation has more than one parameter instruction "
|
||||
"with parameter number "
|
||||
<< instruction.parameter_number();
|
||||
parameters[instruction.parameter_number()] = {
|
||||
instruction.parameter_name(), &instruction.shape()};
|
||||
}
|
||||
}
|
||||
TF_RET_CHECK(root != nullptr)
|
||||
<< "Entry computation is missing root instruction named "
|
||||
<< entry_computation->root_name();
|
||||
|
||||
ProgramShape program_shape;
|
||||
*program_shape.mutable_result() = root->shape();
|
||||
for (int64 i = 0; i < parameters.size(); ++i) {
|
||||
TF_RET_CHECK(ContainsKey(parameters, i))
|
||||
<< "Entry computation missing parameter number " << i;
|
||||
const string& name = parameters.at(i).first;
|
||||
const Shape& shape = *parameters.at(i).second;
|
||||
*program_shape.add_parameters() = shape;
|
||||
program_shape.add_parameter_names(name);
|
||||
}
|
||||
|
||||
return std::move(program_shape);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
/* static */
|
||||
StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
||||
const HloModuleProto& proto,
|
||||
const VersionedComputationHandle& entry_computation_handle,
|
||||
const HloModuleConfig& config) {
|
||||
auto module =
|
||||
MakeUnique<HloModule>(proto.name(), entry_computation_handle, config);
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||
const VersionedComputationHandle& entry_computation_handle) {
|
||||
// The ProgramShape in the passed in module config must match the shapes of
|
||||
// the entry parameters and root.
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape expected_program_shape,
|
||||
ProgramShapeFromProto(proto));
|
||||
TF_RET_CHECK(expected_program_shape.parameters_size() ==
|
||||
module_config.entry_computation_layout().parameter_count());
|
||||
for (int i = 0; i < expected_program_shape.parameters_size(); ++i) {
|
||||
const Shape& parameter_shape =
|
||||
module_config.entry_computation_layout().parameter_layout(i).shape();
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::Equal(expected_program_shape.parameters(i), parameter_shape))
|
||||
<< "HloModuleConfig has different shape for parameter " << i
|
||||
<< " than the HLO module. Expected: "
|
||||
<< ShapeUtil::HumanStringWithLayout(
|
||||
expected_program_shape.parameters(i))
|
||||
<< ", actual: " << ShapeUtil::HumanStringWithLayout(parameter_shape);
|
||||
}
|
||||
const Shape& result_shape =
|
||||
module_config.entry_computation_layout().result_layout().shape();
|
||||
TF_RET_CHECK(ShapeUtil::Equal(expected_program_shape.result(), result_shape))
|
||||
<< "HloModuleConfig has different result shape than the HLO module. "
|
||||
"Expected: "
|
||||
<< ShapeUtil::HumanStringWithLayout(expected_program_shape.result())
|
||||
<< ", actual: " << ShapeUtil::HumanStringWithLayout(result_shape);
|
||||
|
||||
auto module = MakeUnique<HloModule>(proto.name(), entry_computation_handle,
|
||||
module_config);
|
||||
|
||||
tensorflow::gtl::FlatMap<string, HloComputation*> computation_map;
|
||||
for (const HloComputationProto& computation_proto : proto.computations()) {
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloComputation> computation,
|
||||
|
|
@ -250,6 +330,29 @@ StatusOr<std::unique_ptr<HloModule>> HloModule::CreateFromProto(
|
|||
return std::move(module);
|
||||
}
|
||||
|
||||
/* static */
|
||||
StatusOr<HloModuleConfig> HloModule::CreateModuleConfigFromProto(
|
||||
const HloModuleProto& module) {
|
||||
TF_ASSIGN_OR_RETURN(ProgramShape program_shape,
|
||||
ProgramShapeFromProto(module));
|
||||
|
||||
HloModuleConfig module_config(program_shape);
|
||||
|
||||
// The module config is constructed with default layouts regardless of what is
|
||||
// passed in via the ProgramShape. Set the layouts to the appropriate values.
|
||||
ComputationLayout* entry_layout =
|
||||
module_config.mutable_entry_computation_layout();
|
||||
for (int64 i = 0; i < entry_layout->parameter_count(); ++i) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
entry_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
|
||||
program_shape.parameters(i)));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(entry_layout->mutable_result_layout()->CopyLayoutFromShape(
|
||||
program_shape.result()));
|
||||
|
||||
return module_config;
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Returns whether `hlo` is used outside the given subcomputation.
|
||||
// `instructions_in_subcomputation` is the instruction set of the given
|
||||
|
|
|
|||
|
|
@ -144,9 +144,14 @@ class HloModule {
|
|||
// Convert an HloModule to or from a proto.
|
||||
HloModuleProto ToProto() const;
|
||||
static StatusOr<std::unique_ptr<HloModule>> CreateFromProto(
|
||||
const HloModuleProto& proto,
|
||||
const VersionedComputationHandle& entry_computation_handle,
|
||||
const HloModuleConfig& config);
|
||||
const HloModuleProto& proto, const HloModuleConfig& module_config,
|
||||
const VersionedComputationHandle& entry_computation_handle =
|
||||
VersionedComputationHandle());
|
||||
|
||||
// Creates and returns an HloModuleConfig with an appropriate program shape
|
||||
// for the HLO module in the given proto.
|
||||
static StatusOr<HloModuleConfig> CreateModuleConfigFromProto(
|
||||
const HloModuleProto& module);
|
||||
|
||||
// Outlines the given expression from the given computation.
|
||||
// instructions_to_outline contains the instructions that form the expression.
|
||||
|
|
|
|||
|
|
@ -45,11 +45,12 @@ HloRunner::ReadModuleFromHloProtoFile(const char* filename,
|
|||
HloProto proto;
|
||||
TF_RETURN_IF_ERROR(tensorflow::ReadBinaryProto(tensorflow::Env::Default(),
|
||||
filename, &proto));
|
||||
HloModuleConfig config;
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloModuleConfig config,
|
||||
HloModule::CreateModuleConfigFromProto(proto.hlo_module()));
|
||||
config.set_debug_options(debug_options);
|
||||
TF_ASSIGN_OR_RETURN(auto module, HloModule::CreateFromProto(
|
||||
proto.hlo_module(),
|
||||
VersionedComputationHandle(), config));
|
||||
TF_ASSIGN_OR_RETURN(auto module,
|
||||
HloModule::CreateFromProto(proto.hlo_module(), config));
|
||||
return std::move(module);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user