mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
for some tensor x, x.type(torch.FloatTensor) will essentially do the same thing as x.to(torch.float). x.type can be called with at least 3 types of inputs:
* a string "torch.FloatTensor"
* a dtype torch.float
* a tensor type torch.FloatTensor
the third option (torch.FloatTensor) fails in fx, because fx cannot trace torch.FloatTensor objects. So this PR will replace the torch.FloatTensor type with a string "torch.FloatTensor"
Why not fix this in fx? Well, it's possible, but I'm not sure a nice way to do it. We would want to update [torch.fx.node.BaseArgumentTypes](d88bc38b0c/torch/fx/node.py (L17)) to contain torch.FloatTensor etc. We could hard-code a list of tensor types there (the types vary depending on build type, e.g. whether or not cuda tensors are available), but that's not great in case our hardcoded list differs from the actual list registered by python_tensor.cpp. Another option is to dynamically populate the list of types with `Union[tuple(...)])`, and fill the tuple with `torch._tensor_classes` (which is directly populated by python_tensor.cpp), but apparently this breaks most typecheckers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93043
Approved by: https://github.com/jansel
911 lines
22 KiB
Python
911 lines
22 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
# flake8: noqa
|
|
import collections
|
|
import functools
|
|
import inspect
|
|
import itertools
|
|
import operator
|
|
import unittest
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
|
|
import torch
|
|
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch import sub
|
|
from torch._dynamo.testing import requires_static_shapes
|
|
from torch._dynamo.utils import same
|
|
from torch.nn import functional as F
|
|
|
|
tensor_for_import_testing = torch.ones(10, 10)
|
|
d = torch.ones(10, 10)
|
|
e = torch.nn.Linear(10, 10)
|
|
flag = True
|
|
|
|
|
|
def constant3(a, b):
|
|
return a - b + (1.0 + 2)
|
|
|
|
|
|
def func_with_default(a, b, some_default_arg=True):
|
|
if some_default_arg:
|
|
return a - b
|
|
|
|
|
|
def make_test(fn):
|
|
nargs = len(inspect.signature(fn).parameters)
|
|
|
|
def test_fn(self):
|
|
return torch._dynamo.testing.standard_test(self, fn=fn, nargs=nargs)
|
|
|
|
return test_fn
|
|
|
|
|
|
@torch.jit.script_if_tracing
|
|
def inline_script_if_tracing(x):
|
|
return x + 1.2
|
|
|
|
|
|
@torch.jit.ignore
|
|
def inline_ignore(x):
|
|
return x + 3.4
|
|
|
|
|
|
@torch.jit.unused
|
|
def inline_unused(x):
|
|
return x + 5.6
|
|
|
|
|
|
class FunctionTests(torch._dynamo.test_case.TestCase):
|
|
@make_test
|
|
def test_inline_jit_annotations(x):
|
|
x = inline_script_if_tracing(x)
|
|
x = inline_ignore(x)
|
|
x = inline_unused(x)
|
|
return
|
|
|
|
@make_test
|
|
def test_add(a, b):
|
|
return a + b
|
|
|
|
@make_test
|
|
def test_add_(a, b):
|
|
a_copy = torch.tensor(a)
|
|
return a_copy.add_(b, alpha=5.0)
|
|
|
|
@make_test
|
|
def test_addcdiv(a, b, c):
|
|
# dynamo decomposes this to avoid a graph break when
|
|
# the value kwarg is populated
|
|
return torch.addcdiv(a, b, c, value=5.0)
|
|
|
|
@make_test
|
|
def test_addcdiv_(a, b, c):
|
|
a_copy = torch.tensor(a)
|
|
return a_copy.addcdiv_(b, c, value=5.0)
|
|
|
|
@make_test
|
|
def test_is_not_null(a, b):
|
|
if a is not None and b is not None:
|
|
return a + b
|
|
|
|
@make_test
|
|
def test_constant1(a, b, c):
|
|
return a - b * c + 1.0
|
|
|
|
@make_test
|
|
def test_constant2(a, b, c):
|
|
return a - b * c + 1
|
|
|
|
@make_test
|
|
def test_constant3(a):
|
|
b = 1
|
|
c = 2
|
|
d = 3
|
|
return b + c - d + a
|
|
|
|
@make_test
|
|
def test_constant4(a, b):
|
|
c = 2
|
|
d = 3
|
|
if c > d:
|
|
return a - b
|
|
return b - a
|
|
|
|
@make_test
|
|
def test_finfo(a, b):
|
|
if torch.iinfo(torch.int32).bits == 32:
|
|
return torch.finfo(a.dtype).min * b
|
|
|
|
@make_test
|
|
def test_globalfn(a, b):
|
|
return sub(a, b)
|
|
|
|
@make_test
|
|
def test_viatorch(a, b):
|
|
return torch.sub(a, b)
|
|
|
|
@make_test
|
|
def test_viamethod(a, b):
|
|
return a.sub(b)
|
|
|
|
@make_test
|
|
def test_indirect1(a, b):
|
|
t = a.sub
|
|
return t(b)
|
|
|
|
@make_test
|
|
def test_indirect2(a, b):
|
|
t = a.sub
|
|
args = (b,)
|
|
return t(*args)
|
|
|
|
@make_test
|
|
def test_indirect3(a, b):
|
|
t = a.sub
|
|
args = (b,)
|
|
kwargs = {}
|
|
return t(*args, **kwargs)
|
|
|
|
@make_test
|
|
def test_methodcall1(a, b, c):
|
|
return constant3(a, b) * c
|
|
|
|
@make_test
|
|
def test_methodcall2(a, b):
|
|
return constant3(a=b, b=a) + 1
|
|
|
|
@make_test
|
|
def test_methodcall3(a, b):
|
|
return constant3(a, b=1.0) + b
|
|
|
|
@make_test
|
|
def test_device_constant(a):
|
|
return a + torch.ones(1, device=torch.device("cpu"))
|
|
|
|
@make_test
|
|
def test_tuple1(a, b):
|
|
args = (a, b)
|
|
return sub(*args)
|
|
|
|
@make_test
|
|
def test_tuple2(a, b):
|
|
args = [a, b]
|
|
return sub(*args)
|
|
|
|
@make_test
|
|
def test_is_in_onnx_export(x, y):
|
|
if torch.onnx.is_in_onnx_export():
|
|
return x - 1
|
|
else:
|
|
return y + 1
|
|
|
|
@make_test
|
|
def test_is_fx_tracing(x, y):
|
|
if torch.fx._symbolic_trace.is_fx_tracing():
|
|
return x - 1
|
|
else:
|
|
return y + 1
|
|
|
|
@make_test
|
|
def test_listarg1(a, b):
|
|
return torch.cat([a, b])
|
|
|
|
@make_test
|
|
def test_listarg2(a, b):
|
|
return torch.cat((a, b), dim=0)
|
|
|
|
@make_test
|
|
def test_listarg3(a, b):
|
|
kwargs = {"tensors": (a, b), "dim": 0}
|
|
return torch.cat(**kwargs)
|
|
|
|
@make_test
|
|
def test_listarg4(a, b):
|
|
return torch.cat(tensors=[a, b], dim=0)
|
|
|
|
@make_test
|
|
def test_listarg5(a, b):
|
|
args = [(a, b)]
|
|
kwargs = {"dim": 0}
|
|
return torch.cat(*args, **kwargs)
|
|
|
|
@make_test
|
|
def test_slice1(a):
|
|
return a[5]
|
|
|
|
@make_test
|
|
def test_slice2(a):
|
|
return a[:5]
|
|
|
|
@make_test
|
|
def test_slice3(a):
|
|
return a[5:]
|
|
|
|
@make_test
|
|
def test_slice4(a):
|
|
return a[2:5]
|
|
|
|
@make_test
|
|
def test_slice5(a):
|
|
return a[::2]
|
|
|
|
@make_test
|
|
def test_slice6(a):
|
|
return torch.unsqueeze(a, 0)[:, 2:]
|
|
|
|
@make_test
|
|
def test_range1(a):
|
|
return torch.tensor(range(a.size(0)))
|
|
|
|
@make_test
|
|
def test_range2(x, y):
|
|
r = x + y
|
|
for i in range(x.size(0) + 2):
|
|
r = r / y
|
|
return r
|
|
|
|
@make_test
|
|
def test_unpack1(a):
|
|
a, b = a[:5], a[5:]
|
|
return a - b
|
|
|
|
@make_test
|
|
def test_unpack2(a):
|
|
packed = [a[:5], a[5:]]
|
|
a, b = packed
|
|
return a - b
|
|
|
|
@make_test
|
|
def test_unpack3(a):
|
|
packed = (a[:5], a[5:])
|
|
a, b = packed
|
|
return a - b
|
|
|
|
@make_test
|
|
def test_fn_with_self_set(a, b):
|
|
# avg_pool2d is an odd one with __self__ set
|
|
return F.avg_pool2d(
|
|
torch.unsqueeze(a, 0) * torch.unsqueeze(b, 1), kernel_size=2, padding=1
|
|
)
|
|
|
|
@make_test
|
|
def test_return_tuple1(a, b):
|
|
return (a - b, b - a, a, b)
|
|
|
|
@make_test
|
|
def test_globalvar(a, b):
|
|
return a - b + d
|
|
|
|
@make_test
|
|
def test_globalmodule(x):
|
|
return e(x)
|
|
|
|
@make_test
|
|
def test_inline_with_default(a, b, c):
|
|
return func_with_default(a, b) * c
|
|
|
|
@make_test
|
|
def test_inner_function(x):
|
|
def fn(x):
|
|
return torch.add(x, x)
|
|
|
|
return fn(x)
|
|
|
|
@make_test
|
|
def test_transpose_for_scores(x):
|
|
new_x_shape = x.size()[:-1] + (2, 5)
|
|
x = x.view(*new_x_shape)
|
|
return x.permute(0, 2, 1)
|
|
|
|
@make_test
|
|
def test_return_tuple2(x):
|
|
return (torch.add(x, x), x)
|
|
|
|
@make_test
|
|
def test_load_global_bool(x):
|
|
if flag:
|
|
return torch.add(x, x)
|
|
else:
|
|
return x
|
|
|
|
@make_test
|
|
def test_len_tensor(x):
|
|
z = len(x)
|
|
return torch.add(x, z)
|
|
|
|
@make_test
|
|
def test_len_constant_list(x):
|
|
z = len([1, 2, 3])
|
|
return torch.add(x, z)
|
|
|
|
@make_test
|
|
def test_len_constant_dict(x):
|
|
z = len({"foo": "bar"})
|
|
return torch.add(x, z)
|
|
|
|
@make_test
|
|
def test_dict_copy(x):
|
|
z = dict({"foo": x + 1})
|
|
return z
|
|
|
|
@make_test
|
|
def test_len_constant_misc_iterables(x):
|
|
a = len((1, 2, 3))
|
|
b = len("test str")
|
|
c = a + b
|
|
return torch.add(x, c)
|
|
|
|
@make_test
|
|
def test_float(x):
|
|
y = float(1.2)
|
|
y += float("1.2")
|
|
return torch.add(x, y)
|
|
|
|
@make_test
|
|
def test_dtype(x):
|
|
if x.dtype == torch.float32:
|
|
return x + 1
|
|
|
|
@make_test
|
|
def test_get_default_dtype(x):
|
|
if x.dtype == torch.get_default_dtype():
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
@make_test
|
|
def test_device(x):
|
|
if not x.is_cuda:
|
|
return x + 1
|
|
|
|
@make_test
|
|
def test_tensor_type(a, b):
|
|
m = a.to(torch.float16)
|
|
return b.type(m.type())
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
|
@make_test
|
|
def test_tensor_type2(a, b):
|
|
m = a.to("cuda")
|
|
return m + b.type(m.type())
|
|
|
|
@make_test
|
|
def test_tensor_type3(a, b):
|
|
m = a.type(torch.HalfTensor)
|
|
return b.type(m.type())
|
|
|
|
@make_test
|
|
def test_tensor_type4(a, b):
|
|
m = a.type("torch.HalfTensor")
|
|
return b.type(m.type())
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "requires cuda")
|
|
@make_test
|
|
def test_tensor_type5(a, b):
|
|
m = a.type(torch.cuda.HalfTensor)
|
|
return b.type(m.type())
|
|
|
|
@make_test
|
|
def test_ndim(x):
|
|
if x.ndim == 2 and x.ndimension() == 2 and x.dim() == 2:
|
|
return x + 1
|
|
|
|
@make_test
|
|
def test_T(x):
|
|
return torch.ones_like(x.T)
|
|
|
|
@make_test
|
|
def test_is_sparse(x):
|
|
if not x.is_sparse:
|
|
return x + 1
|
|
|
|
@requires_static_shapes
|
|
@make_test
|
|
def test_shape1(x):
|
|
if x.shape[0] == 10:
|
|
return x + 1
|
|
|
|
@requires_static_shapes
|
|
@make_test
|
|
def test_shape2(x):
|
|
if x.size(1) == 10:
|
|
return x + 1
|
|
|
|
@make_test
|
|
def test_del(a, b):
|
|
c = a + 1
|
|
d = c + 2
|
|
del c, a
|
|
return b + d
|
|
|
|
@requires_static_shapes
|
|
@make_test
|
|
def test_chunks1(x):
|
|
chunk_size = 5
|
|
assert x.shape[0] % chunk_size == 0
|
|
assert x.shape[0] // chunk_size == 2
|
|
return x[:chunk_size] - x[chunk_size:]
|
|
|
|
@make_test
|
|
def test_import1(x, y):
|
|
import torch
|
|
from torch import sub
|
|
|
|
return sub(torch.add(x, y), y)
|
|
|
|
@make_test
|
|
def test_return_dict(x, y):
|
|
z = [x + y, y, False]
|
|
return {"x": x, "z": z, "a": x, "b": z, "c": x}
|
|
|
|
@make_test
|
|
def test_return_dict2(x, y):
|
|
tmp = {"x": x}
|
|
tmp["z"] = [x + y, y]
|
|
tmp["y"] = y
|
|
tmp["z"].append(False)
|
|
return tmp
|
|
|
|
@make_test
|
|
def test_funcdef_closure(x, y):
|
|
x = x + y + 1.0
|
|
|
|
def inner(z):
|
|
nonlocal x, y
|
|
y = x + z + 20.0
|
|
x = y + z + 10.0
|
|
|
|
inner(2.0)
|
|
inner(3.0)
|
|
|
|
return x, y
|
|
|
|
@make_test
|
|
def test_module_constant(x, y):
|
|
r = x + y
|
|
for i in range(torch._dynamo.testing.three):
|
|
r = r / y
|
|
return r
|
|
|
|
@make_test
|
|
def test_inline_softmax(x, y):
|
|
# This is common in sme huggingface models
|
|
return torch.nn.Softmax(dim=-1)(x + y * 2)
|
|
|
|
@make_test
|
|
def test_dtype_compare(a, b):
|
|
if a.dtype == torch.float16:
|
|
return a + 10
|
|
if a.dtype == torch.float32:
|
|
return a - b * 32
|
|
|
|
@make_test
|
|
def test_build_list_unpack(a, b):
|
|
it1 = (x + 1 for x in (a, b))
|
|
it2 = (x - 1 for x in (a, b))
|
|
return torch.cat([*it1, *it2], dim=-1)
|
|
|
|
@make_test
|
|
def test_tensor_len(a, b):
|
|
return a + b + len(a) + b.__len__()
|
|
|
|
@make_test
|
|
def test_pop(a, b):
|
|
ll = [a, b]
|
|
ll.append(a + 1)
|
|
ll.extend(
|
|
[
|
|
b + 2,
|
|
a + b,
|
|
]
|
|
)
|
|
ll.pop(-1)
|
|
ll.pop(0)
|
|
ll.pop()
|
|
v1, v2 = ll
|
|
return v1 - v2
|
|
|
|
@make_test
|
|
def test_list_convert(a, b):
|
|
ll = [a + 2, b]
|
|
ll = tuple(ll)
|
|
tmp = b + 3
|
|
ll = list(ll)
|
|
v1, v2 = ll
|
|
return v1 - v2 + tmp
|
|
|
|
@make_test
|
|
def test_list_add(a, b):
|
|
l1 = (a, b)
|
|
l2 = () # being a LOAD_CONST in the bytecode
|
|
l3 = l1 + l2
|
|
return l3[0] + l3[1]
|
|
|
|
@make_test
|
|
def test_startswith(a, b):
|
|
x = a + b
|
|
if "foobar".startswith("foo") and "test" in constant3.__module__:
|
|
x = x + 1
|
|
return x
|
|
|
|
@make_test
|
|
def test_dict_ops(a, b):
|
|
tmp = {"a": a + 1, "b": b + 2}
|
|
v = tmp.pop("b") + tmp.get("a") + tmp.get("missing", 3) + tmp.pop("missing", 4)
|
|
tmp.update({"d": 3})
|
|
tmp["c"] = v + tmp["d"]
|
|
if "c" in tmp and "missing" not in tmp:
|
|
return tmp["c"] - tmp["a"] + len(tmp)
|
|
|
|
def test_dict_param_keys(self):
|
|
a_param = torch.nn.Parameter(torch.ones([4, 4]))
|
|
|
|
def fn(a):
|
|
tmp = {"a": a, a_param: 3}
|
|
return tmp["a"] + tmp[a_param]
|
|
|
|
test = make_test(fn)
|
|
test(self)
|
|
|
|
def test_default_dict(self):
|
|
dd = collections.defaultdict(dict)
|
|
param = torch.nn.Parameter(torch.ones([2, 2]))
|
|
|
|
def fn(x):
|
|
dd["a"] = x + 1
|
|
dd[param] = 123
|
|
dd["c"] = x * 2
|
|
return dd["b"], dd
|
|
|
|
test = make_test(fn)
|
|
test(self)
|
|
|
|
@make_test
|
|
def test_min_max(a, b):
|
|
c = a + b
|
|
a = a.sum()
|
|
b = b.sum()
|
|
a = min(max(a, 0), 1)
|
|
b = max(0, min(1, b))
|
|
return max(a, b) - min(a, b) + c
|
|
|
|
@make_test
|
|
def test_map_sum(a, b, c, d):
|
|
return sum(map(lambda x: x + 1, [a, b, c, d]))
|
|
|
|
@make_test
|
|
def test_reduce(a, b, c, d):
|
|
return functools.reduce(operator.add, [a, b, c, d])
|
|
|
|
@make_test
|
|
def test_tuple_contains(a, b):
|
|
v1 = "a"
|
|
v2 = "b"
|
|
v3 = "c"
|
|
vals1 = (v1, v2, v3)
|
|
vals2 = ("d", "e", "f")
|
|
if "a" in vals1 and "b" not in vals2:
|
|
return a + b
|
|
return a - b
|
|
|
|
@make_test
|
|
def test_tuple_iadd(a, b):
|
|
output = (a, b)
|
|
output += (a + b, a - b)
|
|
return output
|
|
|
|
@make_test
|
|
def test_unpack_ex1(x):
|
|
output = (x, x + 1, x + 2, x + 3)
|
|
a, b, *cd = output
|
|
return a - b / cd[0]
|
|
|
|
@make_test
|
|
def test_unpack_ex2(x):
|
|
output = (x, x + 1, x + 2, x + 3)
|
|
*ab, c, d = output
|
|
return c - d / ab[0]
|
|
|
|
@make_test
|
|
def test_unpack_ex3(x):
|
|
output = (x, x + 1, x + 2, x + 3)
|
|
a, *bc, d = output
|
|
return a - d / bc[0]
|
|
|
|
@make_test
|
|
def test_const_tuple_add1(x):
|
|
output = (x, x + 1, x + 2, x + 3)
|
|
output = () + output + ()
|
|
return output[2] + output[3]
|
|
|
|
@make_test
|
|
def test_const_tuple_add2(x):
|
|
output = (x, x + 1, x + 2, x + 3)
|
|
output = (None,) + output + (None,)
|
|
return output[2] + output[3]
|
|
|
|
@make_test
|
|
def test_list_truth(a, b):
|
|
tmp = [1, 2, 3]
|
|
if tmp:
|
|
return a + b
|
|
else:
|
|
return a - b
|
|
|
|
@make_test
|
|
def test_list_reversed(a, b):
|
|
tmp = [a + 1, a + 2, a + 3]
|
|
return a + b + next(iter(reversed(tmp)))
|
|
|
|
@make_test
|
|
def test_list_clear(a, b):
|
|
tmp = [a + 1, a + 2]
|
|
tmp.clear()
|
|
tmp.append(a + b)
|
|
return tmp
|
|
|
|
@make_test
|
|
def test_islice_chain(a, b):
|
|
tmp1 = [a + 1, a + 2]
|
|
tmp2 = [a + 3, a + 4]
|
|
a, b = list(itertools.islice(itertools.chain(tmp1, tmp2), 1, 3))
|
|
c = next(itertools.islice(tmp1, 1, None))
|
|
return a - b / c
|
|
|
|
@make_test
|
|
def test_is_quantized(a, b):
|
|
if not a.is_quantized:
|
|
return a + b
|
|
|
|
@make_test
|
|
def test_fstrings1(a, b):
|
|
x = 1.229
|
|
tmp = f"{x:.2f} bar"
|
|
if tmp.startswith("1.23"):
|
|
return a + b
|
|
|
|
@requires_static_shapes
|
|
@make_test
|
|
def test_fstrings2(x):
|
|
tmp = f"{x.shape[0]} bar"
|
|
if tmp.startswith("10"):
|
|
return x + 1
|
|
|
|
@make_test
|
|
def test_fstrings3(x):
|
|
tmp = f"{x.__class__.__name__} foo"
|
|
if tmp.startswith("Tensor"):
|
|
return x + 1
|
|
|
|
@requires_static_shapes
|
|
@make_test
|
|
def test_tensor_new_with_size(x):
|
|
y = torch.rand(5, 8)
|
|
z = x.new(y.size())
|
|
assert z.size() == y.size()
|
|
|
|
@requires_static_shapes
|
|
@make_test
|
|
def test_tensor_new_with_shape(x):
|
|
y = torch.rand(5, 8)
|
|
z = x.new(y.shape)
|
|
assert z.size() == y.size()
|
|
|
|
@make_test
|
|
def test_jit_annotate(x):
|
|
y = torch.jit.annotate(Any, x + 1)
|
|
return y + 2
|
|
|
|
@requires_static_shapes
|
|
@make_test
|
|
def test_is_contiguous_memory_format(tensor):
|
|
if torch.jit.is_scripting():
|
|
return None
|
|
elif tensor.is_contiguous(memory_format=torch.contiguous_format):
|
|
return tensor + 1
|
|
|
|
@make_test
|
|
def test_list_slice_assignment(x):
|
|
m = [1, 2, 3, 4]
|
|
m[1:] = [6] * (len(m) - 1)
|
|
return x + 1
|
|
|
|
@make_test
|
|
def test_distributed_is_available(x):
|
|
if torch.distributed.is_available():
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
@unittest.skipIf(
|
|
not torch.distributed.is_available(), "requires distributed package"
|
|
)
|
|
@make_test
|
|
def test_distributed_is_initialized(x):
|
|
if torch.distributed.is_initialized():
|
|
return x + 1
|
|
else:
|
|
return x - 1
|
|
|
|
@make_test
|
|
def test_torch_distributions_functions(x):
|
|
normal = torch.distributions.Normal(x, torch.tensor(1))
|
|
independent = torch.distributions.Independent(normal, 1)
|
|
return independent.log_prob(x)
|
|
|
|
@make_test
|
|
def test_context_wrapping_nested_functions_no_closure(x):
|
|
@torch.no_grad()
|
|
def augment(x: torch.Tensor) -> torch.Tensor:
|
|
return (x + 1) * 2
|
|
|
|
return augment(x)
|
|
|
|
# # This is to test the new syntax for pattern matching
|
|
# # ("match ... case ...") added on python 3.10.
|
|
# # Uncomment these test cases if you run on 3.10+
|
|
# @make_test
|
|
# def test_match_sequence(a):
|
|
# point = (5, 8)
|
|
# match point:
|
|
# case (0, 0):
|
|
# return a
|
|
# case (0, y):
|
|
# return a - y
|
|
# case (x, 0):
|
|
# return a + x
|
|
# case (x, y):
|
|
# return a + x - y
|
|
|
|
# @make_test
|
|
# def test_match_mapping_and_match_keys(x):
|
|
# param = {"a": 0.5}
|
|
# match param:
|
|
# case {"a": param}:
|
|
# return x * param
|
|
# case {"b": param}:
|
|
# return x / param
|
|
|
|
|
|
def global_func_with_default_tensor_args(
|
|
x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))
|
|
):
|
|
x.add_(1)
|
|
kw_x.add_(1)
|
|
return x, kw_x
|
|
|
|
|
|
class ModuleWithDefaultTensorArgsMethod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x=torch.zeros((2, 2)), *, kw_x=torch.zeros((1, 2))):
|
|
x.add_(1)
|
|
kw_x.add_(1)
|
|
return x, kw_x
|
|
|
|
|
|
class WrapperModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.m = ModuleWithDefaultTensorArgsMethod()
|
|
|
|
def forward(self):
|
|
return self.m()
|
|
|
|
|
|
class DefaultsTests(torch._dynamo.test_case.TestCase):
|
|
def test_func_default_tensor_args(self):
|
|
"""
|
|
Tests that we indeed reference (and mutate) "the one" default tensor arg
|
|
stored on the globally allocated function object, both from the orig and
|
|
compiled function
|
|
"""
|
|
|
|
def func():
|
|
return global_func_with_default_tensor_args()
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
compiled_func = torch.compile(func, backend=cnts)
|
|
for i in range(4):
|
|
if i % 2 == 0:
|
|
x, kw_x = func()
|
|
else:
|
|
x, kw_x = compiled_func()
|
|
# the inner func mutates += 1 each call
|
|
self.assertTrue(same(x, torch.ones_like(x) + i))
|
|
self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i))
|
|
# Calling compiled_func twice does not recompile
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
# But with a change to the guarded default tensor, we do recompile
|
|
with patch.object(
|
|
global_func_with_default_tensor_args,
|
|
"__defaults__",
|
|
(torch.ones((3, 4, 5)),),
|
|
):
|
|
x, kw_x = compiled_func()
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 4)
|
|
|
|
with patch.object(
|
|
global_func_with_default_tensor_args,
|
|
"__kwdefaults__",
|
|
{"kw_x": torch.ones((3, 4, 5))},
|
|
):
|
|
x, kw_x = compiled_func()
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
self.assertEqual(cnts.op_count, 6)
|
|
|
|
def test_meth_default_tensor_args(self):
|
|
"""
|
|
Tests that we indeed reference (and mutate) "the one" default tensor arg
|
|
stored on the globally allocated function object, both from the orig and
|
|
compiled function
|
|
"""
|
|
mod = WrapperModule()
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
compiled_mod = torch.compile(mod, backend=cnts)
|
|
for i in range(4):
|
|
if i % 2 == 0:
|
|
x, kw_x = mod()
|
|
else:
|
|
x, kw_x = compiled_mod()
|
|
# the inner func mutates += 1 each call
|
|
self.assertTrue(same(x, torch.ones_like(x) + i))
|
|
self.assertTrue(same(kw_x, torch.ones_like(kw_x) + i))
|
|
# Calling compiled_func twice does not recompile
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 2)
|
|
|
|
# But with a change to the guarded default tensor, we do recompile
|
|
with patch.object(
|
|
ModuleWithDefaultTensorArgsMethod.forward,
|
|
"__defaults__",
|
|
(torch.ones((3, 4, 5)),),
|
|
):
|
|
x, kw_x = compiled_mod()
|
|
self.assertEqual(cnts.frame_count, 2)
|
|
self.assertEqual(cnts.op_count, 4)
|
|
|
|
with patch.object(
|
|
ModuleWithDefaultTensorArgsMethod.forward,
|
|
"__kwdefaults__",
|
|
{"kw_x": torch.ones((3, 4, 5))},
|
|
):
|
|
x, kw_x = compiled_mod()
|
|
self.assertEqual(cnts.frame_count, 3)
|
|
self.assertEqual(cnts.op_count, 6)
|
|
|
|
def test_func_default_torch_args(self):
|
|
"""
|
|
Tests other types of torch types as function default (size, dtype, device)
|
|
"""
|
|
|
|
def func_with_default_torch_args(
|
|
dt=torch.float16, ds=torch.Size((1, 2, 3)), dd=torch.device("cpu")
|
|
):
|
|
return torch.ones(ds, dtype=dt, device=dd)
|
|
|
|
def func():
|
|
return func_with_default_torch_args()
|
|
|
|
cnts = torch._dynamo.testing.CompileCounter()
|
|
compiled_func = torch.compile(func, backend=cnts)
|
|
out = func()
|
|
compiled_out = compiled_func()
|
|
self.assertEqual(out.dtype, compiled_out.dtype)
|
|
self.assertEqual(out.device, compiled_out.device)
|
|
self.assertEqual(out.size(), compiled_out.size())
|
|
self.assertEqual(cnts.frame_count, 1)
|
|
self.assertEqual(cnts.op_count, 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|