Summary:
To fix the following failure cases:
For example, when `M, K, N = 245760, 656, 6560`, fp8 with compile fails due to `RuntimeError: mat2 must be col_major`.
---------
From the inductor generated code (https://fburl.com/everpaste/epcagkrd)
```
V0625 01:38:55.551000 140329914449920 torch/_inductor/scheduler.py:1623] [0/0] scheduling ComputedBuffer(name='buf12', layout=FixedLayout('cuda', torch.float8_e4m3fn, size=[656, 6560], stride=[6656, 1]),
... ...
V0625 01:38:56.194000 140329914449920 torch/_inductor/graph.py:1680] [0/0] [__output_code] buf12 = empty_strided_cuda((656, 6560), (6656, 1), torch.float8_e4m3fn)
... ...
V0625 01:38:56.194000 140329914449920 torch/_inductor/graph.py:1680] [0/0] [__output_code] return (buf10, buf2, buf5, buf6, reinterpret_tensor(buf11, (245760, 656), (1, 245760), 0), reinterpret_tensor(buf12, (6560, 656), (1, 6656), 0), )
... ...
V0625 01:39:12.098000 140312968167424 torch/_inductor/graph.py:1680] [1/0_1] [__output_code] assert_size_stride(permute_10, (6560, 656), (1, 6656))
... ...
V0625 01:39:12.098000 140312968167424 torch/_inductor/graph.py:1680] [1/0_1] [__output_code] buf8 = aten._scaled_mm.default(buf6, permute_10, buf7, reciprocal_3, None, None, torch.bfloat16)
```
Inductor gives the mat2 (`permute_10`) a different stride (`6656`) instead of using its shape[0] (`(6560, 656)`).
Therefore, the `stride[1] == shape[0]` condition fails.
To fix the issue, simply modify the `is_col_major` check to exclude this condition as it doesn't hold for all valid cases.
Test Plan:
Run the failed case again. It works with the fix.
-----
Sandcastle / GitHub CI will make sure the existing tests could still pass.
Reviewed By: vkuzo
Differential Revision: D58994704
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129521
Approved by: https://github.com/drisspg
# Summary
First PR got reverted and needed a redo
This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met:
- `x`'s scale should be a 1-dimensional tensor of length `M`.
- `y`'s scale should be a 1-dimensional tensor of length `N`.
It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".
The following two PRs were required to enable local builds:
- [PR #126185](https://github.com/pytorch/pytorch/pull/126185)
- [PR #125523](https://github.com/pytorch/pytorch/pull/125523)
### Todo
We still do not build our Python wheels with this architecture.
@ptrblck @malfet, should we replace `sm_90` with `sm_90a`?
The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954
#### ifdef
I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \
defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this
Kernel Credit:
@jwfromm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128989
Approved by: https://github.com/yangsiyu007, https://github.com/vkuzo
# Summary
The primary reason for the change was lack of current use case and the need to work around an two Inductor issue.
- Tensor arguments as kwarg only
- multiple outputs from triton templates
If the need for the amax return type arises we can consider either adding it, more likely creating a separate op.
In principle PyTorch is moving away from ops that bundle lots of functionality into "mega ops". We instead rely upon the compiler to generate appropriate fused kernels.
### Changes:
- This removes the amax return type from scaled_mm. We have found that the common use case is to return in "high-precision" ( a type with more precision than fp8). This is only relevant when returning in low-precision.
- We currently still allow for fp8 returns and scaled result. Perhaps we should also ban this as well...
New signature:
```Python
def meta_scaled_mm(
self: torch.Tensor,
mat2: torch.Tensor,
scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: Optional[torch.Tensor] = None,
scale_result: Optional[torch.Tensor] = None,
out_dtype: Optional[torch.dtype] = None,
use_fast_accum: bool = False,
) -> torch.Tensor:
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/128683
Approved by: https://github.com/vkuzo
# Summary
This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met:
- `x`'s scale should be a 1-dimensional tensor of length `M`.
- `y`'s scale should be a 1-dimensional tensor of length `N`.
It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".
The following two PRs were required to enable local builds:
- [PR #126185](https://github.com/pytorch/pytorch/pull/126185)
- [PR #125523](https://github.com/pytorch/pytorch/pull/125523)
### Todo
We still do not build our Python wheels with this architecture.
@ptrblck @malfet, should we replace `sm_90` with `sm_90a`?
The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954
#### ifdef
I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \
defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this
Kernel Credit:
@jwfromm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125204
Approved by: https://github.com/lw, https://github.com/malfet
# Summary
This pull request introduces an fp8 row-scaling kernel as an optional implementation for `scaled_mm`. The kernel selection is based on the scaling tensors of the inputs. For inputs `x` and `y` of shape `[M, K]` and `[K, N]` respectively, the following conditions must be met:
- `x`'s scale should be a 1-dimensional tensor of length `M`.
- `y`'s scale should be a 1-dimensional tensor of length `N`.
It's important to note that this kernel is not called "rowwise, columnwise" scaling because, although the scales for `y` are semantically along its columns, this implementation only supports the TN format. This means the scaling is along the faster-moving dimension, or the "row".
The following two PRs were required to enable local builds:
- [PR #126185](https://github.com/pytorch/pytorch/pull/126185)
- [PR #125523](https://github.com/pytorch/pytorch/pull/125523)
### Todo
We still do not build our Python wheels with this architecture.
@ptrblck @malfet, should we replace `sm_90` with `sm_90a`?
The NVRTC TMA shadowing feels wrong, but I a not sure the right way to spoof the symbol for this compilation unit:
https://github.com/pytorch/pytorch/pull/125204/files#r1586986954
#### ifdef
I tried to use : `#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12000 && \
defined(__CUDA_ARCH__) && __CUDA_ARCH__ > 900` to gate the building of the kernel. I was having a hell of a time with this.. so I am not really sure the right way to do this
Kernel Credit:
@jwfromm
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125204
Approved by: https://github.com/lw
Recently there has been work in an experimental repo to start implementing the intrinsics necessary handle F8 workloads. (see: https://github.com/pytorch-labs/float8_experimental)
A recent PR was submitted to add support for AMD F8 types (fnuz). This PR uncovered a bug in the rocm code that caused unit tests to fail due to numerical inaccuracy. This PR fixes that bug by swapping `abs_()` with `abs()` as the former performs elementwise absolute value on the tensor in-place causing the final assertion to fail due to the tensor only containing positive values.
Important to note, this fix is part of a workaround as hipblasLT does not yet support amax (HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER). This functionality has been implemented internally and is going through the proper channels to propagate to the community.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123275
Approved by: https://github.com/drisspg, https://github.com/jeffdaily
scaled_gemm for ROCm using hipblaslt. As of ROCm 6.0, HIPBLASLT_MATMUL_DESC_AMAX_D_POINTER is not supported. A work-around is provided, performing the absmax operation on the output buffer, but this results in some loss of accuracy for the absmax result. For this reason the feature should be considered beta/preview.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117822
Approved by: https://github.com/jianyuh, https://github.com/xw285cornell
CC @malfet @ptrblck
~~We've been seeing a lot of noise from Ampere and later devices due to reduced precision reductions, so preemptively disabling them for addmm tests.~~
Breaking out addmm tests into one with and without reduced precision reductions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112545
Approved by: https://github.com/malfet
Fixes#68972
Relands #107246
To avoid causing Meta-internal CI failures, this PR avoids always asserting that the default dtype is float in the `TestCase.setUp/tearDown` methods. Instead, the assert is only done if `TestCase._default_dtype_check_enabled == True`. `_default_dtype_check_enabled` is set to True in the `if __name__ == "__main__":` blocks of all the relevant test files that have required changes for this issue
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108088
Approved by: https://github.com/ezyang
Summary:
Based on D48377631 with updates to guard the utilization of cublas features only found after 11.8
According to https://docs.nvidia.com/cuda/cublas/#id99 only FP8 matrix types can be scaled, and `Float8_e4m3`x`Float8_e4m3` results can be returned as `Float8_e4m3` type, or upcast to `Half`, `BFloat16` or `Float`, but in that case `result_scale` will have no effect as well as `amax` would not be computed.
Optional `bias` argument can also be passed to a function, which should be a vector of either `Half` or `BFloat16`, whose values are added to each row of the result matrix.
See table below for supported input and output types:
| Mat1 type | Mat2 type | Bias type | Output types |
| ----------- | ----------- | ----------- | ----------- |
| Float8_e4m3 | Float8_e4m3 | Float16 | Float8_e4m3, Float16 |
| Float8_e4m3 | Float8_e4m3 | BFloat16 | Float8_e4m3, BFloat16, Float |
| Float8_e5m2 | Float8_e4m3 | Float16 | Float8_e4m3, Float8_e5m2, Float16 |
| Float8_e5m2 | Float8_e4m3 | BFloat16 | Float8_e4m3, Float8_e5m2, BFloat16, Float |
| Float8_e4m3 | Float8_e5m2 | Float16 | Float8_e4m3, Float8_e5m2, Float16 |
| Float8_e4m3 | Float8_e5m2 | BFloat16 | Float8_e4m3, Float8_e5m2, BFloat16, Float |
| Float8_e4m3 | Float8_e5m2 | Not supported | Not supported |
Skip decomposition implementation until fp8-on-triton story is better defined, Potential decomposition can look something like the following:
```python
register_decomposition(aten._scaled_mm)
def _scaled_mm(
mat1: Tensor,
mat2: Tensor,
*,
dtype: Optional[torch.dtype] = None,
scale_a: Optional[Tensor] = None,
scale_b: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
rc = torch.mm(mat1.to(torch.float32), mat2.to(torch.float32))
rc = scale_a * rc if scale_a is not None else rc
rc = scale_b * rc if scale_b is not None else rc
rc = scale_result * rc if scale_result is not None else rc
rc = rc.to(dtype if dtype is not None else mat1.dtype)
return rc, torch.tensor(0.0, device=mat1.device)
```
Known limitations:
- Only works for matrix sizes divisible by 16
- 1st operand must be in row-major and 2nd in column-major orders (i.e. if `x` and `y` are contiguous, than only `torch._scaled_mm(x, y.t())` will work)
Test Plan: Tests in test_matmul_cda.py
Differential Revision: D48415871
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107341
Approved by: https://github.com/vkuzo
According to https://docs.nvidia.com/cuda/cublas/#id99 only FP8 matrix types can be scaled, and `Float8_e4m3`x`Float8_e4m3` results can be returned as `Float8_e4m3` type, or upcast to `Half`, `BFloat16` or `Float`, but in that case `result_scale` will have no effect as well as `amax` would not be computed.
Optional `bias` argument can also be passed to a function, which should be a vector of either `Half` or `BFloat16`, whose values are added to each row of the result matrix.
See table below for supported input and output types:
| Mat1 type | Mat2 type | Bias type | Output types |
| ----------- | ----------- | ----------- | ----------- |
| Float8_e4m3 | Float8_e4m3 | Float16 | Float8_e4m3, Float16 |
| Float8_e4m3 | Float8_e4m3 | BFloat16 | Float8_e4m3, BFloat16, Float |
| Float8_e5m2 | Float8_e4m3 | Float16 | Float8_e4m3, Float8_e5m2, Float16 |
| Float8_e5m2 | Float8_e4m3 | BFloat16 | Float8_e4m3, Float8_e5m2, BFloat16, Float |
| Float8_e4m3 | Float8_e5m2 | Float16 | Float8_e4m3, Float8_e5m2, Float16 |
| Float8_e4m3 | Float8_e5m2 | BFloat16 | Float8_e4m3, Float8_e5m2, BFloat16, Float |
| Float8_e4m3 | Float8_e5m2 | Not supported | Not supported |
Skip decomposition implementation until fp8-on-triton story is better defined, Potential decomposition can look something like the following:
```python
@register_decomposition(aten._scaled_mm)
def _scaled_mm(
mat1: Tensor,
mat2: Tensor,
*,
dtype: Optional[torch.dtype] = None,
scale_a: Optional[Tensor] = None,
scale_b: Optional[Tensor] = None,
scale_result: Optional[Tensor] = None,
) -> Tuple[Tensor, Tensor]:
rc = torch.mm(mat1.to(torch.float32), mat2.to(torch.float32))
rc = scale_a * rc if scale_a is not None else rc
rc = scale_b * rc if scale_b is not None else rc
rc = scale_result * rc if scale_result is not None else rc
rc = rc.to(dtype if dtype is not None else mat1.dtype)
return rc, torch.tensor(0.0, device=mat1.device)
```
Known limitations:
- Only works for matrix sizes divisible by 16
- 1st operand must be in row-major and 2nd in column-major orders (i.e. if `x` and `y` are contiguous, than only `torch._scaled_mm(x, y.t())` will work)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106844
Approved by: https://github.com/albanD
ghstack dependencies: #106977
Fixes the underlying issue previously addressed in #92201 by specifying minimum alignments explicitly to `cuBLAS` rather than relying on a handcrafted rule. ~~We're still investigating some potential failure modes on `sm80` and `sm90` but those would be real `cuBlasLt` heuristics bugs rather than being caused by underspecifying constraints to the heuristics.~~
According to the `cuBLAS` docs the default alignment is 256 bytes so that is the current maximum that is currently being checked: https://docs.nvidia.com/cuda/cublas/
CC @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98975
Approved by: https://github.com/ngimel
Follow-up of #89582 to drop flags like `CUDA11OrLater` in tests. Note that in some places it appears that `TEST_WITH_ROCM` is _implicitly_ guarded against via the `CUDA11OrLater` version check, based on my best-guess of how `torch.version.cuda` would behave in ROCM builds, so I've added `not TEST_WITH_ROCM` in cases where ROCM wasn't previously explicitly allowed.
CC @ptrblck @malfet @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92605
Approved by: https://github.com/ngimel
Fix for this issue surfaced from the discuss forum: https://discuss.pytorch.org/t/cuda-error-cublas-status-not-supported-when-calling-cublasltmatmul-from-torch-nn-functional-linear/170214
Note that PyTorch builds before #71200 should not be affected as there was no `cublasLt` dispatch path. Additionally, the provided repro has the quirk of using a 3D input, which means it will not dispatch to `cublasLt`-backed `addmm` until builds that include #72728. Changing the input to 2D by trivially removing the size `1` dimension will surface the failure on builds after #71200.
Interestingly, the use-case where _all_ inputs are 2-byte aligned are supported (runs without crashing), but when some are > 2-byte and some are == 2-byte are not. This behavior suggests that the `cuBlastLt` heuristics are incorrect, as the heuristic function has visibility of the raw pointer values via the descriptors when it is called.
We will follow up with `cuBlasLt` but this fix is needed to prevent unnecessary crashes for now.
CC @ptrblck @ngimel
Pull Request resolved: https://github.com/pytorch/pytorch/pull/92201
Approved by: https://github.com/ngimel