mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Revert D25542799: [PyTorch] Merge CoinflipTLS into RecordFunctionTLS
Test Plan: revert-hammer
Differential Revision:
D25542799 (9ce1df079f)
Original commit changeset: 310f9fd15710
fbshipit-source-id: 51777914422a560e94430a786c86f5de4007a00b
This commit is contained in:
parent
625bc40def
commit
c78fd76f18
|
|
@ -32,20 +32,29 @@ thread_local uint64_t current_thread_id_ = 0;
|
|||
|
||||
// Low probability constant
|
||||
static const double kLowProb = 0.001;
|
||||
struct CoinflipTLS {
|
||||
int tries_left_;
|
||||
std::mt19937 genGeo_;
|
||||
std::mt19937 genZeroOne_;
|
||||
std::geometric_distribution<int> distGeo_;
|
||||
std::uniform_real_distribution<double> distZeroOne_;
|
||||
CoinflipTLS();
|
||||
};
|
||||
|
||||
CoinflipTLS::CoinflipTLS()
|
||||
: tries_left_(0), genGeo_(std::random_device()()), genZeroOne_(std::random_device()()), distGeo_(kLowProb), distZeroOne_(0.0, 1.0) {}
|
||||
thread_local CoinflipTLS coinflip_tls_;
|
||||
|
||||
int sample_geometric() {
|
||||
return rf_tls_.coinflip_state_.distGeo_(rf_tls_.coinflip_state_.genGeo_);
|
||||
return coinflip_tls_.distGeo_(coinflip_tls_.genGeo_);
|
||||
}
|
||||
|
||||
double sample_zero_one() {
|
||||
return rf_tls_.coinflip_state_.distZeroOne_(rf_tls_.coinflip_state_.genZeroOne_);
|
||||
return coinflip_tls_.distZeroOne_(coinflip_tls_.genZeroOne_);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
RecordFunctionTLS::CoinflipTLS::CoinflipTLS()
|
||||
: tries_left_(0), genGeo_(std::random_device()()), genZeroOne_(std::random_device()()), distGeo_(kLowProb), distZeroOne_(0.0, 1.0) {}
|
||||
|
||||
const RecordFunctionTLS& get_record_function_tls_() {
|
||||
return rf_tls_;
|
||||
}
|
||||
|
|
@ -160,11 +169,11 @@ class CallbackManager {
|
|||
// flip for kLowProb with a thread local number of tries tries_left_
|
||||
// sampled from the geometric distribution.
|
||||
if (sampling_prob < kLowProb) {
|
||||
if (rf_tls_.coinflip_state_.tries_left_ == 0) {
|
||||
rf_tls_.coinflip_state_.tries_left_ = sample_geometric();
|
||||
if (coinflip_tls_.tries_left_ == 0) {
|
||||
coinflip_tls_.tries_left_ = sample_geometric();
|
||||
return (sample_zero_one() < sampling_prob / kLowProb);
|
||||
} else {
|
||||
--rf_tls_.coinflip_state_.tries_left_;
|
||||
--coinflip_tls_.tries_left_;
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -503,11 +512,12 @@ bool shouldRunRecordFunction(bool* pre_sampled) {
|
|||
}
|
||||
|
||||
*pre_sampled = true;
|
||||
if (rf_tls_ptr->coinflip_state_.tries_left_ == 0) {
|
||||
rf_tls_ptr->coinflip_state_.tries_left_ = sample_geometric();
|
||||
auto* coinflip_tls_ptr = &coinflip_tls_;
|
||||
if (coinflip_tls_ptr->tries_left_ == 0) {
|
||||
coinflip_tls_ptr->tries_left_ = sample_geometric();
|
||||
return true;
|
||||
} else {
|
||||
--rf_tls_ptr->coinflip_state_.tries_left_;
|
||||
--coinflip_tls_ptr->tries_left_;
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,9 +5,9 @@
|
|||
#include <c10/macros/Export.h>
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/SmallVector.h>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
|
||||
#include <functional>
|
||||
|
||||
namespace c10 {
|
||||
class CAFFE2_API OperatorHandle;
|
||||
|
|
@ -544,15 +544,8 @@ struct TORCH_API RecordFunctionTLS {
|
|||
|
||||
bool tls_record_function_enabled_ = true;
|
||||
|
||||
struct CoinflipTLS {
|
||||
int tries_left_;
|
||||
std::mt19937 genGeo_;
|
||||
std::mt19937 genZeroOne_;
|
||||
std::geometric_distribution<int> distGeo_;
|
||||
std::uniform_real_distribution<double> distZeroOne_;
|
||||
CoinflipTLS();
|
||||
};
|
||||
CoinflipTLS coinflip_state_;
|
||||
// Stores the number of coin flips before the next successful coin flip
|
||||
int tries_left_ = 0;
|
||||
};
|
||||
|
||||
TORCH_API const RecordFunctionTLS& get_record_function_tls_();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user