mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
* Eliminate handle_zero_dim when broadcasting is applied earlier. This ends up not actually doing anything unless all the broadcasted tensors are scalars, which ends up with inconsistent behavior in that case only, because the type promotion rules are different. This is better solved with real type promotion logic. * Change type of script comparison to long. * Fix jit tests. * Fix cpp jit test by being consistent about long-vs-float. * Consistent float and long. * Use int64_t rather than long.
2595 lines
85 KiB
Python
2595 lines
85 KiB
Python
import torch
|
|
import torch.jit
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from contextlib import contextmanager
|
|
from itertools import product, chain
|
|
import torch.jit.frontend
|
|
from torch.autograd import Variable, Function
|
|
from torch.autograd.function import traceable
|
|
from common import TestCase, run_tests, IS_WINDOWS
|
|
import io
|
|
import sys
|
|
import unittest
|
|
import inspect
|
|
import textwrap
|
|
import numpy as np
|
|
import tempfile
|
|
import shutil
|
|
|
|
from torch.jit.frontend import NotSupportedError
|
|
|
|
try:
|
|
import torchvision
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
|
|
|
|
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
RUN_CUDA = torch.cuda.is_available()
|
|
if torch.cuda.is_available():
|
|
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
|
|
for d in range(torch.cuda.device_count()):
|
|
major = torch.cuda.get_device_capability(d)[0]
|
|
if (CUDA_VERSION < 8000 and major >= 6) or (CUDA_VERSION < 9000 and major >= 7):
|
|
RUN_CUDA = False
|
|
|
|
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
|
|
|
|
PY2 = sys.version_info[0] == 2
|
|
WINDOWS = sys.platform == 'win32'
|
|
|
|
|
|
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
|
|
hx, cx = hidden
|
|
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
|
|
|
|
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
|
ingate = F.sigmoid(ingate)
|
|
forgetgate = F.sigmoid(forgetgate)
|
|
cellgate = F.tanh(cellgate)
|
|
outgate = F.sigmoid(outgate)
|
|
|
|
cy = (forgetgate * cx) + (ingate * cellgate)
|
|
hy = outgate * F.tanh(cy)
|
|
return hy, cy
|
|
|
|
|
|
def LSTMCellC(*args, **kwargs):
|
|
hy, cy = LSTMCell(*args, **kwargs)
|
|
return torch.cat((hy, cy))
|
|
|
|
|
|
class TestJit(TestCase):
|
|
maxDiff = None
|
|
|
|
@contextmanager
|
|
def assertCompiled(self, compiled_fn):
|
|
self.assertIsInstance(compiled_fn, torch._C.CompiledFunction)
|
|
hits, misses = compiled_fn.hits, compiled_fn.misses
|
|
yield
|
|
self.assertLess(hits, compiled_fn.hits)
|
|
self.assertEqual(misses, compiled_fn.misses)
|
|
|
|
def assertExpectedTrace(self, trace, *args, **kwargs):
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
trace.set_graph(torch._C._jit_pass_canonicalize(trace.graph()))
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
self.assertExpected(str(trace), *args, **kwargs)
|
|
|
|
def test_simple(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
def f(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
# matmul is currently implemented as a native function, which
|
|
# exercises different codepaths in the JIT. The following two
|
|
# tests ensure that (1) matmul indeed traces into an atomic,
|
|
# native operation, and (2) the JIT knows how to run it
|
|
|
|
def test_matmul_native(self):
|
|
x = Variable(torch.Tensor([[0.4]]), requires_grad=True)
|
|
y = Variable(torch.Tensor([[0.7]]), requires_grad=True)
|
|
|
|
trace, z = torch.jit.get_trace_graph(lambda x, y: x.matmul(y), (x, y), nderivs=0)
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_matmul_native_run(self):
|
|
x = Variable(torch.Tensor([[0.4]]), requires_grad=True)
|
|
y = Variable(torch.Tensor([[0.7]]), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def fn(x, y):
|
|
return x.matmul(y)
|
|
|
|
z = fn(x, y)
|
|
with self.assertCompiled(fn):
|
|
z2 = fn(x, y)
|
|
self.assertEqual(z, z2)
|
|
|
|
# index-2 is not implemented in interpreter
|
|
@unittest.expectedFailure
|
|
def test_index(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.LongTensor([0]), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def fn(x, y):
|
|
return x[y]
|
|
|
|
z = fn(x, y)
|
|
with self.assertCompiled(fn):
|
|
z2 = fn(x, y)
|
|
self.assertEqual(z, z2)
|
|
|
|
# Backwards tracing was broken for indexing by a constant,
|
|
# because it's internally implemented using as_strided,
|
|
# and we attempted to trace its derivative (which is not
|
|
# currently supported.) It currently works because
|
|
# slice() is now not marked as traceable.
|
|
def test_index_constant(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(x):
|
|
return x[0]
|
|
|
|
z = fn(x)
|
|
z.backward()
|
|
grad = x.grad.clone()
|
|
x.grad.zero_()
|
|
with self.assertCompiled(fn):
|
|
z2 = fn(x)
|
|
z2.backward()
|
|
grad2 = x.grad.clone()
|
|
self.assertEqual(z, z2)
|
|
self.assertEqual(grad, grad2)
|
|
|
|
def test_scopes(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
def f(x, y):
|
|
out = x + y
|
|
with torch.jit.scope('Foo', out):
|
|
out = x * out
|
|
with torch.jit.scope('Bar', out):
|
|
out = torch.tanh(out)
|
|
out = torch.sigmoid(out)
|
|
return out
|
|
|
|
trace, z = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_scopes_intermediate_node(self):
|
|
|
|
class Net(nn.Module):
|
|
def forward(self, x):
|
|
return F.log_softmax(x, dim=0)
|
|
|
|
net = Net()
|
|
t = Variable(torch.ones(2), requires_grad=True)
|
|
trace, _ = torch.jit.get_trace_graph(net, (t, ))
|
|
torch.onnx._optimize_trace(trace, False)
|
|
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_scopes_identity_node(self):
|
|
|
|
class Net(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=3, stride=2),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
return x
|
|
|
|
model = Net()
|
|
|
|
t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True)
|
|
|
|
with torch.onnx.set_training(model, False):
|
|
trace, _ = torch.jit.get_trace_graph(model, (t, ))
|
|
|
|
torch.onnx._optimize_trace(trace, False)
|
|
|
|
self.assertExpectedTrace(trace)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_lstm_fusion(self):
|
|
input = Variable(torch.randn(3, 10).float().cuda())
|
|
hx = Variable(torch.randn(3, 20).float().cuda())
|
|
cx = Variable(torch.randn(3, 20).float().cuda())
|
|
module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes
|
|
|
|
trace, _ = torch.jit.get_trace_graph(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
torch._C._jit_pass_fuse(trace.graph())
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def run_lstm_fusion(self, use_cuda):
|
|
def to_type(x):
|
|
x = x.float()
|
|
if use_cuda:
|
|
x = x.cuda()
|
|
return x
|
|
|
|
def rand_v(a, b):
|
|
return Variable(to_type(torch.randn(a, b)))
|
|
|
|
input = rand_v(3, 10)
|
|
hx = rand_v(3, 20)
|
|
cx = rand_v(3, 20)
|
|
module = to_type(nn.LSTMCell(10, 20)) # Just to allocate weights with correct sizes
|
|
|
|
CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCell)
|
|
|
|
z = CompiledLSTMCell(input, (hx, cx), *module.parameters())
|
|
with self.assertCompiled(CompiledLSTMCell):
|
|
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters())
|
|
self.assertEqual(z, z2)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_run_lstm_fusion_cuda(self):
|
|
self.run_lstm_fusion(True)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
def test_run_lstm_fusion_cpu(self):
|
|
self.run_lstm_fusion(False)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_run_lstm_fusion_concat(self):
|
|
input = Variable(torch.randn(3, 10).float().cuda())
|
|
hx = Variable(torch.randn(3, 20).float().cuda())
|
|
cx = Variable(torch.randn(3, 20).float().cuda())
|
|
module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes
|
|
|
|
CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCellC)
|
|
|
|
z = CompiledLSTMCell(input, (hx, cx), *module.parameters())
|
|
with self.assertCompiled(CompiledLSTMCell):
|
|
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters())
|
|
self.assertEqual(z, z2)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_concat_fusion(self):
|
|
hx = Variable(torch.randn(3, 20).float().cuda())
|
|
cx = Variable(torch.randn(3, 20).float().cuda())
|
|
|
|
def Foo(hx, cx):
|
|
return torch.cat((hx + cx, hx * cx))
|
|
|
|
trace, _ = torch.jit.get_trace_graph(Foo, (hx, cx))
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
torch._C._jit_pass_fuse(trace.graph())
|
|
self.assertExpectedTrace(trace)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_fusion_distribute(self):
|
|
def f(x, y):
|
|
z1, z2 = (x + y).chunk(2, dim=1)
|
|
return z1 * z2
|
|
x = Variable(torch.randn(4, 4).float().cuda())
|
|
y = Variable(torch.randn(4, 4).float().cuda())
|
|
trace, _ = torch.jit.get_trace_graph(f, (x, y), nderivs=0)
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
self.assertExpectedTrace(trace, 'raw')
|
|
torch._C._jit_pass_fuse(trace.graph())
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_arg_configurations(self):
|
|
"""Different arg configurations should trigger different traces"""
|
|
x = Variable(torch.FloatTensor(4, 4).uniform_())
|
|
x_double = Variable(x.data.double())
|
|
x_grad = Variable(x.data.clone(), requires_grad=True)
|
|
y = Variable(torch.randn(4))
|
|
|
|
configurations = [
|
|
(x,),
|
|
(x_double,),
|
|
(x_grad,),
|
|
(y,),
|
|
([x, x],),
|
|
([x, y],),
|
|
]
|
|
if torch.cuda.is_available():
|
|
x_cuda = Variable(x.data.cuda())
|
|
configurations += [
|
|
(x_cuda,),
|
|
([x, x_cuda],),
|
|
([x_cuda, x],),
|
|
([[x_cuda, x]],),
|
|
]
|
|
if torch.cuda.device_count() > 1:
|
|
x_cuda_1 = Variable(x.data.cuda(1))
|
|
configurations += [
|
|
(x_cuda_1,),
|
|
([x_cuda, x_cuda_1],),
|
|
]
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def fn(*args):
|
|
in_vars, _ = torch._C._jit_flatten(args)
|
|
return in_vars[0] + 1
|
|
|
|
for i, config in enumerate(configurations):
|
|
self.assertFalse(fn.has_trace_for(*config))
|
|
fn(*config)
|
|
self.assertTrue(fn.has_trace_for(*config))
|
|
for unk_config in configurations[i + 1:]:
|
|
self.assertFalse(fn.has_trace_for(*unk_config))
|
|
self.assertEqual(fn.hits, 0)
|
|
|
|
def test_cse(self):
|
|
x = Variable(torch.Tensor([0.4, 0.3]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7, 0.5]), requires_grad=True)
|
|
|
|
trace, inputs = torch._C._tracer_enter((x, y), 0)
|
|
|
|
def fn(x, y):
|
|
w = (x + y) * (x + y) * (x + y)
|
|
t = torch.tanh(w) + torch.tanh(w)
|
|
z = (x + y) * (x + y) * (x + y) + t
|
|
return z
|
|
z = fn(*inputs)
|
|
torch._C._tracer_exit((z,))
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
torch._C._jit_pass_cse(trace.graph())
|
|
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_compile_run_twice(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=0, optimize=False)
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
z = doit(x, y)
|
|
with self.assertCompiled(doit):
|
|
z2 = doit(x, y)
|
|
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
|
|
self.assertEqual(z, z2)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_compile_addc(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True).float().cuda()
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True).float().cuda()
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y) + 1))
|
|
|
|
z = doit(x, y)
|
|
with self.assertCompiled(doit):
|
|
z2 = doit(x, y)
|
|
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y) + 1)))
|
|
self.assertEqual(z, z2)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
|
|
def test_compile_fuse_last_device(self):
|
|
max_device = torch.cuda.device_count() - 1
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True).float().cuda(max_device)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True).float().cuda(max_device)
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y) + 1))
|
|
|
|
z = doit(x, y)
|
|
with self.assertCompiled(doit):
|
|
z2 = doit(x, y)
|
|
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y) + 1)))
|
|
self.assertEqual(z, z2)
|
|
|
|
def test_traced_function(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
z = doit(x, y)
|
|
with self.assertCompiled(doit):
|
|
z2 = doit(x, y)
|
|
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
|
|
self.assertEqual(z, z2)
|
|
|
|
def test_disabled_traced_function(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
@torch.jit.compile(enabled=False)
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
z = doit(x, y)
|
|
z2 = doit(x, y)
|
|
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
|
|
self.assertEqual(z, z2)
|
|
|
|
def test_assign_traces(self):
|
|
"""Check that output Variables are assigned traces before they are saved."""
|
|
@traceable
|
|
class MyFn(Function):
|
|
@staticmethod
|
|
def forward(ctx, a):
|
|
out = a * 2
|
|
ctx.save_for_backward(out)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_a):
|
|
a, = ctx.saved_tensors
|
|
return a * grad_a
|
|
|
|
x = Variable(torch.randn(10, 10), requires_grad=True)
|
|
trace, out = torch.jit.get_trace_graph(MyFn.apply, x, nderivs=1)
|
|
out.sum().backward()
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_legacy_traced_module(self):
|
|
input = Variable(torch.randn(3, 10))
|
|
hx = Variable(torch.randn(3, 20))
|
|
cx = Variable(torch.randn(3, 20))
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
class MyLSTMCell(nn.LSTMCell):
|
|
pass
|
|
|
|
lstm = MyLSTMCell(10, 20)
|
|
|
|
out = lstm(input, (hx, cx))
|
|
with self.assertCompiled(lstm):
|
|
out2 = lstm(input, (hx, cx))
|
|
self.assertEqual(out, out2)
|
|
|
|
def test_autograd_closure(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
trace, inputs = torch._C._tracer_enter((x, y), 1)
|
|
|
|
def fn(x, y):
|
|
z = torch.sigmoid(x * (x + y))
|
|
w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
|
|
return z, w
|
|
z, w = fn(*inputs)
|
|
|
|
torch._C._tracer_exit((z, w))
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
|
|
(z * w).backward()
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
|
|
x_grad = x.grad.data.clone()
|
|
x.grad.data.zero_()
|
|
|
|
function = torch._C._jit_createInterpreterFactory(trace)
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
z2, w2 = function()(x, y)
|
|
(z2 * w2).backward()
|
|
self.assertEqual(z, z2)
|
|
self.assertEqual(w, w2)
|
|
self.assertEqual(x.grad.data, x_grad)
|
|
|
|
def test_verify(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
@torch.jit.compile
|
|
def f(x, y):
|
|
z = torch.sigmoid(x * (x + y))
|
|
w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
|
|
return z, w
|
|
|
|
torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
|
|
|
|
def test_constant(self):
|
|
x = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
trace, (tx,) = torch._C._tracer_enter((x,), 0)
|
|
|
|
y = Variable(torch.diag(torch.Tensor([2, 2])))
|
|
z = tx.matmul(y)
|
|
|
|
torch._C._tracer_exit((z,))
|
|
function = torch._C._jit_createInterpreterFactory(trace)
|
|
|
|
z2 = function()(x)
|
|
self.assertEqual(z, z2)
|
|
|
|
y.data.fill_(1000) # make sure the data has been cloned
|
|
|
|
x2 = Variable(torch.ones(2, 2) * 2, requires_grad=True)
|
|
z3 = function()(x2)
|
|
self.assertEqual(z3.data, torch.ones(2, 2) * 4)
|
|
|
|
def test_c_function(self):
|
|
x = Variable(torch.randn(1, 3, 10, 10))
|
|
m = nn.Conv2d(3, 8, 3, 1)
|
|
|
|
trace, inputs = torch._C._tracer_enter((x,) + tuple(m.parameters()), 0)
|
|
y = m(inputs[0])
|
|
torch._C._tracer_exit((y,))
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_legacy_fail(self):
|
|
|
|
class MyLegacyFn(Function):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
def backward(self, grad_output):
|
|
return grad_output
|
|
|
|
x = Variable(torch.Tensor([0]), requires_grad=True)
|
|
trace, inputs = torch._C._tracer_enter((x,), 0)
|
|
self.assertRaisesRegex(RuntimeError, "MyLegacyFn", lambda: MyLegacyFn()(*inputs))
|
|
torch._C._tracer_exit(inputs)
|
|
|
|
def test_inplace_transplant(self):
|
|
x = Variable(torch.Tensor([0]), requires_grad=True)
|
|
trace, inputs = torch._C._tracer_enter((x,), 0)
|
|
|
|
def fn(x):
|
|
y = x.clone()
|
|
y.add_(2)
|
|
y.add_(3)
|
|
return y
|
|
y = fn(*inputs)
|
|
torch._C._tracer_exit((y,))
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_inplace_flags(self):
|
|
class InplaceFn(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.mark_dirty(x)
|
|
return x.add_(1)
|
|
|
|
@staticmethod
|
|
def backward(ctx, go):
|
|
return go
|
|
|
|
class RegularFn(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x.add(1)
|
|
|
|
@staticmethod
|
|
def backward(ctx, go):
|
|
return go
|
|
|
|
x = Variable(torch.Tensor([0]), requires_grad=True)
|
|
trace, inputs = torch._C._tracer_enter((x,), 0)
|
|
|
|
def fn(x):
|
|
y = RegularFn.apply(x)
|
|
y = InplaceFn.apply(y)
|
|
y = InplaceFn.apply(y)
|
|
y = RegularFn.apply(y)
|
|
return y
|
|
y = fn(*inputs)
|
|
torch._C._tracer_exit((y,))
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
ops = [n for n in trace.graph().nodes()]
|
|
for op in ops:
|
|
self.assertTrue(op.hasAttribute('inplace'))
|
|
inplace_flags = [False, True, True, False]
|
|
for op, is_inplace in zip(ops, inplace_flags):
|
|
self.assertEqual(op.i('inplace'), is_inplace)
|
|
|
|
def test_inplace_check(self):
|
|
class MyInplaceFn(Function):
|
|
@staticmethod
|
|
def forward(self, x):
|
|
x.add_(1)
|
|
self.mark_dirty(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(self, grad):
|
|
return grad
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def fn(x):
|
|
return MyInplaceFn.apply(x)
|
|
x = Variable(torch.randn(5, 5))
|
|
fn(x) # trace
|
|
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
|
|
fn(x)
|
|
|
|
def test_backward(self):
|
|
a = Variable(torch.randn(2, 2), requires_grad=True)
|
|
b = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
x = a
|
|
y = a * b
|
|
|
|
trace, inputs = torch._C._tracer_enter((x, y), 2)
|
|
|
|
def fn(x, y):
|
|
return y * 2 * x
|
|
z = fn(*inputs)
|
|
torch._C._tracer_exit((z,))
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
|
|
# Run first backward
|
|
grad, = torch.autograd.grad(z, x, Variable(torch.ones(2, 2), requires_grad=True), create_graph=True)
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
|
|
# Run second backward
|
|
grad.sum().backward(create_graph=True)
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
|
|
# Run dead code elimination to remove unused trace nodes
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
# This is nondeterministic, see:
|
|
# https://github.com/ezyang/pytorch/issues/227
|
|
# self.assertExpectedTrace(trace)
|
|
self.skipTest("output is nondeterministic on Travis/Python 3.5")
|
|
|
|
def test_backward_opaque(self):
|
|
x = Variable(torch.randn(3, 3), requires_grad=True)
|
|
y = Variable(torch.randn(3, 3), requires_grad=True)
|
|
|
|
trace, inputs = torch._C._tracer_enter((x, y), 2)
|
|
|
|
def fn(x, y):
|
|
return x.cross(y)
|
|
z = fn(*inputs)
|
|
torch._C._tracer_exit((z,))
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
|
|
# Run first backward
|
|
grad, = torch.autograd.grad(z, x, Variable(torch.ones(3, 3), requires_grad=True), create_graph=True)
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
|
|
# Run dead code elimination to remove unused trace nodes
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
# This is nondeterministic, see:
|
|
# https://github.com/ezyang/pytorch/issues/227
|
|
# self.assertExpectedTrace(trace)
|
|
self.skipTest("output is nondeterministic on Travis/Python 3.5")
|
|
|
|
def test_backward_closure(self):
|
|
"""Check that autograd closures handle multiple stages correctly."""
|
|
x = Variable(torch.randn(1), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=2)
|
|
def fn(x):
|
|
return x * x
|
|
|
|
# Generate trace
|
|
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
|
|
self.assertFalse(fn.has_trace_for(x))
|
|
grad_x.backward()
|
|
self.assertTrue(fn.has_trace_for(x))
|
|
|
|
x_grad = x.grad.data.clone()
|
|
x.grad.data.zero_()
|
|
|
|
# Run the trace
|
|
with self.assertCompiled(fn):
|
|
output = fn(x)
|
|
grad_x, = torch.autograd.grad(output, (x,), create_graph=True)
|
|
grad_x.backward()
|
|
|
|
self.assertEqual(x.grad.data, x_grad)
|
|
|
|
def test_trace_expire(self):
|
|
x = Variable(torch.randn(2, 2), requires_grad=True)
|
|
y = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
def record_trace(num_backwards):
|
|
trace, inputs = torch._C._tracer_enter((x, y), num_backwards)
|
|
|
|
def fn(x, y):
|
|
return y * 2 * x
|
|
z = fn(*inputs)
|
|
torch._C._tracer_exit((z,))
|
|
return z, trace
|
|
|
|
def check(expired, complete):
|
|
self.assertEqual(trace.is_expired, expired)
|
|
self.assertEqual(trace.is_complete, complete)
|
|
|
|
z, trace = record_trace(0)
|
|
check(False, True)
|
|
del z
|
|
check(False, True)
|
|
|
|
z, trace = record_trace(1)
|
|
check(False, False)
|
|
del z
|
|
check(True, False)
|
|
|
|
z, trace = record_trace(1)
|
|
check(False, False)
|
|
z.sum().backward()
|
|
check(False, True)
|
|
del z
|
|
check(False, True)
|
|
|
|
def test_multiuse_fn(self):
|
|
x = Variable(torch.randn(2, 2), requires_grad=True)
|
|
w = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
@torch.jit.compile
|
|
def cell(x, w):
|
|
return x * w + 2
|
|
|
|
out = cell(cell(cell(x, w), w), w)
|
|
self.assertFalse(cell.has_trace_for(x, w))
|
|
|
|
out.sum().backward()
|
|
self.assertTrue(cell.has_trace_for(x, w))
|
|
|
|
torch.jit.verify(cell, (x, w), devices=[])
|
|
|
|
def test_output_unflatten(self):
|
|
"""Check that outputs of traced functions retain the original structure and nesting"""
|
|
x = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
def fn(x):
|
|
return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
|
|
|
|
expected_out = fn(x)
|
|
fn = torch.jit.compile(fn)
|
|
|
|
def recursive_sum(obj):
|
|
if isinstance(obj, Variable):
|
|
return obj.sum()
|
|
else:
|
|
return sum(recursive_sum(o) for o in obj)
|
|
|
|
recursive_sum(fn(x)).backward()
|
|
self.assertTrue(fn.has_trace_for(x))
|
|
with self.assertCompiled(fn):
|
|
self.assertEqual(fn(x), expected_out)
|
|
|
|
def test_input_flatten(self):
|
|
"""Check that inputs to traced functions are flattened"""
|
|
def make_var():
|
|
return Variable(torch.randn(1), requires_grad=True)
|
|
x = (make_var(), (make_var(), make_var()))
|
|
|
|
def fn(x, t):
|
|
y, z = t
|
|
return x * y * z
|
|
|
|
expected_out = fn(*x)
|
|
fn = torch.jit.compile(fn)
|
|
fn(*x).backward()
|
|
self.assertTrue(fn.has_trace_for(*x))
|
|
with self.assertCompiled(fn):
|
|
self.assertEqual(fn(*x), expected_out)
|
|
|
|
def test_flags(self):
|
|
x = Variable(torch.randn(2, 2))
|
|
y = Variable(torch.randn(2, 2))
|
|
|
|
@torch.jit.compile
|
|
def fn(x, y):
|
|
return (x * x + y * y + x * y).sum()
|
|
|
|
grads = {}
|
|
for rx, ry in product((True, False), repeat=2):
|
|
x.requires_grad = rx
|
|
y.requires_grad = ry
|
|
|
|
self.assertFalse(fn.has_trace_for(x, y))
|
|
out = fn(x, y)
|
|
|
|
self.assertFalse(fn.has_trace_for(x, y))
|
|
for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
|
|
if not compute:
|
|
continue
|
|
grad_v, = torch.autograd.grad(out, v, retain_graph=True)
|
|
expected_grad = grads.setdefault(name, grad_v)
|
|
self.assertEqual(grad_v, expected_grad)
|
|
self.assertEqual(fn.has_trace_for(x, y), rx or ry)
|
|
|
|
def test_no_grad_fallback(self):
|
|
"""Check that Traceable falls back to num_backwards=0 if in no-backprop mode"""
|
|
x = Variable(torch.randn(2, 2))
|
|
y = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
@torch.jit.compile
|
|
def fn(x, y):
|
|
return x * x + x * y
|
|
|
|
out = fn(x, y)
|
|
self.assertFalse(fn.has_trace_for(x, y))
|
|
with torch.no_grad():
|
|
out = fn(x, y)
|
|
self.assertTrue(fn.has_trace_for(x, y))
|
|
with self.assertCompiled(fn):
|
|
out2 = fn(x, y)
|
|
self.assertEqual(out, out2)
|
|
|
|
def test_backward_flag_checks(self):
|
|
x = Variable(torch.randn(1), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=2)
|
|
def fn(x):
|
|
return x * x
|
|
|
|
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
|
|
self.assertFalse(fn.has_trace_for(x))
|
|
grad_x.backward()
|
|
self.assertTrue(fn.has_trace_for(x))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'was compiled with'):
|
|
fn(x).backward(Variable(torch.ones(1), requires_grad=True))
|
|
with self.assertRaisesRegex(RuntimeError, 'was compiled with'):
|
|
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
|
|
grad_x.backward(Variable(torch.ones(1), requires_grad=True))
|
|
|
|
# TODO: Test executing this
|
|
|
|
def test_python_ir(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
traced, _ = torch.jit.get_trace_graph(doit, (x, y))
|
|
g = torch._C._jit_get_graph(traced)
|
|
g2 = torch._C.Graph()
|
|
g_to_g2 = {}
|
|
for node in g.inputs():
|
|
g_to_g2[node] = g2.addInput()
|
|
for node in g.nodes():
|
|
n_ = g2.createClone(node, lambda x: g_to_g2[x])
|
|
g2.appendNode(n_)
|
|
for o, no in zip(node.outputs(), n_.outputs()):
|
|
g_to_g2[o] = no
|
|
|
|
for node in g.outputs():
|
|
g2.registerOutput(g_to_g2[node])
|
|
|
|
t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
|
|
assert(t_node.attributeNames() == ["a"])
|
|
g2.appendNode(t_node)
|
|
assert(torch.equal(torch.ones([2, 2]), t_node.t("a")))
|
|
self.assertExpected(str(g2))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
|
|
def test_cpp(self):
|
|
# rather than rebuild assertExpected in cpp,
|
|
# just glob all the cpp outputs into one file for now
|
|
self.assertExpected(torch._C._jit_run_cpp_tests())
|
|
|
|
def test_batchnorm(self):
|
|
x = Variable(torch.randn(2, 2, 2, 2).fill_(1.0), requires_grad=True)
|
|
trace, _ = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_dropout(self):
|
|
x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
|
|
trace, _ = torch.jit.get_trace_graph(nn.Dropout(0.6), x)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_batchnorm_run_twice(self):
|
|
@torch.jit.compile(nderivs=0)
|
|
class MyBatchNorm2d(nn.BatchNorm2d):
|
|
pass
|
|
|
|
bn = MyBatchNorm2d(1)
|
|
x = Variable(torch.randn(5, 1, 2, 1))
|
|
z = bn(x)
|
|
with self.assertCompiled(bn):
|
|
z2 = bn(x)
|
|
self.assertEqual(z, z2)
|
|
|
|
def test_non_decorator_use_fails(self):
|
|
MyLSTM = torch.jit.compile(nn.LSTM)
|
|
self.assertRaisesRegex(TypeError, "class decorator", lambda: MyLSTM(2, 2))
|
|
|
|
def test_conv(self):
|
|
x = Variable(torch.randn(20, 16, 50, 40).fill_(1.0), requires_grad=True)
|
|
trace, _ = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_reuse_function(self):
|
|
@torch.jit.compile(nderivs=0)
|
|
def clinear(*args):
|
|
return F.linear(*args)
|
|
|
|
def cast(x):
|
|
return x
|
|
|
|
input = Variable(cast(torch.randn(1, 1)))
|
|
weights = Variable(cast(torch.randn(1, 1)))
|
|
bias = Variable(cast(torch.randn(1, 1)))
|
|
|
|
# linear AKA addmm without bias is of particular interest
|
|
# because we allocate a zero-filled new variable when we execute,
|
|
# and then *fill* it with the result
|
|
|
|
r1_ = clinear(input, weights)
|
|
with self.assertCompiled(clinear):
|
|
r1 = clinear(r1_, weights)
|
|
r2 = F.linear(F.linear(input, weights), weights)
|
|
|
|
self.assertEqual(r1, r2)
|
|
|
|
def test_unused_input(self):
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(a, b, c):
|
|
return a + b
|
|
|
|
a, b, c = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(3)]
|
|
fn(a, b, c).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(a, b, c).sum().backward()
|
|
|
|
def test_repeated_input(self):
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
a, b = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(2)]
|
|
fn(a, a).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(a, a).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(a, b).sum().backward()
|
|
self.assertExpected(str(fn.graph_for(a, a)))
|
|
|
|
def test_repeated_output(self):
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(a, b):
|
|
z = a + b
|
|
return z, z
|
|
|
|
a, b = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(2)]
|
|
sum(fn(a, b)).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
sum(fn(a, b)).sum().backward()
|
|
self.assertExpected(str(fn.graph_for(a, b)))
|
|
|
|
def test_re_enter(self):
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn2(a, b, c):
|
|
return fn(a, b) + c
|
|
|
|
a, b, c = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(3)]
|
|
|
|
fn(a, b).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(a, b).sum().backward()
|
|
|
|
fn2(a, b, c).sum().backward()
|
|
with self.assertCompiled(fn2):
|
|
fn2(a, b, c).sum().backward()
|
|
|
|
def test_mini_wlm(self):
|
|
"""Exercise null-edge pruning in the tracer."""
|
|
|
|
@torch.jit.compile
|
|
class MyModel(nn.Module):
|
|
def __init__(self):
|
|
super(MyModel, self).__init__()
|
|
self.encoder = nn.Embedding(2, 2)
|
|
|
|
def forward(self, input, hidden):
|
|
emb = self.encoder(input)
|
|
hidden = hidden.clone() # simulate some RNN operation
|
|
return emb, hidden
|
|
|
|
model = MyModel()
|
|
|
|
x = Variable(torch.LongTensor([[0, 1], [1, 0]]))
|
|
y = Variable(torch.FloatTensor([0]))
|
|
|
|
z, _ = model(x, y)
|
|
z.sum().backward()
|
|
self.assertTrue(model.has_trace_for(x, y))
|
|
|
|
with self.assertCompiled(model):
|
|
z, _ = model(x, y)
|
|
z.sum().backward()
|
|
|
|
def test_module_cast(self):
|
|
"""Compiled modules can be casted to other data types"""
|
|
@torch.jit.compile(nderivs=0)
|
|
class Adder(nn.Module):
|
|
def __init__(self):
|
|
super(Adder, self).__init__()
|
|
self.y = nn.Parameter(torch.randn(2, 2))
|
|
|
|
def forward(self, x):
|
|
return x + self.y
|
|
|
|
x = Variable(torch.randn(2, 2).float())
|
|
# Wrap it in a sequential to make sure it works for submodules
|
|
a = nn.Sequential(Adder()).float()
|
|
|
|
def check_type(caster):
|
|
caster(a)
|
|
a(caster(x))
|
|
with self.assertCompiled(a[0]):
|
|
a(caster(x))
|
|
|
|
check_type(lambda x: x)
|
|
check_type(lambda x: x.double())
|
|
if torch.cuda.is_available():
|
|
check_type(lambda x: x.float().cuda())
|
|
check_type(lambda x: x.double().cuda())
|
|
self.assertEqual(a[0].hits, 4 if torch.cuda.is_available() else 2)
|
|
|
|
# Tracer fails when it receives the same grad variable as multiple input to
|
|
# traced region. The problem is that it's not immediately obvious how to
|
|
# assign multiple inputs to this Variable. It might be possible to solve
|
|
# this using the view mechanism, but this requires some thought.
|
|
# In general, it should be supported, because the user has no control
|
|
# over this (and it's quite common, e.g. the sum call below will pass the same
|
|
# grad variable as both inputs to grad of fn).
|
|
@unittest.skip("Broken - repeated grads trigger an assertion failure.")
|
|
def test_repeated_grad(self):
|
|
@torch.jit.compile
|
|
def fn(x):
|
|
return x * x, x + x
|
|
|
|
x = Variable(torch.randn(5, 5), requires_grad=True)
|
|
# This shouldn't raise!
|
|
sum(fn(x)).sum().backward()
|
|
|
|
def test_input_pruning(self):
|
|
"""Check that stage 1 will return only one value"""
|
|
# One of the inputs doesn't require grad, so it should be pruned
|
|
@torch.jit.compile
|
|
def fn(x, y):
|
|
return x * y, x + y
|
|
|
|
x = Variable(torch.randn(5, 5), requires_grad=True)
|
|
y = Variable(torch.randn(5, 5))
|
|
|
|
out = fn(x, y)
|
|
(out[0] * out[1]).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(x, y)
|
|
self.assertExpected(str(fn.graph_for(x, y)))
|
|
|
|
def test_output_pruning(self):
|
|
"""Check that stage 1 will take one value as an argument"""
|
|
# One of the outputs doesn't require grad, so it should be pruned
|
|
@torch.jit.compile
|
|
def fn(x, y):
|
|
return x * y, y + y
|
|
|
|
x = Variable(torch.randn(5, 5), requires_grad=True)
|
|
y = Variable(torch.randn(5, 5))
|
|
|
|
out = fn(x, y)
|
|
(out[0] * out[1]).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(x, y)
|
|
self.assertExpected(str(fn.graph_for(x, y)))
|
|
|
|
@skipIfNoTorchVision
|
|
def test_alexnet(self):
|
|
return
|
|
x = Variable(torch.randn(10, 3, 224, 224).fill_(1.0), requires_grad=True)
|
|
trace, _ = torch.jit.get_trace_graph(torchvision.models.AlexNet(), x)
|
|
self.assertExpectedTrace(trace)
|
|
# NB: Purposely NOT testing protobuf export here
|
|
|
|
def test_debug_info(self):
|
|
"""Check that debug info doesn't crash and has some reasonable info"""
|
|
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(x, y):
|
|
return x * y + x + y
|
|
|
|
x = Variable(torch.randn(5, 5), requires_grad=True)
|
|
y = Variable(torch.randn(5, 5), requires_grad=True)
|
|
|
|
out = fn(x, y)
|
|
|
|
out.sum().backward()
|
|
|
|
for _ in range(0, 100):
|
|
out = fn(x, y)
|
|
info_str = fn.jit_debug_info()
|
|
self.assertTrue("hits: 100" in info_str)
|
|
self.assertTrue("stage 1" in info_str)
|
|
|
|
# Inplace copies don't work with tracer yet.
|
|
# This is actually somewhat important to support correctly
|
|
# as all backwards functions of views are implemented
|
|
# as a zero filled tensor with a gradient fill on the
|
|
# viewed portion.
|
|
@unittest.expectedFailure
|
|
def test_inplace_copy(self):
|
|
x = Variable(torch.randn(4, 4), requires_grad=True)
|
|
|
|
def f(x):
|
|
out = Variable(torch.zeros(x.size()))
|
|
out.copy_(x)
|
|
return out
|
|
|
|
trace, z = torch.jit.get_trace_graph(f, (x, ), nderivs=0)
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_index_trace(self):
|
|
x = Variable(torch.randn(4, 4), requires_grad=True)
|
|
trace, z = torch.jit.get_trace_graph(lambda x: x[0], (x, ), nderivs=1)
|
|
z.sum().backward()
|
|
torch._C._jit_pass_lint(trace.graph())
|
|
torch._C._jit_pass_dce(trace.graph())
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_saved_output(self):
|
|
x = Variable(torch.randn(4, 4), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(x):
|
|
return x.sigmoid()
|
|
|
|
fn(x).sum().backward()
|
|
self.assertExpected(str(fn.graph_for(x)))
|
|
|
|
def test_shared_param(self):
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.b = self.a = nn.Parameter(torch.randn(2, 2))
|
|
|
|
def forward(self, x):
|
|
return x * self.a + self.b
|
|
|
|
m = MyModule()
|
|
trace, _ = torch.jit.get_trace_graph(m, (Variable(torch.randn(2, 2)),), nderivs=0)
|
|
self.assertEqual(len(list(trace.graph().inputs())), 2)
|
|
self.assertExpected(str(trace))
|
|
|
|
def test_nested_inplace(self):
|
|
x = Variable(torch.randn(2, 2))
|
|
trace, _ = torch.jit.get_trace_graph(lambda x: F.threshold(x, 0, 0, inplace=True), (x,), nderivs=0)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def checkGraphExecutor(self, func, reference_tensors, input_tensors=None,
|
|
optimize=True, drop=None, allow_unused=False):
|
|
def allSum(vs):
|
|
# drop allows us to remove some values from ever being used
|
|
# to test unused outputs
|
|
if drop is not None:
|
|
vs = vs[:-drop]
|
|
# we don't want all the grad for all the outputs to be the same
|
|
# so we multiply each by a constant
|
|
return sum([(i + 1) * v.sum() for i, v in enumerate(vs) if v is not None])
|
|
if input_tensors is None:
|
|
input_tensors = reference_tensors
|
|
|
|
nograd_inputs = [Variable(t) for t in reference_tensors]
|
|
recording_inputs = [Variable(t, requires_grad=True)
|
|
for t in reference_tensors]
|
|
|
|
ge = torch._C.GraphExecutor(func, [Variable(t) for t in input_tensors], optimize)
|
|
|
|
# test no gradients case
|
|
|
|
outputs = func(*nograd_inputs)
|
|
outputs_ge = ge(*nograd_inputs)
|
|
self.assertEqual(outputs, outputs_ge)
|
|
|
|
# test single grad case
|
|
|
|
outputs = func(*recording_inputs)
|
|
grads = torch.autograd.grad(allSum(outputs), recording_inputs,
|
|
allow_unused=allow_unused)
|
|
|
|
outputs_ge = ge(*recording_inputs)
|
|
grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
|
|
allow_unused=allow_unused)
|
|
self.assertEqual(outputs, outputs_ge)
|
|
self.assertEqual(grads, grads_ge)
|
|
|
|
# test the grad grad case
|
|
|
|
outputs = func(*recording_inputs)
|
|
l1 = allSum(outputs)
|
|
grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
|
|
allow_unused=allow_unused)
|
|
l2 = (allSum(grads) * l1)
|
|
grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
|
|
|
|
recording_inputs = [Variable(t, requires_grad=True)
|
|
for t in reference_tensors]
|
|
|
|
outputs_ge = ge(*recording_inputs)
|
|
l1_ge = allSum(outputs_ge)
|
|
grads_ge = torch.autograd.grad(
|
|
l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
|
|
l2_ge = (allSum(grads_ge) * l1_ge)
|
|
grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
|
|
|
|
self.assertEqual(outputs, outputs_ge)
|
|
self.assertEqual(grads, grads_ge)
|
|
self.assertEqual(grads2, grads2_ge)
|
|
|
|
def run_ge_tests(self, optimize, use_cuda):
|
|
def rand(*args):
|
|
t = torch.rand(*args).float()
|
|
if use_cuda:
|
|
t = t.cuda()
|
|
return t
|
|
self.checkGraphExecutor(lambda a, b: a * b + b,
|
|
[rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
|
|
optimize=optimize)
|
|
# trivial identity
|
|
self.checkGraphExecutor(lambda a, b: (
|
|
b, a), [rand(1), rand(1)], optimize=optimize)
|
|
|
|
def foo(a):
|
|
t = a * a
|
|
return t * t, 4 * t
|
|
self.checkGraphExecutor(foo, [rand(1)], optimize=optimize)
|
|
# unused input
|
|
self.checkGraphExecutor(
|
|
lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
|
|
allow_unused=True)
|
|
# test outputs that do not get used in grad
|
|
self.checkGraphExecutor(foo, [rand(1)], drop=1, optimize=optimize)
|
|
# test autograd fallback
|
|
self.checkGraphExecutor(lambda a, b: a * b /
|
|
(a - 2 * b) + b, [rand(1), rand(1)],
|
|
optimize=optimize)
|
|
|
|
def test_ge_unoptimized(self):
|
|
self.run_ge_tests(False, False)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
def test_ge_optimized(self):
|
|
self.run_ge_tests(True, False)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
def test_ge_cuda(self):
|
|
self.run_ge_tests(True, True)
|
|
|
|
# more manual test of graph executor that can be used as a scratchpad
|
|
def test_ge(self):
|
|
def foo(a, b):
|
|
return a * b / (a - b) + b
|
|
V = Variable
|
|
a, b = V(torch.rand(1)), V(torch.rand(1))
|
|
ge = torch._C.GraphExecutor(foo, (a, b))
|
|
a, b = V(torch.rand(1), requires_grad=True), V(
|
|
torch.rand(1), requires_grad=True)
|
|
r, = ge(a, b)
|
|
da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
|
|
|
|
l2 = (da * db + db * db)
|
|
g2result = torch.autograd.grad(l2, [da, db])
|
|
|
|
r = foo(a, b)
|
|
da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
|
|
self.assertEqual(da, da2)
|
|
self.assertEqual(db, db2)
|
|
l3 = (da2 * db2 + db2 * db2)
|
|
g2result2 = torch.autograd.grad(l3, [da2, db2])
|
|
self.assertEqual(g2result, g2result2)
|
|
|
|
def test_trace_annotation(self):
|
|
@torch.jit.trace(Variable(torch.rand(1)))
|
|
def foo(a):
|
|
return a + a + a
|
|
s = Variable(torch.rand(2))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "calls .cuda()")
|
|
def test_traced_module(self):
|
|
class Model(nn.Module):
|
|
def __init__(self, num_features, num_layers):
|
|
super(Model, self).__init__()
|
|
self.num_layers = num_layers
|
|
layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
|
|
for _ in range(num_layers)]
|
|
self.submodule = nn.Sequential(*chain(*layers))
|
|
|
|
def forward(self, x):
|
|
for i in range(self.num_layers):
|
|
x = self.submodule[i](x) + x
|
|
return x
|
|
|
|
model = Model(5, 3)
|
|
x = torch.randn(2, 5)
|
|
traced_model = torch.jit.trace(x)(model)
|
|
|
|
# We're missing some attributes these modules had initially. Make sure we can
|
|
# still get the __repr__()
|
|
model.__repr__()
|
|
|
|
# XXX: indexing sequentials is broken
|
|
linear_submodule = next(iter(traced_model.submodule._modules.values()))
|
|
|
|
# All attributes that aren't parameters should raise
|
|
with self.assertRaises(AttributeError):
|
|
linear_submodule.in_features
|
|
linear_submodule.weight
|
|
with self.assertRaises(RuntimeError):
|
|
traced_model.asdf = 4
|
|
linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
|
|
with self.assertRaises(RuntimeError):
|
|
del linear_submodule.weight
|
|
|
|
# Submodules can't be called
|
|
with self.assertRaises(RuntimeError):
|
|
linear_submodule(x)
|
|
|
|
# Type casts
|
|
linear_submodule.cuda()
|
|
traced_model.float().cuda()
|
|
cuda_out = traced_model(x.float().cuda())
|
|
traced_model.cpu()
|
|
cpu_out = traced_model(x.float())
|
|
self.assertEqual(cpu_out, cuda_out)
|
|
traced_model.double()
|
|
|
|
# state_dict + load_state_dict
|
|
state = {k: v.clone() for k, v in traced_model.state_dict().items()}
|
|
new_state = {k: v.clone().fill_(1) for k, v in state.items()}
|
|
out = traced_model(x)
|
|
traced_model.load_state_dict(new_state)
|
|
out_ones = traced_model(x)
|
|
traced_model.load_state_dict(state)
|
|
out_state = traced_model(x)
|
|
self.assertEqual(out, out_state)
|
|
self.assertNotEqual(out, out_ones)
|
|
|
|
def test_shape_prop_mismatch_output(self):
|
|
with self.assertRaises(RuntimeError):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def test_shape_prop_mismatch_output(a):
|
|
b = slice(a, dim=0, end=-2, start=2, step=1)
|
|
b = topk(a, dim=0, k=2, largest=True, sorted=True)
|
|
return b
|
|
''')
|
|
inputs = [torch.zeros(10)]
|
|
outputs = [torch.zeros(2), torch.from_numpy(np.array([1, 5])).long()]
|
|
|
|
real_outs = cu.test_shape_prop_mismatch_output(*inputs)
|
|
self.assertEqual(real_outs, outputs)
|
|
|
|
def test_view_shape_prop(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def test_view_shape_prop(a):
|
|
return view(a, size=[-1])
|
|
''')
|
|
inputs = [torch.zeros(10, 10)]
|
|
outputs = torch.zeros(100)
|
|
|
|
real_outs = cu.test_view_shape_prop(*inputs)
|
|
self.assertEqual(real_outs, outputs)
|
|
|
|
def test_integral_shape_inference(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def test_integral_shape_inference(a):
|
|
return a / a
|
|
''')
|
|
inputs = [torch.ones(10, 10).type(torch.LongTensor)]
|
|
outputs = torch.ones(10, 10)
|
|
|
|
self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
|
|
|
|
def test_shape_analysis_broadcast(self):
|
|
def broadcast(a, b):
|
|
return a + b
|
|
|
|
x = torch.randn(3, 1, 5, requires_grad=True)
|
|
y = torch.randn(4, 1, 8, 5, requires_grad=True)
|
|
|
|
graph = torch.jit._script_graph(broadcast)
|
|
torch._C._jit_pass_shape_analysis(graph, (x, y), False)
|
|
self.assertExpected(str(graph))
|
|
|
|
def test_fuser_multiple_blocks(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def test_fuser_multiple_blocks(this, that, theother, meme):
|
|
i = 0
|
|
while i < 20:
|
|
this = cat(this, meme, dim=0)
|
|
that = cat(that, meme, dim=0)
|
|
theother = cat(theother, meme, dim=0)
|
|
i = i + 1
|
|
return this, that, theother
|
|
''')
|
|
|
|
inputs = [torch.ones(0, 10, 10)] * 3
|
|
inputs += [torch.ones(1, 10, 10)]
|
|
outputs = [torch.ones(20, 10, 10)] * 3
|
|
|
|
self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
|
|
|
|
|
|
class TestScript(TestCase):
|
|
|
|
@contextmanager
|
|
def capture_stdout(self):
|
|
# No idea how to capture stdout from C++ on Windows
|
|
if WINDOWS:
|
|
yield ['']
|
|
return
|
|
import os
|
|
import fcntl
|
|
import errno
|
|
sys.stdout.flush()
|
|
stdout_fd = os.dup(1)
|
|
r, w = os.pipe()
|
|
try:
|
|
# Override stdout with r - dup is guaranteed to return the lowest free fd
|
|
os.close(1)
|
|
os.dup(w)
|
|
|
|
captured_stdout = ['']
|
|
yield captured_stdout
|
|
sys.stdout.flush() # Make sure that Python hasn't buffered anything
|
|
|
|
# Do the ugly dance to read all the data that was written into the pipe
|
|
fcntl.fcntl(r, fcntl.F_SETFL, os.O_NONBLOCK)
|
|
total_stdout = ''
|
|
while True:
|
|
try:
|
|
total_stdout += os.read(r, 1000).decode('ascii')
|
|
except OSError as e:
|
|
if e.errno != errno.EAGAIN:
|
|
raise
|
|
break
|
|
captured_stdout[0] = total_stdout
|
|
finally:
|
|
# Revert the change, and clean up all fds
|
|
os.close(1)
|
|
os.dup(stdout_fd)
|
|
os.close(stdout_fd)
|
|
os.close(r)
|
|
os.close(w)
|
|
|
|
def checkScript(self, script, inputs, optimize, outputs=None, name='func', capture_output=False):
|
|
if isinstance(script, str):
|
|
cu = torch.jit.CompilationUnit(script, optimize)
|
|
ge = getattr(cu, name)
|
|
else:
|
|
if capture_output:
|
|
with self.capture_stdout() as captured:
|
|
outputs = script(*inputs)
|
|
else:
|
|
outputs = script(*inputs)
|
|
# Check the string frontend first
|
|
source = textwrap.dedent(inspect.getsource(script))
|
|
self.checkScript(source, inputs, optimize, outputs, script.__name__, capture_output)
|
|
# Continue checking the Python frontend
|
|
ge = torch.jit.script(script)
|
|
|
|
if capture_output:
|
|
with self.capture_stdout() as captured:
|
|
outputs_ge = ge(*inputs)
|
|
if not WINDOWS:
|
|
self.assertExpected(captured[0], subname='stdout')
|
|
else:
|
|
outputs_ge = ge(*inputs)
|
|
self.assertEqual(outputs, outputs_ge)
|
|
|
|
def test_script_cu(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(a):
|
|
b = a
|
|
return b
|
|
''')
|
|
a = Variable(torch.rand(1))
|
|
self.assertEqual(a, cu.foo(a))
|
|
|
|
def test_script_annotation(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return a + a + a
|
|
s = Variable(torch.rand(2))
|
|
self.assertEqual(s + s + s, foo(s))
|
|
|
|
def test_add(self):
|
|
def func(a, b):
|
|
c = a + b
|
|
c += a
|
|
return c
|
|
|
|
a = torch.rand(1, requires_grad=True)
|
|
b = torch.rand(1, requires_grad=True)
|
|
self.checkScript(func, (a, b), optimize=True)
|
|
|
|
def test_mul(self):
|
|
def func(a, b):
|
|
return a * b
|
|
|
|
a = torch.rand(1, requires_grad=True)
|
|
b = torch.rand(1, requires_grad=True)
|
|
self.checkScript(func, (a, b), optimize=True)
|
|
|
|
def test_triple(self):
|
|
def func(x):
|
|
return 3. * x
|
|
|
|
x = torch.rand(1, dtype=torch.float, requires_grad=True)
|
|
self.checkScript(func, [x], optimize=True)
|
|
|
|
def test_slice(self):
|
|
def func(x):
|
|
return x[:5]
|
|
|
|
x = torch.rand(10, dtype=torch.float, requires_grad=True)
|
|
self.checkScript(func, [x], optimize=True)
|
|
|
|
def test_gather(self):
|
|
def func(x):
|
|
return x[0]
|
|
|
|
x = torch.rand(10, dtype=torch.float, requires_grad=True)
|
|
self.checkScript(func, [x], optimize=True)
|
|
|
|
def test_keyword(self):
|
|
@torch.jit.script
|
|
def func(x):
|
|
return torch.sum(x, dim=0, keepdim=True)
|
|
|
|
x = torch.rand(10, dtype=torch.float, requires_grad=True)
|
|
y = func(x)
|
|
y2 = torch.sum(x, dim=0, keepdim=True)
|
|
self.assertEqual(y, y2)
|
|
|
|
def test_func_call(self):
|
|
script = '''
|
|
def add(a, b):
|
|
return a + b
|
|
|
|
def mul(a, x):
|
|
return a * x
|
|
|
|
def func(alpha, beta, x, y):
|
|
return add(mul(alpha, x), mul(beta, y))
|
|
'''
|
|
alpha = torch.rand(1, dtype=torch.float, requires_grad=True)
|
|
beta = torch.rand(1, dtype=torch.float, requires_grad=True)
|
|
x = torch.rand(3, dtype=torch.float, requires_grad=True)
|
|
y = torch.rand(3, dtype=torch.float, requires_grad=True)
|
|
outputs = alpha * x + beta * y
|
|
# NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
|
|
self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
|
|
|
|
@unittest.skip("RuntimeError: VariableType::ID() not implemented")
|
|
def test_cast(self):
|
|
script = '''
|
|
def to_int(x):
|
|
return int(x)
|
|
'''
|
|
x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True)
|
|
out = Variable(torch.IntTensor([1, 2]), requires_grad=True)
|
|
self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int')
|
|
|
|
def test_python_frontend(self):
|
|
def fn(x, y, z):
|
|
q = x + y - z.sigmoid()
|
|
print(q)
|
|
w = -z
|
|
if not x and not y and z:
|
|
m = x if not z else y
|
|
while x < y > z:
|
|
q = x
|
|
return x
|
|
|
|
ast = torch.jit.frontend.get_jit_ast(fn)
|
|
self.assertExpected(str(ast))
|
|
|
|
def _make_scalar_vars(self, arr, dtype):
|
|
return [torch.tensor(val, dtype=dtype) for val in arr]
|
|
|
|
def test_while(self):
|
|
def func(a, b, max):
|
|
while a < max:
|
|
a = a + 1
|
|
b = b + 1
|
|
c = a + b
|
|
return c
|
|
|
|
inputs = self._make_scalar_vars([1, 1, 10], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_fibb(self):
|
|
def func(lim):
|
|
first = 1
|
|
second = 1
|
|
i = 1
|
|
somenum = 5
|
|
dontmutateme = 3
|
|
third = 0
|
|
while i < lim:
|
|
third = first + second
|
|
first = second
|
|
second = third
|
|
j = 0
|
|
while j < 10:
|
|
somenum = somenum * 2
|
|
j = j + 1
|
|
i = i + j
|
|
i = i + dontmutateme
|
|
|
|
st = second + third
|
|
fs = first + second
|
|
return third, st, fs
|
|
|
|
inputs = self._make_scalar_vars([10], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_if(self):
|
|
def func(a, b):
|
|
d = 3
|
|
if a > 10:
|
|
a = 3 + d
|
|
else:
|
|
b = 3 + d
|
|
d = 4
|
|
c = a + b
|
|
return c
|
|
|
|
inputs = self._make_scalar_vars([1, -1], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_if_noelse(self):
|
|
def func(a, b):
|
|
if a > 10:
|
|
a = 3 + b
|
|
c = a + b
|
|
return c
|
|
|
|
inputs = self._make_scalar_vars([-1, 1], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_while_nonexistent_value(self):
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value x"):
|
|
torch.jit.CompilationUnit('''
|
|
def test_while(a, b):
|
|
while a < 10:
|
|
a = a + x
|
|
b = b + 1
|
|
return a + b
|
|
''')
|
|
|
|
def test_while_nonexistent_cond_value(self):
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value x"):
|
|
torch.jit.CompilationUnit('''
|
|
def test_while(a, b):
|
|
while a < x:
|
|
a = a + 1
|
|
b = b + 1
|
|
return a + b
|
|
''')
|
|
|
|
def test_while_write_outer_then_read(self):
|
|
def func(a, b):
|
|
while a < 10:
|
|
a = a + 1
|
|
b = a + 1
|
|
return a + b
|
|
|
|
inputs = self._make_scalar_vars([42, 1337], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_while_nest_if(self):
|
|
def func(a, b):
|
|
c = 0
|
|
while a < 10:
|
|
a = a + 1
|
|
b = b + 1
|
|
if a > b:
|
|
c = -a
|
|
else:
|
|
c = -b
|
|
return c + 1
|
|
|
|
inputs = self._make_scalar_vars([-1234, 4321], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_if_nest_while(self):
|
|
def func(a, b):
|
|
c = 0
|
|
if a > b:
|
|
while a > b:
|
|
b = b + 1
|
|
c = -b
|
|
return c
|
|
|
|
inputs = self._make_scalar_vars([4321, 1234], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_script_for_in_range(self):
|
|
script = '''
|
|
def test_for_in_range():
|
|
c = 0
|
|
for i in range(100):
|
|
c += i
|
|
return c
|
|
'''
|
|
self.checkScript(script, [], outputs=[4950], optimize=True, name='test_for_in_range')
|
|
|
|
def test_script_for_in_range_dynamic(self):
|
|
script = '''
|
|
def test_script_for_in_range_dynamic():
|
|
c = 0
|
|
for i in range(100):
|
|
acc = 0
|
|
for j in range(i):
|
|
acc += j
|
|
c += acc
|
|
return c
|
|
'''
|
|
self.checkScript(script, [], outputs=[161700], optimize=True, name='test_script_for_in_range_dynamic')
|
|
|
|
def test_script_for_in_range_ast(self):
|
|
@torch.jit.script
|
|
def test_script_for_in_range_ast(zero):
|
|
c = zero
|
|
for i in range(100):
|
|
acc = zero
|
|
for j in range(i):
|
|
acc += j
|
|
c += acc
|
|
return c
|
|
|
|
inputs = self._make_scalar_vars([0], torch.int64)
|
|
|
|
self.assertEqual(test_script_for_in_range_ast(*inputs), 161700)
|
|
|
|
def test_script_bool_constant(self):
|
|
script = '''
|
|
def test_script_bool_constant():
|
|
a = True
|
|
return a
|
|
'''
|
|
outputs = [1]
|
|
self.checkScript(script, [], outputs[0], True, 'test_script_bool_constant')
|
|
|
|
def test_ternary(self):
|
|
def func(a, b):
|
|
c = 3
|
|
c = a + b if a > 3 else b
|
|
return c
|
|
|
|
inputs_true = self._make_scalar_vars([5, 2], torch.int64)
|
|
inputs_false = self._make_scalar_vars([1, 0], torch.int64)
|
|
self.checkScript(func, inputs_true, optimize=True)
|
|
self.checkScript(func, inputs_false, optimize=True)
|
|
|
|
def test_print(self):
|
|
def func(x, y):
|
|
q = (x + y).sigmoid()
|
|
print(q)
|
|
w = -q
|
|
return w * w
|
|
|
|
x = torch.arange(4, requires_grad=True)
|
|
y = torch.arange(0, 8, 2, requires_grad=True)
|
|
self.checkScript(func, [x, y], optimize=True, capture_output=True)
|
|
|
|
def test_multiple_assignment(self):
|
|
def outer_func(x):
|
|
return x * 2, x + 2
|
|
|
|
@torch.jit.script
|
|
def func(x):
|
|
y, z = outer_func(x)
|
|
return y + z
|
|
|
|
x = torch.arange(4)
|
|
self.assertEqual(func(x), x * 2 + x + 2)
|
|
|
|
def test_literals(self):
|
|
def func(a):
|
|
return a.view(size=[1, 2, 3])
|
|
|
|
a = torch.randn(6)
|
|
self.checkScript(func, [a], optimize=True)
|
|
|
|
def test_return(self):
|
|
def no_return(a):
|
|
a + 1
|
|
|
|
def void_return(a):
|
|
return
|
|
|
|
def one_return(a):
|
|
return a + 1.
|
|
|
|
def multiple_returns(a):
|
|
return a * 1., a * 2., a * 3.
|
|
|
|
a = torch.randn(1, dtype=torch.float)
|
|
self.checkScript(no_return, [a], optimize=True)
|
|
self.checkScript(void_return, [a], optimize=True)
|
|
self.checkScript(one_return, [a], optimize=True)
|
|
self.checkScript(multiple_returns, [a], optimize=True)
|
|
|
|
def test_error(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return a.t()
|
|
s = Variable(torch.rand(10))
|
|
# XXX: this should stay quiet in stay propagation and only fail in the interpreter
|
|
with self.assertRaisesRegex(RuntimeError, "failed in interpreter"):
|
|
foo(s)
|
|
|
|
@torch.jit.script
|
|
def bar(c, b):
|
|
return c / b
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "failed in interpreter"):
|
|
bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True))
|
|
|
|
def test_binop_unsupported_error(self):
|
|
with self.assertRaisesRegex(NotSupportedError, "unsupported binary operator:"):
|
|
@torch.jit.script
|
|
def binop(x, y):
|
|
# Replace this with another unsupported op when/if it gets supported
|
|
return x ** y
|
|
|
|
def test_python_call(self):
|
|
def pyfunc(a):
|
|
return a * 3.0
|
|
|
|
cu = torch.jit.CompilationUnit('''
|
|
def other_func(a):
|
|
return a + a
|
|
|
|
def test_call_python(a):
|
|
b = pyfunc(a)
|
|
b = other_func(b)
|
|
i = 0
|
|
step = 1
|
|
while i < 10:
|
|
b = pyfunc(b)
|
|
if b > 3.0:
|
|
b = pyfunc(b)
|
|
i = 11
|
|
return b
|
|
''')
|
|
inputs = self._make_scalar_vars([1], torch.float)
|
|
outputs = self._make_scalar_vars([54], torch.float)
|
|
|
|
self.assertEqual(cu.test_call_python(*inputs), outputs[0])
|
|
|
|
def test_python_call_failure(self):
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
|
|
def pyfunc(a):
|
|
return a * 3.0
|
|
|
|
cu = torch.jit.CompilationUnit('''
|
|
def other_func(a):
|
|
return a + a
|
|
|
|
def test_call_python(a):
|
|
b = pyfunc(a)
|
|
b = other_func(b)
|
|
i = 0
|
|
step = 1
|
|
while i < 10:
|
|
b = pyfunc2(b)
|
|
if b > 3.0:
|
|
b = pyfunc(b)
|
|
i = 11
|
|
return b
|
|
''')
|
|
inputs = self._make_scalar_vars([1], torch.float)
|
|
outputs = self._make_scalar_vars([54], torch.float)
|
|
|
|
self.assertEqual(cu.test_call_python(*inputs), outputs)
|
|
|
|
def test_python_call_annotation(self):
|
|
def pyfunc(a):
|
|
return a * 3.0
|
|
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return pyfunc(a) + pyfunc(a)
|
|
|
|
inputs = self._make_scalar_vars([1], torch.float)
|
|
outputs = self._make_scalar_vars([6], torch.float)
|
|
self.assertEqual(foo(*inputs), outputs[0])
|
|
|
|
def test_python_call_annoytation_failure(self):
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
|
|
def pyfunc(a):
|
|
return a * 3.0
|
|
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return pyfunc2(a) + pyfunc(a)
|
|
|
|
inputs = self._make_scalar_vars([1], torch.float)
|
|
outputs = self._make_scalar_vars([6], torch.float)
|
|
|
|
self.assertEqual(foo(*inputs), outputs[0])
|
|
|
|
def test_desugar_module(self):
|
|
import torch.nn.functional as F
|
|
|
|
def fn(x, slope):
|
|
a = torch.abs(x)
|
|
b = torch.nn.functional.prelu(x, slope)
|
|
c = F.prelu(x, slope)
|
|
return a, b, c
|
|
|
|
x = torch.arange(-3, 4)
|
|
slope = torch.tensor([0.5])
|
|
self.checkScript(fn, [x, slope], optimize=True)
|
|
|
|
def test_script_module(self):
|
|
class M1(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M1, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class PModule(nn.Module):
|
|
def __init__(self):
|
|
super(PModule, self).__init__()
|
|
self.a = nn.Parameter(torch.randn(2, 3))
|
|
|
|
def forward(self, a):
|
|
return self.a.mm(a)
|
|
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(False)
|
|
# test submodule
|
|
self.sub = M1()
|
|
self.sub2 = PModule()
|
|
# test parameters
|
|
self.weight = nn.Parameter(torch.randn(2, 3))
|
|
self.bias = nn.Parameter(torch.randn(2))
|
|
# test defining a method from a string
|
|
self.define("""
|
|
def hi(self, a):
|
|
return self.weight.mm(a)
|
|
""")
|
|
# test script methods
|
|
|
|
@torch.jit.script_method
|
|
def doit(self, input):
|
|
# test use of parameter
|
|
return self.weight.mm(input)
|
|
|
|
@torch.jit.script_method
|
|
def doit2(self, input):
|
|
return self.weight.mm(input)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
a = self.doit(input)
|
|
b = self.doit2(input)
|
|
c = self.hi(input)
|
|
d = self.sub2(input)
|
|
return a + b + self.bias + self.sub(a) + c + d
|
|
m2 = M2()
|
|
input = torch.randn(3, 2)
|
|
a = m2.weight.mm(input)
|
|
b = m2.weight.mm(input)
|
|
c = m2.weight.mm(input)
|
|
d = m2.sub2.a.mm(input)
|
|
ref = a + b + m2.bias + m2.sub.weight + a + c + d
|
|
self.assertEqual(ref, m2.forward(input))
|
|
m2.weight = nn.Parameter(torch.zeros_like(m2.weight))
|
|
m2.bias = nn.Parameter(torch.zeros_like(m2.bias))
|
|
m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight))
|
|
m2.sub2.a.data.zero_()
|
|
self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
|
|
|
|
def test_script_module_call_noscript(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.value = 1
|
|
|
|
def foo(self):
|
|
return torch.ones(2, 2) + self.value
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input + self.foo()
|
|
|
|
m = M()
|
|
input = torch.randn(2, 2)
|
|
o = m(input)
|
|
self.assertEqual(o, input + torch.ones(2, 2) + 1)
|
|
# check that we can change python attributes
|
|
# and that those changes are picked up in script methods
|
|
m.value = 2
|
|
o = m(input)
|
|
self.assertEqual(o, input + torch.ones(2, 2) + 2)
|
|
|
|
def test_script_module_nochange_submodule(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.sub = nn.Linear(5, 5)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return self.sub(input)
|
|
|
|
m = M()
|
|
input = torch.randn(1, 5, 5)
|
|
o = m(input)
|
|
self.assertEqual(o, m.sub(input))
|
|
with self.assertRaisesRegex(RuntimeError, "cannot re-assign"):
|
|
m.sub = nn.Linear(5, 5)
|
|
|
|
def test_script_inline_trace_multiple_args(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
|
|
def forward(self, input, input2):
|
|
return input + input2
|
|
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(False)
|
|
self.m = torch.jit.trace(torch.zeros(4, 3), torch.zeros(4, 3))(M())
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, inp):
|
|
return self.m(inp, inp)
|
|
|
|
m2 = M2()
|
|
m2(torch.zeros(4, 3))
|
|
|
|
def test_script_module_const(self):
|
|
class M(torch.jit.ScriptModule):
|
|
|
|
__constants__ = ['b', 'i', 'c']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.b = False
|
|
self.i = 1
|
|
self.c = 3.5
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return self.b, self.i, self.c
|
|
|
|
m = M()
|
|
o0, o1, o2 = m()
|
|
self.assertEqual(o0, 0)
|
|
self.assertEqual(o1, 1)
|
|
self.assertEqual(o2, 3.5)
|
|
|
|
def test_script_module_fail_const(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.b = False
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return self.b
|
|
with self.assertRaisesRegex(RuntimeError, "is not usable in a script method"):
|
|
M()
|
|
|
|
def test_script_module_valid_consts(self):
|
|
class Foo(torch.jit.ScriptModule):
|
|
__constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']
|
|
|
|
def __init__(self):
|
|
super(Foo, self).__init__(False)
|
|
self.a = 1
|
|
self.b = 1.2
|
|
self.c = False
|
|
self.d = [nn.Linear(3, 4)]
|
|
self.e = lambda x: x
|
|
self.f = [3, 4, 5]
|
|
self.assertTrue(type(self.f) is tuple)
|
|
self.g = [3, (3, 4), 5]
|
|
with self.assertRaisesRegex(TypeError, "is not a valid constant"):
|
|
self.h = type(1)
|
|
with self.assertRaisesRegex(TypeError, "is not a valid constant"):
|
|
self.i = (3, 4, {})
|
|
|
|
# https://github.com/pytorch/pytorch/issues/6714
|
|
@unittest.expectedFailure
|
|
def test_script_module_for(self):
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['b']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.b = [1, 2, 3, 4]
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
sum = 0
|
|
for i in self.b:
|
|
sum += i
|
|
return sum
|
|
|
|
m = M()
|
|
self.assertEqual(m(), 10)
|
|
|
|
def test_script_module_for2(self):
|
|
class Sub(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Sub, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['mods']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.mods = nn.ModuleList([Sub() for i in range(10)])
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
for m in self.mods:
|
|
v = m(v)
|
|
return v
|
|
|
|
i = torch.Tensor(2)
|
|
m = M()
|
|
o = m(i)
|
|
v = i
|
|
for sub in m.mods:
|
|
v = sub(v)
|
|
self.assertEqual(o, v)
|
|
|
|
def test_script_module_const_submodule_fail(self):
|
|
class Sub(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Sub, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.mods = [Sub() for _ in range(10)]
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
for _ in self.mods:
|
|
print(1)
|
|
return 4
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "did you forget to add it __constants__"):
|
|
M()
|
|
|
|
def test_script_module_not_tuple(self):
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['mods']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.mods = 1
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
for m in self.mods:
|
|
print(m)
|
|
return v
|
|
with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
|
|
M()
|
|
|
|
class StarTestSumStarred(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestScript.StarTestSumStarred, self).__init__()
|
|
|
|
def forward(self, *inputs):
|
|
output = inputs[0]
|
|
for i in range(1, len(inputs)):
|
|
output += inputs[i]
|
|
return output
|
|
|
|
class StarTestReturnThree(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestScript.StarTestReturnThree, self).__init__()
|
|
|
|
def forward(self, rep):
|
|
return rep, rep, rep
|
|
|
|
def test_script_star_expr(self):
|
|
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
self.m = torch.jit.trace(
|
|
torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3))(TestScript.StarTestSumStarred())
|
|
self.g = torch.jit.trace(torch.ones(4, 3))(TestScript.StarTestReturnThree())
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, rep):
|
|
tup = self.g(rep)
|
|
return self.m(*tup)
|
|
|
|
m = M2()
|
|
self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
|
|
|
|
def test_script_star_expr_string(self):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
self.m = torch.jit.trace(
|
|
torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3))(TestScript.StarTestSumStarred())
|
|
self.g = torch.jit.trace(torch.ones(4, 3))(TestScript.StarTestReturnThree())
|
|
|
|
self.define('''
|
|
def forward(self, rep):
|
|
tup = self.g(rep)
|
|
return self.m(*tup)
|
|
''')
|
|
|
|
m = M2()
|
|
self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
|
|
|
|
class StarTestSumAndReturnThree(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestScript.StarTestSumAndReturnThree, self).__init__()
|
|
|
|
def forward(self, *inputs):
|
|
output = inputs[0]
|
|
for i in range(1, len(inputs)):
|
|
output += inputs[i]
|
|
return output, output, output
|
|
|
|
def test_script_star_assign(self):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
self.g = torch.jit.trace(torch.ones(4, 3))(TestScript.StarTestSumAndReturnThree())
|
|
self.define('''
|
|
def forward(self, rep):
|
|
head, *tail = self.g(rep)
|
|
return head
|
|
''')
|
|
|
|
m = M2()
|
|
self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
|
|
|
|
def test_script_module_star_assign2(self):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
self.g = torch.jit.trace(
|
|
torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)
|
|
)(
|
|
TestScript.StarTestSumAndReturnThree()
|
|
)
|
|
self.define('''
|
|
def forward(self, rep):
|
|
*head, tail = self.g(rep, rep, rep)
|
|
return tail
|
|
''')
|
|
|
|
m = M2()
|
|
self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
|
|
|
|
def test_script_module_star_assign_fail_pythonop(self):
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "value cannot be used as a tuple"):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
|
|
def myfunc():
|
|
return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
|
|
|
|
self.define('''
|
|
def forward(self, rep):
|
|
a, *b = myfunc()
|
|
return a
|
|
''')
|
|
|
|
m = M2()
|
|
m(torch.zeros(4, 3))
|
|
|
|
def test_script_module_star_assign_fail_builtin(self):
|
|
with self.assertRaisesRegex(RuntimeError, "value cannot be used as a tuple"):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
|
|
self.define('''
|
|
def forward(self, rep):
|
|
a, *b = torch.neg(rep)
|
|
return a
|
|
''')
|
|
|
|
m = M2()
|
|
m(torch.zeros(4, 3))
|
|
|
|
def test_pack_padded_pad_packed_trace(self):
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
T, B, C = 3, 5, 7
|
|
|
|
class PadPackedWrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super(PadPackedWrapper, self).__init__()
|
|
|
|
def forward(self, x, seq_lens):
|
|
x = pack_padded_sequence(x, seq_lens)
|
|
x, _ = pad_packed_sequence(x)
|
|
return x
|
|
|
|
x = np.ones((T, B, C))
|
|
seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
|
|
# set padding value so we can test equivalence
|
|
for b in range(B):
|
|
if seq_lens[b] < T:
|
|
x[seq_lens[b]:, b, :] = 0
|
|
seq_lens = torch.from_numpy(seq_lens)
|
|
x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)
|
|
|
|
m = PadPackedWrapper()
|
|
m_traced = torch.jit.trace(x, seq_lens)(m)
|
|
|
|
y = m(x, seq_lens)
|
|
loss = torch.sum(y)
|
|
loss.backward()
|
|
grad = x.grad.clone()
|
|
x.grad.zero_()
|
|
|
|
y_traced = m_traced(x, seq_lens)
|
|
loss_traced = torch.sum(y_traced)
|
|
loss_traced.backward()
|
|
grad_traced = x.grad.clone()
|
|
|
|
self.assertEqual(y_traced, x)
|
|
self.assertEqual(y_traced, y)
|
|
self.assertEqual(grad, grad_traced)
|
|
|
|
f = io.BytesIO()
|
|
torch.onnx._export(m, (x, seq_lens), f, verbose=False)
|
|
|
|
def test_script_outputs(self):
|
|
with self.assertRaisesRegex(RuntimeError, "value cannot be used as a tuple"):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
c, d = a + a
|
|
return c + d
|
|
|
|
@torch.jit.script
|
|
def return3():
|
|
return 1, 2, 3
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
|
|
@torch.jit.script
|
|
def bind2():
|
|
a, b = return3()
|
|
print(a)
|
|
print(b)
|
|
|
|
def test_script_chunk(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
b, c = torch.chunk(a, dim=0, chunks=2)
|
|
return b
|
|
v = torch.rand(10, 3)
|
|
self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
b, c = torch.chunk(a, dim=0, chunks=3)
|
|
return b
|
|
|
|
def test_rnn_trace_override(self):
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
num_layers = 3
|
|
T, B, C = 11, 5, 7
|
|
|
|
class RNNTraceWrapper(torch.nn.Module):
|
|
def __init__(self, cell_type):
|
|
super(RNNTraceWrapper, self).__init__()
|
|
if cell_type == 'RNN':
|
|
self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
|
|
elif cell_type == 'LSTM':
|
|
self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
|
|
elif cell_type == 'GRU':
|
|
self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
|
|
|
|
def forward(self, x, seq_lens):
|
|
x = pack_padded_sequence(x, seq_lens)
|
|
x, _ = self.rnn(x)
|
|
x, _ = pad_packed_sequence(x)
|
|
return x
|
|
|
|
for cell_type in ['RNN', 'LSTM', 'GRU']:
|
|
x = torch.ones(T, B, C, requires_grad=True)
|
|
seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
|
|
|
|
m = RNNTraceWrapper(cell_type)
|
|
m_traced = torch.jit.trace(x, seq_lens)(m)
|
|
|
|
y = m(x, seq_lens)
|
|
loss = torch.sum(y)
|
|
loss.backward()
|
|
grad = x.grad.clone()
|
|
x.grad.zero_()
|
|
|
|
y_traced = m_traced(x, seq_lens)
|
|
loss_traced = torch.sum(y_traced)
|
|
loss_traced.backward()
|
|
grad_traced = x.grad.clone()
|
|
|
|
self.assertEqual(y_traced, y)
|
|
self.assertEqual(grad, grad_traced)
|
|
|
|
f = io.BytesIO()
|
|
torch.onnx._export(m, (x, seq_lens), f, verbose=False)
|
|
|
|
def test_tuples(self):
|
|
@torch.jit.script
|
|
def foo(i):
|
|
a = torch.chunk(i, dim=0, chunks=2)
|
|
c = a
|
|
# some nonsense with if-statements and loops to check
|
|
# that tuple lowering doesn't fail
|
|
if True:
|
|
c = torch.chunk(i, dim=0, chunks=2)
|
|
t0, t1 = c
|
|
while False:
|
|
t0, t1 = c
|
|
c = torch.chunk(i, dim=0, chunks=2)
|
|
return t0
|
|
|
|
v = torch.rand(10, 3)
|
|
self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "variable 'a' previously has type"):
|
|
@torch.jit.script
|
|
def mixtypes():
|
|
a = torch.chunk(1, dim=0, chunks=2)
|
|
if True:
|
|
a = 4
|
|
|
|
def test_script_define_order(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
pass
|
|
|
|
@torch.jit.script_method
|
|
def call_foo(self, input):
|
|
return self.foo(input)
|
|
|
|
@torch.jit.script_method
|
|
def foo(self, input):
|
|
return input + 1
|
|
m = M()
|
|
self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
|
|
|
|
def test_script_define_order_recursive_fail(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
pass
|
|
|
|
@torch.jit.script_method
|
|
def call_foo(self, input):
|
|
return self.foo(input)
|
|
|
|
@torch.jit.script_method
|
|
def foo(self, input):
|
|
self.call_foo(input)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'called recursively involving'):
|
|
M()
|
|
|
|
def test_trace_of_script(self):
|
|
@torch.jit.script
|
|
def foo(a, c):
|
|
b = 0
|
|
if a == 0:
|
|
b = 1
|
|
return b + c
|
|
|
|
a = torch.ones(1, dtype=torch.long)
|
|
|
|
@torch.jit.trace(torch.zeros(1, dtype=torch.long))
|
|
def use(b):
|
|
return foo(b - 1, a) + 1
|
|
|
|
self.assertEqual(3, use(torch.ones(1, dtype=torch.long)))
|
|
self.assertEqual(2, use(torch.zeros(1, dtype=torch.long)))
|
|
|
|
|
|
# Smoke tests for export methods
|
|
class TestPytorchExportModes(unittest.TestCase):
|
|
class MyModel(nn.Module):
|
|
def __init__(self):
|
|
super(TestPytorchExportModes.MyModel, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x.t()
|
|
|
|
def test_protobuf(self):
|
|
torch_model = TestPytorchExportModes.MyModel()
|
|
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
|
|
f = io.BytesIO()
|
|
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
|
|
export_type=torch.onnx.ExportTypes.PROTOBUF_FILE)
|
|
|
|
def test_zipfile(self):
|
|
torch_model = TestPytorchExportModes.MyModel()
|
|
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
|
|
f = io.BytesIO()
|
|
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
|
|
export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE)
|
|
|
|
def test_compressed_zipfile(self):
|
|
torch_model = TestPytorchExportModes.MyModel()
|
|
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
|
|
f = io.BytesIO()
|
|
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
|
|
export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE)
|
|
|
|
def test_directory(self):
|
|
torch_model = TestPytorchExportModes.MyModel()
|
|
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
|
|
d = tempfile.mkdtemp()
|
|
torch.onnx._export(torch_model, (fake_input), d, verbose=False,
|
|
export_type=torch.onnx.ExportTypes.DIRECTORY)
|
|
shutil.rmtree(d)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|