mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Populate the cost for async collective in both async-start and the computation root op.
PiperOrigin-RevId: 822223031
This commit is contained in:
parent
633c3efcf9
commit
2de2bb8581
2
third_party/xla/xla/service/gpu/model/BUILD
vendored
2
third_party/xla/xla/service/gpu/model/BUILD
vendored
|
|
@ -906,7 +906,9 @@ xla_cc_test(
|
|||
deps = [
|
||||
":sol_gpu_cost_model_stats_collection",
|
||||
"//xla:shape_util",
|
||||
"//xla/hlo/ir:hlo",
|
||||
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
|
||||
"//xla/service:hlo_cost_analysis",
|
||||
"//xla/service/gpu:backend_configs_cc",
|
||||
"//xla/service/gpu:gpu_device_info_for_tests",
|
||||
"//xla/service/gpu/model/experimental:symbolic_expr",
|
||||
|
|
|
|||
|
|
@ -60,8 +60,18 @@ bool SetReificationCost(HloInstruction* instr, double cost_us) {
|
|||
return false;
|
||||
}
|
||||
auto reification_cost = gpu_config->add_reification_cost();
|
||||
VLOG(3) << "Setting exec_time_us=" << cost_us << " for " << instr->name()
|
||||
<< " in SolGpuCostModelStatsCollection";
|
||||
reification_cost->set_exec_time_us(cost_us);
|
||||
reification_cost->set_name("sol");
|
||||
if (instr->opcode() == HloOpcode::kAsyncStart &&
|
||||
instr->async_wrapped_instruction() != nullptr) {
|
||||
VLOG(9) << "AsyncStart: Setting reification cost for async start "
|
||||
<< instr->ToString() << " computation:"
|
||||
<< instr->async_wrapped_computation()->ToString();
|
||||
return SetReificationCost(
|
||||
instr->async_wrapped_computation()->root_instruction(), cost_us);
|
||||
}
|
||||
return instr->set_backend_config(*gpu_config).ok();
|
||||
}
|
||||
|
||||
|
|
@ -72,9 +82,13 @@ bool RecordReificationCost(HloInstruction& instr,
|
|||
HloGraphNode from(&instr, /*original_position=*/-1);
|
||||
HloGraphNode to(instr.users()[0], /*original_position=*/-1);
|
||||
if (estimator.IsAsyncPair(from, to)) {
|
||||
VLOG(10) << "Recording reification cost for async pair from: "
|
||||
<< instr.ToString() << " to: " << instr.users()[0]->ToString();
|
||||
return SetReificationCost(&instr, estimator.GetLatencyBetween(from, to));
|
||||
}
|
||||
}
|
||||
VLOG(10) << "Recording reification cost for single node: "
|
||||
<< instr.ToString();
|
||||
return SetReificationCost(&instr, estimator.NodeCost(&instr));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,10 +23,12 @@ limitations under the License.
|
|||
#include "absl/log/log.h"
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "xla/hlo/ir/hlo_instruction.h"
|
||||
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
|
||||
#include "xla/service/gpu/backend_configs.pb.h"
|
||||
#include "xla/service/gpu/gpu_device_info_for_tests.h"
|
||||
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
|
||||
#include "xla/service/hlo_cost_analysis.h"
|
||||
#include "xla/shape.h"
|
||||
#include "xla/shape_util.h"
|
||||
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
|
||||
|
|
@ -43,26 +45,9 @@ using ShapeSizeFn = std::function<int64_t(const Shape&)>;
|
|||
|
||||
class SolGpuCostModelStatsCollectionTest
|
||||
: public HloHardwareIndependentTestBase {
|
||||
public:
|
||||
explicit SolGpuCostModelStatsCollectionTest() {
|
||||
ShapeSizeFn shape_size_bytes =
|
||||
[&shape_size_bytes](const Shape& shape) -> int64_t {
|
||||
int64_t shape_size = 0;
|
||||
if (shape.IsTuple()) {
|
||||
for (auto& sub_shape : shape.tuple_shapes()) {
|
||||
shape_size += shape_size_bytes(sub_shape);
|
||||
}
|
||||
return shape_size;
|
||||
}
|
||||
return ShapeUtil::ByteSizeOfElements(shape);
|
||||
};
|
||||
shape_size_fn_ = shape_size_bytes;
|
||||
}
|
||||
|
||||
protected:
|
||||
se::DeviceDescription device_info_ =
|
||||
TestGpuDeviceInfo::RTXA6000DeviceInfo(se::CudaComputeCapability(9, 0));
|
||||
ShapeSizeFn shape_size_fn_;
|
||||
int pointer_size_ = 8;
|
||||
mlir::MLIRContext mlir_context_;
|
||||
SymbolicExprContext symbolic_expr_context_{&mlir_context_};
|
||||
|
|
@ -89,11 +74,11 @@ TEST_F(SolGpuCostModelStatsCollectionTest,
|
|||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(
|
||||
bool changed,
|
||||
SolGpuCostModelStatsCollection(device_info_, shape_size_fn_,
|
||||
pointer_size_, &symbolic_expr_context_)
|
||||
.Run(module.get()));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
SolGpuCostModelStatsCollection(
|
||||
device_info_, HloCostAnalysis::DefaultShapeSize,
|
||||
pointer_size_, &symbolic_expr_context_)
|
||||
.Run(module.get()));
|
||||
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
|
|
@ -105,6 +90,42 @@ TEST_F(SolGpuCostModelStatsCollectionTest,
|
|||
->reification_cost(),
|
||||
ElementsAre(Property(&ReificationCost::exec_time_us, Gt(0))));
|
||||
}
|
||||
TEST_F(SolGpuCostModelStatsCollectionTest,
|
||||
RecordsRuntimeInfoForAsyncStartReduceScatter) {
|
||||
constexpr absl::string_view kHloText = R"(
|
||||
HloModule async_rs_test
|
||||
%add.f32 (x: f32[], y: f32[]) -> f32[] {
|
||||
%x = f32[] parameter(0)
|
||||
%y = f32[] parameter(1)
|
||||
ROOT %add = f32[] add(%x, %y)
|
||||
}
|
||||
%async_rs {
|
||||
%p0 = f32[4096,128256] parameter(0)
|
||||
ROOT %rs = f32[512,128256] reduce-scatter(%p0), channel_id=1,
|
||||
replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, to_apply=%add.f32
|
||||
}
|
||||
ENTRY main {
|
||||
%param = f32[4096,128256] parameter(0)
|
||||
%rs_start = ((f32[4096,128256]), f32[512,128256], u32[])
|
||||
async-start(%param), calls=%async_rs
|
||||
ROOT %rs_done = f32[512,128256] async-done(%rs_start)
|
||||
})";
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
|
||||
TF_ASSERT_OK_AND_ASSIGN(bool changed,
|
||||
SolGpuCostModelStatsCollection(
|
||||
device_info_, HloCostAnalysis::DefaultShapeSize,
|
||||
pointer_size_, &symbolic_expr_context_)
|
||||
.Run(module.get()));
|
||||
VLOG(1) << module->ToString();
|
||||
EXPECT_FALSE(changed);
|
||||
HloInstruction* rs_start = FindInstruction(module.get(), "rs_start");
|
||||
ASSERT_NE(rs_start, nullptr);
|
||||
HloComputation* async_comp = rs_start->async_wrapped_computation();
|
||||
ASSERT_NE(async_comp, nullptr);
|
||||
HloInstruction* rs_instr = async_comp->root_instruction();
|
||||
|
||||
EXPECT_THAT(rs_instr->backend_config<GpuBackendConfig>()->reification_cost(),
|
||||
ElementsAre(Property(&ReificationCost::exec_time_us, Gt(0))));
|
||||
}
|
||||
} // namespace
|
||||
} // namespace xla::gpu
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user