mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
add modulo operator
Summary: as desc. Reviewed By: chocjy Differential Revision: D6240026 fbshipit-source-id: fa4dcccebc44b0a713946823b6f56e73d5d6146b
This commit is contained in:
parent
84067bc17d
commit
cbb03b8db8
60
caffe2/operators/mod_op.cc
Normal file
60
caffe2/operators/mod_op.cc
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#include "caffe2/operators/mod_op.h"
|
||||
|
||||
#include "caffe2/core/operator.h"
|
||||
#include "caffe2/core/tensor.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <>
|
||||
template <typename T>
|
||||
bool ModOp<CPUContext>::DoRunWithType() {
|
||||
auto& data = Input(DATA);
|
||||
auto N = data.size();
|
||||
const auto* data_ptr = data.template data<T>();
|
||||
|
||||
auto* output = Output(0);
|
||||
output->ResizeLike(Input(DATA));
|
||||
auto* output_ptr = output->template mutable_data<T>();
|
||||
|
||||
for (auto i = 0; i < N; i++) {
|
||||
output_ptr[i] = data_ptr[i] % divisor_;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
REGISTER_CPU_OPERATOR(Mod, ModOp<CPUContext>);
|
||||
OPERATOR_SCHEMA(Mod)
|
||||
.NumInputs(1)
|
||||
.NumOutputs(1)
|
||||
.Arg("divisor", "The divisor of the modulo operation. Must >= 1")
|
||||
.IdenticalTypeAndShape()
|
||||
.AllowInplace({{0, 0}})
|
||||
.SetDoc(R"DOC(
|
||||
Elementwise modulo operation. Each element in the output is the modulo result
|
||||
of the corresponding elment in the input data. The divisor of the modulo is
|
||||
provided by the operator argument `divisor`.
|
||||
)DOC")
|
||||
.Input(0, "data", "input int32 or int64 data")
|
||||
.Output(0, "output", "output of data with modulo operation applied");
|
||||
|
||||
SHOULD_NOT_DO_GRADIENT(ModOp);
|
||||
} // namespace
|
||||
} // namespace caffe2
|
||||
52
caffe2/operators/mod_op.h
Normal file
52
caffe2/operators/mod_op.h
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
/**
|
||||
* Copyright (c) 2016-present, Facebook, Inc.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef CAFFE_OPERATORS_MOD_OP_H_
|
||||
#define CAFFE_OPERATORS_MOD_OP_H_
|
||||
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/logging.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <class Context>
|
||||
class ModOp final : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
ModOp(const OperatorDef& operator_def, Workspace* ws)
|
||||
: Operator<Context>(operator_def, ws) {
|
||||
divisor_ = OperatorBase::GetSingleArgument<int64_t>("divisor", -1);
|
||||
CAFFE_ENFORCE_GE(divisor_, 1, "divisor must be given with value >= 1");
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
return DispatchHelper<TensorTypes<int, int64_t>>::call(this, Input(DATA));
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool DoRunWithType();
|
||||
|
||||
protected:
|
||||
INPUT_TAGS(DATA);
|
||||
|
||||
private:
|
||||
int64_t divisor_;
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE_OPERATORS_MOD_OP_H_
|
||||
60
caffe2/python/operator_test/mod_op_test.py
Normal file
60
caffe2/python/operator_test/mod_op_test.py
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
# Copyright (c) 2016-present, Facebook, Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
##############################################################################
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from caffe2.python import core
|
||||
from hypothesis import given
|
||||
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
import hypothesis.strategies as st
|
||||
import numpy as np
|
||||
|
||||
|
||||
@st.composite
|
||||
def _data(draw):
|
||||
dtype = draw(st.sampled_from([np.int32, np.int64]))
|
||||
return draw(hu.tensor(dtype=dtype))
|
||||
|
||||
|
||||
class TestMod(hu.HypothesisTestCase):
|
||||
@given(
|
||||
data=_data(),
|
||||
divisor=st.integers(min_value=1, max_value=np.iinfo(np.int64).max),
|
||||
inplace=st.booleans(),
|
||||
**hu.gcs_cpu_only
|
||||
)
|
||||
def test_mod(self, data, divisor, inplace, gc, dc):
|
||||
|
||||
def ref(data):
|
||||
output = data % divisor
|
||||
return [output]
|
||||
|
||||
op = core.CreateOperator(
|
||||
'Mod',
|
||||
['data'],
|
||||
['data' if inplace else 'output'],
|
||||
divisor=divisor,
|
||||
)
|
||||
|
||||
self.assertReferenceChecks(gc, op, [data], ref)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import unittest
|
||||
unittest.main()
|
||||
Loading…
Reference in New Issue
Block a user