Making ops c10-full: list of optional tensors (#49138)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/49138

See for details: https://fb.quip.com/QRtJAin66lPN

We need to model optional types explicitly, mostly for schema inference. So we cannot pass a `Tensor?[]` as `ArrayRef<Tensor>`, instead we need to pass it as an optional type. This PR changes it to `torch::List<c10::optional<Tensor>>`. It also makes the ops c10-full that were blocked by this.

## Backwards Compatibility

- This should not break the Python API because the representation in Python is the same and python_arg_parser just transforms the python list into a `List<optional<Tensor>>` instead of into a `List<Tensor>`.
- This should not break serialized models because there's some logic that allows loading a serialized `List<Tensor>` as `List<optional<Tensor>>`, see https://github.com/pytorch/pytorch/pull/49138/files#diff-9315f5dd045f47114c677174dcaa2f982721233eee1aa19068a42ff3ef775315R57
- This will break backwards compatibility for the C++ API. There is no implicit conversion from `ArrayRef<Tensor>` (which was the old argument type) to `List<optional<Tensor>>`. One common call pattern is `tensor.index({indices_tensor})`, where indices_tensor is another `Tensor`, and that will continue working because the `{}` initializer_list constructor for `List<optional<Tensor>>` can take `Tensor` elements that are implicitly converted to `optional<Tensor>`, but another common call pattern was `tensor.index(indices_tensor)`, where previously, the `Tensor` got implicitly converted to an `ArrayRef<Tensor>`, and to implicitly convert `Tensor -> optional<Tensor> -> List<optional<Tensor>>` would be two implicit conversions. C++ doesn't allow chaining. two implicit conversions. So those call sites have to be rewritten to `tensor.index({indices_tensor})`.

ghstack-source-id: 119269131

Test Plan:
## Benchmarks (C++ instruction counts):
### Forward
#### Script
```py
from torch.utils.benchmark import Timer

counts = Timer(
    stmt="""
        auto t = {{op call to measure}};
    """,
    setup="""
        using namespace torch::indexing;
        auto x = torch::ones({4, 4, 4});
    """,
    language="cpp",
).collect_callgrind(number=1_000)
print(counts)
```
#### Results
|  Op call                                                              |before   |after   |delta  |      |
|------------------------------------------------------------------------|---------|--------|-------|------|
|x[0] = 1                                                                |11566015 |11566015|0      |0.00% |
|x.index({0})                                                            |6807019  |6801019 |-6000  |-0.09%|
|x.index({0, 0})                                                         |13529019 |13557019|28000  |0.21% |
|x.index({0, 0, 0})                                                      |10677004 |10692004|15000  |0.14% |
|x.index({"..."})                                                        |5512015  |5506015 |-6000  |-0.11%|
|x.index({Slice(None, None, None)})                                      |6866016  |6936016 |70000  |1.02% |
|x.index({None})                                                         |8554015  |8548015 |-6000  |-0.07%|
|x.index({false})                                                        |22400000 |22744000|344000 |1.54% |
|x.index({true})                                                         |27624088 |27264393|-359695|-1.30%|
|x.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})})|123472000|123463306|-8694|-0.01%|

### Autograd
#### Script
```py
from torch.utils.benchmark import Timer

counts = Timer(
    stmt="""
        auto t = {{op call to measure}};
    """,
    setup="""
        using namespace torch::indexing;
        auto x = torch::ones({4, 4, 4}, torch::requires_grad());
    """,
    language="cpp",
).collect_callgrind(number=1_000)
print(counts)
```
Note: the script measures the **forward** path of an op call with autograd enabled (i.e. calls into VariableType). It does not measure the backward path.

#### Results
|  Op call                                                              |before   |after   |delta  |      |
|------------------------------------------------------------------------|---------|--------|-------|------|
|x.index({0})                                                            |14839019|14833019|-6000| 0.00% |
|x.index({0, 0})                                                         |28342019|28370019|28000| 0.00% |
|x.index({0, 0, 0})                                                      |24434004|24449004|15000| 0.00% |
|x.index({"..."})                                                       |12773015|12767015|-6000| 0.00% |
|x.index({Slice(None, None, None)})                                      |14837016|14907016|70000| 0.47% |
|x.index({None})                                                        |15926015|15920015|-6000| 0.00% |
|x.index({false})                                                        |36958000|37477000|519000| 1.40% |
|x.index({true})                                                         |41971408|42426094|454686| 1.08% |
|x.index({"...", 0, true, Slice(1, None, 2), torch::tensor({1, 2})}) |168184392|164545682|-3638710| -2.16% |

Reviewed By: bhosmer

Differential Revision: D25454632

fbshipit-source-id: 28ab0cffbbdbdff1c40b4130ca62ee72f981b76d
This commit is contained in:
Sebastian Messmer 2021-01-04 05:01:02 -08:00 committed by Facebook GitHub Bot
parent e44b2b72bd
commit c7e9abb66a
45 changed files with 511 additions and 306 deletions

View File

@ -31,3 +31,4 @@
#include <c10/util/Exception.h>
#include <ATen/core/UnsafeFromTH.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>

View File

@ -1,4 +1,5 @@
#include <ATen/Config.h>
#include <ATen/core/jit_type.h>
#if AT_PARALLEL_OPENMP
#include <ATen/Parallel.h>

View File

