mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Nested tensor forward only chunk operations (#85645)
# Summary Taking over this pr: https://github.com/pytorch/pytorch/pull/83736 Adding support for chunk without autograd support Pull Request resolved: https://github.com/pytorch/pytorch/pull/85645 Approved by: https://github.com/cpuhrsch
This commit is contained in:
parent
4fc0d5341c
commit
16f65f178a
|
|
@ -1280,6 +1280,9 @@
|
|||
variants: function, method
|
||||
device_check: NoCheck
|
||||
device_guard: False
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: chunk
|
||||
NestedTensorCPU, NestedTensorCUDA: chunk_nested_tensor
|
||||
|
||||
- func: tensor_split.sections(Tensor(a -> *) self, int sections, int dim=0) -> Tensor(a)[]
|
||||
variants: function, method
|
||||
|
|
|
|||
|
|
@ -87,7 +87,9 @@ std::vector<at::Tensor> NestedTensor_unbind(
|
|||
}
|
||||
|
||||
Tensor& NestedTensor_relu_(Tensor& self) {
|
||||
auto buffer = get_nested_tensor_impl(self)->get_buffer();
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
check_numel_equals_buffer_size(self_ptr);
|
||||
auto buffer = self_ptr->get_buffer();
|
||||
at::relu_(buffer);
|
||||
return self;
|
||||
}
|
||||
|
|
@ -97,7 +99,9 @@ Tensor NestedTensor_relu(const Tensor& self) {
|
|||
}
|
||||
|
||||
Tensor& NestedTensor_gelu_(Tensor& self, c10::string_view approximate) {
|
||||
auto buffer = get_nested_tensor_impl(self)->get_buffer();
|
||||
auto self_ptr = get_nested_tensor_impl(self);
|
||||
check_numel_equals_buffer_size(self_ptr);
|
||||
auto buffer = self_ptr->get_buffer();
|
||||
at::gelu_(buffer, approximate);
|
||||
return self;
|
||||
}
|
||||
|
|
@ -712,11 +716,11 @@ Tensor clone_nested(
|
|||
else if (memory_format == c10::MemoryFormat::Contiguous) {
|
||||
const Tensor& self_buffer = self_ptr->get_unsafe_storage_as_tensor(),
|
||||
sizemat = self_ptr->get_nested_size_tensor();
|
||||
Tensor output_buffer = at::empty_like(self_buffer);
|
||||
Tensor output_buffer = at::empty(self.numel(), self_buffer.options());
|
||||
Tensor output = wrap_buffer(output_buffer, sizemat);
|
||||
std::vector<Tensor> self_unbind = self.unbind(),
|
||||
output_unbind = output.unbind();
|
||||
for (int64_t i = 0; i < self_ptr->size(0); i++) {
|
||||
for (const int64_t i: c10::irange(self_ptr->size(0))) {
|
||||
output_unbind[i].copy_(self_unbind[i]);
|
||||
}
|
||||
return output;
|
||||
|
|
|
|||
|
|
@ -213,6 +213,7 @@ Tensor NestedTensor_batch_offsets_from_size_tensor(
|
|||
|
||||
Tensor NestedTensor_to_mask(const Tensor& nt, c10::optional<int64_t> mask_dim, c10::optional<int64_t> mask_dim_length) {
|
||||
auto* nt_impl = get_nested_tensor_impl(nt);
|
||||
TORCH_CHECK(nested_tensor_impl_is_contiguous(nt_impl), "to_mask only works on contiguous NestedTensors.");
|
||||
TORCH_CHECK(
|
||||
!mask_dim || *mask_dim < nt.dim(),
|
||||
"Requested mask dimension ",
|
||||
|
|
|
|||
|
|
@ -56,6 +56,51 @@ int64_t get_consistent_last_dim_of_nested_tensor(const NestedTensorImpl& nt) {
|
|||
nt.get_nested_size_tensor().select(1, -1));
|
||||
return *last_dim;
|
||||
}
|
||||
|
||||
std::vector<Tensor> chunk_nested_tensor(const Tensor& self, int64_t chunks, int64_t dim) {
|
||||
int64_t ndim = self.dim();
|
||||
if (ndim == 0) {
|
||||
TORCH_CHECK_INDEX(false, "chunk() cannot be applied to a 0-dim tensor.");
|
||||
}
|
||||
dim = maybe_wrap_dim(dim, ndim);
|
||||
TORCH_CHECK(self.dim() - 1 == dim,
|
||||
"Chunk for nested tensors is currently only supported for the last dimension.");
|
||||
TORCH_CHECK(chunks > 0,"chunk expects `chunks` to be greater than 0, got: ", chunks);
|
||||
TORCH_CHECK(self.is_contiguous(), "chunk expects `self` to be contiguous.");
|
||||
auto self_impl = get_nested_tensor_impl(self);
|
||||
const int64_t last_dim_size = get_consistent_last_dim_of_nested_tensor(*self_impl);
|
||||
TORCH_CHECK(last_dim_size % chunks == 0,
|
||||
"Chunk for nested tensors is only supported for nested tensors with trailing dimension divisible by chunks, got: ",
|
||||
last_dim_size, " % ", chunks, " != 0");
|
||||
int64_t n_tensors = self.size(0);
|
||||
int64_t split_size = last_dim_size / chunks;
|
||||
std::vector<Tensor> splits(chunks);
|
||||
const auto& sizes = self_impl->get_nested_size_tensor();
|
||||
const auto& strides = self_impl->get_nested_stride_tensor();
|
||||
const std::vector<int64_t>& offsets = self_impl->get_storage_offsets();
|
||||
// Account for the implicit batch dim
|
||||
--dim;
|
||||
int64_t tensor_dim = sizes.size(1);
|
||||
for (const auto split_idx : c10::irange(chunks)) {
|
||||
auto new_sizes = sizes.clone() ;
|
||||
auto new_strides = strides.clone();
|
||||
// This copys offsets so we are safe to move
|
||||
auto new_offsets = std::vector<int64_t>(offsets);
|
||||
int64_t *size_ptr = new_sizes.data_ptr<int64_t>();
|
||||
int64_t *stride_ptr = new_strides.data_ptr<int64_t>();
|
||||
// Get start val for each split
|
||||
int64_t start_val = split_idx * split_size;
|
||||
for (int64_t i : c10::irange(n_tensors)) {
|
||||
const int64_t index = i * tensor_dim + dim;
|
||||
new_offsets[i] = offsets[i] + start_val * stride_ptr[index];
|
||||
size_ptr[index] = split_size;
|
||||
stride_ptr[index] *= 1;
|
||||
}
|
||||
splits[split_idx] = create_nested_view_tensor(self, new_sizes, new_strides, std::move(new_offsets));
|
||||
}
|
||||
return splits;
|
||||
}
|
||||
|
||||
std::vector<IntArrayRef> NestedTensor_get_sizes(
|
||||
const NestedTensorImpl* self_ptr) {
|
||||
int64_t ntensors = self_ptr->size(0);
|
||||
|
|
|
|||
|
|
@ -126,6 +126,19 @@ inline std::vector<IntArrayRef> NestedTensor_get_strides(
|
|||
const NestedTensorImpl* self_ptr = get_nested_tensor_impl(self);
|
||||
return NestedTensor_get_strides(self_ptr);
|
||||
}
|
||||
|
||||
inline void check_numel_equals_buffer_size(const at::Tensor& self) {
|
||||
auto self_impl = get_nested_tensor_impl(self);
|
||||
TORCH_CHECK(
|
||||
self.numel() == self_impl -> get_buffer_size(),
|
||||
"Number of elements in nested tensor must match number of elements in buffer.");
|
||||
}
|
||||
|
||||
inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) {
|
||||
TORCH_CHECK(
|
||||
self_ptr-> numel() == self_ptr -> get_buffer_size(),
|
||||
"Number of elements in nested tensor must match number of elements in buffer.");
|
||||
}
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Data structures and functions for generically applying a function on a nested tensor.
|
||||
namespace impl {
|
||||
|
|
|
|||
|
|
@ -618,6 +618,67 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
answer = torch.tensor(200.0, device=device, dtype=dtype).expand(4)
|
||||
self.assertEqual(nt[1, 1, :], answer)
|
||||
|
||||
# Test that indexing works when requires_grad_(True)
|
||||
# previously this was failing because the backward kernel for select.int uses .sizes()
|
||||
nt = torch.nested.nested_tensor([x0, x1]).requires_grad_(True)
|
||||
self.assertEqual(nt[0], x0)
|
||||
self.assertEqual(nt[-1], x1)
|
||||
grad_x0 = torch.randn((2, 5), device=device, dtype=dtype)
|
||||
nt[0].backward(grad_x0)
|
||||
expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4), device=device, dtype=dtype)])
|
||||
self.assertEqual(nt.grad, expected_grad)
|
||||
|
||||
@dtypes(*floating_types_and_half())
|
||||
def test_nested_tensor_chunk(self, device, dtype):
|
||||
# Transformer use case
|
||||
a = torch.randn(3, 3 * 4, device=device, dtype=dtype)
|
||||
b = torch.randn(2, 3 * 4, device=device, dtype=dtype)
|
||||
c = torch.randn(1, 3 * 4, device=device, dtype=dtype)
|
||||
a_chunks = a.chunk(3, dim=-1)
|
||||
b_chunks = b.chunk(3, dim=-1)
|
||||
c_chunks = c.chunk(3, dim=-1)
|
||||
|
||||
a_nt = [a_chunks[0], b_chunks[0], c_chunks[0]]
|
||||
b_nt = [a_chunks[1], b_chunks[1], c_chunks[1]]
|
||||
c_nt = [a_chunks[2], b_chunks[2], c_chunks[2]]
|
||||
|
||||
nt = torch.nested.nested_tensor([a, b, c])
|
||||
chunked = nt.chunk(3, dim=-1)
|
||||
|
||||
self.assertEqual(chunked[0], torch.nested.nested_tensor(a_nt))
|
||||
self.assertEqual(chunked[1], torch.nested.nested_tensor(b_nt))
|
||||
self.assertEqual(chunked[2], torch.nested.nested_tensor(c_nt))
|
||||
|
||||
for chunk in chunked:
|
||||
self.assertFalse(chunk.is_contiguous())
|
||||
|
||||
# Failure chunking on ragged dimensions
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "Chunk for nested tensors is currently only supported for the last dimension.",
|
||||
lambda: torch.chunk(nt, 5, dim=1))
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "Chunk for nested tensors is currently only supported for the last dimension.",
|
||||
lambda: torch.chunk(nt, 5, dim=0))
|
||||
|
||||
# Failure on non-contiguous nt
|
||||
_, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3), device, dtype)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "chunk expects `self` to be contiguous.", lambda: torch.chunk(nt_noncontiguous, 5, dim=-1))
|
||||
|
||||
# Failure when calling non divisible n_chunks
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "Chunk for nested tensors is only supported for "
|
||||
"nested tensors with trailing dimension divisible by chunks.",
|
||||
lambda: torch.chunk(nt, 5, dim=-1))
|
||||
|
||||
# Failure when calling backward on a chunk
|
||||
a = torch.randn(3, 3 * 4, device=device, dtype=dtype, requires_grad=True)
|
||||
b = torch.randn(2, 3 * 4, device=device, dtype=dtype, requires_grad=True)
|
||||
nt_grad = torch.nested.as_nested_tensor([a, b])
|
||||
chunked = torch.chunk(nt_grad, 2, dim=-1)
|
||||
self.assertRaisesRegex(RuntimeError, "derivative for aten::chunk is not implemented",
|
||||
lambda: chunked[0].backward(chunked[0].clone()))
|
||||
|
||||
@dtypes(torch.float, torch.float16, torch.double)
|
||||
@torch.inference_mode()
|
||||
def test_nested_tensor_indexing_noncontiguous(self, device, dtype):
|
||||
|
|
@ -758,6 +819,24 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, "NestedTensor always requires keepdim=True for now."):
|
||||
torch.nested.nested_tensor([torch.tensor([3, 4, 5]), torch.tensor([1, 2])]).sum(-1)
|
||||
|
||||
@dtypes(torch.float, torch.float16)
|
||||
def test_contiguous(self, device, dtype):
|
||||
# Since we don't have access to the buffer in python this is harder to show what
|
||||
# we are testing for. When we call chunk on a consistent dim of a NT
|
||||
# for chunk_size > 1 the resulting tensors are views of the original NT
|
||||
# whose numels is now less than the size of the buffer. Clone was
|
||||
# previously creating a new NT with a buffer that was the same size as the
|
||||
# original.
|
||||
nt_contiguous = torch.nested.nested_tensor([torch.randn(2, 20, device=device, dtype=dtype),
|
||||
torch.randn(4, 20, device=device, dtype=dtype)])
|
||||
# Split up the last dimension which has a consistent size of 20 into 5 chunks
|
||||
chunks = nt_contiguous.chunk(5, dim=-1)
|
||||
|
||||
# # Check chunks are contiguous after calling contiguous
|
||||
for chunk in chunks:
|
||||
self.assertFalse(chunk.is_contiguous())
|
||||
self.assertTrue(chunk.contiguous().is_contiguous())
|
||||
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@skipMeta
|
||||
def test_clone(self, device, dtype):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user