[ROCm] [TunableOp] Track top solutions during tuning process (#147243)

For each set of GEMM parameters that are evaluated by Tunableop, keep track of the top 5 solutions. Print the top 5 solutions when `PYTORCH_TUNABLEOP_VERBOSE=2`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147243
Approved by: https://github.com/jeffdaily
This commit is contained in:
Nichols A. Romero 2025-03-05 09:34:59 +00:00 committed by PyTorch MergeBot
parent 6c3492b491
commit 0ef2e938d0

View File

@ -21,6 +21,7 @@
#include <string>
#include <unordered_map>
#include <vector>
#include <deque>
namespace at::cuda::tunable {
@ -84,6 +85,25 @@ class Stats {
double _max;
};
class FixedSizeStack {
private:
std::deque<std::string> stack;
const size_t max_size;
public:
FixedSizeStack(size_t size) : max_size(size) {}
void push(const std::string& value) {
if (stack.size() >= max_size) {
stack.pop_front(); // Remove the oldest entry
}
stack.push_back(value); // Add new entry
}
auto rbegin() { return stack.rbegin(); }
auto rend() { return stack.rend(); }
};
} // anonymous namespace
template <typename ParamsT>
@ -208,6 +228,7 @@ class TunableOp {
auto min_duration_ms = std::numeric_limits<double>::infinity();
std::string id_name = "Default";
ParamsT* reference_params = nullptr;
auto top_solns = FixedSizeStack(5);
// numeric check option is controlled by non-static env var, so check it once per tuned operator
bool do_numerics_check = ctx->IsNumericsCheckEnabled();
@ -349,6 +370,8 @@ class TunableOp {
" std ", s_stddev);
min_duration_ms = s._mean;
id_name = op_names_[i];
std::string current_soln = std::to_string(s._mean) + " " + op_names_[i];
top_solns.push(current_soln);
}
else {
TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
@ -367,6 +390,10 @@ class TunableOp {
}
TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
TUNABLE_LOG2("└──top five solutions for ", op_sig, '(', params_sig, ") ");
for (auto it = top_solns.rbegin(); it != top_solns.rend(); ++it) {
TUNABLE_LOG2(" ", *it);
}
return ResultEntry(id_name, min_duration_ms);
}