mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor] Restore original dtype for rank-0 CPU tensors (#166118)
# Problem Inductor implicitly upcasts certain rank-0 kernel arguments from float16 to float32. Currently, this happens only on the `"cpu"` device, which appears to be related to float16 support in CPU Triton. However, it can also affect the behavior of GPU kernels, when a model contains tensors from multiple devices. Upcasting may be undesirable on some platforms, so users can typically disable it with the `config.triton.codegen_upcast_to_fp32` flag. However, this flag was not respected by the rank-0 kernel argument codepath. Through an improbable series of events, float32 upcasting caused an internal model to fail compilation on MTIA. (Internal reviewers see T242444110.) # Fix If `config.triton.codegen_upcast_to_fp32` evaluates to `False`, cast the kernel argument to the original dtype. # Test plan Added a new CI test checking for the downcast iff the config flag is false. The test mixes GPU and CPU tensors to generate a GPU kernel with the implicit float32 upcast and explicit float16 downcast. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166118 Approved by: https://github.com/jfix71, https://github.com/jansel, https://github.com/kundaMwiza
This commit is contained in:
parent
fdcf402d82
commit
0442125362
|
|
@ -9,7 +9,7 @@ from torch._dynamo.utils import disable_cache_limit
|
|||
from torch._inductor import config
|
||||
from torch._inductor.codegen.triton import OpDtypeSupport
|
||||
from torch._inductor.test_case import TestCase as InductorTestCase
|
||||
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
|
||||
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code, triton_type
|
||||
from torch.fx.operator_schemas import get_signature_for_torch_op
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
||||
|
|
@ -306,6 +306,34 @@ class TestCase(InductorTestCase):
|
|||
lambda acc, curr: acc + torch.abs(curr), x, dim=-1, combine_mode="pointwise"
|
||||
)
|
||||
|
||||
@parametrize("upcast_to_fp32", (False, True))
|
||||
@parametrize("dtype", (torch.float16, torch.bfloat16))
|
||||
def test_upcast_rank_0_cpu(self, dtype: torch.dtype, upcast_to_fp32: bool):
|
||||
"""
|
||||
Test whether we implicitly upcast CPU tensors of rank 0 to float32.
|
||||
"""
|
||||
|
||||
# Test broadcasting a rank-0 CPU tensor to rank 1.
|
||||
x = torch.randn(1, dtype=dtype, device="cpu")[0]
|
||||
y = torch.randn(8, dtype=dtype, device=GPU_TYPE)
|
||||
self.assertEqual(len(x.shape), 0)
|
||||
self.assertEqual(len(y.shape), 1)
|
||||
inps = (x, y)
|
||||
func = torch.add
|
||||
|
||||
with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32):
|
||||
compiled = torch.compile(func)
|
||||
result, (code,) = run_and_get_code(compiled, *inps)
|
||||
|
||||
# Check numerics.
|
||||
ref = func(*inps)
|
||||
self.assertTrue(torch.allclose(result, ref))
|
||||
|
||||
# Inductor upcasts CPU arguments of rank 0 to float32. Check for a downcast to
|
||||
# the original dtype.
|
||||
num_downcasts = code.count(f".to({triton_type(dtype)})")
|
||||
self.assertEqual(num_downcasts, 0 if upcast_to_fp32 else 1)
|
||||
|
||||
|
||||
instantiate_device_type_tests(TestCase, globals(), only_for=("cuda",))
|
||||
|
||||
|
|
|
|||
|
|
@ -3183,7 +3183,10 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
# unwrapped bf16/fp16 0d tensors are passed in as float32 scalars
|
||||
# see triton_utils.py:signature_of
|
||||
if dtype in (torch.float16, torch.bfloat16):
|
||||
dtype = torch.float32
|
||||
if config.triton.codegen_upcast_to_fp32:
|
||||
dtype = torch.float32
|
||||
else:
|
||||
line += f".to({triton_type(dtype)})"
|
||||
shape = ()
|
||||
|
||||
else:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user