mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[FakeLowP] T76913842 Make AddFakeFp16 take int inputs (#45992)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45992 Created a template version of AddFakeFp16 to take both float and int inputs. Test Plan: notebook with local bento kernel: N369049 Reviewed By: amylittleyang Differential Revision: D24169720 fbshipit-source-id: 679de391224f65f6c5b3ca890eb0d157f09712f6
This commit is contained in:
parent
c86655a815
commit
a36f11a3a5
|
|
@ -22,7 +22,21 @@ int getSizeFromDims(const std::vector<int>& dims) {
|
|||
|
||||
template <class OP>
|
||||
struct FP16PairWiseCPUFunctor : public OP {
|
||||
template <typename TIn, typename TOut>
|
||||
bool Forward(
|
||||
const std::vector<int>& A_dims,
|
||||
const std::vector<int>& B_dims,
|
||||
const TIn* A,
|
||||
const TIn* B,
|
||||
TOut* C,
|
||||
CPUContext* context) const {
|
||||
OP::Forward(A_dims, B_dims, A, B, C, context);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template<>
|
||||
bool Forward<float, float>(
|
||||
const std::vector<int>& A_dims,
|
||||
const std::vector<int>& B_dims,
|
||||
const float* A,
|
||||
|
|
@ -54,7 +68,7 @@ OPERATOR_SCHEMA(SumFakeFp16).NumInputs(1, INT_MAX).NumOutputs(1, INT_MAX);
|
|||
REGISTER_CPU_OPERATOR(
|
||||
AddFakeFp16,
|
||||
BinaryElementwiseOp<
|
||||
TensorTypes<float>,
|
||||
TensorTypes<float, int>,
|
||||
CPUContext,
|
||||
FP16PairWiseCPUFunctor<AddFunctor<CPUContext>>>);
|
||||
OPERATOR_SCHEMA(AddFakeFp16).NumInputs(2).NumOutputs(1);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user