mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS][BE] Migrate complex_mul to tensor iterator (#149728)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149728 Approved by: https://github.com/dcci ghstack dependencies: #149727
This commit is contained in:
parent
e35ef61066
commit
64d22b9fad
|
|
@ -83,6 +83,7 @@ struct nextafter_functor {
|
|||
}
|
||||
};
|
||||
|
||||
// Complex binary functors
|
||||
struct polar_functor {
|
||||
template <typename U>
|
||||
using ret_type = c10::metal::vec2type_t<U>;
|
||||
|
|
@ -102,6 +103,13 @@ struct make_complex_functor {
|
|||
}
|
||||
};
|
||||
|
||||
struct complex_mul_functor {
|
||||
template <typename T>
|
||||
inline T operator()(const T a, const T b) {
|
||||
return T(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
|
||||
}
|
||||
};
|
||||
|
||||
REGISTER_BINARY_INDEXING_OP(copysign, long);
|
||||
REGISTER_BINARY_INDEXING_OP(copysign, int);
|
||||
REGISTER_BINARY_INDEXING_OP(copysign, float);
|
||||
|
|
@ -135,28 +143,5 @@ REGISTER_BINARY_INDEXING_OP(polar, float);
|
|||
REGISTER_BINARY_INDEXING_OP(polar, half);
|
||||
REGISTER_BINARY_INDEXING_OP(make_complex, float);
|
||||
REGISTER_BINARY_INDEXING_OP(make_complex, half);
|
||||
|
||||
template <typename T>
|
||||
kernel void complex_mul(
|
||||
constant void* input_ [[buffer(0)]],
|
||||
constant void* other_ [[buffer(1)]],
|
||||
device void* out_ [[buffer(2)]],
|
||||
constant uint3* offsets [[buffer(3)]],
|
||||
uint tid [[thread_position_in_grid]]) {
|
||||
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
|
||||
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
|
||||
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
|
||||
out[0] = input[0] * other[0] - input[1] * other[1];
|
||||
out[1] = input[0] * other[1] + input[1] * other[0];
|
||||
}
|
||||
|
||||
#define REGISTER_BINARY_OP(NAME, DTYPE) \
|
||||
template [[host_name(#NAME "_" #DTYPE)]] kernel void NAME<DTYPE>( \
|
||||
constant void* input_, \
|
||||
constant void* other_, \
|
||||
device void* out_, \
|
||||
constant uint3* offsets, \
|
||||
uint tid)
|
||||
|
||||
REGISTER_BINARY_OP(complex_mul, float);
|
||||
REGISTER_BINARY_OP(complex_mul, half);
|
||||
REGISTER_BINARY_INDEXING_OP(complex_mul, float2);
|
||||
REGISTER_BINARY_INDEXING_OP(complex_mul, half2);
|
||||
|
|
|
|||
|
|
@ -42,13 +42,11 @@ void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& out
|
|||
return;
|
||||
}
|
||||
auto common_dtype = output.scalar_type();
|
||||
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
|
||||
auto input_as_real = at::view_as_real(input.to(kMPS, common_dtype)).select(input.dim(), 0);
|
||||
auto other_as_real = at::view_as_real(other.to(kMPS, common_dtype)).select(other.dim(), 0);
|
||||
auto iter =
|
||||
TensorIteratorConfig().add_output(output_as_real).add_input(input_as_real).add_input(other_as_real).build();
|
||||
auto input_cast = input.to(kMPS, common_dtype);
|
||||
auto other_cast = other.to(kMPS, common_dtype);
|
||||
auto iter = TensorIteratorConfig().add_output(output).add_input(input_cast).add_input(other_cast).build();
|
||||
|
||||
lib.exec_binary_kernel(iter, "complex_mul", /*supports_dense=*/false);
|
||||
lib.exec_binary_kernel(iter, "complex_mul");
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
|
@ -115,8 +113,7 @@ Tensor& complex_out_mps(const Tensor& real, const Tensor& imag, Tensor& output)
|
|||
if (!output.sizes().equals(new_size)) {
|
||||
output.resize_(new_size);
|
||||
}
|
||||
uint32_t length = output.numel();
|
||||
if (length == 0) {
|
||||
if (output.numel() == 0) {
|
||||
return output;
|
||||
}
|
||||
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user