mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[cuDNN][TF32] Account for TF32 in test_super_resolution_cuda (#161662)
cuDNN seems to be dispatching to TF32 kernels on B200 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161662 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
196232bb93
commit
2e77a08b95
|
|
@ -7,6 +7,7 @@ import unittest
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.testing._internal.common_cuda import tf32_on_and_off
|
||||||
from torch.testing._internal.common_utils import (
|
from torch.testing._internal.common_utils import (
|
||||||
enable_profiling_mode_for_profiling_tests,
|
enable_profiling_mode_for_profiling_tests,
|
||||||
GRAPH_EXECUTOR,
|
GRAPH_EXECUTOR,
|
||||||
|
|
@ -482,6 +483,7 @@ class TestModels(JitTestCase):
|
||||||
self._test_super_resolution(self, device="cpu")
|
self._test_super_resolution(self, device="cpu")
|
||||||
|
|
||||||
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
@unittest.skipIf(not RUN_CUDA, "no CUDA")
|
||||||
|
@tf32_on_and_off(0.02)
|
||||||
def test_super_resolution_cuda(self):
|
def test_super_resolution_cuda(self):
|
||||||
# XXX: export_import on CUDA modules doesn't work (#11480)
|
# XXX: export_import on CUDA modules doesn't work (#11480)
|
||||||
self._test_super_resolution(self, device="cuda", check_export_import=False)
|
self._test_super_resolution(self, device="cuda", check_export_import=False)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user