[XLA:GPU] Move buffer sharing logic into a separate target (NFC).

This will allow to use the same CanShareBuffer function in gpu_compiler and
hlo_to_llvm_ir testing tool.

PiperOrigin-RevId: 566615614
This commit is contained in:
Adrian Kuegel 2023-09-19 06:40:18 -07:00 committed by TensorFlower Gardener
parent b9209839ea
commit f4529e80ab
7 changed files with 270 additions and 198 deletions

View File

@ -1,7 +1,21 @@
# Description:
# GPU-specific components in XLA service implementation.
load("//xla/tests:build_defs.bzl", "xla_test")
load("@bazel_skylib//rules:common_settings.bzl", "bool_flag")
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("//xla:xla.bzl", "xla_cc_test")
load(
"//xla/stream_executor:build_defs.bzl",
"if_gpu_is_configured",
)
load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm_hipblaslt",
"if_rocm_is_configured",
)
load("@local_tsl//tsl:tsl.bzl", "if_google", "if_nccl", "tsl_copts", "tsl_gpu_library")
load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable")
load(
"@local_tsl//tsl/platform:build_config.bzl",
"tf_proto_library",
@ -11,24 +25,10 @@ load(
"if_static",
"tf_cuda_tests_tags",
)
load("@local_tsl//tsl:tsl.bzl", "if_google", "if_nccl", "tsl_copts", "tsl_gpu_library")
load(
"@local_config_rocm//rocm:build_defs.bzl",
"if_rocm_hipblaslt",
"if_rocm_is_configured",
)
load("//xla:xla.bzl", "xla_cc_test")
load("//xla/tests:build_defs.bzl", "xla_test")
load(
"//xla/stream_executor:build_defs.bzl",
"if_gpu_is_configured",
)
load(
"@local_tsl//tsl/platform/default:cuda_build_defs.bzl",
"if_cuda_is_configured",
)
load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda")
load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable")
package(
default_visibility = ["//visibility:public"],
@ -183,7 +183,7 @@ xla_cc_test(
srcs = if_gpu_is_configured(["gpu_copy_insertion_test.cc"]),
tags = tf_cuda_tests_tags(),
deps = if_gpu_is_configured([
":gpu_compiler",
":buffer_sharing",
"//xla:test",
"//xla/hlo/ir:hlo",
"//xla/service:copy_insertion",
@ -2562,6 +2562,7 @@ cc_library(
":alias_passthrough_params",
":all_reduce_blueconnect",
":autotuner_util",
":buffer_sharing",
":compile_module_to_llvm_ir",
":conv_layout_normalization",
":copy_fusion",
@ -2801,6 +2802,7 @@ cc_library(
visibility = ["//visibility:public"],
deps = if_cuda_is_configured([
":autotuner_util",
":buffer_sharing",
":cublas_cudnn",
":cublas_pad_for_gemms",
":cublas_padding_requirements",
@ -3442,6 +3444,21 @@ xla_cc_test(
]),
)
cc_library(
name = "buffer_sharing",
srcs = ["buffer_sharing.cc"],
hdrs = ["buffer_sharing.h"],
visibility = ["//visibility:public"],
deps = [
":backend_configs_cc",
":cublas_cudnn",
":ir_emission_utils",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"@com_google_absl//absl/container:flat_hash_set",
],
)
cc_library(
name = "gpu_fusible",
srcs = ["gpu_fusible.cc"],

View File

@ -0,0 +1,182 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/service/gpu/buffer_sharing.h"
#include <cstdint>
#include <optional>
#include <queue>
#include <utility>
#include "absl/container/flat_hash_set.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/ir_emission_utils.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
namespace xla {
namespace gpu {
std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
const HloInstruction* operand,
const ShapeIndex& user_index) {
if (user->opcode() != HloOpcode::kFusion) {
return std::nullopt;
}
// First, do the trivial check: if the fusion operand and the fusion output
// have a different number of elements or have a different element byte size,
// the buffer cannot be shared.
const Shape& user_subshape =
ShapeUtil::GetSubshape(user->shape(), user_index);
const Shape& operand_shape = operand->shape();
const bool shapes_equal = ShapeUtil::Equal(operand_shape, user_subshape);
if (!shapes_equal) {
if (!operand_shape.IsArray() || !user_subshape.IsArray()) {
return false;
}
// We cannot share the buffer if the iteration space is not the same.
if (ShapeUtil::ElementsIn(operand_shape) !=
ShapeUtil::ElementsIn(user_subshape)) {
return false;
}
// The buffers needed for 'user_subshape' and 'operand_shape' need to have
// the same size, otherwise they cannot be shared. We already checked that
// the number of elements are the same, so now we check the number of bytes
// needed for the element types.
if (ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type()) !=
ShapeUtil::ByteSizeOfPrimitiveType(user_subshape.element_type())) {
return false;
}
}
// We need to make sure that the fusion parameter is accessed in the same
// iteration order as the fusion output. Also, there should not be two fusion
// outputs that consume the fusion parameter, because we do not want to share
// the same fusion operand with two different fusion outputs. To make sure
// that the iteration order is the same, we only allow ops on the path from
// fusion parameter to fusion output which are elementwise (no copy) or
// bitcast or an elementwise dynamic update slice (i.e. with the first operand
// being on this path).
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
HloInstruction* output = user->fused_expression_root();
for (int64_t o : user_index) {
output = output->mutable_operand(o);
}
const HloInstruction* non_bitcast_root = output;
if (non_bitcast_root->opcode() == HloOpcode::kBitcast) {
non_bitcast_root = non_bitcast_root->operand(0);
}
std::queue<HloInstruction*> q;
absl::flat_hash_set<HloInstruction*> visited;
q.push(fusion_param);
visited.insert(fusion_param);
bool found_path_to_output = false;
while (!q.empty()) {
HloInstruction* hlo_operand = q.front();
q.pop();
if (hlo_operand == output) {
found_path_to_output = true;
// The output should have at most 1 user: the tuple op (in case of a
// multi-output fusion)
if (hlo_operand->user_count() > 1) {
return false;
}
continue;
}
for (HloInstruction* hlo : hlo_operand->users()) {
if (non_bitcast_root->opcode() == HloOpcode::kDynamicUpdateSlice &&
hlo->opcode() == HloOpcode::kDynamicSlice &&
non_bitcast_root->operand(0) == hlo->operand(0) &&
hlo->shape() == non_bitcast_root->operand(1)->shape()) {
// We can still share the buffer in this case if the same slice is
// accessed by the DUS and the DS. If they don't access the same slice,
// the two slices might partially overlap and read/write the same index
// at different times, and then we cannot guarantee that we read before
// it is overwritten. However if both access only a single element,
// there also can be no race condition.
if (!ShapeUtil::IsEffectiveScalar(hlo->shape()) ||
!ShapeUtil::IsEffectiveScalar(
non_bitcast_root->operand(1)->shape())) {
// Now compare all the slice start operands of 'hlo' and
// 'non_bitcast_root'.
for (int64_t i = 1; i < hlo->operand_count(); ++i) {
if (hlo->operand(i) != non_bitcast_root->operand(i + 1)) {
return false;
}
}
}
} else if ((!hlo->IsElementwiseOnOperand(
hlo->operand_index(hlo_operand)) ||
hlo->opcode() == HloOpcode::kCopy) &&
hlo->opcode() != HloOpcode::kBitcast) {
// This check also catches the case that we reach a different fusion
// output, as that fusion output would have a tuple op as user, which we
// do not allow here.
// Even if 'hlo' is not elementwise on the operand, it is ok if we are
// coming from the second operand and 'hlo' is a DynamicUpdateSlice
// which is the non_bitcast_root. This corresponds to the special case
// above, where we allow a DynamicSlice if it accesses the exact same
// slice than the DynamicUpdateSlice. When we are coming from the first
// operand, IsElementwiseOnOperand() will return true for a
// DynamicUpdateSlice.
if (hlo != non_bitcast_root ||
hlo->opcode() != HloOpcode::kDynamicUpdateSlice ||
hlo->operand_index(hlo_operand) != 1) {
return false;
}
}
if (visited.insert(hlo).second) {
q.push(hlo);
}
}
}
return found_path_to_output;
}
std::optional<bool> CanShareBufferHint(const HloInstruction* user,
const HloInstruction* operand,
const ShapeIndex& user_index) {
switch (user->opcode()) {
case HloOpcode::kAllReduce:
// NCCL all-reduce can be performed in-place.
return user->operand_count() == 1 ||
(user_index.size() == 1 &&
user->operand(user_index[0]) == operand);
case HloOpcode::kCustomCall:
// The matrix bias operand can be overwritten in-place.
if (user->custom_call_target() == kCublasLtMatmulCallTarget) {
GemmBackendConfig config =
std::move(user->backend_config<GemmBackendConfig>()).value();
return (config.beta() != 0.) && user->operand(2) == operand;
}
// The operand of cholesky can be shared with the first output.
if (user->custom_call_target() == kCusolverCholeskyCallTarget) {
return user_index.size() == 1 && user_index[0] == 0;
}
return false;
case HloOpcode::kFusion:
return FusionCanShareBufferHint(user, operand, user_index);
default:
return std::nullopt;
}
}
} // namespace gpu
} // namespace xla

View File

@ -0,0 +1,36 @@
/* Copyright 2023 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef XLA_SERVICE_GPU_BUFFER_SHARING_H_
#define XLA_SERVICE_GPU_BUFFER_SHARING_H_
#include <optional>
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/shape_util.h"
namespace xla {
namespace gpu {
std::optional<bool> FusionCanShareBufferHint(const HloInstruction* user,
const HloInstruction* operand,
const ShapeIndex& user_index);
std::optional<bool> CanShareBufferHint(const HloInstruction* user,
const HloInstruction* operand,
const ShapeIndex& user_index);
} // namespace gpu
} // namespace xla
#endif // XLA_SERVICE_GPU_BUFFER_SHARING_H_

View File

@ -21,14 +21,12 @@ limitations under the License.
#include <functional>
#include <memory>
#include <optional>
#include <queue>
#include <string>
#include <tuple>
#include <utility>
#include <variant>
#include <vector>
#include "absl/container/flat_hash_set.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/strings/str_cat.h"
@ -1693,123 +1691,5 @@ StatusOr<std::unique_ptr<AotCompilationResult>> GpuCompiler::Export(
return result;
}
std::optional<bool> GpuCompiler::FusionCanShareBufferHint(
const HloInstruction* user, const HloInstruction* operand,
const ShapeIndex& user_index) {
if (user->opcode() != HloOpcode::kFusion) {
return std::nullopt;
}
// First, do the trivial check: if the fusion operand and the fusion output
// have a different number of elements or have a different element byte size,
// the buffer cannot be shared.
const Shape& user_subshape =
ShapeUtil::GetSubshape(user->shape(), user_index);
const Shape& operand_shape = operand->shape();
const bool shapes_equal = ShapeUtil::Equal(operand_shape, user_subshape);
if (!shapes_equal) {
if (!operand_shape.IsArray() || !user_subshape.IsArray()) {
return false;
}
// We cannot share the buffer if the iteration space is not the same.
if (ShapeUtil::ElementsIn(operand_shape) !=
ShapeUtil::ElementsIn(user_subshape)) {
return false;
}
// The buffers needed for 'user_subshape' and 'operand_shape' need to have
// the same size, otherwise they cannot be shared. We already checked that
// the number of elements are the same, so now we check the number of bytes
// needed for the element types.
if (ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type()) !=
ShapeUtil::ByteSizeOfPrimitiveType(user_subshape.element_type())) {
return false;
}
}
// We need to make sure that the fusion parameter is accessed in the same
// iteration order as the fusion output. Also, there should not be two fusion
// outputs that consume the fusion parameter, because we do not want to share
// the same fusion operand with two different fusion outputs. To make sure
// that the iteration order is the same, we only allow ops on the path from
// fusion parameter to fusion output which are elementwise (no copy) or
// bitcast or an elementwise dynamic update slice (i.e. with the first operand
// being on this path).
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
HloInstruction* output = user->fused_expression_root();
for (int64_t o : user_index) {
output = output->mutable_operand(o);
}
const HloInstruction* non_bitcast_root = output;
if (non_bitcast_root->opcode() == HloOpcode::kBitcast) {
non_bitcast_root = non_bitcast_root->operand(0);
}
std::queue<HloInstruction*> q;
absl::flat_hash_set<HloInstruction*> visited;
q.push(fusion_param);
visited.insert(fusion_param);
bool found_path_to_output = false;
while (!q.empty()) {
HloInstruction* hlo_operand = q.front();
q.pop();
if (hlo_operand == output) {
found_path_to_output = true;
// The output should have at most 1 user: the tuple op (in case of a
// multi-output fusion)
if (hlo_operand->user_count() > 1) {
return false;
}
continue;
}
for (HloInstruction* hlo : hlo_operand->users()) {
if (non_bitcast_root->opcode() == HloOpcode::kDynamicUpdateSlice &&
hlo->opcode() == HloOpcode::kDynamicSlice &&
non_bitcast_root->operand(0) == hlo->operand(0) &&
hlo->shape() == non_bitcast_root->operand(1)->shape()) {
// We can still share the buffer in this case if the same slice is
// accessed by the DUS and the DS. If they don't access the same slice,
// the two slices might partially overlap and read/write the same index
// at different times, and then we cannot guarantee that we read before
// it is overwritten. However if both access only a single element,
// there also can be no race condition.
if (!ShapeUtil::IsEffectiveScalar(hlo->shape()) ||
!ShapeUtil::IsEffectiveScalar(
non_bitcast_root->operand(1)->shape())) {
// Now compare all the slice start operands of 'hlo' and
// 'non_bitcast_root'.
for (int64_t i = 1; i < hlo->operand_count(); ++i) {
if (hlo->operand(i) != non_bitcast_root->operand(i + 1)) {
return false;
}
}
}
} else if ((!hlo->IsElementwiseOnOperand(
hlo->operand_index(hlo_operand)) ||
hlo->opcode() == HloOpcode::kCopy) &&
hlo->opcode() != HloOpcode::kBitcast) {
// This check also catches the case that we reach a different fusion
// output, as that fusion output would have a tuple op as user, which we
// do not allow here.
// Even if 'hlo' is not elementwise on the operand, it is ok if we are
// coming from the second operand and 'hlo' is a DynamicUpdateSlice
// which is the non_bitcast_root. This corresponds to the special case
// above, where we allow a DynamicSlice if it accesses the exact same
// slice than the DynamicUpdateSlice. When we are coming from the first
// operand, IsElementwiseOnOperand() will return true for a
// DynamicUpdateSlice.
if (hlo != non_bitcast_root ||
hlo->opcode() != HloOpcode::kDynamicUpdateSlice ||
hlo->operand_index(hlo_operand) != 1) {
return false;
}
}
if (visited.insert(hlo).second) {
q.push(hlo);
}
}
}
return found_path_to_output;
}
} // namespace gpu
} // namespace xla

View File

@ -27,6 +27,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/service/executable.h"
#include "xla/service/gpu/autotuner_util.h"
#include "xla/service/gpu/buffer_sharing.h"
#include "xla/service/gpu/executable.pb.h"
#include "xla/service/gpu/gpu_device_info.h"
#include "xla/service/gpu/gpu_executable.h"
@ -170,10 +171,6 @@ class GpuCompiler : public LLVMCompiler {
StatusOr<std::unique_ptr<AotCompilationResult>> Export(
Executable* executable) const override;
static std::optional<bool> FusionCanShareBufferHint(
const HloInstruction* user, const HloInstruction* operand,
const ShapeIndex& user_index);
protected:
// During compilation with device, stream_exec != null and autotune_results
// == null. During deviceless AOT compilation, stream_exec == null and

View File

@ -20,7 +20,7 @@ limitations under the License.
#include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/ir/hlo_opcode.h"
#include "xla/service/copy_insertion.h"
#include "xla/service/gpu/gpu_compiler.h"
#include "xla/service/gpu/buffer_sharing.h"
#include "xla/test.h"
#include "xla/tests/hlo_test_base.h"
@ -110,7 +110,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
CopyInsertion copy_insertion(GpuCompiler::FusionCanShareBufferHint,
CopyInsertion copy_insertion(FusionCanShareBufferHint,
/*use_region_based_live_range_analysis=*/0);
ASSERT_IS_OK(copy_insertion.Run(module.get(), {"foobar"}).status());
VLOG(2) << module->ToString();
@ -142,8 +142,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalTrue(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedBitcastedShape) {
@ -166,8 +165,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalTrue(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest,
@ -191,8 +189,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalTrue(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedMultiOutputFusion) {
@ -217,16 +214,15 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalTrue(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {0}));
ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {0}));
// The second operand cannot share the buffer with the second fusion output,
// because the 'neg' op is also used on the path to the first fusion output.
ExpectOptionalFalse(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(1), {1}));
FusionCanShareBufferHint(fusion, fusion->operand(1), {1}));
// The first operand cannot share the buffer with the second fusion output,
// because there is no path between them.
ExpectOptionalFalse(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {1}));
FusionCanShareBufferHint(fusion, fusion->operand(0), {1}));
}
TEST_F(FusionCanShareBufferHintTest,
@ -250,8 +246,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalFalse(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest, BufferCannotBeSharedShapeBitcastConvert) {
@ -274,8 +269,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalFalse(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest, BufferCannotBeSharedDueToCopy) {
@ -297,8 +291,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalFalse(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest, BufferCannotBeSharedDueToTranspose) {
@ -320,8 +313,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalFalse(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest,
@ -351,8 +343,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalFalse(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest,
@ -383,8 +374,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalTrue(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest,
@ -418,7 +408,7 @@ ENTRY main {
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalFalse(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {0}));
FusionCanShareBufferHint(fusion, fusion->operand(0), {0}));
}
TEST_F(FusionCanShareBufferHintTest,
@ -449,8 +439,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalTrue(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest,
@ -481,8 +470,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalFalse(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
TEST_F(FusionCanShareBufferHintTest,
@ -513,8 +501,7 @@ ENTRY main {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<xla::HloModule> module,
ParseAndReturnVerifiedModule(kModuleString));
HloInstruction* fusion = module->entry_computation()->root_instruction();
ExpectOptionalFalse(
GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {}));
}
} // namespace

View File

@ -37,6 +37,7 @@ limitations under the License.
#include "xla/service/float_normalization.h"
#include "xla/service/float_support.h"
#include "xla/service/gpu/autotuner_util.h"
#include "xla/service/gpu/buffer_sharing.h"
#include "xla/service/gpu/conv_algorithm_picker.h"
#include "xla/service/gpu/cublas_cudnn.h"
#include "xla/service/gpu/cublas_pad_for_gemms.h"
@ -332,34 +333,6 @@ Status NVPTXCompiler::SerializeAutotuneResultsToFile(
}
namespace {
std::optional<bool> CanShareBufferHint(const HloInstruction* user,
const HloInstruction* operand,
const ShapeIndex& user_index) {
switch (user->opcode()) {
case HloOpcode::kAllReduce:
// NCCL all-reduce can be performed in-place.
return user->operand_count() == 1 ||
(user_index.size() == 1 &&
user->operand(user_index[0]) == operand);
case HloOpcode::kCustomCall:
// The matrix bias operand can be overwritten in-place.
if (user->custom_call_target() == kCublasLtMatmulCallTarget) {
GemmBackendConfig config =
std::move(user->backend_config<GemmBackendConfig>()).value();
return (config.beta() != 0.) && user->operand(2) == operand;
}
// The operand of cholesky can be shared with the first output.
if (user->custom_call_target() == kCusolverCholeskyCallTarget) {
return user_index.size() == 1 && user_index[0] == 0;
}
return false;
case HloOpcode::kFusion:
return GpuCompiler::FusionCanShareBufferHint(user, operand, user_index);
default:
return std::nullopt;
}
}
// Try to load ptx from files defined in the FLAGS. If successful, return true.
bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
const HloModule* module, std::string* ptx) {