mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[caffe2] explicitly pass use_offsets=false when calling fbgemm embedding kernels (#35711)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35711 As title Test Plan: CI Reviewed By: jianyuh Differential Revision: D20747290 fbshipit-source-id: fc9fced744cc8f0c61a671cb4b424ff067c2573d
This commit is contained in:
parent
81c2412721
commit
ada647214f
|
|
@ -76,14 +76,18 @@ class SparseLengthsFused8BitRowwiseOp : public Operator<Context> {
|
||||||
block_size,
|
block_size,
|
||||||
with_weights,
|
with_weights,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
} else {
|
} else {
|
||||||
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
||||||
kernel64_ = fbgemm::GenerateEmbeddingSpMDM<std::uint8_t, std::int64_t>(
|
kernel64_ = fbgemm::GenerateEmbeddingSpMDM<std::uint8_t, std::int64_t>(
|
||||||
block_size,
|
block_size,
|
||||||
with_weights,
|
with_weights,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -92,7 +92,9 @@ class SparseLengthsFusedNBitRowwiseOp final : public Operator<Context> {
|
||||||
block_size,
|
block_size,
|
||||||
weights != nullptr,
|
weights != nullptr,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
} else {
|
} else {
|
||||||
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
||||||
kernel64_ = fbgemm::GenerateEmbeddingSpMDMNBit<std::int64_t>(
|
kernel64_ = fbgemm::GenerateEmbeddingSpMDMNBit<std::int64_t>(
|
||||||
|
|
@ -100,7 +102,9 @@ class SparseLengthsFusedNBitRowwiseOp final : public Operator<Context> {
|
||||||
block_size,
|
block_size,
|
||||||
weights != nullptr,
|
weights != nullptr,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -408,7 +412,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
|
||||||
block_size,
|
block_size,
|
||||||
weights != nullptr,
|
weights != nullptr,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
} else {
|
} else {
|
||||||
kernel32_ =
|
kernel32_ =
|
||||||
fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<std::int32_t>(
|
fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<std::int32_t>(
|
||||||
|
|
@ -416,7 +422,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
|
||||||
block_size,
|
block_size,
|
||||||
weights != nullptr,
|
weights != nullptr,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
||||||
|
|
@ -426,7 +434,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
|
||||||
block_size,
|
block_size,
|
||||||
weights != nullptr,
|
weights != nullptr,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
} else {
|
} else {
|
||||||
kernel64_ =
|
kernel64_ =
|
||||||
fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<std::int64_t>(
|
fbgemm::GenerateEmbeddingSpMDMNBitRowWiseSparse<std::int64_t>(
|
||||||
|
|
@ -434,7 +444,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
|
||||||
block_size,
|
block_size,
|
||||||
weights != nullptr,
|
weights != nullptr,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else { // fallback_to_no_sparse == true
|
} else { // fallback_to_no_sparse == true
|
||||||
|
|
@ -446,7 +458,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
|
||||||
block_size,
|
block_size,
|
||||||
with_weights,
|
with_weights,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
} else {
|
} else {
|
||||||
kernel32_no_sparse_ =
|
kernel32_no_sparse_ =
|
||||||
fbgemm::GenerateEmbeddingSpMDMNBit<std::int32_t>(
|
fbgemm::GenerateEmbeddingSpMDMNBit<std::int32_t>(
|
||||||
|
|
@ -454,7 +468,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
|
||||||
block_size,
|
block_size,
|
||||||
weights != nullptr,
|
weights != nullptr,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
||||||
|
|
@ -464,7 +480,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
|
||||||
block_size,
|
block_size,
|
||||||
with_weights,
|
with_weights,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
} else {
|
} else {
|
||||||
kernel64_no_sparse_ =
|
kernel64_no_sparse_ =
|
||||||
fbgemm::GenerateEmbeddingSpMDMNBit<std::int64_t>(
|
fbgemm::GenerateEmbeddingSpMDMNBit<std::int64_t>(
|
||||||
|
|
@ -472,7 +490,9 @@ class SparseLengthsNBitRowwiseSparseOp final : public Operator<CPUContext> {
|
||||||
block_size,
|
block_size,
|
||||||
weights != nullptr,
|
weights != nullptr,
|
||||||
is_mean,
|
is_mean,
|
||||||
/*prefetch distance*/ 16);
|
/*prefetch distance*/ 16,
|
||||||
|
/*is_weight_positional*/ false,
|
||||||
|
/*use_offsets*/ false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -100,7 +100,8 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
|
||||||
USE_WEIGHT,
|
USE_WEIGHT,
|
||||||
USE_MEAN,
|
USE_MEAN,
|
||||||
/*prefetch distance*/ 16,
|
/*prefetch distance*/ 16,
|
||||||
USE_POSITIONAL_WEIGHT);
|
USE_POSITIONAL_WEIGHT,
|
||||||
|
/*use_offsets*/ false);
|
||||||
} else {
|
} else {
|
||||||
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
||||||
kernel_fp32_i64_ =
|
kernel_fp32_i64_ =
|
||||||
|
|
@ -109,7 +110,8 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
|
||||||
USE_WEIGHT,
|
USE_WEIGHT,
|
||||||
USE_MEAN,
|
USE_MEAN,
|
||||||
/*prefetch distance*/ 16,
|
/*prefetch distance*/ 16,
|
||||||
USE_POSITIONAL_WEIGHT);
|
USE_POSITIONAL_WEIGHT,
|
||||||
|
/*use_offsets*/ false);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
CAFFE_ENFORCE((std::is_same<InputType, at::Half>::value));
|
CAFFE_ENFORCE((std::is_same<InputType, at::Half>::value));
|
||||||
|
|
@ -120,7 +122,8 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
|
||||||
USE_WEIGHT,
|
USE_WEIGHT,
|
||||||
USE_MEAN,
|
USE_MEAN,
|
||||||
/*prefetch distance*/ 16,
|
/*prefetch distance*/ 16,
|
||||||
USE_POSITIONAL_WEIGHT);
|
USE_POSITIONAL_WEIGHT,
|
||||||
|
/*use_offsets*/ false);
|
||||||
} else {
|
} else {
|
||||||
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
CAFFE_ENFORCE((std::is_same<IndexType, std::int64_t>::value));
|
||||||
kernel_fp16_i64_ =
|
kernel_fp16_i64_ =
|
||||||
|
|
@ -129,7 +132,8 @@ class CPUSparseLengthsReductionOp : public Operator<CPUContext> {
|
||||||
USE_WEIGHT,
|
USE_WEIGHT,
|
||||||
USE_MEAN,
|
USE_MEAN,
|
||||||
/*prefetch distance*/ 16,
|
/*prefetch distance*/ 16,
|
||||||
USE_POSITIONAL_WEIGHT);
|
USE_POSITIONAL_WEIGHT,
|
||||||
|
/*use_offsets*/ false);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user