mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: This adds the torch.arccosh alias and updates alias testing to validate the consistency of the aliased and original operations. The alias testing is also updated to run on CPU and CUDA, which revealed a memory leak when tracing (see https://github.com/pytorch/pytorch/issues/43119). Pull Request resolved: https://github.com/pytorch/pytorch/pull/43107 Reviewed By: ngimel Differential Revision: D23156472 Pulled By: mruberry fbshipit-source-id: 6155fac7954fcc49b95e7c72ed917c85e0eabfcd
145 lines
6.2 KiB
Python
145 lines
6.2 KiB
Python
import torch
|
|
from torch.testing import FileCheck
|
|
|
|
from torch.testing._internal.common_utils import \
|
|
(run_tests)
|
|
from torch.testing._internal.jit_utils import JitTestCase
|
|
from torch.testing._internal.common_device_type import \
|
|
(instantiate_device_type_tests, skipCPUIfNoLapack, skipCUDAIfNoMagma, onlyCPU)
|
|
|
|
# Information for generating an alias test
|
|
# NOTE: ending the alias_name with an underscore will interpret the test
|
|
# as the test for an inplace method of that name
|
|
class AliasInfo(object):
|
|
__slots__ = ['alias_name', 'alias_op', 'original_name', 'original_op',
|
|
'get_input', 'get_args', 'decorators']
|
|
|
|
def __init__(self,
|
|
alias_name, # the name of the alias
|
|
alias_op, # the aliased op
|
|
original_name, # the name of the original function
|
|
original_op, # the original op
|
|
get_input, # callable (device)->tensor that returns the first tensor argument
|
|
*,
|
|
get_args=lambda d: (), # callable (device)->tuple that returns additional positional arguments
|
|
decorators=()): # decorators to apply to the test
|
|
self.alias_name = alias_name
|
|
self.alias_op = alias_op
|
|
self.original_name = original_name
|
|
self.original_op = original_op
|
|
self.get_input = get_input
|
|
self.get_args = get_args
|
|
self.decorators = decorators
|
|
|
|
alias_infos = (
|
|
AliasInfo('absolute', torch.absolute, 'abs', torch.abs,
|
|
lambda d: torch.randn(20, device=d)),
|
|
AliasInfo('absolute_', torch.Tensor.absolute_, 'abs_', torch.Tensor.abs_,
|
|
lambda d: torch.randn(20, device=d)),
|
|
AliasInfo('clip', torch.clip, 'clamp', torch.clamp,
|
|
lambda d: torch.randn(20, device=d), get_args=lambda d: (.4, .6)),
|
|
AliasInfo('clip_', torch.Tensor.clip_, 'clamp_', torch.Tensor.clamp_,
|
|
lambda d: torch.randn(20, device=d), get_args=lambda d: (.4, .6)),
|
|
AliasInfo('linalg.det', torch.linalg.det, 'det', torch.det,
|
|
lambda d: torch.randn(10, 10, device=d),
|
|
decorators=(skipCPUIfNoLapack, skipCUDAIfNoMagma)),
|
|
# NOTE: only runs on CPU because it leaks CUDA memory
|
|
# (see https://github.com/pytorch/pytorch/issues/43119)
|
|
AliasInfo('outer', torch.outer, 'ger', torch.ger,
|
|
lambda d: torch.randn(20, device=d), get_args=lambda d: (torch.randn(20, device=d),),
|
|
decorators=(onlyCPU,)),
|
|
AliasInfo('arccosh', torch.arccosh, 'acosh', torch.acosh,
|
|
lambda d: torch.randn(20, device=d) + 2),
|
|
AliasInfo('arccosh_', torch.Tensor.arccosh_, 'acosh_', torch.Tensor.acosh_,
|
|
lambda d: torch.randn(20, device=d) + 2),
|
|
)
|
|
|
|
# Placeholder test class for validating that aliases are correctly
|
|
# translated when scripted and traced
|
|
class TestOpNormalization(JitTestCase):
|
|
pass
|
|
|
|
# Generates alias tests and adds them to the specified class (cls)
|
|
def create_alias_tests(cls):
|
|
for info in alias_infos:
|
|
|
|
# Tests that the JIT remaps aliases to their original ops
|
|
def _test_jit_op_alias_normalization(self, device, info=info):
|
|
tensor = torch.tensor
|
|
op = info.alias_op
|
|
is_inplace = info.alias_name.endswith('_')
|
|
|
|
# Checks that scripting converts aliases
|
|
# NOTE: the code to test scripting must be generated since
|
|
# scripting does not support splatting args or directly
|
|
# calling torch.Tensor methods. The following
|
|
# splats args after the first tensor by inlining them as constants.
|
|
if is_inplace:
|
|
fn_template = '''
|
|
def _fn(t):
|
|
return t.{alias_name}({args})
|
|
'''
|
|
arg_string = ', '.join((str(arg) for arg in info.get_args(device)))
|
|
script = fn_template.format(alias_name=info.alias_name, args=arg_string)
|
|
else:
|
|
fn_template = '''
|
|
def _fn(t):
|
|
return op(t{args})
|
|
'''
|
|
arg_string = ", " + ', '.join((str(arg) for arg in info.get_args(device)))
|
|
script = fn_template.format(args=arg_string)
|
|
|
|
# Compiles script
|
|
scripted = torch.jit.CompilationUnit(script)._fn
|
|
|
|
# Acquires and checks the graph remaps the alias
|
|
inp = info.get_input(device)
|
|
scripted(inp.clone())
|
|
graph = scripted.graph_for(inp.clone())
|
|
FileCheck().check(info.original_name).check_not(info.alias_name).run(graph)
|
|
|
|
# Checks that tracing converts aliases
|
|
# NOTE: tracing has no problem splatting args
|
|
args = info.get_args(device)
|
|
|
|
def _fn(t, info=info, args=args):
|
|
return info.alias_op(t, *args)
|
|
|
|
traced = torch.jit.trace(_fn, (inp.clone(),))
|
|
traced(inp.clone())
|
|
graph = traced.graph_for(inp.clone())
|
|
FileCheck().check(info.original_name).check_not(info.alias_name).run(graph)
|
|
|
|
# Applies decorators
|
|
for decorator in info.decorators:
|
|
_test_jit_op_alias_normalization = decorator(_test_jit_op_alias_normalization)
|
|
|
|
test_name = "test_jit_op_alias_normalization_" + info.alias_name
|
|
setattr(cls, test_name, _test_jit_op_alias_normalization)
|
|
|
|
# Tests that the alias functions perform the same operation as the original
|
|
def _test_alias_computation(self, device, info=info):
|
|
alias_op = info.alias_op
|
|
original_op = info.original_op
|
|
|
|
inp = info.get_input(device)
|
|
args = info.get_args(device)
|
|
alias_result = alias_op(inp.clone(), *args)
|
|
original_result = alias_op(inp.clone(), *args)
|
|
|
|
self.assertEqual(alias_result, original_result, atol=0, rtol=0)
|
|
|
|
# Applies decorators
|
|
for decorator in info.decorators:
|
|
_test_alias_computation = decorator(_test_alias_computation)
|
|
|
|
test_name = "test_alias_computation_" + info.alias_name
|
|
setattr(cls, test_name, _test_alias_computation)
|
|
|
|
|
|
create_alias_tests(TestOpNormalization)
|
|
instantiate_device_type_tests(TestOpNormalization, globals())
|
|
|
|
if __name__ == '__main__':
|
|
run_tests()
|