pytorch/caffe2/operators/variable_length_sequence_padding.h
Richard Barnes 1622546050 use irange for loops (#70248)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/70248

Modified loops in files under fbsource/fbcode/caffe2/ from the format
```
for(TYPE var=x0;var<x_max;x++)
```
to the format
```
for(const auto var: irange(xmax))
```

This was achieved by running r-barnes's loop upgrader script (D28874212) with some modification to exclude all files under /torch/jit and a number of reversions or unused variable suppression warnings added by hand.

Test Plan: Sandcastle

Reviewed By: malfet

Differential Revision: D32813863

fbshipit-source-id: 527244b4a2b220fdfe7f17dee3599603f492a2ca
2022-01-06 23:14:29 -08:00

57 lines
1.4 KiB
C++

#pragma once
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/eigen_utils.h"
#include "caffe2/utils/math.h"
namespace caffe2 {
namespace detail {
template <typename T, typename Context>
void VariableLengthSequencePadding(
int N,
int B,
int M,
T* X,
const int32_t* seqLengths,
const T padValue,
Context* /*context*/) {
for (const auto j : c10::irange(B)) {
for (int i = seqLengths[j]; i < N; i++) {
EigenVectorArrayMap<T>(X + B * M * i + M * j, M).setConstant(padValue);
}
}
}
} // namespace detail
template <typename T, typename Context>
class VariableLengthSequencePaddingOp : public Operator<Context> {
public:
template <class... Args>
explicit VariableLengthSequencePaddingOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {}
USE_OPERATOR_CONTEXT_FUNCTIONS;
bool RunOnDevice() override {
const auto N = Input(INPUT).size(0);
const auto B = Input(INPUT).size(1);
const auto M = Input(INPUT).size(2);
auto X = Output(OUTPUT)->template mutable_data<T>();
auto seqLengths = Input(SEQ_LENGTHS).template data<int32_t>();
detail::VariableLengthSequencePadding<T, Context>(
N, B, M, X, seqLengths, 0, &context_);
return true;
}
protected:
INPUT_TAGS(INPUT, SEQ_LENGTHS);
OUTPUT_TAGS(OUTPUT);
};
} // namespace caffe2