#define TORCH_ASSERT_ONLY_METHOD_OPERATORS #include #include #include #ifndef AT_PER_OPERATOR_HEADERS #include #include #else #include #include #include #include #include #include #include #endif namespace at::native { // NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor _s_poisson_cuda(const Tensor& lambda, std::optional gen_) { auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); Tensor ret = at::empty(lambda.sizes(), lambda.options()); launch_poisson_cuda_kernel(ret, lambda, gen); return ret; } // NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor _s_binomial_cuda(const Tensor& count, const Tensor& prob, std::optional gen_) { auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); Tensor ret = at::empty(count.sizes(), count.options()); at::TensorIterator iter = at::TensorIteratorConfig() .add_output(ret) .add_input(count) .add_input(prob) .build(); launch_binomial_cuda_kernel(iter, gen); return ret; } // NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor _s_gamma_cuda(const Tensor& alpha, std::optional gen_) { auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); Tensor ret = at::empty(alpha.sizes(), alpha.options()); launch_gamma_kernel(ret, alpha, gen); return ret; } // NOLINTNEXTLINE(performance-unnecessary-value-param) Tensor _s_dirichlet_cuda(const Tensor& alpha, std::optional gen_) { auto gen = get_generator_or_default(gen_, cuda::detail::getDefaultCUDAGenerator()); Tensor ret = at::empty(alpha.sizes(), alpha.options()); launch_gamma_kernel(ret, alpha, gen); auto gamma_sum = ret.sum(/*dim=*/-1, /*keepdim=*/true); at::TensorIterator iter = at::TensorIteratorConfig() .add_output(ret) .add_input(ret) .add_input(gamma_sum) .build(); launch_dirichlet_kernel(iter); return ret; } Tensor _standard_gamma_grad_cuda(const Tensor& self, const Tensor& output) { Tensor ret = at::empty(self.sizes(), self.options()); TensorIterator iter = at::TensorIteratorConfig() .add_output(ret) .add_input(self) .add_input(output) .build(); launch_standard_gamma_grad_kernel(iter); return ret; } Tensor _dirichlet_grad_cuda(const Tensor& x, const Tensor& alpha, const Tensor& total) { Tensor ret = at::empty(x.sizes(), x.options()); TensorIterator iter = at::TensorIteratorConfig() .add_output(ret) .add_input(x) .add_input(alpha) .add_input(total) .build(); launch_dirichlet_grad_kernel(iter); return ret; } } // namespace at::native