pytorch/test/dynamo/test_ctx_manager.py
drisspg ad90ab31f2 Flash Attention v2 (#105602)
# Summary
## PR Dependencies
I don't use ghstack :( this is a PR where it would have been helpful. That beings said I am going to peel off some PRs to make reviewing this easier:
- [x] Separate build flags for Flash and MemEff: #107985

### Description
This pull request updates the version of _scaled_dot_product_flash_attention from version 1 to version 2. The changes are based on the flash attention code originally authored by @tridao

### Changes Made
The majority of the changes in this pull request involve:

- Copying over the flash_attention sources.
- Updating header files.
- Removing padding and slicing code from within the flash_attention kernel and relocating it to the composite implicit region of the SDPA. This was need to make the kernel functional and appease autograd.
- Introducing a simple kernel generator to generate different instantiations of the forward and backward flash templates.
- Adding conditional compilation (ifdef) to prevent building when nvcc is invoked with gencode < sm80.
- Introducing a separate dependent option for mem_eff_attention, as flash_attention v2 lacks support for Windows and cannot be built for sm50 generation codes.
- Modifying build.sh to reduce parallelization on sm86 runners and to lower the maximum parallelization on the manywheel builds. This adjustment was made to address out-of-memory issues during the compilation of FlashAttentionV2 sources.
- Adding/Updating tests.

### Notes for Reviewers
This is not a fun review, and I apologize in advance.
Most of the files-changed are in the flash_attn/ folder. The only files of interest here IMO:
- aten/src/ATen/native/transformers/cuda/flash_attn/flash_api.cpp
- aten/src/ATen/native/transformers/cuda/flash_attn/kernels/generate_kernels.py ( this has been incorporated upstream to flash-attention github)

There are a number of files all related to avoiding OOMs in CI/CD. These are typically shell scripts.

### Follow up items
- Include the updates from e07aa036db and 9e5e8bc91e | https://github.com/pytorch/pytorch/issues/108108

### Work Items
- [x] I don't think Windows will be supported for 3.1.0 - Need to update cmakee
- [x] Let multi_query/attention pass through and test | UPDATE: I have the fast path implemented here: https://github.com/pytorch/pytorch/pull/106730 but since this will require changes to semantics of math to call repeat_interleave, I think this should be done as a followup.
- [x] Had to drop cutlass back to 3.0.0 to get it to compile. Need to figure out how to upgrade to 3.1.0 and later. Spoke with Tri and he is going to be taking a look. Note: compiling with clang currently errors for the cute headers.
- [x] Update test exercise above codepath
- [x] Still need to disable on seq_len % 128 != 0 for backward( Tri beat me to it a4f148b6ab)
- [x] Add determinism warning to BWD, Tri got to this one as well: 1c41d2b
- [x] Update dispatcher to universally prefer FlashV2
- [x] Update tests to exercise new head_dims
- [x] Move the head_dim padding from kernel to top level composite implicit function in order to make it purely functional
- [x] Create template generator script
- [x] Initial cmake support for building kernels/ folder
- [x] Replay CudaGraph changes

### Results
#### Forward only
The TFlops are reported here are on a100 that is underclocked.
![flashv2_tflops_vs_seq_len](https://github.com/pytorch/pytorch/assets/32754868/152de46d-8fa6-42f0-9a9c-ef1eb7ae29e7)

#### Forward+Backward
Ran a sweep and for large compute bound sizes we do see a ~2x performance increase for forw+back.
<img width="1684" alt="Screenshot 2023-07-20 at 3 47 47 PM" src="https://github.com/pytorch/pytorch/assets/32754868/fdd26e07-0077-4878-a417-f3a418b6fb3b">

Pull Request resolved: https://github.com/pytorch/pytorch/pull/105602
Approved by: https://github.com/huydhn, https://github.com/cpuhrsch
2023-09-13 13:59:05 +00:00

898 lines
31 KiB
Python

# Owner(s): ["module: dynamo"]
import unittest
import torch
import torch._dynamo.test_case
import torch._dynamo.testing
import torch.onnx.operators
from torch._dynamo.testing import EagerAndRecordGraphs, normalize_gm, same
from torch.nn import functional as F
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
class CutomizedCtxManager:
def __init__(self, mode):
self.prev = torch.is_grad_enabled()
self.mode = mode
def __enter__(self):
torch._C._set_grad_enabled(self.mode)
def __exit__(self, exc_type, exc_value, traceback):
torch._C._set_grad_enabled(self.prev)
class CtxManagerTests(torch._dynamo.test_case.TestCase):
def test_no_grad(self):
def fn1(a, b):
x = a + 1
# redundant no_grad should get ignored
with torch.no_grad():
x = x + b
x = x + 2
return x
def fn2(a, b):
x = a + 1
with torch.set_grad_enabled(False):
x = x + b
x = x + 2
return x
def fn3(a, b):
x = a + 1
with torch.enable_grad():
x = x + b
x = x + 2
return x
def fn4(a, b):
x = a + 1
with torch.set_grad_enabled(True):
if torch.is_grad_enabled():
x = x + b
x = x + 2
return x
with torch.no_grad():
torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5)
with torch.enable_grad():
torch._dynamo.testing.standard_test(self, fn=fn1, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn2, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn3, nargs=2, expected_ops=5)
torch._dynamo.testing.standard_test(self, fn=fn4, nargs=2, expected_ops=5)
def test_grad_mode_guard(self):
def fn(a, b):
prev_grad = torch.is_grad_enabled()
torch.set_grad_enabled(False)
a = a + 1
a.tolist() # graph break
ret = a + b
torch.set_grad_enabled(prev_grad)
return ret
a = torch.randn([3, 4])
b = torch.randn([3, 4])
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
for _ in range(10):
opt_fn(a, b)
self.assertEqual(cnts.frame_count, 2)
def test_nested_grad_mode_graph_break(self):
def fn(x):
before = torch.is_grad_enabled()
with torch.set_grad_enabled(False):
torch._dynamo.graph_break()
with torch.set_grad_enabled(True):
x = torch.mul(x, 5)
torch._dynamo.graph_break()
x = torch.sqrt(x)
assert torch.is_grad_enabled()
assert not torch.is_grad_enabled()
assert torch.is_grad_enabled() == before
return x
a = torch.randn([3, 4])
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
for _ in range(10):
opt_fn(a)
self.assertEqual(cnts.frame_count, 2)
def test_torch_profiler(self):
# wrap torch.profiler.* as NullContextVariable and do nothing
def fn(x):
y = x**2
with torch.profiler.profile():
y = y + 2
with torch.profiler.record_function("my_function"):
z = y**3
z.tolist() # graph break
z = z + 1
return z
x = torch.randn((2, 2), requires_grad=True)
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 2)
def test_autograd_profiler(self):
# wrap torch.autograd.profiler.* as NullContextVariable and do nothing
def fn(x):
y = x**2
with torch.autograd.profiler.profile():
y = y + 2
with torch.autograd.profiler.record_function("my_function"):
z = y**3
z.tolist() # graph break
z = z + 1
return z
x = torch.randn((2, 2), requires_grad=True)
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 2)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_stream_context_manager1(self):
def fn(x):
s = torch.cuda.Stream()
x = torch.mul(x, 5)
x = torch.add(x, 2)
with torch.cuda.stream(s):
x = torch.relu(x)
x = torch.add(x, 1)
x = torch.cos(x)
return x
x = torch.randn((2, 2))
ref = fn(x)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 1)
self.assertEqual(cnts.op_count, 9)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_stream_context_manager2(self):
def fn(x, s):
x = torch.mul(x, 5)
x = torch.add(x, 2)
with torch.cuda.stream(s):
x = torch.relu(x)
x = torch.add(x, 1)
x = torch.cos(x)
return x
x = torch.randn((2, 2))
s = torch.cuda.Stream()
ref = fn(x, s)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported,
"CUDAStreamVariable does not currently work soundly.",
):
res = opt_fn(x, s)
def test_autograd_profiler_enabled(self):
def fn(x):
if torch.autograd._profiler_enabled():
return x + 1
else:
return x - 1
x = torch.randn((2, 2), requires_grad=True)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts)(fn)
if torch.autograd._profiler_enabled():
torch.autograd._disable_profiler()
assert not torch.autograd._profiler_enabled()
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
with torch.autograd.profiler.profile():
assert torch.autograd._profiler_enabled()
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast(self):
if not torch.cuda.is_bf16_supported():
raise unittest.SkipTest("requires bf16")
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
e_float16 = torch.mm(a_float32, b_float32)
f_float16 = torch.mm(d_float32, e_float16)
return f_float16
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.type, "cuda")
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.bfloat16)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_cuda_amp_autocast(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
with torch.cuda.amp.autocast(dtype=torch.torch.float64):
c_float64 = torch.mm(a_float32, b_float32)
return c_float64
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, _ = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.type, "cuda")
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.float64)
def test_is_autocast_cpu_enabled(self):
def fn(a_float32, b_float32):
with torch.cpu.amp.autocast(dtype=torch.bfloat16):
c_float16 = torch.mm(a_float32, b_float32)
if torch.is_autocast_cpu_enabled():
c_float16 = c_float16 + 1
return c_float16
a = torch.rand((8, 8))
b = torch.rand((8, 8))
ref = fn(a, b)
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
res = opt_fn(a, b)
self.assertTrue(same(ref, res))
@unittest.skipIf(
not PLATFORM_SUPPORTS_FLASH_ATTENTION,
"Can't run fused SDPA on this platform",
)
def test_autocast_sdpa(self):
class MyModule(torch.nn.Module):
def forward(self, query, key, value):
with torch.autocast("cpu"):
with torch.autocast("cuda", dtype=torch.float32):
out = F.scaled_dot_product_attention(
query, key, value, None, 0.0, True
)
return out
dtype = torch.float32
seq_len_q = 1
seq_len_k = 1
head_dim = 8
query = torch.ones(
1, 8, seq_len_q, head_dim, device="cuda", dtype=dtype, requires_grad=True
)
key = torch.ones(
1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True
)
value = torch.ones(
1, 8, seq_len_k, head_dim, device="cuda", dtype=dtype, requires_grad=True
)
module = MyModule()
real = module(query, key, value)
real_device = real.device
real_dtype = real.dtype
opt_mod = torch._dynamo.optimize("inductor")(module)
compiled = opt_mod(query, key, value)
self.assertEqual(compiled.device, real_device)
self.assertEqual(compiled.dtype, real_dtype)
self.assertEqual(compiled.device.type, "cuda")
self.assertEqual(compiled.device.index, 0)
self.assertEqual(compiled.dtype, torch.float32)
def test_autocast_cpu(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
d_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
e_float16 = torch.mm(a_float32, b_float32)
f_float16 = torch.mm(d_float32, e_float16)
return f_float16
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.type, "cpu")
self.assertEqual(exported.dtype, torch.bfloat16)
def test_autocast_cpu_graph_break(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
torch._dynamo.graph_break()
d_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
e_float16 = torch.mm(a_float32, b_float32)
torch._dynamo.graph_break()
f_float16 = torch.mm(d_float32, e_float16)
return f_float16
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
opt = torch._dynamo.optimize("eager")(module)
res = opt(torch.tensor([0.5]))
self.assertEqual(res.device, real_device)
self.assertEqual(res.dtype, real_dtype)
self.assertEqual(res.device.type, "cpu")
self.assertEqual(res.dtype, torch.bfloat16)
def test_autocast_cpu_graph_break_2(self):
# Regression for: https://github.com/pytorch/pytorch/issues/93890
def fn(x):
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
x = torch.mm(x, x)
torch._dynamo.graph_break()
x = torch.relu(x)
return x
x = torch.rand([4, 4])
self.assertEqual(x.dtype, torch.float32)
res = fn(x)
opt_fn = torch._dynamo.optimize("eager")(fn)
opt_res = opt_fn(x)
self.assertTrue(torch.allclose(res, opt_res))
self.assertEqual(res.dtype, torch.bfloat16)
self.assertEqual(opt_res.dtype, torch.bfloat16)
def test_autocast_cpu_graph_break_inner_fn(self):
class MyModule(torch.nn.Module):
@staticmethod
def mm_breaks(x, y):
torch._dynamo.graph_break()
return torch.mm(x, y)
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
torch._dynamo.graph_break()
with torch.autocast(
device_type="cpu", dtype=torch.bfloat16, enabled=False
):
torch._dynamo.graph_break()
g_float32 = torch.mm(a_float32, b_float32)
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# Check that nested with non-inlineable function with graph break
torch._dynamo.graph_break()
f_float16_1 = self.mm_breaks(a_float32, b_float32)
# We remember to exit the inner autocast correctly to outer
# even after graph breaks
f_float16 = self.mm_breaks(a_float32, b_float32)
assert f_float16.dtype == f_float16_1.dtype
return f_float16, g_float32
module = MyModule()
real_16, real_32 = module(torch.tensor([0.5]))
real_device_16 = real_16.device
real_dtype_16 = real_16.dtype
real_device_32 = real_32.device
real_dtype_32 = real_32.dtype
graph = torch._dynamo.optimize("eager")(module)
out_16, out_32 = graph(torch.tensor([0.5]))
self.assertEqual(out_16.device, real_device_16)
self.assertEqual(out_16.dtype, real_dtype_16)
self.assertEqual(out_32.device, real_device_32)
self.assertEqual(out_32.dtype, real_dtype_32)
self.assertEqual(out_16.device.type, "cpu")
self.assertEqual(out_16.dtype, torch.bfloat16)
self.assertEqual(out_32.device.type, "cpu")
self.assertEqual(out_32.dtype, torch.float32)
def test_autocast_graph_break_method(self):
class MyModule(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.bias = bias
def mm_not_break(self, x, y):
return torch.mm(x, y) + self.bias
def mm_breaks(self, x, y):
torch._dynamo.graph_break()
return torch.mm(x, y) + self.bias
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cpu")
b_float32 = torch.rand((8, 8), device="cpu")
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
with torch.autocast(
device_type="cpu", dtype=torch.bfloat16, enabled=False
):
g_float32 = torch.mm(a_float32, b_float32)
f_float16 = self.mm_breaks(a_float32, b_float32)
assert (
f_float16[0][0] == self.mm_not_break(a_float32, b_float32)[0][0]
)
return f_float16, g_float32
module = MyModule(bias=torch.rand((8, 8), device="cpu", dtype=torch.bfloat16))
with torch.autocast(device_type="cpu", dtype=torch.bfloat16):
# Autocast doesn't work on addition, so we need the bias to be `bfloat16`
res = torch.rand((8, 8), device="cpu", dtype=torch.float32) + torch.rand(
(8, 8), device="cpu", dtype=torch.bfloat16
)
self.assertEqual(res.dtype, torch.float32)
real_16, real_32 = module(torch.tensor([0.5]))
real_device_16 = real_16.device
real_dtype_16 = real_16.dtype
real_device_32 = real_32.device
real_dtype_32 = real_32.dtype
graph = torch._dynamo.optimize("eager")(module)
out_16, out_32 = graph(torch.tensor([0.5]))
self.assertEqual(out_16.device, real_device_16)
self.assertEqual(out_16.dtype, real_dtype_16)
self.assertEqual(out_32.device, real_device_32)
self.assertEqual(out_32.dtype, real_dtype_32)
self.assertEqual(out_16.device.type, "cpu")
self.assertEqual(out_16.dtype, torch.bfloat16)
self.assertEqual(out_32.device.type, "cpu")
self.assertEqual(out_32.dtype, torch.float32)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast_float64(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast(device_type="cuda", dtype=torch.float64):
e_float64 = torch.mm(a_float32, b_float32)
f_float64 = torch.mm(d_float32, e_float64)
return f_float64
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.float64)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast_device(self):
class MyModule(torch.nn.Module):
def forward(self, x):
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with torch.autocast("cuda"):
e_float64 = torch.mm(a_float32, b_float32)
f_float64 = torch.mm(d_float32, e_float64)
return f_float64
module = MyModule()
real = module(torch.tensor([0.5]))
real_device = real.device
real_dtype = real.dtype
graph, guards = torch._dynamo.export(module)(torch.tensor([[0.0, 0], [0, 0]]))
exported = graph(torch.tensor([0.5]))
self.assertEqual(exported.device, real_device)
self.assertEqual(exported.dtype, real_dtype)
self.assertEqual(exported.device.index, 0)
self.assertEqual(exported.dtype, torch.torch.float16)
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
def test_autocast_arguments_binding(self):
def f1(x):
with torch.cuda.amp.autocast(False):
x = torch.sin(x + 1)
return x
def f2(x):
with torch.cpu.amp.autocast(False):
x = torch.cos(x + 1)
return x
x = torch.rand([2, 3])
ref1 = f1(x)
ref2 = f2(x)
opt_f1 = torch.compile(backend="eager")(f1)
opt_f2 = torch.compile(backend="eager")(f2)
res1 = opt_f1(x)
res2 = opt_f2(x)
self.assertTrue(same(ref1, res1))
self.assertTrue(same(ref2, res2))
def test_generic_context_manager(self):
def fn(x):
with CutomizedCtxManager(True):
x = x + 1
if torch.is_grad_enabled():
x = x * 2
x = torch.relu(x)
return x - 1
with torch.no_grad():
torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=6)
with torch.enable_grad():
torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=6)
def test_nested_generic_context_manager(self):
def fn(x):
with CutomizedCtxManager(True):
x = x + 1
if torch.is_grad_enabled():
x = x * 2
with CutomizedCtxManager(False):
if torch.is_grad_enabled():
x = x - 3
x = x * 1.5
x = torch.relu(x)
return x - 1
with torch.no_grad():
torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=9)
with torch.enable_grad():
torch._dynamo.testing.standard_test(self, fn=fn, nargs=1, expected_ops=9)
def test_generic_context_manager_with_graph_break(self):
def fn(x):
with CutomizedCtxManager(True):
x = x + 1
if torch.is_grad_enabled():
x = x * 2
torch._dynamo.graph_break()
x = torch.relu(x)
return x - 1
x = torch.rand(2, 3)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
with torch.no_grad():
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 2)
self.assertEqual(cnts.op_count, 2)
with torch.enable_grad():
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 4)
self.assertEqual(cnts.op_count, 4)
def test_nested_generic_context_manager_with_graph_break(self):
def fn(x):
with CutomizedCtxManager(True):
x = x + 1
if torch.is_grad_enabled():
x = x * 2
with CutomizedCtxManager(False):
if torch.is_grad_enabled():
x = x - 3
torch._dynamo.graph_break()
x = x * 1.5
x = torch.relu(x)
return x - 1
x = torch.rand(2, 3)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(backend=cnts, fullgraph=False)(fn)
with torch.no_grad():
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 4)
self.assertEqual(cnts.op_count, 4)
torch._dynamo.reset()
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
with torch.enable_grad():
ref = fn(x)
res = opt_fn(x)
self.assertTrue(same(ref, res))
self.assertEqual(cnts.frame_count, 4)
self.assertEqual(cnts.op_count, 4)
def test_graph_break_inlining_grad(self):
def gn(z):
with torch.no_grad():
torch._dynamo.graph_break()
return torch.sin(z)
def fn(x, y, z):
a = torch.mm(x, y)
z = gn(z)
return a
torch._dynamo.reset()
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch._dynamo.optimize(cnts, nopython=False)(fn)
x = torch.randn(4, 4, requires_grad=True)
y = torch.randn(4, 4, requires_grad=True)
z = torch.randn(4)
opt_fn(x, y, z).sum().backward()
self.assertEqual(cnts.frame_count, 2)
def _graph_break_inlining_autocast_test_helper(self, device):
def gn(x, y):
with torch.autocast(device_type=device, dtype=torch.bfloat16):
z = torch.mm(x, y)
torch._dynamo.graph_break()
return torch.sin(z)
def fn(x, y):
z = torch.mm(x, y)
z = z + gn(x, y)
return z
x = torch.rand(3, 3).to(device)
y = torch.rand(3, 3).to(device)
opt_fn = torch.compile(backend="eager")(fn)
ref = fn(x, y)
res = opt_fn(x, y)
self.assertEqual(ref, res)
def test_graph_break_inlining_autocast(self):
for device in ["cuda", "cpu"]:
if device == "cuda" and not (
torch.cuda.is_available() and torch.cuda.is_bf16_supported()
):
continue
self._graph_break_inlining_autocast_test_helper(device)
def test_disable_saved_tensors_hooks(self):
def fn(z):
@torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
def f(x, y):
return x + y
x, y = torch.ones(
1,
), torch.zeros(
1,
)
return f(x, y)
eager = EagerAndRecordGraphs()
torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
expected = """\
class GraphModule(torch.nn.Module):
def forward(self):
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
x = torch.ones(1)
y = torch.zeros(1)
add = x + y; x = y = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (add,)
"""
self.assertExpectedInline(actual, expected)
def test_disable_saved_tensors_hooks_prev_disabled(self):
def fn(z):
@torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
def f(x, y):
return x + y
x, y = torch.ones(
1,
), torch.zeros(
1,
)
return f(x, y)
eager = EagerAndRecordGraphs()
with torch.autograd.graph.disable_saved_tensors_hooks(
"Previously disabled message"
):
torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
expected = """\
class GraphModule(torch.nn.Module):
def forward(self):
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
x = torch.ones(1)
y = torch.zeros(1)
add = x + y; x = y = None
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message')
return (add,)
"""
self.assertExpectedInline(actual, expected)
def test_disable_saved_tensors_hooks_prev_disabled_nested(self):
def fn(z):
@torch.autograd.graph.disable_saved_tensors_hooks("This is not supported")
def f(x, y):
@torch.autograd.graph.disable_saved_tensors_hooks(
"This is not supported inner"
)
def inner_fn(x, y):
return x + y
return inner_fn(x, y) + x
x, y = torch.ones(
1,
), torch.zeros(
1,
)
return f(x, y)
eager = EagerAndRecordGraphs()
with torch.autograd.graph.disable_saved_tensors_hooks(
"Previously disabled message"
):
torch.compile(fn, backend=eager, fullgraph=True)(torch.randn(()))
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
expected = """\
class GraphModule(torch.nn.Module):
def forward(self):
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
x = torch.ones(1)
y = torch.zeros(1)
_saved_tensors_hooks_disable_1 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported inner')
add = x + y; y = None
_saved_tensors_hooks_disable_2 = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
add_1 = add + x; add = x = None
_saved_tensors_hooks_disable_3 = torch._C._autograd._saved_tensors_hooks_disable('Previously disabled message')
return (add_1,)
"""
self.assertExpectedInline(actual, expected)
def test_disable_saved_tensors_hooks_graph_break(self):
def fn(x):
with torch.autograd.graph.disable_saved_tensors_hooks(
"This is not supported"
):
y = x + 1
torch._dynamo.graph_break()
return y * 2
eager = EagerAndRecordGraphs()
torch.compile(fn, backend=eager, fullgraph=False)(torch.randn(()))
def check_graph(actual, expected):
self.assertExpectedInline(actual, expected)
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
y = l_x_ + 1; l_x_ = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (y,)
"""
graph = eager.graphs[0]
actual = normalize_gm(graph.print_readable(False))
check_graph(actual, expected)
expected = """\
class GraphModule(torch.nn.Module):
def forward(self, L_y_ : torch.Tensor):
l_y_ = L_y_
_saved_tensors_hooks_disable = torch._C._autograd._saved_tensors_hooks_disable('This is not supported')
mul = l_y_ * 2; l_y_ = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable()
return (mul,)
"""
graph = eager.graphs[1]
actual = normalize_gm(graph.print_readable(False))
check_graph(actual, expected)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()