pytorch/torch/testing/_internal/triton_utils.py
Edward Z. Yang 9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00

187 lines
5.8 KiB
Python

# mypy: ignore-errors
import unittest
from torch.testing._internal.inductor_utils import HAS_CUDA
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
if HAS_CUDA:
import triton
from triton import language as tl
# Define here so that multiple tests can take advantage of it
@triton.jit
def add_kernel(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def add_kernel_with_optional_param(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
ARGS_PASSED: "tl.constexpr",
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
if ARGS_PASSED == "two":
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
else:
output = x
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config({"BLOCK_SIZE": 128}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 128}, num_stages=4, num_warps=4),
triton.Config({"BLOCK_SIZE": 64}, num_stages=3, num_warps=8),
triton.Config({"BLOCK_SIZE": 64}, num_stages=4, num_warps=4),
],
key=[],
)
@triton.jit
def add_kernel_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
tl.store(out_ptr + offsets, output, mask=mask)
@triton.autotune(
configs=[
triton.Config(
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_X": 128, "BLOCK_SIZE_Y": 128}, num_stages=4, num_warps=4
),
triton.Config(
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=3, num_warps=8
),
triton.Config(
{"BLOCK_SIZE_X": 64, "BLOCK_SIZE_Y": 64}, num_stages=4, num_warps=4
),
],
key=[],
)
@triton.jit
def add_kernel_2d_autotuned(
in_ptr0,
in_ptr1,
out_ptr,
x_elements,
y_elements,
BLOCK_SIZE_X: "tl.constexpr",
BLOCK_SIZE_Y: "tl.constexpr",
):
xoffset = tl.program_id(0) * BLOCK_SIZE_X
xindex = xoffset + tl.arange(0, BLOCK_SIZE_X)[:, None]
xmask = xindex < x_elements
yoffset = tl.program_id(1) * BLOCK_SIZE_Y
yindex = yoffset + tl.arange(0, BLOCK_SIZE_Y)[None, :]
ymask = yindex < y_elements
x1 = xindex
y0 = yindex
tmp0 = tl.load(in_ptr0 + (x1 + (x_elements * y0)), xmask & ymask)
tmp1 = tl.load(in_ptr0 + (y0 + (y_elements * x1)), xmask & ymask)
tmp2 = tmp0 + tmp1
tl.store(out_ptr + (x1 + (x_elements * y0)), tmp2, xmask & ymask)
@triton.jit
def mul2_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
output = 2 * x
tl.store(out_ptr + offsets, output, mask=mask)
@triton.jit
def mul2_inplace_kernel(
ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
x = tl.load(ptr + offsets, mask=mask)
output = 2 * x
tl.store(ptr + offsets, output, mask=mask)
@triton.jit
def zero_negs(x):
return tl.where(x >= 0, x, 0)
@triton.jit
def indirection_kernel(
in_ptr0,
out_ptr,
n_elements,
BLOCK_SIZE: "tl.constexpr",
ACTIVATION: "tl.constexpr",
):
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
if ACTIVATION == "mul2_inplace_kernel":
mul2_inplace_kernel(in_ptr0, n_elements, BLOCK_SIZE=BLOCK_SIZE)
x = tl.load(in_ptr0 + offsets, mask=mask)
tl.store(out_ptr + offsets, x, mask=mask)
@triton.jit
def double_strided_kernel(
in_ptr,
out_ptr,
in_y_stride,
out_y_stride,
X_BLOCK_SIZE: "tl.constexpr",
Y_BLOCK_SIZE: "tl.constexpr",
):
xid = tl.program_id(axis=0)
yid = tl.program_id(axis=1)
x_start = xid * X_BLOCK_SIZE
y_start = yid * Y_BLOCK_SIZE
x_offsets = x_start + tl.arange(0, X_BLOCK_SIZE)
y_offsets = y_start + tl.arange(0, Y_BLOCK_SIZE)
src_offsets = y_offsets[:, None] * in_y_stride + x_offsets[None, :]
dst_offsets = y_offsets[:, None] * out_y_stride + x_offsets[None, :]
src = tl.load(in_ptr + src_offsets)
tl.store(out_ptr + dst_offsets, src * 2.0)