mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Implement a constructor for nested_tensor that is similar to torch.tensor() (#88213)
Summary: This diff merges both previous implementations of constructors for nested tensors, the one from lists of tensors and the one with arbitrary python lists, adn implements it in pytorch core so no extensions are needed to construct NT. Pull Request resolved: https://github.com/pytorch/pytorch/pull/88213 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
72a7351993
commit
c77368d416
|
|
@ -114,7 +114,8 @@ Tensor& copy_nested_(Tensor& self, const Tensor& src, bool non_blocking) {
|
|||
const auto* nt_self = get_nested_tensor_impl(self);
|
||||
const auto* nt_src = get_nested_tensor_impl(src);
|
||||
TORCH_CHECK(
|
||||
at::equal(nt_self->get_nested_size_tensor(), nt_src->get_nested_size_tensor()),
|
||||
at::equal(
|
||||
nt_self->get_nested_size_tensor(), nt_src->get_nested_size_tensor()),
|
||||
"copy_ only supports tensors that are the same size for Nested implementations");
|
||||
nt_self->get_buffer().copy_(nt_src->get_buffer(), non_blocking);
|
||||
return self;
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/Dispatch.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <ATen/Parallel.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
|
|
@ -8,10 +10,12 @@
|
|||
#include <c10/util/Exception.h>
|
||||
|
||||
#ifndef AT_PER_OPERATOR_HEADERS
|
||||
|
||||
#include <ATen/Functions.h>
|
||||
#include <ATen/NativeFunctions.h>
|
||||
#else
|
||||
#include <ATen/ops/cat.h>
|
||||
#include <ATen/ops/empty.h>
|
||||
#include <ATen/ops/ones_native.h>
|
||||
#include <ATen/ops/prod.h>
|
||||
#include <ATen/ops/stack_native.h>
|
||||
|
|
@ -56,10 +60,11 @@ inline at::Tensor wrap_buffer(
|
|||
at::Tensor nested_stride_tensor,
|
||||
const std::vector<int64_t>& offsets) {
|
||||
std::vector<int64_t> offsets_copy(offsets);
|
||||
return wrap_buffer(buffer,
|
||||
nested_size_tensor,
|
||||
nested_stride_tensor,
|
||||
std::move(offsets_copy));
|
||||
return wrap_buffer(
|
||||
buffer,
|
||||
nested_size_tensor,
|
||||
nested_stride_tensor,
|
||||
std::move(offsets_copy));
|
||||
}
|
||||
|
||||
inline at::Tensor get_buffer(const at::Tensor& tensor) {
|
||||
|
|
@ -320,17 +325,84 @@ inline Tensor wrap_tensor_node(
|
|||
if (tensor_node.degree() == 0) {
|
||||
return wrap_buffer(ones({0}, dtype, layout, device), ones({}));
|
||||
}
|
||||
std::vector<Tensor> sizes;
|
||||
std::vector<Tensor> flat_tensors;
|
||||
|
||||
// Fast path: if all tensors are on CPU, have contiguous memory, and the same
|
||||
// dtype, copying can be done much faster.
|
||||
bool all_tensors_cpu = true;
|
||||
bool all_tensors_contiguous = true;
|
||||
bool all_tensors_same_dtype = true;
|
||||
auto first_dtype = tensor_node.children(0).dtype();
|
||||
std::vector<long> start_offsets(tensor_node.degree());
|
||||
start_offsets[0] = 0;
|
||||
long total_size = 0;
|
||||
for (const auto i : c10::irange(tensor_node.degree())) {
|
||||
flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous());
|
||||
sizes.push_back(tensor(c10::IntArrayRef(tensor_node.children(i).sizes())));
|
||||
all_tensors_cpu = all_tensors_cpu && tensor_node.children(i).is_cpu();
|
||||
all_tensors_contiguous =
|
||||
all_tensors_contiguous && tensor_node.children(i).is_contiguous();
|
||||
all_tensors_same_dtype = all_tensors_same_dtype &&
|
||||
(first_dtype == tensor_node.children(i).dtype());
|
||||
if (!(all_tensors_cpu && all_tensors_contiguous &&
|
||||
all_tensors_same_dtype)) {
|
||||
break;
|
||||
}
|
||||
if (i > 0) {
|
||||
start_offsets[i] =
|
||||
start_offsets[i - 1] + tensor_node.children(i - 1).numel();
|
||||
}
|
||||
total_size += tensor_node.children(i).numel();
|
||||
}
|
||||
|
||||
TensorOptions options = flat_tensors[0].options().merge_in(options_);
|
||||
TensorOptions options;
|
||||
Tensor nt_buffer, nt_sizes;
|
||||
if (all_tensors_cpu && all_tensors_contiguous && all_tensors_same_dtype) {
|
||||
nt_buffer = at::empty({total_size}, tensor_node.children(0).options());
|
||||
nt_sizes = at::empty(
|
||||
{static_cast<long>(tensor_node.degree()),
|
||||
static_cast<long>(tensor_node.children(0).sizes().size())},
|
||||
TensorOptions().dtype(kLong));
|
||||
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
|
||||
at::ScalarType::Half,
|
||||
at::ScalarType::Bool,
|
||||
at::ScalarType::BFloat16,
|
||||
c10::typeMetaToScalarType(first_dtype),
|
||||
"create_nt_buffer",
|
||||
[&]() {
|
||||
at::parallel_for(
|
||||
0, tensor_node.degree(), 1, [&](int64_t begin, int64_t end) {
|
||||
for (int64_t i = begin; i < end; ++i) {
|
||||
// Only try copying memory if there is more than 0 elements
|
||||
// for a certain tensor
|
||||
if (tensor_node.children(i).numel() > 0) {
|
||||
memcpy(
|
||||
nt_buffer.data_ptr<scalar_t>() + start_offsets[i],
|
||||
tensor_node.children(i).data_ptr<scalar_t>(),
|
||||
tensor_node.children(i).numel() * sizeof(scalar_t));
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
long sizes_offset = 0;
|
||||
for (size_t i = 0; i < tensor_node.degree(); ++i) {
|
||||
auto tensor_sizes = tensor_node.children(i).sizes();
|
||||
for (size_t j = 0; j < tensor_sizes.size(); ++j) {
|
||||
nt_sizes.data_ptr<int64_t>()[sizes_offset++] = tensor_sizes[j];
|
||||
}
|
||||
}
|
||||
options = nt_buffer.options().merge_in(options_);
|
||||
} else { // Slow path
|
||||
std::vector<Tensor> flat_tensors;
|
||||
std::vector<Tensor> sizes;
|
||||
for (const auto i : c10::irange(tensor_node.degree())) {
|
||||
flat_tensors.push_back(tensor_node.children(i).reshape(-1).contiguous());
|
||||
sizes.push_back(
|
||||
tensor(c10::IntArrayRef(tensor_node.children(i).sizes())));
|
||||
}
|
||||
options = flat_tensors[0].options().merge_in(options_);
|
||||
nt_buffer = at::cat(flat_tensors);
|
||||
nt_sizes = at::native::stack(sizes);
|
||||
}
|
||||
|
||||
return wrap_buffer(
|
||||
at::cat(flat_tensors).to(options), at::native::stack(sizes));
|
||||
return wrap_buffer(nt_buffer.to(options), nt_sizes);
|
||||
}
|
||||
|
||||
} // namespace impl
|
||||
|
|
|
|||
|
|
@ -899,6 +899,7 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/autograd/python_function.cpp",
|
||||
"torch/csrc/autograd/python_hook.cpp",
|
||||
"torch/csrc/autograd/python_legacy_variable.cpp",
|
||||
"torch/csrc/autograd/python_nested_functions_manual.cpp",
|
||||
"torch/csrc/autograd/python_torch_functions_manual.cpp",
|
||||
"torch/csrc/autograd/python_variable.cpp",
|
||||
"torch/csrc/autograd/python_variable_indexing.cpp",
|
||||
|
|
@ -960,6 +961,7 @@ libtorch_python_core_sources = [
|
|||
"torch/csrc/utils.cpp",
|
||||
"torch/csrc/utils/cuda_lazy_init.cpp",
|
||||
"torch/csrc/utils/invalid_arguments.cpp",
|
||||
"torch/csrc/utils/nested.cpp",
|
||||
"torch/csrc/utils/object_ptr.cpp",
|
||||
"torch/csrc/utils/python_arg_parser.cpp",
|
||||
"torch/csrc/utils/python_dispatch.cpp",
|
||||
|
|
|
|||
|
|
@ -525,6 +525,7 @@
|
|||
"Optional"
|
||||
],
|
||||
"torch.nested": [
|
||||
"nested_tensor",
|
||||
"to_padded_tensor"
|
||||
],
|
||||
"torch.nn.common_types": [
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
# Owner(s): ["module: nestedtensor"]
|
||||
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn
|
||||
import unittest
|
||||
import numpy as np
|
||||
from torch.testing._internal.common_device_type import (
|
||||
dtypes,
|
||||
dtypesIfCUDA,
|
||||
|
|
@ -16,11 +17,12 @@ from torch.testing._internal.common_dtype import floating_types_and_half
|
|||
from torch.testing._internal.common_utils import (
|
||||
freeze_rng_state,
|
||||
gradcheck,
|
||||
instantiate_parametrized_tests,
|
||||
IS_FBCODE,
|
||||
parametrize,
|
||||
run_tests,
|
||||
TestCase,
|
||||
subtest,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
# Tests are ported from pytorch/nestedtensor.
|
||||
|
|
@ -94,6 +96,76 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None):
|
|||
|
||||
|
||||
class TestNestedTensor(TestCase):
|
||||
@parametrize("batch_size", [2, 4])
|
||||
@parametrize("max_seq_len", [3, 5])
|
||||
@parametrize("vocab_size", [10, 20])
|
||||
def test_2d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
|
||||
data = []
|
||||
nested_tensor_ref_list = []
|
||||
for _ in range(batch_size):
|
||||
if max_seq_len == 0:
|
||||
length = 0
|
||||
else:
|
||||
length = np.random.randint(low=1, high=max_seq_len)
|
||||
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
|
||||
data.append(row)
|
||||
nested_tensor_ref_list.append(torch.tensor(row))
|
||||
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
|
||||
nested_tensor_list = nested_tensor.unbind()
|
||||
for id in range(batch_size):
|
||||
self.assertEqual(
|
||||
nested_tensor_list[id],
|
||||
nested_tensor_ref_list[id].type(torch.int64)
|
||||
)
|
||||
|
||||
@parametrize("batch_size", [2, 4])
|
||||
@parametrize("max_seq_len", [3, 5])
|
||||
@parametrize("vocab_size", [10, 20])
|
||||
def test_3d_nested_tensor(self, batch_size, max_seq_len, vocab_size):
|
||||
data = []
|
||||
nested_tensor_ref_list = []
|
||||
for _ in range(batch_size):
|
||||
if max_seq_len == 0:
|
||||
length = 0
|
||||
else:
|
||||
length = np.random.randint(low=1, high=max_seq_len)
|
||||
row = list(np.random.randint(low=0, high=vocab_size, size=(length,)))
|
||||
row = [list(item * np.arange(max_seq_len)) for item in row]
|
||||
data.append(row)
|
||||
nested_tensor_ref_list.append(torch.Tensor(row))
|
||||
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.int64)
|
||||
nested_tensor_list = nested_tensor.unbind()
|
||||
for id in range(batch_size):
|
||||
self.assertEqual(
|
||||
nested_tensor_list[id],
|
||||
nested_tensor_ref_list[id].type(torch.int64)
|
||||
)
|
||||
|
||||
@parametrize("batch_size", [2, 4])
|
||||
@parametrize("max_seq_len", [3, 5])
|
||||
@parametrize("vocab_size", [10, 20])
|
||||
def test_3d_nested_tensor_float(self, batch_size, max_seq_len, vocab_size):
|
||||
data = []
|
||||
nested_tensor_ref_list = []
|
||||
for _ in range(batch_size):
|
||||
if max_seq_len == 0:
|
||||
length = 0
|
||||
else:
|
||||
length = np.random.randint(low=1, high=max_seq_len)
|
||||
row = list(
|
||||
np.random.randint(low=0, high=vocab_size, size=(length,)).astype(float)
|
||||
)
|
||||
row = [list(item * np.arange(max_seq_len)) for item in row]
|
||||
data.append(row)
|
||||
nested_tensor_ref_list.append(torch.Tensor(row))
|
||||
nested_tensor = torch.nested.nested_tensor(data, dtype=torch.float)
|
||||
nested_tensor_list = nested_tensor.unbind()
|
||||
for id in range(batch_size):
|
||||
self.assertEqual(
|
||||
nested_tensor_list[id],
|
||||
nested_tensor_ref_list[id].type(torch.float)
|
||||
)
|
||||
|
||||
|
||||
@torch.inference_mode()
|
||||
def _test_unbind_case(self, a, b):
|
||||
|
|
@ -151,7 +223,6 @@ class TestNestedTensor(TestCase):
|
|||
|
||||
@torch.inference_mode()
|
||||
def test_nested_tensor(self):
|
||||
self.assertRaises(TypeError, lambda: torch.nested.nested_tensor([3.0]))
|
||||
self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(torch.tensor([3.0])))
|
||||
self.assertRaises(TypeError, lambda: torch.nested.nested_tensor(4.0))
|
||||
|
||||
|
|
@ -227,9 +298,7 @@ class TestNestedTensor(TestCase):
|
|||
a1 = constructor([])
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError,
|
||||
"Tensors of type NestedTensorImpl do not have sym sizes"
|
||||
if IS_FBCODE
|
||||
else "NestedTensorImpl doesn't support sizes",
|
||||
"NestedTensorImpl doesn't support sizes",
|
||||
lambda: a1.size(),
|
||||
)
|
||||
|
||||
|
|
@ -2241,6 +2310,7 @@ class TestNestedTensorAutograd(TestCase):
|
|||
self.assertEqual(nt.grad, expected_grad)
|
||||
|
||||
|
||||
instantiate_parametrized_tests(TestNestedTensor)
|
||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||
instantiate_device_type_tests(TestNestedTensorAutograd, globals())
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#include "torch/csrc/Device.h"
|
||||
#include "torch/csrc/DynamicTypes.h"
|
||||
#include "torch/csrc/Exceptions.h"
|
||||
#include "torch/csrc/autograd/python_special_functions.h"
|
||||
#include "torch/csrc/autograd/python_nested_functions.h"
|
||||
#include "torch/csrc/autograd/python_return_types.h"
|
||||
#include "torch/csrc/autograd/python_variable.h"
|
||||
#include "torch/csrc/autograd/utils/wrap_outputs.h"
|
||||
|
|
@ -47,6 +47,7 @@ namespace torch { namespace autograd {
|
|||
${py_forwards}
|
||||
|
||||
static PyMethodDef nested_functions[] = {
|
||||
{NULL, NULL, 0, NULL},
|
||||
${py_method_defs}
|
||||
{NULL}
|
||||
};
|
||||
|
|
@ -54,6 +55,7 @@ static PyMethodDef nested_functions[] = {
|
|||
static PyObject* THPNestedVariableFunctionsModule = NULL;
|
||||
|
||||
void initNestedFunctions(PyObject* module) {
|
||||
nested_functions[0] = get_nested_functions_manual()[0];
|
||||
static struct PyModuleDef def = {
|
||||
PyModuleDef_HEAD_INIT,
|
||||
"torch._C._nested",
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/ATen.h>
|
||||
#include <torch/torch.h>
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
|
||||
#include <algorithm>
|
||||
|
||||
namespace torch {
|
||||
namespace nested {
|
||||
|
|
@ -12,19 +14,51 @@ namespace nested {
|
|||
/// https://pytorch.org/docs/master/nested.html#torch.nested.nested_tensor
|
||||
///
|
||||
/// ```
|
||||
inline Tensor nested_tensor(
|
||||
TensorList list,
|
||||
c10::optional<ScalarType> dtype = c10::nullopt,
|
||||
c10::optional<Device> device = c10::nullopt,
|
||||
c10::optional<bool> requires_grad = false,
|
||||
c10::optional<bool> pin_memory = false) {
|
||||
std::vector<Tensor> new_list;
|
||||
for (const auto i : c10::irange(list.size())) {
|
||||
new_list.push_back(list[i].clone().detach());
|
||||
// implemented on python object to allow torch.nested.nested_tensor to be
|
||||
// constructed with arbitrarily nested python objects - for now, only arbitrary
|
||||
// python lists and lists of Tensors
|
||||
// See torch/csrc/autograd/python_nested_functions_manual.cpp for Python
|
||||
// implementation
|
||||
// See here for C++ implementation
|
||||
inline at::Tensor nested_tensor(
|
||||
at::TensorList nested_tensor_data,
|
||||
const at::TensorOptions& options = {}) {
|
||||
auto out = at::_nested_tensor_from_tensor_list(
|
||||
nested_tensor_data,
|
||||
c10::typeMetaToScalarType(options.dtype()),
|
||||
c10::nullopt,
|
||||
options.device(),
|
||||
options.pinned_memory());
|
||||
if (options.has_requires_grad() && options.requires_grad()) {
|
||||
out.requires_grad_(true);
|
||||
}
|
||||
auto out = torch::_nested_tensor_from_tensor_list(
|
||||
new_list, dtype, c10::nullopt, device, pin_memory);
|
||||
if (requires_grad.has_value() && requires_grad.value()) {
|
||||
return out;
|
||||
}
|
||||
|
||||
inline at::Tensor nested_tensor(
|
||||
at::ArrayRef<detail::TensorDataContainer> nested_tensor_data,
|
||||
const at::TensorOptions& options = {}) {
|
||||
for (const auto& tdc : nested_tensor_data) {
|
||||
TORCH_CHECK(
|
||||
tdc.is_init_list(),
|
||||
"nested_tensor() not implemented for these parameters");
|
||||
}
|
||||
// Construct a TensorList using nested_tensor_data
|
||||
std::vector<at::Tensor> tensor_list(nested_tensor_data.size());
|
||||
std::transform(
|
||||
nested_tensor_data.begin(),
|
||||
nested_tensor_data.end(),
|
||||
tensor_list.begin(),
|
||||
[&](const detail::TensorDataContainer& tdc) {
|
||||
return tdc.convert_to_tensor(options);
|
||||
});
|
||||
auto out = at::_nested_tensor_from_tensor_list(
|
||||
tensor_list,
|
||||
c10::typeMetaToScalarType(options.dtype()),
|
||||
c10::nullopt,
|
||||
options.device(),
|
||||
options.pinned_memory());
|
||||
if (options.has_requires_grad() && options.requires_grad()) {
|
||||
out.requires_grad_(true);
|
||||
}
|
||||
return out;
|
||||
|
|
@ -36,10 +70,10 @@ inline Tensor nested_tensor(
|
|||
/// https://pytorch.org/docs/master/nested.html#torch.nested.as_nested_tensor
|
||||
///
|
||||
/// ```
|
||||
inline Tensor as_nested_tensor(
|
||||
TensorList list,
|
||||
c10::optional<ScalarType> dtype = c10::nullopt,
|
||||
c10::optional<Device> device = c10::nullopt) {
|
||||
inline at::Tensor as_nested_tensor(
|
||||
at::TensorList list,
|
||||
c10::optional<at::ScalarType> dtype = c10::nullopt,
|
||||
c10::optional<at::Device> device = c10::nullopt) {
|
||||
return at::_nested_tensor_from_tensor_list(
|
||||
list, dtype, c10::nullopt, device, c10::nullopt);
|
||||
}
|
||||
|
|
@ -50,11 +84,11 @@ inline Tensor as_nested_tensor(
|
|||
/// https://pytorch.org/docs/master/nested.html#torch.nested.to_padded_tensor
|
||||
///
|
||||
/// ```
|
||||
inline Tensor to_padded_tensor(
|
||||
const Tensor& self,
|
||||
inline at::Tensor to_padded_tensor(
|
||||
const at::Tensor& self,
|
||||
double padding,
|
||||
OptionalIntArrayRef output_size = c10::nullopt) {
|
||||
return torch::nested_to_padded_tensor(self, padding, output_size);
|
||||
at::OptionalIntArrayRef output_size = c10::nullopt) {
|
||||
return at::nested_to_padded_tensor(self, padding, output_size);
|
||||
}
|
||||
|
||||
} // namespace nested
|
||||
|
|
|
|||
|
|
@ -3,7 +3,9 @@
|
|||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
PyMethodDef* get_nested_functions_manual();
|
||||
|
||||
void initNestedFunctions(PyObject* module);
|
||||
|
||||
}
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
|
|
|
|||
44
torch/csrc/autograd/python_nested_functions_manual.cpp
Normal file
44
torch/csrc/autograd/python_nested_functions_manual.cpp
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
#include <torch/csrc/utils/nested.h>
|
||||
#include <torch/csrc/utils/pycfunction_helpers.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
#include <torch/torch.h>
|
||||
|
||||
namespace torch {
|
||||
namespace autograd {
|
||||
|
||||
static PyObject* THPVariable_nested_tensor(
|
||||
PyObject* /*self*/,
|
||||
PyObject* args,
|
||||
PyObject* kwargs) {
|
||||
HANDLE_TH_ERRORS
|
||||
static PythonArgParser parser({
|
||||
"nested_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False)",
|
||||
});
|
||||
|
||||
constexpr int ctor_num_args = 5;
|
||||
ParsedArgs<ctor_num_args> parsed_args;
|
||||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
|
||||
jit::tracer::warn(
|
||||
"torch.nested.nested_tensor", jit::tracer::WARN_CONSTRUCTOR);
|
||||
return THPVariable_Wrap(torch::utils::nested_tensor_ctor(
|
||||
torch::tensors::get_default_dispatch_key(),
|
||||
torch::tensors::get_default_scalar_type(),
|
||||
r));
|
||||
END_HANDLE_TH_ERRORS
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(cppcoreguidelines-avoid-c-arrays,modernize-avoid-c-arrays)
|
||||
static PyMethodDef nested_functions_manual[] = {
|
||||
{"nested_tensor",
|
||||
castPyCFunctionWithKeywords(THPVariable_nested_tensor),
|
||||
METH_VARARGS | METH_KEYWORDS,
|
||||
nullptr},
|
||||
};
|
||||
|
||||
PyMethodDef* get_nested_functions_manual() {
|
||||
return nested_functions_manual;
|
||||
}
|
||||
|
||||
} // namespace autograd
|
||||
} // namespace torch
|
||||
91
torch/csrc/utils/nested.cpp
Normal file
91
torch/csrc/utils/nested.cpp
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
#include <ATen/ATen.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/nested.h>
|
||||
#include <torch/csrc/utils/pybind.h>
|
||||
#include <torch/csrc/utils/tensor_new.h>
|
||||
#include <torch/torch.h>
|
||||
#include <stdexcept>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace utils {
|
||||
|
||||
// NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs
|
||||
c10::TensorOptions typeIdWithDefault(
|
||||
PythonArgs& r,
|
||||
int device_idx,
|
||||
c10::DispatchKey dispatch_key) {
|
||||
auto options = dispatchKeyToTensorOptions(dispatch_key);
|
||||
if (!r.isNone(device_idx)) {
|
||||
options = options.device(r.device(device_idx));
|
||||
}
|
||||
return options;
|
||||
}
|
||||
|
||||
at::Tensor nested_tensor_ctor(
|
||||
c10::DispatchKey dispatch_key,
|
||||
at::ScalarType scalar_type,
|
||||
torch::PythonArgs& r) {
|
||||
TORCH_CHECK(r.idx == 0, "nested_tensor(): invalid arguments");
|
||||
|
||||
PyObject* data = r.pyobject(0);
|
||||
// Check if data is a list: Only List[Tensor] and List[List...[Scalar]] are
|
||||
// accepted for now
|
||||
TORCH_CHECK_TYPE(
|
||||
PyList_Check(data),
|
||||
"Only lists (List[Tensor] and List[List...[Scalar]]) are accepted in nested_tensor");
|
||||
|
||||
auto dtype_val = r.scalartypeWithDefault(1, scalar_type);
|
||||
auto tensor_options = typeIdWithDefault(r, 2, dispatch_key);
|
||||
bool pin_memory = r.toBool(3);
|
||||
bool args_requires_grad = r.toBool(4);
|
||||
|
||||
TORCH_CHECK(
|
||||
PyList_Size(data) >= 0,
|
||||
"Something went really wrong and your list has negative size");
|
||||
|
||||
// Check whether we are dealing with lists of tensors or not
|
||||
std::vector<at::Tensor> new_list(PyList_Size(data));
|
||||
for (const auto i : c10::irange(PyList_Size(data))) {
|
||||
PyObject* elem = PyList_GetItem(data, i);
|
||||
if (THPVariable_Check(elem)) {
|
||||
new_list[i] = THPVariable_Unpack(PyList_GetItem(data, i)).detach();
|
||||
TORCH_CHECK(
|
||||
!new_list[i].is_nested(),
|
||||
"We do not accept nested tensors as input to nested tensors");
|
||||
TORCH_CHECK(
|
||||
new_list[i].layout() == kStrided,
|
||||
"We do not accept non-strided layouts as input to nested tensors");
|
||||
} else {
|
||||
PythonArgs elem_r(r);
|
||||
std::array<PyObject*, 6> elem_args = {
|
||||
elem, // data
|
||||
r.args[1], // dtpye
|
||||
nullptr, // device (cpu)
|
||||
nullptr, // no pinned memory
|
||||
r.args[4], // requires grad
|
||||
nullptr // names
|
||||
};
|
||||
elem_r.args = elem_args.data();
|
||||
new_list[i] = tensor_ctor(dispatch_key, scalar_type, elem_r);
|
||||
}
|
||||
}
|
||||
|
||||
at::ScalarType final_dtype = dtype_val;
|
||||
if (r.isNone(1) && new_list.size() > 0) {
|
||||
final_dtype = c10::typeMetaToScalarType(new_list[0].dtype());
|
||||
}
|
||||
at::Device final_device = tensor_options.device();
|
||||
if (r.isNone(2) && new_list.size() > 0) {
|
||||
final_device = new_list[0].device();
|
||||
}
|
||||
auto out = at::_nested_tensor_from_tensor_list(
|
||||
new_list, final_dtype, c10::nullopt, final_device, pin_memory);
|
||||
out.requires_grad_(args_requires_grad);
|
||||
return out;
|
||||
}
|
||||
|
||||
} // namespace utils
|
||||
} // namespace torch
|
||||
17
torch/csrc/utils/nested.h
Normal file
17
torch/csrc/utils/nested.h
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
#pragma once
|
||||
|
||||
#include <torch/csrc/python_headers.h>
|
||||
#include <torch/csrc/utils/python_arg_parser.h>
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
|
||||
namespace torch {
|
||||
namespace utils {
|
||||
|
||||
at::Tensor nested_tensor_ctor(
|
||||
c10::DispatchKey dispatch_key,
|
||||
at::ScalarType scalar_type,
|
||||
PythonArgs& r);
|
||||
|
||||
} // namespace utils
|
||||
} // namespace torch
|
||||
|
|
@ -1,55 +1,25 @@
|
|||
from typing import List, Optional
|
||||
import torch
|
||||
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
|
||||
from torch import Tensor
|
||||
|
||||
from torch.types import _dtype as DType
|
||||
from torch.types import _device as Device
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._C import _add_docstr, _nested # type: ignore[attr-defined]
|
||||
|
||||
from torch.types import _device as Device, _dtype as DType
|
||||
|
||||
__all__ = [
|
||||
'to_padded_tensor',
|
||||
'as_nested_tensor',
|
||||
'nested_tensor',
|
||||
"to_padded_tensor",
|
||||
"as_nested_tensor",
|
||||
"nested_tensor",
|
||||
]
|
||||
|
||||
# Nested Tensor constructor functions
|
||||
# TODO: move these to pybind to accept numpy/nested lists as inputs in the future
|
||||
def nested_tensor(tensor_list: List[Tensor], *, dtype: Optional[DType] = None, device: Optional[Device] = None,
|
||||
requires_grad: Optional[bool] = False, pin_memory: Optional[bool] = False) -> Tensor:
|
||||
r"""
|
||||
Constructs a nested tensor with no autograd history (also known as a “leaf tensor”, see
|
||||
:ref:`Autograd mechanics <autograd-mechanics>`) from :attr:`tensor_list` a list of tensors.
|
||||
|
||||
Args:
|
||||
tensor_list (List[Tensor]): a list of tensors with the same ndim
|
||||
|
||||
Keyword arguments:
|
||||
dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor.
|
||||
Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
|
||||
device (:class:`torch.device`, optional): the desired device of returned nested tensor.
|
||||
Default: if None, same :class:`torch.device` as leftmost tensor in the list
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned nested tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned nested tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
|
||||
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], requires_grad=True)
|
||||
>>> nt.is_leaf
|
||||
True
|
||||
"""
|
||||
if not isinstance(tensor_list, list) or any([not torch.is_tensor(t) for t in tensor_list]):
|
||||
raise TypeError("nested_tensor(): Expected first argument to be a list of tensors ")
|
||||
new_data = [t.detach() for t in tensor_list]
|
||||
nt = torch._nested_tensor_from_tensor_list(new_data, dtype, None, device, pin_memory)
|
||||
if (requires_grad):
|
||||
nt.requires_grad_(requires_grad)
|
||||
return nt
|
||||
|
||||
def as_nested_tensor(tensor_list: List[Tensor], dtype: Optional[DType] = None, device: Optional[Device] = None) -> Tensor:
|
||||
def as_nested_tensor(
|
||||
tensor_list: List[Tensor],
|
||||
dtype: Optional[DType] = None,
|
||||
device: Optional[Device] = None,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Constructs a nested tensor preserving autograd history from :attr:`tensor_list` a list of tensors.
|
||||
|
||||
|
|
@ -79,15 +49,21 @@ def as_nested_tensor(tensor_list: List[Tensor], dtype: Optional[DType] = None, d
|
|||
>>> b.grad
|
||||
tensor([0., 0., 0., 0., 0.])
|
||||
"""
|
||||
if not isinstance(tensor_list, list) or any([not torch.is_tensor(t) for t in tensor_list]):
|
||||
raise TypeError("nested_tensor(): Expected first argument to be a list of tensors ")
|
||||
if not isinstance(tensor_list, list) or any(
|
||||
[not torch.is_tensor(t) for t in tensor_list]
|
||||
):
|
||||
raise TypeError(
|
||||
"nested_tensor(): Expected first argument to be a list of tensors "
|
||||
)
|
||||
return torch._nested_tensor_from_tensor_list(tensor_list, dtype, None, device, None)
|
||||
|
||||
|
||||
# Note: This not only adds doc strings for the nested ops, but
|
||||
# also connects the torch.nested Python namespace to the torch._C._nested builtins.
|
||||
|
||||
to_padded_tensor = _add_docstr(_nested.nested_to_padded_tensor,
|
||||
r"""
|
||||
to_padded_tensor = _add_docstr(
|
||||
_nested.nested_to_padded_tensor,
|
||||
r"""
|
||||
to_padded_tensor(input, padding, output_size=None, out=None) -> Tensor
|
||||
|
||||
Returns a new (non-nested) Tensor by padding the :attr:`input` nested tensor.
|
||||
|
|
@ -137,4 +113,37 @@ Example::
|
|||
>>> pt_small = torch.nested.to_padded_tensor(nt, 2.0, (2, 2, 2))
|
||||
RuntimeError: Value in output_size is less than NestedTensor padded size. Truncation is not supported.
|
||||
|
||||
""")
|
||||
""",
|
||||
)
|
||||
|
||||
nested_tensor = _add_docstr(
|
||||
_nested.nested_tensor,
|
||||
r"""
|
||||
nested_tensor(tensor_list, *, dtype=None, device=None, requires_grad=False, pin_memory=False) -> Tensor
|
||||
|
||||
Constructs a nested tensor with no autograd history (also known as a “leaf tensor”, see
|
||||
:ref:`Autograd mechanics <autograd-mechanics>`) from :attr:`tensor_list` a list of tensors.
|
||||
|
||||
Args:
|
||||
tensor_list (List[array_like]): a list of tensors (or anything that can be passed to torch.tensor)
|
||||
where their first dimension can be of irregular size, but all other dimensions have to be equal.
|
||||
|
||||
Keyword arguments:
|
||||
dtype (:class:`torch.dtype`, optional): the desired type of returned nested tensor.
|
||||
Default: if None, same :class:`torch.dtype` as leftmost tensor in the list.
|
||||
device (:class:`torch.device`, optional): the desired device of returned nested tensor.
|
||||
Default: if None, same :class:`torch.device` as leftmost tensor in the list
|
||||
requires_grad (bool, optional): If autograd should record operations on the
|
||||
returned nested tensor. Default: ``False``.
|
||||
pin_memory (bool, optional): If set, returned nested tensor would be allocated in
|
||||
the pinned memory. Works only for CPU tensors. Default: ``False``.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.arange(3, dtype=torch.float, requires_grad=True)
|
||||
>>> b = torch.arange(5, dtype=torch.float, requires_grad=True)
|
||||
>>> nt = torch.nested.nested_tensor([a, b], requires_grad=True)
|
||||
>>> nt.is_leaf
|
||||
True
|
||||
""",
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user