mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add VariableTensorId, store it in TensorTypeSet (#25597)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/25597 We now take advantage of the new bitset representation TensorTypeSet to store "Variable-ness" of a tensor directly in the dispatch key. We introduce a new thread local TensorTypeSet "excluded" and replace the previous thread local boolean with it; we no longer have to query `is_variable()` to do dispatch (I didn't delete `is_variable`, because there are still a lot of uses of it). The key change is in `dispatchTypeId`. Knock-on effects: * Because Variable is now a TensorTypeId, I can eliminate the out-of-line registration `registerVariableOp` for variables; instead, make the registrar take a TensorTypeId (instead of a Backend) and you just register under the Variable key. * Tensors aren't really ever created with Variable information initialized correctly at the start; instead, a tensor "becomes" a Variable because we set its `autograd_meta_`. These setters now correctly setup invariants on the dispatch type set. The new invariant is that if `autograd_meta_ != nullptr`, then `type_set().has(TensorTypeId::VariableTensorId)`. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Differential Revision: D17265919 Pulled By: ezyang fbshipit-source-id: a90a7ed14f5cb1086137483ae3d0646fcd4c42d0
This commit is contained in:
parent
ba9fda14a7
commit
2080a15860
|
|
@ -1,12 +1,15 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <c10/core/Backend.h>
|
|
||||||
#include <c10/core/TensorTypeSet.h>
|
#include <c10/core/TensorTypeSet.h>
|
||||||
|
#include <c10/core/Backend.h>
|
||||||
|
#include <c10/core/impl/LocalTensorTypeSet.h>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <c10/util/C++17.h>
|
#include <c10/util/C++17.h>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
|
|
||||||
|
// TODO: Rewrite this comment
|
||||||
|
//
|
||||||
// This dispatch class serves as a replacement for our previous dispatch
|
// This dispatch class serves as a replacement for our previous dispatch
|
||||||
// mechanism, in which all functions were members of a Type class. A derived
|
// mechanism, in which all functions were members of a Type class. A derived
|
||||||
// class existed for each backend (and Variable), and the vtable was used to
|
// class existed for each backend (and Variable), and the vtable was used to
|
||||||
|
|
@ -19,7 +22,7 @@ namespace at {
|
||||||
namespace impl {
|
namespace impl {
|
||||||
|
|
||||||
// Take a TensorTypeSet for a Tensor, and combine it with the current thread
|
// Take a TensorTypeSet for a Tensor, and combine it with the current thread
|
||||||
// local set of inclusions and exclusions (not yet implemented, coming soon!)
|
// local valid (implemented) and enabled (not implemented) TensorTypeSets
|
||||||
// to determine what the actual dispatch TensorTypeId should be. Unlike
|
// to determine what the actual dispatch TensorTypeId should be. Unlike
|
||||||
// Tensor::type_set(), the value of this on a tensor can change depending
|
// Tensor::type_set(), the value of this on a tensor can change depending
|
||||||
// on TLS.
|
// on TLS.
|
||||||
|
|
@ -30,8 +33,7 @@ namespace impl {
|
||||||
// question is whether or not we have access to all the relevant TLS at this
|
// question is whether or not we have access to all the relevant TLS at this
|
||||||
// point.
|
// point.
|
||||||
static inline TensorTypeId dispatchTypeId(TensorTypeSet ts) {
|
static inline TensorTypeId dispatchTypeId(TensorTypeSet ts) {
|
||||||
// TODO: Account for TLS!
|
return (ts - c10::impl::tls_excluded_tensor_type_set()).highestPriorityTypeId();
|
||||||
return ts.highestPriorityTypeId();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -44,70 +46,53 @@ class CAFFE2_API ATenOpTable {
|
||||||
: schema_(std::move(schema)) {}
|
: schema_(std::move(schema)) {}
|
||||||
|
|
||||||
template<class FuncType>
|
template<class FuncType>
|
||||||
FuncType* getOp(TensorTypeSet ts, bool is_variable) const {
|
FuncType* getOp(TensorTypeSet ts) const {
|
||||||
if (is_variable) {
|
return reinterpret_cast<FuncType*>(getOp(impl::dispatchTypeId(ts)));
|
||||||
return reinterpret_cast<FuncType*>(getVariableOp());
|
|
||||||
}
|
|
||||||
return reinterpret_cast<FuncType*>(getBaseOp(tensorTypeIdToBackend(impl::dispatchTypeId(ts))));
|
|
||||||
}
|
}
|
||||||
private:
|
private:
|
||||||
void registerOp(Backend backend, void* fn) {
|
void registerOp(TensorTypeId tid, void* fn) {
|
||||||
TORCH_CHECK(function_table_[static_cast<int64_t>(backend)] == nullptr,
|
TORCH_CHECK(function_table_[static_cast<int64_t>(tid)] == nullptr,
|
||||||
"Attempting to register variable function for schema ", schema_,
|
"Attempting to register function for schema ", schema_,
|
||||||
" and backend ", toString(backend),
|
" and tensor type ", toString(tid),
|
||||||
" but there is already a function registered");
|
" but there is already a function registered");
|
||||||
function_table_[static_cast<int64_t>(backend)] = fn;
|
function_table_[static_cast<int64_t>(tid)] = fn;
|
||||||
}
|
}
|
||||||
|
|
||||||
void registerVariableOp(void* fn) {
|
void* getOp(TensorTypeId tid) const {
|
||||||
TORCH_CHECK(variable_function_ == nullptr,
|
// You might think we can minorly optimize this further by maintaining a
|
||||||
"Attempting to register variable function for schema ", schema_,
|
// bitmask of registered operator keys, so we don't select dispatch ids
|
||||||
" but there is already a function registered");
|
// which don't have implementations here. But the net effect is that if you
|
||||||
variable_function_ = fn;
|
// get a Variable CPUTensor, if there is no variable registration, you'll
|
||||||
|
// fall back to the CPU implementation. Is this what you want? Unlikely...
|
||||||
|
if (function_table_[static_cast<int64_t>(tid)] == nullptr) {
|
||||||
|
TORCH_CHECK(function_table_[static_cast<int64_t>(TensorTypeId::UndefinedTensorId)] != nullptr,
|
||||||
|
"No function is registered for schema ", schema_, " on tensor type ", toString(tid));
|
||||||
|
return function_table_[static_cast<int64_t>(TensorTypeId::UndefinedTensorId)];
|
||||||
}
|
}
|
||||||
|
return function_table_[static_cast<int64_t>(tid)];
|
||||||
void* getBaseOp(Backend backend) const {
|
|
||||||
if (function_table_[static_cast<int64_t>(backend)] == nullptr) {
|
|
||||||
TORCH_CHECK(function_table_[static_cast<int64_t>(Backend::Undefined)] != nullptr,
|
|
||||||
"No function is registered for schema ", schema_, " on backend ", toString(backend));
|
|
||||||
return function_table_[static_cast<int64_t>(Backend::Undefined)];
|
|
||||||
}
|
|
||||||
return function_table_[static_cast<int64_t>(backend)];
|
|
||||||
}
|
|
||||||
|
|
||||||
void* getVariableOp() const {
|
|
||||||
TORCH_CHECK(variable_function_ != nullptr,
|
|
||||||
"No variable function registered for ", schema_);
|
|
||||||
return variable_function_;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
friend class ATenDispatch;
|
friend class ATenDispatch;
|
||||||
|
|
||||||
std::string schema_;
|
std::string schema_;
|
||||||
void* function_table_[static_cast<int64_t>(Backend::NumOptions)] = {nullptr};
|
void* function_table_[static_cast<int64_t>(TensorTypeId::NumTensorIds)] = {nullptr};
|
||||||
void* variable_function_ = nullptr;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class CAFFE2_API ATenDispatch {
|
class CAFFE2_API ATenDispatch {
|
||||||
public:
|
public:
|
||||||
template<class FuncType>
|
template<class FuncType>
|
||||||
ATenDispatch& registerOp(Backend backend, const char* schema, FuncType* fn) {
|
ATenDispatch& registerOp(TensorTypeId id, const char* schema, FuncType* fn) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
std::lock_guard<std::mutex> lock(mutex_);
|
||||||
if (op_tables_.find(schema) == op_tables_.end()) {
|
if (op_tables_.find(schema) == op_tables_.end()) {
|
||||||
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
|
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
|
||||||
}
|
}
|
||||||
op_tables_.at(schema).registerOp(backend, reinterpret_cast<void*>(fn));
|
op_tables_.at(schema).registerOp(id, reinterpret_cast<void*>(fn));
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
||||||
template<class FuncType>
|
template<class FuncType>
|
||||||
ATenDispatch& registerVariableOp(const char* schema, FuncType* fn) {
|
ATenDispatch& registerOp(Backend b, const char* schema, FuncType* fn) {
|
||||||
std::lock_guard<std::mutex> lock(mutex_);
|
return registerOp(backendToTensorTypeId(b), schema, fn);
|
||||||
if (op_tables_.find(schema) == op_tables_.end()) {
|
|
||||||
op_tables_.insert(std::make_pair(schema, ATenOpTable(schema)));
|
|
||||||
}
|
|
||||||
op_tables_.at(schema).registerVariableOp(reinterpret_cast<void*>(fn));
|
|
||||||
return *this;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const ATenOpTable* getOpTable(const char* schema) const {
|
const ATenOpTable* getOpTable(const char* schema) const {
|
||||||
|
|
|
||||||
|
|
@ -20,11 +20,11 @@ namespace at {
|
||||||
class CAFFE2_API LegacyTypeDispatch {
|
class CAFFE2_API LegacyTypeDispatch {
|
||||||
public:
|
public:
|
||||||
void initForTensorTypeSet(TensorTypeSet ts) {
|
void initForTensorTypeSet(TensorTypeSet ts) {
|
||||||
// TODO: When Variable gets turned on in TensorTypeSet, this
|
// TODO: Avoid use of legacyExtractTypeId here. The key
|
||||||
// will skip initialization when you initially process a
|
// problem is that you may get a TensorTypeSet with
|
||||||
// Variable CUDA tensor, for example (because I'll get Variable
|
// VariableTensorId set; should you initialize the "underlying"
|
||||||
// and it's not gonna have any device type.) Is that OK?
|
// type in that case? Hard to say.
|
||||||
auto b = tensorTypeIdToBackend(impl::dispatchTypeId(ts));
|
auto b = tensorTypeIdToBackend(legacyExtractTypeId(ts));
|
||||||
auto p = backendToDeviceType(b);
|
auto p = backendToDeviceType(b);
|
||||||
static std::once_flag cpu_once;
|
static std::once_flag cpu_once;
|
||||||
static std::once_flag cuda_once;
|
static std::once_flag cuda_once;
|
||||||
|
|
|
||||||
|
|
@ -218,10 +218,7 @@ class CAFFE2_API Tensor {
|
||||||
|
|
||||||
DeprecatedTypeProperties & type() const {
|
DeprecatedTypeProperties & type() const {
|
||||||
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
||||||
// TODO: When we build in Variable here, we need to change the
|
tensorTypeIdToBackend(legacyExtractTypeId(type_set())),
|
||||||
// signature of getDeprecatedTypeProperties to collapse backend
|
|
||||||
// and is_variable into TensorTypeSet
|
|
||||||
tensorTypeIdToBackend(type_set().highestPriorityTypeId()),
|
|
||||||
scalar_type(),
|
scalar_type(),
|
||||||
is_variable());
|
is_variable());
|
||||||
}
|
}
|
||||||
|
|
@ -932,14 +929,6 @@ inline TensorTypeSet infer_tensor_type_set(TensorList tl) {
|
||||||
return tl[0].type_set();
|
return tl[0].type_set();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool infer_is_variable(const Tensor & t) {
|
|
||||||
TORCH_CHECK(t.defined(), "undefined Tensor");
|
|
||||||
return t.is_variable();
|
|
||||||
}
|
|
||||||
inline bool infer_is_variable(const TensorList & tl) {
|
|
||||||
TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors");
|
|
||||||
return tl[0].is_variable();
|
|
||||||
}
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
static inline TensorTypeId legacyExtractTypeId(const Tensor& t) {
|
static inline TensorTypeId legacyExtractTypeId(const Tensor& t) {
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load Diff
|
|
@ -216,9 +216,11 @@ private:
|
||||||
if (tensor_list.size() == 0) {
|
if (tensor_list.size() == 0) {
|
||||||
throw std::runtime_error("Tried to dispatch operator " + operator_name + " based on an empty tensor list. When the first tensor argument of an operator is a tensor list, then it must not be empty.");
|
throw std::runtime_error("Tried to dispatch operator " + operator_name + " based on an empty tensor list. When the first tensor argument of an operator is a tensor list, then it must not be empty.");
|
||||||
}
|
}
|
||||||
return at::impl::dispatchTypeId(tensor_list[0].type_set());
|
// TODO: Don't use legacy extractor; blocked on c10 understanding
|
||||||
|
// variable
|
||||||
|
return c10::legacyExtractTypeId(tensor_list[0].type_set());
|
||||||
} else {
|
} else {
|
||||||
return at::impl::dispatchTypeId(first_tensor_arg.unsafeToTensorImpl()->type_set());
|
return c10::legacyExtractTypeId(first_tensor_arg.unsafeToTensorImpl()->type_set());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -113,10 +113,10 @@ ${return_type} ${Type}::${api_name}(${type_method_formals}) {
|
||||||
""")
|
""")
|
||||||
|
|
||||||
DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\
|
DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||||
.registerOp<${return_type} (${formals_types})>(Backend::Undefined, "${schema_string}", &TypeDefault::${api_name})
|
.registerOp<${return_type} (${formals_types})>(TensorTypeId::UndefinedTensorId, "${schema_string}", &TypeDefault::${api_name})
|
||||||
""")
|
""")
|
||||||
BACKEND_FUNCTION_REGISTRATION = CodeTemplate("""\
|
BACKEND_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||||
.registerOp<${return_type} (${formals_types})>(Backend::${Backend}, "${schema_string}", &${Type}::${api_name})
|
.registerOp<${return_type} (${formals_types})>(TensorTypeId::${Backend}TensorId, "${schema_string}", &${Type}::${api_name})
|
||||||
""")
|
""")
|
||||||
|
|
||||||
# Generate a file that lists all functions and their schema string. Used for XLA
|
# Generate a file that lists all functions and their schema string. Used for XLA
|
||||||
|
|
@ -136,7 +136,7 @@ inline ${return_type} Tensor::${api_name}(${method_formals}) const {
|
||||||
${static_dispatch_method_body}
|
${static_dispatch_method_body}
|
||||||
#else
|
#else
|
||||||
static auto table = globalATenDispatch().getOpTable("${schema_string}");
|
static auto table = globalATenDispatch().getOpTable("${schema_string}");
|
||||||
return table->getOp<${return_type} (${formals_types})>(type_set(), is_variable())(${method_actuals});
|
return table->getOp<${return_type} (${formals_types})>(type_set())(${method_actuals});
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
|
@ -155,7 +155,7 @@ static inline ${return_type} ${api_name}(${formals}) {
|
||||||
${static_dispatch_function_body}
|
${static_dispatch_function_body}
|
||||||
#else
|
#else
|
||||||
static auto table = globalATenDispatch().getOpTable("${schema_string}");
|
static auto table = globalATenDispatch().getOpTable("${schema_string}");
|
||||||
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set}, ${inferred_is_variable})(${native_actuals});
|
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${native_actuals});
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
|
@ -191,7 +191,7 @@ static inline ${return_type} ${api_name}(${formals}) {
|
||||||
#else
|
#else
|
||||||
globalLegacyTypeDispatch().initForTensorTypeSet(${inferred_type_set});
|
globalLegacyTypeDispatch().initForTensorTypeSet(${inferred_type_set});
|
||||||
static auto table = globalATenDispatch().getOpTable("${schema_string}");
|
static auto table = globalATenDispatch().getOpTable("${schema_string}");
|
||||||
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set}, ${inferred_is_variable})(${native_actuals});
|
return table->getOp<${return_type} (${formals_types})>(${inferred_type_set})(${native_actuals});
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
|
@ -552,7 +552,6 @@ FunctionOption = TypedDict('FunctionOption', {
|
||||||
'formals': List[str],
|
'formals': List[str],
|
||||||
'formals_types': List[str],
|
'formals_types': List[str],
|
||||||
'inferred_type_set': str,
|
'inferred_type_set': str,
|
||||||
'inferred_is_variable': str,
|
|
||||||
'inplace': bool,
|
'inplace': bool,
|
||||||
'matches_jit_signature': bool,
|
'matches_jit_signature': bool,
|
||||||
# This controls whether or not we generate the interface in Type or
|
# This controls whether or not we generate the interface in Type or
|
||||||
|
|
@ -1095,6 +1094,18 @@ def create_generic(top_env, declarations):
|
||||||
# type: (Any) -> FunctionCode
|
# type: (Any) -> FunctionCode
|
||||||
if isinstance(type_method_dispatch, dict):
|
if isinstance(type_method_dispatch, dict):
|
||||||
static_dispatch_function_switches = []
|
static_dispatch_function_switches = []
|
||||||
|
# NB: As this code is currently written, there will NEVER be
|
||||||
|
# a backend generated for variable dispatch. There is nothing
|
||||||
|
# stopping us from actually implementing this, however, if you
|
||||||
|
# really wanted variable on mobile, there's nothing stopping
|
||||||
|
# you from implementing this (however, you would have an
|
||||||
|
# annoying phase problem, since code generation for variable
|
||||||
|
# happens in tools/ which happens later than here.)
|
||||||
|
#
|
||||||
|
# If you pass in a variable to the dispatch, and variable is
|
||||||
|
# enabled, this switch will fail. This is intentional: you
|
||||||
|
# probably need to disable variable globally in the mobile
|
||||||
|
# calling code.
|
||||||
for backend in static_dispatch_backends:
|
for backend in static_dispatch_backends:
|
||||||
if backend in type_method_dispatch:
|
if backend in type_method_dispatch:
|
||||||
static_dispatch_function_switches.append(STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT.substitute(
|
static_dispatch_function_switches.append(STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT.substitute(
|
||||||
|
|
@ -1104,10 +1115,6 @@ def create_generic(top_env, declarations):
|
||||||
native_arguments=option['method_actuals']))
|
native_arguments=option['method_actuals']))
|
||||||
static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
|
static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute(
|
||||||
option,
|
option,
|
||||||
# TODO: When Variable gets added, this needs to get adjusted
|
|
||||||
# to avoid picking up the Variable bit. The correct way
|
|
||||||
# to encode this is probably to just have Variable in the
|
|
||||||
# disabled set.
|
|
||||||
type_set='type_set()',
|
type_set='type_set()',
|
||||||
static_dispatch_function_switches=static_dispatch_function_switches)
|
static_dispatch_function_switches=static_dispatch_function_switches)
|
||||||
else:
|
else:
|
||||||
|
|
@ -1122,15 +1129,12 @@ def create_generic(top_env, declarations):
|
||||||
# type: (Any, Optional[str], Any) -> FunctionCode
|
# type: (Any, Optional[str], Any) -> FunctionCode
|
||||||
if dispatch_tensor:
|
if dispatch_tensor:
|
||||||
option['inferred_type_set'] = 'at::detail::infer_tensor_type_set({})'.format(dispatch_tensor)
|
option['inferred_type_set'] = 'at::detail::infer_tensor_type_set({})'.format(dispatch_tensor)
|
||||||
option['inferred_is_variable'] = 'at::detail::infer_is_variable({})'.format(dispatch_tensor)
|
|
||||||
elif dispatch_options:
|
elif dispatch_options:
|
||||||
option['inferred_type_set'] = '{}.type_set()'.format(dispatch_options['name'])
|
option['inferred_type_set'] = '{}.type_set()'.format(dispatch_options['name'])
|
||||||
option['inferred_is_variable'] = '{}.is_variable()'.format(dispatch_options['name'])
|
|
||||||
else:
|
else:
|
||||||
# doesn't depend on a specific backend, use the empty set
|
# doesn't depend on a specific backend, use the empty set
|
||||||
# TODO: Does this actually work?
|
# TODO: Does this actually work?
|
||||||
option['inferred_type_set'] = 'TensorTypeSet()'
|
option['inferred_type_set'] = 'TensorTypeSet()'
|
||||||
option['inferred_is_variable'] = 'false'
|
|
||||||
declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
|
declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION
|
||||||
fn_declaration = declaration.substitute(option)
|
fn_declaration = declaration.substitute(option)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -218,10 +218,7 @@ class CAFFE2_API Tensor {
|
||||||
|
|
||||||
DeprecatedTypeProperties & type() const {
|
DeprecatedTypeProperties & type() const {
|
||||||
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
||||||
// TODO: When we build in Variable here, we need to change the
|
tensorTypeIdToBackend(legacyExtractTypeId(type_set())),
|
||||||
// signature of getDeprecatedTypeProperties to collapse backend
|
|
||||||
// and is_variable into TensorTypeSet
|
|
||||||
tensorTypeIdToBackend(type_set().highestPriorityTypeId()),
|
|
||||||
scalar_type(),
|
scalar_type(),
|
||||||
is_variable());
|
is_variable());
|
||||||
}
|
}
|
||||||
|
|
@ -439,14 +436,6 @@ inline TensorTypeSet infer_tensor_type_set(TensorList tl) {
|
||||||
return tl[0].type_set();
|
return tl[0].type_set();
|
||||||
}
|
}
|
||||||
|
|
||||||
inline bool infer_is_variable(const Tensor & t) {
|
|
||||||
TORCH_CHECK(t.defined(), "undefined Tensor");
|
|
||||||
return t.is_variable();
|
|
||||||
}
|
|
||||||
inline bool infer_is_variable(const TensorList & tl) {
|
|
||||||
TORCH_CHECK(tl.size() > 0, "expected a non-empty list of Tensors");
|
|
||||||
return tl[0].is_variable();
|
|
||||||
}
|
|
||||||
} // namespace detail
|
} // namespace detail
|
||||||
|
|
||||||
static inline TensorTypeId legacyExtractTypeId(const Tensor& t) {
|
static inline TensorTypeId legacyExtractTypeId(const Tensor& t) {
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
|
|
||||||
#include <c10/core/Backend.h>
|
#include <c10/core/Backend.h>
|
||||||
#include <c10/core/WrapDimMinimal.h>
|
#include <c10/core/WrapDimMinimal.h>
|
||||||
|
#include <c10/core/impl/LocalTensorTypeSet.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
|
|
||||||
C10_DEFINE_bool(
|
C10_DEFINE_bool(
|
||||||
|
|
@ -57,7 +58,7 @@ TensorImpl::TensorImpl(Storage&& storage, TensorTypeSet type_set, const caffe2::
|
||||||
numel_(0),
|
numel_(0),
|
||||||
data_type_(data_type),
|
data_type_(data_type),
|
||||||
device_opt_(device_opt),
|
device_opt_(device_opt),
|
||||||
type_set_(type_set) {
|
type_set_(type_set.remove(TensorTypeId::VariableTensorId)) {
|
||||||
if (!type_set.empty()) {
|
if (!type_set.empty()) {
|
||||||
AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() ||
|
AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() ||
|
||||||
device_opt_.has_value());
|
device_opt_.has_value());
|
||||||
|
|
@ -210,52 +211,12 @@ int64_t NamedTensorMetaInterface::slow_dim() const {
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
/// NOTE [ Treating Variables as non-Variables in type dispatch ]
|
|
||||||
///
|
|
||||||
/// Previously, in VariableType_*.cpp (generated by gen_variable_type.py), when
|
|
||||||
/// a function is using the 'use_derived' strategy, we call its implementation
|
|
||||||
/// on the base non-Variable type (`baseType`), passing unwrapped tensors to the
|
|
||||||
/// call so that any `.dispatch_type()` calls in the implementation can treat the passed
|
|
||||||
/// tensors as non-Variables and won't dispatch back to functions in VariableType.
|
|
||||||
///
|
|
||||||
/// However, after the Variable/Tensor merge, there is no concept of unwrapping
|
|
||||||
/// a tensor anymore, and directly passing variables to the base type calls will
|
|
||||||
/// cause the `.dispatch_type()` dispatch in the implementation to treat the tensor as a
|
|
||||||
/// variable, and any function dispatch based on `.dispatch_type()` will dispatch back to
|
|
||||||
/// VariableType, which is not what we want.
|
|
||||||
///
|
|
||||||
/// The solution to the above problem is to add `at::NonVariableTypeMode`, which
|
|
||||||
/// when enabled will cause `legacyTensorType()` and `getType()` to always return
|
|
||||||
/// non-Variable type, even if the tensor being called on is a variable.
|
|
||||||
///
|
|
||||||
/// TODO: Since `torch::NoGradGuard` serves the same purpose in libtorch, we should
|
|
||||||
/// merge these two thread-local guards.
|
|
||||||
|
|
||||||
/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting,
|
|
||||||
/// thread_local is not supported. In that case, we don't provide
|
|
||||||
/// `at::NonVariableTypeMode`.
|
|
||||||
#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY
|
|
||||||
|
|
||||||
thread_local bool NonVariableTypeMode_enabled = false;
|
|
||||||
|
|
||||||
bool NonVariableTypeMode::is_enabled() {
|
bool NonVariableTypeMode::is_enabled() {
|
||||||
return NonVariableTypeMode_enabled;
|
return !impl::tls_variable_is_enabled();
|
||||||
}
|
}
|
||||||
|
|
||||||
void NonVariableTypeMode::set_enabled(bool enabled) {
|
void NonVariableTypeMode::set_enabled(bool enabled) {
|
||||||
NonVariableTypeMode_enabled = enabled;
|
impl::tls_variable_set_enabled(!enabled);
|
||||||
}
|
}
|
||||||
|
|
||||||
#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
|
|
||||||
|
|
||||||
bool NonVariableTypeMode::is_enabled() {
|
|
||||||
throw std::runtime_error("NonVariableTypeMode is not supported on mobile");
|
|
||||||
}
|
|
||||||
|
|
||||||
void NonVariableTypeMode::set_enabled(bool enabled) {
|
|
||||||
throw std::runtime_error("NonVariableTypeMode is not supported on mobile");
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif
|
|
||||||
|
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
|
|
@ -826,6 +826,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
*/
|
*/
|
||||||
void set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
|
void set_autograd_meta(std::unique_ptr<c10::AutogradMetaInterface> autograd_meta) {
|
||||||
autograd_meta_ = std::move(autograd_meta);
|
autograd_meta_ = std::move(autograd_meta);
|
||||||
|
if (autograd_meta_) {
|
||||||
|
type_set_ = type_set_.add(TensorTypeId::VariableTensorId);
|
||||||
|
} else {
|
||||||
|
type_set_ = type_set_.remove(TensorTypeId::VariableTensorId);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -839,6 +844,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
* Detach the autograd metadata unique_ptr from this tensor, and return it.
|
* Detach the autograd metadata unique_ptr from this tensor, and return it.
|
||||||
*/
|
*/
|
||||||
std::unique_ptr<c10::AutogradMetaInterface> detach_autograd_meta() {
|
std::unique_ptr<c10::AutogradMetaInterface> detach_autograd_meta() {
|
||||||
|
type_set_ = type_set_.remove(TensorTypeId::VariableTensorId);
|
||||||
return std::move(autograd_meta_);
|
return std::move(autograd_meta_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1522,7 +1528,15 @@ protected:
|
||||||
dest_impl->storage_offset_ = src_impl->storage_offset_;
|
dest_impl->storage_offset_ = src_impl->storage_offset_;
|
||||||
dest_impl->data_type_ = src_impl->data_type_;
|
dest_impl->data_type_ = src_impl->data_type_;
|
||||||
dest_impl->device_opt_ = src_impl->device_opt_;
|
dest_impl->device_opt_ = src_impl->device_opt_;
|
||||||
|
// This may temporarily violate invariant that
|
||||||
|
// type_set_.has(VariableTensorId) iff autograd_meta_ != nullptr...
|
||||||
dest_impl->type_set_ = src_impl->type_set_;
|
dest_impl->type_set_ = src_impl->type_set_;
|
||||||
|
// ...so refresh Variable in autograd_meta_
|
||||||
|
if (dest_impl->autograd_meta_) {
|
||||||
|
dest_impl->type_set_ = dest_impl->type_set_.add(TensorTypeId::VariableTensorId);
|
||||||
|
} else {
|
||||||
|
dest_impl->type_set_ = dest_impl->type_set_.remove(TensorTypeId::VariableTensorId);
|
||||||
|
}
|
||||||
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
|
dest_impl->is_contiguous_ = src_impl->is_contiguous_;
|
||||||
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
|
dest_impl->is_wrapped_number_ = src_impl->is_wrapped_number_;
|
||||||
dest_impl->reserved_ = src_impl->reserved_;
|
dest_impl->reserved_ = src_impl->reserved_;
|
||||||
|
|
@ -1543,12 +1557,17 @@ protected:
|
||||||
static const char * const err_msg_tensor_metadata_change_not_allowed;
|
static const char * const err_msg_tensor_metadata_change_not_allowed;
|
||||||
|
|
||||||
Storage storage_;
|
Storage storage_;
|
||||||
|
|
||||||
|
private:
|
||||||
// This pointer points to an AutogradMeta struct that stores autograd-specific fields
|
// This pointer points to an AutogradMeta struct that stores autograd-specific fields
|
||||||
// (such as grad_ / grad_fn_ / grad_accumulator_).
|
// (such as grad_ / grad_fn_ / grad_accumulator_).
|
||||||
// This pointer always has unique ownership (meaning only one TensorImpl can own it
|
// This pointer always has unique ownership (meaning only one TensorImpl can own it
|
||||||
// at a time).
|
// at a time).
|
||||||
|
// This is private because we must maintain dispatcher invariants on it
|
||||||
|
// in type_set_.
|
||||||
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
|
std::unique_ptr<c10::AutogradMetaInterface> autograd_meta_ = nullptr;
|
||||||
|
|
||||||
|
protected:
|
||||||
#ifdef BUILD_NAMEDTENSOR
|
#ifdef BUILD_NAMEDTENSOR
|
||||||
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;
|
std::unique_ptr<c10::NamedTensorMetaInterface> named_tensor_meta_ = nullptr;
|
||||||
#endif
|
#endif
|
||||||
|
|
|
||||||
|
|
@ -378,8 +378,9 @@ struct C10_API TensorOptions {
|
||||||
|
|
||||||
// Resolves the tensor type set specified by the current construction axes.
|
// Resolves the tensor type set specified by the current construction axes.
|
||||||
TensorTypeSet type_set() const noexcept {
|
TensorTypeSet type_set() const noexcept {
|
||||||
// TODO: This should also contain variable eventually
|
auto r = TensorTypeSet(computeTensorTypeId());
|
||||||
return TensorTypeSet(computeTensorTypeId());
|
if (is_variable()) r = r.add(TensorTypeId::VariableTensorId);
|
||||||
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
inline TensorTypeId computeTensorTypeId() const {
|
inline TensorTypeId computeTensorTypeId() const {
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,8 @@ const char* toString(TensorTypeId t) {
|
||||||
return "ComplexCPUTensorId";
|
return "ComplexCPUTensorId";
|
||||||
case TensorTypeId::ComplexCUDATensorId:
|
case TensorTypeId::ComplexCUDATensorId:
|
||||||
return "ComplexCUDATensorId";
|
return "ComplexCUDATensorId";
|
||||||
|
case TensorTypeId::VariableTensorId:
|
||||||
|
return "VariableTensorId";
|
||||||
default:
|
default:
|
||||||
return "UNKNOWN_TENSOR_TYPE_ID";
|
return "UNKNOWN_TENSOR_TYPE_ID";
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,12 @@ enum class TensorTypeId : uint8_t {
|
||||||
ComplexCPUTensorId, // PyTorch only
|
ComplexCPUTensorId, // PyTorch only
|
||||||
ComplexCUDATensorId, // PyTorch only
|
ComplexCUDATensorId, // PyTorch only
|
||||||
|
|
||||||
// VariableTensorId, // upcoming!
|
// WARNING! If you add more "wrapper" style tensor ids (tensor
|
||||||
|
// ids which don't get kernels directly defined in native_functions.yaml;
|
||||||
|
// examples are tracing or profiling) here, you need to also adjust
|
||||||
|
// legacyExtractTypeId in c10/core/TensorTypeId.h to mask them out.
|
||||||
|
|
||||||
|
VariableTensorId,
|
||||||
|
|
||||||
NumTensorIds, // Sentinel
|
NumTensorIds, // Sentinel
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,19 @@ namespace c10 {
|
||||||
// An undefined tensor is one with an empty tensor type set.
|
// An undefined tensor is one with an empty tensor type set.
|
||||||
class TensorTypeSet final {
|
class TensorTypeSet final {
|
||||||
public:
|
public:
|
||||||
TensorTypeSet() {}
|
enum Full { FULL };
|
||||||
|
enum Raw { RAW };
|
||||||
|
|
||||||
|
// NB: default constructor representation as zero is MANDATORY as
|
||||||
|
// use of TensorTypeSet in TLS requires this.
|
||||||
|
TensorTypeSet()
|
||||||
|
: repr_(0) {}
|
||||||
|
TensorTypeSet(Full)
|
||||||
|
: repr_(-1) {}
|
||||||
|
// Public version of TensorTypeSet(uint64_t) API; external users
|
||||||
|
// must be explicit when they do this!
|
||||||
|
TensorTypeSet(Raw, uint64_t x)
|
||||||
|
: repr_(x) {}
|
||||||
explicit TensorTypeSet(TensorTypeId t)
|
explicit TensorTypeSet(TensorTypeId t)
|
||||||
: repr_(t == TensorTypeId::UndefinedTensorId
|
: repr_(t == TensorTypeId::UndefinedTensorId
|
||||||
? 0
|
? 0
|
||||||
|
|
@ -47,6 +59,14 @@ public:
|
||||||
TensorTypeSet operator|(TensorTypeSet other) const {
|
TensorTypeSet operator|(TensorTypeSet other) const {
|
||||||
return TensorTypeSet(repr_ | other.repr_);
|
return TensorTypeSet(repr_ | other.repr_);
|
||||||
}
|
}
|
||||||
|
// Perform set intersection
|
||||||
|
TensorTypeSet operator&(TensorTypeSet other) const {
|
||||||
|
return TensorTypeSet(repr_ & other.repr_);
|
||||||
|
}
|
||||||
|
// Compute the set difference self - other
|
||||||
|
TensorTypeSet operator-(TensorTypeSet other) const {
|
||||||
|
return TensorTypeSet(repr_ & ~other.repr_);
|
||||||
|
}
|
||||||
// Perform set equality
|
// Perform set equality
|
||||||
bool operator==(TensorTypeSet other) const {
|
bool operator==(TensorTypeSet other) const {
|
||||||
return repr_ == other.repr_;
|
return repr_ == other.repr_;
|
||||||
|
|
@ -66,6 +86,7 @@ public:
|
||||||
bool empty() const {
|
bool empty() const {
|
||||||
return repr_ == 0;
|
return repr_ == 0;
|
||||||
}
|
}
|
||||||
|
uint64_t raw_repr() { return repr_; }
|
||||||
// Return the type id in this set with the highest priority (i.e.,
|
// Return the type id in this set with the highest priority (i.e.,
|
||||||
// is the largest in the TensorTypeId enum). Intuitively, this
|
// is the largest in the TensorTypeId enum). Intuitively, this
|
||||||
// type id is the one that should handle dispatch (assuming there
|
// type id is the one that should handle dispatch (assuming there
|
||||||
|
|
@ -98,10 +119,10 @@ C10_API std::ostream& operator<<(std::ostream&, TensorTypeSet);
|
||||||
// but s.has(VariableTensorId) will evaluate to true if s has VariableTensorId.
|
// but s.has(VariableTensorId) will evaluate to true if s has VariableTensorId.
|
||||||
// For non-VariableTensorId equality tests, they are indistinguishable.
|
// For non-VariableTensorId equality tests, they are indistinguishable.
|
||||||
//
|
//
|
||||||
// TODO: this will need to change when we add VariableTensorId to the
|
// NB: If you add other non-VariableTensorId other keys to this set, you'll
|
||||||
// set of IDs put in TensorTypeSet.
|
// have to adjust this some more (sorry.)
|
||||||
static inline TensorTypeId legacyExtractTypeId(TensorTypeSet s) {
|
static inline TensorTypeId legacyExtractTypeId(TensorTypeSet s) {
|
||||||
return s.highestPriorityTypeId();
|
return s.remove(TensorTypeId::VariableTensorId).highestPriorityTypeId();
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
42
c10/core/impl/LocalTensorTypeSet.cpp
Normal file
42
c10/core/impl/LocalTensorTypeSet.cpp
Normal file
|
|
@ -0,0 +1,42 @@
|
||||||
|
#include <c10/core/impl/LocalTensorTypeSet.h>
|
||||||
|
|
||||||
|
#include <iostream>
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/// In the CAFFE2_FB_LIMITED_MOBILE_CAPABILITY build setting,
|
||||||
|
/// thread_local is not supported. In that case, we don't provide
|
||||||
|
/// `at::NonVariableTypeMode`.
|
||||||
|
#ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY
|
||||||
|
|
||||||
|
// NB: Zero initialized!
|
||||||
|
thread_local uint64_t raw_excluded;
|
||||||
|
|
||||||
|
#else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
|
||||||
|
|
||||||
|
uint64_t raw_excluded = 0;
|
||||||
|
|
||||||
|
#endif
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorTypeSet tls_excluded_tensor_type_set() {
|
||||||
|
return TensorTypeSet(TensorTypeSet::RAW, raw_excluded);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool tls_variable_is_enabled() {
|
||||||
|
return !tls_excluded_tensor_type_set().has(TensorTypeId::VariableTensorId);
|
||||||
|
}
|
||||||
|
|
||||||
|
void tls_variable_set_enabled(bool enabled) {
|
||||||
|
if (enabled) {
|
||||||
|
raw_excluded = tls_excluded_tensor_type_set().remove(TensorTypeId::VariableTensorId).raw_repr();
|
||||||
|
} else {
|
||||||
|
raw_excluded = tls_excluded_tensor_type_set().add(TensorTypeId::VariableTensorId).raw_repr();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}} // namespace c10::impl
|
||||||
22
c10/core/impl/LocalTensorTypeSet.h
Normal file
22
c10/core/impl/LocalTensorTypeSet.h
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
#include <c10/core/TensorTypeSet.h>
|
||||||
|
|
||||||
|
// TLS management for TensorTypeSet
|
||||||
|
//
|
||||||
|
// This manages thread-local TensorTypeSet of excluded keys which disqualify
|
||||||
|
// tensor types from dispatch. Keys which are in this set, even if they appear
|
||||||
|
// in a list of potential valid keys on a tensor, are not considered for
|
||||||
|
// dispatch. This is used to, for example, turn off autograd after we have
|
||||||
|
// handled autograd for a top-level element.
|
||||||
|
//
|
||||||
|
// Originally, I implemented this as storing the inverted set, but
|
||||||
|
// TLS is defined to be zero-initialized, so this doesn't actually work
|
||||||
|
// (you want the set to be -1 initialized).
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
namespace impl {
|
||||||
|
|
||||||
|
C10_API bool tls_variable_is_enabled();
|
||||||
|
C10_API void tls_variable_set_enabled(bool enabled);
|
||||||
|
C10_API TensorTypeSet tls_excluded_tensor_type_set();
|
||||||
|
|
||||||
|
}} // namespace c10::impl
|
||||||
|
|
@ -45,3 +45,11 @@ TEST(TensorTypeSet, Doubleton) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(TensorTypeSet, Full) {
|
||||||
|
TensorTypeSet full(TensorTypeSet::FULL);
|
||||||
|
for (uint8_t i = 1; i < static_cast<uint8_t>(TensorTypeId::NumTensorIds); i++) {
|
||||||
|
auto tid = static_cast<TensorTypeId>(i);
|
||||||
|
ASSERT_TRUE(full.has(tid));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -153,7 +153,7 @@ ${return_type} VariableType::${api_name}(${type_method_formals}) {
|
||||||
""")
|
""")
|
||||||
|
|
||||||
WRAPPER_REGISTRATION = CodeTemplate("""\
|
WRAPPER_REGISTRATION = CodeTemplate("""\
|
||||||
.registerVariableOp<${return_type} (${formal_types})>("${schema_string}", &VariableType::${api_name})
|
.registerOp<${return_type} (${formal_types})>(TensorTypeId::VariableTensorId, "${schema_string}", &VariableType::${api_name})
|
||||||
""")
|
""")
|
||||||
|
|
||||||
UNPACK_TENSOR = CodeTemplate("""\
|
UNPACK_TENSOR = CodeTemplate("""\
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user