@ -10,6 +10,8 @@
// There is some back story, see https://github.com/pytorch/pytorch/issues/48684
#include <ATen/NativeFunctions.h>
#include <ATen/core/List.h>
namespace at {
namespace indexing {
@ -261,14 +263,15 @@ static inline void recordTensorIndex(const Tensor& tensor, std::vector<Tensor>&
(*dim_ptr)++;
};
static inline std::vector<Tensor> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
std::vector<Tensor> converted_inds(indices.size());
static inline c10::List<c10::optional<Tensor>> typeConvertIndices(const Tensor& self, std::vector<Tensor>&& indices) {
c10::List<c10::optional<Tensor>> converted_inds;
converted_inds.reserve(indices.size());
for (size_t i = 0; i < indices.size(); ++i) {
const auto &ind = indices[i];
if (ind.defined()) {
converted_inds[i] = ind.to(ind.options().device(self.device()));
converted_inds.push_back(ind.to(ind.options().device(self.device())));
} else {
converted_inds[i] = std::move(indices[i]);
converted_inds.push_back(std::move(indices[i]));
}
}
return converted_inds;

View File

@ -406,7 +406,7 @@ TORCH_LIBRARY_IMPL(aten, Autocast, m) {
KERNEL(ADD_NS(cross), "cross", Tensor (const Tensor &, const Tensor &, c10::optional<int64_t>), promote)
KERNEL(ADD_NS(dot), "dot", Tensor (const Tensor &, const Tensor &), promote)
KERNEL(ADD_NS(equal), "equal", bool (const Tensor &, const Tensor &), promote)
KERNEL_UNBOXED_ONLY(ADD_NS(index_put), "index_put", Tensor (const Tensor &, TensorList, const Tensor &, bool), promote)
KERNEL(ADD_NS(index_put), "index_put", Tensor (const Tensor &, const torch::List<c10::optional<Tensor>>&, const Tensor &, bool), promote)
KERNEL(ADD_NS(stack), "stack", Tensor (TensorList, int64_t), promote)
KERNEL(ADD_NS(tensordot), "tensordot", Tensor (const Tensor &, const Tensor &, IntArrayRef, IntArrayRef), promote)

View File

@ -243,7 +243,7 @@ public:
* Example:
* List<int> a({2, 3, 4});
*/
explicit List(std::initializer_list<T> initial_values);
List(std::initializer_list<T> initial_values);
explicit List(ArrayRef<T> initial_values);
/**

View File

@ -1,7 +1,7 @@
#pragma once
#include <ATen/core/jit_type_base.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
namespace c10 {
@ -50,7 +50,17 @@ List<T>::List(TypePtr elementType)
namespace impl {
template<class T>
List<T> toTypedList(impl::GenericList list) {
TORCH_INTERNAL_ASSERT(*getTypePtr<T>() == *list.impl_->elementType, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
// If there's other instances of the list (i.e. list.use_count() > 1), then we have to be invariant
// because upcasting would allow people to add types into the new list that would break the old list.
// However, if there aren't any other instances of this list (i.e. list.use_count() == 1), then we can
// allow upcasting. This can be a perf improvement since we can cast List<T> to List<optional<T>>
// without having to copy it. This is also used to provide backwards compatibility with some old models
// that serialized the index arguments to aten::index, aten::index_put, aten::index_put_ and aten::index_put_impl_
// as List<Tensor> before we changed that argument to be List<optional<Tensor>>. When deserializing, we
// have list.use_count() == 1 and can deserialize the List<Tensor> directly as List<optional<Tensor>>.
TORCH_CHECK(*list.impl_->elementType == *getTypePtr<T>()
|| (list.use_count() == 1 && list.impl_->elementType->isSubtypeOf(getTypePtr<T>()))
, "Tried to cast a List<", toString(list.impl_->elementType), "> to a List<", toString(getTypePtr<T>()), ">. Types mismatch.");
return List<T>(std::move(list.impl_));
}
@ -312,3 +322,5 @@ void List<T>::unsafeSetElementType(TypePtr t) {
impl_->elementType = std::move(t);
}
}
#include <ATen/core/jit_type.h>

View File

@ -6,6 +6,7 @@
#include <utility>
#include <c10/util/ArrayRef.h>
#include <ATen/core/List.h>
namespace at {
@ -56,6 +57,15 @@ struct IterArgs {
}
}
template <typename T>
void operator()(const torch::List<T>& args) {
for (const auto& arg : args) {
self()(arg);
if (self().short_circuit())
return;
}
}
// NB: we need to specify std::vector manually as C++ won't
// do an implicit conversion to make a template deduction go through.
template <typename T>

View File

@ -1,10 +1,11 @@
#pragma once
#include <ATen/core/jit_type_base.h>
#include <ATen/core/TensorBody.h>
#include <ATen/core/functional.h>
#include <ATen/core/interned_strings.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/qualified_name.h>
#include <ATen/core/ivalue.h>
#include <c10/util/TypeList.h>
#include <c10/util/Optional.h>
@ -17,197 +18,17 @@ struct ClassType;
namespace torch {
namespace jit {
struct CompilationUnit;
struct Function;
} // namespace jit
} // namespace torch
namespace c10 {
struct IValue;
struct FunctionSchema;
struct NamedType;
using OptNameList = c10::optional<std::vector<std::string>>;
#define C10_FORALL_TYPES(_) \
_(AnyType) \
_(EnumType) \
_(AnyEnumType) \
_(TensorType) \
_(StorageType) \
_(TupleType) \
_(ListType) \
_(DictType) \
_(NumberType) \
_(FloatType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
_(NoneType) \
_(StringType) \
_(GeneratorType) \
_(QuantizerType) \
_(BoolType) \
_(OptionalType) \
_(VarType) \
_(DeviceObjType) \
_(StreamObjType) \
_(FunctionType) \
_(ClassType) \
_(PyObjectType) \
_(CapsuleType) \
_(InterfaceType) \
_(QSchemeType) \
_(LayoutType) \
_(ScalarTypeType) \
_(AnyListType) \
_(AnyTupleType) \
_(AnyClassType)
enum class TypeKind {
#define DEFINE_TYPE(T) T,
C10_FORALL_TYPES(DEFINE_TYPE)
#undef DEFINE_TYPE
};
TORCH_API const char* typeKindToString(TypeKind kind);
struct Type;
using TypePtr = std::shared_ptr<Type>;
using ConstTypePtr = std::shared_ptr<const Type>;
// Use this to customize how a Type is printed using `annotation_str()`. If
// c10::nullopt is returned, `annotation_str()` falls through to its default
// implementation.
using TypePrinter =
std::function<c10::optional<std::string>(const ConstTypePtr&)>;
struct TORCH_API Type : std::enable_shared_from_this<Type> {
private:
TypeKind kind_;
protected:
Type(TypeKind kind) : kind_(kind) {}
virtual std::string annotation_str_impl(TypePrinter printer) const {
return str();
}
public:
virtual bool operator==(const Type& rhs) const = 0;
// subtyping relation. By default, we return true for the case
// when the type is exactly equal or if this <: T where rhs = Optional[T]
// if this returns false and the why_not stream is non-null, it contains
// additional details that describe why this is not a subtype of 'rhs'.
// This additional information should only contain details that are not obvious
// from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false
// but not clear why `Foo <: InterfaceBar` might be false.
virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
virtual bool is_module() const;
bool isSubtypeOf(const TypePtr& rhs) const {
return isSubtypeOfExt(rhs, nullptr);
}
// How this type will appear in FunctionSchema declarations
virtual std::string str() const = 0;
// How this type will appear as if it were a type annotation in Python
// which is sometimes different than how it appears in declarations (e.g.
// int[] vs List[int])
//
// Takes a custom printer that users can pass in to customize the output of
// this method.
std::string annotation_str(TypePrinter printer) const {
if (printer) {
// the printer can return nullopt to fall through to the default impl
if (auto renamed = printer(shared_from_this())) {
return *renamed;
}
}
return annotation_str_impl(printer);
}
std::string annotation_str() const {
// Overload instead of define a default value for `printer` to help
// debuggers out.
return annotation_str(nullptr);
}
// Returns a human readable string that includes additional information like
// "type is inferred rather than explictly defined" to help construct more
// user-friendly messages.
virtual std::string repr_str() const {
return annotation_str();
}
TypeKind kind() const {
return kind_;
}
virtual bool requires_grad() const {
for (const auto& ct : containedTypes()) {
if (ct->requires_grad()) {
return true;
}
}
return false;
}
// Dynamically cast this object to the subclass indicated by the
// template variable, returning nullptr if the cast is invalid.
template <typename T>
std::shared_ptr<T> cast() {
if (T::Kind == kind()) {
return std::static_pointer_cast<T>(shared_from_this());
}
return nullptr;
}
template <typename T>
std::shared_ptr<const T> cast() const {
if (T::Kind == kind()) {
return std::static_pointer_cast<const T>(shared_from_this());
}
return nullptr;
}
template <typename T>
std::shared_ptr<T> expect() {
auto r = cast<T>();
AT_ASSERT(r);
return r;
}
template <typename T>
std::shared_ptr<const T> expect() const {
auto r = cast<const T>();
AT_ASSERT(r);
return r;
}
virtual ~Type() = default;
virtual bool hasFreeVariables() const {
return false;
}
// list of types this type contains, e.g. for a List then element type of a
// list for a tuple, the types of the tuple elements
virtual at::ArrayRef<TypePtr> containedTypes() const {
return {};
}
// create a new version of this type, replacing its contained types with
// contained_types
TypePtr withContained(std::vector<TypePtr> contained_types) {
auto current_contained = containedTypes();
AT_ASSERT(current_contained.size() == contained_types.size());
if (current_contained.equals(contained_types)) {
return shared_from_this();
}
return createWithContained(std::move(contained_types));
}
// per-type constructor, you only need to override this if the
// containedTypes() is not empty
virtual TypePtr createWithContained(
std::vector<TypePtr> contained_types) const {
AT_ERROR(
"type with contained types did not overload createWithContained: ",
str());
}
};
struct AnyType;
using AnyTypePtr = std::shared_ptr<AnyType>;
// Any is the top of the type hierarchy, all other types are subtypes

View File

@ -0,0 +1,195 @@
#pragma once
#include <functional>
#include <string>
#include <memory>
#include <c10/macros/Macros.h>
#include <c10/util/Optional.h>
#include <c10/util/Exception.h>
#include <c10/util/ArrayRef.h>
namespace c10 {
#define C10_FORALL_TYPES(_) \
_(AnyType) \
_(EnumType) \
_(AnyEnumType) \
_(TensorType) \
_(StorageType) \
_(TupleType) \
_(ListType) \
_(DictType) \
_(NumberType) \
_(FloatType) \
_(FutureType) \
_(RRefType) \
_(IntType) \
_(NoneType) \
_(StringType) \
_(GeneratorType) \
_(QuantizerType) \
_(BoolType) \
_(OptionalType) \
_(VarType) \
_(DeviceObjType) \
_(StreamObjType) \
_(FunctionType) \
_(ClassType) \
_(PyObjectType) \
_(CapsuleType) \
_(InterfaceType) \
_(QSchemeType) \
_(LayoutType) \
_(ScalarTypeType) \
_(AnyListType) \
_(AnyTupleType) \
_(AnyClassType)
enum class TypeKind {
#define DEFINE_TYPE(T) T,
C10_FORALL_TYPES(DEFINE_TYPE)
#undef DEFINE_TYPE
};
TORCH_API const char* typeKindToString(TypeKind kind);
struct Type;
using TypePtr = std::shared_ptr<Type>;
using ConstTypePtr = std::shared_ptr<const Type>;
// Use this to customize how a Type is printed using `annotation_str()`. If
// c10::nullopt is returned, `annotation_str()` falls through to its default
// implementation.
using TypePrinter =
std::function<c10::optional<std::string>(const ConstTypePtr&)>;
struct TORCH_API Type : std::enable_shared_from_this<Type> {
private:
TypeKind kind_;
protected:
Type(TypeKind kind) : kind_(kind) {}
virtual std::string annotation_str_impl(TypePrinter printer) const {
return str();
}
public:
virtual bool operator==(const Type& rhs) const = 0;
// subtyping relation. By default, we return true for the case
// when the type is exactly equal or if this <: T where rhs = Optional[T]
// if this returns false and the why_not stream is non-null, it contains
// additional details that describe why this is not a subtype of 'rhs'.
// This additional information should only contain details that are not obvious
// from the annotation_str() that describes the type. For instance it is clear that `int <: str` is false
// but not clear why `Foo <: InterfaceBar` might be false.
virtual bool isSubtypeOfExt(const TypePtr& rhs, std::ostream* why_not) const;
virtual bool is_module() const;
bool isSubtypeOf(const TypePtr& rhs) const {
return isSubtypeOfExt(rhs, nullptr);
}
// How this type will appear in FunctionSchema declarations
virtual std::string str() const = 0;
// How this type will appear as if it were a type annotation in Python
// which is sometimes different than how it appears in declarations (e.g.
// int[] vs List[int])
//
// Takes a custom printer that users can pass in to customize the output of
// this method.
std::string annotation_str(TypePrinter printer) const {
if (printer) {
// the printer can return nullopt to fall through to the default impl
if (auto renamed = printer(shared_from_this())) {
return *renamed;
}
}
return annotation_str_impl(printer);
}
std::string annotation_str() const {
// Overload instead of define a default value for `printer` to help
// debuggers out.
return annotation_str(nullptr);
}
// Returns a human readable string that includes additional information like
// "type is inferred rather than explictly defined" to help construct more
// user-friendly messages.
virtual std::string repr_str() const {
return annotation_str();
}
TypeKind kind() const {
return kind_;
}
virtual bool requires_grad() const {
for (const auto& ct : containedTypes()) {
if (ct->requires_grad()) {
return true;
}
}
return false;
}
// Dynamically cast this object to the subclass indicated by the
// template variable, returning nullptr if the cast is invalid.
template <typename T>
std::shared_ptr<T> cast() {
if (T::Kind == kind()) {
return std::static_pointer_cast<T>(shared_from_this());
}
return nullptr;
}
template <typename T>
std::shared_ptr<const T> cast() const {
if (T::Kind == kind()) {
return std::static_pointer_cast<const T>(shared_from_this());
}
return nullptr;
}
template <typename T>
std::shared_ptr<T> expect() {
auto r = cast<T>();
AT_ASSERT(r);
return r;
}
template <typename T>
std::shared_ptr<const T> expect() const {
auto r = cast<const T>();
AT_ASSERT(r);
return r;
}
virtual ~Type() = default;
virtual bool hasFreeVariables() const {
return false;
}
// list of types this type contains, e.g. for a List then element type of a
// list for a tuple, the types of the tuple elements
virtual at::ArrayRef<TypePtr> containedTypes() const {
return {};
}
// create a new version of this type, replacing its contained types with
// contained_types
TypePtr withContained(std::vector<TypePtr> contained_types) {
auto current_contained = containedTypes();
AT_ASSERT(current_contained.size() == contained_types.size());
if (current_contained.equals(contained_types)) {
return shared_from_this();
}
return createWithContained(std::move(contained_types));
}
// per-type constructor, you only need to override this if the
// containedTypes() is not empty
virtual TypePtr createWithContained(
std::vector<TypePtr> contained_types) const {
AT_ERROR(
"type with contained types did not overload createWithContained: ",
str());
}
};
}

View File

@ -68,7 +68,7 @@ Tensor embedding_sparse_backward(
Tensor indices = indices_;
Tensor grad = grad_;
if (padding_idx != -1) {
auto c = indices != padding_idx;
torch::List<c10::optional<Tensor>> c({indices != padding_idx});
indices = indices.index(c);
grad = grad.index(c);
}

View File

@ -1,6 +1,7 @@
#pragma once
#include <ATen/ExpandUtils.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/core/List.h>
#include <limits>
@ -15,40 +16,45 @@ static void invalid_mask(const Tensor & self, int64_t idx, const Tensor & mask,
}
static std::vector<Tensor> expandTensors(const Tensor & self, TensorList indices) {
static std::vector<Tensor> expandTensors(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
// If indices come in as ByteTensor or BoolTensor (masks), expand them into the equivalent indexing by LongTensors
std::vector<Tensor> result;
for (const auto & index : indices) {
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
if (index.scalar_type() == kByte) {
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
" please use a dtype torch.bool instead.");
}
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
// corresponding dimensions in self
for (int64_t j = 0; j < index.dim(); j++) {
int64_t srcIdx = result.size() + j;
if (index.size(j) != self.size(srcIdx)) {
invalid_mask(self, srcIdx, index, j);
}
}
// Replace with nonzeros
auto nonzero = index.nonzero();
for (int64_t j = 0; j < index.dim(); j++) {
result.emplace_back(nonzero.select(1, j));
}
for (c10::optional<Tensor> index_opt : indices) {
if (!index_opt.has_value()) {
result.emplace_back();
} else {
result.emplace_back(index);
Tensor index = std::move(*index_opt);
if (index.scalar_type() == kByte || index.scalar_type() == kBool) {
if (index.scalar_type() == kByte) {
TORCH_WARN("indexing with dtype torch.uint8 is now deprecated," \
" please use a dtype torch.bool instead.");
}
// The sizes of the ByteTensor mask or bool tensor must match the sizes of the
// corresponding dimensions in self
for (int64_t j = 0; j < index.dim(); j++) {
int64_t srcIdx = result.size() + j;
if (index.size(j) != self.size(srcIdx)) {
invalid_mask(self, srcIdx, index, j);
}
}
// Replace with nonzeros
auto nonzero = index.nonzero();
for (int64_t j = 0; j < index.dim(); j++) {
result.emplace_back(nonzero.select(1, j));
}
} else {
result.emplace_back(std::move(index));
}
}
}
return result;
}
static void checkIndexTensorTypes(TensorList indices) {
for (auto& tensor : indices) {
if (tensor.defined()) {
auto scalarType = tensor.scalar_type();
static void checkIndexTensorTypes(const torch::List<c10::optional<Tensor>>& indices) {
for (c10::optional<Tensor> tensor : indices) {
if (tensor.has_value() && tensor->defined()) {
auto scalarType = tensor->scalar_type();
if (scalarType != kLong && scalarType != kByte && scalarType != kBool) {
TORCH_CHECK_INDEX(false, "tensors used as indices must be long, byte or bool tensors");
}
@ -56,6 +62,15 @@ static void checkIndexTensorTypes(TensorList indices) {
}
}
inline torch::List<c10::optional<Tensor>> toListOfOptionalTensors(ArrayRef<Tensor> list) {
torch::List<c10::optional<Tensor>> result;
result.reserve(list.size());
for (const Tensor& a : list) {
result.push_back(a);
}
return result;
}
static bool hasContiguousSubspace(TensorList tl) {
// true if all the non-null tensors are adjacent
auto isDefined = [](const Tensor & tensor){ return tensor.defined(); };

View File

@ -8,6 +8,7 @@
#include <ATen/native/Resize.h>
#include <ATen/native/TensorIterator.h>
#include <ATen/native/LinearAlgebra.h>
#include <ATen/native/IndexingUtils.h>
#include <ATen/TensorUtils.h>
#include <ATen/Parallel.h>
#include <ATen/LegacyTHFunctionsCPU.h>
@ -73,7 +74,8 @@ Tensor logdet(const Tensor& self) {
// U is singular when U(i, i) = 0 for some i in [1, self.size(-1)].
Tensor logdet_vals = diag_U.abs_().log_().sum(-1);
if (self.dim() > 2) {
logdet_vals.index_put_((det_sign < 0).nonzero_numpy(), at::full({}, NAN, self.options()));
auto indices = toListOfOptionalTensors((det_sign < 0).nonzero_numpy());
logdet_vals.index_put_(std::move(indices), at::full({}, NAN, self.options()));
} else if (det_sign.item<double>() < 0) {
logdet_vals.fill_(NAN);
}

View File

@ -206,7 +206,7 @@ AdvancedIndex::AdvancedIndex(const Tensor& src, TensorList indices_list)
}
}
static AdvancedIndex make_info(Tensor self, TensorList orig) {
static AdvancedIndex make_info(Tensor self, const torch::List<c10::optional<at::Tensor>>& orig) {
checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);
@ -281,7 +281,7 @@ static TensorIterator make_index_out_iterator(const AdvancedIndex& info, Tensor&
return config.build();
}
Tensor index(const Tensor & self, TensorList indices) {
Tensor index(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
auto info = make_info(self, indices);
@ -290,7 +290,7 @@ Tensor index(const Tensor & self, TensorList indices) {
return iter.output();
}
Tensor quantized_index(const Tensor & self, TensorList indices) {
Tensor quantized_index(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
TORCH_INTERNAL_ASSERT(
self.qscheme() == c10::kPerTensorAffine ||
self.qscheme() == c10::kPerTensorSymmetric,
@ -311,12 +311,14 @@ Tensor quantized_index(const Tensor & self, TensorList indices) {
res, self.q_scale(), self.q_zero_point(), self.scalar_type());
}
Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) {
Tensor& index_out(Tensor& result, const Tensor & self, const torch::List<c10::optional<Tensor>>& indices) {
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
at::assert_no_internal_overlap(result);
at::assert_no_overlap(result, self);
for (auto& index: indices) {
at::assert_no_overlap(result, index);
for (const c10::optional<Tensor>& index: indices) {
if (index.has_value()) {
at::assert_no_overlap(result, *index);
}
}
auto info = make_info(self, indices);
@ -325,11 +327,11 @@ Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices) {
return result;
}
Tensor index_put(const Tensor & self, TensorList indices, const Tensor & value, bool accumulate) {
Tensor index_put(const Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, bool accumulate) {
return self.clone(at::MemoryFormat::Preserve).index_put_(indices, value, accumulate);
}
Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate, const bool unsafe) {
Tensor & _index_put_impl_(Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, const bool accumulate, const bool unsafe) {
TORCH_CHECK_INDEX(indices.size() <= (size_t)self.dim(), "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
if (at::has_internal_overlap(self) == MemOverlap::YES) {
TORCH_WARN(
@ -338,8 +340,10 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu
"This also applies to advanced indexing e.g. tensor[indices] = tensor");
}
at::assert_no_overlap(self, value);
for (auto& index: indices) {
at::assert_no_overlap(self, index);
for (const c10::optional<Tensor>& index: indices) {
if (index.has_value()) {
at::assert_no_overlap(self, *index);
}
}
if (accumulate && self.device().type() == kCUDA) {
@ -356,7 +360,7 @@ Tensor & _index_put_impl_(Tensor & self, TensorList indices, const Tensor & valu
}
Tensor & index_put_(Tensor & self, TensorList indices, const Tensor & value, const bool accumulate) {
Tensor & index_put_(Tensor & self, const torch::List<c10::optional<Tensor>>& indices, const Tensor & value, const bool accumulate) {
return at::_index_put_impl_(self, indices, value, accumulate, /*unsafe=*/false);
}

View File

@ -15,7 +15,7 @@ enum class SCATTER_GATHER_OP: uint8_t {REDUCE_ADD, REDUCE_MULTIPLY};
using index_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides);
using index_put_fn = void(*)(TensorIterator &, IntArrayRef indexed_sizes, IntArrayRef indexed_strides, bool accumulate);
using index_put_accum_fn = void(*)(Tensor &, TensorList , const Tensor &, bool unsafe);
using index_put_accum_fn = void(*)(Tensor &, const c10::List<c10::optional<Tensor>> &, const Tensor &, bool unsafe);
using masked_fill_fn = void(*)(TensorIterator &, Scalar scalar);
using masked_select_fn = void(*)(TensorIterator &, int64_t orig_stride);
@ -42,6 +42,6 @@ DECLARE_DISPATCH(scatter_add_fn, scatter_add_stub);
DECLARE_DISPATCH(scatter_reduce_fn, scatter_reduce_stub);
DECLARE_DISPATCH(scatter_scalar_reduce_fn, scatter_scalar_reduce_stub);
TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, TensorList indices);
TORCH_API Tensor& index_out(Tensor& result, const Tensor & self, const c10::List<c10::optional<at::Tensor>>& indices);
}} // namespace at::native

