mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-08 07:38:39 +01:00
An initial step of eliminating all implicit broadcast at the HLO level.
Guard the shape inference for binary ops behind a flag. PiperOrigin-RevId: 157373647
This commit is contained in:
parent
e78e5ec8a8
commit
5f097217f4
|
|
@ -240,6 +240,18 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "user_computation_flags",
|
||||||
|
srcs = ["user_computation_flags.cc"],
|
||||||
|
hdrs = ["user_computation_flags.h"],
|
||||||
|
deps = [
|
||||||
|
":parse_flags_from_env",
|
||||||
|
"//tensorflow/compiler/xla:types",
|
||||||
|
"//tensorflow/core:framework_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
|
|
||||||
filegroup(
|
filegroup(
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,64 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||||
|
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace legacy_flags {
|
||||||
|
|
||||||
|
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||||
|
// via flags_init.
|
||||||
|
static UserComputationFlags* flags;
|
||||||
|
static std::vector<tensorflow::Flag>* flag_list;
|
||||||
|
static std::once_flag flags_init;
|
||||||
|
|
||||||
|
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||||
|
static void AllocateFlags() {
|
||||||
|
flags = new UserComputationFlags;
|
||||||
|
flags->xla_eliminate_hlo_implicit_broadcast = false;
|
||||||
|
flag_list = new std::vector<tensorflow::Flag>({
|
||||||
|
tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast",
|
||||||
|
&flags->xla_eliminate_hlo_implicit_broadcast,
|
||||||
|
"Eliminate implicit broadcast on when lowering user "
|
||||||
|
"computation to HLO instructions, use explicit "
|
||||||
|
"broadcast instead."),
|
||||||
|
});
|
||||||
|
ParseFlagsFromEnv(*flag_list);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline
|
||||||
|
// module.
|
||||||
|
void AppendUserComputationFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||||
|
std::call_once(flags_init, &AllocateFlags);
|
||||||
|
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return a pointer to the UserComputationFlags struct;
|
||||||
|
// repeated calls return the same pointer.
|
||||||
|
// This should be called only after Flags::Parse() has returned.
|
||||||
|
UserComputationFlags* GetUserComputationFlags() {
|
||||||
|
std::call_once(flags_init, &AllocateFlags);
|
||||||
|
return flags;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace legacy_flags
|
||||||
|
} // namespace xla
|
||||||
|
|
@ -0,0 +1,48 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
|
||||||
|
|
||||||
|
// Legacy flags for XLA's user_computation module.
|
||||||
|
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
#include "tensorflow/core/util/command_line_flags.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace legacy_flags {
|
||||||
|
|
||||||
|
// Append to *flag_list flags definitions associated with XLA's user_computation
|
||||||
|
// module.
|
||||||
|
void AppendUserComputationFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
// Eliminate implicit broadcast on when lowering user computation to HLO
|
||||||
|
// instructions, use explicit broadcast instead.
|
||||||
|
bool xla_eliminate_hlo_implicit_broadcast;
|
||||||
|
} UserComputationFlags;
|
||||||
|
|
||||||
|
// Return a pointer to the UserComputationFlags struct;
|
||||||
|
// repeated calls return the same pointer.
|
||||||
|
// This should be called only after Flags::Parse() has returned.
|
||||||
|
UserComputationFlags* GetUserComputationFlags();
|
||||||
|
|
||||||
|
} // namespace legacy_flags
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
|
||||||
|
|
@ -282,6 +282,7 @@ cc_library(
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -298,6 +299,7 @@ cc_test(
|
||||||
"//tensorflow/compiler/xla:test",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:test_helpers",
|
"//tensorflow/compiler/xla:test_helpers",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
|
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
|
||||||
"//tensorflow/compiler/xla/service:hlo",
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
"//tensorflow/core:test",
|
"//tensorflow/core:test",
|
||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
|
|
|
||||||
|
|
@ -547,7 +547,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
||||||
return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
|
return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
|
||||||
} else {
|
} else {
|
||||||
// Ranks do not match, so perform InDim broadcasting using
|
// Ranks do not match, so perform InDim broadcasting using
|
||||||
// broadcast_dimensions. Scalar broadcasting is a special case of this).
|
// broadcast_dimensions. Scalar broadcasting is a special case of this.
|
||||||
const Shape& larger_shape =
|
const Shape& larger_shape =
|
||||||
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
|
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
|
||||||
const Shape& smaller_shape =
|
const Shape& smaller_shape =
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
|
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
|
@ -1887,6 +1888,12 @@ class ComputationLowerer {
|
||||||
const ComputationHandle& handle,
|
const ComputationHandle& handle,
|
||||||
VersionedComputationHandle::Version version);
|
VersionedComputationHandle::Version version);
|
||||||
|
|
||||||
|
// This function takes an input value which is being implicitly broadcast into
|
||||||
|
// an output shape and figures out the right kBroadcast instruction(s)
|
||||||
|
// necessary to replicate the implicit broadcast semantics explicitly.
|
||||||
|
HloInstruction* ImplicitBroadcastToExplicitBroadcast(
|
||||||
|
HloInstruction* operand, const Shape& output_shape);
|
||||||
|
|
||||||
HloComputation::Builder hlo_builder_;
|
HloComputation::Builder hlo_builder_;
|
||||||
const SessionComputation& session_computation_;
|
const SessionComputation& session_computation_;
|
||||||
const VersionedComputationHandle::Version version_;
|
const VersionedComputationHandle::Version version_;
|
||||||
|
|
@ -2204,6 +2211,37 @@ HloComputation* ComputationLowerer::ResolveComputation(
|
||||||
return hlo_resolver_(checked_handle);
|
return hlo_resolver_(checked_handle);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
|
||||||
|
HloInstruction* operand, const Shape& output_shape) {
|
||||||
|
CHECK(ShapeUtil::IsScalar(operand->shape()) ||
|
||||||
|
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape));
|
||||||
|
Shape broadcast_shape = ShapeUtil::MakeShape(
|
||||||
|
operand->shape().element_type(), AsInt64Slice(output_shape.dimensions()));
|
||||||
|
// Do explicit broadcast for scalar.
|
||||||
|
if (ShapeUtil::IsScalar(operand->shape())) {
|
||||||
|
return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
|
broadcast_shape, operand, AsInt64Slice(broadcast_shape.dimensions())));
|
||||||
|
}
|
||||||
|
// Do explicit broadcast for degenerate broadcast.
|
||||||
|
std::vector<int64> broadcast_dimensions;
|
||||||
|
std::vector<int64> reshaped_dimensions;
|
||||||
|
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
|
||||||
|
if (operand->shape().dimensions(i) > 1) {
|
||||||
|
broadcast_dimensions.push_back(i);
|
||||||
|
reshaped_dimensions.push_back(operand->shape().dimensions(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Eliminate the size one dimensions.
|
||||||
|
HloInstruction* reshaped_operand =
|
||||||
|
hlo_builder_.AddInstruction(HloInstruction::CreateReshape(
|
||||||
|
ShapeUtil::MakeShape(operand->shape().element_type(),
|
||||||
|
reshaped_dimensions),
|
||||||
|
operand));
|
||||||
|
// Broadcast 'reshape' up to the larger size.
|
||||||
|
return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
|
||||||
|
broadcast_shape, reshaped_operand, broadcast_dimensions));
|
||||||
|
}
|
||||||
|
|
||||||
void ComputationLowerer::Visit(
|
void ComputationLowerer::Visit(
|
||||||
const ComputationDataHandle& handle,
|
const ComputationDataHandle& handle,
|
||||||
std::unordered_map<int64, HloInstruction*>* instructions) {
|
std::unordered_map<int64, HloInstruction*>* instructions) {
|
||||||
|
|
@ -2629,6 +2667,19 @@ void ComputationLowerer::Visit(
|
||||||
lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
|
lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
|
||||||
rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
|
rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
|
||||||
}
|
}
|
||||||
|
if (legacy_flags::GetUserComputationFlags()
|
||||||
|
->xla_eliminate_hlo_implicit_broadcast) {
|
||||||
|
if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
|
||||||
|
// lhs side is being implicitly broadcast. Change to explicit.
|
||||||
|
lhs =
|
||||||
|
ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
|
||||||
|
rhs =
|
||||||
|
ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape());
|
||||||
|
}
|
||||||
|
}
|
||||||
hlo_instruction = add_instruction(HloInstruction::CreateBinary(
|
hlo_instruction = add_instruction(HloInstruction::CreateBinary(
|
||||||
request.output_shape(), hlo_opcode, lhs, rhs));
|
request.output_shape(), hlo_opcode, lhs, rhs));
|
||||||
break;
|
break;
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/service/user_computation.h"
|
#include "tensorflow/compiler/xla/service/user_computation.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||||
|
|
@ -143,5 +144,138 @@ TEST_F(UserComputationTest, SimpleComputation) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST_F(UserComputationTest, EliminateScalarBroadcast) {
|
||||||
|
if (!legacy_flags::GetUserComputationFlags()
|
||||||
|
->xla_eliminate_hlo_implicit_broadcast) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a binary computation with scalar broadcast.
|
||||||
|
//
|
||||||
|
// %a = Constant({123, 42})
|
||||||
|
// %b = Constant(1)
|
||||||
|
// %add = Add(%a, %b)
|
||||||
|
ComputationHandle handle;
|
||||||
|
handle.set_handle(123);
|
||||||
|
UserComputation computation("TheComputation", handle);
|
||||||
|
|
||||||
|
ConstantRequest a_request;
|
||||||
|
*a_request.mutable_literal() = *LiteralUtil::CreateR1<float>({123.0f, 42.0f});
|
||||||
|
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
|
||||||
|
computation.AddConstantInstruction(a_request));
|
||||||
|
|
||||||
|
ConstantRequest b_request;
|
||||||
|
*b_request.mutable_literal() = *LiteralUtil::CreateR0<float>(1.0f);
|
||||||
|
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
|
||||||
|
computation.AddConstantInstruction(b_request));
|
||||||
|
|
||||||
|
BinaryOpRequest add;
|
||||||
|
add.set_binop(BINOP_ADD);
|
||||||
|
*add.mutable_lhs() = a_handle;
|
||||||
|
*add.mutable_rhs() = b_handle;
|
||||||
|
TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
|
||||||
|
|
||||||
|
auto hlo_resolver = [](const VersionedComputationHandle& handle) {
|
||||||
|
return nullptr;
|
||||||
|
};
|
||||||
|
VersionedComputationHandle latest_version = computation.GetVersionedHandle();
|
||||||
|
|
||||||
|
// Build the HLO computation.
|
||||||
|
TF_ASSIGN_OR_ASSERT_OK(
|
||||||
|
std::unique_ptr<HloComputation> hlo_computation,
|
||||||
|
computation.BuildHloComputation(latest_version.version, hlo_resolver));
|
||||||
|
// The binary operation has implicit scalar broadcast, should be converted
|
||||||
|
// to an explicit broadcast intruction and a binary instruction.
|
||||||
|
EXPECT_EQ(4, hlo_computation->instruction_count());
|
||||||
|
EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
|
||||||
|
const auto& operands = hlo_computation->root_instruction()->operands();
|
||||||
|
ASSERT_EQ(2, operands.size());
|
||||||
|
EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast ||
|
||||||
|
operands[1]->opcode() == HloOpcode::kBroadcast);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
|
||||||
|
if (!legacy_flags::GetUserComputationFlags()
|
||||||
|
->xla_eliminate_hlo_implicit_broadcast) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build a binary computation with in-dim broadcast and degenerate broadcast.
|
||||||
|
//
|
||||||
|
// %a = Param({2, 3});
|
||||||
|
// %b = Param({2, 1, 4});
|
||||||
|
// %add = Add(%a, %b, {0, 1});
|
||||||
|
ComputationHandle handle;
|
||||||
|
handle.set_handle(123);
|
||||||
|
UserComputation computation("TheComputation", handle);
|
||||||
|
|
||||||
|
ParameterRequest a_request;
|
||||||
|
*a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3});
|
||||||
|
a_request.set_name("a");
|
||||||
|
a_request.set_parameter(0);
|
||||||
|
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
|
||||||
|
computation.AddParameterInstruction(a_request));
|
||||||
|
|
||||||
|
ParameterRequest b_request;
|
||||||
|
*b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4});
|
||||||
|
b_request.set_name("b");
|
||||||
|
b_request.set_parameter(1);
|
||||||
|
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
|
||||||
|
computation.AddParameterInstruction(b_request));
|
||||||
|
|
||||||
|
BinaryOpRequest add;
|
||||||
|
add.set_binop(BINOP_ADD);
|
||||||
|
*add.mutable_lhs() = a_handle;
|
||||||
|
*add.mutable_rhs() = b_handle;
|
||||||
|
add.add_broadcast_dimensions(0);
|
||||||
|
add.add_broadcast_dimensions(1);
|
||||||
|
TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
|
||||||
|
|
||||||
|
auto hlo_resolver = [](const VersionedComputationHandle& handle) {
|
||||||
|
return nullptr;
|
||||||
|
};
|
||||||
|
VersionedComputationHandle latest_version = computation.GetVersionedHandle();
|
||||||
|
|
||||||
|
// Build the HLO computation.
|
||||||
|
TF_ASSIGN_OR_ASSERT_OK(
|
||||||
|
std::unique_ptr<HloComputation> hlo_computation,
|
||||||
|
computation.BuildHloComputation(latest_version.version, hlo_resolver));
|
||||||
|
|
||||||
|
// The binary operation has in-dim broadcast and degenerate broadcast, should
|
||||||
|
// first do the in-dim broadcast then convert the degnerate broadcast into a
|
||||||
|
// reshape and a broadcast.
|
||||||
|
//
|
||||||
|
// b a
|
||||||
|
// | |
|
||||||
|
// broadcast reshape
|
||||||
|
// | |
|
||||||
|
// | broadcast
|
||||||
|
// \ /
|
||||||
|
// add
|
||||||
|
EXPECT_EQ(6, hlo_computation->instruction_count());
|
||||||
|
EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
|
||||||
|
const auto& operands = hlo_computation->root_instruction()->operands();
|
||||||
|
ASSERT_EQ(2, operands.size());
|
||||||
|
EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast &&
|
||||||
|
operands[1]->opcode() == HloOpcode::kBroadcast);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
||||||
|
int main(int argc, char** argv) {
|
||||||
|
std::vector<tensorflow::Flag> flag_list;
|
||||||
|
xla::legacy_flags::AppendUserComputationFlags(&flag_list);
|
||||||
|
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||||
|
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||||
|
if (!parse_result) {
|
||||||
|
LOG(ERROR) << "\n" << usage;
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
testing::InitGoogleTest(&argc, argv);
|
||||||
|
if (argc > 1) {
|
||||||
|
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
|
||||||
|
return 2;
|
||||||
|
}
|
||||||
|
return RUN_ALL_TESTS();
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -459,16 +459,17 @@ xla_test(
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:shape_util",
|
"//tensorflow/compiler/xla:shape_util",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:xla_data_proto",
|
"//tensorflow/compiler/xla:xla_data_proto",
|
||||||
"//tensorflow/compiler/xla/client:computation_builder",
|
"//tensorflow/compiler/xla/client:computation_builder",
|
||||||
"//tensorflow/compiler/xla/client:global_data",
|
"//tensorflow/compiler/xla/client:global_data",
|
||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
|
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
|
||||||
|
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
|
||||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||||
"//tensorflow/core:lib",
|
"//tensorflow/core:lib",
|
||||||
"//tensorflow/core:test",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -966,13 +967,13 @@ xla_test(
|
||||||
"//tensorflow/compiler/xla:array4d",
|
"//tensorflow/compiler/xla:array4d",
|
||||||
"//tensorflow/compiler/xla:literal_util",
|
"//tensorflow/compiler/xla:literal_util",
|
||||||
"//tensorflow/compiler/xla:statusor",
|
"//tensorflow/compiler/xla:statusor",
|
||||||
"//tensorflow/compiler/xla:test_helpers",
|
"//tensorflow/compiler/xla:test",
|
||||||
"//tensorflow/compiler/xla/client:computation_builder",
|
"//tensorflow/compiler/xla/client:computation_builder",
|
||||||
"//tensorflow/compiler/xla/client:local_client",
|
"//tensorflow/compiler/xla/client:local_client",
|
||||||
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
|
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
|
||||||
|
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
|
||||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||||
"//tensorflow/core:test",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/layout_util.h"
|
#include "tensorflow/compiler/xla/layout_util.h"
|
||||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
|
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
|
||||||
|
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
|
@ -1858,6 +1859,7 @@ INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
std::vector<tensorflow::Flag> flag_list;
|
std::vector<tensorflow::Flag> flag_list;
|
||||||
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
|
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
|
||||||
|
xla::legacy_flags::AppendUserComputationFlags(&flag_list);
|
||||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||||
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||||
if (!parse_result) {
|
if (!parse_result) {
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
|
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
|
||||||
|
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
|
||||||
#include "tensorflow/compiler/xla/literal_util.h"
|
#include "tensorflow/compiler/xla/literal_util.h"
|
||||||
#include "tensorflow/compiler/xla/statusor.h"
|
#include "tensorflow/compiler/xla/statusor.h"
|
||||||
#include "tensorflow/compiler/xla/test.h"
|
#include "tensorflow/compiler/xla/test.h"
|
||||||
|
|
@ -702,6 +703,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
std::vector<tensorflow::Flag> flag_list;
|
std::vector<tensorflow::Flag> flag_list;
|
||||||
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
|
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
|
||||||
|
xla::legacy_flags::AppendUserComputationFlags(&flag_list);
|
||||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||||
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||||
if (!parse_result) {
|
if (!parse_result) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user