[MPS][BE] Delete unused lerp functors (#152443)

For `lerp.Scalar_out` weight (aka alpha) is not an optional argument, so no point in having those specializations.
But move `alpha=1.0` ahead of dispatching to Metal shaders, as plain copy of tensor should still be faster a1a4fee3b8/aten/src/ATen/native/mps/operations/BinaryOps.mm (L285-L290)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152443
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Shulga 2025-04-29 18:50:27 -07:00 committed by PyTorch MergeBot
parent 4a63cab624
commit c01bcc5efb
2 changed files with 7 additions and 32 deletions

View File

@ -18,13 +18,6 @@ struct sub_functor {
}
};
struct lerp_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return static_cast<T>(b);
}
};
struct add_alpha_functor {
template <typename T>
inline T operator()(const T a, const T b, const T alpha) {
@ -229,13 +222,6 @@ struct complex_lerp_alpha_functor {
}
};
struct complex_lerp_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return T(b.x, b.y);
}
};
REGISTER_BINARY_OP(copysign, long, float);
REGISTER_BINARY_OP(copysign, int, float);
REGISTER_BINARY_OP(copysign, float, float);
@ -282,14 +268,6 @@ REGISTER_BINARY_OP(sub, short, short);
REGISTER_BINARY_OP(sub, uchar, uchar);
REGISTER_BINARY_OP(sub, char, char);
REGISTER_BINARY_OP(sub, bool, bool);
REGISTER_BINARY_OP(lerp, long, long);
REGISTER_BINARY_OP(lerp, int, int);
REGISTER_BINARY_OP(lerp, float, float);
REGISTER_BINARY_OP(lerp, half, half);
REGISTER_BINARY_OP(lerp, short, short);
REGISTER_BINARY_OP(lerp, uchar, uchar);
REGISTER_BINARY_OP(lerp, char, char);
REGISTER_BINARY_OP(lerp, bool, bool);
REGISTER_BINARY_ALPHA_OP(add_alpha, long, long);
REGISTER_BINARY_ALPHA_OP(add_alpha, int, int);
REGISTER_BINARY_ALPHA_OP(add_alpha, float, float);
@ -330,7 +308,6 @@ REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat);
REGISTER_BINARY_OP(hermite_polynomial_he, bfloat, bfloat);
REGISTER_BINARY_OP(add, bfloat, bfloat);
REGISTER_BINARY_OP(sub, bfloat, bfloat);
REGISTER_BINARY_OP(lerp, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat);
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat);
@ -347,8 +324,6 @@ REGISTER_BINARY_OP(add, float2, float2);
REGISTER_BINARY_OP(add, half2, half2);
REGISTER_BINARY_OP(sub, float2, float2);
REGISTER_BINARY_OP(sub, half2, half2);
REGISTER_BINARY_OP(lerp, float2, float2);
REGISTER_BINARY_OP(lerp, half2, half2);
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2);
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2);
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2);

View File

@ -265,6 +265,13 @@ static void add_sub_lerp_template(const Tensor& self,
}
const bool alpha_has_value = alpha.toDouble() != 1.0;
if (!alpha_has_value && op_name == "lerp") {
if (!self.is_alias_of(other)) { // if inplace, no-op
output.copy_(other);
}
return;
}
auto self_complex = c10::isComplexType(self.scalar_type());
auto other_complex = c10::isComplexType(other.scalar_type());
auto commonDtype = at::result_type(self, other);
@ -282,13 +289,6 @@ static void add_sub_lerp_template(const Tensor& self,
return;
}
if (!alpha_has_value && op_name == "lerp") {
if (!self.is_alias_of(other)) { // if inplace, no-op
output.copy_(other);
}
return;
}
BinaryOpBlock add_sub_lerp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
MPSGraph* mpsGraph = cachedGraph->graph();
MPSGraphTensor* secondaryTensor = secondaryCastTensor;