mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
# Summary
This PR adds an alternative triton lowering for _scaled_mm. This uses an updated mm template that utilizes persistent scheduling + TMAs on A and B matrices.
Limitations:
* This implementations does not work with Bias values: 0602676c8d/torch/_inductor/kernel/mm_scaled.py (L106) Plan is to remove this work around and enforce that both scaling + bias is properly done as epilogues onto the existing templates
* K dim must be 32 or greater for these to take effect
* Gated by a config flag ( currently defaults to Off, maybe should be on)
## Testing
We dont have any tests exercising this code in CI/CD but I updated the relevant tests in test_fp8 and they are all green:
<img width="1680" alt="Screenshot 2024-12-05 at 7 24 07 PM" src="https://github.com/user-attachments/assets/9c520541-d97a-416f-9af7-e68b366ec90f">
## Follow Ups
* Work to update the base mm triton templates and utilize the same template from mm/addmm/scaled_mm w/ respective epilogues
* Tuning on Persistent kernel configs. I found ones that work for my problem shapes but need to do some more NCU work
### Some profiling code I was using
Code I am using to iterate w/
```Python
import torch
from dataclasses import dataclass
from jsonargparse import CLI
import logging
from pathlib import Path
from transformer_nuggets.utils.benchmark import ProfileConfig, profile_function
from torchao.float8.inference import (
addmm_float8_unwrapped_inference,
preprocess_data,
Float8MMConfig,
)
from transformer_nuggets.fp8.fp8_matmul import (
matmul_persistent,
matmul_tma_persistent,
matmul_device_tma_persistent,
)
from enum import Enum
logging.getLogger("transformer_nuggets").setLevel(logging.INFO)
class FP8Kernel(Enum):
PERSISTENT = "Persistent"
PERSISTENT_TMA = "Persistent-TMA"
DEVICE_TMA = "Device-TMA"
SCALED_MM = "Scaled-MM"
class ScalingStrategy(Enum):
PER_TENSOR = "PerTensor"
PER_ROW = "PerRow"
@dataclass(frozen=True)
class ExperimentConfig:
M: int
K: int
N: int
scaling_strategy: ScalingStrategy
fp8_kernel: FP8Kernel
compile: bool
def get_fp8_matmul(
A: torch.Tensor,
B: torch.Tensor,
scaling_strategy: ScalingStrategy,
fp8_kernel: FP8Kernel,
):
A_fp8 = A.to(torch.float8_e4m3fn)
B_fp8 = B.to(torch.float8_e4m3fn)
A_fp8, B_fp8 = preprocess_data(A_fp8, B_fp8, Float8MMConfig(use_fast_accum=True))
if scaling_strategy == ScalingStrategy.PER_TENSOR:
a_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
b_scale = torch.tensor(1, device="cuda", dtype=torch.float32)
elif scaling_strategy == ScalingStrategy.PER_ROW:
a_scale = torch.ones((A_fp8.size(0), 1), device="cuda", dtype=torch.float32)
b_scale = torch.ones((B_fp8.size(1), 1), device="cuda", dtype=torch.float32).T
else:
raise ValueError(f"Invalid scaling strategy: {scaling_strategy}")
assert fp8_kernel == FP8Kernel.SCALED_MM
return lambda: addmm_float8_unwrapped_inference(
A_fp8, a_scale, B_fp8, b_scale, output_dtype=torch.bfloat16, use_fast_accum=True
)
def run_matmul(config: ExperimentConfig):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
A = torch.randn(config.M, config.K, device=device, dtype=torch.bfloat16)
B = torch.randn(config.K, config.N, device=device, dtype=torch.bfloat16)
fp8_matmul = get_fp8_matmul(A, B, config.scaling_strategy, config.fp8_kernel)
if config.compile and config.fp8_kernel == FP8Kernel.SCALED_MM:
fp8_matmul = torch.compile(fp8_matmul, mode="max-autotune-no-cudagraphs")
_ = fp8_matmul()
return
def main():
torch.random.manual_seed(123)
# Define your experiment configuration here
config = ExperimentConfig(
M=8192,
K=8192,
N=8192,
scaling_strategy=ScalingStrategy.PER_TENSOR,
fp8_kernel=FP8Kernel.SCALED_MM,
compile=True,
)
run_matmul(config)
if __name__ == "__main__":
CLI(main)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/142045
Approved by: https://github.com/eellison
136 lines
3.6 KiB
Python
136 lines
3.6 KiB
Python
# mypy: allow-untyped-defs
|
|
import functools
|
|
import hashlib
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton_package() -> bool:
|
|
try:
|
|
from triton.compiler.compiler import triton_key
|
|
|
|
return triton_key is not None
|
|
except ImportError:
|
|
return False
|
|
except RuntimeError:
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton_tma():
|
|
if has_triton_package():
|
|
import torch
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
):
|
|
try:
|
|
from triton.tools.experimental_descriptor import ( # noqa: F401
|
|
create_1d_tma_descriptor,
|
|
create_2d_tma_descriptor,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton_tma_device():
|
|
if has_triton_package():
|
|
import torch
|
|
|
|
if (
|
|
torch.cuda.is_available()
|
|
and torch.cuda.get_device_capability() >= (9, 0)
|
|
and not torch.version.hip
|
|
):
|
|
try:
|
|
from triton.language.extra.cuda import ( # noqa: F401
|
|
experimental_device_tensormap_create1d,
|
|
experimental_device_tensormap_create2d,
|
|
)
|
|
|
|
return True
|
|
except ImportError:
|
|
pass
|
|
|
|
return False
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def has_triton() -> bool:
|
|
if not has_triton_package():
|
|
return False
|
|
|
|
from torch._dynamo.device_interface import get_interface_for_device
|
|
|
|
def cuda_extra_check(device_interface):
|
|
return device_interface.Worker.get_device_properties().major >= 7
|
|
|
|
def cpu_extra_check(device_interface):
|
|
import triton.backends
|
|
|
|
return "cpu" in triton.backends.backends
|
|
|
|
def _return_true(device_interface):
|
|
return True
|
|
|
|
triton_supported_devices = {
|
|
"cuda": cuda_extra_check,
|
|
"xpu": _return_true,
|
|
"cpu": cpu_extra_check,
|
|
}
|
|
|
|
def is_device_compatible_with_triton():
|
|
for device, extra_check in triton_supported_devices.items():
|
|
device_interface = get_interface_for_device(device)
|
|
if device_interface.is_available() and extra_check(device_interface):
|
|
return True
|
|
return False
|
|
|
|
return is_device_compatible_with_triton()
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def triton_backend():
|
|
from triton.compiler.compiler import make_backend
|
|
from triton.runtime.driver import driver
|
|
|
|
target = driver.active.get_current_target()
|
|
return make_backend(target)
|
|
|
|
|
|
@functools.lru_cache(None)
|
|
def triton_hash_with_backend():
|
|
from triton.compiler.compiler import triton_key
|
|
|
|
backend = triton_backend()
|
|
key = f"{triton_key()}-{backend.hash()}"
|
|
|
|
# Hash is upper case so that it can't contain any Python keywords.
|
|
return hashlib.sha256(key.encode("utf-8")).hexdigest().upper()
|
|
|
|
|
|
def dtype_to_string(dtype):
|
|
if dtype.name.startswith("fp"):
|
|
suffix = "float" + dtype.name[2:]
|
|
elif dtype.name.startswith("bf"):
|
|
suffix = "bfloat" + dtype.name[2:]
|
|
else:
|
|
suffix = dtype.name
|
|
return "triton.language." + suffix
|
|
|
|
|
|
def patch_triton_dtype_repr():
|
|
import triton
|
|
|
|
# Hack to get triton dtype repr to produce an evaluatable expression
|
|
# triton.language.float32 emits triton.language.fp32 which does not
|
|
# exist
|
|
# REMOVE when https://github.com/openai/triton/pull/3342 lands
|
|
triton.language.dtype.__repr__ = lambda self: dtype_to_string(self)
|