Optimize channel_stats_op (#16243)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/16243

Optimize channel_stats_op and add NHWC impl

Reviewed By: takatosp1

Differential Revision: D13775515

fbshipit-source-id: decb889e646f5316d4afefdf9f9b6bc6343613cd
This commit is contained in:
Xiaomeng Yang 2019-03-12 11:52:01 -07:00 committed by Facebook Github Bot
parent 99f1465c35
commit 54b33503ec
8 changed files with 252 additions and 233 deletions

View File

@ -1,39 +1,51 @@
#include "caffe2/operators/channel_stats_op.h" #include "caffe2/operators/channel_stats_op.h"
#include "caffe2/utils/eigen_utils.h" #include "caffe2/utils/eigen_utils.h"
namespace caffe2 { namespace caffe2 {
template <> template <>
bool ChannelStatsOp<CPUContext>::RunOnDevice() { template <>
const auto& X = Input(INPUT); bool ChannelStatsOp<CPUContext>::ComputeChannelStatsNCHW<float>(
CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5); const int N,
const int N = X.dim32(0); const int C,
const int C = X.dim32(1); const int HxW,
const int H = X.dim32(2); const float* X,
const int W = X.dim() > 3 ? X.dim32(3) : 1; float* sum,
const int D = X.dim() > 4 ? X.dim32(4) : 1; float* sumsq) {
ConstEigenArrayMap<float> X_arr(X, HxW, N * C);
const int sampleSize = H * W * D; for (int i = 0; i < C; ++i) {
sum[i] = X_arr.col(i).sum();
Output(SUM)->Resize(C); sumsq[i] = X_arr.col(i).square().sum();
Output(SUMSQ)->Resize(C); }
EigenVectorArrayMap<float> sum( for (int i = 1; i < N; ++i) {
Output(SUM)->template mutable_data<float>(), C); for (int j = 0; j < C; ++j) {
EigenVectorArrayMap<float> sumsq( const int c = i * C + j;
Output(SUMSQ)->template mutable_data<float>(), C); sum[j] += X_arr.col(c).sum();
sumsq[j] += X_arr.col(c).square().sum();
sum.setZero();
sumsq.setZero();
ConstEigenArrayMap<float> X_arr(X.data<float>(), sampleSize, N * C);
auto index = 0;
for (int n = 0; n < N; ++n) {
for (int c = 0; c < C; ++c) {
sum(c) += X_arr.col(index).sum();
sumsq(c) += X_arr.col(index).matrix().squaredNorm();
index++;
} }
} }
return true;
}
template <>
template <>
bool ChannelStatsOp<CPUContext>::ComputeChannelStatsNHWC<float>(
const int N,
const int C,
const int HxW,
const float* X,
float* sum,
float* sumsq) {
ConstEigenArrayMap<float> X_arr(X, C, N * HxW);
EigenVectorArrayMap<float> sum_arr(sum, C);
EigenVectorArrayMap<float> sumsq_arr(sumsq, C);
sum_arr = X_arr.col(0);
sumsq_arr = X_arr.col(0).square();
for (int i = 1; i < N * HxW; ++i) {
sum_arr += X_arr.col(i);
sumsq_arr += X_arr.col(i).square();
}
return true; return true;
} }
@ -49,7 +61,6 @@ reduced across multiple batches and used to obtain the mean and variance across
the full set of batches. Using the new mean and variance as input to SpatialBN the full set of batches. Using the new mean and variance as input to SpatialBN
has the effect of changing the batch size over which SpatialBN is applied. has the effect of changing the batch size over which SpatialBN is applied.
)DOC") )DOC")
.Input(0, "X", "The input 4-dimensional tensor of shape NCHW") .Input(0, "X", "The input 4-dimensional tensor of shape NCHW")
.Output( .Output(
0, 0,
@ -61,5 +72,7 @@ has the effect of changing the batch size over which SpatialBN is applied.
"sumsq", "sumsq",
"The output 1-dimensional tensor of size C containing the sum of " "The output 1-dimensional tensor of size C containing the sum of "
"elements squared per channel."); "elements squared per channel.");
SHOULD_NOT_DO_GRADIENT(ChannelStats); SHOULD_NOT_DO_GRADIENT(ChannelStats);
} // namespace caffe2 } // namespace caffe2

