Commit Graph

152 Commits

Author SHA1 Message Date
Iurii Zdebskyi
4cb8d306e6 Add _foreach_add_(TensorList tensors, Scalar scalar) API (#42531)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/42531

[First PR: Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar)](https://github.com/pytorch/pytorch/pull/41554).

**Motivation**
[GitHub issue](https://github.com/pytorch/pytorch/issues/38655)
Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start.
As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex).
In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer.

**Current API restrictions**
- List can't be empty (will fixed in upcoming PRs).
- All tensors in the list must have the same dtype, device and size.

**Broadcasting**
At this point we don't support broadcasting.

**What is 'Fast' and 'Slow' route**
In particular cases, we cant process an op with a fast list CUDA kernel. Still, we can do with a regular for-loop where the op will be applied to each tensor individually through the dispatch mechanisms. There are a few checks that decide whether the op will be performed via a 'fast' or 'slow' path.
To go the fast route,
- All tensors must have strided layout
- All tensors must be dense and not have overlapping memory
- The resulting tensor type must be the same.

---------------
**In this PR**
- Adding a `std::vector<Tensor> _foreach_add_(TensorList tensors, Scalar scalar)` API
- Resolving some additional comments from previous [PR](https://github.com/pytorch/pytorch/pull/41554).

**Tests**
Tested via unit tests

**TODO**
1. Properly handle empty lists

**Plan for the next PRs**
1. APIs
- Binary Ops for list with Scalar
- Binary Ops for list with list
- Unary Ops for list
- Pointwise Ops

2. Complete tasks from TODO
3. Rewrite PyTorch optimizers to use for-each operators for performance gains.

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D23331892

Pulled By: izdeby

fbshipit-source-id: c585b72e1e87f6f273f904f75445618915665c4c
2020-08-28 14:34:46 -07:00
iurii zdebskyi
e995c3d21e Add private API to support tensor lists: _foreach_add(TensorList tensors, Scalar scalar) (#41554)
Summary:
Initial PR for the Tensor List functionality.

**Motivation**
[GitHub issue](https://github.com/pytorch/pytorch/issues/38655)
Current PyTorch optimizer implementations are not efficient in cases when we work with a lot of small feature tensors. Starting a lot of kernels slows down the whole process. We need to reduce the number of kernels that we start.
As an example, we should be looking at [NVIDIAs Apex](https://github.com/NVIDIA/apex).
In order to track progress, we will pick PyTorchs DCGAN model with Adam optimizer and once the optimizer is reimplemented with tensor lists, benchmark the model performance against original model version, Apexs version with original Adam optimizer and it’s FusedAdam optimizer.

**In this PR**
- Adding `multi_tensor_apply` mechanism which will help to efficiently apply passed functor on a given list of tensors on CUDA.
- Adding a first private API - `std::vector<Tensor> _foreach_add(TensorList tensors, Scalar scalar)`

**Tests**
Tested via unit tests

**Plan for the next PRs**

1. Cover these ops with `multi_tensor_apply` support
- exponent
- division
- mul_
- add_
- addcmul_
- addcdiv_
- Sqrt

2. Rewrite PyTorch optimizers to use for-each operators in order to get performance gains.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/41554

Reviewed By: cpuhrsch

Differential Revision: D22829724

Pulled By: izdeby

fbshipit-source-id: 47febdbf7845cf931958a638567b7428a24782b1
2020-08-04 15:01:09 -07:00