Summary:
Fixes#6622 .
We used to average over all elements for kl divergence, which is not aligned with its math definition.
This PR corrects the default reduction behavior of KL divergence that it now naverages over batch dimension.
- In KL, default behavior `reduction=mean` averages over batch dimension. While for most other loss functions, `reduction=mean` averages over all elements.
- We used to support scalar tensor as well. For BC purpose, we still support it, no reduction is performed on scalar tensor.
- Added a new reduction mode called `batchmean` which has the correct behavior for KL. Add a warning to make `batchmean` as default for KL instead of `mean` in next major release.
- [deprecated]I chose to not add a new reduction option, since "mean over batch dimension" is kinda special, and it only makes sense in few cases like KL. We don't want to explain why there's a option "batchmean" but it's not applicable for all other functions. I'm open to discussion on this one, as I cannot think of a perfect solution for this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14457
Differential Revision: D13236016
Pulled By: ailzhang
fbshipit-source-id: 905cc7b3bfc35a11d7cf098b1ebc382170a087a7
Summary:
This moves `new_module_tests` from `test_nn.py` to `common_nn.py` so
that they can be used in `test_jit.py` without running any of
`test_nn.py`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14578
Differential Revision: D13268286
Pulled By: driazati
fbshipit-source-id: 6e8654a4c29ab754d656ac83820c14d1c1843e03
Summary:
To convert `max_unpool` functions to weak script, this PR adds support
for `T` as default arguments for `BroadcastingListN[T]`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14361
Differential Revision: D13192231
Pulled By: driazati
fbshipit-source-id: a25b75a0e88ba3dfa22d6a83775e9778d735e249
Summary:
This PR adds weak modules for all activation modules and uses `test_nn` module tests to test weak modules that have been annotated with `weak_module` and therefore are in `torch._jit_internal._weak_types`
Also depends on #14379
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14238
Differential Revision: D13252887
Pulled By: driazati
fbshipit-source-id: e9638cf74089884a32b8f0f38396cf432c02c988
Summary:
This PR adds weak modules for all activation modules and uses `test_nn` module tests to test weak modules that have been annotated with `weak_module` and therefore are in `torch._jit_internal._weak_types`
Also depends on #14379
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14238
Differential Revision: D13192230
Pulled By: driazati
fbshipit-source-id: 36488960b6c91448b38c0fa65422539a93af8c5e
Summary:
As reported in #13386, the pooling operations can return wrong results for large inputs. The root of the problem is that while the output shape is initially being computed with integer operations, it is converted to float32 for division by the stride and applying either a `ceil` or a `floor` depending on the `ceil_mode`. Since even moderately large integers (the smallest being 16,777,217) cannot be expressed exactly in float32, this leads to wrong result shapes.
This PR relies purely on integer operations to perform the shape computation, including the ceil/floor distinction. Since I could not stand all that duplicated code, I pulled it out into a `pooling_shape.h` header, similar to the existing `linear_upsampling.h` header. I hope this is acceptable, let me know if you'd like to see it solved differently. I've also added tests to `test_nn.py` that fail without my changes and pass with my changes. They cover `{max,avg}_pool{1,2,3}d()` for CPU and GPU.
Fixes#13386.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14405
Differential Revision: D13215260
Pulled By: soumith
fbshipit-source-id: 802588ce6cba8db6c346448c3b3c0dac14d12b2d
Summary:
torch.nn.utils.rnn.pack_padded_sequence segment fault if not in
decreasing order #13324
We were seeing this segfault on throw, pre-emptively checking avoids
this:
*** Error in `/home/bvaughan/anaconda3/bin/python': double free or corruption (!prev): 0x00005555566e7510 ***
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13933
Differential Revision: D13090389
Pulled By: nairbv
fbshipit-source-id: 6f6b319e74cb55830be799e9c46bc33aa59256d8
Summary:
This includes everything in nn.yaml except for convolutions, multi_margin_loss, multi_label_margin_loss, nll_loss, and nll_loss2d.
Note that scalar_check False just means we don't do any extra scalar checks (we could elide this from the generated code, which I may do in a later commit).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13906
Reviewed By: ezyang
Differential Revision: D13044507
Pulled By: gchanan
fbshipit-source-id: ebd3bdca2bcf512ca44de1ce3be81946f6c0828e
Summary:
This enables the distributions and utils test sets for ROCm.
Individual tests are enabled that now pass due to fixes in HIP/HCC/libraries versions in white rabbit.
For attention: bddppq ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13166
Differential Revision: D12814759
Pulled By: bddppq
fbshipit-source-id: ea70e775c707d7a8d2776fede6154a755adef43e
Summary:
- Move batch norm from TH(CU)NN to native
- Speedups in many cases (e.g. #12006) for CUDA due to new block/grid layout and Welford-type mean/variance calculations (the latter for training mode)
- It splits the forward kernel in two pieces and reuses the evaluation kernel for the transformation.
- We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to maintain reasonable precision.
Compared to the ill-fated #12368
- I changed the CPU kernel to not call `.sum()` from within parallel for. This seemed to have caused the breakage (NaN-results) in TestModels.test_dcgan_netG (thank you houseroad for the repro, errors in assessment of the fix are my own)
- I updated the Half->Float upcasting in tensors to go through `t.type().scalarType()` instead of `t.dtype()`.
- I have merged master
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13263
Differential Revision: D12946254
Pulled By: SsnL
fbshipit-source-id: 3bb717ee250fbccaf10afe73722996aa4713d10d
Summary:
Problems with SN and DP after #12671 :
1. in eval mode, `weight_orig` is not getting correct gradient #12737 .
Fix: keep `v` vector around as a buffer and always calculate `W = W_orig / (u @ W_orig @ v)` even in eval.
2. in training mode, the `weight` buffer of the parallelized module is never updated, if someone touches `weight_orig` and/or `weight` and makes them not sharing storage. So in `eval` the weight used is wrong.
Fix: Make `weight` not a buffer anymore and always calculate it as above.
3. #12671 changed SN to update `u` in-place to make DP work correctly, but then it breaks backward through two forwards (e.g., the common GAN loss `D(real) - D(fake)`) because the vectors needed to backprop the 1st forward is changed in the 2nd forward.
Fix: This PR clones `u` and `v` before using them.
To maintain BC, I added a hook interface for producing and loading state_dict. This is ugly and we should really have better interface for spectral_norm. But for the purpose to fix this issue, I make this patch. Even if we have a better interface, BC mechanism for legacy loading legacy state_dict still needs to be done.
cc The controller you requested could not be found. crcrpar
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13350
Differential Revision: D12931044
Pulled By: SsnL
fbshipit-source-id: 8be6f934eaa62414d76d2c644dedd7e1b7eb31ef
Summary:
- fixes weights-contiguous requirement for THCUNN Convolutions
- Add tests that conv backward pass works for non-contiguous weights
- fix RNN tests / error messages to be consistent and pass
- relax weight grad precision for fp16 for a particular test
- fix regression of CMAKE_PREFIX_PATH not passing through
- add missing skipIfNoLapack annotations where needed
Differential Revision: D12918456
Pulled By: soumith
fbshipit-source-id: 8642d36bffcc6f2957800d6afa1e10bef2a91d05
Summary:
```
The previous threshold implementation was not vectorized or parallelized.
This speeds up ResNet-50 CPU inference [1] from ~88 ms to ~67 ms
CPU timings:
https://gist.github.com/colesbury/d0d1be6974841d62696dbde329a8fde8
1 thread (before vs. after)
10240: 17.4 us vs. 6.9 µs per loop
102400: 141 us vs. 39.8 µs per loop
16 threads (before vs. after)
10240: 17.4 us vs. 6.7 µs per loop
102400: 141 us vs. 14.3 µs per loop
CUDA timings are not measurably different.
[1]: compiled with MKL-DNN, 8 threads, batch norm merged into convolutions
https://gist.github.com/colesbury/8a64897dae97558b3b82da665048c782
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13182
Reviewed By: soumith
Differential Revision: D12825105
Pulled By: colesbury
fbshipit-source-id: 557da608ebb87db8a04adbb0d2882af4f2eb3c15
Summary:
- Speed up the case of #12006 in the forward
- The backward still isn't as fast as one might hope (factor 2-3 in the #12006 case).
- More extensive benchmarking shows not so great performance compared
to CuDNN for cases with many channels, e.g. bs=8-128 / c=1024 / f=1024.
- We change the meaning of save_mean and save_invstd (aka save_var) to accscalar to
maintain reasonable precision.
Needless to say that I would happily separate the TensorAccessor fixes in a separate PR, as they're fixes and unrelated.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12368
Differential Revision: D10559696
Pulled By: SsnL
fbshipit-source-id: f0d0d1e0912e17b15b8fb7a2c03d0fe757598419
Summary:
Closes#2119.
There was a small bug where the output_size got sliced with `[-2:]`
where we really meant to slice it as `[2:]` (to remove the batch and
channel dimensions).
Added a new test for this.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12952
Differential Revision: D10510678
Pulled By: zou3519
fbshipit-source-id: 4c04a5007fc6d002e1806d6fe981b43d33d6a4f2
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12794
common.py is used in base_module for almost all tests in test/. The
name of this file is so common that can easily conflict with other dependencies
if they happen to have another common.py in the base module. Rename the file to
avoid conflict.
Reviewed By: orionr
Differential Revision: D10438204
fbshipit-source-id: 6a996c14980722330be0a9fd3a54c20af4b3d380
Summary:
Module.to uses the Tensor.to parsing facility.
It should not, however, accept "copy" as a keyword/fourth positional
argument.
See #12571 for discussion.
Thank you SsnL for noticing.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12617
Differential Revision: D10392053
Pulled By: ezyang
fbshipit-source-id: b67a5def7993189b4b47193abc7b741b7d07512c
Summary:
There were two problems with SN + DP:
1. In SN, the updated _u vector is saved back to module via a `setattr`. However, in DP, everything is run on a replica, so those updates are lost.
2. In DP, the buffers are broadcast via a `broadcast_coalesced`, so on replicas they are all views. Therefore, the `detach_` call won't work.
Fixes are:
1. Update _u vector in-place so, by the shared storage between 1st replica and the parallelized module, the update is retained
2. Do not call `detach_`.
3. Added comments in SN about the subtlety.
4. Added a note to the DP doc on this particular behavior of DP.
cc crcrpar taesung89 The controller you requested could not be found. yaoshengfu
Fixes https://github.com/pytorch/pytorch/issues/11476
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12671
Differential Revision: D10410232
Pulled By: SsnL
fbshipit-source-id: c447951844a30366d8c196bf9436340e88f3b6d9
Summary:
Add dtype argument to softmax/log_softmax functions.
Computing softmax in fp32 precision is necessary for mixed precision training, and converting output of the previous layer into fp32 and then reading it as fp32 in softmax is expensive, memory and perf-wise, this PR allows one to avoid it.
For most input data/dtype combinations, input data is converted to dtype and then softmax is computed. If input data is half type and dtype is fp32, kernels with the corresponding template arguments are called.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11719
Reviewed By: ezyang
Differential Revision: D10175514
Pulled By: zou3519
fbshipit-source-id: 06d285af91a0b659932236d41ad63b787eeed243
Summary:
Obviously, the grads of conv weight and conv input are not relevant to the bias, but the original `convXd_input` and `convXd_weight` methods receive a `bias` parameter. What's more, while the doc says `bias` should have the shape `(out_channels,)`, one will get a `RuntimeError` if the bias != None and in_channels != out_channels, for the weight of transposed conv has the shape `(in_channels, out_channels, kH, kW)` while the weight of vanilla conv has the shape `(out_channels, in_channels, kH, kW)`
```
RuntimeError: Given transposed=1, weight of size [channel1, channel2, kH, kW], expected bias to be 1-dimensional with channel2 elements, but got bias of size [channel1] instead
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12281
Differential Revision: D10217370
Pulled By: ezyang
fbshipit-source-id: bc00b439e5ae539276a5e678bdb92af700197bb2
Summary:
- fixes https://github.com/pytorch/pytorch/issues/10723
- migrate PReLU to ATen and deprecate legacy PReLU
- performance:
CPU with weight.numel() = 1
```
>>> m = nn.PReLU()
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
100 loops, best of 100: 9.43 ms per loop
>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
10 loops, best of 100: 24.4 ms per loop
>>> m = nn.PReLU()
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
1000 loops, best of 100: 695 µs per loop
>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
100 loops, best of 100: 2.47 ms per loop
```
CPU with weight.numel() = channels
```
>>> m = nn.PReLU(100)
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
1000 loops, best of 100: 603 µs per loop
>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
100 loops, best of 100: 13.3 ms per loop
>>> m = nn.PReLU(100)
>>> x = torch.randn(100, 100, 100, requires_grad=True)
>>> %timeit -r 100 y = m(x)
1000 loops, best of 100: 655 µs per loop
>>> y = m(x).sum()
>>> %timeit -r 100 y.backward(retain_graph=True)
100 loops, best of 100: 2.45 ms per loop
```
CUDA with weight.numel() = 1
```
>>> m = nn.PReLU().cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
10000 loops, best of 100: 187 µs per loop
>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
100 loops, best of 100: 2.01 ms per loop
>>> m = nn.PReLU().cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
1000 loops, best of 100: 195 µs per loop
>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
100 loops, best of 100: 2.28 ms per loop
```
CUDA with weight.numel() = channel
```
>>> m = nn.PReLU(100).cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
1000 loops, best of 100: 174 µs per loop
>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
100 loops, best of 100: 2.27 ms per loop
>>> m = nn.PReLU(100).cuda()
>>> x = torch.randn(100, 100, 100, requires_grad=True).cuda()
>>> %timeit -r 100 torch.cuda.synchronize(); y = m(x); torch.cuda.synchronize();
10000 loops, best of 100: 181 µs per loop
>>> y = m(x).sum()
>>> %timeit -r 100 torch.cuda.synchronize(); y.backward(retain_graph=True); torch.cuda.synchronize();
100 loops, best of 100: 2.26 ms per loop
```
The huge performance regression in CPU when weight.numel() = 1 is addressed by replacing at::CPU_tensor_apply* with parallelized kernels.
ezyang SsnL zou3519 soumith
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11758
Differential Revision: D9995799
Pulled By: weiyangfb
fbshipit-source-id: d289937c78075f46a54dafbde92fab0cc4b5b86e
Summary:
This PR vectorizes the CPU grid sample 2d forward and backward kernels. Specifically,
1. add `.data()` in `TensorAccessor`
2. support non-void return value for declaring CPU kernel stub
2. add `bool at:: geometry_is_contiguous(IntList sizes, IntList strides)`
1. The following vectorized CPU primitives are added:
+ `gather<scale>(baseaddr, vindex)`: `result[i] = baseaddr[vindex[i] * scale]`
+ `mask_gather<scale>(src, baseaddr, vindex, mask)`: `result[i] = mask[i] ? baseaddr[vindex[i] * scale] : src[i]`.
+ comparison ops
+ binary logical ops
+ `min(a, b)`
+ `cast<dst_t, src_t>(src_vec)`: changing dtype but keeping the bit representation
+ `blendv(a, b, mask)`: `result[i] = mask[i] ? b[i] : a[i]`.
+ ctor with multiple values (i.e., `setr`)
+ `arange(start = 0, step = 1)`: constructs a vector with values specified by the arange parameters
+ `convert_to_int_of_same_size(vec)`: convert floating point vector to corresponding integral type of same size
+ `interleave2(a, b)` & `deinterleave2(x, y)`: interleave or deinterleaves two vectors. E.g., for `interleave`:
```
inputs:
{a0, a1, a2, a3, a4, a5, a6, a7}
{b0, b1, b2, b3, b4, b5, b6, b7}
outputs:
{a0, b0, a1, b1, a2, b2, a3, b3}
{a4, b4, a5, b5, a6, b6, a7, b7}
```
2. Grid sample CPU kernel implementations are described in the following note (also in `GridSampleKernel.cpp`:
```
NOTE [ Grid Sample CPU Kernels ]
Implementation of vectorized grid sample CPU kernels is divided into three
parts:
1. `ComputeLocation` struct
Transforms grid values into interpolation locations of the input tensor
for a particular spatial dimension, basing on the size of that dimension
in input tensor, and the padding mode.
```
```cpp
template<typename scalar_t, GridSamplerPadding padding>
struct ComputeLocation {
using Vec = Vec256<scalar_t>;
// ctor
ComputeLocation(int64_t size);
// Given grid values `in`, return the interpolation locations after
// un-normalization and padding mechanism (elementwise).
Vec apply(const Vec &in) const;
// Similar to `apply`, but also returns `d apply(in) / d in`
// (elementwise).
// this is often used in gradient computation.
std::pair<Vec, Vec> apply_get_grad(const Vec &in) const;
};
```
```
2. `ApplyGridSample` struct
Owns N `ComputeLocation` structs, where N is the number of spatial
dimensions. Given N input grid vectors (one for each spatial dimension)
and spatial offset, it gets the interpolation locations from
`ComputeLocation`s, applies interpolation procedure, and then writes to
the output (or grad_input & grad_grid in backward).
```
```cpp
template<typename scalar_t, int spatial_dim,
GridSamplerInterpolation interp,
GridSamplerPadding padding>
struct ApplyGridSample {
// ctor
ApplyGridSample(const TensorAccessor<scalar_t, 4>& input);
// Applies grid sampling (forward) procedure:
// 1. computes interpolation locations from grid values `grid_x` and
// `grid_y`,
// 2. interpolates output values using the locations and input data
// in `inp_slice`, and
// 3. writes the first `len` values in the interpolated vector to
// `out_slice` with spatial offset being `offset`.
//
// This assimes that `grid_x` and `grid_y` all contain valid grid
// values \in [-1, 1], even at indices greater than `len`.
//
// The `*_slice` argument namess mean samples within a batch (i.e.,
// with the batch dimension sliced out).
void forward(TensorAccessor<scalar_t, 3>& out_slice,
const TensorAccessor<scalar_t, 3>& inp_slice,
int64_t offset, const Vec& grid_x, const Vec& grid_y,
int64_t len) const;
// Applies grid sampling (backward) procedure. Arguments semantics
// and strategy are similar to those of `forward`.
void backward(TensorAccessor<scalar_t, 3>& gInp_slice,
TensorAccessor<scalar_t, 3>& gGrid_slice,
const TensorAccessor<scalar_t, 3>& gOut_slice,
const TensorAccessor<scalar_t, 3>& inp_slice,
int64_t offset, const Vec& grid_x, const Vec& grid_y,
int64_t len) const;
}
```
```
3. `grid_sample_2d_grid_slice_iterator` function
Among the tensors we work with, we know that the output tensors are
contiguous (i.e., `output` in forward, and `grad_input` & `grad_grid` in
backward), we need to randomly read `input` anyways, and `grad_output`
usually comes from autograd and is often contiguous. So we base our
iterating strategy on the geometry of grid.
`grid_sample_2d_grid_slice_iterator` function provides an abstract to
efficiently iterates through a `grid` slice (without batch dimension).
See comments of that function on the specific cases and strategies used.
```
```cpp
template<typename scalar_t, typename ApplyFn>
void grid_sample_2d_grid_slice_iterator(
const TensorAccessor<scalar_t, 3>& grid_slice,
const ApplyFn &apply_fn);
// `apply_fn` is a function/lambda that can be called as if it has
// declaration:
// void apply_fn(const Vec256<scalar_t>& grid_x,
// const Vec256<scalar_t>& grid_y,
// int64_t spatial_offset, int64_t len);
```
```
`apply_fn` will be called multiple times, and together cover the entire
output spatial space. Therefore, e.g., to implement forward 2d grid
sample, we can do
```
```cpp
ApplyGridSample<scalar_t, 2, interp, padding> grid_sample(input_accessor);
for (int n = 0; n < input_accessor.size(0); n++) {
grid_sample_2d_grid_slice_iterator(
grid_accessor[n],
[&](const Vec256<scalar_t>& grid_x, const Vec256<scalar_t>& grid_y,
int64_t spatial_offset, int64_t len) {
grid_sample.forward(out_accessor[n], input_accessor[n],
spatial_offset, grid_x, grid_y, len);
});
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10980
Differential Revision: D9564867
Pulled By: SsnL
fbshipit-source-id: 5b7c3c7ea63af00eec230ae9ee1c3e6c6c9679b4
Summary:
Add the gpu kernel version.
The parallelism I went with performs poorly when there are a large number of vectors, but they're all short, as I don't allocate the thread pool to wrap in that case.
Test Plan
---------
```
python -m unittest test_torch.TestTorch.test_pdist_{empty,scipy} test_nn.TestNN.test_pdist{,_zeros,_empty_row,_empty_col,_cpu_gradgrad_unimplemented,_cuda_gradgrad_unimplemented} test_jit.TestJitGenerated.test_nn_pdist
```
Current performance specs are a little underwhelming, I'm in the process of debugging.
size | torch | torch cuda | scipy
-----|-------|------------|------
16 x 16 | 9.13 µs ± 3.55 µs | 9.86 µs ± 81.5 ns | 15.8 µs ± 1.2 µs
16 x 1024 | 15 µs ± 224 ns | 9.48 µs ± 88.7 ns | 88.7 µs ± 8.83 µs
1024 x 16 | 852 µs ± 6.03 µs | 7.84 ms ± 6.22 µs | 4.7 ms ± 166 µs
1024 x 1024 | 34.1 ms ± 803 µs | 11.5 ms ± 6.24 µs | 273 ms ± 6.7 ms
2048 x 2048 | 261 ms ± 3.5 ms | 77.5 ms ± 41.5 µs | 2.5 s ± 97.6 ms
4096 x 4096 | 2.37 s ± 154 ms | 636 ms ± 2.97 µs | 25.9 s ± 394 ms
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11102
Differential Revision: D9697305
Pulled By: erikbrinkman
fbshipit-source-id: 2b4f4b816c02b3715a85d8db3f4e77479d19bb99
Summary:
* purge hcSPARSE now that rocSPARSE is available
* integrate a custom hcc and HIP
* hcc brings two important compiler fixes (fixes hundreds of unit tests)
* HIP brings a smart dispatcher that allows us to avoid a lot of static_casts (we haven't yet removed the automatic static_casts but this catches some occurrences the script did not catch)
* mark 5 unit tests skipping that have regressed w/ the new hcc (we don't know yet what is at fault)
* optimize bitonic sort - the comparator is always an empty struct - therefore passing it by value saves at least 3 bytes. It also removes an ambiguity around passing references to `__global__` functions
Pull Request resolved: https://github.com/pytorch/pytorch/pull/11198
Differential Revision: D9652340
Pulled By: ezyang
fbshipit-source-id: f5af1d891189da820e3d13b7bed91a7a43154690