An initial step of eliminating all implicit broadcast at the HLO level.

Guard the shape inference for binary ops behind a flag.

PiperOrigin-RevId: 157373647
This commit is contained in:
A. Unique TensorFlower 2017-05-29 01:02:53 -07:00 committed by TensorFlower Gardener
parent e78e5ec8a8
commit 5f097217f4
10 changed files with 320 additions and 4 deletions

View File

@ -240,6 +240,18 @@ cc_library(
],
)
cc_library(
name = "user_computation_flags",
srcs = ["user_computation_flags.cc"],
hdrs = ["user_computation_flags.h"],
deps = [
":parse_flags_from_env",
"//tensorflow/compiler/xla:types",
"//tensorflow/core:framework_internal",
"//tensorflow/core:lib",
],
)
# -----------------------------------------------------------------------------
filegroup(

View File

@ -0,0 +1,64 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
#include <vector>
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Pointers to the parsed value of the flags and flag descriptors, initialized
// via flags_init.
static UserComputationFlags* flags;
static std::vector<tensorflow::Flag>* flag_list;
static std::once_flag flags_init;
// Allocate *flags. Called via call_once(&flags_init,...).
static void AllocateFlags() {
flags = new UserComputationFlags;
flags->xla_eliminate_hlo_implicit_broadcast = false;
flag_list = new std::vector<tensorflow::Flag>({
tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast",
&flags->xla_eliminate_hlo_implicit_broadcast,
"Eliminate implicit broadcast on when lowering user "
"computation to HLO instructions, use explicit "
"broadcast instead."),
});
ParseFlagsFromEnv(*flag_list);
}
// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline
// module.
void AppendUserComputationFlags(std::vector<tensorflow::Flag>* append_to) {
std::call_once(flags_init, &AllocateFlags);
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
}
// Return a pointer to the UserComputationFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
UserComputationFlags* GetUserComputationFlags() {
std::call_once(flags_init, &AllocateFlags);
return flags;
}
} // namespace legacy_flags
} // namespace xla

View File

