mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/86950 Approved by: https://github.com/Chillee
114 lines
2.6 KiB
Python
114 lines
2.6 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
import torch._dynamo
|
|
import torch._dynamo.test_case
|
|
from torch._dynamo.testing import CompileCounter
|
|
|
|
|
|
class SkipNonTensorTests(torch._dynamo.test_case.TestCase):
|
|
def test_add_tensor1(self):
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
counter = CompileCounter()
|
|
x = torch.randn(4)
|
|
y = 5
|
|
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
|
|
opt_fn(x, y)
|
|
|
|
assert counter.op_count == 1
|
|
|
|
def test_add_tensor2(self):
|
|
def fn(a, b):
|
|
return torch.add(a, b)
|
|
|
|
counter = CompileCounter()
|
|
|
|
x = torch.randn(4)
|
|
y = 5
|
|
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
|
|
opt_fn(x, y)
|
|
|
|
assert counter.op_count == 1
|
|
|
|
def test_add_tensor_list(self):
|
|
def fn(lst):
|
|
return lst[0] + lst[1]
|
|
|
|
counter = CompileCounter()
|
|
x = torch.randn(4)
|
|
y = 5
|
|
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
|
|
opt_fn([x, y])
|
|
|
|
assert counter.op_count == 1
|
|
|
|
def test_add_tensor_dict(self):
|
|
def fn(dt):
|
|
return dt["a"] + dt["b"]
|
|
|
|
counter = CompileCounter()
|
|
x = torch.randn(4)
|
|
y = 5
|
|
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
|
|
opt_fn({"a": x, "b": y})
|
|
|
|
assert counter.op_count == 1
|
|
|
|
def test_add_skip(self):
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
counter = CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(counter)(fn)
|
|
x = 4
|
|
y = 5
|
|
opt_fn(x, y)
|
|
|
|
assert counter.op_count == 0
|
|
|
|
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
|
|
def test_recursive_list(self):
|
|
def fn(x):
|
|
return x
|
|
|
|
counter = CompileCounter()
|
|
|
|
x = []
|
|
x.append(x)
|
|
with torch._dynamo.optimize_assert(counter):
|
|
fn(x)
|
|
|
|
assert counter.op_count == 0
|
|
|
|
@patch.object(torch._dynamo.config, "raise_on_ctx_manager_usage", False)
|
|
def test_custom_list(self):
|
|
def fn(x):
|
|
return x[0] + x[1]
|
|
|
|
counter = CompileCounter()
|
|
|
|
class Foo(list):
|
|
def __iter__(self):
|
|
raise Exception()
|
|
|
|
def __len__(self):
|
|
raise Exception()
|
|
|
|
x = Foo()
|
|
x.append(torch.randn(4))
|
|
x.append(torch.randn(4))
|
|
with torch._dynamo.optimize_assert(counter):
|
|
fn(x)
|
|
|
|
assert counter.op_count == 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|