PiperOrigin-RevId: 826722506
This commit is contained in:
A. Unique TensorFlower 2025-10-31 20:10:47 -07:00 committed by TensorFlower Gardener
parent 752a654e9e
commit 4618f903c4

View File

@ -19,7 +19,6 @@ limitations under the License.
#include <functional> #include <functional>
#include "absl/status/statusor.h" #include "absl/status/statusor.h"
#include "xla/hlo/ir/hlo_instructions.h"
#include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/ir/hlo_module.h"
#include "xla/hlo/pass/hlo_pass_interface.h" #include "xla/hlo/pass/hlo_pass_interface.h"
@ -31,8 +30,7 @@ class ReduceScatterDecomposer : public HloModulePass {
public: public:
explicit ReduceScatterDecomposer( explicit ReduceScatterDecomposer(
std::function<void(Shape&)> update_layout = nullptr, std::function<void(Shape&)> update_layout = nullptr,
std::function<bool(HloReduceScatterInstruction*)> should_decompose = std::function<bool(const HloInstruction*)> should_decompose = nullptr)
nullptr)
: update_layout_(update_layout), should_decompose_(should_decompose) {} : update_layout_(update_layout), should_decompose_(should_decompose) {}
absl::string_view name() const override { absl::string_view name() const override {
return "reduce-scatter-decomposer"; return "reduce-scatter-decomposer";
@ -43,7 +41,7 @@ class ReduceScatterDecomposer : public HloModulePass {
HloModule* module, HloModule* module,
const absl::flat_hash_set<absl::string_view>& execution_threads) override; const absl::flat_hash_set<absl::string_view>& execution_threads) override;
std::function<void(Shape&)> update_layout_; std::function<void(Shape&)> update_layout_;
std::function<bool(HloReduceScatterInstruction*)> should_decompose_; std::function<bool(const HloInstruction*)> should_decompose_;
}; };
} // namespace xla } // namespace xla