mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Fixed a bunch of fbcode imports that happened to work but confused autodeps. After this autodeps still suggests "improvements" to TARGETS (which breaks our builds) but at least it can find all the imports. Test Plan: ``` fbpython fbcode/tools/build/buck/linters/lint_autoformat.py --linter=autodeps --default-exec-timeout=1800 -- fbcode/caffe2/TARGETS fbcode/caffe2/test/TARGETS ``` Before: ``` ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/testing.py:229) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fbur$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export.py:87) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_serdes.py:9) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fb$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_serdes.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https://fburl$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_retraceability.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See https:$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_retraceability.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See ht$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_nonstrict.py:7) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See http$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_nonstrict.py:6) when processing rule "test_export". Please make sure it's listed in the srcs parameter of another rule. See $ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "test_export" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:8) when processing rule "test_export". Please make sure it's listed in the srcs parameter of an$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "testing" (from caffe2/test/export/test_export_training_ir_to_run_decomp.py:10) when processing rule "test_export". Please make sure it's listed in the srcs parameter of anoth$ ERROR while processing caffe2/test/TARGETS: Found "//python/typeshed_internal:typeshed_internal_library" owner for "cv2" but it is protected by visibility rules: [] (from caffe2/test/test_bundled_images.py:7) when processing rule "test_bundled_$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "caffe2.test.profiler_test_cpp_thread_lib" (from caffe2/test/profiler/test_cpp_thread.py:29) when processing rule "profiler_test_cpp_thread". Please make sure it's listed in t$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_custom_ops.py:23) when processing rule "custom_ops". Please make sure it's listed in the srcs parameter of anoth$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._utils_internal.get_file_path_2" (from caffe2/test/test_public_bindings.py:13) when processing rule "public_bindings". Please make sure it's listed in the srcs paramete$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.symbolize_tracebacks" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another $ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for "torch._C._profiler.gather_traceback" (from caffe2/test/test_cuda.py:3348) when processing rule "test_cuda". Please make sure it's listed in the srcs parameter of another rule$ ERROR while processing caffe2/test/TARGETS: Cannot find an owner for include <torch/csrc/autograd/profiler_kineto.h> (from caffe2/test/profiler/test_cpp_thread.cpp:2) when processing profiler_test_cpp_thread_lib. Some things to try: ``` Differential Revision: D62049222 Pull Request resolved: https://github.com/pytorch/pytorch/pull/135614 Approved by: https://github.com/oulgen, https://github.com/laithsakka
1386 lines
50 KiB
Python
1386 lines
50 KiB
Python
# Owner(s): ["module: inductor"]
|
|
import gc
|
|
import math
|
|
import sys
|
|
import unittest
|
|
|
|
import torch
|
|
import torch._dynamo.config as dynamo_config
|
|
import torch.backends.cuda
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
from torch._dynamo.debug_utils import same_two_models
|
|
from torch._dynamo.testing import rand_strided
|
|
from torch._dynamo.utils import same
|
|
from torch._inductor import config
|
|
from torch._inductor.compile_fx import compile_fx_inner
|
|
from torch._inductor.runtime.hints import DeviceProperties
|
|
from torch._inductor.utils import (
|
|
run_and_get_code,
|
|
run_and_get_graph_lowering,
|
|
run_fw_bw_and_get_code,
|
|
)
|
|
from torch.fx.experimental.proxy_tensor import make_fx
|
|
from torch.testing import FileCheck
|
|
from torch.testing._internal.common_cuda import (
|
|
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
|
SM80OrLater,
|
|
)
|
|
from torch.testing._internal.common_utils import (
|
|
DeterministicGuard,
|
|
freeze_rng_state,
|
|
IS_FBCODE,
|
|
skipIfRocm,
|
|
TEST_WITH_ASAN,
|
|
)
|
|
from torch.testing._internal.inductor_utils import skipCUDAIf
|
|
|
|
|
|
try:
|
|
try:
|
|
import triton # @manual
|
|
from triton import language as tl # @manual
|
|
except ImportError:
|
|
raise unittest.SkipTest("requires triton") # noqa: B904
|
|
|
|
try:
|
|
from . import test_torchinductor
|
|
except ImportError:
|
|
import test_torchinductor # @manual=fbcode//caffe2/test/inductor:test_inductor-library
|
|
except unittest.SkipTest:
|
|
if __name__ == "__main__":
|
|
sys.exit(0)
|
|
raise
|
|
|
|
|
|
TestCase = test_torchinductor.TestCase
|
|
ToTuple = test_torchinductor.ToTuple
|
|
check_model_cuda = test_torchinductor.check_model_cuda
|
|
aten = torch.ops.aten
|
|
|
|
|
|
class CudaReproTests(TestCase):
|
|
device = "cuda"
|
|
common = check_model_cuda
|
|
|
|
def test_index_put_issue(self):
|
|
def forward(
|
|
self,
|
|
arg76_1,
|
|
expand_default,
|
|
full_like_default,
|
|
_to_copy_default_67,
|
|
zeros,
|
|
):
|
|
sum_sym_int_19 = torch.ops.aten.sum(_to_copy_default_67, [0], True)
|
|
view_default_57 = torch.ops.aten.view.default(sum_sym_int_19, [512, 768])
|
|
where_self = torch.ops.aten.where.self(
|
|
expand_default, view_default_57, full_like_default
|
|
)
|
|
clone_default_12 = torch.ops.aten.clone.default(zeros)
|
|
index_put__default = torch.ops.aten.index_put_.default(
|
|
clone_default_12, [arg76_1], where_self, True
|
|
)
|
|
return (index_put__default,)
|
|
|
|
inps = [
|
|
(torch.Size([512]), torch.int64),
|
|
(torch.Size([512, 768]), torch.bool),
|
|
(torch.Size([512, 768]), torch.float16),
|
|
(torch.Size([4, 512, 768]), torch.float16),
|
|
(torch.Size([512, 768]), torch.float16),
|
|
]
|
|
inps = [torch.zeros(())] + [
|
|
torch.ones(shape, dtype=dtype, device="cuda") for (shape, dtype) in inps
|
|
]
|
|
mod = make_fx(forward)(*inps)
|
|
compiled = compile_fx_inner(mod, inps)
|
|
compiled(inps)
|
|
|
|
@skipIfRocm
|
|
def test_input_channels_last(self):
|
|
m = torch.nn.Sequential(
|
|
torch.nn.Conv2d(3, 3, 1, 1),
|
|
ToTuple(),
|
|
).cuda()
|
|
inp = torch.randn([2, 3, 16, 16]).to(memory_format=torch.channels_last).cuda()
|
|
|
|
self.common(
|
|
m,
|
|
(inp,),
|
|
check_lowp=False,
|
|
)
|
|
|
|
@torch._dynamo.optimize()
|
|
def foo(m, inp):
|
|
return m(inp)
|
|
|
|
self.assertTrue(foo(m, inp)[0].is_contiguous(memory_format=torch.channels_last))
|
|
|
|
# https://github.com/pytorch/torchdynamo/issues/1681#issuecomment-1283433527
|
|
def test_unspec_inputs_interop(self):
|
|
class Repro(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
unsqueeze = torch.ops.aten.unsqueeze.default(x, 4)
|
|
permute = torch.ops.aten.permute.default(unsqueeze, [0, 1, 2, 4, 3])
|
|
add = torch.ops.aten.add.Tensor(y, 1)
|
|
return [permute, add]
|
|
|
|
inps = [
|
|
rand_strided((12, 3, 512, 64), (64, 196608, 768, 1), torch.float32, "cuda"),
|
|
rand_strided((), (), torch.int64, "cpu"),
|
|
]
|
|
mod = make_fx(Repro().to(device="cuda"))(*inps)
|
|
compiled = compile_fx_inner(mod, inps)
|
|
compiled(inps)
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE, "RuntimeError: Triton Error [CUDA]: invalid device context"
|
|
)
|
|
def test_backward_context(self):
|
|
def fn(x):
|
|
return x * 3
|
|
|
|
x = torch.randn(4, device="cuda", requires_grad=True)
|
|
gO = torch.rand_like(x)
|
|
opt_fn = torch.compile(fn)
|
|
out = opt_fn(x)
|
|
out.backward(gO)
|
|
|
|
@config.patch(fallback_random=True)
|
|
def test_dtype_factory_issue(self):
|
|
def forward():
|
|
randn = torch.ops.aten.randn.default(
|
|
[12, 64, 1, 64],
|
|
dtype=torch.float32,
|
|
device=torch.device(type="cuda", index=0),
|
|
pin_memory=False,
|
|
)
|
|
unsqueeze_default_2 = torch.ops.aten.unsqueeze.default(randn, -1)
|
|
return (unsqueeze_default_2,)
|
|
|
|
mod = make_fx(forward)()
|
|
compiled = compile_fx_inner(mod, ())
|
|
assert compiled([])[0].device.type == "cuda"
|
|
|
|
@config.patch({"triton.cudagraphs": True})
|
|
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
|
def test_no_device_idx_repro_cudagraphs(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self):
|
|
full = torch.ops.aten.full.default(
|
|
[8, 512],
|
|
1,
|
|
dtype=torch.float32,
|
|
layout=torch.strided,
|
|
device=torch.device(type="cuda", index=0),
|
|
pin_memory=False,
|
|
)
|
|
full_1 = torch.ops.aten.full.default(
|
|
[8, 512],
|
|
0,
|
|
dtype=torch.int64,
|
|
layout=torch.strided,
|
|
device=torch.device(type="cuda", index=0),
|
|
pin_memory=False,
|
|
)
|
|
return (full_1, full)
|
|
|
|
self.common(Repro(), ())
|
|
|
|
@config.patch({"triton.cudagraphs": True})
|
|
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
|
def test_expanded_inputs_cudagraphs(self):
|
|
@torch._dynamo.optimize("inductor")
|
|
def fn(x, y):
|
|
return x + y
|
|
|
|
inputs = (
|
|
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
|
|
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
|
|
)
|
|
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
|
|
|
|
@config.patch({"triton.cudagraphs": True})
|
|
@dynamo_config.patch(
|
|
automatic_dynamic_shapes=True,
|
|
assume_static_by_default=False,
|
|
)
|
|
def test_dynamic_to_static_cudagraphs(self):
|
|
for b in [False, True]:
|
|
with config.patch({"triton.cudagraph_trees": b}):
|
|
|
|
@torch._dynamo.optimize("inductor")
|
|
def fn(x, y):
|
|
r = x + y
|
|
return r, r.size(0)
|
|
|
|
inputs = (
|
|
torch.randn((5, 5), device="cuda"),
|
|
torch.randn((5, 5), device="cuda"),
|
|
)
|
|
self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 5)))
|
|
|
|
inputs = (
|
|
torch.randn((6, 6), device="cuda"),
|
|
torch.randn((6, 6), device="cuda"),
|
|
)
|
|
self.assertTrue(same(fn(*inputs), (inputs[0] + inputs[1], 6)))
|
|
|
|
@config.patch({"emulate_precision_casts": True})
|
|
def test_emulate_low_precision(self):
|
|
def foo(x):
|
|
return torch.nn.functional.gelu(x) * 10.0
|
|
|
|
inp = torch.rand([32], device="cuda", requires_grad=True, dtype=torch.bfloat16)
|
|
out, codes = run_fw_bw_and_get_code(lambda: torch.compile(foo)(inp))
|
|
|
|
# fwd, backward
|
|
for code in codes:
|
|
f = FileCheck()
|
|
# in eager, there are two down casts
|
|
for _ in range(2):
|
|
f.check(".to(tl.bfloat16)").check_next(".to(tl.float32)")
|
|
f.run(code)
|
|
|
|
self.assertEqual(foo(inp), out)
|
|
|
|
# TODO: Abstract this out, test more extensively
|
|
@torch._dynamo.config.patch(assume_static_by_default=False)
|
|
def test_dynamic_shapes(self):
|
|
torch._dynamo.reset() # Needed since everywhere else uses "inductor"
|
|
|
|
def f(x):
|
|
return x.cos().view(x.shape).sin()
|
|
|
|
cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
|
|
|
f2 = torch._dynamo.optimize(cnts)(f)
|
|
|
|
f2(torch.randn(32))
|
|
|
|
inp = torch.randn(16)
|
|
real_out = f(inp)
|
|
compiled_out = f2(inp)
|
|
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(real_out, compiled_out)
|
|
torch._dynamo.reset()
|
|
|
|
@config.patch({"triton.cudagraphs": True, "size_asserts": False})
|
|
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
|
def test_expanded_inputs_cudagraphs_no_size_asserts(self):
|
|
@torch._dynamo.optimize("inductor")
|
|
def fn(x, y):
|
|
return x + y
|
|
|
|
inputs = (
|
|
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
|
|
rand_strided((5, 5, 5, 5), (0, 5, 0, 1), device="cuda"),
|
|
)
|
|
self.assertTrue(same(fn(*inputs), inputs[0] + inputs[1]))
|
|
|
|
@config.patch({"triton.cudagraph_trees": False})
|
|
@config.patch({"triton.cudagraphs": True})
|
|
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
|
def test_inplace_updates_cudagraphs(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight1 = torch.nn.Parameter(
|
|
torch.randn(10, 20, requires_grad=True)
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = torch.matmul(x, self.weight1)
|
|
return x
|
|
|
|
from copy import deepcopy
|
|
|
|
model = Repro().cuda()
|
|
model_ref = deepcopy(model)
|
|
model_opt = torch._dynamo.optimize("inductor")(model)
|
|
|
|
input = torch.randn(10, 10, device="cuda", requires_grad=True)
|
|
|
|
for i in range(2):
|
|
output_ref = model_ref(input)
|
|
output_res = model_opt(input)
|
|
output_ref.sum().backward()
|
|
output_res.sum().backward()
|
|
for p_ref, p_res in zip(model_ref.parameters(), model_opt.parameters()):
|
|
self.assertEqual(p_ref.grad, p_res.grad)
|
|
with torch.no_grad():
|
|
for param in model_ref.parameters():
|
|
param.add_(1.0)
|
|
for param in model_opt.parameters():
|
|
param.add_(1.0)
|
|
|
|
# https://github.com/pytorch/torchdynamo/issues/1850
|
|
def test_inductor_output_aliases_intermediate(self):
|
|
def foo(x):
|
|
out = x + x
|
|
return out.t()
|
|
|
|
foo_opt = torch._dynamo.optimize("inductor")(foo)
|
|
|
|
inpt = torch.randn(10, 10, device="cuda", requires_grad=True)
|
|
# TODO: this is broken, fix later
|
|
# out = foo_opt(inpt)
|
|
# out.add_(2)
|
|
|
|
out_ref = foo(inpt)
|
|
out_ref.add_(2)
|
|
# self.assertEqual(out_ref, out)
|
|
|
|
def test_accuracy_issue1(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(
|
|
in_features=768, out_features=2, bias=True
|
|
)
|
|
|
|
def forward(self, start_positions: torch.Tensor, x: torch.Tensor):
|
|
linear = self.linear(x)
|
|
split = linear.split(1, dim=-1)
|
|
getitem = split[0]
|
|
squeeze = getitem.squeeze(-1)
|
|
clamp = start_positions.clamp(0, 128)
|
|
cross_entropy = torch.nn.functional.cross_entropy(
|
|
squeeze, clamp, None, None, 128, None, "mean", 0.0
|
|
)
|
|
return cross_entropy
|
|
|
|
mod = Repro().cuda()
|
|
opt_mod = torch._dynamo.optimize("inductor")(mod)
|
|
mod.eval()
|
|
opt_mod.eval()
|
|
|
|
args = [
|
|
((1,), (1,), torch.int64, "cuda", False),
|
|
((1, 128, 768), (98304, 768, 1), torch.float32, "cuda", True),
|
|
]
|
|
args = [
|
|
rand_strided(sh, st, dt, dev).requires_grad_(rg)
|
|
for (sh, st, dt, dev, rg) in args
|
|
]
|
|
with torch.cuda.amp.autocast(enabled=False):
|
|
assert same_two_models(mod, opt_mod, args), "Dynamo failed"
|
|
|
|
@config.patch(allow_buffer_reuse=False)
|
|
def test_issue103461(self):
|
|
def forward(add_1):
|
|
var_mean = torch.ops.aten.var_mean.correction(
|
|
add_1, [2], correction=0, keepdim=True
|
|
)
|
|
getitem_1 = var_mean[1]
|
|
return getitem_1
|
|
|
|
x = torch.randn(1, 8, 768, device="cuda")
|
|
correct = forward(x)
|
|
actual = torch.compile(forward, fullgraph=True)(x)
|
|
self.assertEqual(actual, correct)
|
|
|
|
def test_full_copy(self):
|
|
def forward(x):
|
|
full_10 = torch.ops.aten.full.default(
|
|
[204, 204, 28],
|
|
0,
|
|
dtype=torch.float64,
|
|
layout=torch.strided,
|
|
device="cuda",
|
|
pin_memory=False,
|
|
)
|
|
return x + full_10.to("cpu")
|
|
|
|
o = torch.randn([204, 204, 28], dtype=torch.float64)
|
|
correct = forward(o)
|
|
actual = torch.compile(forward, fullgraph=True)(o)
|
|
self.assertEqual(actual, correct)
|
|
|
|
def test_autotune_inplace_kernel(self):
|
|
"""
|
|
This UT tests autotune on an inplace kernel. The autotune should not contaminate
|
|
the input buffers when tuning with multiple configs. For more details, refer to
|
|
https://github.com/openai/triton/issues/781
|
|
https://github.com/pytorch/torchdynamo/issues/1670
|
|
"""
|
|
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
|
from torch._inductor.runtime.hints import HeuristicType, instance_descriptor
|
|
from torch._inductor.runtime.triton_heuristics import CachingAutotuner, grid
|
|
|
|
def autotune(configs, meta):
|
|
def decorator(fn):
|
|
return CachingAutotuner(
|
|
# force autotune by setting save_cache_hook to False
|
|
fn,
|
|
triton_meta=meta,
|
|
configs=configs,
|
|
save_cache_hook=False,
|
|
mutated_arg_names=["in_out_ptr0"],
|
|
heuristic_type=HeuristicType.POINTWISE,
|
|
)
|
|
|
|
return decorator
|
|
|
|
@autotune(
|
|
configs=[
|
|
triton.Config({"XBLOCK": 1}),
|
|
triton.Config({"XBLOCK": 2}),
|
|
],
|
|
meta={
|
|
"signature": {0: "*fp32", 1: "*fp32", 2: "i32"},
|
|
"device": DeviceProperties.create(torch.device("cuda")),
|
|
"configs": [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())],
|
|
"constants": {},
|
|
},
|
|
)
|
|
@triton.jit
|
|
def kernel(in_out_ptr0, in_ptr0, xnumel, XBLOCK: tl.constexpr):
|
|
pid = tl.program_id(0)
|
|
block_start = pid * XBLOCK
|
|
offsets = block_start + tl.arange(0, XBLOCK)
|
|
mask = offsets < xnumel
|
|
x = tl.load(in_out_ptr0 + offsets, mask=mask, other=0.0)
|
|
y = tl.load(in_ptr0 + offsets, mask=mask, other=0.0)
|
|
output = x + y
|
|
tl.store(in_out_ptr0 + offsets, output, mask=mask)
|
|
|
|
xnumel = 384
|
|
in0 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32)
|
|
inout1 = rand_strided((xnumel,), (1,), device="cuda", dtype=torch.float32)
|
|
inout2 = inout1.clone()
|
|
|
|
stream0 = get_cuda_stream(0)
|
|
kernel.run(inout1, in0, xnumel, grid=grid(xnumel), stream=stream0)
|
|
kernel.run(inout2, in0, xnumel, grid=grid(xnumel), stream=stream0)
|
|
|
|
assert same(
|
|
inout1, inout2, tol=0.001, equal_nan=True
|
|
), "failed autotune with inplace kernel"
|
|
|
|
def test_sort_stride_issue(self):
|
|
# This minified testcase comes from detectron2_maskrcnn_r_50_fpn
|
|
# There was a false error from our size_assert code
|
|
@torch._dynamo.optimize(nopython=True)
|
|
def forward(pred_objectness_logits_3_: torch.Tensor):
|
|
sort_3 = pred_objectness_logits_3_.sort(descending=True, dim=1)
|
|
getitem_12 = sort_3[0]
|
|
return getitem_12
|
|
|
|
args = [((1, 100), (0, 1), torch.float16, "cuda", False)]
|
|
args = [
|
|
rand_strided(sh, st, dt, dev).requires_grad_(rg)
|
|
for (sh, st, dt, dev, rg) in args
|
|
]
|
|
result = forward(*args)
|
|
assert same(result, torch.sort(args[0], descending=True, dim=1)[0])
|
|
|
|
def test_scalar_triton_index(self):
|
|
# The indirect indexing via a scalar like below used to lead to
|
|
# bad triton code that made triton segfault when compiling.
|
|
# See https://github.com/pytorch/torchdynamo/issues/1515
|
|
def fn(a):
|
|
zero = torch.zeros((16,), device=a.device, dtype=torch.int64)
|
|
return (a[zero],)
|
|
|
|
a = torch.randn((8,), dtype=torch.float32, device="cuda")
|
|
|
|
fn_optimized = torch._dynamo.optimize("inductor")(fn)
|
|
assert same(fn(a), fn_optimized(a))
|
|
|
|
def test_indirect_indexing_dense_mask(self):
|
|
def fn(x, y):
|
|
ne = torch.ops.aten.ne.Scalar(x, 1)
|
|
sum_1 = torch.ops.aten.sum.dim_IntList(ne, [1])
|
|
sub = torch.ops.aten.sub.Tensor(sum_1, 1)
|
|
unsqueeze = torch.ops.aten.unsqueeze.default(sub, -1)
|
|
gather = torch.ops.aten.gather.default(x, 1, unsqueeze)
|
|
squeeze = torch.ops.aten.squeeze.default(gather)
|
|
out = torch.ops.aten.multiply(y, squeeze)
|
|
return (out,)
|
|
|
|
a = torch.zeros((1, 128), dtype=torch.int64, device="cuda")
|
|
b = torch.zeros((1, 128), dtype=torch.int64, device="cuda")
|
|
|
|
fn_optimized = torch._dynamo.optimize("inductor")(fn)
|
|
assert same(fn(a, b), fn_optimized(a, b))
|
|
|
|
def test_simplify_dims(self):
|
|
def fn(a):
|
|
return (a + 1,)
|
|
|
|
self.common(fn, (torch.randn(2, 3, 10, 5, 6, device="cuda")[:, :, 2::2, :, :],))
|
|
|
|
@config.patch(permute_fusion=True)
|
|
def test_permute_fusion(self):
|
|
class Repro(torch.nn.Module):
|
|
def forward(self, view, reshape_2):
|
|
permute = view.permute(0, 2, 1)
|
|
view = None
|
|
reshape = torch.reshape(permute, (-1, 642))
|
|
bmm = torch.bmm(permute, reshape_2)
|
|
return (bmm,)
|
|
|
|
args = [
|
|
((1024, 642, 160), (102720, 160, 1), torch.float32, "cuda", True),
|
|
((1024, 642, 20), (12840, 20, 1), torch.float32, "cuda", True),
|
|
]
|
|
args = [
|
|
rand_strided(sh, st, dt, dev).requires_grad_(rg)
|
|
for (sh, st, dt, dev, rg) in args
|
|
]
|
|
|
|
mod = Repro()
|
|
opt_mod = torch._dynamo.optimize("inductor")(mod)
|
|
|
|
ref = mod(*args)
|
|
res = opt_mod(*args)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
@config.patch({"triton.autotune_pointwise": True})
|
|
def test_inplace_add_alpha_autotune(self):
|
|
def fn(x, y):
|
|
aten.add_.Tensor(x, y, alpha=0.55)
|
|
return (x,)
|
|
|
|
x1 = torch.zeros(2, 3, 4, 10, device="cuda")
|
|
x2 = torch.zeros(2, 3, 4, 10, device="cuda")
|
|
x3 = torch.zeros(2, 3, 4, 10, device="cuda")
|
|
y = torch.randn(2, 3, 4, 10, device="cuda").to(
|
|
memory_format=torch.channels_last
|
|
)
|
|
fn_fx = make_fx(fn)(x1, y)
|
|
fn_compiled = compile_fx_inner(fn_fx, [x1, y])
|
|
fn(x2, y)
|
|
fn_compiled([x3, y])
|
|
assert same(x2, x3)
|
|
|
|
@config.patch({"triton.autotune_pointwise": True})
|
|
def test_inplace_buffer_autotune(self):
|
|
def foo(x, y, z):
|
|
a = x @ y
|
|
return a.unsqueeze(0).unsqueeze(0) + z
|
|
|
|
x = torch.zeros(5, 5, device="cuda")
|
|
y = torch.zeros(5, 5, device="cuda")
|
|
z = torch.zeros(1, 1, 5, 5, device="cuda").to(memory_format=torch.channels_last)
|
|
self.common(
|
|
foo,
|
|
(x, y, z),
|
|
check_lowp=False,
|
|
)
|
|
|
|
def test_memory_history_inductor(self):
|
|
def called_inside_compile(x, w, b):
|
|
a = x @ w + b
|
|
return torch.sigmoid(a)
|
|
|
|
@torch.compile
|
|
def fn(x, w, b):
|
|
x = called_inside_compile(x, w, b)
|
|
return called_inside_compile(x, w, b)
|
|
|
|
w = torch.rand(3, 3, device="cuda")
|
|
b = torch.rand(3, device="cuda")
|
|
x = torch.rand(3, device="cuda")
|
|
try:
|
|
torch.cuda.memory.empty_cache()
|
|
torch.cuda.memory._record_memory_history(True)
|
|
r = fn(x, w, b)
|
|
finally:
|
|
torch.cuda.memory._record_memory_history(False)
|
|
snapshot = str(torch.cuda.memory._snapshot())
|
|
self.assertTrue("called_inside_compile" in snapshot)
|
|
|
|
def test_negative_arange_dynamic_shapes(self):
|
|
# Repro from alibi relative encodings
|
|
def sign(x):
|
|
return (x > 0) - (x < 0)
|
|
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
nheads = 16
|
|
start = math.log2(0.5)
|
|
end = math.log2(1 / (2**8))
|
|
|
|
self.scales = nn.Buffer(
|
|
2
|
|
** torch.arange(
|
|
start,
|
|
end + 1e-6 * sign(end - start),
|
|
(end - start) / (nheads - 1),
|
|
).view(1, nheads, 1, 1),
|
|
)
|
|
self.emb = nn.Embedding(1024, 256)
|
|
self.dec_layer = nn.TransformerDecoderLayer(
|
|
256, 16, 512, batch_first=True, norm_first=True
|
|
)
|
|
self.head = nn.Linear(256, 1024)
|
|
|
|
def forward(self, enc_out: torch.Tensor, dec_in: torch.Tensor):
|
|
padmask = dec_in == 0
|
|
dec_mask = padmask.unsqueeze(-1) == padmask.unsqueeze(-2)
|
|
dec_mask = dec_mask.to(dtype=torch.float32)
|
|
dec_mask = dec_mask.tril(diagonal=0).cuda()
|
|
|
|
q_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda")
|
|
k_pos = torch.arange(dec_in.size(1), dtype=torch.long, device="cuda")
|
|
rel_pos = k_pos[None, :] - q_pos[:, None]
|
|
values = rel_pos.abs().neg().unsqueeze(0).unsqueeze(0)
|
|
dec_bias = values * self.scales
|
|
dec_bias.tril_(diagonal=0)
|
|
|
|
dec_mask = dec_mask + dec_bias[0]
|
|
out = self.emb(dec_in)
|
|
out = self.dec_layer(out, enc_out, tgt_mask=dec_mask)
|
|
return self.head(out)
|
|
|
|
mod = Repro().cuda()
|
|
opt_mod = torch._dynamo.optimize("inductor", dynamic=True)(mod)
|
|
mod.eval()
|
|
opt_mod.eval()
|
|
|
|
enc_out = torch.rand(1, 512, 256).cuda()
|
|
dec_inputs = [
|
|
torch.randint(0, 512, (1, i + 1), dtype=torch.long).cuda() for i in range(8)
|
|
]
|
|
|
|
for dec_inp in dec_inputs:
|
|
assert same_two_models(
|
|
mod, opt_mod, [enc_out, dec_inp], only_fwd=True
|
|
), "Inductor with dynamic shapes failed"
|
|
|
|
def test_issue97695_1input(self):
|
|
def fn(arg3_1, relu, permute_1):
|
|
addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1)
|
|
cat_2 = torch.ops.aten.cat.default([addmm_1], 1)
|
|
return (cat_2,)
|
|
|
|
args = [
|
|
((96,), (1,), torch.float32, "cuda"),
|
|
((10, 256), (256, 1), torch.float32, "cuda"),
|
|
((256, 96), (1, 256), torch.float32, "cuda"),
|
|
]
|
|
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
|
|
correct = fn(*args)
|
|
|
|
mod = make_fx(fn, tracing_mode="real")(*args)
|
|
compiled = compile_fx_inner(mod, args)
|
|
ref = compiled(list(args))
|
|
assert same(ref, correct)
|
|
|
|
ref = torch.compile(fn, fullgraph=True)(*args)
|
|
assert same(ref, correct)
|
|
|
|
def test_issue_103924(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.temperature = 1
|
|
self.layer = torch.nn.Softmax(dim=1)
|
|
|
|
def forward(self, x):
|
|
n_samples, _ = x.shape
|
|
y = 1.0 * torch.ones(n_samples, dtype=x.dtype, device=x.device)
|
|
inp = x / y[..., None]
|
|
return self.layer(inp)
|
|
|
|
x = torch.rand([4, 4], device="cuda")
|
|
m = MyModule()
|
|
opt_m = torch.compile(backend="inductor")(m)
|
|
self.assertEqual(opt_m(x), m(x))
|
|
|
|
def test_issue97695_2input(self):
|
|
def fn(arg3_1, arg3_2, relu, permute_1):
|
|
addmm_1 = torch.ops.aten.addmm.default(arg3_1, relu, permute_1)
|
|
addmm_2 = torch.ops.aten.addmm.default(arg3_2, relu, permute_1)
|
|
cat_2 = torch.ops.aten.cat.default([addmm_1, addmm_2], 1)
|
|
return (cat_2,)
|
|
|
|
args = [
|
|
((96,), (1,), torch.float32, "cuda"),
|
|
((96,), (1,), torch.float32, "cuda"),
|
|
((10, 256), (256, 1), torch.float32, "cuda"),
|
|
((256, 96), (1, 256), torch.float32, "cuda"),
|
|
]
|
|
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
|
|
correct = fn(*args)
|
|
|
|
ref = torch.compile(fn, fullgraph=True)(*args)
|
|
assert same(ref, correct)
|
|
|
|
def test_scatter_index_not_wrapped(self):
|
|
src = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0], device=self.device)
|
|
index = torch.tensor([0, 1, 0, 1, 2, 0], device=self.device)
|
|
input = torch.tensor([1.0, 2.0, 3.0, 4.0], device=self.device)
|
|
compiled_sr = torch.compile(torch.scatter_reduce)
|
|
|
|
input_orig = input.clone()
|
|
out, code = run_and_get_code(compiled_sr, input, 0, index, src, "sum")
|
|
# tmp0 - not wrapping of negative numbers
|
|
FileCheck().check("tl.device_assert(((0 <= tmp0) & (tmp0 < 4))").check_next(
|
|
"atomic_add"
|
|
).run(code[0])
|
|
self.assertEqual(
|
|
out, torch.scatter_reduce(input_orig.clone(), 0, index, src, "sum")
|
|
)
|
|
|
|
def test_embedding_var_mean(self):
|
|
def forward(arg0_1):
|
|
full = torch.ops.aten.full.default(
|
|
[1, 2048],
|
|
1,
|
|
dtype=torch.float32,
|
|
layout=torch.strided,
|
|
device=torch.device(type="cuda", index=0),
|
|
pin_memory=False,
|
|
)
|
|
convert_element_type_1 = torch.ops.prims.convert_element_type.default(
|
|
full, torch.int64
|
|
)
|
|
cumsum = torch.ops.aten.cumsum.default(convert_element_type_1, 1)
|
|
mul = torch.ops.aten.mul.Tensor(cumsum, convert_element_type_1)
|
|
sub_1 = torch.ops.aten.sub.Tensor(mul, 1)
|
|
slice_5 = torch.ops.aten.slice.Tensor(sub_1, 0, 0, 9223372036854775807)
|
|
slice_6 = torch.ops.aten.slice.Tensor(slice_5, 1, 0, 9223372036854775807)
|
|
add_2 = torch.ops.aten.add.Tensor(slice_6, 2)
|
|
embedding_1 = torch.ops.aten.embedding.default(arg0_1, add_2)
|
|
var_mean = torch.ops.aten.var_mean.correction(
|
|
embedding_1, [2], correction=0, keepdim=True
|
|
)
|
|
return [var_mean[0], var_mean[1], add_2]
|
|
|
|
emb = torch.randn([2050, 768], device="cuda")
|
|
gm = make_fx(forward)(emb)
|
|
opt = torch._inductor.compile_fx.compile_fx_inner(gm, [emb])
|
|
opt([emb])
|
|
torch.cuda.synchronize()
|
|
|
|
def test_deterministic_algorithms(self):
|
|
N = 10000
|
|
|
|
@torch.compile
|
|
def fn(idx, values):
|
|
x = torch.zeros(1, device="cuda")
|
|
x[idx] += values
|
|
return x
|
|
|
|
idx = torch.zeros(N, dtype=torch.int64, device="cuda")
|
|
values = torch.randn(N, device="cuda")
|
|
|
|
r0 = fn(idx, values)
|
|
with DeterministicGuard(True):
|
|
r1 = fn(idx, values)
|
|
for _ in range(10):
|
|
rn = fn(idx, values)
|
|
self.assertEqual(r1, rn, atol=0, rtol=0)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/96406
|
|
def test_linear_cpu_input(self):
|
|
class Model(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = nn.Linear(4, 4)
|
|
|
|
def forward(self, data):
|
|
data = data.to("cuda")
|
|
return self.linear(data)
|
|
|
|
mod = Model().cuda().eval()
|
|
with torch.no_grad():
|
|
self.common(mod, (torch.randn(4, 4),))
|
|
|
|
@config.patch({"fallback_random": True, "triton.cudagraphs": True})
|
|
def test_xlnet_lm_stride_repro(self):
|
|
class Repro(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.dropout = nn.Dropout(p=0.1, inplace=False)
|
|
|
|
def forward(self, x):
|
|
y = torch._C._nn.gelu(x)
|
|
return self.dropout(y)
|
|
|
|
mod = Repro()
|
|
x = torch.randn((512, 1, 4096), requires_grad=True, device="cuda")
|
|
y = torch.compile(mod)(x)
|
|
# Inductor claims the output layout of gelu's saved variable for
|
|
# backwards will be (4096, 4096, 1) but in actuality it is (4096,
|
|
# 2097152, 1). Fortunately this doesn't actually matter in practice.
|
|
y.sum().backward()
|
|
|
|
def test_lookup_seed_backward(self):
|
|
@torch.compile(fullgraph=True)
|
|
def forward(inductor_seeds, mul_4, view_15):
|
|
inductor_lookup_seed_2 = torch.ops.prims.inductor_lookup_seed.default(
|
|
inductor_seeds, 2
|
|
)
|
|
inductor_random_2 = torch.ops.prims.inductor_random.default(
|
|
[2, 512, 768], inductor_lookup_seed_2, "rand"
|
|
)
|
|
gt_2 = torch.ops.aten.gt.Scalar(inductor_random_2, 0.1)
|
|
mul_7 = torch.ops.aten.mul.Tensor(gt_2, view_15)
|
|
mul_8 = torch.ops.aten.mul.Tensor(mul_7, 1.1111111111111112)
|
|
add_5 = torch.ops.aten.add.Tensor(mul_8, mul_4)
|
|
var_mean_1 = torch.ops.aten.var_mean.correction(
|
|
add_5, [2], correction=0, keepdim=True
|
|
)
|
|
getitem_3 = var_mean_1[1]
|
|
sub_3 = torch.ops.aten.sub.Tensor(add_5, getitem_3)
|
|
return (sub_3,)
|
|
|
|
buf0 = torch.zeros((37,), dtype=torch.int64, device="cuda")
|
|
buf1 = torch.zeros((2, 512, 768), device="cuda")
|
|
buf2 = torch.zeros((2, 512, 768), device="cuda")
|
|
forward(buf0, buf1, buf2)
|
|
|
|
def test_issue100806(self):
|
|
class Model(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(10, 20)
|
|
self.linear2 = torch.nn.Linear(20, 30)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.linear2(x)
|
|
x = torch.cat((x, x), dim=1)
|
|
x = x.view(-1, 2, 30)
|
|
x = x[:, 1, :]
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
device = "cuda"
|
|
batch_size = 2
|
|
x = torch.randn(batch_size, 10).to(device)
|
|
func = Model().to(device)
|
|
|
|
with torch.no_grad():
|
|
func.train(False)
|
|
jit_func = torch.compile(func)
|
|
|
|
res1 = func(x)
|
|
res2 = jit_func(x)
|
|
self.assertEqual(res1, res2)
|
|
|
|
def test_issue103481(self):
|
|
def fn(x, y):
|
|
# NOTE: 6 dimensions is important! does not fail for 5 dimensions
|
|
mean = torch.mean(x, [2, 3, 4, 5], keepdim=True)
|
|
add = mean + y
|
|
return add
|
|
|
|
x = torch.rand(4, 4, 4, 4, 4, 4, device="cuda")
|
|
y = torch.rand((), device="cuda")
|
|
expect = fn(x, y)
|
|
|
|
opt_fn = torch.compile(fn)
|
|
actual = opt_fn(x, y)
|
|
|
|
self.assertEqual(expect, actual)
|
|
|
|
@config.patch({"triton.dense_indexing": True})
|
|
@dynamo_config.patch(automatic_dynamic_shapes=True)
|
|
def test_bucketize_dynamic_dense(self):
|
|
"""
|
|
Make sure that ops.bucketize() can handle dense_indexing, which previously
|
|
caused issues due to incorrect handling of the size of offsets.
|
|
"""
|
|
|
|
def fn(values, offsets):
|
|
return torch.bucketize(values, offsets)
|
|
|
|
values = torch.rand((64, 64), device="cuda")
|
|
offsets = torch.tensor([0.05, 0.1, 0.5, 0.8, 0.85, 0.95], device="cuda")
|
|
|
|
expect = fn(values, offsets)
|
|
|
|
opt_fn = torch.compile(fn, dynamic=True)
|
|
actual = opt_fn(values, offsets)
|
|
|
|
self.assertEqual(expect, actual)
|
|
|
|
def test_float64_constants(self):
|
|
def fn():
|
|
# NOTE: tensors of all the same value are constant folded, so we
|
|
# need a tensor with two distinct values
|
|
a = torch.tensor([1 / 10, 2 / 10], dtype=torch.float64, device="cuda")
|
|
return a * 2e50
|
|
|
|
cfn = torch.compile(fn)
|
|
expect = fn()
|
|
actual = cfn()
|
|
self.assertEqual(expect, actual, atol=0, rtol=0)
|
|
|
|
def test_issue104759(self):
|
|
def fn(arg7_1, add_1, permute_2, select_scatter, slice_8):
|
|
slice_scatter_4 = torch.ops.aten.slice_scatter.default(
|
|
permute_2, select_scatter, 0, 1, 9223372036854775807
|
|
)
|
|
permute_3 = torch.ops.aten.permute.default(slice_scatter_4, [1, 3, 0, 2, 4])
|
|
view_6 = torch.ops.aten.view.default(permute_3, [1, 1000, 48])
|
|
view_7 = torch.ops.aten.view.default(view_6, [1000, 48])
|
|
view_8 = torch.ops.aten.view.default(view_7, [1, 1000, 48])
|
|
view_9 = torch.ops.aten.view.default(view_8, [1, 1000, 3, 4, 4])
|
|
permute_4 = torch.ops.aten.permute.default(view_9, [2, 0, 3, 1, 4])
|
|
slice_7 = torch.ops.aten.slice.Tensor(permute_4, 0, 1, 9223372036854775807)
|
|
slice_scatter_5 = torch.ops.aten.slice_scatter.default(
|
|
slice_8, slice_7, 4, 0, 9223372036854775807
|
|
)
|
|
slice_scatter_6 = torch.ops.aten.slice_scatter.default(
|
|
arg7_1, slice_scatter_5, 3, 0, 1000
|
|
)
|
|
mul_8 = torch.ops.aten.mul.Scalar(add_1, 0.7071067811865476)
|
|
slice_9 = torch.ops.aten.slice.Tensor(slice_scatter_6, 3, 0, 1000)
|
|
slice_10 = torch.ops.aten.slice.Tensor(slice_9, 4, 0, 9223372036854775807)
|
|
select_2 = torch.ops.aten.select.int(slice_10, 0, 0)
|
|
permute_5 = torch.ops.aten.permute.default(select_2, [0, 1, 3, 2])
|
|
mul_9 = torch.ops.aten.mul.Scalar(permute_5, 0.7071067811865476)
|
|
expand = torch.ops.aten.expand.default(mul_8, [1, 4, 1000, 4])
|
|
view_10 = torch.ops.aten.view.default(expand, [4, 1000, 4])
|
|
expand_1 = torch.ops.aten.expand.default(mul_9, [1, 4, 4, 1000])
|
|
view_11 = torch.ops.aten.view.default(expand_1, [4, 4, 1000])
|
|
bmm = torch.ops.aten.bmm.default(view_10, view_11)
|
|
return (bmm,)
|
|
|
|
args = []
|
|
args.append(torch.randn((2, 1, 4, 1200, 4), dtype=torch.float16, device="cuda"))
|
|
args.append(
|
|
rand_strided(
|
|
(1, 4, 1000, 4), (16000, 4, 16, 1), dtype=torch.float16, device="cuda"
|
|
)
|
|
)
|
|
args.append(
|
|
rand_strided(
|
|
(3, 1, 4, 1000, 4),
|
|
(16, 48000, 4, 48, 1),
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
)
|
|
args.append(
|
|
rand_strided(
|
|
(2, 1, 4, 1000, 4),
|
|
(16, 48000, 4, 48, 1),
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
)
|
|
args.append(
|
|
rand_strided(
|
|
(2, 1, 4, 1000, 4),
|
|
(19200, 19200, 4800, 4, 1),
|
|
dtype=torch.float16,
|
|
device="cuda",
|
|
)
|
|
)
|
|
|
|
correct = fn(*args)
|
|
mod = make_fx(fn, tracing_mode="real")(*args)
|
|
compiled = compile_fx_inner(mod, args)
|
|
ref = compiled(list(args))
|
|
assert same(ref, correct)
|
|
|
|
@config.patch({"triton.cudagraphs": True})
|
|
def test_index_put_inplace_cudagraph(self):
|
|
def fn(x, y, z):
|
|
x = torch.zeros_like(x)
|
|
return x.index_put_([y], z, True)
|
|
|
|
x = torch.zeros((512, 512), device="cuda", dtype=torch.bool)
|
|
y = torch.zeros((512,), device="cuda", dtype=torch.int64)
|
|
z = torch.ones((512, 512), device="cuda", dtype=torch.bool)
|
|
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
|
|
ref = fn(x, y, z)
|
|
|
|
# run it twice to test cuda graph issue
|
|
res = opt_fn(x, y, z)
|
|
res = opt_fn(x, y, z)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
@config.patch({"triton.cudagraphs": True})
|
|
@config.patch({"fx_graph_cache": True})
|
|
def test_index_put_cudagraph(self):
|
|
for _ in range(2):
|
|
|
|
def fn(x, y, z):
|
|
x = torch.zeros_like(x)
|
|
return x.index_put([y], z, True)
|
|
|
|
x = torch.zeros((512, 512), device="cuda", dtype=torch.bool)
|
|
y = torch.zeros((512,), device="cuda", dtype=torch.int64)
|
|
z = torch.ones((512, 512), device="cuda", dtype=torch.bool)
|
|
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
|
|
ref = fn(x, y, z)
|
|
|
|
# run it twice to test cuda graph issue
|
|
res = opt_fn(x, y, z)
|
|
res = opt_fn(x, y, z)
|
|
|
|
self.assertEqual(ref, res)
|
|
torch._dynamo.reset()
|
|
gc.collect()
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FLASH_ATTENTION, "flash attention not supported"
|
|
)
|
|
def test_flash_attention_dynamic(self):
|
|
class Model(nn.Module):
|
|
def __init__(self, *args, **kwargs) -> None:
|
|
super().__init__(*args, **kwargs)
|
|
|
|
self.q = nn.Linear(1024, 1024)
|
|
self.k = nn.Linear(1024, 1024)
|
|
self.v = nn.Linear(1024, 1024)
|
|
|
|
def forward(self, x):
|
|
batch_size, seq_len, _ = x.size()
|
|
|
|
queries = self.q(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
|
|
keys = self.k(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
|
|
values = self.v(x).view(batch_size, seq_len, 8, 128).transpose(2, 1)
|
|
|
|
attn = F.scaled_dot_product_attention(
|
|
queries,
|
|
keys,
|
|
values,
|
|
)
|
|
|
|
return attn
|
|
|
|
cnts = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
|
|
|
model = Model().cuda().half()
|
|
model = torch.compile(model, backend=cnts, dynamic=True)
|
|
|
|
with torch.backends.cuda.sdp_kernel(
|
|
enable_flash=True,
|
|
enable_math=False,
|
|
enable_mem_efficient=False,
|
|
enable_cudnn=False,
|
|
):
|
|
input1 = torch.rand(5, 512, 1024, device="cuda", dtype=torch.float16)
|
|
input2 = torch.rand(5, 513, 1024, device="cuda", dtype=torch.float16)
|
|
input3 = torch.rand(5, 514, 1024, device="cuda", dtype=torch.float16)
|
|
|
|
out1 = model(input1)
|
|
out2 = model(input2)
|
|
out3 = model(input3)
|
|
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
@config.patch({"triton.cudagraphs": True})
|
|
def test_index_put_no_fallback_cudagraph(self):
|
|
def fn(x, y, z):
|
|
x = torch.zeros_like(x)
|
|
return x.index_put([y], z, True)
|
|
|
|
x = torch.zeros((512, 512), device="cuda", dtype=torch.int32)
|
|
y = torch.zeros((512,), device="cuda", dtype=torch.int64)
|
|
z = torch.ones((512, 512), device="cuda", dtype=torch.int32)
|
|
|
|
opt_fn = torch._dynamo.optimize("inductor")(fn)
|
|
|
|
ref = fn(x, y, z)
|
|
|
|
# run it twice to test cuda graph issue
|
|
res = opt_fn(x, y, z)
|
|
res = opt_fn(x, y, z)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/104937
|
|
def test_linear_with_zero_infeature_size(self):
|
|
m = nn.Linear(in_features=0, out_features=0, bias=True).to("cuda")
|
|
x = torch.rand(1, 1, 0, device="cuda")
|
|
expect = m(x)
|
|
opt_fn = torch.compile(m)
|
|
actual = opt_fn(x)
|
|
self.assertEqual(expect, actual)
|
|
|
|
@config.patch(fallback_random=True)
|
|
def test_multi_output_layout_fallback(self):
|
|
mod = nn.RReLU(lower=3.2350976, upper=8.4220314, inplace=True)
|
|
inp = torch.rand([4, 4]).cuda()
|
|
m = torch.compile(mod)
|
|
|
|
with freeze_rng_state():
|
|
o1 = m(inp.clone())
|
|
|
|
o2 = mod(inp.clone())
|
|
|
|
self.assertEqual(o1, o2)
|
|
|
|
def test_cat_int8_one_kernel(self):
|
|
@torch.compile()
|
|
def cat(inps):
|
|
return torch.cat(inps) + 1
|
|
|
|
for dtype in [torch.uint8, torch.int8]:
|
|
inps = [
|
|
torch.empty([256, 256], dtype=dtype, device="cuda") for _ in range(4)
|
|
]
|
|
|
|
out, code = run_and_get_code(cat, inps)
|
|
self.assertEqual(torch.cat(inps) + 1, out)
|
|
FileCheck().check_not("aten.cat.default(").check_count(
|
|
".run(", 1, exactly=True
|
|
).run(code[0])
|
|
|
|
@config.patch("triton.use_block_ptr", True)
|
|
def test_selecsls42b_misaligned_address(self):
|
|
# https://github.com/openai/triton/issues/2836
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def fn(arg207_1, arg208_1, convert_element_type_40, expand, full, mul_3):
|
|
div = torch.ops.aten.div.Scalar(expand, 16)
|
|
where = torch.ops.aten.where.self(arg207_1, full, div)
|
|
convert_element_type_43 = torch.ops.prims.convert_element_type.default(
|
|
where, torch.float32
|
|
)
|
|
sum_2 = torch.ops.aten.sum.dim_IntList(convert_element_type_43, [0, 2, 3])
|
|
sub = torch.ops.aten.sub.Tensor(convert_element_type_40, arg208_1)
|
|
mul = torch.ops.aten.mul.Tensor(convert_element_type_43, sub)
|
|
sum_3 = torch.ops.aten.sum.dim_IntList(mul, [0, 2, 3])
|
|
mul_1 = torch.ops.aten.mul.Tensor(sum_2, 0.0078125)
|
|
unsqueeze = torch.ops.aten.unsqueeze.default(mul_1, 0)
|
|
unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, 2)
|
|
unsqueeze_2 = torch.ops.aten.unsqueeze.default(unsqueeze_1, 3)
|
|
mul_2 = torch.ops.aten.mul.Tensor(sum_3, 0.0078125)
|
|
mul_4 = torch.ops.aten.mul.Tensor(mul_2, mul_3)
|
|
unsqueeze_3 = torch.ops.aten.unsqueeze.default(mul_4, 0)
|
|
unsqueeze_4 = torch.ops.aten.unsqueeze.default(unsqueeze_3, 2)
|
|
unsqueeze_5 = torch.ops.aten.unsqueeze.default(unsqueeze_4, 3)
|
|
mul_6 = torch.ops.aten.mul.Tensor(sub, unsqueeze_5)
|
|
sub_1 = torch.ops.aten.sub.Tensor(convert_element_type_43, mul_6)
|
|
sub_2 = torch.ops.aten.sub.Tensor(sub_1, unsqueeze_2)
|
|
return (sub_2,)
|
|
|
|
args = [
|
|
torch.randn((8, 1024, 4, 4), device="cuda") > 0, # torch.bool tensor
|
|
torch.randn((1, 1024, 1, 1), device="cuda"),
|
|
torch.randn((8, 1024, 4, 4), device="cuda"),
|
|
torch.randn((8, 1024, 1, 1), dtype=torch.float16, device="cuda").expand(
|
|
(8, 1024, 4, 4)
|
|
),
|
|
torch.randn((), device="cuda"),
|
|
torch.randn((1024,), device="cuda"),
|
|
]
|
|
fn(*args)
|
|
torch.cuda.synchronize() # shake out Triton Error [CUDA]: misaligned address
|
|
|
|
@skipIfRocm
|
|
def test_non_commutative_scan_op(self):
|
|
from torch._higher_order_ops.associative_scan import associative_scan
|
|
|
|
a = torch.randn(1024, 8192, dtype=torch.float64, device="cuda")
|
|
b = torch.randn(1024, 8192, dtype=torch.float64, device="cuda")
|
|
|
|
def baseline(v, u):
|
|
A = []
|
|
A.append(b[:, 0])
|
|
for i in range(1, v.shape[1]):
|
|
A.append(a[:, i] * A[i - 1] + b[:, i])
|
|
return torch.stack(A, dim=1)
|
|
|
|
def combine_fn(i, j):
|
|
ia, ib = i
|
|
ja, jb = j
|
|
return ia * ja, ib * ja + jb
|
|
|
|
@torch.compile
|
|
def compiled_scan(a, b):
|
|
return associative_scan(combine_fn, (a, b), dim=-1)[1]
|
|
|
|
out1 = baseline(a, b)
|
|
out2 = compiled_scan(a, b)
|
|
self.assertEqual(out1, out2)
|
|
|
|
def test_dynamic_persistent_reductions(self):
|
|
@torch.compile(dynamic=True)
|
|
def inner_reduce(x):
|
|
assert x.shape[1] <= 1024
|
|
return x.sum(1)
|
|
|
|
a = torch.randn(50, 600, device="cuda")
|
|
out, code = run_and_get_code(inner_reduce, a)
|
|
self.assertEqual(inner_reduce(a), out)
|
|
self.assertTrue("for roffset" not in code)
|
|
|
|
@torch.compile(dynamic=True)
|
|
def outer_reduce(x):
|
|
assert x.shape[0] <= 64
|
|
return x.sum(0)
|
|
|
|
out, code = run_and_get_code(outer_reduce, a)
|
|
self.assertEqual(outer_reduce(a), out)
|
|
self.assertTrue("for roffset" not in code)
|
|
|
|
def test_non_contiguous_unaligned_input_indices(self):
|
|
from torch._inductor.compile_fx import remove_unaligned_input_idxs
|
|
|
|
inputs = [torch.ones(2, 2, device="cuda"), torch.ones(2, 2, device="cuda")[1:]]
|
|
idxs = remove_unaligned_input_idxs(inputs, [1])
|
|
self.assertEqual(idxs, [])
|
|
|
|
inputs = [
|
|
torch.ones(2, 2, device="cuda"),
|
|
torch.ones(2, 2, device="cuda"),
|
|
torch.ones(2, 2, device="cuda")[1:],
|
|
]
|
|
idxs = remove_unaligned_input_idxs(inputs, [0, 2])
|
|
self.assertEqual(idxs, [0])
|
|
|
|
@config.patch("triton.cudagraphs", True)
|
|
def test_unused_cpu_input_cudagraphs(self):
|
|
def fn(x, y):
|
|
return x.sin().sin().sin().sin().cos() + 1
|
|
|
|
fx_graph = torch.fx.symbolic_trace(fn)
|
|
inp = [torch.randn(64, device="cuda"), torch.randn(64, device="cpu")]
|
|
compiled_fn, (graph,) = run_and_get_graph_lowering(
|
|
torch._inductor.compile, fx_graph, inp
|
|
)
|
|
self.assertEqual(graph.disable_cudagraphs_reason, None)
|
|
self.assertEqual(graph.device_types, {"cuda"})
|
|
self.assertEqual(compiled_fn(*inp), fn(*inp))
|
|
|
|
def test_epilogue_fusion_with_view(self):
|
|
class ToyModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
|
self.linear = torch.nn.Linear(262144, 100)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
x = self.conv(x)
|
|
x = x.view(x.size(0), -1)
|
|
return self.relu(self.linear(x))
|
|
|
|
m = ToyModel().to(device="cuda:0")
|
|
input_tensor = torch.randn(32, 3, 64, 64).to(device="cuda:0")
|
|
from torch._inductor.utils import fresh_inductor_cache
|
|
|
|
with fresh_inductor_cache():
|
|
cm = torch.compile(m, mode="max-autotune")
|
|
out = cm(input_tensor)
|
|
out2 = m(input_tensor)
|
|
self.assertEqual(out, out2, atol=1e-3, rtol=1e-3)
|
|
|
|
@config.patch("triton.cudagraphs", True)
|
|
def test_cpu_index(self):
|
|
@torch.compile(fullgraph=True)
|
|
def fn(x):
|
|
return x[torch.arange(32)]
|
|
|
|
result, (graph,) = run_and_get_graph_lowering(
|
|
fn, torch.randn(64, device="cuda")
|
|
)
|
|
self.assertEqual(graph.disable_cudagraphs_reason, None)
|
|
self.assertEqual(graph.device_types, {"cuda"})
|
|
|
|
inp = torch.randn(64, device="cuda", requires_grad=True)
|
|
result, (graph,) = run_and_get_graph_lowering(fn, inp)
|
|
self.assertEqual(graph.disable_cudagraphs_reason, None)
|
|
self.assertEqual(graph.device_types, {"cuda"})
|
|
|
|
result, (graph,) = run_and_get_graph_lowering(lambda: result.sum().backward())
|
|
self.assertEqual(graph.disable_cudagraphs_reason, None)
|
|
self.assertEqual(graph.device_types, {"cuda"})
|
|
|
|
def test_reflection_pad_loop_order(self):
|
|
def fn(x, y):
|
|
a = torch.nn.functional.pad(x, (5, 5, 5, 5), mode="reflect")
|
|
b = torch.nn.functional.pad(y, (5, 5, 5, 5), mode="reflect")
|
|
return a + b
|
|
|
|
cfn = torch.compile(fn)
|
|
a = torch.rand((10, 10, 10), device="cuda")
|
|
b = torch.rand((10, 10, 10), device="cuda")
|
|
expect = fn(a, b)
|
|
actual, code = run_and_get_code(cfn, a, b)
|
|
self.assertEqual(expect, actual)
|
|
|
|
# Expect the code iterates in contiguous order, and is not tiled
|
|
kernel_code = "\n".join(code[0].split("\n")[60:74])
|
|
self.assertExpectedInline(
|
|
kernel_code,
|
|
"""\
|
|
@triton.jit
|
|
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
|
|
xnumel = 4000
|
|
xoffset = tl.program_id(0) * XBLOCK
|
|
xindex = xoffset + tl.arange(0, XBLOCK)[:]
|
|
xmask = xindex < xnumel
|
|
x0 = xindex % 20
|
|
x1 = (xindex // 20) % 20
|
|
x2 = (xindex // 400)
|
|
x3 = xindex
|
|
tmp0 = tl.load(in_ptr0 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
|
|
tmp1 = tl.load(in_ptr1 + (99 + ((-1)*(tl_math.abs((-9) + (tl_math.abs((-5) + x0))))) + ((-10)*(tl_math.abs((-9) + (tl_math.abs((-5) + x1))))) + (100*x2)), xmask, eviction_policy='evict_last')
|
|
tmp2 = tmp0 + tmp1
|
|
tl.store(out_ptr0 + (x3), tmp2, xmask)""", # noqa: B950
|
|
)
|
|
|
|
@skipCUDAIf(not SM80OrLater, "uses bfloat16 which requires SM >= 80")
|
|
def test_int64_index_intermediate(self):
|
|
def foo(inp):
|
|
view_23 = torch.ops.aten.view.default(inp, [-1, 8192, 8192])
|
|
split_1 = torch.ops.aten.split.Tensor(view_23, 1024, 1)
|
|
view_23 = None
|
|
getitem_17 = split_1[0]
|
|
getitem_18 = split_1[1]
|
|
getitem_19 = split_1[2]
|
|
getitem_20 = split_1[3]
|
|
getitem_21 = split_1[4]
|
|
getitem_22 = split_1[5]
|
|
getitem_23 = split_1[6]
|
|
getitem_24 = split_1[7]
|
|
split_1 = None
|
|
cat_1 = torch.ops.aten.cat.default(
|
|
[
|
|
getitem_17,
|
|
getitem_18,
|
|
getitem_19,
|
|
getitem_20,
|
|
getitem_21,
|
|
getitem_22,
|
|
getitem_23,
|
|
getitem_24,
|
|
]
|
|
)
|
|
getitem_17 = (
|
|
getitem_18
|
|
) = (
|
|
getitem_19
|
|
) = getitem_20 = getitem_21 = getitem_22 = getitem_23 = getitem_24 = None
|
|
return cat_1
|
|
|
|
for mark_dynamic in [False, True]:
|
|
inp = torch.rand((65536, 8192), dtype=torch.bfloat16, device="cuda")
|
|
if mark_dynamic:
|
|
torch._dynamo.mark_dynamic(inp, 0)
|
|
foo_c = torch.compile(foo)
|
|
torch.testing.assert_allclose(foo(inp), foo_c(inp))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._inductor.test_case import run_tests
|
|
from torch.testing._internal.inductor_utils import HAS_CUDA
|
|
|
|
if HAS_CUDA and not TEST_WITH_ASAN:
|
|
run_tests(needs="filelock")
|