diff --git a/torch/amp/autocast_mode.py b/torch/amp/autocast_mode.py index b5dff9eb6c4..e8af9aca5a5 100644 --- a/torch/amp/autocast_mode.py +++ b/torch/amp/autocast_mode.py @@ -260,8 +260,8 @@ class autocast: self._cache_enabled = torch.is_autocast_cache_enabled() if ( enabled - and torch.cuda.amp.common.amp_definitely_not_available() and self.device == "cuda" + and torch.cuda.amp.common.amp_definitely_not_available() ): warnings.warn( "User provided device_type of 'cuda', but CUDA is not available. Disabling"