@ -0,0 +1,48 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
// Legacy flags for XLA's user_computation module.
#include <vector>
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/command_line_flags.h"
namespace xla {
namespace legacy_flags {
// Append to *flag_list flags definitions associated with XLA's user_computation
// module.
void AppendUserComputationFlags(std::vector<tensorflow::Flag>* flag_list);
typedef struct {
// Eliminate implicit broadcast on when lowering user computation to HLO
// instructions, use explicit broadcast instead.
bool xla_eliminate_hlo_implicit_broadcast;
} UserComputationFlags;
// Return a pointer to the UserComputationFlags struct;
// repeated calls return the same pointer.
// This should be called only after Flags::Parse() has returned.
UserComputationFlags* GetUserComputationFlags();
} // namespace legacy_flags
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_

View File

@ -282,6 +282,7 @@ cc_library(
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
"//tensorflow/core:lib",
],
)
@ -298,6 +299,7 @@ cc_test(
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/core:test",
"//tensorflow/core:test_main",

View File

@ -547,7 +547,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
} else {
// Ranks do not match, so perform InDim broadcasting using
// broadcast_dimensions. Scalar broadcasting is a special case of this).
// broadcast_dimensions. Scalar broadcasting is a special case of this.
const Shape& larger_shape =
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
const Shape& smaller_shape =

View File

@ -22,6 +22,7 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
@ -1887,6 +1888,12 @@ class ComputationLowerer {
const ComputationHandle& handle,
VersionedComputationHandle::Version version);
// This function takes an input value which is being implicitly broadcast into
// an output shape and figures out the right kBroadcast instruction(s)
// necessary to replicate the implicit broadcast semantics explicitly.
HloInstruction* ImplicitBroadcastToExplicitBroadcast(
HloInstruction* operand, const Shape& output_shape);
HloComputation::Builder hlo_builder_;
const SessionComputation& session_computation_;
const VersionedComputationHandle::Version version_;
@ -2204,6 +2211,37 @@ HloComputation* ComputationLowerer::ResolveComputation(
return hlo_resolver_(checked_handle);
}
HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
HloInstruction* operand, const Shape& output_shape) {
CHECK(ShapeUtil::IsScalar(operand->shape()) ||
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape));
Shape broadcast_shape = ShapeUtil::MakeShape(
operand->shape().element_type(), AsInt64Slice(output_shape.dimensions()));
// Do explicit broadcast for scalar.
if (ShapeUtil::IsScalar(operand->shape())) {
return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
broadcast_shape, operand, AsInt64Slice(broadcast_shape.dimensions())));
}
// Do explicit broadcast for degenerate broadcast.
std::vector<int64> broadcast_dimensions;
std::vector<int64> reshaped_dimensions;
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
if (operand->shape().dimensions(i) > 1) {
broadcast_dimensions.push_back(i);
reshaped_dimensions.push_back(operand->shape().dimensions(i));
}
}
// Eliminate the size one dimensions.
HloInstruction* reshaped_operand =
hlo_builder_.AddInstruction(HloInstruction::CreateReshape(
ShapeUtil::MakeShape(operand->shape().element_type(),
reshaped_dimensions),
operand));
// Broadcast 'reshape' up to the larger size.
return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
broadcast_shape, reshaped_operand, broadcast_dimensions));
}
void ComputationLowerer::Visit(
const ComputationDataHandle& handle,
std::unordered_map<int64, HloInstruction*>* instructions) {
@ -2629,6 +2667,19 @@ void ComputationLowerer::Visit(
lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
}
if (legacy_flags::GetUserComputationFlags()
->xla_eliminate_hlo_implicit_broadcast) {
if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
// lhs side is being implicitly broadcast. Change to explicit.
lhs =
ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape());
}
if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
rhs =
ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape());
}
}
hlo_instruction = add_instruction(HloInstruction::CreateBinary(
request.output_shape(), hlo_opcode, lhs, rhs));
break;

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/service/user_computation.h"
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
@ -143,5 +144,138 @@ TEST_F(UserComputationTest, SimpleComputation) {
}
}
TEST_F(UserComputationTest, EliminateScalarBroadcast) {
if (!legacy_flags::GetUserComputationFlags()
->xla_eliminate_hlo_implicit_broadcast) {
return;
}
// Build a binary computation with scalar broadcast.
//
// %a = Constant({123, 42})
// %b = Constant(1)
// %add = Add(%a, %b)
ComputationHandle handle;
handle.set_handle(123);
UserComputation computation("TheComputation", handle);
ConstantRequest a_request;
*a_request.mutable_literal() = *LiteralUtil::CreateR1<float>({123.0f, 42.0f});
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
computation.AddConstantInstruction(a_request));
ConstantRequest b_request;
*b_request.mutable_literal() = *LiteralUtil::CreateR0<float>(1.0f);
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
computation.AddConstantInstruction(b_request));
BinaryOpRequest add;
add.set_binop(BINOP_ADD);
*add.mutable_lhs() = a_handle;
*add.mutable_rhs() = b_handle;
TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
auto hlo_resolver = [](const VersionedComputationHandle& handle) {
return nullptr;
};
VersionedComputationHandle latest_version = computation.GetVersionedHandle();
// Build the HLO computation.
TF_ASSIGN_OR_ASSERT_OK(
std::unique_ptr<HloComputation> hlo_computation,
computation.BuildHloComputation(latest_version.version, hlo_resolver));
// The binary operation has implicit scalar broadcast, should be converted
// to an explicit broadcast intruction and a binary instruction.
EXPECT_EQ(4, hlo_computation->instruction_count());
EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
const auto& operands = hlo_computation->root_instruction()->operands();
ASSERT_EQ(2, operands.size());
EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast ||
operands[1]->opcode() == HloOpcode::kBroadcast);
}
TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
if (!legacy_flags::GetUserComputationFlags()
->xla_eliminate_hlo_implicit_broadcast) {
return;
}
// Build a binary computation with in-dim broadcast and degenerate broadcast.
//
// %a = Param({2, 3});
// %b = Param({2, 1, 4});
// %add = Add(%a, %b, {0, 1});
ComputationHandle handle;
handle.set_handle(123);
UserComputation computation("TheComputation", handle);
ParameterRequest a_request;
*a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3});
a_request.set_name("a");
a_request.set_parameter(0);
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
computation.AddParameterInstruction(a_request));
ParameterRequest b_request;
*b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4});
b_request.set_name("b");
b_request.set_parameter(1);
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
computation.AddParameterInstruction(b_request));
BinaryOpRequest add;
add.set_binop(BINOP_ADD);
*add.mutable_lhs() = a_handle;
*add.mutable_rhs() = b_handle;
add.add_broadcast_dimensions(0);
add.add_broadcast_dimensions(1);
TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
auto hlo_resolver = [](const VersionedComputationHandle& handle) {
return nullptr;
};
VersionedComputationHandle latest_version = computation.GetVersionedHandle();
// Build the HLO computation.
TF_ASSIGN_OR_ASSERT_OK(
std::unique_ptr<HloComputation> hlo_computation,
computation.BuildHloComputation(latest_version.version, hlo_resolver));
// The binary operation has in-dim broadcast and degenerate broadcast, should
// first do the in-dim broadcast then convert the degnerate broadcast into a
// reshape and a broadcast.
//
// b a
// | |
// broadcast reshape
// | |
// | broadcast
// \ /
// add
EXPECT_EQ(6, hlo_computation->instruction_count());
EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
const auto& operands = hlo_computation->root_instruction()->operands();
ASSERT_EQ(2, operands.size());
EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast &&
operands[1]->opcode() == HloOpcode::kBroadcast);
}
} // namespace
} // namespace xla
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendUserComputationFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {
LOG(ERROR) << "\n" << usage;
return 2;
}
testing::InitGoogleTest(&argc, argv);
if (argc > 1) {
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
return 2;
}
return RUN_ALL_TESTS();
}

View File

@ -459,16 +459,17 @@ xla_test(
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:xla_data_proto",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:global_data",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:lib",
"//tensorflow/core:test",
],
)
@ -966,13 +967,13 @@ xla_test(
"//tensorflow/compiler/xla:array4d",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:statusor",
"//tensorflow/compiler/xla:test_helpers",
"//tensorflow/compiler/xla:test",
"//tensorflow/compiler/xla/client:computation_builder",
"//tensorflow/compiler/xla/client:local_client",
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
"//tensorflow/compiler/xla/tests:client_library_test_base",
"//tensorflow/compiler/xla/tests:literal_test_util",
"//tensorflow/core:test",
],
)

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@ -1858,6 +1859,7 @@ INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
xla::legacy_flags::AppendUserComputationFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {

View File

@ -22,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/client/local_client.h"
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/test.h"
@ -702,6 +703,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
int main(int argc, char** argv) {
std::vector<tensorflow::Flag> flag_list;
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
xla::legacy_flags::AppendUserComputationFlags(&flag_list);
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
if (!parse_result) {