mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[inductor] Fix create_specialize_impl error in latest Triton (#148933)
```py
$ python test/inductor/test_triton_kernels.py KernelTests.test_triton_kernel_2d_autotune_grad_False_dynamic_True_backend_inductor_grid_type_1
WARNING:torch._dynamo:Encountered an exception in identify_mutated_tensors, assuming every input is mutated
Traceback (most recent call last):
File "/home/jansel/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 715, in identify_mutated_tensors
ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 289, in generate_ttir
specialization = _get_specialization(ordered_args.values())
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/jansel/pytorch/torch/_higher_order_ops/triton_kernel_wrap.py", line 262, in _get_specialization
specialize_impl = triton.runtime.jit.create_specialize_impl()
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
TypeError: create_specialize_impl() missing 1 required positional argument: 'specialize_extra'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148933
Approved by: https://github.com/yanboliang, https://github.com/davidberard98
This commit is contained in:
parent
16560d4e8f
commit
09029010e5
|
|
@ -1,6 +1,7 @@
|
|||
import collections
|
||||
import copy
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
|
|
@ -259,9 +260,24 @@ def generate_ttir(
|
|||
)
|
||||
# specialize_impl switched to create_specialize_impl in https://github.com/triton-lang/triton/pull/6099
|
||||
if hasattr(triton.runtime.jit, "create_specialize_impl"):
|
||||
specialize_impl = triton.runtime.jit.create_specialize_impl()
|
||||
try:
|
||||
# Latest versions of Triton take specialize_extra as an arg to create_specialize_impl
|
||||
specialize_impl = triton.runtime.jit.create_specialize_impl(
|
||||
specialize_extra=backend.get_arg_specialization
|
||||
)
|
||||
except TypeError: # Unknown arg `specialize_extra`
|
||||
# Older versions of Triton take specialize_extra as an arg to specialize_impl
|
||||
specialize_impl = functools.partial(
|
||||
triton.runtime.jit.create_specialize_impl(),
|
||||
specialize_extra=backend.get_arg_specialization,
|
||||
)
|
||||
else:
|
||||
from triton.runtime.jit import specialize_impl # type: ignore[no-redef]
|
||||
from triton.runtime.jit import specialize_impl as specialize_impl_orig
|
||||
|
||||
specialize_impl = functools.partial(
|
||||
specialize_impl_orig,
|
||||
specialize_extra=backend.get_arg_specialization,
|
||||
)
|
||||
|
||||
from triton._utils import find_paths_if, get_iterable_path
|
||||
|
||||
|
|
@ -273,7 +289,6 @@ def generate_ttir(
|
|||
else:
|
||||
spec = specialize_impl(
|
||||
arg,
|
||||
specialize_extra=backend.get_arg_specialization,
|
||||
is_const=kp.is_const,
|
||||
specialize_value=not kp.do_not_specialize,
|
||||
align=not kp.do_not_specialize_on_alignment,
|
||||
|
|
|
|||
|
|
@ -412,7 +412,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
"""
|
||||
bool _check_aoti_runtime_check_inputs_env() {
|
||||
const static char* env_var_value = getenv("AOTI_RUNTIME_CHECK_INPUTS");
|
||||
const static bool result = env_var_value != nullptr && env_var_value[0] != '\0';
|
||||
const static bool result = env_var_value != nullptr && env_var_value[0] != 0;
|
||||
return result;
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user