mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add type hints to cuda kernel (#147471)
Missed this in a previous PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/147471 Approved by: https://github.com/eellison
This commit is contained in:
parent
48203bec63
commit
004d65aeb0
|
|
@ -1,13 +1,17 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING, Union
|
||||
from typing import Any, Callable, List, Literal, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from sympy import Expr, symbols
|
||||
|
||||
from torch import dtype as torch_dtype
|
||||
from torch._inductor.codegen.cpp_wrapper_cpu import CppWrapperCpu
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .cuda_template import ArgInfo
|
||||
|
||||
from ...autotune_process import CUDABenchmarkRequest
|
||||
from ...ir import (
|
||||
Buffer,
|
||||
|
|
@ -153,7 +157,12 @@ class CUDATemplateKernel(CUDAKernel):
|
|||
|
||||
_EXTRA_CPP_ARGS = "size_t* workspace_size, uint8_t* workspace, cudaStream_t stream"
|
||||
|
||||
def __init__(self, kernel_name, runtime_arg_info, runtime_arg_values) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
kernel_name: str,
|
||||
runtime_arg_info: List["ArgInfo"],
|
||||
runtime_arg_values: List[Any],
|
||||
) -> None:
|
||||
"""
|
||||
Initializes a new instance of the CUDATemplateKernel class.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user