mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Merged commit includes the following changes:
646570256 by A. Unique TensorFlower<gardener@tensorflow.org>:
Put back type aliases for some 3p projects until they're migrated off of xla::Status.
--
646567067 by A. Unique TensorFlower<gardener@tensorflow.org>:
[xla:cpu] Optimize KernelThunk alignment checks
--
646562233 by A. Unique TensorFlower<gardener@tensorflow.org>:
Automated rollback of changelist 609005660.
646560125 by A. Unique TensorFlower<gardener@tensorflow.org>:
[XLA:CollectivePipeliner] Add more execution tests (using the HLOs in collective_pipeliner_test.cc).
--
646554714 by A. Unique TensorFlower<gardener@tensorflow.org>:
Instead of copybara rules, use `if_google` to remove extra proto deps
Followup will do the same for TSL
--
646551061 by A. Unique TensorFlower<gardener@tensorflow.org>:
Remove `Array::Reshard`
This CL removes the deprecated `Array::Reshard` API. All existing users have been manually migrated to use `Client::CopyArrays`.
IFRT Proxy is updated such that the client no longer issues `Array::Reshard` and the server emulates the reshard behavior by using `Client::CopyArrays`. Since this does not actually change the wire format, we do not need to update the version number. Once the reshard API passes the compatibility window, we can remove its proto message and handler altogether.
--
646545951 by A. Unique TensorFlower<gardener@tensorflow.org>:
Add license header to `dependabot.yml`
--
646545541 by A. Unique TensorFlower<gardener@tensorflow.org>:
Remove force_synchronous attribute from ParallelMap op in map_parallelization optimizer.
The code reuses the attributes/inputs of the original Map op but just changes it to a ParallelMap op. But the force_synchronous attribute is not supported in ParallelMap and causes log warnings.
The issue was introduced in cl/642418430
--
646534280 by A. Unique TensorFlower<gardener@tensorflow.org>:
Use absl::StatusOr instead of xla::StatusOr.
--
646517068 by A. Unique TensorFlower<gardener@tensorflow.org>:
Add more pattern to HloUnstacker pass + some refactoring.
Added a support for handling slicing fusion pattern:
fusion(stacked_operand, loop_iteration_var), calls=fusion_computation
fusion_computation {
p0 = parameter(0)
p1 = parameter(1)
slice = dynamic_slice(p0, p1, zero, ...)
ROOT bitcast = bitcast(slice)
}
Add "xla_enable_hlo_unstacker" flag to the compiler.
--
646513305 by A. Unique TensorFlower<gardener@tensorflow.org>:
Remove unused deps.
--
646513101 by A. Unique TensorFlower<gardener@tensorflow.org>:
[xla:cpu] Add a fast path for executing thunks sequentially
--
646507520 by A. Unique TensorFlower<gardener@tensorflow.org>:
Added a fingerprint field to PjRtStreamExecutorLoadedExecutable to avoid recalculating fingerprints when FingerprintExecutable() is called. This change significantly reduces idle time before execution when the GPU load tracker enqueues an executable.
--
646505763 by A. Unique TensorFlower<gardener@tensorflow.org>:
Change visibility rules.
--
646505592 by A. Unique TensorFlower<gardener@tensorflow.org>:
[XLA:GPU] Parse block-level parameters from backend config when available.
If block-level parameters are not available, fall back to the SoftMax heuristic.
The original plan was to parse block-level parameters from the config and remove the heuristic, but it turned out that we don't support all "valid" tiling. With this change it will be easier to write tests and verify that we don't have problem, before we could remove the heuristic and fully migrate to fusion backend config.
Also fix strides in ir_emitter_triton.cc. This was not a problem before, because SoftMax heuristic only produces tiles that are contiguous in memory.
--
646505352 by A. Unique TensorFlower<gardener@tensorflow.org>:
[xla:cpu] Add dynamic-update-slice fusion optimization to IrEmitter2
+ enable select-and-scatter test that used to time out without DUS optimization
--
646504512 by A. Unique TensorFlower<gardener@tensorflow.org>:
PR #62472: Hash Pin docker images
Imported from GitHub PR https://github.com/tensorflow/tensorflow/pull/62472
Also related to https://github.com/tensorflow/tensorflow/pull/62471, would you consider hash pin the docker images?
The security benefit of doing so is that it mitigates the risk of typosquatting attacks since the images are public. If there is a need for them to be updated regularly, I can also submit a .github/dependabot file to update the docker images regularly (weekly or monthly for example).
Besides, AFAIUC, the dockerfiles are used for build and tests, which lead to another benefit of hash pinning: reliability and stability.
Let me know your thoughts about i.
Thanks!
Copybara import of the project:
--
8f4589fe583518d3099c98215e5e6bf3858fa24e by Joyce Brum <joycebrum@google.com>:
feat: create dependabot
Signed-off-by: Joyce Brum <joycebrum@google.com>
Merging this change closes #62472
--
PiperOrigin-RevId: 646570256
This commit is contained in:
parent
1b5fc0570e
commit
1081403fee
44
.github/dependabot.yml
vendored
Normal file
44
.github/dependabot.yml
vendored
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
# Copyright 2024 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
version: 2
|
||||
updates:
|
||||
- package-ecosystem: github-actions
|
||||
directory: /
|
||||
schedule:
|
||||
interval: monthly
|
||||
groups:
|
||||
github-actions:
|
||||
patterns:
|
||||
- "*"
|
||||
|
||||
- package-ecosystem: docker
|
||||
directory: /ci/devinfra/docker_windows
|
||||
schedule:
|
||||
interval: monthly
|
||||
|
||||
- package-ecosystem: docker
|
||||
directory: /ci/official/containers/linux_arm64
|
||||
schedule:
|
||||
interval: monthly
|
||||
|
||||
- package-ecosystem: docker
|
||||
directory: /tensorflow/tools/gcs_test
|
||||
schedule:
|
||||
interval: monthly
|
||||
|
||||
- package-ecosystem: docker
|
||||
directory: /tensorflow/tools/tf_sig_build_dockerfiles
|
||||
schedule:
|
||||
interval: monthly
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
FROM mcr.microsoft.com/dotnet/framework/sdk:4.8-windowsservercore-ltsc2019
|
||||
FROM mcr.microsoft.com/dotnet/framework/sdk:4.8-windowsservercore-ltsc2019@sha256:46e393cbb7c915c504a810639e35f40cb516f8e886e4cbcf8a3b49f86705a070
|
||||
|
||||
# Set default powershell policy for this script (ProgressPreference='SilentlyContinue' makes
|
||||
# downloads with Invoke-WebRequest not show the progress bar and is MUCH faster).
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
################################################################################
|
||||
FROM ubuntu:20.04 as builder
|
||||
FROM ubuntu:20.04@sha256:874aca52f79ae5f8258faff03e10ce99ae836f6e7d2df6ecd3da5c1cad3a912b as builder
|
||||
################################################################################
|
||||
|
||||
# Install devtoolset build dependencies
|
||||
|
|
@ -23,7 +23,7 @@ COPY apt.conf /etc/apt/
|
|||
RUN /build_patchelf.sh
|
||||
|
||||
################################################################################
|
||||
FROM nvidia/cuda:12.3.1-devel-ubuntu20.04 as devel
|
||||
FROM nvidia/cuda:12.3.1-devel-ubuntu20.04@sha256:befbdfddbb52727f9ce8d0c574cac0f631c606b1e6f0e523f3a0777fe2720c99 as devel
|
||||
################################################################################
|
||||
COPY --from=builder /dt10 /dt10
|
||||
COPY --from=builder /patchelf/patchelf_0.14.3-1_arm64.deb /patchelf/patchelf_0.14.3-1_arm64.deb
|
||||
|
|
|
|||
|
|
@ -866,7 +866,7 @@ Status RunPjRtExecutable(
|
|||
xla::PjRtLocalDeviceId(pjrt_device_id)));
|
||||
|
||||
gpu::GpuServingDeviceSelectorResource* device_selector_resource = nullptr;
|
||||
if (device_type == DEVICE_GPU && gpu::kUseGpuServingDeviceSelector) {
|
||||
if (device_type == DEVICE_GPU) {
|
||||
auto rm = ctx->resource_manager();
|
||||
TF_RETURN_IF_ERROR(rm->LookupOrCreate<
|
||||
gpu::GpuServingDeviceSelectorResource>(
|
||||
|
|
|
|||
|
|
@ -32,9 +32,6 @@ namespace gpu {
|
|||
class GpuServingDeviceSelector;
|
||||
const char kGpuServingDeviceSelectorResourceName[] =
|
||||
"gpu_serving_device_selector";
|
||||
// TODO(b/335729939): Disable GPU load tracker for performance regression
|
||||
// investigation. Remove when fixed.
|
||||
const bool kUseGpuServingDeviceSelector = false;
|
||||
|
||||
class GpuServingDeviceSelectorResource : public ResourceBase {
|
||||
public:
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ NodeDef MakeParallelMap(const string& name, MutableGraphView* graph) {
|
|||
auto* num_parallel_calls = graph_utils::AddScalarConstNode(
|
||||
static_cast<int64_t>(data::model::kAutotune), graph);
|
||||
parallel_map.add_input(num_parallel_calls->name());
|
||||
parallel_map.mutable_attr()->erase("force_synchronous");
|
||||
AddNodeAttr("deterministic", "true", ¶llel_map);
|
||||
|
||||
return parallel_map;
|
||||
|
|
|
|||
|
|
@ -294,13 +294,10 @@ cc_library(
|
|||
deps = [
|
||||
":macros",
|
||||
"//tensorflow/lite:allocation",
|
||||
"//tensorflow/lite:mutable_op_resolver",
|
||||
"//tensorflow/lite:stderr_reporter",
|
||||
"//tensorflow/lite:string",
|
||||
"//tensorflow/lite/core/api:error_reporter",
|
||||
"//tensorflow/lite/core/api:op_resolver",
|
||||
"//tensorflow/lite/core/api:verifier",
|
||||
"//tensorflow/lite/core/c:common",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@flatbuffers",
|
||||
|
|
|
|||
|
|
@ -33,13 +33,9 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/lite/allocation.h"
|
||||
#include "tensorflow/lite/core/api/error_reporter.h"
|
||||
#include "tensorflow/lite/core/api/op_resolver.h"
|
||||
#include "tensorflow/lite/core/api/verifier.h"
|
||||
#include "tensorflow/lite/core/c/common.h"
|
||||
#include "tensorflow/lite/mutable_op_resolver.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/stderr_reporter.h"
|
||||
#include "tensorflow/lite/string_type.h"
|
||||
|
||||
namespace tflite {
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
FROM ubuntu:16.04
|
||||
FROM ubuntu:16.04@sha256:1f1a2d56de1d604801a9671f301190704c25d604a416f59e03c04f5c6ffee0d6
|
||||
|
||||
LABEL maintainer="Shanqing Cai <cais@google.com>"
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
################################################################################
|
||||
FROM ubuntu:22.04 as builder
|
||||
FROM ubuntu:22.04@sha256:a6d2b38300ce017add71440577d5b0a90460d0e57fd7aec21dd0d1b0761bbfb2 as builder
|
||||
################################################################################
|
||||
|
||||
# Install devtoolset build dependencies
|
||||
|
|
@ -16,7 +16,7 @@ COPY builder.devtoolset/glibc2.17-inline.patch /glibc2.17-inline.patch
|
|||
RUN /build_devtoolset.sh devtoolset-9 /dt9
|
||||
|
||||
################################################################################
|
||||
FROM nvidia/cuda:12.3.1-base-ubuntu22.04 as devel
|
||||
FROM nvidia/cuda:12.3.1-base-ubuntu22.04@sha256:6a7febf317514458233b87819ce47d5441357dd7763e91800c35f6745f34bbbd as devel
|
||||
################################################################################
|
||||
COPY --from=builder /dt9 /dt9
|
||||
|
||||
|
|
|
|||
|
|
@ -66,10 +66,10 @@ limitations under the License.
|
|||
|
||||
namespace tsl {
|
||||
namespace {
|
||||
constexpr char kGcsUriBase[] = "https://www.googleapis.com./storage/v1/";
|
||||
constexpr char kGcsUriBase[] = "https://www.googleapis.com/storage/v1/";
|
||||
constexpr char kGcsUploadUriBase[] =
|
||||
"https://www.googleapis.com./upload/storage/v1/";
|
||||
constexpr char kStorageHost[] = "storage.googleapis.com.";
|
||||
"https://www.googleapis.com/upload/storage/v1/";
|
||||
constexpr char kStorageHost[] = "storage.googleapis.com";
|
||||
constexpr char kBucketMetadataLocationKey[] = "location";
|
||||
constexpr size_t kReadAppendableFileBufferSize = 1024 * 1024; // In bytes.
|
||||
constexpr int kGetChildrenDefaultPageSize = 1000;
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
13
third_party/xla/xla/BUILD
vendored
13
third_party/xla/xla/BUILD
vendored
|
|
@ -8,7 +8,7 @@ load("//third_party/compute_library:build_defs.bzl", "if_enable_acl")
|
|||
|
||||
# Placeholder: load py_proto_library
|
||||
load("//xla:xla.bzl", "xla_cc_test", "xla_py_proto_library")
|
||||
load("//xla/tsl:tsl.bzl", "internal_visibility")
|
||||
load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility")
|
||||
load("//xla/tsl:tsl.default.bzl", "filegroup", "get_compatible_with_portable")
|
||||
|
||||
package(
|
||||
|
|
@ -103,7 +103,7 @@ tf_proto_library(
|
|||
protodeps = [
|
||||
":xla_data_proto",
|
||||
"//xla/service:hlo_proto",
|
||||
],
|
||||
] + if_google(["@com_google_protobuf//:any"]),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
|
@ -304,9 +304,7 @@ cc_library(
|
|||
deprecation = "Use @com_google_absl//absl/status instead.",
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -1209,7 +1207,12 @@ tf_proto_library(
|
|||
name = "autotuning_proto",
|
||||
srcs = ["autotuning.proto"],
|
||||
make_default_target_header_only = True,
|
||||
protodeps = ["@local_tsl//tsl/protobuf:dnn_proto"],
|
||||
protodeps = [
|
||||
"@local_tsl//tsl/protobuf:dnn_proto",
|
||||
] + if_google([
|
||||
"@com_google_protobuf//:any",
|
||||
"@com_google_protobuf//:duration",
|
||||
]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ cc_library(
|
|||
srcs = ["platform_id.cc"],
|
||||
hdrs = ["platform_id.h"],
|
||||
deps = ["//xla/stream_executor"] + if_static(
|
||||
["@com_google_protobuf//:protobuf"],
|
||||
["@com_google_protobuf//:any_cc_proto"],
|
||||
["@com_google_protobuf//:protobuf_headers"],
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -2287,6 +2287,7 @@ PjRtStreamExecutorLoadedExecutable::PjRtStreamExecutorLoadedExecutable(
|
|||
TransferManager* transfer_manager =
|
||||
client_->client()->backend().transfer_manager();
|
||||
executables_.reserve(executables.size());
|
||||
tsl::Fprint128 fingerprint = tsl::Fingerprint128(fingerprint_);
|
||||
for (auto& executable : executables) {
|
||||
const auto& computation_layout =
|
||||
executable->executable()->module().entry_computation_layout();
|
||||
|
|
@ -2296,10 +2297,14 @@ PjRtStreamExecutorLoadedExecutable::PjRtStreamExecutorLoadedExecutable(
|
|||
parameter_shapes.push_back(transfer_manager->HostShapeToDeviceShape(
|
||||
computation_layout.parameter_shape(i)));
|
||||
}
|
||||
fingerprint = tsl::FingerprintCat128(
|
||||
fingerprint,
|
||||
tsl::Fingerprint128(executable->executable()->module().ToString()));
|
||||
executables_.emplace_back(std::move(executable));
|
||||
on_device_executable_parameter_shapes_.push_back(
|
||||
std::move(parameter_shapes));
|
||||
}
|
||||
fingerprint_ = absl::StrCat(fingerprint.low64, fingerprint.high64);
|
||||
|
||||
int num_partitions;
|
||||
if (device_assignment_ == nullptr) {
|
||||
|
|
@ -3251,22 +3256,6 @@ PjRtStreamExecutorLoadedExecutable::GetOutputMemoryKinds() const {
|
|||
return Unimplemented("GetOutputMemoryKinds is not supported.");
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string>
|
||||
PjRtStreamExecutorLoadedExecutable::FingerprintExecutable() const {
|
||||
if (executables_.size() != 1) {
|
||||
return absl::InternalError(
|
||||
"Fingerprinting multiple executables within one "
|
||||
"PjRtStreamExecutorLoadedExecutable is not supported.");
|
||||
}
|
||||
|
||||
Executable* executable = executables_[0]->executable();
|
||||
if (executable->has_module()) {
|
||||
return executable->module().GetFingerprint128();
|
||||
} else {
|
||||
return absl::InternalError("Executable does not have HLO modules.");
|
||||
}
|
||||
}
|
||||
|
||||
absl::StatusOr<PjRtStreamExecutorClient::ExecutableExtras>
|
||||
PjRtStreamExecutorClient::GetExecutableExtras(CompileOptions* options) {
|
||||
ExecutableExtras extras;
|
||||
|
|
|
|||
|
|
@ -998,7 +998,9 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
|||
return compile_options_;
|
||||
}
|
||||
|
||||
absl::StatusOr<std::string> FingerprintExecutable() const override;
|
||||
absl::StatusOr<std::string> FingerprintExecutable() const override {
|
||||
return fingerprint_;
|
||||
};
|
||||
|
||||
protected:
|
||||
bool parameter_is_tupled_arguments() const {
|
||||
|
|
@ -1077,6 +1079,7 @@ class PjRtStreamExecutorLoadedExecutable : public PjRtLoadedExecutable {
|
|||
// addressable_device_logical_ids_[i] is assigned. shared_ptrs instead of
|
||||
// unique_ptrs to play well with the Python bindings (see xla.cc).
|
||||
std::vector<PjRtDevice*> addressable_devices_;
|
||||
std::string fingerprint_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
|||
5
third_party/xla/xla/python/BUILD
vendored
5
third_party/xla/xla/python/BUILD
vendored
|
|
@ -410,7 +410,6 @@ cc_library(
|
|||
"@local_tsl//tsl/profiler/lib:profiler_session",
|
||||
"@local_tsl//tsl/profiler/lib:traceme",
|
||||
"@local_tsl//tsl/profiler/protobuf:profiled_instructions_proto_cc",
|
||||
"@com_google_protobuf//:protobuf",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
] + if_cuda([
|
||||
|
|
@ -419,7 +418,9 @@ cc_library(
|
|||
"//xla/stream_executor/cuda:cuda_driver",
|
||||
]) + if_rocm([
|
||||
"@local_config_rocm//rocm:rocm_headers",
|
||||
]) + if_cuda_or_rocm([":py_client_gpu"]), # TODO(b/337876408): remove after migration to plugin
|
||||
]) + if_cuda_or_rocm([
|
||||
":py_client_gpu", # TODO(b/337876408): remove after migration to plugin
|
||||
]) + if_google(["@com_google_protobuf//:any_cc_proto"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
|||
23
third_party/xla/xla/python/ifrt/array.h
vendored
23
third_party/xla/xla/python/ifrt/array.h
vendored
|
|
@ -122,29 +122,6 @@ class Array : public llvm::RTTIExtends<Array, Value> {
|
|||
void* data, std::optional<absl::Span<const int64_t>> byte_strides,
|
||||
ArrayCopySemantics semantics) = 0;
|
||||
|
||||
// Copies the array with a new sharding, creating a new array.
|
||||
//
|
||||
// Resharding falls into one of the three cases:
|
||||
//
|
||||
// * Metadata-only resharding: Use a new sharding for the array that expects
|
||||
// the same physical layout of underlying buffers on the same devices.
|
||||
// * 1-to-1 buffer copy: Copy individual buffers to different devices without
|
||||
// altering their physical layout.
|
||||
// * M-to-N buffer resharding: Shuffle the buffer data across the boundary of
|
||||
// the buffers, changing their physical layout.
|
||||
//
|
||||
// Implementations may return `UNIMPLEMENTED` if they do not know how to copy
|
||||
// or reshuffle the data to match the new sharding.
|
||||
//
|
||||
// It may fail if the buffer data would be sent from/to an unaddressable
|
||||
// device.
|
||||
//
|
||||
// TODO(b/343992694): Remove this API in favor of `Client::CopyArrays`.
|
||||
ABSL_DEPRECATED("Use `Client::CopyArrays` instead")
|
||||
virtual absl::StatusOr<tsl::RCReference<Array>> Reshard(
|
||||
std::shared_ptr<const Sharding> new_sharding,
|
||||
ArrayCopySemantics semantics) = 0;
|
||||
|
||||
static char ID; // NOLINT
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -445,70 +445,6 @@ TEST(ArrayImplTest, AssembleAndDisassembleSingleDeviceArray) {
|
|||
ElementsAreArray(array->sharding().devices().devices()));
|
||||
}
|
||||
|
||||
TEST(ArrayImplTest, ReshardToSameSharding) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
|
||||
|
||||
DType dtype(DType::kF32);
|
||||
Shape shape({2, 3});
|
||||
std::vector<float> data(6);
|
||||
std::iota(data.begin(), data.end(), 0);
|
||||
Device* device = client->addressable_devices().at(0);
|
||||
std::shared_ptr<const Sharding> sharding =
|
||||
SingleDeviceSharding::Create(device, MemoryKind());
|
||||
auto semantics = Client::HostBufferSemantics::kImmutableOnlyDuringCall;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto array, client->MakeArrayFromHostBuffer(
|
||||
data.data(), dtype, shape,
|
||||
/*byte_strides=*/std::nullopt, sharding, semantics,
|
||||
/*on_done_with_host_buffer=*/{}));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto resharded_array,
|
||||
array->Reshard(sharding, ArrayCopySemantics::kAlwaysCopy));
|
||||
|
||||
std::vector<float> out_data(6);
|
||||
auto future = resharded_array->CopyToHostBuffer(
|
||||
out_data.data(), /*byte_strides=*/std::nullopt,
|
||||
ArrayCopySemantics::kAlwaysCopy);
|
||||
TF_ASSERT_OK(future.Await());
|
||||
EXPECT_THAT(out_data, ElementsAreArray(data));
|
||||
}
|
||||
|
||||
TEST(ArrayImplTest, ReshardToDifferentDevice) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
|
||||
|
||||
DType dtype(DType::kF32);
|
||||
Shape shape({2, 3});
|
||||
std::vector<float> data(6);
|
||||
std::iota(data.begin(), data.end(), 0);
|
||||
Device* device = client->addressable_devices().at(0);
|
||||
std::shared_ptr<const Sharding> sharding =
|
||||
SingleDeviceSharding::Create(device, MemoryKind());
|
||||
auto semantics = Client::HostBufferSemantics::kImmutableOnlyDuringCall;
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto array, client->MakeArrayFromHostBuffer(
|
||||
data.data(), dtype, shape,
|
||||
/*byte_strides=*/std::nullopt, sharding, semantics,
|
||||
/*on_done_with_host_buffer=*/{}));
|
||||
|
||||
Device* new_device = client->addressable_devices().at(1);
|
||||
std::shared_ptr<const Sharding> new_sharding =
|
||||
SingleDeviceSharding::Create(new_device, MemoryKind());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
auto resharded_array,
|
||||
array->Reshard(new_sharding, ArrayCopySemantics::kAlwaysCopy));
|
||||
|
||||
std::vector<float> out_data(6);
|
||||
auto future = resharded_array->CopyToHostBuffer(
|
||||
out_data.data(), /*byte_strides=*/std::nullopt,
|
||||
ArrayCopySemantics::kAlwaysCopy);
|
||||
TF_ASSERT_OK(future.Await());
|
||||
EXPECT_THAT(out_data, ElementsAreArray(data));
|
||||
}
|
||||
|
||||
TEST(ArrayImplTest, CopyToSameDevices) {
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto client, test_util::GetClient());
|
||||
|
||||
|
|
|
|||
5
third_party/xla/xla/python/ifrt/mock.cc
vendored
5
third_party/xla/xla/python/ifrt/mock.cc
vendored
|
|
@ -95,11 +95,6 @@ MockArray::MockArray(tsl::RCReference<xla::ifrt::Array> delegated)
|
|||
ArrayCopySemantics semantics) {
|
||||
return delegated_->CopyToHostBuffer(data, byte_strides, semantics);
|
||||
});
|
||||
ON_CALL(*this, Reshard)
|
||||
.WillByDefault([this](std::shared_ptr<const Sharding> new_sharding,
|
||||
ArrayCopySemantics semantics) {
|
||||
return delegated_->Reshard(std::move(new_sharding), semantics);
|
||||
});
|
||||
}
|
||||
// LINT.ThenChange()
|
||||
|
||||
|
|
|
|||
4
third_party/xla/xla/python/ifrt/mock.h
vendored
4
third_party/xla/xla/python/ifrt/mock.h
vendored
|
|
@ -88,10 +88,6 @@ class MockArray : public llvm::RTTIExtends<MockArray, Array> {
|
|||
std::optional<absl::Span<const int64_t>> byte_strides,
|
||||
ArrayCopySemantics semantics),
|
||||
(final));
|
||||
MOCK_METHOD(absl::StatusOr<tsl::RCReference<Array>>, Reshard,
|
||||
(std::shared_ptr<const Sharding> new_sharding,
|
||||
ArrayCopySemantics semantics),
|
||||
(final));
|
||||
// LINT.ThenChange(mock.cc:MockArrayDelegation)
|
||||
|
||||
tsl::RCReference<xla::ifrt::Array> delegated() const { return delegated_; }
|
||||
|
|
|
|||
|
|
@ -127,9 +127,11 @@ class Array final : public llvm::RTTIExtends<Array, xla::ifrt::Array> {
|
|||
void* data, std::optional<absl::Span<const int64_t>> byte_strides,
|
||||
ArrayCopySemantics semantics) override;
|
||||
|
||||
// This will be deleted once the client requires the minimum version of 3.
|
||||
ABSL_DEPRECATED("Use `Client::CopyArrays` instead")
|
||||
absl::StatusOr<tsl::RCReference<xla::ifrt::Array>> Reshard(
|
||||
std::shared_ptr<const Sharding> new_sharding,
|
||||
ArrayCopySemantics semantics) override;
|
||||
ArrayCopySemantics semantics);
|
||||
|
||||
static char ID; // NOLINT
|
||||
|
||||
|
|
|
|||
|
|
@ -227,8 +227,15 @@ Client::CopyArrays(absl::Span<tsl::RCReference<xla::ifrt::Array>> arrays,
|
|||
TF_ASSIGN_OR_RETURN(
|
||||
auto new_sharding,
|
||||
array->sharding().WithDeviceAssignment(devices, memory_kind));
|
||||
TF_ASSIGN_OR_RETURN(new_arrays.emplace_back(),
|
||||
array->Reshard(std::move(new_sharding), semantics));
|
||||
if (auto* const proxy_array =
|
||||
llvm::dyn_cast<xla::ifrt::proxy::Array>(array.get())) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
new_arrays.emplace_back(),
|
||||
proxy_array->Reshard(std::move(new_sharding), semantics));
|
||||
} else {
|
||||
return absl::InvalidArgumentError(
|
||||
"Unsupported array type for xla::ifrt::proxy::Client::CopyArrays");
|
||||
}
|
||||
}
|
||||
return new_arrays;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ tf_proto_library(
|
|||
srcs = ["ifrt_service.proto"],
|
||||
protodeps = [
|
||||
":types_proto",
|
||||
# copybara:uncomment "//google/protobuf:any",
|
||||
# copybara:uncomment "@com_google_protobuf//:any",
|
||||
"//xla:xla_data_proto",
|
||||
"//xla/pjrt:execute_options_proto",
|
||||
"//xla/python/ifrt:dtype_proto",
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ message IfrtRequest {
|
|||
disassemble_into_single_device_arrays_request = 7;
|
||||
DeleteArrayRequest delete_array_request = 9;
|
||||
CopyArraysRequest copy_arrays_request = 24;
|
||||
ReshardRequest reshard_request = 10;
|
||||
ReshardRequest reshard_request = 10 [deprecated = true];
|
||||
FullyReplicatedShardRequest fully_replicated_shard_request = 20;
|
||||
IsArrayDeletedRequest is_array_deleted_request = 11;
|
||||
DestructArrayRequest destruct_array_request = 12;
|
||||
|
|
@ -102,7 +102,7 @@ message IfrtResponse {
|
|||
disassemble_into_single_device_arrays_response = 7;
|
||||
DeleteArrayResponse delete_array_response = 9;
|
||||
CopyArraysResponse copy_arrays_response = 24;
|
||||
ReshardResponse reshard_response = 10;
|
||||
ReshardResponse reshard_response = 10 [deprecated = true];
|
||||
FullyReplicatedShardResponse fully_replicated_shard_response = 20;
|
||||
IsArrayDeletedResponse is_array_deleted_response = 11;
|
||||
DestructArrayResponse destruct_array_response = 12;
|
||||
|
|
|
|||
|
|
@ -697,13 +697,24 @@ absl::StatusOr<BackendInterface::Response> IfrtBackend::HandleReshardRequest(
|
|||
TF_ASSIGN_OR_RETURN(auto semantics, FromArrayCopySemanticsProto(
|
||||
reshard_request.copy_semantics()));
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto resharded_array,
|
||||
array->Reshard(sharding, semantics));
|
||||
// Emulate the old `Array::Reshard` behavior using `Client::CopyArrays`. No
|
||||
// existing IFRT implementations before `Array::Reshard` was deleted actually
|
||||
// supported resharding, so this should be safe.
|
||||
if (!array->sharding().HasSamePartitioning(*sharding)) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"IFRT Proxy does not support resharding, but got ",
|
||||
array->sharding().DebugString(), " as the original sharding and ",
|
||||
sharding->DebugString(), " as the target sharding"));
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto copied_arrays,
|
||||
client_->CopyArrays(absl::MakeSpan(&array, 1), sharding->devices(),
|
||||
sharding->memory_kind(), semantics));
|
||||
|
||||
uint64_t resharded_array_handle = handle_generator_.New();
|
||||
{
|
||||
absl::MutexLock lock(&arrays_mutex_);
|
||||
arrays_.insert({resharded_array_handle, std::move(resharded_array)});
|
||||
arrays_.insert({resharded_array_handle, std::move(copied_arrays[0])});
|
||||
}
|
||||
|
||||
auto ifrt_resp = NewIfrtResponse(request->request_metadata().op_id());
|
||||
|
|
|
|||
|
|
@ -700,21 +700,27 @@ TEST_P(IfrtBackendHandlerTest, CopyArrays) {
|
|||
|
||||
TEST_P(IfrtBackendHandlerTest, ReshardSuccess) {
|
||||
auto src_mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
|
||||
auto resharded_mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
|
||||
EXPECT_CALL(*src_mock_array, Reshard(_, _))
|
||||
.WillOnce(Return(std::move(resharded_mock_array)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto* device,
|
||||
mock_client_->LookupDevice(DeviceId(0)));
|
||||
auto src_sharding = SingleDeviceSharding::Create(device, MemoryKind());
|
||||
ON_CALL(*src_mock_array, sharding()).WillByDefault(ReturnRef(*src_sharding));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto src_array_handle,
|
||||
MakeTestArray(std::move(src_mock_array)));
|
||||
|
||||
auto copied_mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
|
||||
EXPECT_CALL(*mock_client_, CopyArrays(_, _, _, _))
|
||||
.WillOnce(Return(std::vector<tsl::RCReference<xla::ifrt::Array>>(
|
||||
{copied_mock_array})));
|
||||
|
||||
auto ifrt_request = NewIfrtRequest(NewOpId());
|
||||
auto* reshard_request = ifrt_request->mutable_reshard_request();
|
||||
reshard_request->set_array_handle(src_array_handle);
|
||||
reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto* device,
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto* new_device,
|
||||
mock_client_->LookupDevice(DeviceId(1)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
*ifrt_request->mutable_reshard_request()->mutable_sharding(),
|
||||
SingleDeviceSharding::Create(device, MemoryKind())->ToProto());
|
||||
SingleDeviceSharding::Create(new_device, MemoryKind())->ToProto());
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto response, CallBackend(std::move(ifrt_request)));
|
||||
|
||||
|
|
@ -723,6 +729,43 @@ TEST_P(IfrtBackendHandlerTest, ReshardSuccess) {
|
|||
EXPECT_NE(response->reshard_response().array_handle(), 0);
|
||||
}
|
||||
|
||||
TEST_P(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) {
|
||||
auto mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto* device,
|
||||
mock_client_->LookupDevice(DeviceId(1)));
|
||||
auto sharding = SingleDeviceSharding::Create(device, MemoryKind());
|
||||
ON_CALL(*mock_array, sharding()).WillByDefault(ReturnRef(*sharding));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto array_handle,
|
||||
MakeTestArray(std::move(mock_array)));
|
||||
|
||||
EXPECT_CALL(*mock_client_, CopyArrays(_, _, _, _))
|
||||
.WillOnce(Return(absl::UnknownError("injected error")));
|
||||
|
||||
auto ifrt_request = NewIfrtRequest(NewOpId());
|
||||
auto* reshard_request = ifrt_request->mutable_reshard_request();
|
||||
reshard_request->set_array_handle(array_handle);
|
||||
reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto* new_device,
|
||||
mock_client_->LookupDevice(DeviceId(1)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
*ifrt_request->mutable_reshard_request()->mutable_sharding(),
|
||||
SingleDeviceSharding::Create(new_device, MemoryKind())->ToProto());
|
||||
|
||||
EXPECT_THAT(CallBackend(std::move(ifrt_request)),
|
||||
StatusIs(absl::StatusCode::kUnknown, StrEq("injected error")));
|
||||
}
|
||||
|
||||
TEST_P(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) {
|
||||
auto ifrt_request = NewIfrtRequest(NewOpId());
|
||||
auto* reshard_request = ifrt_request->mutable_reshard_request();
|
||||
reshard_request->set_array_handle(0);
|
||||
reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY);
|
||||
reshard_request->mutable_sharding();
|
||||
|
||||
EXPECT_THAT(CallBackend(std::move(ifrt_request)),
|
||||
StatusIs(absl::StatusCode::kNotFound));
|
||||
}
|
||||
|
||||
TEST_P(IfrtBackendHandlerTest, FullyReplicatedShardSuccess) {
|
||||
auto fully_replicated_mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
|
||||
auto resultant_array = tsl::MakeRef<xla::ifrt::MockArray>();
|
||||
|
|
@ -777,38 +820,6 @@ TEST_P(IfrtBackendHandlerTest,
|
|||
StatusIs(absl::StatusCode::kNotFound));
|
||||
}
|
||||
|
||||
TEST_P(IfrtBackendHandlerTest, ReshardFailsWhenTheBackendFails) {
|
||||
auto mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
|
||||
EXPECT_CALL(*mock_array, Reshard(_, _))
|
||||
.WillOnce(Return(absl::UnknownError("injected error")));
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto array_handle,
|
||||
MakeTestArray(std::move(mock_array)));
|
||||
|
||||
auto ifrt_request = NewIfrtRequest(NewOpId());
|
||||
auto* reshard_request = ifrt_request->mutable_reshard_request();
|
||||
reshard_request->set_array_handle(array_handle);
|
||||
reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY);
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto* device,
|
||||
mock_client_->LookupDevice(DeviceId(1)));
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
*ifrt_request->mutable_reshard_request()->mutable_sharding(),
|
||||
SingleDeviceSharding::Create(device, MemoryKind())->ToProto());
|
||||
|
||||
EXPECT_THAT(CallBackend(std::move(ifrt_request)),
|
||||
StatusIs(absl::StatusCode::kUnknown, StrEq("injected error")));
|
||||
}
|
||||
|
||||
TEST_P(IfrtBackendHandlerTest, ReshardFailsWithNonExistentArrayHandle) {
|
||||
auto ifrt_request = NewIfrtRequest(NewOpId());
|
||||
auto* reshard_request = ifrt_request->mutable_reshard_request();
|
||||
reshard_request->set_array_handle(0);
|
||||
reshard_request->set_copy_semantics(proto::ARRAY_COPY_SEMANTICS_ALWAYS_COPY);
|
||||
reshard_request->mutable_sharding();
|
||||
|
||||
EXPECT_THAT(CallBackend(std::move(ifrt_request)),
|
||||
StatusIs(absl::StatusCode::kNotFound));
|
||||
}
|
||||
|
||||
TEST_P(IfrtBackendHandlerTest,
|
||||
CheckArrayReadyRequestRelaysTheResultFromBackend) {
|
||||
auto mock_array = tsl::MakeRef<xla::ifrt::MockArray>();
|
||||
|
|
|
|||
6
third_party/xla/xla/python/pjrt_ifrt/BUILD
vendored
6
third_party/xla/xla/python/pjrt_ifrt/BUILD
vendored
|
|
@ -1,6 +1,6 @@
|
|||
load("@local_tsl//tsl/platform:build_config.bzl", "tf_proto_library")
|
||||
load("//xla:xla.bzl", "xla_cc_test")
|
||||
load("//xla/tsl:tsl.bzl", "internal_visibility")
|
||||
load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility")
|
||||
load("//xla/tsl:tsl.default.bzl", "get_compatible_with_portable")
|
||||
|
||||
package_group(
|
||||
|
|
@ -70,7 +70,9 @@ tf_proto_library(
|
|||
name = "xla_host_callback_proto",
|
||||
srcs = ["xla_host_callback.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = ["//xla:xla_data_proto"],
|
||||
protodeps = [
|
||||
"//xla:xla_data_proto",
|
||||
] + if_google(["@com_google_protobuf//:any"]),
|
||||
)
|
||||
|
||||
tf_proto_library(
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ limitations under the License.
|
|||
#include "absl/types/span.h"
|
||||
#include "xla/pjrt/pjrt_layout.h"
|
||||
#include "xla/python/ifrt/array.h"
|
||||
#include "xla/python/ifrt/device.h"
|
||||
#include "xla/python/ifrt/future.h"
|
||||
#include "xla/python/ifrt/memory.h"
|
||||
#include "xla/python/ifrt/shape.h"
|
||||
|
|
@ -286,8 +287,9 @@ Future<> BasicStringArray::CopyToHostBuffer(
|
|||
return Future<>(absl::UnimplementedError("Not implemented"));
|
||||
}
|
||||
|
||||
absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Reshard(
|
||||
std::shared_ptr<const Sharding> new_sharding,
|
||||
absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Copy(
|
||||
std::optional<xla::ifrt::DeviceList> devices,
|
||||
std::optional<xla::ifrt::MemoryKind> memory_kind,
|
||||
ArrayCopySemantics semantics) {
|
||||
DCHECK(this);
|
||||
absl::MutexLock lock(&mu_);
|
||||
|
|
@ -295,6 +297,8 @@ absl::StatusOr<tsl::RCReference<Array>> BasicStringArray::Reshard(
|
|||
return absl::FailedPreconditionError("Array has already been deleted");
|
||||
}
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto new_sharding,
|
||||
sharding().WithDeviceAssignment(devices, memory_kind));
|
||||
if (new_sharding->devices().size() != sharding_->devices().size()) {
|
||||
return absl::InvalidArgumentError(absl::StrCat(
|
||||
"Number of devices in new sharding: ", new_sharding->devices().size(),
|
||||
|
|
|
|||
|
|
@ -34,8 +34,10 @@ limitations under the License.
|
|||
#include "llvm/Support/ExtensibleRTTI.h"
|
||||
#include "xla/pjrt/pjrt_layout.h"
|
||||
#include "xla/python/ifrt/array.h"
|
||||
#include "xla/python/ifrt/device.h"
|
||||
#include "xla/python/ifrt/dtype.h"
|
||||
#include "xla/python/ifrt/future.h"
|
||||
#include "xla/python/ifrt/memory.h"
|
||||
#include "xla/python/ifrt/shape.h"
|
||||
#include "xla/python/ifrt/sharding.h"
|
||||
#include "xla/tsl/concurrency/ref_count.h"
|
||||
|
|
@ -128,9 +130,10 @@ class BasicStringArray final
|
|||
void* data, std::optional<absl::Span<const int64_t>> byte_strides,
|
||||
ArrayCopySemantics semantics) override;
|
||||
|
||||
absl::StatusOr<tsl::RCReference<Array>> Reshard(
|
||||
std::shared_ptr<const Sharding> new_sharding,
|
||||
ArrayCopySemantics semantics) override;
|
||||
absl::StatusOr<tsl::RCReference<Array>> Copy(
|
||||
std::optional<xla::ifrt::DeviceList> devices,
|
||||
std::optional<xla::ifrt::MemoryKind> memory_kind,
|
||||
ArrayCopySemantics semantics);
|
||||
|
||||
Future<> GetReadyFuture() const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -450,10 +450,13 @@ absl::StatusOr<Memory*> GetMemorySpaceFromMemoryKind(
|
|||
return memory;
|
||||
}
|
||||
|
||||
absl::StatusOr<tsl::RCReference<Array>> PjRtArray::Reshard(
|
||||
std::shared_ptr<const Sharding> new_sharding,
|
||||
absl::StatusOr<tsl::RCReference<Array>> PjRtArray::Copy(
|
||||
std::optional<xla::ifrt::DeviceList> devices,
|
||||
std::optional<xla::ifrt::MemoryKind> memory_kind,
|
||||
ArrayCopySemantics semantics) {
|
||||
DCHECK(this);
|
||||
TF_ASSIGN_OR_RETURN(auto new_sharding,
|
||||
sharding().WithDeviceAssignment(devices, memory_kind));
|
||||
if (new_sharding->devices().size() != sharding_->devices().size()) {
|
||||
return InvalidArgument(
|
||||
"Resharding to a different number of devices: %d; expected %d",
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ limitations under the License.
|
|||
#include "llvm/Support/ExtensibleRTTI.h"
|
||||
#include "xla/python/ifrt/array.h"
|
||||
#include "xla/python/ifrt/client.h"
|
||||
#include "xla/python/ifrt/device.h"
|
||||
#include "xla/python/ifrt/shape.h"
|
||||
#include "xla/python/pjrt_ifrt/pjrt_client.h"
|
||||
#include "xla/tsl/concurrency/ref_count.h"
|
||||
|
|
@ -156,9 +157,10 @@ class PjRtArray final
|
|||
void* data, std::optional<absl::Span<const int64_t>> byte_strides,
|
||||
ArrayCopySemantics semantics) override;
|
||||
|
||||
absl::StatusOr<tsl::RCReference<Array>> Reshard(
|
||||
std::shared_ptr<const Sharding> new_sharding,
|
||||
ArrayCopySemantics semantics) override;
|
||||
absl::StatusOr<tsl::RCReference<Array>> Copy(
|
||||
std::optional<xla::ifrt::DeviceList> devices,
|
||||
std::optional<xla::ifrt::MemoryKind> memory_kind,
|
||||
ArrayCopySemantics semantics);
|
||||
|
||||
Future<> GetReadyFuture() const override;
|
||||
|
||||
|
|
|
|||
|
|
@ -696,11 +696,17 @@ absl::StatusOr<std::vector<tsl::RCReference<Array>>> PjRtClient::CopyArrays(
|
|||
std::vector<tsl::RCReference<Array>> new_arrays;
|
||||
new_arrays.reserve(arrays.size());
|
||||
for (const auto& array : arrays) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto new_sharding,
|
||||
array->sharding().WithDeviceAssignment(devices, memory_kind));
|
||||
TF_ASSIGN_OR_RETURN(new_arrays.emplace_back(),
|
||||
array->Reshard(std::move(new_sharding), semantics));
|
||||
if (auto* const pjrt_array = llvm::dyn_cast<PjRtArray>(array.get())) {
|
||||
TF_ASSIGN_OR_RETURN(new_arrays.emplace_back(),
|
||||
pjrt_array->Copy(devices, memory_kind, semantics));
|
||||
} else if (auto* const string_array =
|
||||
llvm::dyn_cast<BasicStringArray>(array.get())) {
|
||||
TF_ASSIGN_OR_RETURN(new_arrays.emplace_back(),
|
||||
string_array->Copy(devices, memory_kind, semantics));
|
||||
} else {
|
||||
return absl::InvalidArgumentError(
|
||||
"Unsupported array type for PjRtClient::CopyArrays");
|
||||
}
|
||||
}
|
||||
return new_arrays;
|
||||
}
|
||||
|
|
|
|||
2
third_party/xla/xla/python/tools/BUILD
vendored
2
third_party/xla/xla/python/tools/BUILD
vendored
|
|
@ -48,7 +48,7 @@ pytype_strict_library(
|
|||
)
|
||||
|
||||
# NOTE: Copybara detects the `tsl_pybind_extension` rule and automatically
|
||||
# injects the "@com_google_protobuf//:protobuf_python" python dependency
|
||||
# injects the @com_google_protobuf//:protobuf_python python dependency
|
||||
# required by "@pybind11_protobuf//pybind11_protobuf:native_proto_caster".
|
||||
tsl_pybind_extension(
|
||||
name = "_types",
|
||||
|
|
|
|||
13
third_party/xla/xla/service/BUILD
vendored
13
third_party/xla/xla/service/BUILD
vendored
|
|
@ -51,7 +51,9 @@ tf_proto_library(
|
|||
srcs = ["hlo.proto"],
|
||||
cc_api_version = 2,
|
||||
make_default_target_header_only = True,
|
||||
protodeps = ["//xla:xla_data_proto"],
|
||||
protodeps = [
|
||||
"//xla:xla_data_proto",
|
||||
] + if_google(["@com_google_protobuf//:any"]),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
|
@ -79,6 +81,11 @@ tf_proto_library(
|
|||
name = "metrics_proto",
|
||||
srcs = ["metrics.proto"],
|
||||
cc_api_version = 2,
|
||||
protodeps = if_google([
|
||||
"@com_google_protobuf//:any",
|
||||
"@com_google_protobuf//:duration",
|
||||
"@com_google_protobuf//:timestamp",
|
||||
]),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
|
@ -7566,7 +7573,7 @@ cc_library(
|
|||
"@local_tsl//tsl/platform:protobuf",
|
||||
"@local_tsl//tsl/platform:status",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
],
|
||||
] + if_google(["@com_google_protobuf//:any_cc_proto"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
@ -8154,7 +8161,7 @@ tf_proto_library(
|
|||
protodeps = [
|
||||
":hlo_proto",
|
||||
"@local_tsl//tsl/protobuf:status_proto",
|
||||
],
|
||||
] + if_google(["@com_google_protobuf//:duration"]),
|
||||
visibility = ["//visibility:public"],
|
||||
)
|
||||
|
||||
|
|
|
|||
1
third_party/xla/xla/service/cpu/BUILD
vendored
1
third_party/xla/xla/service/cpu/BUILD
vendored
|
|
@ -633,6 +633,7 @@ cc_library(
|
|||
"//xla/hlo/ir:hlo",
|
||||
"//xla/service:elemental_ir_emitter",
|
||||
"//xla/service/cpu:dot_op_emitter",
|
||||
"//xla/service/llvm_ir:dynamic_update_slice_util",
|
||||
"//xla/service/llvm_ir:fused_ir_emitter",
|
||||
"//xla/service/llvm_ir:ir_array",
|
||||
"//xla/service/llvm_ir:llvm_util",
|
||||
|
|
|
|||
|
|
@ -87,6 +87,33 @@ static void BM_BcastFusionF32(benchmark::State& state) {
|
|||
CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}));
|
||||
}
|
||||
|
||||
static void BM_DynamicUpdateSliceFusionF32(benchmark::State& state) {
|
||||
int64_t d0 = state.range(0);
|
||||
|
||||
std::string_view hlo = R"(
|
||||
HloModule dynamic_update_slice_fusion_f32_$d0
|
||||
|
||||
ENTRY e {
|
||||
p0 = f32[$d0,256] parameter(0)
|
||||
p1 = s32[] parameter(1)
|
||||
p2 = s32[] parameter(2)
|
||||
slice = f32[1,1] dynamic-slice(p0, p1, p2), dynamic_slice_sizes={1,1}
|
||||
add = f32[1,1] add(slice, slice)
|
||||
ROOT update = f32[$d0,256] dynamic-update-slice(p0, add, p1, p2)
|
||||
}
|
||||
)";
|
||||
|
||||
std::minstd_rand0 engine;
|
||||
|
||||
auto shape = ShapeUtil::MakeShape(F32, {d0, 256});
|
||||
auto p0 = *LiteralUtil::CreateRandomLiteral<F32>(shape, &engine, 1.0f, 0.1f);
|
||||
auto p1 = LiteralUtil::CreateR0<int32_t>(0);
|
||||
auto p2 = LiteralUtil::CreateR0<int32_t>(0);
|
||||
|
||||
std::vector<const Literal*> args = {&p0, &p1, &p2};
|
||||
CHECK_OK(RunHloBenchmark(state, hlo, args, {{"$d0", absl::StrCat(d0)}}));
|
||||
}
|
||||
|
||||
BENCHMARK(BM_FusionF32)
|
||||
->MeasureProcessCPUTime()
|
||||
->Arg(128)
|
||||
|
|
@ -105,4 +132,13 @@ BENCHMARK(BM_BcastFusionF32)
|
|||
->Arg(8192)
|
||||
->Arg(16384);
|
||||
|
||||
BENCHMARK(BM_DynamicUpdateSliceFusionF32)
|
||||
->MeasureProcessCPUTime()
|
||||
->Arg(128)
|
||||
->Arg(256)
|
||||
->Arg(512)
|
||||
->Arg(1024)
|
||||
->Arg(8192)
|
||||
->Arg(16384);
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
2
third_party/xla/xla/service/cpu/ir_emitter.h
vendored
2
third_party/xla/xla/service/cpu/ir_emitter.h
vendored
|
|
@ -165,6 +165,8 @@ class IrEmitter : public DfsHloVisitorWithDefault,
|
|||
return target_machine_features_;
|
||||
}
|
||||
|
||||
const BufferAssignment& assignment() const { return assignment_; }
|
||||
|
||||
protected:
|
||||
friend class IrEmitter2;
|
||||
|
||||
|
|
|
|||
19
third_party/xla/xla/service/cpu/ir_emitter2.cc
vendored
19
third_party/xla/xla/service/cpu/ir_emitter2.cc
vendored
|
|
@ -55,6 +55,7 @@ limitations under the License.
|
|||
#include "xla/service/cpu/parallel_loop_emitter.h"
|
||||
#include "xla/service/cpu/shape_partition.h"
|
||||
#include "xla/service/elemental_ir_emitter.h"
|
||||
#include "xla/service/llvm_ir/dynamic_update_slice_util.h"
|
||||
#include "xla/service/llvm_ir/fused_ir_emitter.h"
|
||||
#include "xla/service/llvm_ir/ir_array.h"
|
||||
#include "xla/service/llvm_ir/llvm_util.h"
|
||||
|
|
@ -285,8 +286,8 @@ absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitFusionHostKernel(
|
|||
|
||||
ElementalIrEmitter elemental_emitter(module_, &b, &hlo_module_,
|
||||
nested_ir_emitter_, fast_min_max());
|
||||
FusedIrEmitter fused_emitter(elemental_emitter);
|
||||
|
||||
FusedIrEmitter fused_emitter(elemental_emitter);
|
||||
for (int i = 0; i < fusion->operand_count(); i++) {
|
||||
fused_emitter.BindGenerator(
|
||||
*fusion->fused_parameter(i), [&, i](llvm_ir::IrArray::Index idx) {
|
||||
|
|
@ -294,6 +295,22 @@ absl::StatusOr<IrEmitter2::KernelInfo> IrEmitter2::EmitFusionHostKernel(
|
|||
});
|
||||
}
|
||||
|
||||
// Check if the fusion can be emitted in-place and skip expensive loop for
|
||||
// all elements in the output array.
|
||||
if (llvm_ir::CanEmitFusedDynamicUpdateSliceInPlace(
|
||||
const_cast<HloFusionInstruction*>(fusion),
|
||||
nested_ir_emitter_->assignment())) {
|
||||
// Delegate to common implementation of fused in-place dynamic-update-slice.
|
||||
TF_RETURN_IF_ERROR(llvm_ir::EmitFusedDynamicUpdateSliceInPlace(
|
||||
const_cast<HloFusionInstruction*>(fusion), kernel_prototype.results[0],
|
||||
&fused_emitter, &b));
|
||||
|
||||
return kernels_.emplace_back(
|
||||
KernelInfo{kernel_prototype.function->getName().str(), se::BlockDim(),
|
||||
se::ThreadDim()});
|
||||
}
|
||||
|
||||
// Emit plain elemental loops for the fusion operation.
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto element_generator,
|
||||
fused_emitter.GetGenerator(*fusion->fused_expression_root()));
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ limitations under the License.
|
|||
#include <optional>
|
||||
|
||||
#include "absl/algorithm/container.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/service/hlo_pass_interface.h"
|
||||
|
|
@ -35,7 +36,7 @@ class OneDnnConvolutionRewriter : public HloModulePass {
|
|||
}
|
||||
|
||||
using HloPassInterface::Run;
|
||||
StatusOr<bool> Run(
|
||||
absl::StatusOr<bool> Run(
|
||||
HloModule* module,
|
||||
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
|
||||
|
||||
|
|
|
|||
|
|
@ -717,8 +717,10 @@ cc_library(
|
|||
"//xla/stream_executor/host:host_kernel",
|
||||
"//xla/stream_executor/host:host_kernel_c_api",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"@com_google_absl//absl/base:core_headers",
|
||||
"@com_google_absl//absl/container:inlined_vector",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/numeric:bits",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
|
|
@ -746,6 +748,7 @@ xla_cc_test(
|
|||
"//xla/stream_executor/host:host_kernel_c_api",
|
||||
"//xla/tsl/concurrency:async_value",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@local_tsl//tsl/lib/core:status_test_util",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
"@local_tsl//tsl/platform:test",
|
||||
|
|
|
|||
|
|
@ -23,8 +23,10 @@ limitations under the License.
|
|||
#include <string>
|
||||
#include <utility>
|
||||
|
||||
#include "absl/base/optimization.h"
|
||||
#include "absl/container/inlined_vector.h"
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/numeric/bits.h"
|
||||
#include "absl/strings/str_format.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive
|
||||
|
|
@ -49,7 +51,12 @@ absl::StatusOr<std::unique_ptr<KernelThunk>> KernelThunk::Create(
|
|||
Info info, absl::Span<const BufferAllocation::Slice> arguments_buffers,
|
||||
absl::Span<const BufferAllocation::Slice> results_buffers,
|
||||
std::string kernel_name, se::ThreadDim thread_dim,
|
||||
std::optional<int64_t> min_alignment) {
|
||||
std::optional<uint64_t> min_alignment) {
|
||||
if (min_alignment.has_value() && !absl::has_single_bit(*min_alignment)) {
|
||||
return Internal("Host kernel %s minimum alignment %d is not a power of 2",
|
||||
info.op_name, *min_alignment);
|
||||
}
|
||||
|
||||
return absl::WrapUnique(
|
||||
new KernelThunk(std::move(info), arguments_buffers, results_buffers,
|
||||
std::move(kernel_name), thread_dim, min_alignment));
|
||||
|
|
@ -59,13 +66,14 @@ KernelThunk::KernelThunk(
|
|||
Info info, absl::Span<const BufferAllocation::Slice> arguments_buffers,
|
||||
absl::Span<const BufferAllocation::Slice> results_buffers,
|
||||
std::string kernel_name, se::ThreadDim thread_dim,
|
||||
std::optional<int64_t> min_alignment)
|
||||
std::optional<uint64_t> min_alignment)
|
||||
: Thunk(Kind::kKernel, std::move(info)),
|
||||
arguments_buffers_(arguments_buffers.begin(), arguments_buffers.end()),
|
||||
results_buffers_(results_buffers.begin(), results_buffers.end()),
|
||||
kernel_name_(std::move(kernel_name)),
|
||||
thread_dim_(thread_dim),
|
||||
min_alignment_(min_alignment),
|
||||
use_task_runner_(thread_dim != se::ThreadDim()),
|
||||
kernel_ptr_(nullptr) {}
|
||||
|
||||
tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
|
||||
|
|
@ -104,12 +112,12 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
|
|||
// will crash with a segmentation fault, or worse, produce incorrect results.
|
||||
if (min_alignment_.has_value()) {
|
||||
for (int64_t i = 0; i < buffers_data.size(); ++i) {
|
||||
se::DeviceMemoryBase& data = buffers_data[i];
|
||||
if (reinterpret_cast<uintptr_t>(data.opaque()) % *min_alignment_ != 0) {
|
||||
auto ptr = reinterpret_cast<uintptr_t>(buffers_data[i].opaque());
|
||||
if (ABSL_PREDICT_FALSE((ptr & (*min_alignment_ - 1)) != 0)) {
|
||||
return Internal(
|
||||
"Host kernel %s buffer argument #%d (%p) is not aligned to a "
|
||||
"required minimum alignment of %d bytes",
|
||||
info().op_name, i, data.opaque(), *min_alignment_);
|
||||
info().op_name, i, buffers_data[i].opaque(), *min_alignment_);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -130,7 +138,7 @@ tsl::AsyncValueRef<Thunk::ExecuteEvent> KernelThunk::Execute(
|
|||
// If intra-op thread pool is not nullptr, we launch HostKernel in async mode
|
||||
// by scheduling tasks into it. HostKernel launch completion will
|
||||
// automatically signal KernelThunk execute completion.
|
||||
if (params.intra_op_threadpool) {
|
||||
if (params.intra_op_threadpool && use_task_runner_) {
|
||||
return kernel.Launch(thread_dim_, buffers_data,
|
||||
[¶ms](se::host::HostKernel::Task task) {
|
||||
params.intra_op_threadpool->getPool()->Schedule(
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ class KernelThunk final : public Thunk {
|
|||
Info info, absl::Span<const BufferAllocation::Slice> arguments_buffers,
|
||||
absl::Span<const BufferAllocation::Slice> results_buffers,
|
||||
std::string kernel_name, se::ThreadDim thread_dim,
|
||||
std::optional<int64_t> min_alignment = std::nullopt);
|
||||
std::optional<uint64_t> min_alignment = std::nullopt);
|
||||
|
||||
tsl::AsyncValueRef<ExecuteEvent> Execute(const ExecuteParams& params) final;
|
||||
|
||||
|
|
@ -51,13 +51,18 @@ class KernelThunk final : public Thunk {
|
|||
absl::Span<const BufferAllocation::Slice> arguments_buffers,
|
||||
absl::Span<const BufferAllocation::Slice> results_buffers,
|
||||
std::string kernel_name, se::ThreadDim thread_dim,
|
||||
std::optional<int64_t> min_alignment);
|
||||
std::optional<uint64_t> min_alignment);
|
||||
|
||||
std::vector<BufferAllocation::Slice> arguments_buffers_;
|
||||
std::vector<BufferAllocation::Slice> results_buffers_;
|
||||
std::string kernel_name_;
|
||||
se::ThreadDim thread_dim_;
|
||||
std::optional<int64_t> min_alignment_;
|
||||
std::optional<uint64_t> min_alignment_;
|
||||
|
||||
// If `true`, pass a HostKernel::TaskRunner to the kernel launch. If kernel
|
||||
// has a single thread, we skip constructing HostKernel::TaskRunner and
|
||||
// launch the kernel directly in the caller thread.
|
||||
bool use_task_runner_;
|
||||
|
||||
// Pointer to the host kernel corresponding to `kernel_name_`. Initialized
|
||||
// lazily at run time by looking it up in the HostKernels passed via params.
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include <vector>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/match.h"
|
||||
#include "xla/service/buffer_assignment.h"
|
||||
#include "xla/service/cpu/runtime/buffer_allocations.h"
|
||||
#include "xla/service/cpu/runtime/thunk.h"
|
||||
|
|
@ -53,6 +54,13 @@ class AddF32HostKernels : public Thunk::HostKernels {
|
|||
}
|
||||
};
|
||||
|
||||
TEST(KernelThunkTest, CheckAlignment) {
|
||||
auto thunk = KernelThunk::Create({"test"}, {}, {}, "test", se::ThreadDim(),
|
||||
/*min_alignment=*/3);
|
||||
EXPECT_TRUE(absl::StrContains(thunk.status().message(),
|
||||
"minimum alignment 3 is not a power of 2"));
|
||||
}
|
||||
|
||||
TEST(KernelThunkTest, AddF32) {
|
||||
std::vector<MaybeOwningDeviceMemory> buffers;
|
||||
std::vector<float> in = {1.0, 2.0, 3.0, 4.0};
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ namespace xla::cpu {
|
|||
ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence,
|
||||
std::vector<NodeDef> nodes_defs)
|
||||
: thunk_sequence_(std::move(thunk_sequence)),
|
||||
nodes_defs_(std::move(nodes_defs)) {
|
||||
nodes_defs_(std::move(nodes_defs)),
|
||||
is_sequential_(true) {
|
||||
for (NodeId i = 0; i < nodes_defs_.size(); ++i) {
|
||||
// Mark nodes with empty in-edges as source nodes.
|
||||
if (nodes_defs_[i].in_edges.empty()) {
|
||||
|
|
@ -57,10 +58,17 @@ ThunkExecutor::ThunkExecutor(ThunkSequence thunk_sequence,
|
|||
// Erase redundant edges between nodes.
|
||||
int64_t num_erased_edges = TransitiveReduction();
|
||||
|
||||
// Check if constructed execution DAG is sequential: every node depends on the
|
||||
// completion of the previous node.
|
||||
for (NodeId i = 1; i < nodes_defs_.size() && is_sequential_; ++i) {
|
||||
is_sequential_ &= (absl::c_count(nodes_defs_[i].in_edges, i - 1) != 0);
|
||||
}
|
||||
|
||||
VLOG(2) << absl::StreamFormat(
|
||||
"Constructed ThunkExecutor with %d nodes: #source_nodes=%d "
|
||||
"#sink_nodes=%d, #erased_edges=%d",
|
||||
nodes_defs_.size(), source_.size(), sink_.size(), num_erased_edges);
|
||||
"#sink_nodes=%d, #erased_edges=%d, is_sequential=%v",
|
||||
nodes_defs_.size(), source_.size(), sink_.size(), num_erased_edges,
|
||||
is_sequential_);
|
||||
|
||||
// Sanity check that all vectors are empty or all vectors are non-empty.
|
||||
DCHECK((!source_.empty() && !sink_.empty() && !thunk_sequence_.empty()) ||
|
||||
|
|
@ -123,6 +131,13 @@ tsl::AsyncValueRef<ThunkExecutor::ExecuteEvent> ThunkExecutor::Execute(
|
|||
return thunk_sequence_[0]->Execute(params);
|
||||
}
|
||||
|
||||
// If thunk sequence dependencies form a sequential execution graph, we skip
|
||||
// expensive async execution and simply run thunks one by one.
|
||||
if (is_sequential_) {
|
||||
return ExecuteSequential(params);
|
||||
}
|
||||
|
||||
// Create async execution state on heap and kick-off execution.
|
||||
auto state = std::make_unique<ExecuteState>(this, std::move(runner));
|
||||
Execute(state.get(), params, ReadyQueue(source_.begin(), source_.end()));
|
||||
|
||||
|
|
@ -138,6 +153,70 @@ tsl::AsyncValueRef<ThunkExecutor::ExecuteEvent> ThunkExecutor::Execute(
|
|||
return execute_event;
|
||||
}
|
||||
|
||||
tsl::AsyncValueRef<ThunkExecutor::ExecuteEvent>
|
||||
ThunkExecutor::ExecuteSequential(const Thunk::ExecuteParams& params) {
|
||||
for (int64_t i = 0; i < thunk_sequence_.size(); ++i) {
|
||||
Thunk& thunk = *thunk_sequence_[i];
|
||||
auto execute_event = thunk.Execute(params);
|
||||
|
||||
// If thunk execution is not completed yet, attach a continuation to
|
||||
// resume sequential execution starting from the next thunk.
|
||||
if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) {
|
||||
auto event = tsl::MakeConstructedAsyncValueRef<ExecuteEvent>();
|
||||
execute_event.AndThen([this, ¶ms, i, event](absl::Status status) {
|
||||
if (ABSL_PREDICT_FALSE(!status.ok())) {
|
||||
event.SetError(std::move(status));
|
||||
} else {
|
||||
ResumeExecuteSequential(i + 1, params, std::move(event));
|
||||
}
|
||||
});
|
||||
return event;
|
||||
}
|
||||
|
||||
// Abort execution if any of the thunks failed.
|
||||
if (ABSL_PREDICT_FALSE(execute_event.IsError())) {
|
||||
return execute_event;
|
||||
}
|
||||
}
|
||||
|
||||
// If we got to the end of the sequence it means that all thunks have
|
||||
// succeeded.
|
||||
return Thunk::OkExecuteEvent();
|
||||
}
|
||||
|
||||
void ThunkExecutor::ResumeExecuteSequential(
|
||||
int64_t index, const Thunk::ExecuteParams& params,
|
||||
tsl::AsyncValueRef<ExecuteEvent> event) {
|
||||
for (int64_t i = index; i < thunk_sequence_.size(); ++i) {
|
||||
Thunk& thunk = *thunk_sequence_[i];
|
||||
auto execute_event = thunk.Execute(params);
|
||||
|
||||
// If thunk execution is not completed yet, attach a continuation to
|
||||
// resume sequential execution starting from the next thunk.
|
||||
if (ABSL_PREDICT_FALSE(!execute_event.IsAvailable())) {
|
||||
execute_event.AndThen(
|
||||
[this, ¶ms, i, event = std::move(event)](absl::Status status) {
|
||||
if (ABSL_PREDICT_FALSE(!status.ok())) {
|
||||
event.SetError(std::move(status));
|
||||
} else {
|
||||
ResumeExecuteSequential(i + 1, params, std::move(event));
|
||||
}
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
// Abort execution if any of the thunks failed.
|
||||
if (ABSL_PREDICT_FALSE(execute_event.IsError())) {
|
||||
event.SetError(execute_event.GetError());
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If we got to the end of the sequence it means that all thunks have
|
||||
// succeeded.
|
||||
event.SetStateConcrete();
|
||||
}
|
||||
|
||||
void ThunkExecutor::Execute(ExecuteState* state,
|
||||
const Thunk::ExecuteParams& params,
|
||||
ReadyQueue ready_queue) {
|
||||
|
|
|
|||
|
|
@ -85,6 +85,8 @@ class ThunkExecutor {
|
|||
|
||||
std::string ToString() const;
|
||||
|
||||
bool is_sequential() const { return is_sequential_; }
|
||||
|
||||
private:
|
||||
using ReadyQueue = absl::InlinedVector<NodeId, 8>;
|
||||
|
||||
|
|
@ -121,6 +123,15 @@ class ThunkExecutor {
|
|||
tsl::AsyncValueRef<ExecuteEvent> execute_event;
|
||||
};
|
||||
|
||||
// Executes thunks sequentially starting from the first thunk in the sequence.
|
||||
tsl::AsyncValueRef<ExecuteEvent> ExecuteSequential(
|
||||
const Thunk::ExecuteParams& params);
|
||||
|
||||
// Resumes sequential thunk execution starting from the given index.
|
||||
void ResumeExecuteSequential(int64_t index,
|
||||
const Thunk::ExecuteParams& params,
|
||||
tsl::AsyncValueRef<ExecuteEvent> event);
|
||||
|
||||
// Executes nodes in the ready queue with given thunk parameters.
|
||||
void Execute(ExecuteState* state, const Thunk::ExecuteParams& params,
|
||||
ReadyQueue ready_queue);
|
||||
|
|
@ -143,6 +154,11 @@ class ThunkExecutor {
|
|||
|
||||
std::vector<NodeId> source_;
|
||||
std::vector<NodeId> sink_;
|
||||
|
||||
// If NodeDef graph dependency structure is sequential and does not have any
|
||||
// opportunities for executing thunks concurrently, we skip the expensive
|
||||
// async execution and simply run thunks in the `thunk_sequence_` one by one.
|
||||
bool is_sequential_;
|
||||
};
|
||||
|
||||
} // namespace xla::cpu
|
||||
|
|
|
|||
|
|
@ -63,13 +63,15 @@ using ::testing::ElementsAre;
|
|||
class AddI32Thunk final : public Thunk {
|
||||
public:
|
||||
AddI32Thunk(std::string name, std::vector<BufferAllocation::Slice> srcs,
|
||||
std::vector<BufferAllocation::Slice> dsts, bool inject_error,
|
||||
std::vector<std::string>* trace);
|
||||
std::vector<BufferAllocation::Slice> dsts,
|
||||
std::vector<std::string>* trace, bool inject_error,
|
||||
bool inject_side_effect);
|
||||
|
||||
static std::unique_ptr<Thunk> Create(
|
||||
std::string name, std::vector<BufferAllocation::Slice> srcs,
|
||||
std::vector<BufferAllocation::Slice> dsts, bool inject_error = false,
|
||||
std::vector<std::string>* trace = nullptr);
|
||||
std::vector<BufferAllocation::Slice> dsts,
|
||||
std::vector<std::string>* trace = nullptr, bool inject_error = false,
|
||||
bool inject_side_effect = false);
|
||||
|
||||
static std::vector<MaybeOwningDeviceMemory> AsDeviceMemory(
|
||||
absl::Span<std::vector<int32_t>* const> data);
|
||||
|
|
@ -86,16 +88,18 @@ class AddI32Thunk final : public Thunk {
|
|||
private:
|
||||
std::vector<BufferAllocation::Slice> srcs_;
|
||||
std::vector<BufferAllocation::Slice> dsts_;
|
||||
bool inject_error_;
|
||||
std::vector<std::string>* trace_;
|
||||
bool inject_error_;
|
||||
bool inject_side_effect_;
|
||||
};
|
||||
|
||||
std::unique_ptr<Thunk> AddI32Thunk::Create(
|
||||
std::string name, std::vector<BufferAllocation::Slice> srcs,
|
||||
std::vector<BufferAllocation::Slice> dsts, bool inject_error,
|
||||
std::vector<std::string>* trace) {
|
||||
std::vector<BufferAllocation::Slice> dsts, std::vector<std::string>* trace,
|
||||
bool inject_error, bool inject_side_effect) {
|
||||
return std::make_unique<AddI32Thunk>(std::move(name), std::move(srcs),
|
||||
std::move(dsts), inject_error, trace);
|
||||
std::move(dsts), trace, inject_error,
|
||||
inject_side_effect);
|
||||
}
|
||||
|
||||
std::vector<MaybeOwningDeviceMemory> AddI32Thunk::AsDeviceMemory(
|
||||
|
|
@ -111,12 +115,14 @@ std::vector<MaybeOwningDeviceMemory> AddI32Thunk::AsDeviceMemory(
|
|||
AddI32Thunk::AddI32Thunk(std::string name,
|
||||
std::vector<BufferAllocation::Slice> srcs,
|
||||
std::vector<BufferAllocation::Slice> dsts,
|
||||
bool inject_error, std::vector<std::string>* trace)
|
||||
std::vector<std::string>* trace, bool inject_error,
|
||||
bool inject_side_effect)
|
||||
: Thunk(Kind::kKernel, Info{name}),
|
||||
srcs_(std::move(srcs)),
|
||||
dsts_(std::move(dsts)),
|
||||
trace_(trace),
|
||||
inject_error_(inject_error),
|
||||
trace_(trace) {}
|
||||
inject_side_effect_(inject_side_effect) {}
|
||||
|
||||
absl::Status AddI32Thunk::Execute(const BufferAllocations* allocations,
|
||||
BufferAllocation::Slice src_slice,
|
||||
|
|
@ -178,10 +184,20 @@ AddI32Thunk::BufferUses AddI32Thunk::buffer_uses() const {
|
|||
BufferUses buffer_uses;
|
||||
for (const auto& src : srcs_) buffer_uses.push_back(BufferUse::Read(src));
|
||||
for (const auto& dst : dsts_) buffer_uses.push_back(BufferUse::Write(dst));
|
||||
|
||||
// TODO(ezhulenev): Add proper side-effect support to Thunks. For now we just
|
||||
// inject a write to a random slice of allocation 0 to emulate a side-effect
|
||||
// and force all thunks to be executed sequentially.
|
||||
if (inject_side_effect_) {
|
||||
static auto* fake_alloc = new BufferAllocation(0, 1, 0);
|
||||
buffer_uses.push_back(
|
||||
BufferUse::Write(BufferAllocation::Slice(fake_alloc, 0, 1)));
|
||||
}
|
||||
|
||||
return buffer_uses;
|
||||
}
|
||||
|
||||
TEST(ThunkExecutorTest, Ordering) {
|
||||
TEST(ThunkExecutorTest, DependencyOrdering) {
|
||||
BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0);
|
||||
|
||||
BufferAllocation::Slice slice0(&alloc, /*offset=*/0, /*size=*/40);
|
||||
|
|
@ -196,10 +212,28 @@ TEST(ThunkExecutorTest, Ordering) {
|
|||
TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor,
|
||||
ThunkExecutor::Create(std::move(sequence)));
|
||||
|
||||
EXPECT_FALSE(executor.is_sequential());
|
||||
EXPECT_THAT(executor.source(), ElementsAre(0, 1));
|
||||
EXPECT_THAT(executor.sink(), ElementsAre(2));
|
||||
}
|
||||
|
||||
TEST(ThunkExecutorTest, SequentialOrdering) {
|
||||
BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0);
|
||||
BufferAllocation::Slice slice(&alloc, /*offset=*/0, /*size=*/40);
|
||||
|
||||
ThunkSequence sequence;
|
||||
sequence.push_back(AddI32Thunk::Create("a", {slice}, {slice}));
|
||||
sequence.push_back(AddI32Thunk::Create("b", {slice}, {slice}));
|
||||
sequence.push_back(AddI32Thunk::Create("c", {slice}, {slice}));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor,
|
||||
ThunkExecutor::Create(std::move(sequence)));
|
||||
|
||||
EXPECT_TRUE(executor.is_sequential());
|
||||
EXPECT_THAT(executor.source(), ElementsAre(0));
|
||||
EXPECT_THAT(executor.sink(), ElementsAre(2));
|
||||
}
|
||||
|
||||
TEST(ThunkExecutorTest, TransitiveReduction) {
|
||||
BufferAllocation alloc(/*index=*/0, /*size=*/80, /*color=*/0);
|
||||
BufferAllocation::Slice slice(&alloc, /*offset=*/0, /*size=*/40);
|
||||
|
|
@ -231,12 +265,9 @@ TEST(ThunkExecutorTest, Execute) {
|
|||
std::vector<std::string> trace;
|
||||
|
||||
ThunkSequence sequence;
|
||||
sequence.push_back(AddI32Thunk::Create("a", {slice0}, {slice0},
|
||||
/*inject_error=*/false, &trace));
|
||||
sequence.push_back(AddI32Thunk::Create("b", {slice1}, {slice1},
|
||||
/*inject_error=*/false, &trace));
|
||||
sequence.push_back(AddI32Thunk::Create("c", {slice2}, {slice2},
|
||||
/*inject_error=*/false, &trace));
|
||||
sequence.push_back(AddI32Thunk::Create("a", {slice0}, {slice0}, &trace));
|
||||
sequence.push_back(AddI32Thunk::Create("b", {slice1}, {slice1}, &trace));
|
||||
sequence.push_back(AddI32Thunk::Create("c", {slice2}, {slice2}, &trace));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor,
|
||||
ThunkExecutor::Create(std::move(sequence)));
|
||||
|
|
@ -281,7 +312,7 @@ struct GeneratedThunkSequence {
|
|||
|
||||
static absl::StatusOr<std::unique_ptr<GeneratedThunkSequence>>
|
||||
GenerateThunkSequence(size_t num_elements, size_t num_thunks,
|
||||
bool inject_errors = false) {
|
||||
bool inject_errors, bool inject_side_effects) {
|
||||
auto g = std::make_unique<GeneratedThunkSequence>(GeneratedThunkSequence{
|
||||
BufferAllocation(/*index=*/0, num_elements * sizeof(int32_t), 0),
|
||||
BufferAllocation(/*index=*/1, num_elements * sizeof(int32_t), 0),
|
||||
|
|
@ -316,8 +347,9 @@ GenerateThunkSequence(size_t num_elements, size_t num_thunks,
|
|||
TF_RETURN_IF_ERROR(AddI32Thunk::Execute(&allocations, src, dst));
|
||||
|
||||
bool inject_error = inject_errors && inject_error_dist(engine) == 0;
|
||||
g->sequence.push_back(
|
||||
AddI32Thunk::Create(absl::StrCat(i), {src}, {dst}, inject_error));
|
||||
g->sequence.push_back(AddI32Thunk::Create(absl::StrCat(i), {src}, {dst},
|
||||
/*trace=*/nullptr, inject_error,
|
||||
inject_side_effects));
|
||||
}
|
||||
|
||||
return g;
|
||||
|
|
@ -326,10 +358,12 @@ GenerateThunkSequence(size_t num_elements, size_t num_thunks,
|
|||
// Parameterized thunk executor stress tests that builds a random thunk sequence
|
||||
// and optionally uses a thread pool to execute thunk executor tasks.
|
||||
class ThunkExecutorStressTest
|
||||
: public testing::TestWithParam<std::tuple<int32_t, bool, bool, bool>> {
|
||||
: public testing::TestWithParam<
|
||||
std::tuple<int32_t, bool, bool, bool, bool>> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
auto& [_, use_task_runner, use_device, inject_errors] = GetParam();
|
||||
auto& [_, use_task_runner, use_device, inject_errors, inject_side_effects] =
|
||||
GetParam();
|
||||
|
||||
use_task_runner_ = use_task_runner;
|
||||
use_device_ = use_device;
|
||||
|
|
@ -366,11 +400,13 @@ class ThunkExecutorStressTest
|
|||
};
|
||||
|
||||
TEST_P(ThunkExecutorStressTest, Execute) {
|
||||
auto [num_thunks, use_task_runner, use_device, inject_errors] = GetParam();
|
||||
auto [num_thunks, use_task_runner, use_device, inject_errors,
|
||||
inject_side_effects] = GetParam();
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
std::unique_ptr<GeneratedThunkSequence> g,
|
||||
GenerateThunkSequence(/*num_elements=*/1024, num_thunks, inject_errors));
|
||||
GenerateThunkSequence(/*num_elements=*/1024, num_thunks, inject_errors,
|
||||
inject_side_effects));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(ThunkExecutor executor,
|
||||
ThunkExecutor::Create(std::move(g->sequence)));
|
||||
|
|
@ -393,7 +429,7 @@ TEST_P(ThunkExecutorStressTest, Execute) {
|
|||
INSTANTIATE_TEST_SUITE_P(ThunkExecutor, ThunkExecutorStressTest,
|
||||
testing::Combine(testing::ValuesIn({10, 100, 1000}),
|
||||
testing::Bool(), testing::Bool(),
|
||||
testing::Bool()));
|
||||
testing::Bool(), testing::Bool()));
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Performance benchmarks below
|
||||
|
|
@ -402,7 +438,10 @@ INSTANTIATE_TEST_SUITE_P(ThunkExecutor, ThunkExecutorStressTest,
|
|||
static void BM_SyncThunkExecutor(benchmark::State& state) {
|
||||
const size_t num_thunks = state.range(0);
|
||||
|
||||
auto g = GenerateThunkSequence(/*num_elements=*/1024, num_thunks).value();
|
||||
auto g = GenerateThunkSequence(/*num_elements=*/1024, num_thunks,
|
||||
/*inject_errors=*/false,
|
||||
/*inject_side_effects=*/false)
|
||||
.value();
|
||||
auto e = ThunkExecutor::Create(std::move(g->sequence)).value();
|
||||
|
||||
BufferAllocations allocations(g->buffers);
|
||||
|
|
@ -422,7 +461,10 @@ static void BM_AsyncThunkExecutor(benchmark::State& state) {
|
|||
Eigen::ThreadPoolDevice device(thread_pool.AsEigenThreadPool(),
|
||||
thread_pool.NumThreads());
|
||||
|
||||
auto g = GenerateThunkSequence(/*num_elements=*/1024, num_thunks).value();
|
||||
auto g = GenerateThunkSequence(/*num_elements=*/1024, num_thunks,
|
||||
/*inject_errors=*/false,
|
||||
/*inject_side_effects=*/false)
|
||||
.value();
|
||||
auto e = ThunkExecutor::Create(std::move(g->sequence)).value();
|
||||
|
||||
BufferAllocations allocations(g->buffers);
|
||||
|
|
|
|||
2
third_party/xla/xla/service/gpu/BUILD
vendored
2
third_party/xla/xla/service/gpu/BUILD
vendored
|
|
@ -4797,7 +4797,7 @@ cc_library(
|
|||
"@local_config_cuda//cuda:cuda_headers",
|
||||
"@local_config_cuda//cuda:cudnn_header",
|
||||
]) + if_static([
|
||||
"@com_google_protobuf//:protobuf",
|
||||
"@com_google_protobuf//:wrappers_cc_proto",
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -171,13 +171,9 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
|
|||
auto launch_config = *this->launch_config();
|
||||
launch_dimensions = launch_config.launch_dimensions;
|
||||
|
||||
// TODO(bchetioui): parse block-level parameters from backend config
|
||||
// where available.
|
||||
BlockLevelParameters block_level_parameters;
|
||||
block_level_parameters.output_tile_sizes = std::vector<int64_t>(
|
||||
hlo_computation->root_instruction()->shape().rank() - 1, 1);
|
||||
block_level_parameters.output_tile_sizes.push_back(
|
||||
hlo_computation->root_instruction()->shape().dimensions().back());
|
||||
block_level_parameters.output_tile_sizes =
|
||||
launch_config.output_tile_sizes;
|
||||
block_level_parameters.num_warps =
|
||||
launch_dimensions.num_threads_per_block() / WarpSize();
|
||||
block_level_parameters.num_ctas = 1;
|
||||
|
|
@ -283,6 +279,29 @@ absl::StatusOr<FusionEmissionResult> TritonFusion::Emit(
|
|||
}
|
||||
|
||||
std::optional<TritonFusion::LaunchConfig> TritonFusion::launch_config() const {
|
||||
if (analysis_.fusion_backend_config().has_block_level_fusion_config()) {
|
||||
BlockLevelParameters block_level_parameters =
|
||||
BlockLevelParameters::FromBlockLevelFusionConfig(
|
||||
analysis_.fusion_backend_config().block_level_fusion_config());
|
||||
|
||||
int64_t num_blocks = 1;
|
||||
for (auto [dim_size, dim_tile_size] :
|
||||
llvm::zip(analysis_.fusion_root(0).shape().dimensions(),
|
||||
block_level_parameters.output_tile_sizes)) {
|
||||
num_blocks *= (dim_size + dim_tile_size - 1) / dim_tile_size;
|
||||
}
|
||||
|
||||
LaunchConfig launch_config;
|
||||
launch_config.launch_dimensions = LaunchDimensions(
|
||||
static_cast<uint64_t>(num_blocks),
|
||||
static_cast<uint64_t>(block_level_parameters.num_warps * WarpSize()));
|
||||
launch_config.output_tile_sizes = block_level_parameters.output_tile_sizes;
|
||||
return launch_config;
|
||||
}
|
||||
|
||||
// TODO(shyshkov): Remove the SoftMax heuristic once the block-level fusion
|
||||
// config is fully rolled out. All tiles size should be set before reaching
|
||||
// this point.
|
||||
if (analysis_.fusion_backend_config().kind() == kTritonFusionKind) {
|
||||
// TODO(b/332649307): Change the line below to something more generic that
|
||||
// can handle different instructions (not just Reduce) and different
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ limitations under the License.
|
|||
#include "xla/stream_executor/device_description.h"
|
||||
#include "xla/stream_executor/device_description.pb.h"
|
||||
#include "xla/tests/hlo_test_base.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
|
||||
namespace xla {
|
||||
namespace gpu {
|
||||
|
|
@ -33,9 +32,66 @@ using ::testing::ElementsAre;
|
|||
|
||||
class TritonFusionTest : public HloTestBase {};
|
||||
|
||||
TEST_F(TritonFusionTest, TritonSoftmaxFusion) {
|
||||
TEST_F(TritonFusionTest,
|
||||
TritonFusionWithBlockLevelFusionConfig_LaunchDimensionsAreCorrect) {
|
||||
#ifndef GOOGLE_CUDA
|
||||
GTEST_SKIP() << "Triton fusion only enable for CUDA devices.";
|
||||
GTEST_SKIP() << "Triton fusion only enabled for CUDA devices.";
|
||||
#endif
|
||||
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule t
|
||||
|
||||
add {
|
||||
Arg_0 = f32[] parameter(0)
|
||||
Arg_1 = f32[] parameter(1)
|
||||
ROOT add = f32[] add(Arg_0, Arg_1)
|
||||
}
|
||||
|
||||
auxiliary_computation {
|
||||
parameter_0 = f32[125]{0} parameter(0)
|
||||
ROOT broadcast = f32[125,127]{1,0} broadcast(parameter_0), dimensions={0}
|
||||
}
|
||||
|
||||
triton_softmax_computation {
|
||||
parameter_0 = f32[125,127]{1,0} parameter(0)
|
||||
multiply_0 = f32[125,127]{1,0} multiply(parameter_0, parameter_0)
|
||||
constant_0 = f32[] constant(0)
|
||||
reduce_0 = f32[125]{0} reduce(multiply_0, constant_0), dimensions={1}, to_apply=add
|
||||
broadcast_4 = f32[125,127]{1,0} broadcast(reduce_0), dimensions={0}
|
||||
ROOT multiply = f32[125,127]{1,0} multiply(multiply_0, broadcast_4)
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
param_0 = f32[125]{0} parameter(0)
|
||||
auxiliary_fusion = f32[125,127]{1,0} fusion(param_0), kind=kLoop, calls=auxiliary_computation
|
||||
ROOT triton_softmax = f32[125,127]{1,0} fusion(auxiliary_fusion), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config":{"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["3","127"],"num_warps":"4"}}}
|
||||
})")
|
||||
.value();
|
||||
|
||||
stream_executor::GpuDeviceInfoProto device_info_proto;
|
||||
stream_executor::DeviceDescription device_info(device_info_proto);
|
||||
|
||||
auto* root = module->entry_computation()->root_instruction();
|
||||
auto analysis_fused =
|
||||
AnalyzeProducerConsumerFusion(*root->operand(0), *root, device_info);
|
||||
|
||||
auto emitter_fused =
|
||||
GetFusionEmitter(PreBufferAssignmentFusionInfo{analysis_fused});
|
||||
auto triton_fusion = dynamic_cast<TritonFusion*>(emitter_fused.get());
|
||||
ASSERT_NE(triton_fusion, nullptr);
|
||||
auto launch_config = triton_fusion->launch_config();
|
||||
ASSERT_NE(launch_config, std::nullopt);
|
||||
EXPECT_EQ(launch_config->launch_dimensions.num_blocks(),
|
||||
/*ceil(125 / 3)=*/42);
|
||||
EXPECT_EQ(launch_config->launch_dimensions.num_threads_per_block(),
|
||||
/*32 * num_warps=*/128);
|
||||
EXPECT_THAT(launch_config->output_tile_sizes, ElementsAre(3, 127));
|
||||
}
|
||||
|
||||
TEST_F(TritonFusionTest,
|
||||
TritonFusionWithoutBlockLevelFusionConfig_LaunchFromSoftMaxHeuristic) {
|
||||
#ifndef GOOGLE_CUDA
|
||||
GTEST_SKIP() << "Triton fusion only enabled for CUDA devices.";
|
||||
#endif
|
||||
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ limitations under the License.
|
|||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/hlo/utils/hlo_query.h"
|
||||
#include "xla/layout_util.h"
|
||||
#include "xla/literal.h"
|
||||
#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h"
|
||||
#include "xla/mlir_hlo/mhlo/transforms/map_mhlo_to_scalar_op.h"
|
||||
|
|
@ -2493,8 +2494,23 @@ MakeTensorPtrOpAndBoundaryChecks CreateMakeTensorPtrOp(
|
|||
llvm::SmallVector<int32_t> order;
|
||||
llvm::SmallVector<int32_t> boundary_checks;
|
||||
|
||||
const std::vector<int64_t>& tile_strides = tiled_hlo.tile_strides();
|
||||
const Shape& shape = tiled_hlo.hlo()->shape();
|
||||
|
||||
// Compute physical strides of the tile. `tile_strides` contains strides for
|
||||
// individual dimensions. We need to convert them to strides in the buffer
|
||||
// taking into account physical layout.
|
||||
// TODO(b/331332678): Compute indexing maps to physical layout indexing in
|
||||
// SymbolicTileAnalysis.
|
||||
llvm::SmallVector<int64_t> physical_strides(tile_strides.size(), 1);
|
||||
int64_t current_stride = 1;
|
||||
for (int64_t cur_dim : LayoutUtil::MinorToMajor(shape)) {
|
||||
physical_strides[cur_dim] = tile_strides[cur_dim] * current_stride;
|
||||
current_stride *= shape.dimensions(cur_dim);
|
||||
}
|
||||
|
||||
for (auto [size, stride] :
|
||||
llvm::zip(tiled_hlo.tile_sizes(), tiled_hlo.tile_strides())) {
|
||||
llvm::zip(tiled_hlo.tile_sizes(), physical_strides)) {
|
||||
if (size == 1) continue;
|
||||
|
||||
int dimension_index = sizes.size();
|
||||
|
|
|
|||
|
|
@ -161,7 +161,7 @@ TEST_F(TritonMakeTensorPtrTest, BlockProperties) {
|
|||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getShape()), ElementsAre(3, 4));
|
||||
EXPECT_THAT(TensorShape(ptr.op), ElementsAre(4, 4));
|
||||
EXPECT_THAT(ptr.boundary_checks, ElementsAre(0));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(1, 1));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(20, 1));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0));
|
||||
EXPECT_THAT(ptr.op.getOrder(), ElementsAre(1, 0));
|
||||
}
|
||||
|
|
@ -170,7 +170,7 @@ TEST_F(TritonMakeTensorPtrTest, BlockProperties) {
|
|||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getShape()), ElementsAre(4, 4));
|
||||
EXPECT_THAT(TensorShape(ptr.op), ElementsAre(4, 4));
|
||||
EXPECT_TRUE(ptr.boundary_checks.empty());
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(1, 1));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(20, 1));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0));
|
||||
EXPECT_THAT(ptr.op.getOrder(), ElementsAre(1, 0));
|
||||
}
|
||||
|
|
@ -199,7 +199,7 @@ TEST_F(TritonMakeTensorPtrTest, BlockProperties) {
|
|||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getShape()), ElementsAre(3, 4));
|
||||
EXPECT_THAT(TensorShape(ptr.op), ElementsAre(4, 4));
|
||||
EXPECT_THAT(ptr.boundary_checks, ElementsAre(0));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(1, 1));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(20, 1));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0));
|
||||
EXPECT_THAT(ptr.op.getOrder(), ElementsAre(1, 0));
|
||||
}
|
||||
|
|
@ -212,7 +212,8 @@ TEST_F(TritonMakeTensorPtrTest, BlockProperties) {
|
|||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getShape()), ElementsAre(3, 4, 6));
|
||||
EXPECT_THAT(TensorShape(ptr.op), ElementsAre(4, 4, 8));
|
||||
EXPECT_THAT(ptr.boundary_checks, ElementsAre(0, 2));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()), ElementsAre(1, 1, 1));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getStrides()),
|
||||
ElementsAre(3000, 150, 1));
|
||||
EXPECT_THAT(ConstOpValuesToInt(ptr.op.getOffsets()), ElementsAre(0, 0, 0));
|
||||
EXPECT_THAT(ptr.op.getOrder(), ElementsAre(2, 1, 0));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5668,6 +5668,60 @@ CHECK: }
|
|||
)"));
|
||||
}
|
||||
|
||||
TEST_F(TritonTest, TestSoftMaxWithTileElementsNotAllContiguous) {
|
||||
const std::string kHloText = R"(
|
||||
HloModule m
|
||||
|
||||
region {
|
||||
param_0 = f32[] parameter(0)
|
||||
param_1 = f32[] parameter(1)
|
||||
ROOT add.1 = f32[] add(param_0, param_1)
|
||||
}
|
||||
|
||||
triton_softmax_computation {
|
||||
constant.1 = f32[] constant(0)
|
||||
broadcast.2 = f32[4,4,8] broadcast(constant.1), dimensions={}
|
||||
param_0.1 = f32[4,4,8] parameter(0)
|
||||
constant = f32[] constant(0)
|
||||
reduce = f32[4,4] reduce(param_0.1, constant), dimensions={2}, to_apply=region
|
||||
broadcast = f32[4,4,8] broadcast(reduce), dimensions={0,1}
|
||||
multiply = f32[4,4,8] multiply(broadcast.2, broadcast)
|
||||
ROOT add.2 = f32[4,4,8] add(multiply, broadcast)
|
||||
}
|
||||
|
||||
ENTRY entry_computation {
|
||||
param_0.2 = f32[4,4,8] parameter(0)
|
||||
ROOT fusion = f32[4,4,8] fusion(param_0.2), kind=kCustom, calls=triton_softmax_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["2","2","8"],"num_warps":"1"}}}
|
||||
})";
|
||||
EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6,
|
||||
/*arel=*/1e-6}));
|
||||
}
|
||||
|
||||
TEST_F(TritonTest, TestSliceWithTileElementsNotAllContiguous) {
|
||||
const std::string kHloText = R"(
|
||||
HloModule m
|
||||
|
||||
region {
|
||||
param_0 = f32[] parameter(0)
|
||||
param_1 = f32[] parameter(1)
|
||||
ROOT add.2 = f32[] add(param_0, param_1)
|
||||
}
|
||||
|
||||
fused_computation {
|
||||
param_0.1 = f32[16,16,32] parameter(0)
|
||||
slice = f32[4,4,8] slice(param_0.1), slice={[2:10:2], [2:6], [3:11]}
|
||||
slice.1 = f32[4,4,8] slice(param_0.1), slice={[4:8], [8:16:2], [13:21]}
|
||||
ROOT add.3 = f32[4,4,8] add(slice, slice.1)
|
||||
}
|
||||
|
||||
ENTRY entry_computation {
|
||||
param_0.2 = f32[16,16,32] parameter(0)
|
||||
ROOT fusion = f32[4,4,8] fusion(param_0.2), kind=kCustom, calls=fused_computation, backend_config={"fusion_backend_config": {"kind":"__triton","block_level_fusion_config":{"output_tile_sizes":["2","2","8"],"num_warps":"1"}}}
|
||||
})";
|
||||
EXPECT_TRUE(RunAndCompare(kHloText, ErrorSpec{/*aabs=*/1e-6,
|
||||
/*arel=*/1e-6}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace gpu
|
||||
} // namespace xla
|
||||
|
|
|
|||
318
third_party/xla/xla/service/hlo_unstacker.cc
vendored
318
third_party/xla/xla/service/hlo_unstacker.cc
vendored
|
|
@ -15,6 +15,7 @@ limitations under the License.
|
|||
|
||||
#include "xla/service/hlo_unstacker.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
#include <deque>
|
||||
#include <functional>
|
||||
|
|
@ -46,7 +47,7 @@ limitations under the License.
|
|||
namespace xla {
|
||||
namespace {
|
||||
|
||||
// TODO(b/342457472): Remove this struct and move its field to the
|
||||
// TODO: b/342457472 - Remove this struct and move its field to the
|
||||
// UnstackerTransformer as static members. A struct that holds the required
|
||||
// information for unstacking that is fixed across different unstacker
|
||||
// instastances.
|
||||
|
|
@ -63,96 +64,141 @@ struct UnstackerMetadata {
|
|||
WhileLoopUnroller::GetUnrollableLoops(module, {});
|
||||
for (const auto& [instr, while_loop_config] : loops) {
|
||||
metadata.unrollable_loop_bodies[instr->while_body()] = while_loop_config;
|
||||
metadata.bodies[instr->while_body()] = instr;
|
||||
}
|
||||
return metadata;
|
||||
}
|
||||
absl::flat_hash_map<HloComputation*, WhileLoopConfig> unrollable_loop_bodies;
|
||||
// A pair of custom pattern and its handler lambda that describes the
|
||||
// transformation needed to unstack the hlo graph for the pattern.
|
||||
std::pair<std::function<const HloInstruction*(
|
||||
const UnstackerMetadata&, const HloInstruction*, int64_t)>,
|
||||
std::function<absl::Status(HloInstruction*, const Shape&)>>
|
||||
custom_handler;
|
||||
absl::flat_hash_map<const HloComputation*, HloInstruction*> bodies;
|
||||
// Vector containing pairs of custom patterns and their corresponding handler
|
||||
// lambdas. The patterns are checked in the order in which they are inserted
|
||||
// into this vector.
|
||||
std::vector<
|
||||
std::pair<std::function<const HloInstruction*(
|
||||
const UnstackerMetadata&, const HloInstruction*, int64_t)>,
|
||||
std::function<absl::Status(HloInstruction*, const Shape&)>>>
|
||||
custom_handlers;
|
||||
};
|
||||
|
||||
// A struct that holds the required information for two-step unstacking. The
|
||||
// content of each instance differs for each operand of a while loop.
|
||||
struct UnstackerTransformer {
|
||||
UnstackerMetadata metadata;
|
||||
static absl::StatusOr<UnstackerTransformer> Create(
|
||||
const UnstackerMetadata& c) {
|
||||
UnstackerTransformer transformer;
|
||||
transformer.metadata = std::move(c);
|
||||
return transformer;
|
||||
}
|
||||
// Performs the two-step unstacking. Each instance of this class is responsible
|
||||
// for a single operand of a while loop.
|
||||
class UnstackerTransformer {
|
||||
public:
|
||||
// Default unroll_factor of -1 indicates full unrolling
|
||||
explicit UnstackerTransformer(const UnstackerMetadata& metadata)
|
||||
: metadata_(metadata) {}
|
||||
|
||||
// Given an instruction and the index of the its changed operand, it applies
|
||||
// the custom handler and populates body_changes lambdas that unstacks the hlo
|
||||
// graph accordingly.
|
||||
bool HandleInstruction(const HloInstruction* instr, int64_t changed_idx) {
|
||||
// Currently, we only unstack operands that are used within fusion
|
||||
// computations.
|
||||
if (instr->opcode() != HloOpcode::kFusion) {
|
||||
return false;
|
||||
}
|
||||
VLOG(3) << "HandleInstruction(" << instr->shape().ToString()
|
||||
<< instr->name() << ", " << changed_idx << ")";
|
||||
|
||||
auto custom_pattern = metadata.custom_handler.first;
|
||||
auto custom_handler = metadata.custom_handler.second;
|
||||
for (const auto& [custom_pattern, custom_handler] :
|
||||
metadata_.custom_handlers) {
|
||||
const HloInstruction* stacked_user =
|
||||
custom_pattern(metadata_, instr, changed_idx);
|
||||
// Try the next pattern if current pattern is not found.
|
||||
if (stacked_user == nullptr) {
|
||||
continue;
|
||||
}
|
||||
if (unstacking_computation_ != nullptr) {
|
||||
VLOG(3) << "Seen multiple users, cannot handle. \n instr: "
|
||||
<< instr->ToString() << "\n hoisted_computation: "
|
||||
<< unstacking_computation_->ToString(
|
||||
HloPrintOptions::Fingerprint());
|
||||
return false;
|
||||
}
|
||||
|
||||
const HloInstruction* stacked_user =
|
||||
custom_pattern(metadata, instr, changed_idx);
|
||||
if (stacked_user == nullptr) {
|
||||
return false;
|
||||
unstacking_computation_ =
|
||||
stacked_user->fused_instructions_computation()->Clone(
|
||||
"hoisted_unstacking");
|
||||
VLOG(3) << "Unstacking computation: "
|
||||
<< unstacking_computation_->ToString(
|
||||
HloPrintOptions::Fingerprint());
|
||||
|
||||
// TODO: b/342440749 - Currently, we assume the stacked dimension is
|
||||
// always the most major dimension. This condition can be checked and
|
||||
// terminate unstacking if not met.
|
||||
Shape slice_shape = stacked_user->shape();
|
||||
int64_t num_layers = stacked_user->operand(0)->shape().dimensions(0);
|
||||
std::vector<Shape> shapes;
|
||||
for (int64_t i = 0; i < num_layers; ++i) {
|
||||
shapes.push_back(slice_shape);
|
||||
}
|
||||
unstacked_shape_ =
|
||||
std::make_unique<Shape>(ShapeUtil::MakeTupleShape(shapes));
|
||||
|
||||
unstacked_instrs_.push_back(instr);
|
||||
|
||||
// Wrapper function around the unstacker lambda which calls the unstacker.
|
||||
std::function<absl::Status()> unstack_wrapper =
|
||||
[&custom_handler = custom_handler, stacked_user,
|
||||
slice_shape]() mutable -> absl::Status {
|
||||
HloInstruction* mutable_dynamic_slicing_fusion =
|
||||
const_cast<HloInstruction*>(stacked_user);
|
||||
return custom_handler(mutable_dynamic_slicing_fusion, slice_shape);
|
||||
};
|
||||
body_changes_.push_back(unstack_wrapper);
|
||||
return true;
|
||||
}
|
||||
if (unstacking_computation != nullptr) {
|
||||
LOG(ERROR) << "Seen multiple users, cannot handle. \n instr: "
|
||||
<< instr->ToString() << "\n hoisted_computation: "
|
||||
<< unstacking_computation->ToString(
|
||||
HloPrintOptions::Fingerprint());
|
||||
return false;
|
||||
}
|
||||
|
||||
unstacking_computation =
|
||||
stacked_user->fused_instructions_computation()->Clone(
|
||||
"hoisted_unstacking");
|
||||
VLOG(3) << "Unstacking computation: "
|
||||
<< unstacking_computation->ToString(HloPrintOptions::Fingerprint());
|
||||
|
||||
// TODO(b/342440749): Currently, we assume the stacked dimension is always
|
||||
// the most major dimension. This condition can be checked and terminate
|
||||
// unstacking if not met.
|
||||
Shape slice_shape = stacked_user->shape();
|
||||
int64_t num_layers = stacked_user->operand(0)->shape().dimensions(0);
|
||||
std::vector<Shape> shapes;
|
||||
for (int64_t i = 0; i < num_layers; ++i) {
|
||||
shapes.push_back(slice_shape);
|
||||
}
|
||||
unstacked_shape =
|
||||
std::make_unique<Shape>(ShapeUtil::MakeTupleShape(shapes));
|
||||
|
||||
// Wrapper function around the unstacker lambda which calls the unstacker.
|
||||
std::function<absl::Status()> unstack_wrapper =
|
||||
[=]() mutable -> absl::Status {
|
||||
HloInstruction* mutable_dynamic_slicing_fusion =
|
||||
const_cast<HloInstruction*>(stacked_user);
|
||||
return custom_handler(mutable_dynamic_slicing_fusion, slice_shape);
|
||||
};
|
||||
body_changes.push_back(unstack_wrapper);
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<const HloInstruction*>& GetUnstackedInstructions() {
|
||||
return unstacked_instrs_;
|
||||
}
|
||||
|
||||
const Shape* GetUnstackedShape() const { return unstacked_shape_.get(); }
|
||||
|
||||
// The function returns a mutable pointer to the unstacking computation since
|
||||
// the pointer is later used to clone the computation.
|
||||
HloComputation* GetUnstackingComputation() const {
|
||||
return unstacking_computation_.get();
|
||||
}
|
||||
|
||||
std::vector<std::function<void(const Shape*)>>& GetLoopChanges() {
|
||||
return loop_changes_;
|
||||
}
|
||||
|
||||
std::vector<std::function<absl::Status()>>& GetBodyChanges() {
|
||||
return body_changes_;
|
||||
}
|
||||
|
||||
absl::flat_hash_map<HloInstruction*, int64_t>& GetOperandChanges() {
|
||||
return operand_changes_;
|
||||
}
|
||||
|
||||
void AddLoopChange(std::function<void(const Shape*)> loop_change) {
|
||||
loop_changes_.push_back(loop_change);
|
||||
}
|
||||
|
||||
private:
|
||||
const UnstackerMetadata& metadata_;
|
||||
// This pointer is populated if the unstacker finds unstackable loop input.
|
||||
std::unique_ptr<Shape> unstacked_shape = nullptr;
|
||||
std::unique_ptr<Shape> unstacked_shape_ = nullptr;
|
||||
// This is a pointer to the computation that is responsible for unstacking. It
|
||||
// is used to hoist the unstacking computations outside the loop bodies.
|
||||
std::unique_ptr<HloComputation> unstacking_computation = nullptr;
|
||||
std::unique_ptr<HloComputation> unstacking_computation_ = nullptr;
|
||||
// A vector of lambdas that describe necessary changes to the shape of the
|
||||
// loops to unstack. The lambdas accept the pointer to the new unstacked
|
||||
// shape.
|
||||
std::vector<std::function<void(const Shape*)>> loop_changes;
|
||||
std::vector<std::function<void(const Shape*)>> loop_changes_;
|
||||
// a list of lambdas that captures all the changes to the hlo graph needed for
|
||||
// unstacking.
|
||||
std::vector<std::function<absl::Status()>> body_changes;
|
||||
std::vector<std::function<absl::Status()>> body_changes_;
|
||||
// A map that tracks the index of the changed operand for instructions of type
|
||||
// get-tuple-element, tuple, and while during unstacking.
|
||||
absl::flat_hash_map<HloInstruction*, int64_t> operand_changes;
|
||||
absl::flat_hash_map<HloInstruction*, int64_t> operand_changes_;
|
||||
// Holds the list of unstacked instructions that will be used to identify
|
||||
// loops that need to be unrolled.
|
||||
std::vector<const HloInstruction*> unstacked_instrs_;
|
||||
};
|
||||
|
||||
bool CanUnstackWhileOperand(const HloInstruction* while_instr,
|
||||
|
|
@ -169,12 +215,12 @@ bool PropagateGteShapeChange(HloInstruction* gte,
|
|||
UnstackerTransformer& unstacker) {
|
||||
VLOG(5) << "PropagateGteShapeChange(" << gte->ToString() << ")";
|
||||
|
||||
// TODO(b/343457903): Use HloDataflowAnalysis to track the usage of a value
|
||||
// TODO: b/343457903 - Use HloDataflowAnalysis to track the usage of a value
|
||||
// instead of manually applying bfs
|
||||
//
|
||||
// Apply BFS to propagate the index of the changed operand.
|
||||
absl::flat_hash_map<HloInstruction*, int64_t>& visited =
|
||||
unstacker.operand_changes;
|
||||
unstacker.GetOperandChanges();
|
||||
std::deque<HloInstruction*> worklist;
|
||||
worklist.push_back(gte);
|
||||
visited.insert({gte, gte->tuple_index()});
|
||||
|
|
@ -283,11 +329,12 @@ bool CanUnstackWhileOperand(const HloInstruction* while_instr,
|
|||
loop->while_condition()->ReplaceParameter(
|
||||
0, HloInstruction::CreateParameter(0, old_shape, "unstacked"));
|
||||
};
|
||||
auto loop_change_wrapper = [=](const Shape* new_shape) {
|
||||
auto loop_change_wrapper = [&loop_change, while_instr,
|
||||
index](const Shape* new_shape) {
|
||||
HloInstruction* mutable_loop = const_cast<HloInstruction*>(while_instr);
|
||||
loop_change(mutable_loop, new_shape, index);
|
||||
};
|
||||
unstacker.loop_changes.push_back(loop_change_wrapper);
|
||||
unstacker.AddLoopChange(loop_change_wrapper);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
|
|
@ -303,7 +350,7 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker,
|
|||
HloInstruction* old_while_input =
|
||||
while_instr->while_init()->mutable_operand(index);
|
||||
|
||||
// TODO(b/341815540): Instead of creating the unstacked tuple for every input
|
||||
// TODO: b/341815540 - Instead of creating the unstacked tuple for every input
|
||||
// index, we should reuse if the input and unstacking computations are the
|
||||
// same.
|
||||
//
|
||||
|
|
@ -312,15 +359,16 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker,
|
|||
std::vector<HloInstruction*> slices;
|
||||
for (int64_t i = 0; i < new_shape->tuple_shapes_size(); ++i) {
|
||||
std::vector<HloInstruction*> operands = {
|
||||
old_while_input,
|
||||
while_instr->AddInstruction(MakeConstantWithShape(
|
||||
unstacker.unstacking_computation->parameter_instruction(1)->shape(),
|
||||
i))};
|
||||
old_while_input, while_instr->AddInstruction(MakeConstantWithShape(
|
||||
unstacker.GetUnstackingComputation()
|
||||
->parameter_instruction(1)
|
||||
->shape(),
|
||||
i))};
|
||||
HloInstruction* slice =
|
||||
while_instr->AddInstruction(HloInstruction::CreateFusion(
|
||||
slice_shape, HloInstruction::FusionKind::kLoop, operands,
|
||||
while_instr->GetModule()->AddEmbeddedComputation(
|
||||
unstacker.unstacking_computation->Clone()),
|
||||
unstacker.GetUnstackingComputation()->Clone()),
|
||||
"hoisted"));
|
||||
slices.push_back(slice);
|
||||
}
|
||||
|
|
@ -335,10 +383,12 @@ void UnstackWhileInput(const UnstackerTransformer& unstacker,
|
|||
|
||||
// Apply the two-step unstacking algorithm to the given while_instr at the given
|
||||
// index.
|
||||
bool UnstackWhileOperandAtIndex(const UnstackerMetadata& metadata,
|
||||
HloInstruction* while_instr, int64_t index) {
|
||||
UnstackerTransformer unstacker =
|
||||
UnstackerTransformer::Create(metadata).value();
|
||||
bool UnstackWhileOperandAtIndex(
|
||||
const UnstackerMetadata& metadata, HloInstruction* while_instr,
|
||||
int64_t index, std::vector<const HloInstruction*>& unstacked_instructions) {
|
||||
// UnstackerTransformer unstacker =
|
||||
// UnstackerTransformer::Create(metadata).value();
|
||||
UnstackerTransformer unstacker = UnstackerTransformer(metadata);
|
||||
|
||||
// First step of unstacking to determine whether while_instr at index is
|
||||
// unstackable.
|
||||
|
|
@ -357,7 +407,7 @@ bool UnstackWhileOperandAtIndex(const UnstackerMetadata& metadata,
|
|||
|
||||
// If unstacker has not found an unstackable shape, there is no point in
|
||||
// applying the unstacker changes.
|
||||
if (unstacker.unstacked_shape == nullptr) {
|
||||
if (unstacker.GetUnstackedShape() == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -366,17 +416,17 @@ bool UnstackWhileOperandAtIndex(const UnstackerMetadata& metadata,
|
|||
//
|
||||
// Update the shape of get-tuple-element, tuple, and, while instructions
|
||||
// based on the unstacked_shape and the index of the changed operand.
|
||||
for (const auto& [instr, index] : unstacker.operand_changes) {
|
||||
for (const auto& [instr, index] : unstacker.GetOperandChanges()) {
|
||||
switch (instr->opcode()) {
|
||||
case HloOpcode::kGetTupleElement:
|
||||
*instr->mutable_shape() = *unstacker.unstacked_shape;
|
||||
*instr->mutable_shape() = *unstacker.GetUnstackedShape();
|
||||
break;
|
||||
case HloOpcode::kTuple:
|
||||
*instr->mutable_shape()->mutable_tuple_shapes(index) =
|
||||
*unstacker.unstacked_shape;
|
||||
*unstacker.GetUnstackedShape();
|
||||
break;
|
||||
case HloOpcode::kWhile:
|
||||
ShapeUtil::UpdateTupleShape(*unstacker.unstacked_shape, index,
|
||||
ShapeUtil::UpdateTupleShape(*unstacker.GetUnstackedShape(), index,
|
||||
instr->mutable_shape());
|
||||
break;
|
||||
default:
|
||||
|
|
@ -384,22 +434,87 @@ bool UnstackWhileOperandAtIndex(const UnstackerMetadata& metadata,
|
|||
}
|
||||
}
|
||||
// Apply the changes to the body according to the provided custom handler.
|
||||
for (const auto& body_change : unstacker.body_changes) {
|
||||
for (const auto& body_change : unstacker.GetBodyChanges()) {
|
||||
CHECK_OK(body_change());
|
||||
}
|
||||
// Update the input and output shape of the loop.
|
||||
UnstackWhileInput(unstacker, while_instr, unstacker.unstacked_shape.get(),
|
||||
UnstackWhileInput(unstacker, while_instr, unstacker.GetUnstackedShape(),
|
||||
index);
|
||||
const Shape& new_while_shape = while_instr->while_init()->shape();
|
||||
*while_instr->mutable_shape() = new_while_shape;
|
||||
// Apply the changes to the shape of the loop body and condition
|
||||
// computations.
|
||||
for (auto& loop_change : unstacker.loop_changes) {
|
||||
loop_change(unstacker.unstacked_shape.get());
|
||||
for (auto& loop_change : unstacker.GetLoopChanges()) {
|
||||
loop_change(unstacker.GetUnstackedShape());
|
||||
}
|
||||
for (const HloInstruction* instr : unstacker.GetUnstackedInstructions()) {
|
||||
unstacked_instructions.push_back(instr);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// This function recognizes fusions with the following pattern:
|
||||
// fusion(stacked, loop_iteration_var)
|
||||
// computation {
|
||||
// p0 = parameter(0)
|
||||
// p1 = parameter(1)
|
||||
// slice = dynamic_slice(p0, p1, zero, ...)
|
||||
// ROOT bitcast = bitcast(slice)
|
||||
// }
|
||||
const HloInstruction* IsDynamicSlicingFusion(const UnstackerMetadata& metadata,
|
||||
const HloInstruction* instr,
|
||||
int64_t stacked_operand_idx) {
|
||||
CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
|
||||
if (instr->fused_parameters().size() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
if (!metadata.unrollable_loop_bodies.contains(instr->parent())) {
|
||||
VLOG(5) << "Instruction not inside unrollable while body, "
|
||||
<< instr->ToString() << instr->parent()->ToString();
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
WhileLoopConfig while_instr_config =
|
||||
metadata.unrollable_loop_bodies.at(instr->parent());
|
||||
|
||||
for (HloInstruction* fused_instr :
|
||||
instr->fused_instructions_computation()->MakeInstructionPostOrder()) {
|
||||
if (!Match(fused_instr, match::DynamicSlice())) {
|
||||
continue;
|
||||
}
|
||||
std::optional<int64_t> dynamic_index =
|
||||
MatchShapeCoveringDynamicIndexInstruction(
|
||||
fused_instr,
|
||||
instr->fused_instructions_computation()->parameter_instruction(
|
||||
stacked_operand_idx),
|
||||
HloOpcode::kDynamicSlice, while_instr_config);
|
||||
if (dynamic_index.has_value() && dynamic_index.value() == 0) {
|
||||
HloInstruction* bitcast_operand = nullptr;
|
||||
if (Match(instr->fused_instructions_computation()->root_instruction(),
|
||||
match::Bitcast(match::Op(&bitcast_operand)))) {
|
||||
if (bitcast_operand == fused_instr) {
|
||||
return instr;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
absl::Status UnstackDynamicSlicingFusion(
|
||||
HloInstruction* mutable_dynamic_slicing_fusion, const Shape& slice_shape) {
|
||||
HloComputation* parent_loop = mutable_dynamic_slicing_fusion->parent();
|
||||
|
||||
HloInstruction* stacked = mutable_dynamic_slicing_fusion->mutable_operand(0);
|
||||
HloInstruction* offset = mutable_dynamic_slicing_fusion->mutable_operand(1);
|
||||
|
||||
HloInstruction* new_operand =
|
||||
parent_loop->AddInstruction(HloInstruction::CreateCustomCall(
|
||||
slice_shape, {stacked, offset}, "DynamicGte"));
|
||||
return mutable_dynamic_slicing_fusion->ReplaceAllUsesWithDifferentShape(
|
||||
new_operand);
|
||||
}
|
||||
|
||||
// This method checks if the given instruction is a fusion with the following
|
||||
// properties:
|
||||
// 1. It is inside the body of an unrollable loop
|
||||
|
|
@ -413,9 +528,7 @@ bool UnstackWhileOperandAtIndex(const UnstackerMetadata& metadata,
|
|||
const HloInstruction* GetNestedDynamicSlicingFusion(
|
||||
const UnstackerMetadata& metadata, const HloInstruction* instr,
|
||||
int64_t stacked_operand_idx) {
|
||||
if (!Match(instr, match::Fusion())) {
|
||||
return nullptr;
|
||||
}
|
||||
CHECK_EQ(instr->opcode(), HloOpcode::kFusion);
|
||||
|
||||
if (!metadata.unrollable_loop_bodies.contains(instr->parent())) {
|
||||
VLOG(5) << "Instruction not inside unrollable while body, "
|
||||
|
|
@ -536,13 +649,13 @@ absl::StatusOr<bool> HloUnstacker::Run(
|
|||
const absl::flat_hash_set<absl::string_view>& execution_threads) {
|
||||
TF_ASSIGN_OR_RETURN(auto metadata, UnstackerMetadata::Create(module));
|
||||
|
||||
// Custom handler is a pair of pattern and transformation function that
|
||||
// captures different cases of unstacking. It is decoupled from the unstacking
|
||||
// algorithm for modularity.
|
||||
metadata.custom_handler = std::make_pair(GetNestedDynamicSlicingFusion,
|
||||
UnstackNestedDynamicSlicingFusion);
|
||||
metadata.custom_handlers.push_back(
|
||||
std::make_pair(IsDynamicSlicingFusion, UnstackDynamicSlicingFusion));
|
||||
metadata.custom_handlers.push_back(std::make_pair(
|
||||
GetNestedDynamicSlicingFusion, UnstackNestedDynamicSlicingFusion));
|
||||
|
||||
bool unstacked = false;
|
||||
std::vector<const HloInstruction*> unstacked_instructions;
|
||||
for (HloInstruction* instr :
|
||||
module->entry_computation()->MakeInstructionPostOrder()) {
|
||||
if (instr->opcode() != HloOpcode::kWhile) {
|
||||
|
|
@ -552,7 +665,8 @@ absl::StatusOr<bool> HloUnstacker::Run(
|
|||
VLOG(3) << "Attempting to unstack " << instr->name() << " at " << i
|
||||
<< " with stacked shape "
|
||||
<< instr->shape().tuple_shapes(i).ToString();
|
||||
if (UnstackWhileOperandAtIndex(metadata, instr, i)) {
|
||||
if (UnstackWhileOperandAtIndex(metadata, instr, i,
|
||||
unstacked_instructions)) {
|
||||
VLOG(3) << "Unstacked " << instr->name() << " at " << i
|
||||
<< " with stacked shape "
|
||||
<< instr->shape().tuple_shapes(i).ToString();
|
||||
|
|
@ -566,7 +680,19 @@ absl::StatusOr<bool> HloUnstacker::Run(
|
|||
TF_RETURN_IF_ERROR(module->RemoveUnusedComputations());
|
||||
// We rely on the WhileLoopUnroller pass to unroll loop bodies and rewrite
|
||||
// custom-calls created by unstacker, i.e., DynamicGte and DynamicTuple.
|
||||
TF_RETURN_IF_ERROR(WhileLoopUnroller(-1, true).Run(module).status());
|
||||
std::vector<HloInstruction*> loops_to_unroll;
|
||||
for (const HloInstruction* instr : unstacked_instructions) {
|
||||
HloInstruction* loop = metadata.bodies[instr->parent()];
|
||||
if (std::find(loops_to_unroll.begin(), loops_to_unroll.end(), loop) ==
|
||||
loops_to_unroll.end()) {
|
||||
loops_to_unroll.push_back(loop);
|
||||
}
|
||||
}
|
||||
for (HloInstruction* loop : loops_to_unroll) {
|
||||
TF_ASSIGN_OR_RETURN(bool unrolled,
|
||||
WhileLoopUnroller::Unroll(loop, -1, true, true));
|
||||
CHECK(unrolled);
|
||||
}
|
||||
}
|
||||
return unstacked;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ limitations under the License.
|
|||
#include <utility>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "xla/hlo/ir/hlo_computation.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/tests/hlo_test_base.h"
|
||||
#include "tsl/platform/statusor.h"
|
||||
|
|
@ -30,6 +31,62 @@ namespace {
|
|||
|
||||
using UnstackerTest = HloTestBase;
|
||||
|
||||
TEST_F(UnstackerTest, UnstackLoopSingleFusionUser) {
|
||||
std::string hlo_string = R"(
|
||||
HloModule SimpleLoop
|
||||
%fused_computation.slice (param_0.51117: s8[3,128,128], p1: s32[]) ->
|
||||
s8[128,128] {
|
||||
%param_0.51117 = s8[3,128,128] parameter(0)
|
||||
p1 = s32[] parameter(1)
|
||||
%constant.85694 = s32[] constant(0)
|
||||
%dynamic-slice.22040 = s8[1,128,128] dynamic-slice(s8[3,128,128]
|
||||
%param_0.51117, p1, s32[] %constant.85694, s32[] %constant.85694),
|
||||
dynamic_slice_sizes={1,128,128} ROOT %bitcast.31250 = s8[128,128]
|
||||
bitcast(s8[1,128,128] %dynamic-slice.22040)
|
||||
}
|
||||
|
||||
%while.body (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> (s32[],
|
||||
bf16[8,128], s8[3,128,128]) {
|
||||
wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0)
|
||||
i = s32[] get-tuple-element(wide_p), index=0
|
||||
p0 = bf16[8,128] get-tuple-element(wide_p), index=1
|
||||
p1 = s8[3,128,128] get-tuple-element(wide_p), index=2
|
||||
one = s32[] constant(1)
|
||||
inc = s32[] add(i, one)
|
||||
%fusion.67830 = s8[128,128] fusion(s8[3,128,128] p1, i), kind=kLoop,
|
||||
calls=%fused_computation.slice conv = bf16[8,128] convolution(bf16[8,128]
|
||||
p0, s8[128,128] %fusion.67830), dim_labels=bf_io->bf ROOT out = (s32[],
|
||||
bf16[8,128], s8[3,128,128]) tuple(inc, conv, p1)
|
||||
}
|
||||
|
||||
%while.cond (wide_param: (s32[], bf16[8,128], s8[3,128,128])) -> pred[] {
|
||||
wide_p = (s32[], bf16[8,128], s8[3,128,128]) parameter(0)
|
||||
i = s32[] get-tuple-element(wide_p), index=0
|
||||
%constant.12857 = s32[] constant(3)
|
||||
ROOT %compare.1921 = pred[]{:T(512)} compare(s32[] i, s32[]
|
||||
%constant.12857), direction=LT
|
||||
}
|
||||
|
||||
ENTRY main {
|
||||
p0 = s8[3,128,128] parameter(0)
|
||||
p1 = bf16[8,128] parameter(1)
|
||||
init = s32[] constant(0)
|
||||
while.input = (s32[], bf16[8,128], s8[3,128,128]) tuple(init, p1, p0)
|
||||
while.out = (s32[], bf16[8,128], s8[3,128,128]) while(while.input),
|
||||
condition=%while.cond , body=%while.body while_use = s8[3,128,128]
|
||||
get-tuple-element(while.out), index=2 ROOT out = bf16[8,128]
|
||||
get-tuple-element(while.out), index=1
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
auto original = module->Clone();
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool unstacked, HloUnstacker().Run(module.get()));
|
||||
EXPECT_TRUE(unstacked);
|
||||
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(original),
|
||||
std::nullopt));
|
||||
}
|
||||
|
||||
TEST_F(UnstackerTest, UnstackLoopSingleNestedFusionUser) {
|
||||
std::string hlo_string = R"(
|
||||
HloModule SimpleLoop
|
||||
|
|
@ -412,7 +469,7 @@ TEST_F(UnstackerTest, UnstackMultipleLoops) {
|
|||
while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight)
|
||||
while.out = (s32[], bf16[8,128], s8[4,128,128]) while(while.input), condition=%while.cond1 , body=%while.body1
|
||||
second.while.input = (s32[], bf16[8,128], s8[4,128,128]) tuple(init, p1, weight)
|
||||
second.while.output = (s32[], bf16[8,128], s8[4,128,128]) while(second.while.input), condition=%while.cond2 , body=%while.body2
|
||||
second.while.out = (s32[], bf16[8,128], s8[4,128,128]) while(second.while.input), condition=%while.cond2 , body=%while.body2
|
||||
ROOT out = bf16[8,128] get-tuple-element(while.out), index=1
|
||||
}
|
||||
)";
|
||||
|
|
|
|||
8
third_party/xla/xla/status.h
vendored
8
third_party/xla/xla/status.h
vendored
|
|
@ -16,6 +16,14 @@ limitations under the License.
|
|||
#ifndef XLA_STATUS_H_
|
||||
#define XLA_STATUS_H_
|
||||
|
||||
#include "absl/status/status.h"
|
||||
|
||||
// This is an obsolete header. Please use absl/status/status.h instead.
|
||||
namespace xla {
|
||||
// NOLINTBEGIN(misc-unused-using-decls)
|
||||
using absl::OkStatus;
|
||||
using absl::Status;
|
||||
// NOLINTEND(misc-unused-using-decls)
|
||||
} // namespace xla
|
||||
|
||||
#endif // XLA_STATUS_H_
|
||||
|
|
|
|||
9
third_party/xla/xla/stream_executor/BUILD
vendored
9
third_party/xla/xla/stream_executor/BUILD
vendored
|
|
@ -4,7 +4,7 @@ load("@local_tsl//tsl/platform:build_config_root.bzl", "if_static")
|
|||
load("@local_tsl//tsl/platform:rules_cc.bzl", "cc_library")
|
||||
load("//xla:xla.bzl", "xla_cc_test")
|
||||
load("//xla/stream_executor:build_defs.bzl", "stream_executor_build_defs_bzl_deps", "stream_executor_friends", "stream_executor_internal")
|
||||
load("//xla/tsl:tsl.bzl", "internal_visibility")
|
||||
load("//xla/tsl:tsl.bzl", "if_google", "internal_visibility")
|
||||
load("//xla/tsl:tsl.default.bzl", "filegroup")
|
||||
|
||||
package(
|
||||
|
|
@ -137,7 +137,8 @@ cc_library(
|
|||
"@local_tsl//tsl/protobuf:dnn_proto_cc",
|
||||
] + if_static([
|
||||
":stream_executor_impl",
|
||||
"@com_google_protobuf//:protobuf", # indirectly-used by dnn.h
|
||||
]) + if_google([
|
||||
"@com_google_protobuf//:wrappers_cc_proto", # indirectly-used by dnn.h
|
||||
]),
|
||||
)
|
||||
|
||||
|
|
@ -225,7 +226,7 @@ cc_library(
|
|||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/types:span",
|
||||
] + if_static(["@com_google_protobuf//:protobuf"]),
|
||||
] + if_google(["@com_google_protobuf//:wrappers_cc_proto"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
@ -398,7 +399,7 @@ cc_library(
|
|||
"@local_tsl//tsl/platform:status",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
"@local_tsl//tsl/protobuf:dnn_proto_cc",
|
||||
] + if_static(["@com_google_protobuf//:protobuf"]),
|
||||
] + if_google(["@com_google_protobuf//:wrappers_cc_proto"]),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
|
|
|
|||
4
third_party/xla/xla/tests/BUILD
vendored
4
third_party/xla/xla/tests/BUILD
vendored
|
|
@ -1718,6 +1718,7 @@ xla_test(
|
|||
"nomac", # b/194731834
|
||||
"nozapfhahn",
|
||||
"optonly",
|
||||
"test_xla_cpu_thunks",
|
||||
],
|
||||
deps = [
|
||||
":client_library_test_base",
|
||||
|
|
@ -2227,12 +2228,15 @@ xla_test(
|
|||
deps = [
|
||||
":hlo_test_base",
|
||||
":xla_internal_test_main",
|
||||
"//xla:error_spec",
|
||||
"//xla:util",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/service:collective_pipeliner",
|
||||
"//xla/service:hlo_dce",
|
||||
"//xla/service:hlo_parser",
|
||||
"//xla/service:hlo_pass_pipeline",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,14 +20,18 @@ limitations under the License.
|
|||
|
||||
#include <gtest/gtest.h>
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/error_spec.h"
|
||||
#include "xla/hlo/ir/hlo_computation.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/hlo/ir/hlo_opcode.h"
|
||||
#include "xla/service/collective_pipeliner.h"
|
||||
#include "xla/service/hlo_dce.h"
|
||||
#include "xla/service/hlo_parser.h"
|
||||
#include "xla/service/hlo_pass_pipeline.h"
|
||||
#include "xla/tests/hlo_test_base.h"
|
||||
#include "xla/util.h"
|
||||
|
||||
namespace xla {
|
||||
namespace {
|
||||
|
|
@ -125,6 +129,71 @@ ENTRY entry {
|
|||
ErrorSpec{0.1, 0.1}));
|
||||
}
|
||||
|
||||
TEST_F(CollectivePipelinerExecutionTest, TransformIncrementIndexByOneNoReuse) {
|
||||
constexpr absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add {
|
||||
lhs = bf16[] parameter(0)
|
||||
rhs = bf16[] parameter(1)
|
||||
ROOT add = bf16[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
while_cond {
|
||||
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
|
||||
gte = s32[] get-tuple-element(param), index=0
|
||||
constant.1 = s32[] constant(3)
|
||||
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
|
||||
}
|
||||
|
||||
while_body {
|
||||
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
|
||||
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
|
||||
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
|
||||
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
|
||||
constant.2557 = s32[] constant(1)
|
||||
add.230 = s32[] add(get-tuple-element.394, constant.2557)
|
||||
constant.2559 = s32[] constant(3)
|
||||
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
|
||||
constant.2560 = s32[] constant(-1)
|
||||
add.231 = s32[] add(subtract.139, constant.2560)
|
||||
constant.2561 = s32[] constant(0)
|
||||
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
|
||||
constant.2562 = s32[] constant(2)
|
||||
add.232 = s32[] add(subtract.139, constant.2562)
|
||||
select.1348 = s32[] select(compare.747, add.232, add.231)
|
||||
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
|
||||
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
|
||||
ar.1 = bf16[1,8,128] negate(mul)
|
||||
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, ar.1, select.1348, constant.2561, constant.2561)
|
||||
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
c0 = s32[] constant(0)
|
||||
p0 = bf16[3,8,128] parameter(0)
|
||||
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
|
||||
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
|
||||
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
|
||||
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
|
||||
EXPECT_TRUE(
|
||||
RunOptimizer(module.get(), /*last_run=*/true, /*level_to_operate_on=*/0,
|
||||
/*should_process=*/HloPredicateIsOp<HloOpcode::kNegate>,
|
||||
/*pipelining_direction=*/
|
||||
CollectivePipeliner::PipeliningDirection::kForward,
|
||||
/*pipeline_use_tree=*/false,
|
||||
/*acceptable_formatting=*/HloPredicateTrue,
|
||||
/*reuse_pipelined_op_buffer=*/HloPredicateFalse)
|
||||
.value());
|
||||
XLA_VLOG_LINES(1, module->ToString());
|
||||
XLA_VLOG_LINES(1, module2->ToString());
|
||||
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
|
||||
ErrorSpec{0.1, 0.1}));
|
||||
}
|
||||
|
||||
TEST_F(CollectivePipelinerExecutionTest, PushAgOver) {
|
||||
constexpr absl::string_view hlo_string = R"(
|
||||
HloModule module, entry_computation_layout={(bf16[3,8,128]{2,1,0})->bf16[3,8,128]{2,1,0}}
|
||||
|
|
@ -953,5 +1022,281 @@ ENTRY entry {
|
|||
ErrorSpec{0.1, 0.1}));
|
||||
}
|
||||
|
||||
TEST_F(CollectivePipelinerExecutionTest, TransformIncrementByTwoFormat) {
|
||||
constexpr absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add {
|
||||
lhs = bf16[] parameter(0)
|
||||
rhs = bf16[] parameter(1)
|
||||
ROOT add = bf16[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
while_cond {
|
||||
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
|
||||
gte = s32[] get-tuple-element(param), index=0
|
||||
constant.1 = s32[] constant(3)
|
||||
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
|
||||
}
|
||||
|
||||
while_body {
|
||||
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
|
||||
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
|
||||
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
|
||||
get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2
|
||||
constant.2557 = s32[] constant(1)
|
||||
add.230 = s32[] add(get-tuple-element.394, constant.2557)
|
||||
constant.2559 = s32[] constant(3)
|
||||
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
|
||||
constant.2560 = s32[] constant(-1)
|
||||
add.231 = s32[] add(subtract.139, constant.2560)
|
||||
constant.2561 = s32[] constant(0)
|
||||
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
|
||||
constant.2562 = s32[] constant(2)
|
||||
add.232 = s32[] add(subtract.139, constant.2562)
|
||||
select.1348 = s32[] select(compare.747, add.232, add.231)
|
||||
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
|
||||
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
|
||||
ar.1 = bf16[1,8,128] negate(mul)
|
||||
c = bf16[] constant(5.0)
|
||||
b = bf16[1,8,128] broadcast(c), dimensions={}
|
||||
a = bf16[1,8,128] add(ar.1, b)
|
||||
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, a, select.1348, constant.2561, constant.2561)
|
||||
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
c0 = s32[] constant(0)
|
||||
p0 = bf16[3,8,128] parameter(0)
|
||||
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
|
||||
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
|
||||
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
|
||||
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
RunOptimizer(module.get(), /*last_run=*/true, 0,
|
||||
/*should_process=*/HloPredicateIsOp<HloOpcode::kNegate>,
|
||||
CollectivePipeliner::PipeliningDirection::kForwardSink,
|
||||
/*pipeline_use_tree=*/true)
|
||||
.value());
|
||||
XLA_VLOG_LINES(1, module->ToString());
|
||||
XLA_VLOG_LINES(1, module2->ToString());
|
||||
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
|
||||
ErrorSpec{0.1, 0.1}));
|
||||
}
|
||||
|
||||
TEST_F(CollectivePipelinerExecutionTest, MultiUsesElementwiseMerge) {
|
||||
constexpr absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add {
|
||||
lhs = bf16[] parameter(0)
|
||||
rhs = bf16[] parameter(1)
|
||||
ROOT add = bf16[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
while_cond {
|
||||
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
|
||||
gte = s32[] get-tuple-element(param), index=0
|
||||
constant.1 = s32[] constant(3)
|
||||
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
|
||||
}
|
||||
|
||||
while_body {
|
||||
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
|
||||
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
|
||||
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
|
||||
get-tuple-element.5 = bf16[3,8,128] get-tuple-element(param), index=2
|
||||
constant.2557 = s32[] constant(1)
|
||||
add.230 = s32[] add(get-tuple-element.394, constant.2557)
|
||||
constant.2559 = s32[] constant(3)
|
||||
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
|
||||
constant.2560 = s32[] constant(-1)
|
||||
add.231 = s32[] add(subtract.139, constant.2560)
|
||||
constant.2561 = s32[] constant(0)
|
||||
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
|
||||
constant.2562 = s32[] constant(2)
|
||||
add.232 = s32[] add(subtract.139, constant.2562)
|
||||
select.1348 = s32[] select(compare.747, add.232, add.231)
|
||||
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.5, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
|
||||
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
|
||||
c2 = bf16[] constant(2.0)
|
||||
bc = bf16[1,8,128] broadcast(c2)
|
||||
ar.1 = bf16[1,8,128] sqrt(mul)
|
||||
ar.2 = bf16[1,8,128] negate(mul)
|
||||
mul2 = bf16[1,8,128] multiply(ar.1, bc)
|
||||
mul3 = bf16[1,8,128] multiply(mul2, ar.2)
|
||||
mul4 = bf16[1,8,128] multiply(mul3, mul)
|
||||
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul4, select.1348, constant.2561, constant.2561)
|
||||
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.5)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
c0 = s32[] constant(0)
|
||||
p0 = bf16[3,8,128] parameter(0)
|
||||
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
|
||||
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
|
||||
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
|
||||
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
RunOptimizer(module.get(), /*last_run=*/true, 0,
|
||||
/*should_process=*/
|
||||
HloPredicateIsOp<HloOpcode::kNegate, HloOpcode::kSqrt>,
|
||||
CollectivePipeliner::PipeliningDirection::kForward,
|
||||
/*pipeline_use_tree=*/true)
|
||||
.value());
|
||||
XLA_VLOG_LINES(1, module->ToString());
|
||||
XLA_VLOG_LINES(1, module2->ToString());
|
||||
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
|
||||
ErrorSpec{0.1, 0.1}));
|
||||
}
|
||||
|
||||
TEST_F(CollectivePipelinerExecutionTest, BroadcastAsFormattingOp) {
|
||||
constexpr absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add.1 {
|
||||
lhs = bf16[] parameter(0)
|
||||
rhs = bf16[] parameter(1)
|
||||
ROOT add = bf16[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
while_cond {
|
||||
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
|
||||
gte = s32[] get-tuple-element(param), index=0
|
||||
constant.1 = s32[] constant(3)
|
||||
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
|
||||
}
|
||||
|
||||
while_body {
|
||||
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
|
||||
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
|
||||
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
|
||||
get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2
|
||||
constant.2557 = s32[] constant(1)
|
||||
add.230 = s32[] add(get-tuple-element.394, constant.2557)
|
||||
constant.2559 = s32[] constant(3)
|
||||
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
|
||||
constant.2560 = s32[] constant(-1)
|
||||
add.231 = s32[] add(subtract.139, constant.2560)
|
||||
constant.2561 = s32[] constant(0)
|
||||
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
|
||||
constant.2562 = s32[] constant(2)
|
||||
add.232 = s32[] add(subtract.139, constant.2562)
|
||||
select.1348 = s32[] select(compare.747, add.232, add.231)
|
||||
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
|
||||
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
|
||||
ar.1 = bf16[1,8,128] negate(mul)
|
||||
b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2}
|
||||
constant = bf16[] constant(0)
|
||||
reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1
|
||||
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, reduce, select.1348, constant.2561, constant.2561)
|
||||
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
c0 = s32[] constant(0)
|
||||
p0 = bf16[3,8,128] parameter(0)
|
||||
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
|
||||
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
|
||||
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
|
||||
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
RunOptimizer(module.get(), /*last_run=*/true, 0,
|
||||
/*should_process=*/HloPredicateIsOp<HloOpcode::kNegate>,
|
||||
CollectivePipeliner::PipeliningDirection::kForwardSink,
|
||||
/*pipeline_use_tree=*/true)
|
||||
.value());
|
||||
XLA_VLOG_LINES(1, module->ToString());
|
||||
XLA_VLOG_LINES(1, module2->ToString());
|
||||
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
|
||||
ErrorSpec{0.1, 0.1}));
|
||||
}
|
||||
|
||||
TEST_F(CollectivePipelinerExecutionTest,
|
||||
ForwardSinkDependentPipelineableCollectives) {
|
||||
constexpr absl::string_view hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
add.1 {
|
||||
lhs = bf16[] parameter(0)
|
||||
rhs = bf16[] parameter(1)
|
||||
ROOT add = bf16[] add(lhs, rhs)
|
||||
}
|
||||
|
||||
while_cond {
|
||||
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
|
||||
gte = s32[] get-tuple-element(param), index=0
|
||||
constant.1 = s32[] constant(3)
|
||||
ROOT cmp = pred[] compare(gte, constant.1), direction=LT
|
||||
}
|
||||
|
||||
while_body {
|
||||
param = (s32[], bf16[3,8,128], bf16[3,8,128]) parameter(0)
|
||||
get-tuple-element.394 = s32[] get-tuple-element(param), index=0
|
||||
get-tuple-element.395 = bf16[3,8,128] get-tuple-element(param), index=1
|
||||
get-tuple-element.35 = bf16[3,8,128] get-tuple-element(param), index=2
|
||||
constant.2557 = s32[] constant(1)
|
||||
add.230 = s32[] add(get-tuple-element.394, constant.2557)
|
||||
constant.2559 = s32[] constant(3)
|
||||
subtract.139 = s32[] subtract(constant.2559, get-tuple-element.394)
|
||||
constant.2560 = s32[] constant(-1)
|
||||
add.231 = s32[] add(subtract.139, constant.2560)
|
||||
constant.2561 = s32[] constant(0)
|
||||
compare.747 = pred[] compare(add.231, constant.2561), direction=LT
|
||||
constant.2562 = s32[] constant(2)
|
||||
add.232 = s32[] add(subtract.139, constant.2562)
|
||||
select.1348 = s32[] select(compare.747, add.232, add.231)
|
||||
dynamic-slice.99 = bf16[1,8,128] dynamic-slice(get-tuple-element.35, select.1348, constant.2561, constant.2561), dynamic_slice_sizes={1,8,128}
|
||||
mul = bf16[1,8,128] multiply(dynamic-slice.99, dynamic-slice.99)
|
||||
ar.1 = bf16[1,8,128] negate(mul)
|
||||
b.1 = bf16[1,8,128,32] broadcast(ar.1), dimensions={0,1,2}
|
||||
constant = bf16[] constant(0)
|
||||
reduce = bf16[1,8,128] reduce(b.1, constant), dimensions={3}, to_apply=add.1
|
||||
ar.2 = bf16[1,8,128] negate(reduce)
|
||||
c1 = bf16[] constant(2.0)
|
||||
bc = bf16[1,8,128] broadcast(c1)
|
||||
mul1 = bf16[1,8,128] multiply(ar.2, bc)
|
||||
mul3 = bf16[1,8,128] multiply(mul1, ar.2)
|
||||
dynamic-update-slice.35 = bf16[3,8,128] dynamic-update-slice(get-tuple-element.395, mul3, select.1348, constant.2561, constant.2561)
|
||||
ROOT tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(add.230, dynamic-update-slice.35, get-tuple-element.35)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
c0 = s32[] constant(0)
|
||||
p0 = bf16[3,8,128] parameter(0)
|
||||
tuple = (s32[], bf16[3,8,128], bf16[3,8,128]) tuple(c0, p0, p0)
|
||||
while = (s32[], bf16[3,8,128], bf16[3,8,128]) while(tuple), condition=while_cond, body=while_body
|
||||
ROOT gte1 = bf16[3,8,128] get-tuple-element(while), index=1
|
||||
}
|
||||
)";
|
||||
auto module = ParseAndReturnUnverifiedModule(hlo_string).value();
|
||||
auto module2 = ParseAndReturnUnverifiedModule(hlo_string).value();
|
||||
|
||||
EXPECT_TRUE(
|
||||
RunOptimizer(
|
||||
module.get(), /*last_run=*/true, 0,
|
||||
/*should_process=*/HloPredicateIsOp<HloOpcode::kNegate>,
|
||||
CollectivePipeliner::PipeliningDirection::kForwardSink,
|
||||
/*pipeline_use_tree=*/true,
|
||||
/*acceptable_formatting=*/HloPredicateIsNotOp<HloOpcode::kNegate>)
|
||||
.value());
|
||||
XLA_VLOG_LINES(1, module->ToString());
|
||||
XLA_VLOG_LINES(1, module2->ToString());
|
||||
EXPECT_TRUE(RunAndCompareTwoModules(std::move(module), std::move(module2),
|
||||
ErrorSpec{0.1, 0.1}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
|
|
|||
4
third_party/xla/xla/tools/BUILD
vendored
4
third_party/xla/xla/tools/BUILD
vendored
|
|
@ -797,7 +797,7 @@ tsl_gpu_library(
|
|||
"//xla/service/gpu:gpu_compiler",
|
||||
"//xla/stream_executor/gpu:gpu_init",
|
||||
"//xla/service/gpu:gpu_symbol_repository",
|
||||
]),
|
||||
]) + if_google(["@com_google_protobuf//:duration_cc_proto"]),
|
||||
)
|
||||
|
||||
xla_test(
|
||||
|
|
@ -840,7 +840,7 @@ xla_test(
|
|||
"@local_tsl//tsl/platform:test",
|
||||
"@local_tsl//tsl/protobuf:error_codes_proto_impl_cc",
|
||||
"@local_tsl//tsl/protobuf:status_proto_cc",
|
||||
],
|
||||
] + if_google(["@com_google_protobuf//:duration_cc_proto"]),
|
||||
)
|
||||
|
||||
xla_test(
|
||||
|
|
|
|||
4
third_party/xla/xla/tsl/util/proto/BUILD
vendored
4
third_party/xla/xla/tsl/util/proto/BUILD
vendored
|
|
@ -2,6 +2,7 @@ load(
|
|||
"@local_tsl//tsl/platform:rules_cc.bzl",
|
||||
"cc_library",
|
||||
)
|
||||
load("//xla/tsl:tsl.bzl", "if_google")
|
||||
|
||||
package(
|
||||
# copybara:uncomment default_applicable_licenses = ["//tensorflow:license"],
|
||||
|
|
@ -16,6 +17,5 @@ cc_library(
|
|||
hdrs = ["proto_utils.h"],
|
||||
deps = [
|
||||
"@com_google_absl//absl/time",
|
||||
"@com_google_protobuf//:protobuf_headers",
|
||||
],
|
||||
] + if_google(["@com_google_protobuf//:duration_cc_proto"]),
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user