mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[MPS][BE] Migrate complex_mul to tensor iterator (#149728)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/149728 Approved by: https://github.com/dcci ghstack dependencies: #149727
This commit is contained in:
parent
e35ef61066
commit
64d22b9fad
|
|
@ -83,6 +83,7 @@ struct nextafter_functor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Complex binary functors
|
||||||
struct polar_functor {
|
struct polar_functor {
|
||||||
template <typename U>
|
template <typename U>
|
||||||
using ret_type = c10::metal::vec2type_t<U>;
|
using ret_type = c10::metal::vec2type_t<U>;
|
||||||
|
|
@ -102,6 +103,13 @@ struct make_complex_functor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct complex_mul_functor {
|
||||||
|
template <typename T>
|
||||||
|
inline T operator()(const T a, const T b) {
|
||||||
|
return T(a.x * b.x - a.y * b.y, a.x * b.y + a.y * b.x);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
REGISTER_BINARY_INDEXING_OP(copysign, long);
|
REGISTER_BINARY_INDEXING_OP(copysign, long);
|
||||||
REGISTER_BINARY_INDEXING_OP(copysign, int);
|
REGISTER_BINARY_INDEXING_OP(copysign, int);
|
||||||
REGISTER_BINARY_INDEXING_OP(copysign, float);
|
REGISTER_BINARY_INDEXING_OP(copysign, float);
|
||||||
|
|
@ -135,28 +143,5 @@ REGISTER_BINARY_INDEXING_OP(polar, float);
|
||||||
REGISTER_BINARY_INDEXING_OP(polar, half);
|
REGISTER_BINARY_INDEXING_OP(polar, half);
|
||||||
REGISTER_BINARY_INDEXING_OP(make_complex, float);
|
REGISTER_BINARY_INDEXING_OP(make_complex, float);
|
||||||
REGISTER_BINARY_INDEXING_OP(make_complex, half);
|
REGISTER_BINARY_INDEXING_OP(make_complex, half);
|
||||||
|
REGISTER_BINARY_INDEXING_OP(complex_mul, float2);
|
||||||
template <typename T>
|
REGISTER_BINARY_INDEXING_OP(complex_mul, half2);
|
||||||
kernel void complex_mul(
|
|
||||||
constant void* input_ [[buffer(0)]],
|
|
||||||
constant void* other_ [[buffer(1)]],
|
|
||||||
device void* out_ [[buffer(2)]],
|
|
||||||
constant uint3* offsets [[buffer(3)]],
|
|
||||||
uint tid [[thread_position_in_grid]]) {
|
|
||||||
device T* out = (device T*)((device uint8_t*)out_ + offsets[tid].x);
|
|
||||||
constant T* input = (constant T*)((constant uint8_t*)input_ + offsets[tid].y);
|
|
||||||
constant T* other = (constant T*)((constant uint8_t*)other_ + offsets[tid].z);
|
|
||||||
out[0] = input[0] * other[0] - input[1] * other[1];
|
|
||||||
out[1] = input[0] * other[1] + input[1] * other[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
#define REGISTER_BINARY_OP(NAME, DTYPE) \
|
|
||||||
template [[host_name(#NAME "_" #DTYPE)]] kernel void NAME<DTYPE>( \
|
|
||||||
constant void* input_, \
|
|
||||||
constant void* other_, \
|
|
||||||
device void* out_, \
|
|
||||||
constant uint3* offsets, \
|
|
||||||
uint tid)
|
|
||||||
|
|
||||||
REGISTER_BINARY_OP(complex_mul, float);
|
|
||||||
REGISTER_BINARY_OP(complex_mul, half);
|
|
||||||
|
|
|
||||||
|
|
@ -42,13 +42,11 @@ void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& out
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
auto common_dtype = output.scalar_type();
|
auto common_dtype = output.scalar_type();
|
||||||
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
|
auto input_cast = input.to(kMPS, common_dtype);
|
||||||
auto input_as_real = at::view_as_real(input.to(kMPS, common_dtype)).select(input.dim(), 0);
|
auto other_cast = other.to(kMPS, common_dtype);
|
||||||
auto other_as_real = at::view_as_real(other.to(kMPS, common_dtype)).select(other.dim(), 0);
|
auto iter = TensorIteratorConfig().add_output(output).add_input(input_cast).add_input(other_cast).build();
|
||||||
auto iter =
|
|
||||||
TensorIteratorConfig().add_output(output_as_real).add_input(input_as_real).add_input(other_as_real).build();
|
|
||||||
|
|
||||||
lib.exec_binary_kernel(iter, "complex_mul", /*supports_dense=*/false);
|
lib.exec_binary_kernel(iter, "complex_mul");
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace mps
|
} // namespace mps
|
||||||
|
|
@ -115,8 +113,7 @@ Tensor& complex_out_mps(const Tensor& real, const Tensor& imag, Tensor& output)
|
||||||
if (!output.sizes().equals(new_size)) {
|
if (!output.sizes().equals(new_size)) {
|
||||||
output.resize_(new_size);
|
output.resize_(new_size);
|
||||||
}
|
}
|
||||||
uint32_t length = output.numel();
|
if (output.numel() == 0) {
|
||||||
if (length == 0) {
|
|
||||||
return output;
|
return output;
|
||||||
}
|
}
|
||||||
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
|
auto output_as_real = at::view_as_real(output).select(output.dim(), 0);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user