Description:
1. Quantize Linear Layer Weights to 4-bits:
Quantize the weights of the Linear layer to 4 bits, using symmetric quantization.
Pack two 4-bit weights into one uint8 container.
Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32.
2. Prepare Quantized Weights, Scales, and Optional Bias:
After quantizing, obtain the quantized_weights, scales, and groupsize.
If the original Linear layer has a bias, prepare it as well.
3. Pack the Weights Efficiently:
Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias.
```python
packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features)
```
Input parameters should include:
in_features and out_features (the same as the Linear layer’s corresponding parameters).
4. Perform Dynamic Quantized Matrix Multiplication:
Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights.
```python
output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights, groupsize, in_features, out_features)
```
Inputs required include:
The input tensor, packed_weights , groupsize, and the in_features and out_features.
API Usage: https://github.com/pytorch/pytorch/issues/143289
Model Perf :
7B Transformer model:
Prefill : 340 t/s
Decode : 40 t/s
2B Transformer model
Prefill : 747 t/s
Decode : 80 t/s
Tests:
python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight
Ran 1 test in 0.016s
OK
python test/test_linalg.py -k test__dyn_quant_matmul_4bit
Ran 8 tests in 0.077s
OK
python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit
Ran 8 tests in 11.454s
Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124
Approved by: https://github.com/digantdesai, https://github.com/malfet
Fixes#137936
The PR contains:
* Fix for `matmul_offline_tunableop`
* Clean-up try-finally blocks in UTs that don't use environment variables (`test_validator_tunableop_rocm`, `test_minimum_tuning_iteration_tunableop`, `test_disable_tuning_tunableop`)
* Avoid the use of environment variables in `minimum_tuning_iteration_tunableop`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143322
Approved by: https://github.com/jeffdaily
TunableOp's rotating buffer feature cannot be properly tested because the environment variable that controls this feature is sticky. A Python API is introduced to modify this value.
Additional items in this PR:
* UT for rotating buffer API
* Clean up UTs that were setting the rotating buffer via the environment variable
* Align behavior of environment variable and Python API when a negative value (< 0) is set.
* Update documentation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143172
Approved by: https://github.com/jeffdaily
Fixes#141652
This PR fixes (at least in part) the unit test failure. However, we may also need to do a separate flush of the untuned results-- if this test continues to be flaky, another PR would be needed to flush the untuned results as well.
Tested locally and it seems to be working.
Also fixing code that was accidentally commented out code in the unit test from the prior multi-gpu offline tuning PR https://github.com/pytorch/pytorch/pull/139673
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142269
Approved by: https://github.com/jeffdaily
Fixes#141636Fixes#141635Fixes#141458
Changes include:
- TunableOp filename that wasn't set properly
- Activate numerical check (see additional test comment)
- Entire test in try-finally clause to avoid OS environment variable leakage (see additional test comment)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142089
Approved by: https://github.com/jeffdaily
Co-authored-by: Jeff Daily <jeff.daily@amd.com>
This PR enhances offline tuning to support multi-GPUs.
High-level description of algorithm:
- Duplicate GEMMs are first eliminated
- GEMMs are distributed to multi-GPUs for tuning
- Results are gathered into a file with `_full` in the filename
Also adding support for GemmAndBias and ScaledGemm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/139673
Approved by: https://github.com/jeffdaily, https://github.com/hongxiayang
Expands the `test_linalg_qr_autograd_errors` unit test to check all cases of differentiablity/non-differentiability as given in the docs https://pytorch.org/docs/stable/generated/torch.linalg.qr.html:
- mode= ‘reduced’ (default): Returns (Q, R) of shapes (*, m, k), (*, k, n) respectively. It is always differentiable.
- mode= ‘complete’: Returns (Q, R) of shapes (*, m, m), (*, m, n) respectively. It is differentiable for m <= n.
- mode= ‘r’: Computes only the reduced R. Returns (Q, R) with Q empty and R of shape (*, k, n). It is never differentiable.
(in particular, the happy paths are added)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/135097
Approved by: https://github.com/IvanYashchuk, https://github.com/nikitaved
- composable_kernel as a third_party submodule
- "ck" as a `torch.backends.cuda.preferred_linalg_library()`
- reference CK gemm implementations for float, bfloat16, and half types
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131004
Approved by: https://github.com/xw285cornell, https://github.com/pruthvistony
Co-authored-by: Andres Lugo <Andy.LugoReyes@amd.com>
Co-authored-by: Pruthvi Madugundu <pruthvigithub@gmail.com>
The test is failing (flakily?) on periodic Windows CUDA jobs with the following error:
```
__________ TestLinalgCUDA.test_matmul_offline_tunableop_cuda_float16 __________
Traceback (most recent call last):
File "C:\actions-runner\_work\pytorch\pytorch\test\test_linalg.py", line 4618, in test_matmul_offline_tunableop
os.remove(filename)
PermissionError: [WinError 32] The process cannot access the file because it is being used by another process: 'tunableop_untuned0.csv'
```
For example, https://github.com/pytorch/pytorch/actions/runs/11292745299/job/31410578167#step:15:15097
The test tried to catch and ignore this, but this is Windows. So, the fix is to:
1. Ignore if these files couldn't be removed
2. Write them to a temp directory instead, otherwise, [assert_git_not_dirty](https://github.com/pytorch/pytorch/blob/main/.ci/pytorch/test.sh#L286) won't be happy
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137835
Approved by: https://github.com/atalman
When enable tunableop, It is easy to have OOM since APP usually needs large video memory size, such as running a LLM for inference. So we need a offline mode to tune the GEMMs. This PR provide an offline mode for tunableOp:
- record untuned GEMMs to file.
- a python API named tune_gemm_in_file is added to read the untuned file and tune the GEMMs in file
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128813
Approved by: https://github.com/jeffdaily, https://github.com/hongxiayang, https://github.com/naromero77amd
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Related to #107302.
When built and tested with NumPy 2 the following unit tests failed.
```
=========================================================== short test summary info ============================================================
FAILED [0.0026s] test/test_linalg.py::TestLinalgCPU::test_householder_product_cpu_complex128 - TypeError: expected np.ndarray (got Tensor)
FAILED [0.0024s] test/test_linalg.py::TestLinalgCPU::test_householder_product_cpu_complex64 - TypeError: expected np.ndarray (got Tensor)
FAILED [0.0025s] test/test_linalg.py::TestLinalgCPU::test_householder_product_cpu_float32 - TypeError: expected np.ndarray (got Tensor)
FAILED [0.0024s] test/test_linalg.py::TestLinalgCPU::test_householder_product_cpu_float64 - TypeError: expected np.ndarray (got Tensor)
FAILED [0.0016s] test/test_linalg.py::TestLinalgCPU::test_nuclear_norm_axes_small_brute_force_old_cpu - ValueError: Unable to avoid copy while creating an array as requested.
FAILED [0.0054s] test/test_linalg.py::TestLinalgCPU::test_solve_cpu_complex128 - AssertionError: The values for attribute 'shape' do not match: torch.Size([0, 0]) != torch.Size([0, 0, 0]).
FAILED [0.0055s] test/test_linalg.py::TestLinalgCPU::test_solve_cpu_complex64 - AssertionError: The values for attribute 'shape' do not match: torch.Size([0, 0]) != torch.Size([0, 0, 0]).
FAILED [0.0048s] test/test_linalg.py::TestLinalgCPU::test_solve_cpu_float32 - AssertionError: The values for attribute 'shape' do not match: torch.Size([0, 0]) != torch.Size([0, 0, 0]).
FAILED [0.0054s] test/test_linalg.py::TestLinalgCPU::test_solve_cpu_float64 - AssertionError: The values for attribute 'shape' do not match: torch.Size([0, 0]) != torch.Size([0, 0, 0]).
=========================================== 9 failed, 1051 passed, 118 skipped in 152.51s (0:02:32) ============================================
```
This PR fixes them. The test is now compatible with both NumPy 1 & 2.
Some more details:
1. The `np.linalg.solve` has changed its behavior. So I added an adapt function in the unit test to keep its behavior the same no matter it is NumPy 1 or Numpy 2.
2. The cause of the failure is when passing a `torch.Tensor` to `np.linalg.qr`, the return type in NumPy 1 is `(np.ndarray, np.ndarray)`, while it is `(torch.Tensor, torch.Tensor)` in NumPy 2.
3. NumPy 2 does not allow `np.array(obj, copy=False)`, but recommended to use `np.asarray(obj)` instead.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136800
Approved by: https://github.com/lezcano
Reland of #128143 but added `alpha` and `bias` initialization to `launchTunableGemmAndBias`
Thus far TunableOp was implemented for gemm, bgemm, and scaled_mm. gemm_and_bias was notably missing. This PR closes that gap.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128919
Approved by: https://github.com/malfet
Summary:
When tunable ops load selected kernels from csv file, it will validate hipblaslt version defined in hipblaslt-version.h
This PR changes the validator to fetch hipblaslt version and revision from hipblaslt runtime instead of the header file, as in our environment we might rollout a new version of the run time prior to updating the header file fleet wide.
Test Plan:
Verified generated tunableops kernel selection has the correct hipblaslt version from runtime:
```
Validator,PT_VERSION,2.5.0
Validator,ROCBLAS_VERSION,4.0.0-72e57364-dirty
Validator,HIPBLASLT_VERSION,800-bf2c3184
Validator,ROCM_VERSION,6.0.0.0-12969-1544e39
Validator,GCN_ARCH_NAME,gfx942:sramecc+:xnack-
GemmTunableOp_BFloat16_TN,tn_8192_2_3584,Gemm_Hipblaslt_TN_572,0.0240676
GemmTunableOp_BFloat16_TN,tn_7168_2_8192,Gemm_Hipblaslt_TN_482,0.0359019
GemmTunableOp_BFloat16_TN,tn_8192_2_1024,Default,0.0173723
GemmTunableOp_BFloat16_TN,tn_1280_2_8192,Gemm_Hipblaslt_TN_491,0.0191047
```
Differential Revision: D59889043
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131078
Approved by: https://github.com/jeffdaily, https://github.com/xw285cornell
Enables a few extra ruff rules, most of which do not have any violations as I already cleaned them with earlier PRs, these just turns them on to enforce them. Adds 1 noqa as we want the suboptimal lambda generation + call kept as a test. Also enables the test in flake8
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130700
Approved by: https://github.com/justinchuby, https://github.com/ezyang
This PR is to update the input `weight` of `_convert_weight_to_int4pack` from `[n][k] int32` to `[n][k / 2] uint8`, both for CPU, CUDA and MPS, which can help decouple int4 model checkpoint with different ISAs and different platforms in `gpt-fast`. The advantage is int4 model checkpoint can be shared in different test machines, without re-generating in one certain platform. Meanwhile, the size of input `weight` can be reduced to `1 / 8`.
Before this PR, packed weight stored in CUDA specific layout: `[n/8][k/(InnerKTiles*16)][32][InnerKTiles/2]`, dtype int32, where InnerKTiles = 2, 4, 8. CPU packed weight viewed as the SAME shape but stored in different layout: `[n/64][k][32]`, dtype uint8. Weight is strongly coupled with platforms (CPU/CUDA) and ISAs (AVX512/AVX2/scalar). And users cannot use a generated weight in another different ISA or platform, because when loading weight into devices, the compute format is different.

Now, we use common serialized layout (`[n][k/2] uint8`) for different devices or ISAs as input `weight` of `_convert_weight_to_int4pack`, and each back chooses how to interpret as compute layout.

### Performance
Intel (R) Xeon (R) CPU Max 9480, single socket (56 cores)
There is no obvious regression of this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129940
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/mingfeima
This PR is to update the input `weight` of `_convert_weight_to_int4pack` from `[n][k] int32` to `[n][k / 2] uint8`, both for CPU, CUDA and MPS, which can help decouple int4 model checkpoint with different ISAs and different platforms in `gpt-fast`. The advantage is int4 model checkpoint can be shared in different test machines, without re-generating in one certain platform. Meanwhile, the size of input `weight` can be reduced to `1 / 8`.
Before this PR, packed weight stored in CUDA specific layout: `[n/8][k/(InnerKTiles*16)][32][InnerKTiles/2]`, dtype int32, where InnerKTiles = 2, 4, 8. CPU packed weight viewed as the SAME shape but stored in different layout: `[n/64][k][32]`, dtype uint8. Weight is strongly coupled with platforms (CPU/CUDA) and ISAs (AVX512/AVX2/scalar). And users cannot use a generated weight in another different ISA or platform, because when loading weight into devices, the compute format is different.

Now, we use common serialized layout (`[n][k/2] uint8`) for different devices or ISAs as input `weight` of `_convert_weight_to_int4pack`, and each back chooses how to interpret as compute layout.

### Performance
Intel (R) Xeon (R) CPU Max 9480, single socket (56 cores)
There is no obvious regression of this PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129940
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/mingfeima
- Add AMD support for int4 kernel
- Only supports CDNA2 and CDNA3 gpus for now
- Uses `mfma_f32_16x16x16bf16` instruction for matrix multiply
- Uses `v_and_or_b32` instruction and `__hfma2` instrinsic for unpacking bf16 values
- Enable hipify for `__nv_bfloat16` and `__nv_bfloat162` data types
- Enable int4 unit tests for CDNA2 and CDNA3 AMD gpus
- Fix torchscript issues due to hipify for `__nv_bfloat16` type
- TorchScript has its own implementation for bfloat16 type
- Implemented in `__nv_bloat16` structure at [resource_strings.h](https://github.com/pytorch/pytorch/blob/main/torch/csrc/jit/codegen/fuser/cuda/resource_strings.h)
- So, we shouldn't hipify any reference of `__nv_bfloat16` in the torchscript implementation
- Hence moved the `__nv_bfloat16` direct references in `codegen.cpp` and `cuda_codegen.cpp` to `resource_strings.h` which is already exempted from hipify
Fixes#124699
Fixes pytorch-labs/gpt-fast/issues/154
Co-authored-by: Nikita Shulga <2453524+malfet@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129710
Approved by: https://github.com/malfet
In nvidia internal testing, for slower devices such as Orin NX, on large dtypes like complex128, test_linalg_solve_triangular_large is taking multiple hours to complete and timing out CI. This PR adds a slowTest marker so it can be skipped due to speed issues. cc @nWEIdia
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129903
Approved by: https://github.com/lezcano
We fix a number of bugs previously present in the complex
implementation.
We also heavily simplify the implementation, using, among
other things, that we now have conjugate views.
I saw there is a comment regarding how slow some checks on this
function are. As such, I removed quite a few of the combinations of inputs
to make the OpInfo lighter. I still left a couple relevant examples to not regress
coverage though.
Fixes https://github.com/pytorch/pytorch/issues/122188
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125580
Approved by: https://github.com/pearu, https://github.com/peterbell10
Summary: This kernel is special-cased on ARM because it's important for LLMs, so let's have test coverage.
Test Plan: Ran locally and it passes. Intentionally broke fp16_gemv_trans and saw it fail, confirming it provides coverage.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/126700
Approved by: https://github.com/malfet
- Implement a very straightforward Metal copy of CPU int4mm kernel
- Implement int8mm kernel by constructing a graph consisting of upcast, transpose and mm
- Add `isCapturing`, `isCaptureEnabled`, `startCapture` and `stopCapture` methods to `MPSProfile` which can be used to help one debug/profile Metal kernels by wrapping the calls with the following
```cpp
if (getMPSProfiler().profiler.isCaptureEnabled()) {
getMPSProfiler().startCapture(__func__, mpsStream);
}
...
if (getMPSProfiler().isCapturing()) {
getMPSProfiler().stopCapture(mpsStream);
}
```
that, if invoked with `MTL_CAPTURE_ENABLED` environment variable set to one, will produce .gputrace files, in the current working directory, which can later be loaded and used to debug or profiler the kernel
<img width="1093" alt="image" src="https://github.com/pytorch/pytorch/assets/2453524/a2bf27e8-df8a-442c-a525-1df67b8a376a">
- Added `test_int4mm` to TestLinalgMPS, which is mostly copy-n-paste of the test from `test_linalg`
TODOs:
- Add weight pack
- Perf-tune both kernels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125163
Approved by: https://github.com/mikekgfb
Following the example of PyTorch supporting a preferred Linalg library (cusolver or magma), this PR introduces a preferred blas library selector of either cublas or cublaslt for CUDA and hipblas or hipblaslt for ROCm via normal hipification of sources.
The default blas implementation remains cublas or hipblas. cublaslt or hipblaslt can be enabled using environment variable TORCH_BLAS_PREFER_CUBLASLT=1 (or TORCH_BLAS_PREFER_HIPBLASLT=1 as an alias) or by calling `torch.backends.cuda.preferred_blas_library(backend="cublaslt")` or as an alias `backend="hipblaslt"`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122106
Approved by: https://github.com/lezcano
This replaces a bunch of unnecessary lambdas with the operator package. This is semantically equivalent, but the operator package is faster, and arguably more readable. When the FURB rules are taken out of preview, I will enable it as a ruff check.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116027
Approved by: https://github.com/malfet
disable test int_mm for sm90 or later
```
python test/test_linalg.py -k test__int_mm_k_32_n_32_use_transpose_a_False_use_transpose_b_False_cuda
_ TestLinalgCUDA.test__int_mm_k_32_n_32_use_transpose_a_False_use_transpose_b_False_cuda _
Traceback (most recent call last):
File "/usr/lib/python3.10/unittest/case.py", line 59, in testPartExecutor
yield
File "/usr/lib/python3.10/unittest/case.py", line 591, in run
self._callTestMethod(testMethod)
File "/usr/lib/python3.10/unittest/case.py", line 549, in _callTestMethod
method()
File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_utils.py", line 2410, in wrapper
method(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_utils.py", line 2410, in wrapper
method(*args, **kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_device_type.py", line 428, in instantiated_test
raise rte
File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_device_type.py", line 415, in instantiated_test
result = test(self, **param_kwargs)
File "/usr/local/lib/python3.10/dist-packages/torch/testing/_internal/common_device_type.py", line 1084, in only_fn
return fn(slf, *args, **kwargs)
File "/opt/pytorch/pytorch/test/test_linalg.py", line 5719, in test__int_mm
_test(17, k, n, use_transpose_a, use_transpose_b)
File "/opt/pytorch/pytorch/test/test_linalg.py", line 5680, in _test
c_int32 = torch._int_mm(a_int8, b_int8)
RuntimeError: CUDA error: CUBLAS_STATUS_NOT_SUPPORTED when calling cublasLtMatmul with transpose_mat1 0 transpose_mat2 0 m 32 n 17 k 32 mat1_ld 32 mat2_ld 32 result_ld 32 abType 3 cType 10 computeType 72 scaleType 10
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113327
Approved by: https://github.com/malfet
Currently, for `matrix_exp` function, if we have NaN values in the input matrices (small batches), it will keep outputting a "normal" result without any NaN value in it, and this will cause some problems that we may can't notice. This PR is for preventing such undefined behavior by "bring back" those NaN values.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111539
Approved by: https://github.com/lezcano
Fixes#109604
Resubmit gh-109715 + several skips and small fixes to make tests pass.
The main fix here is by @ysiraichi : previously, dynamo did not resume tracing numpy ndarrays after a graph break.
While at it, fix several small issues Yukio's fix uncovers:
- graph break gracefully on numpy dtypes which do not map to torch.dtypes (uint16 etc)
- recognize array scalars in dynamo, treat them as 0D ndarrays
- make sure that iterating over torch.ndarray generates arrays not bare tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110512
Approved by: https://github.com/lezcano
Fixes#108754.
`hf_T5_generate` would encounter a regression when calling `extern_kernels.bmm`, if one input is `reinterpret_tensor(buf2, (8, 1, 64), (64, 0, 1))` rather than `reinterpret_tensor(buf2, (8, 1, 64), (64, 512, 1), 0)`. As @jgong5 mentioned in comment, in fact the two tensors are equivalent: The stride doesn't matter when the corresponding size is 1.
We revise the definition of contiguity in `bmm` to add the above situation as a contiguous case. Thus, when stride equals to 0, `extern_kernels.bmm` could still use `gemm` of MKL to gain the performance.
Speedup of `hf_T5_generate` is **1.343x** now and **1.138x** before, with script `bash inductor_single_test.sh multiple inference performance torchbench hf_T5_generate float32 first dynamic default 0`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110811
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/Chillee
Fixes#68972
Relands #107246
To avoid causing Meta-internal CI failures, this PR avoids always asserting that the default dtype is float in the `TestCase.setUp/tearDown` methods. Instead, the assert is only done if `TestCase._default_dtype_check_enabled == True`. `_default_dtype_check_enabled` is set to True in the `if __name__ == "__main__":` blocks of all the relevant test files that have required changes for this issue
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108088
Approved by: https://github.com/ezyang
This is a follow up to https://github.com/pytorch/pytorch/pull/105881 and replaces https://github.com/pytorch/pytorch/pull/103203
The batched linalg drivers from 103203 were brought in as part of the first PR. This change enables the ROCm unit tests that were enabled as a result of that change. Along with a fix to prioritize hipsolver over magma when the preferred linalg backend is set to `default`
The following 16 unit tests will be enabled for rocm in this change:
- test_inverse_many_batches_cuda*
- test_inverse_errors_large_cuda*
- test_linalg_solve_triangular_large_cuda*
- test_lu_solve_batched_many_batches_cuda*
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106620
Approved by: https://github.com/lezcano
Current test case causes an edge case tensor input that causes a single generated tensor to fail the tolerance assertion on ROCm only and only for float32. We have reviewed the logic with our libraries team and have discovered the discrepancy is due to a difference in order of operations on AMD GPUs. They came back with "working as intended" and found no perceivable bug. Interestingly, if we change the values in ks, ns, or bs, the test passes on ROCm. These particular sizes in this particular order generates a single problematic input that causes the assertion to fail the tolerance check by ~0.07. Again, this is not a bug, just differences in implementation. This PR loosens the tolerance for ROCm only.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/104425
Approved by: https://github.com/jeffdaily, https://github.com/nikitaved, https://github.com/lezcano