Support non-contiguous NestedTensors for elementwise ops (#87888)

Enables benchmarking of math path of sdp kernel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87888
Approved by: https://github.com/drisspg
This commit is contained in:
Christian Puhrsch 2022-10-28 11:26:17 +00:00 committed by PyTorch MergeBot
parent f150e70ca2
commit b192e7e415
6 changed files with 118 additions and 27 deletions

View File

@ -1881,6 +1881,7 @@
dispatch: dispatch:
SparseCPU, SparseCUDA: div_sparse SparseCPU, SparseCUDA: div_sparse
ZeroTensor: div_zerotensor ZeroTensor: div_zerotensor
NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Tensor
tags: canonical tags: canonical
- func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)
@ -1928,6 +1929,7 @@
variants: function, method variants: function, method
dispatch: dispatch:
CompositeExplicitAutograd: div CompositeExplicitAutograd: div
NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Scalar
tags: canonical tags: canonical
- func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) - func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)

View File

@ -500,10 +500,21 @@ get_elementwise_nested_tensor_impl(
op_name, op_name,
" does not support broadcasting when given a NestedTensor"); " does not support broadcasting when given a NestedTensor");
TORCH_CHECK( TORCH_CHECK(
nested_tensor_impl_is_contiguous(self_ptr) && at::equal(
nested_tensor_impl_is_contiguous(other_ptr), self_ptr->get_nested_stride_tensor(),
other_ptr->get_nested_stride_tensor()),
op_name, op_name,
" does not support non-contiguous NestedTensor inputs"); " requires strides to match when given NestedTensors");
auto self_offsets = self_ptr->get_storage_offsets();
auto other_offsets = other_ptr->get_storage_offsets();
bool offsets_match = true;
for (size_t i = 0; i < self_offsets.size(); i++) {
offsets_match = offsets_match && (self_offsets[i] == other_offsets[i]);
}
TORCH_CHECK(
offsets_match,
op_name,
" requires offsets to match when given NestedTensors");
return std::make_pair(self_ptr, other_ptr); return std::make_pair(self_ptr, other_ptr);
} }
@ -517,16 +528,20 @@ Tensor NestedTensor_elementwise_Tensor(
if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) { if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) {
auto other_impl = get_nested_tensor_impl(other); auto other_impl = get_nested_tensor_impl(other);
return wrap_buffer( return wrap_buffer(
f(self, other_impl->get_buffer()), f(self, other_impl->get_unsafe_storage_as_tensor()),
other_impl->get_nested_size_tensor().clone() other_impl->get_nested_size_tensor().clone(),
other_impl->get_nested_stride_tensor().clone(),
other_impl->get_storage_offsets()
); );
} }
// other is a scalar // other is a scalar
if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) { if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) {
auto self_impl = get_nested_tensor_impl(self); auto self_impl = get_nested_tensor_impl(self);
return wrap_buffer( return wrap_buffer(
f(self_impl->get_buffer(), other), f(self_impl->get_unsafe_storage_as_tensor(), other),
self_impl->get_nested_size_tensor().clone() self_impl->get_nested_size_tensor().clone(),
self_impl->get_nested_stride_tensor().clone(),
self_impl->get_storage_offsets()
); );
} }
NestedTensorImpl* self_impl = nullptr; NestedTensorImpl* self_impl = nullptr;
@ -535,13 +550,12 @@ Tensor NestedTensor_elementwise_Tensor(
get_elementwise_nested_tensor_impl(self, other, op_name); get_elementwise_nested_tensor_impl(self, other, op_name);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl);
const auto& nt_self = *self_impl;
const auto& nt_other = *other_impl;
const auto& self_sizes = nt_self.get_nested_size_tensor();
return wrap_buffer( return wrap_buffer(
f(nt_self.get_buffer().reshape({-1}), f(self_impl->get_unsafe_storage_as_tensor(),
nt_other.get_buffer().reshape({-1})), other_impl->get_unsafe_storage_as_tensor()),
self_sizes); self_impl->get_nested_size_tensor(),
self_impl->get_nested_stride_tensor(),
self_impl->get_storage_offsets());
} }
Tensor NestedTensor_add_Tensor( Tensor NestedTensor_add_Tensor(
@ -566,6 +580,18 @@ Tensor NestedTensor_mul_Scalar(const Tensor& self, const Scalar& other) {
return NestedTensor_mul_Tensor(self, wrapped_scalar_tensor(other)); return NestedTensor_mul_Tensor(self, wrapped_scalar_tensor(other));
} }
Tensor NestedTensor_div_Tensor(const Tensor& self, const Tensor& other) {
return NestedTensor_elementwise_Tensor(
self, other, "div", [](const Tensor& b1, const Tensor& b2) {
return at::div(b1, b2);
});
}
// Only usable on the C++ side; scalars are converted to tensors coming from Python.
Tensor NestedTensor_div_Scalar(const Tensor& self, const Scalar& other) {
return NestedTensor_div_Tensor(self, wrapped_scalar_tensor(other));
}
template <typename Func> template <typename Func>
Tensor& NestedTensor_elementwise__Tensor( Tensor& NestedTensor_elementwise__Tensor(
Tensor& self, Tensor& self,

View File

@ -1,7 +1,7 @@
#pragma once #pragma once
#include <ATen/core/Tensor.h>
#include <ATen/NestedTensorImpl.h> #include <ATen/NestedTensorImpl.h>
#include <ATen/core/Tensor.h>
#include <c10/core/DispatchKeySet.h> #include <c10/core/DispatchKeySet.h>
#include <c10/core/TensorImpl.h> #include <c10/core/TensorImpl.h>
#include <c10/macros/Macros.h> #include <c10/macros/Macros.h>
@ -50,6 +50,18 @@ inline at::Tensor wrap_buffer(
std::move(offsets)); std::move(offsets));
} }
inline at::Tensor wrap_buffer(
at::Tensor buffer,
at::Tensor nested_size_tensor,
at::Tensor nested_stride_tensor,
const std::vector<int64_t>& offsets) {
std::vector<int64_t> offsets_copy(offsets);
return wrap_buffer(buffer,
nested_size_tensor,
nested_stride_tensor,
std::move(offsets_copy));
}
inline at::Tensor get_buffer(const at::Tensor& tensor) { inline at::Tensor get_buffer(const at::Tensor& tensor) {
return get_nested_tensor_impl(tensor)->get_buffer(); return get_nested_tensor_impl(tensor)->get_buffer();
} }
@ -119,7 +131,6 @@ inline std::vector<IntArrayRef> NestedTensor_get_sizes(
return sizes; return sizes;
} }
TORCH_API std::vector<int64_t> NestedTensor_get_max_size( TORCH_API std::vector<int64_t> NestedTensor_get_max_size(
const NestedTensorImpl& nt); const NestedTensorImpl& nt);
@ -161,17 +172,18 @@ inline std::vector<IntArrayRef> NestedTensor_get_strides(
inline void check_numel_equals_buffer_size(const at::Tensor& self) { inline void check_numel_equals_buffer_size(const at::Tensor& self) {
auto self_impl = get_nested_tensor_impl(self); auto self_impl = get_nested_tensor_impl(self);
TORCH_CHECK( TORCH_CHECK(
self.numel() == self_impl -> get_buffer_size(), self.numel() == self_impl->get_buffer_size(),
"Number of elements in nested tensor must match number of elements in buffer."); "Number of elements in nested tensor must match number of elements in buffer.");
} }
inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) { inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) {
TORCH_CHECK( TORCH_CHECK(
self_ptr-> numel() == self_ptr -> get_buffer_size(), self_ptr->numel() == self_ptr->get_buffer_size(),
"Number of elements in nested tensor must match number of elements in buffer."); "Number of elements in nested tensor must match number of elements in buffer.");
} }
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
// Data structures and functions for generically applying a function on a nested tensor. // Data structures and functions for generically applying a function on a nested
// tensor.
namespace impl { namespace impl {
template <typename T> template <typename T>

View File

@ -703,7 +703,7 @@ std::tuple<Tensor, Tensor> _scaled_dot_product_attention_math(
auto attn_mask = attn_mask_; auto attn_mask = attn_mask_;
// Naive, composite implementation defined here. // Naive, composite implementation defined here.
const auto embed_size = query_.size(-1); const auto embed_size = query_.size(-1);
const auto query = query_ * (1. / ::sqrt(static_cast<double>(embed_size))); const auto query = query_ / ::sqrt(static_cast<double>(embed_size));
if (is_causal) { if (is_causal) {
TORCH_CHECK(!attn_mask.has_value(), TORCH_CHECK(!attn_mask.has_value(),
"_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True"); "_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True");

View File

@ -7,6 +7,8 @@ import random
import warnings import warnings
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
class CompositeMHA(torch.nn.Module): class CompositeMHA(torch.nn.Module):
def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj): def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj):
super().__init__() super().__init__()
@ -90,8 +92,8 @@ def benchmark_torch_function(iters, f, *args, **kwargs):
return (start_event.elapsed_time(end_event) * 1.0e-3) / iters return (start_event.elapsed_time(end_event) * 1.0e-3) / iters
def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, pad_percentage, writer): def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, pad_percentage, enable_math, enable_flash, writer):
with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True): with torch.backends.cuda.sdp_kernel(enable_math=enable_math, enable_flash=enable_flash):
with torch.inference_mode(): with torch.inference_mode():
dropout_p = 0.0 dropout_p = 0.0
mask = None mask = None
@ -122,6 +124,8 @@ def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len,
results["cp_time"] = cp_time results["cp_time"] = cp_time
results["speedup"] = pt_time / cp_time results["speedup"] = pt_time / cp_time
results["dtype"] = str(x.dtype) results["dtype"] = str(x.dtype)
results["enable_math"] = str(enable_math)
results["enable_flash"] = str(enable_flash)
writer.writerow(results) writer.writerow(results)
@ -131,15 +135,22 @@ def main():
np.random.seed(seed) np.random.seed(seed)
torch.manual_seed(seed) torch.manual_seed(seed)
headers = ["max_sequence_len", "num_heads", "embed_dimension", "pt_time", "cp_time", "speedup", "dtype"] headers = ["max_sequence_len", "num_heads", "embed_dimension", "pt_time",
"cp_time", "speedup", "dtype", "enable_math", "enable_flash"]
writer = csv.DictWriter(sys.stdout, headers) writer = csv.DictWriter(sys.stdout, headers)
writer.writeheader() writer.writeheader()
batch_size = 64 batch_size = 64
pad_percentage = 0.5 pad_percentage = 0.5
for num_heads, max_seq_len in itertools.product([2, 4, 8, 16, 32], [64, 128, 256]): for (enable_math, enable_flash) in [(False, True), (True, False), (True, True)]:
run_timing(iters, batch_size, 1024, num_heads, max_seq_len, pad_percentage, writer) for num_heads, max_seq_len in itertools.product([2, 4, 8, 16, 32], [64, 128, 256]):
run_timing(iters, batch_size, 1024, num_heads, max_seq_len,
pad_percentage, enable_math, enable_flash, writer)
run_timing(iters, batch_size, 1024, num_heads, max_seq_len,
pad_percentage, enable_math, enable_flash, writer)
run_timing(iters, batch_size, 1024, num_heads, max_seq_len,
pad_percentage, enable_math, enable_flash, writer)
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -24,6 +24,8 @@ from torch.testing._internal.common_utils import (
# Tests are ported from pytorch/nestedtensor. # Tests are ported from pytorch/nestedtensor.
# This makes porting as_nested_tensor easier in the future. # This makes porting as_nested_tensor easier in the future.
def _iter_constructors(): def _iter_constructors():
# yield as_nested_tensor # yield as_nested_tensor
yield torch.nested.nested_tensor yield torch.nested.nested_tensor
@ -33,6 +35,8 @@ def _iter_constructors():
# an output nested tensor consists of # an output nested tensor consists of
# * `len(ragged_sizes)` matrices # * `len(ragged_sizes)` matrices
# * matrices[i].shape == (20, ragged_sizes[i]) # * matrices[i].shape == (20, ragged_sizes[i])
def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16): def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16):
xs = [] xs = []
for size in ragged_sizes: for size in ragged_sizes:
@ -49,6 +53,8 @@ def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16
# Helper functions to pad a noncontiguous nested tensor # Helper functions to pad a noncontiguous nested tensor
# can be replaced once to_padded_tensor supports noncontiguous memory # can be replaced once to_padded_tensor supports noncontiguous memory
def noncontiguous_to_padded_tensor(input, shape=None): def noncontiguous_to_padded_tensor(input, shape=None):
tensors = input.unbind() tensors = input.unbind()
ntensors = len(tensors) ntensors = len(tensors)
@ -72,6 +78,8 @@ def noncontiguous_to_padded_tensor(input, shape=None):
return result return result
# Helper function to generate a random nested tensor # Helper function to generate a random nested tensor
def random_nt(device, dtype, num_tensors, max_dims, min_dims=None): def random_nt(device, dtype, num_tensors, max_dims, min_dims=None):
if min_dims is None: if min_dims is None:
min_dims = tuple([0] * len(max_dims)) min_dims = tuple([0] * len(max_dims))
@ -83,6 +91,7 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None):
ts1.append(t1) ts1.append(t1)
return torch.nested.nested_tensor(ts1, device=device, dtype=dtype) return torch.nested.nested_tensor(ts1, device=device, dtype=dtype)
class TestNestedTensor(TestCase): class TestNestedTensor(TestCase):
@torch.inference_mode() @torch.inference_mode()
@ -478,7 +487,6 @@ class TestNestedTensorDeviceType(TestCase):
self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype)) self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype))
self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype)) self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype))
@dtypes(torch.float, torch.float16, torch.double) @dtypes(torch.float, torch.float16, torch.double)
def test_unbind_noncontiguous(self, device, dtype): def test_unbind_noncontiguous(self, device, dtype):
nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype)
@ -857,6 +865,39 @@ class TestNestedTensorDeviceType(TestCase):
lambda: vector.mul(nt1) lambda: vector.mul(nt1)
) )
@dtypes(torch.float, torch.float16)
@skipMeta
@torch.inference_mode()
def test_nested_tensor_div(self, device, dtype):
nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4))
scale = 4.0
ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()])
out = nt / 4.0
self.assertEqual(ref, out)
ref_transposed = ref.transpose(1, 2)
out = nt.transpose(1, 2) / 4.0
self.assertEqual(ref_transposed, out)
ref = torch.nested.nested_tensor([t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())])
out = nt / nt2
self.assertEqual(ref, out)
out = nt.transpose(1, 2) / nt2.transpose(1, 2)
self.assertEqual(ref.transpose(1, 2), out)
nt_transpose_copy = torch.nested.nested_tensor([t.transpose(0, 1) for t in nt.unbind()])
self.assertRaisesRegex(
RuntimeError, "div requires strides to match when given NestedTensors",
lambda: nt_transpose_copy.transpose(1, 2) / nt2)
nt = torch.nested.nested_tensor([torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype)
nt_chunks = nt.chunk(2, -1)
self.assertRaisesRegex(
RuntimeError, "div requires offsets to match when given NestedTensors",
lambda: nt_chunks[0] / nt_chunks[1])
@dtypes(torch.float, torch.float16) @dtypes(torch.float, torch.float16)
@skipMeta @skipMeta
@torch.inference_mode() @torch.inference_mode()
@ -1732,7 +1773,6 @@ class TestNestedTensorAutograd(TestCase):
return torch.nested.as_nested_tensor([torch.randn(1, 2, requires_grad=requires_grad), return torch.nested.as_nested_tensor([torch.randn(1, 2, requires_grad=requires_grad),
torch.randn(7, 8, requires_grad=requires_grad)]) torch.randn(7, 8, requires_grad=requires_grad)])
def _create_nested_tensor_from_mask(self, requires_grad=False): def _create_nested_tensor_from_mask(self, requires_grad=False):
data = torch.randn(2, 3, 4, requires_grad=requires_grad) data = torch.randn(2, 3, 4, requires_grad=requires_grad)
mask = torch.ones_like(data[:, :, 0]).bool() mask = torch.ones_like(data[:, :, 0]).bool()
@ -1772,7 +1812,6 @@ class TestNestedTensorAutograd(TestCase):
self.assertEqual(a.grad, None) self.assertEqual(a.grad, None)
self.assertEqual(b.grad, None) self.assertEqual(b.grad, None)
def test_set_requires_grad_from_list(self): def test_set_requires_grad_from_list(self):
nt = self._create_nested_tensor_from_list() nt = self._create_nested_tensor_from_list()
nt.requires_grad_() nt.requires_grad_()
@ -2139,6 +2178,7 @@ class TestNestedTensorAutograd(TestCase):
expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4))]) expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4))])
self.assertEqual(nt.grad, expected_grad) self.assertEqual(nt.grad, expected_grad)
instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) instantiate_device_type_tests(TestNestedTensorDeviceType, globals())
if __name__ == '__main__': if __name__ == '__main__':