mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Inductor/Triton] Customize triton codegen to optionally preserve input dtype on tl.load (#132406)
Differential Revision: D60536337 Pull Request resolved: https://github.com/pytorch/pytorch/pull/132406 Approved by: https://github.com/jfix71, https://github.com/blaine-rister
This commit is contained in:
parent
8ff3a5be1b
commit
1f19ccb5b3
|
|
@ -11396,6 +11396,7 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||||
|
|
||||||
copy_tests(CommonTemplate, GPUTests, GPU_TYPE)
|
copy_tests(CommonTemplate, GPUTests, GPU_TYPE)
|
||||||
|
|
||||||
|
@instantiate_parametrized_tests
|
||||||
class TritonCodeGenTests(TestCase):
|
class TritonCodeGenTests(TestCase):
|
||||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||||
|
|
||||||
|
|
@ -11900,6 +11901,20 @@ if HAS_GPU and not TEST_WITH_ASAN:
|
||||||
# it does not move the tensor constructor to cuda and keeps it on CPU.
|
# it does not move the tensor constructor to cuda and keeps it on CPU.
|
||||||
self.assertFalse("empty_strided_cuda(()" in code)
|
self.assertFalse("empty_strided_cuda(()" in code)
|
||||||
|
|
||||||
|
@requires_gpu()
|
||||||
|
@parametrize("upcast_to_fp32", [False, True])
|
||||||
|
def test_codegen_upcast_to_fp32(self, upcast_to_fp32):
|
||||||
|
@torch.compile
|
||||||
|
def func(a, b):
|
||||||
|
return a * b
|
||||||
|
|
||||||
|
inps = (torch.rand((32, 32), device=GPU_TYPE, dtype=torch.float16),) * 2
|
||||||
|
with config.patch("triton.codegen_upcast_to_fp32", upcast_to_fp32):
|
||||||
|
func_opt = torch._dynamo.optimize("inductor")(func)
|
||||||
|
code = run_and_get_triton_code(func_opt, *inps)
|
||||||
|
fp32_cast_in_code = "float32" in code
|
||||||
|
self.assertEqual(fp32_cast_in_code, upcast_to_fp32)
|
||||||
|
|
||||||
@config.patch("triton.use_block_ptr", False)
|
@config.patch("triton.use_block_ptr", False)
|
||||||
def test_evict_last_non_coalesced_loads(self):
|
def test_evict_last_non_coalesced_loads(self):
|
||||||
@torch.compile
|
@torch.compile
|
||||||
|
|
|
||||||
|
|
@ -539,7 +539,10 @@ def triton_compute_type(dtype):
|
||||||
triton_type_name = str(dtype).split(".")[-1]
|
triton_type_name = str(dtype).split(".")[-1]
|
||||||
if triton_type_name == "bool":
|
if triton_type_name == "bool":
|
||||||
triton_type_name = "int1"
|
triton_type_name = "int1"
|
||||||
elif triton_type_name in ("float16", "bfloat16"):
|
elif (
|
||||||
|
triton_type_name in ("float16", "bfloat16")
|
||||||
|
and config.triton.codegen_upcast_to_fp32
|
||||||
|
):
|
||||||
# float16 math is done in float32 inside the kernel
|
# float16 math is done in float32 inside the kernel
|
||||||
triton_type_name = "float32"
|
triton_type_name = "float32"
|
||||||
elif triton_type_name == "float8_e4m3fn":
|
elif triton_type_name == "float8_e4m3fn":
|
||||||
|
|
@ -557,7 +560,10 @@ def _get_primitive_bitwidth(dtype):
|
||||||
if hasattr(dtype, "is_floating_point"):
|
if hasattr(dtype, "is_floating_point"):
|
||||||
if dtype.is_floating_point:
|
if dtype.is_floating_point:
|
||||||
# triton_compute_type changes the bitwidth
|
# triton_compute_type changes the bitwidth
|
||||||
if dtype in [torch.bfloat16, torch.float16]:
|
if (
|
||||||
|
dtype in [torch.bfloat16, torch.float16]
|
||||||
|
and config.triton.codegen_upcast_to_fp32
|
||||||
|
):
|
||||||
return 32
|
return 32
|
||||||
return torch.finfo(dtype).bits
|
return torch.finfo(dtype).bits
|
||||||
else:
|
else:
|
||||||
|
|
@ -669,7 +675,10 @@ class TritonOverrides(OpOverrides):
|
||||||
# In such as case, we will have to convert the input tensor to
|
# In such as case, we will have to convert the input tensor to
|
||||||
# its src_type, perform bitcast, and then convert the bit-casted
|
# its src_type, perform bitcast, and then convert the bit-casted
|
||||||
# tensor back to float to ensure we use values with the right precision.
|
# tensor back to float to ensure we use values with the right precision.
|
||||||
if src_dtype in (torch.float16, torch.bfloat16):
|
if (
|
||||||
|
src_dtype in (torch.float16, torch.bfloat16)
|
||||||
|
and config.triton.codegen_upcast_to_fp32
|
||||||
|
):
|
||||||
triton_src_dtype = str(src_dtype).split(".")[-1]
|
triton_src_dtype = str(src_dtype).split(".")[-1]
|
||||||
cast_x = f"{x}.to(tl.{triton_src_dtype})"
|
cast_x = f"{x}.to(tl.{triton_src_dtype})"
|
||||||
if dtype in (torch.float16, torch.bfloat16):
|
if dtype in (torch.float16, torch.bfloat16):
|
||||||
|
|
@ -1778,7 +1787,10 @@ class TritonKernel(SIMDKernel):
|
||||||
line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})"
|
line = f"tl.load({var} + ({indexing.index_str}), {indexing.mask_str}{ep}{other})"
|
||||||
|
|
||||||
dtype = V.graph.get_dtype(name)
|
dtype = V.graph.get_dtype(name)
|
||||||
if dtype in (torch.float16, torch.bfloat16):
|
if (
|
||||||
|
dtype in (torch.float16, torch.bfloat16)
|
||||||
|
and config.triton.codegen_upcast_to_fp32
|
||||||
|
):
|
||||||
line += ".to(tl.float32)"
|
line += ".to(tl.float32)"
|
||||||
if dtype == torch.bool and torch.version.hip is None:
|
if dtype == torch.bool and torch.version.hip is None:
|
||||||
# Workaround for https://github.com/openai/triton/issues/2151
|
# Workaround for https://github.com/openai/triton/issues/2151
|
||||||
|
|
|
||||||
|
|
@ -899,6 +899,9 @@ class triton:
|
||||||
# Valid values: "compile_error", "runtime_error", "accuracy"
|
# Valid values: "compile_error", "runtime_error", "accuracy"
|
||||||
inject_relu_bug_TESTING_ONLY: Optional[str] = None
|
inject_relu_bug_TESTING_ONLY: Optional[str] = None
|
||||||
|
|
||||||
|
# Whether to upcast float16 / bfloat16 to float32 in triton codegen (Experimental)
|
||||||
|
codegen_upcast_to_fp32 = True
|
||||||
|
|
||||||
|
|
||||||
class aot_inductor:
|
class aot_inductor:
|
||||||
# AOTInductor output path
|
# AOTInductor output path
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user