#include #include #include void sigmoid_add_kernel(const float* x, const float* y, float* output, const int size, const sycl::nd_item<3> &item_ct1) { const int index = item_ct1.get_group(2) * item_ct1.get_local_range(2) + item_ct1.get_local_id(2); if (index < size) { const float sigmoid_x = 1.0f / (1.0f + sycl::native::exp(-x[index])); const float sigmoid_y = 1.0f / (1.0f + sycl::native::exp(-y[index])); output[index] = sigmoid_x + sigmoid_y; } } class SigmoidAddKernel { public: void operator()(const sycl::nd_item<3> &item_ct1) const { sigmoid_add_kernel(x, y, output, size, item_ct1); } SigmoidAddKernel(const float* _x, const float* _y, float* _output, int _size): x(_x), y(_y), output(_output), size(_size) {} private: const float* x; const float* y; float* output; int size; }; void sigmoid_add_xpu(const float* x, const float* y, float* output, int size) { SigmoidAddKernel krn(x, y, output, size); const int threads = 1024; const int blocks = (size + threads - 1) / threads; sycl::queue& queue = c10::xpu::getCurrentXPUStream().queue(); queue.submit([&](sycl::handler &cgh) { cgh.parallel_for( sycl::nd_range<3>( sycl::range<3>(1, 1, blocks) * sycl::range<3>(1, 1, threads), sycl::range<3>(1, 1, threads)), krn); }); } torch::Tensor sigmoid_add(torch::Tensor x, torch::Tensor y) { TORCH_CHECK(x.device().is_xpu(), "x must be a XPU tensor"); TORCH_CHECK(y.device().is_xpu(), "y must be a XPU tensor"); auto output = torch::zeros_like(x); sigmoid_add_xpu( x.data_ptr(), y.data_ptr(), output.data_ptr(), output.numel()); return output; } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("sigmoid_add", &sigmoid_add, "sigmoid(x) + sigmoid(y)"); }