[inductor] Fix create_specialize_impl error in latest Triton (#148933)

```py
$ python test/inductor/test_triton_kernels.py KernelTests.test_triton_kernel_2d_autotune_grad_False_dynamic_True_backend_inductor_grid_type_1
WARNING:torch._dynamo:Encountered an exception in identify_mutated_tensors, assuming every input is mutated
Traceback (most recent call last):
  File "/home/jansel/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 715, in identify_mutated_tensors
    ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jansel/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 289, in generate_ttir
    specialization = _get_specialization(ordered_args.values())
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/jansel/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 262, in _get_specialization
    specialize_impl = triton.runtime.jit.create_specialize_impl()
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: create_specialize_impl() missing 1 required positional argument: 'specialize_extra'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148933
Approved by: https://github.com/yanboliang, https://github.com/davidberard98
This commit is contained in:
Jason Ansel 2025-03-10 17:48:13 -07:00 committed by PyTorch MergeBot
parent 16560d4e8f
commit 09029010e5
2 changed files with 19 additions and 4 deletions

View File

@ -1,6 +1,7 @@
import collections
import copy
import dataclasses
import functools
import inspect
import logging
import threading
@ -259,9 +260,24 @@ def generate_ttir(
)
# specialize_impl switched to create_specialize_impl in https://github.com/triton-lang/triton/pull/6099
if hasattr(triton.runtime.jit, "create_specialize_impl"):
specialize_impl = triton.runtime.jit.create_specialize_impl()
try:
# Latest versions of Triton take specialize_extra as an arg to create_specialize_impl
specialize_impl = triton.runtime.jit.create_specialize_impl(
specialize_extra=backend.get_arg_specialization
)
except TypeError: # Unknown arg `specialize_extra`
# Older versions of Triton take specialize_extra as an arg to specialize_impl
specialize_impl = functools.partial(
triton.runtime.jit.create_specialize_impl(),
specialize_extra=backend.get_arg_specialization,
)
else:
from triton.runtime.jit import specialize_impl # type: ignore[no-redef]
from triton.runtime.jit import specialize_impl as specialize_impl_orig
specialize_impl = functools.partial(
specialize_impl_orig,
specialize_extra=backend.get_arg_specialization,
)
from triton._utils import find_paths_if, get_iterable_path
@ -273,7 +289,6 @@ def generate_ttir(
else:
spec = specialize_impl(
arg,
specialize_extra=backend.get_arg_specialization,
is_const=kp.is_const,
specialize_value=not kp.do_not_specialize,
align=not kp.do_not_specialize_on_alignment,

View File

@ -412,7 +412,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
"""
bool _check_aoti_runtime_check_inputs_env() {
const static char* env_var_value = getenv("AOTI_RUNTIME_CHECK_INPUTS");
const static bool result = env_var_value != nullptr && env_var_value[0] != '\0';
const static bool result = env_var_value != nullptr && env_var_value[0] != 0;
return result;
}