pytorch/torch/_inductor/template_heuristics
Janani Sriram ff46d5a79b [Inductor][Triton][FP8] Support deepseek-style scaling in Inductor (#164404)
Summary:
Support deepseek-style scaling in Inductor Triton for FP8 GEMMs. DeepSeek-style scaling is a colloquial term for a fine-grained mixed precision framework using FP8 to train [Deepseek-V3](https://arxiv.org/pdf/2412.19437), DeepSeek AI's recent MoE (Mixture of Experts) model. DeepSeek-style scaling effectively extends the dynamic range of FP8 by mitigating dequantization overhead under increased-precision accumulation, which is key to achieving more accurate FP8 GEMM results.

DeepSeek-style scaling on matmul `A @ B` leverages two different types of scaling strategies to preserve a balance between numerical stability and training efficiency:
- Activations (input tensor `A`): tile-wise (1x128 across shape `(M, K)`)
- Weights (input tensor `B`): block-wise (128x128 across shape `(N, K)`)

This diff enables Inductor users to replicate past successes with deepseek-style scaling and achieve higher numerical stability while increasing training efficiency.

NOTE: Block-wise 128x128 scaling is only supported in CUDA 12.9+; therefore, deepseek-style scaling is currently unsupported in `fbcode` (CUDA 12.4). Use OSS PyTorch to run deepseek-style scaling.

NOTE: Accuracy for FP8 is unstable, even with high tolerances, which is why TritonBench benchmarks are unlikely to be accurate against a `torch` implementation.

Test Plan:
In OSS PyTorch, run
```
TORCHINDUCTOR_CACHE_DIR=~/personal/cache_dir_inductor CUDA_LAUNCH_BLOCKING=1 TORCH_USE_CUDA_DSA=1 TRITON_PRINT_AUTOTUNING=1 TRITON_ALWAYS_COMPILE=1 TORCH_LOGS=+inductor TORCHINDUCTOR_FORCE_DISABLE_CACHES=1 ENABLE_PERSISTENT_TMA_MATMUL=1 TORCHINDUCTOR_MAX_AUTOTUNE_GEMM=1 python run.py --op fp8_gemm --only torch_fp8_gemm,pt2_fp8_gemm --metrics tflops,accuracy --m 4096 --n 768 --k 512 --output="{output_dir}/deepseek_bench.csv" --scaling_deepseek --atol=1e-2 --rtol=0.5 2>&1 | tee ~/personal/deepseek_style/deepseek_bench.log
```

Differential Revision: D83609850

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164404
Approved by: https://github.com/slayton58
2025-10-28 03:38:54 +00:00
..
__init__.py [inductor][aten] treat like a template in GEMMs (#161342) 2025-09-05 18:02:10 +00:00
aten.py [inductor] Support out_dtype arg to matmul (#163393) 2025-09-23 15:37:38 +00:00
base.py [inductor][choices] move extra kwargs out of get_template_configs (#163209) 2025-09-20 05:30:40 +00:00
contiguous_mm.py [inductor][template heuristics] don't take layout to generate choices (#162238) 2025-09-09 17:17:04 +00:00
decompose_k.py [inductor][template heuristics] don't take layout to generate choices (#162238) 2025-09-09 17:17:04 +00:00
gemm.py [inductor][template heuristics] don't take layout to generate choices (#162238) 2025-09-09 17:17:04 +00:00
params.py [inductor][heuristics] add kernel template params (#162781) 2025-09-18 02:15:42 +00:00
registry.py [inductor][heuristics registry] missing heuristic is not an error anymore, cross device heuristics (#161767) 2025-08-29 22:41:27 +00:00
triton_addmm.py [inductor][template heuristics] don't take layout to generate choices (#162238) 2025-09-09 17:17:04 +00:00
triton.py [Inductor][Triton][FP8] Support deepseek-style scaling in Inductor (#164404) 2025-10-28 03:38:54 +00:00