mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
[XLA] Remove HLO unstacker.
The pass is not used. PiperOrigin-RevId: 824274493
This commit is contained in:
parent
cc5ee2577c
commit
2de4be94aa
40
third_party/xla/xla/service/BUILD
vendored
40
third_party/xla/xla/service/BUILD
vendored
|
|
@ -2743,46 +2743,6 @@ xla_cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hlo_unstacker",
|
||||
srcs = ["hlo_unstacker.cc"],
|
||||
hdrs = ["hlo_unstacker.h"],
|
||||
deps = [
|
||||
":hlo_creation_utils",
|
||||
":pattern_matcher",
|
||||
":tuple_util",
|
||||
":while_loop_unroller",
|
||||
"//xla:shape_util",
|
||||
"//xla:util",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/pass:hlo_pass",
|
||||
"@com_google_absl//absl/algorithm:container",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/log",
|
||||
"@com_google_absl//absl/log:check",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/status:statusor",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:string_view",
|
||||
"@local_tsl//tsl/platform:errors",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
xla_cc_test(
|
||||
name = "hlo_unstacker_test",
|
||||
srcs = ["hlo_unstacker_test.cc"],
|
||||
tags = if_google(["requires-net:external"]),
|
||||
deps = [
|
||||
":hlo_unstacker",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/tests:hlo_test_base",
|
||||
"@com_google_googletest//:gtest_main",
|
||||
"@local_tsl//tsl/platform:statusor",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "while_loop_unroller",
|
||||
srcs = ["while_loop_unroller.cc"],
|
||||
|
|
|
|||
1503
third_party/xla/xla/service/hlo_unstacker.cc
vendored
1503
third_party/xla/xla/service/hlo_unstacker.cc
vendored
File diff suppressed because it is too large
Load Diff
100
third_party/xla/xla/service/hlo_unstacker.h
vendored
100
third_party/xla/xla/service/hlo_unstacker.h
vendored
|
|
@ -1,100 +0,0 @@
|
|||
/* Copyright 2024 The OpenXLA Authors.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef XLA_SERVICE_HLO_UNSTACKER_H_
|
||||
#define XLA_SERVICE_HLO_UNSTACKER_H_
|
||||
|
||||
#include <stdbool.h>
|
||||
|
||||
#include <functional>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/status/statusor.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/pass/hlo_pass_interface.h"
|
||||
|
||||
namespace xla {
|
||||
// This pass implements unstacking for loop operands. Generally speaking,
|
||||
// unstacking is the act of breaking a rank n tensor into n smaller n-1 rank
|
||||
// tensors without changing the semantics of the program. There are different
|
||||
// patterns that can benefit from unstacking. This pass aims to implement such
|
||||
// patterns. The patterns implemented are not exhaustive by any means. Lets
|
||||
// consider a simple example:
|
||||
// In the pattern below, `I` (the most-major dimension in the stacked tensor),
|
||||
// is equal to the trip count of the while loop and `i` is the iteration
|
||||
// variable of the loop. The stacked input is used only as input to a
|
||||
// shape-covering dynamic-slice (check the definition of a shape-covering
|
||||
// dynamic-slice: `tensorflow/compiler/xla/service/while_loop_unroller.h`)
|
||||
//
|
||||
// +-while----------------------------------------------------+
|
||||
// | param = tuple(..., [I,x1,...,xn]stacked, ...) |
|
||||
// | ... |
|
||||
// | [1,x1,...,xn]slice = ds([I,x1,...,xn]stacked, i, 0, ...) |
|
||||
// | ... |
|
||||
// | ops using the slice |
|
||||
// | ... |
|
||||
// | ROOT = tuple(..., stacked, ...) |
|
||||
// +----------------------------------------------------------+
|
||||
//
|
||||
// This pattern can be unstacked and rewritten as following:
|
||||
//
|
||||
// +-while-----------------------------------------------------------------+
|
||||
// | param = tuple(..., ([1,x1,...,xn], ..., [1,x1,...,xn])unstacked, ...) |
|
||||
// | ... |
|
||||
//. | slice_1 = get_tuple_element(unstacked), index=i |
|
||||
// | ops using the slice_i |
|
||||
// | ... |
|
||||
// | ROOT = tuple(..., unstacked, ...) |
|
||||
// +-----------------------------------------------------------------------+
|
||||
//
|
||||
// where the unstacked input is initialized with the slices outside of the loop:
|
||||
// unstacked = tuple(slice_1, ..., slice_n)
|
||||
// To get each slice, the pass introduces a dynamic version of the
|
||||
// kGetTupleElement instruction using a custom-call. This custom-call is then
|
||||
// replaced with a normal get-tuple-element during loop unrolling.
|
||||
//
|
||||
// Below is a high-level overview of the unstacking algorithm:
|
||||
// We unstack a module by unstacking inputs to the while loops within the entry
|
||||
// computation for every index. Given a while loop and a candidate for
|
||||
// unstacking, the algorithm performs the following two steps:
|
||||
// 1. The first step is to determine if unstacking is possible by checking if
|
||||
// the unstacking of the while operand at the given index can be propagated
|
||||
// through the body (and nested bodies if any). Unstacking is possible
|
||||
// if a pair of pattern and handler is provided that can identify and handle
|
||||
// such pattern that involves all the uses of the stacked operand at the given
|
||||
// index.
|
||||
// 2. Apply the unstacking by executing the changes gathered in the first phase.
|
||||
class HloUnstacker : public HloModulePass {
|
||||
public:
|
||||
~HloUnstacker() override = default;
|
||||
|
||||
explicit HloUnstacker(std::function<bool(HloInstruction*)> unfuse_slice =
|
||||
[](HloInstruction* instr) { return true; })
|
||||
: unfuse_slice_(unfuse_slice) {}
|
||||
|
||||
absl::string_view name() const override { return "hlo_unstacker"; }
|
||||
using HloPassInterface::Run;
|
||||
absl::StatusOr<bool> Run(
|
||||
HloModule* module,
|
||||
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
|
||||
|
||||
private:
|
||||
std::function<bool(HloInstruction*)> unfuse_slice_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
||||
#endif // XLA_SERVICE_HLO_UNSTACKER_H_
|
||||
1503
third_party/xla/xla/service/hlo_unstacker_test.cc
vendored
1503
third_party/xla/xla/service/hlo_unstacker_test.cc
vendored
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user