mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
parent
752a654e9e
commit
4618f903c4
|
|
@ -19,7 +19,6 @@ limitations under the License.
|
|||
#include <functional>
|
||||
|
||||
#include "absl/status/statusor.h"
|
||||
#include "xla/hlo/ir/hlo_instructions.h"
|
||||
#include "xla/hlo/ir/hlo_module.h"
|
||||
#include "xla/hlo/pass/hlo_pass_interface.h"
|
||||
|
||||
|
|
@ -31,8 +30,7 @@ class ReduceScatterDecomposer : public HloModulePass {
|
|||
public:
|
||||
explicit ReduceScatterDecomposer(
|
||||
std::function<void(Shape&)> update_layout = nullptr,
|
||||
std::function<bool(HloReduceScatterInstruction*)> should_decompose =
|
||||
nullptr)
|
||||
std::function<bool(const HloInstruction*)> should_decompose = nullptr)
|
||||
: update_layout_(update_layout), should_decompose_(should_decompose) {}
|
||||
absl::string_view name() const override {
|
||||
return "reduce-scatter-decomposer";
|
||||
|
|
@ -43,7 +41,7 @@ class ReduceScatterDecomposer : public HloModulePass {
|
|||
HloModule* module,
|
||||
const absl::flat_hash_set<absl::string_view>& execution_threads) override;
|
||||
std::function<void(Shape&)> update_layout_;
|
||||
std::function<bool(HloReduceScatterInstruction*)> should_decompose_;
|
||||
std::function<bool(const HloInstruction*)> should_decompose_;
|
||||
};
|
||||
|
||||
} // namespace xla
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user