[caffe2] explicitly pass use_offsets=false when calling fbgemm embedding kernels (#35711)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35711

As title

Test Plan: CI

Reviewed By: jianyuh

Differential Revision: D20747290

fbshipit-source-id: fc9fced744cc8f0c61a671cb4b424ff067c2573d
This commit is contained in:
Jongsoo Park 2020-03-31 08:33:19 -07:00 committed by Facebook GitHub Bot
parent 81c2412721
commit ada647214f
3 changed files with 44 additions and 16 deletions

View File

@ -76,14 +76,18 @@ class SparseLengthsFused8BitRowwiseOp : public Operator<Context> {
block_size, block_size,
with_weights, with_weights,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} else { } else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value)); CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
kernel64_ = fbgemm::GenerateEmbeddingSpMDM<std::uint8_t, std::int64_t>( kernel64_ = fbgemm::GenerateEmbeddingSpMDM<std::uint8_t, std::int64_t>(
block_size, block_size,
with_weights, with_weights,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} }
} }

View File

@ -92,7 +92,9 @@ class SparseLengthsFusedNBitRowwiseOp final : public Operator<Context> {
block_size, block_size,
weights != nullptr, weights != nullptr,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} else { } else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value)); CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
kernel64_ = fbgemm::GenerateEmbeddingSpMDMNBit<std::int64_t>( kernel64_ = fbgemm::GenerateEmbeddingSpMDMNBit<std::int64_t>(
@ -100,7 +102,9 @@ class SparseLengthsFusedNBitRowwiseOp final : public Operator<Context> {
block_size, block_size,
weights != nullptr, weights != nullptr,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} }
} }
@ -408,7 +412,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
block_size, block_size,
weights != nullptr, weights != nullptr,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} else { } else {
kernel32_ = kernel32_ =
fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<std::int32_t>( fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<std::int32_t>(
@ -416,7 +422,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
block_size, block_size,
weights != nullptr, weights != nullptr,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} }
} else { } else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value)); CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
@ -426,7 +434,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
block_size, block_size,
weights != nullptr, weights != nullptr,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} else { } else {
kernel64_ = kernel64_ =
fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<std::int64_t>( fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<std::int64_t>(
@ -434,7 +444,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
block_size, block_size,
weights != nullptr, weights != nullptr,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} }
} }
} else { // fallback_to_no_sparse == true } else { // fallback_to_no_sparse == true
@ -446,7 +458,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
block_size, block_size,
with_weights, with_weights,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} else { } else {
kernel32_no_sparse_ = kernel32_no_sparse_ =
fbgemm::GenerateEmbeddingSpMDMNBit<std::int32_t>( fbgemm::GenerateEmbeddingSpMDMNBit<std::int32_t>(
@ -454,7 +468,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
block_size, block_size,
weights != nullptr, weights != nullptr,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} }
} else { } else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value)); CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
@ -464,7 +480,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
block_size, block_size,
with_weights, with_weights,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} else { } else {
kernel64_no_sparse_ = kernel64_no_sparse_ =
fbgemm::GenerateEmbeddingSpMDMNBit<std::int64_t>( fbgemm::GenerateEmbeddingSpMDMNBit<std::int64_t>(
@ -472,7 +490,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
block_size, block_size,
weights != nullptr, weights != nullptr,
is_mean, is_mean,
/*prefetch distance*/ 16); /*prefetch distance*/ 16,
/*is_weight_positional*/ false,
/*use_offsets*/ false);
} }
} }
} }

View File

@ -100,7 +100,8 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
USE_WEIGHT, USE_WEIGHT,
USE_MEAN, USE_MEAN,
/*prefetch distance*/ 16, /*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT); USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
} else { } else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value)); CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
kernel_fp32_i64_ = kernel_fp32_i64_ =
@ -109,7 +110,8 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
USE_WEIGHT, USE_WEIGHT,
USE_MEAN, USE_MEAN,
/*prefetch distance*/ 16, /*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT); USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
} }
} else { } else {
CAFFE_ENFORCE((std::is_same<InputType, at::Half>::value)); CAFFE_ENFORCE((std::is_same<InputType, at::Half>::value));
@ -120,7 +122,8 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
USE_WEIGHT, USE_WEIGHT,
USE_MEAN, USE_MEAN,
/*prefetch distance*/ 16, /*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT); USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
} else { } else {
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value)); CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
kernel_fp16_i64_ = kernel_fp16_i64_ =
@ -129,7 +132,8 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
USE_WEIGHT, USE_WEIGHT,
USE_MEAN, USE_MEAN,
/*prefetch distance*/ 16, /*prefetch distance*/ 16,
USE_POSITIONAL_WEIGHT); USE_POSITIONAL_WEIGHT,
/*use_offsets*/ false);
} }
} }
} }