mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Implement MM fusion (MM with add reduction tree) A tree where leaves are matrix multiplies and inner vertices are adds can be computed as a single mm. Such subgraph often appear in backward if a single weight is reused multiple times (e.g. in RNNs). NOTE: this seems to be slightly slower on the GPU than the naive implementation, but it's a huge win on the CPU (think 100x lower overhead)
1171 lines
38 KiB
Python
1171 lines
38 KiB
Python
import torch
|
|
import torch.jit
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import unittest
|
|
from contextlib import contextmanager
|
|
from itertools import product
|
|
from torch.autograd import Variable, Function
|
|
from torch.autograd.function import traceable
|
|
from common import TestCase, run_tests, IS_WINDOWS
|
|
import io
|
|
import sys
|
|
|
|
try:
|
|
import torchvision
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
|
|
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
RUN_CUDA = torch.cuda.is_available()
|
|
if torch.cuda.is_available():
|
|
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
|
|
for d in range(torch.cuda.device_count()):
|
|
major = torch.cuda.get_device_capability(d)[0]
|
|
if (CUDA_VERSION < 8000 and major >= 6) or (CUDA_VERSION < 9000 and major >= 7):
|
|
RUN_CUDA = False
|
|
|
|
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
|
|
|
|
|
|
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
|
|
hx, cx = hidden
|
|
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
|
|
|
|
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
|
ingate = F.sigmoid(ingate)
|
|
forgetgate = F.sigmoid(forgetgate)
|
|
cellgate = F.tanh(cellgate)
|
|
outgate = F.sigmoid(outgate)
|
|
|
|
cy = (forgetgate * cx) + (ingate * cellgate)
|
|
hy = outgate * F.tanh(cy)
|
|
return hy, cy
|
|
|
|
|
|
def LSTMCellC(*args, **kwargs):
|
|
hy, cy = LSTMCell(*args, **kwargs)
|
|
return torch.cat((hy, cy))
|
|
|
|
|
|
class TestJit(TestCase):
|
|
maxDiff = None
|
|
|
|
@contextmanager
|
|
def assertCompiled(self, compiled_fn):
|
|
self.assertIsInstance(compiled_fn, torch._C.CompiledFunction)
|
|
hits, misses = compiled_fn.hits, compiled_fn.misses
|
|
yield
|
|
self.assertLess(hits, compiled_fn.hits)
|
|
self.assertEqual(misses, compiled_fn.misses)
|
|
|
|
def assertExpectedTrace(self, trace, *args, **kwargs):
|
|
torch._C._jit_pass_lint(trace)
|
|
torch._C._jit_pass_dce(trace)
|
|
torch._C._jit_pass_lint(trace)
|
|
torch._C._jit_pass_canonicalize(trace)
|
|
torch._C._jit_pass_lint(trace)
|
|
self.assertExpected(str(trace), *args, **kwargs)
|
|
|
|
def test_simple(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
def f(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
# matmul is currently implemented as a native function, which
|
|
# exercises different codepaths in the JIT. The following two
|
|
# tests ensure that (1) matmul indeed traces into an atomic,
|
|
# native operation, and (2) the JIT knows how to run it
|
|
|
|
def test_matmul_native(self):
|
|
x = Variable(torch.Tensor([[0.4]]), requires_grad=True)
|
|
y = Variable(torch.Tensor([[0.7]]), requires_grad=True)
|
|
|
|
trace, z = torch.jit.trace(lambda x, y: x.matmul(y), (x, y), nderivs=0)
|
|
torch._C._jit_pass_lint(trace)
|
|
torch._C._jit_pass_dce(trace)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_matmul_native_run(self):
|
|
x = Variable(torch.Tensor([[0.4]]), requires_grad=True)
|
|
y = Variable(torch.Tensor([[0.7]]), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def fn(x, y):
|
|
return x.matmul(y)
|
|
|
|
z = fn(x, y)
|
|
with self.assertCompiled(fn):
|
|
z2 = fn(x, y)
|
|
self.assertEqual(z, z2)
|
|
|
|
# index-2 is not implemented in interpreter
|
|
@unittest.expectedFailure
|
|
def test_index(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.LongTensor([0]), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def fn(x, y):
|
|
return x[y]
|
|
|
|
z = fn(x, y)
|
|
with self.assertCompiled(fn):
|
|
z2 = fn(x, y)
|
|
self.assertEqual(z, z2)
|
|
|
|
# Backwards tracing was broken for indexing by a constant,
|
|
# because it's internally implemented using as_strided,
|
|
# and we attempted to trace its derivative (which is not
|
|
# currently supported.) It currently works because
|
|
# slice() is now not marked as traceable.
|
|
def test_index_constant(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(x):
|
|
return x[0]
|
|
|
|
z = fn(x)
|
|
z.backward()
|
|
grad = x.grad.clone()
|
|
x.grad.zero_()
|
|
with self.assertCompiled(fn):
|
|
z2 = fn(x)
|
|
z2.backward()
|
|
grad2 = x.grad.clone()
|
|
self.assertEqual(z, z2)
|
|
self.assertEqual(grad, grad2)
|
|
|
|
def test_scopes(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
def f(x, y):
|
|
out = x + y
|
|
with torch.jit.scope('Foo', out):
|
|
out = x * out
|
|
with torch.jit.scope('Bar', out):
|
|
out = torch.tanh(out)
|
|
out = torch.sigmoid(out)
|
|
return out
|
|
|
|
trace, z = torch.jit.trace(f, (x, y), nderivs=0)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_scopes_intermediate_node(self):
|
|
|
|
class Net(nn.Module):
|
|
def forward(self, x):
|
|
return F.log_softmax(x, dim=0)
|
|
|
|
net = Net()
|
|
t = Variable(torch.ones(2), requires_grad=True)
|
|
trace, _ = torch.jit.trace(net, (t, ))
|
|
torch.onnx._optimize_trace(trace, False)
|
|
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_scopes_identity_node(self):
|
|
|
|
class Net(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=3, stride=2),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
return x
|
|
|
|
model = Net()
|
|
|
|
t = Variable(torch.ones(1, 3, 227, 227), requires_grad=True)
|
|
|
|
with torch.onnx.set_training(model, False):
|
|
trace, _ = torch.jit.trace(model, (t, ))
|
|
|
|
torch.onnx._optimize_trace(trace, False)
|
|
|
|
self.assertExpectedTrace(trace)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_lstm_fusion(self):
|
|
input = Variable(torch.randn(3, 10).float().cuda())
|
|
hx = Variable(torch.randn(3, 20).float().cuda())
|
|
cx = Variable(torch.randn(3, 20).float().cuda())
|
|
module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes
|
|
|
|
trace, _ = torch.jit.trace(LSTMCell, (input, (hx, cx)) + tuple(module.parameters()))
|
|
torch._C._jit_pass_lint(trace)
|
|
torch._C._jit_pass_dce(trace)
|
|
torch._C._jit_pass_lint(trace)
|
|
torch._C._jit_pass_fuse(trace)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def run_lstm_fusion(self, use_cuda):
|
|
def to_type(x):
|
|
x = x.float()
|
|
if use_cuda:
|
|
x = x.cuda()
|
|
return x
|
|
|
|
def rand_v(a, b):
|
|
return Variable(to_type(torch.randn(a, b)))
|
|
|
|
input = rand_v(3, 10)
|
|
hx = rand_v(3, 20)
|
|
cx = rand_v(3, 20)
|
|
module = to_type(nn.LSTMCell(10, 20)) # Just to allocate weights with correct sizes
|
|
|
|
CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCell)
|
|
|
|
z = CompiledLSTMCell(input, (hx, cx), *module.parameters())
|
|
with self.assertCompiled(CompiledLSTMCell):
|
|
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters())
|
|
self.assertEqual(z, z2)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_run_lstm_fusion_cuda(self):
|
|
self.run_lstm_fusion(True)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
def test_run_lstm_fusion_cpu(self):
|
|
self.run_lstm_fusion(False)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_run_lstm_fusion_concat(self):
|
|
input = Variable(torch.randn(3, 10).float().cuda())
|
|
hx = Variable(torch.randn(3, 20).float().cuda())
|
|
cx = Variable(torch.randn(3, 20).float().cuda())
|
|
module = nn.LSTMCell(10, 20).float().cuda() # Just to allocate weights with correct sizes
|
|
|
|
CompiledLSTMCell = torch.jit.compile(nderivs=0)(LSTMCellC)
|
|
|
|
z = CompiledLSTMCell(input, (hx, cx), *module.parameters())
|
|
with self.assertCompiled(CompiledLSTMCell):
|
|
z2 = CompiledLSTMCell(input, (hx, cx), *module.parameters())
|
|
self.assertEqual(z, z2)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_concat_fusion(self):
|
|
hx = Variable(torch.randn(3, 20).float().cuda())
|
|
cx = Variable(torch.randn(3, 20).float().cuda())
|
|
|
|
def Foo(hx, cx):
|
|
return torch.cat((hx + cx, hx * cx))
|
|
|
|
trace, _ = torch.jit.trace(Foo, (hx, cx))
|
|
torch._C._jit_pass_lint(trace)
|
|
torch._C._jit_pass_fuse(trace)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_fusion_distribute(self):
|
|
def f(x, y):
|
|
z1, z2 = (x + y).chunk(2, dim=1)
|
|
return z1 * z2
|
|
x = Variable(torch.randn(4, 4).float().cuda())
|
|
y = Variable(torch.randn(4, 4).float().cuda())
|
|
trace, _ = torch.jit.trace(f, (x, y), nderivs=0)
|
|
torch._C._jit_pass_lint(trace)
|
|
torch._C._jit_pass_dce(trace)
|
|
self.assertExpectedTrace(trace, 'raw')
|
|
torch._C._jit_pass_fuse(trace)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Mysteriously fails on Windows")
|
|
def test_arg_configurations(self):
|
|
"""Different arg configurations should trigger different traces"""
|
|
x = Variable(torch.FloatTensor(4, 4).uniform_())
|
|
x_double = Variable(x.data.double())
|
|
x_grad = Variable(x.data.clone(), requires_grad=True)
|
|
y = Variable(torch.randn(4))
|
|
|
|
configurations = [
|
|
(x,),
|
|
(x_double,),
|
|
(x_grad,),
|
|
(y,),
|
|
([x, x],),
|
|
([x, y],),
|
|
]
|
|
if torch.cuda.is_available():
|
|
x_cuda = Variable(x.data.cuda())
|
|
configurations += [
|
|
(x_cuda,),
|
|
([x, x_cuda],),
|
|
([x_cuda, x],),
|
|
([[x_cuda, x]],),
|
|
]
|
|
if torch.cuda.device_count() > 1:
|
|
x_cuda_1 = Variable(x.data.cuda(1))
|
|
configurations += [
|
|
(x_cuda_1,),
|
|
([x_cuda, x_cuda_1],),
|
|
]
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def fn(*args):
|
|
in_vars, _ = torch._C._jit_flatten(args)
|
|
return in_vars[0] + 1
|
|
|
|
for i, config in enumerate(configurations):
|
|
self.assertFalse(fn.has_trace_for(*config))
|
|
fn(*config)
|
|
self.assertTrue(fn.has_trace_for(*config))
|
|
for unk_config in configurations[i + 1:]:
|
|
self.assertFalse(fn.has_trace_for(*unk_config))
|
|
self.assertEqual(fn.hits, 0)
|
|
|
|
def test_cse(self):
|
|
x = Variable(torch.Tensor([0.4, 0.3]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7, 0.5]), requires_grad=True)
|
|
|
|
trace, inputs = torch._C._tracer_enter((x, y), 0)
|
|
|
|
def fn(x, y):
|
|
w = (x + y) * (x + y) * (x + y)
|
|
t = torch.tanh(w) + torch.tanh(w)
|
|
z = (x + y) * (x + y) * (x + y) + t
|
|
return z
|
|
z = fn(*inputs)
|
|
torch._C._tracer_exit((z,))
|
|
torch._C._jit_pass_lint(trace)
|
|
torch._C._jit_pass_cse(trace)
|
|
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_compile_run_twice(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=0, optimize=False)
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
z = doit(x, y)
|
|
with self.assertCompiled(doit):
|
|
z2 = doit(x, y)
|
|
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
|
|
self.assertEqual(z, z2)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_compile_addc(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True).float().cuda()
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True).float().cuda()
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y) + 1))
|
|
|
|
z = doit(x, y)
|
|
with self.assertCompiled(doit):
|
|
z2 = doit(x, y)
|
|
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y) + 1)))
|
|
self.assertEqual(z, z2)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
|
|
def test_compile_fuse_last_device(self):
|
|
max_device = torch.cuda.device_count() - 1
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True).float().cuda(max_device)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True).float().cuda(max_device)
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y) + 1))
|
|
|
|
z = doit(x, y)
|
|
with self.assertCompiled(doit):
|
|
z2 = doit(x, y)
|
|
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y) + 1)))
|
|
self.assertEqual(z, z2)
|
|
|
|
def test_traced_function(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
z = doit(x, y)
|
|
with self.assertCompiled(doit):
|
|
z2 = doit(x, y)
|
|
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
|
|
self.assertEqual(z, z2)
|
|
|
|
def test_disabled_traced_function(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
@torch.jit.compile(enabled=False)
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
z = doit(x, y)
|
|
z2 = doit(x, y)
|
|
self.assertEqual(z, torch.sigmoid(torch.tanh(x * (x + y))))
|
|
self.assertEqual(z, z2)
|
|
|
|
def test_assign_traces(self):
|
|
"""Check that output Variables are assigned traces before they are saved."""
|
|
@traceable
|
|
class MyFn(Function):
|
|
@staticmethod
|
|
def forward(ctx, a):
|
|
out = a * 2
|
|
ctx.save_for_backward(out)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_a):
|
|
a, = ctx.saved_variables
|
|
return a * grad_a
|
|
|
|
x = Variable(torch.randn(10, 10), requires_grad=True)
|
|
trace, out = torch.jit.trace(MyFn.apply, x, nderivs=1)
|
|
out.sum().backward()
|
|
torch._C._jit_pass_dce(trace)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_traced_module(self):
|
|
input = Variable(torch.randn(3, 10))
|
|
hx = Variable(torch.randn(3, 20))
|
|
cx = Variable(torch.randn(3, 20))
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
class MyLSTMCell(nn.LSTMCell):
|
|
pass
|
|
|
|
lstm = MyLSTMCell(10, 20)
|
|
|
|
out = lstm(input, (hx, cx))
|
|
with self.assertCompiled(lstm):
|
|
out2 = lstm(input, (hx, cx))
|
|
self.assertEqual(out, out2)
|
|
|
|
def test_autograd_closure(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
trace, inputs = torch._C._tracer_enter((x, y), 1)
|
|
|
|
def fn(x, y):
|
|
z = torch.sigmoid(x * (x + y))
|
|
w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
|
|
return z, w
|
|
z, w = fn(*inputs)
|
|
|
|
torch._C._tracer_exit((z, w))
|
|
torch._C._jit_pass_lint(trace)
|
|
|
|
(z * w).backward()
|
|
torch._C._jit_pass_dce(trace)
|
|
torch._C._jit_pass_lint(trace)
|
|
|
|
x_grad = x.grad.data.clone()
|
|
x.grad.data.zero_()
|
|
|
|
function = torch._C._jit_createInterpreterFactory(trace)
|
|
torch._C._jit_pass_lint(trace)
|
|
z2, w2 = function()(x, y)
|
|
(z2 * w2).backward()
|
|
self.assertEqual(z, z2)
|
|
self.assertEqual(w, w2)
|
|
self.assertEqual(x.grad.data, x_grad)
|
|
|
|
def test_verify(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
@torch.jit.compile
|
|
def f(x, y):
|
|
z = torch.sigmoid(x * (x + y))
|
|
w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
|
|
return z, w
|
|
|
|
torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
|
|
|
|
def test_constant(self):
|
|
x = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
trace, (tx,) = torch._C._tracer_enter((x,), 0)
|
|
|
|
y = Variable(torch.diag(torch.Tensor([2, 2])))
|
|
z = tx.matmul(y)
|
|
|
|
torch._C._tracer_exit((z,))
|
|
function = torch._C._jit_createInterpreterFactory(trace)
|
|
|
|
z2 = function()(x)
|
|
self.assertEqual(z, z2)
|
|
|
|
y.data.fill_(1000) # make sure the data has been cloned
|
|
|
|
x2 = Variable(torch.ones(2, 2) * 2, requires_grad=True)
|
|
z3 = function()(x2)
|
|
self.assertEqual(z3.data, torch.ones(2, 2) * 4)
|
|
|
|
def test_c_function(self):
|
|
x = Variable(torch.randn(1, 3, 10, 10))
|
|
m = nn.Conv2d(3, 8, 3, 1)
|
|
|
|
trace, inputs = torch._C._tracer_enter((x,) + tuple(m.parameters()), 0)
|
|
y = m(inputs[0])
|
|
torch._C._tracer_exit((y,))
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_legacy_fail(self):
|
|
|
|
class MyLegacyFn(Function):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
def backward(self, grad_output):
|
|
return grad_output
|
|
|
|
x = Variable(torch.Tensor([0]), requires_grad=True)
|
|
trace, inputs = torch._C._tracer_enter((x,), 0)
|
|
self.assertRaisesRegex(RuntimeError, "MyLegacyFn", lambda: MyLegacyFn()(*inputs))
|
|
torch._C._tracer_exit(inputs)
|
|
|
|
def test_inplace_transplant(self):
|
|
x = Variable(torch.Tensor([0]), requires_grad=True)
|
|
trace, inputs = torch._C._tracer_enter((x,), 0)
|
|
|
|
def fn(x):
|
|
y = x.clone()
|
|
y.add_(2)
|
|
y.add_(3)
|
|
return y
|
|
y = fn(*inputs)
|
|
torch._C._tracer_exit((y,))
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_inplace_flags(self):
|
|
class InplaceFn(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.mark_dirty(x)
|
|
return x.add_(1)
|
|
|
|
@staticmethod
|
|
def backward(ctx, go):
|
|
return go
|
|
|
|
class RegularFn(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x.add(1)
|
|
|
|
@staticmethod
|
|
def backward(ctx, go):
|
|
return go
|
|
|
|
x = Variable(torch.Tensor([0]), requires_grad=True)
|
|
trace, inputs = torch._C._tracer_enter((x,), 0)
|
|
|
|
def fn(x):
|
|
y = RegularFn.apply(x)
|
|
y = InplaceFn.apply(y)
|
|
y = InplaceFn.apply(y)
|
|
y = RegularFn.apply(y)
|
|
return y
|
|
y = fn(*inputs)
|
|
torch._C._tracer_exit((y,))
|
|
ops = [n for n in trace.graph().nodes()]
|
|
for op in ops:
|
|
self.assertTrue(op.hasAttribute('inplace'))
|
|
inplace_flags = [False, True, True, False]
|
|
for op, is_inplace in zip(ops, inplace_flags):
|
|
self.assertEqual(op.i('inplace'), is_inplace)
|
|
|
|
def test_inplace_check(self):
|
|
class MyInplaceFn(Function):
|
|
@staticmethod
|
|
def forward(self, x):
|
|
x.add_(1)
|
|
self.mark_dirty(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(self, grad):
|
|
return grad
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def fn(x):
|
|
return MyInplaceFn.apply(x)
|
|
x = Variable(torch.randn(5, 5))
|
|
fn(x) # trace
|
|
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
|
|
fn(x)
|
|
|
|
def test_backward(self):
|
|
a = Variable(torch.randn(2, 2), requires_grad=True)
|
|
b = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
x = a
|
|
y = a * b
|
|
|
|
trace, inputs = torch._C._tracer_enter((x, y), 2)
|
|
|
|
def fn(x, y):
|
|
return y * 2 * x
|
|
z = fn(*inputs)
|
|
torch._C._tracer_exit((z,))
|
|
torch._C._jit_pass_lint(trace)
|
|
|
|
# Run first backward
|
|
grad, = torch.autograd.grad(z, x, Variable(torch.ones(2, 2), requires_grad=True), create_graph=True)
|
|
torch._C._jit_pass_lint(trace)
|
|
|
|
# Run second backward
|
|
grad.sum().backward(create_graph=True)
|
|
torch._C._jit_pass_lint(trace)
|
|
|
|
# Run dead code elimination to remove unused trace nodes
|
|
torch._C._jit_pass_dce(trace)
|
|
# This is nondeterministic, see:
|
|
# https://github.com/ezyang/pytorch/issues/227
|
|
# self.assertExpectedTrace(trace)
|
|
self.skipTest("output is nondeterministic on Travis/Python 3.5")
|
|
|
|
def test_backward_opaque(self):
|
|
x = Variable(torch.randn(3, 3), requires_grad=True)
|
|
y = Variable(torch.randn(3, 3), requires_grad=True)
|
|
|
|
trace, inputs = torch._C._tracer_enter((x, y), 2)
|
|
|
|
def fn(x, y):
|
|
return x.cross(y)
|
|
z = fn(*inputs)
|
|
torch._C._tracer_exit((z,))
|
|
torch._C._jit_pass_lint(trace)
|
|
|
|
# Run first backward
|
|
grad, = torch.autograd.grad(z, x, Variable(torch.ones(3, 3), requires_grad=True), create_graph=True)
|
|
torch._C._jit_pass_lint(trace)
|
|
|
|
# Run dead code elimination to remove unused trace nodes
|
|
torch._C._jit_pass_dce(trace)
|
|
# This is nondeterministic, see:
|
|
# https://github.com/ezyang/pytorch/issues/227
|
|
# self.assertExpectedTrace(trace)
|
|
self.skipTest("output is nondeterministic on Travis/Python 3.5")
|
|
|
|
def test_backward_closure(self):
|
|
"""Check that autograd closures handle multiple stages correctly."""
|
|
x = Variable(torch.randn(1), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=2)
|
|
def fn(x):
|
|
return x * x
|
|
|
|
# Generate trace
|
|
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
|
|
self.assertFalse(fn.has_trace_for(x))
|
|
grad_x.backward()
|
|
self.assertTrue(fn.has_trace_for(x))
|
|
|
|
x_grad = x.grad.data.clone()
|
|
x.grad.data.zero_()
|
|
|
|
# Run the trace
|
|
with self.assertCompiled(fn):
|
|
output = fn(x)
|
|
grad_x, = torch.autograd.grad(output, (x,), create_graph=True)
|
|
grad_x.backward()
|
|
|
|
self.assertEqual(x.grad.data, x_grad)
|
|
|
|
def test_trace_expire(self):
|
|
x = Variable(torch.randn(2, 2), requires_grad=True)
|
|
y = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
def record_trace(num_backwards):
|
|
trace, inputs = torch._C._tracer_enter((x, y), num_backwards)
|
|
|
|
def fn(x, y):
|
|
return y * 2 * x
|
|
z = fn(*inputs)
|
|
torch._C._tracer_exit((z,))
|
|
return z, trace
|
|
|
|
def check(expired, complete):
|
|
self.assertEqual(trace.is_expired, expired)
|
|
self.assertEqual(trace.is_complete, complete)
|
|
|
|
z, trace = record_trace(0)
|
|
check(False, True)
|
|
del z
|
|
check(False, True)
|
|
|
|
z, trace = record_trace(1)
|
|
check(False, False)
|
|
del z
|
|
check(True, False)
|
|
|
|
z, trace = record_trace(1)
|
|
check(False, False)
|
|
z.sum().backward()
|
|
check(False, True)
|
|
del z
|
|
check(False, True)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Mysteriously fails on Windows")
|
|
def test_multiuse_fn(self):
|
|
x = Variable(torch.randn(2, 2), requires_grad=True)
|
|
w = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
@torch.jit.compile
|
|
def cell(x, w):
|
|
return x * w + 2
|
|
|
|
out = cell(cell(cell(x, w), w), w)
|
|
self.assertFalse(cell.has_trace_for(x, w))
|
|
|
|
out.sum().backward()
|
|
self.assertTrue(cell.has_trace_for(x, w))
|
|
|
|
torch.jit.verify(cell, (x, w), devices=[])
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Mysteriously fails on Windows")
|
|
def test_output_unflatten(self):
|
|
"""Check that outputs of traced functions retain the original structure and nesting"""
|
|
x = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
def fn(x):
|
|
return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
|
|
|
|
expected_out = fn(x)
|
|
fn = torch.jit.compile(fn)
|
|
|
|
def recursive_sum(obj):
|
|
if isinstance(obj, Variable):
|
|
return obj.sum()
|
|
else:
|
|
return sum(recursive_sum(o) for o in obj)
|
|
|
|
recursive_sum(fn(x)).backward()
|
|
self.assertTrue(fn.has_trace_for(x))
|
|
with self.assertCompiled(fn):
|
|
self.assertEqual(fn(x), expected_out)
|
|
|
|
def test_input_flatten(self):
|
|
"""Check that inputs to traced functions are flattened"""
|
|
def make_var():
|
|
return Variable(torch.randn(1), requires_grad=True)
|
|
x = (make_var(), (make_var(), make_var()))
|
|
|
|
def fn(x, t):
|
|
y, z = t
|
|
return x * y * z
|
|
|
|
expected_out = fn(*x)
|
|
fn = torch.jit.compile(fn)
|
|
fn(*x).backward()
|
|
self.assertTrue(fn.has_trace_for(*x))
|
|
with self.assertCompiled(fn):
|
|
self.assertEqual(fn(*x), expected_out)
|
|
|
|
def test_flags(self):
|
|
x = Variable(torch.randn(2, 2))
|
|
y = Variable(torch.randn(2, 2))
|
|
|
|
@torch.jit.compile
|
|
def fn(x, y):
|
|
return (x * x + y * y + x * y).sum()
|
|
|
|
grads = {}
|
|
for rx, ry in product((True, False), repeat=2):
|
|
x.requires_grad = rx
|
|
y.requires_grad = ry
|
|
|
|
self.assertFalse(fn.has_trace_for(x, y))
|
|
out = fn(x, y)
|
|
|
|
self.assertFalse(fn.has_trace_for(x, y))
|
|
for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
|
|
if not compute:
|
|
continue
|
|
grad_v, = torch.autograd.grad(out, v, retain_graph=True)
|
|
expected_grad = grads.setdefault(name, grad_v)
|
|
self.assertEqual(grad_v, expected_grad)
|
|
self.assertEqual(fn.has_trace_for(x, y), rx or ry)
|
|
|
|
def test_no_grad_fallback(self):
|
|
"""Check that Traceable falls back to num_backwards=0 if in no-backprop mode"""
|
|
x = Variable(torch.randn(2, 2))
|
|
y = Variable(torch.randn(2, 2), requires_grad=True)
|
|
|
|
@torch.jit.compile
|
|
def fn(x, y):
|
|
return x * x + x * y
|
|
|
|
out = fn(x, y)
|
|
self.assertFalse(fn.has_trace_for(x, y))
|
|
with torch.no_grad():
|
|
out = fn(x, y)
|
|
self.assertTrue(fn.has_trace_for(x, y))
|
|
with self.assertCompiled(fn):
|
|
out2 = fn(x, y)
|
|
self.assertEqual(out, out2)
|
|
|
|
def test_backward_flag_checks(self):
|
|
x = Variable(torch.randn(1), requires_grad=True)
|
|
|
|
@torch.jit.compile(nderivs=2)
|
|
def fn(x):
|
|
return x * x
|
|
|
|
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
|
|
self.assertFalse(fn.has_trace_for(x))
|
|
grad_x.backward()
|
|
self.assertTrue(fn.has_trace_for(x))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'was compiled with'):
|
|
fn(x).backward(Variable(torch.ones(1), requires_grad=True))
|
|
with self.assertRaisesRegex(RuntimeError, 'was compiled with'):
|
|
grad_x, = torch.autograd.grad(fn(x), (x,), create_graph=True)
|
|
grad_x.backward(Variable(torch.ones(1), requires_grad=True))
|
|
|
|
# TODO: Test executing this
|
|
|
|
def test_python_ir(self):
|
|
x = Variable(torch.Tensor([0.4]), requires_grad=True)
|
|
y = Variable(torch.Tensor([0.7]), requires_grad=True)
|
|
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
traced, _ = torch.jit.trace(doit, (x, y))
|
|
g = torch._C._jit_get_graph(traced)
|
|
g2 = torch._C.Graph()
|
|
g_to_g2 = {}
|
|
for node in g.inputs():
|
|
g_to_g2[node] = g2.addInput()
|
|
for node in g.nodes():
|
|
n_ = g2.createClone(node, lambda x: g_to_g2[x])
|
|
g2.appendNode(n_)
|
|
for o, no in zip(node.outputs(), n_.outputs()):
|
|
g_to_g2[o] = no
|
|
|
|
for node in g.outputs():
|
|
g2.registerOutput(g_to_g2[node])
|
|
|
|
t_node = g2.create("TensorTest").t_("a", torch.ones([2, 2]))
|
|
assert(t_node.attributeNames() == ["a"])
|
|
g2.appendNode(t_node)
|
|
assert(torch.equal(torch.ones([2, 2]), t_node.t("a")))
|
|
self.assertExpected(str(g2))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
|
|
def test_cpp(self):
|
|
torch._C._jit_run_cpp_tests()
|
|
|
|
def test_batchnorm(self):
|
|
x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
|
|
trace, _ = torch.jit.trace(nn.BatchNorm2d(2), x)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_dropout(self):
|
|
x = Variable(torch.randn(2, 2).fill_(1.0), requires_grad=True)
|
|
trace, _ = torch.jit.trace(nn.Dropout(0.6), x)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_batchnorm_run_twice(self):
|
|
@torch.jit.compile(nderivs=0)
|
|
class MyBatchNorm2d(nn.BatchNorm2d):
|
|
pass
|
|
|
|
bn = MyBatchNorm2d(1)
|
|
x = Variable(torch.randn(5, 1))
|
|
z = bn(x)
|
|
with self.assertCompiled(bn):
|
|
z2 = bn(x)
|
|
self.assertEqual(z, z2)
|
|
|
|
def test_non_decorator_use_fails(self):
|
|
MyLSTM = torch.jit.compile(nn.LSTM)
|
|
self.assertRaisesRegex(TypeError, "class decorator", lambda: MyLSTM(2, 2))
|
|
|
|
def test_conv(self):
|
|
x = Variable(torch.randn(20, 16, 50, 40).fill_(1.0), requires_grad=True)
|
|
trace, _ = torch.jit.trace(nn.Conv2d(16, 13, 3, bias=False), x)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_reuse_function(self):
|
|
@torch.jit.compile(nderivs=0)
|
|
def clinear(*args):
|
|
return F.linear(*args)
|
|
|
|
def cast(x):
|
|
return x
|
|
|
|
input = Variable(cast(torch.randn(1, 1)))
|
|
weights = Variable(cast(torch.randn(1, 1)))
|
|
bias = Variable(cast(torch.randn(1, 1)))
|
|
|
|
# linear AKA addmm without bias is of particular interest
|
|
# because we allocate a zero-filled new variable when we execute,
|
|
# and then *fill* it with the result
|
|
|
|
r1_ = clinear(input, weights)
|
|
with self.assertCompiled(clinear):
|
|
r1 = clinear(r1_, weights)
|
|
r2 = F.linear(F.linear(input, weights), weights)
|
|
|
|
self.assertEqual(r1, r2)
|
|
|
|
def test_unused_input(self):
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(a, b, c):
|
|
return a + b
|
|
|
|
a, b, c = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(3)]
|
|
fn(a, b, c).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(a, b, c).sum().backward()
|
|
|
|
def test_repeated_input(self):
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
a, b = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(2)]
|
|
fn(a, a).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(a, a).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(a, b).sum().backward()
|
|
self.assertExpected(str(fn.graph_for(a, a)))
|
|
|
|
def test_repeated_output(self):
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(a, b):
|
|
z = a + b
|
|
return z, z
|
|
|
|
a, b = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(2)]
|
|
sum(fn(a, b)).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
sum(fn(a, b)).sum().backward()
|
|
self.assertExpected(str(fn.graph_for(a, b)))
|
|
|
|
def test_re_enter(self):
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn2(a, b, c):
|
|
return fn(a, b) + c
|
|
|
|
a, b, c = [Variable(torch.randn(2, 2), requires_grad=True) for _ in range(3)]
|
|
|
|
fn(a, b).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(a, b).sum().backward()
|
|
|
|
fn2(a, b, c).sum().backward()
|
|
with self.assertCompiled(fn2):
|
|
fn2(a, b, c).sum().backward()
|
|
|
|
def test_mini_wlm(self):
|
|
"""Exercise null-edge pruning in the tracer."""
|
|
|
|
@torch.jit.compile
|
|
class MyModel(nn.Module):
|
|
def __init__(self):
|
|
super(MyModel, self).__init__()
|
|
self.encoder = nn.Embedding(2, 2)
|
|
|
|
def forward(self, input, hidden):
|
|
emb = self.encoder(input)
|
|
hidden = hidden.clone() # simulate some RNN operation
|
|
return emb, hidden
|
|
|
|
model = MyModel()
|
|
|
|
x = Variable(torch.LongTensor([[0, 1], [1, 0]]))
|
|
y = Variable(torch.FloatTensor([0]))
|
|
|
|
z, _ = model(x, y)
|
|
z.sum().backward()
|
|
self.assertTrue(model.has_trace_for(x, y))
|
|
|
|
with self.assertCompiled(model):
|
|
z, _ = model(x, y)
|
|
z.sum().backward()
|
|
|
|
def test_module_cast(self):
|
|
"""Compiled modules can be casted to other data types"""
|
|
@torch.jit.compile(nderivs=0)
|
|
class Adder(nn.Module):
|
|
def __init__(self):
|
|
super(Adder, self).__init__()
|
|
self.y = nn.Parameter(torch.randn(2, 2))
|
|
|
|
def forward(self, x):
|
|
return x + self.y
|
|
|
|
x = Variable(torch.randn(2, 2).float())
|
|
# Wrap it in a sequential to make sure it works for submodules
|
|
a = nn.Sequential(Adder()).float()
|
|
|
|
def check_type(caster):
|
|
caster(a)
|
|
a(caster(x))
|
|
with self.assertCompiled(a[0]):
|
|
a(caster(x))
|
|
|
|
check_type(lambda x: x)
|
|
check_type(lambda x: x.double())
|
|
if torch.cuda.is_available():
|
|
check_type(lambda x: x.float().cuda())
|
|
check_type(lambda x: x.double().cuda())
|
|
self.assertEqual(a[0].hits, 4 if torch.cuda.is_available() else 2)
|
|
|
|
# Tracer fails when it receives the same grad variable as multiple input to
|
|
# traced region. The problem is that it's not immediately obvious how to
|
|
# assign multiple inputs to this Variable. It might be possible to solve
|
|
# this using the view mechanism, but this requires some thought.
|
|
# In general, it should be supported, because the user has no control
|
|
# over this (and it's quite common, e.g. the sum call below will pass the same
|
|
# grad variable as both inputs to grad of fn).
|
|
@unittest.skip("Broken - repeated grads trigger an assertion failure.")
|
|
def test_repeated_grad(self):
|
|
@torch.jit.compile
|
|
def fn(x):
|
|
return x * x, x + x
|
|
|
|
x = Variable(torch.randn(5, 5), requires_grad=True)
|
|
# This shouldn't raise!
|
|
sum(fn(x)).sum().backward()
|
|
|
|
def test_input_pruning(self):
|
|
"""Check that stage 1 will return only one value"""
|
|
# One of the inputs doesn't require grad, so it should be pruned
|
|
@torch.jit.compile
|
|
def fn(x, y):
|
|
return x * y, x + y
|
|
|
|
x = Variable(torch.randn(5, 5), requires_grad=True)
|
|
y = Variable(torch.randn(5, 5))
|
|
|
|
out = fn(x, y)
|
|
(out[0] * out[1]).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(x, y)
|
|
self.assertExpected(str(fn.graph_for(x, y)))
|
|
|
|
def test_output_pruning(self):
|
|
"""Check that stage 1 will take one value as an argument"""
|
|
# One of the outputs doesn't require grad, so it should be pruned
|
|
@torch.jit.compile
|
|
def fn(x, y):
|
|
return x * y, y + y
|
|
|
|
x = Variable(torch.randn(5, 5), requires_grad=True)
|
|
y = Variable(torch.randn(5, 5))
|
|
|
|
out = fn(x, y)
|
|
(out[0] * out[1]).sum().backward()
|
|
with self.assertCompiled(fn):
|
|
fn(x, y)
|
|
self.assertExpected(str(fn.graph_for(x, y)))
|
|
|
|
@skipIfNoTorchVision
|
|
def test_alexnet(self):
|
|
return
|
|
x = Variable(torch.randn(10, 3, 224, 224).fill_(1.0), requires_grad=True)
|
|
trace, _ = torch.jit.trace(torchvision.models.AlexNet(), x)
|
|
self.assertExpectedTrace(trace)
|
|
# NB: Purposely NOT testing protobuf export here
|
|
|
|
def test_debug_info(self):
|
|
"""Check that debug info doesn't crash and has some reasonable info"""
|
|
|
|
@torch.jit.compile(nderivs=1)
|
|
def fn(x, y):
|
|
return x * y + x + y
|
|
|
|
x = Variable(torch.randn(5, 5), requires_grad=True)
|
|
y = Variable(torch.randn(5, 5), requires_grad=True)
|
|
|
|
out = fn(x, y)
|
|
|
|
out.sum().backward()
|
|
|
|
for _ in range(0, 100):
|
|
out = fn(x, y)
|
|
info_str = fn.jit_debug_info()
|
|
self.assertTrue("hits: 100" in info_str)
|
|
self.assertTrue("stage 1" in info_str)
|
|
|
|
# Inplace copies don't work with tracer yet.
|
|
# This is actually somewhat important to support correctly
|
|
# as all backwards functions of views are implemented
|
|
# as a zero filled tensor with a gradient fill on the
|
|
# viewed portion.
|
|
@unittest.expectedFailure
|
|
def test_inplace_copy(self):
|
|
x = Variable(torch.randn(4, 4), requires_grad=True)
|
|
|
|
def f(x):
|
|
out = Variable(torch.zeros(x.size()))
|
|
out.copy_(x)
|
|
return out
|
|
|
|
trace, z = torch.jit.trace(f, (x, ), nderivs=0)
|
|
torch._C._jit_pass_lint(trace)
|
|
torch._C._jit_pass_dce(trace)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_index_trace(self):
|
|
x = Variable(torch.randn(4, 4), requires_grad=True)
|
|
trace, z = torch.jit.trace(lambda x: x[0], (x, ), nderivs=1)
|
|
z.sum().backward()
|
|
torch._C._jit_pass_lint(trace)
|
|
torch._C._jit_pass_dce(trace)
|
|
self.assertExpectedTrace(trace)
|
|
|
|
def test_shared_param(self):
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.b = self.a = nn.Parameter(torch.randn(2, 2))
|
|
|
|
def forward(self, x):
|
|
return x * self.a + self.b
|
|
|
|
m = MyModule()
|
|
trace, _ = torch.jit.trace(m, (Variable(torch.randn(2, 2)),), nderivs=0)
|
|
self.assertEqual(len(list(trace.graph().inputs())), 2)
|
|
self.assertExpected(str(trace))
|
|
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|