diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index c481e03d23c..2cd92a5ef19 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -1881,6 +1881,7 @@ dispatch: SparseCPU, SparseCUDA: div_sparse ZeroTensor: div_zerotensor + NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Tensor tags: canonical - func: div_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!) @@ -1928,6 +1929,7 @@ variants: function, method dispatch: CompositeExplicitAutograd: div + NestedTensorCPU, NestedTensorCUDA: NestedTensor_div_Scalar tags: canonical - func: div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!) diff --git a/aten/src/ATen/native/nested/NestedTensorMath.cpp b/aten/src/ATen/native/nested/NestedTensorMath.cpp index fc9e11ea449..2d0e8de8b46 100644 --- a/aten/src/ATen/native/nested/NestedTensorMath.cpp +++ b/aten/src/ATen/native/nested/NestedTensorMath.cpp @@ -500,10 +500,21 @@ get_elementwise_nested_tensor_impl( op_name, " does not support broadcasting when given a NestedTensor"); TORCH_CHECK( - nested_tensor_impl_is_contiguous(self_ptr) && - nested_tensor_impl_is_contiguous(other_ptr), + at::equal( + self_ptr->get_nested_stride_tensor(), + other_ptr->get_nested_stride_tensor()), op_name, - " does not support non-contiguous NestedTensor inputs"); + " requires strides to match when given NestedTensors"); + auto self_offsets = self_ptr->get_storage_offsets(); + auto other_offsets = other_ptr->get_storage_offsets(); + bool offsets_match = true; + for (size_t i = 0; i < self_offsets.size(); i++) { + offsets_match = offsets_match && (self_offsets[i] == other_offsets[i]); + } + TORCH_CHECK( + offsets_match, + op_name, + " requires offsets to match when given NestedTensors"); return std::make_pair(self_ptr, other_ptr); } @@ -517,16 +528,20 @@ Tensor NestedTensor_elementwise_Tensor( if (!self.is_nested() && self.dim() == 0 && self.numel() == 1) { auto other_impl = get_nested_tensor_impl(other); return wrap_buffer( - f(self, other_impl->get_buffer()), - other_impl->get_nested_size_tensor().clone() + f(self, other_impl->get_unsafe_storage_as_tensor()), + other_impl->get_nested_size_tensor().clone(), + other_impl->get_nested_stride_tensor().clone(), + other_impl->get_storage_offsets() ); } // other is a scalar if (!other.is_nested() && other.dim() == 0 && other.numel() == 1) { auto self_impl = get_nested_tensor_impl(self); return wrap_buffer( - f(self_impl->get_buffer(), other), - self_impl->get_nested_size_tensor().clone() + f(self_impl->get_unsafe_storage_as_tensor(), other), + self_impl->get_nested_size_tensor().clone(), + self_impl->get_nested_stride_tensor().clone(), + self_impl->get_storage_offsets() ); } NestedTensorImpl* self_impl = nullptr; @@ -535,13 +550,12 @@ Tensor NestedTensor_elementwise_Tensor( get_elementwise_nested_tensor_impl(self, other, op_name); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(self_impl); TORCH_INTERNAL_ASSERT_DEBUG_ONLY(other_impl); - const auto& nt_self = *self_impl; - const auto& nt_other = *other_impl; - const auto& self_sizes = nt_self.get_nested_size_tensor(); return wrap_buffer( - f(nt_self.get_buffer().reshape({-1}), - nt_other.get_buffer().reshape({-1})), - self_sizes); + f(self_impl->get_unsafe_storage_as_tensor(), + other_impl->get_unsafe_storage_as_tensor()), + self_impl->get_nested_size_tensor(), + self_impl->get_nested_stride_tensor(), + self_impl->get_storage_offsets()); } Tensor NestedTensor_add_Tensor( @@ -566,6 +580,18 @@ Tensor NestedTensor_mul_Scalar(const Tensor& self, const Scalar& other) { return NestedTensor_mul_Tensor(self, wrapped_scalar_tensor(other)); } +Tensor NestedTensor_div_Tensor(const Tensor& self, const Tensor& other) { + return NestedTensor_elementwise_Tensor( + self, other, "div", [](const Tensor& b1, const Tensor& b2) { + return at::div(b1, b2); + }); +} + +// Only usable on the C++ side; scalars are converted to tensors coming from Python. +Tensor NestedTensor_div_Scalar(const Tensor& self, const Scalar& other) { + return NestedTensor_div_Tensor(self, wrapped_scalar_tensor(other)); +} + template Tensor& NestedTensor_elementwise__Tensor( Tensor& self, diff --git a/aten/src/ATen/native/nested/NestedTensorUtils.h b/aten/src/ATen/native/nested/NestedTensorUtils.h index 77d512c519b..ff8ec37dfc5 100644 --- a/aten/src/ATen/native/nested/NestedTensorUtils.h +++ b/aten/src/ATen/native/nested/NestedTensorUtils.h @@ -1,7 +1,7 @@ #pragma once -#include #include +#include #include #include #include @@ -50,6 +50,18 @@ inline at::Tensor wrap_buffer( std::move(offsets)); } +inline at::Tensor wrap_buffer( + at::Tensor buffer, + at::Tensor nested_size_tensor, + at::Tensor nested_stride_tensor, + const std::vector& offsets) { + std::vector offsets_copy(offsets); + return wrap_buffer(buffer, + nested_size_tensor, + nested_stride_tensor, + std::move(offsets_copy)); +} + inline at::Tensor get_buffer(const at::Tensor& tensor) { return get_nested_tensor_impl(tensor)->get_buffer(); } @@ -119,7 +131,6 @@ inline std::vector NestedTensor_get_sizes( return sizes; } - TORCH_API std::vector NestedTensor_get_max_size( const NestedTensorImpl& nt); @@ -161,17 +172,18 @@ inline std::vector NestedTensor_get_strides( inline void check_numel_equals_buffer_size(const at::Tensor& self) { auto self_impl = get_nested_tensor_impl(self); TORCH_CHECK( - self.numel() == self_impl -> get_buffer_size(), + self.numel() == self_impl->get_buffer_size(), "Number of elements in nested tensor must match number of elements in buffer."); } inline void check_numel_equals_buffer_size(const NestedTensorImpl* self_ptr) { TORCH_CHECK( - self_ptr-> numel() == self_ptr -> get_buffer_size(), + self_ptr->numel() == self_ptr->get_buffer_size(), "Number of elements in nested tensor must match number of elements in buffer."); } // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ -// Data structures and functions for generically applying a function on a nested tensor. +// Data structures and functions for generically applying a function on a nested +// tensor. namespace impl { template diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index c03935ecfbf..5b3d5f999cf 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -703,7 +703,7 @@ std::tuple _scaled_dot_product_attention_math( auto attn_mask = attn_mask_; // Naive, composite implementation defined here. const auto embed_size = query_.size(-1); - const auto query = query_ * (1. / ::sqrt(static_cast(embed_size))); + const auto query = query_ / ::sqrt(static_cast(embed_size)); if (is_causal) { TORCH_CHECK(!attn_mask.has_value(), "_scaled_dot_product_attention: Explicit attn_mask should not be set when is_causal=True"); diff --git a/benchmarks/transformer/sdp.py b/benchmarks/transformer/sdp.py index 50db76e9f8c..fbd123fc39b 100644 --- a/benchmarks/transformer/sdp.py +++ b/benchmarks/transformer/sdp.py @@ -7,6 +7,8 @@ import random import warnings warnings.filterwarnings("ignore") + + class CompositeMHA(torch.nn.Module): def __init__(self, num_heads, in_proj_weight, in_proj_bias, out_proj): super().__init__() @@ -90,8 +92,8 @@ def benchmark_torch_function(iters, f, *args, **kwargs): return (start_event.elapsed_time(end_event) * 1.0e-3) / iters -def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, pad_percentage, writer): - with torch.backends.cuda.sdp_kernel(enable_math=False, enable_flash=True): +def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, pad_percentage, enable_math, enable_flash, writer): + with torch.backends.cuda.sdp_kernel(enable_math=enable_math, enable_flash=enable_flash): with torch.inference_mode(): dropout_p = 0.0 mask = None @@ -122,6 +124,8 @@ def run_timing(iters, batch_size, embed_dimension, num_heads, max_sequence_len, results["cp_time"] = cp_time results["speedup"] = pt_time / cp_time results["dtype"] = str(x.dtype) + results["enable_math"] = str(enable_math) + results["enable_flash"] = str(enable_flash) writer.writerow(results) @@ -131,15 +135,22 @@ def main(): np.random.seed(seed) torch.manual_seed(seed) - headers = ["max_sequence_len", "num_heads", "embed_dimension", "pt_time", "cp_time", "speedup", "dtype"] + headers = ["max_sequence_len", "num_heads", "embed_dimension", "pt_time", + "cp_time", "speedup", "dtype", "enable_math", "enable_flash"] writer = csv.DictWriter(sys.stdout, headers) writer.writeheader() batch_size = 64 pad_percentage = 0.5 - for num_heads, max_seq_len in itertools.product([2, 4, 8, 16, 32], [64, 128, 256]): - run_timing(iters, batch_size, 1024, num_heads, max_seq_len, pad_percentage, writer) + for (enable_math, enable_flash) in [(False, True), (True, False), (True, True)]: + for num_heads, max_seq_len in itertools.product([2, 4, 8, 16, 32], [64, 128, 256]): + run_timing(iters, batch_size, 1024, num_heads, max_seq_len, + pad_percentage, enable_math, enable_flash, writer) + run_timing(iters, batch_size, 1024, num_heads, max_seq_len, + pad_percentage, enable_math, enable_flash, writer) + run_timing(iters, batch_size, 1024, num_heads, max_seq_len, + pad_percentage, enable_math, enable_flash, writer) if __name__ == "__main__": diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index f51db599586..f5e9aa1b8d7 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -24,6 +24,8 @@ from torch.testing._internal.common_utils import ( # Tests are ported from pytorch/nestedtensor. # This makes porting as_nested_tensor easier in the future. + + def _iter_constructors(): # yield as_nested_tensor yield torch.nested.nested_tensor @@ -33,6 +35,8 @@ def _iter_constructors(): # an output nested tensor consists of # * `len(ragged_sizes)` matrices # * matrices[i].shape == (20, ragged_sizes[i]) + + def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16): xs = [] for size in ragged_sizes: @@ -49,6 +53,8 @@ def random_nt_noncontiguous_pair(ragged_sizes, device="cpu", dtype=torch.float16 # Helper functions to pad a noncontiguous nested tensor # can be replaced once to_padded_tensor supports noncontiguous memory + + def noncontiguous_to_padded_tensor(input, shape=None): tensors = input.unbind() ntensors = len(tensors) @@ -72,6 +78,8 @@ def noncontiguous_to_padded_tensor(input, shape=None): return result # Helper function to generate a random nested tensor + + def random_nt(device, dtype, num_tensors, max_dims, min_dims=None): if min_dims is None: min_dims = tuple([0] * len(max_dims)) @@ -83,6 +91,7 @@ def random_nt(device, dtype, num_tensors, max_dims, min_dims=None): ts1.append(t1) return torch.nested.nested_tensor(ts1, device=device, dtype=dtype) + class TestNestedTensor(TestCase): @torch.inference_mode() @@ -478,7 +487,6 @@ class TestNestedTensorDeviceType(TestCase): self.assertEqual(a.grad, torch.ones(2, 4, device=device, dtype=dtype)) self.assertEqual(b.grad, torch.ones(5, 4, device=device, dtype=dtype)) - @dtypes(torch.float, torch.float16, torch.double) def test_unbind_noncontiguous(self, device, dtype): nt_contiguous, nt_noncontiguous = random_nt_noncontiguous_pair((2, 3, 6, 7), device, dtype) @@ -857,6 +865,39 @@ class TestNestedTensorDeviceType(TestCase): lambda: vector.mul(nt1) ) + @dtypes(torch.float, torch.float16) + @skipMeta + @torch.inference_mode() + def test_nested_tensor_div(self, device, dtype): + nt, nt2 = self.random_nt_pair(device, dtype, 4, (4, 4)) + scale = 4.0 + ref = torch.nested.nested_tensor([t / scale for t in nt.unbind()]) + out = nt / 4.0 + self.assertEqual(ref, out) + ref_transposed = ref.transpose(1, 2) + out = nt.transpose(1, 2) / 4.0 + self.assertEqual(ref_transposed, out) + + ref = torch.nested.nested_tensor([t / t2 for (t, t2) in zip(nt.unbind(), nt2.unbind())]) + out = nt / nt2 + self.assertEqual(ref, out) + + out = nt.transpose(1, 2) / nt2.transpose(1, 2) + self.assertEqual(ref.transpose(1, 2), out) + + nt_transpose_copy = torch.nested.nested_tensor([t.transpose(0, 1) for t in nt.unbind()]) + + self.assertRaisesRegex( + RuntimeError, "div requires strides to match when given NestedTensors", + lambda: nt_transpose_copy.transpose(1, 2) / nt2) + + nt = torch.nested.nested_tensor([torch.randn(i, 4) for i in [3, 4, 5]], device=device, dtype=dtype) + nt_chunks = nt.chunk(2, -1) + self.assertRaisesRegex( + RuntimeError, "div requires offsets to match when given NestedTensors", + lambda: nt_chunks[0] / nt_chunks[1]) + + @dtypes(torch.float, torch.float16) @skipMeta @torch.inference_mode() @@ -1732,7 +1773,6 @@ class TestNestedTensorAutograd(TestCase): return torch.nested.as_nested_tensor([torch.randn(1, 2, requires_grad=requires_grad), torch.randn(7, 8, requires_grad=requires_grad)]) - def _create_nested_tensor_from_mask(self, requires_grad=False): data = torch.randn(2, 3, 4, requires_grad=requires_grad) mask = torch.ones_like(data[:, :, 0]).bool() @@ -1772,7 +1812,6 @@ class TestNestedTensorAutograd(TestCase): self.assertEqual(a.grad, None) self.assertEqual(b.grad, None) - def test_set_requires_grad_from_list(self): nt = self._create_nested_tensor_from_list() nt.requires_grad_() @@ -2139,6 +2178,7 @@ class TestNestedTensorAutograd(TestCase): expected_grad = torch.nested.nested_tensor([grad_x0, torch.zeros((3, 4))]) self.assertEqual(nt.grad, expected_grad) + instantiate_device_type_tests(TestNestedTensorDeviceType, globals()) if __name__ == '__main__':