mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Making ops c10-full: list of optional tensors (#49138)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/49138 See for details: https://fb.quip.com/QRtJAin66lPN We need to model optional types explicitly, mostly for schema inference. So we cannot pass a `Tensor?[]` as `ArrayRef<Tensor>`, instead we need to pass it as an optional type. This PR changes it to `torch::List<c10::optional<Tensor>>`. It also makes the ops c10-full that were blocked by this. ## Backwards Compatibility - This should not break the Python API because the representation in Python is the same and python_arg_parser just transforms the python list into a `List<optional<Tensor>>` instead of into a `List<Tensor>`. - This should not break serialized models because there's some logic that allows loading a serialized `List<Tensor>` as `List<optional<Tensor>>`, see https://github.com/pytorch/pytorch/pull/49138/files#diff-9315f5dd045f47114c677174dcaa2f982721233eee1aa19068a42ff3ef775315R57 - This will break backwards compatibility for the C++ API. There is no implicit conversion from `ArrayRef<Tensor>` (which was the old argument type) to `List<optional<Tensor>>`. One common call pattern is `tensor.index({indices_tensor})`, where indices_tensor is another `Tensor`, and that will continue working because the `{}` initializer_list constructor for `List<optional<Tensor>>` can take `Tensor` elements that are implicitly converted to `optional<Tensor>`, but another common call pattern was `tensor.index(indices_tensor)`, where previously, the `Tensor` got implicitly converted to an `ArrayRef<Tensor>`, and to implicitly convert `Tensor -> optional<Tensor> -> List<optional<Tensor>>` would be two implicit conversions. C++ doesn't allow chaining. two implicit conversions. So those call sites have to be rewritten to `tensor.index({indices_tensor})`. ghstack-source-id: 119269131 Test Plan: ## Benchmarks (C++ instruction counts): ### Forward #### Script ```py from torch.utils.benchmark import Timer counts = Timer( stmt=""" auto t = {{op call to measure}}; """, setup=""" using namespace torch::indexing; auto x = torch::ones({4, 4, 4}); """, language="cpp", ).collect_callgrind(number=1_000) print(counts) ``` #### Results | Op call |before |after |delta | | |------------------------------------------------------------------------|---------|--------|-------|------| |x[0] = 1 |11566015 |11566015|0 |0.00% | |x.index({0}) |6807019 |6801019 |-6000 |-0.09%| |x.index({0, 0}) |13529019 |13557019|28000 |0.21% | |x.index({0, 0, 0}) |10677004 |10692004|15000 |0.14% | |x.index({"..."}) |5512015 |5506015 |-6000 |-0.11%| |x.index({Slice(None, None, None)}) |6866016 |6936016 |70000 |1.02% | |x.index({None}) |8554015 |8548015 |-6000 |-0.07%| |x.index({false}) |22400000 |22744000|344000 |1.54% | |x.index({true}) |27624088 |27264393|-359695|-1.30%| |x.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})|123472000|123463306|-8694|-0.01%| ### Autograd #### Script ```py from torch.utils.benchmark import Timer counts = Timer( stmt=""" auto t = {{op call to measure}}; """, setup=""" using namespace torch::indexing; auto x = torch::ones({4, 4, 4}, torch::requires_grad()); """, language="cpp", ).collect_callgrind(number=1_000) print(counts) ``` Note: the script measures the **forward** path of an op call with autograd enabled (i.e. calls into VariableType). It does not measure the backward path. #### Results | Op call |before |after |delta | | |------------------------------------------------------------------------|---------|--------|-------|------| |x.index({0}) |14839019|14833019|-6000| 0.00% | |x.index({0, 0}) |28342019|28370019|28000| 0.00% | |x.index({0, 0, 0}) |24434004|24449004|15000| 0.00% | |x.index({"..."}) |12773015|12767015|-6000| 0.00% | |x.index({Slice(None, None, None)}) |14837016|14907016|70000| 0.47% | |x.index({None}) |15926015|15920015|-6000| 0.00% | |x.index({false}) |36958000|37477000|519000| 1.40% | |x.index({true}) |41971408|42426094|454686| 1.08% | |x.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}) |168184392|164545682|-3638710| -2.16% | Reviewed By: bhosmer Differential Revision: D25454632 fbshipit-source-id: 28ab0cffbbdbdff1c40b4130ca62ee72f981b76d
This commit is contained in:
parent
e44b2b72bd
commit
c7e9abb66a
|
|
@ -31,3 +31,4 @@
|
||||||
#include <c10/util/Exception.h>
|
#include <c10/util/Exception.h>
|
||||||
#include <ATen/core/UnsafeFromTH.h>
|
#include <ATen/core/UnsafeFromTH.h>
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
#include <ATen/Config.h>
|
#include <ATen/Config.h>
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
#if AT_PARALLEL_OPENMP
|
#if AT_PARALLEL_OPENMP
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,8 @@
|
||||||
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
|
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
|
|
||||||
|
#include <ATen/core/List.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace indexing {
|
namespace indexing {
|
||||||
|
|
||||||
|
|
@ -261,14 +263,15 @@ static inline void recordTensorIndex(const Tensor& tensor, std::vector<Tensor>&
|
||||||
(*dim_ptr)++;
|
(*dim_ptr)++;
|
||||||
};
|
};
|
||||||
|
|
||||||
static inline std::vector<Tensor> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
|
static inline c10::List<c10::optional<Tensor>> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
|
||||||
std::vector<Tensor> converted_inds(indices.size());
|
c10::List<c10::optional<Tensor>> converted_inds;
|
||||||
|
converted_inds.reserve(indices.size());
|
||||||
for (size_t i = 0; i < indices.size(); ++i) {
|
for (size_t i = 0; i < indices.size(); ++i) {
|
||||||
const auto &ind = indices[i];
|
const auto &ind = indices[i];
|
||||||
if (ind.defined()) {
|
if (ind.defined()) {
|
||||||
converted_inds[i] = ind.to(ind.options().device(self.device()));
|
converted_inds.push_back(ind.to(ind.options().device(self.device())));
|
||||||
} else {
|
} else {
|
||||||
converted_inds[i] = std::move(indices[i]);
|
converted_inds.push_back(std::move(indices[i]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return converted_inds;
|
return converted_inds;
|
||||||
|
|
|
||||||
|
|
@ -406,7 +406,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
|
||||||
KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional<int64_t>), promote)
|
KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional<int64_t>), promote)
|
||||||
KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote)
|
KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote)
|
||||||
KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote)
|
KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote)
|
||||||
KERNEL_UNBOXED_ONLY(ADD_NS(index_put), "index_put", Tensor (const Tensor &, TensorList, const Tensor &, bool), promote)
|
KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List<c10::optional<Tensor>>&, const Tensor &, bool), promote)
|
||||||
KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
|
KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
|
||||||
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)
|
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -243,7 +243,7 @@ public:
|
||||||
* Example:
|
* Example:
|
||||||
* List<int> a({2, 3, 4});
|
* List<int> a({2, 3, 4});
|
||||||
*/
|
*/
|
||||||
explicit List(std::initializer_list<T> initial_values);
|
List(std::initializer_list<T> initial_values);
|
||||||
explicit List(ArrayRef<T> initial_values);
|
explicit List(ArrayRef<T> initial_values);
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/core/jit_type_base.h>
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
#include <ATen/core/jit_type.h>
|
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
|
|
@ -50,7 +50,17 @@ List<T>::List(TypePtr elementType)
|
||||||
namespace impl {
|
namespace impl {
|
||||||
template<class T>
|
template<class T>
|
||||||
List<T> toTypedList(impl::GenericList list) {
|
List<T> toTypedList(impl::GenericList list) {
|
||||||
TORCH_INTERNAL_ASSERT(*getTypePtr<T>() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
|
// If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
|
||||||
|
// because upcasting would allow people to add types into the new list that would break the old list.
|
||||||
|
// However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
|
||||||
|
// allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
|
||||||
|
// without having to copy it. This is also used to provide backwards compatibility with some old models
|
||||||
|
// that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
|
||||||
|
// as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
|
||||||
|
// have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
|
||||||
|
TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
|
||||||
|
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr<T>()))
|
||||||
|
, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
|
||||||
return List<T>(std::move(list.impl_));
|
return List<T>(std::move(list.impl_));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -312,3 +322,5 @@ void List<T>::unsafeSetElementType(TypePtr t) {
|
||||||
impl_->elementType = std::move(t);
|
impl_->elementType = std::move(t);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,7 @@
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
||||||
#include <c10/util/ArrayRef.h>
|
#include <c10/util/ArrayRef.h>
|
||||||
|
#include <ATen/core/List.h>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
|
|
@ -56,6 +57,15 @@ struct IterArgs {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void operator()(const torch::List<T>& args) {
|
||||||
|
for (const auto& arg : args) {
|
||||||
|
self()(arg);
|
||||||
|
if (self().short_circuit())
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NB: we need to specify std::vector manually as C++ won't
|
// NB: we need to specify std::vector manually as C++ won't
|
||||||
// do an implicit conversion to make a template deduction go through.
|
// do an implicit conversion to make a template deduction go through.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/core/jit_type_base.h>
|
||||||
#include <ATen/core/TensorBody.h>
|
#include <ATen/core/TensorBody.h>
|
||||||
#include <ATen/core/functional.h>
|
#include <ATen/core/functional.h>
|
||||||
#include <ATen/core/interned_strings.h>
|
#include <ATen/core/interned_strings.h>
|
||||||
#include <ATen/core/ivalue.h>
|
|
||||||
#include <ATen/core/qualified_name.h>
|
#include <ATen/core/qualified_name.h>
|
||||||
|
#include <ATen/core/ivalue.h>
|
||||||
#include <c10/util/TypeList.h>
|
#include <c10/util/TypeList.h>
|
||||||
#include <c10/util/Optional.h>
|
#include <c10/util/Optional.h>
|
||||||
|
|
||||||
|
|
@ -17,197 +18,17 @@ struct ClassType;
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace jit {
|
namespace jit {
|
||||||
struct CompilationUnit;
|
struct CompilationUnit;
|
||||||
|
struct Function;
|
||||||
} // namespace jit
|
} // namespace jit
|
||||||
} // namespace torch
|
} // namespace torch
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
||||||
|
struct IValue;
|
||||||
struct FunctionSchema;
|
struct FunctionSchema;
|
||||||
struct NamedType;
|
struct NamedType;
|
||||||
using OptNameList = c10::optional<std::vector<std::string>>;
|
using OptNameList = c10::optional<std::vector<std::string>>;
|
||||||
|
|
||||||
#define C10_FORALL_TYPES(_) \
|
|
||||||
_(AnyType) \
|
|
||||||
_(EnumType) \
|
|
||||||
_(AnyEnumType) \
|
|
||||||
_(TensorType) \
|
|
||||||
_(StorageType) \
|
|
||||||
_(TupleType) \
|
|
||||||
_(ListType) \
|
|
||||||
_(DictType) \
|
|
||||||
_(NumberType) \
|
|
||||||
_(FloatType) \
|
|
||||||
_(FutureType) \
|
|
||||||
_(RRefType) \
|
|
||||||
_(IntType) \
|
|
||||||
_(NoneType) \
|
|
||||||
_(StringType) \
|
|
||||||
_(GeneratorType) \
|
|
||||||
_(QuantizerType) \
|
|
||||||
_(BoolType) \
|
|
||||||
_(OptionalType) \
|
|
||||||
_(VarType) \
|
|
||||||
_(DeviceObjType) \
|
|
||||||
_(StreamObjType) \
|
|
||||||
_(FunctionType) \
|
|
||||||
_(ClassType) \
|
|
||||||
_(PyObjectType) \
|
|
||||||
_(CapsuleType) \
|
|
||||||
_(InterfaceType) \
|
|
||||||
_(QSchemeType) \
|
|
||||||
_(LayoutType) \
|
|
||||||
_(ScalarTypeType) \
|
|
||||||
_(AnyListType) \
|
|
||||||
_(AnyTupleType) \
|
|
||||||
_(AnyClassType)
|
|
||||||
|
|
||||||
enum class TypeKind {
|
|
||||||
#define DEFINE_TYPE(T) T,
|
|
||||||
C10_FORALL_TYPES(DEFINE_TYPE)
|
|
||||||
#undef DEFINE_TYPE
|
|
||||||
};
|
|
||||||
|
|
||||||
TORCH_API const char* typeKindToString(TypeKind kind);
|
|
||||||
|
|
||||||
struct Type;
|
|
||||||
using TypePtr = std::shared_ptr<Type>;
|
|
||||||
using ConstTypePtr = std::shared_ptr<const Type>;
|
|
||||||
|
|
||||||
// Use this to customize how a Type is printed using `annotation_str()`. If
|
|
||||||
// c10::nullopt is returned, `annotation_str()` falls through to its default
|
|
||||||
// implementation.
|
|
||||||
using TypePrinter =
|
|
||||||
std::function<c10::optional<std::string>(const ConstTypePtr&)>;
|
|
||||||
|
|
||||||
struct TORCH_API Type : std::enable_shared_from_this<Type> {
|
|
||||||
private:
|
|
||||||
TypeKind kind_;
|
|
||||||
|
|
||||||
protected:
|
|
||||||
Type(TypeKind kind) : kind_(kind) {}
|
|
||||||
|
|
||||||
virtual std::string annotation_str_impl(TypePrinter printer) const {
|
|
||||||
return str();
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
virtual bool operator==(const Type& rhs) const = 0;
|
|
||||||
|
|
||||||
// subtyping relation. By default, we return true for the case
|
|
||||||
// when the type is exactly equal or if this <: T where rhs = Optional[T]
|
|
||||||
|
|
||||||
// if this returns false and the why_not stream is non-null, it contains
|
|
||||||
// additional details that describe why this is not a subtype of 'rhs'.
|
|
||||||
// This additional information should only contain details that are not obvious
|
|
||||||
// from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false
|
|
||||||
// but not clear why `Foo <: InterfaceBar` might be false.
|
|
||||||
virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
|
|
||||||
virtual bool is_module() const;
|
|
||||||
bool isSubtypeOf(const TypePtr& rhs) const {
|
|
||||||
return isSubtypeOfExt(rhs, nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
// How this type will appear in FunctionSchema declarations
|
|
||||||
virtual std::string str() const = 0;
|
|
||||||
|
|
||||||
// How this type will appear as if it were a type annotation in Python
|
|
||||||
// which is sometimes different than how it appears in declarations (e.g.
|
|
||||||
// int[] vs List[int])
|
|
||||||
//
|
|
||||||
// Takes a custom printer that users can pass in to customize the output of
|
|
||||||
// this method.
|
|
||||||
std::string annotation_str(TypePrinter printer) const {
|
|
||||||
if (printer) {
|
|
||||||
// the printer can return nullopt to fall through to the default impl
|
|
||||||
if (auto renamed = printer(shared_from_this())) {
|
|
||||||
return *renamed;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return annotation_str_impl(printer);
|
|
||||||
}
|
|
||||||
std::string annotation_str() const {
|
|
||||||
// Overload instead of define a default value for `printer` to help
|
|
||||||
// debuggers out.
|
|
||||||
return annotation_str(nullptr);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a human readable string that includes additional information like
|
|
||||||
// "type is inferred rather than explictly defined" to help construct more
|
|
||||||
// user-friendly messages.
|
|
||||||
virtual std::string repr_str() const {
|
|
||||||
return annotation_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
TypeKind kind() const {
|
|
||||||
return kind_;
|
|
||||||
}
|
|
||||||
|
|
||||||
virtual bool requires_grad() const {
|
|
||||||
for (const auto& ct : containedTypes()) {
|
|
||||||
if (ct->requires_grad()) {
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dynamically cast this object to the subclass indicated by the
|
|
||||||
// template variable, returning nullptr if the cast is invalid.
|
|
||||||
template <typename T>
|
|
||||||
std::shared_ptr<T> cast() {
|
|
||||||
if (T::Kind == kind()) {
|
|
||||||
return std::static_pointer_cast<T>(shared_from_this());
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
template <typename T>
|
|
||||||
std::shared_ptr<const T> cast() const {
|
|
||||||
if (T::Kind == kind()) {
|
|
||||||
return std::static_pointer_cast<const T>(shared_from_this());
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
template <typename T>
|
|
||||||
std::shared_ptr<T> expect() {
|
|
||||||
auto r = cast<T>();
|
|
||||||
AT_ASSERT(r);
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
template <typename T>
|
|
||||||
std::shared_ptr<const T> expect() const {
|
|
||||||
auto r = cast<const T>();
|
|
||||||
AT_ASSERT(r);
|
|
||||||
return r;
|
|
||||||
}
|
|
||||||
virtual ~Type() = default;
|
|
||||||
virtual bool hasFreeVariables() const {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
// list of types this type contains, e.g. for a List then element type of a
|
|
||||||
// list for a tuple, the types of the tuple elements
|
|
||||||
virtual at::ArrayRef<TypePtr> containedTypes() const {
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
// create a new version of this type, replacing its contained types with
|
|
||||||
// contained_types
|
|
||||||
TypePtr withContained(std::vector<TypePtr> contained_types) {
|
|
||||||
auto current_contained = containedTypes();
|
|
||||||
AT_ASSERT(current_contained.size() == contained_types.size());
|
|
||||||
if (current_contained.equals(contained_types)) {
|
|
||||||
return shared_from_this();
|
|
||||||
}
|
|
||||||
return createWithContained(std::move(contained_types));
|
|
||||||
}
|
|
||||||
// per-type constructor, you only need to override this if the
|
|
||||||
// containedTypes() is not empty
|
|
||||||
virtual TypePtr createWithContained(
|
|
||||||
std::vector<TypePtr> contained_types) const {
|
|
||||||
AT_ERROR(
|
|
||||||
"type with contained types did not overload createWithContained: ",
|
|
||||||
str());
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct AnyType;
|
struct AnyType;
|
||||||
using AnyTypePtr = std::shared_ptr<AnyType>;
|
using AnyTypePtr = std::shared_ptr<AnyType>;
|
||||||
// Any is the top of the type hierarchy, all other types are subtypes
|
// Any is the top of the type hierarchy, all other types are subtypes
|
||||||
|
|
|
||||||
195
aten/src/ATen/core/jit_type_base.h
Normal file
195
aten/src/ATen/core/jit_type_base.h
Normal file
|
|
@ -0,0 +1,195 @@
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
#include <memory>
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
#include <c10/util/Optional.h>
|
||||||
|
#include <c10/util/Exception.h>
|
||||||
|
#include <c10/util/ArrayRef.h>
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
|
||||||
|
#define C10_FORALL_TYPES(_) \
|
||||||
|
_(AnyType) \
|
||||||
|
_(EnumType) \
|
||||||
|
_(AnyEnumType) \
|
||||||
|
_(TensorType) \
|
||||||
|
_(StorageType) \
|
||||||
|
_(TupleType) \
|
||||||
|
_(ListType) \
|
||||||
|
_(DictType) \
|
||||||
|
_(NumberType) \
|
||||||
|
_(FloatType) \
|
||||||
|
_(FutureType) \
|
||||||
|
_(RRefType) \
|
||||||
|
_(IntType) \
|
||||||
|
_(NoneType) \
|
||||||
|
_(StringType) \
|
||||||
|
_(GeneratorType) \
|
||||||
|
_(QuantizerType) \
|
||||||
|
_(BoolType) \
|
||||||
|
_(OptionalType) \
|
||||||
|
_(VarType) \
|
||||||
|
_(DeviceObjType) \
|
||||||
|
_(StreamObjType) \
|
||||||
|
_(FunctionType) \
|
||||||
|
_(ClassType) \
|
||||||
|
_(PyObjectType) \
|
||||||
|
_(CapsuleType) \
|
||||||
|
_(InterfaceType) \
|
||||||
|
_(QSchemeType) \
|
||||||
|
_(LayoutType) \
|
||||||
|
_(ScalarTypeType) \
|
||||||
|
_(AnyListType) \
|
||||||
|
_(AnyTupleType) \
|
||||||
|
_(AnyClassType)
|
||||||
|
|
||||||
|
enum class TypeKind {
|
||||||
|
#define DEFINE_TYPE(T) T,
|
||||||
|
C10_FORALL_TYPES(DEFINE_TYPE)
|
||||||
|
#undef DEFINE_TYPE
|
||||||
|
};
|
||||||
|
|
||||||
|
TORCH_API const char* typeKindToString(TypeKind kind);
|
||||||
|
|
||||||
|
struct Type;
|
||||||
|
using TypePtr = std::shared_ptr<Type>;
|
||||||
|
using ConstTypePtr = std::shared_ptr<const Type>;
|
||||||
|
|
||||||
|
// Use this to customize how a Type is printed using `annotation_str()`. If
|
||||||
|
// c10::nullopt is returned, `annotation_str()` falls through to its default
|
||||||
|
// implementation.
|
||||||
|
using TypePrinter =
|
||||||
|
std::function<c10::optional<std::string>(const ConstTypePtr&)>;
|
||||||
|
|
||||||
|
struct TORCH_API Type : std::enable_shared_from_this<Type> {
|
||||||
|
private:
|
||||||
|
TypeKind kind_;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
Type(TypeKind kind) : kind_(kind) {}
|
||||||
|
|
||||||
|
virtual std::string annotation_str_impl(TypePrinter printer) const {
|
||||||
|
return str();
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
virtual bool operator==(const Type& rhs) const = 0;
|
||||||
|
|
||||||
|
// subtyping relation. By default, we return true for the case
|
||||||
|
// when the type is exactly equal or if this <: T where rhs = Optional[T]
|
||||||
|
|
||||||
|
// if this returns false and the why_not stream is non-null, it contains
|
||||||
|
// additional details that describe why this is not a subtype of 'rhs'.
|
||||||
|
// This additional information should only contain details that are not obvious
|
||||||
|
// from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false
|
||||||
|
// but not clear why `Foo <: InterfaceBar` might be false.
|
||||||
|
virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
|
||||||
|
virtual bool is_module() const;
|
||||||
|
bool isSubtypeOf(const TypePtr& rhs) const {
|
||||||
|
return isSubtypeOfExt(rhs, nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// How this type will appear in FunctionSchema declarations
|
||||||
|
virtual std::string str() const = 0;
|
||||||
|
|
||||||
|
// How this type will appear as if it were a type annotation in Python
|
||||||
|
// which is sometimes different than how it appears in declarations (e.g.
|
||||||
|
// int[] vs List[int])
|
||||||
|
//
|
||||||
|
// Takes a custom printer that users can pass in to customize the output of
|
||||||
|
// this method.
|
||||||
|
std::string annotation_str(TypePrinter printer) const {
|
||||||
|
if (printer) {
|
||||||
|
// the printer can return nullopt to fall through to the default impl
|
||||||
|
if (auto renamed = printer(shared_from_this())) {
|
||||||
|
return *renamed;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return annotation_str_impl(printer);
|
||||||
|
}
|
||||||
|
std::string annotation_str() const {
|
||||||
|
// Overload instead of define a default value for `printer` to help
|
||||||
|
// debuggers out.
|
||||||
|
return annotation_str(nullptr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a human readable string that includes additional information like
|
||||||
|
// "type is inferred rather than explictly defined" to help construct more
|
||||||
|
// user-friendly messages.
|
||||||
|
virtual std::string repr_str() const {
|
||||||
|
return annotation_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
TypeKind kind() const {
|
||||||
|
return kind_;
|
||||||
|
}
|
||||||
|
|
||||||
|
virtual bool requires_grad() const {
|
||||||
|
for (const auto& ct : containedTypes()) {
|
||||||
|
if (ct->requires_grad()) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dynamically cast this object to the subclass indicated by the
|
||||||
|
// template variable, returning nullptr if the cast is invalid.
|
||||||
|
template <typename T>
|
||||||
|
std::shared_ptr<T> cast() {
|
||||||
|
if (T::Kind == kind()) {
|
||||||
|
return std::static_pointer_cast<T>(shared_from_this());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
std::shared_ptr<const T> cast() const {
|
||||||
|
if (T::Kind == kind()) {
|
||||||
|
return std::static_pointer_cast<const T>(shared_from_this());
|
||||||
|
}
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
std::shared_ptr<T> expect() {
|
||||||
|
auto r = cast<T>();
|
||||||
|
AT_ASSERT(r);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
std::shared_ptr<const T> expect() const {
|
||||||
|
auto r = cast<const T>();
|
||||||
|
AT_ASSERT(r);
|
||||||
|
return r;
|
||||||
|
}
|
||||||
|
virtual ~Type() = default;
|
||||||
|
virtual bool hasFreeVariables() const {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
// list of types this type contains, e.g. for a List then element type of a
|
||||||
|
// list for a tuple, the types of the tuple elements
|
||||||
|
virtual at::ArrayRef<TypePtr> containedTypes() const {
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
// create a new version of this type, replacing its contained types with
|
||||||
|
// contained_types
|
||||||
|
TypePtr withContained(std::vector<TypePtr> contained_types) {
|
||||||
|
auto current_contained = containedTypes();
|
||||||
|
AT_ASSERT(current_contained.size() == contained_types.size());
|
||||||
|
if (current_contained.equals(contained_types)) {
|
||||||
|
return shared_from_this();
|
||||||
|
}
|
||||||
|
return createWithContained(std::move(contained_types));
|
||||||
|
}
|
||||||
|
// per-type constructor, you only need to override this if the
|
||||||
|
// containedTypes() is not empty
|
||||||
|
virtual TypePtr createWithContained(
|
||||||
|
std::vector<TypePtr> contained_types) const {
|
||||||
|
AT_ERROR(
|
||||||
|
"type with contained types did not overload createWithContained: ",
|
||||||
|
str());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
}
|
||||||
|
|
@ -68,7 +68,7 @@ Tensor embedding_sparse_backward(
|
||||||
Tensor indices = indices_;
|
Tensor indices = indices_;
|
||||||
Tensor grad = grad_;
|
Tensor grad = grad_;
|
||||||
if (padding_idx != -1) {
|
if (padding_idx != -1) {
|
||||||
auto c = indices != padding_idx;
|
torch::List<c10::optional<Tensor>> c({indices != padding_idx});
|
||||||
indices = indices.index(c);
|
indices = indices.index(c);
|
||||||
grad = grad.index(c);
|
grad = grad.index(c);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
#include <ATen/ExpandUtils.h>
|
#include <ATen/ExpandUtils.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
|
#include <ATen/core/List.h>
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
|
|
||||||
|
|
@ -15,40 +16,45 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static std::vector<Tensor> expandTensors(const Tensor & self, TensorList indices) {
|
static std::vector<Tensor> expandTensors(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
|
||||||
// If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
|
// If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
|
||||||
std::vector<Tensor> result;
|
std::vector<Tensor> result;
|
||||||
for (const auto & index : indices) {
|
for (c10::optional<Tensor> index_opt : indices) {
|
||||||
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
|
if (!index_opt.has_value()) {
|
||||||
if (index.scalar_type() == kByte) {
|
result.emplace_back();
|
||||||
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
|
|
||||||
" please use a dtype torch.bool instead.");
|
|
||||||
}
|
|
||||||
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
|
|
||||||
// corresponding dimensions in self
|
|
||||||
for (int64_t j = 0; j < index.dim(); j++) {
|
|
||||||
int64_t srcIdx = result.size() + j;
|
|
||||||
if (index.size(j) != self.size(srcIdx)) {
|
|
||||||
invalid_mask(self, srcIdx, index, j);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// Replace with nonzeros
|
|
||||||
auto nonzero = index.nonzero();
|
|
||||||
for (int64_t j = 0; j < index.dim(); j++) {
|
|
||||||
result.emplace_back(nonzero.select(1, j));
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
result.emplace_back(index);
|
Tensor index = std::move(*index_opt);
|
||||||
|
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
|
||||||
|
if (index.scalar_type() == kByte) {
|
||||||
|
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
|
||||||
|
" please use a dtype torch.bool instead.");
|
||||||
|
}
|
||||||
|
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
|
||||||
|
// corresponding dimensions in self
|
||||||
|
for (int64_t j = 0; j < index.dim(); j++) {
|
||||||
|
int64_t srcIdx = result.size() + j;
|
||||||
|
if (index.size(j) != self.size(srcIdx)) {
|
||||||
|
invalid_mask(self, srcIdx, index, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Replace with nonzeros
|
||||||
|
auto nonzero = index.nonzero();
|
||||||
|
for (int64_t j = 0; j < index.dim(); j++) {
|
||||||
|
result.emplace_back(nonzero.select(1, j));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
result.emplace_back(std::move(index));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static void checkIndexTensorTypes(TensorList indices) {
|
static void checkIndexTensorTypes(const torch::List<c10::optional<Tensor>>& indices) {
|
||||||
for (auto& tensor : indices) {
|
for (c10::optional<Tensor> tensor : indices) {
|
||||||
if (tensor.defined()) {
|
if (tensor.has_value() && tensor->defined()) {
|
||||||
auto scalarType = tensor.scalar_type();
|
auto scalarType = tensor->scalar_type();
|
||||||
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
|
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
|
||||||
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
|
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
|
||||||
}
|
}
|
||||||
|
|
@ -56,6 +62,15 @@ static void checkIndexTensorTypes(TensorList indices) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline torch::List<c10::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
|
||||||
|
torch::List<c10::optional<Tensor>> result;
|
||||||
|
result.reserve(list.size());
|
||||||
|
for (const Tensor& a : list) {
|
||||||
|
result.push_back(a);
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
static bool hasContiguousSubspace(TensorList tl) {
|
static bool hasContiguousSubspace(TensorList tl) {
|
||||||
// true if all the non-null tensors are adjacent
|
// true if all the non-null tensors are adjacent
|
||||||
auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
|
auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include <ATen/native/Resize.h>
|
#include <ATen/native/Resize.h>
|
||||||
#include <ATen/native/TensorIterator.h>
|
#include <ATen/native/TensorIterator.h>
|
||||||
#include <ATen/native/LinearAlgebra.h>
|
#include <ATen/native/LinearAlgebra.h>
|
||||||
|
#include <ATen/native/IndexingUtils.h>
|
||||||
#include <ATen/TensorUtils.h>
|
#include <ATen/TensorUtils.h>
|
||||||
#include <ATen/Parallel.h>
|
#include <ATen/Parallel.h>
|
||||||
#include <ATen/LegacyTHFunctionsCPU.h>
|
#include <ATen/LegacyTHFunctionsCPU.h>
|
||||||
|
|
@ -73,7 +74,8 @@ Tensor logdet(const Tensor& self) {
|
||||||
// U is singular when U(i, i) = 0 for some i in [1, self.size(-1)].
|
// U is singular when U(i, i) = 0 for some i in [1, self.size(-1)].
|
||||||
Tensor logdet_vals = diag_U.abs_().log_().sum(-1);
|
Tensor logdet_vals = diag_U.abs_().log_().sum(-1);
|
||||||
if (self.dim() > 2) {
|
if (self.dim() > 2) {
|
||||||
logdet_vals.index_put_((det_sign < 0).nonzero_numpy(), at::full({}, NAN, self.options()));
|
auto indices = toListOfOptionalTensors((det_sign < 0).nonzero_numpy());
|
||||||
|
logdet_vals.index_put_(std::move(indices), at::full({}, NAN, self.options()));
|
||||||
} else if (det_sign.item<double>() < 0) {
|
} else if (det_sign.item<double>() < 0) {
|
||||||
logdet_vals.fill_(NAN);
|
logdet_vals.fill_(NAN);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -206,7 +206,7 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static AdvancedIndex make_info(Tensor self, TensorList orig) {
|
static AdvancedIndex make_info(Tensor self, const torch::List<c10::optional<at::Tensor>>& orig) {
|
||||||
checkIndexTensorTypes(orig);
|
checkIndexTensorTypes(orig);
|
||||||
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
||||||
auto indices = expandTensors(self, orig);
|
auto indices = expandTensors(self, orig);
|
||||||
|
|
@ -281,7 +281,7 @@ static TensorIterator make_index_out_iterator(const AdvancedIndex& info, Tensor&
|
||||||
return config.build();
|
return config.build();
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor index(const Tensor & self, TensorList indices) {
|
Tensor index(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
|
||||||
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||||
|
|
||||||
auto info = make_info(self, indices);
|
auto info = make_info(self, indices);
|
||||||
|
|
@ -290,7 +290,7 @@ Tensor index(const Tensor & self, TensorList indices) {
|
||||||
return iter.output();
|
return iter.output();
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor quantized_index(const Tensor & self, TensorList indices) {
|
Tensor quantized_index(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
|
||||||
TORCH_INTERNAL_ASSERT(
|
TORCH_INTERNAL_ASSERT(
|
||||||
self.qscheme() == c10::kPerTensorAffine ||
|
self.qscheme() == c10::kPerTensorAffine ||
|
||||||
self.qscheme() == c10::kPerTensorSymmetric,
|
self.qscheme() == c10::kPerTensorSymmetric,
|
||||||
|
|
@ -311,12 +311,14 @@ Tensor quantized_index(const Tensor & self, TensorList indices) {
|
||||||
res, self.q_scale(), self.q_zero_point(), self.scalar_type());
|
res, self.q_scale(), self.q_zero_point(), self.scalar_type());
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) {
|
Tensor& index_out(Tensor& result, const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
|
||||||
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||||
at::assert_no_internal_overlap(result);
|
at::assert_no_internal_overlap(result);
|
||||||
at::assert_no_overlap(result, self);
|
at::assert_no_overlap(result, self);
|
||||||
for (auto& index: indices) {
|
for (const c10::optional<Tensor>& index: indices) {
|
||||||
at::assert_no_overlap(result, index);
|
if (index.has_value()) {
|
||||||
|
at::assert_no_overlap(result, *index);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto info = make_info(self, indices);
|
auto info = make_info(self, indices);
|
||||||
|
|
@ -325,11 +327,11 @@ Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
|
Tensor index_put(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, bool accumulate) {
|
||||||
return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate);
|
return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate, const bool unsafe) {
|
Tensor & _index_put_impl_(Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, const bool accumulate, const bool unsafe) {
|
||||||
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||||
if (at::has_internal_overlap(self) == MemOverlap::YES) {
|
if (at::has_internal_overlap(self) == MemOverlap::YES) {
|
||||||
TORCH_WARN(
|
TORCH_WARN(
|
||||||
|
|
@ -338,8 +340,10 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu
|
||||||
"This also applies to advanced indexing e.g. tensor[indices] = tensor");
|
"This also applies to advanced indexing e.g. tensor[indices] = tensor");
|
||||||
}
|
}
|
||||||
at::assert_no_overlap(self, value);
|
at::assert_no_overlap(self, value);
|
||||||
for (auto& index: indices) {
|
for (const c10::optional<Tensor>& index: indices) {
|
||||||
at::assert_no_overlap(self, index);
|
if (index.has_value()) {
|
||||||
|
at::assert_no_overlap(self, *index);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (accumulate && self.device().type() == kCUDA) {
|
if (accumulate && self.device().type() == kCUDA) {
|
||||||
|
|
@ -356,7 +360,7 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate) {
|
Tensor & index_put_(Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, const bool accumulate) {
|
||||||
return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false);
|
return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_MULTIPLY};
|
||||||
|
|
||||||
using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
|
using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
|
||||||
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
|
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
|
||||||
using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe);
|
using index_put_accum_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool unsafe);
|
||||||
using masked_fill_fn = void(*)(TensorIterator &, Scalar scalar);
|
using masked_fill_fn = void(*)(TensorIterator &, Scalar scalar);
|
||||||
using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
|
using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
|
||||||
|
|
||||||
|
|
@ -42,6 +42,6 @@ DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
|
||||||
DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
|
DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
|
||||||
DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
|
DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
|
||||||
|
|
||||||
TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices);
|
TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<c10::optional<at::Tensor>>& indices);
|
||||||
|
|
||||||
}} // namespace at::native
|
}} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -190,7 +190,7 @@ static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self
|
||||||
Tensor _mask = (mask.dim() == 0) ? mask.unsqueeze(0) : mask;
|
Tensor _mask = (mask.dim() == 0) ? mask.unsqueeze(0) : mask;
|
||||||
Tensor _self = (self.dim() == 0) ? self.unsqueeze(0) : self;
|
Tensor _self = (self.dim() == 0) ? self.unsqueeze(0) : self;
|
||||||
std::tie(_mask, _self) = expand_outplace(_mask, _self);
|
std::tie(_mask, _self) = expand_outplace(_mask, _self);
|
||||||
at::native::index_out(result, _self, _mask);
|
at::native::index_out(result, _self, c10::List<c10::optional<at::Tensor>>({_mask}));
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -160,7 +160,7 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, TensorList orig, bool check_range) {
|
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, const c10::List<c10::optional<at::Tensor>>& orig, bool check_range) {
|
||||||
checkIndexTensorTypes(orig);
|
checkIndexTensorTypes(orig);
|
||||||
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
|
||||||
auto indices = expandTensors(self, orig);
|
auto indices = expandTensors(self, orig);
|
||||||
|
|
@ -184,7 +184,7 @@ static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t
|
||||||
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
void index_put_accum_kernel(Tensor & self, TensorList indices, const Tensor & value, bool unsafe) {
|
void index_put_accum_kernel(Tensor & self, const c10::List<c10::optional<Tensor>>& indices, const Tensor & value, bool unsafe) {
|
||||||
if (indices.size() > (size_t)self.dim()) {
|
if (indices.size() > (size_t)self.dim()) {
|
||||||
TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -2226,6 +2226,7 @@
|
||||||
use_c10_dispatcher: full
|
use_c10_dispatcher: full
|
||||||
|
|
||||||
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
|
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
|
||||||
|
use_c10_dispatcher: full
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: index
|
CPU, CUDA: index
|
||||||
|
|
@ -2254,6 +2255,7 @@
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
||||||
- func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
|
- func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
|
||||||
|
use_c10_dispatcher: full
|
||||||
variants: function, method
|
variants: function, method
|
||||||
dispatch:
|
dispatch:
|
||||||
DefaultBackend: index_put_
|
DefaultBackend: index_put_
|
||||||
|
|
@ -2264,9 +2266,11 @@
|
||||||
# - Tensor & Tensor::index_put_(std::initializer_list<TensorIndex> indices, Scalar v)
|
# - Tensor & Tensor::index_put_(std::initializer_list<TensorIndex> indices, Scalar v)
|
||||||
|
|
||||||
- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
|
- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
|
||||||
|
use_c10_dispatcher: full
|
||||||
variants: function, method
|
variants: function, method
|
||||||
|
|
||||||
- func: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)
|
- func: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)
|
||||||
|
use_c10_dispatcher: full
|
||||||
variants: function
|
variants: function
|
||||||
dispatch:
|
dispatch:
|
||||||
CPU, CUDA: _index_put_impl_
|
CPU, CUDA: _index_put_impl_
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@
|
||||||
#include <ATen/NativeFunctions.h>
|
#include <ATen/NativeFunctions.h>
|
||||||
#include <ATen/InitialTensorOptions.h>
|
#include <ATen/InitialTensorOptions.h>
|
||||||
#include <ATen/SparseTensorUtils.h>
|
#include <ATen/SparseTensorUtils.h>
|
||||||
|
#include <ATen/native/IndexingUtils.h>
|
||||||
|
|
||||||
#include <TH/THBlasUtils.h>
|
#include <TH/THBlasUtils.h>
|
||||||
|
|
||||||
|
|
@ -14,7 +15,6 @@ namespace at { namespace native {
|
||||||
|
|
||||||
using namespace at::sparse;
|
using namespace at::sparse;
|
||||||
|
|
||||||
|
|
||||||
/******************************************************************************
|
/******************************************************************************
|
||||||
* access methods
|
* access methods
|
||||||
******************************************************************************/
|
******************************************************************************/
|
||||||
|
|
@ -328,7 +328,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
|
||||||
|
|
||||||
Tensor values;
|
Tensor values;
|
||||||
if (self.dim() > 0) {
|
if (self.dim() > 0) {
|
||||||
std::vector<Tensor> ix = indices.chunk(indices.size(0), 0);
|
auto ix = toListOfOptionalTensors(indices.chunk(indices.size(0), 0));
|
||||||
values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve);
|
values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve);
|
||||||
} else {
|
} else {
|
||||||
AT_ASSERT(nz.sizes().equals({0, 1}));
|
AT_ASSERT(nz.sizes().equals({0, 1}));
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ class Tensor;
|
||||||
}
|
}
|
||||||
namespace c10{
|
namespace c10{
|
||||||
struct TensorOptions;
|
struct TensorOptions;
|
||||||
|
template<class T> class List;
|
||||||
}
|
}
|
||||||
namespace at {
|
namespace at {
|
||||||
struct Generator;
|
struct Generator;
|
||||||
|
|
|
||||||
|
|
@ -6,13 +6,17 @@ namespace caffe2 {
|
||||||
namespace internal {
|
namespace internal {
|
||||||
at::Tensor index_with_uint8_handling(
|
at::Tensor index_with_uint8_handling(
|
||||||
const at::Tensor& self,
|
const at::Tensor& self,
|
||||||
at::TensorList indices) {
|
const torch::List<c10::optional<at::Tensor>>& indices) {
|
||||||
// Support BC only for the simplest case of mask indexing
|
// Support BC only for the simplest case of mask indexing
|
||||||
if (indices.size() == 1 && indices[0].scalar_type() == at::kByte) {
|
if (indices.size() == 1) {
|
||||||
TORCH_WARN(
|
c10::optional<at::Tensor> first = indices[0];
|
||||||
"Indexing with uint8 mask tensor in ATenOp is now deprecated,"
|
if (first.has_value()
|
||||||
" please use a bool mask instead.");
|
&& first->scalar_type() == at::kByte) {
|
||||||
return at::index(self, {indices[0].to(at::kBool)});
|
TORCH_WARN(
|
||||||
|
"Indexing with uint8 mask tensor in ATenOp is now deprecated,"
|
||||||
|
" please use a bool mask instead.");
|
||||||
|
return at::index(self, {first->to(at::kBool)});
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return at::index(self, indices);
|
return at::index(self, indices);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ using at::Half; // for AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ...)
|
||||||
namespace internal {
|
namespace internal {
|
||||||
TORCH_API at::Tensor index_with_uint8_handling(
|
TORCH_API at::Tensor index_with_uint8_handling(
|
||||||
const at::Tensor& self,
|
const at::Tensor& self,
|
||||||
at::TensorList indices);
|
const torch::List<c10::optional<at::Tensor>>& indices);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class Context>
|
template <class Context>
|
||||||
|
|
@ -86,6 +86,16 @@ private:
|
||||||
|
|
||||||
std::vector<at::Tensor> peekSlice(size_t i, size_t len, size_t N) {
|
std::vector<at::Tensor> peekSlice(size_t i, size_t len, size_t N) {
|
||||||
std::vector<at::Tensor> results;
|
std::vector<at::Tensor> results;
|
||||||
|
results.reserve(len);
|
||||||
|
for (size_t ii = i; ii < i + len; ++ii) {
|
||||||
|
results.push_back(peek(ii, N));
|
||||||
|
}
|
||||||
|
return results;
|
||||||
|
}
|
||||||
|
|
||||||
|
torch::List<c10::optional<at::Tensor>> peekSliceOptionals(size_t i, size_t len, size_t N) {
|
||||||
|
torch::List<c10::optional<at::Tensor>> results;
|
||||||
|
results.reserve(len);
|
||||||
for (size_t ii = i; ii < i + len; ++ii) {
|
for (size_t ii = i; ii < i + len; ++ii) {
|
||||||
results.push_back(peek(ii, N));
|
results.push_back(peek(ii, N));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -68,7 +68,7 @@ def value_has_tensors(v):
|
||||||
|
|
||||||
|
|
||||||
def value_is_tensor_type(v):
|
def value_is_tensor_type(v):
|
||||||
return value_has_tensors(v) and v['dynamic_type'] != 'TensorList'
|
return value_has_tensors(v) and v['dynamic_type'] not in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']
|
||||||
|
|
||||||
|
|
||||||
# for each aten type, how do we handle a return value of that type?
|
# for each aten type, how do we handle a return value of that type?
|
||||||
|
|
@ -208,7 +208,7 @@ def self_as_first_argument(arguments):
|
||||||
def get_num_inputs(o):
|
def get_num_inputs(o):
|
||||||
args = 0
|
args = 0
|
||||||
for a in o['arguments']:
|
for a in o['arguments']:
|
||||||
if a['type'] == 'TensorList':
|
if a['type'] in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']:
|
||||||
return '*'
|
return '*'
|
||||||
elif value_has_tensors(a):
|
elif value_has_tensors(a):
|
||||||
args += 1
|
args += 1
|
||||||
|
|
@ -277,10 +277,10 @@ if __name__ == '__main__':
|
||||||
# e.g. "Float" is at::kFloat
|
# e.g. "Float" is at::kFloat
|
||||||
assert('Type' in o['method_of'])
|
assert('Type' in o['method_of'])
|
||||||
|
|
||||||
static_tensor_inputs = sum(arg['type'] != 'TensorList' and value_is_tensor_type(arg) for arg in o['arguments'])
|
static_tensor_inputs = sum(arg['type'] not in ['TensorList', 'const c10::List<c10::optional<Tensor>> &'] and value_is_tensor_type(arg) for arg in o['arguments'])
|
||||||
has_tensorlist = any(arg['type'] == 'TensorList' for arg in o['arguments'])
|
has_tensorlist = any(arg['type'] in ['TensorList', 'const c10::List<c10::optional<Tensor>> &'] for arg in o['arguments'])
|
||||||
if has_tensorlist:
|
if has_tensorlist:
|
||||||
tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] == 'TensorList'][0]
|
tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']][0]
|
||||||
|
|
||||||
real_inputs = 0
|
real_inputs = 0
|
||||||
for i, arg in enumerate(o['arguments']):
|
for i, arg in enumerate(o['arguments']):
|
||||||
|
|
@ -290,10 +290,16 @@ if __name__ == '__main__':
|
||||||
view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs
|
view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs
|
||||||
if arg['type'] == 'TensorList':
|
if arg['type'] == 'TensorList':
|
||||||
# NOTE: do not advance real_inputs here. After this we will
|
# NOTE: do not advance real_inputs here. After this we will
|
||||||
# switch to indexing the "stack" from the end as if we only had
|
# switch to indexing the "stack" from the end
|
||||||
env['statements'].append(
|
env['statements'].append(
|
||||||
'auto {} = peekSlice({}, InputSize() - {}, InputSize());'
|
'auto {} = peekSlice({}, InputSize() - {}, InputSize());'
|
||||||
.format(arg['name'], real_inputs, static_tensor_inputs))
|
.format(arg['name'], real_inputs, static_tensor_inputs))
|
||||||
|
elif arg['type'] == 'const c10::List<c10::optional<Tensor>> &':
|
||||||
|
# NOTE: do not advance real_inputs here. After this we will
|
||||||
|
# switch to indexing the "stack" from the end
|
||||||
|
env['statements'].append(
|
||||||
|
'auto {} = peekSliceOptionals({}, InputSize() - {}, InputSize());'
|
||||||
|
.format(arg['name'], real_inputs, static_tensor_inputs))
|
||||||
elif value_is_tensor_type(arg):
|
elif value_is_tensor_type(arg):
|
||||||
# load tensor inputs from Caffe2
|
# load tensor inputs from Caffe2
|
||||||
env['statements'].append(
|
env['statements'].append(
|
||||||
|
|
|
||||||
|
|
@ -83,27 +83,27 @@ TEST(TensorIndexingTest, TestNoIndices) {
|
||||||
ASSERT_THROWS_WITH(tensor.index_put_(indices, value), "Passing an empty index list to Tensor::index_put_() is not valid syntax");
|
ASSERT_THROWS_WITH(tensor.index_put_(indices, value), "Passing an empty index list to Tensor::index_put_() is not valid syntax");
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(TensorIndexingTest, TestAdvancedIndexingWithArrayRefOfTensor) {
|
TEST(TensorIndexingTest, TestAdvancedIndexingWithListOfTensor) {
|
||||||
{
|
{
|
||||||
torch::Tensor tensor = torch::randn({20, 20});
|
torch::Tensor tensor = torch::randn({20, 20});
|
||||||
torch::Tensor index = torch::arange(10, torch::kLong).cpu();
|
torch::Tensor index = torch::arange(10, torch::kLong).cpu();
|
||||||
torch::Tensor result_with_array_ref = tensor.index(at::ArrayRef<torch::Tensor>({index}));
|
torch::Tensor result = at::index(tensor, {index});
|
||||||
torch::Tensor result_with_init_list = tensor.index({index});
|
torch::Tensor result_with_init_list = tensor.index({index});
|
||||||
ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list));
|
ASSERT_TRUE(result.equal(result_with_init_list));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
torch::Tensor tensor = torch::randn({20, 20});
|
torch::Tensor tensor = torch::randn({20, 20});
|
||||||
torch::Tensor index = torch::arange(10, torch::kLong).cpu();
|
torch::Tensor index = torch::arange(10, torch::kLong).cpu();
|
||||||
torch::Tensor result_with_array_ref = tensor.index_put_(at::ArrayRef<torch::Tensor>({index}), torch::ones({20}));
|
torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({20}));
|
||||||
torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({20}));
|
torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({20}));
|
||||||
ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list));
|
ASSERT_TRUE(result.equal(result_with_init_list));
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
torch::Tensor tensor = torch::randn({20, 20});
|
torch::Tensor tensor = torch::randn({20, 20});
|
||||||
torch::Tensor index = torch::arange(10, torch::kLong).cpu();
|
torch::Tensor index = torch::arange(10, torch::kLong).cpu();
|
||||||
torch::Tensor result_with_array_ref = tensor.index_put_(at::ArrayRef<torch::Tensor>({index}), torch::ones({1, 20}));
|
torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({1, 20}));
|
||||||
torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({1, 20}));
|
torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({1, 20}));
|
||||||
ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list));
|
ASSERT_TRUE(result.equal(result_with_init_list));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -173,7 +173,7 @@ TEST(TensorIndexingTest, TestBoolIndices) {
|
||||||
TEST(TensorIndexingTest, TestBoolIndicesAccumulate) {
|
TEST(TensorIndexingTest, TestBoolIndicesAccumulate) {
|
||||||
auto mask = torch::zeros({10}, torch::kBool);
|
auto mask = torch::zeros({10}, torch::kBool);
|
||||||
auto y = torch::ones({10, 10});
|
auto y = torch::ones({10, 10});
|
||||||
y.index_put_({mask}, y.index({mask}), /*accumulate=*/true);
|
y.index_put_({mask}, {y.index({mask})}, /*accumulate=*/true);
|
||||||
assert_tensor_equal(y, torch::ones({10, 10}));
|
assert_tensor_equal(y, torch::ones({10, 10}));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -563,6 +563,8 @@ def generate_tensor_like_override_tests(cls):
|
||||||
func_args.append(instance_gen())
|
func_args.append(instance_gen())
|
||||||
elif t == 'TensorList':
|
elif t == 'TensorList':
|
||||||
func_args.append([instance_gen(), instance_gen()])
|
func_args.append([instance_gen(), instance_gen()])
|
||||||
|
elif t == 'c10::List<c10::optional<Tensor>>':
|
||||||
|
func_args.append([instance_gen(), instance_gen()])
|
||||||
elif t == 'IntArrayRef':
|
elif t == 'IntArrayRef':
|
||||||
size = arg.get('size', 2)
|
size = arg.get('size', 2)
|
||||||
if size == 1:
|
if size == 1:
|
||||||
|
|
|
||||||
|
|
@ -141,7 +141,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
||||||
compute_index_ranges: List[str] = []
|
compute_index_ranges: List[str] = []
|
||||||
|
|
||||||
for arg in info.args_with_derivatives:
|
for arg in info.args_with_derivatives:
|
||||||
if arg.type == 'TensorList':
|
if arg.type == 'TensorList' or arg.type == 'const c10::List<c10::optional<Tensor>> &':
|
||||||
size = f'{arg.name}_size_'
|
size = f'{arg.name}_size_'
|
||||||
saved_list_sizes.append(f'size_t {arg.name}_size_;')
|
saved_list_sizes.append(f'size_t {arg.name}_size_;')
|
||||||
else:
|
else:
|
||||||
|
|
@ -166,6 +166,15 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
|
||||||
release_variables.append(f'{name}_released_ = true;')
|
release_variables.append(f'{name}_released_ = true;')
|
||||||
unpack.append(f'auto {name} = unpack_list({name}_);')
|
unpack.append(f'auto {name} = unpack_list({name}_);')
|
||||||
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
|
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
|
||||||
|
elif var.type == 'c10::List<c10::optional<Tensor>>':
|
||||||
|
saved_variables.append(f'std::vector<SavedVariable> {name}_;')
|
||||||
|
saved_variables.append(f'bool {name}_released_ = false;')
|
||||||
|
# Just clear() is sufficient, we don't need to loop and clear each variable.
|
||||||
|
# Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
|
||||||
|
release_variables.append(f'{name}_.clear();')
|
||||||
|
release_variables.append(f'{name}_released_ = true;')
|
||||||
|
unpack.append(f'auto {name} = unpack_opt_list({name}_);')
|
||||||
|
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
|
||||||
elif var.type == 'IntArrayRef':
|
elif var.type == 'IntArrayRef':
|
||||||
saved_variables.append(f'std::vector<int64_t> {name};')
|
saved_variables.append(f'std::vector<int64_t> {name};')
|
||||||
elif var.type == 'c10::optional<IntArrayRef>':
|
elif var.type == 'c10::optional<IntArrayRef>':
|
||||||
|
|
|
||||||
|
|
@ -112,9 +112,8 @@ def format_trace_inputs(f: NativeFunction) -> str:
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
name = arg.name
|
name = arg.name
|
||||||
# XXX: For arg that have type of Tensor?[], tracer will pass allow_undefined to addInputs
|
|
||||||
if str(arg.type) == 'Tensor?[]':
|
if str(arg.type) == 'Tensor?[]':
|
||||||
return [f'jit::tracer::addInputs(node, "{name}", {name}, true);']
|
return [f'jit::tracer::addInputs(node, "{name}", {name});']
|
||||||
else:
|
else:
|
||||||
return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
|
return [ADD_TRACE_INPUT.substitute(name=name, input=name)]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -118,6 +118,21 @@ for (size_t i=0; i<${tensorlist_name}.size(); i++) {
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\
|
||||||
|
std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
|
||||||
|
for (const c10::optional<Tensor>& tensor : ${tensorlist_name})
|
||||||
|
${tensorlist_name}_storage_saved.push_back(
|
||||||
|
tensor.has_value() && tensor->has_storage() ? c10::optional<Storage>(tensor->storage()) : c10::nullopt);
|
||||||
|
""")
|
||||||
|
|
||||||
|
ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\
|
||||||
|
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
|
||||||
|
if (${tensorlist_name}_storage_saved[i].has_value())
|
||||||
|
AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(
|
||||||
|
static_cast<c10::optional<Tensor>>(${tensorlist_name}[i])->storage()));
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
|
||||||
SAVE_TENSOR_IMPL = CodeTemplate("""\
|
SAVE_TENSOR_IMPL = CodeTemplate("""\
|
||||||
c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved;
|
c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved;
|
||||||
if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
|
if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
|
||||||
|
|
@ -140,6 +155,21 @@ for (size_t i=0; i<${tensorlist_name}.size(); i++) {
|
||||||
}
|
}
|
||||||
""")
|
""")
|
||||||
|
|
||||||
|
SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\
|
||||||
|
std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
|
||||||
|
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
|
||||||
|
c10::optional<Tensor> t = ${tensorlist_name}[i];
|
||||||
|
if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr();
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
|
||||||
|
ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\
|
||||||
|
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
|
||||||
|
if (${tensorlist_name}_impl_saved[i])
|
||||||
|
AT_ASSERT(${tensorlist_name}_impl_saved[i] == static_cast<c10::optional<Tensor>>(${tensorlist_name}[i])->getIntrusivePtr());
|
||||||
|
}
|
||||||
|
""")
|
||||||
|
|
||||||
# The following list contains functions that we don't enforce the invariant on.
|
# The following list contains functions that we don't enforce the invariant on.
|
||||||
DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
|
DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
|
||||||
# These functions are expected to change impl or storage of input tensors
|
# These functions are expected to change impl or storage of input tensors
|
||||||
|
|
@ -466,7 +496,8 @@ def emit_body(declaration):
|
||||||
if func is None:
|
if func is None:
|
||||||
return setup
|
return setup
|
||||||
|
|
||||||
has_tensorlist_arg = any(arg.type == 'TensorList' for arg in func.args_with_derivatives)
|
has_tensorlist_arg = \
|
||||||
|
any(arg.type in ['TensorList', 'const c10::List<c10::optional<Tensor>> &'] for arg in func.args_with_derivatives)
|
||||||
|
|
||||||
# We don't want to save tensors if we know that they will never be used
|
# We don't want to save tensors if we know that they will never be used
|
||||||
# when computing the derivative, so we add guards to those statements
|
# when computing the derivative, so we add guards to those statements
|
||||||
|
|
@ -515,7 +546,7 @@ def emit_body(declaration):
|
||||||
|
|
||||||
setup.extend(save_variables(func.all_saved_inputs, False, guard_for))
|
setup.extend(save_variables(func.all_saved_inputs, False, guard_for))
|
||||||
for arg in func.args_with_derivatives:
|
for arg in func.args_with_derivatives:
|
||||||
if arg.type == 'TensorList':
|
if arg.type in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']:
|
||||||
setup.append(f'grad_fn->{arg.name}_size_ = {arg.name}.size();')
|
setup.append(f'grad_fn->{arg.name}_size_ = {arg.name}.size();')
|
||||||
|
|
||||||
return setup
|
return setup
|
||||||
|
|
@ -554,7 +585,7 @@ def emit_body(declaration):
|
||||||
return body
|
return body
|
||||||
for arg in differentiable_outputs:
|
for arg in differentiable_outputs:
|
||||||
name = arg['name']
|
name = arg['name']
|
||||||
if arg['type'] == 'Tensor' or arg['type'] == 'TensorList':
|
if arg['type'] in ['Tensor', 'TensorList', 'const c10::List<c10::optional<Tensor>> &']:
|
||||||
body.append('throw_error_for_complex_autograd({}, "{}");'.format(name, base_name))
|
body.append('throw_error_for_complex_autograd({}, "{}");'.format(name, base_name))
|
||||||
return body
|
return body
|
||||||
|
|
||||||
|
|
@ -599,7 +630,7 @@ def emit_body(declaration):
|
||||||
expr = f'SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})'
|
expr = f'SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})'
|
||||||
else:
|
else:
|
||||||
expr = f'SavedVariable({var}, {str(is_output).lower()})'
|
expr = f'SavedVariable({var}, {str(is_output).lower()})'
|
||||||
elif arg.type == 'TensorList':
|
elif arg.type in ['TensorList', 'c10::List<c10::optional<Tensor>>']:
|
||||||
name += '_'
|
name += '_'
|
||||||
expr = f'make_saved_variable_list({arg.name})'
|
expr = f'make_saved_variable_list({arg.name})'
|
||||||
elif arg.type == 'IntArrayRef':
|
elif arg.type == 'IntArrayRef':
|
||||||
|
|
@ -699,7 +730,7 @@ def emit_body(declaration):
|
||||||
# Only allow rebasing of the history if we return a single Tensor
|
# Only allow rebasing of the history if we return a single Tensor
|
||||||
# If we are in a no grad block, raise a warning
|
# If we are in a no grad block, raise a warning
|
||||||
# See NOTE [ View + Inplace detection ] for more details about this logic
|
# See NOTE [ View + Inplace detection ] for more details about this logic
|
||||||
if return_info['dynamic_type'] == 'TensorList':
|
if return_info['dynamic_type'] in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']:
|
||||||
if base_name in MULTI_OUTPUT_SAFE_FUNCTIONS:
|
if base_name in MULTI_OUTPUT_SAFE_FUNCTIONS:
|
||||||
creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE"
|
creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE"
|
||||||
else:
|
else:
|
||||||
|
|
@ -736,6 +767,11 @@ def emit_body(declaration):
|
||||||
SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
|
SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
|
||||||
enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
|
enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
|
||||||
ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
|
ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
|
||||||
|
elif simple_type == 'c10::List<c10::optional<Tensor>>':
|
||||||
|
save_ptrs_stmts += [SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
|
||||||
|
SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)]
|
||||||
|
enforce_same_ptrs_stmts += [ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
|
||||||
|
ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)]
|
||||||
elif simple_type == 'Tensor':
|
elif simple_type == 'Tensor':
|
||||||
save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
|
save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
|
||||||
SAVE_TENSOR_IMPL.substitute(tensor_name=arg)]
|
SAVE_TENSOR_IMPL.substitute(tensor_name=arg)]
|
||||||
|
|
@ -836,7 +872,7 @@ def emit_body(declaration):
|
||||||
|
|
||||||
def unpack_args(env, declaration):
|
def unpack_args(env, declaration):
|
||||||
def requires_unpack(arg):
|
def requires_unpack(arg):
|
||||||
return 'Tensor' in arg['dynamic_type']
|
return 'Tensor' in arg['dynamic_type'] and 'c10::optional' not in arg['type']
|
||||||
|
|
||||||
body = []
|
body = []
|
||||||
unpacked_args = []
|
unpacked_args = []
|
||||||
|
|
@ -855,9 +891,8 @@ def unpack_args(env, declaration):
|
||||||
dynamic_type = arg['dynamic_type']
|
dynamic_type = arg['dynamic_type']
|
||||||
if 'TensorOptions' not in dynamic_type:
|
if 'TensorOptions' not in dynamic_type:
|
||||||
is_nullable = arg.get('is_nullable', False)
|
is_nullable = arg.get('is_nullable', False)
|
||||||
ref = (not is_nullable) and dynamic_type not in ['TensorList']
|
ref = (not is_nullable) and dynamic_type != 'TensorList'
|
||||||
suffix = '_opt' if is_nullable and dynamic_type != 'TensorList' else ''
|
suffix = '_opt' if is_nullable and dynamic_type != 'TensorList' else ''
|
||||||
|
|
||||||
body.append(UNPACK_TENSOR.substitute(
|
body.append(UNPACK_TENSOR.substitute(
|
||||||
arg_name=arg['name'],
|
arg_name=arg['name'],
|
||||||
arg_pos=i,
|
arg_pos=i,
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,15 @@ inline std::vector<Tensor> unpack_list(at::ArrayRef<SavedVariable> xs) {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline c10::List<c10::optional<Tensor>> unpack_opt_list(at::ArrayRef<SavedVariable> xs) {
|
||||||
|
torch::List<c10::optional<Tensor>> result;
|
||||||
|
result.reserve(xs.size());
|
||||||
|
for (const SavedVariable& v : xs) {
|
||||||
|
result.push_back(v.unpack());
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
struct TypeAndSize {
|
struct TypeAndSize {
|
||||||
TypeAndSize() : options(at::TensorOptions()) {}
|
TypeAndSize() : options(at::TensorOptions()) {}
|
||||||
/* implicit */
|
/* implicit */
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,6 @@ namespace VariableType {
|
||||||
at::Tensor & unpack(Tensor & t, const char * name, int pos);
|
at::Tensor & unpack(Tensor & t, const char * name, int pos);
|
||||||
const at::Tensor & unpack(const Tensor & t, const char * name, int pos);
|
const at::Tensor & unpack(const Tensor & t, const char * name, int pos);
|
||||||
at::Tensor unpack_opt(const Tensor & t, const char * name, int pos);
|
at::Tensor unpack_opt(const Tensor & t, const char * name, int pos);
|
||||||
c10::optional<at::Tensor> unpack_opt(const c10::optional<Tensor> & t, const char * name, int pos);
|
|
||||||
std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos);
|
std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -104,9 +104,11 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
|
||||||
return BaseCType("TensorList", binds)
|
return BaseCType("TensorList", binds)
|
||||||
elif str(t.elem) == 'Dimname':
|
elif str(t.elem) == 'Dimname':
|
||||||
return BaseCType("DimnameList", binds)
|
return BaseCType("DimnameList", binds)
|
||||||
# TODO: do something reasonable about lists of optional tensors
|
elif str(t.elem) == 'Tensor?':
|
||||||
elif (not local.use_c10_dispatcher().dispatcher_uses_new_style()) and str(t.elem) == 'Tensor?':
|
if local.use_c10_dispatcher().dispatcher_uses_new_style():
|
||||||
return BaseCType("TensorList", binds)
|
return BaseCType("const c10::List<c10::optional<Tensor>> &", binds)
|
||||||
|
else:
|
||||||
|
return BaseCType("TensorList", binds)
|
||||||
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
|
||||||
# TODO: explicitly qualify namespace here
|
# TODO: explicitly qualify namespace here
|
||||||
return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds)
|
return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds)
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
|
||||||
else:
|
else:
|
||||||
return ConstRefCType(BaseCType('Tensor', binds))
|
return ConstRefCType(BaseCType('Tensor', binds))
|
||||||
elif str(t) == 'Tensor?[]':
|
elif str(t) == 'Tensor?[]':
|
||||||
return BaseCType('TensorList', binds)
|
return BaseCType('const c10::List<c10::optional<Tensor>> &', binds)
|
||||||
return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
|
return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
|
||||||
|
|
||||||
def returns_type(rs: Sequence[Return]) -> str:
|
def returns_type(rs: Sequence[Return]) -> str:
|
||||||
|
|
|
||||||
|
|
@ -228,7 +228,7 @@ class PythonArgument:
|
||||||
# Compute argument formal for python argument parsing.
|
# Compute argument formal for python argument parsing.
|
||||||
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
|
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
|
||||||
def argument_str(self, *, method: bool = False) -> str:
|
def argument_str(self, *, method: bool = False) -> str:
|
||||||
type_str = argument_type_str(self.type)
|
type_str = argument_type_str(self.type).replace('const ', '').replace(' &', '')
|
||||||
|
|
||||||
name = self.name
|
name = self.name
|
||||||
# s/self/input/ outside method bindings
|
# s/self/input/ outside method bindings
|
||||||
|
|
@ -624,10 +624,9 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str:
|
||||||
return f'ScalarList[{size}]' if size is not None else 'ScalarList'
|
return f'ScalarList[{size}]' if size is not None else 'ScalarList'
|
||||||
elif str(t.elem) == 'Tensor?':
|
elif str(t.elem) == 'Tensor?':
|
||||||
if simple_type:
|
if simple_type:
|
||||||
return 'TensorList'
|
return 'c10::List<c10::optional<Tensor>>'
|
||||||
else:
|
else:
|
||||||
# TODO: clone the old codegen behavior but does it make sense?
|
return 'const c10::List<c10::optional<Tensor>> &'
|
||||||
return 'TensorList?'
|
|
||||||
elif str(t.elem) == 'Dimname':
|
elif str(t.elem) == 'Dimname':
|
||||||
return f'DimnameList[{size}]' if size is not None else 'DimnameList'
|
return f'DimnameList[{size}]' if size is not None else 'DimnameList'
|
||||||
elem = argument_type_str(t.elem, simple_type=simple_type)
|
elem = argument_type_str(t.elem, simple_type=simple_type)
|
||||||
|
|
@ -1051,12 +1050,14 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str:
|
||||||
return 'toDimnameListOptional'
|
return 'toDimnameListOptional'
|
||||||
|
|
||||||
elif isinstance(t, ListType):
|
elif isinstance(t, ListType):
|
||||||
if str(t.elem) == 'Tensor' or str(t.elem) == 'Tensor?':
|
if str(t.elem) == 'Tensor':
|
||||||
# accept and use definite size
|
# accept and use definite size
|
||||||
if t.size is not None:
|
if t.size is not None:
|
||||||
return f'tensorlist_n<{t.size}>'
|
return f'tensorlist_n<{t.size}>'
|
||||||
else:
|
else:
|
||||||
return 'tensorlist'
|
return 'tensorlist'
|
||||||
|
elif str(t.elem) == 'Tensor?':
|
||||||
|
return 'list_of_optional_tensors'
|
||||||
elif str(t.elem) == 'Dimname':
|
elif str(t.elem) == 'Dimname':
|
||||||
# accept definite size
|
# accept definite size
|
||||||
return 'dimnamelist'
|
return 'dimnamelist'
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@
|
||||||
#include <ATen/ScalarOps.h>
|
#include <ATen/ScalarOps.h>
|
||||||
#include <ATen/native/LinearAlgebraUtils.h>
|
#include <ATen/native/LinearAlgebraUtils.h>
|
||||||
#include <ATen/SparseTensorUtils.h>
|
#include <ATen/SparseTensorUtils.h>
|
||||||
|
#include <ATen/native/IndexingUtils.h>
|
||||||
|
|
||||||
#include <ciso646>
|
#include <ciso646>
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
|
|
@ -2211,15 +2212,17 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det)
|
||||||
return nonsingular_case_backward(grad, self, det);
|
return nonsingular_case_backward(grad, self, det);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto nonzero_det_indices = at::where(det);
|
auto nonzero_det_indices = at::native::toListOfOptionalTensors(at::where(det));
|
||||||
|
c10::optional<Tensor> first_nonzero_det_index = nonzero_det_indices[0];
|
||||||
|
|
||||||
if (nonzero_det_indices[0].size(0) == det.numel()) { // all determinants are nonzero (non-singular)
|
if (first_nonzero_det_index->size(0) == det.numel()) { // all determinants are nonzero (non-singular)
|
||||||
return nonsingular_case_backward(grad, self, det);
|
return nonsingular_case_backward(grad, self, det);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto zero_det_indices = at::where(det == 0);
|
auto zero_det_indices = at::native::toListOfOptionalTensors(at::where(det == 0));
|
||||||
|
c10::optional<Tensor> first_zero_det_index = zero_det_indices[0];
|
||||||
|
|
||||||
if (zero_det_indices[0].size(0) == det.numel()) { // all determinants are zero (singular)
|
if (first_zero_det_index->size(0) == det.numel()) { // all determinants are zero (singular)
|
||||||
return singular_case_backward(grad, self, det);
|
return singular_case_backward(grad, self, det);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2261,15 +2264,17 @@ Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& lo
|
||||||
return singular_case_backward(grad, self);
|
return singular_case_backward(grad, self);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto finite_logdet_indices = at::where(logdet != -INFINITY);
|
auto finite_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY));
|
||||||
|
c10::optional<Tensor> first_finite_logdet_index = finite_logdet_indices[0];
|
||||||
|
|
||||||
if (finite_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are finite (non-singular)
|
if (first_finite_logdet_index->size(0) == logdet.numel()) { // all log determinants are finite (non-singular)
|
||||||
return nonsingular_case_backward(grad, self);
|
return nonsingular_case_backward(grad, self);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto neginf_logdet_indices = at::where(logdet == -INFINITY);
|
auto neginf_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY));
|
||||||
|
c10::optional<Tensor> first_neginf_logdet_index = neginf_logdet_indices[0];
|
||||||
|
|
||||||
if (neginf_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are -inf (singular)
|
if (first_neginf_logdet_index->size(0) == logdet.numel()) { // all log determinants are -inf (singular)
|
||||||
return singular_case_backward(grad, self);
|
return singular_case_backward(grad, self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2313,15 +2318,17 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet,
|
||||||
return nonsingular_case_backward(grad_logabsdet, self);
|
return nonsingular_case_backward(grad_logabsdet, self);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
auto nonzero_signdet_indices = at::where(signdet);
|
auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet));
|
||||||
|
c10::optional<Tensor> first_nonzero_signdet_index = nonzero_signdet_indices[0];
|
||||||
|
|
||||||
if (nonzero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular)
|
if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular)
|
||||||
return nonsingular_case_backward(grad_logabsdet, self);
|
return nonsingular_case_backward(grad_logabsdet, self);
|
||||||
}
|
}
|
||||||
|
|
||||||
auto zero_signdet_indices = at::where(signdet == 0);
|
auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0));
|
||||||
|
c10::optional<Tensor> first_zero_signdet_index = zero_signdet_indices[0];
|
||||||
|
|
||||||
if (zero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are -inf (singular)
|
if (first_zero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are -inf (singular)
|
||||||
return singular_case_backward(grad_logabsdet, self);
|
return singular_case_backward(grad_logabsdet, self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -2873,8 +2880,8 @@ Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indic
|
||||||
return gg_weight.view(size);
|
return gg_weight.view(size);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor index_backward(Tensor zeros_like_self, TensorList indices, const Tensor& grad) {
|
Tensor index_backward(Tensor zeros_like_self, const torch::List<c10::optional<Tensor>>& indices, const Tensor& grad) {
|
||||||
return at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
|
return at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) {
|
Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) {
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,7 @@ at::Tensor slogdet_backward(const at::Tensor& grad_logabsdet, const at::Tensor&
|
||||||
at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self);
|
at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self);
|
||||||
at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices, at::IntArrayRef values_shape);
|
at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices, at::IntArrayRef values_shape);
|
||||||
at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx);
|
at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx);
|
||||||
at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad);
|
at::Tensor index_backward(at::Tensor zeros_like_self, const torch::List<c10::optional<Tensor>>& indices, const at::Tensor& grad);
|
||||||
at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity);
|
at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity);
|
||||||
|
|
||||||
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
|
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,
|
||||||
|
|
|
||||||
|
|
@ -66,10 +66,6 @@ Tensor unpack_opt(const Tensor & t, const char * name, int pos) {
|
||||||
return unpack(t, name, pos);
|
return unpack(t, name, pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
c10::optional<Tensor> unpack_opt(const c10::optional<Tensor> & t, const char * name, int pos) {
|
|
||||||
return t;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos) {
|
std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos) {
|
||||||
std::vector<at::Tensor> ret(tl.size());
|
std::vector<at::Tensor> ret(tl.size());
|
||||||
for (size_t i = 0; i < tl.size(); ++i) {
|
for (size_t i = 0; i < tl.size(); ++i) {
|
||||||
|
|
@ -94,7 +90,7 @@ void _backward(
|
||||||
// instead of us having to unwrap it to Tensor _gradient here.
|
// instead of us having to unwrap it to Tensor _gradient here.
|
||||||
Tensor _gradient = gradient.has_value() ? *gradient : Tensor();
|
Tensor _gradient = gradient.has_value() ? *gradient : Tensor();
|
||||||
std::vector<torch::autograd::Variable> input_vars(inputs.begin(), inputs.end());
|
std::vector<torch::autograd::Variable> input_vars(inputs.begin(), inputs.end());
|
||||||
torch::autograd::backward({self}, {_gradient}, std::move(keep_graph), create_graph, input_vars);
|
torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars);
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_data(Tensor & self, const Tensor & new_data) {
|
void set_data(Tensor & self, const Tensor & new_data) {
|
||||||
|
|
@ -230,7 +226,6 @@ Tensor _fw_primal(const Tensor & self, int64_t level) {
|
||||||
|
|
||||||
// We don't have an outplace copy, so this can't be generated automatically
|
// We don't have an outplace copy, so this can't be generated automatically
|
||||||
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
|
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
|
||||||
jit::Value* output = nullptr;
|
|
||||||
// TODO: once copy is exposed in Declarations.yaml we may be able to bind
|
// TODO: once copy is exposed in Declarations.yaml we may be able to bind
|
||||||
// it automatically
|
// it automatically
|
||||||
auto& self_ = unpack(self, "self", 0);
|
auto& self_ = unpack(self, "self", 0);
|
||||||
|
|
@ -282,7 +277,7 @@ Tensor& resize_(
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||||
self_.resize_(size, std::move(optional_memory_format));
|
self_.resize_(size, optional_memory_format);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (self.fw_grad(/* level */ 0).defined()) {
|
if (self.fw_grad(/* level */ 0).defined()) {
|
||||||
|
|
@ -303,7 +298,7 @@ Tensor& resize_as_(
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||||
at::resize_as_(self_, the_template_, std::move(optional_memory_format));
|
at::resize_as_(self_, the_template_, optional_memory_format);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle fw grad
|
// Handle fw grad
|
||||||
|
|
|
||||||
|
|
@ -266,12 +266,31 @@ inline void check_no_requires_grad(TensorList tensors, const char* name) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void check_no_requires_grad(const c10::List<c10::optional<Tensor>>& tensors, const char* name) {
|
||||||
|
for (c10::optional<Tensor> tensor : tensors) {
|
||||||
|
if (tensor.has_value()) {
|
||||||
|
check_no_requires_grad(*tensor, name);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Assumed that saved tensor lists are never inplace outputs
|
// Assumed that saved tensor lists are never inplace outputs
|
||||||
inline std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
|
inline std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
|
||||||
return fmap(tensors, [](const Tensor& tensor) -> SavedVariable {
|
return fmap(tensors, [](const Tensor& tensor) -> SavedVariable {
|
||||||
return SavedVariable{tensor, false /* is output */}; });
|
return SavedVariable{tensor, false /* is output */}; });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Assumed that saved tensor lists are never inplace outputs
|
||||||
|
inline std::vector<SavedVariable> make_saved_variable_list(const c10::List<c10::optional<at::Tensor>>& tensors) {
|
||||||
|
return fmap(tensors, [](const c10::optional<Tensor>& tensor) -> SavedVariable {
|
||||||
|
if (tensor.has_value()) {
|
||||||
|
return SavedVariable{*tensor, false /* is output */};
|
||||||
|
} else {
|
||||||
|
return SavedVariable{Tensor(), false /* is output */};
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
inline std::vector<std::vector<int64_t>> to_args_sizes(TensorList tensors) {
|
inline std::vector<std::vector<int64_t>> to_args_sizes(TensorList tensors) {
|
||||||
std::vector<std::vector<int64_t>> args_sizes(tensors.size());
|
std::vector<std::vector<int64_t>> args_sizes(tensors.size());
|
||||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
#include <ATen/core/stack.h>
|
#include <ATen/core/stack.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
|
|
||||||
|
|
@ -103,6 +103,9 @@ void TracingState::delValue(const IValue& var) {
|
||||||
Value* getValueTrace(const IValue& var) {
|
Value* getValueTrace(const IValue& var) {
|
||||||
return getTracingState()->getValue(var);
|
return getTracingState()->getValue(var);
|
||||||
}
|
}
|
||||||
|
Value* getOptTensorValueTrace(const c10::optional<at::Tensor>& var) {
|
||||||
|
return getValueTrace(IValue(var));
|
||||||
|
}
|
||||||
Value* TracingState::getValue(const IValue& var) {
|
Value* TracingState::getValue(const IValue& var) {
|
||||||
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...]
|
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...]
|
||||||
// arguments
|
// arguments
|
||||||
|
|
@ -686,6 +689,16 @@ void addInputs(
|
||||||
}
|
}
|
||||||
n->addInput(list_node->output());
|
n->addInput(list_node->output());
|
||||||
}
|
}
|
||||||
|
TORCH_API void addInputs(
|
||||||
|
Node* n,
|
||||||
|
const char* name,
|
||||||
|
const List<c10::optional<at::Tensor>>& value) {
|
||||||
|
Graph* g = n->owningGraph();
|
||||||
|
Node* list_node = nullptr;
|
||||||
|
list_node = g->insertNode(g->createList(
|
||||||
|
OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace)));
|
||||||
|
n->addInput(list_node->output());
|
||||||
|
}
|
||||||
|
|
||||||
void addInputs(
|
void addInputs(
|
||||||
Node* n,
|
Node* n,
|
||||||
|
|
|
||||||
|
|
@ -255,6 +255,10 @@ TORCH_API void addInputs(
|
||||||
const char* name,
|
const char* name,
|
||||||
ArrayRef<at::Tensor> value,
|
ArrayRef<at::Tensor> value,
|
||||||
bool allow_undefined = false);
|
bool allow_undefined = false);
|
||||||
|
TORCH_API void addInputs(
|
||||||
|
Node* n,
|
||||||
|
const char* name,
|
||||||
|
const List<c10::optional<at::Tensor>>& value);
|
||||||
TORCH_API void addInputs(
|
TORCH_API void addInputs(
|
||||||
Node* n,
|
Node* n,
|
||||||
const char* name,
|
const char* name,
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
//#include <ATen/core/function_schema.h>
|
//#include <ATen/core/function_schema.h>
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
#include <torch/csrc/jit/mobile/function.h>
|
#include <torch/csrc/jit/mobile/function.h>
|
||||||
#include <torch/csrc/jit/mobile/method.h>
|
#include <torch/csrc/jit/mobile/method.h>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@
|
||||||
|
|
||||||
#include <ATen/ThreadLocalState.h>
|
#include <ATen/ThreadLocalState.h>
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
#include <torch/csrc/WindowsTorchApiMacro.h>
|
#include <torch/csrc/WindowsTorchApiMacro.h>
|
||||||
#include <torch/csrc/jit/frontend/source_range.h>
|
#include <torch/csrc/jit/frontend/source_range.h>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -908,7 +908,7 @@ RegisterOperators reg(
|
||||||
TORCH_SELECTIVE_SCHEMA(
|
TORCH_SELECTIVE_SCHEMA(
|
||||||
"aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
|
"aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
|
||||||
[](Stack* stack) {
|
[](Stack* stack) {
|
||||||
auto indices = pop(stack).toTensorVector();
|
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
|
||||||
auto self = pop(stack).toTensor();
|
auto self = pop(stack).toTensor();
|
||||||
auto result = at::index(self, indices);
|
auto result = at::index(self, indices);
|
||||||
push(stack, std::move(result));
|
push(stack, std::move(result));
|
||||||
|
|
@ -921,7 +921,7 @@ RegisterOperators reg(
|
||||||
auto unsafe = pop(stack).toBool();
|
auto unsafe = pop(stack).toBool();
|
||||||
auto accumulate = pop(stack).toBool();
|
auto accumulate = pop(stack).toBool();
|
||||||
auto values = pop(stack).toTensor();
|
auto values = pop(stack).toTensor();
|
||||||
auto indices = pop(stack).toTensorVector();
|
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
|
||||||
auto self = pop(stack).toTensor();
|
auto self = pop(stack).toTensor();
|
||||||
auto result =
|
auto result =
|
||||||
at::_index_put_impl_(self, indices, values, accumulate, unsafe);
|
at::_index_put_impl_(self, indices, values, accumulate, unsafe);
|
||||||
|
|
@ -934,7 +934,7 @@ RegisterOperators reg(
|
||||||
[](Stack* stack) {
|
[](Stack* stack) {
|
||||||
auto accumulate = pop(stack).toBool();
|
auto accumulate = pop(stack).toBool();
|
||||||
auto values = pop(stack).toTensor();
|
auto values = pop(stack).toTensor();
|
||||||
auto indices = pop(stack).toTensorVector();
|
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
|
||||||
auto self = pop(stack).toTensor();
|
auto self = pop(stack).toTensor();
|
||||||
auto result = at::index_put_(self, indices, values, accumulate);
|
auto result = at::index_put_(self, indices, values, accumulate);
|
||||||
push(stack, std::move(result));
|
push(stack, std::move(result));
|
||||||
|
|
@ -946,7 +946,7 @@ RegisterOperators reg(
|
||||||
[](Stack* stack) {
|
[](Stack* stack) {
|
||||||
auto accumulate = pop(stack).toBool();
|
auto accumulate = pop(stack).toBool();
|
||||||
auto values = pop(stack).toTensor();
|
auto values = pop(stack).toTensor();
|
||||||
auto indices = pop(stack).toTensorVector();
|
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
|
||||||
auto self = pop(stack).toTensor();
|
auto self = pop(stack).toTensor();
|
||||||
auto result = at::index_put_(self, indices, values, accumulate);
|
auto result = at::index_put_(self, indices, values, accumulate);
|
||||||
push(stack, std::move(result));
|
push(stack, std::move(result));
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@
|
||||||
#include <ATen/core/List.h>
|
#include <ATen/core/List.h>
|
||||||
#include <ATen/core/functional.h>
|
#include <ATen/core/functional.h>
|
||||||
#include <ATen/core/ivalue.h>
|
#include <ATen/core/ivalue.h>
|
||||||
|
#include <ATen/core/jit_type.h>
|
||||||
#include <ATen/core/stack.h>
|
#include <ATen/core/stack.h>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
|
||||||
{"double", ParameterType::DOUBLE},
|
{"double", ParameterType::DOUBLE},
|
||||||
{"complex", ParameterType::COMPLEX},
|
{"complex", ParameterType::COMPLEX},
|
||||||
{"TensorList", ParameterType::TENSOR_LIST},
|
{"TensorList", ParameterType::TENSOR_LIST},
|
||||||
|
{"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
|
||||||
{"IntArrayRef", ParameterType::INT_LIST},
|
{"IntArrayRef", ParameterType::INT_LIST},
|
||||||
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
|
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
|
||||||
{"Generator", ParameterType::GENERATOR},
|
{"Generator", ParameterType::GENERATOR},
|
||||||
|
|
@ -390,7 +391,7 @@ bool is_float_or_complex_list(PyObject* obj) {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
|
auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
|
||||||
if (size > 0) {
|
if (size > 0) {
|
||||||
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
|
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
|
||||||
if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
|
if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
|
||||||
return false;
|
return false;
|
||||||
|
|
|
||||||
|
|
@ -160,6 +160,7 @@ struct PythonArgs {
|
||||||
inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
||||||
inline std::vector<at::Scalar> scalarlist(int i);
|
inline std::vector<at::Scalar> scalarlist(int i);
|
||||||
inline std::vector<at::Tensor> tensorlist(int i);
|
inline std::vector<at::Tensor> tensorlist(int i);
|
||||||
|
inline torch::List<c10::optional<at::Tensor>> list_of_optional_tensors(int i);
|
||||||
template<int N>
|
template<int N>
|
||||||
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
||||||
inline std::vector<int64_t> intlist(int i);
|
inline std::vector<int64_t> intlist(int i);
|
||||||
|
|
@ -327,6 +328,22 @@ inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) {
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline torch::List<c10::optional<at::Tensor>> PythonArgs::list_of_optional_tensors(int i) {
|
||||||
|
if (!args[i]) return torch::List<c10::optional<at::Tensor>>();
|
||||||
|
auto tuple = six::isTuple(args[i]);
|
||||||
|
THPObjectPtr arg = six::maybeAsTuple(args[i]);
|
||||||
|
auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
|
||||||
|
torch::List<c10::optional<at::Tensor>> res;
|
||||||
|
res.reserve(size);
|
||||||
|
for (int idx = 0; idx < size; idx++) {
|
||||||
|
PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx);
|
||||||
|
// This is checked by the argument parser so it's safe to cast without checking
|
||||||
|
// if this is a tensor first
|
||||||
|
res.push_back(reinterpret_cast<THPVariable*>(obj)->cdata);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
template<int N>
|
template<int N>
|
||||||
inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) {
|
inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) {
|
||||||
auto res = std::array<at::Tensor, N>();
|
auto res = std::array<at::Tensor, N>();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user