mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo] Constant fold torch.autograd._profiler_enabled (#158482)"
This reverts commit d7e1b8b11d.
Reverted https://github.com/pytorch/pytorch/pull/158482 on behalf of https://github.com/borgstrom due to NCCL hangs in S560336 ([comment](https://github.com/pytorch/pytorch/pull/158482#issuecomment-3268426781))
This commit is contained in:
parent
897c4e70a7
commit
ed77e23b68
|
|
@ -192,47 +192,6 @@ class DynamoProfilerTests(torch._dynamo.test_case.TestCase):
|
|||
],
|
||||
)
|
||||
|
||||
def test_profiler_enabled(self):
|
||||
def fn(x):
|
||||
x = torch.sin(x)
|
||||
if torch.autograd._profiler_enabled():
|
||||
return torch.cos(x)
|
||||
else:
|
||||
return torch.sigmoid(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
with torch.autograd.profiler.profile():
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
def test_profiler_record_function_ignore(self):
|
||||
def fn(x):
|
||||
x = torch.sin(x)
|
||||
if torch.autograd._profiler_enabled():
|
||||
with torch.autograd.profiler.record_function("dummy"):
|
||||
return torch.cos(x)
|
||||
else:
|
||||
return torch.sigmoid(x)
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
with torch.autograd.profiler.profile():
|
||||
ref = fn(x)
|
||||
res = opt_fn(x)
|
||||
self.assertEqual(ref, res)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -177,6 +177,7 @@ manual_torch_name_rule_map: dict[
|
|||
"torch.compiler.is_compiling": TorchInGraphFunctionVariable,
|
||||
"torch.compiler.is_dynamo_compiling": TorchInGraphFunctionVariable,
|
||||
"torch.compiler.is_exporting": TorchInGraphFunctionVariable,
|
||||
"torch.autograd._profiler_enabled": SkipFunctionVariable,
|
||||
"torch._C._to_dlpack": SkipFunctionVariable,
|
||||
"torch.to_dlpack": SkipFunctionVariable,
|
||||
# We graph break on RNG state setters or getters like
|
||||
|
|
@ -2440,7 +2441,6 @@ torch_non_c_binding_in_graph_functions = dict.fromkeys(
|
|||
"torch.atleast_3d",
|
||||
"torch.autograd._calculate_shape",
|
||||
"torch.autograd._is_checkpoint_valid",
|
||||
"torch.autograd._profiler_enabled",
|
||||
"torch.autograd._make_grads",
|
||||
"torch.autograd._register_py_tensor_class_for_device",
|
||||
"torch.autograd._tensor_or_tensors_to_tuple",
|
||||
|
|
|
|||
|
|
@ -149,7 +149,6 @@ constant_fold_functions_need_guards = [
|
|||
torch.cuda.is_initialized,
|
||||
torch.xpu.current_device,
|
||||
torch.xpu.is_initialized,
|
||||
torch.autograd._profiler_enabled,
|
||||
]
|
||||
|
||||
constant_fold_functions = [
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user