mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Initial PR for the Tensor List functionality. **Motivation** [GitHub issue](https://github.com/pytorch/pytorch/issues/38655) Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start. As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex). In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer. **In this PR** - Adding `multi_tensor_apply` mechanism which will help to efficiently apply passed functor on a given list of tensors on CUDA. - Adding a first private API - `std::vector<Tensor> _foreach_add(TensorList tensors, Scalar scalar)` **Tests** Tested via unit tests **Plan for the next PRs** 1. Cover these ops with `multi_tensor_apply` support - exponent - division - mul_ - add_ - addcmul_ - addcdiv_ - Sqrt 2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains. Pull Request resolved: https://github.com/pytorch/pytorch/pull/41554 Reviewed By: cpuhrsch Differential Revision: D22829724 Pulled By: izdeby fbshipit-source-id: 47febdbf7845cf931958a638567b7428a24782b1
103 lines
3.7 KiB
Python
103 lines
3.7 KiB
Python
import torch
|
|
import torch.cuda
|
|
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
|
|
|
|
class TestForeach(TestCase):
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_add_scalar_with_same_size_tensors(self, device, dtype):
|
|
N = 20
|
|
H = 20
|
|
W = 20
|
|
tensors = []
|
|
for _ in range(N):
|
|
tensors.append(torch.zeros(H, W, device=device, dtype=dtype))
|
|
|
|
res = torch._foreach_add(tensors, 1)
|
|
for t in res:
|
|
if dtype == torch.bool:
|
|
dtype = torch.int64
|
|
self.assertEqual(t, torch.ones(H, W, device=device, dtype=dtype))
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_add_scalar_with_different_size_tensors(self, device, dtype):
|
|
N = 20
|
|
H = 20
|
|
W = 20
|
|
|
|
tensors = []
|
|
size_change = 0
|
|
for _ in range(N):
|
|
tensors.append(torch.zeros(H + size_change, W + size_change, device=device, dtype=dtype))
|
|
size_change += 1
|
|
|
|
res = torch._foreach_add(tensors, 1)
|
|
|
|
size_change = 0
|
|
for t in res:
|
|
if dtype == torch.bool:
|
|
dtype = torch.int64
|
|
self.assertEqual(t, torch.ones(H + size_change, W + size_change, device=device, dtype=dtype))
|
|
size_change += 1
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_add_scalar_with_empty_list(self, device, dtype):
|
|
tensors = []
|
|
with self.assertRaises(RuntimeError):
|
|
torch._foreach_add(tensors, 1)
|
|
|
|
@dtypes(*torch.testing.get_all_dtypes())
|
|
def test_add_scalar_with_overlapping_tensors(self, device, dtype):
|
|
tensors = [torch.ones(1, 1, device=device, dtype=dtype).expand(2, 1, 3)]
|
|
expected = [torch.tensor([[[2, 2, 2]], [[2, 2, 2]]], dtype=dtype, device=device)]
|
|
|
|
if dtype == torch.bool:
|
|
expected[0] = expected[0].to(torch.int64).add(1)
|
|
|
|
res = torch._foreach_add(tensors, 1)
|
|
self.assertEqual(res, expected)
|
|
|
|
def test_add_scalar_with_different_tensor_dtypes(self, device):
|
|
tensors = [torch.tensor([1], dtype=torch.float, device=device),
|
|
torch.tensor([1], dtype=torch.int, device=device)]
|
|
|
|
expected = [torch.tensor([2], dtype=torch.float, device=device),
|
|
torch.tensor([2], dtype=torch.int, device=device)]
|
|
|
|
res = torch._foreach_add(tensors, 1)
|
|
self.assertEqual(res, expected)
|
|
|
|
def test_add_scalar_with_different_scalar_type(self, device):
|
|
# int tensor with float scalar
|
|
# should go 'slow' route
|
|
scalar = 1.1
|
|
tensors = [torch.tensor([1], dtype=torch.int, device=device)]
|
|
res = torch._foreach_add(tensors, scalar)
|
|
self.assertEqual(res, [torch.tensor([2.1], device=device)])
|
|
|
|
# float tensor with int scalar
|
|
# should go 'fast' route
|
|
scalar = 1
|
|
tensors = [torch.tensor([1.1], device=device)]
|
|
res = torch._foreach_add(tensors, scalar)
|
|
self.assertEqual(res, [torch.tensor([2.1], device=device)])
|
|
|
|
# bool tensor with int scalar
|
|
# should go 'slow' route
|
|
scalar = 1
|
|
tensors = [torch.tensor([False], device=device)]
|
|
res = torch._foreach_add(tensors, scalar)
|
|
self.assertEqual(res, [torch.tensor([1], device=device)])
|
|
|
|
# bool tensor with float scalar
|
|
# should go 'slow' route
|
|
scalar = 1.1
|
|
tensors = [torch.tensor([False], device=device)]
|
|
res = torch._foreach_add(tensors, scalar)
|
|
self.assertEqual(res, [torch.tensor([1.1], device=device)])
|
|
|
|
instantiate_device_type_tests(TestForeach, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|