mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Support deepseek-style scaling in Inductor Triton for FP8 GEMMs. DeepSeek-style scaling is a colloquial term for a fine-grained mixed precision framework using FP8 to train [Deepseek-V3](https://arxiv.org/pdf/2412.19437), DeepSeek AI's recent MoE (Mixture of Experts) model. DeepSeek-style scaling effectively extends the dynamic range of FP8 by mitigating dequantization overhead under increased-precision accumulation, which is key to achieving more accurate FP8 GEMM results. DeepSeek-style scaling on matmul `A @ B` leverages two different types of scaling strategies to preserve a balance between numerical stability and training efficiency: - Activations (input tensor `A`): tile-wise (1x128 across shape `(M, K)`) - Weights (input tensor `B`): block-wise (128x128 across shape `(N, K)`) This diff enables Inductor users to replicate past successes with deepseek-style scaling and achieve higher numerical stability while increasing training efficiency. NOTE: Block-wise 128x128 scaling is only supported in CUDA 12.9+; therefore, deepseek-style scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run deepseek-style scaling. NOTE: Accuracy for FP8 is unstable, even with high tolerances, which is why TritonBench benchmarks are unlikely to be accurate against a `torch` implementation. Test Plan: In OSS PyTorch, run ``` TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 4096 --n 768 --k 512 --output="{output_dir}/deepseek_bench.csv" --scaling_deepseek --atol=1e-2 --rtol=0.5 2>&1 | tee ~/personal/deepseek_style/deepseek_bench.log ``` Differential Revision: D83609850 Pull Request resolved: https://github.com/pytorch/pytorch/pull/164404 Approved by: https://github.com/slayton58 |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| aten.py | ||
| base.py | ||
| contiguous_mm.py | ||
| decompose_k.py | ||
| gemm.py | ||
| params.py | ||
| registry.py | ||
| triton_addmm.py | ||
| triton.py | ||