mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Reviewed By: florazzz Differential Revision: D50380126 Pull Request resolved: https://github.com/pytorch/pytorch/pull/111494 Approved by: https://github.com/Skylion007
299 lines
9.6 KiB
Python
299 lines
9.6 KiB
Python
|
|
|
|
|
|
import inspect
|
|
import os
|
|
import shutil
|
|
import sys
|
|
import tempfile
|
|
import threading
|
|
from contextlib import contextmanager
|
|
from zipfile import ZipFile
|
|
|
|
import argparse
|
|
import hypothesis as hy
|
|
import numpy as np
|
|
|
|
import caffe2.python.hypothesis_test_util as hu
|
|
from caffe2.proto import caffe2_pb2
|
|
from caffe2.python import gradient_checker
|
|
from caffe2.python.serialized_test import coverage
|
|
|
|
operator_test_type = 'operator_test'
|
|
TOP_DIR = os.path.dirname(os.path.realpath(__file__))
|
|
DATA_SUFFIX = 'data'
|
|
DATA_DIR = os.path.join(TOP_DIR, DATA_SUFFIX)
|
|
_output_context = threading.local()
|
|
|
|
|
|
def given(*given_args, **given_kwargs):
|
|
def wrapper(f):
|
|
hyp_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(*given_args, **given_kwargs)(f)))
|
|
fixed_seed_func = hy.seed(0)(hy.settings(max_examples=1)(hy.given(
|
|
*given_args, **given_kwargs)(f)))
|
|
|
|
def func(self, *args, **kwargs):
|
|
self.should_serialize = True
|
|
fixed_seed_func(self, *args, **kwargs)
|
|
self.should_serialize = False
|
|
hyp_func(self, *args, **kwargs)
|
|
return func
|
|
return wrapper
|
|
|
|
|
|
def _getGradientOrNone(op_proto):
|
|
try:
|
|
grad_ops, _ = gradient_checker.getGradientForOp(op_proto)
|
|
return grad_ops
|
|
except Exception:
|
|
return []
|
|
|
|
|
|
# necessary to support converting jagged lists into numpy arrays
|
|
def _transformList(l):
|
|
ret = np.empty(len(l), dtype=object)
|
|
for (i, arr) in enumerate(l):
|
|
ret[i] = arr
|
|
return ret
|
|
|
|
|
|
def _prepare_dir(path):
|
|
if os.path.exists(path):
|
|
shutil.rmtree(path)
|
|
os.makedirs(path)
|
|
|
|
|
|
class SerializedTestCase(hu.HypothesisTestCase):
|
|
|
|
should_serialize = False
|
|
|
|
def get_output_dir(self):
|
|
output_dir_arg = getattr(_output_context, 'output_dir', DATA_DIR)
|
|
output_dir = os.path.join(
|
|
output_dir_arg, operator_test_type)
|
|
|
|
if os.path.exists(output_dir):
|
|
return output_dir
|
|
|
|
# fall back to pwd
|
|
cwd = os.getcwd()
|
|
serialized_util_module_components = __name__.split('.')
|
|
serialized_util_module_components.pop()
|
|
serialized_dir = '/'.join(serialized_util_module_components)
|
|
output_dir_fallback = os.path.join(cwd, serialized_dir, DATA_SUFFIX)
|
|
output_dir = os.path.join(
|
|
output_dir_fallback,
|
|
operator_test_type)
|
|
|
|
return output_dir
|
|
|
|
def get_output_filename(self):
|
|
class_path = inspect.getfile(self.__class__)
|
|
file_name_components = os.path.basename(class_path).split('.')
|
|
test_file = file_name_components[0]
|
|
|
|
function_name_components = self.id().split('.')
|
|
test_function = function_name_components[-1]
|
|
|
|
return test_file + '.' + test_function
|
|
|
|
def serialize_test(self, inputs, outputs, grad_ops, op, device_option):
|
|
output_dir = self.get_output_dir()
|
|
test_name = self.get_output_filename()
|
|
full_dir = os.path.join(output_dir, test_name)
|
|
_prepare_dir(full_dir)
|
|
|
|
inputs = _transformList(inputs)
|
|
outputs = _transformList(outputs)
|
|
device_type = int(device_option.device_type)
|
|
|
|
op_path = os.path.join(full_dir, 'op.pb')
|
|
grad_paths = []
|
|
inout_path = os.path.join(full_dir, 'inout')
|
|
|
|
with open(op_path, 'wb') as f:
|
|
f.write(op.SerializeToString())
|
|
for (i, grad) in enumerate(grad_ops):
|
|
grad_path = os.path.join(full_dir, 'grad_{}.pb'.format(i))
|
|
grad_paths.append(grad_path)
|
|
with open(grad_path, 'wb') as f:
|
|
f.write(grad.SerializeToString())
|
|
|
|
np.savez_compressed(
|
|
inout_path,
|
|
inputs=inputs,
|
|
outputs=outputs,
|
|
device_type=device_type)
|
|
|
|
with ZipFile(os.path.join(output_dir, test_name + '.zip'), 'w') as z:
|
|
z.write(op_path, 'op.pb')
|
|
z.write(inout_path + '.npz', 'inout.npz')
|
|
for path in grad_paths:
|
|
z.write(path, os.path.basename(path))
|
|
|
|
shutil.rmtree(full_dir)
|
|
|
|
def compare_test(self, inputs, outputs, grad_ops, atol=1e-7, rtol=1e-7):
|
|
|
|
def parse_proto(x):
|
|
proto = caffe2_pb2.OperatorDef()
|
|
proto.ParseFromString(x)
|
|
return proto
|
|
|
|
source_dir = self.get_output_dir()
|
|
test_name = self.get_output_filename()
|
|
temp_dir = tempfile.mkdtemp()
|
|
with ZipFile(os.path.join(source_dir, test_name + '.zip')) as z:
|
|
z.extractall(temp_dir)
|
|
|
|
op_path = os.path.join(temp_dir, 'op.pb')
|
|
inout_path = os.path.join(temp_dir, 'inout.npz')
|
|
|
|
# load serialized input and output
|
|
loaded = np.load(inout_path, encoding='bytes', allow_pickle=True)
|
|
loaded_inputs = loaded['inputs'].tolist()
|
|
inputs_equal = True
|
|
for (x, y) in zip(inputs, loaded_inputs):
|
|
if not np.array_equal(x, y):
|
|
inputs_equal = False
|
|
loaded_outputs = loaded['outputs'].tolist()
|
|
|
|
# if inputs are not the same, run serialized input through serialized op
|
|
if not inputs_equal:
|
|
# load operator
|
|
with open(op_path, 'rb') as f:
|
|
loaded_op = f.read()
|
|
|
|
op_proto = parse_proto(loaded_op)
|
|
device_type = loaded['device_type']
|
|
device_option = caffe2_pb2.DeviceOption(
|
|
device_type=int(device_type))
|
|
|
|
outputs = hu.runOpOnInput(device_option, op_proto, loaded_inputs)
|
|
grad_ops = _getGradientOrNone(op_proto)
|
|
|
|
# assert outputs are equal
|
|
for (x, y) in zip(outputs, loaded_outputs):
|
|
np.testing.assert_allclose(x, y, atol=atol, rtol=rtol)
|
|
|
|
# assert gradient op is equal
|
|
for i in range(len(grad_ops)):
|
|
grad_path = os.path.join(temp_dir, 'grad_{}.pb'.format(i))
|
|
with open(grad_path, 'rb') as f:
|
|
loaded_grad = f.read()
|
|
grad_proto = parse_proto(loaded_grad)
|
|
self._assertSameOps(grad_proto, grad_ops[i])
|
|
|
|
shutil.rmtree(temp_dir)
|
|
|
|
def _assertSameOps(self, op1, op2):
|
|
op1_ = caffe2_pb2.OperatorDef()
|
|
op1_.CopyFrom(op1)
|
|
op1_.arg.sort(key=lambda arg: arg.name)
|
|
|
|
op2_ = caffe2_pb2.OperatorDef()
|
|
op2_.CopyFrom(op2)
|
|
op2_.arg.sort(key=lambda arg: arg.name)
|
|
|
|
self.assertEqual(op1_, op2_)
|
|
|
|
def assertSerializedOperatorChecks(
|
|
self,
|
|
inputs,
|
|
outputs,
|
|
gradient_operator,
|
|
op,
|
|
device_option,
|
|
atol=1e-7,
|
|
rtol=1e-7,
|
|
):
|
|
if self.should_serialize:
|
|
if getattr(_output_context, 'should_generate_output', False):
|
|
self.serialize_test(
|
|
inputs, outputs, gradient_operator, op, device_option)
|
|
if not getattr(_output_context, 'disable_gen_coverage', False):
|
|
coverage.gen_serialized_test_coverage(
|
|
self.get_output_dir(), TOP_DIR)
|
|
else:
|
|
self.compare_test(
|
|
inputs, outputs, gradient_operator, atol, rtol)
|
|
|
|
def assertReferenceChecks(
|
|
self,
|
|
device_option,
|
|
op,
|
|
inputs,
|
|
reference,
|
|
input_device_options=None,
|
|
threshold=1e-4,
|
|
output_to_grad=None,
|
|
grad_reference=None,
|
|
atol=None,
|
|
outputs_to_check=None,
|
|
ensure_outputs_are_inferred=False,
|
|
):
|
|
outs = super().assertReferenceChecks(
|
|
device_option,
|
|
op,
|
|
inputs,
|
|
reference,
|
|
input_device_options,
|
|
threshold,
|
|
output_to_grad,
|
|
grad_reference,
|
|
atol,
|
|
outputs_to_check,
|
|
ensure_outputs_are_inferred,
|
|
)
|
|
if not getattr(_output_context, 'disable_serialized_check', False):
|
|
grad_ops = _getGradientOrNone(op)
|
|
rtol = threshold
|
|
if atol is None:
|
|
atol = threshold
|
|
self.assertSerializedOperatorChecks(
|
|
inputs,
|
|
outs,
|
|
grad_ops,
|
|
op,
|
|
device_option,
|
|
atol,
|
|
rtol,
|
|
)
|
|
|
|
@contextmanager
|
|
def set_disable_serialized_check(self, val: bool):
|
|
orig = getattr(_output_context, 'disable_serialized_check', False)
|
|
try:
|
|
# pyre-fixme[16]: `local` has no attribute `disable_serialized_check`.
|
|
_output_context.disable_serialized_check = val
|
|
yield
|
|
finally:
|
|
_output_context.disable_serialized_check = orig
|
|
|
|
|
|
def testWithArgs():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument(
|
|
'-G', '--generate-serialized', action='store_true', dest='generate',
|
|
help='generate output files (default=false, compares to current files)')
|
|
parser.add_argument(
|
|
'-O', '--output', default=DATA_DIR,
|
|
help='output directory (default: %(default)s)')
|
|
parser.add_argument(
|
|
'-D', '--disable-serialized_check', action='store_true', dest='disable',
|
|
help='disable checking serialized tests')
|
|
parser.add_argument(
|
|
'-C', '--disable-gen-coverage', action='store_true',
|
|
dest='disable_coverage',
|
|
help='disable generating coverage markdown file')
|
|
parser.add_argument('unittest_args', nargs='*')
|
|
args = parser.parse_args()
|
|
sys.argv[1:] = args.unittest_args
|
|
_output_context.__setattr__('should_generate_output', args.generate)
|
|
_output_context.__setattr__('output_dir', args.output)
|
|
_output_context.__setattr__('disable_serialized_check', args.disable)
|
|
_output_context.__setattr__('disable_gen_coverage', args.disable_coverage)
|
|
|
|
import unittest
|
|
unittest.main()
|