[FakeLowP] T76913842 Make AddFakeFp16 take int inputs (#45992)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45992

Created a template version of AddFakeFp16 to take both float and int inputs.

Test Plan: notebook with local bento kernel: N369049

Reviewed By: amylittleyang

Differential Revision: D24169720

fbshipit-source-id: 679de391224f65f6c5b3ca890eb0d157f09712f6
This commit is contained in:
Venkata Chintapalli 2020-10-07 17:41:03 -07:00 committed by Facebook GitHub Bot
parent c86655a815
commit a36f11a3a5

View File

@ -22,7 +22,21 @@ int getSizeFromDims(const std::vector<int>& dims) {
template <class OP>
struct FP16PairWiseCPUFunctor : public OP {
template <typename TIn, typename TOut>
bool Forward(
const std::vector<int>& A_dims,
const std::vector<int>& B_dims,
const TIn* A,
const TIn* B,
TOut* C,
CPUContext* context) const {
OP::Forward(A_dims, B_dims, A, B, C, context);
return true;
}
template<>
bool Forward<float, float>(
const std::vector<int>& A_dims,
const std::vector<int>& B_dims,
const float* A,
@ -54,7 +68,7 @@ OPERATOR_SCHEMA(SumFakeFp16).NumInputs(1, INT_MAX).NumOutputs(1, INT_MAX);
REGISTER_CPU_OPERATOR(
AddFakeFp16,
BinaryElementwiseOp<
TensorTypes<float>,
TensorTypes<float, int>,
CPUContext,
FP16PairWiseCPUFunctor<AddFunctor<CPUContext>>>);
OPERATOR_SCHEMA(AddFakeFp16).NumInputs(2).NumOutputs(1);