mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo][DebugMode] mask python keys in dispatch_key_set guard checks (#164992)"
This reverts commit306b344a18. Reverted https://github.com/pytorch/pytorch/pull/164992 on behalf of https://github.com/jeffdaily due to broke ROCm CI test/inductor/test_inductor_scheduler.py::TestSchedulerCUDA::test_flop_counter_op_options0_cuda_float32 [GH job link](https://github.com/pytorch/pytorch/actions/runs/18417066364/job/52485636942) [HUD commit link](306b344a18) ([comment](https://github.com/pytorch/pytorch/pull/164992#issuecomment-3397927142))
This commit is contained in:
parent
4874cce52f
commit
8580112682
|
|
@ -4,7 +4,6 @@ import contextlib
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from torch._dynamo.testing import CompileCounterWithBackend
|
|
||||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
from torch.distributed.tensor import DeviceMesh, DTensor, Partial, Replicate, Shard
|
||||||
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
from torch.distributed.tensor._dtensor_spec import ShardOrderEntry
|
||||||
|
|
@ -322,21 +321,14 @@ class TestDTensorDebugMode(TestCase):
|
||||||
self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string())
|
self.assertIn("torch.ops.higher_order.cond", debug_mode.debug_string())
|
||||||
|
|
||||||
def test_compile(self):
|
def test_compile(self):
|
||||||
cnt = CompileCounterWithBackend("inductor")
|
@torch.compile
|
||||||
|
|
||||||
@torch.compile(backend=cnt)
|
|
||||||
def f(x):
|
def f(x):
|
||||||
return x.sin().cos()
|
return x.sin().cos()
|
||||||
|
|
||||||
x = torch.randn(8)
|
x = torch.randn(8)
|
||||||
with DebugMode() as debug_mode:
|
with DebugMode() as debug_mode:
|
||||||
f(x)
|
f(x)
|
||||||
self.assertEqual(len(debug_mode.debug_string()), 0)
|
self.assertEqual(len(debug_mode.debug_string()), 0)
|
||||||
f(x)
|
|
||||||
f(x)
|
|
||||||
self.assertEqual(
|
|
||||||
cnt.frame_count, 1
|
|
||||||
) # check DebugMode doesn't trigger additional recompilations
|
|
||||||
|
|
||||||
|
|
||||||
instantiate_parametrized_tests(TestDTensorDebugMode)
|
instantiate_parametrized_tests(TestDTensorDebugMode)
|
||||||
|
|
|
||||||
|
|
@ -21,10 +21,7 @@ struct LocalState {
|
||||||
|
|
||||||
at::DispatchKeySet apply(at::DispatchKeySet ks) const {
|
at::DispatchKeySet apply(at::DispatchKeySet ks) const {
|
||||||
if (override_dispatch_key_set.empty()) {
|
if (override_dispatch_key_set.empty()) {
|
||||||
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_ -
|
return (ks | dispatch_modifier.included_) - dispatch_modifier.excluded_;
|
||||||
c10::DispatchKeySet(
|
|
||||||
{c10::DispatchKey::Python,
|
|
||||||
c10::DispatchKey::PythonTLSSnapshot});
|
|
||||||
} else {
|
} else {
|
||||||
return override_dispatch_key_set;
|
return override_dispatch_key_set;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user