mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
DO NOT MERGE: debugging ROCm failure
ghstack-source-id: 9391899c34
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34091
This commit is contained in:
parent
b6b2ef20b3
commit
6f6ee5cdb1
|
|
@ -41,7 +41,7 @@ void bernoulli_tensor_cuda_kernel(
|
||||||
|
|
||||||
at::native::gpu_kernel(iter,
|
at::native::gpu_kernel(iter,
|
||||||
[seeds] GPU_LAMBDA (prob_t p) -> scalar_t {
|
[seeds] GPU_LAMBDA (prob_t p) -> scalar_t {
|
||||||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
|
||||||
curandStatePhilox4_32_10_t state;
|
curandStatePhilox4_32_10_t state;
|
||||||
curand_init(
|
curand_init(
|
||||||
seeds.first,
|
seeds.first,
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ void gamma_cuda_kernel(
|
||||||
|
|
||||||
at::native::gpu_kernel(iter,
|
at::native::gpu_kernel(iter,
|
||||||
[seeds] GPU_LAMBDA (scalar_t alpha) {
|
[seeds] GPU_LAMBDA (scalar_t alpha) {
|
||||||
#if defined(__CUDA_ARCH__) || defined(__HIP_DEVICE_COMPILE__)
|
#if defined(__CUDA_ARCH__) || defined(__HIP_PLATFORM_HCC__)
|
||||||
curandStatePhilox4_32_10_t state;
|
curandStatePhilox4_32_10_t state;
|
||||||
curand_init(
|
curand_init(
|
||||||
seeds.first,
|
seeds.first,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user