mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Operator to Merge ID_LIST features
Summary: As an alternative to sharing embeddings, we want to explore merging the ID_LISTs in the net. This commit adds an operator to merge many ID_LIST features into a single one. Differential Revision: D5481523 fbshipit-source-id: 446121122a32de5682d5d75a165370bc8d776d03
This commit is contained in:
parent
58838baa75
commit
ae2aad9c0d
32
caffe2/operators/merge_id_lists_op.cc
Normal file
32
caffe2/operators/merge_id_lists_op.cc
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
#include "caffe2/operators/merge_id_lists_op.h"
|
||||
|
||||
namespace caffe2 {
|
||||
namespace {
|
||||
REGISTER_CPU_OPERATOR(MergeIdLists, MergeIdListsOp<CPUContext>);
|
||||
|
||||
OPERATOR_SCHEMA(MergeIdLists)
|
||||
.NumInputs([](int n) { return (n > 0 && n % 2 == 0); })
|
||||
.NumOutputs(2)
|
||||
.SetDoc(R"DOC(
|
||||
MergeIdLists: Merge multiple ID_LISTs into a single ID_LIST.
|
||||
|
||||
An ID_LIST is a list of IDs (may be ints, often longs) that represents a single
|
||||
feature. As described in https://caffe2.ai/docs/sparse-operations.html, a batch
|
||||
of ID_LIST examples is represented as a pair of lengths and values where the
|
||||
`lengths` (int32) segment the `values` or ids (int32/int64) into examples.
|
||||
|
||||
Given multiple inputs of the form lengths_0, values_0, lengths_1, values_1, ...
|
||||
which correspond to lengths and values of ID_LISTs of different features, this
|
||||
operator produces a merged ID_LIST that combines the ID_LIST features. The
|
||||
final merged output is described by a lengths and values vector.
|
||||
|
||||
WARNING: The merge makes no guarantee about the relative order of ID_LISTs
|
||||
within a batch. This can be an issue if ID_LIST are order sensitive.
|
||||
)DOC")
|
||||
.Input(0, "lengths_0", "Lengths of the ID_LISTs batch for first feature")
|
||||
.Input(1, "values_0", "Values of the ID_LISTs batch for first feature")
|
||||
.Output(0, "merged_lengths", "Lengths of the merged ID_LISTs batch")
|
||||
.Output(1, "merged_values", "Values of the merged ID_LISTs batch");
|
||||
NO_GRADIENT(MergeIdLists);
|
||||
}
|
||||
}
|
||||
82
caffe2/operators/merge_id_lists_op.h
Normal file
82
caffe2/operators/merge_id_lists_op.h
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
#ifndef CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
|
||||
#define CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "caffe2/core/context.h"
|
||||
#include "caffe2/core/operator.h"
|
||||
|
||||
namespace caffe2 {
|
||||
|
||||
template <class Context>
|
||||
class MergeIdListsOp : public Operator<Context> {
|
||||
public:
|
||||
USE_OPERATOR_CONTEXT_FUNCTIONS;
|
||||
USE_SIMPLE_CTOR_DTOR(MergeIdListsOp);
|
||||
|
||||
template <typename T>
|
||||
bool DoRunWithType() {
|
||||
auto& first_lengths = Input(0);
|
||||
CAFFE_ENFORCE_EQ(first_lengths.ndim(), 1, "LENGTHS should be 1-D");
|
||||
const auto batch_size = first_lengths.size();
|
||||
|
||||
auto* out_lengths = Output(0);
|
||||
out_lengths->ResizeLike(first_lengths);
|
||||
|
||||
auto* out_lengths_data = out_lengths->template mutable_data<int32_t>();
|
||||
|
||||
/**
|
||||
* Loop to figure out how much space to reserve for output
|
||||
* and perform checks.
|
||||
*/
|
||||
auto M = 0;
|
||||
for (size_t i = 0; i < InputSize(); i += 2) {
|
||||
auto& lengths = Input(i);
|
||||
CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS should be 1-D");
|
||||
CAFFE_ENFORCE_EQ(lengths.size(), batch_size, "LENGTHS should be equal");
|
||||
auto& values = Input(i + 1);
|
||||
CAFFE_ENFORCE_EQ(values.ndim(), 1, "VALUES should be 1-D");
|
||||
M += values.size();
|
||||
}
|
||||
|
||||
auto* out_values = Output(1);
|
||||
out_values->Resize(M);
|
||||
|
||||
T* out_values_data = out_values->template mutable_data<T>();
|
||||
auto pos = 0;
|
||||
|
||||
// TODO(badri): Use unordered_set if performance is an issue
|
||||
std::set<T> deduped;
|
||||
std::vector<int> offsets(InputSize(), 0);
|
||||
for (auto sample = 0; sample < batch_size; sample++) {
|
||||
for (size_t i = 0; i < InputSize(); i += 2) {
|
||||
auto& lengths = Input(i);
|
||||
const auto* lengths_data = lengths.template data<int32_t>();
|
||||
|
||||
auto& values = Input(i + 1);
|
||||
const T* values_data = values.template data<T>();
|
||||
const auto length = lengths_data[sample];
|
||||
|
||||
for (auto j = offsets[i]; j < offsets[i] + length; j++) {
|
||||
deduped.insert(values_data[j]);
|
||||
}
|
||||
offsets[i] += length;
|
||||
}
|
||||
for (auto val : deduped) {
|
||||
out_values_data[pos++] = val;
|
||||
}
|
||||
out_lengths_data[sample] = deduped.size();
|
||||
deduped.clear();
|
||||
}
|
||||
out_values->Resize(pos);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool RunOnDevice() override {
|
||||
return DispatchHelper<TensorTypes<int32_t, int64_t>>::call(this, Input(1));
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace caffe2
|
||||
|
||||
#endif // CAFFE2_OPERATORS_MERGE_ID_LISTS_OP_H_
|
||||
84
caffe2/python/operator_test/merge_id_lists_op_test.py
Normal file
84
caffe2/python/operator_test/merge_id_lists_op_test.py
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import numpy as np
|
||||
|
||||
from hypothesis import given
|
||||
import hypothesis.strategies as st
|
||||
|
||||
from caffe2.python import core
|
||||
import caffe2.python.hypothesis_test_util as hu
|
||||
|
||||
import hypothesis.extra.numpy as hnp
|
||||
|
||||
|
||||
@st.composite
|
||||
def id_list_batch(draw):
|
||||
num_inputs = draw(st.integers(1, 3))
|
||||
batch_size = draw(st.integers(5, 10))
|
||||
values_dtype = draw(st.sampled_from([np.int32, np.int64]))
|
||||
inputs = []
|
||||
for _ in range(num_inputs):
|
||||
size = draw(st.integers(5, 10))
|
||||
values = draw(hnp.arrays(values_dtype, size, st.integers(1, 10)))
|
||||
lengths = draw(hu.lengths(len(values),
|
||||
min_segments=batch_size,
|
||||
max_segments=batch_size))
|
||||
inputs.append(lengths)
|
||||
inputs.append(values)
|
||||
return inputs
|
||||
|
||||
|
||||
def merge_id_lists_ref(*args):
|
||||
n = len(args)
|
||||
assert n > 0
|
||||
assert n % 2 == 0
|
||||
batch_size = len(args[0])
|
||||
num_inputs = int(n / 2)
|
||||
lengths = np.array([np.insert(args[2 * i], 0, 0)
|
||||
for i in range(num_inputs)])
|
||||
values = [args[2 * i + 1] for i in range(num_inputs)]
|
||||
offsets = [np.cumsum(lengths[j]) for j in range(num_inputs)]
|
||||
|
||||
def merge_arrays(vs, offs, j):
|
||||
concat = np.concatenate([vs[i][offs[i][j]:offs[i][j + 1]]
|
||||
for i in range(num_inputs)])
|
||||
return np.sort(np.unique(concat))
|
||||
|
||||
merged = [merge_arrays(values, offsets, j) for j in range(batch_size)]
|
||||
merged_lengths = np.array([len(x) for x in merged])
|
||||
merged_values = np.concatenate(merged)
|
||||
return merged_lengths, merged_values
|
||||
|
||||
|
||||
class TestMergeIdListsOp(hu.HypothesisTestCase):
|
||||
def test_merge_id_lists_ref(self):
|
||||
# Verify that the reference implementation is correct!
|
||||
lengths_0 = np.array([3, 0, 4], dtype=np.int32)
|
||||
values_0 = np.array([1, 5, 6, 2, 4, 5, 6], dtype=np.int64)
|
||||
lengths_1 = np.array([3, 2, 1], dtype=np.int32)
|
||||
values_1 = np.array([5, 8, 9, 14, 9, 5], dtype=np.int64)
|
||||
|
||||
merged_lengths, merged_values = merge_id_lists_ref(
|
||||
lengths_0, values_0, lengths_1, values_1)
|
||||
expected_lengths = np.array([5, 2, 4], dtype=np.int32)
|
||||
expected_values = np.array([1, 5, 6, 8, 9, 9, 14, 2, 4, 5, 6], dtype=np.int64)
|
||||
|
||||
np.testing.assert_array_equal(merged_lengths, expected_lengths)
|
||||
np.testing.assert_array_equal(merged_values, expected_values)
|
||||
|
||||
@given(inputs=id_list_batch(),
|
||||
**hu.gcs_cpu_only)
|
||||
def test_merge_id_lists_op(self, inputs, gc, dc):
|
||||
num_inputs = int(len(inputs) / 2)
|
||||
op = core.CreateOperator(
|
||||
"MergeIdLists",
|
||||
["{prefix}_{i}".format(prefix=p, i=i)
|
||||
for i in range(num_inputs)
|
||||
for p in ["lengths", "values"]],
|
||||
["merged_lengths", "merged_values"]
|
||||
)
|
||||
self.assertDeviceChecks(dc, op, inputs, [0])
|
||||
self.assertReferenceChecks(gc, op, inputs, merge_id_lists_ref)
|
||||
Loading…
Reference in New Issue
Block a user