Add type hints to cuda kernel (#147471)

Missed this in a previous PR

Pull Request resolved: https://github.com/pytorch/pytorch/pull/147471
Approved by: https://github.com/eellison
This commit is contained in:
Michael Lazos 2025-02-19 23:35:07 +00:00 committed by PyTorch MergeBot
parent 48203bec63
commit 004d65aeb0

View File

@ -1,13 +1,17 @@
# mypy: allow-untyped-defs
import logging
from dataclasses import dataclass
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
from typing import Any, Callable, List, Literal, Optional, TYPE_CHECKING, Union
from sympy import Expr, symbols
from torch import dtype as torch_dtype
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
if TYPE_CHECKING:
from .cuda_template import ArgInfo
from ...autotune_process import CUDABenchmarkRequest
from ...ir import (
Buffer,
@ -153,7 +157,12 @@ class CUDATemplateKernel(CUDAKernel):
_EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
def __init__(self, kernel_name, runtime_arg_info, runtime_arg_values) -> None:
def __init__(
self,
kernel_name: str,
runtime_arg_info: List["ArgInfo"],
runtime_arg_values: List[Any],
) -> None:
"""
Initializes a new instance of the CUDATemplateKernel class.