mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ROCm] [TunableOp] Track top solutions during tuning process (#147243)
For each set of GEMM parameters that are evaluated by Tunableop, keep track of the top 5 solutions. Print the top 5 solutions when `PYTORCH_TUNABLEOP_VERBOSE=2`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/147243 Approved by: https://github.com/jeffdaily
This commit is contained in:
parent
6c3492b491
commit
0ef2e938d0
|
|
@ -21,6 +21,7 @@
|
|||
#include <string>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
#include <deque>
|
||||
|
||||
namespace at::cuda::tunable {
|
||||
|
||||
|
|
@ -84,6 +85,25 @@ class Stats {
|
|||
double _max;
|
||||
};
|
||||
|
||||
class FixedSizeStack {
|
||||
private:
|
||||
std::deque<std::string> stack;
|
||||
const size_t max_size;
|
||||
|
||||
public:
|
||||
FixedSizeStack(size_t size) : max_size(size) {}
|
||||
|
||||
void push(const std::string& value) {
|
||||
if (stack.size() >= max_size) {
|
||||
stack.pop_front(); // Remove the oldest entry
|
||||
}
|
||||
stack.push_back(value); // Add new entry
|
||||
}
|
||||
|
||||
auto rbegin() { return stack.rbegin(); }
|
||||
auto rend() { return stack.rend(); }
|
||||
};
|
||||
|
||||
} // anonymous namespace
|
||||
|
||||
template <typename ParamsT>
|
||||
|
|
@ -208,6 +228,7 @@ class TunableOp {
|
|||
auto min_duration_ms = std::numeric_limits<double>::infinity();
|
||||
std::string id_name = "Default";
|
||||
ParamsT* reference_params = nullptr;
|
||||
auto top_solns = FixedSizeStack(5);
|
||||
|
||||
// numeric check option is controlled by non-static env var, so check it once per tuned operator
|
||||
bool do_numerics_check = ctx->IsNumericsCheckEnabled();
|
||||
|
|
@ -349,6 +370,8 @@ class TunableOp {
|
|||
" std ", s_stddev);
|
||||
min_duration_ms = s._mean;
|
||||
id_name = op_names_[i];
|
||||
std::string current_soln = std::to_string(s._mean) + " " + op_names_[i];
|
||||
top_solns.push(current_soln);
|
||||
}
|
||||
else {
|
||||
TUNABLE_LOG3("├──found slower instance id=", i, ". " , s._mean, "ms. ", op_names_[i],
|
||||
|
|
@ -367,6 +390,10 @@ class TunableOp {
|
|||
}
|
||||
|
||||
TUNABLE_LOG2("└──found fastest for ", op_sig, '(', params_sig, ") ", id_name);
|
||||
TUNABLE_LOG2("└──top five solutions for ", op_sig, '(', params_sig, ") ");
|
||||
for (auto it = top_solns.rbegin(); it != top_solns.rend(); ++it) {
|
||||
TUNABLE_LOG2(" ", *it);
|
||||
}
|
||||
return ResultEntry(id_name, min_duration_ms);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user