mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Enabling the nccl/rccl test for ROCM environment (#32340)
Summary: Enabling the RCCL test on rocm by adding a temporary grace period to clean up. Pull Request resolved: https://github.com/pytorch/pytorch/pull/32340 Differential Revision: D19744459 Pulled By: xw285cornell fbshipit-source-id: 1af3b64113a67f93e622d010ddd3020e5d6c8bc8
This commit is contained in:
parent
e8581869f2
commit
908b451efb
|
|
@ -85,7 +85,7 @@ fi
|
|||
EXTRA_TESTS=()
|
||||
|
||||
# CUDA builds always include NCCL support
|
||||
if [[ "$BUILD_ENVIRONMENT" == *-cuda* ]]; then
|
||||
if [[ "$BUILD_ENVIRONMENT" == *-cuda* ]] || [[ "$BUILD_ENVIRONMENT" == *-rocm* ]]; then
|
||||
EXTRA_TESTS+=("$caffe2_pypath/contrib/nccl")
|
||||
fi
|
||||
|
||||
|
|
|
|||
|
|
@ -28,8 +28,13 @@ class NCCLContext {
|
|||
// get stream priorities
|
||||
int lo_pri, hi_pri;
|
||||
CUDA_ENFORCE(cudaDeviceGetStreamPriorityRange(&lo_pri, &hi_pri));
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
CUDA_ENFORCE(cudaStreamCreateWithPriority(
|
||||
&streams_[i], cudaStreamNonBlocking, hi_pri));
|
||||
#else
|
||||
CUDA_ENFORCE(cudaStreamCreateWithFlags(
|
||||
&streams_[i], cudaStreamNonBlocking));
|
||||
#endif // __HIP_PLATFORM_HCC__
|
||||
CUDA_ENFORCE(cudaEventCreateWithFlags(
|
||||
&events_[i], cudaEventDefault | cudaEventDisableTiming));
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user