Revert D25542799: [PyTorch] Merge CoinflipTLS into RecordFunctionTLS

Test Plan: revert-hammer

Differential Revision:
D25542799 (9ce1df079f)

Original commit changeset: 310f9fd15710

fbshipit-source-id: 51777914422a560e94430a786c86f5de4007a00b
This commit is contained in:
Mike Ruberry 2020-12-17 16:35:11 -08:00 committed by Facebook GitHub Bot
parent 625bc40def
commit c78fd76f18
2 changed files with 25 additions and 22 deletions

View File

@ -32,20 +32,29 @@ thread_local uint64_t current_thread_id_ = 0;
// Low probability constant
static const double kLowProb = 0.001;
struct CoinflipTLS {
int tries_left_;
std::mt19937 genGeo_;
std::mt19937 genZeroOne_;
std::geometric_distribution<int> distGeo_;
std::uniform_real_distribution<double> distZeroOne_;
CoinflipTLS();
};
CoinflipTLS::CoinflipTLS()
: tries_left_(0), genGeo_(std::random_device()()), genZeroOne_(std::random_device()()), distGeo_(kLowProb), distZeroOne_(0.0, 1.0) {}
thread_local CoinflipTLS coinflip_tls_;
int sample_geometric() {
return rf_tls_.coinflip_state_.distGeo_(rf_tls_.coinflip_state_.genGeo_);
return coinflip_tls_.distGeo_(coinflip_tls_.genGeo_);
}
double sample_zero_one() {
return rf_tls_.coinflip_state_.distZeroOne_(rf_tls_.coinflip_state_.genZeroOne_);
return coinflip_tls_.distZeroOne_(coinflip_tls_.genZeroOne_);
}
} // namespace
RecordFunctionTLS::CoinflipTLS::CoinflipTLS()
: tries_left_(0), genGeo_(std::random_device()()), genZeroOne_(std::random_device()()), distGeo_(kLowProb), distZeroOne_(0.0, 1.0) {}
const RecordFunctionTLS& get_record_function_tls_() {
return rf_tls_;
}
@ -160,11 +169,11 @@ class CallbackManager {
// flip for kLowProb with a thread local number of tries tries_left_
// sampled from the geometric distribution.
if (sampling_prob < kLowProb) {
if (rf_tls_.coinflip_state_.tries_left_ == 0) {
rf_tls_.coinflip_state_.tries_left_ = sample_geometric();
if (coinflip_tls_.tries_left_ == 0) {
coinflip_tls_.tries_left_ = sample_geometric();
return (sample_zero_one() < sampling_prob / kLowProb);
} else {
--rf_tls_.coinflip_state_.tries_left_;
--coinflip_tls_.tries_left_;
return false;
}
} else {
@ -503,11 +512,12 @@ bool shouldRunRecordFunction(bool* pre_sampled) {
}
*pre_sampled = true;
if (rf_tls_ptr->coinflip_state_.tries_left_ == 0) {
rf_tls_ptr->coinflip_state_.tries_left_ = sample_geometric();
auto* coinflip_tls_ptr = &coinflip_tls_;
if (coinflip_tls_ptr->tries_left_ == 0) {
coinflip_tls_ptr->tries_left_ = sample_geometric();
return true;
} else {
--rf_tls_ptr->coinflip_state_.tries_left_;
--coinflip_tls_ptr->tries_left_;
return false;
}
}

View File

@ -5,9 +5,9 @@
#include <c10/macros/Export.h>
#include <c10/util/Optional.h>
#include <c10/util/SmallVector.h>
#include <functional>
#include <memory>
#include <random>
#include <functional>
namespace c10 {
class CAFFE2_API OperatorHandle;
@ -544,15 +544,8 @@ struct TORCH_API RecordFunctionTLS {
bool tls_record_function_enabled_ = true;
struct CoinflipTLS {
int tries_left_;
std::mt19937 genGeo_;
std::mt19937 genZeroOne_;
std::geometric_distribution<int> distGeo_;
std::uniform_real_distribution<double> distZeroOne_;
CoinflipTLS();
};
CoinflipTLS coinflip_state_;
// Stores the number of coin flips before the next successful coin flip
int tries_left_ = 0;
};
TORCH_API const RecordFunctionTLS& get_record_function_tls_();