mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[user-cuda-streams] Add cuda streams test suite (#162901)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/162901 Approved by: https://github.com/williamwen42 ghstack dependencies: #162903, #164343, #164344, #164507
This commit is contained in:
parent
e8d887ae3f
commit
23669d02a6
35
test/dynamo/test_streams.py
Normal file
35
test/dynamo/test_streams.py
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# Owner(s): ["module: dynamo"]
|
||||
|
||||
import torch
|
||||
import torch._dynamo.test_case
|
||||
import torch._dynamo.testing
|
||||
from torch.testing._internal.common_utils import requires_cuda
|
||||
|
||||
|
||||
class TestStreams(torch._dynamo.test_case.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
super().setUpClass()
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
super().tearDownClass()
|
||||
|
||||
@requires_cuda
|
||||
def test_run_opcheck(self):
|
||||
from torch._dynamo.variables.streams import fork_stream, join_stream
|
||||
from torch.library import opcheck
|
||||
|
||||
sample_inputs = [
|
||||
(0, torch.device("cuda:0"), 1, torch.device("cuda:1")),
|
||||
(2, torch.device("cuda:2"), 3, torch.device("cuda:1")),
|
||||
]
|
||||
for args in sample_inputs:
|
||||
opcheck(fork_stream, args)
|
||||
opcheck(join_stream, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
Loading…
Reference in New Issue
Block a user