pytorch/torch/testing/_internal/jit_utils.py
Yanan Cao c068233300 Add CHECK-SOURCE-HIGHLIGHTED to file check utils. (#39692)
Summary:
Enhance FileCheck util to check for highlighted source ranges. This is useful when writing tests regarding generated error messages that require source code highlighting.

Here is how the error looks like in different cases:

- In case of needed source code token not found at all in input string:
```
RuntimeError: Expected to find "invalid_token" but did not find it
Searched string:

...  <--- HERE
def to_list_missing_type_annotation(x):
    # type: (torch.Tensor) -> List[float]
From CHECK-SOURCE-HIGHLIGHTED: invalid_token
```

- In case of source code token not highlighted:
```
Traceback (most recent call last):
  File "test_range.py", line 11, in <module>
    FileCheck().check_source_highlighted("x.tolist()").run(s)
RuntimeError: Expected to find "~~~~~~~~~~" but did not find it
Searched string:
    # type: (torch.Tensor) -> List[float]
    li = x.tolist()
         ~~~~~~~~~ <--- HERE
         ~~~~~~~~~~~~~~~~~~~...  <--- HERE
    return li
```

It is a bit confusing since both input text (usually an error message) and generated error messages have their highlighted portions, but this is consistent of previous behavior. Another option is to generate plain error messages without additional range highlighting on input text.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39692

Test Plan:
Added unit test.

Closes https://github.com/pytorch/pytorch/issues/38698

Differential Revision: D22001765

Pulled By: gmagogsfm

fbshipit-source-id: 6681441eee5853ab061d198ccfe55ebffddca202
2020-06-11 23:47:07 -07:00

667 lines
25 KiB
Python

# Torch
from torch.autograd import Variable
from torch.autograd.function import _nested_map
from torch.jit.annotations import BroadcastingList2, BroadcastingList3 # noqa: F401
from torch.onnx import OperatorExportTypes
import torch
import torch.cuda
import torch.jit
import torch.jit._logging
import torch.jit.frontend
import torch.jit.quantized
import zipfile
import functools
# Testing utils
from torch.testing import FileCheck
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS, \
freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, ProfilingMode, TEST_BAILOUTS
from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401
# Standard library
from contextlib import contextmanager
from functools import reduce
from itertools import chain
from torch._six import StringIO
import inspect
import io
import math
import os
import pickle
import sys
import tempfile
import textwrap
RUN_CUDA = torch.cuda.is_available()
RUN_CUDA_MULTI_GPU = RUN_CUDA and torch.cuda.device_count() > 1
def execWrapper(code, glob, loc):
exec(code, glob, loc)
def do_input_map(fn, input):
return _nested_map(lambda t: isinstance(t, torch.Tensor), fn)(input)
def clear_class_registry():
torch._C._jit_clear_class_registry()
torch.jit._recursive.concrete_type_store = torch.jit._recursive.ConcreteTypeStore()
def get_execution_plan(graph_executor_state):
execution_plans = list(graph_executor_state.execution_plans.values())
num_plans = len(execution_plans)
if num_plans != 1:
raise RuntimeError('This test assumes this GraphExecutor should '
'only have one execution plan, got: {}'.format(num_plans))
return execution_plans[0]
class _AssertRaisesRegexWithHighlightContext(object):
"""
A context manager that is useful for checking that error messages highlight
the correct part of the source code.
"""
def __init__(self, test_case, exception, regex, highlight):
self.test_case = test_case
self.exception_type = exception
self.regex = regex
self.highlight = highlight
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
with self.test_case.assertRaisesRegex(self.exception_type, self.regex):
if type:
raise value
if self.highlight:
FileCheck().check_source_highlighted(self.highlight).run(str(value))
return True
class JitTestCase(TestCase):
_do_cuda_memory_leak_check = True
_restored_warnings = False
class capture_stdout(list):
"""
Replace sys.stdout with a temporary StringIO
"""
def __enter__(self):
self.sys_stdout = sys.stdout
self.stringio = StringIO()
sys.stdout = self.stringio
return self
def __exit__(self, *args):
self.append(str(self.stringio.getvalue()))
del self.stringio
sys.stdout = self.sys_stdout
def setHooks(self):
torch._C._jit_set_emit_hooks(self.emitModuleHook, self.emitFunctionHook)
def clearHooks(self):
torch._C._jit_set_emit_hooks(None, None)
def setUp(self):
super().setUp()
# unittest overrides all warning filters and forces all of them to show up
# after we install our own to silence those coming from inside PyTorch.
# This will ensure that our filter still takes precedence.
if not JitTestCase._restored_warnings:
torch.jit.TracerWarning.ignore_lib_warnings()
JitTestCase._restored_warnings = True
self.setHooks()
def tearDown(self):
super().tearDown()
# needs to be cleared because python might be unloaded before
# the callback gets destucted
self.clearHooks()
clear_class_registry()
def _isHookExceptionOk(self, e):
se = str(e)
allowed = ("Could not export Python function",
"closures are not exportable")
for a in allowed:
if a in se:
return True
return False
def _compared_saved_loaded(self, m):
def extract_files(buffer):
# crack open the zip format to get at the main module code
archive = zipfile.ZipFile(buffer)
# check that we have no duplicate names
self.assertEqual(len(set(archive.namelist())), len(archive.namelist()))
files = list(filter(lambda x: x.startswith('archive/code/'), archive.namelist()))
# unwrap all the code files into strings
code_files = filter(lambda x: x.endswith('.py'), files)
code_files = map(lambda f: archive.open(f), code_files)
code_files = map(lambda file: "".join([line.decode() for line in file]), code_files)
# unpickled all the debug files
debug_files = filter(lambda f: f.endswith('.debug_pkl'), files)
debug_files = map(lambda f: archive.open(f), debug_files)
debug_files = map(lambda f: pickle.load(f), debug_files)
return code_files, debug_files
# disable the hook while we parse code, otherwise we will re-enter the hook
with torch.jit._disable_emit_hooks():
try:
# short-circuit if this is an empty function or module
if len(m.code) == 0:
return
if isinstance(m, torch._C.ScriptModule):
if len(m._method_names()) == 0:
return
# save the module to a buffer
buffer = io.BytesIO()
torch.jit.save(m, buffer)
# copy the data in the buffer so we can restore it later. This
# is because py2 and py3 have different semantics with zipfile
# and it's easier to just work with a fresh copy each time.
buffer_copy = buffer.getvalue()
code_files, debug_files = extract_files(buffer)
except RuntimeError as e:
if not self._isHookExceptionOk(e):
raise
else:
return
# import the model again (from a the copy we made of the original)
buffer2 = io.BytesIO(buffer_copy)
imported = torch.jit.load(buffer2)
# save it again
saved_module_buffer_2 = io.BytesIO()
torch.jit.save(imported, saved_module_buffer_2)
saved_module_buffer_2.seek(0)
code_files_2, debug_files_2 = extract_files(saved_module_buffer_2)
for a, b in zip(code_files, code_files_2):
self.assertMultiLineEqual(a, b)
if isinstance(m, torch._C.ScriptModule):
self.assertTrue(torch._C._ivalue_tags_match(m, imported._c))
def emitFunctionHook(self, func):
# func has invalid names for export, skip the jitter check
if func.name == "<lambda>" or "aten::" in func.name:
return
self._compared_saved_loaded(func)
def emitModuleHook(self, module):
self._compared_saved_loaded(module)
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
imported = torch.jit.load(buffer, map_location=map_location)
if not also_test_file:
return imported
with TemporaryFileName() as fname:
torch.jit.save(imported, fname)
return torch.jit.load(fname, map_location=map_location)
def getExportImportCopyWithPacking(self, m, also_test_file=True, map_location=None):
buffer = io.BytesIO()
m.apply(lambda s: s._pack() if s._c._has_method('_pack') else None)
torch.jit.save(m, buffer)
m.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
buffer.seek(0)
imported = torch.jit.load(buffer, map_location=map_location)
imported.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
if not also_test_file:
return imported
# Ideally we would like to not have to manually delete the file, but NamedTemporaryFile
# opens the file, and it cannot be opened multiple times in Windows. To support Windows,
# close the file after creation and try to remove it manually
f = tempfile.NamedTemporaryFile(delete=False)
try:
f.close()
imported.save(f.name)
result = torch.jit.load(f.name, map_location=map_location)
finally:
os.unlink(f.name)
result.apply(lambda s: s._unpack() if s._c._has_method('_unpack') else None)
return result
def assertGraphContains(self, graph, kind):
self.assertTrue(any(n.kind() == kind for n in graph.nodes()))
def assertGraphContainsExactly(self, graph, kind, num_kind_nodes, consider_subgraphs=False):
def perform_assert(graph, kind, actual, expected, consider_subgraphs):
if actual == expected:
return
subgraph = 'including' if consider_subgraphs else 'excluding'
raise AssertionError(
'{}\nError: graph contains {} {} nodes ({} subgraphs) but expected {}'.format(
graph, actual, kind, subgraph, expected))
if consider_subgraphs:
strgraph = str(graph)
count = strgraph.count(kind) - strgraph.count('with {}'.format(kind))
perform_assert(graph, kind, count, num_kind_nodes,
consider_subgraphs)
return
nodes = [node for node in graph.nodes()
if node.kind() == kind]
perform_assert(graph, kind, len(nodes), num_kind_nodes,
consider_subgraphs)
def assertExpectedONNXGraph(self, g, *args, **kwargs):
g = torch.onnx._optimize_trace(g, operator_export_type=OperatorExportTypes.ONNX)
self.assertExpectedGraph(g, *args, **kwargs)
def assertExpectedGraph(self, trace, *args, **kwargs):
if isinstance(trace, torch._C.Graph):
graph = trace
else:
graph = trace.graph()
torch._C._jit_pass_lint(graph)
torch._C._jit_pass_dce(graph)
torch._C._jit_pass_lint(graph)
graph = torch._C._jit_pass_canonicalize(graph)
torch._C._jit_pass_lint(graph)
self.assertExpected(str(graph), *args, **kwargs)
def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
# For any non-fusible node, it must show up in one of the DifferentiableGraph.
found_all_nonfusible_nodes = (len(diff_subgraphs) == 0 and len(nonfusible_nodes) == 0)\
or all([any(g.findNode(n) is not None for g in diff_subgraphs) for n in nonfusible_nodes])
# For any fusible node, it must show up in one of the FusionGroup in the DifferentiableGraph.
fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
found_all_fusible_nodes = (len(fusion_nodes) == 0 and len(fusible_nodes) == 0)\
or all([any(g.findNode(n) is not None for g in fusion_subgraphs) for n in fusible_nodes])
self.assertEqual(should_autodiff_node, found_all_nonfusible_nodes and found_all_fusible_nodes)
def run_pass(self, name, trace):
if isinstance(trace, torch._C.Graph):
graph = trace
set_graph = False
else:
set_graph = True
graph = trace.graph()
torch._C._jit_pass_lint(graph)
result = getattr(torch._C, '_jit_pass_' + name)(graph)
if result is not None:
graph = result
torch._C._jit_pass_lint(graph)
if set_graph:
trace.set_graph(graph)
return graph
def get_frame_vars(self, frames_up):
frame = inspect.currentframe()
i = 0
while i < frames_up + 1:
frame = frame.f_back
i += 1
defined_vars = {}
defined_vars.update(frame.f_locals)
defined_vars.update(frame.f_globals)
return defined_vars
def assertRaisesRegexWithHighlight(self, exception, regex, highlight):
return _AssertRaisesRegexWithHighlightContext(self, exception, regex, highlight)
def checkScriptRaisesRegex(self, script, inputs, exception, regex,
outputs=None, capture_output=False, profiling=ProfilingMode.PROFILING):
"""
Checks that a given function will throw the correct exception,
when executed with normal python, the string frontend, and the AST frontend
"""
with enable_profiling_mode_for_profiling_tests():
# normal python
with self.assertRaisesRegex(exception, regex):
script(*inputs)
# string frontend
with self.assertRaisesRegex(exception, regex):
source = textwrap.dedent(inspect.getsource(script))
cu = torch.jit.CompilationUnit(source)
ge = getattr(cu, script.__name__)
# profiling run
with self.assertRaisesRegex(exception, regex):
ge(*inputs)
# optimized run
ge(*inputs)
# python AST frontend
with self.assertRaisesRegex(exception, regex):
ge = torch.jit.script(script)
# profiling run
with self.assertRaisesRegex(exception, regex):
ge(*inputs)
# optimized run
ge(*inputs)
def checkBailouts(self, model, inputs, expected):
state = model.get_debug_state()
plan = get_execution_plan(state)
num_bailouts = plan.code.num_bailouts()
for i in range(0, num_bailouts):
plan.code.request_bailout(i)
bailout_outputs = model(*inputs)
self.assertEqual(bailout_outputs, expected)
def checkScript(self,
script,
inputs,
name='func',
optimize=True,
inputs_requires_grad=False,
capture_output=False,
frames_up=1,
profiling=ProfilingMode.PROFILING):
with torch.jit.optimized_execution(optimize):
with enable_profiling_mode_for_profiling_tests():
if isinstance(script, str):
# Compile the string to a Script function
# with enable_profiling_mode():
cu = torch.jit.CompilationUnit(script, _frames_up=frames_up)
# Execute the Python function so we can run it later and get its
# outputs
frame = self.get_frame_vars(frames_up)
the_locals = {}
execWrapper(script, glob=frame, loc=the_locals)
frame.update(the_locals)
python_fn = frame[name]
scripted_fn = getattr(cu, name)
else:
# Check the string frontend first
source = textwrap.dedent(inspect.getsource(script))
self.checkScript(
source,
inputs,
script.__name__,
optimize=optimize,
inputs_requires_grad=inputs_requires_grad,
capture_output=capture_output,
profiling=profiling,
frames_up=2)
# Continue checking the Python frontend
scripted_fn = torch.jit.script(script, _frames_up=1)
python_fn = script
if inputs_requires_grad:
recording_inputs = do_input_map(lambda t: t.detach().requires_grad_(), inputs)
else:
recording_inputs = inputs
if capture_output:
with self.capture_stdout() as script_stdout:
script_outputs = scripted_fn(*recording_inputs)
with self.capture_stdout() as opt_script_stdout:
opt_script_outputs = scripted_fn(*recording_inputs)
with self.capture_stdout() as _python_stdout:
python_outputs = python_fn(*inputs)
if not IS_WINDOWS:
self.assertExpected(script_stdout[0], subname='stdout')
self.assertEqual(python_outputs, opt_script_outputs)
else:
# profiling run
script_outputs = scripted_fn(*recording_inputs)
# optimized run
opt_script_outputs = scripted_fn(*recording_inputs)
if TEST_BAILOUTS:
self.checkBailouts(scripted_fn, inputs, opt_script_outputs)
python_outputs = python_fn(*inputs)
self.assertEqual(python_outputs, script_outputs)
self.assertEqual(script_outputs, opt_script_outputs)
return scripted_fn
def checkTrace(self, func, reference_tensors, input_tensors=None,
drop=None, allow_unused=False, verbose=False,
inputs_require_grads=True, check_tolerance=1e-5, export_import=True,
_force_outplace=False):
# TODO: check gradients for parameters, not just inputs
def allSum(vs):
# drop allows us to remove some values from ever being used
# to test unused outputs
if drop is not None:
vs = vs[:-drop]
# we don't want all the grad for all the outputs to be the same
# so we multiply each by a constant
return sum(math.log(i + 2) * v.sum() for i, v in enumerate(vs) if v is not None)
if input_tensors is None:
input_tensors = reference_tensors
def flatten_inputs(inputs):
def input_reduce(input, fn, acc):
if isinstance(input, torch.Tensor):
fn(input, acc)
elif isinstance(input, dict):
reduce(lambda acc, key: input_reduce(input[key], fn, acc), input, acc)
else:
reduce(lambda acc, val: input_reduce(val, fn, acc), input, acc)
return acc
return tuple(input_reduce(recording_inputs, lambda t, acc: acc.append(t), []))
nograd_inputs = reference_tensors
if inputs_require_grads:
recording_inputs = do_input_map(lambda t: t.clone().requires_grad_(), reference_tensors)
flattened_recording_inputs = flatten_inputs(recording_inputs)
else:
recording_inputs = reference_tensors
# `check_trace` is set to False because check_trace is run with @no_grad
# Also, `checkTrace` already does all the checks
# against python function
ge = torch.jit.trace(func, input_tensors, check_tolerance=check_tolerance,
_force_outplace=_force_outplace, check_trace=False)
if export_import:
ge = self.getExportImportCopy(ge)
if verbose:
print(ge.graph)
# test no gradients case
outputs = func(*nograd_inputs)
outputs_ge = ge(*nograd_inputs)
self.assertEqual(outputs, outputs_ge)
# test gradients case
outputs = func(*recording_inputs)
if inputs_require_grads:
grads = torch.autograd.grad(allSum(outputs), flattened_recording_inputs,
allow_unused=allow_unused)
outputs_ge = ge(*recording_inputs)
if inputs_require_grads:
grads_ge = torch.autograd.grad(allSum(outputs_ge), flattened_recording_inputs,
allow_unused=allow_unused)
self.assertEqual(outputs, outputs_ge)
if inputs_require_grads:
self.assertEqual(grads, grads_ge)
self.assertEqual(outputs, outputs_ge)
if inputs_require_grads:
self.assertEqual(grads, grads_ge)
# test the grad grad case
outputs = func(*recording_inputs)
l1 = allSum(outputs)
if inputs_require_grads:
grads = torch.autograd.grad(l1, flattened_recording_inputs, create_graph=True,
allow_unused=allow_unused)
if inputs_require_grads:
l2 = (allSum(grads) * l1)
grads2 = torch.autograd.grad(l2, flattened_recording_inputs, allow_unused=allow_unused)
if inputs_require_grads:
recording_inputs = do_input_map(lambda t: Variable(t, requires_grad=True), reference_tensors)
flattened_recording_inputs = flatten_inputs(recording_inputs)
outputs_ge = ge(*recording_inputs)
l1_ge = allSum(outputs_ge)
if inputs_require_grads:
grads_ge = torch.autograd.grad(
l1_ge, flattened_recording_inputs, create_graph=True, allow_unused=allow_unused)
if inputs_require_grads:
l2_ge = (allSum(grads_ge) * l1_ge)
grads2_ge = torch.autograd.grad(l2_ge, flattened_recording_inputs, allow_unused=allow_unused)
self.assertEqual(outputs, outputs_ge)
if inputs_require_grads:
self.assertEqual(grads, grads_ge)
for g2, g2_ge in zip(grads2, grads2_ge):
if g2 is None and g2_ge is None:
continue
self.assertTrue(torch.allclose(g2, g2_ge, atol=8e-4, rtol=8e-4))
return ge
def createFunctionFromGraph(self, trace):
graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
return torch._C._create_function_from_graph("forward", graph)
def assertExportImport(self, trace, inputs):
m = self.createFunctionFromGraph(trace)
self.assertExportImportModule(m, inputs)
def assertExportImportModule(self, m, inputs):
m_import = self.getExportImportCopy(m)
a = self.runAndSaveRNG(m, inputs)
b = self.runAndSaveRNG(m_import, inputs)
self.assertEqual(a, b)
def runAndSaveRNG(self, func, inputs, kwargs=None):
kwargs = kwargs if kwargs else {}
with freeze_rng_state():
results = func(*inputs, **kwargs)
return results
def checkModule(self, nn_module, args):
"""
Check that a nn.Module's results in Script mode match eager and that it
can be exported
"""
sm = torch.jit.script(nn_module)
with freeze_rng_state():
eager_out = nn_module(*args)
with freeze_rng_state():
script_out = sm(*args)
self.assertEqual(eager_out, script_out)
self.assertExportImportModule(sm, args)
return sm
@contextmanager
def inline_everything_mode(should_inline):
old = torch._C._jit_get_inline_everything_mode()
torch._C._jit_set_inline_everything_mode(should_inline)
try:
yield
finally:
torch._C._jit_set_inline_everything_mode(old)
# note: not re-entrant, use unnested only
@contextmanager
def disable_autodiff_subgraph_inlining(enabled=True):
torch._C._debug_set_autodiff_subgraph_inlining(not enabled)
try:
yield
finally:
torch._C._debug_set_autodiff_subgraph_inlining(True)
def _inline_everything(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
with inline_everything_mode(True):
fn(*args, **kwargs)
return wrapper
# this exists for forward compatibility reasons temporarily.
# TODO(suo) remove
def _tmp_donotuse_dont_inline_everything(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
with inline_everything_mode(False):
fn(*args, **kwargs)
return wrapper
# make it easy to quicky define/trace a function for these tests
def _trace(*args, **kwargs):
def wrapper(func):
return torch.jit.trace(func, args, **kwargs)
return wrapper
def enable_cpu_fuser(fn):
def wrapper(*args, **kwargs):
torch._C._jit_override_can_fuse_on_cpu(True)
try:
fn(*args, **kwargs)
finally:
torch._C._jit_override_can_fuse_on_cpu(False)
return wrapper
def enable_cpu_fuser_if(cond):
if cond:
return enable_cpu_fuser
else:
def noop_fuser(fn):
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)
return wrapper
return noop_fuser
def get_forward(c):
return c._get_method('forward')
def get_forward_graph(c):
return c._get_method('forward').graph
def get_module_method(m, module, method):
return m._c.getattr(module)._get_method(method)
def attrs_with_prefix(module, prefix):
return [x for x, _ in module._modules._c.items()
if x.startswith(prefix)]
op_alias_mappings = {
"absolute" : "abs",
"absolute_" : "abs_",
}