pytorch/test/jit/test_tracer.py
Aaron Gokaslan 292af3cc89 [BE][Ez]: ISC001 Auto concatenate implicit one line strings (#146408)
Apply ruff rule about implicit string concatenation, this autofixes strings that are all the same type and on the same line. These lines are broken up likely as the result of autoformatters in the past. All fixes are automated using the autofixes in ISC001.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146408
Approved by: https://github.com/justinchuby, https://github.com/janeyx99
2025-02-04 19:07:04 +00:00

2829 lines
90 KiB
Python

# Owner(s): ["oncall: jit"]
# ruff: noqa: F841
import copy
import io
import os
import sys
import unittest
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function, Variable
from torch.testing import FileCheck
# Make the helper files in test/ importable
pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
sys.path.append(pytorch_test_dir)
import warnings
# Standard library
from collections import namedtuple
from itertools import chain
from typing import Dict, List, Optional, Tuple
from torch import Tensor
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_utils import (
enable_profiling_mode_for_profiling_tests,
IS_SANDCASTLE,
skipIfCompiledWithoutNumpy,
skipIfCrossRef,
skipIfTorchDynamo,
suppress_warnings,
TemporaryFileName,
)
from torch.testing._internal.jit_utils import (
_tmp_donotuse_dont_inline_everything,
_trace,
enable_cpu_fuser,
JitTestCase,
make_global,
RUN_CUDA,
RUN_CUDA_MULTI_GPU,
)
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestTracer(JitTestCase):
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_large_nbr_kernel_args(self):
class Recurrence(nn.Module):
def __init__(self, seq_len):
super().__init__()
self.seq_len = seq_len
def forward(self, input):
input = input.transpose(0, 1)
# Main loop
output = []
for i in range(self.seq_len):
b = input[i] * 2
output.append(b)
output = torch.cat(output, 0).view(input.size(0), *output[0].size())
output = output.transpose(0, 1)
return output
input_size = 8
batch_size = 2
seq_len = 130
rec = Recurrence(seq_len)
input = torch.rand(batch_size, seq_len, input_size)
torch.cuda.set_device(0)
rec = rec.cuda()
input = input.cuda()
traced_rec = torch.jit.trace(rec, (input))
def test_trace_legacy_ctor(self):
class MyModule(nn.Module):
def forward(self, x):
return (x + 1, torch.FloatTensor([0]))
traced_rec = torch.jit.trace(MyModule(), torch.randn(2, 2))
def test_simple(self):
x = torch.tensor([0.4], requires_grad=True)
y = torch.tensor([0.7], requires_grad=True)
def f(x, y):
return torch.sigmoid(torch.tanh(x * (x + y)))
self.checkTrace(f, (x, y))
def test_trace_checking_with_global_name(self):
class MyClass(torch.nn.Module):
def forward(self, xs: List[Tensor]):
y = torch.cat(xs, dim=0)
return y
model = MyClass()
# Simulate these inputs being in the globals, like they would be if,
# e.g. they were defined outermost scope of a script
global input1, input2
input1 = torch.ones(2, 2)
input2 = torch.ones(2, 2)
m2 = torch.jit.trace(model, ((input1, input2),))
def test_trace_aliased_parameter(self):
class M(nn.Module):
def __init__(self, x):
super().__init__()
self.x = nn.Parameter(x)
def forward(self, y):
return self.x + y
m = M(torch.rand(3, 4))
r = torch.jit.trace(m, m.x)
t2 = torch.rand(3, 4)
self.assertEqual(r(t2), m.x + t2)
def test_trace_nested_fn(self):
class TracedInlineDecision(torch.nn.Module):
def forward(self, x, flag):
@torch.jit.script
def make_decision(flag, x):
if flag:
return x
else:
return torch.zeros_like(x)
x = torch.neg(x)
return make_decision(flag, x)
decision = TracedInlineDecision()
torch.jit.trace(
decision,
(torch.rand(3, 4), torch.tensor([True], dtype=torch.bool)),
check_trace=True,
)
def test_trace_single_tuple(self):
x = torch.tensor(2.0)
def f2(x):
return (x,)
jit_f2 = torch.jit.trace(f2, x)
assert f2(x) == jit_f2(x) # fails
def test_trace_out_operator_with_two_output(self):
example_input = torch.rand(2, 8)
out_1, out_2 = torch.cummax(example_input, 1)
def run_cummax(example_input, out_1, out_2):
output_1, output_2 = torch.cummax(example_input, 1, out=(out_1, out_2))
return output_1, output_2
trace_model = torch.jit.trace(run_cummax, (example_input, out_1, out_2))
def test_trace_namedtuple(self):
Point = namedtuple("point", ["x", "y"])
def f(p):
if type(p) is tuple:
p = Point(*p)
return p.x + p.y
p = Point(torch.randn(1), torch.randn(1))
traced = torch.jit.trace(f, (p,))
self.assertEqual(f(p), traced(p))
def test_trace_topk(self):
class M(torch.nn.Module):
def forward(self, x, y):
return x.topk(y, dim=1)[1]
mod = M()
inputs = (torch.randint(0, 10, (20, 20)), torch.tensor(17))
traced_func = torch.jit.trace(mod, inputs)
test_inputs = (torch.randint(0, 9, (9, 9)), torch.tensor(8))
eager_out = mod(*test_inputs)
traced_out = traced_func(*test_inputs)
self.assertNotWarn(
lambda: traced_func(*test_inputs),
"Shouldn't throw slicing related warn here",
)
self.assertEqual(eager_out, traced_out)
test_inputs = (torch.randint(0, 50, (50, 50)), torch.tensor(12))
eager_out = mod(*test_inputs)
traced_out = traced_func(*test_inputs)
self.assertNotWarn(
lambda: traced_func(*test_inputs),
"Shouldn't throw slicing related warn here",
)
self.assertEqual(eager_out, traced_out)
def test_typeas_trace_check(self):
a = torch.tensor([0.4], requires_grad=True)
b = torch.tensor([0.7], requires_grad=True)
def f(x, y):
return x.type_as(y)
trace = torch.jit.trace(f, (a, b))
def test_trace_index(self):
x = torch.tensor([0.4], requires_grad=True)
y = torch.tensor([0], dtype=torch.int64)
def fn(x, y):
return x[y]
fn_traced = torch.jit.trace(
fn,
(
x,
y,
),
)
self.assertEqual(fn(x, y), fn_traced(x, y))
# Backwards tracing was broken for indexing by a constant,
# because it's internally implemented using as_strided,
# and we attempted to trace its derivative (which is not
# currently supported.) It currently works because
# slice() is now not marked as traceable.
def test_trace_index_constant(self):
x = torch.tensor([0.4], requires_grad=True)
def fn(x):
return x[0]
def run(f):
y = f(x)
grad = torch.autograd.grad(y, x)[0].clone()
return y, grad
traced_fn = torch.jit.trace(fn, torch.ones(1))
self.assertEqual(run(fn), run(traced_fn))
def test_index_put(self):
ten = torch.zeros(3, 3)
mask = torch.tensor(
[[True, True, True], [True, False, False], [True, True, False]]
)
def test_fn(ten, mask):
ten[mask] = torch.ones(6)
return ten
traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
ten = torch.rand(3, 3)
self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
def test_canonicalize_tensor_iterator(self):
x = torch.randn(4, 4)
def f(x):
x = x + 2
x = x - 4
x = x * 6
x = x / 8
return x
traced = torch.jit.trace(f, (x,))
f(x)
graph = traced.graph_for(x)
# There should be 4 int constants for the right sides of operators, plus one
# for the alpha argument for add and sub
self.assertTrue(str(traced.graph_for(x)).count(": int = prim::Constant") == 5)
@suppress_warnings
def test_constant(self):
x = torch.randn(2, 2, requires_grad=True)
def f(x):
return x.matmul(torch.diag(torch.tensor([2.0, 2.0])))
self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
def test_wrapped_number(self):
# Scalar's get converted to 'wrapped' tensors of default tensor type.
# Wrapped tensors behave differently in certain promotion operations:
# float_tensor * double -> float but wrapped_float * double -> double.
# This can cause issues in check-trace if not handled correctly in
# `aten::isclose()`.
def foobar():
x = -10000.0
result = x * torch.ones(1, dtype=torch.float)
return result
scripted = torch.jit.trace(foobar, (), check_trace=True)
def test_inplace_transplant(self):
x = torch.tensor([0.0], requires_grad=True)
def fn(x):
y = x.clone()
y.add_(2)
y.add_(3)
return y
g, _ = torch.jit._get_trace_graph(fn, (x,))
self.run_pass("dce", g)
FileCheck().check_count("aten::clone", 1, exactly=True).check_count(
"aten::add_", 2, exactly=True
).check_next("return").run(str(g))
self.assertExportImport(g, (x,))
def test_inplace_flags(self):
class InplaceFn(Function):
@staticmethod
def forward(ctx, x):
ctx.mark_dirty(x)
return x.add_(1)
@staticmethod
def backward(ctx, go):
return go
class RegularFn(Function):
@staticmethod
def forward(ctx, x):
return x.add(1)
@staticmethod
def backward(ctx, go):
return go
x = torch.tensor([0.0], requires_grad=True)
def fn(x):
y = RegularFn.apply(x)
y = InplaceFn.apply(y)
y = InplaceFn.apply(y)
y = RegularFn.apply(y)
return y
trace_graph, _ = torch.jit._get_trace_graph(fn, (x,), _force_outplace=True)
self.run_pass("dce", trace_graph)
ops = list(trace_graph.nodes())
for op in ops:
self.assertTrue(op.hasAttribute("inplace"))
inplace_flags = [False, True, True, False]
for op, is_inplace in zip(ops, inplace_flags):
self.assertEqual(op.i("inplace"), is_inplace)
def test_inplace_check(self):
class MyInplaceFn(Function):
@staticmethod
def forward(self, x):
x.add_(1)
self.mark_dirty(x)
return x
@staticmethod
def backward(self, grad):
return grad
def fn(x):
return MyInplaceFn.apply(x)
x = torch.randn(5, 5)
ge = torch.jit.trace(fn, (x,), _force_outplace=True, check_trace=False)
with self.assertRaisesRegex(RuntimeError, "inplace MyInplaceFn"):
ge(x)
def test_force_outplace_check_fill(self):
def f(x):
return torch.empty(x.shape).fill_(7)
x = torch.randn(10, 15)
ft = torch.jit.trace(f, x, _force_outplace=True)
self.assertEqual(f(x), ft(x))
def test_force_outplace_check_zero(self):
def f(x):
return torch.empty(x.shape).zero_()
x = torch.randn(10, 15)
ft = torch.jit.trace(f, x, _force_outplace=True)
self.assertEqual(f(x), ft(x))
def do_trace_size(self, requires_grad):
def fn(x):
return x.view(x.shape[1] * 2, x.size(0), 2)
x = torch.randn(5, 2, 4, requires_grad=requires_grad)
y = torch.randn(4, 8, 4, requires_grad=requires_grad)
# Check that it behaves as expected
traced_fn = torch.jit.trace(fn, x)
self.assertEqual(traced_fn(y), fn(y))
self.assertEqual(traced_fn(x), fn(x))
def test_trace_size(self):
self.do_trace_size(False)
# test the different graph_executor path that happens when
# gradients are required and sizes are involved
def test_trace_size_with_grad(self):
self.do_trace_size(True)
def test_trace_numel(self):
def fn(x):
return x.numel()
x = torch.randn(2, 3, 4)
y = torch.randn(4, 5, 6)
traced_fn = torch.jit.trace(fn, x)
self.assertEqual(traced_fn(y), fn(y))
self.assertEqual(traced_fn(x), fn(x))
def do_trace_arange(self, requires_grad):
def arange(x):
return torch.arange(x.shape[0])
def arange_scalar(x):
return torch.arange(12)
def arange_start_end(x):
return torch.arange(start=x.shape[0], end=x.shape[0] + 5)
x = torch.randn(5, 3, 2, requires_grad=requires_grad)
y = torch.randn(8, 2, 4, requires_grad=requires_grad)
# Check that it behaves as expected
traced_arange = torch.jit.trace(arange, x)
self.assertEqual(traced_arange(y), arange(y))
self.assertEqual(traced_arange(x), arange(x))
traced_arange_scalar = torch.jit.trace(arange_scalar, x)
self.assertEqual(traced_arange_scalar(y), arange_scalar(y))
self.assertEqual(traced_arange_scalar(x), arange_scalar(x))
traced_arange_start_end = torch.jit.trace(arange_start_end, x)
self.assertEqual(traced_arange_start_end(y), arange_start_end(y))
self.assertEqual(traced_arange_start_end(x), arange_start_end(x))
def test_trace_arange(self):
self.do_trace_arange(False)
# test the different graph_executor path that happens when
# gradients are required and sizes are involved
def test_trace_arange_with_grad(self):
self.do_trace_arange(True)
# Test that a trace of torch.full(x.shape) doesn't store the shape as a constant
def test_trace_full_dynamic_shape(self):
def full_with_shape_like(x):
return torch.full(x.shape, 2.0)
x = torch.randn(3, 4)
ge = torch.jit.trace(full_with_shape_like, example_inputs=x)
y = torch.randn(2, 7)
self.assertEqual(ge(y).shape, y.shape)
self.assertEqual(ge(x).shape, x.shape)
# Test that the trace of setitem doesn't store shapes as constants
# Fix https://github.com/pytorch/pytorch/issues/43548
def test_trace_slice_setitem_dynamic_shape(self):
def slice_setitem(x, y):
x[:, 2] = y + 1
return x
x = torch.randn(3, 4)
traced = torch.jit.trace(slice_setitem, (x, x[:, 0]))
x = torch.randn(10, 5)
self.assertEqual(traced(x.clone(), x[:, 0]), slice_setitem(x.clone(), x[:, 0]))
# Suppression: we are intentionally slicing a tensor, we don't care that it
# will be constantified
@suppress_warnings
def do_trace_slice(self, requires_grad):
def slice(x):
results = []
for i in range(4):
results.append(x[: x.size(0) - i, i : x.size(2), i:3])
return tuple(results)
def slice_select(x):
results = []
for i in range(4):
results.append(x[:, i:, x.size(2) - 5])
return tuple(results)
x = torch.randn(5, 6, 7, requires_grad=requires_grad)
y = torch.randn(7, 8, 9, requires_grad=requires_grad)
# Check that it behaves as expected
traced_slice = torch.jit.trace(slice, x)
self.assertEqual(traced_slice(y), slice(y))
self.assertEqual(traced_slice(x), slice(x))
traced_slice_select = torch.jit.trace(slice_select, x)
self.assertEqual(traced_slice_select(y), slice_select(y))
self.assertEqual(traced_slice_select(x), slice_select(x))
def test_trace_slice(self):
self.do_trace_slice(False)
# test the different graph_executor path that happens when
# gradients are required and sizes are involved
def test_trace_slice_with_grad(self):
self.do_trace_slice(True)
def test_trace_casts(self):
casts = [
lambda x: x.byte(),
lambda x: x.float(),
lambda x: x.cpu(),
lambda x: x.to(device="cpu"),
lambda x: x.to(dtype=torch.int64),
lambda x: x.to(device="cpu", dtype=torch.float),
lambda x: x.to(x),
]
def assertContainsCast(trace):
self.assertEqual(
sum(n.kind() == "aten::to" for n in trace.graph.nodes()), 1
)
for cast in casts:
trace = torch.jit.trace(cast, torch.randn(2, 2))
assertContainsCast(trace)
x = torch.randn(2, 2)
self.assertEqual(trace(x), cast(x))
def to_tensor(x, y):
return x.to(y)
to_tensor_trace = torch.jit.trace(
to_tensor, (torch.randn(2, 2), torch.randn(1, 8))
)
assertContainsCast(to_tensor_trace)
x, y = torch.randn(2, 2), torch.randn(1, 10)
self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
@skipIfCompiledWithoutNumpy
@skipIfCrossRef
def test_trace_warn(self):
def fn(x):
int(x) # Warning 1.
y = x * 1
if y: # Warning 2.
pass
q = [x, x * 4]
z = q[y]
float(z) # Warning 3.
z.tolist() # Warning 4.
z.numpy() # Warning 5.
for _ in torch.ones(4, 4): # Warning 6.
pass
return z + 4
with warnings.catch_warnings(record=True) as warns:
traced_fn = torch.jit.trace(fn, torch.tensor([1]))
for warn in warns:
self.assertIs(warn.category, torch.jit.TracerWarning)
warns = [str(w.message) for w in warns]
self.assertIn("a Python integer", warns[0])
self.assertIn("a Python boolean", warns[1])
self.assertIn("a Python float", warns[2])
self.assertIn("a Python list", warns[3])
self.assertIn("a NumPy array", warns[4])
self.assertIn("Iterating over", warns[5])
def test_trace_tuple(self):
def fn(x, y):
return x, (x * y[1], x * y[0])
x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
traced_fn = torch.jit.trace(fn, (x, y))
self.assertEqual(traced_fn(x, y), fn(x, y))
# should be a tuple nested within another tuple
FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next(
"return"
).run(str(traced_fn.graph))
self.assertExportImport(traced_fn.graph, (x, y))
def test_trace_random(self):
def f(mean, std):
return torch.normal(mean, std)
traced = torch.jit.trace(
f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False
)
mean, std = torch.zeros(5, 5), torch.ones(5, 5)
with torch.random.fork_rng(devices=[]):
output = f(mean, std)
traced_output = traced(mean, std)
self.assertEqual(output, traced_output)
def test_trace_tensor_factory(self):
def run(**kwargs):
inputs_require_grads = kwargs.pop("inputs_require_grads", True)
def fn(x):
return x + torch.ones(2, 3, **kwargs)
input_kwargs = kwargs.copy()
if "out" in input_kwargs:
del input_kwargs["out"]
input = torch.ones(2, 3, **input_kwargs)
self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
# check we recorded 'ones' and did not just record a constant
tfn = torch.jit.trace(fn, input)
self.assertTrue("ones" in str(tfn.graph))
run()
run(dtype=torch.int, inputs_require_grads=False)
run(out=torch.tensor([]))
if RUN_CUDA:
run(device="cuda:0")
if RUN_CUDA_MULTI_GPU:
run(device="cuda:1")
def test_trace_indexed_assignment(self):
def stuff(x, y):
x = x.clone()
x[0] = y
return x
example = torch.rand(3, 4)
self.checkTrace(stuff, (example, example[0] + 1))
# TODO: implement
@unittest.expectedFailure
def test_output_unflatten(self):
"""Check that outputs of traced functions retain the original structure and nesting"""
def fn(x):
return (
x * 2,
(
x**2,
x + 4,
(x + 2,),
),
x * 4,
)
self.checkTrace(fn, (torch.randn(2, 2),))
def test_input_flatten(self):
"""Check that inputs to traced functions are flattened"""
def fn(x, t):
y, z = t
return x * y * z
inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
self.checkTrace(fn, inputs)
def test_input_dict_empty(self):
def test(d):
pass
with self.assertRaises(RuntimeError):
self.checkTrace(test, {})
def test_input_dict_remembers_keys(self):
"""Check that the trace remembers which keys were in a dict input"""
class TestModule(torch.nn.Module):
def forward(self, dict_input):
return dict_input["x"]
input_1 = {"x": torch.tensor(1)}
m = TestModule()
m_traced = torch.jit.trace(m, (input_1,))
self.assertEqual(m_traced(input_1), torch.tensor(1))
# should work to change the values and not the keys
input_same_key_different_value = {"x": torch.tensor(2)}
self.assertEqual(m_traced(input_same_key_different_value), torch.tensor(2))
# error to use something that doesn't have `x`
input_different_key = {"y": torch.tensor(3)}
with self.assertRaises(RuntimeError):
m_traced(input_different_key)
# it's okay to have additional elements in the dictionary, so long as 'x' is there
input_additional_key = {"x": torch.tensor(4), "y": torch.tensor(3)}
self.assertEqual(m_traced(input_additional_key), torch.tensor(4))
def test_input_dict_insertion_order(self):
"""Check that dictionary access doesn't care about insertion order"""
class TestModule(torch.nn.Module):
def forward(self, dict_input):
return dict_input["x"], dict_input["y"]
input_x_then_y = {}
input_x_then_y["x"] = torch.tensor(1)
input_x_then_y["y"] = torch.tensor(2)
m = TestModule()
m_traced = torch.jit.trace(m, (input_x_then_y,))
self.assertEqual(m_traced(input_x_then_y), (torch.tensor(1), torch.tensor(2)))
input_y_then_x = {}
input_y_then_x["y"] = torch.tensor(4)
input_y_then_x["x"] = torch.tensor(3)
self.assertEqual(m_traced(input_y_then_x), (torch.tensor(3), torch.tensor(4)))
def test_input_dict_recursive(self):
class TestModule(torch.nn.Module):
def forward(self, dict_input):
return dict_input["x"][1]
input_1 = {"x": {1: torch.tensor(1)}}
m = TestModule()
m_traced = torch.jit.trace(m, (input_1,))
input_2 = {"x": {1: torch.tensor(2)}}
self.assertEqual(m_traced(input_2), torch.tensor(2))
def test_input_dict_checkTrace_mut(self):
def test(d):
d["x"].tanh_()
return d["x"]
inputs = {"x": torch.rand(3, 4), "y": torch.rand(3, 4)}
self.checkTrace(test, (inputs,), inputs_require_grads=False)
def test_input_dict_unify(self):
def test(d):
return d["int"], d["float"]
inputs = {
"int": torch.ones((2, 2), dtype=torch.int32),
"float": torch.ones((2, 2), dtype=torch.float32),
}
self.checkTrace(test, (inputs,), inputs_require_grads=False)
def test_input_tuple_of_dicts(self):
def test(t):
d = t[0]
return d["x"]["y"]
inputs = {"x": {"y": torch.rand(2, 3)}}
self.checkTrace(test, ((inputs, inputs),), allow_unused=True)
def test_input_dict_of_dicts(self):
def test(d):
return d["x"]["y"]
nested_input = {"y": torch.rand(2, 3)}
unified_nested = {"y": torch.rand(3, 2)}
inputs = {"x": nested_input, "force_unify": unified_nested}
self.checkTrace(test, (inputs,), allow_unused=True)
def test_input_dict_of_lists(self):
def test(d):
return d["x"][0]
inputs = {"x": [torch.rand(3, 2)]}
self.checkTrace(test, (inputs,))
def test_input_list_toplevel_flatten(self):
def test(t1, t2):
return torch.add(t1, t2)
inputs = [torch.ones(2, 2), torch.rand(2, 2)]
self.checkTrace(test, inputs)
def test_input_list_toplevel_flatten_direct(self):
class Test(torch.nn.Module):
def forward(self, t1, t2):
return torch.add(t1, t2)
inputs = [torch.ones(2, 2), torch.rand(2, 2)]
torch.jit.trace(Test(), inputs)
def test_input_list_of_tuples(self):
def test(l):
return l[0][0]
inputs = [(torch.ones(2, 2),)]
self.checkTrace(test, (inputs,))
def test_input_dict_empty_list(self):
def test(d):
pass
inputs = {1: []}
with self.assertRaisesRegex(RuntimeError, "List trace"):
self.checkTrace(test, (inputs,))
def test_input_list_mixed_type(self):
def test(d):
pass
inputs = [torch.rand(2, 3), (torch.ones(2), torch.ones(2))]
with self.assertRaisesRegex(RuntimeError, "consistent"):
self.checkTrace(test, (inputs,))
def test_conv(self):
x = torch.ones(20, 16, 50, 40)
g, outputs, inputs = torch.jit._get_trace_graph(
nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True
)
m = self.createFunctionFromGraph(g)
self.assertEqual(outputs, m(*inputs))
def test_max_pool(self):
x = torch.rand(20, 16, 10, 10)
def max_pool2d(x):
return F.max_pool2d(x, 2) + 2
trace = torch.jit.trace(max_pool2d, (x))
graph = trace.graph_for(x)
FileCheck().check("aten::max_pool2d(").run(graph)
self.assertEqual(max_pool2d(x), trace(x))
def test_nested_inplace(self):
x = torch.randn(2, 2)
g, outputs, inputs = torch.jit._get_trace_graph(
lambda x: F.threshold(x, 0, 0, inplace=True), (x,), return_inputs=True
)
m = self.createFunctionFromGraph(g)
self.assertEqual(outputs, m(*inputs))
FileCheck().check("threshold_").run(str(g))
self.assertExportImport(g, (x,))
def test_repeated_input(self):
def fn(a, b):
return a + b
ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
inputs = set(ge.graph.inputs())
# three instead of 2 because the export/import in checkTrace adds a
# `self` module argument
self.assertTrue(len(inputs) == 3)
def test_repeated_output(self):
def fn(a, b):
z = a + b
return z, z
ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
tuple_output = list(ge.graph.outputs())[0]
tuple_inputs = list(tuple_output.node().inputs())
self.assertTrue(tuple_inputs[0] == tuple_inputs[1])
def test_inplace_copy(self):
x = torch.randn(4, 4, requires_grad=True)
def f(x):
out = torch.zeros(x.size())
out.copy_(x)
return out
g, outputs, inputs = torch.jit._get_trace_graph(f, (x,), return_inputs=True)
self.run_pass("dce", g)
m = self.createFunctionFromGraph(g)
self.assertEqual(outputs, m(*inputs))
self.assertExportImport(g, (x,))
def test_inplace_copy_force_outplace(self):
x = torch.randn(4, 4, requires_grad=True)
def f(x):
out = torch.zeros(x.size())
out.copy_(x)
return out
g, outputs, inputs = torch.jit._get_trace_graph(
f, (x,), return_inputs=True, _force_outplace=True
)
self.run_pass("dce", g)
m = self.createFunctionFromGraph(g)
self.assertEqual(outputs, m(*inputs))
self.assertExportImport(g, (x,))
FileCheck().check("expand_as").run(str(g))
def test_shared_param(self):
class MyModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.b = self.a = nn.Parameter(torch.randn(2, 2))
def forward(self, x):
return x * self.a + self.b
m = MyModule()
g, _ = torch.jit._get_trace_graph(m, (torch.randn(2, 2),))
self.run_pass("dce", g)
self.assertEqual(len(list(g.inputs())), 2)
FileCheck().check("mul").check("add").run(str(g))
def run_ge_tests(self, optimize, use_cuda):
with enable_profiling_mode_for_profiling_tests():
with torch.jit.optimized_execution(optimize):
def rand(*args):
t = torch.rand(*args).float()
if use_cuda:
t = t.cuda()
return t
self.checkTrace(
lambda a, b: a * b + b, [rand(1), rand(1)], [rand(2, 3), rand(2, 3)]
)
# trivial identity
self.checkTrace(lambda a, b: (b, a), [rand(1), rand(1)])
def foo(a):
t = a * a
return t * t, 4 * t
self.checkTrace(foo, [rand(1)])
# unused input
self.checkTrace(
lambda a, b: a * a, [rand(1), rand(1)], allow_unused=True
)
# test outputs that do not get used in grad
self.checkTrace(foo, [rand(1)], drop=1)
# test autograd fallback
self.checkTrace(
lambda a, b: a * b / (a - 2 * b) + b, [rand(1), rand(1)]
)
def test_ge_unoptimized(self):
self.run_ge_tests(False, False)
@unittest.skipIf(IS_SANDCASTLE, "NYI: fuser support for Sandcastle")
@enable_cpu_fuser
def test_ge_optimized(self):
with enable_profiling_mode_for_profiling_tests():
self.run_ge_tests(True, False)
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
def test_ge_cuda(self):
self.run_ge_tests(True, True)
# more manual test of graph executor that can be used as a scratchpad
def test_ge(self):
def foo(a, b):
return a * b / (a - b) + b
V = Variable
a, b = V(torch.rand(1)), V(torch.rand(1))
ge = torch.jit.trace(foo, (a, b))
a, b = V(torch.rand(1), requires_grad=True), V(
torch.rand(1), requires_grad=True
)
(r,) = ge(a, b)
da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
l2 = da * db + db * db
g2result = torch.autograd.grad(l2, [da, db])
r = foo(a, b)
da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
self.assertEqual(da, da2)
self.assertEqual(db, db2)
l3 = da2 * db2 + db2 * db2
g2result2 = torch.autograd.grad(l3, [da2, db2])
self.assertEqual(g2result, g2result2)
def test_trace_annotation(self):
@_trace(torch.rand(1))
def foo(a):
return a + a + a
x = torch.randn(5, 5)
self.assertEqual(foo(x), x + x + x)
@unittest.skipIf(not RUN_CUDA, "calls .cuda()")
# By default, on Ampere or later GPUs, nn.Linear computes float tensors at TF32 precision.
# We want float tensors to be computed at full precision in order to use the default precision
@with_tf32_off
def test_traced_module_cuda(self):
class Model(nn.Module):
def __init__(self, num_features, num_layers):
super().__init__()
self.num_layers = num_layers
layers = [
[nn.Linear(num_features, num_features), nn.Sigmoid()]
for _ in range(num_layers)
]
self.submodule = nn.Sequential(*chain(*layers))
def forward(self, x):
for i in range(self.num_layers):
x = self.submodule[i](x) + x
return x
model = Model(5, 3)
x = torch.randn(2, 5)
traced_model = torch.jit.trace(model, x)
# We're missing some attributes these modules had initially. Make sure we can
# still get the __repr__()
model.__repr__()
# XXX: indexing sequentials is broken
linear_submodule = next(iter(traced_model.submodule._modules.values()))
# All attributes that aren't parameters should raise
with self.assertRaises(AttributeError):
linear_submodule.in_features
linear_submodule.weight
linear_submodule.weight = nn.Parameter(
torch.randn(linear_submodule.weight.shape)
)
with self.assertRaises(RuntimeError):
del linear_submodule.weight
# Submodules can't be called
with self.assertRaises(RuntimeError):
linear_submodule(x)
# Type casts
linear_submodule.cuda()
traced_model.float().cuda()
cuda_out = traced_model(x.float().cuda())
traced_model.cpu()
cpu_out = traced_model(x.float())
self.assertEqual(cpu_out, cuda_out)
traced_model.to("cuda")
cuda_out = traced_model(x.float().cuda())
traced_model.to("cpu")
cpu_out = traced_model(x.float())
self.assertEqual(cpu_out, cuda_out)
traced_model.to(torch.get_default_dtype())
# state_dict + load_state_dict
state = {k: v.clone() for k, v in traced_model.state_dict().items()}
new_state = {k: v.clone().fill_(1) for k, v in state.items()}
out = traced_model(x)
traced_model.load_state_dict(new_state)
out_ones = traced_model(x)
traced_model.load_state_dict(state)
out_state = traced_model(x)
self.assertEqual(out, out_state)
self.assertNotEqual(out, out_ones)
@unittest.skipIf(not RUN_CUDA, "uses cuda")
def test_type_same_device(self):
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.dtype = torch.float16
def forward(self, x=None):
h = x.type(self.dtype)
return h
a = Model()
b = torch.jit.trace(
a, example_inputs=(torch.ones([1], device=torch.device("cuda")),)
)
FileCheck().check_not("device").run(b.code)
def test_export_no_reorder(self):
def func(a, b):
return a * b / (a - 2 * b) + b
recording_inputs = [
torch.tensor(
[0.55619788169860839844], dtype=torch.float32, requires_grad=True
),
torch.tensor(
[0.25947844982147216797], dtype=torch.float32, requires_grad=True
),
]
ge1 = torch.jit.trace(func, recording_inputs)
ge2 = self.getExportImportCopy(ge1)
outputs_ge1 = ge1(*recording_inputs)
outputs_ge2 = ge2(*recording_inputs)
grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs)
grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs)
self.assertTrue(outputs_ge1 == outputs_ge2)
self.assertTrue(grad_ge1 == grad_ge2)
def test_python_function(self):
class MyFn(Function):
@staticmethod
def forward(ctx, x):
return x + 1
@staticmethod
def backward(ctx, grad_output):
return grad_output
@_trace(torch.zeros(2))
def fn(x):
return MyFn.apply(x + 2) + 3
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.randn(2, 2, requires_grad=True)
fn(x)
fn(y)
def test_python_function_tup(self):
class MyFn(Function):
@staticmethod
def forward(ctx, x):
return x + 1, x - 1
@staticmethod
def backward(ctx, grad_output):
return grad_output, grad_output
@_trace(torch.zeros(2))
def fn(x):
a, b = MyFn.apply(x + 2)
return a + b + 3
x = torch.tensor([1.0, 2.0, 3.0])
y = torch.randn(2, 2, requires_grad=True)
fn(x)
fn(y)
def test_trace_detach(self):
def foo(x, w):
return torch.matmul(x, w).detach()
traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
FileCheck().check("matmul").check("detach").run(str(traced.graph))
x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
traced_result = traced(x, w)
self.assertEqual(foo(x, w), traced_result)
self.assertFalse(traced_result.requires_grad)
self.assertIsNone(traced_result.grad_fn)
def test_trace_detach_redispatch(self):
def foo(x, w):
y = torch.matmul(x, w)
assert y.requires_grad
y = y.detach()
# Make sure trace kernel redispatches to the right lower kernel.
assert not y.requires_grad
return y
x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
# With `check_trace=True` it will run with `@torch.no_grad()` and break assert.
torch.jit.trace(foo, (x, w), check_trace=False)
def test_trace_detach_inplace(self):
def foo(x, w):
y = torch.matmul(x, w)
y.detach_()
return y
traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
FileCheck().check("matmul").check("detach(").run(str(traced.graph))
x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
traced_result = traced(x, w)
self.assertEqual(foo(x, w), traced_result)
self.assertFalse(traced_result.requires_grad)
self.assertIsNone(traced_result.grad_fn)
def test_trace_detach_inplace_redispatch(self):
def foo(x, w):
y = torch.matmul(x, w)
assert y.requires_grad
y.detach_()
# Make sure trace kernel redispatches to the right lower kernel.
assert not y.requires_grad
return y
x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
# With `check_trace=True` it will run with `@torch.no_grad()` and break assert.
torch.jit.trace(foo, (x, w), check_trace=False)
def test_trace_slice_full_dim(self):
def foo(x):
return x[0:5, 0] + 1.0
traced = torch.jit.trace(foo, (torch.rand(5, 4),))
test_x = torch.rand(6, 3)
self.assertEqual(foo(test_x), traced(test_x))
def test_trace_dict_input(self):
class Bar(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Foo()
def forward(self, a, b):
return self.foo({"a": a, "b": b})["a"]
class Foo(torch.nn.Module):
def forward(self, x):
return {"a": x["a"] * x["b"]}
x = (torch.rand(3), torch.rand(3))
model = Bar()
self.checkTrace(model, x)
def test_trace_dict_output(self):
class TraceDictStrTensor(torch.nn.Module):
def forward(self, a, b):
return {"a": a, "b": b}
class TraceDictTensorTensor(torch.nn.Module):
def forward(self, a, b):
return {a: b, b: a}
x = (torch.rand(3), torch.rand(3))
with self.assertRaisesRegex(RuntimeError, r"Encountering a dict at the output"):
torch.jit.trace(TraceDictStrTensor(), x)
traced_dict_str_mod = torch.jit.trace(TraceDictStrTensor(), x, strict=False)
self.assertEqual(traced_dict_str_mod(*x), {"a": x[0], "b": x[1]})
traced_dict_tensor_mod = torch.jit.trace(
TraceDictTensorTensor(), x, strict=False
)
self.assertEqual(traced_dict_tensor_mod(*x), {x[0]: x[1], x[1]: x[0]})
def test_trace_with_tensor_list_output(self):
def f():
return [torch.zeros(1), torch.zeros(5)]
with self.assertWarnsRegex(
torch.jit.TracerWarning, "cause the trace to be incorrect"
):
torch.jit.trace(f, [])
traced_non_strict_f = torch.jit.trace(f, [], strict=False)
self.assertEqual(traced_non_strict_f(), f())
def test_trace_with_number_list_output(self):
def f():
return [1, 5]
with self.assertRaisesRegex(
RuntimeError, r"Only tensors.+can be output from traced functions"
):
traced_f = torch.jit.trace(f, [])
def test_trace_with_nested_tensor_list_output(self):
def f():
return [[torch.zeros(1)], [torch.zeros(5)]]
with self.assertRaisesRegex(
RuntimeError, r"Only tensors.+can be output from traced functions"
):
traced_f = torch.jit.trace(f, [])
def test_trace_with_nested_strided_tensor_output(self):
@torch.jit.script
def nt_construct(values, kv_lengths):
kv_lengths_list: List[int] = kv_lengths.tolist()
return torch._nested_tensor_from_tensor_list(
list(values.split(kv_lengths_list, dim=0)), None, None, None, None
)
def f(x, offsets):
kv_lengths = offsets[1:] - offsets[:-1]
return nt_construct(x, kv_lengths).cos()
x = torch.rand(5, 4)
offsets = torch.tensor([0, 2, 5])
ref = f(x, offsets)
f_t = torch.jit.trace(f, (x, offsets))
res = f_t(x, offsets)
self.assertEqual(ref, res)
x2 = torch.rand((8, 4))
offsets2 = torch.tensor([0, 2, 4, 8])
self.assertEqual(f(x2, offsets2), f_t(x2, offsets2))
def test_trace_variable_instantiation(self):
def random_foo(x):
return Variable(Variable(x) + 1.0)
random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
x = torch.rand(5, 6)
self.assertEqual(random_foo(x), random_foo_traced(x))
def test_trace_slice_expr_complete_type(self):
def random_foo(x):
return x + 1.0
random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
@torch.jit.script
def random_bar(x):
return random_foo_traced(x)[0:1]
x = torch.rand(3, 4)
self.assertEqual(random_bar(x), (x + 1)[0:1])
def test_trace_inline_shape(self):
# testing peephole optimization of size is turned into a constant
# in script fn
@torch.jit.script
def tensor_size(x: torch.Tensor) -> torch.Tensor:
return torch.tensor([x.size()[0]])
self.assertEqual(
tensor_size(
torch.rand(
15,
)
),
torch.tensor([15]),
)
traced_tensor_size = torch.jit.trace(
tensor_size,
torch.rand(
7,
),
)
self.assertEqual(
traced_tensor_size(
torch.rand(
15,
)
),
torch.tensor([15]),
)
@torch.jit.script
def use_device(x):
return torch.zeros_like(x, device=x.device)
def foo(x):
return use_device(x)
traced_tensor_size = torch.jit.trace(
foo,
torch.rand(
7,
),
)
self.run_pass("inline", traced_tensor_size.graph)
FileCheck().check("prim::device").run(traced_tensor_size.graph)
def test_trace_save(self):
def fn(x):
return x + 2
def check(func):
with TemporaryFileName() as fname:
func.save(fname)
loaded = torch.jit.load(fname)
input = torch.randn(2, 2)
self.assertEqual(func(input), loaded(input))
out = torch.jit.trace(fn, (torch.ones(2, 2),))
check(out)
def test_trace_optioanl_dtype(self):
class Test(torch.nn.Module):
def forward(self):
return torch.arange(5)
traced = torch.jit.trace(Test(), ())
torch.allclose(traced(), Test()())
def test_trace_save_load_copy(self):
class Test(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = torch.nn.Conv2d(3, 3, 3)
def forward(self, x):
return self.conv(x)
traced = torch.jit.trace(Test(), torch.rand(1, 3, 224, 224))
buffer = io.BytesIO()
torch.jit.save(traced, buffer)
buffer.seek(0)
loaded = torch.jit.load(buffer)
# should work
copy.copy(loaded)
copy.deepcopy(loaded)
def test_trace_export_fns(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = 3
@torch.jit.export
def __getstate__(self):
return (3, self.training)
@torch.jit.export
def __setstate__(self, state):
self.a = state[0]
self.training = state[1]
def forward(self, x):
return x + self.a
f = Foo()
traced = torch.jit.trace(f, (torch.rand(3, 4),))
expected_names = ["__getstate__", "__setstate__"]
def check(mod):
self.assertTrue(
all(name in mod._c._method_names() for name in expected_names)
)
check(traced)
imported = self.getExportImportCopy(traced)
check(imported)
def test_trace_export_fns_recursive(self):
class Foo(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.a = 3
@torch.jit.export
def __getstate__(self):
return (3, self.training)
@torch.jit.export
def __setstate__(self, state):
self.a = state[0]
self.training = state[1]
def forward(self, x):
return x + self.a
class Wrapper(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.foo = Foo()
def forward(self, x):
return self.foo(x)
f = Wrapper()
traced = torch.jit.trace(f, (torch.rand(3, 4),))
expected_names = ["__getstate__", "__setstate__"]
def check(mod):
self.assertTrue(
all(name in mod._c._method_names() for name in expected_names)
)
check(traced.foo)
imported = self.getExportImportCopy(traced)
check(imported.foo)
# Note that Bar's forward can only be traced, but not scripted
class Bar(nn.Module):
@torch.jit.export
def addTwo(self, x):
return x + 2
def forward(self, input):
return (lambda a: a + 1)(input) # noqa: PLC3002
# When tracing Bar as a submodule, we only want to script the
# exported methods, and we want to keep the forwards still
# being traced.
class WrapperExports(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bar = Bar()
@torch.jit.export
def addOne(self, x):
return x + 1
def forward(self, x):
return self.bar(x)
f = WrapperExports()
traced = torch.jit.trace(f, (torch.rand(3, 4),))
expected_names = ["addOne"]
check(traced)
def test_trace_autograd_function(self):
class TestFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return torch.neg(input)
@staticmethod
def backward(ctx, grad_output):
return torch.neg(grad_output)
class TracedModule(torch.nn.Module):
def forward(self, x):
return torch.relu(TestFunc.apply(x))
class Wrapper(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.tm = TracedModule()
def forward(self, x):
return self.tm(x)
traced = torch.jit.trace(Wrapper(), (torch.rand(3, 4),))
def test_trace_multi_output_function(self):
# An autograd.Function with two outputs.
# It swaps inputs so we can check if shape
# handling is correct in TorchScript.
class Foo(torch.autograd.Function):
@staticmethod
def forward(ctx, x, y):
return y, x
@staticmethod
def backward(ctx, du, dv):
return dv, du
class Bar(torch.nn.Module):
def forward(self, x, y):
x = x.relu()
y = y.relu()
z = Foo.apply(x, y)
return z
x = torch.rand(3, 2, dtype=torch.double)
y = torch.rand(1, 2, dtype=torch.double)
# Generate JIT IR.
traced = torch.jit.trace(Bar(), (x, y))
print(traced.graph)
# Expected output schema of the custom autograd.Function.
schema = (
"(Double(1, 2, strides=[2, 1], requires_grad=0, device=cpu), "
"Double(3, 2, strides=[2, 1], requires_grad=0, device=cpu)) "
"= ^Foo"
)
# See if expected schema exists.
FileCheck().check(schema).run(traced.graph)
# Also examine if the graph is runnable and produces
# the right result.
u, v = traced(x, y)
self.assertEqual(u, y)
self.assertEqual(v, x)
def test_interpolate_trace(self):
class test(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(1, 32, kernel_size=3, padding=1)
def forward(self, x):
y = self.conv(x)
w = nn.functional.interpolate(
y, mode="bilinear", align_corners=False, scale_factor=3
)
return w
f = test()
# no failure
g = torch.jit.trace(f, (torch.zeros(1, 1, 28, 28),))
x = torch.zeros(1, 1, 14, 14)
# constants not baked in
self.assertEqual(g(x), f(x))
@_tmp_donotuse_dont_inline_everything
def test_trace_optional(self):
@torch.jit.script
def test(x: Optional[Tensor]):
if x is None:
return torch.zeros(1)
else:
return x
def test_none():
return test(None)
def test_tensor():
return test(torch.zeros(2))
f_none = torch.jit.trace(test_none, ())
self.assertEqual(f_none(), torch.zeros(1))
f_tensor = torch.jit.trace(test_tensor, ())
self.assertEqual(f_tensor(), torch.zeros(2))
graph = f_tensor.graph
FileCheck().check('name="test"').check_next("prim::CallFunction").run(graph)
def test_trace_nested_datatypes(self):
@torch.jit.script
def foo(x):
return [[x + 1, x - 1], [x + 2, x - 2]]
def bar(x):
list_stuff = foo(x)
return list_stuff[0][0], list_stuff[1][1]
traced = torch.jit.trace(bar, torch.rand(3, 4))
x = torch.rand(5, 6)
self.assertEqual(bar(x), traced(x))
@_tmp_donotuse_dont_inline_everything
def test_call_traced_fn_from_traced_module(self):
@_trace(torch.rand(3, 4))
def traced_fn(x):
return torch.neg(x)
class TracedModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 5))
def forward(self, x):
return traced_fn(torch.mm(x, self.param))
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
# Note: neg op from the traced function should be properly inlined
FileCheck().check("aten::mm").check('name="traced_fn"').check_next(
"prim::CallFunction"
).run(str(tm.graph))
@_tmp_donotuse_dont_inline_everything
def test_call_traced_module_from_traced_module(self):
class TracedModule1(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(5, 7))
def forward(self, x):
return torch.mm(x, self.param)
class TracedModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 5))
self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))
def forward(self, x):
return self.mod(torch.mm(x, self.param)) + 1.0
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
FileCheck().check("aten::mm").check("prim::CallMethod").check_same(
"forward"
).check("aten::add").run(str(tm.graph))
def test_index_put_trace_with_view(self):
@_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
def test_index_put(target, indices, rhs):
target[indices] = rhs
return target
FileCheck().check("aten::view").check("index_put_").run(
str(test_index_put.graph)
)
def test_index_put_trace_without_view(self):
@_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
def test_index_put(target, indices, rhs):
target[indices] = rhs
return target
FileCheck().check_not("aten::view").check("index_put_").run(
str(test_index_put.graph)
)
@suppress_warnings
def test_trace_checker_dot_data(self):
with self.assertRaisesRegex(
torch.jit.TracingCheckError,
r"Tensor-valued Constant nodes differed in value across invocations",
):
@_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
def foo(x):
y = x.data
return x + y
@suppress_warnings
def test_trace_checker_control_flow(self):
def foo(x):
for _ in range(x.size(0)):
x = torch.neg(x)
return x
with self.assertRaisesRegex(
torch.jit.TracingCheckError, r"Graphs differed across invocations!"
):
torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
@suppress_warnings
def test_trace_checker_memoization(self):
with self.assertRaisesRegex(
torch.jit.TracingCheckError, r"Graphs differed across invocations!"
):
def foo(x):
if not hasattr(foo, "cache"):
foo.cache = torch.neg(x)
return x + foo.cache
traced = torch.jit.trace(
foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)]
)
def test_trace_checker_slice_lhs(self):
def foo(x):
for i in range(3):
x[i, :] = torch.zeros(4)
return x
self.checkTrace(foo, (torch.rand(3, 4),), inputs_require_grads=False)
def test_trace_checker_inplace_on_view(self):
def foo(x):
x.view(-1).add_(-x.view(-1))
return x
with self.assertWarnsRegex(
torch.jit.TracerWarning,
"Output nr 1. of the traced function does not match the "
"corresponding output of the Python function",
):
torch.jit.trace(
foo,
torch.rand(3, 4),
check_inputs=[torch.rand(5, 6)],
_force_outplace=True,
)
def test_lhs_index_fails(self):
def foo(x):
x[0, 1] = 4
return x
with self.assertWarnsRegex(
torch.jit.TracerWarning, "cause the trace to be incorrect"
):
torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
def test_lhs_index_trivial(self):
def foo(y, x):
y[...] = x
return y
self.checkTrace(
foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False
)
def test_inplace_warn(self):
def foo(x):
x.view(-1).add_(-x.view(-1))
return x
with self.assertWarnsRegex(
torch.jit.TracerWarning, "cause the trace to be incorrect"
):
torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
@suppress_warnings
def test_trace_checker_dropout_train(self):
def foo(x):
return torch.dropout(x, p=0.5, train=True)
with self.assertWarnsRegex(
torch.jit.TracerWarning,
"Output nr 1. of the traced function does not match the "
"corresponding output of the Python function",
):
torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
with self.assertWarnsRegex(
torch.jit.TracerWarning, "Trace had nondeterministic nodes"
):
torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
def test_trace_checker_dropout_notrain(self):
input = torch.rand(3, 4)
@_trace(input)
def foo(x):
return torch.dropout(x, p=0.5, train=False)
self.assertEqual(foo(input), input)
def test_trace_contiguous(self):
def foo(x):
return x[:, :, ::2].contiguous().view(12)
x = torch.rand(2, 3, 4)
traced = torch.jit.trace(foo, (x,))
y = traced(x)
self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
# This tests the logic in THPVariable_contiguous. There is short-circuiting
# code that prevents us from even getting to VariableType::contiguous, since
# it is an optimization that prevents us from acquiring the GIL for touching
# the device. We needed to add the tracing logic directly into the
# THPVariable_contiguous function only for the path where we are skipping
# dispatch into contiguous. We should see an aten::contiguous in this trace!
def test_trace_contiguous_short_circuit(self):
def foo(x):
return x.contiguous()
x = torch.rand(2, 3, 4)
traced = torch.jit.trace(foo, (x,))
FileCheck().check("aten::contiguous").run(str(traced.graph))
def test_trace_inverse(self):
def foo(x):
return ~x
foo_traced = torch.jit.trace(foo, torch.zeros(3, 4, dtype=torch.uint8))
eg = torch.zeros(3, dtype=torch.uint8)
self.assertEqual(foo_traced(eg), foo(eg))
def test_trace_modulelist(self):
class MySubmod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x)
class MyMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.ml = torch.nn.ModuleList([MySubmod(), MySubmod()])
def forward(self, x):
for mod in self.ml:
x = mod(x)
return x
traced = torch.jit.trace(MyMod(), (torch.rand(3, 4),))
def test_trace_fork_join_and_module(self):
class MySubmod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x), torch.neg(x)
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.ml = torch.nn.ModuleList([MySubmod() for i in range(2)])
def forward(self, x):
futs = []
for i in range(2):
futs.append(torch.jit._fork(self.ml[i], x))
results = []
for i in range(2):
results.append(torch.jit._wait(futs[i])[0])
return torch.stack(results)
m = Mod()
traced = torch.jit.trace(m, torch.rand(3, 4))
def test_trace_invert_module_hierarchy(self):
class MySubmod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.relu = torch.nn.ReLU()
def forward(self, x):
return self.relu(x), torch.neg(x)
class MyFunctionalMod(torch.nn.Module):
def forward(self, x, submod):
return submod(x)
class Mod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.sm = MySubmod()
self.fm = MyFunctionalMod()
def forward(self, x):
return self.fm(x, self.sm)
torch.jit.trace(Mod(), (torch.rand(3, 4),))
@skipIfCrossRef
def test_trace_records_names(self):
def foo(bar, baz):
baz = bar + 3
quick_brown_fox = torch.neg(baz)
for _ in range(20):
yeet = quick_brown_fox - 3.14
return yeet
traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
graph_str = str(traced.graph)
assert "bar" in graph_str
assert "baz" in graph_str
assert "quick_brown_fox" in graph_str
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
def test_tracing_hooks(self):
class Net(nn.Module):
def forward(self, x):
return x + x
def test_hook(is_post_hook, hook, fc):
n = Net()
if is_post_hook:
n.register_forward_hook(hook)
else:
n.register_forward_pre_hook(hook)
module = torch.jit.trace(n, (torch.tensor(1.0),))
eager_input = torch.tensor(1.0)
eager_out = n(eager_input)
fc.run(module.forward.graph)
input = torch.tensor(1.0)
output = module(input)
self.assertEqual(input, eager_input)
self.assertEqual(output, eager_out)
def hook_no_return(mod, input, output):
input[0].add_(1)
output.sub_(1)
fc = FileCheck().check("add(").check("add_(").check("sub_(")
test_hook(True, hook_no_return, fc)
def hook_return(mod, input, output):
input[0].add_(1)
return output - 3
fc = FileCheck().check("add(").check("add_(").check("sub(")
test_hook(True, hook_return, fc)
b = torch.tensor(3.0)
def captured_hook(mod, input, output):
return output - b
fc = FileCheck().check("add(").check("sub(")
test_hook(True, captured_hook, fc)
def pre_hook_no_ret(mod, input):
input[0].add_(3)
fc = FileCheck().check("add_(").check("add(")
test_hook(False, pre_hook_no_ret, fc)
def pre_hook_ret(mod, input):
return input[0] - 4
fc = FileCheck().check("sub(").check("add(")
test_hook(False, pre_hook_ret, fc)
def test_tracing_backward_hook_error(self):
class Net(nn.Module):
def forward(self, x):
return x + x
n = Net()
def backward_hook(module, grad_input, grad_output):
pass
n.register_backward_hook(backward_hook)
with self.assertRaisesRegex(Exception, "backward hooks assigned"):
torch.jit.trace(n, (torch.tensor(1.0),))
def test_tracing_multiple_methods(self):
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, x):
return self.conv(x)
def weighted_kernel_sum(self, weight):
return weight * self.conv.weight
example_weight = torch.rand(1, 1, 3, 3)
example_forward_input = torch.rand(1, 1, 3, 3)
inputs = {
"forward": example_forward_input,
"weighted_kernel_sum": example_weight,
}
n = Net()
module = torch.jit.trace_module(n, inputs)
check_inputs = []
for i in range(2):
check_weight = torch.rand(1, 1, 3, 3)
check_forward_input = torch.rand(1, 1, 3, 3)
check_inputs.append(
{"forward": check_forward_input, "weighted_kernel_sum": check_weight}
)
module = torch.jit.trace_module(
n, inputs, check_trace=True, check_inputs=check_inputs
)
self.assertTrue(module._c._has_method("forward"))
self.assertTrue(module._c._has_method("weighted_kernel_sum"))
module = torch.jit.trace(n.forward, example_forward_input)
module = torch.jit.trace(
n.forward,
example_forward_input,
check_trace=True,
check_inputs=[example_forward_input],
)
with self.assertRaisesRegex(
AttributeError,
"trace doesn't support compiling individual module's functions",
):
module = torch.jit.trace(n.weighted_kernel_sum, inputs)
def test_tensor_with_grad_as_constant(self):
param = torch.randn(3).requires_grad_()
x = torch.randn(3)
def f(x):
return x + param
with self.assertRaisesRegex(
RuntimeError, "Cannot insert a Tensor that requires grad as a constant"
):
torch.jit.trace(f, x)
def test_non_tensor_tracing(self):
def f(x):
return x + param # noqa: F821
with self.assertRaisesRegex(
RuntimeError, r"Type 'Tuple\[int\]' cannot be traced"
):
torch.jit.trace(f, (1,))
def test_trace_skip_none_submodule(self):
class TestModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.submod = torch.nn.Linear(3, 4)
self.submod = None
def forward(self, inputs):
return inputs
m = TestModule()
tm = torch.jit.trace(m, torch.tensor(1.0))
self.assertFalse(hasattr(tm, "submod"))
def test_trace_with_conditional_property(self):
class Net(nn.Module):
def __init__(self, attr=None):
super().__init__()
if attr is not None:
self._attr = attr
self.attr_name = "_attr"
@property
def attr(self):
return getattr(self, self.attr_name)
def forward(self, x):
return x
x = torch.ones(1)
torch.jit.trace(Net(), x)
def test_trace_func_argument_names_captured(self):
def fn(first_arg: torch.Tensor, second_arg: torch.Tensor) -> torch.Tensor:
return first_arg + second_arg
traced_fn = torch.jit.trace(fn, (torch.ones(1), torch.ones(1)))
FileCheck().check("first_arg").check_next("second_arg").run(
str(traced_fn.graph)
)
def test_trace_partial_func_argument_names_captured(self):
def fn(first_arg: torch.Tensor, second_arg=1) -> torch.Tensor:
return first_arg + second_arg
traced_fn = torch.jit.trace(fn, (torch.ones(1),))
FileCheck().check("first_arg").check_not("second_arg").run(str(traced_fn.graph))
def test_trace_module_argument_names_captured(self):
class TestModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(self, first_arg: torch.Tensor, second_arg: torch.Tensor):
return self.conv(first_arg) + second_arg
m = TestModule()
example_input = (torch.ones(1, 1, 3, 3), torch.ones(1, 1, 3, 3))
# Explicitly tracing module's forward method
traced_module_forward = torch.jit.trace(m.forward, example_input)
FileCheck().check("first_arg").check_next("second_arg").run(
str(traced_module_forward.graph)
)
# Tracing module's directly
traced_module = torch.jit.trace(m, example_input)
FileCheck().check("first_arg").check_next("second_arg").run(
str(traced_module.graph)
)
def test_trace_checking_with_deprecated_name(self):
class MyClass(torch.nn.Module):
def __init__(self) -> None:
super(MyClass, self).__init__()
def forward(self, x, y, **deprecated_arguments):
if len(deprecated_arguments) > 0:
raise RuntimeError(
f"Got unexpected arguments: {deprecated_arguments}"
)
return x + y
model = MyClass()
m2 = torch.jit.trace(model, (torch.ones(1), torch.ones(1)))
m3 = torch.jit.trace(
model,
example_kwarg_inputs={"x": torch.ones(1), "y": torch.ones(1)},
strict=False,
)
def test_trace_with_tuple_tensor(self):
class MyClass(torch.nn.Module):
def __init__(self) -> None:
super(MyClass, self).__init__()
def forward(self, x, y):
return x + y[0] + y[1]
model = MyClass()
traced_model = torch.jit.trace(
model, (torch.ones(1), (torch.ones(1), torch.ones(1)))
)
input_dict = {
"x": torch.tensor([2, 3]),
"y": (torch.tensor([5, 6]), torch.tensor([7, 8])),
}
self.assertEqual(model(**input_dict), traced_model(**input_dict))
traced_model = torch.jit.trace(
model,
example_kwarg_inputs={
"x": torch.ones(1),
"y": (torch.ones(1), torch.ones(1)),
},
)
self.assertEqual(model(**input_dict), traced_model(**input_dict))
def test_trace_no_duplicated_lifted_input_output(self):
class Normalize(nn.Module):
def __init__(self) -> None:
super().__init__()
self.norm = nn.GroupNorm(num_groups=32, num_channels=32)
def forward(self, x, y):
if y is None:
y = x
else:
y = self.norm(y)
y = y * 2
return y
class G(nn.Module):
def __init__(self) -> None:
super().__init__()
self.norm = Normalize()
def forward(self, x):
A = self.norm(x, None)
B = F.relu(A)
return A, B
class Net(nn.Module):
def __init__(self) -> None:
super().__init__()
self.g = G()
self.norm_1 = Normalize()
def forward(self, x):
hs = self.g(x)
A, B = hs
h = self.norm_1(B, A)
return h
net = Net()
net = net.eval()
x = torch.randn(1, 32, 16, 16)
traced = torch.jit.trace(net, x)
FileCheck().check_not("prim::TupleUnpack").run(str(traced.graph))
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
class TestMixTracingScripting(JitTestCase):
def test_trace_script(self):
@torch.jit.script
def func1(x: Tuple[Tensor, Tensor]) -> Tensor:
return x[0] + x[1]
@torch.jit.script
def func2(x: List[Tensor]) -> Tensor:
return x[0] + x[1]
a = torch.randn(5)
b = torch.randn(5)
self.checkTrace(func1, ((a, b),))
self.checkTrace(func2, ((a, b),))
@torch.jit.script
def func3(
x: Tensor, method: str = "bilinear", align_corners: bool = True
) -> Tensor:
hw = x.shape[2:4]
return F.interpolate(x, hw, mode=method, align_corners=align_corners)
inp = torch.rand(1, 3, 6, 6)
self.checkTrace(func3, (inp,))
@torch.jit.script
def func4(x: Tensor, a: List[Optional[str]]) -> Tensor:
if len(a) == 2:
return x + 2
else:
return x
def test_trace_mixed_by_script_with_dict_output(self):
@torch.jit.script
def return_dict(input: torch.Tensor) -> Dict[str, torch.Tensor]:
return {"foo": input + 1}
class TraceModule(torch.nn.Module):
def forward(self, input):
dict = return_dict(input)
return dict["foo"] + dict["foo"]
x = torch.ones(1)
tm = torch.jit.trace(TraceModule(), x)
self.assertEqual(tm(x), x + 1 + x + 1)
def test_trace_of_script(self):
@torch.jit.script
def foo(a, c):
b = 0.0
if bool(a == 0.0):
b = 1.0
return b + c
a = torch.ones(1, dtype=torch.float)
@_trace(torch.zeros(1, dtype=torch.float))
def use(b):
return foo(b - 1.0, a) + 1.0
# test we propagated shapes through the function
self.assertTrue("Dynamic" not in str(use.graph))
self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
def test_trace_with_size(self):
@_trace(torch.zeros(1, 1))
def foo(x):
return x + 1
@torch.jit.script
def bar(x):
y = int(foo(x))
if 1 == 1:
y = 7
return y + 1
self.assertEqual(8, bar(torch.ones(1, 1)))
def test_tracing_slicing(self):
@_trace(torch.zeros(10))
def foo_trace(x):
return x[-5:-3]
@torch.jit.script
def foo_script(x):
return x[-5:-3]
def foo(x):
return x[-5:-3]
a = torch.arange(0, 8)
b = torch.arange(0, 20)
self.assertEqual(foo_trace(a), foo_script(a))
self.assertEqual(foo_trace(a), foo(a))
self.assertNotEqual(foo_trace(a), foo_trace(b))
def test_tracing_indexing(self):
@_trace(torch.zeros(10))
def foo_trace(x):
return x[-2]
@torch.jit.script
def foo_script(x):
return x[-2]
def foo(x):
return x[-2]
a = torch.arange(0, 8)
b = torch.arange(0, 20)
self.assertEqual(foo_script(a), foo_trace(a))
self.assertEqual(foo_trace(a), foo(a))
self.assertNotEqual(foo_trace(a), foo_trace(b))
def test_trace_hierarchy(self):
# Test that we preserve the module hierarchy for a ScriptModule
# submodule during tracing
class AnotherScriptMod(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(1, 2, 3))
@torch.jit.script_method
def bar(self):
return torch.zeros(4, 5)
class SomeScriptMod(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.asm = AnotherScriptMod()
@torch.jit.script_method
def foo(self):
return torch.zeros(3, 4)
@torch.jit.script_method
def bar(self):
return torch.zeros(4, 3)
class TraceMe(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.ssm = SomeScriptMod()
def forward(self, x):
return self.ssm.bar() + x
orig = TraceMe()
traced = torch.jit.trace(orig, (torch.rand(4, 3),))
# for each of these checks, check that *BOTH* the underlying
# _C.ScriptModule object has the expected method/param, as well as the
# Python object that wraps it.
self.assertTrue(traced.ssm._c._has_method("foo"))
self.assertTrue(hasattr(traced.ssm, "foo"))
imported = self.getExportImportCopy(traced)
self.assertTrue(imported.ssm._c._has_method("foo"))
self.assertTrue(hasattr(imported.ssm, "foo"))
self.assertTrue(imported.ssm.asm._c._has_method("bar"))
self.assertTrue(hasattr(imported.ssm.asm, "bar"))
self.assertTrue(hasattr(imported.ssm.asm, "param"))
def test_trace_parameter(self):
class Param(nn.Module):
def __init__(self) -> None:
super().__init__()
self.register_parameter("bias", nn.Parameter(torch.empty(4, 4)))
def forward(self, x):
return x
class M3(torch.jit.ScriptModule):
def __init__(self, model):
super().__init__()
self.traced = torch.jit.trace(model, (torch.rand(3, 3)))
@torch.jit.script_method
def forward(self, x):
return self.traced(x)
class M2(nn.Module):
def __init__(self, model):
super().__init__()
self.module = M3(model)
def forward(self, x):
return self.module(x)
class M1(torch.jit.ScriptModule):
def __init__(self, model):
super().__init__()
self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3)))
@torch.jit.script_method
def forward(self, x):
return self.traced(x)
with torch.jit.optimized_execution(False):
module = M1(Param())
f = io.BytesIO()
torch.jit.save(module, f)
@_tmp_donotuse_dont_inline_everything
def test_call_script_fn_from_traced_module(self):
@torch.jit.script
def scripted_fn(x):
return torch.neg(x)
class TracedModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 5))
def forward(self, x):
return scripted_fn(torch.mm(x, self.param))
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
FileCheck().check("aten::mm").check('name="scripted_fn"').check(
"prim::CallFunction"
).run(str(tm.graph))
@_tmp_donotuse_dont_inline_everything
def test_call_script_module_from_traced_module(self):
class ScriptMod(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.param_foo = torch.nn.Parameter(torch.rand(5, 7))
@torch.jit.script_method
def forward(self, x):
return torch.mm(x, self.param_foo)
class TracedModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 5))
self.mod = ScriptMod()
def forward(self, x):
return self.mod(torch.mm(x, self.param)) + 1.0
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
FileCheck().check("aten::mm").check("prim::CallMethod").check_same(
"forward"
).check("aten::add").run(str(tm.graph))
@_tmp_donotuse_dont_inline_everything
def test_call_traced_fn_from_script_fn(self):
@_trace(torch.rand(3, 4))
def traced_fn(x):
return torch.neg(x)
@torch.jit.script
def script_fn(x):
return traced_fn(x) + 1
FileCheck().check("prim::CallFunction").check("aten::add").run(
str(script_fn.graph)
)
def test_call_traced_mod_from_script_fn(self):
with self.assertRaisesRegex(
RuntimeError,
"Cannot call a ScriptModule that is not a submodule of the caller",
):
class TracedModule(torch.nn.Module):
def forward(self, x):
return torch.mm(x, torch.zeros(4, 3))
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
@torch.jit.script
def script_fn(x):
return tm(x) + 1
@_tmp_donotuse_dont_inline_everything
def test_call_tracing_fn_from_script_module(self):
@_trace(torch.rand(3, 3))
def traced_fn(x):
return torch.neg(x)
class ScriptMod(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 3))
@torch.jit.script_method
def forward(self, x):
return traced_fn(torch.mm(x, self.param))
sm = ScriptMod()
FileCheck().check("aten::mm").check("prim::CallFunction").run(
str(sm.forward.graph)
)
@_tmp_donotuse_dont_inline_everything
def test_call_tracing_mod_from_script_module(self):
class TracedMod(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(3, 5))
def forward(self, x):
return torch.mm(x, self.param)
class ScriptMod(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.param = torch.nn.Parameter(torch.rand(4, 3))
self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))
@torch.jit.script_method
def forward(self, x):
return self.tm(torch.mm(x, self.param))
sm = ScriptMod()
FileCheck().check("aten::mm").check("prim::CallMethod").run(str(sm.graph))
def test_script_inline_trace_multiple_args(self):
class M(torch.nn.Module):
def forward(self, input, input2):
return input + input2
class M2(torch.jit.ScriptModule):
def __init__(self) -> None:
super().__init__()
self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3)))
@torch.jit.script_method
def forward(self, inp):
return self.m(inp, inp)
with torch.jit.optimized_execution(False):
m2 = M2()
m2(torch.zeros(4, 3))
def test_trace_dict_mix_script(self):
class testB(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
def forward(self, feature_map: Dict[str, List[Tensor]]) -> Tensor:
output = []
for j in feature_map.values():
output.append(self.linear(j[0]))
return torch.stack(output)
class testA(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.b = torch.jit.script(testB())
def forward(self, input_map: Dict[str, List[Tensor]]) -> Tensor:
feature_map = {}
for i, j in input_map.items():
feature_map[i] = [j[0]]
return self.b(feature_map)
input_map = {
"1": [torch.rand(2, 2), torch.rand(2, 2)],
"3": [torch.rand(2, 2), torch.rand(2, 2)],
}
model = testA()
traced_model = torch.jit.trace(model, input_map)
new_input_map = {
"1": [torch.rand(2, 2), torch.randn(2, 2)],
"3": [torch.rand(2, 2), torch.rand(2, 2)],
}
self.assertEqual(model(new_input_map), traced_model(new_input_map))
def test_trace_script_returning_complex_dict(self):
"""Tracing over a script function returning a dictionary should work.
The dictionary can should be able to contain other containers (like a tuple) recursively.
"""
class ReturnsDict(torch.nn.Module):
def forward(
self,
id_score_list: Dict[
str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
],
) -> Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
# do some random operations and then return a dict of the same structure
v = id_score_list["1000"]
idx_keys = v[1] - 1500000
weights = v[2]
result = {"1000": (v[0], idx_keys, weights)}
return result
class ChecksDict(torch.nn.Module):
def forward(
self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
):
v = input["1000"]
return v[1] + 1
class TestModule(torch.nn.Module):
def __init__(self, checks_dict, returns_dict):
super().__init__()
self.checks_dict = checks_dict
self.returns_dict = returns_dict
def forward(
self, input: Dict[str, Tuple[torch.Tensor, torch.Tensor, torch.Tensor]]
):
foo = self.returns_dict(input)
return self.checks_dict(foo)
input1 = {
"1000": (
torch.tensor([0]),
torch.tensor([], dtype=torch.int64),
torch.tensor([]),
)
}
input2 = {
"1000": (
torch.tensor([0]),
torch.tensor([1500000, 1500004], dtype=torch.int64),
torch.tensor([2.0, 3.0]),
)
}
checks_dict = torch.jit.script(ChecksDict())
returns_dict = torch.jit.script(ReturnsDict())
eager_module = TestModule(checks_dict, returns_dict)
traced_module = torch.jit.trace(eager_module, input1)
self.assertEqual(traced_module(input1), eager_module(input1))
self.assertEqual(traced_module(input2), eager_module(input2))
def test_trace_returning_dict_with_tensor_tuples(self):
"""Tracing over a module returning a dictionary whose values are tuples of tensors
should work.
"""
class ReturnsDict(torch.nn.Module):
def forward(
self, k: torch.Tensor, v: torch.Tensor
) -> Dict[str, Tuple[torch.Tensor, torch.Tensor]]:
x = 2 * k
y = 3 * v
result = {"imakey": (x, y)}
return result
class ReturnsBadDict(torch.nn.Module):
def forward(
self, k: torch.Tensor, v: torch.Tensor
) -> Dict[str, Tuple[torch.Tensor, float]]:
x = 2 * k
result = {"imakey": (x, 1)}
return result
mod = ReturnsDict()
traced_module = torch.jit.trace(
mod, [torch.ones(1), torch.ones(1)], strict=False
)
out = traced_module(torch.ones(1), torch.ones(1))
expected = {"imakey": (torch.tensor([2.0]), torch.tensor([3.0]))}
self.assertEqual(out, expected)
with self.assertRaisesRegex(
RuntimeError, "cannot be understood by the tracer, only outputs matching"
):
mod = ReturnsBadDict()
traced_module = torch.jit.trace(
mod, [torch.ones(1), torch.ones(1)], strict=False
)
def test_trace_linear(self):
m = torch.nn.Linear(20, 20)
inp = torch.rand([20, 20])
self.checkTrace(m, (inp,))
g = torch.jit.trace(m, (inp,)).graph
FileCheck().check("aten::linear").run(g)
def test_traced_module_implements_interface(self):
@torch.jit.interface
class TestModuleInterface(nn.Module):
def forward(
self, first_arg: torch.Tensor, second_arg: torch.Tensor
) -> torch.Tensor:
pass
make_global(TestModuleInterface)
class TestModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.conv = nn.Conv2d(1, 1, 3)
def forward(
self, first_arg: torch.Tensor, second_arg: torch.Tensor
) -> torch.Tensor:
return self.conv(first_arg) + second_arg
def fn_takes_interface(x: TestModuleInterface):
ones = torch.ones(1, 1, 3, 3)
return x.forward(ones, ones)
scripted_test_module = torch.jit.script(TestModule())
self.checkScript(fn_takes_interface, (scripted_test_module,))
def test_traced_module_contains_scripted_interface_types(self):
class LeafModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.rand(19))
def forward(self, input: torch.Tensor):
return input + self.weight
class LowerModuleImpl(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.leaf = LeafModule()
def forward(self, input: torch.Tensor) -> torch.Tensor:
return self.leaf(input)
@torch.jit.interface
class LowerModuleInterface(torch.nn.Module):
def forward(self, input: torch.Tensor) -> torch.Tensor:
pass
class MiddleModule(torch.nn.Module):
lower: LowerModuleInterface
def __init__(self, feature_processor_modules=None):
super().__init__()
self.lower = LowerModuleImpl()
def forward(self, input):
return self.lower(input)
class WrapperModule(torch.nn.Module):
def __init__(self, m):
super().__init__()
self.middle = m
def forward(self, input):
return self.middle(input)
class TopModule(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
m = MiddleModule()
m = torch.jit.script(m)
self.sub1 = m
self.sub2 = WrapperModule(m)
def forward(self, input: torch.Tensor):
return self.sub1(input) + self.sub2(input)
top = TopModule()
top_example_input = torch.ones(1)
torch.jit.trace(top, top_example_input)
def test_jit_trace_callfunction_return_shapes(self):
# a torch.jit.script function gets inserted as a CallFunction node
@torch.jit.script
def inner_fn(x):
return torch.cat((x, x))
def outer_fn(x, y):
return inner_fn(x + y).relu()
x, y = [torch.rand((2, 2), dtype=torch.float) for _ in range(2)]
fn_t = torch.jit.trace(outer_fn, (x, y))
# expect that the CallFunction node return type has shape information on it.
FileCheck().check("Float").check("4, 2").check("CallFunction").run(fn_t.graph)
for n in fn_t.graph.nodes():
if n.kind() == "prim::CallFunction":
self.assertTrue(n.output().isCompleteTensor())