import inspect import os import shutil import sys import tempfile import threading from contextlib import contextmanager from zipfile import ZipFile import argparse import hypothesis as hy import numpy as np import caffe2.python.hypothesis_test_util as hu from caffe2.proto import caffe2_pb2 from caffe2.python import gradient_checker from caffe2.python.serialized_test import coverage operator_test_type = 'operator_test' TOP_DIR = os.path.dirname(os.path.realpath(__file__)) DATA_SUFFIX = 'data' DATA_DIR = os.path.join(TOP_DIR, DATA_SUFFIX) _output_context = threading.local() def given(*given_args, **given_kwargs): def wrapper(f): hyp_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(*given_args, **given_kwargs)(f))) fixed_seed_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given( *given_args, **given_kwargs)(f))) def func(self, *args, **kwargs): self.should_serialize = True fixed_seed_func(self, *args, **kwargs) self.should_serialize = False hyp_func(self, *args, **kwargs) return func return wrapper def _getGradientOrNone(op_proto): try: grad_ops, _ = gradient_checker.getGradientForOp(op_proto) return grad_ops except Exception: return [] # necessary to support converting jagged lists into numpy arrays def _transformList(l): ret = np.empty(len(l), dtype=np.object) for (i, arr) in enumerate(l): ret[i] = arr return ret def _prepare_dir(path): if os.path.exists(path): shutil.rmtree(path) os.makedirs(path) class SerializedTestCase(hu.HypothesisTestCase): should_serialize = False def get_output_dir(self): output_dir_arg = getattr(_output_context, 'output_dir', DATA_DIR) output_dir = os.path.join( output_dir_arg, operator_test_type) if os.path.exists(output_dir): return output_dir # fall back to pwd cwd = os.getcwd() serialized_util_module_components = __name__.split('.') serialized_util_module_components.pop() serialized_dir = '/'.join(serialized_util_module_components) output_dir_fallback = os.path.join(cwd, serialized_dir, DATA_SUFFIX) output_dir = os.path.join( output_dir_fallback, operator_test_type) return output_dir def get_output_filename(self): class_path = inspect.getfile(self.__class__) file_name_components = os.path.basename(class_path).split('.') test_file = file_name_components[0] function_name_components = self.id().split('.') test_function = function_name_components[-1] return test_file + '.' + test_function def serialize_test(self, inputs, outputs, grad_ops, op, device_option): output_dir = self.get_output_dir() test_name = self.get_output_filename() full_dir = os.path.join(output_dir, test_name) _prepare_dir(full_dir) inputs = _transformList(inputs) outputs = _transformList(outputs) device_type = int(device_option.device_type) op_path = os.path.join(full_dir, 'op.pb') grad_paths = [] inout_path = os.path.join(full_dir, 'inout') with open(op_path, 'wb') as f: f.write(op.SerializeToString()) for (i, grad) in enumerate(grad_ops): grad_path = os.path.join(full_dir, 'grad_{}.pb'.format(i)) grad_paths.append(grad_path) with open(grad_path, 'wb') as f: f.write(grad.SerializeToString()) np.savez_compressed( inout_path, inputs=inputs, outputs=outputs, device_type=device_type) with ZipFile(os.path.join(output_dir, test_name + '.zip'), 'w') as z: z.write(op_path, 'op.pb') z.write(inout_path + '.npz', 'inout.npz') for path in grad_paths: z.write(path, os.path.basename(path)) shutil.rmtree(full_dir) def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7): def parse_proto(x): proto = caffe2_pb2.OperatorDef() proto.ParseFromString(x) return proto source_dir = self.get_output_dir() test_name = self.get_output_filename() temp_dir = tempfile.mkdtemp() with ZipFile(os.path.join(source_dir, test_name + '.zip')) as z: z.extractall(temp_dir) op_path = os.path.join(temp_dir, 'op.pb') inout_path = os.path.join(temp_dir, 'inout.npz') # load serialized input and output loaded = np.load(inout_path, encoding='bytes', allow_pickle=True) loaded_inputs = loaded['inputs'].tolist() inputs_equal = True for (x, y) in zip(inputs, loaded_inputs): if not np.array_equal(x, y): inputs_equal = False loaded_outputs = loaded['outputs'].tolist() # if inputs are not the same, run serialized input through serialized op if not inputs_equal: # load operator with open(op_path, 'rb') as f: loaded_op = f.read() op_proto = parse_proto(loaded_op) device_type = loaded['device_type'] device_option = caffe2_pb2.DeviceOption( device_type=int(device_type)) outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs) grad_ops = _getGradientOrNone(op_proto) # assert outputs are equal for (x, y) in zip(outputs, loaded_outputs): np.testing.assert_allclose(x, y, atol=atol, rtol=rtol) # assert gradient op is equal for i in range(len(grad_ops)): grad_path = os.path.join(temp_dir, 'grad_{}.pb'.format(i)) with open(grad_path, 'rb') as f: loaded_grad = f.read() grad_proto = parse_proto(loaded_grad) self._assertSameOps(grad_proto, grad_ops[i]) shutil.rmtree(temp_dir) def _assertSameOps(self, op1, op2): op1_ = caffe2_pb2.OperatorDef() op1_.CopyFrom(op1) op1_.arg.sort(key=lambda arg: arg.name) op2_ = caffe2_pb2.OperatorDef() op2_.CopyFrom(op2) op2_.arg.sort(key=lambda arg: arg.name) self.assertEqual(op1_, op2_) def assertSerializedOperatorChecks( self, inputs, outputs, gradient_operator, op, device_option, atol=1e-7, rtol=1e-7, ): if self.should_serialize: if getattr(_output_context, 'should_generate_output', False): self.serialize_test( inputs, outputs, gradient_operator, op, device_option) if not getattr(_output_context, 'disable_gen_coverage', False): coverage.gen_serialized_test_coverage( self.get_output_dir(), TOP_DIR) else: self.compare_test( inputs, outputs, gradient_operator, atol, rtol) def assertReferenceChecks( self, device_option, op, inputs, reference, input_device_options=None, threshold=1e-4, output_to_grad=None, grad_reference=None, atol=None, outputs_to_check=None, ensure_outputs_are_inferred=False, ): outs = super(SerializedTestCase, self).assertReferenceChecks( device_option, op, inputs, reference, input_device_options, threshold, output_to_grad, grad_reference, atol, outputs_to_check, ensure_outputs_are_inferred, ) if not getattr(_output_context, 'disable_serialized_check', False): grad_ops = _getGradientOrNone(op) rtol = threshold if atol is None: atol = threshold self.assertSerializedOperatorChecks( inputs, outs, grad_ops, op, device_option, atol, rtol, ) @contextmanager def set_disable_serialized_check(self, val: bool): orig = getattr(_output_context, 'disable_serialized_check', False) try: # pyre-fixme[16]: `local` has no attribute `disable_serialized_check`. _output_context.disable_serialized_check = val yield finally: _output_context.disable_serialized_check = orig def testWithArgs(): parser = argparse.ArgumentParser() parser.add_argument( '-G', '--generate-serialized', action='store_true', dest='generate', help='generate output files (default=false, compares to current files)') parser.add_argument( '-O', '--output', default=DATA_DIR, help='output directory (default: %(default)s)') parser.add_argument( '-D', '--disable-serialized_check', action='store_true', dest='disable', help='disable checking serialized tests') parser.add_argument( '-C', '--disable-gen-coverage', action='store_true', dest='disable_coverage', help='disable generating coverage markdown file') parser.add_argument('unittest_args', nargs='*') args = parser.parse_args() sys.argv[1:] = args.unittest_args _output_context.__setattr__('should_generate_output', args.generate) _output_context.__setattr__('output_dir', args.output) _output_context.__setattr__('disable_serialized_check', args.disable) _output_context.__setattr__('disable_gen_coverage', args.disable_coverage) import unittest unittest.main()