Named inference rule for abs. (#22151)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/22151
ghimport-source-id: 54c1726b578ac162af817f78df6f540b764e46e3

Test Plan:
- `python test/test_namedtensor.py` [namedtensor ci]

Imported from OSS

Differential Revision: D15970326

Pulled By: zou3519

fbshipit-source-id: 4ea25f0a73bbc24b604d3ded2027eeb4ce800de0
This commit is contained in:
Richard Zou 2019-06-27 12:37:55 -07:00 committed by Facebook Github Bot
parent 2913f6a26d
commit 6386e4d244
5 changed files with 72 additions and 0 deletions

View File

@ -603,6 +603,10 @@ def device_guard(option, dispatch_options, dispatch_tensor):
def named_guard(option, tensors, tensorlists):
if not option.get('named_guard', True) or (len(tensors) + len(tensorlists) == 0):
return ''
# Override: named_guard = True for _th_ functions. This is because:
# There is always some at:: function that calls the _th_ function.
if option['name'].startswith('_th_'):
return ''
named_conditions = []
for tensor in tensors:
named_conditions.append('{}.is_named()'.format(tensor))

View File

@ -16,6 +16,9 @@
#include <ATen/Parallel.h>
#include <ATen/native/UnaryOps.h>
#include <ATen/native/TensorIterator.h>
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/NamedTensorUtils.h>
#endif
#include <algorithm>
#include <cmath>
@ -137,6 +140,14 @@ Tensor& _sigmoid_out_cpu(Tensor& result, const Tensor& self) {
return result;
}
static void propagate_names(Tensor& result, const Tensor& src) {
#ifdef NAMEDTENSOR_ENABLED
if (src.is_named()) {
at::internal_set_names_inplace(result, src.names());
}
#endif
}
// NB: If you use this macro, you may also need to add a CUDA forwarding
// stub in CUDAUnaryOps
@ -154,6 +165,7 @@ Tensor& _sigmoid_out_cpu(Tensor& result, const Tensor& self) {
assert_no_internal_overlap(result, #op); \
auto iter = TensorIterator::unary_op(result, self); \
op##_stub(iter->device_type(), *iter); \
propagate_names(result, self); \
return result; \
}

View File

@ -92,14 +92,17 @@
- func: abs(Tensor self) -> Tensor
variants: function, method
named_guard: False
- func: abs_(Tensor(a!) self) -> Tensor(a!)
variants: function, method
named_guard: False
dispatch:
CPU: _abs__cpu
CUDA: _abs__cuda
- func: abs(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
named_guard: False
dispatch:
CPU: _abs_out_cpu
CUDA: _abs_out_cuda

View File

@ -3,6 +3,9 @@
#else
#include <ATen/MemoryOverlap.h>
#ifdef NAMEDTENSOR_ENABLED
#include <ATen/NamedTensorUtils.h>
#endif
void THCTensor_(cbitand)(THCState* state, THCTensor *self_, THCTensor *src1, THCTensor *src2)
{
@ -172,6 +175,15 @@ void THCTensor_(cminValue)(THCState *state, THCTensor *self, THCTensor *src, sca
#if !defined(THC_REAL_IS_BOOL)
static void propagate_names(THCTensor* result, THCTensor* src) {
#ifdef NAMEDTENSOR_ENABLED
if (at::impl::internal_is_named(src)) {
const auto names = at::impl::internal_get_names(src);
at::impl::internal_set_names_inplace(result, names);
}
#endif
}
#define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC_(NAME, CFUNC, REAL) \
struct Tensor_##NAME##_##REAL##_Op { \
__device__ __forceinline__ void operator()(scalar_t* out, scalar_t* in) const { \
@ -199,6 +211,7 @@ void THCTensor_(cminValue)(THCState *state, THCTensor *self, THCTensor *src, sca
} \
\
THCudaCheck(cudaGetLastError()); \
propagate_names(self_, src); \
}
#define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(NAME, CFUNC, REAL) \

View File

@ -1,6 +1,7 @@
import unittest
from common_utils import TestCase, run_tests
from common_cuda import TEST_CUDA
import itertools
import torch
import sys
@ -69,6 +70,45 @@ class TestNamedTensor(TestCase):
def test_empty_cuda(self):
self._test_factory(torch.empty, 'cuda')
def test_unary_fns(self):
def _test(lambd, names=('N', 'D'), device='cpu'):
sizes = [2] * len(names)
tensor = torch.empty(sizes, names=names, device=device)
out = lambd(tensor)
self.assertEqual(out.names, tensor.names)
def method(name, *args, **kwargs):
return [lambda t: getattr(t, name)(*args, **kwargs)]
def out_function(name, *args, **kwargs):
out_fn = getattr(torch, name)
def fn(tensor):
result = tensor.new_empty([0])
out_fn(tensor, *args, out=result, **kwargs)
return result
return [fn]
def fn_method_and_inplace(name, *args, **kwargs):
return (
method(name, *args, **kwargs) +
method(name + '_', *args, **kwargs) +
out_function(name, *args, **kwargs)
)
def flatten(lst):
return [item for sublist in lst for item in sublist]
tests = [
fn_method_and_inplace('abs'),
]
tests = flatten(tests)
for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()):
_test(testcase, device=device)
def test_using_seen_interned_string_doesnt_bump_refcount(self):
def see_name():
seen_name = 'N'