Revert D32175963: Converting hardswish to strucutred kernels with metatensor support

Test Plan: revert-hammer

Differential Revision:
D32175963 (57335a9ee3)

Original commit changeset: f4d749c6aeaf

fbshipit-source-id: 6d68a96cf872c2d7b518c061875b9336bca0043a
This commit is contained in:
Alban Desmaison 2021-11-05 07:01:30 -07:00 committed by Facebook GitHub Bot
parent 4d5338228f
commit bb8978f605
10 changed files with 142 additions and 167 deletions

View File

@ -165,10 +165,6 @@ TORCH_META_FUNC(softshrink_backward) (
build_borrowing_binary_op(maybe_get_output(), grad, self);
}
TORCH_META_FUNC(hardswish)(const Tensor& self) {
build_unary_op(maybe_get_output(), self);
}
TORCH_META_FUNC(gelu) (const Tensor & self) {
build_unary_op(maybe_get_output(), self);
}
@ -389,13 +385,34 @@ Tensor hardtanh_backward(const Tensor& grad_output, const Tensor& self, const Sc
return iter.output();
}
TORCH_IMPL_FUNC(hardswish_out)(const Tensor& self, const Tensor& result) {
#if defined(C10_MOBILE) && defined(USE_XNNPACK)
Tensor hardswish(const Tensor& self) {
#if defined(C10_MOBILE) && defined(USE_XNNPACK)
if (xnnpack::use_hardswish(self)) {
xnnpack::hardswish_out(self, result);
return xnnpack::hardswish(self);
}
#endif
hardswish_stub(device_type(), *this);
#endif
Tensor result;
auto iter = TensorIterator::unary_op(result, self);
hardswish_stub(iter.device_type(), iter);
return iter.output();
}
Tensor& hardswish_out(const Tensor& self, Tensor& result) {
auto iter = TensorIterator::unary_op(result, self);
hardswish_stub(iter.device_type(), iter);
return result;
}
Tensor& hardswish_(Tensor& self) {
#if defined(C10_MOBILE) && defined(USE_XNNPACK)
if (xnnpack::use_hardswish(self)) {
xnnpack::hardswish_(self);
return self;
}
#endif
auto iter = TensorIterator::unary_op(self, self);
hardswish_stub(iter.device_type(), iter);
return self;
}
Tensor hardswish_backward(const Tensor& grad_output, const Tensor& self) {

View File

@ -23,6 +23,10 @@ using softplus_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10:
using softplus_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
using threshold_fn = void (*)(TensorIteratorBase&, const c10::Scalar&, const c10::Scalar&);
using hardtanh_backward_fn = void (*)(TensorIterator&, const c10::Scalar&, const c10::Scalar&);
using hardsigmoid_fn = void(*)(TensorIteratorBase&);
using hardsigmoid_backward_fn = void(*)(TensorIteratorBase&);
using hardswish_fn = void(*)(TensorIterator&);
using hardswish_backward_fn = void(*)(TensorIterator&);
using shrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using softshrink_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
using shrink_backward_fn = void (*)(TensorIteratorBase&, const c10::Scalar&);
@ -42,10 +46,10 @@ DECLARE_DISPATCH(threshold_fn, threshold_stub);
DECLARE_DISPATCH(structured_activation_fn, GeluKernel);
DECLARE_DISPATCH(structured_activation_backward_fn, GeluBackwardKernel);
DECLARE_DISPATCH(hardtanh_backward_fn, hardtanh_backward_stub);
DECLARE_DISPATCH(structured_activation_fn, hardsigmoid_stub);
DECLARE_DISPATCH(structured_activation_backward_fn, hardsigmoid_backward_stub);
DECLARE_DISPATCH(structured_activation_fn, hardswish_stub);
DECLARE_DISPATCH(activation_backward_fn, hardswish_backward_stub);
DECLARE_DISPATCH(hardsigmoid_fn, hardsigmoid_stub);
DECLARE_DISPATCH(hardsigmoid_backward_fn, hardsigmoid_backward_stub);
DECLARE_DISPATCH(hardswish_fn, hardswish_stub);
DECLARE_DISPATCH(hardswish_backward_fn, hardswish_backward_stub);
DECLARE_DISPATCH(shrink_fn, hardshrink_stub);
DECLARE_DISPATCH(softshrink_fn, softshrink_stub);
DECLARE_DISPATCH(shrink_backward_fn, shrink_backward_stub);

View File

@ -336,7 +336,7 @@ void hardtanh_backward_kernel(TensorIterator& iter, const Scalar& min, const Sca
});
}
void hardswish_kernel(TensorIteratorBase& iter) {
void hardswish_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES(iter.dtype(), "hardswish_cpu", [&]() {
const scalar_t zero(0.0f);
const scalar_t three(3.0f);

View File

@ -438,7 +438,7 @@ void leaky_relu_backward_kernel(TensorIteratorBase& iter, const Scalar& negval_)
});
}
void hardswish_kernel(TensorIteratorBase& iter) {
void hardswish_kernel(TensorIterator& iter) {
AT_DISPATCH_FLOATING_TYPES_AND2(at::ScalarType::Half, at::ScalarType::BFloat16, iter.dtype(), "hardswish_cuda", [&]() {
using T_ACC = acc_type<scalar_t, true>;
const T_ACC zero(0.0f);

View File

@ -8685,22 +8685,22 @@
QuantizedCPU: hardtanh_quantized_cpu_
- func: hardswish.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
structured: True
structured_inherits: TensorIteratorBase
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU, CUDA: hardswish_out
- func: hardswish(Tensor self) -> Tensor
structured_delegate: hardswish.out
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU, CUDA: hardswish
- func: hardswish_(Tensor(a!) self) -> Tensor(a!)
structured_delegate: hardswish.out
device_check: NoCheck # TensorIterator
python_module: nn
dispatch:
CPU, CUDA: hardswish_
- func: hardswish_backward(Tensor grad_output, Tensor self) -> Tensor
python_module: nn

View File

@ -32,21 +32,16 @@ Tensor empty_with_tail_padding(
maybe_names);
}
bool is_padded_contiguous(
Tensor allocate_padded_contiguous_if_needed(
const Tensor& input,
const c10::MemoryFormat memory_format) {
const auto* const allocator = input.storage().allocator();
const auto* const mobile_allocator = c10::GetDefaultMobileCPUAllocator();
return (allocator == mobile_allocator) && input.is_contiguous(memory_format);
}
Tensor allocate_padded_contiguous_if_needed(
const Tensor& input,
const c10::MemoryFormat memory_format) {
// If the allocators are the same and the memory is contiguous in the requested
// format, then there is no need to reallocate the tensor.
if (is_padded_contiguous(input, memory_format)) {
if ((allocator == mobile_allocator) && input.is_contiguous(memory_format)) {
return input;
}

View File

@ -6,10 +6,6 @@ namespace at {
namespace native {
namespace mobile {
bool is_padded_contiguous(
const Tensor& input,
const c10::MemoryFormat memory_format);
Tensor allocate_padded_contiguous_if_needed(
const Tensor& input,
c10::MemoryFormat memory_format);

View File

@ -1,76 +1,97 @@
#ifdef USE_XNNPACK
#include <ATen/native/utils/Factory.h>
#include <ATen/native/xnnpack/Common.h>
#include <ATen/native/utils/Factory.h>
namespace at {
namespace native {
namespace xnnpack {
namespace {
void hardswish_impl(const Tensor& input, const Tensor& output) {
bool use_hardswish(
const Tensor& input) {
return xnnpack::internal::available() &&
(1 <= input.ndimension()) &&
(input.device().is_cpu()) &&
(kFloat == input.scalar_type()) &&
!input.requires_grad() &&
true;
}
Tensor& hardswish_impl(Tensor& input, Tensor& output) {
using namespace internal;
xnn_operator_t hardswish_op{};
const xnn_status create_status = xnn_create_hardswish_nc_f32(
1, // channels
1, // input stride
1, // output stride
0, // flags
&hardswish_op);
1, // channels
1, // input stride
1, // output stride
0, // flags
&hardswish_op);
TORCH_CHECK(
xnn_status_success == create_status,
"xnn_create_hardswish_nc_f32 failed!");
xnn_status_success == create_status,
"xnn_create_hardswish_nc_f32 failed!");
Operator hardswish_scoped_op(hardswish_op);
const xnn_status setup_status = xnn_setup_hardswish_nc_f32(
hardswish_op,
input.numel(), // Batch
input.data_ptr<float>(),
output.data_ptr<float>(),
caffe2::pthreadpool_()); // threadpool
hardswish_op,
input.numel(), // Batch
input.data_ptr<float>(),
output.data_ptr<float>(),
caffe2::pthreadpool_()); // threadpool
TORCH_CHECK(
xnn_status_success == setup_status, "xnn_setup_hardswish_nc_f32 failed!");
xnn_status_success == setup_status,
"xnn_setup_hardswish_nc_f32 failed!");
const xnn_status run_status = xnn_run_operator(
hardswish_op,
caffe2::pthreadpool_()); // threadpool
hardswish_op,
caffe2::pthreadpool_()); // threadpool
TORCH_INTERNAL_ASSERT(
xnn_status_success == run_status, "xnn_run_operator failed!");
xnn_status_success == run_status,
"xnn_run_operator failed!");
return output;
}
} // namespace
bool use_hardswish(const Tensor& input) {
return xnnpack::internal::available() && (1 <= input.ndimension()) &&
(input.device().is_cpu()) && (kFloat == input.scalar_type()) &&
!input.requires_grad();
}
const Tensor& hardswish_out(const Tensor& input, const Tensor& result) {
Tensor hardswish(const Tensor& input) {
Tensor padded_input = mobile::allocate_padded_contiguous_if_needed(
input, input.suggest_memory_format());
input, input.suggest_memory_format());
if (mobile::is_padded_contiguous(result, result.suggest_memory_format())) {
hardswish_impl(padded_input, result);
Tensor output = mobile::empty_with_tail_padding(
padded_input.sizes(),
padded_input.options().dtype(),
input.suggest_memory_format(),
padded_input.names());
hardswish_impl(padded_input, output);
return output.contiguous(input.suggest_memory_format());
}
Tensor& hardswish_(Tensor& input) {
Tensor padded_input = mobile::allocate_padded_contiguous_if_needed(
input, input.suggest_memory_format());
// Don't need to allocate output if input is contiguous & already padded
if (input.data_ptr() == padded_input.data_ptr()) {
hardswish_impl(input, input);
return input;
} else {
Tensor output = mobile::empty_with_tail_padding(
padded_input.sizes(),
padded_input.options().dtype(),
input.suggest_memory_format(),
padded_input.names());
padded_input.sizes(),
padded_input.options().dtype(),
input.suggest_memory_format(),
padded_input.names());
hardswish_impl(padded_input, output);
result.copy_(output);
return input.copy_(output);
}
return result;
}
} // namespace xnnpack
} // namespace native
} // namespace at
}
}
}
#endif /* USE_XNNPACK */

View File

@ -90,7 +90,8 @@ Tensor channel_shuffle(
// Activations
//
bool use_hardswish(const Tensor& input);
const Tensor& hardswish_out(const Tensor& input, const Tensor& result);
Tensor hardswish(const Tensor& input);
Tensor& hardswish_(Tensor& input);
} // namespace xnnpack
} // namespace native

