pytorch/torch/testing/_internal/common_jit.py
Lillian Johnson c465602d78 Refactor existing JIT testing utils to enable new OpInfo test suite to reuse existing logic (#47695)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47695

The method_tests from common_methods_invoations.py are being migrated into a new OpInfo class-based testing framework. The work in this commit pulls out the functions embedded in the old method_tests logic and places them in a location that both the old method_tests and OpInfo tests can use

Specifically: created torch/testing/_internal/common_jit.py from functions and methods in torch/testing/_internal/jit_utils.py and test/test_jit.py. Also created new intermediate class JitCommonTestCase to house moved methods. Also slightly modified jit_metaprogramming_utils.py to work for OpInfo tests

Test Plan: Imported from OSS

Reviewed By: mruberry

Differential Revision: D25212437

Pulled By: Lilyjjo

fbshipit-source-id: 97bc52c95d776d567750e7478fac722da30f4985
2020-12-02 19:54:30 -08:00

161 lines
6.6 KiB
Python

# Torch
import torch
import torch.cuda
import torch.jit
import torch.jit._logging
import torch.jit.frontend
import torch.jit.quantized
# Testing utils
from torch.testing import floating_and_complex_types_and
from torch.testing._internal.common_utils import TestCase, \
freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests
from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401
# Standard library
from itertools import chain
import io
def check_output_types(self, func, ref_outputs, args, kwargs):
graph = getattr(func, 'last_graph', None)
types = [o.type() for o in graph.outputs()]
self.assertTrue(len(types) == 1)
t = types[0]
torch._C._jit_assert_is_instance(ref_outputs, t)
# Test names in this set are only checked for a single derivative
nn_functional_single_grad = frozenset('test_nn_' + name for name in [
'pdist',
'multilabel_margin_loss',
'max_unpool3d',
'multi_margin_loss',
'binary_cross_entropy',
'binary_cross_entropy_size_average',
'ctc_loss',
'grid_sample',
])
def check_against_reference(self, func, reference_func, args, kwargs=None,
allow_unused=True, check_types=True, no_grad=False):
kwargs = kwargs if kwargs else {}
def allSum(vs):
if isinstance(vs, torch.Tensor):
vs = (vs,)
return sum((i + 1) * v.sum()
for i, v in enumerate(vs)
if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
def clone_inputs(requires_grad):
inputs = [
arg.detach().clone().requires_grad_(requires_grad and arg.requires_grad)
if isinstance(arg, torch.Tensor) else arg for arg in args
]
return inputs, [input for input in inputs if isinstance(input, torch.Tensor) and input.requires_grad]
nograd_inputs, nograd_tensors = clone_inputs(False)
recording_inputs, recording_tensors = clone_inputs(True)
# test no gradients case
outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
with enable_profiling_mode_for_profiling_tests():
outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
self.assertEqual(outputs, outputs_test)
if check_types:
check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
if no_grad:
# skip grad tests
return
with enable_profiling_mode_for_profiling_tests():
# test single grad case
outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
grads = torch.autograd.grad(allSum(outputs), recording_tensors,
allow_unused=allow_unused)
outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
allow_unused=allow_unused)
self.assertEqual(outputs, outputs_test)
self.assertEqual(grads, grads_test)
# test the grad grad case
if self._testMethodName in nn_functional_single_grad:
return
outputs = self.runAndSaveRNG(reference_func, recording_inputs, kwargs)
l1 = allSum(outputs)
grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
allow_unused=allow_unused)
l2 = (allSum(grads) * l1)
grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
recording_inputs, recording_tensors = clone_inputs(True)
outputs_test = self.runAndSaveRNG(func, recording_inputs, kwargs)
l1_test = allSum(outputs_test)
grads_test = torch.autograd.grad(
l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
l2_test = (allSum(grads_test) * l1_test)
grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
self.assertEqual(outputs, outputs_test)
self.assertEqual(grads, grads_test)
for g2, g2_test in zip(grads2, grads2_test):
if g2 is None and g2_test is None:
continue
self.assertTrue(torch.allclose(g2, g2_test, atol=5e-4, rtol=1e-4))
class JitCommonTestCase(TestCase):
def createFunctionFromGraph(self, trace):
graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
return torch._C._create_function_from_graph("forward", graph)
def assertExportImport(self, trace, inputs):
m = self.createFunctionFromGraph(trace)
self.assertExportImportModule(m, inputs)
def assertExportImportModule(self, m, inputs):
m_import = self.getExportImportCopy(m)
a = self.runAndSaveRNG(m, inputs)
b = self.runAndSaveRNG(m_import, inputs)
self.assertEqual(a, b, "Results of original model and "
"exported/imported version of model differed")
def runAndSaveRNG(self, func, inputs, kwargs=None):
kwargs = kwargs if kwargs else {}
with freeze_rng_state():
results = func(*inputs, **kwargs)
return results
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
buffer = io.BytesIO()
torch.jit.save(m, buffer)
buffer.seek(0)
imported = torch.jit.load(buffer, map_location=map_location)
if not also_test_file:
return imported
with TemporaryFileName() as fname:
torch.jit.save(imported, fname)
return torch.jit.load(fname, map_location=map_location)
def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
# For any non-fusible nde, it must show up in one of the DifferentiableGraph.
found_all_nonfusible_nodes = (len(diff_subgraphs) == 0 and len(nonfusible_nodes) == 0) \
or all([any(g.findNode(n) is not None for g in diff_subgraphs) for n in nonfusible_nodes])
# For any fusible node, it must show up in one of the FusionGroup in the DifferentiableGraph.
fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
found_all_fusible_nodes = (len(fusion_nodes) == 0 and len(fusible_nodes) == 0) \
or all([any(g.findNode(n) is not None for g in fusion_subgraphs) for n in fusible_nodes])
self.assertEqual(should_autodiff_node, found_all_nonfusible_nodes and found_all_fusible_nodes)