diff --git a/third_party/xla/xla/service/gpu/BUILD b/third_party/xla/xla/service/gpu/BUILD index 4c10efd9148..02f2eb5deda 100644 --- a/third_party/xla/xla/service/gpu/BUILD +++ b/third_party/xla/xla/service/gpu/BUILD @@ -1,7 +1,21 @@ # Description: # GPU-specific components in XLA service implementation. +load("//xla/tests:build_defs.bzl", "xla_test") load("@bazel_skylib//rules:common_settings.bzl", "bool_flag") +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("//xla:xla.bzl", "xla_cc_test") +load( + "//xla/stream_executor:build_defs.bzl", + "if_gpu_is_configured", +) +load( + "@local_config_rocm//rocm:build_defs.bzl", + "if_rocm_hipblaslt", + "if_rocm_is_configured", +) +load("@local_tsl//tsl:tsl.bzl", "if_google", "if_nccl", "tsl_copts", "tsl_gpu_library") +load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") load( "@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library", @@ -11,24 +25,10 @@ load( "if_static", "tf_cuda_tests_tags", ) -load("@local_tsl//tsl:tsl.bzl", "if_google", "if_nccl", "tsl_copts", "tsl_gpu_library") -load( - "@local_config_rocm//rocm:build_defs.bzl", - "if_rocm_hipblaslt", - "if_rocm_is_configured", -) -load("//xla:xla.bzl", "xla_cc_test") -load("//xla/tests:build_defs.bzl", "xla_test") -load( - "//xla/stream_executor:build_defs.bzl", - "if_gpu_is_configured", -) load( "@local_tsl//tsl/platform/default:cuda_build_defs.bzl", "if_cuda_is_configured", ) -load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") -load("@local_tsl//tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable") package( default_visibility = ["//visibility:public"], @@ -183,7 +183,7 @@ xla_cc_test( srcs = if_gpu_is_configured(["gpu_copy_insertion_test.cc"]), tags = tf_cuda_tests_tags(), deps = if_gpu_is_configured([ - ":gpu_compiler", + ":buffer_sharing", "//xla:test", "//xla/hlo/ir:hlo", "//xla/service:copy_insertion", @@ -2562,6 +2562,7 @@ cc_library( ":alias_passthrough_params", ":all_reduce_blueconnect", ":autotuner_util", + ":buffer_sharing", ":compile_module_to_llvm_ir", ":conv_layout_normalization", ":copy_fusion", @@ -2801,6 +2802,7 @@ cc_library( visibility = ["//visibility:public"], deps = if_cuda_is_configured([ ":autotuner_util", + ":buffer_sharing", ":cublas_cudnn", ":cublas_pad_for_gemms", ":cublas_padding_requirements", @@ -3442,6 +3444,21 @@ xla_cc_test( ]), ) +cc_library( + name = "buffer_sharing", + srcs = ["buffer_sharing.cc"], + hdrs = ["buffer_sharing.h"], + visibility = ["//visibility:public"], + deps = [ + ":backend_configs_cc", + ":cublas_cudnn", + ":ir_emission_utils", + "//xla:shape_util", + "//xla/hlo/ir:hlo", + "@com_google_absl//absl/container:flat_hash_set", + ], +) + cc_library( name = "gpu_fusible", srcs = ["gpu_fusible.cc"], diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.cc b/third_party/xla/xla/service/gpu/buffer_sharing.cc new file mode 100644 index 00000000000..6cc4e6f962e --- /dev/null +++ b/third_party/xla/xla/service/gpu/buffer_sharing.cc @@ -0,0 +1,182 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "xla/service/gpu/buffer_sharing.h" + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/hlo/ir/hlo_opcode.h" +#include "xla/service/gpu/backend_configs.pb.h" +#include "xla/service/gpu/cublas_cudnn.h" +#include "xla/service/gpu/ir_emission_utils.h" +#include "xla/shape.h" +#include "xla/shape_util.h" + +namespace xla { +namespace gpu { + +std::optional FusionCanShareBufferHint(const HloInstruction* user, + const HloInstruction* operand, + const ShapeIndex& user_index) { + if (user->opcode() != HloOpcode::kFusion) { + return std::nullopt; + } + + // First, do the trivial check: if the fusion operand and the fusion output + // have a different number of elements or have a different element byte size, + // the buffer cannot be shared. + const Shape& user_subshape = + ShapeUtil::GetSubshape(user->shape(), user_index); + const Shape& operand_shape = operand->shape(); + const bool shapes_equal = ShapeUtil::Equal(operand_shape, user_subshape); + if (!shapes_equal) { + if (!operand_shape.IsArray() || !user_subshape.IsArray()) { + return false; + } + // We cannot share the buffer if the iteration space is not the same. + if (ShapeUtil::ElementsIn(operand_shape) != + ShapeUtil::ElementsIn(user_subshape)) { + return false; + } + // The buffers needed for 'user_subshape' and 'operand_shape' need to have + // the same size, otherwise they cannot be shared. We already checked that + // the number of elements are the same, so now we check the number of bytes + // needed for the element types. + if (ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type()) != + ShapeUtil::ByteSizeOfPrimitiveType(user_subshape.element_type())) { + return false; + } + } + + // We need to make sure that the fusion parameter is accessed in the same + // iteration order as the fusion output. Also, there should not be two fusion + // outputs that consume the fusion parameter, because we do not want to share + // the same fusion operand with two different fusion outputs. To make sure + // that the iteration order is the same, we only allow ops on the path from + // fusion parameter to fusion output which are elementwise (no copy) or + // bitcast or an elementwise dynamic update slice (i.e. with the first operand + // being on this path). + HloInstruction* fusion_param = + user->fused_parameter(user->operand_index(operand)); + HloInstruction* output = user->fused_expression_root(); + for (int64_t o : user_index) { + output = output->mutable_operand(o); + } + const HloInstruction* non_bitcast_root = output; + if (non_bitcast_root->opcode() == HloOpcode::kBitcast) { + non_bitcast_root = non_bitcast_root->operand(0); + } + std::queue q; + absl::flat_hash_set visited; + q.push(fusion_param); + visited.insert(fusion_param); + bool found_path_to_output = false; + while (!q.empty()) { + HloInstruction* hlo_operand = q.front(); + q.pop(); + if (hlo_operand == output) { + found_path_to_output = true; + // The output should have at most 1 user: the tuple op (in case of a + // multi-output fusion) + if (hlo_operand->user_count() > 1) { + return false; + } + continue; + } + for (HloInstruction* hlo : hlo_operand->users()) { + if (non_bitcast_root->opcode() == HloOpcode::kDynamicUpdateSlice && + hlo->opcode() == HloOpcode::kDynamicSlice && + non_bitcast_root->operand(0) == hlo->operand(0) && + hlo->shape() == non_bitcast_root->operand(1)->shape()) { + // We can still share the buffer in this case if the same slice is + // accessed by the DUS and the DS. If they don't access the same slice, + // the two slices might partially overlap and read/write the same index + // at different times, and then we cannot guarantee that we read before + // it is overwritten. However if both access only a single element, + // there also can be no race condition. + if (!ShapeUtil::IsEffectiveScalar(hlo->shape()) || + !ShapeUtil::IsEffectiveScalar( + non_bitcast_root->operand(1)->shape())) { + // Now compare all the slice start operands of 'hlo' and + // 'non_bitcast_root'. + for (int64_t i = 1; i < hlo->operand_count(); ++i) { + if (hlo->operand(i) != non_bitcast_root->operand(i + 1)) { + return false; + } + } + } + } else if ((!hlo->IsElementwiseOnOperand( + hlo->operand_index(hlo_operand)) || + hlo->opcode() == HloOpcode::kCopy) && + hlo->opcode() != HloOpcode::kBitcast) { + // This check also catches the case that we reach a different fusion + // output, as that fusion output would have a tuple op as user, which we + // do not allow here. + // Even if 'hlo' is not elementwise on the operand, it is ok if we are + // coming from the second operand and 'hlo' is a DynamicUpdateSlice + // which is the non_bitcast_root. This corresponds to the special case + // above, where we allow a DynamicSlice if it accesses the exact same + // slice than the DynamicUpdateSlice. When we are coming from the first + // operand, IsElementwiseOnOperand() will return true for a + // DynamicUpdateSlice. + if (hlo != non_bitcast_root || + hlo->opcode() != HloOpcode::kDynamicUpdateSlice || + hlo->operand_index(hlo_operand) != 1) { + return false; + } + } + if (visited.insert(hlo).second) { + q.push(hlo); + } + } + } + return found_path_to_output; +} + +std::optional CanShareBufferHint(const HloInstruction* user, + const HloInstruction* operand, + const ShapeIndex& user_index) { + switch (user->opcode()) { + case HloOpcode::kAllReduce: + // NCCL all-reduce can be performed in-place. + return user->operand_count() == 1 || + (user_index.size() == 1 && + user->operand(user_index[0]) == operand); + case HloOpcode::kCustomCall: + // The matrix bias operand can be overwritten in-place. + if (user->custom_call_target() == kCublasLtMatmulCallTarget) { + GemmBackendConfig config = + std::move(user->backend_config()).value(); + return (config.beta() != 0.) && user->operand(2) == operand; + } + // The operand of cholesky can be shared with the first output. + if (user->custom_call_target() == kCusolverCholeskyCallTarget) { + return user_index.size() == 1 && user_index[0] == 0; + } + return false; + case HloOpcode::kFusion: + return FusionCanShareBufferHint(user, operand, user_index); + default: + return std::nullopt; + } +} + +} // namespace gpu +} // namespace xla diff --git a/third_party/xla/xla/service/gpu/buffer_sharing.h b/third_party/xla/xla/service/gpu/buffer_sharing.h new file mode 100644 index 00000000000..5867dd8d031 --- /dev/null +++ b/third_party/xla/xla/service/gpu/buffer_sharing.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef XLA_SERVICE_GPU_BUFFER_SHARING_H_ +#define XLA_SERVICE_GPU_BUFFER_SHARING_H_ + +#include + +#include "xla/hlo/ir/hlo_instruction.h" +#include "xla/shape_util.h" + +namespace xla { +namespace gpu { +std::optional FusionCanShareBufferHint(const HloInstruction* user, + const HloInstruction* operand, + const ShapeIndex& user_index); + +std::optional CanShareBufferHint(const HloInstruction* user, + const HloInstruction* operand, + const ShapeIndex& user_index); +} // namespace gpu +} // namespace xla + +#endif // XLA_SERVICE_GPU_BUFFER_SHARING_H_ diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.cc b/third_party/xla/xla/service/gpu/gpu_compiler.cc index 4517edc2219..8476ee029d1 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.cc +++ b/third_party/xla/xla/service/gpu/gpu_compiler.cc @@ -21,14 +21,12 @@ limitations under the License. #include #include #include -#include #include #include #include #include #include -#include "absl/container/flat_hash_set.h" #include "absl/log/check.h" #include "absl/log/log.h" #include "absl/strings/str_cat.h" @@ -1693,123 +1691,5 @@ StatusOr> GpuCompiler::Export( return result; } -std::optional GpuCompiler::FusionCanShareBufferHint( - const HloInstruction* user, const HloInstruction* operand, - const ShapeIndex& user_index) { - if (user->opcode() != HloOpcode::kFusion) { - return std::nullopt; - } - - // First, do the trivial check: if the fusion operand and the fusion output - // have a different number of elements or have a different element byte size, - // the buffer cannot be shared. - const Shape& user_subshape = - ShapeUtil::GetSubshape(user->shape(), user_index); - const Shape& operand_shape = operand->shape(); - const bool shapes_equal = ShapeUtil::Equal(operand_shape, user_subshape); - if (!shapes_equal) { - if (!operand_shape.IsArray() || !user_subshape.IsArray()) { - return false; - } - // We cannot share the buffer if the iteration space is not the same. - if (ShapeUtil::ElementsIn(operand_shape) != - ShapeUtil::ElementsIn(user_subshape)) { - return false; - } - // The buffers needed for 'user_subshape' and 'operand_shape' need to have - // the same size, otherwise they cannot be shared. We already checked that - // the number of elements are the same, so now we check the number of bytes - // needed for the element types. - if (ShapeUtil::ByteSizeOfPrimitiveType(operand_shape.element_type()) != - ShapeUtil::ByteSizeOfPrimitiveType(user_subshape.element_type())) { - return false; - } - } - - // We need to make sure that the fusion parameter is accessed in the same - // iteration order as the fusion output. Also, there should not be two fusion - // outputs that consume the fusion parameter, because we do not want to share - // the same fusion operand with two different fusion outputs. To make sure - // that the iteration order is the same, we only allow ops on the path from - // fusion parameter to fusion output which are elementwise (no copy) or - // bitcast or an elementwise dynamic update slice (i.e. with the first operand - // being on this path). - HloInstruction* fusion_param = - user->fused_parameter(user->operand_index(operand)); - HloInstruction* output = user->fused_expression_root(); - for (int64_t o : user_index) { - output = output->mutable_operand(o); - } - const HloInstruction* non_bitcast_root = output; - if (non_bitcast_root->opcode() == HloOpcode::kBitcast) { - non_bitcast_root = non_bitcast_root->operand(0); - } - std::queue q; - absl::flat_hash_set visited; - q.push(fusion_param); - visited.insert(fusion_param); - bool found_path_to_output = false; - while (!q.empty()) { - HloInstruction* hlo_operand = q.front(); - q.pop(); - if (hlo_operand == output) { - found_path_to_output = true; - // The output should have at most 1 user: the tuple op (in case of a - // multi-output fusion) - if (hlo_operand->user_count() > 1) { - return false; - } - continue; - } - for (HloInstruction* hlo : hlo_operand->users()) { - if (non_bitcast_root->opcode() == HloOpcode::kDynamicUpdateSlice && - hlo->opcode() == HloOpcode::kDynamicSlice && - non_bitcast_root->operand(0) == hlo->operand(0) && - hlo->shape() == non_bitcast_root->operand(1)->shape()) { - // We can still share the buffer in this case if the same slice is - // accessed by the DUS and the DS. If they don't access the same slice, - // the two slices might partially overlap and read/write the same index - // at different times, and then we cannot guarantee that we read before - // it is overwritten. However if both access only a single element, - // there also can be no race condition. - if (!ShapeUtil::IsEffectiveScalar(hlo->shape()) || - !ShapeUtil::IsEffectiveScalar( - non_bitcast_root->operand(1)->shape())) { - // Now compare all the slice start operands of 'hlo' and - // 'non_bitcast_root'. - for (int64_t i = 1; i < hlo->operand_count(); ++i) { - if (hlo->operand(i) != non_bitcast_root->operand(i + 1)) { - return false; - } - } - } - } else if ((!hlo->IsElementwiseOnOperand( - hlo->operand_index(hlo_operand)) || - hlo->opcode() == HloOpcode::kCopy) && - hlo->opcode() != HloOpcode::kBitcast) { - // This check also catches the case that we reach a different fusion - // output, as that fusion output would have a tuple op as user, which we - // do not allow here. - // Even if 'hlo' is not elementwise on the operand, it is ok if we are - // coming from the second operand and 'hlo' is a DynamicUpdateSlice - // which is the non_bitcast_root. This corresponds to the special case - // above, where we allow a DynamicSlice if it accesses the exact same - // slice than the DynamicUpdateSlice. When we are coming from the first - // operand, IsElementwiseOnOperand() will return true for a - // DynamicUpdateSlice. - if (hlo != non_bitcast_root || - hlo->opcode() != HloOpcode::kDynamicUpdateSlice || - hlo->operand_index(hlo_operand) != 1) { - return false; - } - } - if (visited.insert(hlo).second) { - q.push(hlo); - } - } - } - return found_path_to_output; -} - } // namespace gpu } // namespace xla diff --git a/third_party/xla/xla/service/gpu/gpu_compiler.h b/third_party/xla/xla/service/gpu/gpu_compiler.h index a03c2e2b1e4..42d219c6c63 100644 --- a/third_party/xla/xla/service/gpu/gpu_compiler.h +++ b/third_party/xla/xla/service/gpu/gpu_compiler.h @@ -27,6 +27,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/service/executable.h" #include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/buffer_sharing.h" #include "xla/service/gpu/executable.pb.h" #include "xla/service/gpu/gpu_device_info.h" #include "xla/service/gpu/gpu_executable.h" @@ -170,10 +171,6 @@ class GpuCompiler : public LLVMCompiler { StatusOr> Export( Executable* executable) const override; - static std::optional FusionCanShareBufferHint( - const HloInstruction* user, const HloInstruction* operand, - const ShapeIndex& user_index); - protected: // During compilation with device, stream_exec != null and autotune_results // == null. During deviceless AOT compilation, stream_exec == null and diff --git a/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc b/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc index 24b098e986b..8fe7547b870 100644 --- a/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc +++ b/third_party/xla/xla/service/gpu/gpu_copy_insertion_test.cc @@ -20,7 +20,7 @@ limitations under the License. #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_opcode.h" #include "xla/service/copy_insertion.h" -#include "xla/service/gpu/gpu_compiler.h" +#include "xla/service/gpu/buffer_sharing.h" #include "xla/test.h" #include "xla/tests/hlo_test_base.h" @@ -110,7 +110,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); - CopyInsertion copy_insertion(GpuCompiler::FusionCanShareBufferHint, + CopyInsertion copy_insertion(FusionCanShareBufferHint, /*use_region_based_live_range_analysis=*/0); ASSERT_IS_OK(copy_insertion.Run(module.get(), {"foobar"}).status()); VLOG(2) << module->ToString(); @@ -142,8 +142,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalTrue( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedBitcastedShape) { @@ -166,8 +165,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalTrue( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, @@ -191,8 +189,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalTrue( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, BufferCanBeSharedMultiOutputFusion) { @@ -217,16 +214,15 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalTrue( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {0})); + ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {0})); // The second operand cannot share the buffer with the second fusion output, // because the 'neg' op is also used on the path to the first fusion output. ExpectOptionalFalse( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(1), {1})); + FusionCanShareBufferHint(fusion, fusion->operand(1), {1})); // The first operand cannot share the buffer with the second fusion output, // because there is no path between them. ExpectOptionalFalse( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {1})); + FusionCanShareBufferHint(fusion, fusion->operand(0), {1})); } TEST_F(FusionCanShareBufferHintTest, @@ -250,8 +246,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalFalse( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, BufferCannotBeSharedShapeBitcastConvert) { @@ -274,8 +269,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalFalse( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, BufferCannotBeSharedDueToCopy) { @@ -297,8 +291,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalFalse( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, BufferCannotBeSharedDueToTranspose) { @@ -320,8 +313,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalFalse( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, @@ -351,8 +343,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalFalse( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, @@ -383,8 +374,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalTrue( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, @@ -418,7 +408,7 @@ ENTRY main { ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); ExpectOptionalFalse( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {0})); + FusionCanShareBufferHint(fusion, fusion->operand(0), {0})); } TEST_F(FusionCanShareBufferHintTest, @@ -449,8 +439,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalTrue( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalTrue(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, @@ -481,8 +470,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalFalse( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } TEST_F(FusionCanShareBufferHintTest, @@ -513,8 +501,7 @@ ENTRY main { TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, ParseAndReturnVerifiedModule(kModuleString)); HloInstruction* fusion = module->entry_computation()->root_instruction(); - ExpectOptionalFalse( - GpuCompiler::FusionCanShareBufferHint(fusion, fusion->operand(0), {})); + ExpectOptionalFalse(FusionCanShareBufferHint(fusion, fusion->operand(0), {})); } } // namespace diff --git a/third_party/xla/xla/service/gpu/nvptx_compiler.cc b/third_party/xla/xla/service/gpu/nvptx_compiler.cc index ebb03ebfed6..01cf4fed6a3 100644 --- a/third_party/xla/xla/service/gpu/nvptx_compiler.cc +++ b/third_party/xla/xla/service/gpu/nvptx_compiler.cc @@ -37,6 +37,7 @@ limitations under the License. #include "xla/service/float_normalization.h" #include "xla/service/float_support.h" #include "xla/service/gpu/autotuner_util.h" +#include "xla/service/gpu/buffer_sharing.h" #include "xla/service/gpu/conv_algorithm_picker.h" #include "xla/service/gpu/cublas_cudnn.h" #include "xla/service/gpu/cublas_pad_for_gemms.h" @@ -332,34 +333,6 @@ Status NVPTXCompiler::SerializeAutotuneResultsToFile( } namespace { -std::optional CanShareBufferHint(const HloInstruction* user, - const HloInstruction* operand, - const ShapeIndex& user_index) { - switch (user->opcode()) { - case HloOpcode::kAllReduce: - // NCCL all-reduce can be performed in-place. - return user->operand_count() == 1 || - (user_index.size() == 1 && - user->operand(user_index[0]) == operand); - case HloOpcode::kCustomCall: - // The matrix bias operand can be overwritten in-place. - if (user->custom_call_target() == kCublasLtMatmulCallTarget) { - GemmBackendConfig config = - std::move(user->backend_config()).value(); - return (config.beta() != 0.) && user->operand(2) == operand; - } - // The operand of cholesky can be shared with the first output. - if (user->custom_call_target() == kCusolverCholeskyCallTarget) { - return user_index.size() == 1 && user_index[0] == 0; - } - return false; - case HloOpcode::kFusion: - return GpuCompiler::FusionCanShareBufferHint(user, operand, user_index); - default: - return std::nullopt; - } -} - // Try to load ptx from files defined in the FLAGS. If successful, return true. bool MaybeLoadPtxFromFile(const HloModuleConfig module_config, const HloModule* module, std::string* ptx) {