Add VariableTensorId, store it in TensorTypeSet (#25597)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25597

We now take advantage of the new bitset representation TensorTypeSet to store "Variable-ness" of a tensor directly in the dispatch key. We introduce a new thread local TensorTypeSet "excluded" and replace the previous thread local boolean with it; we no longer have to query `is_variable()` to do dispatch (I didn't delete `is_variable`, because there are still a lot of uses of it). The key change is in `dispatchTypeId`.

Knock-on effects:
* Because Variable is now a TensorTypeId, I can eliminate the out-of-line registration `registerVariableOp` for variables; instead, make the registrar take a TensorTypeId (instead of a Backend) and you just register under the Variable key.
* Tensors aren't really ever created with Variable information initialized correctly at the start; instead, a tensor "becomes" a Variable because we set its `autograd_meta_`. These setters now correctly setup invariants on the dispatch type set. The new invariant is that if `autograd_meta_ != nullptr`, then `type_set().has(TensorTypeId::VariableTensorId)`.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Differential Revision: D17265919

Pulled By: ezyang

fbshipit-source-id: a90a7ed14f5cb1086137483ae3d0646fcd4c42d0
This commit is contained in:
Edward Yang 2019-09-11 08:55:51 -07:00 committed by Facebook Github Bot
parent ba9fda14a7
commit 2080a15860
17 changed files with 632 additions and 582 deletions

View File

@ -1,12 +1,15 @@
#pragma once
#include <c10/core/Backend.h>
#include <c10/core/TensorTypeSet.h>
#include <c10/core/Backend.h>
#include <c10/core/impl/LocalTensorTypeSet.h>
#include <unordered_map>
#include <c10/util/C++17.h>
#include <memory>
#include <mutex>
// TODO: Rewrite this comment
//
// This dispatch class serves as a replacement for our previous dispatch
// mechanism, in which all functions were members of a Type class. A derived
// class existed for each backend (and Variable), and the vtable was used to
@ -19,7 +22,7 @@ namespace at {
namespace impl {
// Take a TensorTypeSet for a Tensor, and combine it with the current thread
// local set of inclusions and exclusions (not yet implemented, coming soon!)
// local valid (implemented) and enabled (not implemented) TensorTypeSets
// to determine what the actual dispatch TensorTypeId should be. Unlike
// Tensor::type_set(), the value of this on a tensor can change depending
// on TLS.
@ -30,8 +33,7 @@ namespace impl {
// question is whether or not we have access to all the relevant TLS at this
// point.
static inline TensorTypeId dispatchTypeId(TensorTypeSet ts) {
// TODO: Account for TLS!
return ts.highestPriorityTypeId();
return (ts - c10::impl::tls_excluded_tensor_type_set()).highestPriorityTypeId();
}
}
@ -44,70 +46,53 @@ class CAFFE2_API ATenOpTable {
: schema_(std::move(schema)) {}
template<class FuncType>
FuncType* getOp(TensorTypeSet ts, bool is_variable) const {
if (is_variable) {
return reinterpret_cast<FuncType*>(getVariableOp());
}
return reinterpret_cast<FuncType*>(getBaseOp(tensorTypeIdToBackend(impl::dispatchTypeId(ts))));
FuncType* getOp(TensorTypeSet ts) const {
return reinterpret_cast<FuncType*>(getOp(impl::dispatchTypeId(ts)));
}
private:
void registerOp(Backend backend, void* fn) {
TORCH_CHECK(function_table_[static_cast<int64_t>(backend)] == nullptr,
"Attempting to register variable function for schema ", schema_,
" and backend ", toString(backend),
void registerOp(TensorTypeId tid, void* fn) {
TORCH_CHECK(function_table_[static_cast<int64_t>(tid)] == nullptr,
"Attempting to register function for schema ", schema_,
" and tensor type ", toString(tid),
" but there is already a function registered");
function_table_[static_cast<int64_t>(backend)] = fn;
function_table_[static_cast<int64_t>(tid)] = fn;
}
void registerVariableOp(void* fn) {
TORCH_CHECK(variable_function_ == nullptr,
"Attempting to register variable function for schema ", schema_,
" but there is already a function registered");
variable_function_ = fn;
void* getOp(TensorTypeId tid) const {
// You might think we can minorly optimize this further by maintaining a
// bitmask of registered operator keys, so we don't select dispatch ids
// which don't have implementations here. But the net effect is that if you
// get a Variable CPUTensor, if there is no variable registration, you'll
// fall back to the CPU implementation. Is this what you want? Unlikely...
if (function_table_[static_cast<int64_t>(tid)] == nullptr) {
TORCH_CHECK(function_table_[static_cast<int64_t>(TensorTypeId::UndefinedTensorId)] != nullptr,
"No function is registered for schema ", schema_, " on tensor type ", toString(tid));
return function_table_[static_cast<int64_t>(TensorTypeId::UndefinedTensorId)];
}
void* getBaseOp(Backend backend) const {
if (function_table_[static_cast<int64_t>(backend)] == nullptr) {
TORCH_CHECK(function_table_[static_cast<int64_t>(Backend::Undefined)] != nullptr,
"No function is registered for schema ", schema_, " on backend ", toString(backend));
return function_table_[static_cast<int64_t>(Backend::Undefined)];
}
return function_table_[static_cast<int64_t>(backend)];
}
void* getVariableOp() const {
TORCH_CHECK(variable_function_ != nullptr,
"No variable function registered for ", schema_);
return variable_function_;
return function_table_[static_cast<int64_t>(tid)];
}
friend class ATenDispatch;
std::string schema_;
void* function_table_[static_cast<int64_t>(Backend::NumOptions)] = {nullptr};
void* variable_function_ = nullptr;
void* function_table_[static_cast<int64_t>(TensorTypeId::NumTensorIds)] = {nullptr};
};
class CAFFE2_API ATenDispatch {
public:
template<class FuncType>
ATenDispatch& registerOp(Backend backend, const char* schema, FuncType* fn) {
ATenDispatch& registerOp(TensorTypeId id, const char* schema, FuncType* fn) {
std::lock_guard<std::mutex> lock(mutex_);
if (op_tables_.find(schema) == op_tables_.end()) {
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
}
op_tables_.at(schema).registerOp(backend, reinterpret_cast<void*>(fn));
op_tables_.at(schema).registerOp(id, reinterpret_cast<void*>(fn));
return *this;
}
template<class FuncType>
ATenDispatch& registerVariableOp(const char* schema, FuncType* fn) {
std::lock_guard<std::mutex> lock(mutex_);
if (op_tables_.find(schema) == op_tables_.end()) {
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
}
op_tables_.at(schema).registerVariableOp(reinterpret_cast<void*>(fn));
return *this;
ATenDispatch& registerOp(Backend b, const char* schema, FuncType* fn) {
return registerOp(backendToTensorTypeId(b), schema, fn);
}
const ATenOpTable* getOpTable(const char* schema) const {

View File

@ -20,11 +20,11 @@ namespace at {
class CAFFE2_API LegacyTypeDispatch {
public:
void initForTensorTypeSet(TensorTypeSet ts) {
// TODO: When Variable gets turned on in TensorTypeSet, this
// will skip initialization when you initially process a
// Variable CUDA tensor, for example (because I'll get Variable
// and it's not gonna have any device type.) Is that OK?
auto b = tensorTypeIdToBackend(impl::dispatchTypeId(ts));
// TODO: Avoid use of legacyExtractTypeId here. The key
// problem is that you may get a TensorTypeSet with
// VariableTensorId set; should you initialize the "underlying"
// type in that case? Hard to say.
auto b = tensorTypeIdToBackend(legacyExtractTypeId(ts));
auto p = backendToDeviceType(b);
static std::once_flag cpu_once;
static std::once_flag cuda_once;

View File

@ -218,10 +218,7 @@ class CAFFE2_API Tensor {
DeprecatedTypeProperties & type() const {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
// TODO: When we build in Variable here, we need to change the
// signature of getDeprecatedTypeProperties to collapse backend
// and is_variable into TensorTypeSet
tensorTypeIdToBackend(type_set().highestPriorityTypeId()),
tensorTypeIdToBackend(legacyExtractTypeId(type_set())),
scalar_type(),
is_variable());
}
@ -932,14 +929,6 @@ inline TensorTypeSet infer_tensor_type_set(TensorList tl) {
return tl[0].type_set();
}
inline bool infer_is_variable(const Tensor & t) {
TORCH_CHECK(t.defined(), "undefined Tensor");
return t.is_variable();
}
inline bool infer_is_variable(const TensorList & tl) {
TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors");
return tl[0].is_variable();
}
} // namespace detail
static inline TensorTypeId legacyExtractTypeId(const Tensor& t) {

File diff suppressed because it is too large Load Diff

View File

@ -216,9 +216,11 @@ private:
if (tensor_list.size() == 0) {
throw std::runtime_error("Tried to dispatch operator " + operator_name + " based on an empty tensor list. When the first tensor argument of an operator is a tensor list, then it must not be empty.");
}
return at::impl::dispatchTypeId(tensor_list[0].type_set());
// TODO: Don't use legacy extractor; blocked on c10 understanding
// variable
return c10::legacyExtractTypeId(tensor_list[0].type_set());
} else {
return at::impl::dispatchTypeId(first_tensor_arg.unsafeToTensorImpl()->type_set());
return c10::legacyExtractTypeId(first_tensor_arg.unsafeToTensorImpl()->type_set());
}
}
};

View File

@ -113,10 +113,10 @@ ${return_type} ${Type}::${api_name}(${type_method_formals}) {
""")
DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\
.registerOp<${return_type} (${formals_types})>(Backend::Undefined, "${schema_string}", &TypeDefault::${api_name})
.registerOp<${return_type} (${formals_types})>(TensorTypeId::UndefinedTensorId, "${schema_string}", &TypeDefault::${api_name})
""")
BACKEND_FUNCTION_REGISTRATION = CodeTemplate("""\
.registerOp<${return_type} (${formals_types})>(Backend::${Backend}, "${schema_string}", &${Type}::${api_name})
.registerOp<${return_type} (${formals_types})>(TensorTypeId::${Backend}TensorId, "${schema_string}", &${Type}::${api_name})
""")
# Generate a file that lists all functions and their schema string. Used for XLA
@ -136,7 +136,7 @@ inline ${return_type} Tensor::${api_name}(${method_formals}) const {
${static_dispatch_method_body}
#else
static auto table = globalATenDispatch().getOpTable("${schema_string}");
return table->getOp<${return_type} (${formals_types})>(type_set(), is_variable())(${method_actuals});
return table->getOp<${return_type} (${formals_types})>(type_set())(${method_actuals});
#endif
}
""")
@ -155,7 +155,7 @@ static inline ${return_type} ${api_name}(${formals}) {
${static_dispatch_function_body}
#else
static auto table = globalATenDispatch().getOpTable("${schema_string}");
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set}, ${inferred_is_variable})(${native_actuals});
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${native_actuals});
#endif
}
""")
@ -191,7 +191,7 @@ static inline ${return_type} ${api_name}(${formals}) {
#else
globalLegacyTypeDispatch().initForTensorTypeSet(${inferred_type_set});
static auto table = globalATenDispatch().getOpTable("${schema_string}");
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set}, ${inferred_is_variable})(${native_actuals});
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${native_actuals});
#endif
}
""")
@ -552,7 +552,6 @@ FunctionOption = TypedDict('FunctionOption', {
'formals': List[str],
'formals_types': List[str],
'inferred_type_set': str,
'inferred_is_variable': str,
'inplace': bool,
'matches_jit_signature': bool,
# This controls whether or not we generate the interface in Type or
@ -1095,6 +1094,18 @@ def create_generic(top_env, declarations):
# type: (Any) -> FunctionCode
if isinstance(type_method_dispatch, dict):
static_dispatch_function_switches = []
# NB: As this code is currently written, there will NEVER be
# a backend generated for variable dispatch. There is nothing
# stopping us from actually implementing this, however, if you
# really wanted variable on mobile, there's nothing stopping
# you from implementing this (however, you would have an
# annoying phase problem, since code generation for variable
# happens in tools/ which happens later than here.)
#
# If you pass in a variable to the dispatch, and variable is
# enabled, this switch will fail. This is intentional: you
# probably need to disable variable globally in the mobile
# calling code.
for backend in static_dispatch_backends:
if backend in type_method_dispatch:
static_dispatch_function_switches.append(STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT.substitute(
@ -1104,10 +1115,6 @@ def create_generic(top_env, declarations):
native_arguments=option['method_actuals']))
static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
option,
# TODO: When Variable gets added, this needs to get adjusted
# to avoid picking up the Variable bit. The correct way
# to encode this is probably to just have Variable in the
# disabled set.
type_set='type_set()',
static_dispatch_function_switches=static_dispatch_function_switches)
else:
@ -1122,15 +1129,12 @@ def create_generic(top_env, declarations):
# type: (Any, Optional[str], Any) -> FunctionCode
if dispatch_tensor:
option['inferred_type_set'] = 'at::detail::infer_tensor_type_set({})'.format(dispatch_tensor)
option['inferred_is_variable'] = 'at::detail::infer_is_variable({})'.format(dispatch_tensor)
elif dispatch_options:
option['inferred_type_set'] = '{}.type_set()'.format(dispatch_options['name'])
option['inferred_is_variable'] = '{}.is_variable()'.format(dispatch_options['name'])
else:
# doesn't depend on a specific backend, use the empty set
# TODO: Does this actually work?
option['inferred_type_set'] = 'TensorTypeSet()'
option['inferred_is_variable'] = 'false'
declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
fn_declaration = declaration.substitute(option)

View File

@ -218,10 +218,7 @@ class CAFFE2_API Tensor {
DeprecatedTypeProperties & type() const {
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
// TODO: When we build in Variable here, we need to change the
// signature of getDeprecatedTypeProperties to collapse backend
// and is_variable into TensorTypeSet
tensorTypeIdToBackend(type_set().highestPriorityTypeId()),
tensorTypeIdToBackend(legacyExtractTypeId(type_set())),
scalar_type(),
is_variable());
}
@ -439,14 +436,6 @@ inline TensorTypeSet infer_tensor_type_set(TensorList tl) {
return tl[0].type_set();
}
inline bool infer_is_variable(const Tensor & t) {
TORCH_CHECK(t.defined(), "undefined Tensor");
return t.is_variable();
}
inline bool infer_is_variable(const TensorList & tl) {
TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors");
return tl[0].is_variable();
}
} // namespace detail
static inline TensorTypeId legacyExtractTypeId(const Tensor& t) {

View File

@ -2,6 +2,7 @@
#include <c10/core/Backend.h>
#include <c10/core/WrapDimMinimal.h>
#include <c10/core/impl/LocalTensorTypeSet.h>
#include <c10/util/Optional.h>
C10_DEFINE_bool(
@ -57,7 +58,7 @@ TensorImpl::TensorImpl(Storage&& storage, TensorTypeSet type_set, const caffe2::
numel_(0),
data_type_(data_type),
device_opt_(device_opt),
type_set_(type_set) {
type_set_(type_set.remove(TensorTypeId::VariableTensorId)) {
if (!type_set.empty()) {
AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() ||
device_opt_.has_value());
@ -210,52 +211,12 @@ int64_t NamedTensorMetaInterface::slow_dim() const {
}
#endif
/// NOTE [ Treating Variables as non-Variables in type dispatch ]
///
/// Previously, in VariableType_*.cpp (generated by gen_variable_type.py), when
/// a function is using the 'use_derived' strategy, we call its implementation
/// on the base non-Variable type (`baseType`), passing unwrapped tensors to the
/// call so that any `.dispatch_type()` calls in the implementation can treat the passed
/// tensors as non-Variables and won't dispatch back to functions in VariableType.
///
/// However, after the Variable/Tensor merge, there is no concept of unwrapping
/// a tensor anymore, and directly passing variables to the base type calls will
/// cause the `.dispatch_type()` dispatch in the implementation to treat the tensor as a
/// variable, and any function dispatch based on `.dispatch_type()` will dispatch back to
/// VariableType, which is not what we want.
///
/// The solution to the above problem is to add `at::NonVariableTypeMode`, which
/// when enabled will cause `legacyTensorType()` and `getType()` to always return
/// non-Variable type, even if the tensor being called on is a variable.
///
/// TODO: Since `torch::NoGradGuard` serves the same purpose in libtorch, we should
/// merge these two thread-local guards.
/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting,
/// thread_local is not supported. In that case, we don't provide
/// `at::NonVariableTypeMode`.
#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY
thread_local bool NonVariableTypeMode_enabled = false;
bool NonVariableTypeMode::is_enabled() {
return NonVariableTypeMode_enabled;
return !impl::tls_variable_is_enabled();
}
void NonVariableTypeMode::set_enabled(bool enabled) {
NonVariableTypeMode_enabled = enabled;
impl::tls_variable_set_enabled(!enabled);
}
#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
bool NonVariableTypeMode::is_enabled() {
throw std::runtime_error("NonVariableTypeMode is not supported on mobile");
}
void NonVariableTypeMode::set_enabled(bool enabled) {
throw std::runtime_error("NonVariableTypeMode is not supported on mobile");
}
#endif
} // namespace c10

View File

@ -826,6 +826,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
*/
void set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
autograd_meta_ = std::move(autograd_meta);
if (autograd_meta_) {
type_set_ = type_set_.add(TensorTypeId::VariableTensorId);
} else {
type_set_ = type_set_.remove(TensorTypeId::VariableTensorId);
}
}
/**
@ -839,6 +844,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
* Detach the autograd metadata unique_ptr from this tensor, and return it.
*/
std::unique_ptr<c10::AutogradMetaInterface> detach_autograd_meta() {
type_set_ = type_set_.remove(TensorTypeId::VariableTensorId);
return std::move(autograd_meta_);
}
@ -1522,7 +1528,15 @@ protected:
dest_impl->storage_offset_ = src_impl->storage_offset_;
dest_impl->data_type_ = src_impl->data_type_;
dest_impl->device_opt_ = src_impl->device_opt_;
// This may temporarily violate invariant that
// type_set_.has(VariableTensorId) iff autograd_meta_ != nullptr...
dest_impl->type_set_ = src_impl->type_set_;
// ...so refresh Variable in autograd_meta_
if (dest_impl->autograd_meta_) {
dest_impl->type_set_ = dest_impl->type_set_.add(TensorTypeId::VariableTensorId);
} else {
dest_impl->type_set_ = dest_impl->type_set_.remove(TensorTypeId::VariableTensorId);
}
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
dest_impl->reserved_ = src_impl->reserved_;
@ -1543,12 +1557,17 @@ protected:
static const char * const err_msg_tensor_metadata_change_not_allowed;
Storage storage_;
private:
// This pointer points to an AutogradMeta struct that stores autograd-specific fields
// (such as grad_ / grad_fn_ / grad_accumulator_).
// This pointer always has unique ownership (meaning only one TensorImpl can own it
// at a time).
// This is private because we must maintain dispatcher invariants on it
// in type_set_.
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
protected:
#ifdef BUILD_NAMEDTENSOR
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;
#endif

View File

@ -378,8 +378,9 @@ struct C10_API TensorOptions {
// Resolves the tensor type set specified by the current construction axes.
TensorTypeSet type_set() const noexcept {
// TODO: This should also contain variable eventually
return TensorTypeSet(computeTensorTypeId());
auto r = TensorTypeSet(computeTensorTypeId());
if (is_variable()) r = r.add(TensorTypeId::VariableTensorId);
return r;
}
inline TensorTypeId computeTensorTypeId() const {

View File

@ -38,6 +38,8 @@ const char* toString(TensorTypeId t) {
return "ComplexCPUTensorId";
case TensorTypeId::ComplexCUDATensorId:
return "ComplexCUDATensorId";
case TensorTypeId::VariableTensorId:
return "VariableTensorId";
default:
return "UNKNOWN_TENSOR_TYPE_ID";
}

View File

@ -40,7 +40,12 @@ enum class TensorTypeId : uint8_t {
ComplexCPUTensorId, // PyTorch only
ComplexCUDATensorId, // PyTorch only
// VariableTensorId, // upcoming!
// WARNING! If you add more "wrapper" style tensor ids (tensor
// ids which don't get kernels directly defined in native_functions.yaml;
// examples are tracing or profiling) here, you need to also adjust
// legacyExtractTypeId in c10/core/TensorTypeId.h to mask them out.
VariableTensorId,
NumTensorIds, // Sentinel
};

View File

@ -33,7 +33,19 @@ namespace c10 {
// An undefined tensor is one with an empty tensor type set.
class TensorTypeSet final {
public:
TensorTypeSet() {}
enum Full { FULL };
enum Raw { RAW };
// NB: default constructor representation as zero is MANDATORY as
// use of TensorTypeSet in TLS requires this.
TensorTypeSet()
: repr_(0) {}
TensorTypeSet(Full)
: repr_(-1) {}
// Public version of TensorTypeSet(uint64_t) API; external users
// must be explicit when they do this!
TensorTypeSet(Raw, uint64_t x)
: repr_(x) {}
explicit TensorTypeSet(TensorTypeId t)
: repr_(t == TensorTypeId::UndefinedTensorId
? 0
@ -47,6 +59,14 @@ public:
TensorTypeSet operator|(TensorTypeSet other) const {
return TensorTypeSet(repr_ | other.repr_);
}
// Perform set intersection
TensorTypeSet operator&(TensorTypeSet other) const {
return TensorTypeSet(repr_ & other.repr_);
}
// Compute the set difference self - other
TensorTypeSet operator-(TensorTypeSet other) const {
return TensorTypeSet(repr_ & ~other.repr_);
}
// Perform set equality
bool operator==(TensorTypeSet other) const {
return repr_ == other.repr_;
@ -66,6 +86,7 @@ public:
bool empty() const {
return repr_ == 0;
}
uint64_t raw_repr() { return repr_; }
// Return the type id in this set with the highest priority (i.e.,
// is the largest in the TensorTypeId enum). Intuitively, this
// type id is the one that should handle dispatch (assuming there
@ -98,10 +119,10 @@ C10_API std::ostream& operator<<(std::ostream&, TensorTypeSet);
// but s.has(VariableTensorId) will evaluate to true if s has VariableTensorId.
// For non-VariableTensorId equality tests, they are indistinguishable.
//
// TODO: this will need to change when we add VariableTensorId to the
// set of IDs put in TensorTypeSet.
// NB: If you add other non-VariableTensorId other keys to this set, you'll
// have to adjust this some more (sorry.)
static inline TensorTypeId legacyExtractTypeId(TensorTypeSet s) {
return s.highestPriorityTypeId();
return s.remove(TensorTypeId::VariableTensorId).highestPriorityTypeId();
}
}

View File

@ -0,0 +1,42 @@
#include <c10/core/impl/LocalTensorTypeSet.h>
#include <iostream>
namespace c10 {
namespace impl {
namespace {
/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting,
/// thread_local is not supported. In that case, we don't provide
/// `at::NonVariableTypeMode`.
#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY
// NB: Zero initialized!
thread_local uint64_t raw_excluded;
#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
uint64_t raw_excluded = 0;
#endif
}
TensorTypeSet tls_excluded_tensor_type_set() {
return TensorTypeSet(TensorTypeSet::RAW, raw_excluded);
}
bool tls_variable_is_enabled() {
return !tls_excluded_tensor_type_set().has(TensorTypeId::VariableTensorId);
}
void tls_variable_set_enabled(bool enabled) {
if (enabled) {
raw_excluded = tls_excluded_tensor_type_set().remove(TensorTypeId::VariableTensorId).raw_repr();
} else {
raw_excluded = tls_excluded_tensor_type_set().add(TensorTypeId::VariableTensorId).raw_repr();
}
}
}} // namespace c10::impl

View File

@ -0,0 +1,22 @@
#include <c10/core/TensorTypeSet.h>
// TLS management for TensorTypeSet
//
// This manages thread-local TensorTypeSet of excluded keys which disqualify
// tensor types from dispatch. Keys which are in this set, even if they appear
// in a list of potential valid keys on a tensor, are not considered for
// dispatch. This is used to, for example, turn off autograd after we have
// handled autograd for a top-level element.
//
// Originally, I implemented this as storing the inverted set, but
// TLS is defined to be zero-initialized, so this doesn't actually work
// (you want the set to be -1 initialized).
namespace c10 {
namespace impl {
C10_API bool tls_variable_is_enabled();
C10_API void tls_variable_set_enabled(bool enabled);
C10_API TensorTypeSet tls_excluded_tensor_type_set();
}} // namespace c10::impl

View File

@ -45,3 +45,11 @@ TEST(TensorTypeSet, Doubleton) {
}
}
}
TEST(TensorTypeSet, Full) {
TensorTypeSet full(TensorTypeSet::FULL);
for (uint8_t i = 1; i < static_cast<uint8_t>(TensorTypeId::NumTensorIds); i++) {
auto tid = static_cast<TensorTypeId>(i);
ASSERT_TRUE(full.has(tid));
}
}

View File

@ -153,7 +153,7 @@ ${return_type} VariableType::${api_name}(${type_method_formals}) {
""")
WRAPPER_REGISTRATION = CodeTemplate("""\
.registerVariableOp<${return_type} (${formal_types})>("${schema_string}", &VariableType::${api_name})
.registerOp<${return_type} (${formal_types})>(TensorTypeId::VariableTensorId, "${schema_string}", &VariableType::${api_name})
""")
UNPACK_TENSOR = CodeTemplate("""\