mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add SqueezeOp in MKLDNN
Summary: SqueezeOp support to drop drop dims of size 1. MKLMemory now supports Reshape() if the buffer is in plain layout, in which case just the dims and layouts are modified similar to caffe2::Tensor. SqueezeOp takes care of converting the input to plain layout if needed via an intermediate buffer before calling Reshape(). Differential Revision: D6735656 fbshipit-source-id: 953309498370e1b8986e8c593bc6963f38036255
This commit is contained in:
parent
e64ad91365
commit
231d6f7b09
85
caffe2/mkl/operators/squeeze_op.cc
Normal file
85
caffe2/mkl/operators/squeeze_op.cc
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "caffe2/mkl/mkl_utils.h"
|
||||
#include "caffe2/operators/expand_squeeze_dims_op.h"
|
||||
|
||||
#ifdef CAFFE2_HAS_MKL_DNN
|
||||
|
||||
namespace caffe2 {
|
||||
namespace mkl {
|
||||
|
||||
template <typename T>
|
||||
class MKLSqueezeOp final : public MKLOperator<T> {
|
||||
public:
|
||||
USE_MKLOPERATOR_FUNCTIONS(T);
|
||||
|
||||
MKLSqueezeOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: MKLOperator<T>(operator_def, ws),
|
||||
dims_(OperatorBase::GetRepeatedArgument<int>("dims")) {
|
||||
auto originalSize = dims_.size();
|
||||
CAFFE_ENFORCE(originalSize > 0, "Parameter `dims` must be provided.");
|
||||
|
||||
std::sort(dims_.begin(), dims_.end());
|
||||
dims_.erase(std::unique(dims_.begin(), dims_.end()), dims_.end());
|
||||
if (dims_.size() < originalSize) {
|
||||
LOG(WARNING) << "Parameter `dims` has repeated dimensions.";
|
||||
}
|
||||
CAFFE_ENFORCE(dims_.front() >= 0, "Dimension ids must be non-negative.");
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
const auto& X = Input(0);
|
||||
auto* Y = Output(0);
|
||||
|
||||
CAFFE_ENFORCE_GT(
|
||||
X.ndim(),
|
||||
dims_.back(),
|
||||
"Input needs at least ",
|
||||
(dims_.back() + 1),
|
||||
" dimensions.");
|
||||
const auto& new_dims = SqueezeOp<MKLContext>::ComputeDims(X.dims(), dims_);
|
||||
|
||||
bool dims_changed;
|
||||
CHECK_INPUT_DIMS(X, dims_changed);
|
||||
if (dims_changed) {
|
||||
// Temp buffer mainly to convert the input to plain layout before
|
||||
// Reshape() if the input has a custom layout.
|
||||
buffer_.Reset(X.dims());
|
||||
}
|
||||
|
||||
// Always copy to temp buffer to avoid subsequent runs throwing layout
|
||||
// mismatch errors for X.
|
||||
buffer_.CopyFrom(X);
|
||||
Y->Reset(X.dims(), nullptr, dnnResourceNumber, true);
|
||||
CAFFE_ENFORCE(dnnLayoutCompare<T>(buffer_.layout(), Y->layout()));
|
||||
CAFFE_ENFORCE(Y->ShareFrom(buffer_));
|
||||
Y->Reshape(new_dims);
|
||||
return true;
|
||||
}
|
||||
|
||||
private:
|
||||
vector<int> dims_;
|
||||
vector<TIndex> cached_input_dims_;
|
||||
};
|
||||
|
||||
} // namespace mkl
|
||||
|
||||
REGISTER_MKL_OPERATOR(Squeeze, mkl::MKLSqueezeOp<float>);
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_HAS_MKL_DNN
|
||||
|
|
@ -251,6 +251,41 @@ class MKLMemory {
|
|||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resizes the tensor without touching underlying storage.
|
||||
* This requires the total size of the tensor to remains constant.
|
||||
*/
|
||||
template <typename IndexType>
|
||||
void Reshape(const vector<IndexType>& dims) {
|
||||
CAFFE_ENFORCE(
|
||||
layout_is_user_layout_,
|
||||
"Reshape is not allowed for custom layouts. "
|
||||
"Convert to plain layout before invoking Reshape().");
|
||||
|
||||
TIndex new_size = 1;
|
||||
for (auto i = 0; i < dims.size(); ++i) {
|
||||
CAFFE_ENFORCE_GE_WITH_CALLER(dims[i], 0);
|
||||
new_size *= dims[i];
|
||||
}
|
||||
CAFFE_ENFORCE_WITH_CALLER(
|
||||
new_size == size_,
|
||||
"New size and old size are not equal. Reshape is not possible.");
|
||||
|
||||
vector<TIndex> new_dims(dims.size());
|
||||
vector<size_t> size(dims.size());
|
||||
vector<size_t> strides(dims.size());
|
||||
for (int i = 0; i < dims.size(); ++i) {
|
||||
new_dims[i] = dims[i];
|
||||
size[i] = dims[dims.size() - i - 1];
|
||||
strides[i] = (i == 0) ? 1 : strides[i - 1] * size[i - 1];
|
||||
}
|
||||
dims_ = new_dims;
|
||||
user_layout_.Reset(dims.size(), size.data(), strides.data());
|
||||
layout_.Reset(dims.size(), size.data(), strides.data());
|
||||
convert_in_.Reset(dnnConversionCreate<T>, user_layout_, layout_);
|
||||
convert_out_.Reset(dnnConversionCreate<T>, layout_, user_layout_);
|
||||
}
|
||||
|
||||
// Destructs the MKLMemory.
|
||||
~MKLMemory() {}
|
||||
|
||||
|
|
|
|||
52
caffe2/python/mkl/mkl_squeeze_op_test.py
Normal file
52
caffe2/python/mkl/mkl_squeeze_op_test.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
# Copyright (c) 2016-present, Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
##############################################################################
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import unittest
|
||||
import hypothesis.strategies as st
|
||||
from hypothesis import given
|
||||
import numpy as np
|
||||
from caffe2.python import core, workspace
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import caffe2.python.mkl_test_util as mu
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not workspace.C.has_mkldnn, "Skipping as we do not have mkldnn."
|
||||
)
|
||||
class MKLSqueezeTest(hu.HypothesisTestCase):
|
||||
@given(
|
||||
squeeze_dims=st.lists(st.integers(0, 3), min_size=1, max_size=3),
|
||||
inplace=st.booleans(),
|
||||
**mu.gcs
|
||||
)
|
||||
def test_mkl_squeeze(self, squeeze_dims, inplace, gc, dc):
|
||||
shape = [
|
||||
1 if dim in squeeze_dims else np.random.randint(1, 5)
|
||||
for dim in range(4)
|
||||
]
|
||||
X = np.random.rand(*shape).astype(np.float32)
|
||||
op = core.CreateOperator(
|
||||
"Squeeze", "X", "X" if inplace else "Y", dims=squeeze_dims
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, [X], [0])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user