mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enables B027 and applies fixes by adding abstract method decorators. Autofix generated by ruff master. Pull Request resolved: https://github.com/pytorch/pytorch/pull/100715 Approved by: https://github.com/ezyang
5238 lines
163 KiB
Python
5238 lines
163 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import abc
|
|
import collections
|
|
import copy
|
|
import dataclasses
|
|
import dis
|
|
import enum
|
|
import logging
|
|
import math
|
|
import operator
|
|
import os
|
|
import sys
|
|
import typing
|
|
import unittest
|
|
import unittest.mock as mock
|
|
import weakref
|
|
from unittest.mock import patch
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch.onnx.operators
|
|
from torch._C import FileCheck
|
|
from torch._dynamo import bytecode_analysis, bytecode_transformation
|
|
from torch._dynamo.output_graph import OutputGraph
|
|
from torch._dynamo.source import GetItemSource, LocalSource
|
|
from torch._dynamo.testing import (
|
|
CompileCounter,
|
|
requires_static_shapes,
|
|
same,
|
|
skipIfNotPy311,
|
|
unsupported,
|
|
)
|
|
|
|
from torch._dynamo.utils import CompileProfiler, ifdyn, ifdynstaticdefault, ifunspec
|
|
from torch.ao.quantization import MinMaxObserver
|
|
from torch.ao.quantization.fake_quantize import FakeQuantize
|
|
from torch.ao.quantization.qconfig import QConfig
|
|
from torch.ao.quantization.quantize_fx import prepare_qat_fx
|
|
from torch.autograd.profiler import _enable_dynamo_cache_lookup_profiler
|
|
from torch.fx.experimental.symbolic_shapes import ConstraintViolationError
|
|
from torch.nn import functional as F
|
|
from torch.testing._internal.common_cuda import (
|
|
PLATFORM_SUPPORTS_FUSED_SDPA,
|
|
SM80OrLater,
|
|
)
|
|
from torch.testing._internal.common_utils import freeze_rng_state
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
|
|
mytuple = collections.namedtuple("mytuple", ["a", "b", "ab"])
|
|
|
|
|
|
class MyPickledModule(torch.nn.Module):
|
|
def __init__(self, z):
|
|
super().__init__()
|
|
self.z = z
|
|
|
|
def forward(self, x, y):
|
|
return x * x * x + y + self.z
|
|
|
|
|
|
# These are used for test_{cond/map}_with_quantization
|
|
default_symmetric_fake_quant = FakeQuantize.with_args(
|
|
observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.quint8
|
|
)
|
|
default_weight_symmetric_fake_quant = FakeQuantize.with_args(
|
|
observer=MinMaxObserver, qscheme=torch.per_tensor_symmetric, dtype=torch.qint8
|
|
)
|
|
uniform_qconfig_8bit = QConfig(
|
|
activation=default_symmetric_fake_quant,
|
|
weight=default_weight_symmetric_fake_quant.with_args,
|
|
)
|
|
qconfig_dict = {"object_type": [(torch.nn.Linear, uniform_qconfig_8bit)]}
|
|
|
|
|
|
class MiscTests(torch._dynamo.test_case.TestCase):
|
|
def test_boolarg(self):
|
|
def boolarg(aa, bb, flag):
|
|
if flag:
|
|
return aa - bb
|
|
else:
|
|
return bb - aa
|
|
|
|
a = torch.randn(10, 10)
|
|
b = torch.randn(10, 10)
|
|
correct1 = boolarg(a, b, True)
|
|
correct2 = boolarg(a, b, False)
|
|
correct3 = boolarg(a, b, None)
|
|
counter = CompileCounter()
|
|
opt_boolarg = torch._dynamo.optimize_assert(counter)(boolarg)
|
|
val1 = opt_boolarg(a, b, True)
|
|
val2 = opt_boolarg(a, b, False)
|
|
val3 = opt_boolarg(a, b, None)
|
|
val4 = opt_boolarg(a, b, True)
|
|
self.assertTrue(same(val1, correct1))
|
|
self.assertTrue(same(val2, correct2))
|
|
self.assertTrue(same(val3, correct3))
|
|
self.assertTrue(same(val4, correct1))
|
|
self.assertEqual(counter.frame_count, 3)
|
|
|
|
def test_callpacked(self):
|
|
def call_packed(args):
|
|
a, b, c = args
|
|
return a - b * c
|
|
|
|
counter = CompileCounter()
|
|
a = torch.randn(10, 10)
|
|
b = torch.randn(10, 10)
|
|
c = torch.randn(10, 10)
|
|
correct = call_packed([a, b, c])
|
|
opt_call_packed = torch._dynamo.optimize_assert(counter)(call_packed)
|
|
val1 = opt_call_packed([a, b, c])
|
|
val2 = opt_call_packed((a, b, c))
|
|
val3 = opt_call_packed([a, b, c])
|
|
val4 = opt_call_packed((a, b, c))
|
|
self.assertTrue(same(val1, correct))
|
|
self.assertTrue(same(val2, correct))
|
|
self.assertTrue(same(val3, correct))
|
|
self.assertTrue(same(val4, correct))
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
def test_raises(self):
|
|
def fn(a, b, c, cls):
|
|
x = a + b - c * 10
|
|
raise cls(str(x))
|
|
|
|
counter = CompileCounter()
|
|
a = torch.randn(10, 10)
|
|
b = torch.randn(10, 10)
|
|
c = torch.randn(10, 10)
|
|
opt_fn = torch._dynamo.optimize(counter)(fn)
|
|
self.assertRaises(AssertionError, lambda: opt_fn(a, b, c, AssertionError))
|
|
self.assertEqual(counter.frame_count, 1)
|
|
self.assertEqual(counter.op_count, 3)
|
|
|
|
def test_inplace(self):
|
|
def inplace1(a, b):
|
|
o = torch.empty((10, 10))
|
|
o.copy_(a)
|
|
o -= b
|
|
return o
|
|
|
|
torch._dynamo.testing.standard_test(self, inplace1, 2, expected_ops=3)
|
|
|
|
def test_unpack4(self):
|
|
def unpack4(a, b):
|
|
a = a[:5, :]
|
|
b = b[:5, :]
|
|
x, y = a.size()
|
|
o = torch.empty((x, y))
|
|
o.copy_(a / b)
|
|
return o
|
|
|
|
torch._dynamo.testing.standard_test(
|
|
self,
|
|
unpack4,
|
|
2,
|
|
expected_ops=5,
|
|
expected_ops_dynamic=ifdynstaticdefault(6, 7),
|
|
)
|
|
|
|
def test_unpack5(self):
|
|
def unpack5(a, b):
|
|
a = a[:5, :]
|
|
b = b[:5, :]
|
|
x, y = a.shape
|
|
o = torch.empty((x, y))
|
|
o.copy_(a / b)
|
|
return o
|
|
|
|
torch._dynamo.testing.standard_test(
|
|
self,
|
|
unpack5,
|
|
2,
|
|
expected_ops=5,
|
|
expected_ops_dynamic=ifdynstaticdefault(6, 7),
|
|
)
|
|
|
|
def test_matmul1(self):
|
|
def matmul_op1(a, b):
|
|
return a @ b
|
|
|
|
# TODO(jansel): FX doesn't support this, should add upstream support
|
|
torch._dynamo.testing.standard_test(self, matmul_op1, 2, expected_ops=1)
|
|
|
|
def test_int_shape_binops(self):
|
|
def fn(x):
|
|
# Test reversal by putting int arg first.
|
|
y = 15 - x.shape[0]
|
|
y = 4 + y
|
|
y = 5 * y
|
|
y = 2 % y
|
|
y = 3**y
|
|
y = 10 // y
|
|
y = pow(2, y)
|
|
y = 10 / y
|
|
return x + y
|
|
|
|
torch._dynamo.testing.standard_test(
|
|
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 11)
|
|
)
|
|
|
|
def test_shape_int_inplace_binops(self):
|
|
def fn(x):
|
|
p = x.shape[0]
|
|
p += 2
|
|
p -= 2
|
|
p **= 2
|
|
p /= 2
|
|
p *= 2
|
|
p //= 2
|
|
p %= 2
|
|
return x + p
|
|
|
|
torch._dynamo.testing.standard_test(
|
|
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
|
|
)
|
|
|
|
def test_int_shape_inplace_binops(self):
|
|
def fn(x):
|
|
p = x.shape[0]
|
|
# Test reversal by putting constant first
|
|
y = 2
|
|
y += p
|
|
y = 2
|
|
y -= p
|
|
y = 2
|
|
y **= p
|
|
y = 2
|
|
y /= p
|
|
y = 2
|
|
y *= p
|
|
y = 2
|
|
y //= p
|
|
y = 2
|
|
y %= p
|
|
return x + y
|
|
|
|
torch._dynamo.testing.standard_test(
|
|
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 10)
|
|
)
|
|
|
|
def test_int_int_comparisons(self):
|
|
def fn(x):
|
|
if 2 != 2:
|
|
out = 1
|
|
elif 2 < 1:
|
|
out = 1
|
|
elif 1 > 2:
|
|
out = 1
|
|
elif 1 >= 2:
|
|
out = 1
|
|
elif 2 <= 1:
|
|
out = 1
|
|
elif 2 == 2:
|
|
out = 2
|
|
else:
|
|
out = 1
|
|
return x + out
|
|
|
|
torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
|
|
|
|
def test_shape_int_comparisons(self):
|
|
def fn(x):
|
|
a = x.shape[0]
|
|
# Ensure support for constant on right side
|
|
if a != 10:
|
|
out = 1
|
|
elif a < 2:
|
|
out = 1
|
|
elif a > 12:
|
|
out = 1
|
|
elif a >= 12:
|
|
out = 1
|
|
elif a <= 2:
|
|
out = 1
|
|
elif a == 10:
|
|
out = 2
|
|
else:
|
|
out = 1
|
|
return x + out
|
|
|
|
# expect for dynamic: size, index, 6 comparison ops, add
|
|
torch._dynamo.testing.standard_test(
|
|
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
|
|
)
|
|
|
|
def test_int_shape_comparisons(self):
|
|
def fn(x):
|
|
a = x.shape[0]
|
|
# Ensure support for constant on left side
|
|
if 10 != a:
|
|
out = 1
|
|
elif 12 < a:
|
|
out = 1
|
|
elif 2 > a:
|
|
out = 1
|
|
elif 2 >= a:
|
|
out = 1
|
|
elif 12 <= a:
|
|
out = 1
|
|
elif 10 == a:
|
|
out = 2
|
|
else:
|
|
out = 1
|
|
return x + out
|
|
|
|
# expect for dynamic: size, index, 6 comparison ops, add
|
|
torch._dynamo.testing.standard_test(
|
|
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 9)
|
|
)
|
|
|
|
def test_param_shape_binops(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.randn(15))
|
|
|
|
def forward(self, x):
|
|
# Test reversal by putting param shape arg first.
|
|
p = self.param.shape[0]
|
|
y = p - x.shape[0]
|
|
y = p + y
|
|
y = p * y
|
|
y = p % y
|
|
y = p**y
|
|
y = p // y
|
|
y = pow(p, y)
|
|
y = p / y
|
|
return x + y
|
|
|
|
counts = torch._dynamo.testing.CompileCounter()
|
|
mod = MyModule()
|
|
optimized_mod = torch._dynamo.optimize(counts, nopython=True)(mod)
|
|
|
|
x = torch.randn(3)
|
|
ref = mod(x)
|
|
res = optimized_mod(x)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
self.assertEqual(counts.frame_count, 1)
|
|
|
|
expected_op_count = (
|
|
ifdynstaticdefault(3, 12)
|
|
if torch._dynamo.testing.config.dynamic_shapes
|
|
else 1
|
|
)
|
|
self.assertEqual(counts.op_count, expected_op_count)
|
|
|
|
def test_user_defined_binop(self):
|
|
class MyClass:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def __radd__(self, other):
|
|
return self.value + other
|
|
|
|
def fn(x, c):
|
|
y = x.shape[0] + c
|
|
return x + y
|
|
|
|
counts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(counts)(fn)
|
|
|
|
x = torch.randn(3)
|
|
c = MyClass(4)
|
|
ref = fn(x, c)
|
|
res = opt_fn(x, c)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
self.assertEqual(counts.frame_count, 1)
|
|
expected_op_count = (
|
|
ifdynstaticdefault(2, 4)
|
|
if torch._dynamo.testing.config.dynamic_shapes
|
|
else 1
|
|
)
|
|
self.assertEqual(counts.op_count, expected_op_count)
|
|
|
|
def test_compare_shapes_eq(self):
|
|
def compare_shapes(a, b, to_list):
|
|
x = list(a.unsqueeze(-1).shape) if to_list else a.shape
|
|
y = list(b.unsqueeze(-1).shape) if to_list else b.shape
|
|
if x == y:
|
|
return a + 1
|
|
else:
|
|
return a + 2
|
|
|
|
# Test both ListVariable and ShapeVariable
|
|
torch._dynamo.testing.standard_test(
|
|
self, lambda a, b: compare_shapes(a, b, to_list=True), 2
|
|
)
|
|
torch._dynamo.testing.standard_test(
|
|
self, lambda a, b: compare_shapes(a, b, to_list=False), 2
|
|
)
|
|
|
|
def test_compare_shapes_tuple_eq(self):
|
|
def compare_shapes(a, b):
|
|
x = tuple(a.unsqueeze(-1).shape)
|
|
y = tuple(b.unsqueeze(-1).shape)
|
|
if x == y:
|
|
return a + 1
|
|
else:
|
|
return a + 2
|
|
|
|
torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2)
|
|
|
|
def test_compare_shapes_tuple_neq(self):
|
|
def compare_shapes(a, b):
|
|
x = tuple(a.unsqueeze(-1).shape)
|
|
y = tuple(b.unsqueeze(-1).shape)
|
|
if x != y:
|
|
return a + 1
|
|
else:
|
|
return a + 2
|
|
|
|
torch._dynamo.testing.standard_test(self, lambda a, b: compare_shapes(a, b), 2)
|
|
|
|
def test_compare_shapes_neq(self):
|
|
def compare_shapes(a, b, to_list):
|
|
x = list(a.unsqueeze(-1).shape) if to_list else a.shape
|
|
y = list(b.unsqueeze(-1).shape) if to_list else b.shape
|
|
if x != y:
|
|
return a + 1
|
|
else:
|
|
return a + 2
|
|
|
|
# Test both ListVariable and ShapeVariable
|
|
torch._dynamo.testing.standard_test(
|
|
self, lambda a, b: compare_shapes(a, b, to_list=True), 2
|
|
)
|
|
torch._dynamo.testing.standard_test(
|
|
self, lambda a, b: compare_shapes(a, b, to_list=False), 2
|
|
)
|
|
|
|
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
|
def test_compare_shapes_with_constant(self):
|
|
def compare_shapes(a):
|
|
x = a.shape
|
|
if x[0] != 3:
|
|
return a * 4
|
|
return a * 3
|
|
|
|
guard_failure = None
|
|
|
|
def guard_failures(failure):
|
|
nonlocal guard_failure
|
|
guard_failure = failure
|
|
|
|
opt_fn = torch._dynamo.optimize(
|
|
"eager", nopython=True, guard_fail_fn=guard_failures
|
|
)(compare_shapes)
|
|
opt_fn(torch.randn([3, 4]))
|
|
opt_fn(torch.randn([4, 3]))
|
|
self.assertExpectedInline(
|
|
guard_failure.reason,
|
|
"""tensor 'L['a']' size mismatch at index 0. expected 3, actual 4""",
|
|
)
|
|
|
|
def test_builtin_isinstance(self):
|
|
def fn(x):
|
|
t = torch.arange(1, 3)
|
|
a = isinstance(x, torch.Tensor)
|
|
b = isinstance(t, torch.Tensor)
|
|
c = isinstance(x, int)
|
|
d = isinstance(3, int)
|
|
e = isinstance([1, 2, 3], list)
|
|
f = isinstance({"foo": 1, "bar": 2}, dict)
|
|
res = [a, b, c, d, e, f]
|
|
# Can't run yet due to other unimplemented instructions
|
|
# res += [isinstance(torch.nn.LazyLinear(2, 3), torch.nn.Linear)]
|
|
return res
|
|
|
|
torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
|
|
|
|
def test_fold(self):
|
|
def fn(a):
|
|
return a + math.sqrt(63)
|
|
|
|
torch._dynamo.testing.standard_test(self, fn, 1, expected_ops=1)
|
|
|
|
def test_shape_unpack(self):
|
|
def fn(x):
|
|
a, b = x.size()
|
|
return x * b
|
|
|
|
i = torch.randn(5, 10)
|
|
r1 = fn(i)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
r2 = opt_fn(i)
|
|
self.assertTrue(same(r1, r2))
|
|
|
|
def test_tensor_iter(self):
|
|
def fn(x):
|
|
for y in x:
|
|
y.add_(1.0)
|
|
return y
|
|
|
|
# expect extra size node for dynamic
|
|
torch._dynamo.testing.standard_test(
|
|
self, fn, 1, expected_ops=20, expected_ops_dynamic=21
|
|
)
|
|
|
|
def test_empty_list(self):
|
|
def fn(x, ll):
|
|
if len(ll) == 0 and not ll and ll is not None:
|
|
return x + 1
|
|
|
|
i = torch.randn(5, 10)
|
|
r1 = fn(i, [])
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
r2 = opt_fn(i, [])
|
|
r3 = opt_fn(i, tuple())
|
|
self.assertTrue(same(r1, r2))
|
|
self.assertTrue(same(r1, r3))
|
|
|
|
def test_min_max_over_iterable(self):
|
|
def get_test_fn(func):
|
|
def _fn(a, b, func=func):
|
|
# try all of list, iterator, tuple, vararg.
|
|
lst = [a.shape[0] + 1, 8, a.shape[0]]
|
|
x = func(lst)
|
|
y = func(iter(lst))
|
|
z = func(tuple(lst))
|
|
w = func(*lst)
|
|
return a + (x + y + z + w)
|
|
|
|
return _fn
|
|
|
|
torch._dynamo.testing.standard_test(
|
|
self,
|
|
get_test_fn(func=min),
|
|
2,
|
|
expected_ops=1,
|
|
expected_ops_dynamic=ifdynstaticdefault(3, 14),
|
|
)
|
|
torch._dynamo.testing.standard_test(
|
|
self,
|
|
get_test_fn(func=max),
|
|
2,
|
|
expected_ops=1,
|
|
expected_ops_dynamic=ifdynstaticdefault(3, 17),
|
|
)
|
|
|
|
def test_config_obj(self):
|
|
class Cfg:
|
|
def __init__(self):
|
|
self.val = 0.5
|
|
self.count = 3
|
|
|
|
def fn(x, cfg):
|
|
for i in range(cfg.count):
|
|
x = x + cfg.val
|
|
return x
|
|
|
|
cfg1 = Cfg()
|
|
cfg1.val = 1.0
|
|
cfg2 = Cfg()
|
|
v = torch.zeros(1)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
v = opt_fn(v, cfg1) # 3
|
|
v = opt_fn(v, cfg2) # 4.5
|
|
cfg2.count = 1
|
|
v = opt_fn(v, cfg2) # 5
|
|
cfg2.val = 2.0
|
|
v = opt_fn(v, cfg2) # 7
|
|
self.assertEqual(v[0], 7)
|
|
self.assertEqual(cnts.op_count, 8)
|
|
|
|
def test_config_getattr_default(self):
|
|
class Cfg:
|
|
def __init__(self):
|
|
self.val = 0.5
|
|
self.count = 10
|
|
|
|
def fn(x, cfg):
|
|
if getattr(cfg, "just_add_7", False):
|
|
return x + 7
|
|
for i in range(cfg.count):
|
|
x = x + cfg.val
|
|
return x
|
|
|
|
cfg1 = Cfg()
|
|
v = torch.zeros(1)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertEqual(opt_fn(v, cfg1)[0], 5)
|
|
self.assertEqual(opt_fn(v, cfg1)[0], 5)
|
|
cfg1.just_add_7 = True
|
|
self.assertEqual(opt_fn(v, cfg1)[0], 7)
|
|
self.assertEqual(opt_fn(v, cfg1)[0], 7)
|
|
cfg1.just_add_7 = False
|
|
self.assertEqual(opt_fn(v, cfg1)[0], 5)
|
|
self.assertEqual(opt_fn(v, cfg1)[0], 5)
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
|
|
def test_size_input(self):
|
|
def fn(x, s):
|
|
a, b = s
|
|
return x + (a - b)
|
|
|
|
v = torch.zeros(10, 20)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertEqual(opt_fn(v, v.size())[0, 0], -10)
|
|
self.assertEqual(opt_fn(v, (10, 20))[0, 0], -10)
|
|
self.assertEqual(opt_fn(v, [10, 20])[0, 0], -10)
|
|
# One recompile per differing input type
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
|
|
def test_cell_output1(self):
|
|
out = None
|
|
|
|
def fn(a, b):
|
|
nonlocal out
|
|
out = a + b * 10
|
|
|
|
v = torch.Tensor([100])
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertIsNone(opt_fn(v, v))
|
|
self.assertEqual(out[0], 1100)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_cell_output2(self):
|
|
out = None
|
|
|
|
def fn(a, b):
|
|
nonlocal out
|
|
c = unsupported(a, b)
|
|
out = a + b * 10 + c
|
|
|
|
v = torch.Tensor([100])
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertIsNone(opt_fn(v, v))
|
|
self.assertEqual(out[0], 1200)
|
|
self.assertEqual(cnts.op_count, 3)
|
|
|
|
def test_return_nested_function(self):
|
|
out = None
|
|
|
|
def fn(a, b):
|
|
nonlocal out
|
|
c = a + b
|
|
d = a + 1.0
|
|
|
|
def fn2(f: int = 7, g: float = 9.0):
|
|
nonlocal out
|
|
out = a + b * 10
|
|
return c * f - d * g
|
|
|
|
return fn2
|
|
|
|
v1 = torch.Tensor([100])
|
|
v2 = torch.Tensor([200])
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
opt_fn_ret = torch._dynamo.optimize(cnts)(opt_fn(v1, v2))
|
|
self.assertEqual(opt_fn_ret(1.5)[0], -459)
|
|
self.assertEqual(out[0], 2100)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 7)
|
|
|
|
def test_tensor_dict1(self):
|
|
def fn(inputs):
|
|
return inputs["a"] - inputs["b"] * 1.5
|
|
|
|
v1 = torch.Tensor([100])
|
|
v2 = torch.Tensor([200])
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertEqual(opt_fn({"a": v1, "b": v2})[0], -200)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_tensor_dict2(self):
|
|
def fn1(inputs):
|
|
total = torch.zeros(1)
|
|
for k, v in inputs.items():
|
|
total += v
|
|
return total
|
|
|
|
def fn2(inputs):
|
|
total = torch.zeros(1)
|
|
for v in inputs.values():
|
|
total += v
|
|
return total
|
|
|
|
def fn3(inputs):
|
|
total = torch.zeros(1)
|
|
for k in inputs.keys():
|
|
total += inputs[k]
|
|
return total
|
|
|
|
v1 = torch.Tensor([100])
|
|
v2 = torch.Tensor([200])
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn1 = torch._dynamo.optimize(cnts)(fn1)
|
|
opt_fn2 = torch._dynamo.optimize(cnts)(fn2)
|
|
opt_fn3 = torch._dynamo.optimize(cnts)(fn3)
|
|
self.assertEqual(opt_fn1({"a": v1, "b": v2})[0], 300)
|
|
self.assertEqual(opt_fn2({"a": v1, "b": v2})[0], 300)
|
|
self.assertEqual(opt_fn3({"a": v1, "b": v2})[0], 300)
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
self.assertEqual(cnts.op_count, 9)
|
|
|
|
def test_dictcomp(self):
|
|
def fn1(inputs):
|
|
return {k: v + 1 for k, v in inputs.items()}
|
|
|
|
v1 = torch.Tensor([100])
|
|
v2 = torch.Tensor([200])
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn1 = torch._dynamo.optimize(cnts)(fn1)
|
|
self.assertEqual(opt_fn1({"a": v1, "b": v2})["a"], 101)
|
|
self.assertEqual(opt_fn1({"a": v1, "b": v2})["b"], 201)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_listcomp(self):
|
|
def fn2(inputs):
|
|
return torch.sum(torch.cat([v + 1 for k, v in inputs.items()], 0))
|
|
|
|
v1 = torch.Tensor([100])
|
|
v2 = torch.Tensor([200])
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn2 = torch._dynamo.optimize(cnts)(fn2)
|
|
self.assertEqual(opt_fn2({"a": v1, "b": v2}), 302)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 4)
|
|
|
|
def test_is_floating_point(self):
|
|
def fn(a, b):
|
|
x = a + 1.0
|
|
if torch.is_floating_point(b):
|
|
x = x + b
|
|
return x + 2.0
|
|
|
|
return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
|
|
|
|
def test_is_floating_point2(self):
|
|
def fn(a, b):
|
|
x = a + 1.0
|
|
if b.is_floating_point():
|
|
x = x + b
|
|
return x + 2.0
|
|
|
|
return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
|
|
|
|
def test_is_tensor(self):
|
|
def fn(a, b):
|
|
x = a + 1.0
|
|
if torch.is_tensor(b):
|
|
x = x + b
|
|
return x + 2.0
|
|
|
|
return torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
|
|
|
|
def test_is_tensor2(self):
|
|
def fn(x):
|
|
if torch.is_tensor(x):
|
|
return x + 1
|
|
else:
|
|
return torch.ones([2, 3])
|
|
|
|
x1 = {"input": torch.rand(2, 3)}
|
|
x2 = torch.rand(2, 3)
|
|
ref1 = fn(x1)
|
|
ref2 = fn(x2)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res1 = opt_fn(x1)
|
|
res2 = opt_fn(x2)
|
|
self.assertEqual(ref1, res1)
|
|
self.assertEqual(ref2, res2)
|
|
|
|
def test_numel(self):
|
|
def fn(a):
|
|
return (a + a.numel() + torch.numel(a), a + a.nelement())
|
|
|
|
return torch._dynamo.testing.standard_test(
|
|
self, fn=fn, nargs=1, expected_ops=3, expected_ops_dynamic=6
|
|
)
|
|
|
|
def test_pair(self):
|
|
def fn(a):
|
|
return (
|
|
torch.zeros(torch.nn.modules.utils._pair(a.size()))
|
|
+ a
|
|
+ torch.ones(torch.nn.modules.utils._ntuple(3)(3)).sum()
|
|
)
|
|
|
|
return torch._dynamo.testing.standard_test(
|
|
self,
|
|
fn=fn,
|
|
nargs=1,
|
|
expected_ops=5,
|
|
expected_ops_dynamic=ifdynstaticdefault(6, 8),
|
|
)
|
|
|
|
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
|
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
|
def test_tensor_item_capture(self):
|
|
def fn(a, b):
|
|
return (a + b).sum().item()
|
|
|
|
v1 = torch.randn((10, 10))
|
|
v2 = torch.randn((10, 10))
|
|
correct = fn(v1, v2)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize((cnts))(fn)
|
|
self.assertEqual(opt_fn(v1, v2), correct)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 3)
|
|
|
|
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
|
@patch.object(torch._dynamo.config, "capture_scalar_outputs", False)
|
|
def test_tensor_item_no_capture(self):
|
|
def fn(a, b):
|
|
return (a + b).sum().item()
|
|
|
|
v1 = torch.randn((10, 10))
|
|
v2 = torch.randn((10, 10))
|
|
correct = fn(v1, v2)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize((cnts))(fn)
|
|
self.assertEqual(opt_fn(v1, v2), correct)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_namedtuple1(self):
|
|
def fn(a, b):
|
|
tmp = mytuple(a, b, a + b)
|
|
return mytuple(tmp.a, tmp[1], tmp.ab + b)
|
|
|
|
v1 = torch.Tensor([10])
|
|
v2 = torch.Tensor([20])
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertEqual(opt_fn(v1, v2).ab, 50)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_namedtuple2(self):
|
|
def fn(packed):
|
|
a, b, c = packed
|
|
if hasattr(packed, "b"):
|
|
b = packed.b + 1
|
|
c = packed[2]
|
|
return a + b + c
|
|
|
|
v1 = torch.Tensor([1])
|
|
v2 = torch.Tensor([2])
|
|
v3 = torch.Tensor([3])
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertEqual(opt_fn(mytuple(v1, v2, v3))[0], 7)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 3)
|
|
|
|
def test_namedtuple3(self):
|
|
def fn(x, packed):
|
|
if isinstance(packed, mytuple):
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.rand([2, 3])
|
|
packed = mytuple(1, 2, 3)
|
|
ref = fn(x, packed)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(x, packed)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_range_input(self):
|
|
def fn(a, rng):
|
|
x = a
|
|
for i in rng:
|
|
x = x + i
|
|
return x
|
|
|
|
def fn1(a):
|
|
return fn(a, rng=range(3))
|
|
|
|
return torch._dynamo.testing.standard_test(
|
|
self, fn=fn1, nargs=1, expected_ops=3
|
|
)
|
|
|
|
def test_range_with_shape(self):
|
|
def fn(a):
|
|
for i in range(1, a.shape[0]):
|
|
a += 1
|
|
return a
|
|
|
|
# expect 1 more op (size call) for dynamic
|
|
return torch._dynamo.testing.standard_test(
|
|
self, fn=fn, nargs=1, expected_ops=9, expected_ops_dynamic=10
|
|
)
|
|
|
|
def test_build_tuple_unpack(self):
|
|
def fn1(a, b, c):
|
|
return a - b / c
|
|
|
|
def fn2(a, b, c):
|
|
tmp1 = (a,)
|
|
tmp2 = (b, c)
|
|
args = (*tmp1, *tmp2)
|
|
return fn1(*args)
|
|
|
|
def fn3(a, *args):
|
|
return fn1(a, *args)
|
|
|
|
torch._dynamo.testing.standard_test(self, fn=fn2, nargs=3, expected_ops=2)
|
|
torch._dynamo.testing.standard_test(self, fn=fn3, nargs=3, expected_ops=2)
|
|
|
|
def test_list_mul(self):
|
|
def fn(count):
|
|
head_mask = count * [None] * count
|
|
return head_mask
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertEqual(opt_fn(2), [None] * 4)
|
|
# TODO: the captured frame here is a bit goofy, because we don't
|
|
# output anything and none of the traced operations have side
|
|
# effects. Probably need better heuristic for bailing on
|
|
# dynamo if there are no outputs
|
|
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
|
|
self.assertEqual(cnts.op_count, ifunspec(2, 0))
|
|
|
|
def test_list_slice_mul(self):
|
|
def fn(count):
|
|
a = [1, 2, 3]
|
|
head_mask = count * a[1:] * count
|
|
return head_mask
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertEqual(opt_fn(2), [2, 3] * 4)
|
|
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
|
|
self.assertEqual(cnts.op_count, ifunspec(2, 0))
|
|
|
|
def test_tuple_mul(self):
|
|
def fn(count):
|
|
head_mask = count * (2, 3) * count
|
|
return head_mask
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertEqual(opt_fn(2), (2, 3) * 4)
|
|
self.assertEqual(cnts.frame_count, ifunspec(1, 0))
|
|
self.assertEqual(cnts.op_count, ifunspec(ifdynstaticdefault(2, 2), 0))
|
|
|
|
def test_tuple_mul_with_shape(self):
|
|
def fn(a):
|
|
x = a.shape[0]
|
|
y = 2 * (x, 3) * 2
|
|
return a + y[4]
|
|
|
|
# expect 3 ops post folding for dynamic case: size, index, add
|
|
torch._dynamo.testing.standard_test(
|
|
self, fn, 1, expected_ops=1, expected_ops_dynamic=ifdynstaticdefault(2, 3)
|
|
)
|
|
|
|
def test_tuple_iadd_with_shape(self):
|
|
def fn(a):
|
|
output = (a + a.shape[0], a - a.shape[0])
|
|
# tuple += tuple
|
|
output += (a - a.shape[0], a + a.shape[0])
|
|
# tuple += constant tuple
|
|
output += (2, 3)
|
|
return output
|
|
|
|
# expect 4 add / subs for static, 4 * 3 (size, index, math op) for dynamic
|
|
torch._dynamo.testing.standard_test(
|
|
self, fn, 1, expected_ops=4, expected_ops_dynamic=ifdynstaticdefault(8, 12)
|
|
)
|
|
|
|
def test_list_iadd_with_shape(self):
|
|
def fn(a):
|
|
output = [a + a.shape[0], a - a.shape[0]]
|
|
# list += list
|
|
output += [a - a.shape[0], a + a.shape[0]]
|
|
# list += tuple
|
|
output += (a + a.shape[0], a - a.shape[0])
|
|
return output
|
|
|
|
# expect 6 add / subs for static, 6 * 3 (size, index, math op) for dynamic
|
|
|
|
torch._dynamo.testing.standard_test(
|
|
self, fn, 1, expected_ops=6, expected_ops_dynamic=ifdynstaticdefault(12, 18)
|
|
)
|
|
|
|
def test_user_getattr1(self):
|
|
class MyConfig(dict):
|
|
def __getattr__(self, name):
|
|
return self[name]
|
|
|
|
def fn(cfg, x, y):
|
|
return x + y + cfg.offset
|
|
|
|
x = torch.randn(10)
|
|
cfg = MyConfig(offset=5)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_user_getattr2(self):
|
|
class MyConfig:
|
|
defined_on_class = 1
|
|
|
|
def __init__(self):
|
|
self.defined_on_object = 2
|
|
|
|
def __getattr__(self, name):
|
|
return 3
|
|
|
|
def fn(cfg, x):
|
|
return x + cfg.defined_on_class - cfg.defined_on_object + cfg.not_defined
|
|
|
|
x = torch.randn(10)
|
|
cfg = MyConfig()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(cfg, x), x + 1 - 2 + 3))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 3)
|
|
|
|
def test_user_getattribute(self):
|
|
class MyObject:
|
|
def __init__(self):
|
|
self.custom_dict = {"a": torch.rand((2, 2))}
|
|
self.my_number = 42
|
|
|
|
def __getattribute__(self, name):
|
|
custom_dict = super().__getattribute__("custom_dict")
|
|
if name in custom_dict:
|
|
return custom_dict[name]
|
|
return super().__getattribute__(name)
|
|
|
|
def run(self, x):
|
|
return self.my_number * x + self.a * x
|
|
|
|
def fn(obj, x):
|
|
return obj.run(x)
|
|
|
|
obj = MyObject()
|
|
x = torch.rand((2, 2))
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(obj, x), fn(obj, x)))
|
|
|
|
def test_nn_module_getattr(self):
|
|
class MyMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.custom_dict = {"queue": [torch.rand((2, 2)) for _ in range(3)]}
|
|
self.other_attr = torch.rand((2, 2))
|
|
|
|
def __getattr__(self, name):
|
|
custom_dict = self.custom_dict
|
|
if name in custom_dict:
|
|
return custom_dict[name]
|
|
return super().__getattr__(name)
|
|
|
|
def forward(self, x):
|
|
return x @ self.other_attr + self.queue[-1]
|
|
|
|
x = torch.rand((2, 2))
|
|
mod = MyMod()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_mod = torch._dynamo.optimize(cnts)(mod)
|
|
self.assertTrue(same(opt_mod(x), mod(x)))
|
|
self.assertTrue(cnts.frame_count, 1)
|
|
self.assertTrue(cnts.op_count, 2)
|
|
|
|
def test_nn_module_getattribute(self):
|
|
class MyMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.my_number = 42
|
|
|
|
def __getattribute__(self, name):
|
|
if name == "special_attr":
|
|
return torch.tensor([[1, 2], [3, 4]])
|
|
return super().__getattribute__(name)
|
|
|
|
def forward(self, x):
|
|
return self.my_number * x + self.special_attr * x
|
|
|
|
def fn(mod, x):
|
|
return mod(x)
|
|
|
|
mod = MyMod()
|
|
x = torch.rand((2, 2))
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(mod, x), fn(mod, x)))
|
|
|
|
def test_constant_getattr(self):
|
|
# https://github.com/pytorch/pytorch/issues/97480
|
|
def fn():
|
|
return getattr(None, "arg", 3)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
optimized_fn = torch._dynamo.optimize(cnt)(fn)
|
|
res = optimized_fn()
|
|
self.assertTrue(same(res, 3))
|
|
|
|
def test_user_property(self):
|
|
class MyConfig:
|
|
@property
|
|
def prop5(self):
|
|
return 5
|
|
|
|
def fn(cfg, x, y):
|
|
return x + y + cfg.prop5
|
|
|
|
x = torch.randn(10)
|
|
cfg = MyConfig()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(cfg, x, x), 2 * x + 5))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_dataclass_fields(self):
|
|
@dataclasses.dataclass
|
|
class MyDataClass:
|
|
a: torch.Tensor
|
|
b: torch.Tensor = None
|
|
c: torch.Tensor = None
|
|
d: torch.Tensor = None
|
|
e: torch.Tensor = None
|
|
|
|
def fn(obj):
|
|
class_fields = dataclasses.fields(obj)
|
|
assert len(class_fields)
|
|
assert all(field.default is None for field in class_fields[1:])
|
|
other_fields_are_none = all(
|
|
getattr(obj, field.name) is None for field in class_fields[1:]
|
|
)
|
|
assert not other_fields_are_none
|
|
|
|
total = getattr(obj, class_fields[0].name)
|
|
for field in class_fields[1:]:
|
|
v = getattr(obj, field.name)
|
|
if v is not None:
|
|
total += v
|
|
|
|
return total
|
|
|
|
obj1 = MyDataClass(torch.randn(10), torch.randn(10), torch.randn(10))
|
|
obj2 = MyDataClass(torch.randn(10), e=torch.randn(10))
|
|
correct1 = fn(obj1)
|
|
correct2 = fn(obj2)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(obj1), correct1))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
torch._dynamo.reset()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(obj2), correct2))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 1)
|
|
|
|
@requires_static_shapes
|
|
def test_tensor_build_list_unpack(self):
|
|
def fn(x):
|
|
# seen in fastNLP_Bert
|
|
return torch.cat([*x], dim=-1)
|
|
|
|
val = torch.randn([1, 1, 473, 768])
|
|
correct = fn(val)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(val), correct))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_numpy_int_constant(self):
|
|
def fn(x, a, b):
|
|
return x + (a % b)
|
|
|
|
args = [torch.randn(10), 4096, np.int64(8)]
|
|
correct = fn(*args)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(*args), correct))
|
|
self.assertTrue(same(opt_fn(*args), correct))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_inplace_resize_on_graph_input(self):
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
# graph break when calling resize_() on graph input
|
|
def f1(x):
|
|
x.resize_(6)
|
|
x.mul_(2)
|
|
return x
|
|
|
|
@torch.compile(backend=cnts)
|
|
def f2(x):
|
|
x.resize_(6)
|
|
x.mul_(2)
|
|
return x
|
|
|
|
x = torch.ones(4)
|
|
y = torch.ones(4)
|
|
self.assertTrue(same(f1(x).shape, f2(y).shape))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 1) # mul_
|
|
|
|
def test_dict_mutation_side_effect(self):
|
|
def fn(d):
|
|
d["c"] = d["a"] + d.pop("b")
|
|
return d
|
|
|
|
args1 = {"a": torch.randn(10), "b": torch.randn(10)}
|
|
args2 = dict(args1)
|
|
assert fn(args1) is args1
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertIs(opt_fn(args2), args2)
|
|
self.assertTrue(same(args1, args2))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 1)
|
|
|
|
def test_module_deepcopy(self):
|
|
m1 = torch.nn.Sequential(
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
)
|
|
m2 = torch.nn.Sequential(
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
)
|
|
|
|
def fn(m, x):
|
|
m_copy = copy.deepcopy(m)
|
|
return m_copy(x)
|
|
|
|
v = torch.randn(10)
|
|
correct1 = fn(m1, v)
|
|
correct2 = fn(m2, v)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
for _ in range(10):
|
|
self.assertTrue(same(opt_fn(m1, v), correct1))
|
|
for _ in range(10):
|
|
self.assertTrue(same(opt_fn(m2, v), correct2))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 4)
|
|
|
|
def test_type_copy(self):
|
|
def fn(seq):
|
|
a, b = seq
|
|
return type(seq)([a + 1, b + 2, a + b])
|
|
|
|
args1 = [torch.randn(10), torch.randn(10)]
|
|
args2 = (torch.randn(10), torch.randn(10))
|
|
correct1 = fn(args1)
|
|
correct2 = fn(args2)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertTrue(same(opt_fn(args1), correct1))
|
|
self.assertTrue(same(opt_fn(args2), correct2))
|
|
self.assertIsInstance(opt_fn(args1), list)
|
|
self.assertIsInstance(opt_fn(args2), tuple)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 6)
|
|
|
|
def test_setattr_mutation1(self):
|
|
class MyObj: # noqa: B903
|
|
def __init__(self, a, b):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
def fn(obj):
|
|
obj.c = obj.a * obj.b + 1
|
|
obj.b = obj.a * obj.c + 2
|
|
obj.a = obj.b * obj.c + 3
|
|
obj.c = obj.a * obj.b + 4
|
|
obj.b = obj.a * obj.c + 5
|
|
obj.a = obj.b * obj.c + 6
|
|
return obj
|
|
|
|
x1 = torch.randn(10)
|
|
x2 = torch.randn(10)
|
|
obj1 = MyObj(x1, x2)
|
|
obj2 = MyObj(x1, x2)
|
|
fn(obj2)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
self.assertIs(opt_fn(obj1), obj1)
|
|
self.assertTrue(same(obj1.a, obj2.a))
|
|
self.assertTrue(same(obj1.b, obj2.b))
|
|
self.assertTrue(same(obj1.c, obj2.c))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 12)
|
|
|
|
def test_setattr_mutation2(self):
|
|
class MyObj:
|
|
def __init__(self, x):
|
|
self.a = x + 1
|
|
self.b = x + 2
|
|
|
|
def fn(x):
|
|
x = x / 3.0
|
|
obj = MyObj(x)
|
|
obj.c = obj.a * obj.b + 1
|
|
obj.b = obj.a * obj.c + 2
|
|
obj.a = obj.b * obj.c + 3
|
|
return obj
|
|
|
|
x1 = torch.randn(10)
|
|
obj2 = fn(x1)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
obj1 = opt_fn(x1)
|
|
self.assertTrue(same(obj1.a, obj2.a))
|
|
self.assertTrue(same(obj1.b, obj2.b))
|
|
self.assertTrue(same(obj1.c, obj2.c))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 9)
|
|
|
|
def test_setattr_mutation3(self):
|
|
# TODO(jansel): dead code eliminate the object creation
|
|
class MyObj:
|
|
def __init__(self, x):
|
|
super().__init__()
|
|
self.a = x + 1
|
|
self.b = x + 2
|
|
|
|
def fn(x):
|
|
x = x / 3.0
|
|
obj = MyObj(x)
|
|
obj.c = obj.a * obj.b + 1
|
|
obj.b = obj.a * obj.c + 2
|
|
obj.a = obj.b * obj.c + 3
|
|
return obj.a, obj.b, obj.c
|
|
|
|
x1 = torch.randn(10)
|
|
obj2 = fn(x1)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
obj1 = opt_fn(x1)
|
|
self.assertTrue(same(obj1, obj2))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 9)
|
|
|
|
def test_user_defined_class_name(self):
|
|
class MyClassFoo:
|
|
pass
|
|
|
|
def fn1(a, b, c):
|
|
tmp = MyClassFoo()
|
|
if tmp.__class__.__name__ == "MyClassFoo":
|
|
return a - b / c
|
|
|
|
torch._dynamo.testing.standard_test(self, fn=fn1, nargs=3)
|
|
|
|
def test_user_defined_class_python_type(self):
|
|
class MyClass1:
|
|
pass
|
|
|
|
class ExampleMeta(type):
|
|
pass
|
|
|
|
class MyClass2(metaclass=ExampleMeta):
|
|
pass
|
|
|
|
def fn(x, c):
|
|
if isinstance(c, MyClass1):
|
|
return x + 1
|
|
elif isinstance(c, MyClass2):
|
|
return x + 2
|
|
else:
|
|
return x + 3
|
|
|
|
x = torch.rand(3)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
for c in [MyClass1, MyClass2]:
|
|
ref = fn(x, c)
|
|
res = opt_fn(x, c)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_super_calling_with_metaclass(self):
|
|
class ExampleMeta(type):
|
|
pass
|
|
|
|
class MyClass1(metaclass=ExampleMeta):
|
|
@classmethod
|
|
def add(cls, x):
|
|
return x + 1
|
|
|
|
class MyClass2(MyClass1):
|
|
@classmethod
|
|
def add(cls, x):
|
|
torch._dynamo.graph_break()
|
|
return x + super().add(x)
|
|
|
|
def fn(x, obj):
|
|
return x + obj.add(x)
|
|
|
|
x = torch.rand(3)
|
|
obj = MyClass2()
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
ref = fn(x, obj)
|
|
res = opt_fn(x, obj)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_manual_seed(self):
|
|
def fn(a, b):
|
|
x = a + b
|
|
torch.manual_seed(9000)
|
|
return x + 1
|
|
|
|
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2, expected_ops=3)
|
|
|
|
def test_usr_cls_staticmethod(self):
|
|
class Foo:
|
|
@staticmethod
|
|
def bar(a, b):
|
|
return a + b
|
|
|
|
def fn(a, b):
|
|
return Foo.bar(a, b) - 1
|
|
|
|
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2)
|
|
|
|
def test_usr_cls_classmethod(self):
|
|
class Foo:
|
|
@classmethod
|
|
def bar(cls, a, b):
|
|
return a + b
|
|
|
|
def fn(a, b):
|
|
return Foo.bar(a, b) - 1
|
|
|
|
torch._dynamo.testing.standard_test(self, fn=fn, nargs=2)
|
|
|
|
def test_dunder_methods(self):
|
|
class Foo:
|
|
def __init__(self, val):
|
|
super().__init__()
|
|
self.val = val
|
|
|
|
def __add__(self, other):
|
|
return Foo(self.val + other.val)
|
|
|
|
def __mul__(self, other):
|
|
return Foo(self.val * other.val)
|
|
|
|
def __truediv__(self, other):
|
|
return Foo(self.val / other.val)
|
|
|
|
def __sub__(self, other):
|
|
return Foo(self.val - other.val)
|
|
|
|
def fn(a, b, c):
|
|
return Foo(a) + Foo(b) * Foo(c) / Foo(a) - Foo(b)
|
|
|
|
torch._dynamo.testing.standard_test(self, fn=fn, nargs=3, expected_ops=4)
|
|
|
|
def test_function_annotation(self):
|
|
class Variable:
|
|
pass
|
|
|
|
def fn(x):
|
|
x = x / 3.0
|
|
|
|
def inner(y: typing.List[Variable]):
|
|
return x + 1
|
|
|
|
return inner
|
|
|
|
x1 = torch.randn(10)
|
|
obj2 = fn(x1)([])
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
|
|
opt_fn_inner = torch._dynamo.optimize_assert(cnts)(opt_fn(x1))
|
|
obj1 = opt_fn_inner([])
|
|
self.assertTrue(same(obj1, obj2))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
def test_nested_closure(self):
|
|
v0 = torch.randn(10)
|
|
|
|
def fn1():
|
|
v1 = torch.randn(10)
|
|
|
|
def fn2(*args, **kwargs):
|
|
assert len(args) == 1
|
|
assert len(kwargs) == 1
|
|
v2 = torch.randn(10) + args[0] + kwargs["b"]
|
|
|
|
def fn3(v3=torch.randn(10)):
|
|
def fn4():
|
|
return v0 + v1 + v2 + v3 + 1
|
|
|
|
return fn4
|
|
|
|
return fn3
|
|
|
|
return fn2(1, b=2)()
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1)
|
|
tmp1 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
|
|
tmp2 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
|
|
self.assertTrue(tmp1().shape, (10,))
|
|
self.assertTrue(same(tmp1(), tmp1()))
|
|
self.assertFalse(same(tmp1(), tmp2()))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 9)
|
|
|
|
def test_nested_closure_mutation(self):
|
|
def fn1():
|
|
v1 = torch.randn(10)
|
|
|
|
def fn2():
|
|
v2 = torch.randn(10)
|
|
|
|
def fn3():
|
|
nonlocal v1, v2
|
|
v1 += 1
|
|
v2 += 2
|
|
return v1 + v2
|
|
|
|
return fn3
|
|
|
|
rv = fn2()
|
|
rv()
|
|
rv()
|
|
return rv
|
|
|
|
torch.manual_seed(9000)
|
|
counter1 = fn1()
|
|
result1 = [counter1(), counter1(), counter1()]
|
|
|
|
torch.manual_seed(9000)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn1 = torch._dynamo.optimize_assert(cnts)(fn1)
|
|
counter2 = torch._dynamo.optimize_assert(cnts)(opt_fn1())
|
|
result2 = [counter2(), counter2(), counter2()]
|
|
result1.append(counter1())
|
|
result2.append(counter2())
|
|
|
|
self.assertTrue(same(result1, result2))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 11)
|
|
|
|
def test_write_to_closures_in_inlining(self):
|
|
out = []
|
|
for use_dynamo in [False, True]:
|
|
|
|
def make_counter():
|
|
x = torch.randn(10)
|
|
|
|
def counter():
|
|
nonlocal x
|
|
x = x + 1
|
|
return x
|
|
|
|
return counter
|
|
|
|
torch.manual_seed(0)
|
|
counter = make_counter()
|
|
if not use_dynamo:
|
|
out.append(counter() + counter())
|
|
else:
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(cnts, nopython=True)
|
|
def fn(counter):
|
|
return counter() + counter()
|
|
|
|
out.append(fn(counter))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 3)
|
|
self.assertFalse(same(counter() + counter(), out[-1]))
|
|
|
|
self.assertTrue(same(out[0], out[1]))
|
|
|
|
def test_top_package_import(self):
|
|
def fn(x):
|
|
import torch.fx
|
|
|
|
assert not isinstance(x, torch.fx.Proxy)
|
|
return torch.sin(x)
|
|
|
|
x = torch.randn(4, 5)
|
|
ref = fn(x)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnts)(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_typing_union_and_optional(self):
|
|
def fn(x):
|
|
a = torch.jit.annotate(typing.Dict[str, typing.Optional[torch.Tensor]], {})
|
|
b = torch.jit.annotate(
|
|
typing.Dict[str, typing.Union[torch.Tensor, None]], {}
|
|
)
|
|
return a, b, x + 1
|
|
|
|
x = torch.randn(3)
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_optimize_on_module(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def custom_member(self):
|
|
# Just for checking that Dynamo returned mod object can redirect
|
|
# to this method
|
|
pass
|
|
|
|
def forward(self, x):
|
|
return self.relu(x)
|
|
|
|
cnts1 = torch._dynamo.testing.CompileCounter()
|
|
mod = MockModule()
|
|
optimized_mod = torch._dynamo.optimize(cnts1, nopython=True)(mod)
|
|
|
|
a = torch.randn(10)
|
|
ref = mod(a)
|
|
res = optimized_mod(a)
|
|
|
|
optimized_mod.custom_member()
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_nested_optimize_decorator(self):
|
|
cnts2 = torch._dynamo.testing.CompileCounter()
|
|
cnts3 = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.run()
|
|
def fn1(x):
|
|
return torch.sin(x) * 10
|
|
|
|
@torch._dynamo.optimize(cnts2, nopython=True)
|
|
def fn2(x):
|
|
return fn1(x) + 1
|
|
|
|
@torch._dynamo.optimize(cnts3, nopython=True)
|
|
def fn3(x):
|
|
return torch.relu(fn2(x))
|
|
|
|
fn3(torch.randn(4, 5))
|
|
self.assertEqual(cnts2.frame_count, 0)
|
|
self.assertEqual(cnts3.frame_count, 1)
|
|
self.assertEqual(cnts3.op_count, 4)
|
|
|
|
def test_nested_optimize_run(self):
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(cnts, nopython=True)
|
|
def fn(x):
|
|
return torch.relu(torch.cos(x) + torch.sin(x))
|
|
|
|
fn(torch.randn(4))
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
fn(torch.randn(4, 4))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
# Test that run works on a decorated fn
|
|
fn = torch._dynamo.run(fn)
|
|
fn(torch.randn(4, 4, 4))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_nested_optimize(self):
|
|
cnts1 = torch._dynamo.testing.CompileCounter()
|
|
cnts2 = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x):
|
|
return torch.relu(torch.cos(x) + torch.sin(x))
|
|
|
|
fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn)
|
|
fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1)
|
|
|
|
# The first optimize in the nesting should be ignored
|
|
fn2(torch.randn(4))
|
|
self.assertEqual(cnts2.frame_count, 1)
|
|
self.assertEqual(cnts1.frame_count, 0)
|
|
|
|
# Since the fn code object is already compiled, calling fn1 should
|
|
# directly call the compiled_fn callable.
|
|
torch._dynamo.run()(fn1)(torch.randn(4))
|
|
self.assertEqual(cnts1.frame_count, 0)
|
|
|
|
# Test same behavior by reversing the calls
|
|
torch._dynamo.reset()
|
|
cnts1 = torch._dynamo.testing.CompileCounter()
|
|
cnts2 = torch._dynamo.testing.CompileCounter()
|
|
fn1 = torch._dynamo.optimize(cnts1, nopython=True)(fn)
|
|
fn2 = torch._dynamo.optimize(cnts2, nopython=True)(fn1)
|
|
fn1(torch.randn(4))
|
|
self.assertEqual(cnts1.frame_count, 1)
|
|
torch._dynamo.run()(fn2)(torch.randn(4))
|
|
self.assertEqual(cnts2.frame_count, 0)
|
|
|
|
def test_torch_size(self):
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x):
|
|
output_size = torch.Size([10, 10])
|
|
x = x.view(*output_size)
|
|
return (x,)
|
|
|
|
x = torch.randn(100, requires_grad=True)
|
|
x_clone = x.clone()
|
|
ref = fn(x)
|
|
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
res = opt_fn(x_clone)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_size_dim(self):
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x, dim):
|
|
return x.size(dim=dim)
|
|
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
x = torch.empty([4, 9, 8])
|
|
self.assertTrue(opt_fn(x, 1) == 9)
|
|
self.assertTrue(opt_fn(x, -2) == 9)
|
|
|
|
def test_stride_dim(self):
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x, dim):
|
|
return x.stride(dim=dim)
|
|
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
x = torch.empty([4, 9, 8])
|
|
self.assertTrue(opt_fn(x, 0) == 72)
|
|
self.assertTrue(opt_fn(x, -2) == 8)
|
|
|
|
def test_torch_seed(self):
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(x):
|
|
attention_seed = int(torch.seed() % sys.maxsize)
|
|
torch.manual_seed(attention_seed)
|
|
return (x,)
|
|
|
|
x = torch.randn(100, requires_grad=True)
|
|
ref = fn(x)
|
|
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
res = opt_fn(x)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_is_tensor_like(self):
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def f(x):
|
|
if torch.overrides.is_tensor_like(x):
|
|
return (x * 2,)
|
|
return (torch.ones(10) + x,)
|
|
|
|
x = torch.randn(10)
|
|
ref0 = f(x)
|
|
ref1 = f(4)
|
|
opt_f = torch._dynamo.optimize(cnts, nopython=True)(f)
|
|
res0 = opt_f(x)
|
|
res1 = opt_f(4)
|
|
self.assertTrue(same(ref0, res0))
|
|
self.assertTrue(same(ref1, res1))
|
|
|
|
def test_is_tensor_like2(self):
|
|
class MyTensor:
|
|
@classmethod
|
|
def __torch_function__(cls, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
|
|
if func is torch.max:
|
|
return torch.tensor(123)
|
|
return func(*args, **kwargs)
|
|
|
|
def fn(x):
|
|
if torch.overrides.is_tensor_like(x):
|
|
return torch.max(x)
|
|
else:
|
|
return torch.zeros(1)
|
|
|
|
x = MyTensor()
|
|
ref0 = fn(x)
|
|
ref1 = fn(4)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res0 = opt_fn(x)
|
|
res1 = opt_fn(4)
|
|
self.assertTrue(same(ref0, res0))
|
|
self.assertTrue(same(ref1, res1))
|
|
|
|
def test_tensor_data(self):
|
|
def fn(x, y):
|
|
return x[y.data]
|
|
|
|
x = torch.rand(8)
|
|
y = torch.ones(8).to(torch.int)
|
|
ref = fn(x, y)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_tensor_layout(self):
|
|
def fn(x):
|
|
return torch.zeros(
|
|
[x.size()[0], x.size()[1]],
|
|
dtype=x.dtype,
|
|
layout=x.layout,
|
|
device=x.device,
|
|
)
|
|
|
|
x = torch.rand(2, 3)
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_version_ci(self):
|
|
# temporary test to check that the ci torch version is set correctly
|
|
self.assertTrue(hasattr(torch, "_subclasses"))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
|
def test_rand(self):
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
device = "cuda"
|
|
|
|
def fn():
|
|
return torch.randn(10, device=device)
|
|
|
|
torch.manual_seed(10)
|
|
ref_run1 = fn()
|
|
|
|
torch.manual_seed(10)
|
|
ref_run2 = fn()
|
|
self.assertTrue(same(ref_run1, ref_run2))
|
|
|
|
torch.manual_seed(10)
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
res = opt_fn()
|
|
|
|
self.assertTrue(same(res, ref_run1))
|
|
|
|
def test_slice_input(self):
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def getitem(a, idx):
|
|
if isinstance(idx, slice):
|
|
return (
|
|
torch.zeros(1),
|
|
a[idx]
|
|
+ [
|
|
100,
|
|
],
|
|
)
|
|
else:
|
|
return (torch.zeros(1), a[idx])
|
|
|
|
layers = list(range(10))
|
|
ref0 = getitem(layers, slice(0, 2, 1))
|
|
ref1 = getitem(layers, 2)
|
|
ref2 = getitem(layers, slice(3, 8, 2))
|
|
opt_getitem = torch._dynamo.optimize(cnts, nopython=True)(getitem)
|
|
res0 = opt_getitem(layers, slice(0, 2, 1))
|
|
res1 = opt_getitem(layers, 2)
|
|
res2 = opt_getitem(layers, slice(3, 8, 2))
|
|
|
|
self.assertTrue(ref0 == res0)
|
|
self.assertTrue(ref1 == res1)
|
|
self.assertTrue(ref2 == res2)
|
|
|
|
def test_grad(self):
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
def fn(a, b):
|
|
out = a * b
|
|
out.sum().backward()
|
|
real_out = torch.sigmoid(a.grad + b)
|
|
return real_out
|
|
|
|
inps = [torch.randn(4, requires_grad=True) for _ in range(2)]
|
|
for inp in inps:
|
|
inp.grad = None
|
|
ref = fn(*inps)
|
|
|
|
for inp in inps:
|
|
inp.grad = None
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
res = opt_fn(*inps)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
@skipIfNotPy311
|
|
def test_linetable_311_writer1(self):
|
|
def fn():
|
|
a = 10
|
|
b = 20
|
|
c = a + b
|
|
f = "linetable_writer"
|
|
return f"Test if {f} generates correct co_linetable: {c}"
|
|
|
|
# Dynamo doesn't deal with column locations or end line numbers,
|
|
# so we only check that start line numbers in the linetables match.
|
|
keys = bytecode_transformation.get_code_keys()
|
|
code_options = {k: getattr(fn.__code__, k) for k in keys}
|
|
result = bytecode_transformation.clean_and_assemble_instructions(
|
|
bytecode_transformation.cleaned_instructions(fn.__code__),
|
|
keys,
|
|
code_options,
|
|
)
|
|
l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
|
|
self.assertEqual(len(l1), len(l2))
|
|
for p1, p2 in zip(l1, l2):
|
|
# check that start line numbers match
|
|
self.assertEqual(p1[0], p2[0])
|
|
self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)
|
|
|
|
@skipIfNotPy311
|
|
def test_linetable_311_writer2(self):
|
|
"""
|
|
test large ops (LOAD_METHOD) and EXTENDED_ARGS
|
|
fn_str is in the form:
|
|
def fn():
|
|
...
|
|
x0 = 1
|
|
x1 = 1
|
|
...
|
|
l = [x0, x1, ...]
|
|
"""
|
|
fn_str = f"""\
|
|
def fn():
|
|
foo.bar(1, 2, 3)
|
|
{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))}
|
|
l = [{str(' ').join('x' + str(i) + ',' for i in range(1 << 9))}]
|
|
"""
|
|
locals = {}
|
|
exec(fn_str, {}, locals)
|
|
fn = locals["fn"]
|
|
orig_inst_str = "\n".join(list(map(str, dis.get_instructions(fn))))
|
|
self.assertIn("EXTENDED_ARG", orig_inst_str)
|
|
self.assertIn("LOAD_METHOD", orig_inst_str)
|
|
keys = bytecode_transformation.get_code_keys()
|
|
code_options = {k: getattr(fn.__code__, k) for k in keys}
|
|
result = bytecode_transformation.clean_and_assemble_instructions(
|
|
bytecode_transformation.cleaned_instructions(fn.__code__),
|
|
keys,
|
|
code_options,
|
|
)
|
|
new_inst_str = "\n".join(list(map(str, result[0])))
|
|
self.assertIn("EXTENDED_ARG", new_inst_str)
|
|
self.assertIn("LOAD_METHOD", new_inst_str)
|
|
l1, l2 = list(fn.__code__.co_positions()), list(result[1].co_positions())
|
|
self.assertEqual(len(l1), len(l2))
|
|
for p1, p2 in zip(l1, l2):
|
|
# check that start line numbers match
|
|
self.assertEqual(p1[0], p2[0])
|
|
self.assertEqual(fn.__code__.co_lnotab, result[1].co_lnotab)
|
|
|
|
@unittest.skipIf(
|
|
sys.version_info < (3, 10) or sys.version_info >= (3, 11),
|
|
"linetable test for Python 3.10",
|
|
)
|
|
def test_linetable_310_writer(self):
|
|
def fn():
|
|
a = 10
|
|
b = 20
|
|
c = a + b
|
|
f = "linetable_writer"
|
|
return f"Test if {f} generates correct co_linetable: {c}"
|
|
|
|
inst = dis.get_instructions(fn)
|
|
result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
|
|
self.assertTrue(result[1] == fn.__code__.co_linetable)
|
|
|
|
@unittest.skipIf(sys.version_info >= (3, 10), "use lnotab when python < 3.10")
|
|
def test_lnotab_writer(self):
|
|
def fn():
|
|
a = 10
|
|
b = 20
|
|
c = a + b
|
|
f = "lnotab_writer"
|
|
return f"Test if {f} generates correct co_lnotab: {c}"
|
|
|
|
inst = dis.get_instructions(fn)
|
|
result = bytecode_transformation.assemble(inst, fn.__code__.co_firstlineno)
|
|
self.assertTrue(result[1] == fn.__code__.co_lnotab)
|
|
|
|
def test_profiler_cache_lookup(self):
|
|
def fn(x):
|
|
y = x**2
|
|
y = y + 2
|
|
z = y**3
|
|
return z
|
|
|
|
for profiler, get_events in (
|
|
(torch.autograd.profiler.profile, lambda prof: prof.function_events),
|
|
(torch.profiler.profiler.profile, lambda prof: prof.events()),
|
|
):
|
|
x = torch.randn((2, 2), requires_grad=True)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="aot_eager")
|
|
|
|
# warmup
|
|
opt_fn(x)
|
|
|
|
# whenver we enter the profiler context, hooks are automatically registered
|
|
with profiler() as prof:
|
|
res = opt_fn(x)
|
|
events = list(
|
|
filter(
|
|
lambda event: event.name == "TorchDynamo Cache Lookup",
|
|
get_events(prof),
|
|
)
|
|
)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
self.assertTrue(
|
|
len(events) == 1,
|
|
"Expected one lookup profiler event for one opt_fn run",
|
|
)
|
|
|
|
with profiler() as prof:
|
|
# just make sure the disable functionality works
|
|
_enable_dynamo_cache_lookup_profiler(False)
|
|
res = opt_fn(x)
|
|
events = list(
|
|
filter(
|
|
lambda event: event.name == "TorchDynamo Cache Lookup",
|
|
get_events(prof),
|
|
)
|
|
)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
self.assertTrue(len(events) == 0, "Expected disabled profiling")
|
|
|
|
def test_tensor_is_contiguous(self):
|
|
def fn(x):
|
|
input = torch.randn((1, 16, 1, 1))
|
|
weight = torch.randn((8, 16, 3, 3))
|
|
weight = weight.to(memory_format=x)
|
|
output = torch.conv2d(input, weight, None, (2, 1), (1, 1), (1, 1), 1)
|
|
return output.is_contiguous(memory_format=x)
|
|
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
for x in [torch.contiguous_format, torch.channels_last]:
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_python_slice(self):
|
|
def f1(input):
|
|
y = 0
|
|
for i, x in enumerate(input[2:], 1):
|
|
y = y + x
|
|
return y
|
|
|
|
def f2(input):
|
|
y = 0
|
|
for i, x in enumerate(input.shape[2:], 1):
|
|
y = y + x
|
|
return y
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_f1 = torch._dynamo.optimize(cnts)(f1)
|
|
opt_f2 = torch._dynamo.optimize(cnts)(f2)
|
|
res1 = opt_f1([1, 2, 3, 5])
|
|
res2 = opt_f2(torch.rand([2, 3, 4, 5]))
|
|
|
|
self.assertEqual(res1, 8)
|
|
self.assertEqual(res2, 9)
|
|
|
|
def test_enum_as_dict_key(self):
|
|
class MyEnum(enum.Enum):
|
|
FOO = 10
|
|
BAR = 20
|
|
|
|
def fn(x):
|
|
y = x + 2
|
|
z = {
|
|
MyEnum.FOO: torch.tensor(1),
|
|
MyEnum.BAR: 10,
|
|
"MyEnum.BAR": torch.tensor(8),
|
|
5: torch.rand(3),
|
|
}
|
|
torch._dynamo.graph_break()
|
|
a = z[MyEnum.FOO] + z["MyEnum.BAR"]
|
|
b = y * 2
|
|
return a, b
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
for _ in range(10):
|
|
x = torch.rand(3)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_const_dict_variable_python_type(self):
|
|
from torch._dynamo.variables import ConstantVariable, ConstDictVariable
|
|
|
|
d1 = {"a": ConstantVariable(10), "b": ConstantVariable(20)}
|
|
d2 = collections.OrderedDict(
|
|
[("x", ConstantVariable(12)), ("y", ConstantVariable(22))]
|
|
)
|
|
self.assertEqual(ConstDictVariable(d1, dict).python_type(), dict)
|
|
self.assertEqual(
|
|
ConstDictVariable(d2, collections.OrderedDict).python_type(),
|
|
collections.OrderedDict,
|
|
)
|
|
|
|
def test_builtin_subclasses_as_method_on_class_type(self):
|
|
class Foo:
|
|
def __init__(self, name):
|
|
self.ame_ = name
|
|
|
|
def get_name(self):
|
|
return "Foo " + self.name_
|
|
|
|
class Bar(Foo):
|
|
def __init__(self, name):
|
|
self.name_ = name
|
|
|
|
def get_name(self):
|
|
return "Bar " + self.name_
|
|
|
|
class Baz(Foo):
|
|
def __init__(self, name): # noqa: B903
|
|
self.name_ = name
|
|
|
|
def get_name(self):
|
|
return "Baz " + self.name_
|
|
|
|
subs_of_foo_reg = Foo.__subclasses__()
|
|
|
|
counter = CompileCounter()
|
|
|
|
@torch._dynamo.optimize_assert(counter)
|
|
def fn():
|
|
return Foo.__subclasses__()
|
|
|
|
subs_of_foo_optim = fn()
|
|
|
|
self.assertEqual(len(subs_of_foo_reg), 2)
|
|
self.assertEqual(subs_of_foo_reg, subs_of_foo_optim)
|
|
|
|
def test_builtin_subclasses_as_method_on_var(self):
|
|
class Foo:
|
|
def __init__(self, name):
|
|
self.name_ = name
|
|
|
|
def get_name(self):
|
|
return "Foo " + self.name_
|
|
|
|
class Bar(Foo):
|
|
def __init__(self, name):
|
|
self.name_ = name
|
|
|
|
def get_name(self):
|
|
return "Bar " + self.name_
|
|
|
|
class Baz(Bar):
|
|
def __init__(self, name):
|
|
self.name_ = name
|
|
|
|
def get_name(self):
|
|
return "Baz " + self.name_
|
|
|
|
subs_of_foo_reg = Foo.__subclasses__()
|
|
sub_of_foo_subclass_var_reg = subs_of_foo_reg[0].__subclasses__()
|
|
|
|
sub_of_foo_subclass_var_optim = list()
|
|
counter = CompileCounter()
|
|
|
|
@torch._dynamo.optimize_assert(counter)
|
|
def fn():
|
|
return Foo.__subclasses__()
|
|
|
|
@torch._dynamo.optimize_assert(counter)
|
|
def fn_single(subs_of_foo_optim):
|
|
return subs_of_foo_optim[0].__subclasses__()
|
|
|
|
subs_of_foo_optim = fn()
|
|
sub_of_foo_subclass_var_optim = fn_single(subs_of_foo_optim)
|
|
|
|
self.assertEqual(len(sub_of_foo_subclass_var_optim), 1)
|
|
self.assertEqual(sub_of_foo_subclass_var_optim, sub_of_foo_subclass_var_reg)
|
|
|
|
def test_enum_no_graphbreaks(self):
|
|
class Foo(enum.Enum):
|
|
FOO = 0
|
|
BAR = 1
|
|
|
|
def fn(x, foo):
|
|
if foo is Foo.FOO:
|
|
x = torch.add(x, 1.0)
|
|
x = torch.mul(x, 1.0)
|
|
return x
|
|
|
|
x = torch.randn(1)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
opt_fn(x, Foo.FOO)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
torch._dynamo.reset()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
opt_fn(x, Foo.BAR)
|
|
self.assertEqual(cnts.op_count, 1)
|
|
|
|
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
|
def test_repeat_interleave_graphbreaks(self):
|
|
def fn_no_breaks(x):
|
|
# no breaks on self_int
|
|
x += 1
|
|
x = torch.repeat_interleave(x, 2, 3)
|
|
x += 1
|
|
return x
|
|
|
|
def fn_has_breaks(x):
|
|
# breaks on self_Tensor
|
|
x += 1
|
|
x = torch.repeat_interleave(x, torch.tensor(2), 3)
|
|
x += 1
|
|
return x
|
|
|
|
x = torch.randn([4, 16, 1, 64])
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn_no_breaks)
|
|
opt_fn(x)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
torch._dynamo.reset()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn_has_breaks)
|
|
opt_fn(x)
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_id_of_nn_module(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, ref_id):
|
|
self_id = id(self)
|
|
if self_id == ref_id:
|
|
x = torch.mul(x, 1.0)
|
|
x = torch.add(x, 1.0)
|
|
return x
|
|
|
|
m = M().eval()
|
|
data = torch.randn(1)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
correct_ref_id = id(m)
|
|
opt_m = torch._dynamo.optimize(cnts, nopython=True)(m)
|
|
opt_m(data, correct_ref_id)
|
|
# Extra op is the recorded equality test (although once
|
|
# the trace is flattened this is dead!)
|
|
self.assertEqual(cnts.op_count, ifunspec(3, 2))
|
|
|
|
torch._dynamo.reset()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
incorrect_ref_id = id(m) + 1
|
|
opt_m = torch._dynamo.optimize(cnts, nopython=True)(m)
|
|
opt_m(data, incorrect_ref_id)
|
|
self.assertEqual(cnts.op_count, ifunspec(2, 1))
|
|
|
|
def test_inline_func_jump_on_tensor_condition(self):
|
|
def f1(input):
|
|
if input == 0:
|
|
return input + 1
|
|
else:
|
|
return input + 2
|
|
|
|
def f2(input):
|
|
return f1(input)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_f2 = torch._dynamo.optimize(cnts)(f2)
|
|
res1 = opt_f2(torch.tensor([1.0]))
|
|
res2 = opt_f2(torch.tensor([0.0]))
|
|
|
|
self.assertEqual(res1, 3)
|
|
self.assertEqual(res2, 1)
|
|
|
|
def test_frozenset_torch_func_contains(self):
|
|
funcs = frozenset([torch.add])
|
|
|
|
def fn(x, func):
|
|
if func in funcs:
|
|
x = torch.add(x, 1.0)
|
|
x = torch.mul(x, 1.0)
|
|
return x
|
|
|
|
x = torch.randn(1)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
opt_fn(x, torch.add)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
torch._dynamo.reset()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
opt_fn(x, torch.mul)
|
|
self.assertEqual(cnts.op_count, 1)
|
|
|
|
def test_inline_list_mutation(self):
|
|
def f1(x):
|
|
x.append(torch.ones(8))
|
|
return x
|
|
|
|
def f2():
|
|
x = [torch.ones(6)]
|
|
f1(x)
|
|
return x
|
|
|
|
res1 = f2()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_f2 = torch._dynamo.optimize(cnts)(f2)
|
|
res2 = opt_f2()
|
|
self.assertTrue(same(res1, res2))
|
|
|
|
def test_inline_dict_mutation(self):
|
|
def f1(d):
|
|
d["c"] = d["a"] + d.pop("b")
|
|
return d
|
|
|
|
def f2():
|
|
d = {"a": torch.ones(5), "b": torch.ones(5)}
|
|
f1(d)
|
|
return d
|
|
|
|
res1 = f2()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_f2 = torch._dynamo.optimize(cnts)(f2)
|
|
res2 = opt_f2()
|
|
self.assertTrue(same(res1, res2))
|
|
|
|
def test_recursive_inline_list_mutation(self):
|
|
def f1(x, y):
|
|
x.append(torch.tensor([1.1]))
|
|
y.append(torch.tensor([1.2]))
|
|
return x, y
|
|
|
|
def f2(x, y):
|
|
x.append(torch.tensor([2.1]))
|
|
y.append(torch.tensor([2.2]))
|
|
f1(x, y)
|
|
return x, y
|
|
|
|
def f3(x):
|
|
x.append(torch.tensor([3.1]))
|
|
y = [torch.tensor([3.2])]
|
|
f2(x, y)
|
|
return x, y
|
|
|
|
def f4():
|
|
x = [torch.tensor([4.1])]
|
|
return f3(x)
|
|
|
|
res1 = f4()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_f4 = torch._dynamo.optimize(cnts)(f4)
|
|
res2 = opt_f4()
|
|
self.assertTrue(same(res1, res2))
|
|
|
|
def test_sample_input(self):
|
|
from torch.testing._internal.common_methods_invocations import SampleInput
|
|
|
|
def fn(sample):
|
|
if isinstance(sample.input, torch.Tensor):
|
|
return sample.input * 2
|
|
return torch.zeros(())
|
|
|
|
sample = SampleInput(torch.ones(2))
|
|
ref = fn(sample)
|
|
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(sample)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_release_input_memory(self):
|
|
x = torch.rand([4])
|
|
x_ref = weakref.ref(x)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(cnts)
|
|
def foo(x):
|
|
return x + x
|
|
|
|
out = foo(x)
|
|
self.assertTrue(same(out, x + x))
|
|
del x
|
|
self.assertIs(x_ref(), None)
|
|
|
|
def test_release_module_memory(self):
|
|
mod = torch.nn.Linear(10, 10)
|
|
x = torch.rand([10, 10])
|
|
mod_weight_ref = weakref.ref(mod.weight)
|
|
mod_ref = weakref.ref(mod)
|
|
|
|
# Modules that are passed into torch._dynamo optimized functions
|
|
# will normally be held onto through the generated GraphModule,
|
|
# which contains the modules. remove the reference in this backend
|
|
# and test that no additional references are being held.
|
|
class NoLeakBackend:
|
|
def __call__(self, gm: torch.fx.GraphModule, example_inputs):
|
|
gm.mod = None
|
|
|
|
def foo(*args, **kwargs):
|
|
return (1,)
|
|
|
|
return foo
|
|
|
|
no_leak_backend = NoLeakBackend()
|
|
|
|
@torch._dynamo.optimize(no_leak_backend)
|
|
def foo(mod, x):
|
|
return mod(x)
|
|
|
|
foo(mod, x)
|
|
del mod
|
|
del x
|
|
self.assertIsNone(mod_ref(), None)
|
|
self.assertIsNone(mod_weight_ref(), None)
|
|
|
|
def test_update_locals_and_stack_uses_shared_cache(self):
|
|
def fn(x):
|
|
perm = [0, 3, 5]
|
|
perm = list(range(min(perm))) + perm
|
|
perm.extend(i for i in range(x.dim()) if i not in perm)
|
|
return perm
|
|
|
|
x = torch.rand([2, 2, 2, 2, 2, 2])
|
|
res1 = fn(x)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
res2 = opt_fn(x)
|
|
self.assertTrue(same(res1, res2))
|
|
|
|
def test_dict_reconstruct_keeps_original_order(self):
|
|
def fn():
|
|
modules = collections.OrderedDict([("act", torch.nn.ReLU())])
|
|
module_dict = torch.nn.ModuleDict(modules)
|
|
|
|
next_modules = {"fc4": torch.nn.Linear(5, 6), "act3": torch.nn.Sigmoid()}
|
|
modules.update(next_modules.items())
|
|
module_dict.update(next_modules)
|
|
return modules, module_dict
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
modules, module_dict = opt_fn()
|
|
|
|
self.assertEqual(len(module_dict), len(modules))
|
|
for k1, m2 in zip(modules, module_dict.children()):
|
|
self.assertTrue(modules[k1] is m2)
|
|
|
|
def test_side_effects_codegen_update_mutated(self):
|
|
# codegen to update mutated variables with side effect
|
|
# should after stack value's codegen
|
|
def f1(x):
|
|
alist = [x]
|
|
alist.append(x + 1)
|
|
alist[0].sum().item() # graph break
|
|
res = alist.pop()
|
|
res.sum().item() # graph break
|
|
return res
|
|
|
|
def f2(a, b):
|
|
d = {"a": a + 1, "b": b + 2}
|
|
x = d.pop("b")
|
|
x.sum().item() # graph break
|
|
y = d["a"] + x
|
|
y.sum().item() # graph break
|
|
d["c"] = y
|
|
return d
|
|
|
|
x = torch.rand([2, 3])
|
|
a = torch.rand([5, 6])
|
|
b = torch.rand([5, 6])
|
|
res11 = f1(x)
|
|
res21 = f2(a, b)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_f1 = torch._dynamo.optimize(cnts)(f1)
|
|
opt_f2 = torch._dynamo.optimize(cnts)(f2)
|
|
res12 = opt_f1(x)
|
|
res22 = opt_f2(a, b)
|
|
self.assertTrue(same(res11, res12))
|
|
self.assertTrue(same(res21, res22))
|
|
|
|
def test_list_append_return_none(self):
|
|
def fn(x):
|
|
alist = []
|
|
blist = alist.append(x + 1)
|
|
return alist, blist
|
|
|
|
x = torch.tensor([2.3])
|
|
res = fn(x)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
res2 = opt_fn(x)
|
|
self.assertEqual(res, res2)
|
|
|
|
def test_tensor_types(self):
|
|
def fn(dtype, tensor_type):
|
|
x = torch.empty(4, dtype=dtype)
|
|
assert isinstance(x, tensor_type)
|
|
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
opt_fn(torch.float32, torch.FloatTensor)
|
|
opt_fn(torch.float64, torch.DoubleTensor)
|
|
opt_fn(torch.float16, torch.HalfTensor)
|
|
opt_fn(torch.bfloat16, torch.BFloat16Tensor)
|
|
opt_fn(torch.uint8, torch.ByteTensor)
|
|
opt_fn(torch.int8, torch.CharTensor)
|
|
opt_fn(torch.int64, torch.LongTensor)
|
|
opt_fn(torch.int, torch.IntTensor)
|
|
opt_fn(torch.int16, torch.ShortTensor)
|
|
opt_fn(torch.bool, torch.BoolTensor)
|
|
|
|
def test_nan(self):
|
|
def f(x, n):
|
|
return x * 2 + n
|
|
|
|
x = torch.randn(4)
|
|
n = float("nan")
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_f = torch._dynamo.optimize(cnts)(f)
|
|
opt_f(x, n)
|
|
opt_f(x, n)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
|
|
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
|
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
|
def test_item(self):
|
|
class MyMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
z = torch.max(x)
|
|
return z.int().item()
|
|
|
|
x = torch.tensor([[10.6763, 11.7445, -2.2369]])
|
|
model = MyMod()
|
|
y = torch._dynamo.optimize("eager", nopython=True)(model)(x)
|
|
|
|
self.assertEqual(y, 11)
|
|
|
|
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
|
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
|
def test_item_changes(self):
|
|
class MyMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
z = torch.max(x)
|
|
return z.int().item()
|
|
|
|
x = torch.tensor([[10.6763, 11.7445, -2.2369]])
|
|
model = MyMod()
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
y = opt_model(x)
|
|
z = opt_model(torch.tensor([[y - 5, y + 10, y + 50]]))
|
|
|
|
self.assertEqual(y, 11)
|
|
self.assertEqual(z, 61)
|
|
|
|
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
|
@patch.object(torch._dynamo.config, "capture_scalar_outputs", True)
|
|
def test_item_changes_new_shape(self):
|
|
class MyMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
z = torch.max(x)
|
|
return z.int().item()
|
|
|
|
x = torch.tensor([[10.6763, 11.7445, -2.2369]])
|
|
model = MyMod()
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
y = opt_model(x)
|
|
z = opt_model(torch.tensor([[y - 5, y + 50], [y + 5, y - 50]]))
|
|
|
|
self.assertEqual(y, 11)
|
|
self.assertEqual(z, 61)
|
|
|
|
@unittest.skip("https://github.com/pytorch/pytorch/issues/99726")
|
|
def test_cross_entropy_loss_fancy_ctor1(self):
|
|
rand_5 = torch.randn(5)
|
|
rand_3_5 = torch.randn(3, 5)
|
|
target = torch.empty(3, dtype=torch.long).random_(5)
|
|
|
|
loss = torch.nn.CrossEntropyLoss(
|
|
weight=rand_5, reduce=False, label_smoothing=0.5
|
|
)
|
|
opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
|
|
input = rand_3_5
|
|
dynamo_output = opt_loss(input, target)
|
|
|
|
loss = torch.nn.CrossEntropyLoss(
|
|
weight=rand_5, reduce=False, label_smoothing=0.5
|
|
)
|
|
input = rand_3_5
|
|
output = loss(input, target)
|
|
|
|
self.assertTrue(torch.allclose(dynamo_output, output))
|
|
|
|
@requires_static_shapes
|
|
def test_cross_entropy_loss_fancy_ctor2(self):
|
|
rand_3_5 = torch.randn(3, 5)
|
|
target = torch.empty(3, dtype=torch.long).random_(5)
|
|
|
|
loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5)
|
|
opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
|
|
input = rand_3_5
|
|
dynamo_output = opt_loss(input, target)
|
|
|
|
loss = torch.nn.CrossEntropyLoss(reduce=False, label_smoothing=0.5)
|
|
input = rand_3_5
|
|
output = loss(input, target)
|
|
|
|
self.assertTrue(torch.allclose(dynamo_output, output))
|
|
|
|
def test_cross_entropy_loss_simple_ctor(self):
|
|
output = None
|
|
rand_3_5 = torch.randn(3, 5)
|
|
target = torch.empty(3, dtype=torch.long).random_(5)
|
|
|
|
loss = torch.nn.CrossEntropyLoss()
|
|
opt_loss = torch._dynamo.optimize("eager", nopython=True)(loss)
|
|
input = rand_3_5
|
|
dynamo_output = opt_loss(input, target)
|
|
|
|
loss = torch.nn.CrossEntropyLoss()
|
|
input = rand_3_5
|
|
output = loss(input, target)
|
|
|
|
self.assertTrue(torch.allclose(dynamo_output, output))
|
|
|
|
def test_nn_functional_reduction(self):
|
|
def fn(loss, reduction):
|
|
reduction_enum = F._Reduction.get_enum(reduction)
|
|
if reduction_enum == 0:
|
|
return loss
|
|
elif reduction_enum == 1:
|
|
return loss.mean()
|
|
elif reduction_enum == 2:
|
|
return loss.sum()
|
|
|
|
x = torch.rand([3, 5])
|
|
y = "mean"
|
|
ref = fn(x, y)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_large_reduction_list(self):
|
|
dtype = torch.float32
|
|
device = "cpu"
|
|
|
|
def check_sum_all(tensor: torch.Tensor) -> None:
|
|
pylist = tensor.reshape(-1).tolist()
|
|
self.assertTrue(same(tensor.sum(), torch.tensor(sum(pylist))))
|
|
|
|
check_sum_all(torch.randn(200000, dtype=dtype, device=device))
|
|
|
|
def test_raise_on_backend_error(self):
|
|
def my_compiler(gm, _):
|
|
raise RuntimeError("duck!")
|
|
|
|
@torch._dynamo.optimize(my_compiler)
|
|
def fn(a, b):
|
|
return a + b / (a - b)
|
|
|
|
self.assertRaises(
|
|
torch._dynamo.exc.BackendCompilerFailed,
|
|
lambda: fn(torch.randn(10), torch.randn(10)),
|
|
)
|
|
|
|
def test_named_parameters(self):
|
|
n_embd = 768
|
|
block_size = 128
|
|
vocab_size = 65
|
|
embd_pdrop = 0.1
|
|
|
|
class MyModel2(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
|
|
self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
|
|
self.drop = torch.nn.Dropout(embd_pdrop)
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
|
|
self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
|
|
self.drop = torch.nn.Dropout(embd_pdrop)
|
|
self.submod2 = MyModel2()
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
# Regular
|
|
params = []
|
|
mod = MyModel()
|
|
actual_params = list(mod.named_parameters())
|
|
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def fn():
|
|
return list(mod.named_parameters())
|
|
|
|
params = fn()
|
|
|
|
self.assertEqual(len(actual_params), len(params))
|
|
for idx in range(len(params)):
|
|
k_a, v_a = actual_params[idx]
|
|
k, v = params[idx]
|
|
self.assertEqual(k_a, k)
|
|
self.assertTrue(torch.allclose(v_a, v))
|
|
|
|
# Prefix
|
|
params = []
|
|
mod = MyModel()
|
|
actual_params = list(mod.named_parameters(prefix="foo"))
|
|
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def fn1():
|
|
return list(mod.named_parameters(prefix="foo"))
|
|
|
|
params = fn1()
|
|
|
|
self.assertEqual(len(actual_params), len(params))
|
|
for idx in range(len(params)):
|
|
k_a, v_a = actual_params[idx]
|
|
k, v = params[idx]
|
|
self.assertEqual(k_a, k)
|
|
self.assertTrue(torch.allclose(v_a, v))
|
|
|
|
def test_module_complex_iter(self):
|
|
n_embd = 768
|
|
block_size = 128
|
|
vocab_size = 65
|
|
embd_pdrop = 0.1
|
|
|
|
class FakeGPT(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.tok_emb = torch.nn.Embedding(vocab_size, n_embd)
|
|
self.pos_emb = torch.nn.Parameter(torch.zeros(1, block_size, n_embd))
|
|
self.drop = torch.nn.Dropout(embd_pdrop)
|
|
self.ln_f = torch.nn.LayerNorm(n_embd)
|
|
self.head = torch.nn.Linear(n_embd, vocab_size, bias=False)
|
|
|
|
self.block_size = block_size
|
|
self.names = []
|
|
|
|
def forward(self, idx, targets=None):
|
|
b, t = idx.size()
|
|
assert (
|
|
t <= self.block_size
|
|
), "Cannot forward, model block size is exhausted."
|
|
|
|
# forward the GPT model
|
|
token_embeddings = self.tok_emb(
|
|
idx
|
|
) # each index maps to a (learnable) vector
|
|
position_embeddings = self.pos_emb[
|
|
:, :t, :
|
|
] # each position maps to a (learnable) vector
|
|
x = self.drop(token_embeddings + position_embeddings)
|
|
x = self.blocks(x)
|
|
x = self.ln_f(x)
|
|
logits = self.head(x)
|
|
|
|
# if we are given some desired targets also calculate the loss
|
|
loss = None
|
|
if targets is not None:
|
|
loss = F.cross_entropy(
|
|
logits.view(-1, logits.size(-1)), targets.view(-1)
|
|
)
|
|
|
|
return logits, loss
|
|
|
|
def foo(self, memo=None, prefix="", remove_duplicate=False):
|
|
for mn, m in self.named_modules(
|
|
memo=memo, prefix=prefix, remove_duplicate=remove_duplicate
|
|
):
|
|
for pn, p in self.named_parameters():
|
|
fpn = "%s.%s" % (mn, pn) if mn else pn
|
|
self.names.append(fpn)
|
|
|
|
# Test plain recurse
|
|
model_a = FakeGPT()
|
|
model_a.foo()
|
|
a_names = model_a.names
|
|
|
|
model_b = FakeGPT()
|
|
opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b)
|
|
opt_model_b.foo()
|
|
|
|
self.assertEqual(a_names, model_b.names)
|
|
|
|
# Test with prefix
|
|
model_a = FakeGPT()
|
|
model_a.foo(prefix="abc")
|
|
a_names = model_a.names
|
|
|
|
model_b = FakeGPT()
|
|
opt_model_b = torch._dynamo.optimize("eager", nopython=True)(model_b)
|
|
opt_model_b.foo(prefix="abc")
|
|
|
|
self.assertEqual(a_names, model_b.names)
|
|
|
|
def test_numpy_variable_isinstance(self):
|
|
def fn(x, m):
|
|
if isinstance(m, np.ndarray):
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.tensor([2.3])
|
|
m = np.array([1, 2, 3])
|
|
ref = fn(x, m)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
res = opt_fn(x, m)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_tensor_dot_grad_no_graph_break(self):
|
|
def fn(a, b):
|
|
y = 3 * a**3 - b**2
|
|
y.backward(gradient=torch.tensor([1.0, 1.0]))
|
|
b.grad.zero_()
|
|
return a.grad, b.grad
|
|
|
|
a = torch.tensor([2.0, 3.0], requires_grad=True)
|
|
b = torch.tensor([6.0, 4.0], requires_grad=True)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts)(fn)
|
|
_, b_grad = opt_fn(a, b)
|
|
self.assertTrue(same(b_grad, torch.tensor([0.0, 0.0])))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_torch_nn_parameter_isinstance(self):
|
|
def fn(x):
|
|
a = torch.nn.Parameter(torch.rand(2, 3))
|
|
if isinstance(a, torch.Tensor):
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.tensor([2.5])
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
@torch._dynamo.config.patch(raise_on_backend_change=True)
|
|
def test_change_backends(self):
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def fn1():
|
|
return x + 1
|
|
|
|
@torch._dynamo.optimize("ts")
|
|
def fn2():
|
|
return x + 2
|
|
|
|
@torch._dynamo.optimize("eager", nopython=False)
|
|
def fn3():
|
|
return x + 1
|
|
|
|
x = torch.tensor([3, 5])
|
|
|
|
fn1()
|
|
fn1()
|
|
fn3()
|
|
self.assertRaises(torch._dynamo.exc.ResetRequired, fn2)
|
|
fn1()
|
|
torch._dynamo.reset()
|
|
fn2()
|
|
fn2()
|
|
self.assertRaises(torch._dynamo.exc.ResetRequired, fn1)
|
|
self.assertRaises(torch._dynamo.exc.ResetRequired, fn3)
|
|
fn2()
|
|
|
|
def test_dynamo_min_operator_with_shape(self):
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def f(x, a):
|
|
return min(x.shape[0], a)
|
|
|
|
result = f(torch.ones(6), 3)
|
|
self.assertEqual(result, 3)
|
|
|
|
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
|
def test_onnx_shape_as_tensor(self):
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def f(x):
|
|
return 1 + torch._shape_as_tensor(x)[0]
|
|
|
|
gm, _ = torch._dynamo.export(f, torch.ones(6))
|
|
|
|
input_one_dim = torch.ones(6)
|
|
input_two_dims = torch.ones(7, 4)
|
|
self.assertEqual(f(input_one_dim), 7)
|
|
self.assertEqual(f(input_two_dims), 8)
|
|
self.assertEqual(f(input_two_dims), 8)
|
|
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def f_onnx(x):
|
|
return 1 + torch.onnx.operators.shape_as_tensor(x)[0]
|
|
|
|
self.assertEqual(f_onnx(input_one_dim), 7)
|
|
self.assertEqual(f_onnx(input_two_dims), 8)
|
|
self.assertEqual(f_onnx(input_two_dims), 8)
|
|
|
|
def test_cond(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
def true_fn(x):
|
|
return x.sin()
|
|
|
|
def false_fn(x):
|
|
return x.cos()
|
|
|
|
def f(pred, x):
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
opt_fn = torch._dynamo.optimize("eager")(f)
|
|
a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25]))
|
|
self.assertTrue(same(torch.cos(torch.tensor([0.25, 0.25])), a))
|
|
b = opt_fn(torch.tensor(True), torch.tensor([0.25, 0.25]))
|
|
self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), b))
|
|
|
|
def test_nonzero_static(self):
|
|
# invalid size
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
|
|
):
|
|
torch.nonzero_static(torch.tensor([8]), size=-2)
|
|
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "nonzero_static: 'size' must be an non-negative integer"
|
|
):
|
|
torch.nonzero_static(torch.tensor([8]), size=-2, out=torch.tensor(0))
|
|
|
|
# nonzero_static.out: out dtype mismatch
|
|
input_tensor = torch.tensor([8])
|
|
static_size = 1
|
|
out_tensor = torch.empty((static_size, input_tensor.dim()), dtype=torch.float)
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "nonzero_static: Expected out tensor to have scalar type Long"
|
|
):
|
|
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor)
|
|
|
|
# nonzero_static.out: out resize (shrink)
|
|
input_tensor = torch.tensor([8])
|
|
static_size = 1
|
|
out_tensor = torch.empty((10, 10, 10, 10), dtype=torch.long)
|
|
self.assertTrue(
|
|
same(
|
|
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor),
|
|
torch.tensor([0]),
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
same(
|
|
out_tensor,
|
|
torch.tensor([0]),
|
|
)
|
|
)
|
|
|
|
# nonzero_static.out: out resize (enlarge)
|
|
input_tensor = torch.tensor([8])
|
|
static_size = 1
|
|
out_tensor = torch.empty((0), dtype=torch.long)
|
|
self.assertTrue(
|
|
same(
|
|
torch.nonzero_static(input_tensor, size=static_size, out=out_tensor),
|
|
torch.tensor([0]),
|
|
)
|
|
)
|
|
self.assertTrue(
|
|
same(
|
|
out_tensor,
|
|
torch.tensor([0]),
|
|
)
|
|
)
|
|
|
|
# 0 rank
|
|
input_tensor = torch.tensor(6)
|
|
static_size = 2
|
|
self.assertTrue(
|
|
same(
|
|
torch.nonzero_static(input_tensor, size=static_size),
|
|
torch.empty((static_size, input_tensor.dim()), dtype=torch.long),
|
|
)
|
|
)
|
|
|
|
# 0 size
|
|
input_tensor = torch.tensor([[[1]]])
|
|
static_size = 0
|
|
self.assertTrue(
|
|
same(
|
|
torch.nonzero_static(input_tensor, size=static_size),
|
|
torch.empty((static_size, input_tensor.dim()), dtype=torch.long),
|
|
)
|
|
)
|
|
|
|
# 1D input
|
|
input_tensor = torch.tensor([0, 8])
|
|
static_size = 1
|
|
self.assertTrue(
|
|
same(
|
|
torch.nonzero_static(input_tensor, size=static_size),
|
|
torch.tensor([1]),
|
|
)
|
|
)
|
|
|
|
input_tensor = torch.tensor([8, 0])
|
|
static_size = 2
|
|
self.assertTrue(
|
|
same(
|
|
torch.nonzero_static(input_tensor, size=static_size),
|
|
torch.tensor([[0], [-1]]), # padded with default fill_value "-1"
|
|
)
|
|
)
|
|
|
|
# 2D input
|
|
input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]])
|
|
static_size = 5
|
|
fill_value = -100
|
|
self.assertTrue(
|
|
torch._dynamo.utils.same(
|
|
torch.nonzero_static(
|
|
input_tensor, size=static_size, fill_value=fill_value
|
|
),
|
|
torch.tensor(
|
|
[
|
|
[0, 0],
|
|
[1, 0],
|
|
[1, 1],
|
|
[fill_value, fill_value],
|
|
[fill_value, fill_value],
|
|
]
|
|
),
|
|
)
|
|
)
|
|
input_tensor = torch.tensor([[1.2, 0], [3.4, 5.6]])
|
|
static_size = 2
|
|
fill_value = -100
|
|
self.assertTrue(
|
|
torch._dynamo.utils.same(
|
|
torch.nonzero_static(
|
|
input_tensor, size=static_size, fill_value=fill_value
|
|
),
|
|
torch.tensor([[0, 0], [1, 0]]),
|
|
)
|
|
)
|
|
|
|
# 3D input
|
|
input_tensor = torch.tensor([[[0, 0], [0, -3]], [[0, 0], [5, 0]]])
|
|
static_size = 4
|
|
fill_value = -999
|
|
self.assertTrue(
|
|
torch._dynamo.utils.same(
|
|
torch.nonzero_static(
|
|
input_tensor,
|
|
size=static_size,
|
|
fill_value=fill_value,
|
|
),
|
|
torch.tensor(
|
|
[
|
|
[0, 1, 1],
|
|
[1, 1, 0],
|
|
[fill_value, fill_value, fill_value],
|
|
[fill_value, fill_value, fill_value],
|
|
]
|
|
),
|
|
)
|
|
)
|
|
|
|
def test_cond_with_quantization(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
example_inputs = (torch.randn(5, 5),)
|
|
self.model = torch.nn.Linear(5, 5)
|
|
self.quantized_model = prepare_qat_fx(
|
|
self.model, qconfig_dict, example_inputs=example_inputs
|
|
)
|
|
|
|
def forward(self, pred, x):
|
|
def true_fn(x):
|
|
return x.sin() + self.quantized_model(x)
|
|
|
|
def false_fn(x):
|
|
return x.cos() + self.model(x)
|
|
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
module = MyModule()
|
|
opt_m = torch._dynamo.optimize("eager", nopython=True)(module)
|
|
x = torch.rand((5, 5))
|
|
pred = torch.tensor(True)
|
|
self.assertTrue(same(module(pred, x), opt_m(pred, x)))
|
|
pred = torch.tensor(False)
|
|
self.assertTrue(same(module(pred, x), opt_m(pred, x)))
|
|
|
|
def test_map_with_quantization(self):
|
|
from functorch.experimental.control_flow import map
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
example_inputs = (torch.randn(5, 5),)
|
|
self.model = torch.nn.Linear(5, 5)
|
|
self.quantized_model = prepare_qat_fx(
|
|
self.model, qconfig_dict, example_inputs=example_inputs
|
|
)
|
|
|
|
def forward(self, x):
|
|
def body(x):
|
|
return x.sin() + self.quantized_model(x)
|
|
|
|
return map(body, x)
|
|
|
|
module = MyModule()
|
|
opt_m = torch._dynamo.optimize("eager", nopython=True)(module)
|
|
x = torch.rand((5, 5))
|
|
self.assertTrue(same(module(x), opt_m(x)))
|
|
|
|
def test_cond_side_effects(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
c = 0
|
|
|
|
def true_fn(x):
|
|
return x - c
|
|
|
|
def false_fn(x):
|
|
return x + c
|
|
|
|
def f(pred, x):
|
|
nonlocal c
|
|
c = 1
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
opt_fn = torch._dynamo.optimize("eager")(f)
|
|
c = 0
|
|
a = opt_fn(torch.tensor(False), torch.tensor([0.25, 0.25]))
|
|
self.assertTrue(same(torch.tensor([1.25, 1.25]), a))
|
|
|
|
def test_map_side_effects(self):
|
|
from functorch.experimental.control_flow import map
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.w = torch.tensor(1)
|
|
|
|
def forward(self, xs):
|
|
def body(x):
|
|
self.w += 1
|
|
return x
|
|
|
|
return map(body, xs)
|
|
|
|
mod = Module()
|
|
with self.assertRaisesRegex(
|
|
TypeError, "missing 1 required positional argument"
|
|
):
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(mod)
|
|
opt_fn(torch.randn(3, 2))
|
|
|
|
def test_cond_nested(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
def true_fn_nested(x):
|
|
return x * 10
|
|
|
|
def false_fn_nested(x):
|
|
return x * -1
|
|
|
|
def true_fn(pred2, x):
|
|
return x.sin()
|
|
|
|
def false_fn(pred2, x):
|
|
return x + cond(pred2, true_fn_nested, false_fn_nested, [x])
|
|
|
|
def f(pred, pred2, x):
|
|
return cond(pred, true_fn, false_fn, [pred2, x])
|
|
|
|
cc = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cc)(f)
|
|
true_true_sin = opt_fn(
|
|
torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
|
|
)
|
|
self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin))
|
|
|
|
true_false_sin = opt_fn(
|
|
torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25])
|
|
)
|
|
self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin))
|
|
|
|
false_true_sum_mult = opt_fn(
|
|
torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
|
|
)
|
|
self.assertTrue(
|
|
same(torch.tensor([2.75, 2.75]), false_true_sum_mult)
|
|
) # * 10 then add x
|
|
|
|
false_false_sum_neg = opt_fn(
|
|
torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25])
|
|
)
|
|
self.assertTrue(
|
|
same(torch.tensor([0.0, 0.0]), false_false_sum_neg)
|
|
) # * -1 then add x
|
|
self.assertTrue(cc.frame_count, 2)
|
|
|
|
def test_cond_export(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
def true_fn_nested(x):
|
|
return x * 10
|
|
|
|
def false_fn_nested(x):
|
|
return x * -1
|
|
|
|
def true_fn(pred2, x):
|
|
return x.sin()
|
|
|
|
def false_fn(pred2, x):
|
|
return x + cond(pred2, true_fn_nested, false_fn_nested, [x])
|
|
|
|
def f(pred, pred2, x):
|
|
return cond(pred, true_fn, false_fn, [pred2, x])
|
|
|
|
graph, guard = torch._dynamo.export(
|
|
f, torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
|
|
)
|
|
true_true_sin = graph(
|
|
torch.tensor(True), torch.tensor(True), torch.tensor([0.25, 0.25])
|
|
)
|
|
self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_true_sin))
|
|
|
|
true_false_sin = graph(
|
|
torch.tensor(True), torch.tensor(False), torch.tensor([0.25, 0.25])
|
|
)
|
|
self.assertTrue(same(torch.sin(torch.tensor([0.25, 0.25])), true_false_sin))
|
|
|
|
false_true_sum_mult = graph(
|
|
torch.tensor(False), torch.tensor(True), torch.tensor([0.25, 0.25])
|
|
)
|
|
self.assertTrue(
|
|
same(torch.tensor([2.75, 2.75]), false_true_sum_mult)
|
|
) # * 10 then add x
|
|
|
|
false_false_sum_neg = graph(
|
|
torch.tensor(False), torch.tensor(False), torch.tensor([0.25, 0.25])
|
|
)
|
|
self.assertTrue(
|
|
same(torch.tensor([0.0, 0.0]), false_false_sum_neg)
|
|
) # * -1 then add x
|
|
|
|
def test_cond_export_single_arg(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
def true_fn(x):
|
|
return x
|
|
|
|
def false_fn(x):
|
|
return x.sin()
|
|
|
|
def f(pred, x):
|
|
return cond(pred, true_fn, false_fn, [x])
|
|
|
|
graph, guard = torch._dynamo.export(
|
|
f, torch.tensor(False), torch.tensor([0.25, 0.25])
|
|
)
|
|
true_mirror = graph(torch.tensor(True), torch.tensor([0.25, 0.25]))
|
|
self.assertTrue(same(torch.tensor([0.25, 0.25]), true_mirror))
|
|
true_mirror_2 = graph(torch.tensor(True), torch.tensor([0.33, 0.33, 0.33]))
|
|
self.assertTrue(same(torch.tensor([0.33, 0.33, 0.33]), true_mirror_2))
|
|
|
|
false_sin = graph(torch.tensor(False), torch.tensor([0.5, 0.5]))
|
|
self.assertTrue(same(torch.sin(torch.tensor([0.5, 0.5])), false_sin))
|
|
|
|
def test_enum_guards(self):
|
|
class MyEnum(enum.Enum):
|
|
FOO = 10
|
|
BAR = 20
|
|
|
|
def fn(x, y):
|
|
if y == MyEnum.FOO:
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.rand(3)
|
|
y = MyEnum.BAR
|
|
ref = fn(x, y)
|
|
opt_fn = torch.compile(backend="eager")(fn)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
@patch.object(torch._dynamo.config, "print_graph_breaks", True)
|
|
def test_duplicate_graph_break_warning(self):
|
|
@torch._dynamo.optimize("eager")
|
|
def f1(a, b):
|
|
f2(a, b)
|
|
|
|
def f2(a, b):
|
|
c = a + b
|
|
print("break")
|
|
return a + b + c
|
|
|
|
@torch._dynamo.optimize("eager")
|
|
def g1(a, b):
|
|
g2(a, b)
|
|
|
|
def g2(a, b):
|
|
c = a + b
|
|
print("break")
|
|
return a + b + c
|
|
|
|
def count_graph_break_msgs(msgs):
|
|
return sum(msg.find("Graph break") != -1 for msg in msgs)
|
|
|
|
with self.assertLogs(logger="torch._dynamo", level=logging.WARNING) as log:
|
|
torch._dynamo.config.verbose = True
|
|
f1(torch.randn(10), torch.randn(10))
|
|
self.assertGreater(count_graph_break_msgs(log.output), 1)
|
|
|
|
with self.assertLogs(logger="torch._dynamo", level=logging.WARNING) as log:
|
|
torch._dynamo.config.verbose = False
|
|
g1(torch.randn(10), torch.randn(10))
|
|
self.assertEqual(count_graph_break_msgs(log.output), 1)
|
|
|
|
def test_inplace_param_update(self):
|
|
def fn(param, y):
|
|
prev_grad = torch.is_grad_enabled()
|
|
try:
|
|
torch.set_grad_enabled(False)
|
|
torch.set_grad_enabled(True)
|
|
torch.set_grad_enabled(False)
|
|
param.add_(y)
|
|
finally:
|
|
torch.set_grad_enabled(prev_grad)
|
|
|
|
y = torch.randn(4)
|
|
x = torch.nn.Parameter(torch.randn(4))
|
|
fn(x, y)
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
opt_fn(x, y)
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 3)
|
|
|
|
@unittest.skipIf(
|
|
not PLATFORM_SUPPORTS_FUSED_SDPA or not SM80OrLater,
|
|
"Can't run fused SDPA on this platform",
|
|
)
|
|
def test_parsing_sdpa(self):
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, query, key, value):
|
|
out = F.scaled_dot_product_attention(query, key, value, None, 0, True)
|
|
out = F.scaled_dot_product_attention(
|
|
query, key, value, None, 0, True, scale=8
|
|
)
|
|
out = F.scaled_dot_product_attention(
|
|
query=query,
|
|
key=key,
|
|
value=value,
|
|
attn_mask=None,
|
|
dropout_p=0,
|
|
is_causal=True,
|
|
)
|
|
out = F.scaled_dot_product_attention(
|
|
query,
|
|
key=key,
|
|
value=value,
|
|
attn_mask=None,
|
|
dropout_p=0,
|
|
is_causal=True,
|
|
)
|
|
out = F.scaled_dot_product_attention(
|
|
query, key, value, None, dropout_p=0, is_causal=True
|
|
)
|
|
out = F.scaled_dot_product_attention(query, key, value, None, scale=8)
|
|
return out
|
|
|
|
device = "cuda"
|
|
dtype = torch.float16
|
|
seq_len_q = 1
|
|
seq_len_k = 1
|
|
head_dim = 8
|
|
query = torch.ones(
|
|
1, 8, seq_len_q, head_dim, device=device, dtype=dtype, requires_grad=True
|
|
)
|
|
key = torch.ones(
|
|
1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
|
|
)
|
|
value = torch.ones(
|
|
1, 8, seq_len_k, head_dim, device=device, dtype=dtype, requires_grad=True
|
|
)
|
|
module = MyModule()
|
|
opt_mod = torch._dynamo.optimize("inductor")(module)
|
|
opt_mod(query, key, value)
|
|
|
|
def test_generate_tensor_from_list_of_numpy_primitive_type(self):
|
|
# Test sth like torch.LongTensor(list(np.int64, np.int64, ...))
|
|
def fn():
|
|
x = np.array([1, 2, 3, 4, 5, 6], dtype=np.int64)
|
|
y = [x[0], x[2], x[4]]
|
|
z = torch.LongTensor(y)
|
|
return z
|
|
|
|
ref = fn()
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn()
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_autograd_function_equivalence(self):
|
|
for i in range(1, 5):
|
|
model = globals()[f"Module{i}"]()
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
self.assertTrue(
|
|
torch.allclose(opt_model(torch.ones(2, 3)), torch.tensor([2.0]))
|
|
)
|
|
|
|
def test_autograd_function_has_graph_break(self):
|
|
x = torch.randn(10)
|
|
for model in [Module5(), Module6()]:
|
|
torch._dynamo.reset()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize(cnts)(model)
|
|
for _ in range(3):
|
|
ref = model(x)
|
|
res = opt_model(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
|
|
def test_object_classmethod(self):
|
|
class C:
|
|
@classmethod
|
|
def fn(cls, x):
|
|
return x + x
|
|
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def f():
|
|
return C().fn(torch.ones(2, 3))
|
|
|
|
self.assertTrue(torch.allclose(f(), torch.tensor([2.0])))
|
|
|
|
def test_object_staticmethod(self):
|
|
class C:
|
|
@staticmethod
|
|
def fn(x):
|
|
return x + x
|
|
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def f():
|
|
return C().fn(torch.ones(2, 3))
|
|
|
|
self.assertTrue(torch.allclose(f(), torch.tensor([2.0])))
|
|
|
|
def test_user_function_variable_supports_enum_argument(self):
|
|
class Foo(enum.Enum):
|
|
FOO = 0
|
|
BAR = 1
|
|
|
|
def gn(x, y=Foo.FOO):
|
|
if y is Foo.FOO:
|
|
return x
|
|
else:
|
|
return x + 1
|
|
|
|
def fn(x):
|
|
return gn(x)
|
|
|
|
x = torch.randn(2, 3)
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_user_function_variable_supports_type_abcmeta_argument(self):
|
|
class Foo(metaclass=abc.ABCMeta):
|
|
@abc.abstractclassmethod
|
|
def read(self): # noqa: B027
|
|
pass
|
|
|
|
class Bar(Foo):
|
|
def read(self):
|
|
return "Hello World!"
|
|
|
|
class Baz:
|
|
pass
|
|
|
|
def gn(x, tys=(Bar, Baz)):
|
|
if Bar in tys:
|
|
return x - 1
|
|
else:
|
|
return x + 1
|
|
|
|
def fn(x):
|
|
return gn(x)
|
|
|
|
x = torch.randn(2, 3)
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_user_function_variable_supports_function_argument(self):
|
|
# Test user defined function default arguments can be:
|
|
# 1, user defined functions (e.g, add1)
|
|
# 2, torch functions (e.g, torch.sin)
|
|
# 3, python builtin functions (e.g, operator.neg)
|
|
def add1(x):
|
|
return x + 1
|
|
|
|
def gn(x, f1=add1, f2=torch.sin, f3=operator.neg):
|
|
return f3(f2(f1(x)))
|
|
|
|
def fn(x):
|
|
return gn(x)
|
|
|
|
x = torch.randn(2, 3)
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_typing_variable_isinstance(self):
|
|
def fn(x, m):
|
|
if isinstance(m, typing.Mapping):
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.randn(2, 3)
|
|
m = {"x": torch.randn(3)}
|
|
ref = fn(x, m)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(x, m)
|
|
self.assertTrue(torch.allclose(ref, res))
|
|
|
|
def test_repro_graph_breaks_in__get_item_by_idx(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mod = torch.nn.Sequential(
|
|
torch.nn.Linear(3, 3), torch.nn.Linear(3, 3)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.mod[0](x)
|
|
|
|
m = Mod()
|
|
graph, _ = torch._dynamo.export(m, torch.randn(3, 3))
|
|
|
|
def test_nn_sequential_invocation(self):
|
|
with freeze_rng_state():
|
|
|
|
class TestModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linears = torch.nn.Sequential(
|
|
torch.nn.Linear(2, 2),
|
|
torch.nn.Linear(2, 2),
|
|
torch.nn.Linear(2, 2),
|
|
torch.nn.Linear(2, 2),
|
|
)
|
|
|
|
def forward(self, x):
|
|
all_but_last = self.linears[:-1]
|
|
return all_but_last(x)
|
|
|
|
m = TestModel()
|
|
x = torch.rand((2, 2))
|
|
real = m(x)
|
|
graph, _ = torch._dynamo.export(m, x)
|
|
dynamo_result = graph(x)
|
|
self.assertTrue(same(real, dynamo_result))
|
|
|
|
def test_nn_sequential_invocation_reposition_indices(self):
|
|
with freeze_rng_state():
|
|
|
|
class TestModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linears = torch.nn.Sequential(
|
|
torch.nn.Linear(2, 2),
|
|
torch.nn.Linear(2, 2),
|
|
torch.nn.Linear(2, 2),
|
|
torch.nn.Linear(2, 2),
|
|
)
|
|
|
|
def forward(self, x):
|
|
all_but_last = self.linears[1:3]
|
|
return all_but_last(x)
|
|
|
|
m = TestModel()
|
|
x = torch.rand((2, 2))
|
|
real = m(x)
|
|
graph, _ = torch._dynamo.export(m, x)
|
|
dynamo_result = graph(x)
|
|
self.assertTrue(same(real, dynamo_result))
|
|
|
|
def test_error_on_nested_fx_trace(self):
|
|
input = torch.rand(2, 3)
|
|
|
|
def f(x):
|
|
x + x
|
|
|
|
real = f(input)
|
|
|
|
optimized = torch._dynamo.optimize("eager")(f)
|
|
self.assertTrue(same(optimized(input), real))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Detected that you are using FX"):
|
|
gm = torch.fx.symbolic_trace(optimized)
|
|
|
|
@patch.object(torch._dynamo.config, "error_on_nested_fx_trace", False)
|
|
def test_no_error_on_nested_fx_trace(self):
|
|
input = torch.rand(2, 3)
|
|
|
|
def f(x):
|
|
x + x
|
|
|
|
real = f(input)
|
|
|
|
optimized = torch._dynamo.optimize("eager")(f)
|
|
self.assertTrue(same(optimized(input), real))
|
|
|
|
# should not error
|
|
gm = torch.fx.symbolic_trace(optimized)
|
|
self.assertTrue(same(gm(input), real))
|
|
|
|
def test_not_dynamic_scope(self):
|
|
def f(y):
|
|
x = 1
|
|
|
|
def g():
|
|
x = 2
|
|
return lambda: x
|
|
|
|
return y + g()()
|
|
|
|
input = torch.zeros(1)
|
|
real = f(input)
|
|
optimized = torch._dynamo.optimize("eager")(f)
|
|
opt = optimized(input)
|
|
self.assertTrue(same(opt, real))
|
|
|
|
def test_inference_mode(self):
|
|
@torch.inference_mode()
|
|
def func(x, y):
|
|
return x.add(1.0) + y
|
|
|
|
x = torch.ones(4, requires_grad=True)
|
|
y = torch.ones(4, requires_grad=True)
|
|
ref = func(x, y)
|
|
opt_func = torch._dynamo.optimize("eager")(func)
|
|
|
|
x1 = torch.ones(4, requires_grad=True)
|
|
res = opt_func(x1, y)
|
|
self.assertTrue(same(ref, res))
|
|
self.assertTrue(same(x, x1))
|
|
|
|
def test_if_cond_nn_mod(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self, output_relu=True):
|
|
super().__init__()
|
|
self.relu = torch.nn.ReLU() if output_relu else None
|
|
|
|
def forward(self, x):
|
|
x = torch.sin(x)
|
|
if self.relu:
|
|
x = self.relu(x)
|
|
return x
|
|
|
|
model = MockModule()
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
|
|
x = torch.rand(4)
|
|
ref = model(x)
|
|
res = opt_model(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
model = MockModule(output_relu=False)
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
|
|
x = torch.rand(4)
|
|
ref = model(x)
|
|
res = opt_model(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_if_cond_user_defined_object(self):
|
|
# obj.__bool__ is not existed
|
|
class A: # noqa: B903
|
|
def __init__(self, x):
|
|
self.x = x
|
|
|
|
# obj.__bool__ is function and returns bool type
|
|
class B:
|
|
def __init__(self, x):
|
|
self.x = x
|
|
|
|
def __bool__(self):
|
|
return self.x > 0
|
|
|
|
# obj.__bool__ is non-function
|
|
class C:
|
|
def __init__(self, x):
|
|
self.x = x
|
|
self.__bool__ = False
|
|
|
|
def fn(x, obj):
|
|
if not obj:
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.rand(4)
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnts, nopython=True)(fn)
|
|
obj1 = A(0.5)
|
|
obj2 = B(0.5)
|
|
obj3 = B(-0.5)
|
|
obj4 = C(0.5)
|
|
for obj in [obj1, obj2, obj3, obj4, obj3, obj2]:
|
|
ref = fn(x, obj)
|
|
res = opt_fn(x, obj)
|
|
self.assertTrue(same(ref, res))
|
|
self.assertEqual(cnts.frame_count, 4)
|
|
|
|
def test_if_cond_user_defined_object2(self):
|
|
# obj.__bool__ is function and returns non-bool type
|
|
class MyObj:
|
|
def __init__(self, x):
|
|
self.x = x
|
|
|
|
def __bool__(self):
|
|
self.x = 1
|
|
return self.x
|
|
|
|
def fn(a, obj):
|
|
if not obj:
|
|
return a + obj.x
|
|
else:
|
|
return a - obj.x
|
|
|
|
x = torch.rand(4)
|
|
obj = MyObj(0.5)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
try:
|
|
opt_fn(x, obj)
|
|
self.assertFalse(True)
|
|
except TypeError as e:
|
|
self.assertIn("__bool__ should return bool, returned int", str(e))
|
|
|
|
def test_class_has_instancecheck_method(self):
|
|
class A:
|
|
pass
|
|
|
|
class ExampleMeta(type):
|
|
def __instancecheck__(cls, instance):
|
|
return True
|
|
|
|
class B(metaclass=ExampleMeta):
|
|
pass
|
|
|
|
def fn(x, obj):
|
|
if isinstance(obj, B):
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.rand(4)
|
|
obj = A()
|
|
ref = fn(x, obj)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x, obj)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_torch_cuda_is_available(self):
|
|
def fn(x):
|
|
if torch.cuda.is_available():
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.rand(4)
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
|
@unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
|
|
def test_torch_cudnn_is_acceptable(self):
|
|
def fn(x):
|
|
if torch.backends.cudnn.is_acceptable(tensor=x):
|
|
return x + 1
|
|
return x
|
|
|
|
x = torch.rand(4).cuda()
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
|
@unittest.skipIf(not torch.backends.cudnn.is_available(), "requires cudnn")
|
|
def test_torch_cudnn_is_acceptable_bad_inputs(self):
|
|
def fn1(x):
|
|
if torch.backends.cudnn.is_acceptable("invalid"):
|
|
return x + 1
|
|
return x
|
|
|
|
def fn2(x):
|
|
if torch.backends.cudnn.is_acceptable(x, 3.14):
|
|
return x + 1
|
|
return x
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "Expect input to cudnn.is_acceptable to be a tensor"
|
|
):
|
|
x1 = torch.rand(4).cuda()
|
|
opt_fn1 = torch._dynamo.optimize("eager", nopython=True)(fn1)
|
|
res1 = opt_fn1(x1)
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "Expect 1 input to cudnn.is_acceptable"
|
|
):
|
|
x2 = torch.rand(4).cuda()
|
|
opt_fn2 = torch._dynamo.optimize("eager", nopython=True)(fn2)
|
|
res = opt_fn2(x2)
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
|
def test_get_device(self):
|
|
def fn(x, y):
|
|
x = x + 1
|
|
y = y + 1
|
|
return x.get_device(), y.get_device()
|
|
|
|
x = torch.rand(4, device="cuda")
|
|
y = torch.rand(4, device="cpu")
|
|
ref = fn(x, y)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_disable_flag(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
with patch.dict(os.environ, {"TORCH_COMPILE_DISABLE": "1"}):
|
|
|
|
def fn(x, y):
|
|
x = x + 1
|
|
y = y + 1
|
|
|
|
opt_fn = torch._dynamo.optimize(cnt)
|
|
|
|
self.assertEqual(cnt.frame_count, 0)
|
|
|
|
def test_is_compiling(self):
|
|
def f():
|
|
if torch._dynamo.is_compiling():
|
|
return torch.ones(2, 2)
|
|
else:
|
|
return torch.zeros(2, 2)
|
|
|
|
opt_f = torch._dynamo.optimize("eager")(f)
|
|
|
|
self.assertEqual(f(), torch.zeros(2, 2))
|
|
self.assertEqual(opt_f(), torch.ones(2, 2))
|
|
|
|
def test_torch_generator_set_state(self):
|
|
def fn():
|
|
default_state = torch.default_generator.get_state()
|
|
x = torch.rand([2, 3])
|
|
torch._dynamo.graph_break()
|
|
torch.default_generator.set_state(default_state)
|
|
y = torch.rand([2, 3])
|
|
return x, y
|
|
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
x, y = opt_fn()
|
|
self.assertEqual(x, y)
|
|
|
|
def test_guard_failure_fn(self):
|
|
def fn(x, y, k):
|
|
x = x + 1
|
|
y = y + 1
|
|
return x * y * k
|
|
|
|
x = torch.tensor([0.5, 0.5])
|
|
y = torch.tensor([1.0, 1.0])
|
|
|
|
guard_failure = None
|
|
|
|
def guard_failures(failure):
|
|
nonlocal guard_failure
|
|
guard_failure = failure
|
|
|
|
opt_fn = torch._dynamo.optimize(
|
|
"eager", nopython=True, guard_fail_fn=guard_failures
|
|
)(fn)
|
|
|
|
x2 = torch.tensor([0.5, 0.5, 1.0])
|
|
y2 = torch.tensor([0.5, 0.5, 0.5])
|
|
|
|
opt_fn(x, y, 3)
|
|
opt_fn(x2, y2, 5)
|
|
|
|
if (
|
|
torch._dynamo.config.dynamic_shapes
|
|
and not torch._dynamo.config.specialize_int
|
|
and not torch._dynamo.config.assume_static_by_default
|
|
):
|
|
# we didn't actually test guard_failure_fn here but whatever,
|
|
# nice to see no guard failure on the test
|
|
self.assertTrue(guard_failure is None)
|
|
else:
|
|
self.assertTrue(guard_failure is not None)
|
|
if not torch._dynamo.config.dynamic_shapes:
|
|
self.assertExpectedInline(guard_failure[0], """L['k'] == 3""")
|
|
|
|
@patch.object(torch._dynamo.config, "dynamic_shapes", True)
|
|
def test_guard_failure_fn_shape_control(self):
|
|
def fn(x, y):
|
|
if x.shape[0] < 3:
|
|
if y.shape[0] < 3:
|
|
return x * y
|
|
else:
|
|
return x + y
|
|
else:
|
|
return -1
|
|
|
|
x = torch.randn([2, 2])
|
|
y = torch.randn([2, 2])
|
|
|
|
guard_failure = None
|
|
|
|
def guard_failures(failure):
|
|
nonlocal guard_failure
|
|
guard_failure = failure
|
|
|
|
opt_fn = torch._dynamo.optimize(
|
|
"eager", nopython=True, guard_fail_fn=guard_failures
|
|
)(fn)
|
|
|
|
x2 = torch.randn([5, 5])
|
|
y2 = torch.randn([5, 5])
|
|
|
|
opt_fn(x, y)
|
|
opt_fn(x2, y2)
|
|
|
|
self.assertTrue(guard_failure is not None)
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(
|
|
guard_failure[0],
|
|
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 5""",
|
|
)
|
|
else:
|
|
self.assertExpectedInline(guard_failure[0], """L['x'].size()[0] < 3""")
|
|
|
|
def test_guard_failure_fn2(self):
|
|
def fn(x, y):
|
|
x = x + 1
|
|
y = y + 1
|
|
return x * y
|
|
|
|
x = torch.tensor([0.5, 0.5])
|
|
y = torch.tensor([1.0, 1.0])
|
|
|
|
guard_failure = None
|
|
|
|
def guard_failures(failure):
|
|
nonlocal guard_failure
|
|
guard_failure = failure
|
|
|
|
opt_fn = torch._dynamo.optimize(
|
|
"eager", nopython=True, guard_fail_fn=guard_failures
|
|
)(fn)
|
|
|
|
x2 = torch.tensor([0.5, 0.5, 1.0])
|
|
y2 = torch.tensor([0.5, 0.5, 0.5])
|
|
|
|
opt_fn(x, y)
|
|
opt_fn(x2, y2)
|
|
|
|
if torch._dynamo.config.dynamic_shapes:
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(
|
|
guard_failure[0],
|
|
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
|
|
)
|
|
else:
|
|
self.assertTrue(guard_failure is None)
|
|
else:
|
|
self.assertTrue(guard_failure is not None)
|
|
self.assertExpectedInline(
|
|
guard_failure[0],
|
|
"""tensor 'L['x']' size mismatch at index 0. expected 2, actual 3""",
|
|
)
|
|
|
|
def test_guard_failure_fn_tensor_iter(self):
|
|
def fn(x):
|
|
for y in x:
|
|
y.add_(1.0)
|
|
return y
|
|
|
|
guard_failure = None
|
|
|
|
def guard_failures(failure):
|
|
nonlocal guard_failure
|
|
guard_failure = failure
|
|
|
|
opt_fn = torch._dynamo.optimize(
|
|
"eager", nopython=True, guard_fail_fn=guard_failures
|
|
)(fn)
|
|
|
|
args1 = torch.randn(10, 10)
|
|
out = fn(args1)
|
|
opt_out = opt_fn(args1)
|
|
self.assertTrue(same(out, opt_out))
|
|
|
|
args2 = torch.randn(9, 10)
|
|
out = fn(args2)
|
|
opt_out = opt_fn(args2)
|
|
self.assertTrue(same(out, opt_out))
|
|
|
|
# guard is expected for both static and dynamic shapes
|
|
self.assertTrue(guard_failure is not None)
|
|
self.assertExpectedInline(guard_failure[0], """len(L['x']) == 10""")
|
|
|
|
def test_restore_graphstate(self):
|
|
# This function does some guard accumulation,
|
|
# and then rolls back due to control flow.
|
|
# The idea is that if one were printing guards as they appear,
|
|
# they would see this insert a guard that does not show up in the final set of
|
|
# guards as we rolled back from it.
|
|
def nested_fn(s):
|
|
if x[0] < 10:
|
|
return s * s
|
|
return s
|
|
|
|
def fn(x, y):
|
|
x = x + 1
|
|
y = nested_fn(y)
|
|
y = y + 10
|
|
return x * y
|
|
|
|
all_guards = []
|
|
|
|
def guard_export_print(guards):
|
|
nonlocal all_guards
|
|
all_guards.extend(guards)
|
|
|
|
opt_fn = torch._dynamo.optimize("eager", guard_export_fn=guard_export_print)(fn)
|
|
|
|
x = torch.tensor([0.5, 0.5])
|
|
y = torch.tensor([1.0, 1.0])
|
|
opt_fn(x, y)
|
|
|
|
for guard in all_guards:
|
|
# This guard was created
|
|
self.assertTrue(guard.name != "nested_fn.__closure__[0].cell_contents")
|
|
|
|
# Note - here be mild dragons.
|
|
# This test relies a ton on internal implementation. Future refactor efforts
|
|
# are welcome to delete it if necessary, rewriting this test constantly is a chore, not
|
|
# a feature. We kept it around with some amount of saddness, as it was extremely useful in debugging.
|
|
def test_restore_graphstate_internals(self):
|
|
def fn(x, y):
|
|
x = x + 1
|
|
y = y + 1
|
|
return x * y
|
|
|
|
_, guards = torch._dynamo.export(
|
|
fn, torch.tensor([0.25, 0.25]), torch.tensor([0.25, 0.25])
|
|
)
|
|
# Dummy ctor
|
|
graph = OutputGraph(
|
|
f_globals={},
|
|
code_options={},
|
|
compiler_fn=None,
|
|
root_tx=None,
|
|
export=False,
|
|
export_constraints=None,
|
|
frame_state={"_id": 0},
|
|
)
|
|
graph.nn_modules_sources = {}
|
|
# Contrived generation timestamp
|
|
graph.timestamp = 4
|
|
# Contrived guards
|
|
graph.tracing_context.guards_context.dynamo_guards = guards
|
|
|
|
# Save the state
|
|
state = graph.copy_graphstate()
|
|
# Saving increments the generation
|
|
self.assertEqual(graph.timestamp, 5)
|
|
|
|
# Assure that the saved state is valid
|
|
self.assertEqual(state.timestamp, 4)
|
|
|
|
# Ensure that the guards reflect the expected state
|
|
self.assertEqual(graph.tracing_context.guards_context.dynamo_guards, guards)
|
|
self.assertEqual(graph.guards, guards)
|
|
|
|
# Mess around with the state
|
|
graph.tracing_context.guards_context.dynamo_guards = set()
|
|
self.assertEqual(graph.guards, set())
|
|
|
|
# Restore the state
|
|
graph.restore_graphstate(state)
|
|
|
|
# Make sure it restored correctly
|
|
self.assertEqual(graph.timestamp, 4)
|
|
self.assertEqual(graph.guards, guards)
|
|
self.assertEqual(graph.tracing_context.guards_context.dynamo_guards, guards)
|
|
|
|
def test_call_parent_non_class_methods_from_child(self):
|
|
class A:
|
|
def add(self, x):
|
|
return x + 10
|
|
|
|
def mul(self, x):
|
|
return x * 0.1
|
|
|
|
class B(A):
|
|
def add(self, x):
|
|
return x + 20
|
|
|
|
def mul(self, x):
|
|
return x * 0.2
|
|
|
|
class C(B):
|
|
def add(self, x):
|
|
y = A.add(self, x)
|
|
z = B.mul(self, y)
|
|
return z + 30
|
|
|
|
x = torch.rand(4)
|
|
fn = C().add
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_builder_for_class_with_metaclass(self):
|
|
class ExampleMeta(type):
|
|
pass
|
|
|
|
class MyClass(metaclass=ExampleMeta):
|
|
pass
|
|
|
|
def fn(x, y):
|
|
if isinstance(y, MyClass):
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
x = torch.rand([4, 4])
|
|
y = MyClass()
|
|
ref = fn(x, y)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_tuple_from_tuple_iter(self):
|
|
def inner_fn(*args):
|
|
acc = torch.ones(10, 10)
|
|
for arg in args:
|
|
acc.add_(arg)
|
|
|
|
return acc
|
|
|
|
@torch._dynamo.optimize("eager")
|
|
def fn(inputs, params):
|
|
y = tuple(inputs) + tuple(params)
|
|
return inner_fn(*y)
|
|
|
|
inputs = [torch.randn(10, 10) for _ in range(3)]
|
|
|
|
fn(inputs, iter(tuple(inputs)))
|
|
|
|
def test_torch_package_working_with_trace(self):
|
|
# from torch._dynamo.test_case import run_tests
|
|
|
|
inputs = [torch.randn([2, 2]), torch.randn([2, 2])]
|
|
|
|
optimized_model = torch._dynamo.optimize(backend="eager")(
|
|
MyPickledModule(torch.randn([2, 2]))
|
|
)
|
|
from torch import package
|
|
|
|
path = "/tmp/MyPickledModule.pt"
|
|
package_name = "MyPickledModule"
|
|
resource_name = "MyPickledModule.pkl"
|
|
|
|
model = MyPickledModule(torch.randn([2, 2]))
|
|
|
|
with package.PackageExporter(path) as exp:
|
|
exp.extern("**")
|
|
exp.save_pickle(package_name, resource_name, model)
|
|
|
|
imp = package.PackageImporter(path)
|
|
loaded_model = imp.load_pickle(package_name, resource_name)
|
|
|
|
optimized_loaded_model = torch._dynamo.optimize("eager")(loaded_model)(*inputs)
|
|
|
|
def test_shape_and_tuple_equality(self):
|
|
def fn(x, y, t):
|
|
z = x * y
|
|
if x.size() == t:
|
|
return z.cos()
|
|
return z.sin()
|
|
|
|
torch._dynamo.optimize("eager", nopython=True)(fn)(
|
|
torch.randn([4, 4]), torch.randn([4, 4]), (4, 4)
|
|
)
|
|
|
|
def test_int_list(self):
|
|
# if dynamic_shapes == True: unspec int list
|
|
# if dynamic_shapes == False: spec int list
|
|
def fn(x, y):
|
|
return torch.sin(x + y[1] % 2)
|
|
|
|
x = torch.randn(6)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
for i in range(10, 25, 3):
|
|
y = [i, i + 1, i + 2]
|
|
ref = fn(x, y)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
self.assertEqual(cnt.frame_count, ifunspec(ifdyn(1, 5), 5))
|
|
|
|
# specifically test for tensor.attribute -> torch.something()
|
|
def test_real_imag_tensor_attribute(self):
|
|
def fn(x, y):
|
|
a = x.real
|
|
b = x.imag
|
|
return torch.mul(torch.add(a, y), b)
|
|
|
|
x_real = torch.rand((4, 4))
|
|
x_imag = torch.rand((4, 4))
|
|
x = torch.complex(x_real, x_imag)
|
|
y = torch.rand((4, 4))
|
|
|
|
ref = fn(x, y)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_T_tensor_attribute(self):
|
|
def fn(x, y):
|
|
a = x.T
|
|
return torch.add(a, y)
|
|
|
|
x = torch.rand((4, 4))
|
|
y = torch.rand((4, 4))
|
|
|
|
ref = fn(x, y)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_recursive_tensor_attribute(self):
|
|
def fn(x, y):
|
|
a = x.real.T
|
|
b = x.imag
|
|
return torch.mul(torch.add(a, y), b)
|
|
|
|
x_real = torch.rand((4, 4))
|
|
x_imag = torch.rand((4, 4))
|
|
x = torch.complex(x_real, x_imag)
|
|
y = torch.rand((4, 4))
|
|
|
|
ref = fn(x, y)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_tagging_tensors_simple(self):
|
|
def foo(x, y):
|
|
return x * y, x, y
|
|
|
|
a = torch.randn([3, 3])
|
|
a.tag = "a"
|
|
a.frog = "ribbity ribbit"
|
|
b = torch.randn([3, 3])
|
|
b.tag = "b"
|
|
b.frog = "ribbit"
|
|
|
|
exported = torch._dynamo.export(foo, a, b)
|
|
out_graph = exported[0]
|
|
|
|
nodes = list(out_graph.graph.nodes)
|
|
placeholders = [node for node in nodes if node.op == "placeholder"]
|
|
all_tags = []
|
|
all_frogs = []
|
|
for placeholder in placeholders:
|
|
if "tensor_dict" in placeholder.meta:
|
|
all_tags.append(placeholder.meta["tensor_dict"]["tag"])
|
|
all_frogs.append(placeholder.meta["tensor_dict"]["frog"])
|
|
|
|
self.assertEqual(all_tags, ["a", "b"])
|
|
self.assertEqual(all_frogs, ["ribbity ribbit", "ribbit"])
|
|
|
|
def test_tagging_tensors_mix_used_unused_structure(self):
|
|
def pre_attention_state_ops(input, mems, state):
|
|
lc_key = state[0]
|
|
lc_val = state[1]
|
|
bar = []
|
|
for i in range(0, 4):
|
|
bar2 = []
|
|
for j in range(0, 3):
|
|
bar2.append(
|
|
lc_key + lc_val + torch.tensor([0.1, 0.25, 0.4, 0.5, 0.1])
|
|
)
|
|
bar.append(bar2)
|
|
|
|
return bar
|
|
|
|
mems = torch.tensor([[[1.8364, 0.2724, -1.4917, -0.4367, 0.8640]]])
|
|
state = [
|
|
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
|
|
torch.tensor([[[1.0517, 0.3848, -0.6472, 0.0823, 0.9116]]]),
|
|
]
|
|
i = torch.tensor(
|
|
[
|
|
[0.0313, -0.1487, -0.3846, -0.5321],
|
|
[-1.7073, 1.3331, -0.0890, -1.4935],
|
|
[-0.8314, -0.1862, -0.5935, 1.5232],
|
|
]
|
|
)
|
|
|
|
mems.tag = "MEMS"
|
|
i.tag = "FOO"
|
|
state[0].tag = "STATE_0"
|
|
state[1].tag = "HMMM"
|
|
|
|
exported = torch._dynamo.export(pre_attention_state_ops, i, mems, state)
|
|
out_graph = exported[0]
|
|
|
|
nodes = list(out_graph.graph.nodes)
|
|
placeholders = [node for node in nodes if node.op == "placeholder"]
|
|
all_tags = []
|
|
for placeholder in placeholders:
|
|
if "tensor_dict" in placeholder.meta:
|
|
all_tags.append(placeholder.meta["tensor_dict"]["tag"])
|
|
|
|
self.assertEqual(all_tags, ["STATE_0", "HMMM"])
|
|
|
|
def test_get_custom_tensor_attribute(self):
|
|
def fn(x):
|
|
return x.custom_attr * x
|
|
|
|
x = torch.rand((2, 2))
|
|
x.custom_attr = 3.14
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_set_custom_tensor_attribute(self):
|
|
def fn(x):
|
|
x.custom_attr = 3.14
|
|
return x.custom_attr * x
|
|
|
|
x = torch.rand((2, 2))
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_if_tensor_is_none(self):
|
|
"""
|
|
Python 3.11 adds new jump instructions that check if
|
|
TOS is None. We do not support these instructions.
|
|
"""
|
|
|
|
def f(x, y):
|
|
z = 1
|
|
if x is None:
|
|
z *= 2
|
|
if y is not None:
|
|
z *= 3
|
|
return z
|
|
|
|
opt_f = torch._dynamo.optimize("eager", nopython=True)(f)
|
|
self.assertEqual(opt_f(None, torch.ones(2)), 6)
|
|
|
|
if sys.version_info >= (3, 11):
|
|
insts = bytecode_transformation.cleaned_instructions(f.__code__)
|
|
for inst in insts:
|
|
self.assertNotIn("_NONE", inst.opname)
|
|
|
|
@skipIfNotPy311
|
|
def test_py311_jump_offset(self):
|
|
new_inst = bytecode_transformation.create_instruction
|
|
load_global = bytecode_transformation.create_load_global
|
|
consts = (None, 1, 2, 3, 4)
|
|
|
|
def create_test_code(jump_opname, target_idx):
|
|
targets = [
|
|
new_inst("LOAD_CONST", argval=1),
|
|
new_inst("LOAD_CONST", argval=3),
|
|
]
|
|
jump_to_target_inst = new_inst(jump_opname, target=targets[target_idx])
|
|
"""
|
|
pseudocode of generated bytecode:
|
|
def test_py311_fn():
|
|
goto target1
|
|
target0:
|
|
return 1
|
|
target1:
|
|
goto [target0/target2] (via fwd or bwd jump)
|
|
return 2
|
|
target2:
|
|
return 3
|
|
return 4
|
|
"""
|
|
# test with LOAD_GLOBAL since it has a different instruction size
|
|
insts = [
|
|
new_inst("RESUME", arg=0),
|
|
new_inst("JUMP_FORWARD", target=jump_to_target_inst),
|
|
targets[0],
|
|
load_global("print", False),
|
|
new_inst("POP_TOP"),
|
|
new_inst("RETURN_VALUE"),
|
|
jump_to_target_inst,
|
|
new_inst("LOAD_CONST", argval=2),
|
|
load_global("print", False),
|
|
new_inst("POP_TOP"),
|
|
new_inst("RETURN_VALUE"),
|
|
targets[1],
|
|
new_inst("RETURN_VALUE"),
|
|
new_inst("LOAD_CONST", argval=4),
|
|
new_inst("RETURN_VALUE"),
|
|
]
|
|
code_options = collections.OrderedDict(
|
|
[
|
|
("co_argcount", 0),
|
|
("co_posonlyargcount", 0),
|
|
("co_kwonlyargcount", 0),
|
|
("co_nlocals", 0),
|
|
("co_stacksize", 2),
|
|
("co_flags", 3),
|
|
("co_code", b""),
|
|
("co_consts", consts),
|
|
("co_names", ("print",)),
|
|
("co_varnames", ()),
|
|
("co_filename", __file__),
|
|
("co_name", "test_py311_fn"),
|
|
("co_qualname", "test_py311_fn"),
|
|
("co_firstlineno", 1),
|
|
("co_linetable", b""),
|
|
("co_exceptiontable", b""),
|
|
("co_freevars", ()),
|
|
("co_cellvars", ()),
|
|
]
|
|
)
|
|
return bytecode_transformation.clean_and_assemble_instructions(
|
|
insts,
|
|
list(code_options.keys()),
|
|
code_options,
|
|
)
|
|
|
|
# format: jump_opname, target_idx, expected forward jump, expected return value
|
|
test_args = (
|
|
("JUMP_FORWARD", 0, False, 1),
|
|
("JUMP_FORWARD", 1, True, 3),
|
|
("JUMP_BACKWARD", 0, False, 1),
|
|
("JUMP_BACKWARD", 1, True, 3),
|
|
)
|
|
|
|
for test in test_args:
|
|
insts, code = create_test_code(test[0], test[1])
|
|
# check if offset of latest jump instruction is forward/backward
|
|
for inst in reversed(insts):
|
|
if inst.opname.startswith("JUMP"):
|
|
if test[2]:
|
|
self.assertIn("FORWARD", inst.opname)
|
|
else:
|
|
self.assertIn("BACKWARD", inst.opname)
|
|
break
|
|
# run the code and check result
|
|
|
|
def dummy_fn():
|
|
pass
|
|
|
|
dummy_fn.__code__ = code
|
|
self.assertEqual(dummy_fn(), test[3])
|
|
|
|
dummy_opt = torch._dynamo.optimize("eager")(dummy_fn)
|
|
self.assertEqual(dummy_opt(), test[3])
|
|
|
|
def test_exception_table_encode_varint(self):
|
|
# these numbers have no real meaning to them
|
|
nums = [
|
|
0b111_101010_000000,
|
|
0b1100_111000_010101_101010,
|
|
]
|
|
b = bytecode_transformation.encode_exception_table_varint(
|
|
nums[0]
|
|
) + bytecode_transformation.encode_exception_table_varint(nums[1])
|
|
nums_new = []
|
|
b_iter = iter(bytes(b))
|
|
while True:
|
|
try:
|
|
nums_new.append(
|
|
bytecode_transformation.decode_exception_table_varint(b_iter)
|
|
)
|
|
except StopIteration:
|
|
break
|
|
self.assertEqual(nums, nums_new)
|
|
|
|
@skipIfNotPy311
|
|
def test_exception_table_parsing(self):
|
|
def fn():
|
|
try:
|
|
with a():
|
|
b()
|
|
c()
|
|
except Exception:
|
|
d()
|
|
finally:
|
|
e()
|
|
f()
|
|
|
|
tab = bytecode_transformation.parse_exception_table(
|
|
fn.__code__.co_exceptiontable
|
|
)
|
|
b = bytecode_transformation.assemble_exception_table(tab)
|
|
self.assertEqual(b, fn.__code__.co_exceptiontable)
|
|
|
|
@skipIfNotPy311
|
|
def test_exception_table_e2e(self):
|
|
def fn():
|
|
try:
|
|
with a():
|
|
b()
|
|
c()
|
|
except Exception:
|
|
d()
|
|
finally:
|
|
e()
|
|
f()
|
|
|
|
def nothing(*args):
|
|
pass
|
|
|
|
code = bytecode_transformation.transform_code_object(fn.__code__, nothing)
|
|
self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable)
|
|
|
|
@skipIfNotPy311
|
|
def test_exception_table_e2e_2(self):
|
|
# last instructions of an exn_table entry is a large instruction
|
|
# i.e., LOAD_GLOBAL a
|
|
def fn():
|
|
try:
|
|
return a
|
|
except Exception:
|
|
pass
|
|
|
|
def nothing(*args):
|
|
pass
|
|
|
|
code = bytecode_transformation.transform_code_object(fn.__code__, nothing)
|
|
self.assertEqual(code.co_exceptiontable, fn.__code__.co_exceptiontable)
|
|
|
|
@skipIfNotPy311
|
|
def test_exception_table_entry_propagation(self):
|
|
insts = []
|
|
for _ in range(10):
|
|
insts.append(bytecode_transformation.create_instruction("NOP"))
|
|
insts[8].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[0], insts[9], insts[0], 0, True
|
|
)
|
|
insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[0], insts[0], insts[1], 0, True
|
|
)
|
|
insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[0], insts[2], insts[2], 0, True
|
|
)
|
|
insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[4], insts[6], insts[3], 0, True
|
|
)
|
|
insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[9], insts[9], insts[4], 0, True
|
|
)
|
|
insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[7], insts[9], insts[5], 0, True
|
|
)
|
|
bytecode_transformation.propagate_inst_exn_table_entries(insts)
|
|
expected = [1, 2, 2, 0, 3, 3, 3, 5, 5, 4]
|
|
for inst, exp in zip(insts, expected):
|
|
self.assertIsNotNone(inst.exn_tab_entry)
|
|
self.assertIs(inst.exn_tab_entry.target, insts[exp])
|
|
|
|
@skipIfNotPy311
|
|
def test_compute_exception_table_nested(self):
|
|
insts = []
|
|
for _ in range(20):
|
|
insts.append(bytecode_transformation.create_instruction("NOP"))
|
|
insts[10].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[1], insts[10], insts[0], 0, True
|
|
)
|
|
insts[0].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[1], insts[1], insts[1], 0, True
|
|
)
|
|
insts[1].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[1], insts[3], insts[2], 0, True
|
|
)
|
|
insts[5].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[5], insts[7], insts[3], 0, True
|
|
)
|
|
insts[9].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[10], insts[10], insts[4], 0, True
|
|
)
|
|
insts[7].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[8], insts[10], insts[5], 0, True
|
|
)
|
|
insts[14].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[13], insts[17], insts[6], 0, True
|
|
)
|
|
insts[16].exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
insts[15], insts[16], insts[7], 0, True
|
|
)
|
|
bytecode_transformation.update_offsets(insts)
|
|
tab = bytecode_transformation.compute_exception_table(insts)
|
|
expected = [
|
|
(1, 1, 1),
|
|
(2, 3, 2),
|
|
(4, 4, 0),
|
|
(5, 7, 3),
|
|
(8, 9, 5),
|
|
(10, 10, 4),
|
|
(13, 14, 6),
|
|
(15, 16, 7),
|
|
(17, 17, 6),
|
|
]
|
|
self.assertEquals(len(tab), len(expected))
|
|
for entry, exp in zip(tab, expected):
|
|
self.assertEquals(entry.start, exp[0] * 2)
|
|
self.assertEquals(entry.end, exp[1] * 2)
|
|
self.assertEquals(entry.target, exp[2] * 2)
|
|
|
|
@skipIfNotPy311
|
|
def test_remove_dead_code_with_exn_table_entries(self):
|
|
create_instruction = bytecode_transformation.create_instruction
|
|
target1 = create_instruction("NOP")
|
|
target2 = create_instruction("NOP")
|
|
target3 = create_instruction("NOP")
|
|
exn_start = create_instruction("NOP")
|
|
exn_end = create_instruction("NOP")
|
|
insts = [
|
|
create_instruction("JUMP_FORWARD", target=target1),
|
|
exn_start, # dead
|
|
target1,
|
|
create_instruction("JUMP_FORWARD", target=target3),
|
|
exn_end, # dead
|
|
target2,
|
|
target3,
|
|
]
|
|
exn_start.exn_tab_entry = bytecode_transformation.InstructionExnTabEntry(
|
|
exn_start, exn_end, target2, 0, True
|
|
)
|
|
bytecode_transformation.propagate_inst_exn_table_entries(insts)
|
|
insts = bytecode_analysis.remove_dead_code(insts)
|
|
self.assertEquals(len(insts), 5)
|
|
self.assertNotIn(exn_start, insts)
|
|
self.assertNotIn(exn_end, insts)
|
|
self.assertIn(target2, insts)
|
|
self.assertIn(target3, insts)
|
|
bytecode_transformation.update_offsets(insts)
|
|
tab = bytecode_transformation.compute_exception_table(insts)
|
|
self.assertEquals(len(tab), 1)
|
|
self.assertEquals(tab[0].start, 2)
|
|
self.assertEquals(tab[0].end, 4)
|
|
self.assertEquals(tab[0].target, 6)
|
|
|
|
def test_ordered_dict_alias_reconstruct(self):
|
|
od = collections.OrderedDict
|
|
|
|
def fn():
|
|
d1 = dict()
|
|
d1["a"] = 1
|
|
d2 = od(d1)
|
|
d2["b"] = 2
|
|
torch._dynamo.graph_break()
|
|
if isinstance(d2, od):
|
|
return d2["a"] + d2["b"]
|
|
else:
|
|
return 0
|
|
|
|
dis.dis(fn)
|
|
self.assertEqual(torch._dynamo.optimize("eager")(fn)(), 3)
|
|
|
|
@torch._dynamo.config.patch(dynamic_shapes=True)
|
|
def test_raise_guard_full_constraint(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(x):
|
|
if x.shape[0] == 3:
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
torch._dynamo.mark_dynamic(y, 0)
|
|
with self.assertRaises(ConstraintViolationError):
|
|
torch._dynamo.optimize("eager")(my_dyn_fn)(y)
|
|
|
|
def test_mark_static(self):
|
|
counter = CompileCounter()
|
|
|
|
def my_dyn_fn(x):
|
|
return x.cos()
|
|
|
|
y = torch.randn([3])
|
|
torch._dynamo.mark_static(y, 0)
|
|
torch._dynamo.optimize(counter)(my_dyn_fn)(y)
|
|
|
|
z = torch.randn([4])
|
|
torch._dynamo.optimize(counter)(my_dyn_fn)(z)
|
|
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
@torch._dynamo.config.patch(dynamic_shapes=True)
|
|
def test_no_raise_guard_partial_constraint(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(x):
|
|
if x.shape[0] > 3:
|
|
return x.sin()
|
|
return x.cos()
|
|
|
|
torch._dynamo.optimize("eager")(my_dyn_fn)(y)
|
|
torch._dynamo.mark_dynamic(y, 0)
|
|
torch._dynamo.reset()
|
|
torch._dynamo.optimize("eager")(my_dyn_fn)(y)
|
|
|
|
@torch._dynamo.config.patch(dynamic_shapes=True)
|
|
def test_no_raise_guard_partial_constraint_across_break(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(x, y):
|
|
z = x * y
|
|
|
|
torch._dynamo.graph_break()
|
|
if z.shape[0] > 2:
|
|
return z.cos()
|
|
|
|
return x.cos()
|
|
|
|
torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
|
|
torch._dynamo.mark_dynamic(y, 0)
|
|
torch._dynamo.reset()
|
|
torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
|
|
|
|
# Sadly, this does not throw - we do not prop correctly across the graph break
|
|
@unittest.expectedFailure
|
|
@torch._dynamo.config.patch(dynamic_shapes=True)
|
|
def test_raise_guard_partial_constraint_across_break(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(x, y):
|
|
z = x * y
|
|
|
|
torch._dynamo.graph_break()
|
|
if z.shape[0] == 3:
|
|
return z.cos()
|
|
|
|
return x.cos()
|
|
|
|
torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
|
|
torch._dynamo.mark_dynamic(y, 0)
|
|
torch._dynamo.reset()
|
|
with self.assertRaisesRegex(
|
|
Exception,
|
|
):
|
|
torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
|
|
|
|
@torch._dynamo.config.patch(dynamic_shapes=True)
|
|
def test_raise_guard_partial_constraint_no_graph_break(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(x, y):
|
|
z = x * y
|
|
|
|
if z.shape[0] == 3:
|
|
return z.cos()
|
|
|
|
return x.cos()
|
|
|
|
torch._dynamo.mark_dynamic(y, 0)
|
|
with self.assertRaises(ConstraintViolationError):
|
|
torch._dynamo.optimize("eager")(my_dyn_fn)(y, y)
|
|
|
|
def test_cannot_trace_mark_dynamic(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(x):
|
|
torch._dynamo.mark_dynamic(x, 0)
|
|
return x * x
|
|
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "Attempt to trace forbidden callable"
|
|
):
|
|
torch._dynamo.optimize("eager")(my_dyn_fn)(y)
|
|
|
|
def test_cannot_trace_mark_dynamic_safe_unreached(self):
|
|
y = torch.randn([3, 3, 3])
|
|
|
|
def my_dyn_fn(x):
|
|
if x.shape[0] == 3:
|
|
return x
|
|
print("Running", torch._dynamo.mark_dynamic(x, 0))
|
|
return x * x
|
|
|
|
torch._dynamo.optimize("eager")(my_dyn_fn)(y)
|
|
|
|
@torch._dynamo.config.patch(dynamic_shapes=True)
|
|
def test_py_guards_mark_dynamic(self):
|
|
def my_dyn_fn(a):
|
|
if a.shape[0] > 2:
|
|
return a.cos()
|
|
return a.sin()
|
|
|
|
counter = CompileCounter()
|
|
|
|
# Run with dynamic
|
|
x0 = torch.randn([3, 3, 3])
|
|
torch._dynamo.mark_dynamic(x0, 0)
|
|
torch._dynamo.optimize(counter)(my_dyn_fn)(x0)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
# Run without dynamic, no recompile
|
|
x = torch.randn([3, 3, 3])
|
|
torch._dynamo.optimize(counter)(my_dyn_fn)(x)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
# Mark a new dim, 1, as dynamic
|
|
x1 = torch.randn([3, 3, 3])
|
|
torch._dynamo.mark_dynamic(x1, 1)
|
|
torch._dynamo.optimize(counter)(my_dyn_fn)(x1)
|
|
# Recompile triggered because we marked a new dym as dynamic
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
# Reset
|
|
torch._dynamo.reset()
|
|
# Reset counter
|
|
counter = CompileCounter()
|
|
|
|
# Run with dynamic 1
|
|
torch._dynamo.optimize(counter)(my_dyn_fn)(x1)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
# Run with dynamic 0, not subset
|
|
torch._dynamo.optimize(counter)(my_dyn_fn)(x0)
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
# Run with dynamic 0, 1, 2, not subset
|
|
x012 = torch.randn([3, 3, 3])
|
|
torch._dynamo.mark_dynamic(x012, 0)
|
|
torch._dynamo.mark_dynamic(x012, 1)
|
|
torch._dynamo.mark_dynamic(x012, 2)
|
|
torch._dynamo.optimize(counter)(my_dyn_fn)(x012)
|
|
self.assertEqual(counter.frame_count, 3)
|
|
|
|
def test_torch_compile_ctx_on_forward_and_training_step(self):
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self):
|
|
...
|
|
|
|
def training_step(self):
|
|
self()
|
|
|
|
model = MyModel()
|
|
compiled_model = torch.compile(model)
|
|
|
|
model.forward = compiled_model.dynamo_ctx(model.forward)
|
|
model.training_step = compiled_model.dynamo_ctx(model.training_step)
|
|
|
|
model.training_step()
|
|
|
|
def test_torch_guards_stack_frame_register_inlining(self):
|
|
x = torch.tensor([0.5, 0.5])
|
|
y = torch.tensor([0.75, 0.75, 0.75, 0.75])
|
|
z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25])
|
|
|
|
def uwu_inline_me(x, y, z):
|
|
r = torch.cat((x, x)) + y
|
|
r2 = torch.cat((y, y)) + z
|
|
return r, r2
|
|
|
|
def fn(x, y, z):
|
|
r, r2 = uwu_inline_me(x, y, z)
|
|
return torch.mul(r, r), torch.mul(r2, r2)
|
|
|
|
seen_frames = []
|
|
import contextlib
|
|
|
|
@contextlib.contextmanager
|
|
def global_context_capture_fn(frame_summary):
|
|
seen_frames.append(frame_summary)
|
|
yield
|
|
|
|
with mock.patch(
|
|
"torch._guards.TracingContext.current_frame",
|
|
side_effect=global_context_capture_fn,
|
|
):
|
|
torch._dynamo.optimize("eager")(fn)(x, y, z)
|
|
|
|
self.assertEqual(len(seen_frames), 1)
|
|
self.assertEqual(seen_frames[0].name, "fn")
|
|
self.assertEqual(seen_frames[0].line, "r, r2 = uwu_inline_me(x, y, z)")
|
|
|
|
def test_torch_guards_stack_frame_register_inlining_deep(self):
|
|
x = torch.tensor([0.5, 0.5])
|
|
y = torch.tensor([0.75, 0.75, 0.75, 0.75])
|
|
z = torch.tensor([0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25, 0.25])
|
|
|
|
def uwu_inline_me_deep(x, y):
|
|
return torch.cat((x, x)) + y
|
|
|
|
def uwu_inline_me(x, y, z):
|
|
r = uwu_inline_me_deep(x, y)
|
|
r2 = uwu_inline_me_deep(y, z)
|
|
return r, r2
|
|
|
|
def fn(x, y, z):
|
|
r, r2 = uwu_inline_me(x, y, z)
|
|
return torch.mul(r, r), torch.mul(r2, r2)
|
|
|
|
seen_frames = []
|
|
import contextlib
|
|
|
|
@contextlib.contextmanager
|
|
def global_context_capture_fn(frame_summary):
|
|
seen_frames.append(frame_summary)
|
|
yield
|
|
|
|
with mock.patch(
|
|
"torch._guards.TracingContext.current_frame",
|
|
side_effect=global_context_capture_fn,
|
|
):
|
|
torch._dynamo.optimize("eager")(fn)(x, y, z)
|
|
|
|
self.assertEqual(len(seen_frames), 3)
|
|
self.assertEqual(seen_frames[0].name, "fn")
|
|
self.assertEqual(seen_frames[1].name, "uwu_inline_me")
|
|
self.assertEqual(seen_frames[2].line, "r2 = uwu_inline_me_deep(y, z)")
|
|
|
|
def test_error_on_recompile(self):
|
|
@torch._dynamo.optimize("eager")
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
with unittest.mock.patch("torch._dynamo.config.error_on_recompile", True):
|
|
with self.assertRaises(torch._dynamo.exc.RecompileError):
|
|
fn(torch.rand(2, 3), torch.rand(2, 3))
|
|
fn(torch.rand(2, 3), (1, 2, 3))
|
|
|
|
def test_compile_profiler(self):
|
|
class Model(torch.nn.Module):
|
|
def forward(self, input):
|
|
return input + input
|
|
|
|
model = Model()
|
|
with CompileProfiler() as prof:
|
|
compiled = torch.compile(model, backend=prof)
|
|
base_checker = (
|
|
lambda: FileCheck()
|
|
.check("Torchdynamo Profiler Report")
|
|
.check("Graph Breaks")
|
|
.check("No graph breaks detected.")
|
|
.check("Recompilation")
|
|
)
|
|
input = torch.rand((2, 3, 4))
|
|
_ = compiled(input)
|
|
base_checker().check("No recompilation detected.").run(prof.report())
|
|
|
|
new_shape_input = torch.rand((3, 3, 4))
|
|
_ = compiled(new_shape_input)
|
|
|
|
# Not an exhaustive test of dynamic shapes behavior, but some sanity
|
|
if (
|
|
not torch._dynamo.config.dynamic_shapes
|
|
or torch._dynamo.config.assume_static_by_default
|
|
):
|
|
base_checker().check("Recompile Reasons").check("'forward'").check(
|
|
"cache_size_limit to 1"
|
|
).run(prof.report())
|
|
else:
|
|
base_checker().check("No recompilation detected.").run(prof.report())
|
|
|
|
# Ensure correct guard fail message is selected to show to user
|
|
if not torch._dynamo.config.dynamic_shapes:
|
|
new_shape_input = torch.rand((4, 3, 4))
|
|
_ = compiled(new_shape_input)
|
|
|
|
base_checker().check("Recompile Reasons").check("'forward'").check(
|
|
"tensor 'L['input']' size mismatch at index 0. expected 2, actual 3"
|
|
).check(
|
|
"tensor 'L['input']' size mismatch at index 0. expected 3, actual 4"
|
|
).run(
|
|
prof.report()
|
|
)
|
|
|
|
def test_guards_strip_function_call(self):
|
|
from torch._dynamo.guards import strip_function_call
|
|
|
|
test_case = [
|
|
("___odict_getitem(a, 1)", "a"),
|
|
("a.layers[slice(2)][0]._xyz", "a"),
|
|
("getattr(a.layers[slice(2)][0]._abc, '0')", "a"),
|
|
("getattr(getattr(a.x[3], '0'), '3')", "a"),
|
|
("a.layers[slice(None, -1, None)][0]._xyz", "a"),
|
|
("a.layers[func('offset', -1, None)][0]._xyz", "a"),
|
|
]
|
|
# strip_function_call should extract the object from the string.
|
|
for name, expect_obj in test_case:
|
|
self.assertEqual(strip_function_call(name), expect_obj)
|
|
|
|
def test_int_neg(self):
|
|
def int_neg(a, b):
|
|
x = a.shape[0]
|
|
y = b.shape[0]
|
|
return -x * -y * a * b
|
|
|
|
torch._dynamo.testing.standard_test(self, int_neg, 2)
|
|
|
|
def test_hash_getitem_slice(self):
|
|
s = GetItemSource(LocalSource("foo"), slice(None, -1, None))
|
|
s2 = GetItemSource(LocalSource("foo"), slice(None, -1, None))
|
|
s3 = GetItemSource(LocalSource("foo"), slice(None, -1, 2))
|
|
some_set = set()
|
|
|
|
self.assertTrue(s not in some_set)
|
|
self.assertTrue(s2 not in some_set)
|
|
self.assertTrue(s3 not in some_set)
|
|
|
|
some_set.add(s)
|
|
|
|
self.assertTrue(s in some_set)
|
|
# s and s2 should hash the same
|
|
self.assertTrue(s2 in some_set)
|
|
# s3 should be different
|
|
self.assertTrue(s3 not in some_set)
|
|
|
|
self.assertTrue(s == s2)
|
|
self.assertTrue(s != s3)
|
|
|
|
def test_add_sizes(self):
|
|
def func(x):
|
|
y = x.size()
|
|
return y + y
|
|
|
|
eager_out = func(torch.ones(10, 10, 3))
|
|
compile_out = torch._dynamo.optimize("eager")(func)(torch.ones(10, 10, 3))
|
|
self.assertTrue(isinstance(compile_out, torch.Size))
|
|
self.assertEqual(eager_out, compile_out)
|
|
|
|
def test_nested_function_resuming_with_correct_globals(self):
|
|
# https://github.com/pytorch/pytorch/issues/99665
|
|
try:
|
|
from .utils import outer_func
|
|
except ImportError:
|
|
from utils import outer_func
|
|
|
|
def gn(x, y):
|
|
return x + y
|
|
|
|
def fn(x, y):
|
|
return outer_func(gn)(x, y)
|
|
|
|
x = torch.rand([3])
|
|
y = torch.rand([3])
|
|
opt_fn = torch.compile(backend="eager")(fn)
|
|
ref = fn(x, y)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
|
|
class CustomFunc1(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
return foo + foo
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
|
|
class CustomFunc2(torch.autograd.Function):
|
|
# the forward function can be staticmethod or classmethod
|
|
@classmethod
|
|
def forward(cls, ctx, foo):
|
|
return foo + foo
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
|
|
class CustomFunc3(torch.autograd.Function):
|
|
# Test there is graph break in forward function
|
|
@staticmethod
|
|
def forward(ctx, foo):
|
|
result = foo + foo
|
|
torch._dynamo.graph_break()
|
|
result = result + foo
|
|
ctx.save_for_backward(result)
|
|
return result
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
(result,) = ctx.saved_tensors
|
|
return grad_output * math.sqrt(result.numel())
|
|
|
|
|
|
class Module1(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFunc1().apply(foo)
|
|
|
|
|
|
class Module2(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fn = CustomFunc1.apply
|
|
|
|
def forward(self, foo):
|
|
return self.fn(foo)
|
|
|
|
|
|
class Module3(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFunc2().apply(foo)
|
|
|
|
|
|
class Module4(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fn = CustomFunc2.apply
|
|
|
|
def forward(self, foo):
|
|
return self.fn(foo)
|
|
|
|
|
|
class Module5(torch.nn.Module):
|
|
def forward(self, foo):
|
|
return CustomFunc3().apply(foo)
|
|
|
|
|
|
class Module6(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fn = CustomFunc3.apply
|
|
|
|
def forward(self, foo):
|
|
return self.fn(foo)
|
|
|
|
|
|
class TestTracer(JitTestCase):
|
|
def test_jit_save(self):
|
|
def fn():
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = 3
|
|
|
|
@torch.jit.export
|
|
def __getstate__(self):
|
|
return (3, self.training)
|
|
|
|
@torch.jit.export
|
|
def __setstate__(self, state):
|
|
self.a = state[0]
|
|
self.training = state[1]
|
|
|
|
def forward(self, x):
|
|
return x + self.a
|
|
|
|
f = Foo()
|
|
|
|
return torch.jit.trace(f, (torch.rand(3, 4),))
|
|
|
|
fn()
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
opt_fn()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|