mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Dropout is now eligible for fusion, and generated fused kernels are just as fast as dropout in ATen. Change its lowering in symbolic script so that it can actually be fused. Still special-cased for cuda, because without fusion this lowering is less efficient than current (bernoulli_ * input). Testing is covered by the test case that ailzhang added (test_dropout_cuda). Pull Request resolved: https://github.com/pytorch/pytorch/pull/18375 Differential Revision: D14611938 Pulled By: soumith fbshipit-source-id: 11b18f4784e6c9265e382a8f8deca7add8df3b37
14339 lines
502 KiB
Python
14339 lines
502 KiB
Python
from __future__ import division
|
|
import torch
|
|
import torch.jit
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
import torch.nn.parallel as dp
|
|
import torch.optim as optim
|
|
import torch.cuda
|
|
import torch.jit.quantized
|
|
from contextlib import contextmanager
|
|
from itertools import product, chain
|
|
import torch.jit.frontend
|
|
from torch.autograd import Variable, Function
|
|
from torch.nn import Module
|
|
from torch.autograd.function import traceable
|
|
from torch.testing import assert_allclose
|
|
from torch.onnx import OperatorExportTypes
|
|
from torch._six import inf, PY2, builtins
|
|
from common_utils import TestCase, run_tests, IS_WINDOWS, TEST_WITH_UBSAN, \
|
|
skipIfRocm, skipIfNoLapack, suppress_warnings, load_tests, IS_SANDCASTLE, \
|
|
freeze_rng_state, set_rng_seed
|
|
from common_nn import module_tests, new_module_tests, criterion_tests
|
|
from textwrap import dedent
|
|
from functools import wraps
|
|
import os
|
|
import io
|
|
import itertools
|
|
import sys
|
|
import unittest
|
|
import inspect
|
|
import textwrap
|
|
import numpy as np
|
|
import tempfile
|
|
import shutil
|
|
import warnings
|
|
import math
|
|
import types
|
|
import pickle
|
|
import copy
|
|
|
|
from common_methods_invocations import method_tests as autograd_method_tests
|
|
from common_methods_invocations import create_input, unpack_variables, \
|
|
exclude_tensor_method, non_differentiable, EXCLUDE_GRADCHECK, EXCLUDE_FUNCTIONAL
|
|
from torch.testing import FileCheck
|
|
from torch._C import TensorType, TupleType, FloatType, IntType, \
|
|
ListType, StringType, DictType
|
|
from copy import deepcopy
|
|
import random
|
|
from typing import List, Dict, Optional, Tuple
|
|
from torch.jit.frontend import NotSupportedError
|
|
from torch.jit import BatchTensor
|
|
from torch import Tensor
|
|
from torch.jit.annotations import BroadcastingList2, BroadcastingList3
|
|
|
|
# For testing truediv in python 2
|
|
from test_module.future_div import div_int_future, div_float_future
|
|
from test_module.no_future_div import div_int_nofuture, div_float_nofuture
|
|
|
|
|
|
# load_tests from common_utils is used to automatically filter tests for
|
|
# sharding on sandcastle. This line silences flake warnings
|
|
load_tests = load_tests
|
|
|
|
try:
|
|
import torchvision
|
|
HAS_TORCHVISION = True
|
|
except ImportError:
|
|
HAS_TORCHVISION = False
|
|
|
|
|
|
skipIfNoTorchVision = unittest.skipIf(not HAS_TORCHVISION, "no torchvision")
|
|
|
|
RUN_CUDA = torch.cuda.is_available()
|
|
RUN_CUDA_HALF = RUN_CUDA
|
|
if torch.cuda.is_available():
|
|
CUDA_VERSION = torch._C._cuda_getCompiledVersion()
|
|
for d in range(torch.cuda.device_count()):
|
|
major = torch.cuda.get_device_capability(d)[0]
|
|
if (CUDA_VERSION < 8000 and major >= 6) or (CUDA_VERSION < 9000 and major >= 7):
|
|
RUN_CUDA = False
|
|
if (CUDA_VERSION < 9000 or major < 6):
|
|
RUN_CUDA_HALF = False
|
|
|
|
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
|
|
|
|
PY35 = sys.version_info >= (3, 5)
|
|
WINDOWS = sys.platform == 'win32'
|
|
|
|
|
|
if WINDOWS:
|
|
@contextmanager
|
|
def TemporaryFileName():
|
|
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
|
|
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
|
|
# close the file after creation and try to remove it manually
|
|
f = tempfile.NamedTemporaryFile(delete=False)
|
|
try:
|
|
f.close()
|
|
yield f.name
|
|
finally:
|
|
os.unlink(f.name)
|
|
else:
|
|
@contextmanager # noqa: T484
|
|
def TemporaryFileName():
|
|
with tempfile.NamedTemporaryFile() as f:
|
|
yield f.name
|
|
|
|
|
|
def LSTMCellF(input, hx, cx, *params):
|
|
return LSTMCell(input, (hx, cx), *params)
|
|
|
|
|
|
def LSTMCell(input, hidden, w_ih, w_hh, b_ih=None, b_hh=None):
|
|
hx, cx = hidden
|
|
gates = F.linear(input, w_ih, b_ih) + F.linear(hx, w_hh, b_hh)
|
|
|
|
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
|
ingate = torch.sigmoid(ingate)
|
|
forgetgate = torch.sigmoid(forgetgate)
|
|
cellgate = torch.tanh(cellgate)
|
|
outgate = torch.sigmoid(outgate)
|
|
|
|
cy = (forgetgate * cx) + (ingate * cellgate)
|
|
hy = outgate * torch.tanh(cy)
|
|
return hy, cy
|
|
|
|
|
|
def LSTMCellC(*args, **kwargs):
|
|
hy, cy = LSTMCellF(*args, **kwargs)
|
|
return torch.cat((hy, cy))
|
|
|
|
|
|
def LSTMCellS(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
|
|
gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh
|
|
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
|
ingate = torch.sigmoid(ingate)
|
|
forgetgate = torch.sigmoid(forgetgate)
|
|
cellgate = torch.tanh(cellgate)
|
|
outgate = torch.sigmoid(outgate)
|
|
cy = (forgetgate * cx) + (ingate * cellgate)
|
|
hy = outgate * torch.tanh(cy)
|
|
return hy, cy
|
|
|
|
|
|
# Code reference: https://github.com/pytorch/translate/blob/master/pytorch_translate/rnn_cell.py#L27:44
|
|
def MiLSTMCell(x, hx, cx, w_ih, w_hh, alpha, beta_i, beta_h, bias):
|
|
Wx = x.mm(w_ih.t())
|
|
Uz = hx.mm(w_hh.t())
|
|
# Section 2.1 in https://arxiv.org/pdf/1606.06630.pdf
|
|
gates = alpha * Wx * Uz + beta_i * Wx + beta_h * Uz + bias
|
|
# Same as LSTMCell after this point
|
|
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
|
ingate = ingate.sigmoid()
|
|
forgetgate = forgetgate.sigmoid()
|
|
cellgate = cellgate.tanh()
|
|
outgate = outgate.sigmoid()
|
|
cy = (forgetgate * cx) + (ingate * cellgate)
|
|
hy = outgate * cy.tanh()
|
|
return hy, cy
|
|
|
|
|
|
def canonical(graph):
|
|
return str(torch._C._jit_pass_canonicalize(graph))
|
|
|
|
|
|
def get_lstm_inputs(device, training=False, seq_length=None):
|
|
input_shape = (3, 10) if seq_length is None else (seq_length, 3, 10)
|
|
input = torch.randn(*input_shape, dtype=torch.float, device=device, requires_grad=training)
|
|
hx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
|
|
cx = torch.randn(3, 20, dtype=torch.float, device=device, requires_grad=training)
|
|
module = nn.LSTMCell(10, 20).to(device, torch.float) # Just to allocate weights with correct sizes
|
|
if training:
|
|
params = tuple(module.parameters())
|
|
else:
|
|
params = tuple(p.requires_grad_(False) for p in module.parameters())
|
|
return (input, hx, cx) + params
|
|
|
|
|
|
def get_milstm_inputs(device, training=False):
|
|
minibatch = 3
|
|
input_size = 10
|
|
hidden_size = 20
|
|
x = torch.randn(minibatch, input_size, device=device, dtype=torch.float)
|
|
hx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
|
|
cx = torch.randn(minibatch, hidden_size, device=device, dtype=torch.float)
|
|
|
|
ih = torch.randn(4 * hidden_size, input_size, device=device, dtype=torch.float, requires_grad=training)
|
|
hh = torch.randn(4 * hidden_size, hidden_size, device=device, dtype=torch.float, requires_grad=training)
|
|
alpha = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
|
|
ibeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
|
|
hbeta = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
|
|
bias = torch.randn(4 * hidden_size, dtype=torch.float, device=device, requires_grad=training)
|
|
return x, hx, cx, ih, hh, alpha, ibeta, hbeta, bias
|
|
|
|
|
|
def get_fn(file_name, script_path):
|
|
import importlib.util
|
|
spec = importlib.util.spec_from_file_location(file_name, script_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
fn = module.fn
|
|
return fn
|
|
|
|
|
|
def get_execution_plan(graph_executor_state):
|
|
execution_plans = list(graph_executor_state.execution_plans.values())
|
|
num_plans = len(execution_plans)
|
|
if num_plans != 1:
|
|
raise RuntimeError('This test assumes this GraphExecutor should '
|
|
'only have one execution plan, got: {}'.format(num_plans))
|
|
return execution_plans[0]
|
|
|
|
|
|
def get_grad_executor(plan_state, diff_graph_idx=None):
|
|
if diff_graph_idx is None:
|
|
nodes = list(plan_state.graph.nodes())
|
|
if len(nodes) == 1 or (len(nodes) == 2 and nodes[1].kind() == "prim::TupleConstruct"):
|
|
pass
|
|
else:
|
|
raise RuntimeError("Can't get a grad_executor for a non-differentiable graph")
|
|
grad_executors = list(plan_state.code.grad_executors())
|
|
return grad_executors[diff_graph_idx or 0]
|
|
|
|
|
|
def backward_graph(script_module, diff_graph_idx=None):
|
|
if not isinstance(script_module, torch.jit.ScriptModule):
|
|
raise RuntimeError('Expected ScriptModule')
|
|
ge_state = script_module.get_debug_state()
|
|
fwd_plan = get_execution_plan(ge_state)
|
|
grad_executor = get_grad_executor(fwd_plan, diff_graph_idx=diff_graph_idx)
|
|
bwd_plan = get_execution_plan(grad_executor.get_debug_state())
|
|
# Running JIT passes requires that we own the graph (with a shared_ptr).
|
|
# The debug state struct does not own its graph so we make a copy of it.
|
|
return bwd_plan.graph.copy()
|
|
|
|
|
|
# make it easy to quicky define/trace a function for these tests
|
|
def _trace(*args, **kwargs):
|
|
def wrapper(func):
|
|
return torch.jit.trace(func, args, **kwargs)
|
|
return wrapper
|
|
|
|
|
|
def enable_cpu_fuser(fn):
|
|
def wrapper(*args, **kwargs):
|
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
try:
|
|
fn(*args, **kwargs)
|
|
finally:
|
|
torch._C._jit_override_can_fuse_on_cpu(False)
|
|
return wrapper
|
|
|
|
|
|
class JitTestCase(TestCase):
|
|
_do_cuda_memory_leak_check = True
|
|
_restored_warnings = False
|
|
|
|
def setUp(self):
|
|
super(JitTestCase, self).setUp()
|
|
# unittest overrides all warning filters and forces all of them to show up
|
|
# after we install our own to silence those coming from inside PyTorch.
|
|
# This will ensure that our filter still takes precedence.
|
|
if not JitTestCase._restored_warnings:
|
|
torch.jit.TracerWarning.ignore_lib_warnings()
|
|
JitTestCase._restored_warnings = True
|
|
torch._C._jit_set_emit_module_hook(self.emitModuleHook)
|
|
|
|
def tearDown(self):
|
|
super(JitTestCase, self).tearDown()
|
|
# needs to be cleared because python might be unloaded before
|
|
# the callback gets destucted
|
|
torch._C._jit_set_emit_module_hook(None)
|
|
torch._C._jit_clear_class_registry()
|
|
|
|
@contextmanager
|
|
def disableModuleHook(self):
|
|
torch._C._jit_set_emit_module_hook(None)
|
|
yield None
|
|
torch._C._jit_set_emit_module_hook(self.emitModuleHook)
|
|
|
|
def emitModuleHook(self, module):
|
|
def copy_structure_and_params(m):
|
|
c = torch.jit.ScriptModule()
|
|
for name, v in m._get_parameters():
|
|
c._register_parameter(name, v, False)
|
|
for name, the_type, v in m._get_attributes():
|
|
c._register_attribute(name, the_type, v)
|
|
for name, s in m._get_modules():
|
|
c._register_module(name, copy_structure_and_params(s))
|
|
return c
|
|
|
|
# disable the hook while we parse code, otherwise we will re-enter the hook
|
|
with self.disableModuleHook():
|
|
try:
|
|
pp, constant_table = module._python_print()
|
|
except RuntimeError as e:
|
|
se = str(e)
|
|
if "could not export python function" not in se and \
|
|
"closures are not exportable" not in se:
|
|
raise
|
|
else:
|
|
return
|
|
ppv = "op_version_set = 0\n{}".format(pp)
|
|
sm = copy_structure_and_params(module)
|
|
torch._C._jit_import_methods(sm, ppv, constant_table)
|
|
pp2, _ = sm._python_print()
|
|
if pp != pp2:
|
|
self.assertMultiLineEqual(pp, pp2)
|
|
|
|
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer)
|
|
buffer.seek(0)
|
|
imported = torch.jit.load(buffer, map_location=map_location)
|
|
|
|
if not also_test_file:
|
|
return imported
|
|
|
|
with TemporaryFileName() as fname:
|
|
imported.save(fname)
|
|
return torch.jit.load(fname, map_location=map_location)
|
|
|
|
def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
|
|
buffer = io.BytesIO()
|
|
m.apply(lambda s: s._pack() if s._has_method('_pack') else None)
|
|
torch.jit.save(m, buffer)
|
|
m.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
|
|
buffer.seek(0)
|
|
imported = torch.jit.load(buffer, map_location=map_location)
|
|
imported.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
|
|
|
|
if not also_test_file:
|
|
return imported
|
|
|
|
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
|
|
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
|
|
# close the file after creation and try to remove it manually
|
|
f = tempfile.NamedTemporaryFile(delete=False)
|
|
try:
|
|
f.close()
|
|
imported.save(f.name)
|
|
result = torch.jit.load(f.name, map_location=map_location)
|
|
finally:
|
|
os.unlink(f.name)
|
|
|
|
result.apply(lambda s: s._unpack() if s._has_method('_unpack') else None)
|
|
return result
|
|
|
|
def assertGraphContains(self, graph, kind):
|
|
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
|
|
|
|
def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
|
|
def perform_assert(graph, kind, actual, expected, consider_subgraphs):
|
|
if actual == expected:
|
|
return
|
|
subgraph = 'including' if consider_subgraphs else 'excluding'
|
|
raise AssertionError(
|
|
'{}\nError: graph contains {} {} nodes ({} subgraphs) but expected {}'.format(
|
|
graph, actual, kind, subgraph, expected))
|
|
|
|
if consider_subgraphs:
|
|
strgraph = str(graph)
|
|
count = strgraph.count(kind) - strgraph.count('with {}'.format(kind))
|
|
perform_assert(graph, kind, count, num_kind_nodes,
|
|
consider_subgraphs)
|
|
return
|
|
|
|
nodes = [node for node in graph.nodes()
|
|
if node.kind() == kind]
|
|
perform_assert(graph, kind, len(nodes), num_kind_nodes,
|
|
consider_subgraphs)
|
|
|
|
def assertExpectedONNXGraph(self, trace, *args, **kwargs):
|
|
torch.onnx._optimize_trace(trace, operator_export_type=OperatorExportTypes.ONNX)
|
|
self.assertExpectedGraph(trace, *args, **kwargs)
|
|
|
|
def assertExpectedGraph(self, trace, *args, **kwargs):
|
|
if isinstance(trace, torch._C.Graph):
|
|
graph = trace
|
|
else:
|
|
graph = trace.graph()
|
|
|
|
torch._C._jit_pass_lint(graph)
|
|
torch._C._jit_pass_dce(graph)
|
|
torch._C._jit_pass_lint(graph)
|
|
graph = torch._C._jit_pass_canonicalize(graph)
|
|
torch._C._jit_pass_lint(graph)
|
|
self.assertExpected(str(graph), *args, **kwargs)
|
|
|
|
def run_pass(self, name, trace):
|
|
if isinstance(trace, torch._C.Graph):
|
|
graph = trace
|
|
set_graph = False
|
|
else:
|
|
set_graph = True
|
|
graph = trace.graph()
|
|
|
|
torch._C._jit_pass_lint(graph)
|
|
result = getattr(torch._C, '_jit_pass_' + name)(graph)
|
|
if result is not None:
|
|
graph = result
|
|
torch._C._jit_pass_lint(graph)
|
|
|
|
if set_graph:
|
|
trace.set_graph(graph)
|
|
return graph
|
|
|
|
def checkScript(self,
|
|
script,
|
|
inputs,
|
|
optimize=True,
|
|
outputs=None,
|
|
name='func',
|
|
capture_output=False,
|
|
frames_up=1,
|
|
check_expected=False):
|
|
if isinstance(script, str):
|
|
cu = torch.jit.CompilationUnit(script, optimize, _frames_up=frames_up)
|
|
ge = getattr(cu, name)
|
|
else:
|
|
if capture_output:
|
|
with self.capture_stdout() as captured:
|
|
outputs = script(*inputs)
|
|
else:
|
|
outputs = script(*inputs)
|
|
# Check the string frontend first
|
|
source = textwrap.dedent(inspect.getsource(script))
|
|
self.checkScript(
|
|
source,
|
|
inputs,
|
|
optimize,
|
|
outputs,
|
|
script.__name__,
|
|
capture_output,
|
|
frames_up=2,
|
|
check_expected=check_expected)
|
|
# Continue checking the Python frontend
|
|
ge = torch.jit.script(script, optimize, _frames_up=1)
|
|
|
|
if capture_output:
|
|
with self.capture_stdout() as captured:
|
|
outputs_ge = ge(*inputs)
|
|
if not WINDOWS:
|
|
self.assertExpected(captured[0], subname='stdout')
|
|
else:
|
|
outputs_ge = ge(*inputs)
|
|
self.assertEqual(outputs, outputs_ge)
|
|
|
|
if check_expected:
|
|
self.assertExpectedGraph(ge.graph)
|
|
|
|
return ge
|
|
|
|
def checkTrace(self, func, reference_tensors, input_tensors=None,
|
|
optimize=True, drop=None, allow_unused=False, verbose=False,
|
|
inputs_require_grads=True, check_tolerance=1e-5, export_import=True,
|
|
_force_outplace=False):
|
|
# TODO: check gradients for parameters, not just inputs
|
|
def allSum(vs):
|
|
# drop allows us to remove some values from ever being used
|
|
# to test unused outputs
|
|
if drop is not None:
|
|
vs = vs[:-drop]
|
|
# we don't want all the grad for all the outputs to be the same
|
|
# so we multiply each by a constant
|
|
return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None)
|
|
if input_tensors is None:
|
|
input_tensors = reference_tensors
|
|
|
|
nograd_inputs = reference_tensors
|
|
if inputs_require_grads:
|
|
recording_inputs = [t.clone().requires_grad_() for t in reference_tensors]
|
|
else:
|
|
recording_inputs = reference_tensors
|
|
|
|
if isinstance(func, torch._C.Graph):
|
|
ge = torch._C.GraphExecutor(func, optimize)
|
|
else:
|
|
ge = torch.jit.trace(func, input_tensors, optimize=optimize, check_tolerance=check_tolerance,
|
|
_force_outplace=_force_outplace)
|
|
|
|
if export_import:
|
|
ge = self.getExportImportCopy(ge)
|
|
|
|
if verbose:
|
|
print(ge.graph)
|
|
|
|
# test no gradients case
|
|
outputs = func(*nograd_inputs)
|
|
outputs_ge = ge(*nograd_inputs)
|
|
self.assertEqual(outputs, outputs_ge)
|
|
|
|
# test single grad case
|
|
outputs = func(*recording_inputs)
|
|
if inputs_require_grads:
|
|
grads = torch.autograd.grad(allSum(outputs), recording_inputs,
|
|
allow_unused=allow_unused)
|
|
|
|
outputs_ge = ge(*recording_inputs)
|
|
if inputs_require_grads:
|
|
grads_ge = torch.autograd.grad(allSum(outputs_ge), recording_inputs,
|
|
allow_unused=allow_unused)
|
|
self.assertEqual(outputs, outputs_ge)
|
|
if inputs_require_grads:
|
|
self.assertEqual(grads, grads_ge)
|
|
|
|
# test the grad grad case
|
|
|
|
outputs = func(*recording_inputs)
|
|
l1 = allSum(outputs)
|
|
if inputs_require_grads:
|
|
grads = torch.autograd.grad(l1, recording_inputs, create_graph=True,
|
|
allow_unused=allow_unused)
|
|
if inputs_require_grads:
|
|
l2 = (allSum(grads) * l1)
|
|
grads2 = torch.autograd.grad(l2, recording_inputs, allow_unused=allow_unused)
|
|
|
|
if inputs_require_grads:
|
|
recording_inputs = [Variable(t, requires_grad=True)
|
|
for t in reference_tensors]
|
|
|
|
outputs_ge = ge(*recording_inputs)
|
|
l1_ge = allSum(outputs_ge)
|
|
if inputs_require_grads:
|
|
grads_ge = torch.autograd.grad(
|
|
l1_ge, recording_inputs, create_graph=True, allow_unused=allow_unused)
|
|
|
|
if inputs_require_grads:
|
|
l2_ge = (allSum(grads_ge) * l1_ge)
|
|
grads2_ge = torch.autograd.grad(l2_ge, recording_inputs, allow_unused=allow_unused)
|
|
|
|
self.assertEqual(outputs, outputs_ge)
|
|
if inputs_require_grads:
|
|
self.assertEqual(grads, grads_ge)
|
|
for g2, g2_ge in zip(grads2, grads2_ge):
|
|
if g2 is None and g2_ge is None:
|
|
continue
|
|
self.assertTrue(torch.allclose(g2, g2_ge, atol=8e-4, rtol=8e-4))
|
|
|
|
return ge
|
|
|
|
def createScriptModuleFromGraph(self, trace):
|
|
graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
|
|
m = torch.jit.ScriptModule()
|
|
m._create_method_from_graph("forward", graph)
|
|
return m
|
|
|
|
def assertExportImport(self, trace, inputs):
|
|
m = self.createScriptModuleFromGraph(trace)
|
|
self.assertExportImportModule(m, inputs)
|
|
|
|
def assertExportImportModule(self, m, inputs):
|
|
m_import = self.getExportImportCopy(m)
|
|
self.assertEqual(self.runAndSaveRNG(m.forward, inputs),
|
|
self.runAndSaveRNG(m_import.forward, inputs))
|
|
|
|
def runAndSaveRNG(self, func, inputs, kwargs=None):
|
|
kwargs = kwargs if kwargs else {}
|
|
with freeze_rng_state():
|
|
results = func(*inputs, **kwargs)
|
|
return results
|
|
|
|
|
|
# has to be at top level or Pickle complains
|
|
class FooToPickle(torch.nn.Module):
|
|
def __init__(self):
|
|
super(FooToPickle, self).__init__()
|
|
self.bar = torch.jit.ScriptModule()
|
|
|
|
|
|
class TestJit(JitTestCase):
|
|
|
|
@unittest.skip("Requires a lot of RAM")
|
|
def test_big(self):
|
|
m = torch.jit.ScriptModule()
|
|
gig = int(1024 * 1024 * 1024 / 4)
|
|
# a small tensor in the first 4GB
|
|
m.v0 = nn.Parameter(torch.full((2,), 1, dtype=torch.float))
|
|
# a large tensor in the first 4GB that ends outside of it
|
|
m.v1 = nn.Parameter(torch.full((5, gig), 2, dtype=torch.float))
|
|
# a small tensor in >4GB space
|
|
m.v2 = nn.Parameter(torch.full((2,), 3, dtype=torch.float))
|
|
# s large tensor in the > 4GB space
|
|
m.v3 = nn.Parameter(torch.full((5, gig), 4, dtype=torch.float))
|
|
|
|
m2 = self.getExportImportCopy(m)
|
|
|
|
self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
|
|
|
|
def test_simple(self):
|
|
x = torch.tensor([0.4], requires_grad=True)
|
|
y = torch.tensor([0.7], requires_grad=True)
|
|
|
|
def f(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
self.checkTrace(f, (x, y))
|
|
|
|
def test_restore_device(self):
|
|
# main purpose is checking map_location works
|
|
m = torch.jit.ScriptModule()
|
|
cpu_device_str = 'cpu'
|
|
m.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
|
|
device=cpu_device_str))
|
|
m.register_buffer('b0', torch.tensor([0.9], dtype=torch.float,
|
|
device=cpu_device_str))
|
|
m2 = self.getExportImportCopy(m)
|
|
self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
|
|
self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
|
|
self.assertFalse(m2.p0.is_cuda)
|
|
self.assertFalse(m2.b0.is_cuda)
|
|
|
|
def test_model_save_error(self):
|
|
with TemporaryFileName() as fname:
|
|
with self.assertRaisesRegex(pickle.PickleError, "not supported"):
|
|
torch.save(FooToPickle(), fname)
|
|
|
|
def test_single_tuple_trace(self):
|
|
x = torch.tensor(2.)
|
|
|
|
def f2(x):
|
|
return (x,)
|
|
jit_f2 = torch.jit.trace(f2, x)
|
|
assert f2(x) == jit_f2(x) # fails
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
|
|
def test_restore_device_cuda(self):
|
|
class MyModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__(False)
|
|
self.register_buffer('b0', torch.randn(1, 3))
|
|
self.p0 = nn.Parameter(torch.randn(2, 3))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x + self.b0 + self.p0
|
|
|
|
m = MyModule()
|
|
m.cuda(torch.cuda.device_count() - 1)
|
|
cuda_device_str = 'cuda:' + str(torch.cuda.device_count() - 1)
|
|
|
|
self.assertTrue(m.p0.is_cuda)
|
|
self.assertTrue(m.b0.is_cuda)
|
|
|
|
# restore to the saved devices
|
|
m2 = self.getExportImportCopy(m)
|
|
self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
|
|
self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
|
|
self.assertEqual(str(m2.p0.device), cuda_device_str)
|
|
self.assertEqual(str(m2.b0.device), cuda_device_str)
|
|
|
|
# restore all to cpu using string
|
|
cpu_device_str = 'cpu'
|
|
m3 = self.getExportImportCopy(m, map_location=cpu_device_str)
|
|
self.assertEqual(str(m3.p0.device), cpu_device_str)
|
|
self.assertEqual(str(m3.b0.device), cpu_device_str)
|
|
|
|
# restore all to first gpu using device
|
|
m4 = self.getExportImportCopy(
|
|
m3, map_location=torch.device('cuda:0'))
|
|
self.assertEqual(str(m4.p0.device), 'cuda:0')
|
|
self.assertEqual(str(m4.b0.device), 'cuda:0')
|
|
|
|
# compute and compare the results
|
|
input = torch.rand(2, 3).cuda(torch.cuda.device_count() - 1)
|
|
origin_result = m(input)
|
|
self.assertEqual(origin_result, m2(input))
|
|
self.assertEqual(origin_result, m3(input.cpu()))
|
|
self.assertEqual(origin_result, m4(input.cuda(0)))
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "restore device requires CUDA")
|
|
def test_restore_shared_storage_on_cuda(self):
|
|
whole_tensor = torch.randn(4, 5, dtype=torch.float, device='cpu')
|
|
m = torch.jit.ScriptModule()
|
|
m.p0 = nn.Parameter(whole_tensor.narrow(0, 0, 1))
|
|
m.register_buffer('b0', whole_tensor.narrow(0, 3, 1))
|
|
m2 = self.getExportImportCopy(m, map_location=torch.device('cuda:0'))
|
|
self.assertEqual(tuple(m.parameters()), tuple(m2.parameters()))
|
|
self.assertEqual(tuple(m.buffers()), tuple(m2.buffers()))
|
|
self.assertTrue(m2.p0.is_cuda)
|
|
self.assertTrue(m2.b0.is_cuda)
|
|
self.assertTrue(m2.p0.is_shared())
|
|
self.assertTrue(m2.b0.is_shared())
|
|
self.assertEqual(m2.b0.storage().data_ptr(), m2.p0.storage().data_ptr())
|
|
|
|
def test_typeas_trace_check(self):
|
|
a = torch.tensor([0.4], requires_grad=True)
|
|
b = torch.tensor([0.7], requires_grad=True)
|
|
|
|
def f(x, y):
|
|
return x.type_as(y)
|
|
|
|
trace = torch.jit.trace(f, (a, b))
|
|
|
|
def test_peephole(self):
|
|
a = torch.tensor([0.4])
|
|
b = torch.tensor([0.7])
|
|
c = torch.tensor([0], dtype=torch.int32)
|
|
|
|
def f(x, y):
|
|
return x.type_as(y)
|
|
|
|
tf = torch.jit.trace(f, (a, b))
|
|
FileCheck().check("type_as").run(str(tf.graph))
|
|
self.run_pass('peephole', tf.graph)
|
|
FileCheck().check_not("type_as").run(str(tf.graph))
|
|
tf2 = torch.jit.trace(f, (a, c))
|
|
s = str(tf2.graph)
|
|
self.run_pass('peephole', tf2.graph)
|
|
self.assertEqual(s, str(s))
|
|
|
|
def test_peephole_dynamic(self):
|
|
def f(x, y):
|
|
return x.type_as(y)
|
|
|
|
fn = torch.jit.script(f)
|
|
s = str(fn.graph)
|
|
torch._C._jit_pass_peephole(fn.graph)
|
|
self.assertEqual(s, str(fn.graph))
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
|
|
def test_peephole_cuda(self):
|
|
a = torch.tensor([0.4], device='cpu')
|
|
b = torch.tensor([0.7], device='cuda')
|
|
c = torch.tensor([0.7], device='cuda')
|
|
|
|
def f(x, y):
|
|
return x.type_as(y)
|
|
|
|
trace = torch.jit.trace(f, (a, c))
|
|
s = str(trace.graph)
|
|
self.run_pass('peephole', trace.graph)
|
|
self.assertEqual(s, str(trace.graph))
|
|
trace = torch.jit.trace(f, (b, c))
|
|
self.run_pass('peephole', trace.graph)
|
|
self.assertTrue(len(list(trace.graph.nodes())) == 0)
|
|
|
|
def test_index(self):
|
|
x = torch.tensor([0.4], requires_grad=True)
|
|
y = torch.tensor([0], dtype=torch.int64)
|
|
|
|
def fn(x, y):
|
|
return x[y]
|
|
|
|
fn_traced = torch.jit.trace(fn, (x, y,))
|
|
|
|
self.assertEqual(fn(x, y), fn_traced(x, y))
|
|
|
|
def test_disabled(self):
|
|
torch.jit._enabled = False
|
|
try:
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
self.assertIs(torch.jit.trace(f, (torch.randn(2, 2), torch.randn(2, 2))), f)
|
|
self.assertIs(torch.jit.script(f), f)
|
|
|
|
class MyModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def method(self, x):
|
|
return x
|
|
|
|
# XXX: Unfortunately ScriptModule won't simply become Module now,
|
|
# because that requires disabling the JIT at startup time, which
|
|
# we can't do in here.
|
|
# We need to or those two conditions to make it work with all versions of Python
|
|
self.assertTrue(inspect.ismethod(MyModule.method) or inspect.isfunction(MyModule.method))
|
|
finally:
|
|
torch.jit._enabled = True
|
|
|
|
def test_train_eval(self):
|
|
class Sub(nn.Module):
|
|
def forward(self, input):
|
|
if self.training:
|
|
return input
|
|
else:
|
|
return -input
|
|
|
|
class MyModule(torch.jit.ScriptModule):
|
|
def __init__(self, module):
|
|
super(MyModule, self).__init__()
|
|
self.module = module
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return self.module(input) + 1
|
|
|
|
m = MyModule(Sub())
|
|
input = torch.rand(3, 4)
|
|
self.assertEqual(input + 1, m(input))
|
|
m.eval()
|
|
self.assertEqual(-input + 1, m(input))
|
|
|
|
# test batchnorm and dropout train/eval
|
|
input = torch.randn(6, 10)
|
|
batchnorm = nn.BatchNorm1d(10)
|
|
dropout = nn.Dropout(p=0.2)
|
|
|
|
m_batchnorm = MyModule(batchnorm)
|
|
self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
|
|
batchnorm.eval()
|
|
m_batchnorm.eval()
|
|
self.assertEqual(batchnorm(input) + 1, m_batchnorm(input))
|
|
|
|
m_dropout = MyModule(dropout)
|
|
dropout.eval()
|
|
m_dropout.eval()
|
|
self.assertEqual(dropout(input) + 1, m_dropout(input))
|
|
|
|
def test_diff_subgraph_clones_constants(self):
|
|
@torch.jit.script
|
|
def f(x, y):
|
|
return x + x + y + x + y + x + y + x + y + x
|
|
|
|
def count_constants(graph):
|
|
return sum(node.kind() == 'prim::Constant' for node in graph.nodes())
|
|
|
|
graph = f.graph.copy()
|
|
self.run_pass('cse', graph)
|
|
self.run_pass('create_autodiff_subgraphs', graph)
|
|
nodes = list(graph.nodes())
|
|
self.assertEqual(count_constants(graph), 1)
|
|
self.assertEqual(count_constants(nodes[1].g('Subgraph')), 1)
|
|
|
|
# Backwards tracing was broken for indexing by a constant,
|
|
# because it's internally implemented using as_strided,
|
|
# and we attempted to trace its derivative (which is not
|
|
# currently supported.) It currently works because
|
|
# slice() is now not marked as traceable.
|
|
def test_index_constant(self):
|
|
x = torch.tensor([0.4], requires_grad=True)
|
|
|
|
def fn(x):
|
|
return x[0]
|
|
|
|
def run(f):
|
|
y = f(x)
|
|
grad = torch.autograd.grad(y, x)[0].clone()
|
|
return y, grad
|
|
|
|
traced_fn = torch.jit.trace(fn, torch.ones(1))
|
|
self.assertEqual(run(fn), run(traced_fn))
|
|
|
|
def test_scopes(self):
|
|
x = torch.tensor([0.4], requires_grad=True)
|
|
y = torch.tensor([0.7], requires_grad=True)
|
|
|
|
def f(x, y):
|
|
out = x + y
|
|
with torch.jit.scope('Foo'):
|
|
out = x * out
|
|
with torch.jit.scope('Bar'):
|
|
out = torch.tanh(out)
|
|
out = torch.sigmoid(out)
|
|
return out
|
|
|
|
self.checkTrace(f, (x, y))
|
|
|
|
def test_scopes_intermediate_node(self):
|
|
class Net(nn.Module):
|
|
def forward(self, x):
|
|
return F.log_softmax(x, dim=0)
|
|
|
|
net = Net()
|
|
t = torch.ones(2, requires_grad=True)
|
|
trace, outputs, inputs = torch.jit.get_trace_graph(net, (t,), return_inputs=True)
|
|
self.assertEqual(outputs, self.createScriptModuleFromGraph(trace)(*inputs))
|
|
self.assertExportImport(trace, (t,))
|
|
torch.onnx._optimize_trace(trace, operator_export_type=OperatorExportTypes.ONNX)
|
|
FileCheck().check("onnx::LogSoftmax").check("scope: Net").run(str(trace))
|
|
|
|
def test_scopes_identity_node(self):
|
|
|
|
class Net(nn.Module):
|
|
|
|
def __init__(self):
|
|
super(Net, self).__init__()
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),
|
|
nn.ReLU(inplace=True),
|
|
nn.MaxPool2d(kernel_size=3, stride=2),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
return x
|
|
|
|
model = Net()
|
|
|
|
t = torch.ones(1, 3, 227, 227, requires_grad=True)
|
|
|
|
with torch.onnx.set_training(model, False):
|
|
trace, _ = torch.jit.get_trace_graph(model, (t,))
|
|
|
|
self.assertExportImport(trace, (t,) + tuple(model.parameters()))
|
|
torch.onnx._optimize_trace(trace, operator_export_type=OperatorExportTypes.ONNX)
|
|
FileCheck().check("Net/Sequential[features]/Conv2d[0]").check("ReLU").check("MaxPool").run(str(trace))
|
|
|
|
def test_canonicalize_tensor_iterator(self):
|
|
x = torch.randn(4, 4)
|
|
|
|
def f(x):
|
|
x = x + 2
|
|
x = x - 4
|
|
x = x * 6
|
|
x = x / 8
|
|
return x
|
|
|
|
traced = torch.jit.trace(f, (x,))
|
|
f(x)
|
|
graph = traced.graph_for(x)
|
|
# There should be 4 int constants for the right sides of operators, plus one
|
|
# for the alpha argument for add and sub
|
|
self.assertTrue(str(traced.graph_for(x)).count(': int = prim::Constant') == 5)
|
|
|
|
# TODO: adapt this test to check that GraphExecutor treats them differently
|
|
@unittest.skip("Need to be adjusted to Graph Executor")
|
|
def test_arg_configurations(self):
|
|
"""Different arg configurations should trigger different traces"""
|
|
x = Variable(torch.FloatTensor(4, 4).uniform_())
|
|
x_double = Variable(x.data.double())
|
|
x_grad = Variable(x.data.clone(), requires_grad=True)
|
|
y = Variable(torch.randn(4))
|
|
|
|
configurations = [
|
|
(x,),
|
|
(x_double,),
|
|
(x_grad,),
|
|
(y,),
|
|
([x, x],),
|
|
([x, y],),
|
|
]
|
|
if torch.cuda.is_available():
|
|
x_cuda = Variable(x.data.cuda())
|
|
configurations += [
|
|
(x_cuda,),
|
|
([x, x_cuda],),
|
|
([x_cuda, x],),
|
|
([[x_cuda, x]],),
|
|
]
|
|
if torch.cuda.device_count() > 1:
|
|
x_cuda_1 = Variable(x.data.cuda(1))
|
|
configurations += [
|
|
(x_cuda_1,),
|
|
([x_cuda, x_cuda_1],),
|
|
]
|
|
|
|
@torch.jit.compile(nderivs=0)
|
|
def fn(*args):
|
|
in_vars, _ = torch._C._jit_flatten(args)
|
|
return in_vars[0] + 1
|
|
|
|
for i, config in enumerate(configurations):
|
|
self.assertFalse(fn.has_trace_for(*config))
|
|
fn(*config)
|
|
self.assertTrue(fn.has_trace_for(*config))
|
|
for unk_config in configurations[i + 1:]:
|
|
self.assertFalse(fn.has_trace_for(*unk_config))
|
|
self.assertEqual(fn.hits, 0)
|
|
|
|
def test_cse(self):
|
|
x = torch.tensor([0.4, 0.3], requires_grad=True)
|
|
y = torch.tensor([0.7, 0.5], requires_grad=True)
|
|
|
|
def fn(x, y):
|
|
w = (x + y) * (x + y) * (x + y)
|
|
t = torch.tanh(w) + torch.tanh(w)
|
|
z = (x + y) * (x + y) * (x + y) + t
|
|
return z
|
|
|
|
trace, _ = torch.jit.get_trace_graph(fn, (x, y))
|
|
self.run_pass('cse', trace)
|
|
do_exactly = True
|
|
FileCheck().check_count("add", 1).check_count("mul", 2, do_exactly) \
|
|
.check_count("tanh", 1, do_exactly).check_count("add", 2, do_exactly).check_next("return") \
|
|
.run(str(trace))
|
|
|
|
self.assertExportImport(trace, (x, y))
|
|
|
|
def test_recursive_cse(self):
|
|
x = torch.tensor([0.1])
|
|
y = torch.tensor([0.2])
|
|
|
|
def fn(x, y):
|
|
z = x
|
|
if bool(x + y > x):
|
|
z = x + y
|
|
return z
|
|
|
|
graph = torch.jit.script(fn).graph
|
|
self.run_pass('cse', graph)
|
|
FileCheck().check("block").check_not("aten::add").check_not("aten::gt").run(str(graph))
|
|
|
|
def test_expand_fakequant(self):
|
|
pass
|
|
|
|
def test_expand_propagate_qinfo(self):
|
|
pass
|
|
|
|
def test_expand_insert_observers(self):
|
|
pass
|
|
|
|
def test_expand_insert_fakequant(self):
|
|
pass
|
|
|
|
def test_expand_quantlint(self):
|
|
pass
|
|
|
|
def test_expand_fold_quant_inputs(self):
|
|
pass
|
|
|
|
def test_shape_analysis_broadcast(self):
|
|
def broadcast(a, b):
|
|
return a + b
|
|
|
|
x = torch.randn(3, 1, 5, requires_grad=True)
|
|
y = torch.randn(4, 1, 8, 5, requires_grad=True)
|
|
|
|
graph = torch.jit.script(broadcast).graph
|
|
torch._C._jit_pass_complete_shape_analysis(graph, (x, y), False)
|
|
FileCheck().check("Double(4, 3, 8, 5)").run(str(graph))
|
|
|
|
# TODO: update verify to work with GraphExecutors
|
|
@unittest.skip("verify needs to be updated to work with GraphExecutors")
|
|
def test_verify(self):
|
|
x = torch.tensor([0.4], requires_grad=True)
|
|
y = torch.tensor([0.7], requires_grad=True)
|
|
|
|
@torch.jit.compile
|
|
def f(x, y):
|
|
z = torch.sigmoid(x * (x + y))
|
|
w = torch.abs(x * x * x + y) + Variable(torch.ones(1))
|
|
return z, w
|
|
|
|
torch.jit.verify(f, (x, y), loss_fn=lambda z, w: z * w, devices=[])
|
|
|
|
@suppress_warnings
|
|
def test_constant(self):
|
|
x = torch.randn(2, 2, requires_grad=True)
|
|
|
|
def f(x):
|
|
return x.matmul(torch.diag(torch.tensor([2., 2.])))
|
|
|
|
self.checkTrace(f, (x,), (torch.ones(2, 2, requires_grad=True),))
|
|
|
|
def test_legacy_fail(self):
|
|
class MyLegacyFn(Function):
|
|
def forward(self, x):
|
|
return x
|
|
|
|
def backward(self, grad_output):
|
|
return grad_output
|
|
|
|
x = torch.tensor([0.], requires_grad=True)
|
|
with self.assertRaisesRegex(RuntimeError, "MyLegacyFn"):
|
|
torch.jit.get_trace_graph(lambda x: MyLegacyFn()(x), (x,))
|
|
|
|
def test_inplace_transplant(self):
|
|
x = torch.tensor([0.], requires_grad=True)
|
|
|
|
def fn(x):
|
|
y = x.clone()
|
|
y.add_(2)
|
|
y.add_(3)
|
|
return y
|
|
|
|
trace, _ = torch.jit.get_trace_graph(fn, (x,))
|
|
self.run_pass('dce', trace)
|
|
FileCheck().check_count("aten::clone", 1, exactly=True) \
|
|
.check_count("aten::add_", 2, exactly=True) \
|
|
.check_next("return").run(str(trace))
|
|
self.assertExportImport(trace, (x,))
|
|
|
|
def test_inplace_flags(self):
|
|
class InplaceFn(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
ctx.mark_dirty(x)
|
|
return x.add_(1)
|
|
|
|
@staticmethod
|
|
def backward(ctx, go):
|
|
return go
|
|
|
|
class RegularFn(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x.add(1)
|
|
|
|
@staticmethod
|
|
def backward(ctx, go):
|
|
return go
|
|
|
|
x = torch.tensor([0.], requires_grad=True)
|
|
|
|
def fn(x):
|
|
y = RegularFn.apply(x)
|
|
y = InplaceFn.apply(y)
|
|
y = InplaceFn.apply(y)
|
|
y = RegularFn.apply(y)
|
|
return y
|
|
|
|
trace, _ = torch.jit.get_trace_graph(fn, (x,), _force_outplace=True)
|
|
self.run_pass('dce', trace)
|
|
ops = [n for n in trace.graph().nodes()]
|
|
for op in ops:
|
|
self.assertTrue(op.hasAttribute('inplace'))
|
|
inplace_flags = [False, True, True, False]
|
|
for op, is_inplace in zip(ops, inplace_flags):
|
|
self.assertEqual(op.i('inplace'), is_inplace)
|
|
|
|
def test_inplace_check(self):
|
|
class MyInplaceFn(Function):
|
|
@staticmethod
|
|
def forward(self, x):
|
|
x.add_(1)
|
|
self.mark_dirty(x)
|
|
return x
|
|
|
|
@staticmethod
|
|
def backward(self, grad):
|
|
return grad
|
|
|
|
def fn(x):
|
|
return MyInplaceFn.apply(x)
|
|
|
|
x = torch.randn(5, 5)
|
|
ge = torch._C.GraphExecutor(fn, (x,), lambda var: '', _force_outplace=True)
|
|
with self.assertRaisesRegex(RuntimeError, 'inplace MyInplaceFn'):
|
|
ge(x)
|
|
|
|
def do_trace_size(self, requires_grad):
|
|
def fn(x):
|
|
return x.view(x.shape[1] * 2, x.size(0), 2)
|
|
|
|
x = torch.randn(5, 2, 4, requires_grad=requires_grad)
|
|
y = torch.randn(4, 8, 4, requires_grad=requires_grad)
|
|
|
|
# Check that it behaves as expected
|
|
traced_fn = torch.jit.trace(fn, x)
|
|
self.assertEqual(traced_fn(y), fn(y))
|
|
self.assertEqual(traced_fn(x), fn(x))
|
|
|
|
def test_trace_size(self):
|
|
self.do_trace_size(False)
|
|
|
|
# test the different graph_executor path that happens when
|
|
# gradients are required and sizes are involved
|
|
def test_trace_size_with_grad(self):
|
|
self.do_trace_size(True)
|
|
|
|
def test_trace_casts(self):
|
|
casts = [
|
|
lambda x: x.byte(),
|
|
lambda x: x.float(),
|
|
lambda x: x.cpu(),
|
|
lambda x: x.to(device='cpu'),
|
|
lambda x: x.to(dtype=torch.int64),
|
|
lambda x: x.to(device='cpu', dtype=torch.float),
|
|
lambda x: x.to(x)
|
|
]
|
|
|
|
def assertContainsCast(trace):
|
|
self.assertEqual(sum(n.kind() == 'aten::to' for n in trace.graph.nodes()), 1)
|
|
|
|
for cast in casts:
|
|
trace = torch.jit.trace(cast, torch.randn(2, 2))
|
|
assertContainsCast(trace)
|
|
x = torch.randn(2, 2)
|
|
self.assertEqual(trace(x), cast(x))
|
|
|
|
def to_tensor(x, y):
|
|
return x.to(y)
|
|
|
|
to_tensor_trace = torch.jit.trace(to_tensor, (torch.randn(2, 2), torch.randn(1, 8)))
|
|
assertContainsCast(to_tensor_trace)
|
|
x, y = torch.randn(2, 2), torch.randn(1, 10)
|
|
self.assertEqual(to_tensor_trace(x, y), to_tensor(x, y))
|
|
|
|
def test_trace_warn(self):
|
|
def fn(x):
|
|
int(x) # Warning 1.
|
|
y = x * 1
|
|
if y: # Warning 2.
|
|
pass
|
|
q = [x, x * 4]
|
|
z = q[y] # Warning 3.
|
|
float(z) # Warning 4.
|
|
z.tolist() # Warning 5.
|
|
z.numpy() # Warning 6.
|
|
for _ in torch.ones(4, 4): # Warning 7.
|
|
pass
|
|
return z + 4
|
|
|
|
with warnings.catch_warnings(record=True) as warns:
|
|
traced_fn = torch.jit.trace(fn, torch.tensor([1]))
|
|
warns = [str(w.message) for w in warns]
|
|
self.assertEqual(len(warns), 7)
|
|
self.assertIn('a Python integer', warns[0])
|
|
self.assertIn('a Python boolean', warns[1])
|
|
self.assertIn('a Python index', warns[2])
|
|
self.assertIn('a Python float', warns[3])
|
|
self.assertIn('a Python list', warns[4])
|
|
self.assertIn('a NumPy array', warns[5])
|
|
self.assertIn('Iterating over', warns[6])
|
|
|
|
def test_trace_tuple(self):
|
|
def fn(x, y):
|
|
return x, (x * y[1], x * y[0])
|
|
|
|
x, y = torch.randn(2, 2), (torch.ones(2, 2), torch.randn(2, 2))
|
|
traced_fn = torch.jit.trace(fn, (x, y))
|
|
self.assertEqual(traced_fn(x, y), fn(x, y))
|
|
# should be a tuple nested within another tuple
|
|
FileCheck().check_count("prim::TupleConstruct", 2, exactly=True).check_next("return") \
|
|
.run(str(traced_fn.graph))
|
|
self.assertExportImport(traced_fn.graph, (x, y))
|
|
|
|
def test_trace_random(self):
|
|
def f(mean, std):
|
|
return torch.normal(mean, std)
|
|
|
|
traced = torch.jit.trace(f, (torch.zeros(2, 3), torch.ones(2, 3)), check_trace=False)
|
|
mean, std = torch.zeros(5, 5), torch.ones(5, 5)
|
|
with torch.random.fork_rng(devices=[]):
|
|
output = f(mean, std)
|
|
traced_output = traced(mean, std)
|
|
self.assertEqual(output, traced_output)
|
|
|
|
def test_trace_tensor_factory(self):
|
|
def run(**kwargs):
|
|
inputs_require_grads = kwargs.pop('inputs_require_grads', True)
|
|
|
|
def fn(x):
|
|
return x + torch.ones(2, 3, **kwargs)
|
|
|
|
input_kwargs = kwargs.copy()
|
|
if 'out' in input_kwargs:
|
|
del input_kwargs['out']
|
|
input = torch.ones(2, 3, **input_kwargs)
|
|
self.checkTrace(fn, (input,), inputs_require_grads=inputs_require_grads)
|
|
# check we recorded 'ones' and did not just record a constant
|
|
tfn = torch.jit.trace(fn, input)
|
|
self.assertTrue("ones" in str(tfn.graph))
|
|
run()
|
|
run(dtype=torch.int, inputs_require_grads=False)
|
|
run(out=torch.tensor([]))
|
|
if RUN_CUDA:
|
|
run(device="cuda:0")
|
|
if RUN_CUDA_MULTI_GPU:
|
|
run(device="cuda:1")
|
|
|
|
def test_trace_indexed_assignment(self):
|
|
def stuff(x, y):
|
|
x = x.clone()
|
|
x[0] = y
|
|
return x
|
|
example = torch.rand(3, 4)
|
|
self.checkTrace(stuff, (example, example[0] + 1))
|
|
|
|
# TODO: implement
|
|
@unittest.expectedFailure
|
|
def test_output_unflatten(self):
|
|
"""Check that outputs of traced functions retain the original structure and nesting"""
|
|
def fn(x):
|
|
return (x * 2, (x ** 2, x + 4, (x + 2,), ), x * 4)
|
|
|
|
self.checkTrace(fn, (torch.randn(2, 2),))
|
|
|
|
# TODO: implement
|
|
@unittest.expectedFailure
|
|
def test_input_flatten(self):
|
|
"""Check that inputs to traced functions are flattened"""
|
|
|
|
def fn(x, t):
|
|
y, z = t
|
|
return x * y * z
|
|
|
|
inputs = (torch.randn(1), (torch.randn(1), torch.randn(1)))
|
|
self.checkTrace(fn, inputs)
|
|
|
|
# TODO: adapt to a GraphExecutor test
|
|
@unittest.skip("Need to instrument GraphExecutors a bit more")
|
|
def test_flags(self):
|
|
x, y = torch.randn(2, 2)
|
|
y = Variable(torch.randn(2, 2))
|
|
|
|
@torch.jit.compile
|
|
def fn(x, y):
|
|
return (x * x + y * y + x * y).sum()
|
|
|
|
grads = {}
|
|
for rx, ry in product((True, False), repeat=2):
|
|
x.requires_grad = rx
|
|
y.requires_grad = ry
|
|
|
|
self.assertFalse(fn.has_trace_for(x, y))
|
|
out = fn(x, y)
|
|
|
|
self.assertFalse(fn.has_trace_for(x, y))
|
|
for v, name, compute in [(x, 'x', rx), (y, 'y', ry)]:
|
|
if not compute:
|
|
continue
|
|
grad_v, = torch.autograd.grad(out, v, retain_graph=True)
|
|
expected_grad = grads.setdefault(name, grad_v)
|
|
self.assertEqual(grad_v, expected_grad)
|
|
self.assertEqual(fn.has_trace_for(x, y), rx or ry)
|
|
|
|
def test_python_ir(self):
|
|
x = torch.tensor([0.4], requires_grad=True)
|
|
y = torch.tensor([0.7], requires_grad=True)
|
|
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y)))
|
|
|
|
trace, _ = torch.jit.get_trace_graph(doit, (x, y))
|
|
self.run_pass('dce', trace)
|
|
self.run_pass('canonicalize', trace)
|
|
g = trace.graph()
|
|
g2 = torch._C.Graph()
|
|
g_to_g2 = {}
|
|
for node in g.inputs():
|
|
g_to_g2[node] = g2.addInput()
|
|
for node in g.nodes():
|
|
n_ = g2.createClone(node, lambda x: g_to_g2[x])
|
|
g2.appendNode(n_)
|
|
for o, no in zip(node.outputs(), n_.outputs()):
|
|
g_to_g2[o] = no
|
|
|
|
for node in g.outputs():
|
|
g2.registerOutput(g_to_g2[node])
|
|
|
|
t_node = g2.create("prim::TensorTest").t_("a", torch.ones([2, 2]))
|
|
self.assertEqual(t_node.attributeNames(), ["a"])
|
|
g2.appendNode(t_node)
|
|
self.assertTrue(torch.equal(torch.ones(2, 2), t_node.t("a")))
|
|
for node in g.nodes():
|
|
self.assertTrue(g2.findNode(node.kind()) is not None)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA")
|
|
@skipIfRocm
|
|
def test_cpp_cuda(self):
|
|
from cpp.jit import tests_setup
|
|
tests_setup.setup()
|
|
torch._C._jit_run_cpp_tests()
|
|
tests_setup.shutdown()
|
|
|
|
def test_batchnorm(self):
|
|
x = torch.ones(2, 2, 2, 2)
|
|
trace, outputs, inputs = torch.jit.get_trace_graph(nn.BatchNorm2d(2), x,
|
|
_force_outplace=True, return_inputs=True)
|
|
m = self.createScriptModuleFromGraph(trace)
|
|
self.assertEqual(outputs, m(*inputs))
|
|
|
|
def test_dropout(self):
|
|
x = torch.ones(2, 2)
|
|
with torch.random.fork_rng(devices=[]):
|
|
trace, outputs, inputs = torch.jit.get_trace_graph(nn.Dropout(0.6), x, return_inputs=True)
|
|
with torch.random.fork_rng(devices=[]):
|
|
m = self.createScriptModuleFromGraph(trace)
|
|
self.assertEqual(outputs, m(*inputs))
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "test_dropout_cuda require CUDA")
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@skipIfRocm
|
|
def test_dropout_cuda(self):
|
|
# Dropout AD is dispatched to _fused_dropout in CUDA case,
|
|
# which is not included in TestJitGeneratedFunctional
|
|
x = torch.ones(4, 4).cuda().requires_grad_()
|
|
|
|
@torch.jit.script
|
|
def func(x):
|
|
return torch.nn.functional.dropout(x)
|
|
|
|
with freeze_rng_state():
|
|
out_ref = torch.nn.functional.dropout(x)
|
|
grad_ref = torch.autograd.grad(out_ref.sum(), x)
|
|
|
|
with freeze_rng_state():
|
|
out = func(x)
|
|
grad = torch.autograd.grad(out.sum(), x)
|
|
|
|
self.assertEqual(out, out_ref)
|
|
self.assertEqual(grad, grad_ref)
|
|
|
|
def test_conv(self):
|
|
x = torch.ones(20, 16, 50, 40)
|
|
trace, outputs, inputs = torch.jit.get_trace_graph(nn.Conv2d(16, 13, 3, bias=False), x, return_inputs=True)
|
|
m = self.createScriptModuleFromGraph(trace)
|
|
self.assertEqual(outputs, m(*inputs))
|
|
|
|
def test_repeated_input(self):
|
|
def fn(a, b):
|
|
return a + b
|
|
|
|
ge = self.checkTrace(fn, [torch.randn(2, 2)] * 2)
|
|
inputs = set(ge.graph.inputs())
|
|
self.assertTrue(len(inputs) == 2)
|
|
|
|
def test_repeated_output(self):
|
|
def fn(a, b):
|
|
z = a + b
|
|
return z, z
|
|
|
|
ge = self.checkTrace(fn, [torch.randn(2, 2) for _ in range(2)])
|
|
tuple_output = list(ge.graph.outputs())[0]
|
|
tuple_inputs = list(tuple_output.node().inputs())
|
|
self.assertTrue(tuple_inputs[0] == tuple_inputs[1])
|
|
|
|
@skipIfNoTorchVision
|
|
def test_alexnet(self):
|
|
x = torch.ones(1, 3, 224, 224)
|
|
model = torchvision.models.AlexNet()
|
|
with torch.random.fork_rng(devices=[]):
|
|
trace, outputs, inputs = torch.jit.get_trace_graph(model, x, return_inputs=True)
|
|
self.run_pass('cse', trace)
|
|
m = self.createScriptModuleFromGraph(trace)
|
|
with torch.random.fork_rng(devices=[]):
|
|
self.assertEqual(outputs, m(*inputs))
|
|
|
|
def test_inplace_copy(self):
|
|
x = torch.randn(4, 4, requires_grad=True)
|
|
|
|
def f(x):
|
|
out = Variable(torch.zeros(x.size()))
|
|
out.copy_(x)
|
|
return out
|
|
|
|
trace, outputs, inputs = torch.jit.get_trace_graph(f, (x, ), return_inputs=True)
|
|
self.run_pass('dce', trace)
|
|
m = self.createScriptModuleFromGraph(trace)
|
|
self.assertEqual(outputs, m(*inputs))
|
|
self.assertExportImport(trace, (x,))
|
|
|
|
def test_shared_param(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModule, self).__init__()
|
|
self.b = self.a = nn.Parameter(torch.randn(2, 2))
|
|
|
|
def forward(self, x):
|
|
return x * self.a + self.b
|
|
|
|
m = MyModule()
|
|
trace, _ = torch.jit.get_trace_graph(m, (torch.randn(2, 2),))
|
|
self.run_pass('dce', trace)
|
|
self.assertEqual(len(list(trace.graph().inputs())), 2)
|
|
FileCheck().check("mul").check("add").run(str(trace))
|
|
|
|
def test_trace_c10_ops(self):
|
|
class MyModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super(MyModel, self).__init__()
|
|
|
|
def forward(self, scores, bbox_deltas, im_info, anchors):
|
|
a, b = torch.ops._caffe2.GenerateProposals(
|
|
(scores), (bbox_deltas), (im_info), (anchors),
|
|
2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0,
|
|
)
|
|
return a, b
|
|
model = MyModel()
|
|
A = 4
|
|
H = 10
|
|
W = 8
|
|
img_count = 3
|
|
scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
|
|
bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
|
|
dtype=torch.float32)
|
|
bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
|
|
im_info = torch.ones(img_count, 3, dtype=torch.float32)
|
|
anchors = torch.ones(A, 4, dtype=torch.float32)
|
|
inputs = (scores, bbox_deltas, im_info, anchors)
|
|
traced_model = torch.jit.trace(model, inputs)
|
|
self.assertEqual(traced_model(*inputs), model(*inputs))
|
|
self.assertExportImport(traced_model.graph, (scores, bbox_deltas, im_info, anchors))
|
|
|
|
def test_nested_inplace(self):
|
|
x = torch.randn(2, 2)
|
|
trace, outputs, inputs = torch.jit.get_trace_graph(
|
|
lambda x: F.threshold(x, 0, 0, inplace=True), (x, ), return_inputs=True)
|
|
m = self.createScriptModuleFromGraph(trace)
|
|
self.assertEqual(outputs, m(*inputs))
|
|
FileCheck().check("threshold_").run(str(trace))
|
|
self.assertExportImport(trace, (x,))
|
|
|
|
def run_ge_tests(self, optimize, use_cuda):
|
|
def rand(*args):
|
|
t = torch.rand(*args).float()
|
|
if use_cuda:
|
|
t = t.cuda()
|
|
return t
|
|
self.checkTrace(lambda a, b: a * b + b,
|
|
[rand(1), rand(1)], [rand(2, 3), rand(2, 3)],
|
|
optimize=optimize)
|
|
# trivial identity
|
|
self.checkTrace(lambda a, b: (
|
|
b, a), [rand(1), rand(1)], optimize=optimize)
|
|
|
|
def foo(a):
|
|
t = a * a
|
|
return t * t, 4 * t
|
|
self.checkTrace(foo, [rand(1)], optimize=optimize)
|
|
# unused input
|
|
self.checkTrace(
|
|
lambda a, b: a * a, [rand(1), rand(1)], optimize=optimize,
|
|
allow_unused=True)
|
|
# test outputs that do not get used in grad
|
|
self.checkTrace(foo, [rand(1)], drop=1, optimize=optimize)
|
|
# test autograd fallback
|
|
self.checkTrace(lambda a, b: a * b /
|
|
(a - 2 * b) + b, [rand(1), rand(1)],
|
|
optimize=optimize)
|
|
|
|
def test_ge_unoptimized(self):
|
|
self.run_ge_tests(False, False)
|
|
|
|
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
|
@enable_cpu_fuser
|
|
def test_ge_optimized(self):
|
|
self.run_ge_tests(True, False)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
def test_ge_cuda(self):
|
|
self.run_ge_tests(True, True)
|
|
|
|
# more manual test of graph executor that can be used as a scratchpad
|
|
def test_ge(self):
|
|
def foo(a, b):
|
|
return a * b / (a - b) + b
|
|
V = Variable
|
|
a, b = V(torch.rand(1)), V(torch.rand(1))
|
|
ge = torch._C.GraphExecutor(foo, (a, b), lambda var: '')
|
|
a, b = V(torch.rand(1), requires_grad=True), V(
|
|
torch.rand(1), requires_grad=True)
|
|
r, = ge(a, b)
|
|
da, db = torch.autograd.grad(r + 3, [a, b], create_graph=True)
|
|
|
|
l2 = (da * db + db * db)
|
|
g2result = torch.autograd.grad(l2, [da, db])
|
|
|
|
r = foo(a, b)
|
|
da2, db2 = torch.autograd.grad(r + 3, [a, b], create_graph=True)
|
|
self.assertEqual(da, da2)
|
|
self.assertEqual(db, db2)
|
|
l3 = (da2 * db2 + db2 * db2)
|
|
g2result2 = torch.autograd.grad(l3, [da2, db2])
|
|
self.assertEqual(g2result, g2result2)
|
|
|
|
def test_trace_annotation(self):
|
|
@_trace(torch.rand(1))
|
|
def foo(a):
|
|
return a + a + a
|
|
|
|
x = torch.randn(5, 5)
|
|
self.assertEqual(foo(x), x + x + x)
|
|
|
|
def test_trace_script(self):
|
|
@torch.jit.script
|
|
def func1(x):
|
|
# type: (Tuple[Tensor, Tensor]) -> Tensor
|
|
return x[0] + x[1]
|
|
|
|
@torch.jit.script
|
|
def func2(x):
|
|
# type: (List[Tensor]) -> Tensor
|
|
return x[0] + x[1]
|
|
|
|
a = torch.randn(5)
|
|
b = torch.randn(5)
|
|
|
|
expected = func1((a, b))
|
|
traced = torch.jit.trace(func1, ((a, b),))
|
|
result = traced((a, b))
|
|
self.assertEqual(expected, result)
|
|
|
|
expected = func2((a, b))
|
|
traced = torch.jit.trace(func2, ((a, b),))
|
|
result = traced((a, b))
|
|
self.assertEqual(expected, result)
|
|
|
|
def test_einsum(self):
|
|
def outer(x, y):
|
|
return torch.einsum('i,j->ij', (x, y))
|
|
|
|
traced = torch.jit.trace(outer, (torch.randn(4), torch.randn(5)))
|
|
script = torch.jit.script(outer)
|
|
fns = [traced, script]
|
|
x, y = torch.randn(10), torch.randn(2)
|
|
for fn in [traced, script]:
|
|
self.assertGraphContains(fn.graph, kind='aten::einsum')
|
|
self.assertEqual(fn(x, y), outer(x, y))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "calls .cuda()")
|
|
def test_traced_module_cuda(self):
|
|
class Model(nn.Module):
|
|
def __init__(self, num_features, num_layers):
|
|
super(Model, self).__init__()
|
|
self.num_layers = num_layers
|
|
layers = [[nn.Linear(num_features, num_features), nn.Sigmoid()]
|
|
for _ in range(num_layers)]
|
|
self.submodule = nn.Sequential(*chain(*layers))
|
|
|
|
def forward(self, x):
|
|
for i in range(self.num_layers):
|
|
x = self.submodule[i](x) + x
|
|
return x
|
|
|
|
model = Model(5, 3)
|
|
x = torch.randn(2, 5)
|
|
traced_model = torch.jit.trace(model, x)
|
|
|
|
# We're missing some attributes these modules had initially. Make sure we can
|
|
# still get the __repr__()
|
|
model.__repr__()
|
|
|
|
# XXX: indexing sequentials is broken
|
|
linear_submodule = next(iter(traced_model.submodule._modules.values()))
|
|
|
|
# All attributes that aren't parameters should raise
|
|
with self.assertRaises(AttributeError):
|
|
linear_submodule.in_features
|
|
linear_submodule.weight
|
|
with self.assertRaises(RuntimeError):
|
|
traced_model.asdf = 4
|
|
linear_submodule.weight = nn.Parameter(torch.randn(linear_submodule.weight.shape))
|
|
with self.assertRaises(RuntimeError):
|
|
del linear_submodule.weight
|
|
|
|
# Submodules can't be called
|
|
with self.assertRaises(RuntimeError):
|
|
linear_submodule(x)
|
|
|
|
# Type casts
|
|
linear_submodule.cuda()
|
|
traced_model.float().cuda()
|
|
cuda_out = traced_model(x.float().cuda())
|
|
traced_model.cpu()
|
|
cpu_out = traced_model(x.float())
|
|
self.assertEqual(cpu_out, cuda_out)
|
|
traced_model.to('cuda')
|
|
cuda_out = traced_model(x.float().cuda())
|
|
traced_model.to('cpu')
|
|
cpu_out = traced_model(x.float())
|
|
self.assertEqual(cpu_out, cuda_out)
|
|
traced_model.double()
|
|
|
|
# state_dict + load_state_dict
|
|
state = {k: v.clone() for k, v in traced_model.state_dict().items()}
|
|
new_state = {k: v.clone().fill_(1) for k, v in state.items()}
|
|
out = traced_model(x)
|
|
traced_model.load_state_dict(new_state)
|
|
out_ones = traced_model(x)
|
|
traced_model.load_state_dict(state)
|
|
out_state = traced_model(x)
|
|
self.assertEqual(out, out_state)
|
|
self.assertNotEqual(out, out_ones)
|
|
|
|
def test_export_no_reorder(self):
|
|
def func(a, b):
|
|
return a * b / (a - 2 * b) + b
|
|
|
|
recording_inputs = [torch.tensor([0.55619788169860839844], dtype=torch.float32, requires_grad=True),
|
|
torch.tensor([0.25947844982147216797], dtype=torch.float32, requires_grad=True)]
|
|
|
|
ge1 = torch.jit.trace(func, recording_inputs, optimize=True)
|
|
ge2 = self.getExportImportCopy(ge1)
|
|
|
|
outputs_ge1 = ge1(*recording_inputs)
|
|
outputs_ge2 = ge2(*recording_inputs)
|
|
|
|
grad_ge1 = torch.autograd.grad(outputs_ge1, recording_inputs)
|
|
grad_ge2 = torch.autograd.grad(outputs_ge2, recording_inputs)
|
|
self.assertTrue(outputs_ge1 == outputs_ge2)
|
|
self.assertTrue(grad_ge1 == grad_ge2)
|
|
|
|
def test_python_function(self):
|
|
class MyFn(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x + 1
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output
|
|
|
|
@_trace(torch.zeros(2))
|
|
def fn(x):
|
|
return MyFn.apply(x + 2) + 3
|
|
|
|
x = torch.tensor([1., 2., 3.])
|
|
y = torch.randn(2, 2, requires_grad=True)
|
|
fn(x)
|
|
fn(y)
|
|
|
|
def test_python_function_tup(self):
|
|
class MyFn(Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x + 1, x - 1
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
return grad_output, grad_output
|
|
|
|
@_trace(torch.zeros(2))
|
|
def fn(x):
|
|
a, b = MyFn.apply(x + 2)
|
|
return a + b + 3
|
|
x = torch.tensor([1., 2., 3.])
|
|
y = torch.randn(2, 2, requires_grad=True)
|
|
fn(x)
|
|
fn(y)
|
|
|
|
def test_decompose_addmm(self):
|
|
def does_decompose():
|
|
@torch.jit.script
|
|
def addmm(mat, mat1, mat2, alpha, beta):
|
|
a = mat.addmm(mat1, mat2, alpha=4.20, beta=2.0)
|
|
b = mat.addmm(mat1, mat2, alpha=int(alpha), beta=int(beta))
|
|
|
|
return a + b
|
|
|
|
mat = torch.randn(2, 2)
|
|
mat1 = torch.randn(2, 4)
|
|
mat2 = torch.randn(4, 2)
|
|
alpha = torch.FloatTensor([123.0])
|
|
beta = torch.FloatTensor([321.0])
|
|
|
|
out_ref = addmm(mat, mat1, mat2, alpha, beta)
|
|
self.run_pass('canonicalize_ops', addmm.graph)
|
|
out_test = addmm(mat, mat1, mat2, alpha, beta)
|
|
self.assertEqual(out_ref, out_test)
|
|
FileCheck().check_not("addmm").run(str(addmm.graph))
|
|
|
|
def doesnt_decompose():
|
|
@torch.jit.script
|
|
def addmm(mat, mat1, mat2, alpha, beta):
|
|
a = mat.addmm(mat1, mat2)
|
|
b = mat.addmm(mat1, mat2, alpha=1.0, beta=1.0)
|
|
|
|
orig = str(addm.graph)
|
|
self.run_pass('canonicalize_ops', addmm.graph)
|
|
self.assertTrue(orig == str(addmm.graph))
|
|
|
|
def test_index_put(self):
|
|
ten = torch.zeros(3, 3)
|
|
mask = torch.Tensor([[True, True, True],
|
|
[True, False, False],
|
|
[True, True, False]]).byte()
|
|
|
|
def test_fn(ten, mask):
|
|
ten[mask] = torch.ones(6)
|
|
return ten
|
|
|
|
traced_test_fn = torch.jit.trace(test_fn, (ten, mask))
|
|
|
|
ten = torch.rand(3, 3)
|
|
self.assertEqual(test_fn(ten, mask), traced_test_fn(ten, mask))
|
|
|
|
def test_sparse_tensors_error(self):
|
|
def get_sparse():
|
|
return torch.sparse.FloatTensor(2, 3)
|
|
|
|
@torch.jit.script
|
|
def sparse(input):
|
|
output = get_sparse()
|
|
return output, input
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
|
|
sparse(get_sparse())
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "sparse tensors not supported"):
|
|
sparse(torch.tensor([1]))
|
|
|
|
def test_tuple_specialization(self):
|
|
@torch.jit.script
|
|
def f(t):
|
|
# type: (Tuple[Tensor, Tensor]) -> Tensor
|
|
x, y = t
|
|
return x + y
|
|
|
|
t = torch.randn(2, 2), torch.randn(2, 2)
|
|
f(t)
|
|
graph = f.graph_for(t)
|
|
input_types = list(next(graph.inputs()).type().elements())
|
|
for t in input_types:
|
|
self.assertEqual(t.kind(), 'DimensionedTensorType')
|
|
|
|
def test_constant_prop_simple(self):
|
|
@torch.jit.script
|
|
def constant_prop(input_int):
|
|
# type: (int) -> int
|
|
a = 2 * 3
|
|
b = a + 2
|
|
return b - input_int
|
|
|
|
out_ref = constant_prop(2)
|
|
self.run_pass('constant_propagation', constant_prop.graph)
|
|
out_test = constant_prop(2)
|
|
self.assertEqual(out_ref, out_test)
|
|
graph_str = str(constant_prop.graph)
|
|
self.assertTrue("aten::add" not in graph_str and "aten::mul" not in graph_str)
|
|
const = constant_prop.graph.findNode("prim::Constant").output().toIValue()
|
|
self.assertEqual(const, 8)
|
|
|
|
def test_constant_prop_nested(self):
|
|
@torch.jit.script
|
|
def constant_prop(a):
|
|
b = 2 + 1
|
|
if bool(a < 2):
|
|
c = b + 2
|
|
else:
|
|
c = b - 2
|
|
return c
|
|
out_ref = constant_prop(torch.tensor(2))
|
|
self.run_pass('constant_propagation', constant_prop.graph)
|
|
out_test = constant_prop(torch.tensor(2))
|
|
self.assertEqual(out_ref, out_test)
|
|
if_node = constant_prop.graph.findNode("prim::If")
|
|
for block in if_node.blocks():
|
|
for node in block.nodes():
|
|
self.assertTrue(node.kind() == "prim::Constant")
|
|
|
|
def test_constant_prop_print(self):
|
|
@torch.jit.script
|
|
def constant_prop(input_tensor):
|
|
a = 2 * 3
|
|
print(a)
|
|
b = a + 2
|
|
return b + input_tensor
|
|
|
|
self.run_pass('constant_propagation', constant_prop.graph)
|
|
graph = constant_prop.graph
|
|
print_node = graph.findNode("prim::Print")
|
|
self.assertTrue(print_node.input().toIValue() == 6)
|
|
|
|
def test_constant_prop_rand(self):
|
|
@torch.jit.script
|
|
def constant_prop():
|
|
a = torch.randn([3])
|
|
b = a + 2
|
|
return b
|
|
|
|
self.run_pass('constant_propagation', constant_prop.graph)
|
|
self.assertTrue("aten::randn" in str(constant_prop.graph))
|
|
|
|
def test_constant_prop_none(self):
|
|
@torch.jit.script
|
|
def typed_none():
|
|
# type: () -> Optional[int]
|
|
return None
|
|
|
|
@torch.jit.script
|
|
def constant_prop():
|
|
a = typed_none()
|
|
b = typed_none()
|
|
if (a is None and b is None):
|
|
a = 2
|
|
else:
|
|
a = 1
|
|
return a
|
|
|
|
self.run_pass('constant_propagation', constant_prop.graph)
|
|
graph_str = str(constant_prop.graph)
|
|
self.assertTrue(graph_str.count("prim::Constant") == 1)
|
|
|
|
def test_constant_prop_if_inline(self):
|
|
@torch.jit.script
|
|
def constant_prop():
|
|
cond = True
|
|
a = 1
|
|
if cond:
|
|
a = 1 * 2
|
|
else:
|
|
a = 1 // 0
|
|
return a
|
|
|
|
# testing that 1 // 0 error is not thrownn
|
|
self.run_pass('constant_propagation', constant_prop.graph)
|
|
|
|
def test_trace_records_names(self):
|
|
def foo(bar, baz):
|
|
baz = bar + 3
|
|
quick_brown_fox = torch.neg(baz)
|
|
for _ in range(20):
|
|
yeet = quick_brown_fox - 3.14
|
|
return yeet
|
|
|
|
traced = torch.jit.trace(foo, (torch.rand(3, 3), torch.rand(3, 3)))
|
|
graph_str = str(traced.graph)
|
|
assert 'bar' in graph_str
|
|
assert 'baz' in graph_str
|
|
assert 'quick_brown_fox' in graph_str
|
|
|
|
def test_constant_prop_if_constant(self):
|
|
@torch.jit.script
|
|
def constant_prop(a, b):
|
|
c0 = 1
|
|
c1 = 1
|
|
c2 = 1
|
|
if bool(a): # -> c0, c1
|
|
if bool(b): # -> c0
|
|
if True: # -> c0
|
|
c0 = c0 + 1
|
|
if False:
|
|
c1 = c1 + 1
|
|
c2 = c2 + 1
|
|
else: # -> c0, c1
|
|
c1 = c1 + 1
|
|
|
|
if True: # inlined
|
|
c0 = c0 + 1 # dynamic
|
|
c2 = c2 + 4 # set to 5
|
|
return a + c0 + c1 + c2
|
|
|
|
graph = constant_prop.graph
|
|
self.run_pass('constant_propagation', graph)
|
|
ifs = graph.findAllNodes("prim::If", recurse=False)
|
|
snd_if_inlined = len(ifs) == 1
|
|
self.assertTrue(snd_if_inlined)
|
|
first_if = ifs[0]
|
|
self.assertTrue(first_if.outputsSize() == 2)
|
|
second_if = first_if.findNode("prim::If", recurse=False)
|
|
self.assertTrue(second_if.outputsSize() == 1)
|
|
self.assertTrue(second_if.findNode("prim::If") is None)
|
|
|
|
def test_constant_prop_loop_constant(self):
|
|
@torch.jit.script
|
|
def constant_prop(cond, iter):
|
|
# type: (bool, int) -> int
|
|
b = 0
|
|
while True:
|
|
print("stays")
|
|
for _ in range(2):
|
|
print("stays")
|
|
for _ in range(iter):
|
|
print("stays")
|
|
while cond:
|
|
print("stays")
|
|
while False:
|
|
print("removed")
|
|
for _i in range(0):
|
|
print("removed")
|
|
for _i in range(-4):
|
|
print("removed")
|
|
return b
|
|
|
|
self.run_pass('constant_propagation', constant_prop.graph)
|
|
graph = canonical(constant_prop.graph)
|
|
self.assertTrue(graph.count("removed") == 0)
|
|
self.assertTrue(graph.count("stays") == 1) # constant gets pooled
|
|
self.assertTrue(graph.count("prim::Print") == 4)
|
|
|
|
def test_constant_prop_remove_output(self):
|
|
@torch.jit.script
|
|
def constant_prop(iter):
|
|
# type: (int) -> None
|
|
a = 1
|
|
b = 1
|
|
c = 1
|
|
for i in range(iter):
|
|
if False:
|
|
a = 10
|
|
if i == 5:
|
|
b = 2
|
|
c = 3
|
|
print(a, b, c)
|
|
|
|
graph = constant_prop.graph
|
|
self.run_pass('constant_propagation', graph)
|
|
self.assertTrue(graph.findNode("prim::Loop").outputsSize() == 2)
|
|
|
|
def test_trace_detach(self):
|
|
def foo(x, w):
|
|
return torch.matmul(x, w).detach()
|
|
|
|
traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
|
|
|
|
FileCheck().check("matmul").check("detach").run(str(traced.graph))
|
|
x, w = torch.rand(3, 4), torch.rand(4, 5, requires_grad=True)
|
|
traced_result = traced(x, w)
|
|
self.assertEqual(foo(x, w), traced_result)
|
|
self.assertFalse(traced_result.requires_grad)
|
|
self.assertIsNone(traced_result.grad_fn)
|
|
|
|
def test_trace_detach_inplace(self):
|
|
def foo(x, w):
|
|
y = torch.matmul(x, w)
|
|
y.detach_()
|
|
return y
|
|
|
|
traced = torch.jit.trace(foo, (torch.rand(3, 4), torch.rand(4, 5)))
|
|
|
|
FileCheck().check("matmul").check("detach(").run(str(traced.graph))
|
|
x, w = torch.rand(3, 4), torch.rand(4, 5)
|
|
traced_result = traced(x, w)
|
|
self.assertEqual(foo(x, w), traced_result)
|
|
self.assertFalse(traced_result.requires_grad)
|
|
self.assertIsNone(traced_result.grad_fn)
|
|
|
|
def test_trace_detach_onnx_erase(self):
|
|
class Mod(torch.nn.Module):
|
|
def forward(self, x, w):
|
|
return torch.matmul(x, w).detach()
|
|
|
|
f = io.BytesIO()
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
Mod(), (torch.rand(3, 4), torch.rand(4, 5)), f))
|
|
|
|
def test_trace_slice_full_dim(self):
|
|
def foo(x):
|
|
return x[0:5, 0] + 1.0
|
|
|
|
traced = torch.jit.trace(foo, (torch.rand(5, 4),))
|
|
test_x = torch.rand(6, 3)
|
|
self.assertEqual(foo(test_x), traced(test_x))
|
|
|
|
def test_export_dropout(self):
|
|
test = torch.nn.Dropout()
|
|
test.eval()
|
|
|
|
traced = torch.jit.trace(test, (torch.rand(3, 4),), check_trace=False)
|
|
imported = self.getExportImportCopy(traced)
|
|
x = torch.randn(3, 4)
|
|
self.assertEqual(traced(x), imported(x))
|
|
|
|
def test_onnx_transpose_incomplete_tensor_type(self):
|
|
# Smoke test to get us into the state where we are attempting to export
|
|
# a transpose op, where the input is a TensorType rather than a
|
|
# CompleteTensorType. This would previously not work, since we would
|
|
# take the size of the input and use the length of its sizes as the
|
|
# number of dimensions in the permutation.
|
|
class Foo(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x.contiguous().transpose(0, 1).sum()
|
|
|
|
class TraceMe(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TraceMe, self).__init__()
|
|
self.foo = Foo()
|
|
|
|
def forward(self, x):
|
|
return self.foo(x)
|
|
|
|
tm = TraceMe()
|
|
tm = torch.jit.trace(tm, torch.rand(3, 4))
|
|
example_outputs = (tm(torch.rand(3, 4)),)
|
|
f = io.BytesIO()
|
|
torch.onnx._export(tm, (torch.rand(3, 4),), f, example_outputs=example_outputs)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
def test_cuda_export_restore(self):
|
|
class Sub(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Sub, self).__init__()
|
|
self.weight = nn.Parameter(torch.randn(3, 4))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.mod = Sub()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
return self.mod(v)
|
|
m = M()
|
|
m.cuda()
|
|
m2 = self.getExportImportCopy(m)
|
|
m2.cuda()
|
|
input = torch.rand(3, 4).cuda()
|
|
self.assertEqual(m(input), m2(input))
|
|
|
|
def test_export_batchnorm(self):
|
|
for mode in ['eval', 'train']:
|
|
for clazz in [
|
|
torch.nn.BatchNorm1d(100),
|
|
torch.nn.BatchNorm1d(100, affine=False),
|
|
torch.nn.BatchNorm2d(100),
|
|
torch.nn.BatchNorm2d(100, affine=False)]:
|
|
getattr(clazz, mode)()
|
|
input = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
|
|
torch.randn(20, 100, 35, 45)
|
|
traced = torch.jit.trace(clazz, (input,))
|
|
imported = self.getExportImportCopy(traced)
|
|
x = torch.randn(20, 100) if isinstance(clazz, torch.nn.BatchNorm1d) else \
|
|
torch.randn(20, 100, 35, 45)
|
|
self.assertEqual(traced(x), imported(x))
|
|
|
|
def test_export_rnn(self):
|
|
for clazz in [nn.RNN(10, 20, 2), nn.GRU(10, 20, 2)]:
|
|
class RNNTest(torch.nn.Module):
|
|
def __init__(self):
|
|
super(RNNTest, self).__init__()
|
|
self.rnn = clazz
|
|
|
|
def forward(self, x, lengths, h0):
|
|
packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
|
|
out, h = self.rnn(packed, h0)
|
|
padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
|
|
return padded_outs
|
|
|
|
test = RNNTest()
|
|
|
|
traced = torch.jit.trace(test, (torch.randn(5, 3, 10), torch.LongTensor([3, 2, 1]), torch.randn(2, 3, 20)))
|
|
imported = self.getExportImportCopy(traced)
|
|
# NB: We make sure to pass in a batch with a different max sequence
|
|
# length to ensure that the argument stashing for pad_packed works
|
|
# properly.
|
|
x, lengths, h0 = torch.randn(7, 4, 10), torch.LongTensor([7, 3, 2, 1]), torch.randn(2, 4, 20)
|
|
self.assertEqual(traced(x, lengths, h0), imported(x, lengths, h0))
|
|
|
|
def test_export_lstm(self):
|
|
class LSTMTest(torch.nn.Module):
|
|
def __init__(self):
|
|
super(LSTMTest, self).__init__()
|
|
self.rnn = nn.LSTM(10, 20, 2)
|
|
|
|
def forward(self, x, lengths, hiddens):
|
|
h0, c0 = hiddens
|
|
packed = torch.nn.utils.rnn.pack_padded_sequence(x, lengths)
|
|
out, (h, c) = self.rnn(packed, (h0, c0))
|
|
padded_outs, _ = torch.nn.utils.rnn.pad_packed_sequence(out)
|
|
return padded_outs
|
|
|
|
test = LSTMTest()
|
|
|
|
traced = torch.jit.trace(test, (torch.randn(5, 3, 10),
|
|
torch.LongTensor([3, 2, 1]),
|
|
(torch.randn(2, 3, 20), torch.randn(2, 3, 20))))
|
|
imported = self.getExportImportCopy(traced)
|
|
x, lengths, h0, c0 = \
|
|
torch.randn(7, 3, 10), torch.LongTensor([7, 5, 2]), torch.randn(2, 3, 20), torch.randn(2, 3, 20)
|
|
self.assertEqual(traced(x, lengths, (h0, c0)), imported(x, lengths, (h0, c0)))
|
|
|
|
def test_trace_dict_input(self):
|
|
class Bar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Bar, self).__init__()
|
|
self.foo = Foo()
|
|
|
|
def forward(self, a, b):
|
|
return self.foo({'a': a, 'b': b})['a']
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return {'a': x['a'] * x['b']}
|
|
|
|
x = (torch.rand(3), torch.rand(3))
|
|
model = Bar()
|
|
self.checkTrace(model, x)
|
|
|
|
def test_trace_variable_instantiation(self):
|
|
def random_foo(x):
|
|
return Variable(Variable(x) + 1.0)
|
|
|
|
random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
|
|
|
|
x = torch.rand(5, 6)
|
|
self.assertEqual(random_foo(x), random_foo_traced(x))
|
|
|
|
def test_trace_slice_expr_complete_type(self):
|
|
def random_foo(x):
|
|
return x + 1.0
|
|
|
|
random_foo_traced = torch.jit.trace(random_foo, (torch.rand(3, 4),))
|
|
|
|
@torch.jit.script
|
|
def random_bar(x):
|
|
return random_foo_traced(x)[0:1]
|
|
|
|
x = torch.rand(3, 4)
|
|
self.assertEqual(random_bar(x), (x + 1)[0:1])
|
|
|
|
def test_export_tensoroption_to(self):
|
|
def foo(x):
|
|
return x.new_tensor(x[0]).cpu() + x
|
|
|
|
traced = torch.jit.trace(foo, (torch.rand([2])))
|
|
example_outputs = traced(torch.rand([2]))
|
|
|
|
f = io.BytesIO()
|
|
self.assertExpected(torch.onnx._export_to_pretty_string(traced, (torch.rand([2]),), f,
|
|
example_outputs=example_outputs))
|
|
|
|
def test_pretty_printer(self):
|
|
@torch.jit.script
|
|
def if_test(a, b):
|
|
# FIXME: use 0 instead of a.
|
|
# c = 0
|
|
c = a
|
|
if bool(a < b):
|
|
c = b
|
|
else:
|
|
c = a
|
|
return c
|
|
|
|
@torch.jit.script
|
|
def if_one(a, b):
|
|
c = b
|
|
if bool(a < b):
|
|
c = a
|
|
return c
|
|
|
|
@torch.jit.script
|
|
def while_test(a, i):
|
|
while bool(i < 3):
|
|
a *= a
|
|
i += 1
|
|
return a
|
|
|
|
@torch.jit.script
|
|
def while_if_test(a, b):
|
|
c = 0
|
|
while bool(a < 10):
|
|
a = a + 1
|
|
b = b + 1
|
|
if bool(a > b):
|
|
c = 2
|
|
else:
|
|
c = 3
|
|
return a + 1 + c
|
|
|
|
@torch.jit.script
|
|
def loop_use_test(y):
|
|
x = y + 1
|
|
z = x + 5
|
|
while bool(y < 8):
|
|
y += 1
|
|
z = x
|
|
return x, z
|
|
|
|
def python_fn(x):
|
|
return x + 10
|
|
|
|
@torch.jit.script
|
|
def python_op_name_test(y):
|
|
return python_fn(y)
|
|
|
|
@torch.jit.script
|
|
def empty_int_list_test(y):
|
|
x = torch.jit.annotate(List[int], [])
|
|
return x[0]
|
|
|
|
@torch.jit.script
|
|
def empty_float_list_test(y):
|
|
return [1.0, 2.0, 3.0]
|
|
|
|
@torch.jit.script
|
|
def print_weird_test(y):
|
|
print("hi\016")
|
|
|
|
self.assertExpected(if_test.graph.pretty_print(), "if_test")
|
|
self.assertExpected(if_one.graph.pretty_print(), "if_one")
|
|
self.assertExpected(while_test.graph.pretty_print(), "while_test")
|
|
self.assertExpected(while_if_test.graph.pretty_print(), "while_if_test")
|
|
self.assertExpected(loop_use_test.graph.pretty_print(), "loop_use_test")
|
|
self.assertExpected(python_op_name_test.graph.pretty_print(), "python_op_name_test")
|
|
self.assertExpected(empty_int_list_test.graph.pretty_print(), "empty_int_list_test")
|
|
self.assertExpected(empty_float_list_test.graph.pretty_print(), "empty_float_list_test")
|
|
self.assertExpected(print_weird_test.graph.pretty_print(), "print_weird_test")
|
|
|
|
def test_cu_escaped_number(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(a):
|
|
print("hi\016")
|
|
''')
|
|
self.assertExpected(cu.foo.graph.pretty_print())
|
|
|
|
def test_import_method(self):
|
|
@torch.jit.script
|
|
def foo(x, y):
|
|
return 2 * x + y
|
|
|
|
r, _ = foo._python_print()
|
|
mod = torch.jit.ScriptModule()
|
|
torch._C._jit_import_methods(mod, "op_version_set = 0\n{}".format(r), [])
|
|
self.assertExpected(mod.graph.pretty_print())
|
|
|
|
def test_function_default_values(self):
|
|
outer_var = torch.tensor(20)
|
|
outer_var2 = torch.tensor(30)
|
|
a = torch.tensor(0.5)
|
|
b = torch.tensor(10)
|
|
|
|
@torch.jit.script
|
|
def simple_fn(x, a=a, b=b, c=outer_var + outer_var2):
|
|
return x + a + b + c
|
|
|
|
self.assertEqual(
|
|
simple_fn(torch.ones(1)),
|
|
torch.ones(1) + 0.5 + 10 + (20 + 30))
|
|
self.assertEqual(
|
|
simple_fn(torch.ones(1), torch.tensor(1), torch.tensor(3), torch.tensor(4)),
|
|
torch.ones(1) + 1 + 3 + 4)
|
|
|
|
outer_c = torch.tensor(9)
|
|
outer_flag = torch.tensor(False)
|
|
|
|
@torch.jit.script
|
|
def bool_fn(x, a=outer_c, flag=outer_flag):
|
|
if bool(flag):
|
|
result = x
|
|
else:
|
|
result = x + a
|
|
return result
|
|
|
|
self.assertEqual(bool_fn(torch.ones(1)), torch.ones(1) + 9)
|
|
self.assertEqual(
|
|
bool_fn(torch.ones(1), torch.tensor(1), torch.tensor(True)),
|
|
torch.ones(1))
|
|
|
|
@torch.jit.script
|
|
def none_fn(x=None):
|
|
# type: (Optional[int]) -> Optional[int]
|
|
return x
|
|
|
|
self.assertEqual(none_fn(), None)
|
|
self.assertEqual(none_fn(1), 1)
|
|
|
|
@torch.jit.script
|
|
def hints(x, a=0.5, b=10):
|
|
# type: (Tensor, float, int) -> Tensor
|
|
return x + a + b
|
|
|
|
self.assertEqual(hints(torch.ones(1)), torch.ones(1) + 0.5 + 10)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Expected a default value"):
|
|
|
|
@torch.jit.script
|
|
def hints_bad_types(x, a=10, b=0.5): # noqa: T484
|
|
# type: (Tensor, float, int) -> Tensor
|
|
return x + a + b
|
|
|
|
def test_module_default_values(self):
|
|
four = torch.tensor(4)
|
|
|
|
class Test(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Test, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input, other=four):
|
|
return input + other
|
|
|
|
t = Test()
|
|
self.assertEqual(t(torch.ones(1)), torch.ones(1) + 4)
|
|
|
|
def test_warnings(self):
|
|
import warnings
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
if bool(x < 2):
|
|
warnings.warn("x is less than 2")
|
|
return x
|
|
|
|
FileCheck().check("aten::warn").run(str(fn.graph))
|
|
|
|
def test_no_erroneous_warnings(self):
|
|
import warnings
|
|
|
|
def fn(x):
|
|
if bool(x > 0):
|
|
warnings.warn('This should NOT be printed')
|
|
x += 1
|
|
return x
|
|
|
|
with warnings.catch_warnings(record=True) as warns:
|
|
fn_script = torch.jit.script(fn)
|
|
fn_script(torch.tensor(0))
|
|
warns = [str(w.message) for w in warns]
|
|
self.assertEqual(len(warns), 0)
|
|
|
|
@unittest.skipIf(sys.platform == "win32", "TODO: need to fix this test case for Windows")
|
|
def test_torch_load_error(self):
|
|
class J(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(J, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input + 100
|
|
|
|
j = J()
|
|
with tempfile.NamedTemporaryFile() as f:
|
|
j.save(f.name)
|
|
with self.assertRaisesRegex(RuntimeError, "is a zip"):
|
|
torch.load(f.name)
|
|
|
|
def test_legacy_constructors(self):
|
|
def fn(x):
|
|
return x.new_zeros(5, 5, requires_grad=False)
|
|
|
|
with warnings.catch_warnings(record=True) as warns:
|
|
torch.jit.trace(fn, (torch.ones(2, 2)))
|
|
warns = [str(w.message) for w in warns]
|
|
self.assertEqual(len(warns), 1)
|
|
self.assertEqual(warns[0], "new_zeros is a legacy constructor and is not supported in the JIT.")
|
|
|
|
def test_python_bindings(self):
|
|
lstm_cell = torch.jit.script(LSTMCellS)
|
|
|
|
def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
|
|
for i in range(x.size(0)):
|
|
hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
|
|
return hx
|
|
|
|
slstm = torch.jit.script(lstm)
|
|
|
|
inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
|
|
slstm(*inputs).sum().backward()
|
|
global fw_graph
|
|
fw_graph = slstm.graph_for(*inputs)
|
|
nodes = [n for n in fw_graph.nodes()]
|
|
tested_blocks = False
|
|
for node in nodes:
|
|
for output in [o for o in node.outputs()]:
|
|
self.assertTrue(hasattr(output, 'type'))
|
|
self.assertTrue(output.type() is not None)
|
|
for input in [i for i in node.inputs()]:
|
|
self.assertTrue(hasattr(input, 'type'))
|
|
self.assertTrue(input.type() is not None)
|
|
for block in [b for b in node.blocks()]:
|
|
tested_blocks = True
|
|
self.assertTrue(hasattr(block, 'inputs'))
|
|
self.assertTrue(hasattr(block, 'outputs'))
|
|
for output in [o for o in block.outputs()]:
|
|
self.assertTrue(hasattr(output, 'type'))
|
|
self.assertTrue(output.type() is not None)
|
|
for input in [i for i in block.inputs()]:
|
|
self.assertTrue(hasattr(input, 'type'))
|
|
self.assertTrue(input.type() is not None)
|
|
self.assertTrue(hasattr(block, 'returnNode'))
|
|
self.assertTrue(type(block.returnNode()) == torch._C.Node)
|
|
self.assertTrue(hasattr(block, 'paramNode'))
|
|
self.assertTrue(type(block.paramNode()) == torch._C.Node)
|
|
self.assertTrue(tested_blocks)
|
|
|
|
|
|
class TestBatched(TestCase):
|
|
# generate random examples and create an batchtensor with them
|
|
def rand_batch(self, *dims):
|
|
dims = [dim for dim in dims if dim != ()]
|
|
xs = [torch.rand(1, *(random.randint(1, size) if b else size for b, size in dims[1:]),
|
|
requires_grad=True) for i in range(dims[0])]
|
|
xb = BatchTensor(xs, torch.tensor([b for b, d in dims[1:]]).byte())
|
|
return xs, xb
|
|
|
|
def test_create_batchtensor(self):
|
|
# create from tensorlist
|
|
xs, batch = self.rand_batch(4, (True, 3), (False, 2), (True, 5))
|
|
self.assertEqual(xs, batch.examples())
|
|
# create from data, mask, dims
|
|
batch2 = BatchTensor(batch.get_data(), batch.get_mask(), batch.get_dims())
|
|
self.assertEqual(xs, batch2.examples())
|
|
# expand a tensor to a batchtensor given batch_size
|
|
xs = torch.rand(3, 4, 5)
|
|
batch3 = BatchTensor(xs, 2)
|
|
xs = xs.unsqueeze(0)
|
|
self.assertEqual([xs, xs], batch3.examples())
|
|
|
|
def test_batch_elementwise_unary(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def tanh(a):
|
|
return torch.tanh(a)
|
|
|
|
xs, batch = self.rand_batch(4, (True, 3), (False, 2))
|
|
res_batch = tanh(batch)
|
|
res = [torch.tanh(xs[j]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_batch_elementwise_binary(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def add(a, b):
|
|
return a + b
|
|
|
|
xs, batch = self.rand_batch(4, (True, 3), (False, 2))
|
|
xs2, batch2 = xs, batch
|
|
res_batch = add(batch, batch2)
|
|
res = [torch.add(xs[j], xs2[j]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
# test broadcast
|
|
xs, batch = self.rand_batch(4, (False, 3), (False, 2))
|
|
b = torch.rand(3, 2)
|
|
res_batch = add(batch, b)
|
|
res = [torch.add(xs[j], b) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_batch_mm(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def mm(a, b):
|
|
return torch.mm(a, b)
|
|
|
|
xs, batch = self.rand_batch(4, (True, 3), (False, 2))
|
|
xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
|
|
res_batch = mm(batch, batch2)
|
|
res = [torch.mm(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
# test broadcast
|
|
b = torch.rand(2, 4)
|
|
res_batch = mm(batch, b)
|
|
res = [torch.mm(xs[j].squeeze(0), b).unsqueeze(0) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_batch_matmul(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def matmul(a, b):
|
|
return torch.matmul(a, b)
|
|
|
|
def matmul_test(xs, batch, xs2, batch2):
|
|
ys = [torch.matmul(xs[j].squeeze(0), xs2[j].squeeze(0)).unsqueeze(0) for j in range(4)]
|
|
ybs = matmul(batch, batch2)
|
|
self.assertEqual(ys, ybs.examples())
|
|
|
|
# 1 dimension * 1 dimension
|
|
xs, batch = self.rand_batch(4, (False, 2))
|
|
xs2, batch2 = self.rand_batch(4, (False, 2))
|
|
matmul_test(xs, batch, xs2, batch2)
|
|
# 1 dimension * 2 dimension
|
|
xs, batch = self.rand_batch(4, (False, 2))
|
|
xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
|
|
matmul_test(xs, batch, xs2, batch2)
|
|
# 2 dimension * 1 dimensions
|
|
xs, batch = self.rand_batch(4, (True, 3), (False, 2))
|
|
xs2, batch2 = self.rand_batch(4, (False, 2))
|
|
matmul_test(xs, batch, xs2, batch2)
|
|
# 2 dimension * 2 dimension
|
|
xs, batch = self.rand_batch(4, (True, 3), (False, 2))
|
|
xs2, batch2 = self.rand_batch(4, (False, 2), (True, 3))
|
|
matmul_test(xs, batch, xs2, batch2)
|
|
|
|
def test_batch_select(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def select(x):
|
|
return torch.select(x, 1, 0)
|
|
|
|
xs, batch = self.rand_batch(4, (True, 3), (True, 2))
|
|
res_batch = select(batch)
|
|
res = [torch.select(xs[j], 1, 0) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
xs, batch = self.rand_batch(4, (False, 3), (True, 2))
|
|
res_batch = select(batch)
|
|
res = [torch.select(xs[j], 1, 0) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_batch_index_select(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def index_select(x, ind):
|
|
return x.index_select(1, ind)
|
|
|
|
xs, batch = self.rand_batch(4, (False, 5), (True, 2))
|
|
ind = [torch.randint(0, 4, (1,), dtype=torch.long) for i in range(4)]
|
|
ind_batch = BatchTensor(ind, torch.tensor([]).byte())
|
|
res_batch = index_select(batch, ind_batch)
|
|
res = [torch.index_select(xs[j], 1, ind[j]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_batch_where(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def where(c, a, b):
|
|
return torch.where(c, a, b)
|
|
|
|
xs, batch = self.rand_batch(4, (False, 3), (False, 2))
|
|
xs2, batch2 = self.rand_batch(4, (False, 3), (False, 2))
|
|
|
|
dims = [4, (False, 3), (False, 2)]
|
|
xs_cond = [torch.rand(1, 3, 2).byte() for i in range(dims[0])]
|
|
batch_cond = BatchTensor(xs_cond, torch.tensor([b for b, d in dims[1:]]))
|
|
|
|
res_batch = where(batch_cond, batch, batch2)
|
|
res = [torch.where(xs_cond[j], xs[j], xs2[j]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_batch_argmax(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def argmax(a):
|
|
return torch.argmax(a, 1)
|
|
|
|
xs, batch = self.rand_batch(4, (True, 5), (True, 6))
|
|
res_batch = argmax(batch)
|
|
res = [torch.argmax(xs[j], 1) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
@torch.jit.batch(batch_size=4)
|
|
def argmax(a):
|
|
return torch.argmax(a, 1, False)
|
|
|
|
res_batch = argmax(batch)
|
|
res = [torch.argmax(xs[j], 1, False) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_batch_topk(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def topk(a):
|
|
return torch.topk(a, 3, 1)
|
|
|
|
xs, batch = self.rand_batch(4, (False, 5), (True, 6))
|
|
|
|
# along static dim
|
|
res_batch = topk(batch)
|
|
res = [torch.topk(xs[j], 3, 1)[0] for j in range(4)]
|
|
res_idx = [torch.topk(xs[j], 3, 1)[1] for j in range(4)]
|
|
self.assertEqual(res, res_batch[0].examples())
|
|
self.assertEqual(res_idx, res_batch[1].examples())
|
|
|
|
@torch.jit.batch(batch_size=4)
|
|
def topk(a):
|
|
return torch.topk(a, 1, 2)
|
|
|
|
# along dynamic dim
|
|
res_batch = topk(batch)
|
|
res = [torch.topk(xs[j], 1, 2)[0] for j in range(4)]
|
|
res_idx = [torch.topk(xs[j], 1, 2)[1] for j in range(4)]
|
|
self.assertEqual(res, res_batch[0].examples())
|
|
self.assertEqual(res_idx, res_batch[1].examples())
|
|
|
|
def test_batch_softmax(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def softmax(a):
|
|
return torch.softmax(a, 1)
|
|
|
|
xs, batch = self.rand_batch(4, (False, 5), (True, 6))
|
|
|
|
# along static dim
|
|
res_batch = softmax(batch)
|
|
res = [torch.softmax(xs[j], 1) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
@torch.jit.batch(batch_size=4)
|
|
def softmax(a):
|
|
return torch.softmax(a, 2)
|
|
|
|
# along dynamic dim
|
|
res_batch = softmax(batch)
|
|
res = [torch.softmax(xs[j], 2) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_batch_view(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def view(a):
|
|
return a.view([4, -1, 3])
|
|
|
|
xs, batch = self.rand_batch(4, (True, 5), (False, 3))
|
|
res_batch = view(batch)
|
|
res = [xs[j].view([1, -1, 3]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_batch_cat(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def cat2(a, b):
|
|
return torch.cat([a, b], 2)
|
|
|
|
xs, batch = self.rand_batch(4, (True, 5), (False, 3))
|
|
xs2, batch2 = xs, batch
|
|
res_batch = cat2(batch, batch2)
|
|
res = [torch.cat([xs[j], xs2[j]], 2) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_batch_sum(self):
|
|
@torch.jit.batch(batch_size=4)
|
|
def batch_sum(a):
|
|
return a.sum()
|
|
|
|
xs, batch = self.rand_batch(4, (True, 5), (False, 3))
|
|
res_batch = batch_sum(batch)
|
|
res = [xs[j].sum().unsqueeze(0) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
def test_if_else(self):
|
|
def single_if(a, b):
|
|
if bool(a > b):
|
|
a = a + b
|
|
else:
|
|
a = a - b
|
|
return a
|
|
|
|
batch_if = torch.jit.batch(batch_size=4)(single_if)
|
|
|
|
a, batch_a = self.rand_batch(4, ())
|
|
b, batch_b = self.rand_batch(4, ())
|
|
res_batch = batch_if(batch_a, batch_b)
|
|
res = [single_if(a[j], b[j]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
script_if = torch.jit.script(single_if)
|
|
torch.to_batch_graph(script_if.graph)
|
|
|
|
def test_if_else_with_scalar(self):
|
|
def single_if(a, b):
|
|
if bool(a > 0.1):
|
|
a = a + b
|
|
else:
|
|
a = a - b
|
|
return a
|
|
|
|
batch_if = torch.jit.batch(batch_size=4)(single_if)
|
|
|
|
a, batch_a = self.rand_batch(4, ())
|
|
b, batch_b = self.rand_batch(4, ())
|
|
res_batch = batch_if(batch_a, batch_b)
|
|
res = [single_if(a[j], b[j]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
script_if = torch.jit.script(single_if)
|
|
torch.to_batch_graph(script_if.graph)
|
|
|
|
def test_if_noelse(self):
|
|
def single_if(a, b):
|
|
if bool(a > b):
|
|
a = a + b
|
|
return a
|
|
|
|
batch_if = torch.jit.batch(batch_size=4)(single_if)
|
|
|
|
a, batch_a = self.rand_batch(4, ())
|
|
b, batch_b = self.rand_batch(4, ())
|
|
res_batch = batch_if(batch_a, batch_b)
|
|
res = [single_if(a[j], b[j]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
script_if = torch.jit.script(single_if)
|
|
torch.to_batch_graph(script_if.graph)
|
|
|
|
def test_if_noelse_with_scalar(self):
|
|
def single_if(a, b):
|
|
if bool(a > 0.1):
|
|
a = a + b
|
|
return a
|
|
|
|
batch_if = torch.jit.batch(batch_size=4)(single_if)
|
|
|
|
a, batch_a = self.rand_batch(4, ())
|
|
b, batch_b = self.rand_batch(4, ())
|
|
res_batch = batch_if(batch_a, batch_b)
|
|
res = [single_if(a[j], b[j]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
script_if = torch.jit.script(single_if)
|
|
torch.to_batch_graph(script_if.graph)
|
|
|
|
def test_while(self):
|
|
def single_while(a, b):
|
|
while bool(a > b):
|
|
a = a - b
|
|
return a
|
|
|
|
batch_while = torch.jit.batch(batch_size=4)(single_while)
|
|
|
|
a, batch_a = self.rand_batch(4, ())
|
|
b = [torch.abs(torch.rand(1)) for i in range(4)]
|
|
batch_b = BatchTensor(b, torch.tensor([]).byte())
|
|
res_batch = batch_while(batch_a, batch_b)
|
|
res = [single_while(a[j], b[j]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
script_while = torch.jit.script(single_while)
|
|
torch.to_batch_graph(script_while.graph)
|
|
|
|
def test_for(self):
|
|
def single_for(x, y):
|
|
for _ in range(10):
|
|
x = x + y
|
|
return x
|
|
|
|
batch_for = torch.jit.batch(batch_size=4)(single_for)
|
|
|
|
a, batch_a = self.rand_batch(4, ())
|
|
b, batch_b = self.rand_batch(4, ())
|
|
res_batch = batch_for(batch_a, batch_b)
|
|
res = [single_for(a[j], b[j]) for j in range(4)]
|
|
self.assertEqual(res, res_batch.examples())
|
|
|
|
script_for = torch.jit.script(single_for)
|
|
torch.to_batch_graph(script_for.graph)
|
|
|
|
def test_lstm(self):
|
|
def LSTM(x_all, h, c, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c):
|
|
for i in range(x_all.size(1)):
|
|
x = x_all.select(1, i)
|
|
i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
|
|
f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
|
|
o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
|
|
# activations
|
|
i_t = torch.sigmoid(i_t)
|
|
f_t = torch.sigmoid(f_t)
|
|
o_t = torch.sigmoid(o_t)
|
|
# cell computations
|
|
c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
|
|
c_t = torch.tanh(c_t)
|
|
c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
|
|
h_t = torch.mul(o_t, torch.tanh(c_t))
|
|
h = h_t
|
|
c = c_t
|
|
return h
|
|
|
|
LSTM_batch = torch.jit.batch(batch_size=4)(LSTM)
|
|
|
|
batch_size, input_size, hidden_size = 4, 3, 2
|
|
xs, batch = self.rand_batch(batch_size, (True, 4), (False, input_size))
|
|
hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
|
|
cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
|
|
|
|
# input to hidden weights
|
|
w_xi = torch.rand(input_size, hidden_size)
|
|
w_xf = torch.rand(input_size, hidden_size)
|
|
w_xo = torch.rand(input_size, hidden_size)
|
|
w_xc = torch.rand(input_size, hidden_size)
|
|
# hidden to hidden weights
|
|
w_hi = torch.rand(hidden_size, hidden_size)
|
|
w_hf = torch.rand(hidden_size, hidden_size)
|
|
w_ho = torch.rand(hidden_size, hidden_size)
|
|
w_hc = torch.rand(hidden_size, hidden_size)
|
|
# bias terms
|
|
b_i = torch.rand(hidden_size)
|
|
b_f = torch.rand(hidden_size)
|
|
b_o = torch.rand(hidden_size)
|
|
b_c = torch.rand(hidden_size)
|
|
|
|
ys = [LSTM(xs[j], hx[j], cx[j], w_xi, w_xf, w_xo, w_xc,
|
|
w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c) for j in range(batch_size)]
|
|
ybs = LSTM_batch(batch, h_batch, c_batch, w_xi, w_xf, w_xo, w_xc,
|
|
w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c)
|
|
self.assertEqual(ys, ybs.examples())
|
|
|
|
def test_greedy_search(self):
|
|
def greedy(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
|
|
b_i, b_f, b_o, b_c, w_hs, b_s, iter_num):
|
|
iter_count = torch.zeros_like(iter_num)
|
|
while bool(iter_count < iter_num):
|
|
iter_count = iter_count + 1
|
|
# LSTM Cell
|
|
i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
|
|
f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
|
|
o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
|
|
# activations
|
|
i_t = torch.sigmoid(i_t)
|
|
f_t = torch.sigmoid(f_t)
|
|
o_t = torch.sigmoid(o_t)
|
|
# cell computations
|
|
c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
|
|
c_t = torch.tanh(c_t)
|
|
c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
|
|
h_t = torch.mul(o_t, torch.tanh(c_t))
|
|
h = h_t
|
|
c = c_t
|
|
# calculate feature with max probability
|
|
s_t = torch.matmul(h_t, w_hs) + b_s
|
|
p_t = torch.softmax(s_t, 1)
|
|
i_t = torch.argmax(p_t, 1)
|
|
x = embed.index_select(1, i_t).squeeze(1)
|
|
return h
|
|
|
|
greedy_batch = torch.jit.batch(batch_size=4)(greedy)
|
|
|
|
batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
|
|
xs, batch = self.rand_batch(batch_size, (False, input_size))
|
|
hx, h_batch = self.rand_batch(batch_size, (False, hidden_size))
|
|
cx, c_batch = self.rand_batch(batch_size, (False, hidden_size))
|
|
embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
|
|
iter_num = [torch.randint(2, 5, (1,)) for i in range(batch_size)]
|
|
iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
|
|
|
|
# input to hidden weights
|
|
w_xi = torch.rand(input_size, hidden_size)
|
|
w_xf = torch.rand(input_size, hidden_size)
|
|
w_xo = torch.rand(input_size, hidden_size)
|
|
w_xc = torch.rand(input_size, hidden_size)
|
|
# hidden to hidden weights
|
|
w_hi = torch.rand(hidden_size, hidden_size)
|
|
w_hf = torch.rand(hidden_size, hidden_size)
|
|
w_ho = torch.rand(hidden_size, hidden_size)
|
|
w_hc = torch.rand(hidden_size, hidden_size)
|
|
# bias terms
|
|
b_i = torch.rand(hidden_size)
|
|
b_f = torch.rand(hidden_size)
|
|
b_o = torch.rand(hidden_size)
|
|
b_c = torch.rand(hidden_size)
|
|
# hidden to vocab weights, bias
|
|
w_hs = torch.rand(hidden_size, vocab_size)
|
|
b_s = torch.rand(vocab_size)
|
|
|
|
ys = [greedy(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc,
|
|
w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j]) for j in range(batch_size)]
|
|
ybs = greedy_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
|
|
w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch)
|
|
self.assertEqual(ys, ybs.examples())
|
|
|
|
def test_beam_search(self):
|
|
def beam(x, h, c, embed, w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
|
|
b_i, b_f, b_o, b_c, w_hs, b_s, iter_num, idx):
|
|
k = 5
|
|
vocab_size = embed.size(1)
|
|
iter_count = torch.zeros_like(iter_num)
|
|
max_len = idx.size(2)
|
|
while bool(iter_count < iter_num):
|
|
iter_count = iter_count + 1
|
|
# LSTM Cell
|
|
i_t = torch.matmul(x, w_xi) + torch.matmul(h, w_hi) + b_i
|
|
f_t = torch.matmul(x, w_xf) + torch.matmul(h, w_hf) + b_f
|
|
o_t = torch.matmul(x, w_xo) + torch.matmul(h, w_ho) + b_o
|
|
# activations
|
|
i_t = torch.sigmoid(i_t)
|
|
f_t = torch.sigmoid(f_t)
|
|
o_t = torch.sigmoid(o_t)
|
|
# cell computations
|
|
c_t = torch.matmul(x, w_xc) + torch.matmul(h, w_hc) + b_c
|
|
c_t = torch.tanh(c_t)
|
|
c_t = torch.mul(c_t, f_t) + torch.mul(i_t, c_t)
|
|
h_t = torch.mul(o_t, torch.tanh(c_t))
|
|
h = h_t
|
|
c = c_t
|
|
# calculate features with max probability
|
|
s_t = torch.matmul(h_t, w_hs) + b_s
|
|
s_t = s_t.view([1, s_t.size(1) * s_t.size(2)])
|
|
p_t = torch.softmax(s_t, 1)
|
|
prob_t, idx_t = torch.topk(p_t, k, 1)
|
|
if(int(idx_t.dim()) > 1):
|
|
idx_t_tmp = idx_t.squeeze(0)
|
|
else:
|
|
idx_t_tmp = idx_t
|
|
new_y = torch.fmod(idx_t_tmp, vocab_size)
|
|
pre_y = idx_t_tmp / vocab_size
|
|
x = embed.index_select(1, new_y)
|
|
h = h_t.index_select(1, pre_y)
|
|
c = c_t.index_select(1, pre_y)
|
|
iter = int(iter_count[0])
|
|
idx = torch.cat([idx.narrow(2, 0, iter).index_select(1, pre_y),
|
|
torch.fmod(idx_t, vocab_size).unsqueeze(-1),
|
|
idx.narrow(2, iter, max_len - iter)], 2)
|
|
idx = idx.narrow(2, 0, max_len)
|
|
return idx
|
|
|
|
beam_batch = torch.jit.batch(batch_size=4)(beam)
|
|
|
|
k = 5
|
|
batch_size, input_size, hidden_size, vocab_size = 4, 6, 8, 7
|
|
max_len = 5
|
|
xs, batch = self.rand_batch(batch_size, (False, 1), (False, input_size))
|
|
hx, h_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
|
|
cx, c_batch = self.rand_batch(batch_size, (False, 1), (False, hidden_size))
|
|
embed, embed_batch = self.rand_batch(batch_size, (False, vocab_size), (False, input_size))
|
|
iter_num = [torch.randint(2, max_len + 1, (1,)) for i in range(batch_size)]
|
|
iter_num_batch = BatchTensor(iter_num, torch.tensor([]).byte())
|
|
|
|
# input to hidden weights
|
|
w_xi = torch.rand(input_size, hidden_size)
|
|
w_xf = torch.rand(input_size, hidden_size)
|
|
w_xo = torch.rand(input_size, hidden_size)
|
|
w_xc = torch.rand(input_size, hidden_size)
|
|
# hidden to hidden weights
|
|
w_hi = torch.rand(hidden_size, hidden_size)
|
|
w_hf = torch.rand(hidden_size, hidden_size)
|
|
w_ho = torch.rand(hidden_size, hidden_size)
|
|
w_hc = torch.rand(hidden_size, hidden_size)
|
|
# bias terms
|
|
b_i = torch.rand(1, hidden_size)
|
|
b_f = torch.rand(1, hidden_size)
|
|
b_o = torch.rand(1, hidden_size)
|
|
b_c = torch.rand(1, hidden_size)
|
|
# hidden to vocab weights, bias
|
|
w_hs = torch.rand(hidden_size, vocab_size)
|
|
b_s = torch.rand(1, vocab_size)
|
|
|
|
idx_batch = torch.jit.BatchTensor(torch.zeros([batch_size, k, max_len], dtype=torch.long),
|
|
torch.zeros([batch_size, 1, max_len]).byte(),
|
|
torch.tensor([0, 1]).byte())
|
|
idx = [torch.zeros([1, k, max_len], dtype=torch.long) for _ in range(batch_size)]
|
|
|
|
ys = [beam(xs[j], hx[j], cx[j], embed[j], w_xi, w_xf, w_xo, w_xc, w_hi, w_hf, w_ho, w_hc,
|
|
b_i, b_f, b_o, b_c, w_hs, b_s, iter_num[j], idx[j]).narrow(2, 0, int(iter_num[j]))
|
|
for j in range(batch_size)]
|
|
ybs = beam_batch(batch, h_batch, c_batch, embed_batch, w_xi, w_xf, w_xo, w_xc,
|
|
w_hi, w_hf, w_ho, w_hc, b_i, b_f, b_o, b_c, w_hs, b_s, iter_num_batch, idx_batch)
|
|
self.assertEqual(ys, ybs.examples())
|
|
|
|
|
|
def execWrapper(code, glob, loc):
|
|
if PY2:
|
|
exec(code) in glob, loc
|
|
else:
|
|
exec(code, glob, loc)
|
|
|
|
|
|
class TestScript(JitTestCase):
|
|
@contextmanager
|
|
def capture_stdout(self):
|
|
# No idea how to capture stdout from C++ on Windows
|
|
if WINDOWS:
|
|
yield ['']
|
|
return
|
|
import os
|
|
import fcntl
|
|
import errno
|
|
sys.stdout.flush()
|
|
stdout_fd = os.dup(1)
|
|
r, w = os.pipe()
|
|
try:
|
|
# Override stdout with r - dup is guaranteed to return the lowest free fd
|
|
os.close(1)
|
|
os.dup(w)
|
|
|
|
captured_stdout = ['']
|
|
yield captured_stdout
|
|
sys.stdout.flush() # Make sure that Python hasn't buffered anything
|
|
|
|
# Do the ugly dance to read all the data that was written into the pipe
|
|
fcntl.fcntl(r, fcntl.F_SETFL, os.O_NONBLOCK)
|
|
total_stdout = ''
|
|
while True:
|
|
try:
|
|
total_stdout += os.read(r, 1000).decode('ascii')
|
|
except OSError as e:
|
|
if e.errno != errno.EAGAIN:
|
|
raise
|
|
break
|
|
captured_stdout[0] = total_stdout
|
|
finally:
|
|
# Revert the change, and clean up all fds
|
|
os.close(1)
|
|
os.dup(stdout_fd)
|
|
os.close(stdout_fd)
|
|
os.close(r)
|
|
os.close(w)
|
|
|
|
def checkScriptRaisesRegex(self, script, inputs, exception, regex,
|
|
optimize=True, outputs=None, capture_output=False):
|
|
"""
|
|
Checks that a given function will throw the correct exception,
|
|
when executed with normal python, the string frontend, and the AST frontend
|
|
"""
|
|
# normal python
|
|
with self.assertRaisesRegex(exception, regex):
|
|
script(*inputs)
|
|
# string frontend
|
|
with self.assertRaisesRegex(exception, regex):
|
|
source = textwrap.dedent(inspect.getsource(script))
|
|
cu = torch.jit.CompilationUnit(source, optimize)
|
|
ge = getattr(cu, script.__name__)
|
|
ge(*inputs)
|
|
# python AST frontend
|
|
with self.assertRaisesRegex(exception, regex):
|
|
ge = torch.jit.script(script, optimize)
|
|
ge(*inputs)
|
|
|
|
def test_training_param(self):
|
|
class What(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
# type: (int) -> int
|
|
if self.training:
|
|
r = x
|
|
else:
|
|
r = x + 4
|
|
# check double use of training
|
|
if self.training:
|
|
r = r + 1
|
|
return r
|
|
|
|
w = What()
|
|
self.assertEqual(4, w(3))
|
|
w.train(False)
|
|
self.assertEqual(7, w(3))
|
|
|
|
def test_jitter_bug(self):
|
|
@torch.jit.script
|
|
def fn2(input, kernel_size):
|
|
# type: (Tensor, List[int]) -> Tensor
|
|
if kernel_size[0] > 1:
|
|
_stride = [2]
|
|
else:
|
|
_stride = kernel_size
|
|
print(_stride, kernel_size)
|
|
return input
|
|
|
|
@torch.jit.script
|
|
def fn(input):
|
|
# type: (Tensor) -> Tensor
|
|
return fn2(input, [1])
|
|
|
|
def test_parser_kwargonly(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(x, *, y) -> Tuple[Tensor, Tensor]:
|
|
return x, x
|
|
def bar(x):
|
|
return foo(x, y=x)
|
|
''')
|
|
self.assertTrue('*' in cu.module._get_method('foo').pretty_print_schema())
|
|
with self.assertRaisesRegex(RuntimeError, "not provided"):
|
|
torch.jit.CompilationUnit('''
|
|
def foo(x, *, y) -> Tuple[Tensor, Tensor]:
|
|
return x, x
|
|
def bar(x):
|
|
return foo(x, x)
|
|
''')
|
|
|
|
def test_annoying_doubles(self):
|
|
mod = types.ModuleType("temp")
|
|
mod.inf = float("inf")
|
|
mod.ninf = float("-inf")
|
|
mod.nan = float("nan")
|
|
|
|
with self.disableModuleHook():
|
|
@torch.jit.script
|
|
def foo():
|
|
return math.pi, 0.1, mod.inf, mod.ninf, 2.225073858507201e-308, mod.nan
|
|
|
|
pp, table = foo._get_method('forward').python_print()
|
|
ppv = "op_version_set = 0\n{}".format(pp)
|
|
sm = torch.jit.ScriptModule()
|
|
torch._C._jit_import_methods(sm, ppv, table)
|
|
r = foo()
|
|
r2 = sm()
|
|
# use precise assert, we are checking floating point details
|
|
self.assertTrue(r[:-1] == r2[:-1])
|
|
self.assertTrue(math.isnan(r[-1]) and math.isnan(r2[-1]))
|
|
|
|
def test_type_annotate(self):
|
|
|
|
def foo(a):
|
|
return torch.jit.annotate(torch.Tensor, a)
|
|
|
|
self.checkScript(foo, (torch.rand(3),))
|
|
|
|
def bar():
|
|
a = torch.jit.annotate(List[int], [])
|
|
for _ in range(10):
|
|
a.append(4)
|
|
return a
|
|
|
|
self.checkScript(bar, ())
|
|
|
|
def baz(a):
|
|
return torch.jit.annotate(float, a)
|
|
self.checkScript(baz, (torch.rand(()),))
|
|
|
|
# test annotate none types
|
|
def annotate_none():
|
|
return torch.jit.annotate(Optional[torch.Tensor], None)
|
|
|
|
def annotate_none_no_optional():
|
|
return torch.jit.annotate(torch.Tensor, None)
|
|
|
|
self.checkScript(annotate_none, ())
|
|
self.checkScript(annotate_none_no_optional, ())
|
|
|
|
def test_robust_op_resolution(self):
|
|
neg = torch.add # misleading name to make sure we resolve by function
|
|
|
|
def stuff(x):
|
|
return neg(x, x)
|
|
|
|
a = (torch.rand(3),)
|
|
self.checkScript(stuff, a)
|
|
|
|
def test_tuple_io(self):
|
|
def stuff(x):
|
|
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
|
|
a, b = x
|
|
return b, a
|
|
|
|
a = (torch.rand(3), torch.rand(3))
|
|
self.checkScript(stuff, (a,))
|
|
|
|
def test_tuple_create_return(self):
|
|
def stuff2(x):
|
|
# type: (int) -> Tuple[Tensor, Tensor]
|
|
a = (torch.ones(x), torch.zeros(x))
|
|
return a
|
|
self.checkScript(stuff2, (3,))
|
|
|
|
def test_list_io(self):
|
|
def stuff3(x):
|
|
# type: (List[int]) -> Tuple[Tensor, List[int]]
|
|
return torch.ones(x), x
|
|
self.checkScript(stuff3, ([3, 2],))
|
|
|
|
# to avoid defining sum_list in multiple tests
|
|
def get_sum_list_fn(self):
|
|
def sum_list(a):
|
|
# type: (List[int]) -> int
|
|
sum = 0
|
|
for i in a:
|
|
sum += i
|
|
|
|
return sum
|
|
|
|
return sum_list
|
|
|
|
def test_sum_list_diff_elms(self):
|
|
self.checkScript(self.get_sum_list_fn(), ([1, 2, 3, 4, 5],))
|
|
|
|
def test_sum_list_empty(self):
|
|
self.checkScript(self.get_sum_list_fn(), ([],))
|
|
|
|
def test_sum_list_one(self):
|
|
self.checkScript(self.get_sum_list_fn(), ([1],))
|
|
|
|
def test_sum_list_literal(self):
|
|
|
|
def sum_list():
|
|
# type: () -> int
|
|
sum = 0
|
|
for i in [1, 2, 3, 4, 5]:
|
|
sum += i
|
|
|
|
return sum
|
|
|
|
self.checkScript(sum_list, ())
|
|
|
|
def test_sum_list_wrong_type(self):
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
|
|
@torch.jit.script
|
|
def sum_list(a):
|
|
# type: (int) -> int
|
|
sum = 0
|
|
for i in a: # noqa: T484
|
|
sum += i
|
|
|
|
return sum
|
|
|
|
sum_list(1)
|
|
|
|
def test_bool_list_io(self):
|
|
@torch.jit.script
|
|
def stuff4(x):
|
|
# type: (List[bool]) -> Tuple[List[bool], List[bool], List[List[bool]]]
|
|
return x, [True, False], [[True]]
|
|
|
|
li_1, li_2, li_3 = stuff4([True])
|
|
li_3 = li_3[0]
|
|
for li in [li_1, li_2, li_3]:
|
|
self.assertTrue(type(li[0]) == type(True))
|
|
|
|
def test_nested_list(self):
|
|
def foo(z):
|
|
# type: (Tuple[int, List[List[int]]]) -> int
|
|
x, y = z
|
|
return y[0][1]
|
|
self.checkScript(foo, ((1, [[1, 2], [3, 4]]),))
|
|
|
|
def test_nested_list_construct(self):
|
|
def foo():
|
|
return [[4]] + [[4, 5]]
|
|
self.checkScript(foo, ())
|
|
|
|
def test_tensor_shape(self):
|
|
x = torch.empty(34, 56, 78)
|
|
|
|
def f(x):
|
|
return x.shape
|
|
|
|
self.checkScript(f, (x,))
|
|
|
|
def test_tensor_grad(self):
|
|
x = torch.tensor(1.0, requires_grad=True)
|
|
y = torch.tensor(1.0, requires_grad=False)
|
|
|
|
def f(x):
|
|
return x.requires_grad
|
|
|
|
self.checkScript(f, (x,))
|
|
self.checkScript(f, (y,))
|
|
|
|
def test_tensor_dtype(self):
|
|
x_byte = torch.empty(34, 56, 78, dtype=torch.uint8)
|
|
x_long = torch.empty(34, 56, 78, dtype=torch.long)
|
|
x_float32 = torch.empty(34, 56, 78, dtype=torch.float32)
|
|
|
|
@torch.jit.script
|
|
def byte(x):
|
|
return x.dtype == torch.uint8
|
|
|
|
@torch.jit.script
|
|
def long(x):
|
|
return x.dtype == torch.long
|
|
|
|
@torch.jit.script
|
|
def float32(x):
|
|
return x.dtype == torch.float32
|
|
|
|
self.assertTrue(byte(x_byte))
|
|
self.assertFalse(byte(x_long))
|
|
self.assertFalse(byte(x_float32))
|
|
self.assertFalse(long(x_byte))
|
|
self.assertTrue(long(x_long))
|
|
self.assertFalse(long(x_float32))
|
|
self.assertFalse(float32(x_byte))
|
|
self.assertFalse(float32(x_long))
|
|
self.assertTrue(float32(x_float32))
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
|
|
def test_tensor_device(self):
|
|
cpu = torch.empty(34, 56, 78, device='cpu')
|
|
gpu = torch.empty(34, 56, 78, device='cuda')
|
|
|
|
@torch.jit.script
|
|
def same_device(x, y):
|
|
return x.device == y.device
|
|
|
|
self.assertTrue(same_device(cpu, cpu))
|
|
self.assertTrue(same_device(gpu, gpu))
|
|
self.assertFalse(same_device(cpu, gpu))
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
|
|
def test_tensor_to_device(self):
|
|
def to_device(x):
|
|
return x.to(device="cuda").to(device=torch.device("cpu"))
|
|
|
|
self.checkScript(to_device, (torch.ones(3, 4),))
|
|
|
|
def test_tensor_to_cpu(self):
|
|
def to_cpu(x):
|
|
return x.cpu()
|
|
|
|
x = torch.ones(3, 4)
|
|
script_fn = torch.jit.script(to_cpu)
|
|
self.assertEqual(to_cpu(x).device, script_fn(x).device)
|
|
self.checkScript(to_cpu, (x,))
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "device tests require CUDA")
|
|
def test_tensor_to_cuda(self):
|
|
def to_cuda(x):
|
|
return x.cuda()
|
|
|
|
x = torch.ones(3, 4)
|
|
script_fn = torch.jit.script(to_cuda)
|
|
self.assertEqual(to_cuda(x).device, script_fn(x).device)
|
|
self.checkScript(to_cuda, (x,))
|
|
|
|
def test_generic_list_errors(self):
|
|
with self.assertRaisesRegex(RuntimeError, "previously matched to type"):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return [[x]] + [[1]]
|
|
|
|
def test_script_cu(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(a):
|
|
b = a
|
|
return b
|
|
''')
|
|
a = Variable(torch.rand(1))
|
|
self.assertEqual(a, cu.foo(a))
|
|
|
|
# because the compilation unit ingests python strings
|
|
# to use an escape sequence escape the backslash (\\n = \n)
|
|
def test_string_cu(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(a):
|
|
print(a, """a\\n\tb\\n""", 2, "a\
|
|
a")
|
|
return a
|
|
''')
|
|
FileCheck().check("aa").check("a\\n\\tb\\n").run(str(cu.foo.graph))
|
|
|
|
def test_string_ops(self):
|
|
def foo():
|
|
a = "a" + "b"
|
|
return a + a, "ab" == "b", "ab" != "b", "ab" == "ab", "ab" != "ab"
|
|
|
|
self.checkScript(foo, ())
|
|
|
|
def test_string_new_line(self):
|
|
with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
|
|
torch.jit.CompilationUnit('''
|
|
def test_while(a):
|
|
print("
|
|
a")
|
|
return a
|
|
''')
|
|
|
|
def test_string_single_escape(self):
|
|
with self.assertRaisesRegex(RuntimeError, "expected a valid token*"):
|
|
torch.jit.CompilationUnit('''
|
|
def test_while(a):
|
|
print("\\")
|
|
return a
|
|
''')
|
|
|
|
def test_script_annotation(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return a + a + a
|
|
s = Variable(torch.rand(2))
|
|
self.assertEqual(s + s + s, foo(s))
|
|
|
|
def test_inf(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return a < float('inf')
|
|
s = torch.rand(1)
|
|
self.assertTrue(foo(s))
|
|
|
|
@torch.jit.script
|
|
def bar(a):
|
|
return a > float('-inf')
|
|
s = torch.rand(1)
|
|
self.assertTrue(foo(s))
|
|
|
|
def test_add(self):
|
|
def func(a, b):
|
|
c = a + b
|
|
c += a
|
|
return c
|
|
|
|
a = torch.rand(1, requires_grad=True)
|
|
b = torch.rand(1, requires_grad=True)
|
|
self.checkScript(func, (a, b), optimize=True)
|
|
|
|
def test_mul(self):
|
|
def func(a, b):
|
|
return a * b
|
|
|
|
a = torch.rand(1, requires_grad=True)
|
|
b = torch.rand(1, requires_grad=True)
|
|
self.checkScript(func, (a, b), optimize=True)
|
|
|
|
@unittest.skipIf(not PY35, "Python 3.5 needed")
|
|
def test_matmul_py3(self):
|
|
code = dedent("""
|
|
def fn(a, b):
|
|
return a @ b
|
|
""")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
script_path = os.path.join(tmp_dir, 'script.py')
|
|
with open(script_path, 'w') as f:
|
|
f.write(code)
|
|
fn = get_fn('test_matmul_py3', script_path)
|
|
|
|
a = torch.rand(4, 3, requires_grad=True)
|
|
b = torch.rand(3, 2, requires_grad=True)
|
|
self.checkScript(fn, (a, b), optimize=True)
|
|
|
|
def test_pow(self):
|
|
def func(a, b):
|
|
return a ** b
|
|
|
|
def func2(a, b, c, d):
|
|
return c + a ** b ** d
|
|
|
|
a = torch.rand(1, requires_grad=True)
|
|
b = torch.rand(1, requires_grad=True)
|
|
c = torch.rand(1, requires_grad=True)
|
|
d = torch.rand(1, requires_grad=True)
|
|
self.checkScript(func, (a, b), optimize=True)
|
|
self.checkScript(func2, (a, b, c, d), optimize=True)
|
|
|
|
def test_triple(self):
|
|
def func(x):
|
|
return 3. * x
|
|
|
|
x = torch.rand(1, dtype=torch.float, requires_grad=True)
|
|
self.checkScript(func, [x], optimize=True)
|
|
|
|
def test_slice(self):
|
|
def func(x):
|
|
return x[:5]
|
|
|
|
x = torch.rand(10, dtype=torch.float, requires_grad=True)
|
|
self.checkScript(func, [x], optimize=True)
|
|
|
|
def func2(x):
|
|
return x[5:]
|
|
|
|
self.checkScript(func2, [x], optimize=True)
|
|
|
|
def test_gather(self):
|
|
def func(x):
|
|
return x[0]
|
|
|
|
x = torch.rand(10, dtype=torch.float, requires_grad=True)
|
|
self.checkScript(func, [x], optimize=True)
|
|
|
|
def test_random(self):
|
|
@torch.jit.script
|
|
def f(mean, std):
|
|
return torch.normal(mean, std)
|
|
|
|
mean, std = torch.zeros(5, 5), torch.ones(5, 5)
|
|
with torch.random.fork_rng(devices=[]):
|
|
output = torch.normal(mean, std)
|
|
with torch.random.fork_rng(devices=[]):
|
|
script_output = f(mean, std)
|
|
self.assertEqual(output, script_output)
|
|
|
|
def _check_code(self, code_str, fn_name, inputs):
|
|
scope = {}
|
|
exec(code_str, globals(), scope)
|
|
cu = torch.jit.CompilationUnit(code_str)
|
|
self.assertEqual(cu.func(*inputs), scope[fn_name](*inputs))
|
|
|
|
@unittest.skipIf(not RUN_CUDA, 'no CUDA')
|
|
def test_scriptmodule_releases_tensors_cuda(self):
|
|
@torch.jit.script
|
|
def fn(x, y):
|
|
return x.sigmoid() * y.tanh()
|
|
|
|
def test(backward=False):
|
|
x = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
|
|
y = torch.randn(3, 3, dtype=torch.double, device='cuda', requires_grad=True)
|
|
out = fn(x, y)
|
|
if backward:
|
|
out.sum().backward()
|
|
|
|
with self.assertLeaksNoCudaTensors():
|
|
test()
|
|
test()
|
|
test()
|
|
|
|
with self.assertLeaksNoCudaTensors():
|
|
test(backward=True)
|
|
test(backward=True)
|
|
test(backward=True)
|
|
|
|
def test_index(self):
|
|
def consec(size, start=0):
|
|
numel = torch.tensor(size).prod().item()
|
|
return torch.arange(numel).view(size)
|
|
|
|
def check_indexing(indexing, tensor):
|
|
template = dedent("""
|
|
def func(x):
|
|
return x{}
|
|
""")
|
|
|
|
self._check_code(template.format(indexing), "func", [tensor])
|
|
|
|
def check_dynamic_indexing(indexing, tensor, value1, value2):
|
|
value1 = torch.tensor(value1)
|
|
value2 = torch.tensor(value2)
|
|
|
|
template = dedent("""
|
|
def func(x, value1, value2):
|
|
i = int(value1)
|
|
j = int(value2)
|
|
return x{}
|
|
""")
|
|
|
|
self._check_code(template.format(indexing), "func", [tensor, value1, value2])
|
|
|
|
# basic slices
|
|
check_indexing('[0]', consec((3, 3)))
|
|
check_indexing('[1]', consec((3, 3), 10))
|
|
check_indexing('[2]', consec((3, 3), 19))
|
|
check_indexing('[2]', consec((3,)))
|
|
check_indexing('[-1]', consec((3, 3), 19))
|
|
check_indexing('[0:2]', consec((3, 3, 3)))
|
|
check_indexing('[1:-1]', consec((3, 3, 3)))
|
|
check_indexing('[-3:-1]', consec((6, 3)))
|
|
check_indexing('[1:]', consec((3, 3)))
|
|
check_indexing('[:1]', consec((3, 3)))
|
|
check_indexing('[:]', consec((3, 2)))
|
|
|
|
# multi-dim: indexes
|
|
check_indexing('[0, 1]', consec((3, 3)))
|
|
check_indexing('[0, 1]', consec((3, 3, 2)))
|
|
check_indexing('[1, 0, 2]', consec((3, 3, 3)))
|
|
check_indexing('[2, -1]', consec((3, 3)))
|
|
|
|
# multi-dim: mixed slicing and indexing
|
|
check_indexing('[0, 1:2]', consec((3, 3)))
|
|
check_indexing('[0, :1]', consec((3, 3, 2)))
|
|
check_indexing('[1, 2:]', consec((3, 3, 3)))
|
|
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
|
|
check_indexing('[1:, -1, 0]', consec((3, 3, 3, 3)))
|
|
check_indexing('[-1, 2:, 1:2]', consec((3, 3, 3, 3)))
|
|
check_indexing('[-1, 1:, 0]', consec((3, 3, 3, 3)))
|
|
check_indexing('[-1, :, 0, 2]', consec((3, 3, 3, 3)))
|
|
|
|
# zero-sized slices
|
|
check_indexing('[0:0]', consec((2, 2)))
|
|
check_indexing('[0:0, 1]', consec((3, 3)))
|
|
|
|
# trivial expression usage
|
|
check_indexing('[1+1]', consec((3, 3)))
|
|
check_indexing('[1:(0 + 2)]', consec((3, 3, 3)))
|
|
|
|
# dynamic expression usage
|
|
check_dynamic_indexing("[i + j]", consec((3, 3)), 0, 1)
|
|
check_dynamic_indexing("[i:j, i]", consec((3, 3, 2)), 0, 2)
|
|
|
|
def test_tensor_item(self):
|
|
def test_scalar_to_float_coercion(x):
|
|
return x.item() == 1
|
|
|
|
self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1.0),))
|
|
self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1),))
|
|
|
|
def test_scalar_cast(x):
|
|
scalar = x.item()
|
|
return int(scalar), float(scalar)
|
|
|
|
self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1.0),))
|
|
self.checkScript(test_scalar_to_float_coercion, (torch.tensor(1),))
|
|
|
|
expected_str = r"Use int\(tensor\) or float\(tensor\) to retrieve"
|
|
with self.assertRaisesRegex(RuntimeError, expected_str):
|
|
@torch.jit.script
|
|
def int_fn(a):
|
|
# type: (int) -> int
|
|
return a
|
|
|
|
@torch.jit.script
|
|
def test_error_msg(x):
|
|
return int_fn(x.item())
|
|
|
|
def test_method_on_number(self):
|
|
def func():
|
|
c = 1
|
|
return c.add(1)
|
|
with self.assertRaisesRegex(RuntimeError, 'Cannot call methods on numbers'):
|
|
torch.jit.script(func)
|
|
|
|
# testing implicit conversion of tensors to scalars to match function arguments
|
|
def test_scalar_to_num_conversions(self):
|
|
@torch.jit.script
|
|
def multiple_defs(x):
|
|
c = 1
|
|
x = x + c
|
|
return x
|
|
|
|
self.assertTrue("ImplicitTensorToNum" not in str(multiple_defs.graph))
|
|
|
|
@torch.jit.script
|
|
def tensor_to_int_script(x, tensor):
|
|
return x.unsqueeze(tensor)
|
|
|
|
def tensor_to_int(x, tensor):
|
|
return x.unsqueeze(tensor)
|
|
|
|
@torch.jit.script
|
|
def tensor_to_float_script(x, tensor):
|
|
return x.addcmul(tensor, tensor, value=tensor)
|
|
|
|
def tensor_to_float(x, tensor):
|
|
return x.addcmul(tensor, tensor, value=tensor)
|
|
|
|
x = torch.zeros(10)
|
|
# float tensor, float tensor with grad, int tensor (can't set grad on int tensor)
|
|
tensors = [torch.tensor(1.1),
|
|
torch.tensor(1.1, requires_grad=True),
|
|
torch.tensor(0),
|
|
torch.tensor([2])]
|
|
|
|
script_funs = [tensor_to_int_script, tensor_to_float_script]
|
|
funs = [tensor_to_int, tensor_to_float]
|
|
|
|
# return the result, or whether exception was thrown
|
|
def test_func(func, x, tensor):
|
|
try:
|
|
result = func(x, tensor)
|
|
except RuntimeError as e:
|
|
result = True
|
|
except TypeError as e:
|
|
result = True
|
|
return result
|
|
|
|
# assert result or exception equal for each (function, inputs)
|
|
for tensor in tensors:
|
|
for i in range(len(script_funs)):
|
|
self.assertEqual(test_func(script_funs[i], x, tensor), test_func(funs[i], x, tensor))
|
|
|
|
def test_tuple_to_opt_list(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
# type: (Optional[List[int]]) -> int
|
|
return 1
|
|
|
|
@torch.jit.script
|
|
def tuple_call():
|
|
return foo((1, 2))
|
|
|
|
def test_advancedindex(self):
|
|
def consec(size, start=0):
|
|
numel = torch.tensor(size).prod().item()
|
|
return torch.arange(numel).view(size)
|
|
|
|
def check_indexing(indexing, tensor, **kwargs):
|
|
indices_dict = kwargs
|
|
|
|
template = dedent("""
|
|
def func(x{formals}):
|
|
return x{expr}
|
|
""")
|
|
|
|
formals = []
|
|
values = []
|
|
for formal, value in indices_dict.items():
|
|
formals.append(formal)
|
|
values.append(value)
|
|
|
|
formals = ''.join(map(', {}'.format, formals))
|
|
inputs = [tensor] + values
|
|
self._check_code(template.format(formals=formals, expr=indexing),
|
|
"func", inputs)
|
|
|
|
# Indexing with tensor (basic)
|
|
check_indexing('[i]', consec((3, 3)), i=torch.tensor([0]))
|
|
check_indexing('[i]', consec((3, 3)), i=torch.tensor(1))
|
|
check_indexing('[i]', consec((3, 3)), i=torch.tensor([-2]))
|
|
check_indexing('[i]', consec((3, 3), 2), i=torch.tensor([0, 0]))
|
|
check_indexing('[i]', consec((3, 3, 2, 2)), i=torch.tensor([0, -2, 1]))
|
|
|
|
# NB: indexing with tensors and indexing with sequences can be implemented
|
|
# in a very similar way (sequences are converted to tensors), so only one
|
|
# case needs to be tested extensively.
|
|
# XXX: When we can index with sequences, replace these cases with
|
|
# sequence indexing expressions; those are much easier to read.
|
|
|
|
# Misc sequence advanced indexing
|
|
inp = consec((4, 8, 5))
|
|
to_check = [
|
|
# [[0, 2], [1, 3]]
|
|
['[i, j]', {'i': [0, 2], 'j': [1, 3]}],
|
|
# [[0, 2], [1, 3], [1, 1]]
|
|
['[i, j, k]', {'i': [0, 2], 'j': [1, 3], 'k': [1, 1]}],
|
|
# [[0, 2], 1, [1, 1]]
|
|
['[i, j, k]', {'i': [0, 2], 'j': 1, 'k': [1, 1]}],
|
|
# [:, :, [0, 3, 4]]
|
|
['[:, :, i]', {'i': [0, 3, 4]}],
|
|
# [:, [2, 4, 5, 7], 2:4]
|
|
['[:, i, 2:4]', {'i': [0, 2, 3]}],
|
|
# [[2, 3], :, :]
|
|
['[i, :, :]', {'i': [2, 3]}],
|
|
# [:, [0, 2, 3], [1, 3, 4]]
|
|
['[:, i, j]', {'i': [0, 2, 3], 'j': [1, 3, 4]}],
|
|
# [:, [0], [1, 2, 4]]
|
|
['[:, i, j]', {'i': [0], 'j': [1, 2, 4]}],
|
|
# [:, [0, 1, 3], [4]]
|
|
['[:, i, j]', {'i': [0, 1, 3], 'j': [4]}],
|
|
# [:, [[0, 1], [1, 0]], [[2, 3]]]
|
|
['[:, i, j]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}],
|
|
# [:, [[0, 1], [2, 3]], [[0]]]
|
|
['[:, i, j]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}],
|
|
# [:, [[5, 6]], [[0, 3], [4, 4]]]
|
|
['[:, i, j]', {'i': [[5, 6]], 'j': [[0, 3], [4, 4]]}],
|
|
# [[0, 2, 3], [1, 3, 4], :]
|
|
['[i, j, :]', {'i': [0, 2, 3], 'j': [1, 3, 4]}],
|
|
# [0, [1, 2, 4], :]
|
|
['[i, j, :]', {'i': 0, 'j': [1, 2, 4]}],
|
|
# [[0, 1, 3], 4, :]
|
|
['[i, j, :]', {'i': [0, 1, 3], 'j': 4}],
|
|
# [[[0, 1], [1, 0]], [[2, 1], [3, 5]], :]
|
|
['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 1], [3, 5]]}],
|
|
# [[[0, 1], [1, 0]], [[2, 3]], :]
|
|
['[i, j, :]', {'i': [[0, 1], [1, 0]], 'j': [[2, 3]]}],
|
|
# [[[0, 1], [2, 3]], [[0]], :]
|
|
['[i, j, :]', {'i': [[0, 1], [2, 3]], 'j': [[0]]}],
|
|
# [[[2, 1]], [[0, 3], [4, 4]], :]
|
|
['[i, j, :]', {'i': [[2, 1]], 'j': [[0, 3], [4, 4]]}],
|
|
# [[[2]], [[0, 3], [4, 1]], 0:2]
|
|
['[i, j, 0:2]', {'i': [[2]], 'j': [[0, 3], [4, 1]]}],
|
|
]
|
|
|
|
for expr, argdict in to_check:
|
|
tensordict = {k: torch.tensor(v) for (k, v) in argdict.items()}
|
|
check_indexing(expr, inp, **tensordict)
|
|
|
|
def test_keyword(self):
|
|
@torch.jit.script
|
|
def func(x):
|
|
return torch.sum(x, dim=0)
|
|
|
|
x = torch.rand(10, dtype=torch.float, requires_grad=True)
|
|
y = func(x)
|
|
y2 = torch.sum(x, dim=0)
|
|
self.assertEqual(y, y2)
|
|
|
|
def test_constant_pooling_none(self):
|
|
@torch.jit.script
|
|
def typed_nones(a=None, b=None, c=None):
|
|
# type: (Optional[int], Optional[bool], Optional[Tensor]) -> Tuple[Optional[int], Optional[bool], Optional[Tensor]] # noqa
|
|
return a, b, c
|
|
|
|
@torch.jit.script
|
|
def test(a):
|
|
# type: (bool) -> None
|
|
if a:
|
|
print(typed_nones())
|
|
else:
|
|
print(typed_nones())
|
|
|
|
graph_str = str(test.graph)
|
|
self.assertTrue(graph_str.count("bool? = prim::Constant") == 1)
|
|
self.assertTrue(graph_str.count("int? = prim::Constant") == 1)
|
|
self.assertTrue(graph_str.count("None = prim::Constant") == 1)
|
|
|
|
def test_literal(self):
|
|
def func1(a, b):
|
|
c = a, b
|
|
d, e = c
|
|
return d + e
|
|
|
|
def func2(a, b):
|
|
c = a, (a, b)
|
|
d, e = c
|
|
f, g = e
|
|
return d + f + g
|
|
|
|
def func3(a, b):
|
|
# type: (float, float) -> float
|
|
c = 0., (0., 0.)
|
|
x = True
|
|
while x:
|
|
x = False
|
|
c = a, (a, b)
|
|
d, e = c
|
|
f, g = e
|
|
return d + f + g
|
|
|
|
a = torch.rand(1, requires_grad=True)
|
|
b = torch.rand(1, requires_grad=True)
|
|
self.checkScript(func1, (a, b), optimize=True)
|
|
self.checkScript(func2, (a, b), optimize=True)
|
|
self.checkScript(func3, (a.item(), b.item()), optimize=True)
|
|
|
|
def test_expand(self):
|
|
@torch.jit.script
|
|
def func(x, y):
|
|
return x + y
|
|
|
|
x = torch.rand(2, 3, dtype=torch.float, requires_grad=True)
|
|
y = torch.rand(3, dtype=torch.float, requires_grad=True)
|
|
out = func(x, y)
|
|
self.assertEqual(func(x, y), x + y)
|
|
|
|
grad = torch.randn(2, 3, dtype=torch.float)
|
|
out.backward(grad)
|
|
self.assertEqual(x.grad, grad)
|
|
self.assertEqual(y.grad, grad.sum(dim=0))
|
|
|
|
def test_sum(self):
|
|
@torch.jit.script
|
|
def func(x):
|
|
return x.sum(dim=[4])
|
|
|
|
@torch.jit.script
|
|
def func2(x):
|
|
return x.sum(dim=4)
|
|
|
|
# test that shape analysis is written correctly for sum with IntArrayRef[1] dim argument
|
|
self.run_pass('constant_propagation', func.graph)
|
|
self.run_pass('constant_propagation', func2.graph)
|
|
torch._C._jit_pass_shape_analysis(
|
|
func.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
|
|
torch._C._jit_pass_shape_analysis(
|
|
func2.graph, (torch.zeros(1, 1, 1, 1, 4),), False)
|
|
self.assertTrue(func.graph.findNode("aten::sum").output().type().kind()
|
|
== "DimensionedTensorType")
|
|
self.assertTrue(func2.graph.findNode("aten::sum").output().type().kind()
|
|
== "DimensionedTensorType")
|
|
|
|
def test_cat(self):
|
|
@torch.jit.script
|
|
def func(x):
|
|
return torch.cat((x, x), dim=0)
|
|
|
|
x = torch.rand(10, dtype=torch.float, requires_grad=True)
|
|
self.assertEqual(func(x), torch.cat((x, x), dim=0))
|
|
|
|
@torch.jit.script
|
|
def func2(x, y):
|
|
return torch.cat((x, x), y)
|
|
|
|
x = torch.rand([2, 2])
|
|
y = torch.tensor(1)
|
|
self.assertEqual(func2(x, y), torch.cat((x, x), y))
|
|
|
|
def test_cat_lifts(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return torch.cat([x, x], dim=1)
|
|
|
|
@torch.jit.script
|
|
def foo2(x):
|
|
return torch.cat([], dim=1)
|
|
|
|
@torch.jit.script
|
|
def foo3(x):
|
|
return torch.cat([x], dim=1)
|
|
|
|
for g in [foo.graph, foo2.graph, foo3.graph]:
|
|
FileCheck().check("int =").check("ListConstruct").check("aten::cat").run(str(g))
|
|
|
|
def test_list_literal(self):
|
|
def reassign():
|
|
x = [1]
|
|
if True:
|
|
x = [2, 3]
|
|
return
|
|
self.checkScript(reassign, (), optimize=False)
|
|
|
|
def reassign_arity_change():
|
|
x = [1]
|
|
if True:
|
|
x = [1, 2, 3]
|
|
return
|
|
self.checkScript(reassign_arity_change, (), optimize=False)
|
|
|
|
def reassign_from_empty_literal():
|
|
x = []
|
|
if True:
|
|
x = [1, 2, 3]
|
|
return
|
|
with self.assertRaisesRegex(RuntimeError, r"previously has type Tensor\[\]"):
|
|
self.checkScript(reassign_from_empty_literal, (), optimize=False)
|
|
|
|
def reassign_from_empty_builtin():
|
|
x = torch.jit.annotate(List[int], [])
|
|
if True:
|
|
x = [1, 2, 3]
|
|
y = torch.jit.annotate(List[float], [])
|
|
if True:
|
|
y = [1.0, 2.0, 3.0]
|
|
z = []
|
|
if True:
|
|
z = [torch.randn([1])]
|
|
return
|
|
self.checkScript(reassign_from_empty_builtin, (), optimize=False)
|
|
|
|
def reassign_bad_type():
|
|
x = [1]
|
|
if True:
|
|
x = [1.0]
|
|
return
|
|
with self.assertRaisesRegex(RuntimeError, "previously has type"):
|
|
self.checkScript(reassign_bad_type, (), optimize=False)
|
|
|
|
def reassign_nested():
|
|
x = torch.jit.annotate(List[int], [])
|
|
if True:
|
|
x = [1, 2, 3]
|
|
if True:
|
|
x = [1.0]
|
|
return
|
|
with self.assertRaisesRegex(RuntimeError, "previously has type"):
|
|
self.checkScript(reassign_nested, (), optimize=False)
|
|
|
|
def test_list_gather(self):
|
|
def index():
|
|
a = [1, 2, 3]
|
|
return a[1]
|
|
|
|
self.checkScript(index, ())
|
|
|
|
def negative_index():
|
|
a = [1, 2, 3]
|
|
return a[-1]
|
|
|
|
self.checkScript(negative_index, ())
|
|
|
|
def bad_index():
|
|
a = [1, 2, 3]
|
|
return a[4]
|
|
|
|
self.checkScriptRaisesRegex(bad_index, (), IndexError,
|
|
"list index out of range")
|
|
|
|
def bad_negative_index():
|
|
a = [1, 2, 3]
|
|
return a[-5]
|
|
|
|
self.checkScriptRaisesRegex(bad_negative_index, (), IndexError,
|
|
"list index out of range")
|
|
|
|
def test_tensor_len(self):
|
|
def func(x):
|
|
return len(x)
|
|
|
|
self.checkScript(func, [torch.ones(4, 5, 6)])
|
|
|
|
def test_list_len(self):
|
|
def func():
|
|
a = [1, 2, 3]
|
|
return len(a) == 3
|
|
|
|
self.checkScript(func, ())
|
|
|
|
def func2():
|
|
a = []
|
|
return len(a) == 0
|
|
|
|
self.checkScript(func2, ())
|
|
|
|
def test_list_ops(self):
|
|
def test_equality():
|
|
a = [1, 2, 3]
|
|
b = [1, 2, 3]
|
|
return a == b
|
|
|
|
self.checkScript(test_equality, (), optimize=True)
|
|
|
|
def test_inequality():
|
|
a = [1, 2, 3]
|
|
b = [1, 2, 3]
|
|
return a != b
|
|
|
|
self.checkScript(test_equality, (), optimize=True)
|
|
|
|
def test_non_equality():
|
|
a = [1, 2, 3]
|
|
b = [3]
|
|
return a == b
|
|
|
|
self.checkScript(test_non_equality, (), optimize=True)
|
|
|
|
def test_non_inequality():
|
|
a = [1, 2, 3]
|
|
b = [3]
|
|
return a != b
|
|
|
|
self.checkScript(test_non_equality, (), optimize=True)
|
|
|
|
def test_list_equality_as_cond():
|
|
a = [1, 2, 3]
|
|
b = [3]
|
|
if a == b:
|
|
c = 1
|
|
else:
|
|
c = 2
|
|
return c
|
|
|
|
self.checkScript(test_list_equality_as_cond, (), optimize=True)
|
|
|
|
def test_list_add():
|
|
a = [1, 2, 3]
|
|
b = [2]
|
|
c = a + b
|
|
return c == [1, 2, 3, 2]
|
|
|
|
self.checkScript(test_list_add, (), optimize=True)
|
|
|
|
def test_list_add_empty():
|
|
a = [1, 2, 3]
|
|
b = torch.jit.annotate(List[int], [])
|
|
c = a + b
|
|
return c == [1, 2, 3]
|
|
|
|
self.checkScript(test_list_add_empty, (), optimize=True)
|
|
|
|
def test_tensor_list_equality():
|
|
t1 = torch.ones([1, 1])
|
|
t2 = torch.ones([1, 1])
|
|
x = [t1, t2]
|
|
y = [t2, t1]
|
|
return x == y
|
|
|
|
self.checkScript(test_tensor_list_equality, (), optimize=True)
|
|
|
|
def test_invalid_list_equality():
|
|
t1 = torch.ones([2, 2])
|
|
t2 = torch.ones([2, 2])
|
|
x = [t1, t2]
|
|
y = [t2, t1]
|
|
# will throw since the tensors have more than one element
|
|
return x == y
|
|
|
|
self.checkScriptRaisesRegex(
|
|
test_invalid_list_equality,
|
|
(),
|
|
RuntimeError,
|
|
"bool value of Tensor")
|
|
|
|
def test_list_slice(self):
|
|
def test_regular_slice():
|
|
a = [0, 1, 2, 3, 4]
|
|
return a[2:3] == [2]
|
|
self.checkScript(test_regular_slice, ())
|
|
|
|
def test_open_ended_slice():
|
|
a = [0, 1, 2, 3, 4]
|
|
return a[2:] == [2, 3, 4]
|
|
self.checkScript(test_open_ended_slice, ())
|
|
|
|
def test_open_ended_slice2():
|
|
a = [0, 1, 2, 3, 4]
|
|
return a[:2] == [0, 1]
|
|
self.checkScript(test_open_ended_slice2, ())
|
|
|
|
def test_negative_slice():
|
|
a = [0, 1, 2, 3, 4]
|
|
return a[:-1] == [0, 1, 2, 3]
|
|
self.checkScript(test_negative_slice, ())
|
|
|
|
def test_negative_slice2():
|
|
a = [0, 1, 2, 3, 4]
|
|
return a[-3:-1] == [2, 3]
|
|
self.checkScript(test_negative_slice2, ())
|
|
|
|
def test_backward_slice():
|
|
a = [0, 1, 2, 3, 4]
|
|
return a[3:2] == torch.jit.annotate(List[int], [])
|
|
self.checkScript(test_backward_slice, ())
|
|
|
|
def test_over_slice():
|
|
a = [0, 1, 2, 3, 4]
|
|
return a[3:10] == [3, 4]
|
|
self.checkScript(test_backward_slice, ())
|
|
|
|
def test_mutable_list_append(self):
|
|
def test_append():
|
|
a = [0, 1]
|
|
a.append(2)
|
|
a.append(3)
|
|
return a == [0, 1, 2, 3]
|
|
self.checkScript(test_append, ())
|
|
|
|
def test_comprehensions_basic(self):
|
|
def comp(l):
|
|
# type: (List[int]) -> List[int]
|
|
|
|
n = [x * 3 for x in l]
|
|
return n
|
|
|
|
comp([1, 2, 3])
|
|
self.checkScript(comp, ([1, 2, 3],))
|
|
|
|
def test_comprehensions_basic_float(self):
|
|
def comp(l):
|
|
# type: (List[float]) -> List[float]
|
|
|
|
n = [x * 3 for x in l]
|
|
return n
|
|
|
|
self.checkScript(comp, ([1.0, 2.0, 3.0],))
|
|
|
|
def test_comprehensions_two_comps(self):
|
|
@torch.jit.script
|
|
def comp(l1, l2):
|
|
# type: (List[int], List[int]) -> List[int]
|
|
|
|
n = [x * 3 for x in l1]
|
|
n2 = [x + 2 for x in l2]
|
|
return n + n2
|
|
|
|
self.assertEqual(comp([1, 2, 3], [4, 5]), [3, 6, 9, 6, 7])
|
|
|
|
def test_comprehensions_wrong_expr_type(self):
|
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
|
@torch.jit.script
|
|
def comp(l):
|
|
# type: (List[int]) -> List[float]
|
|
|
|
n = [float(x) for x in l]
|
|
return n
|
|
|
|
comp([1, 2, 3])
|
|
|
|
def test_mutable_list_append_2(self):
|
|
def test_append_2():
|
|
a = [0, 1]
|
|
a.append(2)
|
|
a = [1]
|
|
a.append(4)
|
|
return a == [1, 4]
|
|
self.checkScript(test_append_2, ())
|
|
|
|
def test_mutable_list_append_if(self):
|
|
def test_append_if():
|
|
a = [1]
|
|
if True:
|
|
a.append(4)
|
|
return a == [1, 4]
|
|
self.checkScript(test_append_if, ())
|
|
|
|
def test_mutable_list_append_if_else(self):
|
|
def test_append_if_else():
|
|
a = [1]
|
|
if False:
|
|
a.append(4)
|
|
else:
|
|
a.append(10)
|
|
return a == [1, 10]
|
|
self.checkScript(test_append_if_else, ())
|
|
|
|
def test_mutable_list_append_loop(self):
|
|
def test_append_loop():
|
|
a = torch.jit.annotate(List[int], [])
|
|
for i in range(5):
|
|
a.append(i)
|
|
|
|
return a == [0, 1, 2, 3, 4]
|
|
self.checkScript(test_append_loop, ())
|
|
|
|
def test_mutable_list_append_loop_if(self):
|
|
def test_append_loop_if():
|
|
a = torch.jit.annotate(List[int], [])
|
|
for i in range(5):
|
|
if i > 3:
|
|
a.append(i)
|
|
else:
|
|
a.append(0)
|
|
|
|
return a == [0, 0, 0, 0, 4]
|
|
self.checkScript(test_append_loop_if, ())
|
|
|
|
def test_mutable_list_nested_loop(self):
|
|
def test_nested_loop():
|
|
a = torch.jit.annotate(List[int], [])
|
|
for i in range(2):
|
|
for j in range(2):
|
|
a.append(i + j)
|
|
|
|
return a == [0, 1, 1, 2]
|
|
self.checkScript(test_nested_loop, ())
|
|
|
|
def test_mutable_list_function_inline(self):
|
|
@torch.jit.script
|
|
def bar(y):
|
|
# type: (List[int]) -> None
|
|
y.append(4)
|
|
|
|
@torch.jit.script
|
|
def foo():
|
|
x = [1, 2, 3]
|
|
bar(x)
|
|
return x
|
|
|
|
self.assertEqual(foo(), [1, 2, 3, 4])
|
|
|
|
def test_mutable_list_reverse_empty(self):
|
|
def test_reverse_empty():
|
|
a = []
|
|
a.reverse()
|
|
|
|
return a == []
|
|
self.checkScript(test_reverse_empty, ())
|
|
|
|
def test_mutable_list_reverse(self):
|
|
def test_reverse():
|
|
a = [1, 2, 3, 4]
|
|
a.reverse()
|
|
|
|
return a == [4, 3, 2, 1]
|
|
self.checkScript(test_reverse, ())
|
|
|
|
def test_mutable_tensor_list_reverse(self):
|
|
def test_tensor_reverse():
|
|
a = [torch.tensor(1), torch.tensor(2)]
|
|
a.reverse()
|
|
|
|
return a == [torch.tensor(2), torch.tensor(1)]
|
|
self.checkScript(test_tensor_reverse, ())
|
|
|
|
def test_mutable_list_pop_empty(self):
|
|
@torch.jit.script
|
|
def test_pop_empty():
|
|
a = torch.jit.annotate(List[int], [])
|
|
return a.pop()
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "pop from empty list"):
|
|
test_pop_empty()
|
|
|
|
def test_mutable_list_pop(self):
|
|
def test_pop():
|
|
a = [1, 2, 3, 4]
|
|
b = a.pop()
|
|
|
|
return b == 4
|
|
|
|
self.checkScript(test_pop, ())
|
|
|
|
def test_mutable_list_pop2(self):
|
|
def test_pop2():
|
|
a = [1, 2, 3, 4]
|
|
b = a.pop()
|
|
|
|
return len(a) == 3
|
|
|
|
self.checkScript(test_pop2, ())
|
|
|
|
def test_mutable_list_pop_at(self):
|
|
def test_pop_at():
|
|
a = [1, 2, 3, 4]
|
|
b = a.pop(1)
|
|
|
|
return b == 2
|
|
|
|
self.checkScript(test_pop_at, ())
|
|
|
|
def test_mutable_list_pop_at2(self):
|
|
def test_pop_at2():
|
|
a = [1, 2, 3, 4]
|
|
b = a.pop(1)
|
|
|
|
return len(a) == 3
|
|
|
|
self.checkScript(test_pop_at2, ())
|
|
|
|
def test_mutable_list_pop_at_negative(self):
|
|
def test_pop_at_negative():
|
|
a = [1, 2, 3, 4]
|
|
b = a.pop(-2)
|
|
|
|
return b == 3
|
|
|
|
self.checkScript(test_pop_at_negative, ())
|
|
|
|
def test_mutable_list_pop_at_negative2(self):
|
|
def test_pop_at_negative2():
|
|
a = [1, 2, 3, 4]
|
|
b = a.pop(-2)
|
|
|
|
return len(a) == 3
|
|
|
|
self.checkScript(test_pop_at_negative2, ())
|
|
|
|
def test_mutable_list_pop_slice(self):
|
|
def test_pop_slice():
|
|
a = [1, 2, 3, 4]
|
|
b = [1, 2, 3, 4]
|
|
|
|
a.pop()
|
|
b = b[:-1]
|
|
|
|
return a == b
|
|
|
|
self.checkScript(test_pop_slice, ())
|
|
|
|
@unittest.skipIf(sys.version_info < (3, 3), "clear not supported in version < 3.3")
|
|
def test_mutable_list_clear_empty(self):
|
|
def test_clear_empty():
|
|
a = torch.jit.annotate(List[int], [])
|
|
a.clear()
|
|
|
|
return len(a) == 0
|
|
self.checkScript(test_clear_empty, ())
|
|
|
|
@unittest.skipIf(sys.version_info < (3, 3), "clear not supported in version < 3.3")
|
|
def test_mutable_list_clear(self):
|
|
def test_clear():
|
|
a = [1, 2, 3, 4]
|
|
a.clear()
|
|
|
|
return len(a) == 0
|
|
self.checkScript(test_clear, ())
|
|
|
|
def test_mutable_list_insert(self):
|
|
def test_list_insert():
|
|
a = [1, 2, 3, 4]
|
|
a.insert(2, 5)
|
|
|
|
return a == [1, 2, 5, 3, 4]
|
|
self.checkScript(test_list_insert, ())
|
|
|
|
def test_mutable_list_insert_negative(self):
|
|
def test_list_insert_negative():
|
|
a = [1, 2, 3, 4]
|
|
a.insert(-1, 5)
|
|
|
|
return a == [1, 2, 3, 5, 4]
|
|
self.checkScript(test_list_insert_negative, ())
|
|
|
|
def test_mutable_list_insert_neg_out_of_bounds(self):
|
|
def test_list_insert_neg_out_of_bounds():
|
|
a = [1, 2, 3, 4]
|
|
a.insert(-10, 5)
|
|
|
|
return a == [5, 1, 2, 3, 4]
|
|
self.checkScript(test_list_insert_neg_out_of_bounds, ())
|
|
|
|
def test_mutable_list_insert_out_of_bounds(self):
|
|
def test_list_insert_out_of_bounds():
|
|
a = [1, 2, 3, 4]
|
|
a.insert(10, 5)
|
|
|
|
return a == [1, 2, 3, 4, 5]
|
|
self.checkScript(test_list_insert_out_of_bounds, ())
|
|
|
|
def test_mutable_list_remove_not_existing(self):
|
|
@torch.jit.script
|
|
def test_list_remove_not_existing():
|
|
a = [1, 2, 3, 4]
|
|
a.remove(5)
|
|
|
|
return a
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "x not in list"):
|
|
test_list_remove_not_existing()
|
|
|
|
def test_mutable_list_remove(self):
|
|
def test_list_remove():
|
|
a = [1, 2, 3, 4]
|
|
a.remove(3)
|
|
|
|
return a == [1, 2, 4]
|
|
self.checkScript(test_list_remove, ())
|
|
|
|
def test_list_index_not_existing(self):
|
|
@torch.jit.script
|
|
def list_index_not_existing():
|
|
a = [4, 1, 3, 2]
|
|
i = a.index(5)
|
|
|
|
return i
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "'5' is not in list"):
|
|
list_index_not_existing()
|
|
|
|
def test_list_index(self):
|
|
def list_index():
|
|
a = [4, 1, 3, 2]
|
|
i = a.index(3)
|
|
|
|
return i == 2
|
|
self.checkScript(list_index, ())
|
|
|
|
def test_tensor_list_index(self):
|
|
def tensor_list_index():
|
|
a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)]
|
|
i = a.index(torch.tensor(3))
|
|
|
|
return i == 2
|
|
self.checkScript(tensor_list_index, ())
|
|
|
|
def test_tensor_list_index_not_existing(self):
|
|
@torch.jit.script
|
|
def tensor_list_index_not_existing():
|
|
a = [torch.tensor(4), torch.tensor(1), torch.tensor(3), torch.tensor(2)]
|
|
i = a.index(torch.tensor(5))
|
|
|
|
return i
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "is not in list"):
|
|
tensor_list_index_not_existing()
|
|
|
|
def test_list_count(self):
|
|
def list_count():
|
|
a = [4, 1, 4, 2, 4]
|
|
i = a.count(4)
|
|
|
|
return i == 3
|
|
self.checkScript(list_count, ())
|
|
|
|
def test_list_count_not_existing(self):
|
|
def list_count_not_existing():
|
|
a = [4, 1, 4, 2, 4]
|
|
i = a.count(5)
|
|
|
|
return i == 0
|
|
self.checkScript(list_count_not_existing, ())
|
|
|
|
def test_tensor_list_count(self):
|
|
def tensor_list_count():
|
|
a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)]
|
|
i = a.count(torch.tensor(4))
|
|
|
|
return i == 3
|
|
self.checkScript(tensor_list_count, ())
|
|
|
|
def test_tensor_list_count_not_existing(self):
|
|
def tensor_list_count_not_existing():
|
|
a = [torch.tensor(4), torch.tensor(1), torch.tensor(4), torch.tensor(4)]
|
|
i = a.count(torch.tensor(5))
|
|
|
|
return i == 0
|
|
self.checkScript(tensor_list_count_not_existing, ())
|
|
|
|
def test_mutable_list_remove_tensor(self):
|
|
def test_list_remove_tensor():
|
|
a = [torch.ones(1), torch.zeros(1), torch.ones(2)]
|
|
a.remove(torch.zeros(1))
|
|
|
|
return len(a) == 2
|
|
self.checkScript(test_list_remove_tensor, ())
|
|
|
|
def test_mutable_list_remove2(self):
|
|
def test_list_remove2():
|
|
a = [1]
|
|
a.remove(1)
|
|
|
|
return len(a) == 0
|
|
self.checkScript(test_list_remove2, ())
|
|
|
|
def test_extend_list_mutable(self):
|
|
@torch.jit.script
|
|
def extend_list(a, b):
|
|
# type: (List[Tensor], List[Tensor]) -> List[Tensor]
|
|
|
|
a.extend(b)
|
|
return a
|
|
|
|
for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
|
|
for r in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
|
|
self.assertEqual(extend_list(l, r), l + r)
|
|
|
|
def test_extend_list_immutable(self):
|
|
@torch.jit.script
|
|
def extend_list(a, b):
|
|
# type: (List[int], List[int]) -> List[int]
|
|
|
|
a.extend(b)
|
|
return a
|
|
|
|
for l in [[], [1], [1, 2, 3]]:
|
|
for r in [[], [1], [1, 2, 3]]:
|
|
self.assertEqual(extend_list(l, r), l + r)
|
|
|
|
def test_copy_list_mutable(self):
|
|
@torch.jit.script
|
|
def copy_list(a):
|
|
# type: (List[Tensor]) -> List[Tensor]
|
|
return a.copy()
|
|
|
|
for l in [[], [torch.rand(2)], [torch.rand(2), torch.rand(2), torch.rand(2)]]:
|
|
self.assertEqual(copy_list(l), l)
|
|
|
|
def test_copy_list_immutable(self):
|
|
@torch.jit.script
|
|
def copy_list(a):
|
|
# type: (List[int]) -> List[int]
|
|
return a.copy()
|
|
|
|
for l in [[], [1], [1, 2, 3]]:
|
|
self.assertEqual(copy_list(l), l)
|
|
|
|
def test_func_call(self):
|
|
script = '''
|
|
def add(a, b):
|
|
return a + b
|
|
|
|
def mul(a, x):
|
|
return a * x
|
|
|
|
def func(alpha, beta, x, y):
|
|
return add(mul(alpha, x), mul(beta, y))
|
|
'''
|
|
alpha = torch.rand(1, dtype=torch.float, requires_grad=True)
|
|
beta = torch.rand(1, dtype=torch.float, requires_grad=True)
|
|
x = torch.rand(3, dtype=torch.float, requires_grad=True)
|
|
y = torch.rand(3, dtype=torch.float, requires_grad=True)
|
|
outputs = alpha * x + beta * y
|
|
# NOTE: cannot optimize yet because broadcasts are not inserted before the fuser runs
|
|
self.checkScript(script, [alpha, beta, x, y], optimize=False, outputs=outputs)
|
|
|
|
def test_resize_input_ops(self):
|
|
# resize_ and resize_as resize the input tensor. because our shape analysis
|
|
# is flow invariant, we set any Tensor that can alias a resized Tensor
|
|
# to the base Tensor Type, without size information.
|
|
|
|
# testing that value which is an input of a graph gets handled
|
|
def out_op_graph_input():
|
|
@torch.jit.script
|
|
def test(x, y, z):
|
|
torch.mul(x, y, out=z)
|
|
return z
|
|
|
|
torch._C._jit_pass_shape_analysis(
|
|
test.graph, (torch.zeros(2, 1), torch.zeros(1, 2), torch.zeros(1, 1, 1)), False)
|
|
self.assertTrue(next(test.graph.outputs()).type() == TensorType.get())
|
|
out_op_graph_input()
|
|
|
|
def test_resize():
|
|
@torch.jit.script
|
|
def test(x):
|
|
after_resize_alias = torch.zeros([2])
|
|
for _i in range(5):
|
|
b = x + 1
|
|
f = [1]
|
|
before_resize_alias = b.sub_(1)
|
|
# for i in range(10):
|
|
f.append(1)
|
|
b.resize_(f)
|
|
after_resize_alias = b.add_(1)
|
|
return after_resize_alias
|
|
|
|
g = test.graph
|
|
self.run_pass('constant_propagation', g)
|
|
torch._C._jit_pass_shape_analysis(
|
|
g, (torch.zeros(1, 1),), False)
|
|
resize_node = g.findNode("aten::resize_")
|
|
# first input and output of b.resize_ is b
|
|
self.assertTrue(next(resize_node.inputs()).type() == TensorType.get())
|
|
self.assertTrue(next(resize_node.outputs()).type() == TensorType.get())
|
|
|
|
# correctly propagates to b alias set
|
|
before_resize = g.findNode("aten::sub_")
|
|
self.assertTrue(next(before_resize.outputs()).type() == TensorType.get())
|
|
|
|
after_resize = g.findNode("aten::add_")
|
|
self.assertTrue(next(after_resize.outputs()).type() == TensorType.get())
|
|
|
|
test_resize()
|
|
|
|
def test_resize_as():
|
|
@torch.jit.script
|
|
def test(x):
|
|
b = torch.zeros([2, 2])
|
|
b.resize_as_(x)
|
|
return b
|
|
|
|
g = test.graph
|
|
self.run_pass('constant_propagation', g)
|
|
torch._C._jit_pass_shape_analysis(
|
|
g, (torch.zeros(1, 1),), False)
|
|
|
|
# x doesn't alias a resized op so it shouldn't be set to base Tensor type
|
|
self.assertTrue(next(g.inputs()).type() != TensorType.get())
|
|
# return is resized
|
|
self.assertTrue(next(g.outputs()).type() == TensorType.get())
|
|
|
|
test_resize_as()
|
|
|
|
def test_requires_grad_loop(self):
|
|
@torch.jit.script
|
|
def test(x, y, z):
|
|
# type: (Tensor, Tensor, int) -> Tensor
|
|
for _ in range(z):
|
|
x = y
|
|
return x
|
|
|
|
# x requires grad, y does not
|
|
# testing that requires grad analysis correctly exits, with its input
|
|
# to the loop (x) requiring grad and its output to the loop not requiring grad
|
|
# and the output of the node conservatively setting grad to true
|
|
|
|
inps = (torch.tensor(1.0, requires_grad=True), torch.tensor(1), 10)
|
|
test(*inps)
|
|
|
|
graph = test.graph_for(*inps)
|
|
loop = graph.findNode("prim::Loop")
|
|
loop_body = next(loop.blocks())
|
|
loop_inputs = list(loop_body.inputs())
|
|
loop_outputs = list(loop_body.outputs())
|
|
|
|
self.assertTrue(loop_inputs[1].requires_grad())
|
|
self.assertFalse(loop_outputs[1].requires_grad())
|
|
self.assertTrue(loop.output().requires_grad())
|
|
|
|
def test_view_shape_prop(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def test_view_shape_prop(a):
|
|
return a.view(size=[-1])
|
|
''')
|
|
inputs = [torch.zeros(10, 10)]
|
|
outputs = torch.zeros(100)
|
|
|
|
real_outs = cu.test_view_shape_prop(*inputs)
|
|
self.assertEqual(real_outs, outputs)
|
|
|
|
def test_view_listconstruct_shape_prop(self):
|
|
def fn(x):
|
|
B = x.size(0)
|
|
C = x.size(1)
|
|
T = x.size(2)
|
|
return x.view(T, B, C)
|
|
|
|
x = torch.randn(3, 1, 5, requires_grad=True)
|
|
graph = torch.jit.script(fn).graph
|
|
torch._C._jit_pass_shape_analysis(graph, (x,), False)
|
|
a = next(graph.outputs()).type().kind()
|
|
self.assertTrue(next(graph.outputs()).type().kind() != 'TensorType')
|
|
|
|
def test_integral_shape_inference(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def test_integral_shape_inference(a):
|
|
return a / a
|
|
''')
|
|
inputs = [torch.ones(10, 10).type(torch.LongTensor)]
|
|
outputs = torch.ones(10, 10)
|
|
|
|
self.assertEqual(cu.test_integral_shape_inference(*inputs), outputs)
|
|
|
|
def test_fuser_multiple_blocks(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def test_fuser_multiple_blocks(this, that, theother, meme):
|
|
i = 0
|
|
while i < 20:
|
|
this = torch.cat([this, meme], dim=0)
|
|
that = torch.cat([that, meme], dim=0)
|
|
theother = torch.cat([theother, meme], dim=0)
|
|
i = i + 1
|
|
return this, that, theother
|
|
''')
|
|
|
|
inputs = [torch.ones(0, 10, 10)] * 3
|
|
inputs += [torch.ones(1, 10, 10)]
|
|
outputs = [torch.ones(20, 10, 10)] * 3
|
|
|
|
self.assertEqual(cu.test_fuser_multiple_blocks(*inputs), outputs)
|
|
|
|
def test_dropout_script(self):
|
|
|
|
eg = torch.zeros(1, 2, 3, requires_grad=True)
|
|
|
|
@_trace(eg)
|
|
def foo(x):
|
|
x = torch.neg(x)
|
|
return F.dropout(x)
|
|
|
|
class MyDrop(nn.Module):
|
|
def forward(self, x):
|
|
return foo(x)
|
|
|
|
f = io.BytesIO()
|
|
torch.onnx.export(MyDrop(), (eg,), f, verbose=False)
|
|
|
|
@unittest.skip("RuntimeError: VariableType::ID() not implemented")
|
|
def test_cast(self):
|
|
script = '''
|
|
def to_int(x):
|
|
return int(x)
|
|
'''
|
|
x = Variable(torch.FloatTensor([1.1, 2.3]), requires_grad=True)
|
|
out = Variable(torch.IntTensor([1, 2]), requires_grad=True)
|
|
self.checkScript(script, [x], optimize=True, outputs=[out], func='to_int')
|
|
|
|
def test_python_frontend(self):
|
|
def fn(x, y, z):
|
|
q = None
|
|
q = x + y - z.sigmoid()
|
|
print(q)
|
|
w = -z
|
|
if not x and not y and z:
|
|
m = x if not z else y
|
|
while x < y > z:
|
|
q = x
|
|
assert 1 == 1, "hello"
|
|
return x
|
|
|
|
ast = torch.jit.frontend.get_jit_def(fn)
|
|
self.assertExpected(str(ast))
|
|
|
|
@unittest.skipIf(not PY2, "Requires python 2")
|
|
def test_python_frontend_py2(self):
|
|
def fn():
|
|
raise Exception("hello")
|
|
ast = torch.jit.frontend.get_jit_def(fn)
|
|
self.assertExpected(str(ast))
|
|
|
|
@unittest.skipIf(PY2, "Requires python 3")
|
|
def test_python_frontend_py3(self):
|
|
def fn():
|
|
raise Exception("hello")
|
|
ast = torch.jit.frontend.get_jit_def(fn)
|
|
self.assertExpected(str(ast))
|
|
|
|
def _make_scalar_vars(self, arr, dtype):
|
|
return [torch.tensor(val, dtype=dtype) for val in arr]
|
|
|
|
def test_string_print(self):
|
|
def func(a):
|
|
print(a, "a" 'b' '''c''' """d""", 2, 1.5)
|
|
return a
|
|
|
|
inputs = self._make_scalar_vars([1], torch.int64)
|
|
self.checkScript(func, inputs, capture_output=True)
|
|
|
|
def test_while(self):
|
|
def func(a, b, max):
|
|
while bool(a < max):
|
|
a = a + 1
|
|
b = b + 1
|
|
c = a + b
|
|
return c
|
|
|
|
inputs = self._make_scalar_vars([1, 1, 10], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_fibb(self):
|
|
def func(lim):
|
|
first = 1
|
|
second = 1
|
|
i = 1
|
|
somenum = 5
|
|
dontmutateme = 3
|
|
third = 0
|
|
while bool(i < lim):
|
|
third = first + second
|
|
first = second
|
|
second = third
|
|
j = 0
|
|
while j < 10:
|
|
somenum = somenum * 2
|
|
j = j + 1
|
|
i = i + j
|
|
i = i + dontmutateme
|
|
|
|
st = second + third
|
|
fs = first + second
|
|
return third, st, fs
|
|
|
|
inputs = self._make_scalar_vars([10], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_if(self):
|
|
def func(a, b):
|
|
# type: (int, int) -> int
|
|
d = 3
|
|
if bool(a > 10):
|
|
a = 3 + d
|
|
else:
|
|
b = 3 + d
|
|
d = 4
|
|
c = a + b
|
|
return c
|
|
|
|
inputs = self._make_scalar_vars([1, -1], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_if_for_in_range(self):
|
|
def func(a, b):
|
|
# type: (int, int) -> int
|
|
d = 3
|
|
for _ in range(20):
|
|
if bool(a > 10):
|
|
a = 3 + d
|
|
else:
|
|
b = 3 + d
|
|
d = 4
|
|
c = a + b
|
|
return d
|
|
inputs = self._make_scalar_vars([1, -1], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_if_noelse(self):
|
|
def func(a, b):
|
|
if bool(a > 10):
|
|
a = 3 + b
|
|
c = a + b
|
|
return c
|
|
|
|
inputs = self._make_scalar_vars([-1, 1], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_if_is_none_dispatch(self):
|
|
|
|
@torch.jit.script
|
|
def test_lhs_none_rhs_none():
|
|
# LHS, RHS both alwaysNone, dispatch always_none_branch
|
|
# only emit one prim::Constant
|
|
if None is None:
|
|
return 1
|
|
elif None is not None:
|
|
return 2
|
|
else:
|
|
return 3
|
|
|
|
self.assertTrue(str(test_lhs_none_rhs_none.graph).count(': int = prim::Constant') == 1)
|
|
|
|
@torch.jit.script
|
|
def test_lhs_opt_rhs_none(lhs=None):
|
|
# type: (Optional[Tensor]) -> int
|
|
# LHS maybeNone: emit normal if stmt that contains 3 constants
|
|
if lhs is not None:
|
|
return 2
|
|
elif lhs is None:
|
|
return 1
|
|
else:
|
|
return 3
|
|
|
|
self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
|
|
|
|
@torch.jit.script
|
|
def test_lhs_none_rhs_opt(rhs=None):
|
|
# type: (Optional[Tensor]) -> int
|
|
# RHS maybeNone, emit normal if stmt that contains 3 constants
|
|
if None is rhs:
|
|
return 1
|
|
elif None is not rhs:
|
|
return 2
|
|
else:
|
|
return 3
|
|
|
|
self.assertTrue(str(test_lhs_opt_rhs_none.graph).count(': int = prim::Constant') == 3)
|
|
|
|
@torch.jit.script
|
|
def test_lhs_never_rhs_none(lhs):
|
|
# LHS neverNone, RHS alwaysNone dispatch never_none_branch
|
|
# only emit one prim::Constant
|
|
if lhs is None:
|
|
return 1
|
|
elif lhs is not None:
|
|
return 2
|
|
else:
|
|
return 3
|
|
|
|
self.assertTrue(str(test_lhs_never_rhs_none.graph).count(': int = prim::Constant') == 1)
|
|
|
|
@torch.jit.script
|
|
def test_lhs_none_rhs_never(rhs):
|
|
# LHS alwaysNone, RHS neverNone dispatch never_none_branch
|
|
# only emit one prim::Constant
|
|
if None is rhs:
|
|
return 1
|
|
elif None is not rhs:
|
|
return 2
|
|
else:
|
|
return 3
|
|
|
|
self.assertTrue(str(test_lhs_none_rhs_never.graph).count(': int = prim::Constant') == 1)
|
|
|
|
def test_explicit_bool_cast(self):
|
|
with self.assertRaisesRegex(RuntimeError, "expected a boolean"):
|
|
@torch.jit.script
|
|
def test_bool_cast(a):
|
|
if a:
|
|
return a + 2
|
|
return a + 1
|
|
|
|
def test_while_nonexistent_value(self):
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value x"):
|
|
torch.jit.CompilationUnit('''
|
|
def test_while(a, b):
|
|
while bool(a < 10):
|
|
a = a + x
|
|
b = b + 1
|
|
return a + b
|
|
''')
|
|
|
|
def test_while_nonexistent_cond_value(self):
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value x"):
|
|
torch.jit.CompilationUnit('''
|
|
def test_while(a, b):
|
|
while a < x:
|
|
a = a + 1
|
|
b = b + 1
|
|
return a + b
|
|
''')
|
|
|
|
def test_optional_refinement(self):
|
|
@torch.jit.script
|
|
def test_if_none_assignment(x):
|
|
# type: (Optional[int]) -> int
|
|
if x is None:
|
|
x = 1
|
|
return x + 1
|
|
|
|
self.assertEqual(test_if_none_assignment(1), 2)
|
|
|
|
@torch.jit.script
|
|
def test_ternary(x):
|
|
# type: (Optional[int]) -> int
|
|
x = x if x is not None else 2
|
|
return x
|
|
|
|
@torch.jit.script
|
|
def test_not_none(x):
|
|
# type: (Optional[int]) -> None
|
|
if x is not None:
|
|
print(x + 1)
|
|
|
|
@torch.jit.script
|
|
def test_and(x, y):
|
|
# type: (Optional[int], Optional[int]) -> None
|
|
if x is not None and y is not None:
|
|
print(x + y)
|
|
|
|
@torch.jit.script
|
|
def test_not(x, y):
|
|
# type: (Optional[int], Optional[int]) -> None
|
|
if not (x is not None and y is not None):
|
|
pass
|
|
else:
|
|
print(x + y)
|
|
|
|
@torch.jit.script
|
|
def test_bool_expression(x):
|
|
# type: (Optional[int]) -> None
|
|
if x is not None and x < 2:
|
|
print(x + 1)
|
|
|
|
@torch.jit.script
|
|
def test_nested_bool_expression(x, y):
|
|
# type: (Optional[int], Optional[int]) -> int
|
|
if x is not None and x < 2 and y is not None:
|
|
x = x + y
|
|
else:
|
|
x = 5
|
|
return x + 2
|
|
|
|
@torch.jit.script
|
|
def test_or(x, y):
|
|
# type: (Optional[int], Optional[int]) -> None
|
|
if y is None or x is None:
|
|
pass
|
|
else:
|
|
print(x + y)
|
|
|
|
# backwards compatibility
|
|
@torch.jit.script
|
|
def test_manual_unwrap_opt(x):
|
|
# type: (Optional[int]) -> int
|
|
if x is None:
|
|
x = 1
|
|
else:
|
|
x = torch.jit._unwrap_optional(x)
|
|
return x # noqa: T484
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
|
@torch.jit.script
|
|
def or_error(x, y):
|
|
# type: (Optional[int], Optional[int]) -> None
|
|
if x is None or y is None:
|
|
print(x + y) # noqa: T484
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
|
@torch.jit.script
|
|
def and_error(x, y):
|
|
# type: (Optional[int], Optional[int]) -> None
|
|
if x is None and y is None:
|
|
pass
|
|
else:
|
|
print(x + y) # noqa: T484
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
|
@torch.jit.script
|
|
def named_var(x):
|
|
# type: (Optional[int]) -> None
|
|
x_none = x is not None
|
|
if x_none:
|
|
print(x + 1) # noqa: T484
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
|
@torch.jit.script
|
|
def named_var_and(x, y):
|
|
# type: (Optional[int], Optional[int]) -> None
|
|
x_none = x is not None
|
|
if y is not None and x_none:
|
|
print(x + y) # noqa: T484
|
|
|
|
def test_while_write_outer_then_read(self):
|
|
def func(a, b):
|
|
while bool(a < 10):
|
|
a = a + 1
|
|
b = a + 1
|
|
return a + b
|
|
|
|
inputs = self._make_scalar_vars([42, 1337], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_while_nest_if(self):
|
|
def func(a, b):
|
|
# type: (int, int) -> int
|
|
c = 0
|
|
while a < 10:
|
|
a = a + 1
|
|
b = b + 1
|
|
if a > b:
|
|
c = -a
|
|
else:
|
|
c = -b
|
|
return c + 1
|
|
|
|
inputs = self._make_scalar_vars([-1234, 4321], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_math_ops(self):
|
|
|
|
def test_floor():
|
|
return math.floor(1.5)
|
|
|
|
self.checkScript(test_floor, ())
|
|
|
|
def test_if_nest_while(self):
|
|
def func(a, b):
|
|
# type: (int, int) -> int
|
|
c = 0
|
|
if a > b:
|
|
while a > b:
|
|
b = b + 1
|
|
c = -b
|
|
return c
|
|
|
|
inputs = self._make_scalar_vars([4321, 1234], torch.int64)
|
|
self.checkScript(func, inputs, optimize=True)
|
|
|
|
def test_script_for_in_range(self):
|
|
def fn():
|
|
c = 0
|
|
for i in range(100):
|
|
c += i
|
|
return c
|
|
self.checkScript(fn, (), outputs=4950, optimize=True)
|
|
|
|
def test_script_for_in_range_dynamic(self):
|
|
def fn():
|
|
c = 0
|
|
for i in range(100):
|
|
acc = 0
|
|
for j in range(i):
|
|
acc += j
|
|
c += acc
|
|
return c
|
|
self.checkScript(fn, (), optimize=False)
|
|
|
|
def test_script_for_in_range_ast(self):
|
|
@torch.jit.script
|
|
def test_script_for_in_range_ast():
|
|
c = 0
|
|
for i in range(100):
|
|
acc = 0
|
|
for j in range(i):
|
|
acc += j
|
|
c += acc
|
|
return c
|
|
|
|
self.assertEqual(test_script_for_in_range_ast(), 161700)
|
|
|
|
def test_script_for_in_range_if_ast(self):
|
|
@torch.jit.script
|
|
def test_script_for_in_range_if_ast(x):
|
|
output = x
|
|
for i in range(20):
|
|
if i == 0:
|
|
output = x.unsqueeze(0)
|
|
else:
|
|
output = torch.cat((output, x.unsqueeze(0)), dim=0)
|
|
return output
|
|
inputs = self._make_scalar_vars([0], torch.int64)
|
|
|
|
self.assertEqual(test_script_for_in_range_if_ast(*inputs).shape[0], 20)
|
|
|
|
def test_script_optional_none(self):
|
|
def none_stmt(x):
|
|
output = None
|
|
output = x
|
|
return output
|
|
|
|
def none_args(x):
|
|
# type: (Optional[Tensor]) -> Optional[Tensor]
|
|
return None
|
|
|
|
self.checkScript(none_stmt, [torch.arange(0, 2)], optimize=True)
|
|
self.checkScript(none_args, [None], optimize=True)
|
|
|
|
# test undefined tensor None as default param
|
|
def test_script_optional_tensor_none(x=None):
|
|
# type: (Optional[Tensor]) -> Tensor
|
|
res = torch.zeros(1, dtype=torch.int8)
|
|
if x is None:
|
|
res = res + 1
|
|
else:
|
|
res = x
|
|
return res
|
|
|
|
fn = test_script_optional_tensor_none
|
|
scripted_fn = torch.jit.script(fn)
|
|
self.assertEqual(fn(), scripted_fn())
|
|
self.assertEqual(fn(torch.zeros(1)), scripted_fn(torch.zeros(1)))
|
|
|
|
# test typical None as default param
|
|
def test_script_optional_other_none(x=None):
|
|
# type: (Optional[float]) -> float
|
|
res = 2.0
|
|
if x is None:
|
|
res = res + 1.0
|
|
else:
|
|
res = x
|
|
return res
|
|
|
|
fn = test_script_optional_other_none
|
|
scripted_fn = torch.jit.script(fn)
|
|
self.assertEqual(fn(), scripted_fn())
|
|
self.assertEqual(fn(1.0), scripted_fn(1.0))
|
|
|
|
def test_script_clamp_none(self):
|
|
def test_script_clamp_max_none(x):
|
|
return torch.clamp(x, min=2, max=None)
|
|
|
|
def test_script_clamp_max(x):
|
|
return torch.clamp(x, max=2)
|
|
|
|
def test_script_clamp_min_none(x):
|
|
return torch.clamp(x, min=None, max=2)
|
|
|
|
def test_script_clamp_min(x):
|
|
return torch.clamp(x, min=2)
|
|
|
|
input = [torch.arange(0, 3)]
|
|
self.checkScript(test_script_clamp_max_none, input, optimize=True)
|
|
self.checkScript(test_script_clamp_max, input, optimize=True)
|
|
self.checkScript(test_script_clamp_min_none, input, optimize=True)
|
|
self.checkScript(test_script_clamp_min, input, optimize=True)
|
|
|
|
def test_script_bool_constant(self):
|
|
script = '''
|
|
def test_script_bool_constant():
|
|
a = True
|
|
return a
|
|
'''
|
|
outputs = [1]
|
|
self.checkScript(script, [], outputs[0], True, 'test_script_bool_constant')
|
|
|
|
def test_ternary(self):
|
|
def func(a, b):
|
|
c = 3
|
|
c = a + b if bool(a > 3) else b
|
|
return c
|
|
|
|
inputs_true = self._make_scalar_vars([5, 2], torch.int64)
|
|
inputs_false = self._make_scalar_vars([1, 0], torch.int64)
|
|
self.checkScript(func, inputs_true, optimize=True)
|
|
self.checkScript(func, inputs_false, optimize=True)
|
|
|
|
def test_print(self):
|
|
def func(x, y):
|
|
q = (x + y).sigmoid()
|
|
print(q, 1, 2, [1, 2], [1.0, 2.0])
|
|
w = -q
|
|
return w * w
|
|
|
|
x = torch.arange(4., requires_grad=True)
|
|
y = torch.arange(0., 8, 2, requires_grad=True)
|
|
self.checkScript(func, [x, y], optimize=True, capture_output=True)
|
|
|
|
def test_format(self):
|
|
def func(x):
|
|
print("{}, I'm a {}".format("Hello", "test"))
|
|
print("format blank".format())
|
|
print("stuff before {}".format("hi"))
|
|
print("{} stuff after".format("hi"))
|
|
return x + 1
|
|
|
|
x = torch.arange(4., requires_grad=True)
|
|
self.checkScript(func, [x], optimize=True, capture_output=True)
|
|
|
|
def test_logical_short_circuit(self):
|
|
@torch.jit.script
|
|
def testNoThrows(t):
|
|
c1 = 1
|
|
if (False and bool(t[1])) or (True or bool(t[1])):
|
|
c1 = 0
|
|
return c1
|
|
|
|
self.assertEqual(0, testNoThrows(torch.randn(0)))
|
|
ifs = testNoThrows.graph.findAllNodes("prim::If", recurse=False)
|
|
|
|
# three ifs at the top level, and the second one has a nested if for
|
|
# the or (True or bool(t[1])) expression
|
|
self.assertTrue(len(ifs) == 3)
|
|
self.assertTrue(ifs[0].findNode("prim::If") is None)
|
|
self.assertTrue(ifs[1].findNode("prim::If").findNode("prim::If") is None)
|
|
self.assertTrue(ifs[2].findNode("prim::If") is None)
|
|
|
|
@torch.jit.script
|
|
def throwsOr(t):
|
|
c0 = False or bool(t[1])
|
|
print(c0)
|
|
|
|
@torch.jit.script
|
|
def throwsAnd(t):
|
|
c0 = True and bool(t[1])
|
|
print(c0)
|
|
|
|
t = torch.randn(0)
|
|
with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
|
|
throwsOr(t)
|
|
with self.assertRaisesRegex(RuntimeError, "index 1 out of range for tensor of size"):
|
|
throwsAnd(t)
|
|
|
|
def test_type_cast(self):
|
|
template = dedent('''
|
|
def cast(v):
|
|
# type: ({from_type}) -> {to_type}
|
|
return {to_type}(v)
|
|
''')
|
|
|
|
def check_cast(from_type, to_type, value, raises=False):
|
|
code = template.format(from_type=from_type, to_type=to_type)
|
|
expected = getattr(builtins, to_type)(value)
|
|
if raises:
|
|
with self.assertRaisesRegex(RuntimeError, "Cannot cast"):
|
|
cu = torch.jit.CompilationUnit(code)
|
|
else:
|
|
self.checkScript(code, (value,), name='cast', outputs=expected)
|
|
|
|
check_cast('int', 'float', 1)
|
|
check_cast('int', 'bool', 1)
|
|
check_cast('int', 'bool', 0)
|
|
|
|
check_cast('float', 'int', 1.)
|
|
check_cast('float', 'bool', 1.)
|
|
check_cast('float', 'bool', 0.)
|
|
|
|
check_cast('bool', 'int', True)
|
|
check_cast('bool', 'float', True)
|
|
|
|
def test_multiple_assignment(self):
|
|
def outer_func(x):
|
|
return x * 2, x + 2
|
|
|
|
@torch.jit.script
|
|
def func(x):
|
|
y, z = outer_func(x)
|
|
return y + z
|
|
|
|
x = torch.arange(4)
|
|
self.assertEqual(func(x), x * 2 + x + 2)
|
|
|
|
def test_literals(self):
|
|
def func(a):
|
|
return a.view(size=[1, 2, 3])
|
|
|
|
a = torch.randn(6)
|
|
self.checkScript(func, [a], optimize=True)
|
|
|
|
def test_return(self):
|
|
def no_return(a):
|
|
a + 1
|
|
|
|
def void_return(a):
|
|
return
|
|
|
|
def one_return(a):
|
|
return a + 1.
|
|
|
|
def multiple_returns(a):
|
|
return a * 1., a * 2., a * 3.
|
|
|
|
a = torch.randn(1, dtype=torch.float)
|
|
self.checkScript(no_return, [a], optimize=True)
|
|
self.checkScript(void_return, [a], optimize=True)
|
|
self.checkScript(one_return, [a], optimize=True)
|
|
self.checkScript(multiple_returns, [a], optimize=True)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "but is actually of type None"):
|
|
torch.jit.CompilationUnit('''
|
|
def no_return_bad_annotation(a):
|
|
# type: (Tensor) -> Tensor
|
|
a + 1
|
|
''')
|
|
|
|
def test_error(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return a.t()
|
|
s = Variable(torch.rand(5, 5, 5))
|
|
# XXX: this should stay quiet in stay propagation and only fail in the interpreter
|
|
with self.assertRaisesRegex(RuntimeError, "failed in interpreter"):
|
|
foo(s)
|
|
|
|
@torch.jit.script
|
|
def bar(c, b):
|
|
return c + b
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "failed in interpreter"):
|
|
bar(Variable(torch.rand(10), requires_grad=True), Variable(torch.rand(9), requires_grad=True))
|
|
|
|
def test_binop_unsupported_error(self):
|
|
with self.assertRaisesRegex(NotSupportedError, "unsupported binary operator:"):
|
|
@torch.jit.script
|
|
def binop(x, y):
|
|
# Replace this with another unsupported op when/if it gets supported
|
|
return x << y
|
|
|
|
def test_bitwise_ops(self):
|
|
|
|
def int_test():
|
|
return 2 & 3, 2 ^ 3, 2 | 3
|
|
|
|
self.checkScript(int_test, ())
|
|
|
|
def bool_test(x, y):
|
|
# type: (bool, bool) -> Tuple[bool, bool, bool]
|
|
return x & y, x ^ y, x | y
|
|
|
|
self.checkScript(bool_test, (True, False))
|
|
self.checkScript(bool_test, (True, True))
|
|
|
|
def tensor_test(x, y):
|
|
return x & y, x ^ y, x | y
|
|
|
|
x = torch.tensor(2)
|
|
y = torch.tensor(3)
|
|
|
|
self.checkScript(tensor_test, (x, y))
|
|
|
|
def test_number_math(self):
|
|
ops_template = dedent('''
|
|
def func():
|
|
return {scalar1} {op} {scalar2}
|
|
''')
|
|
ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '//']
|
|
funcs_template = dedent('''
|
|
def func():
|
|
return {func}({scalar1}, {scalar2})
|
|
''')
|
|
funcs = ['min', 'max']
|
|
scalars = ['7', '2', '3', '-3', '3.14', '0.125', '-0.5', '2.0', '-2.0']
|
|
scalar_pairs = [(scalar1, scalar2) for scalar1 in scalars for scalar2 in scalars]
|
|
|
|
def run_test(code):
|
|
scope = {}
|
|
execWrapper(code, globals(), scope)
|
|
cu = torch.jit.CompilationUnit(code)
|
|
|
|
self.assertEqual(cu.func(), scope['func']())
|
|
|
|
for scalar1, scalar2 in scalar_pairs:
|
|
for op in ops:
|
|
code = ops_template.format(op=op, scalar1=scalar1, scalar2=scalar2)
|
|
run_test(code)
|
|
for func in funcs:
|
|
code = funcs_template.format(func=func, scalar1=scalar1, scalar2=scalar2)
|
|
run_test(code)
|
|
|
|
def test_number_div(self):
|
|
self.checkScript(div_int_future, (), optimize=True)
|
|
self.checkScript(div_float_future, (), optimize=True)
|
|
|
|
if PY2:
|
|
with self.assertRaisesRegex(RuntimeError, 'from __future__ import division'):
|
|
torch.jit.script(div_int_nofuture)
|
|
with self.assertRaisesRegex(RuntimeError, 'from __future__ import division'):
|
|
torch.jit.script(div_float_nofuture)
|
|
else:
|
|
self.checkScript(div_int_nofuture, (), optimize=True)
|
|
self.checkScript(div_float_nofuture, (), optimize=True)
|
|
|
|
def test_floor_div(self):
|
|
@torch.jit.script
|
|
def foo(a, b):
|
|
# type: (int, int) -> int
|
|
return a // b
|
|
for i in range(-8, 8):
|
|
for j in range(-8, 8):
|
|
if j != 0:
|
|
self.assertEqual(foo(i, j), i // j)
|
|
else:
|
|
with self.assertRaisesRegex(RuntimeError, 'division by 0'):
|
|
foo(i, j)
|
|
|
|
def test_number_augassign(self):
|
|
def func():
|
|
z = 1
|
|
z += 2
|
|
return z
|
|
|
|
self.checkScript(func, (), optimize=True)
|
|
|
|
def test_number_neg(self):
|
|
# int -> int
|
|
def func1():
|
|
return -8
|
|
|
|
# float -> float
|
|
def func2():
|
|
return -3.14
|
|
|
|
self.checkScript(func1, (), optimize=True)
|
|
self.checkScript(func2, (), optimize=True)
|
|
|
|
def _test_tensor_number_math(self, device='cpu'):
|
|
template = dedent('''
|
|
def func(t):
|
|
return {lhs} {op} {rhs}
|
|
''')
|
|
|
|
def test(op, const, swap_args):
|
|
args = ('t', const)
|
|
if swap_args:
|
|
args = (const, 't')
|
|
|
|
code = template.format(lhs=args[0], rhs=args[1], op=op)
|
|
scope = {}
|
|
execWrapper(code, globals(), scope)
|
|
cu = torch.jit.CompilationUnit(code)
|
|
self.assertEqual(cu.func(tensor), scope['func'](tensor))
|
|
|
|
var_int = [2, -2]
|
|
var_float = [1.4321, -1.2]
|
|
|
|
ops = ['+', '-', '*', '%', '<', '<=', '>', '>=', '==', '!=', '/']
|
|
|
|
float_tensor = torch.randn(5, 5, device=device)
|
|
double_tensor = torch.randn(5, 5, dtype=torch.double, device=device)
|
|
long_tensor = torch.randint(-5, 5, (5, 5), dtype=torch.long, device=device)
|
|
long_tensor[long_tensor == 0] = 2
|
|
|
|
tensors = [float_tensor, double_tensor, long_tensor]
|
|
consts = var_int + var_float
|
|
|
|
for op, tensor, const, swap_args in product(ops, tensors, consts, [True, False]):
|
|
# FIXME: things like 2 / long_tensor are not implemented correctly
|
|
# Look in torch/tensor.py to see how pytorch implements it.
|
|
if op == '/' and tensor.data_ptr() == long_tensor.data_ptr():
|
|
continue
|
|
|
|
# % operator does not take: const % tensor
|
|
if op == '%' and swap_args is True:
|
|
continue
|
|
|
|
test(op, const, swap_args)
|
|
|
|
def test_tensor_number_math(self):
|
|
self._test_tensor_number_math()
|
|
|
|
def test_torch_tensor_bad_input(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Input list to torch.tensor must be of ints, floats, "
|
|
"or bools, got None"):
|
|
@torch.jit.script
|
|
def test():
|
|
return torch.tensor([None])
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Note: empty lists are constructed as Tensor"):
|
|
@torch.jit.script
|
|
def tmp():
|
|
return torch.tensor([])
|
|
|
|
@torch.jit.script
|
|
def foo():
|
|
return torch.tensor([[2, 2], [1]])
|
|
with self.assertRaisesRegex(RuntimeError, "Expected sequence of length"):
|
|
foo()
|
|
|
|
@suppress_warnings
|
|
def test_torch_tensor_empty_list(self):
|
|
def func():
|
|
return torch.tensor(torch.jit.annotate(List[int], []))
|
|
cu = torch.jit.script(func)
|
|
t1 = cu()
|
|
t2 = func()
|
|
|
|
# torchscript returns int tensor, python returns float tensor
|
|
self.assertNotEqual(t1.dtype, t2.dtype)
|
|
|
|
def func():
|
|
li = torch.jit.annotate(List[int], [])
|
|
return torch.tensor([li, li])
|
|
|
|
self.checkScript(func, ())
|
|
|
|
def func():
|
|
li = torch.jit.annotate(List[int], [])
|
|
return torch.tensor([[[li]]])
|
|
|
|
self.checkScript(func, ())
|
|
|
|
def test_torch_tensor(self):
|
|
template = dedent('''
|
|
def func():
|
|
li = {list_create}
|
|
return torch.tensor(li {options})
|
|
''')
|
|
|
|
lists = ["2.5", "4", "True", "False", "[2]", "[-.5]", "[False, True, False]", "[2, 2]",
|
|
"torch.jit.annotate(List[int], [])", "[2.5, 2.5]", "[[2], [2]]", "[[-.5], [2.2]]", "[[False], [True]]"]
|
|
|
|
dtypes = ["", ", dtype=torch.float", ", dtype=torch.double", ", dtype=torch.half",
|
|
", dtype=torch.uint8", ", dtype=torch.int8", ", dtype=torch.short",
|
|
", dtype=torch.int", ", dtype=torch.long"]
|
|
|
|
devices = ['', ", device='cpu'"]
|
|
if RUN_CUDA:
|
|
devices.append(", device='cuda'")
|
|
|
|
option_pairs = [dtype + device for dtype in dtypes for device in devices]
|
|
for li in lists:
|
|
for option in option_pairs:
|
|
# tensor from empty list is type float in python and annotated type in torchscript
|
|
if "annotate" in li and "dtype" not in option:
|
|
continue
|
|
code = template.format(list_create=li, options=option)
|
|
scope = {}
|
|
exec(code, globals(), scope)
|
|
cu = torch.jit.CompilationUnit(code)
|
|
t1 = cu.func()
|
|
t2 = scope['func']()
|
|
if t1.dtype == torch.float16: # equality NYI for half tensor
|
|
self.assertTrue(str(t1) == str(t2))
|
|
else:
|
|
self.assertEqual(t1, t2)
|
|
self.assertEqual(t1.dtype, t2.dtype)
|
|
self.assertEqual(t1.device, t2.device)
|
|
|
|
# adapted from test in test_torch
|
|
def test_tensor_to(self):
|
|
template = dedent('''
|
|
def func(t):
|
|
cuda = "{cuda}"
|
|
device = "{device}"
|
|
non_blocking = {non_blocking}
|
|
return {to_str}
|
|
''')
|
|
|
|
def s(t, to_str, non_blocking=None, device=None, cuda=None):
|
|
device = device if device is not None else str(t.device)
|
|
non_blocking = non_blocking if non_blocking is not None else False
|
|
cuda = "cuda" if cuda is None else cuda
|
|
code = template.format(to_str=to_str, device=device, non_blocking=non_blocking, cuda=cuda)
|
|
scope = {}
|
|
cu = torch.jit.CompilationUnit(code)
|
|
return cu.func(t)
|
|
|
|
def test_copy_behavior(t, non_blocking=False):
|
|
self.assertIs(t, s(t, 't.to(t, non_blocking=non_blocking)', non_blocking))
|
|
self.assertIs(t, s(t, 't.to(t.dtype, non_blocking=non_blocking)', non_blocking))
|
|
self.assertIs(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking)', non_blocking))
|
|
self.assertIsNot(t, s(t, 't.to(t, non_blocking=non_blocking, copy=True)', non_blocking))
|
|
self.assertIsNot(t, s(t, 't.to(t.dtype, non_blocking=non_blocking, copy=True)', non_blocking))
|
|
self.assertIsNot(t, s(t, 't.to(torch.empty_like(t), non_blocking=non_blocking, copy=True)', non_blocking))
|
|
|
|
devices = [t.device]
|
|
if t.device.type == 'cuda':
|
|
if t.device.index == -1:
|
|
devices.append('cuda:{}'.format(torch.cuda.current_device()))
|
|
elif t.device.index == torch.cuda.current_device():
|
|
devices.append('cuda')
|
|
for device in devices:
|
|
self.assertIs(t, s(t, 't.to(device, non_blocking=non_blocking)', non_blocking, device))
|
|
self.assertIs(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking)', non_blocking, device))
|
|
self.assertIsNot(t, s(t, 't.to(device, non_blocking=non_blocking, copy=True)', non_blocking, device))
|
|
self.assertIsNot(t, s(t, 't.to(device, t.dtype, non_blocking=non_blocking, copy=True)',
|
|
non_blocking, device))
|
|
|
|
t = torch.tensor(5)
|
|
test_copy_behavior(t)
|
|
|
|
self.assertEqual(t.device, s(t, "t.to('cpu')").device)
|
|
self.assertEqual(t.device, s(t, "t.to('cpu', dtype=torch.float32)").device)
|
|
self.assertIs(torch.float32, s(t, "t.to('cpu', dtype=torch.float32)").dtype)
|
|
self.assertEqual(t.device, s(t, "t.to(torch.float32)").device)
|
|
self.assertIs(torch.float32, s(t, "t.to(dtype=torch.float32)").dtype)
|
|
self.assertEqual(t.data_ptr(), s(t, "t.to('cpu')").data_ptr())
|
|
self.assertEqual(t.data_ptr(), s(t, "t.to(dtype=t.dtype, device=t.device, copy=False)").data_ptr())
|
|
self.assertEqual(t.data_ptr(), s(t, "t.to('cpu', copy=False)").data_ptr())
|
|
self.assertNotEqual(t.data_ptr(), s(t, "t.to('cpu', copy=True)").data_ptr())
|
|
|
|
a = torch.tensor(5)
|
|
if torch.cuda.is_available():
|
|
for non_blocking in [True, False]:
|
|
for cuda in ['cuda', 'cuda:0' if torch.cuda.device_count() == 1 else 'cuda:1']:
|
|
b = torch.tensor(5., device=cuda)
|
|
test_copy_behavior(b, non_blocking)
|
|
self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
|
|
self.assertEqual(a.device, s(b, "t.to('cpu', non_blocking=non_blocking).device"))
|
|
self.assertEqual(b.device, s(b, "t.to(cuda, non_blocking=non_blocking).device", cuda=cuda))
|
|
self.assertIs(torch.int32, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").dtype)
|
|
self.assertEqual(a.device, s(b, "t.to('cpu', dtype=torch.int32, non_blocking=non_blocking)").device)
|
|
self.assertIs(torch.int32, s(b, "t.to(dtype=torch.int32)").dtype)
|
|
self.assertEqual(b.device, s(b, "t.to(dtype=torch.int32)").device)
|
|
|
|
# Test AD: aten::to(Tensor self, int dtype, bool non_blocking, bool copy) -> Tensor
|
|
t = torch.tensor(5).float().requires_grad_()
|
|
out_ref = t.to(torch.float32)
|
|
out = s(t, "t.to(torch.float32)")
|
|
self.assertEqual(out_ref, out)
|
|
|
|
grad_ref = torch.autograd.grad(out_ref.sum(), t)
|
|
grad = torch.autograd.grad(out.sum(), t)
|
|
self.assertEqual(grad_ref, grad)
|
|
|
|
# Test AD: aten::to(Tensor self, Device? device, int? dtype, bool non_blocking, bool copy) -> Tensor
|
|
out_ref = t.to('cpu')
|
|
out = s(t, "t.to('cpu')")
|
|
self.assertEqual(out_ref, out)
|
|
|
|
grad_ref = torch.autograd.grad(out_ref.sum(), t)
|
|
grad = torch.autograd.grad(out.sum(), t)
|
|
self.assertEqual(grad_ref, grad)
|
|
|
|
# Test AD: aten::to(Tensor self, Tensor other, bool non_blocking, bool copy) -> Tensor
|
|
@torch.jit.script
|
|
def func2(t, t_ref):
|
|
return t.to(t_ref)
|
|
|
|
func2.debug_disable_autodiff_subgraph_inlining()
|
|
|
|
t_ref = torch.tensor(4).double()
|
|
out_ref = t.to(t_ref)
|
|
out = func2(t, t_ref)
|
|
grad_ref = torch.autograd.grad(out_ref.sum(), t)
|
|
grad = torch.autograd.grad(out.sum(), t)
|
|
self.assertEqual(grad_ref, grad)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
|
def test_tensor_number_math_cuda(self):
|
|
self._test_tensor_number_math(device='cuda')
|
|
|
|
def test_not(self):
|
|
# test not operator in python
|
|
# TODO: add more tests when bool conversions ready
|
|
def test_not_op(a):
|
|
return not bool(a > 1)
|
|
|
|
self.checkScript(test_not_op, (torch.tensor(2), ), optimize=True)
|
|
|
|
def test_is_isnot(self):
|
|
# test is and is not operator in python
|
|
template = dedent('''
|
|
def func():
|
|
# type: () -> bool
|
|
return {lhs} {op} {rhs}
|
|
''')
|
|
|
|
def test(op, args):
|
|
code = template.format(lhs=args[0], rhs=args[1], op=op)
|
|
scope = {}
|
|
execWrapper(code, globals(), scope)
|
|
cu = torch.jit.CompilationUnit(code)
|
|
self.assertEqual(
|
|
cu.func(),
|
|
scope['func'](),
|
|
"Failed with op: {}, lhs: {}, rhs: {}"
|
|
.format(op, args[0], args[1])
|
|
)
|
|
|
|
ops = ['is', 'is not']
|
|
type_literals = [True, False, None, [1, 1]]
|
|
|
|
# do literals product to try any types combinations
|
|
for op, lhs, rhs in product(ops, type_literals, type_literals):
|
|
test(op, [lhs, rhs])
|
|
|
|
def test_isinstance(self):
|
|
# test isinstance operator for static type checking
|
|
template = dedent('''
|
|
def func(x):
|
|
# type: ({type_hint}) -> bool
|
|
return isinstance(x, {typ})
|
|
''')
|
|
|
|
def test(inp, typ, type_hint):
|
|
code = template.format(typ=typ, type_hint=type_hint)
|
|
scope = {}
|
|
execWrapper(code, globals(), scope)
|
|
cu = torch.jit.CompilationUnit(code)
|
|
self.assertEqual(
|
|
cu.func(inp),
|
|
scope['func'](inp),
|
|
"Failed with typ: {}"
|
|
.format(typ)
|
|
)
|
|
|
|
inputs = [True, 1, 1.0, torch.tensor(1), [1, 2], (1.0,), [1, 2], 1]
|
|
type_literals = ['bool', 'int', 'float', 'torch.Tensor', 'list', 'tuple',
|
|
'(list, tuple)', '(int, float, bool)']
|
|
type_annotations = ['bool', 'int', 'float', 'Tensor', 'List[int]', 'Tuple[float]',
|
|
'List[int]', 'int']
|
|
|
|
# do zipping to try different types
|
|
for inp, typ, type_hint in zip(inputs, type_literals, type_annotations):
|
|
test(inp, typ, type_hint)
|
|
|
|
# test optional isintance check
|
|
with self.assertRaisesRegex(RuntimeError, "Optional isinstance check is not supported"):
|
|
@torch.jit.script
|
|
def opt_func(x):
|
|
# type: (Optional[int]) -> bool
|
|
return isinstance(x, int)
|
|
|
|
def test_python_call(self):
|
|
def pyfunc(a):
|
|
return a * 3.0
|
|
|
|
cu = torch.jit.CompilationUnit('''
|
|
def other_func(a):
|
|
return a + a
|
|
|
|
def test_call_python(a):
|
|
b = pyfunc(a)
|
|
b = other_func(b)
|
|
i = 0
|
|
step = 1
|
|
while i < 10:
|
|
b = pyfunc(b)
|
|
if bool(b > 3.0):
|
|
b = pyfunc(b)
|
|
i = 11
|
|
return b
|
|
''')
|
|
inputs = self._make_scalar_vars([1], torch.float)
|
|
outputs = self._make_scalar_vars([54], torch.float)
|
|
|
|
self.assertEqual(cu.test_call_python(*inputs), outputs[0])
|
|
|
|
def test_python_call_failure(self):
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
|
|
def pyfunc(a):
|
|
return a * 3.0
|
|
|
|
cu = torch.jit.CompilationUnit('''
|
|
def other_func(a):
|
|
return a + a
|
|
|
|
def test_call_python(a):
|
|
b = pyfunc(a)
|
|
b = other_func(b)
|
|
i = 0
|
|
step = 1
|
|
while i < 10:
|
|
b = pyfunc2(b)
|
|
if b > 3.0:
|
|
b = pyfunc(b)
|
|
i = 11
|
|
return b
|
|
''')
|
|
inputs = self._make_scalar_vars([1], torch.float)
|
|
outputs = self._make_scalar_vars([54], torch.float)
|
|
|
|
self.assertEqual(cu.test_call_python(*inputs), outputs)
|
|
|
|
def test_python_call_annotation(self):
|
|
def pyfunc(a):
|
|
return a * 3.0
|
|
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return pyfunc(a) + pyfunc(a)
|
|
|
|
inputs = self._make_scalar_vars([1], torch.float)
|
|
outputs = self._make_scalar_vars([6], torch.float)
|
|
self.assertEqual(foo(*inputs), outputs[0])
|
|
|
|
def test_python_call_annoytation_failure(self):
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value pyfunc2"):
|
|
def pyfunc(a):
|
|
return a * 3.0
|
|
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return pyfunc2(a) + pyfunc(a)
|
|
|
|
inputs = self._make_scalar_vars([1], torch.float)
|
|
outputs = self._make_scalar_vars([6], torch.float)
|
|
|
|
self.assertEqual(foo(*inputs), outputs[0])
|
|
|
|
def test_desugar_module(self):
|
|
import torch.nn.functional as F
|
|
|
|
def fn(x, slope):
|
|
a = torch.abs(x)
|
|
b = torch.nn.functional.prelu(x, slope)
|
|
c = F.prelu(x, slope)
|
|
return a, b, c
|
|
|
|
x = torch.arange(-3., 4)
|
|
slope = torch.tensor([0.5])
|
|
self.checkScript(fn, [x, slope], optimize=True)
|
|
|
|
def test_script_docstring(self):
|
|
@torch.jit.script
|
|
def with_docstring(x):
|
|
"""test str"""
|
|
y = x
|
|
"""y is the same as x"""
|
|
return y
|
|
self.assertEqual(with_docstring.__doc__, 'test str')
|
|
|
|
def test_script_method_docstring(self):
|
|
class A(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def with_docstring(self, x):
|
|
"""test str"""
|
|
y = x
|
|
"""y is the same as x"""
|
|
return y
|
|
a = A()
|
|
self.assertEqual(a.with_docstring.__doc__, 'test str')
|
|
|
|
@unittest.skipIf(TEST_WITH_UBSAN or not torch.fbgemm_is_cpu_supported(),
|
|
'Quantized RNN requires FBGEMM. FBGEMM does not play'
|
|
' well with UBSAN at the moment, so we skip the test if'
|
|
' we are in a UBSAN environment.')
|
|
def test_rnn_cell_quantized(self):
|
|
d_in, d_hid = 2, 2
|
|
|
|
for cell in [
|
|
torch.nn.LSTMCell(d_in, d_hid).float(),
|
|
torch.nn.GRUCell(d_in, d_hid).float(),
|
|
torch.nn.RNNCell(d_in, d_hid).float(),
|
|
]:
|
|
if isinstance(cell, torch.nn.LSTMCell):
|
|
num_chunks = 4
|
|
elif isinstance(cell, torch.nn.GRUCell):
|
|
num_chunks = 3
|
|
elif isinstance(cell, torch.nn.RNNCell):
|
|
num_chunks = 1
|
|
|
|
# Replace parameter values s.t. the range of values is exactly
|
|
# 255, thus we will have 0 quantization error in the quantized
|
|
# GEMM call. This i s for testing purposes.
|
|
#
|
|
# Note that the current implementation does not support
|
|
# accumulation values outside of the range representable by a
|
|
# 16 bit integer, instead resulting in a saturated value. We
|
|
# must take care that in our test we do not end up with a dot
|
|
# product that overflows the int16 range, e.g.
|
|
# (255*127+255*127) = 64770. So, we hardcode the test values
|
|
# here and ensure a mix of signedness.
|
|
vals = [[100, -155],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155],
|
|
[-155, 100],
|
|
[-155, 100],
|
|
[100, -155]]
|
|
vals = vals[:d_hid * num_chunks]
|
|
cell.weight_ih = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
cell.weight_hh = torch.nn.Parameter(
|
|
torch.tensor(vals, dtype=torch.float),
|
|
requires_grad=False)
|
|
|
|
ref = copy.deepcopy(cell)
|
|
|
|
cell = torch.jit.quantized.quantize_rnn_cell_modules(cell)
|
|
x = torch.tensor([[100, -155],
|
|
[-155, 100],
|
|
[100, -155]], dtype=torch.float)
|
|
h0_vals = [[-155, 100],
|
|
[-155, 155],
|
|
[100, -155]]
|
|
hx = torch.tensor(h0_vals, dtype=torch.float)
|
|
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
|
cx = torch.tensor(h0_vals, dtype=torch.float)
|
|
hiddens = (hx, cx)
|
|
else:
|
|
hiddens = hx
|
|
|
|
if isinstance(cell, torch.jit.quantized.QuantizedLSTMCell):
|
|
class ScriptWrapper(torch.jit.ScriptModule):
|
|
def __init__(self, cell):
|
|
super(ScriptWrapper, self).__init__()
|
|
self.cell = cell
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, hiddens):
|
|
# type: (torch.Tensor, Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]
|
|
return self.cell(x, hiddens)
|
|
else:
|
|
|
|
class ScriptWrapper(torch.jit.ScriptModule):
|
|
def __init__(self, cell):
|
|
super(ScriptWrapper, self).__init__()
|
|
self.cell = cell
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, hiddens):
|
|
# type: (torch.Tensor, torch.Tensor) -> torch.Tensor
|
|
return self.cell(x, hiddens)
|
|
|
|
cell = ScriptWrapper(cell)
|
|
outs = cell(x, hiddens)
|
|
cell = self.getExportImportCopyWithPacking(cell)
|
|
|
|
outs = cell(x, hiddens)
|
|
ref_outs = ref(x, hiddens)
|
|
|
|
self.assertEqual(len(outs), len(ref_outs))
|
|
for out, ref_out in zip(outs, ref_outs):
|
|
torch.testing.assert_allclose(out, ref_out)
|
|
|
|
def test_script_module(self):
|
|
class M1(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M1, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class PModule(nn.Module):
|
|
def __init__(self):
|
|
super(PModule, self).__init__()
|
|
self.a = nn.Parameter(torch.randn(2, 3))
|
|
|
|
def forward(self, a):
|
|
return self.a.mm(a)
|
|
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(False)
|
|
# test submodule
|
|
self.sub = M1()
|
|
self.sub2 = PModule()
|
|
# test parameters
|
|
self.weight = nn.Parameter(torch.randn(2, 3))
|
|
self.bias = nn.Parameter(torch.randn(2))
|
|
# test defining a method from a string
|
|
self.define("""
|
|
def hi(self, a):
|
|
return self.weight.mm(a)
|
|
""")
|
|
# test script methods
|
|
|
|
@torch.jit.script_method
|
|
def doit(self, input):
|
|
# test use of parameter
|
|
return self.weight.mm(input)
|
|
|
|
@torch.jit.script_method
|
|
def doit2(self, input):
|
|
return self.weight.mm(input)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
a = self.doit(input)
|
|
b = self.doit2(input)
|
|
c = self.hi(input)
|
|
d = self.sub2(input)
|
|
return a + b + self.bias + self.sub(a) + c + d
|
|
m2 = M2()
|
|
input = torch.randn(3, 2)
|
|
a = m2.weight.mm(input)
|
|
b = m2.weight.mm(input)
|
|
c = m2.weight.mm(input)
|
|
d = m2.sub2.a.mm(input)
|
|
ref = a + b + m2.bias + m2.sub.weight + a + c + d
|
|
self.assertEqual(ref, m2.forward(input))
|
|
m2.weight = nn.Parameter(torch.zeros_like(m2.weight))
|
|
m2.bias = nn.Parameter(torch.zeros_like(m2.bias))
|
|
m2.sub.weight = nn.Parameter(torch.zeros_like(m2.sub.weight))
|
|
m2.sub2.a.data.zero_()
|
|
self.assertEqual(torch.zeros(2, 2), m2.forward(torch.randn(3, 2)))
|
|
|
|
def test_filecheck(self):
|
|
def test_check():
|
|
file = "232"
|
|
FileCheck().check("2").check("3").check("2").run(file)
|
|
FileCheck().check("232").run(file)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
|
|
FileCheck().check("22").run(file)
|
|
with self.assertRaisesRegex(RuntimeError, "CHECK: 3"):
|
|
FileCheck().check("3").check("3").run(file)
|
|
|
|
test_check()
|
|
|
|
def test_check_count():
|
|
file = "22222"
|
|
FileCheck().check_count("2", 5).run(file)
|
|
FileCheck().check_count("22", 2).run(file)
|
|
FileCheck().check_count("222", 1).run(file)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected to not find'):
|
|
FileCheck().check_count("2", 4, exactly=True).run(file)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected to find "22"'):
|
|
FileCheck().check_count("22", 3).run(file)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "CHECK-COUNT-6: 2"):
|
|
FileCheck().check_count("2", 6).run(file)
|
|
|
|
test_check_count()
|
|
|
|
def test_check_same():
|
|
file = "22\n33"
|
|
FileCheck().check_same("22").run(file)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
|
|
FileCheck().check_same("33").run(file)
|
|
|
|
file = "22 1 3"
|
|
|
|
FileCheck().check("2").check_same("3").run(file)
|
|
FileCheck().check_count("2", 2).check_same("3").run(file)
|
|
|
|
test_check_same()
|
|
|
|
def test_check_next():
|
|
file = "\n1\n2\n3"
|
|
FileCheck().check("1").check_next("2").check_next("3").run(file)
|
|
FileCheck().check_next("1").check_next("2").check_next("3").run(file)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Expected to find"):
|
|
FileCheck().check("1").check_next("2").run("12")
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Expected to not find"):
|
|
FileCheck().check("1").check_next("2").run("1\n\n2")
|
|
|
|
test_check_next()
|
|
|
|
def test_check_dag():
|
|
fc = FileCheck().check_dag("1").check_dag("2").check_not("2")
|
|
fc.run("12")
|
|
fc.run("21")
|
|
|
|
fc = FileCheck()
|
|
fc.check_not("3").check_dag("1").check_dag("2").check_not("3")
|
|
fc.run("1 3 2")
|
|
fc.run("2 3 1")
|
|
|
|
fc = FileCheck().check_dag("1").check_dag("2").check("3")
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected to find "3" but did not find it'):
|
|
fc.run("1 3 2")
|
|
|
|
test_check_dag()
|
|
|
|
def test_check_not():
|
|
FileCheck().check_not("2").check("1").run("12")
|
|
FileCheck().check("2").check_not("2").run("12")
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
|
|
FileCheck().check_not("2").check("1").run("21")
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
|
|
FileCheck().check("2").check_not("1").run("21")
|
|
|
|
# checks with distinct range matchings
|
|
fb = FileCheck().check_count("2", 2).check_count("2", 2).check_not("2")
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected to not find "2"'):
|
|
fb.run("22 2 22")
|
|
|
|
fb = FileCheck().check_count("2", 2).check_not("1").check_count("2", 2)
|
|
with self.assertRaisesRegex(RuntimeError, 'Expected to not find "1"'):
|
|
fb.run("22 1 22")
|
|
|
|
def test_script_module_call_noscript(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.value = 1
|
|
|
|
def foo(self):
|
|
return torch.ones(2, 2) + self.value
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return input + self.foo()
|
|
|
|
m = M()
|
|
input = torch.randn(2, 2)
|
|
o = m(input)
|
|
self.assertEqual(o, input + torch.ones(2, 2) + 1)
|
|
# check that we can change python attributes
|
|
# and that those changes are picked up in script methods
|
|
m.value = 2
|
|
o = m(input)
|
|
self.assertEqual(o, input + torch.ones(2, 2) + 2)
|
|
|
|
def test_script_module_nochange_submodule(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.sub = nn.Linear(5, 5)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return self.sub(input)
|
|
|
|
m = M()
|
|
input = torch.randn(1, 5, 5)
|
|
o = m(input)
|
|
self.assertEqual(o, m.sub(input))
|
|
with self.assertRaisesRegex(RuntimeError, "cannot re-assign"):
|
|
m.sub = nn.Linear(5, 5)
|
|
|
|
def test_script_inline_trace_multiple_args(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
|
|
def forward(self, input, input2):
|
|
return input + input2
|
|
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(False)
|
|
self.m = torch.jit.trace(M(), (torch.zeros(4, 3), torch.zeros(4, 3)))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, inp):
|
|
return self.m(inp, inp)
|
|
|
|
m2 = M2()
|
|
m2(torch.zeros(4, 3))
|
|
|
|
def test_script_module_const(self):
|
|
class M(torch.jit.ScriptModule):
|
|
|
|
__constants__ = ['b', 'i', 'c']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.b = False
|
|
self.i = 1
|
|
self.c = 3.5
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return self.b, self.i, self.c
|
|
|
|
m = M()
|
|
o0, o1, o2 = m()
|
|
self.assertEqual(o0, 0)
|
|
self.assertEqual(o1, 1)
|
|
self.assertEqual(o2, 3.5)
|
|
|
|
def test_script_module_fail_const(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.b = False
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return self.b
|
|
with self.assertRaisesRegex(RuntimeError, "is not usable in a script method"):
|
|
M()
|
|
|
|
def test_script_module_valid_consts(self):
|
|
tester = self
|
|
|
|
class Foo(torch.jit.ScriptModule):
|
|
__constants__ = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i']
|
|
|
|
def __init__(self):
|
|
super(Foo, self).__init__(False)
|
|
self.a = 1
|
|
self.b = 1.2
|
|
self.c = False
|
|
with tester.assertRaisesRegex(
|
|
TypeError,
|
|
"'Linear' object for attribute 'd' is not a valid constant"):
|
|
self.d = [nn.Linear(3, 4)]
|
|
self.e = lambda x: x
|
|
self.f = [3, 4, 5]
|
|
tester.assertTrue(type(self.f) is tuple)
|
|
self.g = [3, (3, 4), 5]
|
|
with tester.assertRaisesRegex(TypeError, "not a valid constant"):
|
|
self.h = type(1)
|
|
with tester.assertRaisesRegex(TypeError, "not a valid constant"):
|
|
self.i = (3, 4, {})
|
|
|
|
f = Foo()
|
|
|
|
def test_script_module_param_buffer_mutation(self):
|
|
# TODO: add param mutation test case after JIT support it
|
|
class ModuleBufferMutate(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleBufferMutate, self).__init__(False)
|
|
self.register_buffer('running_var', torch.tensor(0, dtype=torch.long))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
if self.training:
|
|
self.running_var += 1
|
|
return self.running_var
|
|
|
|
m = ModuleBufferMutate()
|
|
self.assertEqual(m(), 1)
|
|
m.eval()
|
|
self.assertEqual(m(), 1)
|
|
|
|
def test_script_module_for(self):
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['b']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.b = [1, 2, 3, 4]
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
sum = 0
|
|
for i in self.b:
|
|
sum += i
|
|
return sum
|
|
|
|
m = M()
|
|
self.assertEqual(m(), 10)
|
|
|
|
def test_script_module_for2(self):
|
|
class Sub(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Sub, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['mods']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.mods = nn.ModuleList([Sub() for i in range(10)])
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
for m in self.mods:
|
|
v = m(v)
|
|
return v
|
|
|
|
i = torch.Tensor(2)
|
|
m = M()
|
|
o = m(i)
|
|
v = i
|
|
for sub in m.mods:
|
|
v = sub(v)
|
|
self.assertEqual(o, v)
|
|
|
|
def test_script_module_const_submodule_fail(self):
|
|
class Sub(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Sub, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.mods = [Sub() for _ in range(10)]
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
for _ in self.mods:
|
|
print(1)
|
|
return 4
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "did you forget to add it __constants__"):
|
|
M()
|
|
|
|
# Specialized error for Tensors
|
|
class S(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
self.tensor_constant = torch.ones(2)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return self.tensor_constant + 2
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Tensors must be added to a module as a buffer or parameter"):
|
|
S()
|
|
|
|
class DerivedStateModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(TestScript.DerivedStateModule, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.ones(3, 4, dtype=torch.float))
|
|
self.register_buffer('derived', torch.neg(self.param).detach().clone())
|
|
|
|
# This is a flag so we can test that the pack method was called
|
|
self.register_buffer('pack_called', torch.zeros(1, dtype=torch.long))
|
|
# This is a flag so we can test that the unpack method was called
|
|
self.register_buffer('unpack_called', torch.zeros(1, dtype=torch.long))
|
|
|
|
@torch.jit.script_method
|
|
def _pack(self):
|
|
self.pack_called.set_(torch.ones(1, dtype=torch.long))
|
|
self.derived.set_(torch.rand(1, dtype=torch.float).detach())
|
|
|
|
@torch.jit.script_method
|
|
def _unpack(self):
|
|
self.unpack_called.set_(torch.ones(1, dtype=torch.long))
|
|
self.derived.set_(torch.neg(self.param).detach())
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x + self.derived
|
|
|
|
def test_pack_unpack_state(self):
|
|
sm = TestScript.DerivedStateModule()
|
|
x = torch.rand(3, 4, dtype=torch.float)
|
|
torch.testing.assert_allclose(sm(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
|
|
|
|
# Test save path
|
|
self.assertFalse(sm.pack_called.item())
|
|
self.assertFalse(sm.unpack_called.item())
|
|
imported = self.getExportImportCopyWithPacking(sm)
|
|
# ensure pack was called before serialization
|
|
self.assertTrue(sm.pack_called.item())
|
|
# ensure unpack was called after serialization so as to leave the module in an initialized state
|
|
self.assertTrue(sm.unpack_called.item())
|
|
|
|
torch.testing.assert_allclose(sm.derived, torch.neg(sm.param))
|
|
|
|
# Test load paths
|
|
self.assertTrue(imported.unpack_called.item())
|
|
torch.testing.assert_allclose(imported(x), x + torch.neg(torch.ones(3, 4, dtype=torch.float)))
|
|
|
|
def test_pack_unpack_nested(self):
|
|
class SubSubMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(SubSubMod, self).__init__()
|
|
self.register_buffer('buf', torch.ones(3, 4) * 3)
|
|
|
|
@torch.jit.script_method
|
|
def _pack(self):
|
|
self.buf.set_(torch.zeros(1, dtype=torch.double))
|
|
|
|
@torch.jit.script_method
|
|
def _unpack(self):
|
|
self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 3)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x + self.buf
|
|
|
|
class SubMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(SubMod, self).__init__()
|
|
self.register_buffer('buf', torch.ones(3, 4) * 2)
|
|
self.ssm = SubSubMod()
|
|
|
|
@torch.jit.script_method
|
|
def _pack(self):
|
|
self.buf.set_(torch.zeros(1, dtype=torch.double))
|
|
|
|
@torch.jit.script_method
|
|
def _unpack(self):
|
|
self.buf.set_(torch.ones(3, 4, dtype=torch.double) * 2)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.ssm(x + self.buf)
|
|
|
|
class Mod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Mod, self).__init__()
|
|
self.submod = SubMod()
|
|
self.register_buffer('buf', torch.ones(3, 4) * 1)
|
|
|
|
@torch.jit.script_method
|
|
def _pack(self):
|
|
self.buf.set_(torch.zeros(1, dtype=torch.double))
|
|
|
|
@torch.jit.script_method
|
|
def _unpack(self):
|
|
self.buf.set_(torch.ones(3, 4, dtype=torch.double))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.submod(x + self.buf)
|
|
|
|
m = Mod()
|
|
torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
|
|
m.apply(lambda s: s._pack())
|
|
torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.zeros(3, 4))
|
|
m.apply(lambda s: s._unpack())
|
|
torch.testing.assert_allclose(m(torch.zeros(3, 4)), torch.ones(3, 4) * 6)
|
|
|
|
def test_script_module_not_tuple(self):
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['mods']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.mods = 1
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
for m in self.mods:
|
|
print(m)
|
|
return v
|
|
with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
|
|
M()
|
|
|
|
def test_script_module_list_sequential_error(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self, mod_list):
|
|
super(M, self).__init__(False)
|
|
self.mods = mod_list
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
for m in self.mods:
|
|
v = m(v)
|
|
return v
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
|
a = M(nn.Sequential(nn.ReLU()))
|
|
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
|
a = M(nn.ModuleList([nn.ReLU()]))
|
|
|
|
def test_attr_module_constants_error(self):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self, mod_list):
|
|
super(M2, self).__init__(False)
|
|
self.mods = mod_list
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
return self.mods.forward(x)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
|
M2(nn.Sequential(nn.ReLU()))
|
|
with self.assertRaisesRegex(RuntimeError, "Did you forget to add it to __constants"):
|
|
M2(nn.ModuleList([nn.ReLU()]))
|
|
|
|
def test_script_sequential_for(self):
|
|
class Sub(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Sub, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['mods']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.mods = nn.Sequential(Sub(), Sub(), Sub())
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
for m in self.mods:
|
|
v = m(v)
|
|
return v
|
|
|
|
@torch.jit.script_method
|
|
def forward2(self, v):
|
|
return self.mods(v)
|
|
|
|
i = torch.Tensor(2)
|
|
m = M()
|
|
o = m(i)
|
|
v = i
|
|
for sub in m.mods:
|
|
v = sub(v)
|
|
self.assertEqual(o, v)
|
|
|
|
o2 = m.forward2(i)
|
|
self.assertEqual(o2, v)
|
|
|
|
def test_script_sequential_multi_output_fail(self):
|
|
class Sub(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Sub, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class ReturnMulti(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ReturnMulti, self).__init__(False)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x, x, x
|
|
|
|
class HaveSequential(torch.jit.ScriptModule):
|
|
__constants__ = ['someseq']
|
|
|
|
def __init__(self):
|
|
super(HaveSequential, self).__init__(False)
|
|
self.someseq = nn.Sequential(
|
|
Sub(),
|
|
ReturnMulti(),
|
|
Sub()
|
|
)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.someseq(x)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "(Tensor, Tensor, Tensor)"):
|
|
hs = HaveSequential()
|
|
i = torch.Tensor(2)
|
|
hs(i)
|
|
|
|
def test_constant_insert_fail_lint(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
y = x + 1
|
|
z = torch.tensor([[1.0, 2.5]])
|
|
print(x, z)
|
|
|
|
# check that it doesnt error
|
|
self.run_pass('constant_propagation', foo.graph)
|
|
self.assertTrue("aten::tensor" in str(foo.graph)) # not constant propped
|
|
|
|
def test_script_sequential_in_mod_list(self):
|
|
class Sub(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Sub, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['mods']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.mods = nn.ModuleList([Sub(), nn.Sequential(Sub(), nn.Sequential(Sub(), Sub()), Sub())])
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
for mod in self.mods:
|
|
v = mod(v)
|
|
return v
|
|
|
|
m = M()
|
|
graph = str(m.graph)
|
|
self.assertTrue(graph.count("aten::add") == 5)
|
|
self.assertTrue("python" not in graph)
|
|
|
|
def test_script_nested_mod_list(self):
|
|
class Sub(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Sub, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['mods']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.mods = nn.ModuleList([nn.ModuleList([Sub()]), nn.Sequential(Sub()), nn.ModuleList([Sub(), Sub()])])
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
for mod in self.mods:
|
|
for m in mod:
|
|
v = m(v)
|
|
return v
|
|
|
|
m = M()
|
|
graph = str(m.graph)
|
|
self.assertTrue(graph.count("aten::add") == 4)
|
|
self.assertTrue("python" not in graph)
|
|
|
|
def test_constant_as_attr(self):
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['dim']
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.dim = 1
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, v):
|
|
return torch.cat([v, v, v], dim=self.dim)
|
|
v = torch.zeros(1, 1)
|
|
self.assertEqual(torch.cat([v, v, v], dim=1), M()(v))
|
|
|
|
class StarTestSumStarred(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestScript.StarTestSumStarred, self).__init__()
|
|
|
|
def forward(self, *inputs):
|
|
output = inputs[0]
|
|
for i in range(1, len(inputs)):
|
|
output += inputs[i]
|
|
return output
|
|
|
|
class StarTestReturnThree(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestScript.StarTestReturnThree, self).__init__()
|
|
|
|
def forward(self, rep):
|
|
return rep, rep, rep
|
|
|
|
def test_script_star_expr(self):
|
|
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
|
|
(torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
|
|
self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, rep):
|
|
tup = self.g(rep)
|
|
return self.m(*tup)
|
|
|
|
m = M2()
|
|
self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
|
|
|
|
def test_script_star_expr_string(self):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
self.m = torch.jit.trace(TestScript.StarTestSumStarred(),
|
|
(torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)))
|
|
self.g = torch.jit.trace(TestScript.StarTestReturnThree(), torch.ones(4, 3))
|
|
|
|
self.define('''
|
|
def forward(self, rep):
|
|
tup = self.g(rep)
|
|
return self.m(*tup)
|
|
''')
|
|
|
|
m = M2()
|
|
self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
|
|
|
|
class StarTestSumAndReturnThree(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestScript.StarTestSumAndReturnThree, self).__init__()
|
|
|
|
def forward(self, *inputs):
|
|
output = inputs[0]
|
|
for i in range(1, len(inputs)):
|
|
output += inputs[i]
|
|
return output, output, output
|
|
|
|
def test_script_star_assign(self):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
self.g = torch.jit.trace(TestScript.StarTestSumAndReturnThree(), torch.ones(4, 3))
|
|
self.define('''
|
|
def forward(self, rep):
|
|
head, *tail = self.g(rep)
|
|
return head
|
|
''')
|
|
|
|
m = M2()
|
|
self.assertEqual(m(torch.zeros(4, 3)), 3 * torch.zeros(4, 3))
|
|
|
|
def test_script_module_star_assign2(self):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
self.g = torch.jit.trace(
|
|
TestScript.StarTestSumAndReturnThree(),
|
|
(torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
|
|
_force_outplace=True)
|
|
self.define('''
|
|
def forward(self, rep):
|
|
*head, tail = self.g(rep, rep, rep)
|
|
return tail
|
|
''')
|
|
|
|
m = M2()
|
|
self.assertEqual(m(torch.ones(4, 3)), 3 * torch.ones(4, 3))
|
|
|
|
def test_script_module_star_assign2_inplace(self):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
self.g = torch.jit.trace(
|
|
TestScript.StarTestSumAndReturnThree(),
|
|
(torch.ones(4, 3), torch.ones(4, 3), torch.ones(4, 3)),
|
|
_force_outplace=False)
|
|
self.define('''
|
|
def forward(self, rep):
|
|
*head, tail = self.g(rep, rep, rep)
|
|
return tail
|
|
''')
|
|
|
|
m = M2()
|
|
# since forward() makes three aliases to the input `rep` before passing
|
|
# it to StarTestSumAndReturnThree(), in-place behavior will be different
|
|
# than the above out of place.
|
|
self.assertEqual(m(torch.ones(4, 3)), 4 * torch.ones(4, 3))
|
|
|
|
def test_script_module_star_assign_fail_pythonop(self):
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
|
|
def myfunc():
|
|
return torch.zeros(1, 2, 3), torch.zeros(1, 2, 3)
|
|
|
|
self.define('''
|
|
def forward(self, rep):
|
|
a, *b = myfunc()
|
|
return a
|
|
''')
|
|
|
|
m = M2()
|
|
m(torch.zeros(4, 3))
|
|
|
|
def test_script_module_star_assign_fail_builtin(self):
|
|
with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(True)
|
|
|
|
self.define('''
|
|
def forward(self, rep):
|
|
a, *b = torch.neg(rep)
|
|
return a
|
|
''')
|
|
|
|
m = M2()
|
|
m(torch.zeros(4, 3))
|
|
|
|
def test_pack_padded_pad_packed_trace(self):
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
T, B, C = 3, 5, 7
|
|
|
|
class PadPackedWrapper(torch.nn.Module):
|
|
def __init__(self):
|
|
super(PadPackedWrapper, self).__init__()
|
|
|
|
def forward(self, x, seq_lens):
|
|
x = pack_padded_sequence(x, seq_lens)
|
|
x, _ = pad_packed_sequence(x)
|
|
return x
|
|
|
|
x = np.ones((T, B, C))
|
|
seq_lens = np.array([3, 3, 2, 2, 1], dtype=np.int32)
|
|
# set padding value so we can test equivalence
|
|
for b in range(B):
|
|
if seq_lens[b] < T:
|
|
x[seq_lens[b]:, b, :] = 0
|
|
seq_lens = torch.from_numpy(seq_lens)
|
|
x = torch.autograd.Variable(torch.from_numpy(x), requires_grad=True)
|
|
|
|
m = PadPackedWrapper()
|
|
m_traced = torch.jit.trace(m, (x, seq_lens,))
|
|
|
|
y = m(x, seq_lens)
|
|
loss = torch.sum(y)
|
|
loss.backward()
|
|
grad = x.grad.clone()
|
|
x.grad.zero_()
|
|
|
|
y_traced = m_traced(x, seq_lens)
|
|
loss_traced = torch.sum(y_traced)
|
|
loss_traced.backward()
|
|
grad_traced = x.grad.clone()
|
|
|
|
self.assertEqual(y_traced, x)
|
|
self.assertEqual(y_traced, y)
|
|
self.assertEqual(grad, grad_traced)
|
|
|
|
f = io.BytesIO()
|
|
torch.onnx._export(m, (x, seq_lens), f, verbose=False)
|
|
|
|
def test_script_outputs(self):
|
|
with self.assertRaisesRegex(RuntimeError, "cannot be used as a tuple"):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
c, d = a + a
|
|
return c + d
|
|
|
|
@torch.jit.script
|
|
def return3():
|
|
return 1, 2, 3
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "too many values to unpack"):
|
|
@torch.jit.script
|
|
def bind2():
|
|
a, b = return3()
|
|
print(a)
|
|
print(b)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
def test_script_get_device_cuda(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return a.get_device()
|
|
|
|
v = torch.randn(1, device='cuda')
|
|
self.assertEqual(foo(v), 0)
|
|
|
|
def test_script_chunk(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
b, c = torch.chunk(a, dim=0, chunks=2)
|
|
return b
|
|
v = torch.rand(10, 3)
|
|
self.assertEqual(torch.chunk(v, dim=0, chunks=2)[0], foo(v))
|
|
|
|
def test_rnn_trace_override(self):
|
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
|
num_layers = 3
|
|
T, B, C = 11, 5, 7
|
|
|
|
class RNNTraceWrapper(torch.nn.Module):
|
|
def __init__(self, cell_type):
|
|
super(RNNTraceWrapper, self).__init__()
|
|
if cell_type == 'RNN':
|
|
self.rnn = torch.nn.RNN(input_size=C, hidden_size=C, num_layers=num_layers)
|
|
elif cell_type == 'LSTM':
|
|
self.rnn = torch.nn.LSTM(input_size=C, hidden_size=C, num_layers=num_layers)
|
|
elif cell_type == 'GRU':
|
|
self.rnn = torch.nn.GRU(input_size=C, hidden_size=C, num_layers=num_layers)
|
|
|
|
def forward(self, x, seq_lens):
|
|
x = pack_padded_sequence(x, seq_lens)
|
|
x, _ = self.rnn(x)
|
|
x, _ = pad_packed_sequence(x)
|
|
return x
|
|
|
|
for cell_type in ['RNN', 'LSTM', 'GRU']:
|
|
x = torch.ones(T, B, C, requires_grad=True)
|
|
seq_lens = torch.from_numpy(np.array([11, 3, 2, 2, 1], dtype=np.int32))
|
|
|
|
m = RNNTraceWrapper(cell_type)
|
|
m_traced = torch.jit.trace(m, (x, seq_lens,))
|
|
|
|
y = m(x, seq_lens)
|
|
loss = torch.sum(y)
|
|
loss.backward()
|
|
grad = x.grad.clone()
|
|
x.grad.zero_()
|
|
|
|
y_traced = m_traced(x, seq_lens)
|
|
loss_traced = torch.sum(y_traced)
|
|
loss_traced.backward()
|
|
grad_traced = x.grad.clone()
|
|
|
|
self.assertEqual(y_traced, y)
|
|
self.assertEqual(grad, grad_traced)
|
|
|
|
f = io.BytesIO()
|
|
torch.onnx._export(m, (x, seq_lens), f, verbose=False)
|
|
|
|
def test_python_call_non_tensor(self):
|
|
def foo(a, b, c):
|
|
# type: (Tensor, int, Tuple[Tensor, int]) -> Tuple[int, Tensor]
|
|
d, e = c
|
|
return b + e, a + d
|
|
|
|
@torch.jit.script
|
|
def bar():
|
|
x = torch.ones(3, 4)
|
|
a, b = foo(x, 3, (x, 3))
|
|
return a, b
|
|
|
|
self.assertEqual((6, torch.ones(3, 4) + 1), bar())
|
|
|
|
def test_python_call_non_tensor_wrong(self):
|
|
with self.assertRaisesRegex(RuntimeError, r"but instead got value of type tuple"):
|
|
def foo():
|
|
# type: () -> Tensor
|
|
return ((3, 4),) # noqa: T484
|
|
|
|
@torch.jit.script
|
|
def bar():
|
|
return foo()
|
|
|
|
bar()
|
|
|
|
def test_tuples(self):
|
|
def foo(i):
|
|
a = (i + 4, i * 2)
|
|
c = a
|
|
# some nonsense with if-statements and loops to check
|
|
# that tuple lowering doesn't fail
|
|
if True:
|
|
c = (i * 9, i + 1)
|
|
t0, t1 = c
|
|
while False:
|
|
t0, t1 = c
|
|
c = (t1, t0)
|
|
x = (1,)
|
|
y = 1,
|
|
return t0, x, y
|
|
|
|
v = torch.rand(10, 3)
|
|
self.checkScript(foo, (v,))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"variable 'a' previously has type \(Tensor, Tensor\)"):
|
|
@torch.jit.script
|
|
def mixtypes(x):
|
|
a = (x, x)
|
|
if True:
|
|
a = 4
|
|
|
|
def test_if_tuple_sizes(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Type mismatch"):
|
|
@torch.jit.script
|
|
def diff_tuple_sizes(x):
|
|
if False:
|
|
c0 = ((x, x), (x, x, x))
|
|
else:
|
|
c0 = ((x, x, x), (x, x))
|
|
return c0
|
|
|
|
def test_if_different_type(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Type mismatch: c0 is set to type int "
|
|
"in the true branch and type float in the false branch:"):
|
|
@torch.jit.script
|
|
def diff_type_used():
|
|
if False:
|
|
c0 = 1
|
|
else:
|
|
c0 = 1.0
|
|
return c0
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "variable 'c0' previously has type float"):
|
|
@torch.jit.script
|
|
def diff_existing_type(x):
|
|
c0 = 1.0
|
|
if False:
|
|
c0 = 1
|
|
print(x)
|
|
return x
|
|
|
|
@torch.jit.script
|
|
def diff_type_unused():
|
|
if True:
|
|
c0 = 1
|
|
print(c0)
|
|
else:
|
|
c0 = 1.0
|
|
print(c0)
|
|
return 1
|
|
|
|
def test_if_list_cat(self):
|
|
# testing that different length lists don't throw error on cat in shape prop
|
|
@torch.jit.script
|
|
def test_list(x):
|
|
if bool(x.sum() < 1):
|
|
c = [x, x]
|
|
else:
|
|
c = [x, x, x]
|
|
return torch.cat(c)
|
|
|
|
b = torch.zeros(2, 4)
|
|
test_list.graph.propagate_shapes((b,), False)
|
|
|
|
def test_if_supertype(self):
|
|
@torch.jit.script
|
|
def tensor_unifying(x, y, z):
|
|
# testing dynamic is appropriately set for y and z
|
|
if True:
|
|
x, y, z = x, y, z
|
|
else:
|
|
x, y, z = x, x, y
|
|
|
|
return x, y, z
|
|
|
|
a = torch.zeros(2, 2, dtype=torch.float)
|
|
b = torch.zeros(2, 4, dtype=torch.long)
|
|
c = torch.zeros(2, 4, dtype=torch.float)
|
|
|
|
tensor_unifying.graph.propagate_shapes((a, b, c), False)
|
|
if_outputs = list(tensor_unifying.graph.findNode("prim::If").outputs())
|
|
self.assertTrue(if_outputs[0].type().str() == "Float(*, *)")
|
|
self.assertTrue(if_outputs[1].type().str() == "Tensor")
|
|
self.assertTrue(if_outputs[2].type().str() == "Tensor")
|
|
|
|
def test_list_unify(self):
|
|
# allowing a unififed int?[] would cause a runtime error b/c
|
|
# the index operation expects int?[] to be a generic list,
|
|
# but in the true branch the IValue will be a int list
|
|
with self.assertRaisesRegex(RuntimeError, "int[] in the true branch and type None[]"):
|
|
@torch.jit.script
|
|
def list_optional_fails(x):
|
|
# type: (bool) -> Optional[int]
|
|
if x:
|
|
y = [1]
|
|
else:
|
|
y = [None] # noqa: T484
|
|
return y[0]
|
|
|
|
@torch.jit.script
|
|
def list_tensors(x):
|
|
# type: (bool) -> Tuple[Tensor, List[Tensor]]
|
|
if x:
|
|
a = torch.zeros([1, 1])
|
|
y = [a]
|
|
else:
|
|
a = torch.zeros([1, 2])
|
|
y = [a]
|
|
return a, y
|
|
|
|
self.run_pass('constant_propagation', list_tensors.graph)
|
|
m = torch.jit.ScriptModule()
|
|
m._create_method_from_graph("forward", list_tensors.graph)
|
|
# testing that tensor type of lists is unified
|
|
self.getExportImportCopy(m)
|
|
|
|
def test_type_annotations_repeated_list(self):
|
|
@torch.jit.script
|
|
def float_fn(x, y):
|
|
# type: (float, BroadcastingList3[float]) -> List[float]
|
|
return y
|
|
self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, [1.0, 1.0, 1.0]))
|
|
self.assertEqual(float_fn(2.0, 1.0), float_fn(2.0, (1.0, 1.0, 1.0)))
|
|
|
|
@torch.jit.script
|
|
def float_fn_call():
|
|
print(float_fn(1.0, 1.0))
|
|
print(float_fn(1.0, (1.0, 1.0, 1.0)))
|
|
|
|
@torch.jit.script
|
|
def int_fn(x):
|
|
# type: (BroadcastingList3[int]) -> List[int]
|
|
return x
|
|
self.assertEqual(int_fn(1), int_fn([1, 1, 1]))
|
|
self.assertEqual(int_fn(1), int_fn((1, 1, 1)))
|
|
|
|
@torch.jit.script
|
|
def int_fn_call():
|
|
print(int_fn(1))
|
|
print(int_fn((1, 1, 1)))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "must be a positive integer:"):
|
|
@torch.jit.script # noqa: T484
|
|
def fn(x):
|
|
# type: (BroadcastingListx[int]) -> List[int] # noqa: T484
|
|
return x
|
|
|
|
# using CU so that flake8 error on int[2] is not raised (noqa not working)
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown type constructor"):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def nested(x, y):
|
|
# type: (int, Tuple[int, int[2]]) -> List[int]
|
|
return x # noqa: T484
|
|
''')
|
|
|
|
def test_ntuple_builtins(self):
|
|
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
|
|
|
|
def test_ints():
|
|
return _single(1), _pair(2), _triple(3), _quadruple(4)
|
|
|
|
def test_floats():
|
|
return _single(1), _pair(2.1), _triple(3.1), _quadruple(4.1)
|
|
|
|
self.checkScript(test_ints, ())
|
|
self.checkScript(test_floats, ())
|
|
|
|
def test_embedding_renorm_grad_error(self):
|
|
# Testing that the builtin call to embedding_renorm_ correctly throws
|
|
# Error when .backward() is called on its input
|
|
|
|
def embedding_norm(input, embedding_matrix, max_norm):
|
|
F.embedding(input, embedding_matrix, max_norm=0.01)
|
|
|
|
@torch.jit.script
|
|
def embedding_norm_script(input, embedding_matrix, max_norm):
|
|
# type: (Tensor, Tensor, float) -> None
|
|
F.embedding(input, embedding_matrix, max_norm=0.01)
|
|
|
|
for _ in [embedding_norm, embedding_norm_script]:
|
|
input = torch.tensor([[1, 2, 4, 5], [4, 3, 2, 9]])
|
|
embedding_matrix = torch.randn(10, 3)
|
|
|
|
var1 = torch.randn(10, 3, requires_grad=True)
|
|
var2 = var1.detach().requires_grad_()
|
|
output1 = var1 * embedding_matrix
|
|
output2 = var2 * embedding_matrix
|
|
|
|
output1.sum().backward()
|
|
|
|
ignore = F.embedding(input, embedding_matrix, max_norm=0.01)
|
|
with self.assertRaisesRegex(RuntimeError, "modified"):
|
|
output2.sum().backward()
|
|
|
|
def test_type_annotations(self):
|
|
def fn(x, y):
|
|
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
|
return x, x * 2, x * 3
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
|
|
@torch.jit.script
|
|
def script_fn(x):
|
|
x, y, z, w = fn(x, x)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
|
|
@torch.jit.script
|
|
def script_fn2(x):
|
|
x, y = fn(x, x)
|
|
|
|
def fn_unpack(x):
|
|
y, z, w = fn(x, x)
|
|
return y
|
|
|
|
def fn_index(x):
|
|
q = fn(x, x)
|
|
return x
|
|
|
|
def fn_string(str, strpair):
|
|
# type: (str, Tuple[str, str]) -> Tuple[str, int, str, str]
|
|
str1, str2 = strpair
|
|
return str, 2, str1, str2
|
|
|
|
x = torch.ones(2, 2)
|
|
self.checkScript(fn_unpack, (x,), optimize=True)
|
|
self.checkScript(fn_index, (x,), optimize=True)
|
|
self.checkScript(fn_string, ("1", ("3", "4")), optimize=True)
|
|
|
|
def test_type_annotations_varargs(self):
|
|
def fn_varargs(x, *args):
|
|
return args[0] if args else x
|
|
|
|
def fn1(x, y, z):
|
|
return fn_varargs(x)
|
|
|
|
def fn2(x, y, z):
|
|
return fn_varargs(x, y)
|
|
|
|
def fn3(x, y, z):
|
|
return fn_varargs(x, y, z)
|
|
|
|
x, y, z = [torch.randn(2, 2) for _ in range(3)]
|
|
self.checkScript(fn1, (x, y, z), optimize=True)
|
|
self.checkScript(fn2, (x, y, z), optimize=True)
|
|
self.checkScript(fn3, (x, y, z), optimize=True)
|
|
|
|
@unittest.skipIf(not PY35, "Python 3.5 needed")
|
|
def test_type_annotation_py3(self):
|
|
import importlib.util
|
|
|
|
code = dedent("""
|
|
import torch
|
|
from torch import Tensor
|
|
from typing import Tuple
|
|
|
|
def fn(x : torch.Tensor, y : Tensor, z) -> Tuple[Tensor, Tensor, Tensor]:
|
|
return (x, y + z, z)
|
|
""")
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
script_path = os.path.join(tmp_dir, 'script.py')
|
|
with open(script_path, 'w') as f:
|
|
f.write(code)
|
|
fn = get_fn('test_type_annotation_py3', script_path)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"expected a value of type Tensor for argument"
|
|
r" '0' but found \(Tensor, Tensor\)"):
|
|
@torch.jit.script
|
|
def bad_fn(x):
|
|
x, y = fn((x, x), x, x)
|
|
return y
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"too many values .* need 2 but found 3"):
|
|
@torch.jit.script
|
|
def bad_fn2(x):
|
|
x, y = fn(x, x, x)
|
|
return y
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r"need 4 values .* found only 3"):
|
|
@torch.jit.script
|
|
def bad_fn3(x):
|
|
x, y, z, w = fn(x, x, x)
|
|
return y
|
|
|
|
def good_fn(x):
|
|
y, z, w = fn(x, x, x)
|
|
return y, z, w
|
|
|
|
self.checkScript(good_fn, (torch.ones(2, 2),), optimize=True)
|
|
|
|
def test_type_annotation_module(self):
|
|
class BaseModule(torch.jit.ScriptModule):
|
|
def foo(self, x):
|
|
# type: (Tensor) -> Tensor
|
|
return x + 1
|
|
|
|
def bar(self, x, y):
|
|
# type: (Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
|
return x + y, y
|
|
|
|
def baz(self, x, y):
|
|
return x
|
|
|
|
class ModuleTooMany(BaseModule):
|
|
@torch.jit.script_method
|
|
def method(self, x):
|
|
return self.foo(x, x)
|
|
|
|
class ModuleTooFew(BaseModule):
|
|
@torch.jit.script_method
|
|
def method(self, x):
|
|
return self.bar(x)
|
|
|
|
class ModuleTooManyAssign(BaseModule):
|
|
@torch.jit.script_method
|
|
def method(self, x):
|
|
y, z, w = self.bar(x, x)
|
|
return x
|
|
|
|
class ModuleDefault(BaseModule):
|
|
@torch.jit.script_method
|
|
def method(self, x):
|
|
y = self.baz(x)
|
|
return x
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "expected at most 1 arguments but found 2"):
|
|
ModuleTooMany()
|
|
with self.assertRaisesRegex(RuntimeError, "argument 1 not provided"):
|
|
ModuleTooFew()
|
|
with self.assertRaisesRegex(RuntimeError, "need 3 values .* found only 2"):
|
|
ModuleTooManyAssign()
|
|
with self.assertRaisesRegex(RuntimeError, "argument 1 not provided."):
|
|
ModuleDefault()
|
|
|
|
def test_script_define_order(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
pass
|
|
|
|
@torch.jit.script_method
|
|
def call_foo(self, input):
|
|
return self.foo(input)
|
|
|
|
@torch.jit.script_method
|
|
def foo(self, input):
|
|
return input + 1
|
|
m = M()
|
|
self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
|
|
|
|
def test_script_define_order_recursive_fail(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
pass
|
|
|
|
@torch.jit.script_method
|
|
def call_foo(self, input):
|
|
return self.foo(input)
|
|
|
|
@torch.jit.script_method
|
|
def foo(self, input):
|
|
self.call_foo(input)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'called recursively involving'):
|
|
M()
|
|
|
|
def test_script_kwargs_fn_call(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
pass
|
|
|
|
@torch.jit.script_method
|
|
def call_foo(self, input):
|
|
return self.foo(input=input, bar=1)
|
|
|
|
@torch.jit.script_method
|
|
def foo(self, bar, input):
|
|
# type: (int, Tensor) -> Tensor
|
|
return input + bar
|
|
m = M()
|
|
self.assertEqual(2, m.call_foo(torch.ones((), dtype=torch.int64)))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
def test_trace_of_script(self):
|
|
@torch.jit.script
|
|
def foo(a, c):
|
|
b = 0.0
|
|
if bool(a == 0.0):
|
|
b = 1.0
|
|
return b + c
|
|
|
|
a = torch.ones(1, dtype=torch.float)
|
|
|
|
@_trace(torch.zeros(1, dtype=torch.float))
|
|
def use(b):
|
|
return foo(b - 1.0, a) + 1.0
|
|
|
|
# test we propagated shapes through the function
|
|
self.assertTrue("Dynamic" not in str(use.graph))
|
|
|
|
self.assertEqual(3, use(torch.ones(1, dtype=torch.float)))
|
|
self.assertEqual(2, use(torch.zeros(1, dtype=torch.float)))
|
|
|
|
def test_if_define(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
if bool(a == 0):
|
|
b = 1
|
|
else:
|
|
b = 0
|
|
return b + 1
|
|
|
|
@torch.jit.script
|
|
def foo2(a):
|
|
b = 0
|
|
if bool(a == 0):
|
|
b = 1
|
|
return b + 1
|
|
|
|
@torch.jit.script
|
|
def foo3(a):
|
|
b = 1
|
|
if bool(a == 0):
|
|
c = 4
|
|
else:
|
|
b = 0
|
|
return b + 1
|
|
|
|
a = torch.ones(1, dtype=torch.long)
|
|
b = torch.zeros(1, dtype=torch.long)
|
|
self.assertEqual(1, foo(a))
|
|
self.assertEqual(2, foo(b))
|
|
self.assertEqual(1, foo2(a))
|
|
self.assertEqual(2, foo2(b))
|
|
self.assertEqual(1, foo3(a))
|
|
self.assertEqual(2, foo3(b))
|
|
|
|
def test_script_module_export_submodule(self):
|
|
class M1(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M1, self).__init__(False)
|
|
self.weight = nn.Parameter(torch.randn(2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, thing):
|
|
return self.weight + thing
|
|
|
|
class M2(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M2, self).__init__(False)
|
|
# test submodule
|
|
self.sub = M1()
|
|
self.weight = nn.Parameter(torch.randn(2, 3))
|
|
self.bias = nn.Parameter(torch.randn(2))
|
|
self.define("""
|
|
def hi(self, a):
|
|
return self.weight.mm(a)
|
|
""")
|
|
|
|
@torch.jit.script_method
|
|
def doit(self, input):
|
|
return self.weight.mm(input)
|
|
|
|
@torch.jit.script_method
|
|
def doit2(self, input):
|
|
return self.weight.mm(input)
|
|
|
|
@torch.jit.script_method
|
|
def doit3(self, input):
|
|
return input + torch.ones([1], dtype=torch.double)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
a = self.doit(input)
|
|
b = self.doit2(input)
|
|
c = self.hi(input)
|
|
return a + b + self.bias + c
|
|
|
|
m_orig = M2()
|
|
m_import = self.getExportImportCopy(m_orig)
|
|
|
|
input = torch.randn(3, 2)
|
|
self.assertEqual(m_orig.doit(input), m_import.doit(input))
|
|
self.assertEqual(m_orig.hi(input), m_import.hi(input))
|
|
self.assertEqual(m_orig.doit3(input), m_import.doit3(input))
|
|
self.assertEqual(m_orig.forward(input), m_import.forward(input))
|
|
|
|
@skipIfNoTorchVision
|
|
def test_script_module_trace_resnet18(self):
|
|
x = torch.ones(1, 3, 224, 224)
|
|
m_orig = torch.jit.trace(torchvision.models.resnet18(), torch.ones(1, 3, 224, 224))
|
|
m_import = self.getExportImportCopy(m_orig)
|
|
|
|
input = torch.randn(1, 3, 224, 224, requires_grad=True)
|
|
output_orig = m_orig(input)
|
|
output_orig.sum().backward()
|
|
grad_orig = input.grad.clone()
|
|
input.grad.zero_()
|
|
|
|
output_import = m_import(input)
|
|
output_import.sum().backward()
|
|
grad_import = input.grad.clone()
|
|
|
|
self.assertEqual(output_orig, output_import)
|
|
self.assertEqual(grad_orig, grad_import)
|
|
|
|
@skipIfNoTorchVision
|
|
def test_script_module_script_resnet(self):
|
|
def conv1x1(in_planes, out_planes, stride=1):
|
|
"""1x1 convolution"""
|
|
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
|
|
|
|
def conv3x3(in_planes, out_planes, stride=1):
|
|
"""3x3 convolution with padding"""
|
|
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
|
padding=1, bias=False)
|
|
|
|
class BasicBlock(torch.jit.ScriptModule):
|
|
expansion = 1
|
|
__constants__ = ['downsample']
|
|
|
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
|
super(BasicBlock, self).__init__()
|
|
self.conv1 = conv3x3(inplanes, planes, stride)
|
|
self.bn1 = nn.BatchNorm2d(planes)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.conv2 = conv3x3(planes, planes)
|
|
self.bn2 = nn.BatchNorm2d(planes)
|
|
self.downsample = downsample
|
|
self.stride = stride
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
residual = x
|
|
|
|
out = self.conv1(x)
|
|
out = self.bn1(out)
|
|
out = self.relu(out)
|
|
|
|
out = self.conv2(out)
|
|
out = self.bn2(out)
|
|
|
|
if self.downsample is not None:
|
|
residual = self.downsample(x)
|
|
|
|
out += residual
|
|
out = self.relu(out)
|
|
|
|
return out
|
|
|
|
class ResNet(torch.jit.ScriptModule):
|
|
__constants__ = ['layer1', 'layer2', 'layer3', 'layer4']
|
|
|
|
def __init__(self, block, layers, num_classes=1000):
|
|
super(ResNet, self).__init__()
|
|
self.inplanes = 64
|
|
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
|
bias=False)
|
|
self.bn1 = nn.BatchNorm2d(64)
|
|
self.relu = nn.ReLU(inplace=True)
|
|
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
|
self.layer1 = self._make_layer(block, 64, layers[0])
|
|
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
|
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
|
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
|
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
|
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def _make_layer(self, block, planes, blocks, stride=1):
|
|
downsample = None
|
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
|
downsample = nn.Sequential(
|
|
conv1x1(self.inplanes, planes * block.expansion, stride),
|
|
nn.BatchNorm2d(planes * block.expansion),
|
|
)
|
|
|
|
layers = []
|
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
|
self.inplanes = planes * block.expansion
|
|
for _ in range(1, blocks):
|
|
layers.append(block(self.inplanes, planes))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.bn1(x)
|
|
x = self.relu(x)
|
|
x = self.maxpool(x)
|
|
|
|
x = self.layer1(x)
|
|
x = self.layer2(x)
|
|
x = self.layer3(x)
|
|
x = self.layer4(x)
|
|
|
|
x = self.avgpool(x)
|
|
x = x.view(x.size(0), -1)
|
|
x = self.fc(x)
|
|
|
|
return x
|
|
|
|
resnet18 = ResNet(BasicBlock, [2, 2, 2, 2])
|
|
|
|
resnet18_imported = self.getExportImportCopy(resnet18)
|
|
|
|
input = torch.randn(1, 3, 224, 224, requires_grad=True)
|
|
output_orig = resnet18(input)
|
|
output_orig.sum().backward()
|
|
grad_orig = input.grad.clone()
|
|
input.grad.zero_()
|
|
output_import = resnet18_imported(input)
|
|
output_import.sum().backward()
|
|
grad_import = input.grad.clone()
|
|
|
|
self.assertEqual(output_orig, output_import)
|
|
self.assertEqual(grad_orig, grad_import)
|
|
|
|
def test_script_module_export_tensor_type(self):
|
|
class M(torch.jit.ScriptModule):
|
|
|
|
def __init__(self, type):
|
|
super(M, self).__init__(False)
|
|
self.param = torch.nn.Parameter(torch.zeros((5, 5), dtype=type).random_())
|
|
|
|
@torch.jit.script_method
|
|
def foo(self):
|
|
return self.param
|
|
|
|
for type in [torch.float, torch.double]:
|
|
m_orig = M(type)
|
|
m_import = self.getExportImportCopy(m_orig)
|
|
# check to make sure the storage wasn't resized
|
|
self.assertTrue(m_orig.param.storage().size() == 25)
|
|
self.assertEqual(m_orig.foo(), m_import.foo())
|
|
self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "testing cuda tensors require CUDA")
|
|
def test_script_module_export_tensor_cuda(self):
|
|
class M(torch.jit.ScriptModule):
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.param = torch.nn.Parameter(torch.zeros((5, 5), device='cuda:0').random_())
|
|
|
|
@torch.jit.script_method
|
|
def foo(self):
|
|
return self.param
|
|
|
|
m_orig = M()
|
|
m_import = self.getExportImportCopy(m_orig)
|
|
# check to make sure the storage wasn't resized
|
|
self.assertTrue(m_orig.param.storage().size() == 25)
|
|
self.assertTrue(m_import.foo().device == torch.device('cuda:0'))
|
|
self.assertEqual(m_orig.foo(), m_import.foo())
|
|
self.assertTrue(m_orig.foo().dtype == m_import.foo().dtype)
|
|
|
|
def test_script_module_export_blocks(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self, n, m):
|
|
super(M, self).__init__()
|
|
self.weight = torch.nn.Parameter(torch.rand(n, m))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
if bool(input.sum() > 0):
|
|
output = self.weight.mv(input)
|
|
else:
|
|
output = self.weight + input
|
|
return output
|
|
|
|
m_orig = M(200, 200)
|
|
m_import = self.getExportImportCopy(m_orig)
|
|
|
|
t = torch.rand(200)
|
|
self.assertEqual(m_orig(t), m_import(t))
|
|
|
|
def test_script_module_export_shared_storage(self):
|
|
class M(torch.jit.ScriptModule):
|
|
|
|
def __init__(self):
|
|
super(M, self).__init__(False)
|
|
self.param1 = torch.nn.Parameter(torch.rand(5, 5))
|
|
self.param2 = torch.nn.Parameter(self.param1[3])
|
|
self.param3 = torch.nn.Parameter(torch.rand(5, 5))
|
|
self.param4 = torch.nn.Parameter(torch.rand(11, 5)[1:6])
|
|
|
|
@torch.jit.script_method
|
|
def foo(self):
|
|
return self.param1 + self.param2 + self.param3 + self.param4
|
|
|
|
m_orig = M()
|
|
m_import = self.getExportImportCopy(m_orig)
|
|
|
|
self.assertEqual(m_orig.foo(), m_import.foo())
|
|
self.assertTrue(m_import.param1.storage().data_ptr() == m_import.param2.storage().data_ptr())
|
|
self.assertTrue(m_import.param1.storage().data_ptr() != m_import.param3.storage().data_ptr())
|
|
|
|
def test_onnx_export_script_module(self):
|
|
class ModuleToExport(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToExport, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
y = x - x
|
|
return x + x
|
|
|
|
mte = ModuleToExport()
|
|
outputs = mte(torch.zeros(1, 2, 3))
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
|
example_outputs=outputs))
|
|
|
|
def test_trace_nested_datatypes(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return [[x + 1, x - 1], [x + 2, x - 2]]
|
|
|
|
def bar(x):
|
|
list_stuff = foo(x)
|
|
return list_stuff[0][0], list_stuff[1][1]
|
|
|
|
traced = torch.jit.trace(bar, torch.rand(3, 4))
|
|
x = torch.rand(5, 6)
|
|
self.assertEqual(bar(x), traced(x))
|
|
|
|
@suppress_warnings
|
|
def test_onnx_export_func_with_warnings(self):
|
|
@torch.jit.script
|
|
def func_with_warning(inp):
|
|
return torch.nn.functional.sigmoid(inp) # triggers a deprecation warning
|
|
|
|
class WarningTest(torch.nn.Module):
|
|
def __init__(self):
|
|
super(WarningTest, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return func_with_warning(x)
|
|
|
|
outputs = WarningTest()(torch.randn(42))
|
|
# no exception
|
|
torch.onnx.export_to_pretty_string(
|
|
WarningTest(), torch.randn(42), None, verbose=False,
|
|
example_outputs=outputs)
|
|
|
|
def test_onnx_export_script_python_fail(self):
|
|
class ModuleToInline(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToInline, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.neg(x)
|
|
|
|
class ModuleToExport(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToExport, self).__init__()
|
|
self.mod = ModuleToInline()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
y = self.mod(x)
|
|
return y + y
|
|
|
|
mte = ModuleToExport()
|
|
outputs = mte(torch.zeros(1, 2, 3))
|
|
f = io.BytesIO()
|
|
with self.assertRaisesRegex(RuntimeError, "Couldn't export Python operator"):
|
|
torch.onnx._export(mte, (torch.zeros(1, 2, 3),), f, verbose=False,
|
|
example_outputs=outputs)
|
|
|
|
def test_onnx_export_script_inline_trace(self):
|
|
class ModuleToInline(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToInline, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.neg(x)
|
|
|
|
class ModuleToExport(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToExport, self).__init__()
|
|
self.mod = torch.jit.trace(ModuleToInline(), torch.zeros(1, 2, 3))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
y = self.mod(x)
|
|
return y + y
|
|
|
|
mte = ModuleToExport()
|
|
outputs = mte(torch.zeros(1, 2, 3))
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
|
example_outputs=outputs))
|
|
|
|
def test_onnx_export_script_inline_script(self):
|
|
class ModuleToInline(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToInline, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.neg(x)
|
|
|
|
class ModuleToExport(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToExport, self).__init__()
|
|
self.mod = ModuleToInline()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
y = self.mod(x)
|
|
return y + y
|
|
|
|
mte = ModuleToExport()
|
|
outputs = mte(torch.zeros(1, 2, 3))
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
|
example_outputs=outputs))
|
|
|
|
def test_onnx_export_script_module_loop(self):
|
|
class ModuleToExport(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToExport, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
# test if we support end to end onnx export on loop and
|
|
# nested loops with and without loop index
|
|
for _ in range(5):
|
|
for i in range(3):
|
|
x = x + i
|
|
return x
|
|
|
|
mte = ModuleToExport()
|
|
outputs = mte(torch.zeros(1, 2, 3))
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
|
example_outputs=outputs))
|
|
|
|
def test_onnx_export_script_truediv(self):
|
|
class ModuleToExport(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToExport, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
z = x.size(0) / 2
|
|
return x + z
|
|
|
|
mte = ModuleToExport()
|
|
outputs = mte(torch.zeros(1, 2, 3))
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
|
example_outputs=outputs))
|
|
|
|
def test_onnx_raw_export_script_truediv(self):
|
|
class ModuleToExport(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToExport, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
z = x.size(0) / 2
|
|
return x + z
|
|
|
|
mte = ModuleToExport()
|
|
outputs = mte(torch.zeros(1, 2, 3))
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
|
example_outputs=outputs, export_raw_ir=True))
|
|
|
|
def test_onnx_export_script_non_alpha_add_sub(self):
|
|
class ModuleToExport(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToExport, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
bs = x.size(0) + 1
|
|
return bs - 1
|
|
|
|
mte = ModuleToExport()
|
|
outputs = torch.LongTensor([mte(torch.rand(3, 4))])
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
mte, (torch.rand(3, 4),), None, verbose=False,
|
|
example_outputs=outputs))
|
|
|
|
def test_onnx_export_script_module_if(self):
|
|
class ModuleToExport(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToExport, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
if bool(torch.sum(x) > 0):
|
|
x = torch.neg(x)
|
|
return x
|
|
|
|
mte = ModuleToExport()
|
|
outputs = mte(torch.zeros(1, 2, 3, dtype=torch.long))
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
mte, (torch.zeros(1, 2, 3),), None, verbose=False,
|
|
example_outputs=outputs))
|
|
|
|
def test_onnx_export_script_inline_params(self):
|
|
class ModuleToInline(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToInline, self).__init__()
|
|
self.m = torch.nn.Parameter(torch.ones(3, 3))
|
|
self.unused = torch.nn.Parameter(torch.ones(1, 2, 3))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.mm(x, self.m)
|
|
|
|
class ModuleToExport(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ModuleToExport, self).__init__()
|
|
self.mod = ModuleToInline()
|
|
self.param = torch.nn.Parameter(torch.ones(3, 4))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
y = self.mod(x)
|
|
return torch.mm(y, self.param)
|
|
|
|
mte = ModuleToExport()
|
|
result = mte(torch.zeros(2, 3))
|
|
reference = torch.mm(torch.mm(torch.zeros(2, 3), torch.ones(3, 3)), torch.ones(3, 4))
|
|
self.assertEqual(result, reference)
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
mte, (torch.ones(2, 3),), None, verbose=False,
|
|
example_outputs=result, propagate=True))
|
|
|
|
def test_trace_with_size(self):
|
|
@_trace(torch.zeros(1, 1))
|
|
def foo(x):
|
|
return x + 1
|
|
|
|
@torch.jit.script
|
|
def bar(x):
|
|
y = int(foo(x))
|
|
if True:
|
|
y = 7
|
|
return y + 1
|
|
|
|
self.assertEqual(8, bar(torch.ones(1, 1)))
|
|
|
|
def test_tracing_slicing(self):
|
|
@_trace(torch.zeros(10))
|
|
def foo_trace(x):
|
|
return x[-5:-3]
|
|
|
|
@torch.jit.script
|
|
def foo_script(x):
|
|
return x[-5:-3]
|
|
|
|
def foo(x):
|
|
return x[-5:-3]
|
|
|
|
a = torch.arange(0, 8)
|
|
b = torch.arange(0, 20)
|
|
self.assertEqual(foo_trace(a), foo_script(a))
|
|
self.assertEqual(foo_trace(a), foo(a))
|
|
self.assertNotEqual(foo_trace(a), foo_trace(b))
|
|
|
|
def test_tracing_indexing(self):
|
|
@_trace(torch.zeros(10))
|
|
def foo_trace(x):
|
|
return x[-2]
|
|
|
|
@torch.jit.script
|
|
def foo_script(x):
|
|
return x[-2]
|
|
|
|
def foo(x):
|
|
return x[-2]
|
|
|
|
a = torch.arange(0, 8)
|
|
b = torch.arange(0, 20)
|
|
self.assertEqual(foo_script(a), foo_trace(a))
|
|
self.assertEqual(foo_trace(a), foo(a))
|
|
self.assertNotEqual(foo_trace(a), foo_trace(b))
|
|
|
|
def test_index_select_shape_prop(self):
|
|
|
|
@torch.jit.script
|
|
def foo(x, y):
|
|
return torch.index_select(x, index=y, dim=1)
|
|
|
|
a = torch.zeros(2, 2)
|
|
b = torch.zeros(4, dtype=torch.long)
|
|
torch._C._jit_pass_complete_shape_analysis(foo.graph, (a, b), False)
|
|
FileCheck().check("Double(2, 4)").run(str(foo.graph))
|
|
|
|
def test_onnx_export_speculate(self):
|
|
|
|
class Foo(torch.jit.ScriptModule):
|
|
def __init__(self, m):
|
|
super(Foo, self).__init__()
|
|
self.m = m
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
x += x
|
|
# because we are testing if we emit `if` statement correctly
|
|
# we cannot use `True` as the condition. Constant prop
|
|
# would remove the `if` statements.
|
|
c = torch.sum(x) > 4
|
|
if bool(c):
|
|
if bool(c):
|
|
y = self.m(x)
|
|
else:
|
|
y = self.m(x)
|
|
else:
|
|
y = self.m(x)
|
|
return y
|
|
|
|
linear = torch.jit.trace(nn.Linear(10, 20).float(), torch.zeros(1, 10, dtype=torch.float))
|
|
|
|
@torch.jit.script
|
|
def transpose(x):
|
|
return x.t()
|
|
|
|
f1 = Foo(transpose)
|
|
outputs_f1 = f1(torch.ones(1, 10, dtype=torch.float))
|
|
f2 = Foo(linear)
|
|
outputs_f2 = f2(torch.ones(1, 10, dtype=torch.float))
|
|
|
|
onnx_ish = torch.onnx.export_to_pretty_string(
|
|
f1,
|
|
(torch.ones(1, 10, dtype=torch.float), ),
|
|
None, verbose=False, example_outputs=outputs_f1)
|
|
self.assertExpected(onnx_ish, subname='f1')
|
|
onnx_ish = torch.onnx.export_to_pretty_string(
|
|
f2,
|
|
(torch.ones(1, 10, dtype=torch.float), ),
|
|
None, verbose=False, example_outputs=outputs_f2)
|
|
self.assertExpected(onnx_ish, subname='f2')
|
|
|
|
def test_onnx_export_shape_reshape(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
import torch.onnx.operators
|
|
x = x.repeat(5, 1, 1)
|
|
shape = torch.onnx.operators.shape_as_tensor(x)
|
|
reshaped = torch.onnx.operators.reshape_from_tensor_shape(x, shape)
|
|
return reshaped
|
|
|
|
foo = torch.jit.trace(Foo(), torch.zeros(1, 2, 3))
|
|
outputs = foo(torch.zeros(1, 2, 3))
|
|
f = io.BytesIO()
|
|
s = torch.onnx.export_to_pretty_string(foo, (torch.zeros(1, 2, 3)), f,
|
|
example_outputs=outputs)
|
|
self.assertExpected(s)
|
|
|
|
def test_shape_analysis_loop(self):
|
|
def foo(a, b, x):
|
|
c = a
|
|
# on the first iteration of the loop it appears that
|
|
# c should have a expand to the size of b
|
|
# but on the second+ iterations, there is no broadcast and the
|
|
# sizes are different.
|
|
# previously this would cause the compiler to (1) enter an infinite
|
|
# loop trying to compute the shape, and (2) insert invalid
|
|
# broadcasts.
|
|
# this test ensure we don't regress on these issues
|
|
for _ in range(2):
|
|
a = c + b
|
|
c = x
|
|
b = x
|
|
return a
|
|
|
|
self.checkScript(foo, (torch.zeros(1), torch.zeros(4), torch.zeros(5)), optimize=False)
|
|
|
|
def test_intlist_args(self):
|
|
def func_1(x):
|
|
return torch.nn.functional.adaptive_avg_pool1d(x, 1)
|
|
|
|
def func_2(x):
|
|
return torch.nn.functional.adaptive_avg_pool1d(x, output_size=1)
|
|
|
|
def func_3(x):
|
|
return torch.nn.functional.adaptive_avg_pool1d(x, output_size=[1])
|
|
|
|
x = torch.randn(8, 8, 8)
|
|
self.checkScript(func_1, [x], optimize=True)
|
|
self.checkScript(func_2, [x], optimize=True)
|
|
self.checkScript(func_3, [x], optimize=True)
|
|
|
|
def test_wrong_implicit_expand(self):
|
|
|
|
@_trace(torch.zeros(3), torch.zeros(1))
|
|
def foo(a, b):
|
|
return a + b
|
|
|
|
a = torch.rand(4)
|
|
b = torch.rand(4)
|
|
self.assertEqual(a + b, foo(a, b))
|
|
|
|
def test_builtin_args_fails(self):
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'expected at most'):
|
|
@torch.jit.script
|
|
def f0(a):
|
|
torch.sum(a, a, a, a)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'argument self not provided'):
|
|
@torch.jit.script
|
|
def f1(a):
|
|
torch.sum(foo=4)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'specified twice'):
|
|
@torch.jit.script
|
|
def f2(a):
|
|
torch.sum(a, self=a)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'not provided'):
|
|
@torch.jit.script
|
|
def f3(a):
|
|
torch.sum(dim=4)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'for argument \'tensors\' but found Tensor'):
|
|
@torch.jit.script
|
|
def f4(a):
|
|
torch.cat(a)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, r'argument \'tensors\' but found int\[\]'):
|
|
@torch.jit.script
|
|
def f5(a):
|
|
torch.cat([3])
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'Lists must contain only a single type'):
|
|
@torch.jit.script
|
|
def f6(a):
|
|
a.expand(size=[3, [4]])
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'xpected a value of type Tensor for argument \'self\''):
|
|
@torch.jit.script
|
|
def f7(a):
|
|
torch.sum([4])
|
|
|
|
def test_builtin_args(self):
|
|
|
|
def t0(a):
|
|
# default arg dim
|
|
return torch.cat([a, a])
|
|
|
|
self.checkScript(t0, (torch.zeros(1, 1),))
|
|
|
|
def t1(a):
|
|
# keywords out of order
|
|
return torch.cat(dim=1, tensors=[a, a])
|
|
|
|
self.checkScript(t1, (torch.zeros(1, 1, 2),))
|
|
|
|
def t2(a):
|
|
# mix const/non-const attributes
|
|
if True:
|
|
b = 1
|
|
else:
|
|
b = 0
|
|
return torch.sum(a, dim=b, keepdim=False)
|
|
|
|
self.checkScript(t2, (torch.zeros(1, 1, 2),))
|
|
|
|
def test_parser_type_annotations(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(x : Tensor, y : Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
|
|
return x, x
|
|
''')
|
|
|
|
self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
|
|
|
|
def test_parser_type_annotations_comment(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(x, y):
|
|
# type: (Tensor, Tuple[Tuple[Tensor, Tensor], Tensor]) -> Tuple[Tensor, Tensor]
|
|
return x, x
|
|
''')
|
|
|
|
self.assertExpected(cu.__getattr__('foo').pretty_print_schema())
|
|
|
|
def test_parser_type_annotations_unknown_type(self):
|
|
with self.assertRaisesRegex(RuntimeError, r'Unknown type name Foo'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(x : Tensor, y : Tuple[Tuple[Foo, Tensor], Tensor]) -> Tuple[Tensor, Tensor]:
|
|
return x, x
|
|
''')
|
|
|
|
def test_parser_type_annotations_subscript_non_ident(self):
|
|
with self.assertRaisesRegex(RuntimeError, r'Subscripted type must be a type identifier'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(x : Tensor, y : Tuple[Tensor, Tensor][Tensor]) -> Tuple[Tensor, Tensor]:
|
|
return x, x
|
|
''')
|
|
|
|
def test_parser_type_annotations_subscript_tensor(self):
|
|
with self.assertRaisesRegex(RuntimeError, r'Unknown type constructor Tensor'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(x : Tensor, y : Tensor[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
|
|
return x, x
|
|
''')
|
|
|
|
def test_parser_type_annotations_incompatible_expression(self):
|
|
with self.assertRaisesRegex(RuntimeError, r'Expression of type \+ cannot be used in a type expression'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(x : Tensor, y : Tuple[3 + 4, Tensor]) -> Tuple[Tensor, Tensor]:
|
|
return x, x
|
|
''')
|
|
|
|
def test_gather_dynamic_index(self):
|
|
def t(x):
|
|
gather1 = x[0]
|
|
idx = 0 + 1
|
|
gather2 = x[idx]
|
|
return gather1 + gather2
|
|
|
|
self.checkScript(t, (torch.zeros(3, 2, 3),))
|
|
|
|
def test_slice_dynamic_index(self):
|
|
def t(x):
|
|
slice1 = x[0:1]
|
|
zero = 0
|
|
one = zero + 1
|
|
slice2 = x[zero:one]
|
|
return slice1 + slice2
|
|
|
|
self.checkScript(t, (torch.zeros(3, 2, 3),))
|
|
|
|
def test_addmm_grad(self):
|
|
""" This test checks several things:
|
|
1. An expand node was inserted before the addmm operating on the
|
|
bias term.
|
|
2. The fused form of addmm appears in the ultimate graph that's
|
|
executed.
|
|
3. A sum op was emitted for accumulating gradients along the 0th
|
|
(expanded) dimension of the bias term.
|
|
4. The correct symbolic representation for the backward pass of the
|
|
mm operator was emitted (x.t() -> mm)
|
|
|
|
TODO: we should actually check these conditions once we have a way
|
|
to dump the GraphExecutor state. Namely the processed forward graph
|
|
and the backward graph.
|
|
"""
|
|
@torch.jit.script
|
|
def addmm_grad_test(b, x, w):
|
|
return torch.addmm(b, x, w)
|
|
|
|
# Initialize param and input values
|
|
w_init = torch.rand(2, 5)
|
|
b_init = torch.rand(5)
|
|
x = torch.rand(3, 2)
|
|
|
|
# Clone trainable params
|
|
b = b_init.clone()
|
|
b.requires_grad_()
|
|
w = w_init.clone()
|
|
w.requires_grad_()
|
|
|
|
# Test symbolic differentiation
|
|
y = addmm_grad_test(b, x, w)
|
|
y.sum().backward()
|
|
|
|
# clone params for autograd reference
|
|
b_ref = b_init.clone()
|
|
b_ref.requires_grad_()
|
|
w_ref = w_init.clone()
|
|
w_ref.requires_grad_()
|
|
y_ref = torch.addmm(b_ref, x, w_ref)
|
|
y_ref.sum().backward()
|
|
|
|
self.assertEqual(w.grad, w_ref.grad)
|
|
self.assertEqual(b.grad, b_ref.grad)
|
|
|
|
def test_zeros(self):
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['d']
|
|
|
|
def __init__(self):
|
|
self.d = torch.device('cpu')
|
|
|
|
@torch.jit.script_method
|
|
def create(self):
|
|
return torch.zeros([1, 1, 2], dtype=torch.float, device=self.d, layout=torch.strided)
|
|
|
|
r = M().create()
|
|
self.assertEqual(r.dtype, torch.float)
|
|
self.assertEqual(torch.zeros([1, 1, 2], dtype=torch.float), r)
|
|
|
|
def test_vararg_zeros(self):
|
|
def foo():
|
|
return torch.zeros(3, 4, 5, dtype=torch.int)
|
|
|
|
self.checkScript(foo, ())
|
|
|
|
def test_rand(self):
|
|
def test_rand():
|
|
a = torch.rand([3, 4])
|
|
return a + 1.0 - a
|
|
|
|
self.checkScript(test_rand, ())
|
|
|
|
def test_erase_number_types(self):
|
|
def func(a):
|
|
b = 7 + 1 + 3
|
|
c = a + b
|
|
c += b
|
|
return c
|
|
|
|
graph = torch.jit.script(func).graph
|
|
FileCheck().check("int = prim::Constant").check("aten::add_").run(str(graph))
|
|
self.run_pass('remove_inplace_ops', graph)
|
|
self.run_pass('erase_number_types', graph)
|
|
self.run_pass('dce', graph)
|
|
FileCheck().check_not("int = prim::Constant").check_not("aten::add_").run(str(graph))
|
|
|
|
def test_mm_batching(self):
|
|
lstm_cell = torch.jit.script(LSTMCellS)
|
|
|
|
def lstm(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
|
|
for i in range(x.size(0)):
|
|
hx, cx = lstm_cell(x[i], hx, cx, w_ih, w_hh, b_ih, b_hh)
|
|
return hx
|
|
|
|
slstm = torch.jit.script(lstm)
|
|
|
|
inputs = get_lstm_inputs('cpu', training=True, seq_length=10)
|
|
slstm(*inputs).sum().backward()
|
|
|
|
fw_graph = slstm.graph_for(*inputs)
|
|
bw_graph = backward_graph(slstm, diff_graph_idx=0)
|
|
self.assertTrue('prim::MMBatchSide' in str(fw_graph))
|
|
self.assertTrue('prim::MMTreeReduce' in str(bw_graph))
|
|
|
|
sout = slstm(*inputs)
|
|
out = lstm(*inputs)
|
|
self.assertEqual(slstm(*inputs), lstm(*inputs))
|
|
self.assertEqual(torch.autograd.grad(slstm(*inputs).sum(), inputs),
|
|
torch.autograd.grad(lstm(*inputs).sum(), inputs))
|
|
|
|
def test_loop_unrolling(self):
|
|
def fn(x):
|
|
y = 0
|
|
for i in range(int(x)):
|
|
y -= i
|
|
return y
|
|
|
|
graph = torch.jit.script(fn).graph
|
|
self.run_pass('loop_unrolling', graph)
|
|
unroll_factor = 8
|
|
FileCheck().check("prim::Loop").check_count("aten::sub", unroll_factor) \
|
|
.check("prim::Loop").check("aten::sub").run(str(graph))
|
|
self.checkScript(fn, (torch.tensor(10),))
|
|
|
|
def test_loop_unrolling_const(self):
|
|
def fn():
|
|
y = 0
|
|
for _ in range(10):
|
|
y -= 1
|
|
return y
|
|
|
|
def fn2():
|
|
y = 0
|
|
for i in range(10):
|
|
y -= i
|
|
return y
|
|
|
|
def check(fn, name):
|
|
graph = torch.jit.script(fn).graph
|
|
self.run_pass('loop_unrolling', graph)
|
|
# entirely unrolled
|
|
FileCheck().check_not("prim::Loop'").run(str(graph))
|
|
self.checkScript(fn, ())
|
|
|
|
check(fn, 'add_const')
|
|
check(fn2, 'add_iter')
|
|
|
|
def test_loop_unrolling_nested(self):
|
|
def fn(x):
|
|
y = 0
|
|
for _ in range(10):
|
|
for j in range(int(x)):
|
|
y -= j
|
|
return y
|
|
|
|
graph = torch.jit.script(fn).graph
|
|
self.run_pass('loop_unrolling', graph)
|
|
# inner loop with 8 subs followed by loop epilogue
|
|
unroll_factor = 8
|
|
FileCheck().check("prim::Loop").check("prim::Loop").check_count('aten::sub', unroll_factor) \
|
|
.check("prim::Loop").check("aten::sub").run(str(graph))
|
|
self.checkScript(fn, (torch.tensor(10),))
|
|
|
|
def test_loop_unroll_unused_counter(self):
|
|
def fn(x):
|
|
y = 0
|
|
for _ in range(int(x)):
|
|
y -= 1
|
|
return y
|
|
|
|
graph = torch.jit.script(fn).graph
|
|
self.run_pass('loop_unrolling', graph)
|
|
FileCheck().check("prim::Loop").check_not("aten::add").check("return") \
|
|
.run(str(graph))
|
|
|
|
def test_loop_unroll_negative(self):
|
|
def fn(x):
|
|
y = 0
|
|
for _ in range(int(x)):
|
|
y += 1
|
|
return y
|
|
|
|
self.checkScript(fn, (torch.tensor(-20),))
|
|
self.checkScript(fn, (torch.tensor(-2),))
|
|
self.checkScript(fn, (torch.tensor(-1),))
|
|
self.checkScript(fn, (torch.tensor(0),))
|
|
self.checkScript(fn, (torch.tensor(1),))
|
|
self.checkScript(fn, (torch.tensor(2),))
|
|
|
|
def test_where(self):
|
|
def fn(x, y):
|
|
return torch.where(x > 0.0, x, y)
|
|
|
|
self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
|
|
|
|
def test_where_method(self):
|
|
def fn(x, y):
|
|
return x.where(x > 0.0, y)
|
|
|
|
self.checkScript(fn, (torch.randn(3, 2, dtype=torch.float), torch.ones(3, 2, dtype=torch.float)))
|
|
|
|
def test_reassign_module_lhs(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'self\' because it has type value and self is'
|
|
' not a first-class value. Only reassignments to first-class values are allowed'):
|
|
class ReassignSelfLHS(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
for _ in range(20):
|
|
self = x
|
|
return self
|
|
|
|
ReassignSelfLHS()
|
|
|
|
def test_reassign_module_rhs(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'Cannot re-assign \'x\' to a value of type module because x is not a'
|
|
' first-class value. Only reassignments to first-class values are allowed'):
|
|
class ReassignSelfRHS(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
for _ in range(20):
|
|
x = self
|
|
return self
|
|
|
|
ReassignSelfRHS()
|
|
|
|
def test_unknown_builtin(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'unknown builtin op'):
|
|
@torch.jit.script
|
|
def unknown_builtin(x):
|
|
return x.splork(3)
|
|
|
|
def test_return_tuple(self):
|
|
def return_tuple(x):
|
|
a = (x, x)
|
|
return a, x
|
|
self.checkScript(return_tuple, (torch.rand(4),))
|
|
|
|
def test_method_no_self(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'methods must have a self argument'):
|
|
class MethodNoSelf(torch.jit.ScriptModule):
|
|
@torch.jit.script_method # noqa: B902
|
|
def forward():
|
|
return torch.zeros(3, 4)
|
|
|
|
MethodNoSelf()
|
|
|
|
def test_return_stmt_not_at_end(self):
|
|
def return_stmt(x):
|
|
if bool(x > 3):
|
|
return x + 3
|
|
else:
|
|
return x
|
|
self.checkScript(return_stmt, (torch.rand(1),))
|
|
|
|
def test_for_range_no_arg(self):
|
|
with self.assertRaisesRegex(RuntimeError, r'range\(\) expects 1 argument but got 0'):
|
|
@torch.jit.script
|
|
def range_no_arg(x):
|
|
for _ in range():
|
|
x += 1
|
|
return x
|
|
|
|
def test_list_iterables(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'List of iterables is not supported currently'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def list_iterables(x):
|
|
for i, j in [2, 3, 4], [5, 6, 7]:
|
|
x += i
|
|
x += j
|
|
return x
|
|
''')
|
|
|
|
def test_for_tuple_unpack(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'Iteration variable unpacking is not supported'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def for_tuple_unpack(x, y):
|
|
for i, j in [[3, 4], [5, 6], [7, 8]]:
|
|
x += i
|
|
y += j
|
|
return x, y
|
|
''')
|
|
|
|
def test_single_starred_lhs(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'A Starred expression may only appear on the lhs within the presence'
|
|
' of another non-starred expression'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def single_starred_lhs(x):
|
|
a = (x, x, x)
|
|
*b, = a
|
|
return b
|
|
''')
|
|
|
|
def test_singleton_tuple_unpack(self):
|
|
def foo(a):
|
|
b, = (a,)
|
|
return b + 1
|
|
self.checkScript(foo, (torch.rand(3),))
|
|
|
|
def test_multi_reduction(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
'augmented assignment can only have one LHS expression'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def multi_reduction(x):
|
|
a, b += x
|
|
return a, b
|
|
''')
|
|
|
|
def test_invalid_call_arguments(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'arguments for call are not valid'):
|
|
@torch.jit.script
|
|
def invalid_call_arguments(x):
|
|
return torch.unsqueeze(3, 4, 5, 6, 7, 8)
|
|
|
|
def test_invalid_lhs_assignment(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def invalid_lhs_assignment(x):
|
|
x + 1 = x
|
|
return x
|
|
''')
|
|
|
|
def test_multi_starred_expr_lhs(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'Only one starred expression is allowed on the lhs'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def multi_starred_expr_lhs():
|
|
a, *b, *c = [1, 2, 3, 4, 5, 6]
|
|
return a
|
|
''')
|
|
|
|
def test_pack_tuple_into_non_var(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'Cannot pack a tuple into a non-variable'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def pack_tuple_into_non_var(x):
|
|
a, *1 = (3, 4, 5)
|
|
return x
|
|
''')
|
|
|
|
def test_print_kwargs(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'print doesn\'t accept any keyword arguments'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def print_kwargs(x):
|
|
print(x, flush=True)
|
|
return x
|
|
''')
|
|
|
|
def test_builtin_use_as_value(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'builtin cannot be used as a value'):
|
|
@torch.jit.script
|
|
def builtin_use_as_value(x):
|
|
return x.unsqueeze
|
|
|
|
def test_wrong_use_as_tuple(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'cannot be used as a tuple'):
|
|
def test_fn():
|
|
return 3
|
|
|
|
@torch.jit.script
|
|
def wrong_use_as_tuple(self):
|
|
a, b = test_fn
|
|
return a
|
|
|
|
def test_wrong_attr_lookup(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'attribute lookup is not defined on builtin'):
|
|
@torch.jit.script
|
|
def wrong_attr_lookup(self, x):
|
|
a = x.unsqueeze.myattr
|
|
return a
|
|
|
|
def test_wrong_use_as_callable(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
|
|
@torch.jit.script
|
|
def wrong_use_as_callable(x):
|
|
return x(3, 4, 5)
|
|
|
|
def test_python_val_doesnt_have_attr(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'object has no attribute abcd'):
|
|
|
|
@torch.jit.script
|
|
def python_val_doesnt_have_attr():
|
|
# this has to be a module otherwise attr lookup would not be
|
|
# allowed in the first place
|
|
return shutil.abcd
|
|
|
|
def test_wrong_module_attr_lookup(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'python value of type \'type\' cannot be used as a value:'):
|
|
import io
|
|
|
|
@torch.jit.script
|
|
def wrong_module_attr_lookup():
|
|
return io.BytesIO
|
|
|
|
def test_wrong_method_call_inputs(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'argument y not provided'):
|
|
class SomeModule(torch.jit.ScriptModule):
|
|
|
|
@torch.jit.script_method
|
|
def foo(self, x, y):
|
|
return x
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
return self.foo(x)
|
|
SomeModule()
|
|
|
|
def test_single_starred_expr_for_loop(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'unexpected expression'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def test():
|
|
x = 0
|
|
for *a in [1, 2, 3]:
|
|
x = x + 1
|
|
return x
|
|
''')
|
|
|
|
def test_duplicate(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'Method \'test\' already defined'):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def test():
|
|
return 1
|
|
|
|
def test():
|
|
return 2
|
|
''')
|
|
|
|
def test_call_ge(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'expected at most 1 arguments but found 3'):
|
|
@_trace(torch.zeros(1, 2, 3))
|
|
def foo(x):
|
|
return x
|
|
|
|
@torch.jit.script
|
|
def test_fn():
|
|
return foo(torch.full([1], 1), torch.full([1], 2), torch.full([1], 3))
|
|
|
|
def test_wrong_return_type(self):
|
|
with self.assertRaisesRegex(RuntimeError, 'but instead got value of type tuple'):
|
|
def somefunc():
|
|
# type: () -> Tuple[Tuple[Tensor, Tensor]]
|
|
return torch.zeros(3, 4), torch.zeros(4, 5) # noqa: T484
|
|
|
|
@torch.jit.script
|
|
def wrong_return_type():
|
|
return somefunc()
|
|
wrong_return_type()
|
|
|
|
# Tests for calling between different front-end modes
|
|
def test_call_python_fn_from_tracing_fn(self):
|
|
def python_fn(x):
|
|
return torch.neg(x)
|
|
|
|
@_trace(torch.rand(3, 4))
|
|
def traced_fn(x):
|
|
return python_fn(x) + 1
|
|
|
|
# The neg op in the python function should be properly inlined to the
|
|
# graph
|
|
FileCheck().check("aten::neg").run(str(traced_fn.graph))
|
|
|
|
def test_call_python_mod_from_tracing_fn(self):
|
|
class PythonMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super(PythonMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, self.param)
|
|
|
|
pm = PythonMod()
|
|
|
|
@_trace(torch.rand(3, 4))
|
|
def traced_fn(x):
|
|
return pm(x) + 1.0
|
|
|
|
# Note: the parameter self.param from the Python module is inlined
|
|
# into the graph
|
|
self.assertTrue(len(list(traced_fn.graph.inputs())) == 1)
|
|
FileCheck().check("aten::mm").check("aten::add").run(str(traced_fn.graph))
|
|
|
|
def test_call_traced_fn_from_tracing_fn(self):
|
|
@_trace(torch.rand(3, 4))
|
|
def traced_fn1(x):
|
|
return torch.neg(x)
|
|
|
|
@_trace(torch.rand(3, 4))
|
|
def traced_fn(x):
|
|
return traced_fn1(x) + 1
|
|
|
|
FileCheck().check("aten::neg").check_same("scope: traced_fn1").check("aten::add") \
|
|
.run(str(traced_fn.graph))
|
|
|
|
def test_call_traced_mod_from_tracing_fn(self):
|
|
class TracedModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TracedModule, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 3), requires_grad=False)
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, self.param)
|
|
|
|
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
|
|
|
|
@_trace(torch.rand(3, 4))
|
|
def traced_fn(x):
|
|
return tm(x) + 1.0
|
|
|
|
# Note: the parameter self.param from the Python module is inlined
|
|
# into the graph
|
|
FileCheck().check("prim::Constant[value=<Tensor>]").check("aten::mm") \
|
|
.check("aten::add").run(str(traced_fn.graph))
|
|
|
|
def test_call_script_fn_from_tracing_fn(self):
|
|
@torch.jit.script
|
|
def script_fn(x):
|
|
return torch.neg(x)
|
|
|
|
@_trace(torch.rand(3, 4))
|
|
def traced_fn(x):
|
|
return script_fn(x) + 1
|
|
|
|
FileCheck().check("aten::neg").check("aten::add").run(str(traced_fn.graph))
|
|
|
|
def test_call_script_mod_from_tracing_fn(self):
|
|
with self.disableModuleHook():
|
|
class ScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 4), requires_grad=False)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
for _i in range(4):
|
|
x += self.param
|
|
return x
|
|
|
|
sm = ScriptMod()
|
|
|
|
@_trace(torch.rand(3, 4))
|
|
def traced_fn(x):
|
|
return sm(x) + 1.0
|
|
|
|
# parameter turns into constant and loop is perserved
|
|
FileCheck().check("prim::Constant[value=<Tensor>]").check("Loop") \
|
|
.run(str(traced_fn.graph))
|
|
|
|
def test_call_python_fn_from_traced_module(self):
|
|
def python_fn(x):
|
|
return torch.neg(x)
|
|
|
|
class TracedModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TracedModule, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 3))
|
|
|
|
def forward(self, x):
|
|
return torch.mm(python_fn(x), self.param)
|
|
|
|
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
|
|
|
|
# Note: parameter self.param from the traced module should appear as
|
|
# an input to the graph and the neg op from the Python function should
|
|
# be properly inlined
|
|
self.assertTrue(len(list(tm.graph.inputs())) == 2)
|
|
FileCheck().check("aten::neg").check("aten::mm").run(str(tm.graph))
|
|
|
|
def test_call_python_mod_from_traced_module(self):
|
|
class PythonModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(PythonModule, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(5, 7))
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, self.param)
|
|
|
|
class TracedModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TracedModule, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 5))
|
|
self.mod = PythonModule()
|
|
|
|
def forward(self, x):
|
|
return self.mod(torch.mm(x, self.param)) + 1.0
|
|
|
|
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
|
|
|
|
# Note: the parameters from both modules should appear in the flattened
|
|
# inputs of the graph. All ops from both modules should be inlined.
|
|
self.assertTrue(len(list(tm.graph.inputs())) == 3)
|
|
FileCheck().check_not("value=<Tensor>").check_count("aten::mm", 2).check("aten::add") \
|
|
.run(str(tm.graph))
|
|
|
|
def test_call_traced_fn_from_traced_module(self):
|
|
@_trace(torch.rand(3, 4))
|
|
def traced_fn(x):
|
|
return torch.neg(x)
|
|
|
|
class TracedModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TracedModule, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 5))
|
|
|
|
def forward(self, x):
|
|
return traced_fn(torch.mm(x, self.param))
|
|
|
|
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
|
|
# Note: neg op from the traced function should be properly inlined
|
|
FileCheck().check("aten::mm").check_same("scope: TracedModule") \
|
|
.check_next("aten::neg").check("scope: TracedModule/traced_fn") \
|
|
.run(str(tm.graph))
|
|
|
|
def test_trace_hierarchy(self):
|
|
# Test that we preserve the module hierarchy for a ScriptModule
|
|
# submodule during tracing
|
|
|
|
class AnotherScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(AnotherScriptMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(1, 2, 3))
|
|
|
|
@torch.jit.script_method
|
|
def bar(self):
|
|
return torch.zeros(4, 5)
|
|
|
|
class SomeScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(SomeScriptMod, self).__init__()
|
|
self.asm = AnotherScriptMod()
|
|
|
|
@torch.jit.script_method
|
|
def foo(self):
|
|
return torch.zeros(3, 4)
|
|
|
|
@torch.jit.script_method
|
|
def bar(self):
|
|
return torch.zeros(4, 3)
|
|
|
|
class TraceMe(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TraceMe, self).__init__()
|
|
self.ssm = SomeScriptMod()
|
|
|
|
def forward(self, x):
|
|
return self.ssm.bar() + x
|
|
|
|
orig = TraceMe()
|
|
traced = torch.jit.trace(orig, (torch.rand(4, 3),))
|
|
# for each of these checks, check that *BOTH* the underlying
|
|
# _C.ScriptModule object has the expected method/param, as well as the
|
|
# Python object that wraps it.
|
|
self.assertTrue(traced.ssm._has_method('foo'))
|
|
self.assertTrue(hasattr(traced.ssm, 'foo'))
|
|
|
|
imported = self.getExportImportCopy(traced)
|
|
|
|
self.assertTrue(imported.ssm._has_method('foo'))
|
|
self.assertTrue(hasattr(imported.ssm, 'foo'))
|
|
|
|
self.assertTrue(imported.ssm.asm._has_method('bar'))
|
|
self.assertTrue(hasattr(imported.ssm.asm, 'bar'))
|
|
|
|
self.assertTrue(imported.ssm.asm._has_parameter('param'))
|
|
self.assertTrue(hasattr(imported.ssm.asm, 'param'))
|
|
|
|
def test_trace_parameter(self):
|
|
class Param(nn.Module):
|
|
def __init__(self):
|
|
super(Param, self).__init__()
|
|
self.register_parameter("bias", nn.Parameter(torch.Tensor(4, 4)))
|
|
|
|
def forward(self, x):
|
|
return x
|
|
|
|
class M3(torch.jit.ScriptModule):
|
|
def __init__(self, model):
|
|
super(M3, self).__init__(False)
|
|
self.traced = torch.jit.trace(model, (torch.rand(3, 3)))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.traced(x)
|
|
|
|
class M2(nn.Module):
|
|
def __init__(self, model):
|
|
super(M2, self).__init__()
|
|
self.module = M3(model)
|
|
|
|
def forward(self, x):
|
|
return self.module(x)
|
|
|
|
class M1(torch.jit.ScriptModule):
|
|
def __init__(self, model):
|
|
super(M1, self).__init__(False)
|
|
self.traced = torch.jit.trace(M2(model), (torch.rand(3, 3)))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.traced(x)
|
|
|
|
module = M1(Param())
|
|
f = io.BytesIO()
|
|
torch.jit.save(module, f)
|
|
|
|
def test_call_traced_module_from_traced_module(self):
|
|
class TracedModule1(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TracedModule1, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(5, 7))
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, self.param)
|
|
|
|
class TracedModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TracedModule, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 5))
|
|
self.mod = torch.jit.trace(TracedModule1(), torch.rand(3, 5))
|
|
|
|
def forward(self, x):
|
|
return self.mod(torch.mm(x, self.param)) + 1.0
|
|
|
|
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
|
|
|
|
# Note: the parameters from both modules should appear in the flattened
|
|
# inputs of the graph. All ops from both modules should be inlined.
|
|
self.assertTrue(len(list(tm.graph.inputs())) == 3)
|
|
FileCheck().check_count("aten::mm", 2).check("aten::add").run(str(tm.graph))
|
|
|
|
def test_call_script_fn_from_traced_module(self):
|
|
@torch.jit.script
|
|
def traced_fn(x):
|
|
return torch.neg(x)
|
|
|
|
class TracedModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TracedModule, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 5))
|
|
|
|
def forward(self, x):
|
|
return traced_fn(torch.mm(x, self.param))
|
|
|
|
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
|
|
# Note: neg op from the script function should be properly inlined
|
|
FileCheck().check("aten::mm").check("aten::neg").run(str(tm.graph))
|
|
|
|
def test_call_script_module_from_traced_module(self):
|
|
class ScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod, self).__init__()
|
|
self.param_foo = torch.nn.Parameter(torch.rand(5, 7))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.mm(x, self.param_foo)
|
|
|
|
class TracedModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TracedModule, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 5))
|
|
self.mod = ScriptMod()
|
|
|
|
def forward(self, x):
|
|
return self.mod(torch.mm(x, self.param)) + 1.0
|
|
|
|
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
|
|
|
|
# Note: the parameters from both modules should appear in the flattened
|
|
# inputs of the graph. All ops from both modules should be inlined.
|
|
self.assertTrue(len(list(tm.graph.inputs())) == 3)
|
|
FileCheck().check_count("aten::mm", 2).check("aten::add").run(str(tm.graph))
|
|
|
|
def test_call_python_fn_from_script_fn(self):
|
|
def python_fn(x):
|
|
return torch.neg(x)
|
|
|
|
@torch.jit.script
|
|
def script_fn(x):
|
|
return python_fn(x) + 1
|
|
|
|
# Note: the call to python_fn appears as `^python_fn()` and is called
|
|
# as a PythonOp in the interpreter
|
|
a = torch.tensor(1)
|
|
self.assertEqual(script_fn(a), torch.tensor(0))
|
|
FileCheck().check("python_fn").run(str(script_fn.graph))
|
|
|
|
def test_call_python_mod_from_script_fn(self):
|
|
class PythonModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(PythonModule, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(5, 7))
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, self.param)
|
|
|
|
pm = PythonModule()
|
|
|
|
@torch.jit.script
|
|
def script_fn(x):
|
|
return pm(x) + 1
|
|
|
|
# Note: call to pm(x) appears as ^<python_value>() in the trace.
|
|
# Parameters are NOT inlined.
|
|
FileCheck().check("python_value").check("aten::add").run(str(script_fn.graph))
|
|
|
|
def test_call_traced_fn_from_script_fn(self):
|
|
@_trace(torch.rand(3, 4))
|
|
def traced_fn(x):
|
|
return torch.neg(x)
|
|
|
|
@torch.jit.script
|
|
def script_fn(x):
|
|
return traced_fn(x) + 1
|
|
|
|
# Note: the neg op from traced_fn should be properly inlined into the
|
|
# script function's graph
|
|
FileCheck().check("aten::neg").check("aten::add").run(str(script_fn.graph))
|
|
|
|
def test_call_traced_mod_from_script_fn(self):
|
|
class TracedModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TracedModule, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, torch.zeros(4, 3))
|
|
|
|
tm = torch.jit.trace(TracedModule(), torch.rand(3, 4))
|
|
|
|
@torch.jit.script
|
|
def script_fn(x):
|
|
return tm(x) + 1
|
|
|
|
FileCheck().check("aten::zeros").check_same("scope: TracedModule").check("aten::mm") \
|
|
.check("aten::add").run(str(script_fn.graph))
|
|
|
|
def test_call_script_fn_from_script_fn(self):
|
|
@torch.jit.script
|
|
def script_fn1(x):
|
|
return torch.neg(x)
|
|
|
|
@torch.jit.script
|
|
def script_fn(x):
|
|
return script_fn1(x) + 1
|
|
|
|
# Note: the neg op from script_fn1 should be properly inlined into the
|
|
# graph of script_fn
|
|
FileCheck().check("aten::neg").run(str(script_fn.graph))
|
|
|
|
def test_call_script_mod_from_script_fn(self):
|
|
class ScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.mm(x, torch.zeros([4, 3]))
|
|
|
|
sm = ScriptMod()
|
|
|
|
@torch.jit.script
|
|
def script_fn(x):
|
|
return sm(x) + 1
|
|
|
|
FileCheck().check("zeros").check("aten::mm").check("add").run(str(script_fn.graph))
|
|
|
|
def test_call_python_fn_from_script_module(self):
|
|
def python_fn(x):
|
|
return torch.neg(x)
|
|
|
|
class ScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 3))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return python_fn(torch.mm(x, self.param))
|
|
|
|
sm = ScriptMod()
|
|
FileCheck().check("aten::mm").check("python_fn") \
|
|
.run(str(sm.__getattr__('forward').graph))
|
|
|
|
def test_call_python_mod_from_script_module(self):
|
|
class PythonMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super(PythonMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 5))
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, self.param)
|
|
|
|
class ScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 3))
|
|
self.pm = PythonMod()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.pm(torch.mm(x, self.param))
|
|
|
|
sm = ScriptMod()
|
|
# Note: the call into PythonMod appears as ^<python_value>(). Parameters
|
|
# are NOT inlined
|
|
FileCheck().check("aten::mm").check("python_value").run(str(sm.graph))
|
|
|
|
def test_call_tracing_fn_from_script_module(self):
|
|
@_trace(torch.rand(3, 3))
|
|
def traced_fn(x):
|
|
return torch.neg(x)
|
|
|
|
class ScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 3))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return traced_fn(torch.mm(x, self.param))
|
|
|
|
sm = ScriptMod()
|
|
FileCheck().check("aten::mm").check("aten::neg").run(str(sm.__getattr__('forward').graph))
|
|
|
|
def test_call_tracing_mod_from_script_module(self):
|
|
class TracedMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TracedMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 5))
|
|
|
|
def forward(self, x):
|
|
return torch.mm(x, self.param)
|
|
|
|
class ScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 3))
|
|
self.tm = torch.jit.trace(TracedMod(), torch.rand(3, 3))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.tm(torch.mm(x, self.param))
|
|
|
|
sm = ScriptMod()
|
|
# Note: the parameters from both modules should appear in the flattened
|
|
# input list to the graph. The mm op from TracedMod should be properly
|
|
# inlined
|
|
self.assertTrue(len(list(sm.graph.inputs())) == 3)
|
|
FileCheck().check("aten::mm").check("aten::mm").run(str(sm.graph))
|
|
|
|
def test_call_script_fn_from_script_module(self):
|
|
@torch.jit.script
|
|
def script_fn(x):
|
|
return torch.neg(x)
|
|
|
|
class ScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 3))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return script_fn(torch.mm(x, self.param))
|
|
|
|
sm = ScriptMod()
|
|
graph = (sm.__getattr__('forward').graph)
|
|
FileCheck().check("aten::mm").check("aten::neg").run(str(graph))
|
|
|
|
def test_call_script_mod_from_script_module(self):
|
|
class ScriptMod1(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod1, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 5))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.mm(x, self.param)
|
|
|
|
class ScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(4, 3))
|
|
self.tm = ScriptMod1()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.tm(torch.mm(x, self.param))
|
|
|
|
sm = ScriptMod()
|
|
# Note: the parameters from both modules should appear in the flattened
|
|
# input list to the graph. The mm op from ScriptMod1 should be properly
|
|
# inlined
|
|
# 3 % values in graph input lists, two mms in body
|
|
FileCheck().check_count('%', 3).check(":").check_count("mm", 2).run(str(sm.graph))
|
|
|
|
def test_module_with_params_called_fails(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with parameters. Stateful "
|
|
"modules to be inlined must be submodules of the callee."):
|
|
class ScriptMod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(ScriptMod, self).__init__()
|
|
self.param = torch.nn.Parameter(torch.rand(3, 3))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return torch.mm(x, self.param)
|
|
|
|
sm = ScriptMod()
|
|
|
|
@torch.jit.script
|
|
def some_func(x):
|
|
return sm(x)
|
|
|
|
def test_index_put_trace_with_view(self):
|
|
@_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(1, 1, 1, 4))
|
|
def test_index_put(target, indices, rhs):
|
|
target[indices] = rhs
|
|
return target
|
|
|
|
FileCheck().check("aten::view").check("index_put_").run(str(test_index_put.graph))
|
|
|
|
def test_index_put_trace_without_view(self):
|
|
@_trace(torch.rand(100), torch.tensor([1, 2, 3, 4]), torch.rand(4))
|
|
def test_index_put(target, indices, rhs):
|
|
target[indices] = rhs
|
|
return target
|
|
|
|
FileCheck().check_not("aten::view").check("index_put_").run(str(test_index_put.graph))
|
|
|
|
def test_tuple_indexing(self):
|
|
def tuple_index(a):
|
|
if bool(a):
|
|
b = (1, 2)
|
|
else:
|
|
b = (0, 2)
|
|
return b[-2], b[1]
|
|
|
|
self.checkScript(tuple_index, (torch.tensor([0]),))
|
|
self.checkScript(tuple_index, (torch.tensor([1]),))
|
|
self.checkScript(tuple_index, (torch.tensor([1]),), optimize=True)
|
|
tuple_comp = torch.jit.script(tuple_index)
|
|
FileCheck().check_count("TupleIndex", 2, exactly=True).run(str(tuple_comp.graph))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "tuple indices must be integer constants"):
|
|
@torch.jit.script
|
|
def test_non_constant_input(a):
|
|
if bool(a):
|
|
b = 1
|
|
else:
|
|
b = 0
|
|
c = (0, 1)
|
|
return c[b]
|
|
|
|
def test_indexing_float():
|
|
c = (1, 2)
|
|
return c[0.1]
|
|
self.checkScriptRaisesRegex(test_indexing_float, (), Exception,
|
|
"tuple indices must")
|
|
|
|
def test_indexing_out_of_bounds_pos():
|
|
c = (1, 2)
|
|
return c[2]
|
|
|
|
self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
|
|
"out of range")
|
|
|
|
def test_indexing_out_of_bounds_neg():
|
|
c = (1, 2)
|
|
return c[-3]
|
|
|
|
self.checkScriptRaisesRegex(test_indexing_out_of_bounds_pos, (), Exception,
|
|
"out of range")
|
|
|
|
def test_namedtuple_attr(self):
|
|
def f(x):
|
|
return x.max(dim=1).indices + torch.max(x, dim=1).indices
|
|
|
|
self.checkScript(f, (torch.rand(20, 20, 20),), optimize=True)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Unknown attribute to named tuple"):
|
|
@torch.jit.script
|
|
def g1(x):
|
|
return x.max(dim=1).unknown_symbol
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Getting attributes of tuples is not supported"):
|
|
@torch.jit.script
|
|
def g2(x):
|
|
print((x, x, x).__doc__)
|
|
return x
|
|
|
|
def test_tuple_slicing(self):
|
|
def tuple_slice(a):
|
|
if bool(a):
|
|
b = (1, 2, 3, 4)
|
|
else:
|
|
b = (4, 3, 2, 1)
|
|
c = b[-4:4]
|
|
e = c[1:-1]
|
|
return e
|
|
|
|
self.checkScript(tuple_slice, (torch.tensor([1]),), optimize=True)
|
|
tuple_graph = torch.jit.script(tuple_slice).graph
|
|
slices = tuple_graph.findAllNodes("prim::TupleSlice")
|
|
num_outputs = set(map(lambda x: len(x.output().type().elements()), slices))
|
|
# one tuple slice should have an output with 2 elements, other 4
|
|
self.assertTrue(num_outputs == {2, 4})
|
|
self.run_pass('lower_all_tuples', tuple_graph)
|
|
self.assertTrue('Tuple' not in str(tuple_graph))
|
|
tuple_comp = torch.jit.script(tuple_slice)
|
|
self.assertEqual(tuple_comp(torch.tensor(1)), (2, 3))
|
|
|
|
@torch.jit.script
|
|
def test_indexing_end_out_of_bounds():
|
|
c = (1, 2)
|
|
return c[2:10]
|
|
|
|
self.assertEqual(test_indexing_end_out_of_bounds(), ())
|
|
|
|
def test_unwrap_optional_builtin(self):
|
|
def test(x):
|
|
# type: (Optional[int]) -> int
|
|
x = torch.jit._unwrap_optional(x)
|
|
x = x + x # noqa: T484
|
|
return x
|
|
|
|
self.checkScript(test, (3,))
|
|
|
|
with self.assertRaisesRegex(AssertionError, "Unwrapping null optional"):
|
|
test(None)
|
|
|
|
test_script = torch.jit.script(test)
|
|
with self.assertRaisesRegex(RuntimeError, "Unwrapping null optional"):
|
|
test_script(None)
|
|
|
|
@torch.jit.script
|
|
def test_test():
|
|
return torch.jit._unwrap_optional(1)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "cannot match an Optional\\[T\\] to None"):
|
|
@torch.jit.script
|
|
def test_no_type():
|
|
# type: () -> int
|
|
return torch.jit._unwrap_optional(None)
|
|
|
|
def test_indexing_error(self):
|
|
with self.assertRaisesRegex(RuntimeError, "only supported on lists, dictionaries, tensors, and tuples"):
|
|
@torch.jit.script
|
|
def test_wrong_type():
|
|
a = 8
|
|
return a[0]
|
|
|
|
def test_annotated_script_fn(self):
|
|
@torch.jit.script
|
|
def foo(x, y, z):
|
|
# type: (Tensor, Tuple[Tensor, Tensor, Tensor], Tuple[Tensor, Tuple[Tensor, Tensor]]) -> Tensor
|
|
return x
|
|
|
|
self.assertExpected(foo.__getattr__('forward').pretty_print_schema())
|
|
|
|
def test_annotated_script_method(self):
|
|
class SM(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
# type: (Tuple[Tensor, Tensor], Tensor) -> Tuple[Tensor, Tensor, Tensor]
|
|
return y, y, y
|
|
|
|
sm = SM()
|
|
|
|
self.assertExpected(sm.__getattr__('forward').pretty_print_schema())
|
|
|
|
def test_annotated_script_fn_return_mismatch(self):
|
|
with self.assertRaisesRegex(RuntimeError, "but is actually of type"):
|
|
@torch.jit.script
|
|
def return_tup(x):
|
|
# type: (Tensor) -> Tuple[Tuple[Tensor, Tensor], Tensor]
|
|
return x, x # noqa: T484
|
|
|
|
def test_annotated_script_fn_arg_mismatch(self):
|
|
with self.assertRaisesRegex(RuntimeError, r"arguments for call are not valid"):
|
|
@torch.jit.script
|
|
def tuple_arg(x):
|
|
# type: (Tuple[Tensor, Tensor]) -> Tensor
|
|
return x + 1 # noqa: T484
|
|
|
|
def test_script_non_tensor_args_outputs(self):
|
|
@torch.jit.script
|
|
def fn(x, y):
|
|
# type: (Tensor, float) -> float
|
|
return float((x + y).sum())
|
|
|
|
x = torch.ones(2, 2)
|
|
z = fn(x, 1)
|
|
self.assertIsInstance(z, float)
|
|
self.assertEqual(z, 8.)
|
|
|
|
@unittest.skip('https://github.com/pytorch/pytorch/issues/9595')
|
|
def test_inline_and_run_annotated_script_fn(self):
|
|
@torch.jit.script
|
|
def to_inline(x, y):
|
|
# type: (Tuple[Tensor, Tensor], Tensor) -> Tensor
|
|
return y
|
|
|
|
@torch.jit.script
|
|
def some_func(x):
|
|
return to_inline((x, x), x)
|
|
|
|
x = torch.rand(3, 4)
|
|
self.assertEqual(some_func(x), x)
|
|
|
|
def test_file_format_serialization(self):
|
|
import tempfile
|
|
filename = tempfile.mktemp()
|
|
writer = torch._C.PyTorchFileWriter(filename)
|
|
import os
|
|
import random
|
|
buffers = [os.urandom(size) for size in [random.randint(1, 100) for i in range(20)]]
|
|
offsets = []
|
|
for i, buf in enumerate(buffers):
|
|
writer.write_record(str(i), buf, len(buf))
|
|
offsets.append(i)
|
|
import pickle
|
|
serialized_offsets = pickle.dumps(offsets)
|
|
writer.write_record("meta", serialized_offsets, len(serialized_offsets))
|
|
writer.write_end_of_file()
|
|
|
|
reader = torch._C.PyTorchFileReader(filename)
|
|
serialized_offsets_read = reader.get_record("meta")
|
|
parsed_serialized_offsets = pickle.loads(serialized_offsets)
|
|
|
|
for i, offset in enumerate(parsed_serialized_offsets):
|
|
data = reader.get_record(str(offset))
|
|
assert(data == buffers[i])
|
|
|
|
# for each type, the input type annotation and corresponding return type annotation
|
|
def type_input_return_pairs(self):
|
|
return [
|
|
('Tensor', 'Tensor'),
|
|
('torch.Tensor', 'Tensor'),
|
|
('str', 'str'),
|
|
('int', 'int'),
|
|
('bool', 'bool'),
|
|
('BroadcastingList3[float]', 'List[float]'),
|
|
('BroadcastingList2[int]', 'List[int]'),
|
|
('List[int]', 'List[int]'),
|
|
('Optional[int]', 'Optional[int]'),
|
|
]
|
|
|
|
# replacing code input & return type pair
|
|
def format_code(self, code, pair):
|
|
return code.format(input=pair[0], output=pair[1])
|
|
|
|
# ***** Type annotation tests ****
|
|
# Test combinations of:
|
|
# {String frontend, Python AST Frontend}
|
|
# {Python 3-style type annotations, MyPy-style type comments}
|
|
# {Script method, Script function}
|
|
|
|
# String frontend , Python 3-style type annotations , Script function
|
|
def test_annot_string_py3_fn(self):
|
|
code = '''
|
|
def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
|
|
return x, x
|
|
'''
|
|
test_str = []
|
|
for pair in self.type_input_return_pairs():
|
|
cu = torch.jit.CompilationUnit(self.format_code(code, pair))
|
|
test_str.append(cu.__getattr__('foo').pretty_print_schema())
|
|
self.assertExpected("\n".join(test_str))
|
|
|
|
# String frontend , Python 3-style type annotations , Script method
|
|
def test_annot_string_py3_method(self):
|
|
class TestModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
|
|
code = '''
|
|
def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
|
|
return x, x
|
|
'''
|
|
test_str = []
|
|
for pair in self.type_input_return_pairs():
|
|
tm = TestModule()
|
|
tm.define(self.format_code(code, pair))
|
|
test_str.append(tm.__getattr__('foo').pretty_print_schema())
|
|
self.assertExpected("\n".join(test_str))
|
|
|
|
# String frontend , MyPy-style type comments , Script function
|
|
def test_annot_string_mypy_fn(self):
|
|
code = '''
|
|
def foo(x, y):
|
|
# type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
|
|
return x, x
|
|
'''
|
|
test_str = []
|
|
for pair in self.type_input_return_pairs():
|
|
cu = torch.jit.CompilationUnit(self.format_code(code, pair))
|
|
test_str.append(cu.__getattr__('foo').pretty_print_schema())
|
|
self.assertExpected("\n".join(test_str))
|
|
|
|
# String frontend , MyPy-style type comments , Script method
|
|
def test_annot_string_mypy_method(self):
|
|
class TestModule(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(TestModule, self).__init__()
|
|
|
|
code = '''
|
|
def foo(self, x, y):
|
|
# type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
|
|
return x, x
|
|
'''
|
|
|
|
test_str = []
|
|
for pair in self.type_input_return_pairs():
|
|
tm = TestModule()
|
|
tm.define(self.format_code(code, pair))
|
|
test_str.append(tm.__getattr__('foo').pretty_print_schema())
|
|
self.assertExpected("\n".join(test_str))
|
|
|
|
# Helper function to eval Python3 code without causing a syntax error for
|
|
# this file under py2
|
|
def _get_py3_code(self, code, fn_name):
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
script_path = os.path.join(tmp_dir, 'script.py')
|
|
with open(script_path, 'w') as f:
|
|
f.write(code)
|
|
import importlib.util
|
|
spec = importlib.util.spec_from_file_location(fn_name, script_path)
|
|
module = importlib.util.module_from_spec(spec)
|
|
spec.loader.exec_module(module)
|
|
fn = getattr(module, fn_name)
|
|
return fn
|
|
|
|
# Python AST Frontend , Python 3-style type annotations , Script function
|
|
@unittest.skipIf(not PY35, "Python 3.5 needed")
|
|
def test_annot_ast_py3_fn(self):
|
|
code = dedent('''
|
|
from typing import Tuple, List, Optional
|
|
from torch import Tensor
|
|
from torch.jit.annotations import BroadcastingList2, BroadcastingList3
|
|
import torch
|
|
@torch.jit.script
|
|
def foo(x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
|
|
return x, x
|
|
''')
|
|
test_str = []
|
|
for pair in self.type_input_return_pairs():
|
|
fn = self._get_py3_code(self.format_code(code, pair), 'foo')
|
|
test_str.append(fn.__getattr__('forward').pretty_print_schema())
|
|
self.assertExpected("\n".join(test_str))
|
|
|
|
# Python AST Frontend , Python 3-style type annotations , Script method
|
|
@unittest.skipIf(not PY35, "Python 3.5 needed")
|
|
def test_annot_ast_py3_method(self):
|
|
code = dedent('''
|
|
from typing import Tuple, List, Optional
|
|
from torch import Tensor
|
|
from torch.jit.annotations import BroadcastingList2, \\
|
|
BroadcastingList3
|
|
import torch
|
|
class FooModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def foo(self, x : {input}, y : Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]:
|
|
return x, x
|
|
instance = FooModule()
|
|
''')
|
|
|
|
test_str = []
|
|
for pair in self.type_input_return_pairs():
|
|
fn = self._get_py3_code(self.format_code(code, pair), 'instance')
|
|
test_str.append(fn.__getattr__('foo').pretty_print_schema())
|
|
self.assertExpected("\n".join(test_str))
|
|
|
|
# Python AST Frontend , MyPy-style type comments , Script function
|
|
@unittest.skipIf(not PY35, "Python 3.5 needed")
|
|
def test_annot_ast_mypy_fn(self):
|
|
code = dedent('''
|
|
import torch
|
|
@torch.jit.script
|
|
def foo(x, y):
|
|
# type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
|
|
return x, x
|
|
''')
|
|
|
|
test_str = []
|
|
for pair in self.type_input_return_pairs():
|
|
fn = self._get_py3_code(self.format_code(code, pair), 'foo')
|
|
test_str.append(fn.__getattr__('forward').pretty_print_schema())
|
|
self.assertExpected("\n".join(test_str))
|
|
|
|
# Python AST Frontend , MyPy-style type comments , Script method
|
|
@unittest.skipIf(not PY35, "Python 3.5 needed")
|
|
def test_annot_ast_mypy_method(self):
|
|
code = dedent('''
|
|
import torch
|
|
class FooModule(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def foo(self, x, y):
|
|
# type: ({input}, Tuple[Tensor, Tensor]) -> Tuple[{output}, {output}]
|
|
return x, x
|
|
instance = FooModule()
|
|
''')
|
|
|
|
test_str = []
|
|
for pair in self.type_input_return_pairs():
|
|
fn = self._get_py3_code(self.format_code(code, pair), 'instance')
|
|
test_str.append(fn.__getattr__('foo').pretty_print_schema())
|
|
self.assertExpected("\n".join(test_str))
|
|
|
|
def test_method_casts_script(self):
|
|
cast_types = [
|
|
'byte', 'char', 'double', 'float', 'int', 'long', 'short'
|
|
]
|
|
|
|
for cast_type in cast_types:
|
|
cu = torch.jit.CompilationUnit('''
|
|
def cast_to(x):
|
|
return x.{cast_type}()
|
|
'''.format(cast_type=cast_type))
|
|
|
|
x = torch.rand(3, 4, 5) * 128
|
|
cu_result = cu.cast_to(x)
|
|
reference = getattr(x, cast_type)()
|
|
self.assertEqual(cu_result, reference)
|
|
|
|
def test_listconstruct_erasure(self):
|
|
class FooMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
mask = x < 0.0
|
|
return x[mask]
|
|
|
|
import io
|
|
f = io.BytesIO()
|
|
self.assertExpected(torch.onnx.export_to_pretty_string(
|
|
FooMod(), (torch.rand(3, 4),), f,
|
|
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK))
|
|
|
|
def test_trace_checker_arange_as_constant(self):
|
|
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
|
|
@_trace(torch.rand(3, 4), check_inputs=[(torch.rand(4, 5),)])
|
|
def foo(x):
|
|
y = torch.arange(0, x.shape[0]).double()
|
|
return x + y.unsqueeze(1)
|
|
|
|
@suppress_warnings
|
|
def test_trace_checker_dot_data(self):
|
|
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Tensor-valued Constant nodes differed in value '
|
|
r'across invocations'):
|
|
@_trace(torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
|
|
def foo(x):
|
|
y = x.data
|
|
return x + y
|
|
|
|
@suppress_warnings
|
|
def test_trace_checker_control_flow(self):
|
|
def foo(x):
|
|
for _ in range(x.size(0)):
|
|
x = torch.neg(x)
|
|
return x
|
|
|
|
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
|
|
torch.jit.trace(foo, torch.randn(3, 4), check_inputs=[torch.randn(4, 4)])
|
|
|
|
@suppress_warnings
|
|
def test_trace_checker_memoization(self):
|
|
with self.assertRaisesRegex(torch.jit.TracingCheckError, r'Graphs differed across invocations!'):
|
|
def foo(x):
|
|
if not hasattr(foo, 'cache'):
|
|
foo.cache = torch.neg(x)
|
|
return x + foo.cache
|
|
|
|
traced = torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[(torch.rand(3, 4),)])
|
|
|
|
# These tests don't work because UBSAN has a false positive about accessing
|
|
# out of bounds on a dynamically sized struct internal to asmjit
|
|
if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
|
|
def test_int8_quantization_module(self):
|
|
K1, N1 = 2, 2
|
|
|
|
class FooBar(torch.nn.Module):
|
|
def __init__(self):
|
|
super(FooBar, self).__init__()
|
|
self.linear1 = torch.nn.Linear(K1, N1).float()
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
return x
|
|
|
|
fb = FooBar()
|
|
fb.linear1.weight = torch.nn.Parameter(
|
|
torch.tensor([[-150, 100], [100, -150]], dtype=torch.float), requires_grad=False)
|
|
fb.linear1.bias = torch.nn.Parameter(torch.zeros_like(fb.linear1.bias), requires_grad=False)
|
|
fb_ref = FooBar()
|
|
fb_ref.linear1.weight = torch.nn.Parameter(fb.linear1.weight.clone(), requires_grad=False)
|
|
fb_ref.linear1.bias = torch.nn.Parameter(fb.linear1.bias.clone(), requires_grad=False)
|
|
fb = torch.jit.quantized.quantize_linear_modules(fb)
|
|
|
|
x = (torch.rand(1, K1).float() - 0.5) / 10.0
|
|
traced = torch.jit.trace(fb, (x,))
|
|
fb = self.getExportImportCopyWithPacking(traced)
|
|
|
|
x = torch.tensor([[100, -150]], dtype=torch.float)
|
|
y = fb(x)
|
|
y_ref = fb_ref(x)
|
|
torch.testing.assert_allclose(y, y_ref, rtol=0.0001, atol=1e-3)
|
|
|
|
def checkTracerWarning(self, *args, **kwargs):
|
|
with warnings.catch_warnings(record=True) as warns:
|
|
torch.jit.trace(*args, **kwargs)
|
|
self.assertGreater(len(warns), 0)
|
|
for warn in warns:
|
|
self.assertIn("cause the trace to be incorrect", str(warn.message))
|
|
|
|
def test_trace_checker_slice_lhs(self):
|
|
def foo(x):
|
|
for i in range(3):
|
|
x[i, :] = torch.zeros(4)
|
|
return x
|
|
|
|
self.checkTrace(foo, (torch.rand(3, 4),))
|
|
|
|
def test_trace_checker_inplace_on_view(self):
|
|
def foo(x):
|
|
x.view(-1).add_(-x.view(-1))
|
|
return x
|
|
|
|
self.assertWarnsRegex(lambda: torch.jit.trace(foo,
|
|
torch.rand(3, 4),
|
|
check_inputs=[torch.rand(5, 6)],
|
|
_force_outplace=True),
|
|
'Output nr 1. of the traced function does not match the '
|
|
'corresponding output of the Python function')
|
|
|
|
def test_lhs_index_fails(self):
|
|
def foo(x):
|
|
x[0, 1] = 4
|
|
return x
|
|
self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
|
|
|
|
def test_lhs_index_trivial(self):
|
|
def foo(y, x):
|
|
y[...] = x
|
|
return y
|
|
self.checkTrace(foo, (torch.rand(3, 4), torch.rand(4)), inputs_require_grads=False)
|
|
|
|
def test_inplace_warn(self):
|
|
def foo(x):
|
|
x.view(-1).add_(-x.view(-1))
|
|
return x
|
|
self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
|
|
|
|
@suppress_warnings
|
|
def test_trace_checker_dropout_train(self):
|
|
def foo(x):
|
|
return torch.dropout(x, p=0.5, train=True)
|
|
|
|
self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
|
|
'Output nr 1. of the traced function does not match the '
|
|
'corresponding output of the Python function')
|
|
self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
|
|
'Trace had nondeterministic nodes')
|
|
|
|
def test_trace_checker_dropout_notrain(self):
|
|
input = torch.rand(3, 4)
|
|
|
|
@_trace(input)
|
|
def foo(x):
|
|
return torch.dropout(x, p=0.5, train=False)
|
|
|
|
self.assertEqual(foo(input), input)
|
|
|
|
def test_export_dynamic_slice(self):
|
|
class DynamicSliceExportMod(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
retval = x[0]
|
|
for i in range(x.size(1)):
|
|
retval += torch.sum(x[0:i], dim=0)
|
|
return retval
|
|
|
|
mod = DynamicSliceExportMod()
|
|
|
|
input = torch.rand(3, 4, 5)
|
|
example_outs = mod(input)
|
|
|
|
f = io.BytesIO()
|
|
exported = torch.onnx.export_to_pretty_string(
|
|
DynamicSliceExportMod(), (input,), f, example_outputs=example_outs)
|
|
self.assertExpected(exported)
|
|
|
|
def test_string_frontend_elif(self):
|
|
code = '''
|
|
def elif_test(niter : int):
|
|
rv = 0
|
|
for i in range(niter):
|
|
if i % 3 == 0 and i % 5 == 0:
|
|
rv += 35
|
|
elif i % 3 == 0:
|
|
rv += 3
|
|
elif i % 5 == 0:
|
|
rv += 5
|
|
else:
|
|
rv += i
|
|
return rv
|
|
'''
|
|
|
|
self.checkScript(code, (101,), name='elif_test', outputs=3028)
|
|
|
|
def test_pyop_exception_message(self):
|
|
class Foo(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Foo, self).__init__()
|
|
self.conv = nn.Conv2d(1, 10, kernel_size=5)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.conv(x)
|
|
foo = Foo()
|
|
# testing that the correct error message propagates
|
|
with self.assertRaisesRegex(RuntimeError, "Expected 4-dimensional input for 4-dimensional weight"):
|
|
foo(torch.ones([123])) # wrong size
|
|
|
|
def test_builtin_error_messsage(self):
|
|
from torch.nn.modules.utils import _single, _pair, _triple, _quadruple
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "arguments for call are not valid"):
|
|
@torch.jit.script
|
|
def close_match(x):
|
|
return x.masked_fill(True)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "This op may not exist or may not be currently "
|
|
"supported in TorchScript"):
|
|
@torch.jit.script
|
|
def unknown_op(x):
|
|
torch.set_grad_enabled(True)
|
|
return x
|
|
|
|
def test_exceptions(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(cond):
|
|
if bool(cond):
|
|
raise ValueError(3)
|
|
return 1
|
|
''')
|
|
|
|
cu.foo(torch.tensor(0))
|
|
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
|
|
cu.foo(torch.tensor(1))
|
|
|
|
@torch.jit.script
|
|
def foo(cond):
|
|
a = 3
|
|
if bool(cond):
|
|
raise ArbitraryError(a, "hi")
|
|
if False:
|
|
raise ArbitraryError
|
|
return a
|
|
|
|
foo(torch.tensor(0))
|
|
# we don't currently validate the name of the exception
|
|
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
|
|
foo(torch.tensor(1))
|
|
|
|
@torch.jit.script
|
|
def foo_except_used():
|
|
a = Exception()
|
|
print(a)
|
|
raise a
|
|
|
|
# a not DCEd
|
|
with self.assertRaisesRegex(RuntimeError, "expected value of type Tensor"):
|
|
foo_except_used()
|
|
|
|
# We don't validate the expr following raise
|
|
@torch.jit.script
|
|
def foo():
|
|
raise 3 + 4
|
|
|
|
# no control flow analysis yet
|
|
with self.assertRaisesRegex(RuntimeError, "undefined value a"):
|
|
@torch.jit.script
|
|
def foo():
|
|
if True:
|
|
a = 1
|
|
else:
|
|
raise Exception("Hi")
|
|
return a
|
|
|
|
def test_assertions(self):
|
|
cu = torch.jit.CompilationUnit('''
|
|
def foo(cond):
|
|
assert bool(cond), "hi"
|
|
return 0
|
|
''')
|
|
|
|
cu.foo(torch.tensor(1))
|
|
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
|
|
cu.foo(torch.tensor(0))
|
|
|
|
@torch.jit.script
|
|
def foo(cond):
|
|
assert bool(cond), "hi"
|
|
|
|
foo(torch.tensor(1))
|
|
# we don't currently validate the name of the exception
|
|
with self.assertRaisesRegex(torch.jit.Error, "Exception"):
|
|
foo(torch.tensor(0))
|
|
|
|
def test_weak_script_function(self):
|
|
outer_var = 10
|
|
outer_var2 = 11
|
|
|
|
def not_a_script_fn(x):
|
|
return x + 2
|
|
|
|
@torch.jit.script
|
|
def even_more_inner(x):
|
|
return x + 1
|
|
|
|
@torch.jit.script
|
|
def inner(x):
|
|
return not_a_script_fn(x) + x + even_more_inner(x)
|
|
|
|
@torch.jit.script
|
|
def strong_script_fn(x):
|
|
if bool(x.norm() > 2):
|
|
x = x + 3
|
|
return x + 4 + inner(x)
|
|
|
|
@torch._jit_internal.weak_script
|
|
def weak_script_fn_inner(x):
|
|
return x + 6 + not_a_script_fn(x)
|
|
|
|
@torch._jit_internal.weak_script
|
|
def weak_script_fn(x):
|
|
return x + 5 + weak_script_fn_inner(x) + weak_script_fn_inner(x)
|
|
|
|
def fn(x):
|
|
x = not_a_script_fn(x)
|
|
x = strong_script_fn(x)
|
|
return weak_script_fn(x)
|
|
|
|
input = torch.randn(3, 4, 5)
|
|
self.checkScript(fn, (input,))
|
|
|
|
def test_python_op_exception(self):
|
|
def python_op(x):
|
|
raise Exception("bad!")
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
return python_op(x)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "operation failed in interpreter"):
|
|
fn(torch.tensor(4))
|
|
|
|
def test_trace_contiguous(self):
|
|
def foo(x):
|
|
return x[:, :, ::2].contiguous().view(12)
|
|
|
|
x = torch.rand(2, 3, 4)
|
|
traced = torch.jit.trace(foo, (x,))
|
|
y = traced(x)
|
|
self.assertNotEqual(x.storage().data_ptr(), y.storage().data_ptr())
|
|
|
|
# This tests the logic in THPVariable_contiguous. There is short-circuiting
|
|
# code that prevents us from even getting to VariableType::contiguous, since
|
|
# it is an optimization that prevents us from acquiring the GIL for touching
|
|
# the device. We needed to add the tracing logic directly into the
|
|
# THPVariable_contiguous function only for the path where we are skipping
|
|
# dispatch into contiguous. We should see an aten::contiguous in this trace!
|
|
def test_trace_contiguous_short_circuit(self):
|
|
def foo(x):
|
|
return x.contiguous()
|
|
|
|
x = torch.rand(2, 3, 4)
|
|
traced = torch.jit.trace(foo, (x,))
|
|
FileCheck().check("aten::contiguous").run(str(traced.graph))
|
|
|
|
def test_weak_module(self):
|
|
|
|
@torch._jit_internal.weak_module
|
|
class Weak(torch.nn.Module):
|
|
__constants__ = ['number']
|
|
|
|
def __init__(self):
|
|
super(Weak, self).__init__()
|
|
self.number = 199
|
|
|
|
def python_op_in_weak_module(self, x):
|
|
return x + 123
|
|
|
|
@torch._jit_internal.weak_script_method
|
|
def forward(self, x):
|
|
return 55 + self.number + self.python_op_in_weak_module(x)
|
|
|
|
class OtherStrong(torch.jit.ScriptModule):
|
|
__constants__ = ['number']
|
|
|
|
def __init__(self):
|
|
super(OtherStrong, self).__init__()
|
|
self.number = 357
|
|
|
|
def python_op_in_strong_module(self, x):
|
|
return x + 456
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x + self.number + self.python_op_in_strong_module(x)
|
|
|
|
class Passthrough(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Passthrough, self).__init__()
|
|
self.weak = Weak()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.weak(x)
|
|
|
|
weak_mod = Weak()
|
|
x = torch.ones(1)
|
|
expected_result = 55 + 199 + (x + 123)
|
|
|
|
# Ensure weak mod is running without the JIT by passing the wrong type
|
|
# (i.e. not a tensor)
|
|
weak_mod(2)
|
|
|
|
python_result = weak_mod(x)
|
|
strong_mod = Passthrough()
|
|
script_result = strong_mod(x)
|
|
|
|
self.assertEqual(python_result, expected_result)
|
|
self.assertEqual(script_result, expected_result)
|
|
|
|
class Strong(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Strong, self).__init__()
|
|
self.weak = Weak()
|
|
self.strong = OtherStrong()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
y = 2 * x
|
|
return y + 1 + self.weak(y) + self.strong(y)
|
|
|
|
strong_mod = Strong()
|
|
strong_mod2 = Strong()
|
|
x = torch.ones(1)
|
|
expected_result = (x * 2) + 1 + (55 + 199 + x * 2 + 123) + (x * 2 + 357 + x * 2 + 456)
|
|
script_result = strong_mod(x)
|
|
script_result2 = strong_mod2(x)
|
|
self.assertEqual(script_result, expected_result)
|
|
self.assertEqual(script_result, script_result2)
|
|
|
|
def test_weak_module_parameters_and_buffers(self):
|
|
weights = torch.randn(10, 10)
|
|
bias = torch.randn(10)
|
|
weights2 = torch.randn(10, 10)
|
|
bias2 = torch.randn(10)
|
|
|
|
@torch._jit_internal.weak_module
|
|
class TestLinear(torch.nn.Module):
|
|
def __init__(self, in_features, out_features):
|
|
super(TestLinear, self).__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
|
|
self.bias = torch.nn.Parameter(torch.Tensor(out_features))
|
|
self.register_buffer('counter', torch.ones(out_features))
|
|
self.reset_parameters()
|
|
|
|
def reset_parameters(self):
|
|
torch.nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
|
if self.bias is not None:
|
|
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight)
|
|
bound = 1 / math.sqrt(fan_in)
|
|
torch.nn.init.uniform_(self.bias, -bound, bound)
|
|
|
|
@torch._jit_internal.weak_script_method
|
|
def forward(self, input):
|
|
return F.linear(input, self.weight, self.bias) + self.counter
|
|
|
|
# Initialize a ScriptModule that uses the weak module above multiple times
|
|
class Strong(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Strong, self).__init__()
|
|
self.fc1 = TestLinear(10, 10)
|
|
self.fc1.weight = torch.nn.Parameter(weights)
|
|
self.fc1.bias = torch.nn.Parameter(bias)
|
|
self.fc2 = TestLinear(10, 10)
|
|
self.fc2.weight = torch.nn.Parameter(weights2)
|
|
self.fc2.bias = torch.nn.Parameter(bias2)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x + self.fc1(x) + self.fc1(x) + self.fc2(x)
|
|
|
|
strong_mod = Strong()
|
|
|
|
# Run same calculation as module
|
|
inp = torch.ones(10)
|
|
lin = torch.nn.Linear(10, 10)
|
|
lin.weight = torch.nn.Parameter(weights)
|
|
lin.bias = torch.nn.Parameter(bias)
|
|
lin2 = torch.nn.Linear(10, 10)
|
|
lin2.weight = torch.nn.Parameter(weights2)
|
|
lin2.bias = torch.nn.Parameter(bias2)
|
|
expected_result = inp + (lin(inp) + torch.ones(10)) * 2 + lin2(inp) + torch.ones(10)
|
|
|
|
self.assertEqual(strong_mod(inp), expected_result)
|
|
self.assertExportImportModule(strong_mod, (inp,))
|
|
|
|
def test_weak_module_nested(self):
|
|
@torch._jit_internal.weak_module
|
|
class OtherWeak(torch.nn.Module):
|
|
__constants__ = ['constant']
|
|
|
|
def __init__(self, in_features, out_features):
|
|
super(OtherWeak, self).__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
|
|
self.bias = torch.nn.Parameter(torch.ones(out_features))
|
|
self.constant = 3
|
|
|
|
@torch._jit_internal.weak_script_method
|
|
def forward(self, x):
|
|
return x * x + self.constant + F.linear(x, self.weight, self.bias)
|
|
|
|
class OtherStrong(torch.jit.ScriptModule):
|
|
|
|
def __init__(self):
|
|
super(OtherStrong, self).__init__()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x + 27
|
|
|
|
@torch._jit_internal.weak_module
|
|
class Weak(torch.nn.Module):
|
|
def __init__(self, in_features, out_features):
|
|
super(Weak, self).__init__()
|
|
self.in_features = in_features
|
|
self.out_features = out_features
|
|
self.weight = torch.nn.Parameter(2 * torch.ones(out_features, in_features))
|
|
self.bias = torch.nn.Parameter(2 * torch.ones(out_features))
|
|
self.weak_submodule = OtherWeak(10, 10)
|
|
self.strong_submodule = OtherStrong()
|
|
|
|
@torch._jit_internal.weak_script_method
|
|
def forward(self, x):
|
|
return x + self.weak_submodule(x) + self.strong_submodule(x) \
|
|
+ F.linear(x, self.weight, self.bias)
|
|
|
|
class Strong(torch.jit.ScriptModule):
|
|
__constants__ = ['constant']
|
|
|
|
def __init__(self):
|
|
super(Strong, self).__init__()
|
|
self.weak = Weak(10, 10)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x + self.weak(x)
|
|
|
|
strong_mod = Strong()
|
|
inp = torch.randn(10)
|
|
result = strong_mod(inp)
|
|
expected_result = inp + (inp + inp * inp + inp + 27) + 3 \
|
|
+ F.linear(inp, torch.ones(10, 10), torch.ones(10)) \
|
|
+ F.linear(inp, 2 * torch.ones(10, 10), 2 * torch.ones(10))
|
|
self.assertEqual(result, expected_result)
|
|
|
|
def test_weak_module_submodule(self):
|
|
@torch._jit_internal.weak_module
|
|
class Weak(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Weak, self).__init__()
|
|
self.param = torch.nn.Parameter(100 * torch.ones(5))
|
|
|
|
@torch._jit_internal.weak_script_method
|
|
def forward(self, x):
|
|
return x + self.param
|
|
|
|
weak = Weak()
|
|
|
|
class OtherStrong(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(OtherStrong, self).__init__()
|
|
self.weak = weak
|
|
self.weak2 = weak
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return x + self.weak(x)
|
|
|
|
class Strong(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Strong, self).__init__()
|
|
self.weak = Weak()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.weak(x) + weak(x)
|
|
|
|
other_strong_mod = OtherStrong()
|
|
|
|
self.assertIs(other_strong_mod.weak, other_strong_mod.weak2)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "Attempted to inline a Module with param"):
|
|
strong_mod = Strong()
|
|
|
|
def test_weak_module_copying(self):
|
|
class Submodule(torch.nn.Module):
|
|
def __init__(self):
|
|
super(Submodule, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x + 100
|
|
|
|
@torch._jit_internal.weak_module
|
|
class Weak(torch.nn.Module):
|
|
def __init__(self, in_features, out_features):
|
|
super(Weak, self).__init__()
|
|
self.weight = torch.nn.Parameter(torch.ones(out_features, in_features))
|
|
self.bias = torch.nn.Parameter(torch.ones(out_features))
|
|
self.register_buffer("buffer", torch.ones(out_features))
|
|
self.submodule = Submodule()
|
|
|
|
@torch._jit_internal.weak_script_method
|
|
def forward(self, x):
|
|
return F.linear(x, self.weight, self.bias) \
|
|
+ self.buffer + self.submodule(x)
|
|
|
|
class Strong(torch.jit.ScriptModule):
|
|
def __init__(self, weak):
|
|
super(Strong, self).__init__()
|
|
self.weak = weak
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.weak(x)
|
|
|
|
inp = torch.ones(5, 5) * 5
|
|
weak_mod = Weak(5, 5)
|
|
strong_mod = Strong(weak_mod)
|
|
|
|
self.assertTrue(isinstance(strong_mod.weak, torch.jit.ScriptModule))
|
|
self.assertFalse(isinstance(weak_mod, torch.jit.ScriptModule))
|
|
|
|
self.assertIs(strong_mod.weak.weight, weak_mod.weight)
|
|
self.assertIs(strong_mod.weak.buffer, weak_mod.buffer)
|
|
self.assertIs(strong_mod.weak.submodule, weak_mod.submodule)
|
|
|
|
# Test lookup fallback
|
|
weak_mod.new_attribute = 10
|
|
self.assertIs(strong_mod.weak.new_attribute, weak_mod.new_attribute)
|
|
|
|
weak_mod.weight.data += torch.ones(5, 5) * 100
|
|
self.assertTrue(strong_mod(inp).allclose(weak_mod(inp)))
|
|
|
|
# Re-assignment is not tracked
|
|
weak_mod.weight = torch.nn.Parameter(torch.ones(5, 5) * 100)
|
|
self.assertFalse(strong_mod(inp).allclose(weak_mod(inp)))
|
|
|
|
def test_backend_cudnn_enabled(self):
|
|
# Only test that this compiles
|
|
@torch.jit.script
|
|
def fn(x):
|
|
if torch.backends.cudnn.enabled:
|
|
x = x + 2
|
|
else:
|
|
x = x + 3
|
|
return x
|
|
|
|
def test_inplace_add(self):
|
|
|
|
def foo(a, b):
|
|
c = a + b
|
|
c.add_(b)
|
|
return c
|
|
self.checkScript(foo, (torch.rand(3), torch.rand(3)))
|
|
|
|
def test_add_out(self):
|
|
def foo(a, b):
|
|
c = a + b
|
|
e = 2 * a
|
|
torch.add(c, b, out=e)
|
|
return e
|
|
self.checkScript(foo, (torch.rand(3), torch.rand(3)))
|
|
|
|
def test_augmented_assign(self):
|
|
def foo(a, b):
|
|
a += b
|
|
a -= b
|
|
a /= b
|
|
a *= b
|
|
return a, b
|
|
self.checkScript(foo, (torch.rand(3), torch.rand(3)))
|
|
|
|
def test_pass(self):
|
|
def foo(x):
|
|
# type: (bool) -> int
|
|
for _i in range(3):
|
|
pass
|
|
if x:
|
|
pass
|
|
else:
|
|
pass
|
|
return 3
|
|
|
|
self.checkScript(foo, (True,))
|
|
|
|
def test_optional_conversion(self):
|
|
@torch.jit.script
|
|
def other_fn(x=None):
|
|
# type: (Optional[int]) -> int
|
|
return torch.jit._unwrap_optional(x)
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
# type: (int) -> int
|
|
return other_fn(x)
|
|
|
|
self.assertEqual(fn(2), 2)
|
|
|
|
@torch.jit.script
|
|
def unify_to_optional(x):
|
|
# type: (bool) -> Optional[int]
|
|
if x:
|
|
a = None
|
|
else:
|
|
a = 2
|
|
return a
|
|
|
|
self.assertEqual(unify_to_optional(True), None)
|
|
self.assertEqual(unify_to_optional(False), 2)
|
|
|
|
@torch.jit.script
|
|
def opt_list(x):
|
|
# type: (Optional[List[float]]) -> int
|
|
return 2
|
|
|
|
@torch.jit.script
|
|
def broadcast_opt_list(x):
|
|
# type: (Optional[BroadcastingList2[float]]) -> int
|
|
return 2
|
|
|
|
@torch.jit.script
|
|
def opt_list_tuple_caller(x):
|
|
# type: (Tuple[float, float]) -> int
|
|
return opt_list(x) + broadcast_opt_list(x)
|
|
|
|
self.assertEqual(opt_list_tuple_caller((2., 3.)), 4)
|
|
|
|
def test_lhs_indexing(self):
|
|
def foo(a, b):
|
|
a = a.clone()
|
|
a[0] = b
|
|
return a
|
|
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
|
|
|
|
def test_lhs_advanced_indexing_assignment(self):
|
|
def foo(x, y):
|
|
a = torch.exp(x)
|
|
b = x == 1
|
|
a[b] = y[b]
|
|
return a
|
|
self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
|
|
|
|
def test_lhs_advanced_indexing_augmented_assignment(self):
|
|
def foo(x, y):
|
|
a = torch.exp(x)
|
|
b = x == 1
|
|
a[b] += y[b]
|
|
return a
|
|
self.checkScript(foo, (torch.ones(4, 3), torch.ones(4, 3)))
|
|
|
|
def test_lhs_indexing_list(self):
|
|
def foo(a, b):
|
|
ls = [a]
|
|
ls[0] = b
|
|
return ls
|
|
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
|
|
|
|
def test_inplace_copy_script(self):
|
|
def foo(x):
|
|
a = torch.rand(3, 4)
|
|
a.copy_(x)
|
|
return a
|
|
self.checkScript(foo, (torch.rand(3, 4),))
|
|
|
|
def test_lhs_indexing_increment(self):
|
|
def foo(a, b):
|
|
a[0] += b
|
|
return a
|
|
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
|
|
|
|
def test_lhs_indexing_increment_list(self):
|
|
def foo(a, b):
|
|
a = a.clone()
|
|
ls = [a, b]
|
|
ls[0] += b
|
|
return ls
|
|
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
|
|
|
|
def test_lhs_indexing_increment_list_prim(self):
|
|
def foo():
|
|
ls = [1, 2, 3]
|
|
ls[0] += 5
|
|
return ls
|
|
self.checkScript(foo, ())
|
|
|
|
def test_lhs_indexing_multi(self):
|
|
def foo(a, b):
|
|
a = a.clone()
|
|
foo, a[0], bar = (1, b, 3)
|
|
return foo, a, bar
|
|
self.checkScript(foo, (torch.rand(2, 3), torch.rand(3)))
|
|
|
|
def test_bool_dispatch(self):
|
|
with self.disableModuleHook(): # TODO: Python print broadcasting list
|
|
def kwarg_false(x):
|
|
# type: (Tensor) -> Tensor
|
|
return F.max_pool1d(x, 1, 1, return_indices=False)
|
|
self.checkScript(kwarg_false, (torch.randn(3, 3, 3),))
|
|
|
|
def kwarg_true(x):
|
|
# type: (Tensor) -> Tuple[Tensor, Tensor]
|
|
return F.max_pool1d(x, 1, 1, return_indices=True)
|
|
self.checkScript(kwarg_true, (torch.randn(3, 3, 3),))
|
|
|
|
def full_kwarg_false(x):
|
|
# type: (Tensor) -> Tensor
|
|
return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=False)
|
|
self.checkScript(full_kwarg_false, (torch.randn(3, 3, 3),))
|
|
|
|
def full_kwarg_true(x):
|
|
# type: (Tensor) -> Tuple[Tensor, Tensor]
|
|
return F.max_pool1d(x, 1, 1, ceil_mode=False, return_indices=True)
|
|
self.checkScript(full_kwarg_true, (torch.randn(3, 3, 3),))
|
|
|
|
def use_default(x):
|
|
# type: (Tensor) -> Tensor
|
|
return F.max_pool1d(x, 1, 1)
|
|
self.checkScript(use_default, (torch.randn(3, 3, 3),))
|
|
|
|
def arg_false(x):
|
|
# type: (Tensor) -> Tensor
|
|
return F.max_pool1d(x, 1, 1, 0, 1, False, False)
|
|
self.checkScript(arg_false, (torch.randn(3, 3, 3),))
|
|
|
|
def arg_true(x):
|
|
# type: (Tensor) -> Tuple[Tensor, Tensor]
|
|
return F.max_pool1d(x, 1, 1, 0, 1, False, True)
|
|
self.checkScript(arg_true, (torch.randn(3, 3, 3),))
|
|
|
|
def test_infer_size(self):
|
|
from torch._C import _infer_size
|
|
|
|
def fn(x, y):
|
|
# type: (Tensor, Tensor) -> List[int]
|
|
return _infer_size(x.size(), y.size())
|
|
|
|
self.checkScript(fn, (torch.ones(2, 4, 2), torch.ones(2, 4, 2)))
|
|
|
|
def test_mutable_dce(self):
|
|
@torch.jit.script
|
|
def foo():
|
|
a = torch.rand(2, 3)
|
|
a += torch.rand(2, 3)
|
|
b = torch.rand(2, 3)
|
|
b += torch.rand(2, 3)
|
|
# b should be cleaned up but not a
|
|
return a
|
|
|
|
FileCheck().check_count("aten::rand", 2, exactly=True) \
|
|
.check_count("aten::add", 1, exactly=True).run(str(foo.graph))
|
|
|
|
def test_mutable_dce_block(self):
|
|
@torch.jit.script
|
|
def foo():
|
|
a = torch.rand(2, 3)
|
|
a += torch.rand(2, 3)
|
|
b = torch.rand(2, 3)
|
|
if bool(a > torch.zeros(2, 3)):
|
|
b += torch.rand(2, 3)
|
|
a += torch.rand(2, 3)
|
|
# a should be cleaned up but not b
|
|
return b
|
|
|
|
FileCheck().check("prim::If").check_count("aten::rand", 1, exactly=True) \
|
|
.run(str(foo.graph))
|
|
|
|
def test_mutable_dce_graph_input(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
a += torch.rand(2, 3)
|
|
# shouldn't clean up `a` even though it's not used in the output
|
|
|
|
FileCheck().check("aten::rand").check("aten::add").run(str(foo.graph))
|
|
|
|
def test_mutable_dce_list(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
l = []
|
|
l.append(a)
|
|
c = l[0]
|
|
b = torch.rand(2, 3)
|
|
c += torch.rand(2, 3)
|
|
return b
|
|
|
|
# c does not get cleaned up because there is a wildcard + mutation
|
|
FileCheck().check_count("aten::rand", 2, exactly=True).run(str(foo.graph))
|
|
|
|
def test_mutable_dce_loop(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
l = []
|
|
l.append(a)
|
|
i = 0
|
|
b = torch.rand(2, 3)
|
|
while i < 1:
|
|
dead = torch.rand(2, 3)
|
|
c = l[0]
|
|
c += torch.rand(2, 3)
|
|
i += 1
|
|
return b
|
|
|
|
FileCheck().check("prim::Loop").check_not("aten::rand").check("aten::select") \
|
|
.check_count("aten::rand", 1, exactly=True).run(str(foo.graph))
|
|
|
|
def test_mutable_dce_wildcards(self):
|
|
def fn():
|
|
x = torch.ones(2, 3)
|
|
l = []
|
|
l.append(x)
|
|
x_view = l[0]
|
|
x.add_(torch.ones(2, 3))
|
|
return x_view
|
|
|
|
self.checkScript(fn, ())
|
|
|
|
def test_cpp_function_tensor_str(self):
|
|
x = torch.randn(2, 2)
|
|
scale = torch.randn(2, 2, requires_grad=True)
|
|
shift = torch.randn(2, 2, requires_grad=True)
|
|
|
|
@torch.jit.script
|
|
def fn(x, scale, shift):
|
|
return scale * x + shift
|
|
|
|
with self.capture_stdout() as captured:
|
|
print(fn(x, scale, shift))
|
|
|
|
def test_non_final_return(self):
|
|
|
|
def simple(x):
|
|
if bool(x > 3):
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
raise RuntimeError("nope")
|
|
|
|
def nest(x):
|
|
x = x + 1
|
|
if bool(x > 3):
|
|
if bool(x > 4):
|
|
x += 1
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
def early_ret(x):
|
|
x = x + 1
|
|
if bool(x > 3):
|
|
return x + 1
|
|
x = x + 1
|
|
return x + 2
|
|
|
|
def nest_early_ret(x):
|
|
x = x + 1
|
|
if bool(x > 3):
|
|
if bool(x > 4):
|
|
return x + 2
|
|
return x + 1
|
|
x = x + 1
|
|
return x + 2
|
|
|
|
self.checkScript(simple, torch.rand(1))
|
|
self.checkScript(nest, torch.rand(1))
|
|
self.checkScript(early_ret, torch.rand(1))
|
|
self.checkScript(nest_early_ret, torch.rand(1))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "early"):
|
|
@torch.jit.script
|
|
def not_early_ret(x):
|
|
if bool(x > 3):
|
|
if bool(x > 4):
|
|
return 1
|
|
print("foo")
|
|
else:
|
|
print("5")
|
|
return 7
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "some paths"):
|
|
@torch.jit.script
|
|
def not_total_ret(x):
|
|
if bool(x > 3):
|
|
if bool(x > 4):
|
|
return 1
|
|
else:
|
|
return 2
|
|
else:
|
|
print("5")
|
|
return 7
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "from a loop"):
|
|
@torch.jit.script
|
|
def nest_while_ret(x):
|
|
while bool(x > 4):
|
|
if bool(x < 3):
|
|
return 4
|
|
return 5
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "from a loop"):
|
|
@torch.jit.script
|
|
def nest_for_ret(x):
|
|
for _ in range(3):
|
|
if bool(x < 3):
|
|
return 4
|
|
return 5
|
|
|
|
def test_overloading(self):
|
|
@torch._jit_internal.weak_module
|
|
class W(torch.nn.Module):
|
|
__overloads__ = {'forward': ['forward_tuple', 'forward_tensor']}
|
|
|
|
def __init__(self):
|
|
super(W, self).__init__()
|
|
|
|
@torch._jit_internal.weak_script_method
|
|
def forward_tuple(self, x):
|
|
# type: (Tuple[Tensor, Tensor]) -> Tensor
|
|
return x[0] + 5
|
|
|
|
def forward(self, x):
|
|
# manually do argument switching
|
|
if isinstance(x, tuple):
|
|
return self.forward_tuple(x)
|
|
else:
|
|
return self.forward_tensor(x)
|
|
|
|
@torch._jit_internal.weak_script_method
|
|
def forward_tensor(self, x):
|
|
# type: (Tensor) -> Tensor
|
|
return x + 20
|
|
|
|
class S(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(S, self).__init__()
|
|
self.weak = W()
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
return self.weak(x) + self.weak((x, x))
|
|
|
|
s = S()
|
|
x = torch.ones(1)
|
|
self.assertEqual(s(x), x + 20 + 5 + x)
|
|
|
|
w = W()
|
|
self.assertEqual(w((x, x)), x + 5)
|
|
self.assertEqual(w((x)), x + 20)
|
|
|
|
def test_select_after_chunk(self):
|
|
def foo(x):
|
|
chunked = torch.chunk(x, 1)
|
|
foo = chunked[0]
|
|
foo.add_(5)
|
|
return x
|
|
|
|
self.checkScript(foo, [torch.rand(2, 3)])
|
|
|
|
def test_nn_LSTM(self):
|
|
input = torch.nn.utils.rnn.pack_sequence([torch.randn(5, 5)])
|
|
|
|
class S(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(S, self).__init__()
|
|
self.x = torch.nn.LSTM(5, 5)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
# type: (Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]) -> Tuple[Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]], Tuple[Tensor, Tensor]] # noqa
|
|
return self.x(input)
|
|
|
|
eager_out = self.runAndSaveRNG(lambda x: torch.nn.LSTM(5, 5)(x), (input,))[0]
|
|
script_out = self.runAndSaveRNG(lambda x: S()(x), (input,))[0]
|
|
|
|
self.assertEqual(eager_out, script_out)
|
|
|
|
def test_list_python_op(self):
|
|
def python_list_op(lst):
|
|
# type: (List[Tensor]) -> Tensor
|
|
return lst[0]
|
|
|
|
def fn(lst):
|
|
# type: (List[Tensor]) -> Tensor
|
|
return python_list_op(lst)
|
|
|
|
self.checkScript(fn, ([torch.ones(2) + 2, torch.ones(2)],))
|
|
|
|
def test_ignore_decorator(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
tensor = torch.zeros(1, requires_grad=False)
|
|
self.register_buffer('some_state', torch.nn.Parameter(tensor))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
self.ignored_code(x)
|
|
return x
|
|
|
|
@torch.jit.ignore
|
|
def ignored_code(self, x):
|
|
self.some_state = torch.tensor((100,))
|
|
|
|
# Assert ignored code is run
|
|
m = M()
|
|
self.assertEqual(m.some_state, torch.zeros(1))
|
|
m(torch.ones(1))
|
|
self.assertEqual(m.some_state, torch.zeros(1) + 100)
|
|
|
|
# Export and ensure ignored code not present
|
|
pp, constants = m._python_print()
|
|
printed = torch.jit.ScriptModule()
|
|
ppv = "op_version_set = 0\n{}".format(pp)
|
|
torch._C._jit_import_methods(printed, ppv, constants)
|
|
self.assertIn('IgnoredPythonOp', ppv)
|
|
self.assertNotIn('ignored_code', ppv)
|
|
|
|
with self.assertRaisesRegex(torch.jit.Error, "This Python function is annotated to be ignored"):
|
|
printed(torch.ones(1))
|
|
|
|
def test_view_write(self):
|
|
def fn(x, y):
|
|
l = []
|
|
l.append(x)
|
|
x_view = l[0]
|
|
a = x + x
|
|
x_view.add_(y)
|
|
b = x + x
|
|
return a == b
|
|
self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
|
|
|
|
def test_dict_view(self):
|
|
def fn(x, y):
|
|
l = {"a": x}
|
|
x_view = l["a"]
|
|
a = x + x
|
|
x_view.add_(y)
|
|
b = x + x
|
|
return a == b
|
|
self.checkScript(fn, (torch.rand(2, 3), torch.rand(2, 3)))
|
|
|
|
def test_dict_ops(self):
|
|
d = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
|
|
|
|
@torch.jit.script
|
|
def keys(x):
|
|
# type: (Dict[str, Tensor]) -> List[str]
|
|
return list(x.keys())
|
|
|
|
self.assertEqual(set(keys(d)), set(d.keys()))
|
|
|
|
@torch.jit.script
|
|
def values(x):
|
|
# type: (Dict[str, Tensor]) -> List[Tensor]
|
|
return list(x.values())
|
|
|
|
self.assertEqual(set(values(d)), set(d.values()))
|
|
|
|
def length(x):
|
|
# type: (Dict[str, Tensor]) -> int
|
|
return len(x)
|
|
|
|
self.checkScript(length, (d,))
|
|
|
|
def test_dict(self):
|
|
def simple(x):
|
|
# type: (Dict[str, int]) -> Dict[str, int]
|
|
return x
|
|
|
|
self.checkScript(simple, ({'item': 20, 'other_item': 120},))
|
|
|
|
def index(x):
|
|
# type: (Dict[str, int]) -> int
|
|
return x['item']
|
|
|
|
self.checkScript(index, ({'item': 20, 'other_item': 120},))
|
|
|
|
def type_default():
|
|
# type: () -> Dict[str, Tensor]
|
|
return {}
|
|
|
|
self.checkScript(type_default, ())
|
|
|
|
@torch.jit.script
|
|
def missing_index(x):
|
|
# type: (Dict[str, int]) -> int
|
|
return x['dne']
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "KeyError"):
|
|
missing_index({'item': 20, 'other_item': 120})
|
|
|
|
code = dedent('''
|
|
def literal1():
|
|
return torch.jit.annotate(Dict[int, float], {})
|
|
def literal2():
|
|
return torch.jit.annotate(Dict[int, float], {10: 1.2})
|
|
''')
|
|
cu = torch.jit.CompilationUnit(code)
|
|
self.assertEqual({}, cu.literal1())
|
|
self.assertEqual({10: 1.2}, cu.literal2())
|
|
|
|
cu = torch.jit.CompilationUnit(dedent('''
|
|
def literal3():
|
|
return torch.jit.annotate(Dict[int, float], {10: 1.2, 11: 1.3})
|
|
'''))
|
|
self.assertEqual({10: 1.2, 11: 1.3}, cu.literal3())
|
|
|
|
def list_of_dicts():
|
|
# type: () -> List[Dict[str, Tensor]]
|
|
return [{'word': torch.ones(2) + 3}, {'other word': torch.ones(1) + 2}]
|
|
|
|
self.checkScript(list_of_dicts, ())
|
|
|
|
def test_dict_mutability(self):
|
|
@torch.jit.script
|
|
def fn():
|
|
# type: () -> Dict[str, int]
|
|
a = torch.jit.annotate(Dict[str, int], {})
|
|
a['ok'] = 10
|
|
return a
|
|
|
|
self.assertEqual(fn(), {'ok': 10})
|
|
|
|
def test_dict_membership(self):
|
|
def fn(x, y):
|
|
# type: (Dict[int, int], int) -> int
|
|
return x.get(y, 3)
|
|
|
|
d = {1: 2, 3: 4}
|
|
self.checkScript(fn, (d, 3))
|
|
self.checkScript(fn, (d, 2))
|
|
|
|
def optional(x, y):
|
|
# type: (Dict[int, int], int) -> bool
|
|
res = x.get(y)
|
|
return res is None
|
|
|
|
self.checkScript(fn, (d, 3))
|
|
self.checkScript(fn, (d, 2))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "is actually of type Optional"):
|
|
@torch.jit.script
|
|
def bad_types(x, y):
|
|
# type: (Dict[int, int], int) -> int
|
|
return x.get(y) # noqa: T484
|
|
|
|
def dict_to_python(self):
|
|
def python_lookup(my_dict, keys):
|
|
# type: (Dict[str, int], List[str]) -> List[int]
|
|
return [my_dict[k] for k in keys]
|
|
|
|
def fn(my_dict, keys):
|
|
# type: (Dict[str, int], List[str]) -> List[int]
|
|
return python_lookup(my_dict, keys)
|
|
|
|
a_dict = {'a': torch.ones(1), 'b': torch.ones(1) + 1, 'c': torch.ones(1) + 2}
|
|
self.checkScript(fn, (a_dict, ('a', 'c')))
|
|
|
|
def test_module_attrs(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self, table):
|
|
super(M, self).__init__()
|
|
self.table = torch.jit.Attribute(table, Dict[str, torch.Tensor])
|
|
self.x = torch.nn.Parameter(torch.tensor([100.0]))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, key):
|
|
# type: (str) -> Tensor
|
|
return self.table[key] + self.x
|
|
|
|
with self.disableModuleHook():
|
|
# TODO: re-enable module hook when Python printing of attributes is
|
|
# supported
|
|
m = M({char : torch.ones(1) + ord(char) - ord("a") for char in "abcdefg"})
|
|
self.assertEqual(m("c"), torch.tensor([103]))
|
|
|
|
def test_tensor_import_export(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
a = torch.tensor(1)
|
|
b = torch.tensor([1, 2])
|
|
c = [a, b]
|
|
return c
|
|
|
|
self.run_pass('constant_propagation', foo.graph)
|
|
m = torch.jit.ScriptModule()
|
|
m._create_method_from_graph("forward", foo.graph)
|
|
self.getExportImportCopy(m)
|
|
|
|
def test_attribute_serialization(self):
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
|
|
self.float = torch.jit.Attribute(2.3, float)
|
|
self.int = torch.jit.Attribute(99, int)
|
|
self.tuple = torch.jit.Attribute((1, 2, 3, 4), Tuple[int, int, int, int])
|
|
self.list = torch.jit.Attribute([(1, 2), (3, 4)], List[Tuple[int, int]])
|
|
self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
|
|
self.int_list = torch.jit.Attribute([1, 2, 3, 4], List[int])
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return (self.table, self.float, self.int, self.tuple, self.list, self.int_list)
|
|
|
|
m = M()
|
|
imported_m = self.getExportImportCopy(m)
|
|
self.assertEqual(m(), imported_m())
|
|
|
|
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: TemporaryFileName support for Windows or Sandcastle")
|
|
def test_attribute_unpickling(self):
|
|
import zipfile
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
|
|
self.float = torch.jit.Attribute(2.3, float)
|
|
self.int = torch.jit.Attribute(99, int)
|
|
self.tuple = torch.jit.Attribute((1, 2, 3, 4), Tuple[int, int, int, int])
|
|
self.list = torch.jit.Attribute([(1, 2), (3, 4)], List[Tuple[int, int]])
|
|
self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
|
|
self.int_list = torch.jit.Attribute([1, 2, 3, 4], List[int])
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return (self.table, self.float, self.int, self.tuple, self.list, self.int_list)
|
|
|
|
class TensorID(object):
|
|
def __setstate__(self, id):
|
|
self.id = id
|
|
|
|
class IntList(object):
|
|
def __setstate__(self, data):
|
|
self.data = data
|
|
|
|
class JitUnpickler(pickle.Unpickler):
|
|
def find_class(self, module, name):
|
|
if not module == '__main__':
|
|
return None
|
|
|
|
if name == 'TensorID':
|
|
return TensorID
|
|
elif name == 'IntList':
|
|
return IntList
|
|
|
|
with TemporaryFileName() as fname:
|
|
M().save(fname)
|
|
archive_name = os.path.basename(os.path.normpath(fname))
|
|
archive = zipfile.ZipFile(fname, 'r')
|
|
pickled_data = archive.read(os.path.join(archive_name, 'attributes.pkl'))
|
|
JitUnpickler(io.BytesIO(pickled_data)).load()
|
|
|
|
def test_submodule_attribute_serialization(self):
|
|
class S(torch.jit.ScriptModule):
|
|
def __init__(self, list_data):
|
|
super(S, self).__init__()
|
|
self.table = torch.jit.Attribute({"I": "am", "a test": "test"}, Dict[str, str])
|
|
self.list = torch.jit.Attribute(list_data, List[Tuple[int, int]])
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return (self.table, self.list)
|
|
|
|
class M(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(M, self).__init__()
|
|
self.table = torch.jit.Attribute({"this": "is", "a different": "dict"}, Dict[str, str])
|
|
self.tensor = torch.jit.Attribute(torch.randn(2, 2), torch.Tensor)
|
|
self.s1 = S([(1, 2)])
|
|
self.s2 = S([(4, 5)])
|
|
|
|
@torch.jit.script_method
|
|
def forward(self):
|
|
return (self.table, self.tensor, self.s1.table, self.s2.list, self.s1.list)
|
|
|
|
m = M()
|
|
imported_m = self.getExportImportCopy(m)
|
|
self.assertEqual(m(), imported_m())
|
|
|
|
def test_optional_tuple(self):
|
|
def fn(x=None):
|
|
# type: (Optional[Tuple[int, int]]) -> Tuple[int, int]
|
|
if x is None:
|
|
new_x = (1, 2)
|
|
else:
|
|
new_x = x
|
|
return new_x
|
|
|
|
self.checkScript(fn, ((3, 4),))
|
|
self.checkScript(fn, ())
|
|
|
|
def test_split(self):
|
|
def split_two(tensor):
|
|
a, b, c = torch.split(tensor, 2, dim=1)
|
|
return a, b, c
|
|
x = torch.randn(3, 6)
|
|
y = torch.randn(3, 6)
|
|
self.checkScript(split_two, [(x + y)])
|
|
|
|
|
|
class MnistNet(nn.Module):
|
|
def __init__(self):
|
|
super(MnistNet, self).__init__()
|
|
self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
|
|
self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
|
|
self.conv2_drop = nn.Dropout2d()
|
|
self.fc1 = nn.Linear(320, 50)
|
|
self.fc2 = nn.Linear(50, 10)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(F.max_pool2d(self.conv1(x), 2))
|
|
x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
|
|
x = x.view(-1, 320)
|
|
x = F.relu(self.fc1(x))
|
|
x = F.dropout(x, training=self.training)
|
|
x = self.fc2(x)
|
|
return F.log_softmax(x, dim=1)
|
|
|
|
|
|
class TestEndToEndHybridFrontendModels(JitTestCase):
|
|
@staticmethod
|
|
def _test_dcgan_models(self, device, check_export_import=True):
|
|
class DCGANGenerator(nn.Module):
|
|
def __init__(self, nz, ngf, nc):
|
|
super(DCGANGenerator, self).__init__()
|
|
self.main = nn.Sequential(
|
|
# input is Z, going into a convolution
|
|
nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
|
|
nn.BatchNorm2d(ngf * 8),
|
|
nn.ReLU(True),
|
|
# state size. (ngf*8) x 4 x 4
|
|
nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ngf * 4),
|
|
nn.ReLU(True),
|
|
# state size. (ngf*4) x 8 x 8
|
|
nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ngf * 2),
|
|
nn.ReLU(True),
|
|
# state size. (ngf*2) x 16 x 16
|
|
nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ngf),
|
|
nn.ReLU(True),
|
|
# state size. (ngf) x 32 x 32
|
|
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
|
|
nn.Tanh()
|
|
# state size. (nc) x 64 x 64
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.main(input)
|
|
|
|
class DCGANDiscriminator(nn.Module):
|
|
def __init__(self, nc, ndf):
|
|
super(DCGANDiscriminator, self).__init__()
|
|
self.main = nn.Sequential(
|
|
# input is (nc) x 64 x 64
|
|
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf) x 32 x 32
|
|
nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ndf * 2),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf*2) x 16 x 16
|
|
nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ndf * 4),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf*4) x 8 x 8
|
|
nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
|
|
nn.BatchNorm2d(ndf * 8),
|
|
nn.LeakyReLU(0.2, inplace=True),
|
|
# state size. (ndf*8) x 4 x 4
|
|
nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
|
|
nn.Sigmoid()
|
|
)
|
|
|
|
def forward(self, input):
|
|
return self.main(input).view(-1, 1).squeeze(1)
|
|
|
|
bs, nz, ngf, nc, ndf = 5, 6, 9, 3, 10
|
|
self.checkTrace(DCGANGenerator(nz, ngf, nc).to(device),
|
|
(torch.rand(bs, nz, 1, 1, device=device),),
|
|
export_import=check_export_import)
|
|
example_input = DCGANGenerator(nz, ngf, nc).to(device)(torch.rand(bs, nz, 1, 1, device=device))
|
|
self.checkTrace(DCGANDiscriminator(nc, ndf).to(device), (example_input,),
|
|
export_import=check_export_import)
|
|
|
|
def test_dcgan_models(self):
|
|
self._test_dcgan_models(self, device='cpu')
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_dcgan_models_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_dcgan_models(self, device='cuda', check_export_import=False)
|
|
|
|
@staticmethod
|
|
def _test_neural_style(self, device, check_export_import=True):
|
|
class TransformerNet(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TransformerNet, self).__init__()
|
|
# Initial convolution layers
|
|
self.conv1 = ConvLayer(3, 32, kernel_size=9, stride=1)
|
|
self.in1 = torch.nn.InstanceNorm2d(32, affine=True)
|
|
self.conv2 = ConvLayer(32, 64, kernel_size=3, stride=2)
|
|
self.in2 = torch.nn.InstanceNorm2d(64, affine=True)
|
|
self.conv3 = ConvLayer(64, 128, kernel_size=3, stride=2)
|
|
self.in3 = torch.nn.InstanceNorm2d(128, affine=True)
|
|
# Residual layers
|
|
self.res1 = ResidualBlock(128)
|
|
self.res2 = ResidualBlock(128)
|
|
self.res3 = ResidualBlock(128)
|
|
self.res4 = ResidualBlock(128)
|
|
self.res5 = ResidualBlock(128)
|
|
# Upsampling Layers
|
|
self.deconv1 = UpsampleConvLayer(128, 64, kernel_size=3, stride=1, upsample=2)
|
|
self.in4 = torch.nn.InstanceNorm2d(64, affine=True)
|
|
self.deconv2 = UpsampleConvLayer(64, 32, kernel_size=3, stride=1, upsample=2)
|
|
self.in5 = torch.nn.InstanceNorm2d(32, affine=True)
|
|
self.deconv3 = ConvLayer(32, 3, kernel_size=9, stride=1)
|
|
# Non-linearities
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, X):
|
|
y = self.relu(self.in1(self.conv1(X)))
|
|
y = self.relu(self.in2(self.conv2(y)))
|
|
y = self.relu(self.in3(self.conv3(y)))
|
|
y = self.res1(y)
|
|
y = self.res2(y)
|
|
y = self.res3(y)
|
|
y = self.res4(y)
|
|
y = self.res5(y)
|
|
y = self.relu(self.in4(self.deconv1(y)))
|
|
y = self.relu(self.in5(self.deconv2(y)))
|
|
y = self.deconv3(y)
|
|
return y
|
|
|
|
class ConvLayer(torch.nn.Module):
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride):
|
|
super(ConvLayer, self).__init__()
|
|
reflection_padding = kernel_size // 2
|
|
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
|
|
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
|
|
|
|
def forward(self, x):
|
|
out = self.reflection_pad(x)
|
|
out = self.conv2d(out)
|
|
return out
|
|
|
|
class ResidualBlock(torch.nn.Module):
|
|
"""ResidualBlock
|
|
introduced in: https://arxiv.org/abs/1512.03385
|
|
recommended architecture: http://torch.ch/blog/2016/02/04/resnets.html
|
|
"""
|
|
|
|
def __init__(self, channels):
|
|
super(ResidualBlock, self).__init__()
|
|
self.conv1 = ConvLayer(channels, channels, kernel_size=3, stride=1)
|
|
self.in1 = torch.nn.InstanceNorm2d(channels, affine=True)
|
|
self.conv2 = ConvLayer(channels, channels, kernel_size=3, stride=1)
|
|
self.in2 = torch.nn.InstanceNorm2d(channels, affine=True)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
residual = x
|
|
out = self.relu(self.in1(self.conv1(x)))
|
|
out = self.in2(self.conv2(out))
|
|
out = out + residual
|
|
return out
|
|
|
|
class UpsampleConvLayer(torch.nn.Module):
|
|
"""UpsampleConvLayer
|
|
Upsamples the input and then does a convolution. This method gives better results
|
|
compared to ConvTranspose2d.
|
|
ref: http://distill.pub/2016/deconv-checkerboard/
|
|
"""
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size, stride, upsample=None):
|
|
super(UpsampleConvLayer, self).__init__()
|
|
self.upsample = upsample
|
|
if upsample:
|
|
self.upsample_layer = torch.nn.Upsample(mode='nearest', scale_factor=upsample)
|
|
reflection_padding = kernel_size // 2
|
|
self.reflection_pad = torch.nn.ReflectionPad2d(reflection_padding)
|
|
self.conv2d = torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride)
|
|
|
|
def forward(self, x):
|
|
x_in = x
|
|
if self.upsample:
|
|
x_in = self.upsample_layer(x_in)
|
|
out = self.reflection_pad(x_in)
|
|
out = self.conv2d(out)
|
|
return out
|
|
|
|
self.checkTrace(TransformerNet(), (torch.rand(5, 3, 16, 16),), export_import=check_export_import)
|
|
|
|
def test_neural_style(self):
|
|
self._test_neural_style(self, device='cpu')
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_neural_style_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_neural_style(self, device='cuda', check_export_import=False)
|
|
|
|
@staticmethod
|
|
def _test_mnist(self, device, check_export_import=True):
|
|
# eval() is present because dropout makes this nondeterministic
|
|
self.checkTrace(MnistNet().to(device).eval(), (torch.rand(5, 1, 28, 28, device=device),),
|
|
export_import=check_export_import)
|
|
|
|
def test_mnist(self):
|
|
self._test_mnist(self, device='cpu')
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_mnist_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_mnist(self, device='cuda', check_export_import=False)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_mnist_training_leaks_no_memory_cuda(self):
|
|
net = MnistNet().cuda()
|
|
# MnistNet uses dropout, don't check its trace
|
|
traced_net = torch.jit.trace(net, [torch.randn(5, 1, 28, 28, device='cuda')],
|
|
check_trace=False)
|
|
|
|
def train(iters):
|
|
for _ in range(iters):
|
|
# Get some fake data
|
|
inp = torch.randn(5, 1, 28, 28, device='cuda')
|
|
out = traced_net(inp)
|
|
|
|
# Here's some fake loss
|
|
out.sum().backward()
|
|
|
|
# Zero out grads
|
|
traced_net.zero_grad()
|
|
|
|
# Set it up so the params have .grad fields so they are not reported as leaks
|
|
train(1)
|
|
|
|
with self.assertLeaksNoCudaTensors():
|
|
train(5)
|
|
|
|
@staticmethod
|
|
def _test_reinforcement_learning(self, device, test_export_import=True):
|
|
class Policy(nn.Module):
|
|
def __init__(self):
|
|
super(Policy, self).__init__()
|
|
self.affine1 = nn.Linear(4, 128)
|
|
self.affine2 = nn.Linear(128, 2)
|
|
|
|
def forward(self, x):
|
|
x = F.relu(self.affine1(x))
|
|
action_scores = self.affine2(x)
|
|
return F.softmax(action_scores, dim=1)
|
|
|
|
self.checkTrace(Policy().to(device), (torch.rand(1, 4, device=device),),
|
|
export_import=test_export_import)
|
|
|
|
def test_reinforcement_learning(self):
|
|
self._test_reinforcement_learning(self, device='cpu')
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_reinforcement_learning_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_reinforcement_learning(self, device='cuda', test_export_import=False)
|
|
|
|
@staticmethod
|
|
def _test_snli(self, device, check_export_import=True, quantized=False):
|
|
class Bottle(nn.Module):
|
|
|
|
def forward(self, input):
|
|
if len(input.size()) <= 2:
|
|
return super(Bottle, self).forward(input)
|
|
size = input.size()[:2]
|
|
out = super(Bottle, self).forward(input.view(size[0] * size[1], -1))
|
|
return out.view(size[0], size[1], -1)
|
|
|
|
class Linear(Bottle, nn.Linear):
|
|
pass
|
|
|
|
class Encoder(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super(Encoder, self).__init__()
|
|
self.config = config
|
|
input_size = config.d_proj if config.projection else config.d_embed
|
|
dropout = 0 if config.n_layers == 1 else config.dp_ratio
|
|
self.rnn = nn.LSTM(input_size=input_size, hidden_size=config.d_hidden,
|
|
num_layers=config.n_layers, dropout=dropout,
|
|
bidirectional=config.birnn)
|
|
|
|
def forward(self, inputs):
|
|
batch_size = inputs.size()[1]
|
|
state_shape = self.config.n_cells, batch_size, self.config.d_hidden
|
|
h0 = c0 = inputs.new_zeros(state_shape)
|
|
outputs, (ht, ct) = self.rnn(inputs, (h0, c0))
|
|
return ht[-1] if not self.config.birnn else ht[-2:].transpose(0, 1).contiguous().view(batch_size, -1)
|
|
|
|
class SNLIClassifier(nn.Module):
|
|
|
|
def __init__(self, config):
|
|
super(SNLIClassifier, self).__init__()
|
|
self.config = config
|
|
self.embed = nn.Embedding(config.n_embed, config.d_embed)
|
|
self.projection = Linear(config.d_embed, config.d_proj)
|
|
self.encoder = Encoder(config)
|
|
self.dropout = nn.Dropout(p=config.dp_ratio)
|
|
self.relu = nn.ReLU()
|
|
seq_in_size = 2 * config.d_hidden
|
|
if self.config.birnn:
|
|
seq_in_size *= 2
|
|
lin_config = [seq_in_size] * 2
|
|
self.out = nn.Sequential(
|
|
Linear(*lin_config),
|
|
self.relu,
|
|
self.dropout,
|
|
Linear(*lin_config),
|
|
self.relu,
|
|
self.dropout,
|
|
Linear(*lin_config),
|
|
self.relu,
|
|
self.dropout,
|
|
Linear(seq_in_size, config.d_out))
|
|
|
|
def forward(self, premise, hypothesis):
|
|
prem_embed = self.embed(premise)
|
|
hypo_embed = self.embed(hypothesis)
|
|
if self.config.fix_emb:
|
|
prem_embed = prem_embed.detach()
|
|
hypo_embed = hypo_embed.detach()
|
|
if self.config.projection:
|
|
prem_embed = self.relu(self.projection(prem_embed))
|
|
hypo_embed = self.relu(self.projection(hypo_embed))
|
|
premise = self.encoder(prem_embed)
|
|
hypothesis = self.encoder(hypo_embed)
|
|
scores = self.out(torch.cat([premise, hypothesis], 1))
|
|
return scores
|
|
|
|
class Config:
|
|
n_embed = 100
|
|
d_embed = 100
|
|
d_proj = 300
|
|
dp_ratio = 0.0 # For deterministic testing TODO: change by fixing seed in checkTrace?
|
|
d_hidden = 30
|
|
birnn = True
|
|
d_out = 300
|
|
fix_emb = True
|
|
projection = True
|
|
n_layers = 2
|
|
n_cells = 4 # 2 * n_layers because birnn = True
|
|
|
|
premise = torch.LongTensor(48, 64).random_(0, 100).to(device)
|
|
hypothesis = torch.LongTensor(24, 64).random_(0, 100).to(device)
|
|
|
|
if quantized:
|
|
snli = SNLIClassifier(Config()).cpu()
|
|
torch.jit.quantized.quantize_linear_modules(snli)
|
|
# we don't do export/import checks because we would need to call
|
|
# _pack/_unpack
|
|
self.checkTrace(snli, (premise, hypothesis), inputs_require_grads=False,
|
|
export_import=False)
|
|
else:
|
|
self.checkTrace(SNLIClassifier(Config()).to(device), (premise, hypothesis),
|
|
inputs_require_grads=False, export_import=check_export_import)
|
|
|
|
def test_snli(self):
|
|
self._test_snli(self, device='cpu')
|
|
|
|
if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
|
|
def test_snli_quantized(self):
|
|
self._test_snli(self, device='cpu', quantized=True)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_snli_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_snli(self, device='cuda', check_export_import=False)
|
|
|
|
@staticmethod
|
|
def _test_super_resolution(self, device, check_export_import=True):
|
|
import torch.nn.init as init
|
|
|
|
class Net(nn.Module):
|
|
|
|
def __init__(self, upscale_factor):
|
|
super(Net, self).__init__()
|
|
|
|
self.relu = nn.ReLU()
|
|
self.conv1 = nn.Conv2d(1, 64, (5, 5), (1, 1), (2, 2))
|
|
self.conv2 = nn.Conv2d(64, 64, (3, 3), (1, 1), (1, 1))
|
|
self.conv3 = nn.Conv2d(64, 32, (3, 3), (1, 1), (1, 1))
|
|
self.conv4 = nn.Conv2d(32, upscale_factor ** 2, (3, 3), (1, 1), (1, 1))
|
|
self.pixel_shuffle = nn.PixelShuffle(upscale_factor)
|
|
|
|
def forward(self, x):
|
|
x = self.relu(self.conv1(x))
|
|
x = self.relu(self.conv2(x))
|
|
x = self.relu(self.conv3(x))
|
|
x = self.pixel_shuffle(self.conv4(x))
|
|
return x
|
|
|
|
net = Net(upscale_factor=4).to(device)
|
|
self.checkTrace(net, (torch.rand(5, 1, 32, 32, device=device),),
|
|
export_import=check_export_import)
|
|
|
|
def test_super_resolution(self):
|
|
self._test_super_resolution(self, device='cpu')
|
|
|
|
@unittest.skipIf(not RUN_CUDA, 'no CUDA')
|
|
def test_super_resolution_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_super_resolution(self, device='cuda', check_export_import=False)
|
|
|
|
@suppress_warnings
|
|
def test_time_sequence_prediction(self):
|
|
class Sequence(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Sequence, self).__init__()
|
|
self.lstm1 = nn.LSTMCell(1, 51)
|
|
self.lstm2 = nn.LSTMCell(51, 51)
|
|
self.linear = nn.Linear(51, 1)
|
|
|
|
# TODO: could not pass tuple to a python Op and type annotations
|
|
# is not descending to python signature, hence the wrapper
|
|
# see https://github.com/pytorch/pytorch/issues/8778
|
|
# and https://github.com/pytorch/pytorch/issues/8777
|
|
def test_lstm1(self, input, hx, cx):
|
|
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
|
return self.lstm1(input, (hx, cx))
|
|
|
|
def test_lstm2(self, input, hx, cx):
|
|
# type: (Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor]
|
|
return self.lstm2(input, (hx, cx))
|
|
|
|
# TODO: could not support tensor constructors in script
|
|
# see https://github.com/pytorch/pytorch/issues/8814
|
|
def test_tensor(self):
|
|
return torch.tensor([], dtype=torch.double)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
# TODO: add future as input with default val
|
|
# see https://github.com/pytorch/pytorch/issues/8724
|
|
outputs = self.test_tensor()
|
|
h_t = torch.zeros((3, 51), dtype=torch.double)
|
|
c_t = torch.zeros((3, 51), dtype=torch.double)
|
|
h_t2 = torch.zeros((3, 51), dtype=torch.double)
|
|
c_t2 = torch.zeros((3, 51), dtype=torch.double)
|
|
|
|
output = torch.zeros([3, 51])
|
|
future = 2
|
|
|
|
# TODO: chunk call should appear as the for loop iterable
|
|
# We hard-code it to 4 for now.
|
|
a, b, c, d = input.chunk(input.size(1), dim=1)
|
|
for input_t in (a, b, c, d):
|
|
h_t, c_t = self.test_lstm1(input_t, h_t, c_t)
|
|
h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
|
|
output = self.linear(h_t2)
|
|
outputs = torch.cat((outputs, output), 1)
|
|
for _ in range(future): # if we should predict the future
|
|
h_t, c_t = self.test_lstm1(output, h_t, c_t)
|
|
h_t2, c_t2 = self.test_lstm2(h_t, h_t2, c_t2)
|
|
output = self.linear(h_t2)
|
|
outputs = torch.cat((outputs, output), 1)
|
|
return outputs
|
|
|
|
# TODO: toggle export_import once above issues are fixed
|
|
self.checkTrace(Sequence(), (torch.rand(3, 4),),
|
|
export_import=False)
|
|
|
|
@staticmethod
|
|
def _test_vae(self, device, check_export_import=True, quantized=False):
|
|
class VAE(nn.Module):
|
|
def __init__(self):
|
|
super(VAE, self).__init__()
|
|
|
|
self.fc1 = nn.Linear(784, 400)
|
|
self.fc21 = nn.Linear(400, 20)
|
|
self.fc22 = nn.Linear(400, 20)
|
|
self.fc3 = nn.Linear(20, 400)
|
|
self.fc4 = nn.Linear(400, 784)
|
|
|
|
def encode(self, x):
|
|
h1 = F.relu(self.fc1(x))
|
|
return self.fc21(h1), self.fc22(h1)
|
|
|
|
def reparameterize(self, mu, logvar):
|
|
if self.training:
|
|
std = torch.exp(0.5 * logvar)
|
|
eps = torch.randn_like(std)
|
|
return eps.mul(std).add_(mu)
|
|
else:
|
|
return mu
|
|
|
|
def decode(self, z):
|
|
h3 = F.relu(self.fc3(z))
|
|
return torch.sigmoid(self.fc4(h3))
|
|
|
|
def forward(self, x):
|
|
mu, logvar = self.encode(x.view(-1, 784))
|
|
z = self.reparameterize(mu, logvar)
|
|
return self.decode(z), mu, logvar
|
|
|
|
if quantized:
|
|
vae = VAE().to(device).eval()
|
|
torch.jit.quantized.quantize_linear_modules(vae)
|
|
# We don't do export/import checks because we would need to call
|
|
# _unpack and _pack
|
|
self.checkTrace(vae, (torch.rand(128, 1, 28, 28, device=device),),
|
|
export_import=False, allow_unused=True,
|
|
inputs_require_grads=False)
|
|
else:
|
|
# eval() is present because randn_like makes this nondeterministic
|
|
self.checkTrace(VAE().to(device).eval(), (torch.rand(128, 1, 28, 28, device=device),),
|
|
export_import=check_export_import)
|
|
|
|
def test_vae(self):
|
|
self._test_vae(self, device='cpu')
|
|
|
|
if not TEST_WITH_UBSAN and torch.fbgemm_is_cpu_supported():
|
|
def test_vae_quantized(self):
|
|
self._test_vae(self, device='cpu', quantized=True)
|
|
|
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
|
def test_vae_cuda(self):
|
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
|
self._test_vae(self, device='cuda', check_export_import=False)
|
|
|
|
|
|
# Smoke tests for export methods
|
|
class TestPytorchExportModes(JitTestCase):
|
|
class MyModel(nn.Module):
|
|
def __init__(self):
|
|
super(TestPytorchExportModes.MyModel, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return x.transpose(0, 1)
|
|
|
|
def test_protobuf(self):
|
|
torch_model = TestPytorchExportModes.MyModel()
|
|
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
|
|
f = io.BytesIO()
|
|
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
|
|
export_type=torch.onnx.ExportTypes.PROTOBUF_FILE)
|
|
|
|
def test_zipfile(self):
|
|
torch_model = TestPytorchExportModes.MyModel()
|
|
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
|
|
f = io.BytesIO()
|
|
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
|
|
export_type=torch.onnx.ExportTypes.ZIP_ARCHIVE)
|
|
|
|
def test_compressed_zipfile(self):
|
|
torch_model = TestPytorchExportModes.MyModel()
|
|
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
|
|
f = io.BytesIO()
|
|
torch.onnx._export(torch_model, (fake_input), f, verbose=False,
|
|
export_type=torch.onnx.ExportTypes.COMPRESSED_ZIP_ARCHIVE)
|
|
|
|
def test_directory(self):
|
|
torch_model = TestPytorchExportModes.MyModel()
|
|
fake_input = Variable(torch.randn(1, 1, 224, 224), requires_grad=True)
|
|
d = tempfile.mkdtemp()
|
|
torch.onnx._export(torch_model, (fake_input), d, verbose=False,
|
|
export_type=torch.onnx.ExportTypes.DIRECTORY)
|
|
shutil.rmtree(d)
|
|
|
|
def test_onnx_multiple_return(self):
|
|
@torch.jit.script
|
|
def foo(a):
|
|
return (a, a)
|
|
f = io.BytesIO()
|
|
x = torch.ones(3)
|
|
torch.onnx._export(foo, (x,), f, example_outputs=(x, x))
|
|
|
|
@skipIfNoLapack
|
|
def test_aten_fallback(self):
|
|
class ModelWithAtenNotONNXOp(nn.Module):
|
|
def forward(self, x, y):
|
|
abcd = x + y
|
|
defg = torch.qr(abcd)
|
|
return defg
|
|
|
|
x = torch.rand(3, 4)
|
|
y = torch.rand(3, 4)
|
|
f = io.BytesIO()
|
|
exported = torch.onnx.export_to_pretty_string(
|
|
ModelWithAtenNotONNXOp(), (x, y), f,
|
|
operator_export_type=OperatorExportTypes.ONNX_ATEN_FALLBACK)
|
|
self.assertExpected(exported)
|
|
|
|
# torch.fmod is using to test ONNX_ATEN.
|
|
# If you plan to remove fmod from aten, or found this test failed.
|
|
# please contact @Rui.
|
|
def test_onnx_aten(self):
|
|
class ModelWithAtenFmod(nn.Module):
|
|
def forward(self, x, y):
|
|
return torch.fmod(x, y)
|
|
|
|
f = io.BytesIO()
|
|
x = torch.randn(3, 4, dtype=torch.float32)
|
|
y = torch.randn(3, 4, dtype=torch.float32)
|
|
exported = torch.onnx.export_to_pretty_string(
|
|
ModelWithAtenFmod(), (x, y), f,
|
|
operator_export_type=OperatorExportTypes.ONNX_ATEN)
|
|
self.assertExpected(exported)
|
|
|
|
|
|
# known to be failing in tracer
|
|
EXCLUDE_TRACED = {
|
|
# The following fail due to #12024.
|
|
# A prim::ListConstruct is involved and the indices get traced as TensorType,
|
|
# which always require_grad. This causes a crash in autodiff.
|
|
'test___getitem___adv_index',
|
|
'test___getitem___adv_index_beg',
|
|
'test___getitem___adv_index_comb',
|
|
'test___getitem___adv_index_dup',
|
|
'test___getitem___adv_index_sub',
|
|
'test___getitem___adv_index_sub_2',
|
|
'test___getitem___adv_index_sub_3',
|
|
'test___getitem___adv_index_var',
|
|
|
|
}
|
|
|
|
EXCLUDE_TYPE_CHECK = {
|
|
# slogdet tests use itemgetter to select its only differentiable output,
|
|
# but this happens outside of the graph we handle, so there are fewer
|
|
# reference outputs than graph outputs.
|
|
'test_slogdet_1x1_neg_det',
|
|
'test_slogdet_1x1_pos_det',
|
|
'test_slogdet_distinct_singular_values',
|
|
'test_slogdet_neg_det',
|
|
'test_slogdet_pos_det',
|
|
'test_slogdet_symmetric',
|
|
'test_slogdet_symmetric_pd',
|
|
}
|
|
|
|
# known to be failing in script
|
|
EXCLUDE_SCRIPT = {
|
|
'test_norm_fro',
|
|
'test_norm_fro_default',
|
|
'test_norm_nuc',
|
|
|
|
# aten op has additional cudnn argument
|
|
'test_nn_unfold',
|
|
|
|
# flaky test - TODO fix
|
|
'test_nn_ctc_loss',
|
|
|
|
# unknown builtin op
|
|
'test_nn_fold',
|
|
}
|
|
|
|
EXCLUDE_PYTHON_PRINT = {
|
|
# no support for BroadcastingList in python printer
|
|
'test_nn_max_unpool1d',
|
|
'test_nn_max_unpool2d',
|
|
'test_nn_max_unpool3d',
|
|
'test_nn_max_pool1d',
|
|
'test_nn_max_pool2d',
|
|
'test_nn_max_pool3d',
|
|
'test_nn_max_pool1d_with_indices',
|
|
}
|
|
|
|
EXCLUDE_SCRIPT_MODULES = {
|
|
'test_nn_AdaptiveAvgPool2d_tuple_none',
|
|
'test_nn_AdaptiveAvgPool3d_tuple_none',
|
|
'test_nn_AdaptiveMaxPool2d_tuple_none',
|
|
'test_nn_AdaptiveMaxPool3d_tuple_none',
|
|
}
|
|
|
|
DISABLE_AUTODIFF_SUBGRAPH_INLINING = {
|
|
'test_nn_avg_pool2d',
|
|
'test_nn_adaptive_avg_pool1d',
|
|
'test_nn_adaptive_avg_pool2d',
|
|
'test_nn_adaptive_avg_pool3d',
|
|
'test_nn_batch_norm',
|
|
'test_nn_embedding',
|
|
'test_nn_log_softmax',
|
|
'test_nn_softmax',
|
|
'test_nn_softmax_with_all_args',
|
|
'test_nn_threshold',
|
|
'test_nn_nll_loss',
|
|
# Should have added all test_nn_interpolate_* here,
|
|
# but it's using autodiff since its subgraph is over
|
|
# 2 nodes.
|
|
}
|
|
|
|
|
|
# make a new function where all non-tensor arguments in 'args' have been partially
|
|
# applied, and all tensor arguments remain.
|
|
# used to trace functions when some arguments are not tensors
|
|
def partial_apply_nontensors(fn, args, **kwargs):
|
|
source = ['t' if isinstance(arg, torch.Tensor) else 's' for arg in args]
|
|
|
|
def new_fn(*tensors_):
|
|
tensors = iter(tensors_)
|
|
return fn(*(args[i] if s == 's' else next(tensors) for i, s in enumerate(source)), **kwargs)
|
|
|
|
return new_fn, [arg for arg in args if isinstance(arg, torch.Tensor)]
|
|
|
|
|
|
# create a trace function from input fn
|
|
#
|
|
# disable_autodiff_subgraph_inlining:
|
|
# Don't inline autodiff subgraphs so we can test autodiff
|
|
def create_traced_fn(self, fn,
|
|
disable_autodiff_subgraph_inlining=False):
|
|
def traced_fn(*inputs, **kwargs):
|
|
fn_tensors, inputs_tensors = partial_apply_nontensors(fn, inputs, **kwargs)
|
|
traced = torch.jit.trace(fn_tensors, inputs_tensors)
|
|
self.assertExportImport(traced.graph, inputs_tensors)
|
|
if disable_autodiff_subgraph_inlining:
|
|
traced.debug_disable_autodiff_subgraph_inlining()
|
|
output = traced(*inputs_tensors)
|
|
traced_fn.last_graph = traced.graph_for(*inputs_tensors)
|
|
return output
|
|
return traced_fn
|
|
|
|
script_template = '''
|
|
def the_method({}):
|
|
return {}
|
|
'''
|
|
|
|
script_method_template = '''
|
|
def forward({}):
|
|
return {}
|
|
'''
|
|
|
|
|
|
def get_constant(x):
|
|
if x == inf:
|
|
return 'float(\'inf\')' if PY2 else 'math.inf'
|
|
if x == -inf:
|
|
return 'float(\'-inf\')' if PY2 else '-math.inf'
|
|
return x
|
|
|
|
|
|
def get_script_args(args):
|
|
formals = []
|
|
tensors = []
|
|
actuals = []
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor):
|
|
name = 'i{}'.format(len(formals))
|
|
formals.append(name)
|
|
actuals.append(name)
|
|
tensors.append(arg)
|
|
elif isinstance(arg, str):
|
|
actuals.append("'{}'".format(arg))
|
|
else:
|
|
actuals.append(str(get_constant(arg)))
|
|
return (formals, tensors, actuals)
|
|
|
|
|
|
# create a script function from (name, func_type, output_process_fn),
|
|
# returns a function takes in (args, kwargs) and runs the compiled function and
|
|
# then applies the post process fn to the outputs
|
|
def create_script_fn(self, method_name, func_type, output_process_fn,
|
|
disable_autodiff_subgraph_inlining=False):
|
|
def script_fn(*args, **kwargs):
|
|
formals, tensors, actuals = get_script_args(args)
|
|
kwargs_str = ''
|
|
for k, v in kwargs.items():
|
|
kwargs_str += ', ' + k + '=' + str(v)
|
|
if func_type == 'functional':
|
|
call = 'torch.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
|
|
elif func_type == 'method':
|
|
call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
|
|
elif func_type == 'nn_functional':
|
|
call = 'torch.nn.functional.{}({}{})'.format(method_name, ', '.join(actuals), kwargs_str)
|
|
else:
|
|
raise 'Unsupported function type'
|
|
|
|
script = script_template.format(', '.join(formals), call)
|
|
|
|
CU = torch.jit.CompilationUnit(script)
|
|
if disable_autodiff_subgraph_inlining:
|
|
CU.the_method.debug_disable_autodiff_subgraph_inlining()
|
|
self.assertExportImport(CU.the_method.graph, tensors)
|
|
output = output_process_fn(CU.the_method(*tensors))
|
|
script_fn.last_graph = CU.the_method.graph_for(*tensors)
|
|
return output
|
|
return script_fn
|
|
|
|
|
|
def check_alias_annotation(method_name, args, kwargs):
|
|
formals, tensors, actuals = get_script_args(args)
|
|
kwargs_str = ''
|
|
for k, v in kwargs.items():
|
|
kwargs_str += ', ' + k + '=' + str(v)
|
|
call = '{}.{}({}{})'.format(actuals[0], method_name, ', '.join(actuals[1:]), kwargs_str)
|
|
script = script_template.format(', '.join(formals), call)
|
|
CU = torch.jit.CompilationUnit(script)
|
|
torch._C._jit_check_alias_annotation(CU.the_method.graph, tuple(tensors), method_name)
|
|
|
|
|
|
def check_output_types(self, func, ref_outputs, args, kwargs):
|
|
graph = getattr(func, 'last_graph', None)
|
|
types = [o.type() for o in graph.outputs()]
|
|
self.assertTrue(len(types) == 1)
|
|
t = types[0]
|
|
torch._C._jit_assert_is_instance(ref_outputs, t)
|
|
|
|
|
|
def check_against_reference(self, func, reference_func, args, kwargs=None,
|
|
allow_unused=True, check_types=True, no_grad=False):
|
|
kwargs = kwargs if kwargs else {}
|
|
|
|
def allSum(vs):
|
|
if isinstance(vs, torch.Tensor):
|
|
vs = (vs,)
|
|
return sum((i + 1) * v.sum()
|
|
for i, v in enumerate(vs)
|
|
if v is not None and v.dtype.is_floating_point)
|
|
|
|
def clone_inputs(requires_grad):
|
|
inputs = [
|
|
arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad)
|
|
if isinstance(arg, torch.Tensor) else arg for arg in args
|
|
]
|
|
return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad]
|
|
|
|
nograd_inputs, nograd_tensors = clone_inputs(False)
|
|
recording_inputs, recording_tensors = clone_inputs(True)
|
|
|
|
# test no gradients case
|
|
outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
|
|
outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
|
|
self.assertEqual(outputs, outputs_test)
|
|
|
|
if check_types:
|
|
check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
|
|
|
|
if no_grad:
|
|
# skip grad tests
|
|
return
|
|
|
|
# test single grad case
|
|
outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
|
|
grads = torch.autograd.grad(allSum(outputs), recording_tensors,
|
|
allow_unused=allow_unused)
|
|
|
|
outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
|
|
grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
|
|
allow_unused=allow_unused)
|
|
self.assertEqual(outputs, outputs_test)
|
|
self.assertEqual(grads, grads_test)
|
|
|
|
# test the grad grad case
|
|
if self._testMethodName in nn_functional_single_grad:
|
|
return
|
|
|
|
outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
|
|
l1 = allSum(outputs)
|
|
grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
|
|
allow_unused=allow_unused)
|
|
l2 = (allSum(grads) * l1)
|
|
grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
|
|
|
|
recording_inputs, recording_tensors = clone_inputs(True)
|
|
|
|
outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
|
|
l1_test = allSum(outputs_test)
|
|
grads_test = torch.autograd.grad(
|
|
l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
|
|
l2_test = (allSum(grads_test) * l1_test)
|
|
grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
|
|
|
|
self.assertEqual(outputs, outputs_test)
|
|
self.assertEqual(grads, grads_test)
|
|
for g2, g2_test in zip(grads2, grads2_test):
|
|
if g2 is None and g2_test is None:
|
|
continue
|
|
self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
|
|
|
|
|
|
class TestFuser(JitTestCase):
|
|
def assertAllFused(self, graph, except_for=()):
|
|
if [n.kind() for n in graph.nodes()] == ['prim::DifferentiableGraph']:
|
|
graph = next(graph.nodes()).g('Subgraph')
|
|
allowed_nodes = {'prim::Constant', 'prim::FusionGroup'} | set(except_for)
|
|
self.assertTrue(all(node.kind() in allowed_nodes for node in graph.nodes()),
|
|
'got {}'.format(graph))
|
|
self.assertTrue([node.kind() for node in graph.nodes()].count('prim::FusionGroup') == 1)
|
|
|
|
def _test_fused_abs(self, device='cpu'):
|
|
|
|
@torch.jit.script
|
|
def func(x):
|
|
return x.abs() * 2
|
|
|
|
a = torch.randn(5, device=device)
|
|
self.assertEqual(func(a), a.abs() * 2)
|
|
self.assertAllFused(func.graph_for(a))
|
|
|
|
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
|
@enable_cpu_fuser
|
|
def test_abs_cpu(self):
|
|
self._test_fused_abs()
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
|
|
@skipIfRocm
|
|
def test_abs_cuda(self):
|
|
self._test_fused_abs(device="cuda")
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_arg_configurations_smoke_cuda(self):
|
|
# A smoke test to make sure we won't use the same kernel for contiguous
|
|
# and non-contiguous arguments.
|
|
# TODO: add optionally enabled debug counters to the fuser to verify
|
|
# that we really can tell the difference between configurations
|
|
def f(x, y):
|
|
z1, z2 = (x + y).chunk(2, dim=1)
|
|
return z1 * z2
|
|
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
traced_f = torch.jit.trace(f, (x, y,))
|
|
self.assertEqual(traced_f(x.t().contiguous(), y), traced_f(x.t(), y))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_broadcast_cuda(self):
|
|
def scaleshift(x, scale, shift):
|
|
return x * scale + shift
|
|
|
|
inputs = [
|
|
torch.randn(4, 4, dtype=torch.float, device='cuda'),
|
|
torch.randn(4, dtype=torch.float, device='cuda'),
|
|
torch.randn(4, dtype=torch.float, device='cuda'),
|
|
]
|
|
ge = self.checkTrace(scaleshift, inputs)
|
|
self.assertAllFused(ge.graph_for(*inputs))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@unittest.skipIf(not RUN_CUDA_HALF, "no half support")
|
|
def test_cuda_half(self):
|
|
x = torch.randn(4, 4, dtype=torch.half, device='cuda')
|
|
y = torch.randn(4, 4, dtype=torch.half, device='cuda')
|
|
|
|
funcs = [
|
|
self.fn_test_comparison_gt_lt,
|
|
self.fn_test_relu,
|
|
self.fn_test_exp
|
|
]
|
|
|
|
# Note: Non fused inputs must be float to prevent loss of precision
|
|
inputs = (x.float(), y.float())
|
|
fusion_inputs = (x, y)
|
|
for fn in funcs:
|
|
local_inputs = [t.clone().requires_grad_() for t in inputs]
|
|
local_fusion_inputs = [t.clone().requires_grad_() for t in fusion_inputs]
|
|
|
|
# Verifies outputs
|
|
fusion = torch.jit.trace(fn, local_fusion_inputs, check_trace=False, optimize=True)
|
|
outputs = fn(*local_inputs)
|
|
fusion_outputs = fusion(*local_fusion_inputs)
|
|
outputs_half = [t.half() for t in outputs]
|
|
self.assertEqual(outputs_half, fusion_outputs)
|
|
|
|
# Verifies gradients
|
|
for output, fusion_output in zip(outputs_half, fusion_outputs):
|
|
grads = torch.autograd.grad(
|
|
output.float().sum(), local_inputs, allow_unused=True, retain_graph=True)
|
|
fusion_grads = torch.autograd.grad(
|
|
fusion_output.sum(), local_fusion_inputs, allow_unused=True, retain_graph=True)
|
|
grads_half = [t.half() for t in grads]
|
|
self.assertEqual(grads_half, fusion_grads)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_checks_cat_inputs(self):
|
|
# We shouldn't treat cat nodes as broadcasting. All their inputs
|
|
# need to be checked for having the same map size, before we can
|
|
# run the kernel.
|
|
@torch.jit.script
|
|
def f(x, y):
|
|
return torch.cat([x + 2 * x + x ** 2, y + 4 * y + y ** 3], dim=0)
|
|
|
|
# NOTE: y is broadcastable to x, but output of f(x, y) should have
|
|
# shape 3x4, and not 4x4.
|
|
x = torch.randn(2, 4, dtype=torch.float, device='cuda')
|
|
y = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
|
|
|
self.assertEqual(f(x, y).shape, (3, 4))
|
|
self.assertAllFused(f.graph_for(x, y))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
|
@skipIfRocm
|
|
def test_chunk_cuda(self):
|
|
def fn(x):
|
|
a, b, c = x.chunk(3, 1)
|
|
return a * b + c
|
|
|
|
inputs = [torch.randn(10, 6, dtype=torch.float, device='cuda')]
|
|
|
|
ge = self.checkScript(fn, inputs)
|
|
graph = ge.graph_for(*inputs)
|
|
self.assertAllFused(graph)
|
|
FileCheck().check("prim::ConstantChunk[chunks=3, dim=1]").run(str(graph))
|
|
|
|
@staticmethod
|
|
def _test_chunk_correctness(self, device='cpu'):
|
|
def chunk_4_0(x):
|
|
x0, x1, x2, x3 = x.chunk(4, 0)
|
|
return x0 + x1 + x2 + x3
|
|
|
|
def chunk_4_1(x):
|
|
x0, x1, x2, x3 = x.chunk(4, 1)
|
|
return x0 + x1 + x2 + x3
|
|
|
|
def chunk_4_last(x):
|
|
x0, x1, x2, x3 = x.chunk(4, 2)
|
|
return x0 + x1 + x2 + x3
|
|
|
|
fns = [chunk_4_0, chunk_4_1, chunk_4_last]
|
|
tensors = [
|
|
# splitSize = 1
|
|
torch.randn(4, 4, 4, dtype=torch.float, device=device),
|
|
|
|
# contiguous case
|
|
torch.randn(12, 8, 16, dtype=torch.float, device=device),
|
|
|
|
# non-contiguous case
|
|
torch.randn(12, 8, 16, dtype=torch.float, device=device).transpose(1, 2),
|
|
]
|
|
|
|
for tensor in tensors:
|
|
for fn in fns:
|
|
self.checkScript(fn, [tensor])
|
|
|
|
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
|
@enable_cpu_fuser
|
|
def test_chunk_correctness(self):
|
|
return self._test_chunk_correctness(self, 'cpu')
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
|
def test_chunk_correctness_cuda(self):
|
|
return self._test_chunk_correctness(self, 'cuda')
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_chunk_distributes_cuda(self):
|
|
def f(x, y):
|
|
z1, z2 = (x + y).chunk(2, dim=1)
|
|
return z1 * z2
|
|
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
|
|
ge = self.checkTrace(f, (x, y))
|
|
graph = ge.graph_for(x, y)
|
|
FileCheck().check("broadcast_tensors").check('with prim::FusionGroup_0') \
|
|
.check_count('ConstantChunk', 2, exactly=True).run(str(graph))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_chunk_motion_deduplicates_inputs(self):
|
|
def func1(x):
|
|
z = x * x
|
|
z0, z1 = z.chunk(2)
|
|
return z0 * z1
|
|
|
|
def func2(x):
|
|
z = x * x * x
|
|
z0, z1 = z.chunk(2)
|
|
return z0 * z1
|
|
|
|
inputs = [
|
|
torch.tensor([1.1, 1.2], device='cuda', dtype=torch.float),
|
|
]
|
|
for func in [func1, func2]:
|
|
module = self.checkScript(func, inputs)
|
|
forward_graph = module.graph_for(*inputs)
|
|
self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
|
|
fusion_group = list(forward_graph.nodes())[-1]
|
|
self.assertEqual(len(list(fusion_group.inputs())), 1)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "No CUDA")
|
|
@skipIfRocm
|
|
def test_chunk_multiple_cuda(self):
|
|
# The arguments are intentionally used out of order as a test to see
|
|
# if the fusion compiler adds extra args in the correct order
|
|
def fn(s, x, y, z):
|
|
z1, z2 = z.chunk(2, 2)
|
|
x1, x2, x3 = x.chunk(3, 1)
|
|
y1, y2 = y.chunk(2, 0)
|
|
return s + x1 + x2 + x3 + y1 + y2 + z1 + z2
|
|
|
|
inputs = [
|
|
torch.randn(5, 2, 3, dtype=torch.float, device='cuda'),
|
|
torch.randn(5, 6, 3, dtype=torch.float, device='cuda'),
|
|
torch.randn(10, 2, 3, dtype=torch.float, device='cuda'),
|
|
torch.randn(5, 2, 6, dtype=torch.float, device='cuda'),
|
|
]
|
|
|
|
ge = self.checkScript(fn, inputs)
|
|
self.assertAllFused(ge.graph_for(*inputs))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_clamp(self):
|
|
def func2(a, b):
|
|
return torch.clamp(a + b, min=0, max=2)
|
|
|
|
def funcInf(a, b):
|
|
return torch.clamp(a + b, min=0, max=float('inf'))
|
|
|
|
def funcOptMin(a, b):
|
|
return torch.clamp(a + b, max=2)
|
|
|
|
def funcOptMax(a, b):
|
|
return torch.clamp(a + b, min=0)
|
|
|
|
a = torch.randn(4, 4, dtype=torch.float, device='cuda', requires_grad=True)
|
|
b = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
nan = torch.tensor(float('nan'))
|
|
|
|
funcs = (func2, funcInf, funcOptMin, funcOptMax)
|
|
for f, inputs in product(funcs, [[a, b], [a, nan]]):
|
|
inp1, inp2 = inputs
|
|
s = self.checkScript(f, (inp1, inp2))
|
|
self.assertAllFused(s.graph_for(inp1, inp2), except_for={'aten::size'})
|
|
|
|
c = s(inp1, inp2)
|
|
c.sum().backward()
|
|
graph = backward_graph(s)
|
|
self.assertAllFused(graph)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_comparison_eq_ne(self):
|
|
def f(x, y):
|
|
mask = (x == 0).type_as(x)
|
|
z = x * mask + y
|
|
mask = (x != 0).type_as(x)
|
|
z = z * mask + y
|
|
return z
|
|
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
|
|
ge = self.checkTrace(f, (x, y))
|
|
self.assertAllFused(ge.graph_for(x, y))
|
|
|
|
@staticmethod
|
|
def fn_test_comparison_gt_lt(x, y):
|
|
mask = (x > 0).type_as(x)
|
|
z = x * mask + y
|
|
mask = (x < 0).type_as(x)
|
|
z = z * mask + y
|
|
return z
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_comparison_gt_lt_cuda(self):
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
|
|
ge = self.checkTrace(self.fn_test_comparison_gt_lt, (x, y))
|
|
self.assertAllFused(ge.graph_for(x, y))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_comparison_ge_le_cuda(self):
|
|
def f(x, y):
|
|
mask = (x >= 0).type_as(x)
|
|
z = x * mask + y
|
|
mask = (x <= 0).type_as(x)
|
|
z = z * mask + y
|
|
return z
|
|
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
|
|
ge = self.checkTrace(f, (x, y))
|
|
self.assertAllFused(ge.graph_for(x, y))
|
|
x.requires_grad_(True)
|
|
y.requires_grad_(True)
|
|
self.assertAllFused(ge.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes"))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_addcmul_cuda(self):
|
|
t = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
|
t1 = torch.randn(4, 1, dtype=torch.float, device='cuda')
|
|
t2 = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
|
|
|
def foo(t, t1, t2):
|
|
return t.addcmul(t + 1, t2, value=0.1)
|
|
|
|
ge = self.checkTrace(foo, (t, t1, t2), allow_unused=True)
|
|
graph = ge.graph_for(t, t1, t2)
|
|
self.assertAllFused(graph)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_lerp_cuda(self):
|
|
start = torch.randn(4, 1, dtype=torch.float, device='cuda')
|
|
end = torch.randn(1, 4, dtype=torch.float, device='cuda')
|
|
weight = torch.tensor(0.5, dtype=torch.float, device='cuda')
|
|
|
|
# scalar weight overload
|
|
def foo_weight_scalar(start, end):
|
|
return torch.lerp(start + 1, end, 0.5)
|
|
|
|
# tensor weight overload
|
|
def foo_weight_tensor(start, end):
|
|
return torch.lerp(start + 1, end, weight)
|
|
|
|
ge_weight_scalar = self.checkTrace(foo_weight_scalar, (start, end))
|
|
graph = ge_weight_scalar.graph_for(start, end)
|
|
self.assertAllFused(graph)
|
|
|
|
ge_weight_tensor = self.checkTrace(foo_weight_tensor, (start, end))
|
|
graph = ge_weight_tensor.graph_for(start, end)
|
|
self.assertAllFused(graph)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_concat_cuda(self):
|
|
hx = torch.randn(3, 20, dtype=torch.float, device='cuda')
|
|
cx = torch.randn(3, 20, dtype=torch.float, device='cuda')
|
|
|
|
def foo(hx, cx):
|
|
return torch.cat((hx + cx, hx * cx))
|
|
|
|
ge = self.checkTrace(foo, (hx, cx))
|
|
graph = ge.graph_for(hx, cx)
|
|
self.assertAllFused(graph)
|
|
FileCheck().check("FusedConcat").check_next("return").run(str(graph))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_concat_invariant_cuda(self):
|
|
# Invariant: the output of prim::FusedConcat may
|
|
# not be an input to any node inside the FusionGroup.
|
|
def fn(x, y, z):
|
|
x1 = x + y
|
|
y1 = x - y
|
|
w = torch.cat([x1, y1])
|
|
return w + z
|
|
|
|
x = torch.randn(2, 2, dtype=torch.float, device='cuda')
|
|
y = torch.randn(2, 2, dtype=torch.float, device='cuda')
|
|
z = torch.randn(4, 2, dtype=torch.float, device='cuda')
|
|
ge = self.checkTrace(fn, (x, y, z))
|
|
graph = ge.graph_for(x, y, z)
|
|
self.assertAllFused(graph, except_for={'aten::add'})
|
|
FileCheck().check("FusedConcat").check_next("return").run(str(graph))
|
|
|
|
@staticmethod
|
|
def fn_test_exp(x, y):
|
|
return (x + .5 * y).exp()
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_exp_cuda(self):
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
|
|
ge = self.checkTrace(self.fn_test_exp, (x, y))
|
|
self.assertAllFused(ge.graph_for(x, y))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_fuse_batch_norm(self):
|
|
|
|
class ResLike(torch.jit.ScriptModule):
|
|
def __init__(self, optimize=True):
|
|
super(ResLike, self).__init__(optimize)
|
|
self.bn = nn.BatchNorm2d(16)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x, y):
|
|
return y + torch.relu(self.bn(x))
|
|
|
|
model = ResLike().cuda()
|
|
model_noopt = ResLike(optimize=False).cuda()
|
|
model_noopt.load_state_dict(model.state_dict())
|
|
x = torch.randn(2, 16, 8, 8, device='cuda')
|
|
y = torch.randn(2, 16, 8, 8, device='cuda')
|
|
# FIXME: We need differentiation for CNNs for this optimization to trigger
|
|
with torch.no_grad():
|
|
out = model(x, y)
|
|
graph = model.graph_for(x, y)
|
|
rep = str(graph)
|
|
|
|
out_noopt = model_noopt(x, y)
|
|
rep_noopt = str(model_noopt.graph_for(x, y))
|
|
self.assertEqual(out, out_noopt, prec=3e-5)
|
|
|
|
# Check that batch_norm has really been decomposed
|
|
self.assertIn('aten::batch_norm_update_stats', rep)
|
|
self.assertNotIn('aten::batch_norm(', rep)
|
|
self.assertIn('aten::batch_norm(', rep_noopt)
|
|
|
|
# Make sure the fusion group is big, and contains aten::sqrt, which could
|
|
# originate only from decomposing batch_norm in this case
|
|
fusion_groups = [node for node in graph.nodes() if node.kind() == 'prim::FusionGroup']
|
|
self.assertEqual(len(fusion_groups), 1)
|
|
fused_graph = fusion_groups[0].g('Subgraph')
|
|
self.assertTrue(any(node.kind() == 'aten::sqrt' for node in fused_graph.nodes()))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_threshold(self):
|
|
def f(x):
|
|
return torch.threshold(x, 0, -10) + x + x + x
|
|
|
|
x = torch.tensor([-1, -0.5, 0, 1, 2, 3], device='cuda')
|
|
scripted = torch.jit.script(f)
|
|
|
|
self.assertEqual(f(x), scripted(x))
|
|
self.assertAllFused(scripted.graph_for(x))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_scalar_arg_cuda(self):
|
|
def fn_test_scalar_arg(x, p):
|
|
# type: (Tensor, float) -> Tensor
|
|
return p * (x * x + x)
|
|
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
p = 3
|
|
scripted = torch.jit.script(fn_test_scalar_arg, (x, p))
|
|
self.assertEqual(fn_test_scalar_arg(x, p), scripted(x, p))
|
|
self.assertAllFused(scripted.graph_for(x, p))
|
|
x.requires_grad_(True)
|
|
out = scripted(x, p)
|
|
self.assertAllFused(scripted.graph_for(x, p), except_for=("aten::size", "prim::BroadcastSizes"))
|
|
|
|
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
|
@enable_cpu_fuser
|
|
def test_fuser_deduplication(self):
|
|
# See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation
|
|
# see the discussion in PR #14957.
|
|
def f(x, y):
|
|
return torch.sigmoid(x + y)
|
|
|
|
b = torch.randn(5, 5, requires_grad=True)
|
|
a = torch.randn(5, 5, requires_grad=True)
|
|
s = self.checkScript(f, (a, b))
|
|
self.assertAllFused(s.graph_for(a, b), except_for={'aten::size'})
|
|
|
|
c = s(a, b)
|
|
ga, gb = torch.autograd.grad(c.sum(), [a, b])
|
|
graph = backward_graph(s)
|
|
self.assertAllFused(graph)
|
|
# check that a, b share storage, i.e. were generated as a single output in the fuser
|
|
self.assertEqual(ga.data_ptr(), gb.data_ptr())
|
|
|
|
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
|
@enable_cpu_fuser
|
|
def test_fuser_iou(self):
|
|
# This checks if most of Intersection over Union is fused.
|
|
# In particular, the backward contains many _grad_sum_to_size.
|
|
def iou(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2):
|
|
ltx = torch.max(b1x1, b2x1) # [N,M]
|
|
lty = torch.max(b1y1, b2y1)
|
|
rbx = torch.min(b1x2, b2x2)
|
|
rby = torch.min(b1y2, b2y2)
|
|
|
|
w = (rbx - ltx).clamp(min=0, max=float('inf')) # [N,M]
|
|
h = (rby - lty).clamp(min=0, max=float('inf')) # [N,M]
|
|
inter = w * h # [N,M]
|
|
|
|
area1 = (b1x2 - b1x1) * (b1y2 - b1y2) # [N,1]
|
|
area2 = (b2x2 - b2x1) * (b2y2 - b2y2) # [1,M]
|
|
iou = inter / (area1 + area2 - inter)
|
|
return iou
|
|
|
|
box1 = torch.randn(5, 4, requires_grad=True)
|
|
box2 = torch.randn(5, 4, requires_grad=True)
|
|
# unsqueezing can currently not be fused
|
|
b1x1 = box1[:, 0].unsqueeze(1) # [N,1]
|
|
b1y1 = box1[:, 1].unsqueeze(1)
|
|
b1x2 = box1[:, 2].unsqueeze(1)
|
|
b1y2 = box1[:, 3].unsqueeze(1)
|
|
b2x1 = box2[:, 0].unsqueeze(0) # [1,N]
|
|
b2y1 = box2[:, 1].unsqueeze(0)
|
|
b2x2 = box2[:, 2].unsqueeze(0)
|
|
b2y2 = box2[:, 3].unsqueeze(0)
|
|
|
|
s = self.checkScript(iou, (b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2))
|
|
self.assertAllFused(s.graph_for(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2),
|
|
except_for={'aten::size', 'prim::BroadcastSizes'})
|
|
|
|
c = s(b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2)
|
|
torch.autograd.grad(c.sum(), [b1x1, b1y1, b1x2, b1y2, b2x1, b2y1, b2x2, b2y2])
|
|
graph = backward_graph(s)
|
|
self.assertAllFused(graph, except_for={'aten::size', 'prim::BroadcastSizes'})
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
|
|
@skipIfRocm
|
|
@enable_cpu_fuser
|
|
def test_fusion_reuse_multi_gpu(self):
|
|
def fn(x, y):
|
|
return x * y * x * y
|
|
|
|
inputs_cpu = [
|
|
torch.randn(4, 4, dtype=torch.float),
|
|
torch.randn(4, 4, dtype=torch.float),
|
|
]
|
|
inputs_cuda0 = [x.cuda(0) for x in inputs_cpu]
|
|
inputs_cuda1 = [y.cuda(1) for y in inputs_cpu]
|
|
|
|
# Should not crash; these should compile different kernels.
|
|
ge = self.checkScript(fn, inputs_cpu)
|
|
self.assertAllFused(ge.graph_for(*inputs_cpu))
|
|
ge(*inputs_cuda0)
|
|
ge(*inputs_cuda1)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
|
|
@skipIfRocm
|
|
@enable_cpu_fuser
|
|
def test_kernel_cache_multi_gpu(self):
|
|
def not_fusible(x):
|
|
return x
|
|
|
|
def fn(x, y, z):
|
|
x_out = x * x * x * x * x # fusion: lambda x. x * x * x * x * x
|
|
y_out = y * y * y * y * y
|
|
z_out = z * z * z * z * z
|
|
return not_fusible(x_out), not_fusible(y_out), not_fusible(z_out)
|
|
|
|
inputs = [
|
|
torch.randn(4, 4, dtype=torch.float),
|
|
torch.randn(4, 4, dtype=torch.float, device='cuda:0'),
|
|
torch.randn(4, 4, dtype=torch.float, device='cuda:1'),
|
|
]
|
|
|
|
prev_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
|
|
|
|
# There are 3 FusionGroups. Because they have the same graph, they
|
|
# should reuse the same KernelSpec in the KernelSpec cache.
|
|
ge = self.checkScript(fn, inputs)
|
|
self.assertGraphContainsExactly(
|
|
ge.graph_for(*inputs), 'prim::FusionGroup', 3, True)
|
|
new_cache_size = torch._C._jit_debug_fuser_num_cached_kernel_specs()
|
|
# XXX: This assumes that the same kernel isn't already used by another test
|
|
self.assertEqual(new_cache_size - prev_cache_size, 1)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "needs non-zero device")
|
|
@skipIfRocm
|
|
def test_nonzero_device_cuda(self):
|
|
device = 'cuda:' + str(1)
|
|
x = torch.tensor([0.4], dtype=torch.float, device=device)
|
|
y = torch.tensor([0.7], dtype=torch.float, device=device)
|
|
|
|
def doit(x, y):
|
|
return torch.sigmoid(torch.tanh(x * (x + y) + x))
|
|
|
|
ge = self.checkTrace(doit, (x, y))
|
|
self.assertAllFused(ge.graph_for(x, y))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_lstm_cuda(self):
|
|
inputs = get_lstm_inputs('cuda', training=True)
|
|
module = self.checkScript(LSTMCellS, inputs)
|
|
forward_graph = module.graph_for(*inputs)
|
|
self.assertGraphContainsExactly(
|
|
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
|
|
self.assertTrue(len(list(forward_graph.nodes())) == 2)
|
|
# Everything is differentiable but TupleConstruct return
|
|
FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
|
|
.check_next("return").run(str(forward_graph))
|
|
|
|
hy, cy = module(*inputs)
|
|
(hy + cy).sum().backward()
|
|
backward = backward_graph(module)
|
|
FileCheck().check("FusionGroup_0").check_next("FusionGroup_1") \
|
|
.check_not("FusionGroup_2").run(str(backward))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_lstm_concat_cuda(self):
|
|
inputs = get_lstm_inputs('cuda')
|
|
ge = self.checkTrace(LSTMCellC, inputs)
|
|
graph = ge.graph_for(*inputs)
|
|
FileCheck().check("FusedConcat").check_next("return").run(str(graph))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_lstm_gates_permutations_cuda(self):
|
|
# lstm has gates = x.mm(w_ih.t()) + hx.mm(w_hh.t()) + b_ih + b_hh.
|
|
# Test that any permutation of this will still result in one FusionGroup.
|
|
choices = ['x.mm(w_ih.t())', 'hx.mm(w_hh.t())', 'b_ih', 'b_hh']
|
|
template = dedent('''
|
|
def cell(x, hx, cx, w_ih, w_hh, b_ih, b_hh):
|
|
gates = {} + {} + {} + {}
|
|
ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
|
|
return ingate * forgetgate * cellgate * outgate
|
|
''')
|
|
for permutation in itertools.permutations(choices, len(choices)):
|
|
code = template.format(*permutation)
|
|
scope = {}
|
|
exec(code, globals(), scope)
|
|
cu = torch.jit.CompilationUnit(code)
|
|
|
|
inputs = get_lstm_inputs('cuda', training=False)
|
|
self.assertEqual(cu.cell(*inputs), scope['cell'](*inputs))
|
|
forward_graph = cu.cell.graph_for(*inputs)
|
|
self.assertGraphContainsExactly(forward_graph, 'prim::FusionGroup', 1)
|
|
|
|
# TODO: Fuser doesn't work at all when inputs require grad. Fix that
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_lstm_traced_cuda(self):
|
|
inputs = get_lstm_inputs('cuda')
|
|
ge = self.checkTrace(LSTMCellF, inputs)
|
|
graph = ge.graph_for(*inputs)
|
|
FileCheck().check_not("Chunk").check_not("aten::add").check_not("aten::sigmoid") \
|
|
.check_not("aten::tanh").check("FusionGroup").check_next("TupleConstruct") \
|
|
.check_next("return").check_not("FusionGroup_1").run(str(graph))
|
|
|
|
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
|
@unittest.skip("Test is flaky, see https://github.com/pytorch/pytorch/issues/8746")
|
|
@enable_cpu_fuser
|
|
def test_lstm_traced_cpu(self):
|
|
inputs = get_lstm_inputs('cpu')
|
|
try:
|
|
ge = self.checkTrace(LSTMCellF, inputs)
|
|
graph = ge.graph_for(*inputs)
|
|
FileCheck.check("FusionGroup").run(str(graph))
|
|
except RuntimeError as e:
|
|
if 'Failed to compile' in e.args[0]:
|
|
warnings.warn('CPU fuser test has failed! This is not a hard failure, '
|
|
'because the kernels sometimes trigger bugs in compilers '
|
|
'(most notably GCC 7.2).')
|
|
raise unittest.SkipTest('Failed to compile')
|
|
else:
|
|
raise
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_milstm_cuda(self):
|
|
inputs = get_milstm_inputs('cuda', training=True)
|
|
module = self.checkScript(MiLSTMCell, inputs)
|
|
forward_graph = module.graph_for(*inputs)
|
|
self.assertGraphContainsExactly(
|
|
forward_graph, 'prim::FusionGroup', 1, consider_subgraphs=True)
|
|
FileCheck().check("DifferentiableGraph").check_next("TupleConstruct") \
|
|
.check_next("return").check("FusionGroup").run(str(forward_graph))
|
|
hy, cy = module(*inputs)
|
|
(hy + cy).sum().backward()
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_rand_cuda(self):
|
|
class M(torch.jit.ScriptModule):
|
|
__constants__ = ['d']
|
|
|
|
def __init__(self):
|
|
self.d = torch.device('cuda')
|
|
|
|
@torch.jit.script_method
|
|
def create(self, x):
|
|
return x * x + x + torch.rand_like(x)
|
|
|
|
x = torch.zeros([3, 4, 5], dtype=torch.float, device='cuda')
|
|
m = M()
|
|
out1 = m.create(x)
|
|
out2 = m.create(x)
|
|
self.assertNotEqual(out1, out2)
|
|
self.assertTrue(torch.all(out1 >= 0))
|
|
self.assertTrue(torch.all(out1 < 1))
|
|
self.assertTrue(torch.all(out2 >= 0))
|
|
self.assertTrue(torch.all(out2 < 1))
|
|
self.assertAllFused(m.create.graph_for(x))
|
|
|
|
@staticmethod
|
|
def fn_test_relu(x, y):
|
|
return F.relu(x + .5 * y)
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_relu_cuda(self):
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
|
|
ge = self.checkTrace(self.fn_test_relu, (x, y))
|
|
self.assertAllFused(ge.graph_for(x, y))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_erf_cuda(self):
|
|
def fn_test_erf(x):
|
|
return F.relu(torch.erf(x) - torch.erfc(x))
|
|
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
ge = self.checkTrace(fn_test_erf, (x,))
|
|
self.assertAllFused(ge.graph_for(x))
|
|
x.requires_grad_(True)
|
|
self.assertAllFused(ge.graph_for(x), except_for=("aten::size", "prim::BroadcastSizes"))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_rand_broadcast_cuda(self):
|
|
def fn_test_rand(x, y):
|
|
r = torch.rand_like(y)
|
|
return r * x + x
|
|
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
script_f = torch.jit.script(fn_test_rand, (x, y))
|
|
out = script_f(x, y)
|
|
self.assertAllFused(script_f.graph_for(x, y))
|
|
x.requires_grad_(True)
|
|
out = script_f(x, y)
|
|
self.assertAllFused(script_f.graph_for(x, y), except_for=("aten::size", "prim::BroadcastSizes"))
|
|
# test that broadcasting random produces correct results
|
|
x = torch.ones(4, 4, dtype=torch.float, device='cuda')
|
|
y = torch.ones(4, dtype=torch.float, device='cuda')
|
|
out = script_f(x, y)
|
|
self.assertEqual(out[0], out[1])
|
|
|
|
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
|
@enable_cpu_fuser
|
|
def test_scalar(self):
|
|
def fn(x, y):
|
|
return 2 * x + y
|
|
|
|
x = torch.tensor(0.1, dtype=torch.float, device='cpu')
|
|
y = torch.tensor(1, dtype=torch.float, device='cpu')
|
|
ge = self.checkScript(fn, (x, y))
|
|
self.assertAllFused(ge.graph_for(x, y))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_small_constant_cuda(self):
|
|
def fn_test_small_constant(x, y):
|
|
return (1e-8 * x + 5e-9 * y) * 1e8
|
|
x = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
y = torch.randn(4, 4, dtype=torch.float, device='cuda')
|
|
|
|
ge = self.checkTrace(fn_test_small_constant, (x, y))
|
|
self.assertAllFused(ge.graph_for(x, y))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "NYI: fuser support for Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
@skipIfRocm
|
|
def test_tensor_scalar_ops_cuda(self):
|
|
def should_fuse(x):
|
|
z = 3.
|
|
y = x + z
|
|
return x * y
|
|
|
|
# XXX: right now we only support fusing scalars if
|
|
# they're constant (#9940)
|
|
def should_not_fuse(x, z):
|
|
y = x + int(z)
|
|
return x * y
|
|
|
|
inputs = [torch.randn(2, 2, dtype=torch.float, device='cuda')]
|
|
ge = self.checkScript(should_fuse, inputs)
|
|
self.assertAllFused(ge.graph_for(*inputs))
|
|
|
|
inputs = [
|
|
torch.randn(2, 2, dtype=torch.float, device='cuda'),
|
|
torch.tensor(3., dtype=torch.float, device='cuda'),
|
|
]
|
|
ge = self.checkScript(should_not_fuse, inputs)
|
|
self.assertGraphContainsExactly(
|
|
ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
|
|
|
|
@unittest.skipIf(IS_WINDOWS or IS_SANDCASTLE, "NYI: fuser support for Windows or Sandcastle")
|
|
@enable_cpu_fuser
|
|
def test_where_and_typing(self):
|
|
def f(x, y):
|
|
mask = x > y
|
|
res = torch.where(mask, x, y)
|
|
return mask, res
|
|
|
|
script_f = torch.jit.script(f)
|
|
|
|
x = torch.randn(4, 4, dtype=torch.double)
|
|
y = torch.randn(4, 4, dtype=torch.double)
|
|
|
|
result1, result2 = script_f(x, y)
|
|
expected1, expected2 = f(x, y)
|
|
self.assertEqual(result1, expected1)
|
|
self.assertEqual(result2, expected2)
|
|
self.assertAllFused(script_f.graph_for(x, y), except_for={'prim::TupleConstruct'})
|
|
|
|
@unittest.skipIf(not IS_WINDOWS, "Test that the fuser is disabled on Windows")
|
|
@unittest.skipIf(not RUN_CUDA, "fuser requires CUDA")
|
|
def test_windows_cuda(self):
|
|
def scaleshift(x, scale, shift):
|
|
return x * scale + shift
|
|
|
|
inputs = [
|
|
torch.randn(4, 4, dtype=torch.float, device='cuda'),
|
|
torch.randn(4, dtype=torch.float, device='cuda'),
|
|
torch.randn(4, dtype=torch.float, device='cuda'),
|
|
]
|
|
|
|
ge = self.checkScript(scaleshift, inputs)
|
|
self.assertGraphContainsExactly(
|
|
ge.graph_for(*inputs), 'prim::FusionGroup', 0, consider_subgraphs=True)
|
|
|
|
|
|
# NB: torch.jit.script, when used as a function, uses the current scope
|
|
# to resolve variable names. This function cannot be made local to
|
|
# TestAutodiffSubgraphSlicing because those tests call torch.jit.script on functions
|
|
# in a different scope than they are defined in.
|
|
def pyfn(a, b):
|
|
return a * b
|
|
|
|
|
|
class TestAutodiffSubgraphSlicing(JitTestCase):
|
|
# TODO: It is better if we can test directly on graphs instead of the current
|
|
# end-to-end fashion.
|
|
def _perform_ad_subgraph_slicing(self, fn, *input_sizes):
|
|
ge = torch.jit.script(fn)
|
|
ge.debug_disable_autodiff_subgraph_inlining()
|
|
inputs = [torch.randn(size, requires_grad=True) for size in input_sizes]
|
|
ge(*inputs)
|
|
return ge.graph_for(*inputs)
|
|
|
|
def assertGraphSize(self, graph, size):
|
|
self.assertEqual(len(list(graph.nodes())), size)
|
|
|
|
def test_simple_merge(self):
|
|
# o --> o
|
|
def fn(x, y, z):
|
|
a = x * y
|
|
b = a * z
|
|
return b
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
|
|
|
self.assertGraphSize(graph, 1)
|
|
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
|
|
|
def test_simple_no_merge(self):
|
|
# o: autodiff supported. x: not autodiff supported.
|
|
# o --> x
|
|
def fn(x, y, z):
|
|
a = x * y
|
|
b = pyfn(a, z)
|
|
return b
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
|
|
|
self.assertGraphSize(graph, 2)
|
|
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
|
|
|
def test_does_not_merge_unrelated(self):
|
|
# o o
|
|
def fn(w, x, y, z):
|
|
a = x * y
|
|
b = w * z
|
|
return a, b
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
|
|
|
|
self.assertGraphSize(graph, 3)
|
|
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
|
|
|
|
def test_merges_without_cycles(self):
|
|
# o --> o --> o
|
|
# | ^
|
|
# \_________/
|
|
def fn(w, x, y):
|
|
a = w * x
|
|
b = a * y
|
|
c = a * b
|
|
return c
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
|
|
|
self.assertGraphSize(graph, 1)
|
|
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
|
|
|
def test_merges_dense(self):
|
|
# o o
|
|
# |\ /|
|
|
# | \ / |
|
|
# | /\ |
|
|
# vv vv
|
|
# o o
|
|
def fn(x, y):
|
|
a, b = x.chunk(2)
|
|
c, d = y.chunk(2)
|
|
return a + c, b + d
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 2, 2)
|
|
|
|
self.assertGraphSize(graph, 2)
|
|
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
|
|
|
def test_does_not_create_cycles(self):
|
|
# o --> x --> o
|
|
# | ^
|
|
# \_________/
|
|
def fn(w, x, y):
|
|
a = w * x
|
|
b = pyfn(a, y)
|
|
c = a * b
|
|
return c
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1)
|
|
|
|
self.assertGraphSize(graph, 3)
|
|
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
|
|
|
|
def test_merges_up(self):
|
|
# o --> x o
|
|
# | ^
|
|
# \_________/
|
|
def fn(w, x, y, z):
|
|
a = w * x
|
|
b = pyfn(a, y)
|
|
c = a * z
|
|
return b, c
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
|
|
|
|
self.assertGraphSize(graph, 3)
|
|
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
|
|
|
def test_merges_down(self):
|
|
# o x --> o
|
|
# | ^
|
|
# \_________/
|
|
def fn(v, w, x, y):
|
|
a = v * w
|
|
b = pyfn(x, y)
|
|
c = b * a
|
|
return a, c
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1, 1, 1)
|
|
|
|
self.assertGraphSize(graph, 3)
|
|
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 1)
|
|
|
|
def test_respects_lexical_scoping(self):
|
|
def fn(x, k):
|
|
y = x * 1.1
|
|
if bool(k):
|
|
k = k + y
|
|
z = y * k
|
|
return z, k
|
|
|
|
graph = self._perform_ad_subgraph_slicing(fn, 1, 1)
|
|
|
|
# We should not have combined the two multiplications into
|
|
# the same group; they should each be a separate DiffGraph
|
|
self.assertGraphContainsExactly(graph, 'prim::DifferentiableGraph', 2)
|
|
|
|
|
|
class TestCustomOperators(JitTestCase):
|
|
|
|
def test_dynamic_op_registry(self):
|
|
from torch._ops import _OpNamespace
|
|
self.assertTrue(hasattr(torch, 'ops'))
|
|
|
|
if '_test' in torch.ops.__dict__:
|
|
torch.ops.__dict__.pop('_test')
|
|
|
|
# Don't use `hasattr()` because it will call `__getattr__`.
|
|
self.assertNotIn('_test', torch.ops.__dict__)
|
|
torch.ops._test
|
|
self.assertIn('_test', torch.ops.__dict__)
|
|
self.assertEqual(type(torch.ops._test), _OpNamespace)
|
|
|
|
self.assertNotIn('leaky_relu', torch.ops._test.__dict__)
|
|
op = torch.ops._test.leaky_relu
|
|
self.assertTrue(callable(op))
|
|
self.assertIn('leaky_relu', torch.ops._test.__dict__)
|
|
op2 = torch.ops._test.leaky_relu
|
|
self.assertEqual(op, op2)
|
|
|
|
def test_simply_calling_an_operator(self):
|
|
input = torch.randn(100)
|
|
output = torch.ops.aten.relu(input)
|
|
self.assertEqual(output, input.relu())
|
|
|
|
def test_default_arguments_are_used(self):
|
|
output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
|
|
self.assertEqual(output, torch.tensor([-0.01, 1]))
|
|
|
|
def test_only_kwargs(self):
|
|
output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0))
|
|
self.assertEqual(output, torch.tensor(-0.01))
|
|
|
|
def test_passing_too_many_args(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"aten::relu\(\) expected at most 1 argument\(s\) but received 2 argument\(s\)"
|
|
):
|
|
torch.ops.aten.relu(1, 2)
|
|
|
|
def test_passing_too_few_args(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"aten::relu\(\) is missing value for argument 'self'."
|
|
):
|
|
torch.ops.aten.relu()
|
|
|
|
def test_passing_one_positional_but_not_the_second(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
r"aten::transpose\(\) is missing value for argument 'dim0'."
|
|
):
|
|
torch.ops.aten.transpose(torch.ones(5, 5))
|
|
|
|
def test_passing_an_argument_both_as_positional_and_kwarg(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Argument 'self' specified both as positional and keyword argument"
|
|
):
|
|
torch.ops._test.leaky_relu(torch.ones(5), self=torch.ones(5))
|
|
|
|
def test_passing_unknown_kwargs(self):
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"Unknown keyword argument 'foo' for operator '_test::leaky_relu'"
|
|
):
|
|
torch.ops._test.leaky_relu(torch.ones(5), foo=torch.ones(5))
|
|
|
|
def test_passing_and_returning_lists(self):
|
|
# Replace with actual test once we support lists.
|
|
a, b = torch.rand(5), torch.rand(5)
|
|
output = torch.ops._test.cat([a, b])
|
|
output_ref = torch.cat([a, b])
|
|
self.assertEqual(output, output_ref)
|
|
|
|
def test_calling_scripted_custom_op(self):
|
|
@torch.jit.script
|
|
def func(x):
|
|
return torch.ops.aten.relu(x)
|
|
input = torch.ones(5, 5)
|
|
self.assertEqual(func(input), input.relu())
|
|
|
|
def test_calling_traced_custom_op(self):
|
|
input = torch.ones(5, 5)
|
|
func = torch.jit.trace(torch.ops.aten.relu, [input])
|
|
self.assertEqual(func(input), input.relu())
|
|
|
|
def test_script_graph_for_custom_ops_matches_traced_graph(self):
|
|
input = torch.ones(5, 5)
|
|
trace = torch.jit.trace(torch.ops.aten.relu, [input])
|
|
self.assertExpectedInline(canonical(trace.graph), '''\
|
|
graph(%0 : Double(5, 5)):
|
|
%1 : Double(5, 5) = aten::relu(%0)
|
|
return (%1)
|
|
''')
|
|
|
|
def test_script_graph_contains_custom_op(self):
|
|
@torch.jit.script
|
|
def func(x):
|
|
return torch.ops.aten.relu(x)
|
|
self.assertExpectedInline(canonical(func.graph), '''\
|
|
graph(%x : Tensor):
|
|
%1 : Tensor = aten::relu(%x)
|
|
return (%1)
|
|
''')
|
|
|
|
def test_generic_list(self):
|
|
self.assertEqual(torch.ops._test.get_first([['hello']]), 'hello')
|
|
|
|
|
|
class TestJitGeneratedAutograd(JitTestCase):
|
|
pass
|
|
|
|
|
|
class TestJitGeneratedModule(JitTestCase):
|
|
pass
|
|
|
|
|
|
class TestJitGeneratedFunctional(JitTestCase):
|
|
pass
|
|
|
|
|
|
# UBSAN per-function exclusions don't seem to work with OpenMP pragmas,
|
|
# and we have to disable the failing tests here instead.
|
|
UBSAN_BLACKLISTED_TESTS = [
|
|
"test___rdiv___constant",
|
|
"test___rdiv___scalar_constant",
|
|
"test_addcdiv",
|
|
"test_addcdiv_broadcast_all",
|
|
"test_addcdiv_broadcast_rhs",
|
|
"test_addcdiv_scalar",
|
|
"test_addcdiv_scalar_broadcast_lhs",
|
|
"test_addcdiv_scalar_broadcast_rhs",
|
|
"test_addcdiv_scalar_scale",
|
|
"test_addcdiv_scalar_scale_broadcast_lhs",
|
|
"test_addcdiv_scalar_scale_broadcast_rhs",
|
|
"test_addcdiv_scale",
|
|
"test_addcdiv_scale_broadcast_all",
|
|
"test_addcdiv_scale_broadcast_rhs",
|
|
"test_add_broadcast_all",
|
|
"test_add_broadcast_lhs",
|
|
"test_add_broadcast_rhs",
|
|
"test_add_constant",
|
|
"test_add_scalar",
|
|
"test_add_scalar_broadcast_lhs",
|
|
"test_add_scalar_broadcast_rhs",
|
|
"test_div",
|
|
"test_div_broadcast_all",
|
|
"test_div_broadcast_lhs",
|
|
"test_div_broadcast_rhs",
|
|
"test_div_scalar",
|
|
"test_div_scalar_broadcast_lhs",
|
|
"test_div_scalar_broadcast_rhs",
|
|
"test_rsqrt",
|
|
"test_rsqrt_scalar",
|
|
"test_add",
|
|
"test_reciprocal",
|
|
"test_reciprocal_scalar",
|
|
]
|
|
|
|
L = 20
|
|
M = 10
|
|
S = 5
|
|
|
|
# module cannot be exported /imported currently
|
|
EXCLUDE_MODULE_EXPORT_IMPORT = {
|
|
'EmbeddingBag',
|
|
'MaxPool1d',
|
|
'MaxPool2d',
|
|
'MaxPool3d',
|
|
'AdaptiveAvgPool2d',
|
|
'AdaptiveAvgPool3d',
|
|
'Fold',
|
|
'Unfold',
|
|
}
|
|
|
|
# NB: JIT script tests for all nn functional interfaces, script mode does
|
|
# not support in_place operations yet, so no inplace operation tests added.
|
|
# removed all the deprecated functions
|
|
#
|
|
# (
|
|
# method name,
|
|
# input size/constructing fn,
|
|
# args (tuple represents shape of a tensor arg),
|
|
# test variant name(will be used at test name suffix,
|
|
# 'inplace' skips grad tests), // optional
|
|
# fn to determine if test should be skipped, // optional
|
|
# fn mapping output to part that should be gradcheck'ed, // optional
|
|
# kwargs for function, // optional
|
|
# )
|
|
nn_functional_tests = [
|
|
('conv1d', (S, S, S), ((S, S, S),)),
|
|
('conv2d', (S, S, S, S), ((S, S, S, S),)),
|
|
('conv3d', (S, S, S, S, S), ((S, S, S, S, S),)),
|
|
('conv_transpose1d', (S, S, S), ((S, S, S),)),
|
|
('conv_transpose2d', (S, S, S, S), ((S, S, S, S),)),
|
|
('conv_transpose3d', (S, S, S, S, S), ((S, S, S, S, S),)),
|
|
('conv_tbc', (S, S, S), ((S, S, S), (S,), 2)),
|
|
('avg_pool1d', (S, S, S), (3,)),
|
|
('avg_pool2d', (S, S, S, S), (3,)),
|
|
('avg_pool3d', (S, S, S, S, S), (3,)),
|
|
('fractional_max_pool2d', (S, S, S, S), (3, [2, 3],)),
|
|
('max_pool1d', (S, S, S), (2, 1)),
|
|
('max_pool1d', (S, S, S), (2, 1, 1, 1, False, True), 'with_indices'),
|
|
('max_pool2d', (S, S, S, S), (2, 1)),
|
|
('max_pool3d', (S, S, S, S, S), (2, 1)),
|
|
('max_unpool1d', torch.tensor([[[2., 4]]]), (torch.tensor([[[1, 3]]]), 2, 2, 0)),
|
|
('max_unpool2d', torch.tensor([[[[2., 4]]]]), (torch.tensor([[[[1, 3]]]]), 2, 2, 0)),
|
|
('max_unpool3d', torch.tensor([[[[[2., 4]]]]]), (torch.tensor([[[[[1, 3]]]]]), 2, 2, 0)),
|
|
('lp_pool1d', (S, S, S), (2., 3, 2,)),
|
|
('lp_pool2d', (S, S, S, S), (2., 3, 2,)),
|
|
('adaptive_max_pool1d', (S, S, S), (5,)),
|
|
('adaptive_max_pool2d', (S, S, S, S), ([5, 7],)),
|
|
('adaptive_max_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
|
|
('adaptive_avg_pool1d', (S, S, S), (5,)),
|
|
('adaptive_avg_pool2d', (S, S, S, S), ([5, 7],)),
|
|
('adaptive_avg_pool3d', (S, S, S, S, S), ([3, 2, 2],)),
|
|
('dropout', (S, S, S), (0.5,)),
|
|
('alpha_dropout', (S, S, S), (0.5,)),
|
|
('dropout2d', (S, S, S), (0.5,)),
|
|
('dropout3d', (S, S, S), (0.5,)),
|
|
('feature_alpha_dropout', (S, S, S), (0.5,)),
|
|
('threshold', (S, S, S), (0.1, 2.),),
|
|
('threshold', (S, S, S), (0.1, 2., True), 'inplace'),
|
|
('relu', (S, S, S), (),),
|
|
('relu', (S, S, S), (), 'inplace'),
|
|
('glu', (S - 1, S - 1, S - 1), (),),
|
|
('hardtanh', (S, S, S), (-0.5, 0.5),),
|
|
('hardtanh', (S, S, S), (-0.5, 0.5, True), 'inplace'),
|
|
('relu6', (S, S, S), (),),
|
|
('relu6', (S, S, S), (True), 'inplace'),
|
|
('elu', (S, S, S), (0.9,),),
|
|
('elu', (S, S, S), (0.9, True), 'inplace'),
|
|
('selu', (S, S, S), (),),
|
|
('selu', (S, S, S), (True), 'inplace'),
|
|
('celu', (S, S, S), (0.9,),),
|
|
('celu', (S, S, S), (0.9, True), 'inplace'),
|
|
('leaky_relu', (S, S, S), (0.02,),),
|
|
('leaky_relu', (S, S, S), (0.02,), 'inplace'),
|
|
('rrelu', (S, S), (0.1, 0.3, False),),
|
|
('rrelu', (S, S), (0.1, 0.3, False, True), 'inplace'),
|
|
('hardshrink', (S, S, S), (0.4,),),
|
|
('tanhshrink', (S, S, S), (),),
|
|
('softsign', (S, S, S), (),),
|
|
('softplus', (S, S, S), (),),
|
|
('softmin', (S, S, S), (0,),),
|
|
('softmax', (S, S, S), (0,),),
|
|
('softmax', (S, S, S), (0, 3, torch.double), 'with_all_args'),
|
|
('tanh', (S, S, S), (),),
|
|
('sigmoid', (S, S, S), (),),
|
|
('log_softmax', (S, S, S), (0,),),
|
|
('linear', (S, S), ((M, S),),),
|
|
('bilinear', (S, S, S), ((S, S, M), torch.zeros(M, S, M),),),
|
|
('embedding', torch.tensor([[1, 2, 4, 5], [4, 3, 2, 5]]), (torch.rand(6, 3), ),),
|
|
('embedding_bag', torch.tensor([1, 2, 4, 2]), (torch.rand(5, 3), torch.tensor([0, 4]),),),
|
|
('batch_norm', (S, S), (non_differentiable(torch.randn(S)), non_differentiable(torch.ones(S)), ),),
|
|
('instance_norm', (S, S, S), (non_differentiable(torch.zeros(S)), non_differentiable(torch.ones(S))),),
|
|
('layer_norm', (S, S, S, S), ([5],),),
|
|
('layer_norm', (S, S, S, S), ([5], (S,)), 'with_only_weight'),
|
|
('layer_norm', (S, S, S, S), ([5], None, (S,)), 'with_only_bias'),
|
|
('layer_norm', (S, S, S, S), ([5], (S,), (S,)), 'with_weight_and_bias'),
|
|
('group_norm', (S, S, S), (1, torch.rand(5),),),
|
|
('local_response_norm', (S, S, S), (2, ),),
|
|
('nll_loss', F.log_softmax(torch.randn(3, 5), dim=0), (torch.tensor([1, 0, 4]),),),
|
|
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2),),),
|
|
('poisson_nll_loss', torch.rand(S, 2), (torch.rand(S, 2), True, True), 'full'),
|
|
('kl_div', F.log_softmax(torch.randn(S, 10), 1), (F.softmax(torch.randn(S, 10), 1),),),
|
|
('cross_entropy', (3, S), (torch.randint(S, (3,), dtype=torch.int64),),),
|
|
('binary_cross_entropy_with_logits', (3,), (torch.empty(3).random_(2), ),),
|
|
('smooth_l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
|
('l1_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
|
('mse_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
|
('smooth_l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
|
('l1_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
|
('mse_loss', (3, S), ((torch.rand(3, S)),), 'with_grad'),
|
|
('margin_ranking_loss', (3, S), ((3, S), (S,)),),
|
|
('hinge_embedding_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
|
('soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
|
('multilabel_soft_margin_loss', (3, S), (non_differentiable(torch.rand(3, S)),),),
|
|
('cosine_embedding_loss', (S, S), ((S, S), non_differentiable(torch.rand(S,))),),
|
|
('pixel_shuffle', (1, 9, 4, 4), (3,),),
|
|
('affine_grid', (S, 2, 3), (torch.Size([S, 1, 7, 7]),),),
|
|
('pad', (3, 3, 4, 2), ([1, 1],),),
|
|
('pairwise_distance', (S, S), ((S, S),),),
|
|
('pdist', (S, S), (),),
|
|
('cosine_similarity', (S, S), ((S, S),),),
|
|
('triplet_margin_loss', (S, S), ((S, S), (S, S)),),
|
|
('normalize', (S, S, S), (),),
|
|
('unfold', (S, S, S, S), ([2, 3]),),
|
|
('fold', (1, 3 * 2 * 2, 12), ([4, 5], [2, 2]),),
|
|
('grid_sample', (S, S, S, S), (non_differentiable(torch.rand(S, S, S, 2)),),),
|
|
('gumbel_softmax', (S, S), (2.,),),
|
|
('gumbel_softmax', (S, S), (2., True,), 'hard'),
|
|
('multilabel_margin_loss', torch.tensor([[0.2, -0.2, 0.07]]), (torch.tensor([[0, 0, 1]]),),),
|
|
('multi_margin_loss', (S, S), (non_differentiable(torch.randint(S, (S, ), dtype=torch.int64)),
|
|
1, 1., non_differentiable(torch.randn(S))),),
|
|
('binary_cross_entropy', torch.randn(3, 2).sigmoid(), (non_differentiable(torch.rand(3, 2)),
|
|
non_differentiable(torch.randn(3, 2))),),
|
|
('binary_cross_entropy', torch.randn(3, 2).sigmoid(),
|
|
(non_differentiable(torch.rand(3, 2)),
|
|
non_differentiable(torch.randn(3, 2)), None, None, 'mean'), 'size_average'),
|
|
('ctc_loss', torch.rand(S, S, S).log_softmax(2).detach().requires_grad_(),
|
|
(torch.randint(1, S, (S, S), dtype=torch.long), torch.full((S,), S, dtype=torch.long),
|
|
torch.randint(1, S, (S,), dtype=torch.long))),
|
|
('upsample', torch.randn(S, S, M, M), (None, 2), 'with_scale'),
|
|
('upsample', torch.randn(S, S, M, M), (4,), 'with_size'),
|
|
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'nearest_4d'),
|
|
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'nearest_4d_with_scale'),
|
|
('interpolate', torch.randn(S, S, M, M), (4,), 'nearest_4d_with_size'),
|
|
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'area_4d'),
|
|
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'area_4d_with_scale'),
|
|
('interpolate', torch.randn(S, S, M, M), (4,), 'area_4d_with_size'),
|
|
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bilinear_4d'),
|
|
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bilinear_4d_with_scale'),
|
|
('interpolate', torch.randn(S, S, M, M), (4,), 'bilinear_4d_with_size'),
|
|
('interpolate', torch.zeros(3, 3).view(1, 1, 3, 3), (2,), 'bicubic_4d'),
|
|
('interpolate', torch.randn(S, S, M, M), (None, 2.), 'bicubic_4d_with_scale'),
|
|
('interpolate', torch.randn(S, S, M, M), (4,), 'bicubic_4d_with_size'),
|
|
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'nearest_3d'),
|
|
('interpolate', torch.randn(S, M, M), (None, 2.), 'nearest_3d_with_scale'),
|
|
('interpolate', torch.randn(S, M, M), (4,), 'nearest_3d_with_size'),
|
|
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'area_3d'),
|
|
('interpolate', torch.randn(S, M, M), (None, 2.), 'area_3d_with_scale'),
|
|
('interpolate', torch.randn(S, M, M), (4,), 'area_3d_with_size'),
|
|
('interpolate', torch.zeros(3, 3).view(1, 3, 3), (2,), 'linear_3d'),
|
|
('interpolate', torch.randn(S, M, M), (None, 2.), 'linear_3d_with_scale'),
|
|
('interpolate', torch.randn(S, M, M), (4,), 'linear_3d_with_size'),
|
|
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'nearest_5d_with_scale'),
|
|
('interpolate', torch.randn(S, M, M, M, M), (4,), 'nearest_5d_with_size'),
|
|
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'area_5d'),
|
|
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'area_5d_with_scale'),
|
|
('interpolate', torch.randn(S, M, M, M, M), (4,), 'area_5d_with_size'),
|
|
('interpolate', torch.zeros(3, 3, 3).view(1, 1, 3, 3, 3), (2,), 'trilinear_5d'),
|
|
('interpolate', torch.randn(S, M, M, M, M), (None, 2.), 'trilinear_5d_with_scale'),
|
|
('interpolate', torch.randn(S, M, M, M, M), (4,), 'trilinear_5d_with_size'),
|
|
]
|
|
|
|
|
|
# Test names in this set are only checked for a single derivative
|
|
nn_functional_single_grad = frozenset('test_nn_' + name for name in [
|
|
'pdist',
|
|
'multilabel_margin_loss',
|
|
'max_unpool3d',
|
|
'multi_margin_loss',
|
|
'binary_cross_entropy',
|
|
'binary_cross_entropy_size_average',
|
|
'ctc_loss',
|
|
'grid_sample',
|
|
])
|
|
|
|
# additional modules test
|
|
# TODO: delete this list once we make all nn_tests work
|
|
additional_module_tests = [
|
|
{
|
|
'module_name': 'Bilinear',
|
|
'constructor_args': (S, S, M),
|
|
'input_size': (S, S),
|
|
'extra_args': ((S, S),)
|
|
},
|
|
{
|
|
'module_name': 'RNNCell',
|
|
'constructor_args': (S, S),
|
|
'input_size': (S, S),
|
|
},
|
|
{
|
|
'module_name': 'LSTMCell',
|
|
'constructor_args': (S, S),
|
|
'input_size': (S, S),
|
|
},
|
|
{
|
|
'module_name': 'GRUCell',
|
|
'constructor_args': (S, S),
|
|
'input_size': (S, S),
|
|
},
|
|
]
|
|
|
|
|
|
def add_autograd_test(
|
|
name,
|
|
self_size,
|
|
args,
|
|
variant_name='',
|
|
dim_args_idx=(),
|
|
skipTestIf=(),
|
|
output_process_fn=lambda x: x,
|
|
kwargs=None):
|
|
basic_test_name = 'test_' + name
|
|
if variant_name != '':
|
|
basic_test_name += '_' + variant_name
|
|
|
|
for dim_perm in product([-1, 1], repeat=len(dim_args_idx)):
|
|
test_name = basic_test_name
|
|
new_args = [arg * dim_perm[dim_args_idx.index(i)] if i in dim_args_idx else arg for i, arg in enumerate(args)]
|
|
test_name = basic_test_name + ''.join('_neg' + str(i) for i, idx in enumerate(dim_perm) if idx < 0)
|
|
new_args = tuple(new_args)
|
|
|
|
# for-loop bodies don't define scopes, so we have to save the variables
|
|
# we want to close over in some way
|
|
def do_test(self, name=name, self_size=self_size, args=new_args, test_name=test_name,
|
|
output_process_fn=output_process_fn):
|
|
def check(name):
|
|
set_rng_seed(2)
|
|
is_magic_method = name[:2] == '__' and name[-2:] == '__'
|
|
is_inplace = name[-1] == "_" and not is_magic_method
|
|
self_variable = create_input((self_size,))[0][0]
|
|
# FixMe: run grad checks on inplace self
|
|
if is_inplace:
|
|
self_variable.requires_grad = False
|
|
# need to record this because methods can change the size (e.g. unsqueeze)
|
|
args_variable, kwargs_variable = create_input(args, requires_grad=not is_inplace, call_kwargs=kwargs)
|
|
self_tensor = deepcopy(self_variable.data)
|
|
args_tensor = deepcopy(unpack_variables(args_variable))
|
|
|
|
def fn(*inputs, **kwargs):
|
|
output = getattr(inputs[0], name)(*inputs[1:], **kwargs)
|
|
return output_process_fn(output)
|
|
|
|
check_types = test_name not in EXCLUDE_TYPE_CHECK
|
|
|
|
if not is_inplace and name not in EXCLUDE_GRADCHECK and not exclude_tensor_method(name, test_name):
|
|
# Test with disable_autodiff_subgraph_inlining, which forces the graph
|
|
# to contain DifferentiableGraph nodes whenever possible. This allows us
|
|
# to test autodiff; we assume that autograd is correct and use autodiff for backprop
|
|
if test_name not in EXCLUDE_TRACED:
|
|
check_against_reference(self,
|
|
create_traced_fn(self, fn,
|
|
disable_autodiff_subgraph_inlining=True),
|
|
fn, (self_variable,) + args_variable, kwargs_variable,
|
|
check_types=check_types)
|
|
|
|
if not is_magic_method and test_name not in EXCLUDE_SCRIPT:
|
|
check_against_reference(self,
|
|
create_script_fn(self, name, 'method', output_process_fn,
|
|
disable_autodiff_subgraph_inlining=True),
|
|
fn, (self_variable,) + args_variable, kwargs_variable,
|
|
check_types=check_types)
|
|
|
|
# functional interface tests
|
|
if hasattr(torch, name) and name not in EXCLUDE_FUNCTIONAL:
|
|
def fn(*inputs, **kwargs):
|
|
output = getattr(torch, name)(*inputs, **kwargs)
|
|
return output_process_fn(output)
|
|
|
|
f_args_variable = (self_variable,) + args_variable
|
|
f_args_tensor = (self_tensor,) + args_tensor
|
|
|
|
if not is_inplace and test_name not in EXCLUDE_TRACED:
|
|
check_against_reference(self,
|
|
create_traced_fn(self, fn,
|
|
disable_autodiff_subgraph_inlining=True),
|
|
fn, f_args_variable, kwargs_variable, check_types=check_types)
|
|
|
|
if not is_inplace and test_name not in EXCLUDE_SCRIPT:
|
|
check_against_reference(self,
|
|
create_script_fn(self, name, 'functional', output_process_fn,
|
|
disable_autodiff_subgraph_inlining=True),
|
|
fn, f_args_variable, kwargs_variable,
|
|
check_types=check_types)
|
|
|
|
# alias annotation testing
|
|
if is_inplace and test_name not in EXCLUDE_SCRIPT:
|
|
check_alias_annotation(name, (self_variable,) + args_variable, kwargs_variable)
|
|
|
|
check(name)
|
|
inplace_name = name + '_'
|
|
# can't broadcast inplace to left hand side
|
|
broadcast_skip_inplace = 'broadcast_lhs' in test_name or 'broadcast_all' in test_name
|
|
if hasattr(torch.ones(1), inplace_name) and not broadcast_skip_inplace:
|
|
check(inplace_name)
|
|
|
|
post_add_test(test_name, skipTestIf, do_test, TestJitGeneratedAutograd)
|
|
|
|
|
|
def suppress_warnings(fn):
|
|
@wraps(fn)
|
|
def wrapper(*args, **kwargs):
|
|
with warnings.catch_warnings(record=True):
|
|
return fn(*args, **kwargs)
|
|
return wrapper
|
|
|
|
|
|
def add_nn_functional_test(name, self_size, args, variant_name='', skipTestIf=(),
|
|
output_process_fn=lambda x: x, kwargs=None):
|
|
test_name = 'test_nn_' + name
|
|
|
|
if variant_name != '':
|
|
test_name = test_name + '_' + variant_name
|
|
|
|
no_grad = variant_name == 'inplace'
|
|
|
|
@suppress_warnings
|
|
def do_test(self, name=name, args=args, test_name=test_name):
|
|
torch.manual_seed(2)
|
|
|
|
self_variable = create_input((self_size,))[0][0]
|
|
|
|
# need to record this because methods can change the size (e.g. unsqueeze)
|
|
args_variable, kwargs_variable = create_input(args, call_kwargs=kwargs)
|
|
|
|
self_tensor = deepcopy(self_variable.data)
|
|
args_tensor = deepcopy(unpack_variables(args_variable))
|
|
|
|
if not no_grad:
|
|
output_variable = getattr(F, name)(self_variable, *args_variable, **kwargs_variable)
|
|
|
|
def fn(*inputs, **kwargs):
|
|
output = getattr(F, name)(*inputs, **kwargs)
|
|
return output_process_fn(output)
|
|
|
|
f_args_variable = (self_variable,) + args_variable
|
|
f_args_tensor = (self_tensor,) + args_tensor
|
|
|
|
if test_name not in EXCLUDE_SCRIPT:
|
|
disable_ad_subgraph_inlining = test_name in DISABLE_AUTODIFF_SUBGRAPH_INLINING
|
|
|
|
def run_test():
|
|
script_fn = create_script_fn(self, name, 'nn_functional', output_process_fn,
|
|
disable_autodiff_subgraph_inlining=disable_ad_subgraph_inlining)
|
|
check_against_reference(self, script_fn, fn, f_args_variable, kwargs_variable, no_grad=no_grad)
|
|
|
|
if test_name in EXCLUDE_PYTHON_PRINT:
|
|
with self.disableModuleHook():
|
|
run_test()
|
|
else:
|
|
run_test()
|
|
|
|
post_add_test(test_name, skipTestIf, do_test, TestJitGeneratedFunctional)
|
|
|
|
|
|
def add_nn_module_test(*args, **kwargs):
|
|
if 'module_name' in kwargs:
|
|
name = kwargs['module_name']
|
|
elif 'fullname' in kwargs:
|
|
name = kwargs['fullname']
|
|
elif 'constructor' in kwargs:
|
|
name = kwargs['constructor'].__name__
|
|
|
|
no_grad = False if 'no_grad' not in kwargs else kwargs['no_grad']
|
|
|
|
module_name = name.split("_")[0]
|
|
|
|
module = getattr(torch.nn, module_name, None)
|
|
if module is None or torch._jit_internal.weak_types.get(module) is None:
|
|
return
|
|
|
|
if 'desc' in kwargs and 'eval' in kwargs['desc']:
|
|
# eval() is not supported, so skip these tests
|
|
return
|
|
|
|
test_name = name
|
|
if 'desc' in kwargs:
|
|
test_name = "{}_{}".format(test_name, kwargs['desc'])
|
|
test_name = 'test_nn_{}'.format(test_name)
|
|
|
|
@suppress_warnings
|
|
def do_test(self):
|
|
if test_name in EXCLUDE_SCRIPT_MODULES:
|
|
return
|
|
if 'constructor' in kwargs:
|
|
nn_module = kwargs['constructor']
|
|
else:
|
|
nn_module = getattr(torch.nn, name)
|
|
|
|
if "FunctionalModule" in str(nn_module):
|
|
return
|
|
|
|
if 'constructor_args_fn' in kwargs:
|
|
constructor_args = kwargs['constructor_args_fn']()
|
|
else:
|
|
constructor_args = kwargs.get('constructor_args', ())
|
|
|
|
# Construct a script module that passes arguments through
|
|
# to self.submodule
|
|
def create_script_module(*args, **kwargs):
|
|
formals, tensors, actuals = get_script_args(args)
|
|
|
|
method_args = ', '.join(['self'] + actuals)
|
|
call_args_str = ', '.join(actuals)
|
|
call = "self.submodule({})".format(call_args_str)
|
|
script = script_method_template.format(method_args, call)
|
|
|
|
submodule_constants = []
|
|
if kwargs.get('is_constant'):
|
|
submodule_constants = ['submodule']
|
|
|
|
# Create module to use the script method
|
|
class TheModule(torch.jit.ScriptModule):
|
|
__constants__ = submodule_constants
|
|
|
|
def __init__(self):
|
|
super(TheModule, self).__init__()
|
|
self.submodule = nn_module(*constructor_args)
|
|
# module cannot be imported / exported
|
|
if module_name in EXCLUDE_MODULE_EXPORT_IMPORT:
|
|
with self.disableModuleHook():
|
|
module = TheModule()
|
|
module.define(script)
|
|
create_script_module.last_graph = module.graph
|
|
mod = module(*args)
|
|
else:
|
|
module = TheModule()
|
|
module.define(script)
|
|
self.assertExportImportModule(module, tensors)
|
|
create_script_module.last_graph = module.graph
|
|
mod = module(*args)
|
|
return mod
|
|
|
|
# Construct a normal nn module to stay consistent with create_script_module
|
|
# and make use of a single global rng_state in module initialization
|
|
def create_nn_module(*args, **kwargs):
|
|
module = nn_module(*constructor_args)
|
|
return module(*args)
|
|
|
|
# Set up inputs from tuple of sizes or constructor fn
|
|
if 'input_fn' in kwargs:
|
|
input = kwargs['input_fn']()
|
|
else:
|
|
input = (kwargs['input_size'],)
|
|
|
|
# Extra parameters to forward()
|
|
if 'extra_args' in kwargs:
|
|
input = input + kwargs['extra_args']
|
|
|
|
if 'target_size' in kwargs:
|
|
input = input + (kwargs['target_size'],)
|
|
elif 'target_fn' in kwargs:
|
|
if torch.is_tensor(input):
|
|
input = (input,)
|
|
input = input + (kwargs['target_fn'](),)
|
|
|
|
args_variable, kwargs_variable = create_input(input)
|
|
f_args_variable = deepcopy(unpack_variables(args_variable))
|
|
|
|
# Check against Python module as reference
|
|
check_against_reference(self, create_script_module, create_nn_module, f_args_variable, no_grad=no_grad)
|
|
|
|
post_add_test(test_name, (), do_test, TestJitGeneratedModule)
|
|
|
|
|
|
def post_add_test(test_name, skipTestIf, do_test, test_class):
|
|
assert not hasattr(test_class, test_name), 'Two tests have the same name: ' + test_name
|
|
|
|
for skip in skipTestIf:
|
|
do_test = skip(do_test)
|
|
|
|
if not (TEST_WITH_UBSAN and test_name in UBSAN_BLACKLISTED_TESTS):
|
|
setattr(test_class, test_name, do_test)
|
|
|
|
|
|
class TestAsync(JitTestCase):
|
|
def test_async_python(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return torch.neg(x)
|
|
|
|
x = torch.rand(3, 4)
|
|
fut = torch.jit._fork(foo, x)
|
|
y_hat = foo(x)
|
|
y = torch.jit._wait(fut)
|
|
# assert nothing; only to make sure the fake python path works
|
|
|
|
def test_async_parsing(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
# type: (Tensor) -> List[Tensor]
|
|
return [torch.neg(x), x.t()]
|
|
|
|
@torch.jit.script
|
|
def bar(x):
|
|
futures = torch.jit.annotate(List[Future[List[Tensor]]], [])
|
|
for _ in range(3):
|
|
future = torch.jit.annotate(
|
|
Future[List[Tensor]],
|
|
torch.jit._fork(foo, x)
|
|
)
|
|
futures.append(future)
|
|
|
|
output = torch.jit.annotate(List[List[Tensor]], [])
|
|
for i in range(3):
|
|
output.append(torch.jit._wait(futures[i]))
|
|
return output
|
|
|
|
x = torch.rand(3, 3)
|
|
result = bar(x)
|
|
self.assertEqual(len(result), 3)
|
|
|
|
def test_async_script(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return torch.neg(x), x
|
|
|
|
x = torch.rand(3, 4)
|
|
|
|
@torch.jit.script
|
|
def wait_script(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
y_hat = foo(x)
|
|
y = torch.jit._wait(fut)
|
|
return y, y_hat
|
|
|
|
y, y_hat = wait_script(x)
|
|
|
|
self.assertEqual(y, y_hat)
|
|
|
|
def test_async_script_capture(self):
|
|
class Mod(torch.jit.ScriptModule):
|
|
__constants__ = ['const']
|
|
|
|
def __init__(self):
|
|
super(Mod, self).__init__(False)
|
|
self.const = 42
|
|
self.param = nn.Parameter(torch.randn(2, 2))
|
|
|
|
@torch.jit.script_method
|
|
def foo(self, x1, x2):
|
|
return torch.neg(x1), self.param, self.const, torch.neg(x2), self.param
|
|
|
|
@torch.jit.script_method
|
|
def wait_script(self, x1, x2):
|
|
fut = torch.jit._fork(self.foo, x1, x2)
|
|
y_hat = self.foo(x1, x2)
|
|
y = torch.jit._wait(fut)
|
|
return y, y_hat
|
|
|
|
x1 = torch.rand(3, 4)
|
|
x2 = torch.rand(5, 6)
|
|
|
|
m = Mod()
|
|
y, y_hat = m.wait_script(x1, x2)
|
|
|
|
self.assertEqual(y, y_hat)
|
|
|
|
def test_async_script_nested(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return torch.neg(x), x
|
|
|
|
x = torch.rand(3, 4)
|
|
|
|
@torch.jit.script
|
|
def wait_script(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
y_hat = foo(x)
|
|
y = torch.jit._wait(fut)
|
|
return y, y_hat
|
|
|
|
@torch.jit.script
|
|
def wait_script_nest(x):
|
|
fut = torch.jit._fork(wait_script, x)
|
|
return torch.jit._wait(fut)
|
|
|
|
y, y_hat = wait_script_nest(x)
|
|
|
|
self.assertEqual(y, y_hat)
|
|
|
|
def test_async_script_no_script_mod(self):
|
|
x = torch.rand(3, 4)
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'cannot call a value'):
|
|
@torch.jit.script
|
|
def wait_script(x):
|
|
fut = torch.jit._fork(x)
|
|
return fut
|
|
|
|
def test_async_script_multi_waits(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
return torch.neg(x).t() + x
|
|
|
|
@torch.jit.script
|
|
def wait_script(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
|
|
# wait twice on the same future
|
|
y1 = torch.jit._wait(fut)
|
|
y2 = torch.jit._wait(fut)
|
|
return y1, y2
|
|
|
|
x = torch.rand(2, 2)
|
|
y1, y2 = wait_script(x)
|
|
self.assertEqual(y1, y2)
|
|
|
|
def test_async_script_multi_forks(self):
|
|
@torch.jit.script
|
|
def foo1(x):
|
|
return torch.neg(x).t() + x
|
|
|
|
@torch.jit.script
|
|
def foo2(x, y):
|
|
return torch.neg(x).t() + x + torch.neg(y).t()
|
|
|
|
@torch.jit.script
|
|
def foo3(x, y, z):
|
|
return torch.neg(z).t() + y.t() + x
|
|
|
|
x1 = torch.rand(10, 10)
|
|
x2 = torch.rand(10, 10)
|
|
x3 = torch.rand(10, 10)
|
|
|
|
@torch.jit.script
|
|
def wait_script(x1, x2, x3):
|
|
f1 = torch.jit._fork(foo1, x1)
|
|
f2 = torch.jit._fork(foo2, x1, x2)
|
|
f3 = torch.jit._fork(foo3, x1, x2, x3)
|
|
f4 = torch.jit._fork(foo1, x2)
|
|
f5 = torch.jit._fork(foo2, x2, x3)
|
|
|
|
# ignore some forks
|
|
y1 = torch.jit._wait(f1)
|
|
y2 = torch.jit._wait(f2)
|
|
y3 = torch.jit._wait(f3)
|
|
|
|
return y1, y2, y3
|
|
|
|
y1, y2, y3 = wait_script(x1, x2, x3)
|
|
self.assertEqual(y1, foo1(x1))
|
|
self.assertEqual(y2, foo2(x1, x2))
|
|
self.assertEqual(y3, foo3(x1, x2, x3))
|
|
|
|
def test_async_script_trace(self):
|
|
class Traced(nn.Module):
|
|
def __init__(self):
|
|
super(Traced, self).__init__()
|
|
|
|
def forward(self, x):
|
|
return (torch.neg(x), x)
|
|
|
|
class Mod(torch.jit.ScriptModule):
|
|
def __init__(self):
|
|
super(Mod, self).__init__(False)
|
|
x = torch.rand(3, 3)
|
|
self.traced = torch.jit.trace(Traced(), (x), _force_outplace=True)
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, x):
|
|
# type: (Tensor) -> Tuple[List[Tensor], Tuple[Tensor, Tensor], Tensor]
|
|
future1 = torch.jit._fork(self.traced, x)
|
|
future2 = torch.jit._fork(torch.neg, x)
|
|
|
|
tensor_tuple = torch.jit._wait(future1)
|
|
tensor_single = torch.jit._wait(future2)
|
|
|
|
tensor_list = []
|
|
tensor_list.append(tensor_tuple[0])
|
|
tensor_list.append(tensor_single)
|
|
|
|
# return a nested structure of tensors
|
|
return (tensor_list, tensor_tuple, tensor_tuple[1])
|
|
|
|
class TupleCl(nn.Module):
|
|
def __init__(self):
|
|
super(TupleCl, self).__init__()
|
|
self.module = Mod()
|
|
|
|
def forward(self, x):
|
|
z = torch.neg(x)
|
|
y = self.module(x)
|
|
list = [z, y[0][0], y[0][1], y[1][0], y[1][1], y[2]]
|
|
return tuple(list)
|
|
|
|
x = torch.rand(3, 3)
|
|
module = torch.jit.trace(TupleCl(), (x), _force_outplace=True)
|
|
|
|
# Make sure we have forks
|
|
self.assertGraphContainsExactly(module.graph, kind='prim::fork', num_kind_nodes=2)
|
|
# Make sure 1 ::neg is in the root graph and 2 ::negs are in the subgraphs
|
|
self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=1)
|
|
self.assertGraphContainsExactly(module.graph, kind='aten::neg', num_kind_nodes=3, consider_subgraphs=True)
|
|
|
|
y = torch.neg(x)
|
|
self.assertEqual(module(x), (y, y, y, y, x, x))
|
|
|
|
def test_async_script_error(self):
|
|
x = torch.rand(3, 4)
|
|
|
|
@torch.jit.script
|
|
def foo(x):
|
|
# error here
|
|
return x.t() + x
|
|
|
|
@torch.jit.script
|
|
def wait_script(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
return torch.jit._wait(fut)
|
|
|
|
@torch.jit.script
|
|
def wait_script_nest(x):
|
|
fut = torch.jit._fork(wait_script, x)
|
|
return torch.jit._wait(fut)
|
|
|
|
# no future
|
|
error_msg = 'The size.*must match the size of tensor'
|
|
with self.assertRaisesRegex(Exception, error_msg):
|
|
foo(x)
|
|
|
|
# one future
|
|
with self.assertRaisesRegex(Exception, error_msg):
|
|
wait_script(x)
|
|
|
|
# two futures with a different error
|
|
x = torch.rand(3, 4, 5)
|
|
with self.assertRaisesRegex(Exception, 'expects a tensor with <= 2 dimensions'):
|
|
wait_script_nest(x)
|
|
|
|
def test_async_grad_guard_with_grad(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
y = x * 2
|
|
return y.requires_grad
|
|
|
|
@torch.jit.script
|
|
def bar(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
requires_grad_in_fork = torch.jit._wait(fut)
|
|
z = x * 2
|
|
return (requires_grad_in_fork, z.requires_grad)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
|
|
with torch.enable_grad():
|
|
(inside_fork, after_wait) = bar(x)
|
|
|
|
self.assertEqual(inside_fork, True)
|
|
self.assertEqual(after_wait, True)
|
|
|
|
def test_async_grad_guard_no_grad(self):
|
|
@torch.jit.script
|
|
def foo(x):
|
|
y = x * 2
|
|
return y.requires_grad
|
|
|
|
@torch.jit.script
|
|
def bar(x):
|
|
fut = torch.jit._fork(foo, x)
|
|
requires_grad_in_fork = torch.jit._wait(fut)
|
|
z = x * 2
|
|
return (requires_grad_in_fork, z.requires_grad)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
|
|
with torch.no_grad():
|
|
(inside_fork, after_wait) = bar(x)
|
|
|
|
self.assertEqual(inside_fork, False)
|
|
self.assertEqual(after_wait, False)
|
|
|
|
def test_trace_fork_wait(self):
|
|
def fork_body(x):
|
|
return x.neg(), x.neg() + 1
|
|
|
|
def fn(x):
|
|
fut = torch.jit._fork(fork_body, x)
|
|
vals = torch.jit._wait(fut)
|
|
return vals[0], vals[1], x - 1
|
|
|
|
traced = torch.jit.trace(fn, (torch.rand(3, 4),))
|
|
x = torch.rand(3, 4)
|
|
self.assertEqual(fn(x), traced(x))
|
|
|
|
self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=1)
|
|
self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=1)
|
|
self.assertGraphContainsExactly(traced.graph, kind='aten::neg', num_kind_nodes=2, consider_subgraphs=True)
|
|
|
|
def test_trace_fork_wait_leaking(self):
|
|
my_list = []
|
|
|
|
def fork_body(x):
|
|
my_list.append(x + 1)
|
|
return x + 1
|
|
|
|
def fn(x):
|
|
fut = torch.jit._fork(fork_body, x)
|
|
val = torch.jit._wait(fut)
|
|
return my_list[0]
|
|
|
|
with self.assertRaisesRegex(RuntimeError, 'did not have observable data dependence with trace inputs; '
|
|
'this probably indicates your program cannot be understood '
|
|
'by the tracer.'):
|
|
traced = torch.jit.trace(fn, (torch.rand(3, 4),), check_trace=False)
|
|
|
|
def test_trace_fork_wait_inline(self):
|
|
def fork_body(x):
|
|
return x + 1, x + 2
|
|
|
|
def fn(x):
|
|
fut = torch.jit._fork(fork_body, x)
|
|
val = torch.jit._wait(fut)
|
|
return val[1]
|
|
|
|
traced = torch.jit.trace(fn, (torch.rand(3, 4),))
|
|
torch._C._jit_pass_inline_fork_wait(traced.graph)
|
|
torch._C._jit_pass_dce(traced.graph)
|
|
self.assertGraphContainsExactly(traced.graph, kind='prim::fork', num_kind_nodes=0)
|
|
self.assertGraphContainsExactly(traced.graph, kind='aten::wait', num_kind_nodes=0)
|
|
self.assertGraphContainsExactly(traced.graph, kind='aten::add', num_kind_nodes=2)
|
|
|
|
def test_trace_fork_wait_inline_onnx(self):
|
|
def fork_body(x):
|
|
return torch.neg(x), torch.neg(x)
|
|
|
|
class MyMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
fut = torch.jit._fork(fork_body, x)
|
|
val = torch.jit._wait(fut)
|
|
return val[1]
|
|
|
|
# smoke test for ONNX export
|
|
f = io.BytesIO()
|
|
torch.onnx.export(MyMod(), (torch.rand(3, 4),), f)
|
|
|
|
def test_save_load_with_extra_files(self):
|
|
class MyMod(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
return a
|
|
|
|
expected_extra_files = torch._C.ExtraFilesMap()
|
|
expected_extra_files['foo'] = 'bar'
|
|
m = MyMod()
|
|
|
|
# Save to file.
|
|
with TemporaryFileName() as fname:
|
|
m.save(fname, _extra_files=expected_extra_files)
|
|
extra_files = torch._C.ExtraFilesMap()
|
|
extra_files['foo'] = ''
|
|
torch.jit.load(fname, _extra_files=extra_files)
|
|
self.assertEqual('bar', extra_files['foo'])
|
|
|
|
# Use torch.jit API
|
|
torch.jit.save(m, fname, _extra_files=expected_extra_files)
|
|
extra_files['foo'] = ''
|
|
torch.jit.load(fname, _extra_files=extra_files)
|
|
self.assertEqual('bar', extra_files['foo'])
|
|
|
|
# Save to buffer.
|
|
buffer = io.BytesIO(m.save_to_buffer(_extra_files=expected_extra_files))
|
|
extra_files = torch._C.ExtraFilesMap()
|
|
extra_files['foo'] = ''
|
|
torch.jit.load(buffer, _extra_files=extra_files)
|
|
self.assertEqual('bar', extra_files['foo'])
|
|
|
|
# Use torch.jit API
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer, _extra_files=expected_extra_files)
|
|
buffer.seek(0)
|
|
extra_files = torch._C.ExtraFilesMap()
|
|
extra_files['foo'] = ''
|
|
torch.jit.load(buffer, _extra_files=extra_files)
|
|
self.assertEqual('bar', extra_files['foo'])
|
|
|
|
# Non-existent file 'bar'
|
|
with self.assertRaises(RuntimeError):
|
|
extra_files['bar'] = ''
|
|
torch.jit.load(buffer, _extra_files=extra_files)
|
|
|
|
|
|
class TestDataParallel(JitTestCase):
|
|
class Mpy(torch.nn.Module):
|
|
def __init__(self):
|
|
super(TestDataParallel.Mpy, self).__init__()
|
|
self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2),
|
|
nn.ReLU(), nn.Linear(2, 2))
|
|
|
|
def forward(self, input):
|
|
return self.m(input)
|
|
|
|
class Mpy1(torch.nn.Module):
|
|
def __init__(self, block):
|
|
super(TestDataParallel.Mpy1, self).__init__()
|
|
self.m = block
|
|
|
|
def forward(self, input):
|
|
return self.m.forward(input)
|
|
|
|
class Mpy2(torch.nn.Module):
|
|
def __init__(self, block1, block2):
|
|
super(TestDataParallel.Mpy2, self).__init__()
|
|
self.m1 = block1
|
|
self.m2 = block2
|
|
|
|
def forward(self, input):
|
|
x = self.m1.forward(input)
|
|
return self.m2(x)
|
|
|
|
class Msm(torch.jit.ScriptModule):
|
|
|
|
__constants__ = ['m']
|
|
|
|
def __init__(self):
|
|
super(TestDataParallel.Msm, self).__init__(False)
|
|
self.m = nn.Sequential(nn.Linear(2, 2), nn.BatchNorm1d(2),
|
|
nn.ReLU(), nn.Linear(2, 2))
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
return self.m(input)
|
|
|
|
class Msm1(torch.jit.ScriptModule):
|
|
def __init__(self, block):
|
|
super(TestDataParallel.Msm1, self).__init__(False)
|
|
self.block = block
|
|
|
|
@torch.jit.script_method
|
|
def forward(self, input):
|
|
x = self.block(input)
|
|
return x
|
|
|
|
def check_replicas(self, module, replicas, input_shape=(2, 2)):
|
|
input = torch.randn(input_shape).cuda()
|
|
expected_output = module(input).data
|
|
for i, replica in enumerate(replicas):
|
|
for p in replica.parameters():
|
|
self.assertEqual(p.get_device(), i)
|
|
for b in replica.buffers():
|
|
self.assertEqual(b.get_device(), i)
|
|
replica_input = input.cuda(i)
|
|
self.assertEqual(replica(replica_input).data, expected_output)
|
|
|
|
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
|
|
@skipIfRocm
|
|
def test_python_submodule_exception(self):
|
|
module = self.Msm1(self.Mpy()).cuda()
|
|
msg = "Cannot replicate.*"
|
|
with self.assertRaisesRegex(Exception, msg):
|
|
dp.replicate(module, {0, 1})
|
|
|
|
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
|
|
@skipIfRocm
|
|
def test_python_submodule_script(self):
|
|
module = self.Mpy1(self.Msm()).cuda()
|
|
replicas = dp.replicate(module, {0, 1})
|
|
self.check_replicas(module, replicas)
|
|
|
|
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
|
|
@skipIfRocm
|
|
def test_shared_module(self):
|
|
s = self.Msm()
|
|
p1 = self.Mpy1(s)
|
|
module = self.Mpy2(p1, s).cuda()
|
|
replicas = dp.replicate(module, {0, 1})
|
|
self.check_replicas(module, replicas)
|
|
|
|
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
|
|
@skipIfRocm
|
|
def test_traced_module(self):
|
|
module = torch.jit.trace(self.Mpy1(self.Mpy()), torch.ones(2, 2)).cuda()
|
|
replicas = dp.replicate(module, {0, 1})
|
|
self.check_replicas(module, replicas)
|
|
|
|
@unittest.skipIf(not RUN_CUDA_MULTI_GPU, "multi-GPU not supported")
|
|
@skipIfRocm
|
|
def test_tensor_sharing(self):
|
|
module = self.Msm1(self.Msm()).cuda()
|
|
replica = dp.replicate(module, {0, 1})
|
|
optimizer = optim.SGD(module.parameters(), lr=1, momentum=1)
|
|
x = torch.ones(2, 2, requires_grad=True).cuda()
|
|
first_forward = module.forward(x)
|
|
first_forward.sum().backward()
|
|
optimizer.step()
|
|
second_forward = module.forward(first_forward)
|
|
|
|
# replica which is on the same GPU has a shallow copy of the original
|
|
# params and buffers
|
|
r0_forward = replica[0].forward(x)
|
|
self.assertEqual(second_forward, r0_forward)
|
|
|
|
# replca which is on a different GPU has a deep copy of the original
|
|
# params and buffers
|
|
x1 = torch.ones(2, 2, requires_grad=True).cuda(device=1)
|
|
r1_forward = replica[1].forward(x1)
|
|
self.assertEqual(first_forward, r1_forward)
|
|
|
|
|
|
class TestClassType(JitTestCase):
|
|
def test_get_with_method(self):
|
|
@torch.jit.script
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
self.foo = x
|
|
|
|
def getFooTest(self):
|
|
return self.foo
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
foo = FooTest(x)
|
|
return foo.getFooTest()
|
|
|
|
input = torch.ones(2, 3)
|
|
self.assertEqual(fn(input), input)
|
|
|
|
def test_get_attr(self):
|
|
@torch.jit.script # noqa: B903
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
self.foo = x
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
foo = FooTest(x)
|
|
return foo.foo
|
|
|
|
input = torch.ones(2, 3)
|
|
self.assertEqual(fn(input), input)
|
|
|
|
def test_set_attr_in_method(self):
|
|
@torch.jit.script
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
# type: (int) -> None
|
|
self.foo = x
|
|
|
|
def incFooTest(self, y):
|
|
# type: (int) -> None
|
|
self.foo = self.foo + y
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
# type: (int) -> int
|
|
foo = FooTest(x)
|
|
foo.incFooTest(2)
|
|
return foo.foo
|
|
|
|
self.assertEqual(fn(1), 3)
|
|
|
|
def test_set_attr_type_mismatch(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Wrong type for attribute assignment"):
|
|
@torch.jit.script
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
self.foo = x
|
|
self.foo = 10 # should error since int != Tensor
|
|
|
|
def test_get_attr_not_initialized(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Tried to access to nonexistent attribute"):
|
|
@torch.jit.script
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
self.foo = x
|
|
|
|
def get_non_initialized(self):
|
|
return self.asdf # asdf isn't an attr
|
|
|
|
def test_set_attr_non_initialized(self):
|
|
with self.assertRaisesRegex(RuntimeError, "Tried to set nonexistent attribute"):
|
|
@torch.jit.script
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
self.foo = x
|
|
|
|
def set_non_initialized(self, y):
|
|
self.bar = y # can't assign to non-initialized attr
|
|
|
|
def test_type_annotations(self):
|
|
with self.assertRaisesRegex(RuntimeError, "expected a value of type bool"):
|
|
@torch.jit.script # noqa: B903
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
# type: (bool) -> None
|
|
self.foo = x
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
FooTest(x)
|
|
|
|
fn(2)
|
|
|
|
def test_conditional_set_attr(self):
|
|
with self.assertRaisesRegex(RuntimeError, "assignment cannot be in a control-flow block"):
|
|
@torch.jit.script
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
if True:
|
|
self.attr = x
|
|
|
|
def test_class_type_as_param(self):
|
|
@torch.jit.script # noqa: B903
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
self.attr = x
|
|
|
|
@torch.jit.script
|
|
def fn(foo):
|
|
# type: (FooTest) -> Tensor
|
|
return foo.attr
|
|
|
|
@torch.jit.script
|
|
def fn2(x):
|
|
foo = FooTest(x)
|
|
return fn(foo)
|
|
|
|
input = torch.ones(1)
|
|
self.assertEqual(fn2(input), input)
|
|
|
|
def test_out_of_order_methods(self):
|
|
@torch.jit.script
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
self.x = x
|
|
self.x = self.get_stuff(x)
|
|
|
|
def get_stuff(self, y):
|
|
return self.x + y
|
|
|
|
@torch.jit.script
|
|
def fn(x):
|
|
f = FooTest(x)
|
|
return f.x
|
|
|
|
input = torch.ones(1)
|
|
self.assertEqual(fn(input), input + input)
|
|
|
|
def test_save_load_with_classes(self):
|
|
@torch.jit.script
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
self.x = x
|
|
|
|
def get_x(self):
|
|
return self.x
|
|
|
|
class MyMod(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
foo = FooTest(a)
|
|
return foo.get_x()
|
|
|
|
m = MyMod()
|
|
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer)
|
|
|
|
# classes are globally registered for now, so we need to clear the JIT
|
|
# registry to simulate loading a new model
|
|
torch._C._jit_clear_class_registry()
|
|
|
|
buffer.seek(0)
|
|
m_loaded = torch.jit.load(buffer)
|
|
|
|
input = torch.rand(2, 3)
|
|
output = m_loaded(input)
|
|
self.assertEqual(input, output)
|
|
|
|
def test_save_load_with_classes_nested(self):
|
|
@torch.jit.script # noqa: B903
|
|
class FooNestedTest(object):
|
|
def __init__(self, y):
|
|
self.y = y
|
|
|
|
@torch.jit.script
|
|
class FooNestedTest2(object):
|
|
def __init__(self, y):
|
|
self.y = y
|
|
self.nested = FooNestedTest(y)
|
|
|
|
@torch.jit.script
|
|
class FooTest(object):
|
|
def __init__(self, x):
|
|
self.class_attr = FooNestedTest(x)
|
|
self.class_attr2 = FooNestedTest2(x)
|
|
self.x = self.class_attr.y + self.class_attr2.y
|
|
|
|
class MyMod(torch.jit.ScriptModule):
|
|
@torch.jit.script_method
|
|
def forward(self, a):
|
|
foo = FooTest(a)
|
|
return foo.x
|
|
|
|
m = MyMod()
|
|
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer)
|
|
|
|
# classes are globally registered for now, so we need to clear the JIT
|
|
# registry to simulate loading a new model
|
|
torch._C._jit_clear_class_registry()
|
|
|
|
buffer.seek(0)
|
|
m_loaded = torch.jit.load(buffer)
|
|
|
|
input = torch.rand(2, 3)
|
|
output = m_loaded(input)
|
|
self.assertEqual(2 * input, output)
|
|
|
|
def test_python_interop(self):
|
|
@torch.jit.script # noqa: B903
|
|
class Foo(object):
|
|
def __init__(self, x, y):
|
|
self.x = x
|
|
self.y = y
|
|
|
|
@torch.jit.script
|
|
def use_foo(foo):
|
|
# type: (Foo) -> Foo
|
|
return foo
|
|
|
|
# create from python
|
|
x = torch.ones(2, 3)
|
|
y = torch.zeros(2, 3)
|
|
f = Foo(x, y)
|
|
|
|
self.assertEqual(x, f.x)
|
|
self.assertEqual(y, f.y)
|
|
|
|
# pass in and out of script
|
|
f2 = use_foo(f)
|
|
|
|
self.assertEqual(x, f2.x)
|
|
self.assertEqual(y, f2.y)
|
|
|
|
|
|
for test in autograd_method_tests():
|
|
add_autograd_test(*test)
|
|
|
|
for test in nn_functional_tests:
|
|
add_nn_functional_test(*test)
|
|
|
|
for test in module_tests + new_module_tests + additional_module_tests:
|
|
add_nn_module_test(**test)
|
|
|
|
for test in criterion_tests:
|
|
test['no_grad'] = True
|
|
add_nn_module_test(**test)
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|