[Easy] Fix the function signature of torch.Event (#151221)

As the title stated.

The difference between declaration and implemention.
declaration:
d5a19e4525/torch/_C/__init__.pyi.in (L157-L162)

Implementation:
d5a19e4525/torch/csrc/Event.cpp (L30-L32)

**Question**: Which one should we choose?
- Change enable_timing to False to be consistent with torch.cuda.Event
- Change enable_timing to True to avoid BC-break
Pull Request resolved: https://github.com/pytorch/pytorch/pull/151221
Approved by: https://github.com/albanD
ghstack dependencies: #151226
This commit is contained in:
FFFrog 2025-04-19 11:18:36 +08:00 committed by PyTorch MergeBot
parent 8e5fefedf4
commit 92baeecbdd
3 changed files with 24 additions and 2 deletions

View File

@ -129,6 +129,17 @@ class TestAccelerator(TestCase):
):
event1.elapsed_time(event2)
# check default value of enable_timing: False
event1 = torch.Event()
event2 = torch.Event()
event1.record()
event2.record()
with self.assertRaisesRegex(
ValueError,
"Both events must be created with argument 'enable_timing=True'",
):
event1.elapsed_time(event2)
if __name__ == "__main__":
run_tests()

View File

@ -985,6 +985,17 @@ class TestCuda(TestCase):
):
event1.elapsed_time(event2)
# check default value of enable_timing: False
event1 = torch.cuda.Event()
event2 = torch.cuda.Event()
event1.record()
event2.record()
with self.assertRaisesRegex(
ValueError,
"Both events must be created with argument 'enable_timing=True'",
):
event1.elapsed_time(event2)
def test_generic_stream_event(self):
stream = torch.Stream("cuda")
self.assertEqual(stream.device_index, torch.cuda.current_device())

View File

@ -28,7 +28,7 @@ static PyObject* THPEvent_pynew(
unsigned char interprocess = 0;
static torch::PythonArgParser parser({
"Event(Device device=None, *, bool enable_timing=True, bool blocking=False, bool interprocess=False)",
"Event(Device device=None, *, bool enable_timing=False, bool blocking=False, bool interprocess=False)",
});
torch::ParsedArgs<4> parsed_args;
@ -39,7 +39,7 @@ static PyObject* THPEvent_pynew(
if (!device.has_value()) {
device = at::Device(at::getAccelerator(false).value_or(at::kCPU));
}
enable_timing = r.toBoolWithDefault(1, true);
enable_timing = r.toBoolWithDefault(1, false);
blocking = r.toBoolWithDefault(2, false);
interprocess = r.toBoolWithDefault(3, false);