add modulo operator

Summary: as desc.

Reviewed By: chocjy

Differential Revision: D6240026

fbshipit-source-id: fa4dcccebc44b0a713946823b6f56e73d5d6146b
This commit is contained in:
Xianjie Chen 2017-11-06 16:18:53 -08:00 committed by Facebook Github Bot
parent 84067bc17d
commit cbb03b8db8
3 changed files with 172 additions and 0 deletions

View File

@ -0,0 +1,60 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "caffe2/operators/mod_op.h"
#include "caffe2/core/operator.h"
#include "caffe2/core/tensor.h"
namespace caffe2 {
template <>
template <typename T>
bool ModOp<CPUContext>::DoRunWithType() {
auto& data = Input(DATA);
auto N = data.size();
const auto* data_ptr = data.template data<T>();
auto* output = Output(0);
output->ResizeLike(Input(DATA));
auto* output_ptr = output->template mutable_data<T>();
for (auto i = 0; i < N; i++) {
output_ptr[i] = data_ptr[i] % divisor_;
}
return true;
}
namespace {
REGISTER_CPU_OPERATOR(Mod, ModOp<CPUContext>);
OPERATOR_SCHEMA(Mod)
.NumInputs(1)
.NumOutputs(1)
.Arg("divisor", "The divisor of the modulo operation. Must >= 1")
.IdenticalTypeAndShape()
.AllowInplace({{0, 0}})
.SetDoc(R"DOC(
Elementwise modulo operation. Each element in the output is the modulo result
of the corresponding elment in the input data. The divisor of the modulo is
provided by the operator argument `divisor`.
)DOC")
.Input(0, "data", "input int32 or int64 data")
.Output(0, "output", "output of data with modulo operation applied");
SHOULD_NOT_DO_GRADIENT(ModOp);
} // namespace
} // namespace caffe2

52
caffe2/operators/mod_op.h Normal file
View File

@ -0,0 +1,52 @@
/**
* Copyright (c) 2016-present, Facebook, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef CAFFE_OPERATORS_MOD_OP_H_
#define CAFFE_OPERATORS_MOD_OP_H_
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
namespace caffe2 {
template <class Context>
class ModOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
ModOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws) {
divisor_ = OperatorBase::GetSingleArgument<int64_t>("divisor", -1);
CAFFE_ENFORCE_GE(divisor_, 1, "divisor must be given with value >= 1");
}
bool RunOnDevice() override {
return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(DATA));
}
template <typename T>
bool DoRunWithType();
protected:
INPUT_TAGS(DATA);
private:
int64_t divisor_;
};
} // namespace caffe2
#endif // CAFFE_OPERATORS_MOD_OP_H_

View File

@ -0,0 +1,60 @@
# Copyright (c) 2016-present, Facebook, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
##############################################################################
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core
from hypothesis import given
import caffe2.python.hypothesis_test_util as hu
import hypothesis.strategies as st
import numpy as np
@st.composite
def _data(draw):
dtype = draw(st.sampled_from([np.int32, np.int64]))
return draw(hu.tensor(dtype=dtype))
class TestMod(hu.HypothesisTestCase):
@given(
data=_data(),
divisor=st.integers(min_value=1, max_value=np.iinfo(np.int64).max),
inplace=st.booleans(),
**hu.gcs_cpu_only
)
def test_mod(self, data, divisor, inplace, gc, dc):
def ref(data):
output = data % divisor
return [output]
op = core.CreateOperator(
'Mod',
['data'],
['data' if inplace else 'output'],
divisor=divisor,
)
self.assertReferenceChecks(gc, op, [data], ref)
if __name__ == "__main__":
import unittest
unittest.main()