mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
parent
752a654e9e
commit
4618f903c4
|
|
@ -19,7 +19,6 @@ limitations under the License.
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
||||||
#include "absl/status/statusor.h"
|
#include "absl/status/statusor.h"
|
||||||
#include "xla/hlo/ir/hlo_instructions.h"
|
|
||||||
#include "xla/hlo/ir/hlo_module.h"
|
#include "xla/hlo/ir/hlo_module.h"
|
||||||
#include "xla/hlo/pass/hlo_pass_interface.h"
|
#include "xla/hlo/pass/hlo_pass_interface.h"
|
||||||
|
|
||||||
|
|
@ -31,8 +30,7 @@ class ReduceScatterDecomposer : public HloModulePass {
|
||||||
public:
|
public:
|
||||||
explicit ReduceScatterDecomposer(
|
explicit ReduceScatterDecomposer(
|
||||||
std::function<void(Shape&)> update_layout = nullptr,
|
std::function<void(Shape&)> update_layout = nullptr,
|
||||||
std::function<bool(HloReduceScatterInstruction*)> should_decompose =
|
std::function<bool(const HloInstruction*)> should_decompose = nullptr)
|
||||||
nullptr)
|
|
||||||
: update_layout_(update_layout), should_decompose_(should_decompose) {}
|
: update_layout_(update_layout), should_decompose_(should_decompose) {}
|
||||||
absl::string_view name() const override {
|
absl::string_view name() const override {
|
||||||
return "reduce-scatter-decomposer";
|
return "reduce-scatter-decomposer";
|
||||||
|
|
@ -43,7 +41,7 @@ class ReduceScatterDecomposer : public HloModulePass {
|
||||||
HloModule* module,
|
HloModule* module,
|
||||||
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
|
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
|
||||||
std::function<void(Shape&)> update_layout_;
|
std::function<void(Shape&)> update_layout_;
|
||||||
std::function<bool(HloReduceScatterInstruction*)> should_decompose_;
|
std::function<bool(const HloInstruction*)> should_decompose_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace xla
|
} // namespace xla
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user