[Inductor] Restore original dtype for rank-0 CPU tensors (#166118)

# Problem
Inductor implicitly upcasts certain rank-0 kernel arguments from float16 to float32. Currently, this happens only on the `"cpu"` device, which appears to be related to float16 support in CPU Triton. However, it can also affect the behavior of GPU kernels, when a model contains tensors from multiple devices. Upcasting may be undesirable on some platforms, so users can typically disable it with the `config.triton.codegen_upcast_to_fp32` flag. However, this flag was not respected by the rank-0 kernel argument codepath.

Through an improbable series of events, float32 upcasting caused an internal model to fail compilation on MTIA. (Internal reviewers see T242444110.)

# Fix
If `config.triton.codegen_upcast_to_fp32` evaluates to `False`, cast the kernel argument to the original dtype.

# Test plan
Added a new CI test checking for the downcast iff the config flag is false. The test mixes GPU and CPU tensors to generate a GPU kernel with the implicit float32 upcast and explicit float16 downcast.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166118
Approved by: https://github.com/jfix71, https://github.com/jansel, https://github.com/kundaMwiza
This commit is contained in:
Blaine Burton Rister 2025-10-24 19:59:25 +00:00 committed by PyTorch MergeBot
parent fdcf402d82
commit 0442125362
2 changed files with 33 additions and 2 deletions

View File

@ -9,7 +9,7 @@ from torch._dynamo.utils import disable_cache_limit
from torch._inductor import config
from torch._inductor.codegen.triton import OpDtypeSupport
from torch._inductor.test_case import TestCase as InductorTestCase
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code, triton_type
from torch.fx.operator_schemas import get_signature_for_torch_op
from torch.testing import FileCheck
from torch.testing._internal.common_device_type import instantiate_device_type_tests
@ -306,6 +306,34 @@ class TestCase(InductorTestCase):
lambda acc, curr: acc + torch.abs(curr), x, dim=-1, combine_mode="pointwise"
)
@parametrize("upcast_to_fp32", (False, True))
@parametrize("dtype", (torch.float16, torch.bfloat16))
def test_upcast_rank_0_cpu(self, dtype: torch.dtype, upcast_to_fp32: bool):
"""
Test whether we implicitly upcast CPU tensors of rank 0 to float32.
"""
# Test broadcasting a rank-0 CPU tensor to rank 1.
x = torch.randn(1, dtype=dtype, device="cpu")[0]
y = torch.randn(8, dtype=dtype, device=GPU_TYPE)
self.assertEqual(len(x.shape), 0)
self.assertEqual(len(y.shape), 1)
inps = (x, y)
func = torch.add
with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32):
compiled = torch.compile(func)
result, (code,) = run_and_get_code(compiled, *inps)
# Check numerics.
ref = func(*inps)
self.assertTrue(torch.allclose(result, ref))
# Inductor upcasts CPU arguments of rank 0 to float32. Check for a downcast to
# the original dtype.
num_downcasts = code.count(f".to({triton_type(dtype)})")
self.assertEqual(num_downcasts, 0 if upcast_to_fp32 else 1)
instantiate_device_type_tests(TestCase, globals(), only_for=("cuda",))

View File

@ -3183,7 +3183,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
# unwrapped bf16/fp16 0d tensors are passed in as float32 scalars
# see triton_utils.py:signature_of
if dtype in (torch.float16, torch.bfloat16):
dtype = torch.float32
if config.triton.codegen_upcast_to_fp32:
dtype = torch.float32
else:
line += f".to({triton_type(dtype)})"
shape = ()
else: