mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS][BE] Delete unused lerp functors (#152443)
For `lerp.Scalar_out` weight (aka alpha) is not an optional argument, so no point in having those specializations.
But move `alpha=1.0` ahead of dispatching to Metal shaders, as plain copy of tensor should still be faster a1a4fee3b8/aten/src/ATen/native/mps/operations/BinaryOps.mm (L285-L290)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/152443
Approved by: https://github.com/Skylion007
This commit is contained in:
parent
4a63cab624
commit
c01bcc5efb
|
|
@ -18,13 +18,6 @@ struct sub_functor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct lerp_functor {
|
|
||||||
template <typename T>
|
|
||||||
inline T operator()(const T a, const T b) {
|
|
||||||
return static_cast<T>(b);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct add_alpha_functor {
|
struct add_alpha_functor {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T operator()(const T a, const T b, const T alpha) {
|
inline T operator()(const T a, const T b, const T alpha) {
|
||||||
|
|
@ -229,13 +222,6 @@ struct complex_lerp_alpha_functor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct complex_lerp_functor {
|
|
||||||
template <typename T>
|
|
||||||
inline T operator()(const T a, const T b) {
|
|
||||||
return T(b.x, b.y);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
REGISTER_BINARY_OP(copysign, long, float);
|
REGISTER_BINARY_OP(copysign, long, float);
|
||||||
REGISTER_BINARY_OP(copysign, int, float);
|
REGISTER_BINARY_OP(copysign, int, float);
|
||||||
REGISTER_BINARY_OP(copysign, float, float);
|
REGISTER_BINARY_OP(copysign, float, float);
|
||||||
|
|
@ -282,14 +268,6 @@ REGISTER_BINARY_OP(sub, short, short);
|
||||||
REGISTER_BINARY_OP(sub, uchar, uchar);
|
REGISTER_BINARY_OP(sub, uchar, uchar);
|
||||||
REGISTER_BINARY_OP(sub, char, char);
|
REGISTER_BINARY_OP(sub, char, char);
|
||||||
REGISTER_BINARY_OP(sub, bool, bool);
|
REGISTER_BINARY_OP(sub, bool, bool);
|
||||||
REGISTER_BINARY_OP(lerp, long, long);
|
|
||||||
REGISTER_BINARY_OP(lerp, int, int);
|
|
||||||
REGISTER_BINARY_OP(lerp, float, float);
|
|
||||||
REGISTER_BINARY_OP(lerp, half, half);
|
|
||||||
REGISTER_BINARY_OP(lerp, short, short);
|
|
||||||
REGISTER_BINARY_OP(lerp, uchar, uchar);
|
|
||||||
REGISTER_BINARY_OP(lerp, char, char);
|
|
||||||
REGISTER_BINARY_OP(lerp, bool, bool);
|
|
||||||
REGISTER_BINARY_ALPHA_OP(add_alpha, long, long);
|
REGISTER_BINARY_ALPHA_OP(add_alpha, long, long);
|
||||||
REGISTER_BINARY_ALPHA_OP(add_alpha, int, int);
|
REGISTER_BINARY_ALPHA_OP(add_alpha, int, int);
|
||||||
REGISTER_BINARY_ALPHA_OP(add_alpha, float, float);
|
REGISTER_BINARY_ALPHA_OP(add_alpha, float, float);
|
||||||
|
|
@ -330,7 +308,6 @@ REGISTER_BINARY_OP(hermite_polynomial_h, bfloat, bfloat);
|
||||||
REGISTER_BINARY_OP(hermite_polynomial_he, bfloat, bfloat);
|
REGISTER_BINARY_OP(hermite_polynomial_he, bfloat, bfloat);
|
||||||
REGISTER_BINARY_OP(add, bfloat, bfloat);
|
REGISTER_BINARY_OP(add, bfloat, bfloat);
|
||||||
REGISTER_BINARY_OP(sub, bfloat, bfloat);
|
REGISTER_BINARY_OP(sub, bfloat, bfloat);
|
||||||
REGISTER_BINARY_OP(lerp, bfloat, bfloat);
|
|
||||||
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat);
|
REGISTER_BINARY_ALPHA_OP(add_alpha, bfloat, bfloat);
|
||||||
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat);
|
REGISTER_BINARY_ALPHA_OP(sub_alpha, bfloat, bfloat);
|
||||||
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat);
|
REGISTER_BINARY_ALPHA_OP(lerp_alpha, bfloat, bfloat);
|
||||||
|
|
@ -347,8 +324,6 @@ REGISTER_BINARY_OP(add, float2, float2);
|
||||||
REGISTER_BINARY_OP(add, half2, half2);
|
REGISTER_BINARY_OP(add, half2, half2);
|
||||||
REGISTER_BINARY_OP(sub, float2, float2);
|
REGISTER_BINARY_OP(sub, float2, float2);
|
||||||
REGISTER_BINARY_OP(sub, half2, half2);
|
REGISTER_BINARY_OP(sub, half2, half2);
|
||||||
REGISTER_BINARY_OP(lerp, float2, float2);
|
|
||||||
REGISTER_BINARY_OP(lerp, half2, half2);
|
|
||||||
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2);
|
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, float2, float2);
|
||||||
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2);
|
REGISTER_BINARY_ALPHA_OP(complex_add_alpha, half2, half2);
|
||||||
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2);
|
REGISTER_BINARY_ALPHA_OP(complex_sub_alpha, float2, float2);
|
||||||
|
|
|
||||||
|
|
@ -265,6 +265,13 @@ static void add_sub_lerp_template(const Tensor& self,
|
||||||
}
|
}
|
||||||
|
|
||||||
const bool alpha_has_value = alpha.toDouble() != 1.0;
|
const bool alpha_has_value = alpha.toDouble() != 1.0;
|
||||||
|
if (!alpha_has_value && op_name == "lerp") {
|
||||||
|
if (!self.is_alias_of(other)) { // if inplace, no-op
|
||||||
|
output.copy_(other);
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
auto self_complex = c10::isComplexType(self.scalar_type());
|
auto self_complex = c10::isComplexType(self.scalar_type());
|
||||||
auto other_complex = c10::isComplexType(other.scalar_type());
|
auto other_complex = c10::isComplexType(other.scalar_type());
|
||||||
auto commonDtype = at::result_type(self, other);
|
auto commonDtype = at::result_type(self, other);
|
||||||
|
|
@ -282,13 +289,6 @@ static void add_sub_lerp_template(const Tensor& self,
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!alpha_has_value && op_name == "lerp") {
|
|
||||||
if (!self.is_alias_of(other)) { // if inplace, no-op
|
|
||||||
output.copy_(other);
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
BinaryOpBlock add_sub_lerp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
BinaryOpBlock add_sub_lerp_op_block = ^BinaryOpFn(cachedGraph, primaryCastTensor, secondaryCastTensor) {
|
||||||
MPSGraph* mpsGraph = cachedGraph->graph();
|
MPSGraph* mpsGraph = cachedGraph->graph();
|
||||||
MPSGraphTensor* secondaryTensor = secondaryCastTensor;
|
MPSGraphTensor* secondaryTensor = secondaryCastTensor;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user