mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
s/Facebook, Inc./Meta Platforms, Inc/ s/Confidential and proprietary./This source code is licensed under the BSD-style license/ Per https://www.internalfb.com/intern/wiki/Open_Source/Licenses/Straight_BSD/ Also, add linter that prevents adding those in the future Fixes https://github.com/pytorch/pytorch/issues/90187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/90805 Approved by: https://github.com/zpao
138 lines
3.8 KiB
C++
138 lines
3.8 KiB
C++
// Copyright (c) Meta Platforms, Inc. and its affiliates.
|
|
//
|
|
// This source code is licensed under the BSD-style license found in the
|
|
// LICENSE file in the root directory of this source tree.
|
|
|
|
#include <map>
|
|
|
|
#include <c10/util/string_view.h>
|
|
#include <c10/util/Flags.h>
|
|
#include <c10/util/Logging.h>
|
|
|
|
#include "caffe2/utils/knobs.h"
|
|
#include "caffe2/utils/knob_patcher.h"
|
|
|
|
namespace caffe2 {
|
|
namespace detail {
|
|
std::map<c10::string_view, bool*>& getRegisteredKnobs();
|
|
} // namespace detail
|
|
|
|
namespace {
|
|
class PatchNode {
|
|
public:
|
|
PatchNode(c10::string_view name, bool value);
|
|
~PatchNode();
|
|
|
|
std::string name;
|
|
bool oldValue{false};
|
|
// Nodes to form a linked list of existing PatchState objects for this knob.
|
|
// This allows us to restore state correctly even if KnobPatcher objects
|
|
// are destroyed in any arbitrary order.
|
|
PatchNode* prev{nullptr};
|
|
PatchNode* next{nullptr};
|
|
};
|
|
} // namespace
|
|
|
|
class KnobPatcher::PatchState : public PatchNode {
|
|
using PatchNode::PatchNode;
|
|
};
|
|
|
|
KnobPatcher::KnobPatcher(c10::string_view name, bool value)
|
|
: state_{std::make_unique<PatchState>(name, value)} {}
|
|
|
|
KnobPatcher::~KnobPatcher() = default;
|
|
KnobPatcher::KnobPatcher(KnobPatcher&&) noexcept = default;
|
|
KnobPatcher& KnobPatcher::operator=(KnobPatcher&&) noexcept = default;
|
|
|
|
namespace {
|
|
|
|
class Patcher {
|
|
public:
|
|
void patch(PatchNode* node, bool value) {
|
|
std::lock_guard<std::mutex> lock{mutex_};
|
|
|
|
node->oldValue = setKnobValue(node->name, value);
|
|
auto ret = patches_.emplace(node->name, node);
|
|
if (!ret.second) {
|
|
// There was already another patcher for this knob
|
|
// Append the new node to the linked list.
|
|
node->prev = ret.first->second;
|
|
CHECK(!node->prev->next);
|
|
node->prev->next = node;
|
|
ret.first->second = node;
|
|
}
|
|
}
|
|
|
|
void unpatch(PatchNode* node) {
|
|
std::lock_guard<std::mutex> lock{mutex_};
|
|
|
|
// Remove this PatchNode from the linked list
|
|
if (node->prev) {
|
|
node->prev->next = node->next;
|
|
}
|
|
if (node->next) {
|
|
// There was another patch applied after this one.
|
|
node->next->prev = node->prev;
|
|
node->next->oldValue = node->oldValue;
|
|
} else {
|
|
// This was the most recently applied patch for this knob,
|
|
// so restore the knob value.
|
|
setKnobValue(node->name, node->oldValue);
|
|
|
|
// The patches_ map should point to this node.
|
|
// Update it to point to the previous patch, if there is one.
|
|
auto iter = patches_.find(node->name);
|
|
if (iter == patches_.end()) {
|
|
LOG(FATAL) << "patch node not found when unpatching knob value";
|
|
}
|
|
TORCH_CHECK_EQ(iter->second, node);
|
|
if (node->prev) {
|
|
iter->second = node->prev;
|
|
} else {
|
|
patches_.erase(iter);
|
|
}
|
|
}
|
|
}
|
|
|
|
private:
|
|
bool setKnobValue(c10::string_view name, bool value) {
|
|
auto& knobs = caffe2::detail::getRegisteredKnobs();
|
|
auto iter = knobs.find(name);
|
|
if (iter == knobs.end()) {
|
|
throw std::invalid_argument(
|
|
"attempted to patch unknown knob \"" + std::string(name) + "\"");
|
|
}
|
|
bool oldValue = *(iter->second);
|
|
*iter->second = value;
|
|
return oldValue;
|
|
}
|
|
|
|
std::mutex mutex_;
|
|
std::map<std::string, PatchNode*> patches_;
|
|
};
|
|
|
|
Patcher& getPatcher() {
|
|
static Patcher patcher;
|
|
return patcher;
|
|
}
|
|
|
|
PatchNode::PatchNode(c10::string_view knobName, bool value)
|
|
: name{knobName} {
|
|
getPatcher().patch(this, value);
|
|
}
|
|
|
|
PatchNode::~PatchNode() {
|
|
try {
|
|
getPatcher().unpatch(this);
|
|
} catch (const std::exception& ex) {
|
|
// This shouldn't ever happen unless we have a programming bug, but it keeps
|
|
// clang-tidy happy if we put a catch block here to handle the theoretical
|
|
// error if unpatch() calls setKnobValue() and it throws due to not finding
|
|
// the knob by name.
|
|
LOG(FATAL) << "error removing knob patch: " << ex.what();
|
|
}
|
|
}
|
|
|
|
} // namespace
|
|
} // namespace caffe2
|