mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add VariableTensorId, store it in TensorTypeSet (#25597)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25597 We now take advantage of the new bitset representation TensorTypeSet to store "Variable-ness" of a tensor directly in the dispatch key. We introduce a new thread local TensorTypeSet "excluded" and replace the previous thread local boolean with it; we no longer have to query `is_variable()` to do dispatch (I didn't delete `is_variable`, because there are still a lot of uses of it). The key change is in `dispatchTypeId`. Knock-on effects: * Because Variable is now a TensorTypeId, I can eliminate the out-of-line registration `registerVariableOp` for variables; instead, make the registrar take a TensorTypeId (instead of a Backend) and you just register under the Variable key. * Tensors aren't really ever created with Variable information initialized correctly at the start; instead, a tensor "becomes" a Variable because we set its `autograd_meta_`. These setters now correctly setup invariants on the dispatch type set. The new invariant is that if `autograd_meta_ != nullptr`, then `type_set().has(TensorTypeId::VariableTensorId)`. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D17265919 Pulled By: ezyang fbshipit-source-id: a90a7ed14f5cb1086137483ae3d0646fcd4c42d0
This commit is contained in:
parent
ba9fda14a7
commit
2080a15860
|
|
@ -1,12 +1,15 @@
|
|||
#pragma once
|
||||
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/core/TensorTypeSet.h>
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/core/impl/LocalTensorTypeSet.h>
|
||||
#include <unordered_map>
|
||||
#include <c10/util/C++17.h>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
|
||||
// TODO: Rewrite this comment
|
||||
//
|
||||
// This dispatch class serves as a replacement for our previous dispatch
|
||||
// mechanism, in which all functions were members of a Type class. A derived
|
||||
// class existed for each backend (and Variable), and the vtable was used to
|
||||
|
|
@ -19,7 +22,7 @@ namespace at {
|
|||
namespace impl {
|
||||
|
||||
// Take a TensorTypeSet for a Tensor, and combine it with the current thread
|
||||
// local set of inclusions and exclusions (not yet implemented, coming soon!)
|
||||
// local valid (implemented) and enabled (not implemented) TensorTypeSets
|
||||
// to determine what the actual dispatch TensorTypeId should be. Unlike
|
||||
// Tensor::type_set(), the value of this on a tensor can change depending
|
||||
// on TLS.
|
||||
|
|
@ -30,8 +33,7 @@ namespace impl {
|
|||
// question is whether or not we have access to all the relevant TLS at this
|
||||
// point.
|
||||
static inline TensorTypeId dispatchTypeId(TensorTypeSet ts) {
|
||||
// TODO: Account for TLS!
|
||||
return ts.highestPriorityTypeId();
|
||||
return (ts - c10::impl::tls_excluded_tensor_type_set()).highestPriorityTypeId();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -44,70 +46,53 @@ class CAFFE2_API ATenOpTable {
|
|||
: schema_(std::move(schema)) {}
|
||||
|
||||
template<class FuncType>
|
||||
FuncType* getOp(TensorTypeSet ts, bool is_variable) const {
|
||||
if (is_variable) {
|
||||
return reinterpret_cast<FuncType*>(getVariableOp());
|
||||
}
|
||||
return reinterpret_cast<FuncType*>(getBaseOp(tensorTypeIdToBackend(impl::dispatchTypeId(ts))));
|
||||
FuncType* getOp(TensorTypeSet ts) const {
|
||||
return reinterpret_cast<FuncType*>(getOp(impl::dispatchTypeId(ts)));
|
||||
}
|
||||
private:
|
||||
void registerOp(Backend backend, void* fn) {
|
||||
TORCH_CHECK(function_table_[static_cast<int64_t>(backend)] == nullptr,
|
||||
"Attempting to register variable function for schema ", schema_,
|
||||
" and backend ", toString(backend),
|
||||
void registerOp(TensorTypeId tid, void* fn) {
|
||||
TORCH_CHECK(function_table_[static_cast<int64_t>(tid)] == nullptr,
|
||||
"Attempting to register function for schema ", schema_,
|
||||
" and tensor type ", toString(tid),
|
||||
" but there is already a function registered");
|
||||
function_table_[static_cast<int64_t>(backend)] = fn;
|
||||
function_table_[static_cast<int64_t>(tid)] = fn;
|
||||
}
|
||||
|
||||
void registerVariableOp(void* fn) {
|
||||
TORCH_CHECK(variable_function_ == nullptr,
|
||||
"Attempting to register variable function for schema ", schema_,
|
||||
" but there is already a function registered");
|
||||
variable_function_ = fn;
|
||||
void* getOp(TensorTypeId tid) const {
|
||||
// You might think we can minorly optimize this further by maintaining a
|
||||
// bitmask of registered operator keys, so we don't select dispatch ids
|
||||
// which don't have implementations here. But the net effect is that if you
|
||||
// get a Variable CPUTensor, if there is no variable registration, you'll
|
||||
// fall back to the CPU implementation. Is this what you want? Unlikely...
|
||||
if (function_table_[static_cast<int64_t>(tid)] == nullptr) {
|
||||
TORCH_CHECK(function_table_[static_cast<int64_t>(TensorTypeId::UndefinedTensorId)] != nullptr,
|
||||
"No function is registered for schema ", schema_, " on tensor type ", toString(tid));
|
||||
return function_table_[static_cast<int64_t>(TensorTypeId::UndefinedTensorId)];
|
||||
}
|
||||
|
||||
void* getBaseOp(Backend backend) const {
|
||||
if (function_table_[static_cast<int64_t>(backend)] == nullptr) {
|
||||
TORCH_CHECK(function_table_[static_cast<int64_t>(Backend::Undefined)] != nullptr,
|
||||
"No function is registered for schema ", schema_, " on backend ", toString(backend));
|
||||
return function_table_[static_cast<int64_t>(Backend::Undefined)];
|
||||
}
|
||||
return function_table_[static_cast<int64_t>(backend)];
|
||||
}
|
||||
|
||||
void* getVariableOp() const {
|
||||
TORCH_CHECK(variable_function_ != nullptr,
|
||||
"No variable function registered for ", schema_);
|
||||
return variable_function_;
|
||||
return function_table_[static_cast<int64_t>(tid)];
|
||||
}
|
||||
|
||||
friend class ATenDispatch;
|
||||
|
||||
std::string schema_;
|
||||
void* function_table_[static_cast<int64_t>(Backend::NumOptions)] = {nullptr};
|
||||
void* variable_function_ = nullptr;
|
||||
void* function_table_[static_cast<int64_t>(TensorTypeId::NumTensorIds)] = {nullptr};
|
||||
};
|
||||
|
||||
class CAFFE2_API ATenDispatch {
|
||||
public:
|
||||
template<class FuncType>
|
||||
ATenDispatch& registerOp(Backend backend, const char* schema, FuncType* fn) {
|
||||
ATenDispatch& registerOp(TensorTypeId id, const char* schema, FuncType* fn) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (op_tables_.find(schema) == op_tables_.end()) {
|
||||
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
|
||||
}
|
||||
op_tables_.at(schema).registerOp(backend, reinterpret_cast<void*>(fn));
|
||||
op_tables_.at(schema).registerOp(id, reinterpret_cast<void*>(fn));
|
||||
return *this;
|
||||
}
|
||||
|
||||
template<class FuncType>
|
||||
ATenDispatch& registerVariableOp(const char* schema, FuncType* fn) {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (op_tables_.find(schema) == op_tables_.end()) {
|
||||
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
|
||||
}
|
||||
op_tables_.at(schema).registerVariableOp(reinterpret_cast<void*>(fn));
|
||||
return *this;
|
||||
ATenDispatch& registerOp(Backend b, const char* schema, FuncType* fn) {
|
||||
return registerOp(backendToTensorTypeId(b), schema, fn);
|
||||
}
|
||||
|
||||
const ATenOpTable* getOpTable(const char* schema) const {
|
||||
|
|
|
|||
|
|
@ -20,11 +20,11 @@ namespace at {
|
|||
class CAFFE2_API LegacyTypeDispatch {
|
||||
public:
|
||||
void initForTensorTypeSet(TensorTypeSet ts) {
|
||||
// TODO: When Variable gets turned on in TensorTypeSet, this
|
||||
// will skip initialization when you initially process a
|
||||
// Variable CUDA tensor, for example (because I'll get Variable
|
||||
// and it's not gonna have any device type.) Is that OK?
|
||||
auto b = tensorTypeIdToBackend(impl::dispatchTypeId(ts));
|
||||
// TODO: Avoid use of legacyExtractTypeId here. The key
|
||||
// problem is that you may get a TensorTypeSet with
|
||||
// VariableTensorId set; should you initialize the "underlying"
|
||||
// type in that case? Hard to say.
|
||||
auto b = tensorTypeIdToBackend(legacyExtractTypeId(ts));
|
||||
auto p = backendToDeviceType(b);
|
||||
static std::once_flag cpu_once;
|
||||
static std::once_flag cuda_once;
|
||||
|
|
|
|||
|
|
@ -218,10 +218,7 @@ class CAFFE2_API Tensor {
|
|||
|
||||
DeprecatedTypeProperties & type() const {
|
||||
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
||||
// TODO: When we build in Variable here, we need to change the
|
||||
// signature of getDeprecatedTypeProperties to collapse backend
|
||||
// and is_variable into TensorTypeSet
|
||||
tensorTypeIdToBackend(type_set().highestPriorityTypeId()),
|
||||
tensorTypeIdToBackend(legacyExtractTypeId(type_set())),
|
||||
scalar_type(),
|
||||
is_variable());
|
||||
}
|
||||
|
|
@ -932,14 +929,6 @@ inline TensorTypeSet infer_tensor_type_set(TensorList tl) {
|
|||
return tl[0].type_set();
|
||||
}
|
||||
|
||||
inline bool infer_is_variable(const Tensor & t) {
|
||||
TORCH_CHECK(t.defined(), "undefined Tensor");
|
||||
return t.is_variable();
|
||||
}
|
||||
inline bool infer_is_variable(const TensorList & tl) {
|
||||
TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors");
|
||||
return tl[0].is_variable();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
static inline TensorTypeId legacyExtractTypeId(const Tensor& t) {
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load Diff
|
|
@ -216,9 +216,11 @@ private:
|
|||
if (tensor_list.size() == 0) {
|
||||
throw std::runtime_error("Tried to dispatch operator " + operator_name + " based on an empty tensor list. When the first tensor argument of an operator is a tensor list, then it must not be empty.");
|
||||
}
|
||||
return at::impl::dispatchTypeId(tensor_list[0].type_set());
|
||||
// TODO: Don't use legacy extractor; blocked on c10 understanding
|
||||
// variable
|
||||
return c10::legacyExtractTypeId(tensor_list[0].type_set());
|
||||
} else {
|
||||
return at::impl::dispatchTypeId(first_tensor_arg.unsafeToTensorImpl()->type_set());
|
||||
return c10::legacyExtractTypeId(first_tensor_arg.unsafeToTensorImpl()->type_set());
|
||||
}
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -113,10 +113,10 @@ ${return_type} ${Type}::${api_name}(${type_method_formals}) {
|
|||
""")
|
||||
|
||||
DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||
.registerOp<${return_type} (${formals_types})>(Backend::Undefined, "${schema_string}", &TypeDefault::${api_name})
|
||||
.registerOp<${return_type} (${formals_types})>(TensorTypeId::UndefinedTensorId, "${schema_string}", &TypeDefault::${api_name})
|
||||
""")
|
||||
BACKEND_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||
.registerOp<${return_type} (${formals_types})>(Backend::${Backend}, "${schema_string}", &${Type}::${api_name})
|
||||
.registerOp<${return_type} (${formals_types})>(TensorTypeId::${Backend}TensorId, "${schema_string}", &${Type}::${api_name})
|
||||
""")
|
||||
|
||||
# Generate a file that lists all functions and their schema string. Used for XLA
|
||||
|
|
@ -136,7 +136,7 @@ inline ${return_type} Tensor::${api_name}(${method_formals}) const {
|
|||
${static_dispatch_method_body}
|
||||
#else
|
||||
static auto table = globalATenDispatch().getOpTable("${schema_string}");
|
||||
return table->getOp<${return_type} (${formals_types})>(type_set(), is_variable())(${method_actuals});
|
||||
return table->getOp<${return_type} (${formals_types})>(type_set())(${method_actuals});
|
||||
#endif
|
||||
}
|
||||
""")
|
||||
|
|
@ -155,7 +155,7 @@ static inline ${return_type} ${api_name}(${formals}) {
|
|||
${static_dispatch_function_body}
|
||||
#else
|
||||
static auto table = globalATenDispatch().getOpTable("${schema_string}");
|
||||
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set}, ${inferred_is_variable})(${native_actuals});
|
||||
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${native_actuals});
|
||||
#endif
|
||||
}
|
||||
""")
|
||||
|
|
@ -191,7 +191,7 @@ static inline ${return_type} ${api_name}(${formals}) {
|
|||
#else
|
||||
globalLegacyTypeDispatch().initForTensorTypeSet(${inferred_type_set});
|
||||
static auto table = globalATenDispatch().getOpTable("${schema_string}");
|
||||
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set}, ${inferred_is_variable})(${native_actuals});
|
||||
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${native_actuals});
|
||||
#endif
|
||||
}
|
||||
""")
|
||||
|
|
@ -552,7 +552,6 @@ FunctionOption = TypedDict('FunctionOption', {
|
|||
'formals': List[str],
|
||||
'formals_types': List[str],
|
||||
'inferred_type_set': str,
|
||||
'inferred_is_variable': str,
|
||||
'inplace': bool,
|
||||
'matches_jit_signature': bool,
|
||||
# This controls whether or not we generate the interface in Type or
|
||||
|
|
@ -1095,6 +1094,18 @@ def create_generic(top_env, declarations):
|
|||
# type: (Any) -> FunctionCode
|
||||
if isinstance(type_method_dispatch, dict):
|
||||
static_dispatch_function_switches = []
|
||||
# NB: As this code is currently written, there will NEVER be
|
||||
# a backend generated for variable dispatch. There is nothing
|
||||
# stopping us from actually implementing this, however, if you
|
||||
# really wanted variable on mobile, there's nothing stopping
|
||||
# you from implementing this (however, you would have an
|
||||
# annoying phase problem, since code generation for variable
|
||||
# happens in tools/ which happens later than here.)
|
||||
#
|
||||
# If you pass in a variable to the dispatch, and variable is
|
||||
# enabled, this switch will fail. This is intentional: you
|
||||
# probably need to disable variable globally in the mobile
|
||||
# calling code.
|
||||
for backend in static_dispatch_backends:
|
||||
if backend in type_method_dispatch:
|
||||
static_dispatch_function_switches.append(STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT.substitute(
|
||||
|
|
@ -1104,10 +1115,6 @@ def create_generic(top_env, declarations):
|
|||
native_arguments=option['method_actuals']))
|
||||
static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
|
||||
option,
|
||||
# TODO: When Variable gets added, this needs to get adjusted
|
||||
# to avoid picking up the Variable bit. The correct way
|
||||
# to encode this is probably to just have Variable in the
|
||||
# disabled set.
|
||||
type_set='type_set()',
|
||||
static_dispatch_function_switches=static_dispatch_function_switches)
|
||||
else:
|
||||
|
|
@ -1122,15 +1129,12 @@ def create_generic(top_env, declarations):
|
|||
# type: (Any, Optional[str], Any) -> FunctionCode
|
||||
if dispatch_tensor:
|
||||
option['inferred_type_set'] = 'at::detail::infer_tensor_type_set({})'.format(dispatch_tensor)
|
||||
option['inferred_is_variable'] = 'at::detail::infer_is_variable({})'.format(dispatch_tensor)
|
||||
elif dispatch_options:
|
||||
option['inferred_type_set'] = '{}.type_set()'.format(dispatch_options['name'])
|
||||
option['inferred_is_variable'] = '{}.is_variable()'.format(dispatch_options['name'])
|
||||
else:
|
||||
# doesn't depend on a specific backend, use the empty set
|
||||
# TODO: Does this actually work?
|
||||
option['inferred_type_set'] = 'TensorTypeSet()'
|
||||
option['inferred_is_variable'] = 'false'
|
||||
declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
|
||||
fn_declaration = declaration.substitute(option)
|
||||
|
||||
|
|
|
|||
|
|
@ -218,10 +218,7 @@ class CAFFE2_API Tensor {
|
|||
|
||||
DeprecatedTypeProperties & type() const {
|
||||
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
||||
// TODO: When we build in Variable here, we need to change the
|
||||
// signature of getDeprecatedTypeProperties to collapse backend
|
||||
// and is_variable into TensorTypeSet
|
||||
tensorTypeIdToBackend(type_set().highestPriorityTypeId()),
|
||||
tensorTypeIdToBackend(legacyExtractTypeId(type_set())),
|
||||
scalar_type(),
|
||||
is_variable());
|
||||
}
|
||||
|
|
@ -439,14 +436,6 @@ inline TensorTypeSet infer_tensor_type_set(TensorList tl) {
|
|||
return tl[0].type_set();
|
||||
}
|
||||
|
||||
inline bool infer_is_variable(const Tensor & t) {
|
||||
TORCH_CHECK(t.defined(), "undefined Tensor");
|
||||
return t.is_variable();
|
||||
}
|
||||
inline bool infer_is_variable(const TensorList & tl) {
|
||||
TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors");
|
||||
return tl[0].is_variable();
|
||||
}
|
||||
} // namespace detail
|
||||
|
||||
static inline TensorTypeId legacyExtractTypeId(const Tensor& t) {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <c10/core/Backend.h>
|
||||
#include <c10/core/WrapDimMinimal.h>
|
||||
#include <c10/core/impl/LocalTensorTypeSet.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
C10_DEFINE_bool(
|
||||
|
|
@ -57,7 +58,7 @@ TensorImpl::TensorImpl(Storage&& storage, TensorTypeSet type_set, const caffe2::
|
|||
numel_(0),
|
||||
data_type_(data_type),
|
||||
device_opt_(device_opt),
|
||||
type_set_(type_set) {
|
||||
type_set_(type_set.remove(TensorTypeId::VariableTensorId)) {
|
||||
if (!type_set.empty()) {
|
||||
AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() ||
|
||||
device_opt_.has_value());
|
||||
|
|
@ -210,52 +211,12 @@ int64_t NamedTensorMetaInterface::slow_dim() const {
|
|||
}
|
||||
#endif
|
||||
|
||||
/// NOTE [ Treating Variables as non-Variables in type dispatch ]
|
||||
///
|
||||
/// Previously, in VariableType_*.cpp (generated by gen_variable_type.py), when
|
||||
/// a function is using the 'use_derived' strategy, we call its implementation
|
||||
/// on the base non-Variable type (`baseType`), passing unwrapped tensors to the
|
||||
/// call so that any `.dispatch_type()` calls in the implementation can treat the passed
|
||||
/// tensors as non-Variables and won't dispatch back to functions in VariableType.
|
||||
///
|
||||
/// However, after the Variable/Tensor merge, there is no concept of unwrapping
|
||||
/// a tensor anymore, and directly passing variables to the base type calls will
|
||||
/// cause the `.dispatch_type()` dispatch in the implementation to treat the tensor as a
|
||||
/// variable, and any function dispatch based on `.dispatch_type()` will dispatch back to
|
||||
/// VariableType, which is not what we want.
|
||||
///
|
||||
/// The solution to the above problem is to add `at::NonVariableTypeMode`, which
|
||||
/// when enabled will cause `legacyTensorType()` and `getType()` to always return
|
||||
/// non-Variable type, even if the tensor being called on is a variable.
|
||||
///
|
||||
/// TODO: Since `torch::NoGradGuard` serves the same purpose in libtorch, we should
|
||||
/// merge these two thread-local guards.
|
||||
|
||||
/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting,
|
||||
/// thread_local is not supported. In that case, we don't provide
|
||||
/// `at::NonVariableTypeMode`.
|
||||
#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY
|
||||
|
||||
thread_local bool NonVariableTypeMode_enabled = false;
|
||||
|
||||
bool NonVariableTypeMode::is_enabled() {
|
||||
return NonVariableTypeMode_enabled;
|
||||
return !impl::tls_variable_is_enabled();
|
||||
}
|
||||
|
||||
void NonVariableTypeMode::set_enabled(bool enabled) {
|
||||
NonVariableTypeMode_enabled = enabled;
|
||||
impl::tls_variable_set_enabled(!enabled);
|
||||
}
|
||||
|
||||
#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
|
||||
|
||||
bool NonVariableTypeMode::is_enabled() {
|
||||
throw std::runtime_error("NonVariableTypeMode is not supported on mobile");
|
||||
}
|
||||
|
||||
void NonVariableTypeMode::set_enabled(bool enabled) {
|
||||
throw std::runtime_error("NonVariableTypeMode is not supported on mobile");
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -826,6 +826,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
*/
|
||||
void set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
|
||||
autograd_meta_ = std::move(autograd_meta);
|
||||
if (autograd_meta_) {
|
||||
type_set_ = type_set_.add(TensorTypeId::VariableTensorId);
|
||||
} else {
|
||||
type_set_ = type_set_.remove(TensorTypeId::VariableTensorId);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -839,6 +844,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
|||
* Detach the autograd metadata unique_ptr from this tensor, and return it.
|
||||
*/
|
||||
std::unique_ptr<c10::AutogradMetaInterface> detach_autograd_meta() {
|
||||
type_set_ = type_set_.remove(TensorTypeId::VariableTensorId);
|
||||
return std::move(autograd_meta_);
|
||||
}
|
||||
|
||||
|
|
@ -1522,7 +1528,15 @@ protected:
|
|||
dest_impl->storage_offset_ = src_impl->storage_offset_;
|
||||
dest_impl->data_type_ = src_impl->data_type_;
|
||||
dest_impl->device_opt_ = src_impl->device_opt_;
|
||||
// This may temporarily violate invariant that
|
||||
// type_set_.has(VariableTensorId) iff autograd_meta_ != nullptr...
|
||||
dest_impl->type_set_ = src_impl->type_set_;
|
||||
// ...so refresh Variable in autograd_meta_
|
||||
if (dest_impl->autograd_meta_) {
|
||||
dest_impl->type_set_ = dest_impl->type_set_.add(TensorTypeId::VariableTensorId);
|
||||
} else {
|
||||
dest_impl->type_set_ = dest_impl->type_set_.remove(TensorTypeId::VariableTensorId);
|
||||
}
|
||||
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
|
||||
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
|
||||
dest_impl->reserved_ = src_impl->reserved_;
|
||||
|
|
@ -1543,12 +1557,17 @@ protected:
|
|||
static const char * const err_msg_tensor_metadata_change_not_allowed;
|
||||
|
||||
Storage storage_;
|
||||
|
||||
private:
|
||||
// This pointer points to an AutogradMeta struct that stores autograd-specific fields
|
||||
// (such as grad_ / grad_fn_ / grad_accumulator_).
|
||||
// This pointer always has unique ownership (meaning only one TensorImpl can own it
|
||||
// at a time).
|
||||
// This is private because we must maintain dispatcher invariants on it
|
||||
// in type_set_.
|
||||
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
|
||||
|
||||
protected:
|
||||
#ifdef BUILD_NAMEDTENSOR
|
||||
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -378,8 +378,9 @@ struct C10_API TensorOptions {
|
|||
|
||||
// Resolves the tensor type set specified by the current construction axes.
|
||||
TensorTypeSet type_set() const noexcept {
|
||||
// TODO: This should also contain variable eventually
|
||||
return TensorTypeSet(computeTensorTypeId());
|
||||
auto r = TensorTypeSet(computeTensorTypeId());
|
||||
if (is_variable()) r = r.add(TensorTypeId::VariableTensorId);
|
||||
return r;
|
||||
}
|
||||
|
||||
inline TensorTypeId computeTensorTypeId() const {
|
||||
|
|
|
|||
|
|
@ -38,6 +38,8 @@ const char* toString(TensorTypeId t) {
|
|||
return "ComplexCPUTensorId";
|
||||
case TensorTypeId::ComplexCUDATensorId:
|
||||
return "ComplexCUDATensorId";
|
||||
case TensorTypeId::VariableTensorId:
|
||||
return "VariableTensorId";
|
||||
default:
|
||||
return "UNKNOWN_TENSOR_TYPE_ID";
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,7 +40,12 @@ enum class TensorTypeId : uint8_t {
|
|||
ComplexCPUTensorId, // PyTorch only
|
||||
ComplexCUDATensorId, // PyTorch only
|
||||
|
||||
// VariableTensorId, // upcoming!
|
||||
// WARNING! If you add more "wrapper" style tensor ids (tensor
|
||||
// ids which don't get kernels directly defined in native_functions.yaml;
|
||||
// examples are tracing or profiling) here, you need to also adjust
|
||||
// legacyExtractTypeId in c10/core/TensorTypeId.h to mask them out.
|
||||
|
||||
VariableTensorId,
|
||||
|
||||
NumTensorIds, // Sentinel
|
||||
};
|
||||
|
|
|
|||
|
|
@ -33,7 +33,19 @@ namespace c10 {
|
|||
// An undefined tensor is one with an empty tensor type set.
|
||||
class TensorTypeSet final {
|
||||
public:
|
||||
TensorTypeSet() {}
|
||||
enum Full { FULL };
|
||||
enum Raw { RAW };
|
||||
|
||||
// NB: default constructor representation as zero is MANDATORY as
|
||||
// use of TensorTypeSet in TLS requires this.
|
||||
TensorTypeSet()
|
||||
: repr_(0) {}
|
||||
TensorTypeSet(Full)
|
||||
: repr_(-1) {}
|
||||
// Public version of TensorTypeSet(uint64_t) API; external users
|
||||
// must be explicit when they do this!
|
||||
TensorTypeSet(Raw, uint64_t x)
|
||||
: repr_(x) {}
|
||||
explicit TensorTypeSet(TensorTypeId t)
|
||||
: repr_(t == TensorTypeId::UndefinedTensorId
|
||||
? 0
|
||||
|
|
@ -47,6 +59,14 @@ public:
|
|||
TensorTypeSet operator|(TensorTypeSet other) const {
|
||||
return TensorTypeSet(repr_ | other.repr_);
|
||||
}
|
||||
// Perform set intersection
|
||||
TensorTypeSet operator&(TensorTypeSet other) const {
|
||||
return TensorTypeSet(repr_ & other.repr_);
|
||||
}
|
||||
// Compute the set difference self - other
|
||||
TensorTypeSet operator-(TensorTypeSet other) const {
|
||||
return TensorTypeSet(repr_ & ~other.repr_);
|
||||
}
|
||||
// Perform set equality
|
||||
bool operator==(TensorTypeSet other) const {
|
||||
return repr_ == other.repr_;
|
||||
|
|
@ -66,6 +86,7 @@ public:
|
|||
bool empty() const {
|
||||
return repr_ == 0;
|
||||
}
|
||||
uint64_t raw_repr() { return repr_; }
|
||||
// Return the type id in this set with the highest priority (i.e.,
|
||||
// is the largest in the TensorTypeId enum). Intuitively, this
|
||||
// type id is the one that should handle dispatch (assuming there
|
||||
|
|
@ -98,10 +119,10 @@ C10_API std::ostream& operator<<(std::ostream&, TensorTypeSet);
|
|||
// but s.has(VariableTensorId) will evaluate to true if s has VariableTensorId.
|
||||
// For non-VariableTensorId equality tests, they are indistinguishable.
|
||||
//
|
||||
// TODO: this will need to change when we add VariableTensorId to the
|
||||
// set of IDs put in TensorTypeSet.
|
||||
// NB: If you add other non-VariableTensorId other keys to this set, you'll
|
||||
// have to adjust this some more (sorry.)
|
||||
static inline TensorTypeId legacyExtractTypeId(TensorTypeSet s) {
|
||||
return s.highestPriorityTypeId();
|
||||
return s.remove(TensorTypeId::VariableTensorId).highestPriorityTypeId();
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
42
c10/core/impl/LocalTensorTypeSet.cpp
Normal file
42
c10/core/impl/LocalTensorTypeSet.cpp
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
#include <c10/core/impl/LocalTensorTypeSet.h>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
namespace {
|
||||
|
||||
/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting,
|
||||
/// thread_local is not supported. In that case, we don't provide
|
||||
/// `at::NonVariableTypeMode`.
|
||||
#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY
|
||||
|
||||
// NB: Zero initialized!
|
||||
thread_local uint64_t raw_excluded;
|
||||
|
||||
#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
|
||||
|
||||
uint64_t raw_excluded = 0;
|
||||
|
||||
#endif
|
||||
|
||||
}
|
||||
|
||||
TensorTypeSet tls_excluded_tensor_type_set() {
|
||||
return TensorTypeSet(TensorTypeSet::RAW, raw_excluded);
|
||||
}
|
||||
|
||||
bool tls_variable_is_enabled() {
|
||||
return !tls_excluded_tensor_type_set().has(TensorTypeId::VariableTensorId);
|
||||
}
|
||||
|
||||
void tls_variable_set_enabled(bool enabled) {
|
||||
if (enabled) {
|
||||
raw_excluded = tls_excluded_tensor_type_set().remove(TensorTypeId::VariableTensorId).raw_repr();
|
||||
} else {
|
||||
raw_excluded = tls_excluded_tensor_type_set().add(TensorTypeId::VariableTensorId).raw_repr();
|
||||
}
|
||||
}
|
||||
|
||||
}} // namespace c10::impl
|
||||
22
c10/core/impl/LocalTensorTypeSet.h
Normal file
22
c10/core/impl/LocalTensorTypeSet.h
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
#include <c10/core/TensorTypeSet.h>
|
||||
|
||||
// TLS management for TensorTypeSet
|
||||
//
|
||||
// This manages thread-local TensorTypeSet of excluded keys which disqualify
|
||||
// tensor types from dispatch. Keys which are in this set, even if they appear
|
||||
// in a list of potential valid keys on a tensor, are not considered for
|
||||
// dispatch. This is used to, for example, turn off autograd after we have
|
||||
// handled autograd for a top-level element.
|
||||
//
|
||||
// Originally, I implemented this as storing the inverted set, but
|
||||
// TLS is defined to be zero-initialized, so this doesn't actually work
|
||||
// (you want the set to be -1 initialized).
|
||||
|
||||
namespace c10 {
|
||||
namespace impl {
|
||||
|
||||
C10_API bool tls_variable_is_enabled();
|
||||
C10_API void tls_variable_set_enabled(bool enabled);
|
||||
C10_API TensorTypeSet tls_excluded_tensor_type_set();
|
||||
|
||||
}} // namespace c10::impl
|
||||
|
|
@ -45,3 +45,11 @@ TEST(TensorTypeSet, Doubleton) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TensorTypeSet, Full) {
|
||||
TensorTypeSet full(TensorTypeSet::FULL);
|
||||
for (uint8_t i = 1; i < static_cast<uint8_t>(TensorTypeId::NumTensorIds); i++) {
|
||||
auto tid = static_cast<TensorTypeId>(i);
|
||||
ASSERT_TRUE(full.has(tid));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -153,7 +153,7 @@ ${return_type} VariableType::${api_name}(${type_method_formals}) {
|
|||
""")
|
||||
|
||||
WRAPPER_REGISTRATION = CodeTemplate("""\
|
||||
.registerVariableOp<${return_type} (${formal_types})>("${schema_string}", &VariableType::${api_name})
|
||||
.registerOp<${return_type} (${formal_types})>(TensorTypeId::VariableTensorId, "${schema_string}", &VariableType::${api_name})
|
||||
""")
|
||||
|
||||
UNPACK_TENSOR = CodeTemplate("""\
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user