[MPS][BE] Migrate complex_mul to tensor iterator (#149728)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/149728
Approved by: https://github.com/dcci
ghstack dependencies: #149727
This commit is contained in:
Nikita Shulga 2025-03-21 06:49:36 -07:00 committed by PyTorch MergeBot
parent e35ef61066
commit 64d22b9fad
2 changed files with 15 additions and 33 deletions

View File

@ -83,6 +83,7 @@ struct nextafter_functor {
} }
}; };
// Complex binary functors
struct polar_functor { struct polar_functor {
template <typename U> template <typename U>
using ret_type = c10::metal::vec2type_t<U>; using ret_type = c10::metal::vec2type_t<U>;
@ -102,6 +103,13 @@ struct make_complex_functor {
} }
}; };
struct complex_mul_functor {
template <typename T>
inline T operator()(const T a, const T b) {
return T(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
}
};
REGISTER_BINARY_INDEXING_OP(copysign, long); REGISTER_BINARY_INDEXING_OP(copysign, long);
REGISTER_BINARY_INDEXING_OP(copysign, int); REGISTER_BINARY_INDEXING_OP(copysign, int);
REGISTER_BINARY_INDEXING_OP(copysign, float); REGISTER_BINARY_INDEXING_OP(copysign, float);
@ -135,28 +143,5 @@ REGISTER_BINARY_INDEXING_OP(polar, float);
REGISTER_BINARY_INDEXING_OP(polar, half); REGISTER_BINARY_INDEXING_OP(polar, half);
REGISTER_BINARY_INDEXING_OP(make_complex, float); REGISTER_BINARY_INDEXING_OP(make_complex, float);
REGISTER_BINARY_INDEXING_OP(make_complex, half); REGISTER_BINARY_INDEXING_OP(make_complex, half);
REGISTER_BINARY_INDEXING_OP(complex_mul, float2);
template <typename T> REGISTER_BINARY_INDEXING_OP(complex_mul, half2);
kernel void complex_mul(
constant void* input_ [[buffer(0)]],
constant void* other_ [[buffer(1)]],
device void* out_ [[buffer(2)]],
constant uint3* offsets [[buffer(3)]],
uint tid [[thread_position_in_grid]]) {
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
out[0] = input[0] * other[0] - input[1] * other[1];
out[1] = input[0] * other[1] + input[1] * other[0];
}
#define REGISTER_BINARY_OP(NAME, DTYPE) \
template [[host_name(#NAME "_" #DTYPE)]] kernel void NAME<DTYPE>( \
constant void* input_, \
constant void* other_, \
device void* out_, \
constant uint3* offsets, \
uint tid)
REGISTER_BINARY_OP(complex_mul, float);
REGISTER_BINARY_OP(complex_mul, half);

View File

@ -42,13 +42,11 @@ void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& out
return; return;
} }
auto common_dtype = output.scalar_type(); auto common_dtype = output.scalar_type();
auto output_as_real = at::view_as_real(output).select(output.dim(), 0); auto input_cast = input.to(kMPS, common_dtype);
auto input_as_real = at::view_as_real(input.to(kMPS, common_dtype)).select(input.dim(), 0); auto other_cast = other.to(kMPS, common_dtype);
auto other_as_real = at::view_as_real(other.to(kMPS, common_dtype)).select(other.dim(), 0); auto iter = TensorIteratorConfig().add_output(output).add_input(input_cast).add_input(other_cast).build();
auto iter =
TensorIteratorConfig().add_output(output_as_real).add_input(input_as_real).add_input(other_as_real).build();
lib.exec_binary_kernel(iter, "complex_mul", /*supports_dense=*/false); lib.exec_binary_kernel(iter, "complex_mul");
} }
} // namespace mps } // namespace mps
@ -115,8 +113,7 @@ Tensor& complex_out_mps(const Tensor& real, const Tensor& imag, Tensor& output)
if (!output.sizes().equals(new_size)) { if (!output.sizes().equals(new_size)) {
output.resize_(new_size); output.resize_(new_size);
} }
uint32_t length = output.numel(); if (output.numel() == 0) {
if (length == 0) {
return output; return output;
} }
auto output_as_real = at::view_as_real(output).select(output.dim(), 0); auto output_as_real = at::view_as_real(output).select(output.dim(), 0);