mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Make user defined Triton kernels serializable for fx_graph_runnable (#160002)
Resolves issue https://github.com/pytorch/pytorch/issues/153475 where `fx_graph_runnable` didn't work with user defined triton kernels. Pull Request resolved: https://github.com/pytorch/pytorch/pull/160002 Approved by: https://github.com/eellison
This commit is contained in:
parent
b149c7204c
commit
cf0a0dcb0a
|
|
@ -11,12 +11,65 @@ import torch.distributed as dist
|
||||||
from torch._inductor.codecache import WritableTempFile
|
from torch._inductor.codecache import WritableTempFile
|
||||||
from torch._inductor.test_case import TestCase
|
from torch._inductor.test_case import TestCase
|
||||||
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
|
from torch.testing._internal.common_utils import IS_FBCODE, IS_SANDCASTLE
|
||||||
|
from torch.utils._triton import has_triton
|
||||||
|
|
||||||
|
|
||||||
if torch.distributed.is_available():
|
if torch.distributed.is_available():
|
||||||
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
from torch.distributed._tensor import DeviceMesh, DTensor, Replicate, Shard
|
||||||
from torch.testing._internal.distributed.fake_pg import FakeStore
|
from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||||
|
|
||||||
|
if has_triton():
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
|
||||||
|
def init_to_zero(name):
|
||||||
|
return lambda nargs: nargs[name].zero_()
|
||||||
|
|
||||||
|
@triton.jit
|
||||||
|
def add_kernel(x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
|
||||||
|
x = tl.load(x_ptr + offsets, mask=mask)
|
||||||
|
y = tl.load(y_ptr + offsets, mask=mask)
|
||||||
|
output = x + y
|
||||||
|
tl.atomic_add(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
triton.Config(
|
||||||
|
{"BLOCK_SIZE": 1024},
|
||||||
|
num_warps=4,
|
||||||
|
num_stages=2,
|
||||||
|
pre_hook=init_to_zero("output_ptr"),
|
||||||
|
)
|
||||||
|
],
|
||||||
|
pre_hook=init_to_zero("output_ptr"),
|
||||||
|
post_hook=init_to_zero("output_ptr"),
|
||||||
|
key=["n_elements"],
|
||||||
|
)
|
||||||
|
@triton.jit
|
||||||
|
def add_kernel_autotune(
|
||||||
|
x_ptr, y_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr
|
||||||
|
):
|
||||||
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
|
block_start = pid * BLOCK_SIZE
|
||||||
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
||||||
|
mask = offsets < n_elements
|
||||||
|
|
||||||
|
x = tl.load(x_ptr + offsets, mask=mask)
|
||||||
|
y = tl.load(y_ptr + offsets, mask=mask)
|
||||||
|
output = x + y
|
||||||
|
tl.atomic_add(output_ptr + offsets, output, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
|
from torch.testing._internal.inductor_utils import GPU_TYPE
|
||||||
|
from torch.testing._internal.triton_utils import requires_gpu
|
||||||
|
|
||||||
|
|
||||||
class FxGraphRunnableArtifactFilter(logging.Filter):
|
class FxGraphRunnableArtifactFilter(logging.Filter):
|
||||||
def filter(self, record):
|
def filter(self, record):
|
||||||
|
|
@ -100,6 +153,41 @@ class FxGraphRunnableTest(TestCase):
|
||||||
torch.compile(f)(torch.randn(4))
|
torch.compile(f)(torch.randn(4))
|
||||||
self._exec_and_verify_payload()
|
self._exec_and_verify_payload()
|
||||||
|
|
||||||
|
@unittest.skipUnless(has_triton(), "Triton not available")
|
||||||
|
def test_user_defined_triton_kernel_autotune(self):
|
||||||
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
|
||||||
|
n_elements = output.numel()
|
||||||
|
|
||||||
|
def grid(
|
||||||
|
meta,
|
||||||
|
):
|
||||||
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||||
|
|
||||||
|
add_kernel_autotune[grid](x, y, output, n_elements)
|
||||||
|
return output
|
||||||
|
|
||||||
|
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
||||||
|
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
||||||
|
|
||||||
|
torch.compile(add)(x, y)
|
||||||
|
self._exec_and_verify_payload()
|
||||||
|
|
||||||
|
@unittest.skipUnless(has_triton(), "Triton not available")
|
||||||
|
@requires_gpu
|
||||||
|
def test_user_defined_triton_kernel(self):
|
||||||
|
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||||
|
output = torch.ones(x.shape, device=x.device, dtype=x.dtype)
|
||||||
|
n_elements = x.numel()
|
||||||
|
add_kernel[n_elements,](x, y, output, n_elements, BLOCK_SIZE=4)
|
||||||
|
return output
|
||||||
|
|
||||||
|
x = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
||||||
|
y = torch.ones((4096,), device=GPU_TYPE, dtype=torch.float16)
|
||||||
|
|
||||||
|
torch.compile(add)(x, y)
|
||||||
|
self._exec_and_verify_payload()
|
||||||
|
|
||||||
def test_two_inputs_matmul(self):
|
def test_two_inputs_matmul(self):
|
||||||
def f(a, b):
|
def f(a, b):
|
||||||
return (a @ b).relu()
|
return (a @ b).relu()
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,24 @@ from tempfile import TemporaryFile
|
||||||
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
|
from typing import Any, Callable, IO, Optional, TYPE_CHECKING, Union
|
||||||
from typing_extensions import Unpack
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
|
from torch.utils._triton import has_triton
|
||||||
|
|
||||||
|
|
||||||
|
if has_triton():
|
||||||
|
from triton.runtime.autotuner import Autotuner, Heuristics
|
||||||
|
from triton.runtime.jit import JITFunction
|
||||||
|
else:
|
||||||
|
|
||||||
|
class Autotuner: # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
|
class JITFunction: # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
|
class Heuristics: # type: ignore[no-redef]
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.fx as fx
|
import torch.fx as fx
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
@ -58,6 +76,7 @@ from torch._dynamo.debug_utils import (
|
||||||
)
|
)
|
||||||
from torch._dynamo.utils import clone_inputs, counters, same
|
from torch._dynamo.utils import clone_inputs, counters, same
|
||||||
from torch._environment import is_fbcode
|
from torch._environment import is_fbcode
|
||||||
|
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
||||||
from torch._inductor.cpp_builder import normalize_path_separator
|
from torch._inductor.cpp_builder import normalize_path_separator
|
||||||
from torch._inductor.output_code import OutputCode
|
from torch._inductor.output_code import OutputCode
|
||||||
from torch._library.fake_class_registry import FakeScriptObject
|
from torch._library.fake_class_registry import FakeScriptObject
|
||||||
|
|
@ -302,6 +321,16 @@ from torch.testing._internal.distributed.fake_pg import FakeStore
|
||||||
"""
|
"""
|
||||||
).strip()
|
).strip()
|
||||||
|
|
||||||
|
triton_imports = ""
|
||||||
|
|
||||||
|
if len(kernel_side_table.id_to_kernel) > 0:
|
||||||
|
triton_imports = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
import triton
|
||||||
|
import triton.language as tl
|
||||||
|
"""
|
||||||
|
).strip()
|
||||||
|
|
||||||
model_str = textwrap.dedent(
|
model_str = textwrap.dedent(
|
||||||
f"""
|
f"""
|
||||||
{generate_env_vars_string(stable_output=stable_output)}
|
{generate_env_vars_string(stable_output=stable_output)}
|
||||||
|
|
@ -312,6 +341,7 @@ from torch._dynamo.testing import rand_strided
|
||||||
from math import inf
|
from math import inf
|
||||||
import torch._inductor.inductor_prims
|
import torch._inductor.inductor_prims
|
||||||
{distributed_imports}
|
{distributed_imports}
|
||||||
|
{triton_imports}
|
||||||
|
|
||||||
{generate_config_string(stable_output=stable_output)}
|
{generate_config_string(stable_output=stable_output)}
|
||||||
|
|
||||||
|
|
@ -330,6 +360,53 @@ isolate_fails_code_str = None
|
||||||
model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
|
model_str += f"# torch git version: {torch.version.git_version}\n\n\n"
|
||||||
model_str += _cuda_system_info_comment()
|
model_str += _cuda_system_info_comment()
|
||||||
|
|
||||||
|
kernel_side_table_prefix = (
|
||||||
|
"torch._higher_order_ops.triton_kernel_wrap.kernel_side_table"
|
||||||
|
)
|
||||||
|
# Track which grid entry corresponds to the best config
|
||||||
|
for id in kernel_side_table.id_to_kernel:
|
||||||
|
kernel = kernel_side_table.get_kernel(id)
|
||||||
|
|
||||||
|
if isinstance(kernel, Autotuner):
|
||||||
|
if isinstance(kernel.fn, Heuristics):
|
||||||
|
model_str += "ERROR: Repro will not work as intended, "
|
||||||
|
model_str += (
|
||||||
|
"triton.runtime.autotuner.Heuristics is not currently supported\n"
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
config_strs = []
|
||||||
|
for kernel_config in kernel.configs:
|
||||||
|
config_strs.append(f"""triton.Config(
|
||||||
|
{str(kernel_config.kwargs)},
|
||||||
|
num_warps={kernel_config.num_warps},
|
||||||
|
num_stages={kernel_config.num_stages},
|
||||||
|
)""")
|
||||||
|
|
||||||
|
config_str = ",".join(config_strs)
|
||||||
|
model_str += textwrap.dedent(f"""
|
||||||
|
@triton.autotune(
|
||||||
|
configs=[
|
||||||
|
{config_str}
|
||||||
|
],
|
||||||
|
key=[]
|
||||||
|
)
|
||||||
|
""").strip()
|
||||||
|
|
||||||
|
model_str += "\n@triton.jit\n"
|
||||||
|
src_code = kernel.src if isinstance(kernel, JITFunction) else kernel.fn.src
|
||||||
|
fn_name = (
|
||||||
|
kernel._fn_name if isinstance(kernel, JITFunction) else kernel.fn._fn_name
|
||||||
|
)
|
||||||
|
fn_name = fn_name.split(".")[-1]
|
||||||
|
|
||||||
|
model_str += src_code
|
||||||
|
model_str += "\n"
|
||||||
|
model_str += f"{kernel_side_table_prefix}.add_kernel({fn_name})\n"
|
||||||
|
|
||||||
|
if len(kernel_side_table.constant_args) > 0:
|
||||||
|
model_str += f"{kernel_side_table_prefix}.constant_args={kernel_side_table.constant_args}\n"
|
||||||
|
|
||||||
model_str += NNModuleToString.convert(gm)
|
model_str += NNModuleToString.convert(gm)
|
||||||
|
|
||||||
writer = InputWriter(save_dir, stable_hash=stable_hash)
|
writer = InputWriter(save_dir, stable_hash=stable_hash)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user