mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164522 Approved by: https://github.com/williamwen42 ghstack dependencies: #162903, #164343, #164344, #164507, #162901, #164304
45 lines
1.1 KiB
Python
45 lines
1.1 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import weakref
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
from torch.testing._internal.common_utils import requires_cuda
|
|
|
|
|
|
class TestStreams(torch._dynamo.test_case.TestCase):
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
super().setUpClass()
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
super().tearDownClass()
|
|
|
|
def test_stream_weakref(self):
|
|
s = torch.Stream()
|
|
weakref.ref(s)
|
|
|
|
def test_event_weakref(self):
|
|
e = torch.Event()
|
|
weakref.ref(e)
|
|
|
|
@requires_cuda
|
|
def test_run_opcheck(self):
|
|
from torch._dynamo.variables.streams import fork_stream, join_stream
|
|
from torch.library import opcheck
|
|
|
|
sample_inputs = [
|
|
(0, torch.device("cuda:0"), 1, torch.device("cuda:1")),
|
|
(2, torch.device("cuda:2"), 3, torch.device("cuda:1")),
|
|
]
|
|
for args in sample_inputs:
|
|
opcheck(fork_stream, args)
|
|
opcheck(join_stream, args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|