Populate the cost for async collective in both async-start and the computation root op.

PiperOrigin-RevId: 822223031
This commit is contained in:
Felix Wang 2025-10-21 12:07:04 -07:00 committed by TensorFlower Gardener
parent 633c3efcf9
commit 2de2bb8581
3 changed files with 59 additions and 22 deletions

View File

@ -906,7 +906,9 @@ xla_cc_test(
deps = [
":sol_gpu_cost_model_stats_collection",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/testlib:hlo_hardware_independent_test_base",
"//xla/service:hlo_cost_analysis",
"//xla/service/gpu:backend_configs_cc",
"//xla/service/gpu:gpu_device_info_for_tests",
"//xla/service/gpu/model/experimental:symbolic_expr",

View File

@ -60,8 +60,18 @@ bool SetReificationCost(HloInstruction* instr, double cost_us) {
return false;
}
auto reification_cost = gpu_config->add_reification_cost();
VLOG(3) << "Setting exec_time_us=" << cost_us << " for " << instr->name()
<< " in SolGpuCostModelStatsCollection";
reification_cost->set_exec_time_us(cost_us);
reification_cost->set_name("sol");
if (instr->opcode() == HloOpcode::kAsyncStart &&
instr->async_wrapped_instruction() != nullptr) {
VLOG(9) << "AsyncStart: Setting reification cost for async start "
<< instr->ToString() << " computation:"
<< instr->async_wrapped_computation()->ToString();
return SetReificationCost(
instr->async_wrapped_computation()->root_instruction(), cost_us);
}
return instr->set_backend_config(*gpu_config).ok();
}
@ -72,9 +82,13 @@ bool RecordReificationCost(HloInstruction& instr,
HloGraphNode from(&instr, /*original_position=*/-1);
HloGraphNode to(instr.users()[0], /*original_position=*/-1);
if (estimator.IsAsyncPair(from, to)) {
VLOG(10) << "Recording reification cost for async pair from: "
<< instr.ToString() << " to: " << instr.users()[0]->ToString();
return SetReificationCost(&instr, estimator.GetLatencyBetween(from, to));
}
}
VLOG(10) << "Recording reification cost for single node: "
<< instr.ToString();
return SetReificationCost(&instr, estimator.NodeCost(&instr));
}

View File

@ -23,10 +23,12 @@ limitations under the License.
#include "absl/log/log.h"
#include "absl/strings/string_view.h"
#include "mlir/IR/MLIRContext.h"
#include "xla/hlo/ir/hlo_instruction.h"
#include "xla/hlo/testlib/hlo_hardware_independent_test_base.h"
#include "xla/service/gpu/backend_configs.pb.h"
#include "xla/service/gpu/gpu_device_info_for_tests.h"
#include "xla/service/gpu/model/experimental/symbolic_expr.h"
#include "xla/service/hlo_cost_analysis.h"
#include "xla/shape.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
@ -43,26 +45,9 @@ using ShapeSizeFn = std::function<int64_t(const Shape&)>;
class SolGpuCostModelStatsCollectionTest
: public HloHardwareIndependentTestBase {
public:
explicit SolGpuCostModelStatsCollectionTest() {
ShapeSizeFn shape_size_bytes =
[&shape_size_bytes](const Shape& shape) -> int64_t {
int64_t shape_size = 0;
if (shape.IsTuple()) {
for (auto& sub_shape : shape.tuple_shapes()) {
shape_size += shape_size_bytes(sub_shape);
}
return shape_size;
}
return ShapeUtil::ByteSizeOfElements(shape);
};
shape_size_fn_ = shape_size_bytes;
}
protected:
se::DeviceDescription device_info_ =
TestGpuDeviceInfo::RTXA6000DeviceInfo(se::CudaComputeCapability(9, 0));
ShapeSizeFn shape_size_fn_;
int pointer_size_ = 8;
mlir::MLIRContext mlir_context_;
SymbolicExprContext symbolic_expr_context_{&mlir_context_};
@ -89,11 +74,11 @@ TEST_F(SolGpuCostModelStatsCollectionTest,
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(
bool changed,
SolGpuCostModelStatsCollection(device_info_, shape_size_fn_,
pointer_size_, &symbolic_expr_context_)
.Run(module.get()));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
SolGpuCostModelStatsCollection(
device_info_, HloCostAnalysis::DefaultShapeSize,
pointer_size_, &symbolic_expr_context_)
.Run(module.get()));
VLOG(1) << module->ToString();
@ -105,6 +90,42 @@ TEST_F(SolGpuCostModelStatsCollectionTest,
->reification_cost(),
ElementsAre(Property(&ReificationCost::exec_time_us, Gt(0))));
}
TEST_F(SolGpuCostModelStatsCollectionTest,
RecordsRuntimeInfoForAsyncStartReduceScatter) {
constexpr absl::string_view kHloText = R"(
HloModule async_rs_test
%add.f32 (x: f32[], y: f32[]) -> f32[] {
%x = f32[] parameter(0)
%y = f32[] parameter(1)
ROOT %add = f32[] add(%x, %y)
}
%async_rs {
%p0 = f32[4096,128256] parameter(0)
ROOT %rs = f32[512,128256] reduce-scatter(%p0), channel_id=1,
replica_groups={{0,1,2,3,4,5,6,7}}, dimensions={0}, to_apply=%add.f32
}
ENTRY main {
%param = f32[4096,128256] parameter(0)
%rs_start = ((f32[4096,128256]), f32[512,128256], u32[])
async-start(%param), calls=%async_rs
ROOT %rs_done = f32[512,128256] async-done(%rs_start)
})";
TF_ASSERT_OK_AND_ASSIGN(auto module, ParseAndReturnVerifiedModule(kHloText));
TF_ASSERT_OK_AND_ASSIGN(bool changed,
SolGpuCostModelStatsCollection(
device_info_, HloCostAnalysis::DefaultShapeSize,
pointer_size_, &symbolic_expr_context_)
.Run(module.get()));
VLOG(1) << module->ToString();
EXPECT_FALSE(changed);
HloInstruction* rs_start = FindInstruction(module.get(), "rs_start");
ASSERT_NE(rs_start, nullptr);
HloComputation* async_comp = rs_start->async_wrapped_computation();
ASSERT_NE(async_comp, nullptr);
HloInstruction* rs_instr = async_comp->root_instruction();
EXPECT_THAT(rs_instr->backend_config<GpuBackendConfig>()->reification_cost(),
ElementsAre(Property(&ReificationCost::exec_time_us, Gt(0))));
}
} // namespace
} // namespace xla::gpu