View File

@ -190,7 +190,7 @@ static Tensor & masked_select_out_cuda_impl(Tensor & result, const Tensor & self
Tensor _mask = (mask.dim() == 0) ? mask.unsqueeze(0) : mask;
Tensor _self = (self.dim() == 0) ? self.unsqueeze(0) : self;
std::tie(_mask, _self) = expand_outplace(_mask, _self);
at::native::index_out(result, _self, _mask);
at::native::index_out(result, _self, c10::List<c10::optional<at::Tensor>>({_mask}));
return result;
}

View File

@ -160,7 +160,7 @@ computeLinearIndex(const Tensor & src, TensorList indices, bool check_range) {
}
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, TensorList orig, bool check_range) {
static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t>> makeLinearIndex(Tensor self, const c10::List<c10::optional<at::Tensor>>& orig, bool check_range) {
checkIndexTensorTypes(orig);
// first expand BoolTensor (masks) or ByteTensor (masks) into 1 or more LongTensors
auto indices = expandTensors(self, orig);
@ -184,7 +184,7 @@ static std::tuple<Tensor, Tensor, int64_t, int64_t, int64_t, std::vector<int64_t
namespace {
void index_put_accum_kernel(Tensor & self, TensorList indices, const Tensor & value, bool unsafe) {
void index_put_accum_kernel(Tensor & self, const c10::List<c10::optional<Tensor>>& indices, const Tensor & value, bool unsafe) {
if (indices.size() > (size_t)self.dim()) {
TORCH_CHECK_INDEX(false, "too many indices for tensor of dimension ", self.dim(), " (got ", indices.size(), ")");
}

View File

@ -2226,6 +2226,7 @@
use_c10_dispatcher: full
- func: index.Tensor(Tensor self, Tensor?[] indices) -> Tensor
use_c10_dispatcher: full
variants: function, method
dispatch:
CPU, CUDA: index
@ -2254,6 +2255,7 @@
variants: function, method
- func: index_put_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor(a!)
use_c10_dispatcher: full
variants: function, method
dispatch:
DefaultBackend: index_put_
@ -2264,9 +2266,11 @@
# - Tensor & Tensor::index_put_(std::initializer_list<TensorIndex> indices, Scalar v)
- func: index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor
use_c10_dispatcher: full
variants: function, method
- func: _index_put_impl_(Tensor(a!) self, Tensor?[] indices, Tensor values, bool accumulate=False, bool unsafe=False) -> Tensor(a!)
use_c10_dispatcher: full
variants: function
dispatch:
CPU, CUDA: _index_put_impl_

View File

@ -7,6 +7,7 @@
#include <ATen/NativeFunctions.h>
#include <ATen/InitialTensorOptions.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/native/IndexingUtils.h>
#include <TH/THBlasUtils.h>
@ -14,7 +15,6 @@ namespace at { namespace native {
using namespace at::sparse;
/******************************************************************************
* access methods
******************************************************************************/
@ -328,7 +328,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
Tensor values;
if (self.dim() > 0) {
std::vector<Tensor> ix = indices.chunk(indices.size(0), 0);
auto ix = toListOfOptionalTensors(indices.chunk(indices.size(0), 0));
values = self.index(ix).squeeze(0).clone(at::MemoryFormat::Preserve);
} else {
AT_ASSERT(nz.sizes().equals({0, 1}));

View File

@ -28,6 +28,7 @@ class Tensor;
}
namespace c10{
struct TensorOptions;
template<class T> class List;
}
namespace at {
struct Generator;

View File

@ -6,13 +6,17 @@ namespace caffe2 {
namespace internal {
at::Tensor index_with_uint8_handling(
const at::Tensor& self,
at::TensorList indices) {
const torch::List<c10::optional<at::Tensor>>& indices) {
// Support BC only for the simplest case of mask indexing
if (indices.size() == 1 && indices[0].scalar_type() == at::kByte) {
TORCH_WARN(
"Indexing with uint8 mask tensor in ATenOp is now deprecated,"
" please use a bool mask instead.");
return at::index(self, {indices[0].to(at::kBool)});
if (indices.size() == 1) {
c10::optional<at::Tensor> first = indices[0];
if (first.has_value()
&& first->scalar_type() == at::kByte) {
TORCH_WARN(
"Indexing with uint8 mask tensor in ATenOp is now deprecated,"
" please use a bool mask instead.");
return at::index(self, {first->to(at::kBool)});
}
}
return at::index(self, indices);
}

View File

@ -21,7 +21,7 @@ using at::Half; // for AT_FORALL_SCALAR_TYPES_AND3(Bool, Half, BFloat16, ...)
namespace internal {
TORCH_API at::Tensor index_with_uint8_handling(
const at::Tensor& self,
at::TensorList indices);
const torch::List<c10::optional<at::Tensor>>& indices);
}
template <class Context>
@ -86,6 +86,16 @@ private:
std::vector<at::Tensor> peekSlice(size_t i, size_t len, size_t N) {
std::vector<at::Tensor> results;
results.reserve(len);
for (size_t ii = i; ii < i + len; ++ii) {
results.push_back(peek(ii, N));
}
return results;
}
torch::List<c10::optional<at::Tensor>> peekSliceOptionals(size_t i, size_t len, size_t N) {
torch::List<c10::optional<at::Tensor>> results;
results.reserve(len);
for (size_t ii = i; ii < i + len; ++ii) {
results.push_back(peek(ii, N));
}

View File

@ -68,7 +68,7 @@ def value_has_tensors(v):
def value_is_tensor_type(v):
return value_has_tensors(v) and v['dynamic_type'] != 'TensorList'
return value_has_tensors(v) and v['dynamic_type'] not in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']
# for each aten type, how do we handle a return value of that type?
@ -208,7 +208,7 @@ def self_as_first_argument(arguments):
def get_num_inputs(o):
args = 0
for a in o['arguments']:
if a['type'] == 'TensorList':
if a['type'] in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']:
return '*'
elif value_has_tensors(a):
args += 1
@ -277,10 +277,10 @@ if __name__ == '__main__':
# e.g. "Float" is at::kFloat
assert('Type' in o['method_of'])
static_tensor_inputs = sum(arg['type'] != 'TensorList' and value_is_tensor_type(arg) for arg in o['arguments'])
has_tensorlist = any(arg['type'] == 'TensorList' for arg in o['arguments'])
static_tensor_inputs = sum(arg['type'] not in ['TensorList', 'const c10::List<c10::optional<Tensor>> &'] and value_is_tensor_type(arg) for arg in o['arguments'])
has_tensorlist = any(arg['type'] in ['TensorList', 'const c10::List<c10::optional<Tensor>> &'] for arg in o['arguments'])
if has_tensorlist:
tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] == 'TensorList'][0]
tensorlist_idx = [i for i, arg in enumerate(o['arguments']) if arg['type'] in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']][0]
real_inputs = 0
for i, arg in enumerate(o['arguments']):
@ -290,10 +290,16 @@ if __name__ == '__main__':
view_length = 'InputSize()' if has_tensorlist and i < tensorlist_idx else static_tensor_inputs
if arg['type'] == 'TensorList':
# NOTE: do not advance real_inputs here. After this we will
# switch to indexing the "stack" from the end as if we only had
# switch to indexing the "stack" from the end
env['statements'].append(
'auto {} = peekSlice({}, InputSize() - {}, InputSize());'
.format(arg['name'], real_inputs, static_tensor_inputs))
elif arg['type'] == 'const c10::List<c10::optional<Tensor>> &':
# NOTE: do not advance real_inputs here. After this we will
# switch to indexing the "stack" from the end
env['statements'].append(
'auto {} = peekSliceOptionals({}, InputSize() - {}, InputSize());'
.format(arg['name'], real_inputs, static_tensor_inputs))
elif value_is_tensor_type(arg):
# load tensor inputs from Caffe2
env['statements'].append(

View File

@ -83,27 +83,27 @@ TEST(TensorIndexingTest, TestNoIndices) {
ASSERT_THROWS_WITH(tensor.index_put_(indices, value), "Passing an empty index list to Tensor::index_put_() is not valid syntax");
}
TEST(TensorIndexingTest, TestAdvancedIndexingWithArrayRefOfTensor) {
TEST(TensorIndexingTest, TestAdvancedIndexingWithListOfTensor) {
{
torch::Tensor tensor = torch::randn({20, 20});
torch::Tensor index = torch::arange(10, torch::kLong).cpu();
torch::Tensor result_with_array_ref = tensor.index(at::ArrayRef<torch::Tensor>({index}));
torch::Tensor result = at::index(tensor, {index});
torch::Tensor result_with_init_list = tensor.index({index});
ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list));
ASSERT_TRUE(result.equal(result_with_init_list));
}
{
torch::Tensor tensor = torch::randn({20, 20});
torch::Tensor index = torch::arange(10, torch::kLong).cpu();
torch::Tensor result_with_array_ref = tensor.index_put_(at::ArrayRef<torch::Tensor>({index}), torch::ones({20}));
torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({20}));
torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({20}));
ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list));
ASSERT_TRUE(result.equal(result_with_init_list));
}
{
torch::Tensor tensor = torch::randn({20, 20});
torch::Tensor index = torch::arange(10, torch::kLong).cpu();
torch::Tensor result_with_array_ref = tensor.index_put_(at::ArrayRef<torch::Tensor>({index}), torch::ones({1, 20}));
torch::Tensor result = at::index_put_(tensor, {index}, torch::ones({1, 20}));
torch::Tensor result_with_init_list = tensor.index_put_({index}, torch::ones({1, 20}));
ASSERT_TRUE(result_with_array_ref.equal(result_with_init_list));
ASSERT_TRUE(result.equal(result_with_init_list));
}
}
@ -173,7 +173,7 @@ TEST(TensorIndexingTest, TestBoolIndices) {
TEST(TensorIndexingTest, TestBoolIndicesAccumulate) {
auto mask = torch::zeros({10}, torch::kBool);
auto y = torch::ones({10, 10});
y.index_put_({mask}, y.index({mask}), /*accumulate=*/true);
y.index_put_({mask}, {y.index({mask})}, /*accumulate=*/true);
assert_tensor_equal(y, torch::ones({10, 10}));
}

View File

@ -563,6 +563,8 @@ def generate_tensor_like_override_tests(cls):
func_args.append(instance_gen())
elif t == 'TensorList':
func_args.append([instance_gen(), instance_gen()])
elif t == 'c10::List<c10::optional<Tensor>>':
func_args.append([instance_gen(), instance_gen()])
elif t == 'IntArrayRef':
size = arg.get('size', 2)
if size == 1:

View File

@ -141,7 +141,7 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
compute_index_ranges: List[str] = []
for arg in info.args_with_derivatives:
if arg.type == 'TensorList':
if arg.type == 'TensorList' or arg.type == 'const c10::List<c10::optional<Tensor>> &':
size = f'{arg.name}_size_'
saved_list_sizes.append(f'size_t {arg.name}_size_;')
else:
@ -166,6 +166,15 @@ def process_function(info: DifferentiabilityInfo, template: CodeTemplate) -> str
release_variables.append(f'{name}_released_ = true;')
unpack.append(f'auto {name} = unpack_list({name}_);')
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
elif var.type == 'c10::List<c10::optional<Tensor>>':
saved_variables.append(f'std::vector<SavedVariable> {name}_;')
saved_variables.append(f'bool {name}_released_ = false;')
# Just clear() is sufficient, we don't need to loop and clear each variable.
# Because the SavedVariable owns a tensor and a grad_fn, removing the SavedVariable makes them go away as well.
release_variables.append(f'{name}_.clear();')
release_variables.append(f'{name}_released_ = true;')
unpack.append(f'auto {name} = unpack_opt_list({name}_);')
asserts.append(f'TORCH_CHECK(!{name}_released_, ERR_BACKWARD_TWICE);')
elif var.type == 'IntArrayRef':
saved_variables.append(f'std::vector<int64_t> {name};')
elif var.type == 'c10::optional<IntArrayRef>':

View File

@ -112,9 +112,8 @@ def format_trace_inputs(f: NativeFunction) -> str:
]
else:
name = arg.name
# XXX: For arg that have type of Tensor?[], tracer will pass allow_undefined to addInputs
if str(arg.type) == 'Tensor?[]':
return [f'jit::tracer::addInputs(node, "{name}", {name}, true);']
return [f'jit::tracer::addInputs(node, "{name}", {name});']
else:
return [ADD_TRACE_INPUT.substitute(name=name, input=name)]

View File

@ -118,6 +118,21 @@ for (size_t i=0; i<${tensorlist_name}.size(); i++) {
}
""")
SAVE_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\
std::vector<c10::optional<Storage>> ${tensorlist_name}_storage_saved(${tensorlist_name}.size());
for (const c10::optional<Tensor>& tensor : ${tensorlist_name})
${tensorlist_name}_storage_saved.push_back(
tensor.has_value() && tensor->has_storage() ? c10::optional<Storage>(tensor->storage()) : c10::nullopt);
""")
ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE = CodeTemplate("""\
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
if (${tensorlist_name}_storage_saved[i].has_value())
AT_ASSERT(${tensorlist_name}_storage_saved[i].value().is_alias_of(
static_cast<c10::optional<Tensor>>(${tensorlist_name}[i])->storage()));
}
""")
SAVE_TENSOR_IMPL = CodeTemplate("""\
c10::intrusive_ptr<TensorImpl> ${tensor_name}_impl_saved;
if (${tensor_name}.defined()) ${tensor_name}_impl_saved = ${tensor_name}.getIntrusivePtr();
@ -140,6 +155,21 @@ for (size_t i=0; i<${tensorlist_name}.size(); i++) {
}
""")
SAVE_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\
std::vector<c10::intrusive_ptr<TensorImpl>> ${tensorlist_name}_impl_saved(${tensorlist_name}.size());
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
c10::optional<Tensor> t = ${tensorlist_name}[i];
if (t.has_value() && t->defined()) ${tensorlist_name}_impl_saved[i] = t->getIntrusivePtr();
}
""")
ENFORCE_SAME_OPTIONALTENSORLIST_IMPL = CodeTemplate("""\
for (size_t i=0; i<${tensorlist_name}.size(); i++) {
if (${tensorlist_name}_impl_saved[i])
AT_ASSERT(${tensorlist_name}_impl_saved[i] == static_cast<c10::optional<Tensor>>(${tensorlist_name}[i])->getIntrusivePtr());
}
""")
# The following list contains functions that we don't enforce the invariant on.
DONT_ENFORCE_SAME_TENSOR_IMPL_OR_STORAGE = {
# These functions are expected to change impl or storage of input tensors
@ -466,7 +496,8 @@ def emit_body(declaration):
if func is None:
return setup
has_tensorlist_arg = any(arg.type == 'TensorList' for arg in func.args_with_derivatives)
has_tensorlist_arg = \
any(arg.type in ['TensorList', 'const c10::List<c10::optional<Tensor>> &'] for arg in func.args_with_derivatives)
# We don't want to save tensors if we know that they will never be used
# when computing the derivative, so we add guards to those statements
@ -515,7 +546,7 @@ def emit_body(declaration):
setup.extend(save_variables(func.all_saved_inputs, False, guard_for))
for arg in func.args_with_derivatives:
if arg.type == 'TensorList':
if arg.type in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']:
setup.append(f'grad_fn->{arg.name}_size_ = {arg.name}.size();')
return setup
@ -554,7 +585,7 @@ def emit_body(declaration):
return body
for arg in differentiable_outputs:
name = arg['name']
if arg['type'] == 'Tensor' or arg['type'] == 'TensorList':
if arg['type'] in ['Tensor', 'TensorList', 'const c10::List<c10::optional<Tensor>> &']:
body.append('throw_error_for_complex_autograd({}, "{}");'.format(name, base_name))
return body
@ -599,7 +630,7 @@ def emit_body(declaration):
expr = f'SavedVariable({var}, {str(is_output).lower()}, {is_inplace_view})'
else:
expr = f'SavedVariable({var}, {str(is_output).lower()})'
elif arg.type == 'TensorList':
elif arg.type in ['TensorList', 'c10::List<c10::optional<Tensor>>']:
name += '_'
expr = f'make_saved_variable_list({arg.name})'
elif arg.type == 'IntArrayRef':
@ -699,7 +730,7 @@ def emit_body(declaration):
# Only allow rebasing of the history if we return a single Tensor
# If we are in a no grad block, raise a warning
# See NOTE [ View + Inplace detection ] for more details about this logic
if return_info['dynamic_type'] == 'TensorList':
if return_info['dynamic_type'] in ['TensorList', 'const c10::List<c10::optional<Tensor>> &']:
if base_name in MULTI_OUTPUT_SAFE_FUNCTIONS:
creation_meta = "CreationMeta::MULTI_OUTPUT_SAFE"
else:
@ -736,6 +767,11 @@ def emit_body(declaration):
SAVE_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
enforce_same_ptrs_stmts += [ENFORCE_SAME_TENSORLIST_STORAGE.substitute(tensorlist_name=arg),
ENFORCE_SAME_TENSORLIST_IMPL.substitute(tensorlist_name=arg)]
elif simple_type == 'c10::List<c10::optional<Tensor>>':
save_ptrs_stmts += [SAVE_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
SAVE_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)]
enforce_same_ptrs_stmts += [ENFORCE_SAME_OPTIONALTENSORLIST_STORAGE.substitute(tensorlist_name=arg),
ENFORCE_SAME_OPTIONALTENSORLIST_IMPL.substitute(tensorlist_name=arg)]
elif simple_type == 'Tensor':
save_ptrs_stmts += [SAVE_TENSOR_STORAGE.substitute(tensor_name=arg),
SAVE_TENSOR_IMPL.substitute(tensor_name=arg)]
@ -836,7 +872,7 @@ def emit_body(declaration):
def unpack_args(env, declaration):
def requires_unpack(arg):
return 'Tensor' in arg['dynamic_type']
return 'Tensor' in arg['dynamic_type'] and 'c10::optional' not in arg['type']
body = []
unpacked_args = []
@ -855,9 +891,8 @@ def unpack_args(env, declaration):
dynamic_type = arg['dynamic_type']
if 'TensorOptions' not in dynamic_type:
is_nullable = arg.get('is_nullable', False)
ref = (not is_nullable) and dynamic_type not in ['TensorList']
ref = (not is_nullable) and dynamic_type != 'TensorList'
suffix = '_opt' if is_nullable and dynamic_type != 'TensorList' else ''
body.append(UNPACK_TENSOR.substitute(
arg_name=arg['name'],
arg_pos=i,

View File

@ -32,6 +32,15 @@ inline std::vector<Tensor> unpack_list(at::ArrayRef<SavedVariable> xs) {
});
}
inline c10::List<c10::optional<Tensor>> unpack_opt_list(at::ArrayRef<SavedVariable> xs) {
torch::List<c10::optional<Tensor>> result;
result.reserve(xs.size());
for (const SavedVariable& v : xs) {
result.push_back(v.unpack());
}
return result;
}
struct TypeAndSize {
TypeAndSize() : options(at::TensorOptions()) {}
/* implicit */

View File

@ -49,7 +49,6 @@ namespace VariableType {
at::Tensor & unpack(Tensor & t, const char * name, int pos);
const at::Tensor & unpack(const Tensor & t, const char * name, int pos);
at::Tensor unpack_opt(const Tensor & t, const char * name, int pos);
c10::optional<at::Tensor> unpack_opt(const c10::optional<Tensor> & t, const char * name, int pos);
std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos);
};

View File

@ -104,9 +104,11 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
return BaseCType("TensorList", binds)
elif str(t.elem) == 'Dimname':
return BaseCType("DimnameList", binds)
# TODO: do something reasonable about lists of optional tensors
elif (not local.use_c10_dispatcher().dispatcher_uses_new_style()) and str(t.elem) == 'Tensor?':
return BaseCType("TensorList", binds)
elif str(t.elem) == 'Tensor?':
if local.use_c10_dispatcher().dispatcher_uses_new_style():
return BaseCType("const c10::List<c10::optional<Tensor>> &", binds)
else:
return BaseCType("TensorList", binds)
elem = argumenttype_type(t.elem, mutable=mutable, binds=binds)
# TODO: explicitly qualify namespace here
return BaseCType(f"ArrayRef<{elem.cpp_type()}>", binds)

View File

@ -34,7 +34,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> CType:
else:
return ConstRefCType(BaseCType('Tensor', binds))
elif str(t) == 'Tensor?[]':
return BaseCType('TensorList', binds)
return BaseCType('const c10::List<c10::optional<Tensor>> &', binds)
return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
def returns_type(rs: Sequence[Return]) -> str:

View File

@ -228,7 +228,7 @@ class PythonArgument:
# Compute argument formal for python argument parsing.
# Needs to be consistent with torch/csrc/utils/python_arg_parser.h.
def argument_str(self, *, method: bool = False) -> str:
type_str = argument_type_str(self.type)
type_str = argument_type_str(self.type).replace('const ', '').replace(' &', '')
name = self.name
# s/self/input/ outside method bindings
@ -624,10 +624,9 @@ def argument_type_str(t: Type, *, simple_type: bool = False) -> str:
return f'ScalarList[{size}]' if size is not None else 'ScalarList'
elif str(t.elem) == 'Tensor?':
if simple_type:
return 'TensorList'
return 'c10::List<c10::optional<Tensor>>'
else:
# TODO: clone the old codegen behavior but does it make sense?
return 'TensorList?'
return 'const c10::List<c10::optional<Tensor>> &'
elif str(t.elem) == 'Dimname':
return f'DimnameList[{size}]' if size is not None else 'DimnameList'
elem = argument_type_str(t.elem, simple_type=simple_type)
@ -1051,12 +1050,14 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str:
return 'toDimnameListOptional'
elif isinstance(t, ListType):
if str(t.elem) == 'Tensor' or str(t.elem) == 'Tensor?':
if str(t.elem) == 'Tensor':
# accept and use definite size
if t.size is not None:
return f'tensorlist_n<{t.size}>'
else:
return 'tensorlist'
elif str(t.elem) == 'Tensor?':
return 'list_of_optional_tensors'
elif str(t.elem) == 'Dimname':
# accept definite size
return 'dimnamelist'

View File

@ -14,6 +14,7 @@
#include <ATen/ScalarOps.h>
#include <ATen/native/LinearAlgebraUtils.h>
#include <ATen/SparseTensorUtils.h>
#include <ATen/native/IndexingUtils.h>
#include <ciso646>
#include <algorithm>
@ -2211,15 +2212,17 @@ Tensor det_backward(const Tensor & grad, const Tensor& self, const Tensor& det)
return nonsingular_case_backward(grad, self, det);
}
} else {
auto nonzero_det_indices = at::where(det);
auto nonzero_det_indices = at::native::toListOfOptionalTensors(at::where(det));
c10::optional<Tensor> first_nonzero_det_index = nonzero_det_indices[0];
if (nonzero_det_indices[0].size(0) == det.numel()) { // all determinants are nonzero (non-singular)
if (first_nonzero_det_index->size(0) == det.numel()) { // all determinants are nonzero (non-singular)
return nonsingular_case_backward(grad, self, det);
}
auto zero_det_indices = at::where(det == 0);
auto zero_det_indices = at::native::toListOfOptionalTensors(at::where(det == 0));
c10::optional<Tensor> first_zero_det_index = zero_det_indices[0];
if (zero_det_indices[0].size(0) == det.numel()) { // all determinants are zero (singular)
if (first_zero_det_index->size(0) == det.numel()) { // all determinants are zero (singular)
return singular_case_backward(grad, self, det);
}
@ -2261,15 +2264,17 @@ Tensor logdet_backward(const Tensor & grad, const Tensor& self, const Tensor& lo
return singular_case_backward(grad, self);
}
} else {
auto finite_logdet_indices = at::where(logdet != -INFINITY);
auto finite_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet != -INFINITY));
c10::optional<Tensor> first_finite_logdet_index = finite_logdet_indices[0];
if (finite_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are finite (non-singular)
if (first_finite_logdet_index->size(0) == logdet.numel()) { // all log determinants are finite (non-singular)
return nonsingular_case_backward(grad, self);
}
auto neginf_logdet_indices = at::where(logdet == -INFINITY);
auto neginf_logdet_indices = at::native::toListOfOptionalTensors(at::where(logdet == -INFINITY));
c10::optional<Tensor> first_neginf_logdet_index = neginf_logdet_indices[0];
if (neginf_logdet_indices[0].size(0) == logdet.numel()) { // all log determinants are -inf (singular)
if (first_neginf_logdet_index->size(0) == logdet.numel()) { // all log determinants are -inf (singular)
return singular_case_backward(grad, self);
}
@ -2313,15 +2318,17 @@ Tensor slogdet_backward(const Tensor& grad_logabsdet,
return nonsingular_case_backward(grad_logabsdet, self);
}
} else {
auto nonzero_signdet_indices = at::where(signdet);
auto nonzero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet));
c10::optional<Tensor> first_nonzero_signdet_index = nonzero_signdet_indices[0];
if (nonzero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular)
if (first_nonzero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are finite (non-singular)
return nonsingular_case_backward(grad_logabsdet, self);
}
auto zero_signdet_indices = at::where(signdet == 0);
auto zero_signdet_indices = at::native::toListOfOptionalTensors(at::where(signdet == 0));
c10::optional<Tensor> first_zero_signdet_index = zero_signdet_indices[0];
if (zero_signdet_indices[0].size(0) == logabsdet.numel()) { // all log determinants are -inf (singular)
if (first_zero_signdet_index->size(0) == logabsdet.numel()) { // all log determinants are -inf (singular)
return singular_case_backward(grad_logabsdet, self);
}
@ -2873,8 +2880,8 @@ Tensor embedding_dense_double_backward(const Tensor & grad, const Tensor & indic
return gg_weight.view(size);
}
Tensor index_backward(Tensor zeros_like_self, TensorList indices, const Tensor& grad) {
return at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
Tensor index_backward(Tensor zeros_like_self, const torch::List<c10::optional<Tensor>>& indices, const Tensor& grad) {
return at::_index_put_impl_(zeros_like_self, indices, grad, true, true);
}
Tensor _cudnn_ctc_loss_backward(const Tensor& grad_out, const Tensor& loss, const Tensor& raw_grad, bool zero_infinity) {

View File

@ -124,7 +124,7 @@ at::Tensor slogdet_backward(const at::Tensor& grad_logabsdet, const at::Tensor&
at::Tensor log1p_backward(const at::Tensor& grad, const at::Tensor& self);
at::Tensor sparse_constructor_values_backward(const at::Tensor& sparse_grad_out, const at::Tensor& indices, at::IntArrayRef values_shape);
at::Tensor embedding_dense_double_backward(const at::Tensor & grad, const at::Tensor & indices, int64_t padding_idx);
at::Tensor index_backward(at::Tensor zeros_like_self, at::TensorList indices, const at::Tensor& grad);
at::Tensor index_backward(at::Tensor zeros_like_self, const torch::List<c10::optional<Tensor>>& indices, const at::Tensor& grad);
at::Tensor _cudnn_ctc_loss_backward(const at::Tensor& grad_out, const at::Tensor& loss, const at::Tensor& raw_grad, bool zero_infinity);
Tensor svd_backward(const std::vector<torch::autograd::Variable> &grads, const Tensor& self,

View File

@ -66,10 +66,6 @@ Tensor unpack_opt(const Tensor & t, const char * name, int pos) {
return unpack(t, name, pos);
}
c10::optional<Tensor> unpack_opt(const c10::optional<Tensor> & t, const char * name, int pos) {
return t;
}
std::vector<at::Tensor> unpack(at::TensorList tl, const char *name, int pos) {
std::vector<at::Tensor> ret(tl.size());
for (size_t i = 0; i < tl.size(); ++i) {
@ -94,7 +90,7 @@ void _backward(
// instead of us having to unwrap it to Tensor _gradient here.
Tensor _gradient = gradient.has_value() ? *gradient : Tensor();
std::vector<torch::autograd::Variable> input_vars(inputs.begin(), inputs.end());
torch::autograd::backward({self}, {_gradient}, std::move(keep_graph), create_graph, input_vars);
torch::autograd::backward({self}, {_gradient}, keep_graph, create_graph, input_vars);
}
void set_data(Tensor & self, const Tensor & new_data) {
@ -230,7 +226,6 @@ Tensor _fw_primal(const Tensor & self, int64_t level) {
// We don't have an outplace copy, so this can't be generated automatically
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
jit::Value* output = nullptr;
// TODO: once copy is exposed in Declarations.yaml we may be able to bind
// it automatically
auto& self_ = unpack(self, "self", 0);
@ -282,7 +277,7 @@ Tensor& resize_(
}
{
at::AutoNonVariableTypeMode non_var_type_mode(true);
self_.resize_(size, std::move(optional_memory_format));
self_.resize_(size, optional_memory_format);
}
if (self.fw_grad(/* level */ 0).defined()) {
@ -303,7 +298,7 @@ Tensor& resize_as_(
}
{
at::AutoNonVariableTypeMode non_var_type_mode(true);
at::resize_as_(self_, the_template_, std::move(optional_memory_format));
at::resize_as_(self_, the_template_, optional_memory_format);
}
// Handle fw grad

View File

@ -266,12 +266,31 @@ inline void check_no_requires_grad(TensorList tensors, const char* name) {
}
}
inline void check_no_requires_grad(const c10::List<c10::optional<Tensor>>& tensors, const char* name) {
for (c10::optional<Tensor> tensor : tensors) {
if (tensor.has_value()) {
check_no_requires_grad(*tensor, name);
}
}
}
// Assumed that saved tensor lists are never inplace outputs
inline std::vector<SavedVariable> make_saved_variable_list(TensorList tensors) {
return fmap(tensors, [](const Tensor& tensor) -> SavedVariable {
return SavedVariable{tensor, false /* is output */}; });
}
// Assumed that saved tensor lists are never inplace outputs
inline std::vector<SavedVariable> make_saved_variable_list(const c10::List<c10::optional<at::Tensor>>& tensors) {
return fmap(tensors, [](const c10::optional<Tensor>& tensor) -> SavedVariable {
if (tensor.has_value()) {
return SavedVariable{*tensor, false /* is output */};
} else {
return SavedVariable{Tensor(), false /* is output */};
}
});
}
inline std::vector<std::vector<int64_t>> to_args_sizes(TensorList tensors) {
std::vector<std::vector<int64_t>> args_sizes(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {

View File

@ -1,5 +1,6 @@
#pragma once
#include <ATen/core/jit_type.h>
#include <ATen/core/stack.h>
namespace torch {

View File

@ -103,6 +103,9 @@ void TracingState::delValue(const IValue& var) {
Value* getValueTrace(const IValue& var) {
return getTracingState()->getValue(var);
}
Value* getOptTensorValueTrace(const c10::optional<at::Tensor>& var) {
return getValueTrace(IValue(var));
}
Value* TracingState::getValue(const IValue& var) {
// allow tracing of tuples passed to List[Tensor] or Tuple[Tensor...]
// arguments
@ -686,6 +689,16 @@ void addInputs(
}
n->addInput(list_node->output());
}
TORCH_API void addInputs(
Node* n,
const char* name,
const List<c10::optional<at::Tensor>>& value) {
Graph* g = n->owningGraph();
Node* list_node = nullptr;
list_node = g->insertNode(g->createList(
OptionalType::ofTensor(), fmap(value, getOptTensorValueTrace)));
n->addInput(list_node->output());
}
void addInputs(
Node* n,

View File

@ -255,6 +255,10 @@ TORCH_API void addInputs(
const char* name,
ArrayRef<at::Tensor> value,
bool allow_undefined = false);
TORCH_API void addInputs(
Node* n,
const char* name,
const List<c10::optional<at::Tensor>>& value);
TORCH_API void addInputs(
Node* n,
const char* name,

View File

@ -1,5 +1,6 @@
#pragma once
//#include <ATen/core/function_schema.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/jit/mobile/function.h>
#include <torch/csrc/jit/mobile/method.h>

View File

@ -5,6 +5,7 @@
#include <ATen/ThreadLocalState.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <torch/csrc/WindowsTorchApiMacro.h>
#include <torch/csrc/jit/frontend/source_range.h>

View File

@ -908,7 +908,7 @@ RegisterOperators reg(
TORCH_SELECTIVE_SCHEMA(
"aten::index.Tensor_hacked_twin(Tensor self, Tensor[] indices) -> Tensor"),
[](Stack* stack) {
auto indices = pop(stack).toTensorVector();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result = at::index(self, indices);
push(stack, std::move(result));
@ -921,7 +921,7 @@ RegisterOperators reg(
auto unsafe = pop(stack).toBool();
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).toTensorVector();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result =
at::_index_put_impl_(self, indices, values, accumulate, unsafe);
@ -934,7 +934,7 @@ RegisterOperators reg(
[](Stack* stack) {
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).toTensorVector();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result = at::index_put_(self, indices, values, accumulate);
push(stack, std::move(result));
@ -946,7 +946,7 @@ RegisterOperators reg(
[](Stack* stack) {
auto accumulate = pop(stack).toBool();
auto values = pop(stack).toTensor();
auto indices = pop(stack).toTensorVector();
auto indices = pop(stack).to<List<c10::optional<at::Tensor>>>();
auto self = pop(stack).toTensor();
auto result = at::index_put_(self, indices, values, accumulate);
push(stack, std::move(result));

View File

@ -2,6 +2,7 @@
#include <ATen/core/List.h>
#include <ATen/core/functional.h>
#include <ATen/core/ivalue.h>
#include <ATen/core/jit_type.h>
#include <ATen/core/stack.h>
namespace torch {

View File

@ -24,6 +24,7 @@ static std::unordered_map<std::string, ParameterType> type_map = {
{"double", ParameterType::DOUBLE},
{"complex", ParameterType::COMPLEX},
{"TensorList", ParameterType::TENSOR_LIST},
{"c10::List<c10::optional<Tensor>>", ParameterType::TENSOR_LIST},
{"IntArrayRef", ParameterType::INT_LIST},
{"ArrayRef<double>", ParameterType::FLOAT_LIST},
{"Generator", ParameterType::GENERATOR},
@ -390,7 +391,7 @@ bool is_float_or_complex_list(PyObject* obj) {
}
auto size = tuple ? PyTuple_GET_SIZE(obj) : PyList_GET_SIZE(obj);
if (size > 0) {
if (size > 0) {
PyObject* iobj = tuple ? PyTuple_GET_ITEM(obj, 0) : PyList_GET_ITEM(obj, 0);
if (!THPUtils_checkDouble(iobj) && !PyComplex_Check(iobj)) {
return false;

View File

@ -160,6 +160,7 @@ struct PythonArgs {
inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
inline std::vector<at::Scalar> scalarlist(int i);
inline std::vector<at::Tensor> tensorlist(int i);
inline torch::List<c10::optional<at::Tensor>> list_of_optional_tensors(int i);
template<int N>
inline std::array<at::Tensor, N> tensorlist_n(int i);
inline std::vector<int64_t> intlist(int i);
@ -327,6 +328,22 @@ inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) {
return res;
}
inline torch::List<c10::optional<at::Tensor>> PythonArgs::list_of_optional_tensors(int i) {
if (!args[i]) return torch::List<c10::optional<at::Tensor>>();
auto tuple = six::isTuple(args[i]);
THPObjectPtr arg = six::maybeAsTuple(args[i]);
auto size = tuple ? PyTuple_GET_SIZE(arg.get()) : PyList_GET_SIZE(arg.get());
torch::List<c10::optional<at::Tensor>> res;
res.reserve(size);
for (int idx = 0; idx < size; idx++) {
PyObject* obj = tuple ? PyTuple_GET_ITEM(arg.get(), idx) : PyList_GET_ITEM(arg.get(), idx);
// This is checked by the argument parser so it's safe to cast without checking
// if this is a tensor first
res.push_back(reinterpret_cast<THPVariable*>(obj)->cdata);
}
return res;
}
template<int N>
inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) {
auto res = std::array<at::Tensor, N>();