mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS][BE] Delete unused lerp functors (#152443)
For `lerp.Scalar_out` weight (aka alpha) is not an optional argument, so no point in having those specializations.
But move `alpha=1.0` ahead of dispatching to Metal shaders, as plain copy of tensor should still be faster a1a4fee3b8/aten/src/ATen/native/mps/operations/BinaryOps.mm (L285-L290)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152443
Approved by: https://github.com/Skylion007
This commit is contained in:
parent
4a63cab624
commit
c01bcc5efb
|
|
@ -18,13 +18,6 @@ struct sub_functor {
|
|||
}
|
||||
};
|
||||
|
||||
struct lerp_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return static_cast<T>(b);
|
||||
}
|
||||
};
|
||||
|
||||
struct add_alpha_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b, const T alpha) {
|
||||
|
|
@ -229,13 +222,6 @@ struct complex_lerp_alpha_functor {
|
|||
}
|
||||
};
|
||||
|
||||
struct complex_lerp_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return T(b.x, b.y);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_BINARY_OP(copysign, long, float);
|
||||
REGISTER_BINARY_OP(copysign, int, float);
|
||||
REGISTER_BINARY_OP(copysign, float, float);
|
||||
|
|
@ -282,14 +268,6 @@ REGISTER_BINARY_OP(sub, short, short);
|
|||
REGISTER_BINARY_OP(sub, uchar, uchar);
|
||||
REGISTER_BINARY_OP(sub, char, char);
|
||||
REGISTER_BINARY_OP(sub, bool, bool);
|
||||
REGISTER_BINARY_OP(lerp, long, long);
|
||||
REGISTER_BINARY_OP(lerp, int, int);
|
||||
REGISTER_BINARY_OP(lerp, float, float);
|
||||
REGISTER_BINARY_OP(lerp, half, half);
|
||||
REGISTER_BINARY_OP(lerp, short, short);
|
||||
REGISTER_BINARY_OP(lerp, uchar, uchar);
|
||||
REGISTER_BINARY_OP(lerp, char, char);
|
||||
REGISTER_BINARY_OP(lerp, bool, bool);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, long, long);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, int, int);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, float, float);
|
||||
|
|
@ -330,7 +308,6 @@ REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat);
|
|||
REGISTER_BINARY_OP(hermite_polynomial_he, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(add, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(sub, bfloat, bfloat);
|
||||
REGISTER_BINARY_OP(lerp, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat);
|
||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat);
|
||||
|
|
@ -347,8 +324,6 @@ REGISTER_BINARY_OP(add, float2, float2);
|
|||
REGISTER_BINARY_OP(add, half2, half2);
|
||||
REGISTER_BINARY_OP(sub, float2, float2);
|
||||
REGISTER_BINARY_OP(sub, half2, half2);
|
||||
REGISTER_BINARY_OP(lerp, float2, float2);
|
||||
REGISTER_BINARY_OP(lerp, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2);
|
||||
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2);
|
||||
|
|
|
|||
|
|
@ -265,6 +265,13 @@ static void add_sub_lerp_template(const Tensor& self,
|
|||
}
|
||||
|
||||
const bool alpha_has_value = alpha.toDouble() != 1.0;
|
||||
if (!alpha_has_value && op_name == "lerp") {
|
||||
if (!self.is_alias_of(other)) { // if inplace, no-op
|
||||
output.copy_(other);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
auto self_complex = c10::isComplexType(self.scalar_type());
|
||||
auto other_complex = c10::isComplexType(other.scalar_type());
|
||||
auto commonDtype = at::result_type(self, other);
|
||||
|
|
@ -282,13 +289,6 @@ static void add_sub_lerp_template(const Tensor& self,
|
|||
return;
|
||||
}
|
||||
|
||||
if (!alpha_has_value && op_name == "lerp") {
|
||||
if (!self.is_alias_of(other)) { // if inplace, no-op
|
||||
output.copy_(other);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
BinaryOpBlock add_sub_lerp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||
MPSGraphTensor* secondaryTensor = secondaryCastTensor;
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user