mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 12:20:24 +01:00
An initial step of eliminating all implicit broadcast at the HLO level.
Guard the shape inference for binary ops behind a flag. PiperOrigin-RevId: 157373647
This commit is contained in:
parent
e78e5ec8a8
commit
5f097217f4
|
|
@ -240,6 +240,18 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "user_computation_flags",
|
||||
srcs = ["user_computation_flags.cc"],
|
||||
hdrs = ["user_computation_flags.h"],
|
||||
deps = [
|
||||
":parse_flags_from_env",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/core:framework_internal",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
filegroup(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,64 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <mutex> // NOLINT(build/c++11): only using std::call_once, not mutex.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/parse_flags_from_env.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Pointers to the parsed value of the flags and flag descriptors, initialized
|
||||
// via flags_init.
|
||||
static UserComputationFlags* flags;
|
||||
static std::vector<tensorflow::Flag>* flag_list;
|
||||
static std::once_flag flags_init;
|
||||
|
||||
// Allocate *flags. Called via call_once(&flags_init,...).
|
||||
static void AllocateFlags() {
|
||||
flags = new UserComputationFlags;
|
||||
flags->xla_eliminate_hlo_implicit_broadcast = false;
|
||||
flag_list = new std::vector<tensorflow::Flag>({
|
||||
tensorflow::Flag("xla_eliminate_hlo_implicit_broadcast",
|
||||
&flags->xla_eliminate_hlo_implicit_broadcast,
|
||||
"Eliminate implicit broadcast on when lowering user "
|
||||
"computation to HLO instructions, use explicit "
|
||||
"broadcast instead."),
|
||||
});
|
||||
ParseFlagsFromEnv(*flag_list);
|
||||
}
|
||||
|
||||
// Append to *append_to flag definitions associated with XLA's hlo_pass_pipeline
|
||||
// module.
|
||||
void AppendUserComputationFlags(std::vector<tensorflow::Flag>* append_to) {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
append_to->insert(append_to->end(), flag_list->begin(), flag_list->end());
|
||||
}
|
||||
|
||||
// Return a pointer to the UserComputationFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
UserComputationFlags* GetUserComputationFlags() {
|
||||
std::call_once(flags_init, &AllocateFlags);
|
||||
return flags;
|
||||
}
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
|
||||
|
||||
// Legacy flags for XLA's user_computation module.
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/compiler/xla/types.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/util/command_line_flags.h"
|
||||
|
||||
namespace xla {
|
||||
namespace legacy_flags {
|
||||
|
||||
// Append to *flag_list flags definitions associated with XLA's user_computation
|
||||
// module.
|
||||
void AppendUserComputationFlags(std::vector<tensorflow::Flag>* flag_list);
|
||||
|
||||
typedef struct {
|
||||
// Eliminate implicit broadcast on when lowering user computation to HLO
|
||||
// instructions, use explicit broadcast instead.
|
||||
bool xla_eliminate_hlo_implicit_broadcast;
|
||||
} UserComputationFlags;
|
||||
|
||||
// Return a pointer to the UserComputationFlags struct;
|
||||
// repeated calls return the same pointer.
|
||||
// This should be called only after Flags::Parse() has returned.
|
||||
UserComputationFlags* GetUserComputationFlags();
|
||||
|
||||
} // namespace legacy_flags
|
||||
} // namespace xla
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_LEGACY_FLAGS_USER_COMPUTATION_FLAGS_H_
|
||||
|
|
@ -282,6 +282,7 @@ cc_library(
|
|||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
|
||||
"//tensorflow/core:lib",
|
||||
],
|
||||
)
|
||||
|
|
@ -298,6 +299,7 @@ cc_test(
|
|||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
|
|
|
|||
|
|
@ -547,7 +547,7 @@ ShapeInference::InferDegenerateDimensionBroadcastShape(
|
|||
return InferDegenerateDimensionBroadcastShape(operation, lhs, rhs);
|
||||
} else {
|
||||
// Ranks do not match, so perform InDim broadcasting using
|
||||
// broadcast_dimensions. Scalar broadcasting is a special case of this).
|
||||
// broadcast_dimensions. Scalar broadcasting is a special case of this.
|
||||
const Shape& larger_shape =
|
||||
ShapeUtil::Rank(lhs) > ShapeUtil::Rank(rhs) ? lhs : rhs;
|
||||
const Shape& smaller_shape =
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include <utility>
|
||||
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/ptr_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
|
|
@ -1887,6 +1888,12 @@ class ComputationLowerer {
|
|||
const ComputationHandle& handle,
|
||||
VersionedComputationHandle::Version version);
|
||||
|
||||
// This function takes an input value which is being implicitly broadcast into
|
||||
// an output shape and figures out the right kBroadcast instruction(s)
|
||||
// necessary to replicate the implicit broadcast semantics explicitly.
|
||||
HloInstruction* ImplicitBroadcastToExplicitBroadcast(
|
||||
HloInstruction* operand, const Shape& output_shape);
|
||||
|
||||
HloComputation::Builder hlo_builder_;
|
||||
const SessionComputation& session_computation_;
|
||||
const VersionedComputationHandle::Version version_;
|
||||
|
|
@ -2204,6 +2211,37 @@ HloComputation* ComputationLowerer::ResolveComputation(
|
|||
return hlo_resolver_(checked_handle);
|
||||
}
|
||||
|
||||
HloInstruction* ComputationLowerer::ImplicitBroadcastToExplicitBroadcast(
|
||||
HloInstruction* operand, const Shape& output_shape) {
|
||||
CHECK(ShapeUtil::IsScalar(operand->shape()) ||
|
||||
ShapeUtil::Rank(operand->shape()) == ShapeUtil::Rank(output_shape));
|
||||
Shape broadcast_shape = ShapeUtil::MakeShape(
|
||||
operand->shape().element_type(), AsInt64Slice(output_shape.dimensions()));
|
||||
// Do explicit broadcast for scalar.
|
||||
if (ShapeUtil::IsScalar(operand->shape())) {
|
||||
return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
broadcast_shape, operand, AsInt64Slice(broadcast_shape.dimensions())));
|
||||
}
|
||||
// Do explicit broadcast for degenerate broadcast.
|
||||
std::vector<int64> broadcast_dimensions;
|
||||
std::vector<int64> reshaped_dimensions;
|
||||
for (int i = 0; i < ShapeUtil::Rank(operand->shape()); i++) {
|
||||
if (operand->shape().dimensions(i) > 1) {
|
||||
broadcast_dimensions.push_back(i);
|
||||
reshaped_dimensions.push_back(operand->shape().dimensions(i));
|
||||
}
|
||||
}
|
||||
// Eliminate the size one dimensions.
|
||||
HloInstruction* reshaped_operand =
|
||||
hlo_builder_.AddInstruction(HloInstruction::CreateReshape(
|
||||
ShapeUtil::MakeShape(operand->shape().element_type(),
|
||||
reshaped_dimensions),
|
||||
operand));
|
||||
// Broadcast 'reshape' up to the larger size.
|
||||
return hlo_builder_.AddInstruction(HloInstruction::CreateBroadcast(
|
||||
broadcast_shape, reshaped_operand, broadcast_dimensions));
|
||||
}
|
||||
|
||||
void ComputationLowerer::Visit(
|
||||
const ComputationDataHandle& handle,
|
||||
std::unordered_map<int64, HloInstruction*>* instructions) {
|
||||
|
|
@ -2629,6 +2667,19 @@ void ComputationLowerer::Visit(
|
|||
lhs = (lhs == operand_to_broadcast) ? broadcasted_operand : lhs;
|
||||
rhs = (rhs == operand_to_broadcast) ? broadcasted_operand : rhs;
|
||||
}
|
||||
if (legacy_flags::GetUserComputationFlags()
|
||||
->xla_eliminate_hlo_implicit_broadcast) {
|
||||
if (!ShapeUtil::SameDimensions(request.output_shape(), lhs->shape())) {
|
||||
// lhs side is being implicitly broadcast. Change to explicit.
|
||||
lhs =
|
||||
ImplicitBroadcastToExplicitBroadcast(lhs, request.output_shape());
|
||||
}
|
||||
|
||||
if (!ShapeUtil::SameDimensions(request.output_shape(), rhs->shape())) {
|
||||
rhs =
|
||||
ImplicitBroadcastToExplicitBroadcast(rhs, request.output_shape());
|
||||
}
|
||||
}
|
||||
hlo_instruction = add_instruction(HloInstruction::CreateBinary(
|
||||
request.output_shape(), hlo_opcode, lhs, rhs));
|
||||
break;
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/xla/service/user_computation.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_matchers.h"
|
||||
|
|
@ -143,5 +144,138 @@ TEST_F(UserComputationTest, SimpleComputation) {
|
|||
}
|
||||
}
|
||||
|
||||
TEST_F(UserComputationTest, EliminateScalarBroadcast) {
|
||||
if (!legacy_flags::GetUserComputationFlags()
|
||||
->xla_eliminate_hlo_implicit_broadcast) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a binary computation with scalar broadcast.
|
||||
//
|
||||
// %a = Constant({123, 42})
|
||||
// %b = Constant(1)
|
||||
// %add = Add(%a, %b)
|
||||
ComputationHandle handle;
|
||||
handle.set_handle(123);
|
||||
UserComputation computation("TheComputation", handle);
|
||||
|
||||
ConstantRequest a_request;
|
||||
*a_request.mutable_literal() = *LiteralUtil::CreateR1<float>({123.0f, 42.0f});
|
||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
|
||||
computation.AddConstantInstruction(a_request));
|
||||
|
||||
ConstantRequest b_request;
|
||||
*b_request.mutable_literal() = *LiteralUtil::CreateR0<float>(1.0f);
|
||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
|
||||
computation.AddConstantInstruction(b_request));
|
||||
|
||||
BinaryOpRequest add;
|
||||
add.set_binop(BINOP_ADD);
|
||||
*add.mutable_lhs() = a_handle;
|
||||
*add.mutable_rhs() = b_handle;
|
||||
TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
|
||||
|
||||
auto hlo_resolver = [](const VersionedComputationHandle& handle) {
|
||||
return nullptr;
|
||||
};
|
||||
VersionedComputationHandle latest_version = computation.GetVersionedHandle();
|
||||
|
||||
// Build the HLO computation.
|
||||
TF_ASSIGN_OR_ASSERT_OK(
|
||||
std::unique_ptr<HloComputation> hlo_computation,
|
||||
computation.BuildHloComputation(latest_version.version, hlo_resolver));
|
||||
// The binary operation has implicit scalar broadcast, should be converted
|
||||
// to an explicit broadcast intruction and a binary instruction.
|
||||
EXPECT_EQ(4, hlo_computation->instruction_count());
|
||||
EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
|
||||
const auto& operands = hlo_computation->root_instruction()->operands();
|
||||
ASSERT_EQ(2, operands.size());
|
||||
EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast ||
|
||||
operands[1]->opcode() == HloOpcode::kBroadcast);
|
||||
}
|
||||
|
||||
TEST_F(UserComputationTest, EliminateDegenerateBroadcastAfterIndimBroadcast) {
|
||||
if (!legacy_flags::GetUserComputationFlags()
|
||||
->xla_eliminate_hlo_implicit_broadcast) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Build a binary computation with in-dim broadcast and degenerate broadcast.
|
||||
//
|
||||
// %a = Param({2, 3});
|
||||
// %b = Param({2, 1, 4});
|
||||
// %add = Add(%a, %b, {0, 1});
|
||||
ComputationHandle handle;
|
||||
handle.set_handle(123);
|
||||
UserComputation computation("TheComputation", handle);
|
||||
|
||||
ParameterRequest a_request;
|
||||
*a_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 3});
|
||||
a_request.set_name("a");
|
||||
a_request.set_parameter(0);
|
||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle a_handle,
|
||||
computation.AddParameterInstruction(a_request));
|
||||
|
||||
ParameterRequest b_request;
|
||||
*b_request.mutable_shape() = ShapeUtil::MakeShape(F32, {2, 1, 4});
|
||||
b_request.set_name("b");
|
||||
b_request.set_parameter(1);
|
||||
TF_ASSIGN_OR_ASSERT_OK(ComputationDataHandle b_handle,
|
||||
computation.AddParameterInstruction(b_request));
|
||||
|
||||
BinaryOpRequest add;
|
||||
add.set_binop(BINOP_ADD);
|
||||
*add.mutable_lhs() = a_handle;
|
||||
*add.mutable_rhs() = b_handle;
|
||||
add.add_broadcast_dimensions(0);
|
||||
add.add_broadcast_dimensions(1);
|
||||
TF_ASSERT_OK(computation.AddBinaryInstruction(add).status());
|
||||
|
||||
auto hlo_resolver = [](const VersionedComputationHandle& handle) {
|
||||
return nullptr;
|
||||
};
|
||||
VersionedComputationHandle latest_version = computation.GetVersionedHandle();
|
||||
|
||||
// Build the HLO computation.
|
||||
TF_ASSIGN_OR_ASSERT_OK(
|
||||
std::unique_ptr<HloComputation> hlo_computation,
|
||||
computation.BuildHloComputation(latest_version.version, hlo_resolver));
|
||||
|
||||
// The binary operation has in-dim broadcast and degenerate broadcast, should
|
||||
// first do the in-dim broadcast then convert the degnerate broadcast into a
|
||||
// reshape and a broadcast.
|
||||
//
|
||||
// b a
|
||||
// | |
|
||||
// broadcast reshape
|
||||
// | |
|
||||
// | broadcast
|
||||
// \ /
|
||||
// add
|
||||
EXPECT_EQ(6, hlo_computation->instruction_count());
|
||||
EXPECT_THAT(hlo_computation->root_instruction(), op::Add());
|
||||
const auto& operands = hlo_computation->root_instruction()->operands();
|
||||
ASSERT_EQ(2, operands.size());
|
||||
EXPECT_TRUE(operands[0]->opcode() == HloOpcode::kBroadcast &&
|
||||
operands[1]->opcode() == HloOpcode::kBroadcast);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
xla::legacy_flags::AppendUserComputationFlags(&flag_list);
|
||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
if (!parse_result) {
|
||||
LOG(ERROR) << "\n" << usage;
|
||||
return 2;
|
||||
}
|
||||
testing::InitGoogleTest(&argc, argv);
|
||||
if (argc > 1) {
|
||||
LOG(ERROR) << "Unknown argument " << argv[1] << "\n" << usage;
|
||||
return 2;
|
||||
}
|
||||
return RUN_ALL_TESTS();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -459,16 +459,17 @@ xla_test(
|
|||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:xla_data_proto",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:global_data",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -966,13 +967,13 @@ xla_test(
|
|||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:literal_util",
|
||||
"//tensorflow/compiler/xla:statusor",
|
||||
"//tensorflow/compiler/xla:test_helpers",
|
||||
"//tensorflow/compiler/xla:test",
|
||||
"//tensorflow/compiler/xla/client:computation_builder",
|
||||
"//tensorflow/compiler/xla/client:local_client",
|
||||
"//tensorflow/compiler/xla/legacy_flags:cpu_compiler_flags",
|
||||
"//tensorflow/compiler/xla/legacy_flags:user_computation_flags",
|
||||
"//tensorflow/compiler/xla/tests:client_library_test_base",
|
||||
"//tensorflow/compiler/xla/tests:literal_test_util",
|
||||
"//tensorflow/core:test",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/layout_util.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
|
|
@ -1858,6 +1859,7 @@ INSTANTIATE_TEST_CASE_P(ArrayElementwiseOpTestParamCount,
|
|||
int main(int argc, char** argv) {
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
|
||||
xla::legacy_flags::AppendUserComputationFlags(&flag_list);
|
||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
if (!parse_result) {
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ limitations under the License.
|
|||
#include "tensorflow/compiler/xla/client/computation_builder.h"
|
||||
#include "tensorflow/compiler/xla/client/local_client.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/cpu_compiler_flags.h"
|
||||
#include "tensorflow/compiler/xla/legacy_flags/user_computation_flags.h"
|
||||
#include "tensorflow/compiler/xla/literal_util.h"
|
||||
#include "tensorflow/compiler/xla/statusor.h"
|
||||
#include "tensorflow/compiler/xla/test.h"
|
||||
|
|
@ -702,6 +703,7 @@ XLA_TEST_F(BroadcastSimpleTest, InvalidDegenerateBroadcasting) {
|
|||
int main(int argc, char** argv) {
|
||||
std::vector<tensorflow::Flag> flag_list;
|
||||
xla::legacy_flags::AppendCpuCompilerFlags(&flag_list);
|
||||
xla::legacy_flags::AppendUserComputationFlags(&flag_list);
|
||||
xla::string usage = tensorflow::Flags::Usage(argv[0], flag_list);
|
||||
const bool parse_result = tensorflow::Flags::Parse(&argc, argv, flag_list);
|
||||
if (!parse_result) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user