View File

@ -1,192 +1,117 @@
#include "caffe2/core/context_gpu.h"
#include "caffe2/operators/channel_stats_op.h" #include "caffe2/operators/channel_stats_op.h"
#include "caffe2/core/context_gpu.h"
#include "caffe2/utils/math/reduce.cuh"
namespace caffe2 { namespace caffe2 {
namespace { namespace {
// based on "Optimizing Parallel Reduction in CUDA" by Mark Harris template <typename T, int kBlockDimX, int kBlockDimY>
__global__ void ChannelStatsNCHWCUDAKernel(
// note - volatile keyword is needed to allow doing a warp reduction without const int N,
// synchronization on recent architectures const int C,
template <unsigned int blockSize> const int HxW,
__device__ void warpReduce(volatile float* sdata, unsigned int tid) { const T* X,
// note - the if statements are "free" as they are resolved at compile time T* sum,
if (blockSize >= 64) T* sumsq) {
sdata[tid] += sdata[tid + 32]; __shared__
if (blockSize >= 32) typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage m_storage;
sdata[tid] += sdata[tid + 16]; __shared__
if (blockSize >= 16) typename BlockReduce2D<T, kBlockDimX, kBlockDimY>::TempStorage v_storage;
sdata[tid] += sdata[tid + 8]; const int c = blockIdx.x;
if (blockSize >= 8) T m_val = 0;
sdata[tid] += sdata[tid + 4]; T v_val = 0;
if (blockSize >= 4) for (int n = threadIdx.x; n < N; n += blockDim.x) {
sdata[tid] += sdata[tid + 2]; for (int hw = threadIdx.y; hw < HxW; hw += blockDim.y) {
if (blockSize >= 2) const int index = (n * C + c) * HxW + hw;
sdata[tid] += sdata[tid + 1]; #if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
} m_val += __ldg(X + index);
v_val += __ldg(X + index) * __ldg(X + index);
template <unsigned int blockSize> #else
__global__ void ChannelStatsBlockKernel( m_val += X[index];
int N, v_val += X[index] * X[index];
int C, #endif
int valsPerChannel,
const float* inputData,
float* sums,
float* sumsq) {
__shared__ float sumData[blockSize];
__shared__ float sumSqData[blockSize];
auto tid = threadIdx.x;
auto numBlocksPerChannel = (valsPerChannel + blockSize - 1) / blockSize;
auto localBlockIndex = blockIdx.x % numBlocksPerChannel;
auto inputIndex = (blockIdx.x / numBlocksPerChannel) * valsPerChannel +
localBlockIndex * blockSize + tid;
sumData[tid] = 0;
sumSqData[tid] = 0;
if (localBlockIndex * blockSize + tid < valsPerChannel) {
sumData[tid] += inputData[inputIndex];
sumSqData[tid] += inputData[inputIndex] * inputData[inputIndex];
}
__syncthreads();
if (blockSize >= 512) {
if (tid < 256) {
sumData[tid] += sumData[tid + 256];
sumSqData[tid] += sumSqData[tid + 256];
} }
__syncthreads();
} }
if (blockSize >= 256) { m_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(m_storage).Sum(m_val);
if (tid < 128) { v_val = BlockReduce2D<T, kBlockDimX, kBlockDimY>(v_storage).Sum(v_val);
sumData[tid] += sumData[tid + 128]; if (threadIdx.x == 0 && threadIdx.y == 0) {
sumSqData[tid] += sumSqData[tid + 128]; sum[c] = m_val;
} sumsq[c] = v_val;
__syncthreads();
}
if (blockSize >= 128) {
if (tid < 64) {
sumData[tid] += sumData[tid + 64];
sumSqData[tid] += sumSqData[tid + 64];
}
__syncthreads();
}
if (tid < 32) {
warpReduce<blockSize>(sumData, tid);
warpReduce<blockSize>(sumSqData, tid);
}
// output block data sorted by C to simplify second reduction
if (tid == 0) {
auto n = blockIdx.x / numBlocksPerChannel / C;
auto c = (blockIdx.x / numBlocksPerChannel) % C;
auto outputIndex = (c * N + n) * numBlocksPerChannel + localBlockIndex;
sums[outputIndex] = sumData[0];
sumsq[outputIndex] = sumSqData[0];
} }
} }
template <unsigned int blockSize> template <typename T>
__global__ void ChannelStatsFinalSumsKernel( __global__ void ChannelStatsNHWCCUDAKernel(
int N, const int N,
int C, const int C,
int numSumsPerChannel, const int HxW,
const float* sumsScratch, const T* X,
const float* sumsqScratch, T* sum,
float* channelSums, T* sumsq) {
float* channelSumsq) { __shared__ typename BlockReduce<T>::TempStorage m_storage;
__shared__ float sumData[blockSize]; __shared__ typename BlockReduce<T>::TempStorage v_storage;
__shared__ float sumSqData[blockSize]; const int inner_size = N * HxW;
const int c = blockIdx.x;
auto tid = threadIdx.x; T m_val = 0;
auto inputIndex = blockIdx.x * N * numSumsPerChannel + tid; T v_val = 0;
sumData[tid] = 0; for (int i = threadIdx.x; i < inner_size; i += blockDim.x) {
sumSqData[tid] = 0; const int index = i * C + c;
for (auto i = inputIndex; i < (blockIdx.x + 1) * N * numSumsPerChannel; #if __CUDA_ARCH__ >= 350 || defined(__HIP_PLATFORM_HCC__)
i += blockSize) { m_val += __ldg(X + index);
sumData[tid] += sumsScratch[i]; v_val += __ldg(X + index) * __ldg(X + index);
sumSqData[tid] += sumsqScratch[i]; #else
m_val += X[index];
v_val += X[index] * X[index];
#endif
} }
__syncthreads(); m_val = BlockReduce<T>(m_storage).Sum(m_val);
if (blockSize >= 512) { v_val = BlockReduce<T>(v_storage).Sum(v_val);
if (tid < 256) { if (threadIdx.x == 0) {
sumData[tid] += sumData[tid + 256]; sum[c] = m_val;
sumSqData[tid] += sumSqData[tid + 256]; sumsq[c] = v_val;
}
__syncthreads();
}
if (blockSize >= 256) {
if (tid < 128) {
sumData[tid] += sumData[tid + 128];
sumSqData[tid] += sumSqData[tid + 128];
}
__syncthreads();
}
if (blockSize >= 128) {
if (tid < 64) {
sumData[tid] += sumData[tid + 64];
sumSqData[tid] += sumSqData[tid + 64];
}
__syncthreads();
}
if (tid < 32) {
warpReduce<blockSize>(sumData, tid);
warpReduce<blockSize>(sumSqData, tid);
}
if (tid == 0) {
channelSums[blockIdx.x] = sumData[0];
channelSumsq[blockIdx.x] = sumSqData[0];
} }
} }
} // namespace } // namespace
template <> template <>
bool ChannelStatsOp<CUDAContext>::RunOnDevice() { template <>
const auto& X = Input(INPUT); bool ChannelStatsOp<CUDAContext>::ComputeChannelStatsNCHW<float>(
CAFFE_ENFORCE(X.dim() >= 3 && X.dim() <= 5); const int N,
const int N = X.dim32(0); const int C,
const int C = X.dim32(1); const int HxW,
const int H = X.dim32(2); const float* X,
const int W = X.dim() > 3 ? X.dim32(3) : 1; float* sum,
const int D = X.dim() > 4 ? X.dim32(4) : 1; float* sumsq) {
DISPATCH_REDUCE_KERNEL_BY_2D_BLOCK_WITH_TYPE_1(
HxW,
ChannelStatsNCHWCUDAKernel,
float,
C,
context_.cuda_stream(),
N,
C,
HxW,
X,
sum,
sumsq);
return true;
}
const auto X_arr = X.data<float>(); template <>
const auto valsPerChannel = H * W * D; template <>
bool ChannelStatsOp<CUDAContext>::ComputeChannelStatsNHWC<float>(
const auto numBlocksPerChannel = CAFFE_GET_BLOCKS(valsPerChannel); const int N,
const auto numBlocksTotal = numBlocksPerChannel * N * C; const int C,
const int HxW,
ReinitializeTensor( const float* X,
&sumScratch_, {numBlocksTotal}, at::dtype<float>().device(CUDA)); float* sum,
ReinitializeTensor( float* sumsq) {
&sumsqScratch_, {numBlocksTotal}, at::dtype<float>().device(CUDA)); ChannelStatsNHWCCUDAKernel<float>
auto sum = Output(SUM, {C}, at::dtype<float>());
auto sumsq = Output(SUMSQ, {C}, at::dtype<float>());
ChannelStatsBlockKernel<CAFFE_CUDA_NUM_THREADS>
<<<numBlocksTotal, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N,
C,
valsPerChannel,
X_arr,
sumScratch_.mutable_data<float>(),
sumsqScratch_.mutable_data<float>());
ChannelStatsFinalSumsKernel<CAFFE_CUDA_NUM_THREADS>
<<<C, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>( <<<C, CAFFE_CUDA_NUM_THREADS, 0, context_.cuda_stream()>>>(
N, N, C, HxW, X, sum, sumsq);
C,
numBlocksPerChannel,
sumScratch_.data<float>(),
sumsqScratch_.data<float>(),
sum->template mutable_data<float>(),
sumsq->template mutable_data<float>());
return true; return true;
} }

View File

@ -1,5 +1,7 @@
#ifndef CAFFE2_OPERATORS_CHANNEL_STATS_OP_H #ifndef CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
#define CAFFE2_OPERATORS_CHANNEL_STATS_OP_H #define CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_
#include <string>
#include "caffe2/core/context.h" #include "caffe2/core/context.h"
#include "caffe2/core/operator.h" #include "caffe2/core/operator.h"
@ -8,26 +10,51 @@
namespace caffe2 { namespace caffe2 {
template <class Context> template <class Context>
class ChannelStatsOp : public Operator<Context> { class ChannelStatsOp final : public Operator<Context> {
public: public:
USE_OPERATOR_CONTEXT_FUNCTIONS; USE_OPERATOR_CONTEXT_FUNCTIONS;
template <class... Args> template <class... Args>
explicit ChannelStatsOp(Args&&... args) explicit ChannelStatsOp(Args&&... args)
: Operator<Context>(std::forward<Args>(args)...) {} : Operator<Context>(std::forward<Args>(args)...),
~ChannelStatsOp() {} order_(StringToStorageOrder(
this->template GetSingleArgument<std::string>("order", "NCHW"))) {
bool RunOnDevice() override { CAFFE_ENFORCE_NE(order_, StorageOrder::UNKNOWN);
return true;
} }
protected: bool RunOnDevice() override {
INPUT_TAGS(INPUT); return DispatchHelper<TensorTypes<float>>::call(this, Input(0));
OUTPUT_TAGS(SUM, SUMSQ); }
Tensor sumScratch_; template <typename T>
Tensor sumsqScratch_; bool DoRunWithType() {
const auto& X = Input(0);
const int ndim = X.dim();
const int N = X.dim32(0);
const int C = order_ == StorageOrder::NCHW ? X.dim32(1) : X.dim32(ndim - 1);
const int HxW = X.numel() / (N * C);
auto* sum = Output(0, {C}, at::dtype<T>());
auto* sumsq = Output(1, {C}, at::dtype<T>());
const T* X_data = X.template data<T>();
T* sum_data = sum->template mutable_data<T>();
T* sumsq_data = sumsq->template mutable_data<T>();
return order_ == StorageOrder::NCHW
? ComputeChannelStatsNCHW<T>(N, C, HxW, X_data, sum_data, sumsq_data)
: ComputeChannelStatsNHWC<T>(N, C, HxW, X_data, sum_data, sumsq_data);
}
private:
template <typename T>
bool
ComputeChannelStatsNCHW(int N, int C, int HxW, const T* X, T* sum, T* sumsq);
template <typename T>
bool
ComputeChannelStatsNHWC(int N, int C, int HxW, const T* X, T* sum, T* sumsq);
const StorageOrder order_;
}; };
} // namespace caffe2 } // namespace caffe2
#endif // CAFFE2_OPERATORS_CHANNEL_STATS_OP_H #endif // CAFFE2_OPERATORS_CHANNEL_STATS_OP_H_

View File

@ -1,42 +1,85 @@
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
from __future__ import unicode_literals
from caffe2.python import core from caffe2.python import core
import caffe2.python.hypothesis_test_util as hu import caffe2.python.hypothesis_test_util as hu
import caffe2.python.serialized_test.serialized_test_util as serial import caffe2.python.serialized_test.serialized_test_util as serial
from hypothesis import assume, given from hypothesis import assume, given
import hypothesis.strategies as st import hypothesis.strategies as st
import numpy as np import numpy as np
import unittest import unittest
class TestChannelStats(serial.SerializedTestCase): class TestChannelStatsOp(serial.SerializedTestCase):
@serial.given( def channel_stats_nchw_ref(self, X):
size=st.integers(7, 10), dims = X.shape
inputChannels=st.integers(1, 10), N = dims[0]
batchSize=st.integers(1, 3), C = dims[1]
**hu.gcs X = X.reshape(N, C, -1)
) sum1 = np.sum(X, axis=(0, 2), keepdims=False)
def testChannelStats(self, size, inputChannels, batchSize, gc, dc): sum2 = np.sum(X**2, axis=(0, 2), keepdims=False)
return (sum1, sum2)
def channel_stats_nhwc_ref(self, X):
dims = X.shape
N = dims[0]
C = dims[-1]
X = X.reshape(N, -1, C)
sum1 = np.sum(X, axis=(0, 1), keepdims=False)
sum2 = np.sum(X**2, axis=(0, 1), keepdims=False)
return (sum1, sum2)
@serial.given(
N=st.integers(1, 5), C=st.integers(1, 10), H=st.integers(1, 12),
W=st.integers(1, 12), order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs)
def test_channel_stats_2d(self, N, C, H, W, order, gc, dc):
op = core.CreateOperator( op = core.CreateOperator(
"ChannelStats", "ChannelStats",
["X"], ["X"],
["sum", "sumsq"], ["sum", "sumsq"],
order=order,
) )
def referenceChannelStatsTest(X): def ref_op(X):
sums = np.sum(X, axis=(0, 2, 3), keepdims=False) if order == "NCHW":
sumsq = np.zeros(inputChannels) return self.channel_stats_nchw_ref(X)
sumsq = np.sum(X**2, axis=(0, 2, 3), keepdims=False) else:
return sums, sumsq return self.channel_stats_nhwc_ref(X)
X = np.random.rand(batchSize, inputChannels, size, size)\ X = np.random.randn(N, C, H, W).astype(np.float32)
.astype(np.float32) - 0.5 if order == "NHWC":
self.assertReferenceChecks(gc, op, [X], referenceChannelStatsTest) X = np.transpose(X, [0, 2, 3, 1])
self.assertReferenceChecks(gc, op, [X], reference=ref_op)
self.assertDeviceChecks(dc, op, [X], [0, 1])
@serial.given(
N=st.integers(1, 5), C=st.integers(1, 10), D=st.integers(1, 6),
H=st.integers(1, 6), W=st.integers(1, 6),
order=st.sampled_from(["NCHW", "NHWC"]), **hu.gcs)
def test_channel_stats_3d(self, N, C, D, H, W, order, gc, dc):
op = core.CreateOperator(
"ChannelStats",
["X"],
["sum", "sumsq"],
order=order,
)
def ref_op(X):
if order == "NCHW":
return self.channel_stats_nchw_ref(X)
else:
return self.channel_stats_nhwc_ref(X)
X = np.random.randn(N, C, D, H, W).astype(np.float32)
if order == "NHWC":
X = np.transpose(X, [0, 2, 3, 4, 1])
self.assertReferenceChecks(gc, op, [X], reference=ref_op)
self.assertDeviceChecks(dc, op, [X], [0, 1])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -10,6 +10,8 @@ from hypothesis import given
import hypothesis.strategies as st import hypothesis.strategies as st
import numpy as np import numpy as np
import unittest
class TestGroupNormOp(serial.SerializedTestCase): class TestGroupNormOp(serial.SerializedTestCase):
def group_norm_nchw_ref(self, X, gamma, beta, group, epsilon): def group_norm_nchw_ref(self, X, gamma, beta, group, epsilon):
@ -144,3 +146,7 @@ class TestGroupNormOp(serial.SerializedTestCase):
inputs = [X, gamma, beta] inputs = [X, gamma, beta]
for i in range(len(inputs)): for i in range(len(inputs)):
self.assertGradientChecks(gc, op, inputs, i, [0]) self.assertGradientChecks(gc, op, inputs, i, [0])
if __name__ == "__main__":
unittest.main()

View File

@ -1,11 +1,11 @@
# Serialized Test Coverage Report # Serialized Test Coverage Report
This is an automatically generated file. Please see `caffe2/python/serialized_test/README.md` for details. In the case of merge conflicts, please rebase and regenerate. This is an automatically generated file. Please see `caffe2/python/serialized_test/README.md` for details. In the case of merge conflicts, please rebase and regenerate.
## Summary ## Summary
Serialized tests have covered 217/684 (31.7%) operators Serialized tests have covered 219/688 (31.8%) operators
## Not covered operators ## Not covered operators
<details> <details>
<summary>There are 467 not covered operators</summary> <summary>There are 469 not covered operators</summary>
* APMeter * APMeter
* ATen * ATen
@ -17,6 +17,7 @@ Serialized tests have covered 217/684 (31.7%) operators
* Adam * Adam
* Add * Add
* AddGradient * AddGradient
* AdjustBatch
* Alias * Alias
* Allgather * Allgather
* Allreduce * Allreduce
@ -96,6 +97,7 @@ Serialized tests have covered 217/684 (31.7%) operators
* CubeGradient * CubeGradient
* DBExists * DBExists
* DataCouple * DataCouple
* DenseVectorToIdList
* DepthConcat * DepthConcat
* DepthSplit * DepthSplit
* DequeueBlobs * DequeueBlobs
@ -478,7 +480,7 @@ Serialized tests have covered 217/684 (31.7%) operators
## Covered operators ## Covered operators
<details> <details>
<summary>There are 217 covered operators</summary> <summary>There are 219 covered operators</summary>
* Acos * Acos
* AcosGradient * AcosGradient
@ -543,6 +545,8 @@ Serialized tests have covered 217/684 (31.7%) operators
* ElementwiseLinearGradient * ElementwiseLinearGradient
* Elu * Elu
* EluGradient * EluGradient
* Erf
* ErfGradient
* Expand * Expand
* ExpandGradient * ExpandGradient
* FC * FC
@ -702,7 +706,7 @@ Serialized tests have covered 217/684 (31.7%) operators
## Excluded from coverage statistics ## Excluded from coverage statistics
### Schemaless operators ### Schemaless operators
<details> <details>
<summary>There are 21 schemaless operators</summary> <summary>There are 22 schemaless operators</summary>
* C10Add_DontUseThisOpYet * C10Add_DontUseThisOpYet
* C10AveragedLoss_DontUseThisOpYet * C10AveragedLoss_DontUseThisOpYet
@ -718,6 +722,7 @@ Serialized tests have covered 217/684 (31.7%) operators
* C10GivenTensorFill_DontUseThisOpYet * C10GivenTensorFill_DontUseThisOpYet
* C10GivenTensorInt64Fill_DontUseThisOpYet * C10GivenTensorInt64Fill_DontUseThisOpYet
* C10GivenTensorIntFill_DontUseThisOpYet * C10GivenTensorIntFill_DontUseThisOpYet
* C10LayerNorm_DontUseThisOpYet
* C10Mul_DontUseThisOpYet * C10Mul_DontUseThisOpYet
* C10Relu_DontUseThisOpYet * C10Relu_DontUseThisOpYet
* C10SigmoidCrossEntropyWithLogits_DontUseThisOpYet * C10SigmoidCrossEntropyWithLogits_DontUseThisOpYet