mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This will be the last disruptive functorch internals change. Why are we moving these files? - As a part of rationalizing functorch we are moving the code in functorch/_src to torch/_functorch - This is so that we can offer the functorch APIs as native PyTorch APIs (coming soon) and resolve some internal build issues. Why are we moving all of these files at once? - It's better to break developers all at once rather than many times Test Plan: - wait for tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091 Approved by: https://github.com/anijain2305, https://github.com/ezyang
897 lines
30 KiB
Python
897 lines
30 KiB
Python
import torch
|
|
import copy
|
|
from torch.testing._internal.common_methods_invocations import op_db
|
|
from functorch_additional_op_db import additional_op_db
|
|
from enum import Enum
|
|
import torch._functorch.top_operators_github_usage as top_ops
|
|
import pprint
|
|
import unittest
|
|
import enum
|
|
from torch.testing._internal.common_device_type import toleranceOverride
|
|
|
|
# Importing these files make modifications to the op_db that we need
|
|
import test_ops # noqa: F401
|
|
import test_vmap # noqa: F401
|
|
|
|
all_overridable = list(torch.overrides.get_testing_overrides().keys())
|
|
|
|
public_docs = [
|
|
(torch.nn.functional, 'torch.nn.functional', 'docs/source/nn.functional.rst'),
|
|
(torch.fft, 'torch.fft', 'docs/source/fft.rst'),
|
|
(torch.special, 'torch.special', 'docs/source/special.rst'),
|
|
(torch.linalg, 'torch.linalg', 'docs/source/linalg.rst'),
|
|
(torch, 'torch', 'docs/source/torch.rst'),
|
|
(torch.Tensor, 'torch.Tensor', 'docs/source/tensors.rst'),
|
|
]
|
|
|
|
# torch.abs, Tensor.abs, Tensor.abs_ are all considered to be different
|
|
|
|
|
|
def get_public_overridable_apis(pytorch_root='/raid/rzou/pt/debug-cpu'):
|
|
results = {}
|
|
all_overridable_apis = set(torch.overrides.get_testing_overrides().keys())
|
|
for module, module_name, src in public_docs:
|
|
with open(f'{pytorch_root}/{src}') as f:
|
|
lines = f.readlines()
|
|
# APIs eitehr begin with 4 spaces or ".. autofunction::"
|
|
api_lines1 = [line.strip() for line in lines if line.startswith(' ' * 4)]
|
|
api_lines2 = [line.strip()[len('.. autofunction:: '):]
|
|
for line in lines if line.startswith('.. autofunction::')]
|
|
lines = api_lines1 + api_lines2
|
|
lines = [line[7:] if line.startswith('Tensor.') else line for line in lines]
|
|
lines = [line for line in lines if hasattr(module, line)]
|
|
for line in lines:
|
|
api = getattr(module, line)
|
|
if api in all_overridable_apis:
|
|
results[f'{module_name}.{line}'] = api
|
|
return results
|
|
|
|
|
|
denylist = {
|
|
'torch.Tensor.data_ptr',
|
|
'torch.Tensor.dim',
|
|
'torch.Tensor.element_size',
|
|
'torch.Tensor.backward',
|
|
'torch.Tensor.as_strided',
|
|
'torch.Tensor.register_hook',
|
|
'torch.Tensor.record_stream',
|
|
'torch.Tensor.qscheme',
|
|
'torch.Tensor.ndimension',
|
|
'torch.Tensor.smm',
|
|
'torch.Tensor.sspaddmm',
|
|
'torch.Tensor.retain_grad',
|
|
'torch.Tensor.sparse_mask',
|
|
'torch.Tensor.sparse_dim',
|
|
'torch.Tensor.dense_dim',
|
|
'torch.Tensor.values',
|
|
'torch.Tensor.indices',
|
|
'torch.Tensor.numel',
|
|
'torch.Tensor.size',
|
|
'torch.Tensor.nelement',
|
|
'torch.Tensor.q_scale',
|
|
'torch.Tensor.q_zero_point',
|
|
'torch.Tensor.q_per_channel_scales',
|
|
'torch.Tensor.q_per_channel_zero_points',
|
|
'torch.Tensor.q_per_channel_axis',
|
|
'torch.Tensor.int_repr',
|
|
'torch.Tensor.to_sparse',
|
|
'torch.Tensor.is_inference',
|
|
'torch.Tensor.storage',
|
|
'torch.Tensor.storage_type',
|
|
}
|
|
|
|
|
|
def get_method_only_ops_we_care_about():
|
|
apis = get_public_overridable_apis()
|
|
result = []
|
|
for key, _ in apis.items():
|
|
if not key.startswith('torch.Tensor'):
|
|
continue
|
|
if key in denylist:
|
|
continue
|
|
api = key.split('.')[2]
|
|
# filter out in-place
|
|
if api.endswith('_'):
|
|
continue
|
|
if f'torch.{api}' not in apis.keys():
|
|
result.append(api)
|
|
return result
|
|
|
|
# Deduplicates torch.abs and Tensor.abs
|
|
|
|
|
|
def get_public_overridable_ops():
|
|
results = get_public_overridable_apis()
|
|
cpy = copy.deepcopy(results)
|
|
for key, _ in cpy.items():
|
|
if not key.startswith('torch.Tensor'):
|
|
continue
|
|
api = key.split('.')[2]
|
|
if f'torch.{api}' in results.keys():
|
|
del results[key]
|
|
return results
|
|
|
|
|
|
def get_public_overridable_outplace_ops():
|
|
results = get_public_overridable_ops()
|
|
cpy = copy.deepcopy(results)
|
|
for key, _ in cpy.items():
|
|
# NB: there are no dunder methods bcs we don't document those
|
|
if key.endswith('_'):
|
|
del results[key]
|
|
return results
|
|
|
|
|
|
def get_public_overridable_outplace_we_care_about():
|
|
results = get_public_overridable_outplace_ops()
|
|
cpy = copy.deepcopy(results)
|
|
for key, _ in cpy.items():
|
|
# quantization
|
|
if 'quant' in key or '.q_' in key:
|
|
del results[key]
|
|
|
|
# is_cpu, etc. It doesn't make sense to have OpInfos for these
|
|
if '.is_' in key:
|
|
del results[key]
|
|
|
|
if key in denylist and key in results:
|
|
del results[key]
|
|
return results
|
|
|
|
# e.g. nn.functional.softmax
|
|
|
|
|
|
def get_op(dotted_name):
|
|
names = dotted_name.split('.')
|
|
mod = torch
|
|
for name in names:
|
|
if not hasattr(mod, name):
|
|
return None
|
|
mod = getattr(mod, name)
|
|
return mod
|
|
|
|
# Maps function -> [OpInfo]
|
|
|
|
|
|
def get_ops_covered_by_opinfos():
|
|
ops = {}
|
|
|
|
def safe_append(dct, key, val):
|
|
if key in dct:
|
|
dct[key].append(val)
|
|
else:
|
|
dct[key] = [val]
|
|
|
|
for opinfo in op_db:
|
|
func_op = get_op(opinfo.name)
|
|
if func_op:
|
|
safe_append(ops, func_op, opinfo)
|
|
if opinfo.method_variant:
|
|
safe_append(ops, opinfo.method_variant, opinfo)
|
|
if opinfo.inplace_variant:
|
|
safe_append(ops, opinfo.inplace_variant, opinfo)
|
|
for alias in opinfo.aliases:
|
|
safe_append(ops, alias.op, opinfo)
|
|
return ops
|
|
|
|
|
|
factory_fns = {
|
|
'tensor', 'zeros', 'ones', 'randn', 'arange', 'rand', 'empty', 'randperm',
|
|
'linspace', 'logspace', 'hann_window', 'full', 'eye', 'blackman_window',
|
|
'barlett_window', 'randint', 'range', 'arange',
|
|
}
|
|
|
|
|
|
def get_top_ops(torch_threshold, nn_fn_threshold, with_counts=False):
|
|
denylist = set({
|
|
# These are either not real "operators", factory functions
|
|
# that trivially work, or not-documented ops.
|
|
'load', 'no_grad', 'save', 'from_numpy',
|
|
'manual_seed', 'set_grad_enabled',
|
|
'set_default_tensor_type', 'set_num_threads',
|
|
'set_printoptions', 'numel',
|
|
'set_default_dtype', 'sparse_coo_tensor', 'set_rng_state',
|
|
'get_rng_state', 'get_default_dtype', 'initial_seed',
|
|
'get_num_threads', 'quantize_per_tensor',
|
|
'hann_window', 'is_tensor', 'as_tensor',
|
|
'equal', 'enable_grad', 'seed', 'is_storage',
|
|
'is_floating_point', 'nn.functional.torch',
|
|
'set_flush_denormal', 'set_num_interop_threads', 'dequantize',
|
|
'get_num_interop_threads', 'nn.functional.math',
|
|
'nn.functional.threshold_',
|
|
'nn.functional.selu_',
|
|
'nn.functional.elu_',
|
|
'nn.functional.rrelu_',
|
|
'nn.functional.leaky_relu_',
|
|
'nn.functional.hardtanh_',
|
|
'nn.functional.has_torch_function',
|
|
'nn.functional.has_torch_function_unary',
|
|
'nn.functional.has_torch_function_variadic',
|
|
'nn.functional.handle_torch_function',
|
|
'nn.functional.adaptive_max_pool1d_with_indices',
|
|
'nn.functional.adaptive_max_pool2d_with_indices',
|
|
'nn.functional.adaptive_max_pool3d_with_indices',
|
|
'nn.functional.fractional_max_pool2d_with_indices',
|
|
'nn.functional.fractional_max_pool3d_with_indices',
|
|
'is_complex',
|
|
'grad',
|
|
'quantize_per_channel',
|
|
'nn.functional.max_pool2d_with_indices',
|
|
'nn.functional.max_pool3d_with_indices',
|
|
'nn.functional.max_pool1d_with_indices',
|
|
'nn.functional.celu_',
|
|
'nn.functional.grad',
|
|
'nn.functional.relu_',
|
|
'nn.functional.boolean_dispatch',
|
|
'nn.functional.assert_int_or_pair',
|
|
'fft', # is namespace
|
|
})
|
|
|
|
torch_ops = top_ops.top_torch
|
|
nn_fn_ops = top_ops.get_nn_functional_top_list()
|
|
torch_ops = [op for op in torch_ops if op[0] not in denylist]
|
|
nn_fn_ops = [op for op in nn_fn_ops if op[0] not in denylist]
|
|
|
|
ops = torch_ops[:torch_threshold] + nn_fn_ops[:nn_fn_threshold]
|
|
|
|
# Now, sort by priority
|
|
ops.sort(reverse=True, key=lambda op: op[1])
|
|
if not with_counts:
|
|
ops = [op[0] for op in ops]
|
|
return ops
|
|
|
|
|
|
def get_ops_percentage(torch_threshold, nn_fn_threshold):
|
|
data = top_ops.top_torch + top_ops.get_nn_functional_top_list()
|
|
|
|
def get_num_usages(opname):
|
|
# Ignore this, this is heavily inflated
|
|
if opname == 't':
|
|
return 0
|
|
result = [op[1] for op in data if op[0] == opname]
|
|
assert len(result) == 1
|
|
return result[0]
|
|
|
|
# get all operators that are not in the denylist
|
|
all_ops = get_top_ops(999999, 999999)
|
|
total_op_usages = sum([get_num_usages(op) for op in all_ops])
|
|
|
|
# get subset of all operators
|
|
subset_ops = get_top_ops(torch_threshold, nn_fn_threshold)
|
|
subset_op_usages = sum([get_num_usages(op) for op in subset_ops])
|
|
return subset_op_usages / total_op_usages
|
|
|
|
|
|
def get_top_ops_not_covered_by_opinfo(torch_threshold=0, nn_fn_threshold=0):
|
|
ops = get_top_ops(torch_threshold, nn_fn_threshold)
|
|
|
|
ops_with_opinfo = []
|
|
for op in op_db:
|
|
ops_with_opinfo.append(op.name)
|
|
ops_with_opinfo.extend([op.name for op in op.aliases])
|
|
ops_with_opinfo = set(ops_with_opinfo)
|
|
|
|
result = [op for op in ops if op not in ops_with_opinfo]
|
|
result = [op for op in result if op not in denylist]
|
|
result = [op for op in result if op not in factory_fns]
|
|
return result
|
|
|
|
|
|
def get_covered_ops(ops_list, invert=False):
|
|
ops_covered_by_opinfo = get_ops_covered_by_opinfos()
|
|
overridable_outplace_ops = ops_list
|
|
results = {}
|
|
for key, op in overridable_outplace_ops.items():
|
|
cond = op in ops_covered_by_opinfo
|
|
if invert:
|
|
cond = not cond
|
|
if cond:
|
|
results[key] = op
|
|
return results
|
|
|
|
|
|
class Status(Enum):
|
|
Correct = 0
|
|
Fast = 1
|
|
|
|
|
|
tests = {
|
|
'test_vmap_exhaustive',
|
|
'test_op_has_batch_rule',
|
|
'test_vjp',
|
|
'test_vmapvjp',
|
|
'test_vmapvjp_has_batch_rule',
|
|
'test_jvp',
|
|
'test_vmapjvp',
|
|
}
|
|
|
|
|
|
def is_decorateinfo_skip_or_xfail(decorateinfo):
|
|
assert len(decorateinfo.decorators) == 1
|
|
actual_decorator = decorateinfo.decorators[0]
|
|
if isinstance(actual_decorator, toleranceOverride):
|
|
return False
|
|
if actual_decorator == unittest.expectedFailure:
|
|
return True
|
|
# Assume the rest are skips
|
|
return True
|
|
|
|
|
|
def get_all_tested_ops():
|
|
overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
|
|
op_to_opinfo = get_ops_covered_by_opinfos()
|
|
result = set({})
|
|
for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
|
|
opinfos = op_to_opinfo[op]
|
|
for opinfo in opinfos:
|
|
result.add(opinfo.name)
|
|
return result
|
|
|
|
|
|
def get_skipped_or_xfailed_ops_for(test_name):
|
|
overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
|
|
op_to_opinfo = get_ops_covered_by_opinfos()
|
|
result = set({})
|
|
for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
|
|
opinfos = op_to_opinfo[op]
|
|
for opinfo in opinfos:
|
|
for decorator in opinfo.decorators:
|
|
if not hasattr(decorator, 'test_name'):
|
|
continue
|
|
if decorator.test_name != test_name:
|
|
continue
|
|
if is_decorateinfo_skip_or_xfail(decorator):
|
|
result.add(opinfo.name)
|
|
return result
|
|
|
|
|
|
def get_statuses(for_subset=None, invert=False):
|
|
overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
|
|
if for_subset is not None:
|
|
overridable_outplace_we_care_about = {
|
|
k: v
|
|
for k, v in overridable_outplace_we_care_about.items()
|
|
# Removes "torch."
|
|
if k[6:] in for_subset
|
|
}
|
|
op_to_opinfo = get_ops_covered_by_opinfos()
|
|
result = {}
|
|
_ = get_covered_ops(overridable_outplace_we_care_about)
|
|
|
|
def get_covered_tests(op):
|
|
opinfos = op_to_opinfo[op]
|
|
result = copy.deepcopy(tests)
|
|
for opinfo in opinfos:
|
|
for decorator in opinfo.decorators:
|
|
if not hasattr(decorator, 'test_name'):
|
|
continue
|
|
if decorator.test_name in tests and decorator.test_name in result:
|
|
result.remove(decorator.test_name)
|
|
return result
|
|
|
|
def get_all_aliases(op):
|
|
opinfos = op_to_opinfo[op]
|
|
result = []
|
|
for opinfo in opinfos:
|
|
result.append(opinfo.name)
|
|
result.extend(opinfo.aliases)
|
|
return set(result)
|
|
|
|
for name, op in get_covered_ops(overridable_outplace_we_care_about).items():
|
|
successful_tests = get_covered_tests(op)
|
|
failed_tests = tests - successful_tests
|
|
result[name] = failed_tests if invert else successful_tests
|
|
return result
|
|
|
|
|
|
def transpose_statuses(for_subset=None, invert=False):
|
|
statuses = get_statuses(for_subset, invert=invert)
|
|
result = {}
|
|
for test in tests:
|
|
result[test] = set({})
|
|
for op, supported in statuses.items():
|
|
for test in supported:
|
|
result[test].add(op)
|
|
return result
|
|
|
|
|
|
overridable_apis = get_public_overridable_apis()
|
|
|
|
overridable_ops = get_public_overridable_ops()
|
|
|
|
overridable_outplace_ops = get_public_overridable_outplace_ops()
|
|
|
|
overridable_outplace_we_care_about = get_public_overridable_outplace_we_care_about()
|
|
|
|
tested_overridable_outplace_ops = get_covered_ops(overridable_outplace_we_care_about)
|
|
untested_overridable_outplace_ops = get_covered_ops(overridable_outplace_we_care_about, invert=True)
|
|
|
|
# print("List of OpInfos we need:")
|
|
# for key in untested_overridable_outplace_ops.keys():
|
|
# print(key)
|
|
# print("-" * 80)
|
|
# print("")
|
|
|
|
print(f'Overridable public APIs: {len(overridable_apis)}')
|
|
print(f'Overridable public ops: {len(overridable_ops)}')
|
|
print(f'Overridable public outplace ops: {len(overridable_outplace_ops)}')
|
|
print(f'Overridable public outplace ops we care about: {len(overridable_outplace_we_care_about)}')
|
|
print(f'OpInfo-tested overridable public outplace ops: {len(tested_overridable_outplace_ops)}')
|
|
|
|
|
|
def remove_torch(name):
|
|
assert name[:6] == 'torch.'
|
|
return name[6:]
|
|
|
|
|
|
def get_list_of_all_tests():
|
|
all_tests = list(tested_overridable_outplace_ops.keys())
|
|
return set([remove_torch(test) for test in all_tests])
|
|
|
|
|
|
mytest = {
|
|
'test_vmap_exhaustive',
|
|
'test_op_has_batch_rule',
|
|
'test_vjp',
|
|
'test_vmapvjp',
|
|
'test_vmapvjp_has_batch_rule',
|
|
}
|
|
|
|
print('*' * 80)
|
|
all_tests = get_list_of_all_tests()
|
|
for test in mytest:
|
|
result = get_skipped_or_xfailed_ops_for(test)
|
|
diff = len(all_tests - result)
|
|
print(f'{test}: {diff}')
|
|
|
|
|
|
def get_jvp_coverage(subset=None):
|
|
# - number that support autograd
|
|
# - number that support forward_ad (in pytorch core)
|
|
# - number that support functorch.jvp
|
|
op_to_opinfo = get_ops_covered_by_opinfos()
|
|
ops_dct = tested_overridable_outplace_ops
|
|
if subset is not None:
|
|
ops_dct = {name: op for name, op in ops_dct.items()
|
|
if remove_torch(name) in subset}
|
|
supports_autograd_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items()
|
|
if op_to_opinfo[fn][0].supports_autograd}
|
|
supports_forwardad_ops_dct = {name: op_to_opinfo[fn] for name, fn in ops_dct.items()
|
|
if op_to_opinfo[fn][0].supports_forward_ad}
|
|
|
|
ops = set([remove_torch(test) for test in list(ops_dct.keys())])
|
|
supports_autograd = set([remove_torch(test)
|
|
for test in list(supports_autograd_ops_dct.keys())])
|
|
supports_forward_ad = set([remove_torch(test)
|
|
for test in list(supports_forwardad_ops_dct.keys())])
|
|
assert supports_forward_ad.issubset(supports_autograd)
|
|
assert supports_autograd.issubset(ops)
|
|
|
|
failed_ops = get_skipped_or_xfailed_ops_for('test_jvp')
|
|
|
|
coverage = len(supports_forward_ad - failed_ops)
|
|
no_forward_ad = len(supports_autograd) - len(supports_forward_ad)
|
|
print(f'test_jvp, {coverage}, {no_forward_ad}, {len(ops)}')
|
|
|
|
|
|
get_jvp_coverage()
|
|
get_jvp_coverage(get_top_ops(100, 25))
|
|
for op in get_top_ops(100, 25):
|
|
print(op)
|
|
print('*' * 80)
|
|
|
|
# result = get_skipped_or_xfailed_ops_for('test_vmap_exhaustive')
|
|
# result = get_skipped_or_xfailed_ops_for('test_op_has_batch_rule')
|
|
# result = get_skipped_or_xfailed_ops_for('test_vjp')
|
|
# result = get_skipped_or_xfailed_ops_for('test_vmapvjp')
|
|
# result = get_skipped_or_xfailed_ops_for('test_vmapvjp_has_batch_rule')
|
|
# import pdb; pdb.set_trace()
|
|
|
|
statuses = transpose_statuses()
|
|
for test in tests:
|
|
print(f'{test} coverage {len(statuses[test])}')
|
|
|
|
method_only_ops = get_method_only_ops_we_care_about()
|
|
# for op in method_only_ops:
|
|
# print(f' {op},')
|
|
|
|
top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(100, 25)
|
|
print('=' * 80)
|
|
for op in top_ops_not_covered_by_opinfo:
|
|
print(f'{op}, {top_ops.usage_count[op]}')
|
|
|
|
# print("top ops not covered by opinfo: ")
|
|
# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(200, 50)
|
|
# for op in top_ops_not_covered_by_opinfo:
|
|
# print(f'{op}, {top_ops.usage_count[op]}')
|
|
|
|
# print("top ops not covered by opinfo: ")
|
|
# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(220, 92)
|
|
# for op in top_ops_not_covered_by_opinfo:
|
|
# print(f'{op}, {top_ops.usage_count[op]}')
|
|
|
|
# print("top ops not covered by opinfo: ")
|
|
# top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(999, 999)
|
|
# for op in top_ops_not_covered_by_opinfo:
|
|
# print(f'{op}, {top_ops.usage_count[op]}')
|
|
|
|
|
|
def remove_from_set(parent, to_remove):
|
|
for to_remove_elt in to_remove:
|
|
if to_remove_elt in parent:
|
|
parent.remove(to_remove_elt)
|
|
|
|
|
|
def print_coverage_info(th=100, nn=25):
|
|
print('=' * 80)
|
|
print(f"top {th}, {nn} coverage")
|
|
statuses = transpose_statuses(get_top_ops(th, nn), invert=True)
|
|
top_ops_not_covered_by_opinfo = get_top_ops_not_covered_by_opinfo(th, nn)
|
|
|
|
# testing problems
|
|
exemptions = {
|
|
'torch.nn.functional.dropout', # randomness
|
|
}
|
|
|
|
# Allowed exemptions
|
|
vmap_exemptions = {
|
|
'torch.randn_like', # randomness
|
|
'torch.rand_like', # randomness
|
|
'torch.allclose', # number output
|
|
'torch.unique', # dynamic
|
|
'torch.nonzero', # dynamic
|
|
'torch.masked_select', # dynamic
|
|
'torch.prod', # dynamic (backward)
|
|
'torch.norm', # norm with nuc is not commonly used; we support the other cases.
|
|
'torch.svd', # There isn't a bug, it is just nondeterministic so we can't test it.
|
|
'torch.nn.functional.embedding', # We support everything except the sparse option.
|
|
}
|
|
remove_from_set(statuses['test_vmap_exhaustive'], vmap_exemptions)
|
|
remove_from_set(statuses['test_vmapvjp'], vmap_exemptions)
|
|
remove_from_set(statuses['test_vmapvjp_has_batch_rule'], vmap_exemptions)
|
|
remove_from_set(statuses['test_op_has_batch_rule'], vmap_exemptions)
|
|
remove_from_set(statuses['test_vmapjvp'], vmap_exemptions)
|
|
for test in tests:
|
|
remove_from_set(statuses[test], exemptions)
|
|
|
|
print(f"total ops in set: {th + nn}")
|
|
print(f"tested by OpInfo: {th + nn - len(top_ops_not_covered_by_opinfo)}")
|
|
for test in tests:
|
|
if test in {'test_jvp', 'test_vmapjvp'}:
|
|
continue
|
|
print(f'{test} failing coverage {len(statuses[test])}')
|
|
|
|
# We don't care about these yet
|
|
del statuses['test_jvp']
|
|
del statuses['test_vmapjvp']
|
|
|
|
pprint.pprint(statuses)
|
|
|
|
|
|
def get_name_to_opinfo_map():
|
|
dct = {}
|
|
for op in (op_db + additional_op_db):
|
|
def add(name, op):
|
|
if name not in dct:
|
|
dct[name] = []
|
|
dct[name].append(op)
|
|
add(op.name, op)
|
|
for alias in op.aliases:
|
|
add(alias.name, op)
|
|
return dct
|
|
|
|
|
|
NAME_TO_OPINFO = get_name_to_opinfo_map()
|
|
|
|
|
|
class Support(enum.Enum):
|
|
NO = 0
|
|
YES = 1
|
|
UNKNOWN = 2
|
|
|
|
|
|
FACTORY_FNS = {
|
|
'tensor', 'zeros', 'ones', 'randn', 'arange', 'rand', 'empty', 'range',
|
|
'full', 'randperm', 'eye', 'randint', 'linspace', 'logspace',
|
|
}
|
|
|
|
VJP_EXEMPTIONS = {
|
|
'nn.functional.dropout', # not actually problem, randomness testing artifact
|
|
'nn.functional.dropout2d', # not actually problem, randomness testing artifact
|
|
'nn.functional.rrelu', # not actually problem, randomness testing artifact
|
|
'bernoulli', # not actually problem, randomness testing artifact
|
|
'normal', # not actually problem, randomness testing artifact
|
|
}
|
|
|
|
VMAP_EXEMPTIONS = {
|
|
'randn_like', # randomness
|
|
'rand_like', # randomness
|
|
'allclose', # number output
|
|
'unique', # dynamic
|
|
'nonzero', # dynamic
|
|
'masked_select', # dynamic
|
|
'prod', # dynamic (backward)
|
|
'norm', # norm with nuc is not commonly used; we support the other cases.
|
|
'svd', # There isn't a bug, it is just nondeterministic so we can't test it.
|
|
'nn.functional.embedding', # We support everything except the sparse option.
|
|
'nn.functional.dropout', # randomness
|
|
'nn.functional.dropout2d', # randomness
|
|
'bernoulli', # randomness
|
|
'multinomial', # randomness
|
|
'normal', # randomness
|
|
}
|
|
|
|
JVP_EXEMPTIONS = {
|
|
'nn.functional.dropout', # not actually problem, randomness testing artifact
|
|
'nn.functional.dropout2d', # not actually problem, randomness testing artifact
|
|
'nn.functional.rrelu', # not actually problem, randomness testing artifact
|
|
'normal', # not actually problem, randomness testing artifact
|
|
'bernoulli', # not actually problem, randomness testing artifact
|
|
}
|
|
|
|
|
|
class Operator:
|
|
def __init__(self, name):
|
|
self.name = name
|
|
self.opinfos = NAME_TO_OPINFO.get(name, None)
|
|
assert self.opinfos is None or len(self.opinfos) > 0
|
|
|
|
def has_opinfo(self):
|
|
return self.opinfos is not None
|
|
|
|
def __repr__(self):
|
|
return f'Operator("{self.name}")'
|
|
|
|
def __hash__(self):
|
|
return hash(self.name)
|
|
|
|
def no_opinfos_skip_test(self, test_name):
|
|
"""Returns NO if any opinfos have a skip or xfail for the test"""
|
|
if not self.has_opinfo():
|
|
return Support.UNKNOWN
|
|
for opinfo in self.opinfos:
|
|
for decorator in opinfo.decorators:
|
|
if not hasattr(decorator, 'test_name'):
|
|
continue
|
|
if decorator.test_name != test_name:
|
|
continue
|
|
if is_decorateinfo_skip_or_xfail(decorator):
|
|
return Support.NO
|
|
return Support.YES
|
|
|
|
def any_opinfo_attr(self, attr):
|
|
if not self.has_opinfo():
|
|
raise RuntimeError()
|
|
return any([getattr(opinfo, attr) for opinfo in self.opinfos])
|
|
|
|
def all_opinfo_attr(self, attr):
|
|
if not self.has_opinfo():
|
|
raise RuntimeError()
|
|
return all([getattr(opinfo, attr) for opinfo in self.opinfos])
|
|
|
|
def supports_vjp(self):
|
|
if self.name in FACTORY_FNS:
|
|
return Support.YES
|
|
if self.name in VJP_EXEMPTIONS:
|
|
return Support.YES
|
|
return self.no_opinfos_skip_test('test_vjp')
|
|
|
|
def supports_vmap(self):
|
|
if self.name in FACTORY_FNS:
|
|
return Support.YES
|
|
if self.name in VMAP_EXEMPTIONS:
|
|
return Support.YES
|
|
return self.no_opinfos_skip_test('test_vmap_exhaustive')
|
|
|
|
def supports_fast_vmap(self):
|
|
if self.name in FACTORY_FNS:
|
|
return Support.YES
|
|
if self.name in VMAP_EXEMPTIONS:
|
|
return Support.YES
|
|
return self.no_opinfos_skip_test('test_op_has_batch_rule')
|
|
|
|
def supports_vmapvjp(self):
|
|
if self.name in FACTORY_FNS:
|
|
return Support.YES
|
|
if self.name in VMAP_EXEMPTIONS:
|
|
return Support.YES
|
|
return self.no_opinfos_skip_test('test_vmapvjp')
|
|
|
|
def supports_fast_vmapvjp(self):
|
|
if self.name in FACTORY_FNS:
|
|
return Support.YES
|
|
if self.name in VMAP_EXEMPTIONS:
|
|
return Support.YES
|
|
return self.no_opinfos_skip_test('test_vmapvjp_has_batch_rule')
|
|
|
|
def supports_jvp(self):
|
|
if self.name in FACTORY_FNS:
|
|
return Support.YES
|
|
if self.name in JVP_EXEMPTIONS:
|
|
return Support.YES
|
|
if not self.has_opinfo():
|
|
return Support.UNKNOWN
|
|
if self.any_opinfo_attr('supports_autograd') and \
|
|
not self.all_opinfo_attr('supports_forward_ad'):
|
|
return Support.NO
|
|
return self.no_opinfos_skip_test('test_jvp')
|
|
|
|
def supports_jvpvjp(self):
|
|
if self.name in FACTORY_FNS:
|
|
return Support.YES
|
|
exemptions = {
|
|
# we have support (see OpInfo), testing artifact
|
|
'nn.functional.dropout2d',
|
|
'nn.functional.dropout',
|
|
# exception: we dont even support double backward for this
|
|
'nn.functional.hardswish',
|
|
'bernoulli', # this isn't differentiable
|
|
'normal', # not differentiable
|
|
}
|
|
if self.name in exemptions:
|
|
return Support.YES
|
|
return self.no_opinfos_skip_test('test_jvpvjp')
|
|
|
|
def _supports_vmapjvp_base(self, test):
|
|
if self.name in FACTORY_FNS:
|
|
return Support.YES
|
|
VMAPJVP_EXEMPTIONS = {
|
|
'prod', # dynamic (backward)
|
|
'nn.functional.batch_norm', # testing problem
|
|
'normal', # not actually problem, randomness testing artifact
|
|
'bernoulli', # not actually problem, randomness testing artifact
|
|
'nn.functional.dropout2d', # not actually problem, randomness testing artifact
|
|
'nn.functional.dropout', # not actually problem, randomness testing artifact
|
|
# Not a problem.
|
|
# It's just that the max_norm testing mutates inputs...
|
|
# (we have our own functorch variant of the OpInfo without max_norm)
|
|
'nn.functional.embedding',
|
|
}
|
|
if self.name in VMAPJVP_EXEMPTIONS:
|
|
return Support.YES
|
|
if not self.has_opinfo():
|
|
return Support.UNKNOWN
|
|
if self.any_opinfo_attr('supports_autograd') and \
|
|
not self.all_opinfo_attr('supports_forward_ad'):
|
|
return Support.NO
|
|
return self.no_opinfos_skip_test(test)
|
|
|
|
def supports_vmapjvp(self):
|
|
return self._supports_vmapjvp_base('test_vmapjvpall')
|
|
|
|
def supports_fast_vmapjvp(self):
|
|
return self._supports_vmapjvp_base('test_vmapjvpall_has_batch_rule')
|
|
|
|
|
|
class OperatorSet:
|
|
def __init__(self, operators):
|
|
self.data = set(operators)
|
|
|
|
@classmethod
|
|
def from_names(cls, names):
|
|
return OperatorSet([Operator(name) for name in names])
|
|
|
|
@classmethod
|
|
def from_top_ops_threshold(cls, torch_threshold, nn_fn_threshold):
|
|
names = get_top_ops(torch_threshold, nn_fn_threshold)
|
|
return cls.from_names(names)
|
|
|
|
@classmethod
|
|
def from_top125(cls):
|
|
return cls.from_top_ops_threshold(100, 25)
|
|
|
|
@classmethod
|
|
def from_top160(cls):
|
|
return cls.from_top_ops_threshold(107, 53)
|
|
|
|
@classmethod
|
|
def all(cls):
|
|
dct = get_public_overridable_outplace_we_care_about()
|
|
names = dct.keys()
|
|
names_sanitized = []
|
|
for n in names:
|
|
torch_tensor = 'torch.Tensor.'
|
|
torch_dot = 'torch.'
|
|
if n.startswith(torch_tensor):
|
|
names_sanitized.append(n[len(torch_tensor):])
|
|
elif n.startswith(torch_dot):
|
|
names_sanitized.append(n[len(torch_dot):])
|
|
else:
|
|
raise AssertionError()
|
|
return cls.from_names(names_sanitized)
|
|
|
|
def query(self, operator_method, filter=(Support.NO, Support.YES, Support.UNKNOWN)):
|
|
result = {}
|
|
for key in filter:
|
|
result[key] = set([])
|
|
for op in self.data:
|
|
support_status = operator_method(op)
|
|
if support_status in filter:
|
|
result[support_status].add(op)
|
|
return result
|
|
|
|
def summary(self):
|
|
checks = [
|
|
'supports_vjp',
|
|
'supports_vmap',
|
|
'supports_fast_vmap',
|
|
'supports_vmapvjp',
|
|
'supports_fast_vmapvjp',
|
|
'supports_jvp',
|
|
'supports_vmapjvp',
|
|
'supports_fast_vmapjvp',
|
|
'supports_jvpvjp',
|
|
]
|
|
result = ['test, yes, no, unknown']
|
|
for check in checks:
|
|
accessor = getattr(Operator, check)
|
|
all_results = self.query(accessor)
|
|
yes_amt = len(all_results[Support.YES])
|
|
no_amt = len(all_results[Support.NO])
|
|
unknown_amt = len(all_results[Support.UNKNOWN])
|
|
result.append(f'{check}, {yes_amt}, {no_amt}, {unknown_amt}')
|
|
return '\n'.join(result)
|
|
|
|
|
|
opset = OperatorSet.all()
|
|
has_no_opinfo = opset.query(Operator.has_opinfo, (False,))
|
|
|
|
print("=" * 30 + " Summary " + "=" * 30)
|
|
print(f'% of usages on github: {get_ops_percentage(99999, 99999)}')
|
|
print(opset.summary())
|
|
|
|
# sanity checks
|
|
result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN))
|
|
# pprint.pprint(result)
|
|
|
|
print("=" * 30 + " Top 60 Summary " + "=" * 30)
|
|
print(f'% of usages on github: {get_ops_percentage(35, 25)}')
|
|
opset = OperatorSet.from_top_ops_threshold(35, 25)
|
|
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
|
|
# pprint.pprint(result)
|
|
# result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
|
|
# pprint.pprint(result)
|
|
# kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
|
|
# kpprint.pprint(result)
|
|
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
|
|
# pprint.pprint(result)
|
|
# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))
|
|
# pprint.pprint(result)
|
|
# pprint.pprint(result)
|
|
print(opset.summary())
|
|
|
|
print("=" * 30 + " Top 125 Summary " + "=" * 30)
|
|
print(f'% of usages on github: {get_ops_percentage(100, 25)}')
|
|
opset = OperatorSet.from_top125()
|
|
# result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN))
|
|
# pprint.pprint(result)
|
|
# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
|
|
# pprint.pprint(result)
|
|
print("supports_vjp")
|
|
result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN))
|
|
pprint.pprint(result)
|
|
print("supports_jvp")
|
|
result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
|
|
pprint.pprint(result)
|
|
print("supports_vmapjvp")
|
|
result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
|
|
pprint.pprint(result)
|
|
print("supports_jvpvjp")
|
|
result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
|
|
pprint.pprint(result)
|
|
# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))
|
|
# pprint.pprint(result)
|
|
# pprint.pprint(result)
|
|
print(opset.summary())
|
|
|
|
# print("=" * 30 + " Top 160 Summary " + "=" * 30)
|
|
# opset = OperatorSet.from_top160()
|
|
# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
|
|
# pprint.pprint(result)
|
|
# print(opset.summary())
|
|
|
|
# Print list of everything in order
|
|
# all_ops = get_top_ops(999999, 999999, with_counts=True)
|
|
# for op, count in all_ops:
|
|
# print(f'{op}, {count}')
|