pytorch/torch/lib/c10d/Store.hpp
Luca Wehrstedt a1780432fa Move c10d to libtorch(_cuda) (#59563)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/59563

ghstack-source-id: 131331264

Test Plan: CI

Reviewed By: malfet

Differential Revision: D28932239

fbshipit-source-id: 5df6cdfa5253b15cbbc97039fe672d6d97321e34
2021-06-15 02:01:31 -07:00

85 lines
2.5 KiB
C++

#pragma once
#include <chrono>
#include <cstdint>
#include <stdexcept>
#include <string>
#include <vector>
#include <c10/macros/Macros.h>
#include <torch/custom_class.h>
namespace c10d {
// callback function will be given arguments (optional<string> oldValue,
// optional<string> newValue)
using WatchKeyCallback =
std::function<void(c10::optional<std::string>, c10::optional<std::string>)>;
class TORCH_API Store : public torch::CustomClassHolder {
public:
static constexpr std::chrono::milliseconds kDefaultTimeout =
std::chrono::seconds(300);
static constexpr std::chrono::milliseconds kNoTimeout =
std::chrono::milliseconds::zero();
Store() : timeout_(kDefaultTimeout) {}
explicit Store(const std::chrono::milliseconds& timeout)
: timeout_(timeout) {}
virtual ~Store();
virtual void set(
const std::string& key,
const std::vector<uint8_t>& value) = 0;
virtual std::vector<uint8_t> compareSet(
const std::string& key,
const std::vector<uint8_t>& currentValue,
const std::vector<uint8_t>& newValue) {
TORCH_INTERNAL_ASSERT(false, "Not implemented.");
}
virtual std::vector<uint8_t> get(const std::string& key) = 0;
virtual int64_t add(const std::string& key, int64_t value) = 0;
virtual bool deleteKey(const std::string& key) = 0;
virtual bool check(const std::vector<std::string>& keys) = 0;
virtual int64_t getNumKeys() = 0;
virtual void wait(const std::vector<std::string>& keys) = 0;
virtual void wait(
const std::vector<std::string>& keys,
const std::chrono::milliseconds& timeout) = 0;
virtual const std::chrono::milliseconds& getTimeout() const noexcept;
virtual void setTimeout(const std::chrono::milliseconds& timeout);
// watchKey() takes two arguments: key and callback function. The callback
// should be run whenever the key is changed (create, update, or delete). The
// callback function takes two parameters: currentValue and newValue, which
// are optional depending on how the key is changed. These key updates should
// trigger the callback as follows:
// CREATE: callback(c10::nullopt, newValue) // null currentValue
// UPDATE: callback(currentValue, newValue)
// DELETE: callback(currentValue, c10::nullopt) // null newValue
virtual void watchKey(
const std::string& /* unused */,
WatchKeyCallback /* unused */) {
TORCH_CHECK(
false,
"watchKey only implemented for TCPStore and PrefixStore that wraps TCPStore.");
}
protected:
std::chrono::milliseconds timeout_;
};
} // namespace c10d