View File

@ -1,19 +1,19 @@
# Owner(s): ["module: nn"]
import io
import itertools
import math
import pickle
import random
import string
import unittest
import io
import unittest.mock as mock
import itertools
import warnings
from collections import OrderedDict
import pickle
from copy import deepcopy
from functools import partial, reduce
from itertools import product, repeat
from itertools import repeat, product
from functools import reduce, partial
from operator import mul
from collections import OrderedDict
import torch
@ -21,102 +21,43 @@ import torch
# NN tests use double as the default dtype
torch.set_default_dtype(torch.double)
from hypothesis import given
from torch._six import inf, nan
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.nn.utils.rnn as rnn_utils
from torch.nn.utils import clip_grad_norm_, clip_grad_value_
import torch.nn.utils.parametrize as parametrize
import torch.nn.utils.prune as prune
import torch.nn.utils.rnn as rnn_utils
import torch.testing._internal.hypothesis_utils as hu
from torch._six import inf, nan
from torch.nn import MultiheadAttention, Parameter
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.nn import Parameter
from torch.nn.parameter import UninitializedParameter, UninitializedBuffer
from torch.nn.parallel._functions import Broadcast
from torch.nn.parameter import UninitializedBuffer, UninitializedParameter
from torch.nn.utils import (
clip_grad_norm_,
clip_grad_value_,
parameters_to_vector,
vector_to_parameters,
)
from torch.testing._internal.common_cuda import (
TEST_CUDA,
TEST_CUDNN,
TEST_CUDNN_VERSION,
TEST_MULTIGPU,
tf32_is_not_fp32,
tf32_off,
tf32_on,
tf32_on_and_off,
)
from torch.testing._internal.common_device_type import (
deviceCountAtLeast,
dtypes,
dtypesIfCUDA,
expectedFailureMeta,
get_all_device_types,
instantiate_device_type_tests,
largeTensorTest,
onlyCPU,
onlyCUDA,
onlyNativeDeviceTypes,
precisionOverride,
skipCUDAIf,
skipCUDAIfCudnnVersionLessThan,
skipCUDAIfNoCudnn,
skipCUDAIfNotMiopenSuggestNHWC,
skipCUDAIfNotRocm,
skipCUDAIfRocm,
skipCUDAIfRocmVersionLessThan,
skipMeta,
)
from torch.testing._internal.common_dtype import (
get_all_fp_dtypes,
get_all_math_dtypes,
integral_types,
)
from torch.testing._internal.common_nn import (
CriterionTest,
NewModuleTest,
NNTestCase,
criterion_tests,
ctcloss_reference,
loss_reference_fns,
module_tests,
new_module_tests,
single_batch_reference_fn,
)
from torch.testing._internal.common_utils import (
ALL_TENSORTYPES,
ALL_TENSORTYPES2,
GRADCHECK_NONDET_TOL,
IS_PPC,
TEST_NUMPY,
TEST_SCIPY,
TEST_WITH_ROCM,
TEST_WITH_UBSAN,
TemporaryFileName,
TestCase,
_assertGradAndGradgradChecks,
download_file,
dtype2prec_DONTUSE,
freeze_rng_state,
get_function_arglist,
gradcheck,
gradgradcheck,
load_tests,
repeat_test_for_types,
run_tests,
skipIfNoLapack,
skipIfNotMiopenSuggestNHWC,
skipIfRocm,
skipIfRocmVersionLessThan,
suppress_warnings,
)
from torch.testing._internal.common_dtype import integral_types, get_all_fp_dtypes, get_all_math_dtypes
from torch.testing._internal.common_utils import freeze_rng_state, run_tests, TestCase, skipIfNoLapack, skipIfRocm, \
skipIfRocmVersionLessThan, skipIfNotMiopenSuggestNHWC, TEST_NUMPY, TEST_SCIPY, TEST_WITH_ROCM, download_file, \
get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \
ALL_TENSORTYPES2, suppress_warnings, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
module_tests, criterion_tests, loss_reference_fns, \
ctcloss_reference, new_module_tests, single_batch_reference_fn
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \
dtypesIfCUDA, precisionOverride, skipCUDAIfNoCudnn, skipCUDAIfCudnnVersionLessThan, onlyCUDA, onlyCPU, \
skipCUDAIfRocm, skipCUDAIf, skipCUDAIfNotRocm, skipCUDAIfRocmVersionLessThan, skipCUDAIfNotMiopenSuggestNHWC, \
onlyNativeDeviceTypes, deviceCountAtLeast, largeTensorTest, expectedFailureMeta, skipMeta, get_all_device_types
from torch.nn import MultiheadAttention
from hypothesis import given
import torch.testing._internal.hypothesis_utils as hu
from torch.testing._internal.common_utils import _assertGradAndGradgradChecks, gradcheck, gradgradcheck, \
GRADCHECK_NONDET_TOL
from torch.testing._internal.common_utils import dtype2prec_DONTUSE
from torch.testing._internal.common_cuda import tf32_on_and_off, tf32_is_not_fp32, tf32_off, tf32_on
from torch.types import _TensorOrTensors
AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
# load_tests from common_utils is used to automatically filter tests for
@ -124,8 +65,8 @@ AMPERE_OR_ROCM = TEST_WITH_ROCM or tf32_is_not_fp32()
load_tests = load_tests
if TEST_SCIPY:
import scipy.ndimage
from scipy import stats
import scipy.ndimage
if TEST_NUMPY:
import numpy as np
@ -478,7 +419,6 @@ class TestNN(NNTestCase):
def test_conv_backcompat(self):
from torch.serialization import SourceChangeWarning
# This file was generated by running on PyTorch 1.0.1 on Python 2:
#
# import torch
@ -18028,6 +17968,7 @@ class TestNNDeviceType(NNTestCase):
with self.assertRaisesRegex(RuntimeError, "call out-of-place version"):
b.backward(torch.ones(2, device=device))
@expectedFailureMeta # https://github.com/pytorch/pytorch/issues/54897
def test_hardswish_inplace_overlap(self, device):
x = torch.randn((1, 6), device=device).expand((6, 6))
with self.assertRaisesRegex(RuntimeError, 'unsupported operation'):