mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support non-contiguous NestedTensors for elementwise ops (#87888)
Enables benchmarking of math path of sdp kernel Pull Request resolved: https://github.com/pytorch/pytorch/pull/87888 Approved by: https://github.com/drisspg
This commit is contained in:
parent
f150e70ca2
commit
b192e7e415
|
|
@ -1881,6 +1881,7 @@
|
|||
dispatch:
|
||||
SparseCPU, SparseCUDA: div_sparse
|
||||
ZeroTensor: div_zerotensor
|
||||
NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Tensor
|
||||
tags: canonical
|
||||
|
||||
- func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
|
||||
|
|
@ -1928,6 +1929,7 @@
|
|||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: div
|
||||
NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Scalar
|
||||
tags: canonical
|
||||
|
||||
- func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)
|
||||
|
|
|
|||
|
|
@ -500,10 +500,21 @@ get_elementwise_nested_tensor_impl(
|
|||
op_name,
|
||||
" does not support broadcasting when given a NestedTensor");
|
||||
TORCH_CHECK(
|
||||
nested_tensor_impl_is_contiguous(self_ptr) &&
|
||||
nested_tensor_impl_is_contiguous(other_ptr),
|
||||
at::equal(
|
||||
self_ptr->get_nested_stride_tensor(),
|
||||
other_ptr->get_nested_stride_tensor()),
|
||||
op_name,
|
||||
" does not support non-contiguous NestedTensor inputs");
|
||||
" requires strides to match when given NestedTensors");
|
||||
auto self_offsets = self_ptr->get_storage_offsets();
|
||||
auto other_offsets = other_ptr->get_storage_offsets();
|
||||
bool offsets_match = true;
|
||||
for (size_t i = 0; i < self_offsets.size(); i++) {
|
||||
offsets_match = offsets_match && (self_offsets[i] == other_offsets[i]);
|
||||
}
|
||||
TORCH_CHECK(
|
||||
offsets_match,
|
||||
op_name,
|
||||
" requires offsets to match when given NestedTensors");
|
||||
return std::make_pair(self_ptr, other_ptr);
|
||||
}
|
||||
|
||||
|
|
@ -517,16 +528,20 @@ Tensor NestedTensor_elementwise_Tensor(
|
|||
if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) {
|
||||
auto other_impl = get_nested_tensor_impl(other);
|
||||
return wrap_buffer(
|
||||
f(self, other_impl->get_buffer()),
|
||||
other_impl->get_nested_size_tensor().clone()
|
||||
f(self, other_impl->get_unsafe_storage_as_tensor()),
|
||||
other_impl->get_nested_size_tensor().clone(),
|
||||
other_impl->get_nested_stride_tensor().clone(),
|
||||
other_impl->get_storage_offsets()
|
||||
);
|
||||
}
|
||||
// other is a scalar
|
||||
if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) {
|
||||
auto self_impl = get_nested_tensor_impl(self);
|
||||
return wrap_buffer(
|
||||
f(self_impl->get_buffer(), other),
|
||||
self_impl->get_nested_size_tensor().clone()
|
||||
f(self_impl->get_unsafe_storage_as_tensor(), other),
|
||||
self_impl->get_nested_size_tensor().clone(),
|
||||
self_impl->get_nested_stride_tensor().clone(),
|
||||
self_impl->get_storage_offsets()
|
||||
);
|
||||
}
|
||||
NestedTensorImpl* self_impl = nullptr;
|
||||
|
|
@ -535,13 +550,12 @@ Tensor NestedTensor_elementwise_Tensor(
|
|||
get_elementwise_nested_tensor_impl(self, other, op_name);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl);
|
||||
const auto& nt_self = *self_impl;
|
||||
const auto& nt_other = *other_impl;
|
||||
const auto& self_sizes = nt_self.get_nested_size_tensor();
|
||||
return wrap_buffer(
|
||||
f(nt_self.get_buffer().reshape({-1}),
|
||||
nt_other.get_buffer().reshape({-1})),
|
||||
self_sizes);
|
||||
f(self_impl->get_unsafe_storage_as_tensor(),
|
||||
other_impl->get_unsafe_storage_as_tensor()),
|
||||
self_impl->get_nested_size_tensor(),
|
||||
self_impl->get_nested_stride_tensor(),
|
||||
self_impl->get_storage_offsets());
|
||||
}
|
||||
|
||||
Tensor NestedTensor_add_Tensor(
|
||||
|
|
@ -566,6 +580,18 @@ Tensor NestedTensor_mul_Scalar(const Tensor& self, const Scalar& other) {
|
|||
return NestedTensor_mul_Tensor(self, wrapped_scalar_tensor(other));
|
||||
}
|
||||
|
||||
Tensor NestedTensor_div_Tensor(const Tensor& self, const Tensor& other) {
|
||||
return NestedTensor_elementwise_Tensor(
|
||||
self, other, "div", [](const Tensor& b1, const Tensor& b2) {
|
||||
return at::div(b1, b2);
|
||||
});
|
||||
}
|
||||
|
||||
// Only usable on the C++ side; scalars are converted to tensors coming from Python.
|
||||
Tensor NestedTensor_div_Scalar(const Tensor& self, const Scalar& other) {
|
||||
return NestedTensor_div_Tensor(self, wrapped_scalar_tensor(other));
|
||||
}
|
||||
|
||||
template <typename Func>
|
||||
Tensor& NestedTensor_elementwise__Tensor(
|
||||
Tensor& self,
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <ATen/NestedTensorImpl.h>
|
||||
#include <ATen/core/Tensor.h>
|
||||
#include <c10/core/DispatchKeySet.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
|
|
@ -50,6 +50,18 @@ inline at::Tensor wrap_buffer(
|
|||
std::move(offsets));
|
||||
}
|
||||
|
||||
inline at::Tensor wrap_buffer(
|
||||
at::Tensor buffer,
|
||||
at::Tensor nested_size_tensor,
|
||||
at::Tensor nested_stride_tensor,
|
||||
const std::vector<int64_t>& offsets) {
|
||||
std::vector<int64_t> offsets_copy(offsets);
|
||||
return wrap_buffer(buffer,
|
||||
nested_size_tensor,
|
||||
nested_stride_tensor,
|
||||
std::move(offsets_copy));
|
||||
}
|
||||
|
||||
inline at::Tensor get_buffer(const at::Tensor& tensor) {
|
||||
return get_nested_tensor_impl(tensor)->get_buffer();
|
||||
}
|
||||
|
|
@ -119,7 +131,6 @@ inline std::vector<IntArrayRef> NestedTensor_get_sizes(
|
|||
return sizes;
|
||||
}
|
||||
|
||||
|
||||
TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
|
||||
const NestedTensorImpl& nt);
|
||||
|
||||
|
|
@ -161,17 +172,18 @@ inline std::vector<IntArrayRef> NestedTensor_get_strides(
|
|||
inline void check_numel_equals_buffer_size(const at::Tensor& self) {
|
||||
auto self_impl = get_nested_tensor_impl(self);
|
||||
TORCH_CHECK(
|
||||
self.numel() == self_impl -> get_buffer_size(),
|
||||
self.numel() == self_impl->get_buffer_size(),
|
||||
"Number of elements in nested tensor must match number of elements in buffer.");
|
||||
}
|
||||
|
||||
inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) {
|
||||
TORCH_CHECK(
|
||||
self_ptr-> numel() == self_ptr -> get_buffer_size(),
|
||||
self_ptr->numel() == self_ptr->get_buffer_size(),
|
||||
"Number of elements in nested tensor must match number of elements in buffer.");
|
||||
}
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
// Data structures and functions for generically applying a function on a nested tensor.
|
||||
// Data structures and functions for generically applying a function on a nested
|
||||
// tensor.
|
||||
namespace impl {
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -703,7 +703,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
|
|||
auto attn_mask = attn_mask_;
|
||||
// Naive, composite implementation defined here.
|
||||
const auto embed_size = query_.size(-1);
|
||||
const auto query = query_ * (1. / ::sqrt(static_cast<double>(embed_size)));
|
||||
const auto query = query_ / ::sqrt(static_cast<double>(embed_size));
|
||||
if (is_causal) {
|
||||
TORCH_CHECK(!attn_mask.has_value(),
|
||||
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");
|
||||
|
|
|
|||
|
|
@ -7,6 +7,8 @@ import random
|
|||
|
||||
import warnings
|
||||
warnings.filterwarnings("ignore")
|
||||
|
||||
|
||||
class CompositeMHA(torch.nn.Module):
|
||||
def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
|
||||
super().__init__()
|
||||
|
|
@ -90,8 +92,8 @@ def benchmark_torch_function(iters, f, *args, **kwargs):
|
|||
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters
|
||||
|
||||
|
||||
def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, pad_percentage, writer):
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True):
|
||||
def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, pad_percentage, enable_math, enable_flash, writer):
|
||||
with torch.backends.cuda.sdp_kernel(enable_math=enable_math, enable_flash=enable_flash):
|
||||
with torch.inference_mode():
|
||||
dropout_p = 0.0
|
||||
mask = None
|
||||
|
|
@ -122,6 +124,8 @@ def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len,
|
|||
results["cp_time"] = cp_time
|
||||
results["speedup"] = pt_time / cp_time
|
||||
results["dtype"] = str(x.dtype)
|
||||
results["enable_math"] = str(enable_math)
|
||||
results["enable_flash"] = str(enable_flash)
|
||||
writer.writerow(results)
|
||||
|
||||
|
||||
|
|
@ -131,15 +135,22 @@ def main():
|
|||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
headers = ["max_sequence_len", "num_heads", "embed_dimension", "pt_time", "cp_time", "speedup", "dtype"]
|
||||
headers = ["max_sequence_len", "num_heads", "embed_dimension", "pt_time",
|
||||
"cp_time", "speedup", "dtype", "enable_math", "enable_flash"]
|
||||
writer = csv.DictWriter(sys.stdout, headers)
|
||||
writer.writeheader()
|
||||
|
||||
batch_size = 64
|
||||
pad_percentage = 0.5
|
||||
|
||||
for num_heads, max_seq_len in itertools.product([2, 4, 8, 16, 32], [64, 128, 256]):
|
||||
run_timing(iters, batch_size, 1024, num_heads, max_seq_len, pad_percentage, writer)
|
||||
for (enable_math, enable_flash) in [(False, True), (True, False), (True, True)]:
|
||||
for num_heads, max_seq_len in itertools.product([2, 4, 8, 16, 32], [64, 128, 256]):
|
||||
run_timing(iters, batch_size, 1024, num_heads, max_seq_len,
|
||||
pad_percentage, enable_math, enable_flash, writer)
|
||||
run_timing(iters, batch_size, 1024, num_heads, max_seq_len,
|
||||
pad_percentage, enable_math, enable_flash, writer)
|
||||
run_timing(iters, batch_size, 1024, num_heads, max_seq_len,
|
||||
pad_percentage, enable_math, enable_flash, writer)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ from torch.testing._internal.common_utils import (
|
|||
|
||||
# Tests are ported from pytorch/nestedtensor.
|
||||
# This makes porting as_nested_tensor easier in the future.
|
||||
|
||||
|
||||
def _iter_constructors():
|
||||
# yield as_nested_tensor
|
||||
yield torch.nested.nested_tensor
|
||||
|
|
@ -33,6 +35,8 @@ def _iter_constructors():
|
|||
# an output nested tensor consists of
|
||||
# * `len(ragged_sizes)` matrices
|
||||
# * matrices[i].shape == (20, ragged_sizes[i])
|
||||
|
||||
|
||||
def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16):
|
||||
xs = []
|
||||
for size in ragged_sizes:
|
||||
|
|
@ -49,6 +53,8 @@ def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16
|
|||
|
||||
# Helper functions to pad a noncontiguous nested tensor
|
||||
# can be replaced once to_padded_tensor supports noncontiguous memory
|
||||
|
||||
|
||||
def noncontiguous_to_padded_tensor(input, shape=None):
|
||||
tensors = input.unbind()
|
||||
ntensors = len(tensors)
|
||||
|
|
@ -72,6 +78,8 @@ def noncontiguous_to_padded_tensor(input, shape=None):
|
|||
return result
|
||||
|
||||
# Helper function to generate a random nested tensor
|
||||
|
||||
|
||||
def random_nt(device, dtype, num_tensors, max_dims, min_dims=None):
|
||||
if min_dims is None:
|
||||
min_dims = tuple([0] * len(max_dims))
|
||||
|
|
@ -83,6 +91,7 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None):
|
|||
ts1.append(t1)
|
||||
return torch.nested.nested_tensor(ts1, device=device, dtype=dtype)
|
||||
|
||||
|
||||
class TestNestedTensor(TestCase):
|
||||
|
||||
@torch.inference_mode()
|
||||
|
|
@ -478,7 +487,6 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
|
||||
self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))
|
||||
|
||||
|
||||
@dtypes(torch.float, torch.float16, torch.double)
|
||||
def test_unbind_noncontiguous(self, device, dtype):
|
||||
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)
|
||||
|
|
@ -857,6 +865,39 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
lambda: vector.mul(nt1)
|
||||
)
|
||||
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@skipMeta
|
||||
@torch.inference_mode()
|
||||
def test_nested_tensor_div(self, device, dtype):
|
||||
nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4))
|
||||
scale = 4.0
|
||||
ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()])
|
||||
out = nt / 4.0
|
||||
self.assertEqual(ref, out)
|
||||
ref_transposed = ref.transpose(1, 2)
|
||||
out = nt.transpose(1, 2) / 4.0
|
||||
self.assertEqual(ref_transposed, out)
|
||||
|
||||
ref = torch.nested.nested_tensor([t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())])
|
||||
out = nt / nt2
|
||||
self.assertEqual(ref, out)
|
||||
|
||||
out = nt.transpose(1, 2) / nt2.transpose(1, 2)
|
||||
self.assertEqual(ref.transpose(1, 2), out)
|
||||
|
||||
nt_transpose_copy = torch.nested.nested_tensor([t.transpose(0, 1) for t in nt.unbind()])
|
||||
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "div requires strides to match when given NestedTensors",
|
||||
lambda: nt_transpose_copy.transpose(1, 2) / nt2)
|
||||
|
||||
nt = torch.nested.nested_tensor([torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype)
|
||||
nt_chunks = nt.chunk(2, -1)
|
||||
self.assertRaisesRegex(
|
||||
RuntimeError, "div requires offsets to match when given NestedTensors",
|
||||
lambda: nt_chunks[0] / nt_chunks[1])
|
||||
|
||||
|
||||
@dtypes(torch.float, torch.float16)
|
||||
@skipMeta
|
||||
@torch.inference_mode()
|
||||
|
|
@ -1732,7 +1773,6 @@ class TestNestedTensorAutograd(TestCase):
|
|||
return torch.nested.as_nested_tensor([torch.randn(1, 2, requires_grad=requires_grad),
|
||||
torch.randn(7, 8, requires_grad=requires_grad)])
|
||||
|
||||
|
||||
def _create_nested_tensor_from_mask(self, requires_grad=False):
|
||||
data = torch.randn(2, 3, 4, requires_grad=requires_grad)
|
||||
mask = torch.ones_like(data[:, :, 0]).bool()
|
||||
|
|
@ -1772,7 +1812,6 @@ class TestNestedTensorAutograd(TestCase):
|
|||
self.assertEqual(a.grad, None)
|
||||
self.assertEqual(b.grad, None)
|
||||
|
||||
|
||||
def test_set_requires_grad_from_list(self):
|
||||
nt = self._create_nested_tensor_from_list()
|
||||
nt.requires_grad_()
|
||||
|
|
@ -2139,6 +2178,7 @@ class TestNestedTensorAutograd(TestCase):
|
|||
expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4))])
|
||||
self.assertEqual(nt.grad, expected_grad)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user