diff --git a/test/profiler/test_cpp_thread.py b/test/profiler/test_cpp_thread.py index 5dd12277e18..1e7acc155ec 100644 --- a/test/profiler/test_cpp_thread.py +++ b/test/profiler/test_cpp_thread.py @@ -7,6 +7,7 @@ from unittest import skipIf import torch import torch.utils.cpp_extension +from torch._environment import is_fbcode from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase @@ -21,10 +22,6 @@ def remove_build_path(): shutil.rmtree(default_build_root) -def is_fbcode(): - return not hasattr(torch.version, "git_version") - - if is_fbcode(): import caffe2.test.profiler_test_cpp_thread_lib as cpp # @manual=//caffe2/test:profiler_test_cpp_thread_lib else: diff --git a/test/quantization/pt2e/test_numeric_debugger.py b/test/quantization/pt2e/test_numeric_debugger.py index bed1567d2b6..364345a5c5a 100644 --- a/test/quantization/pt2e/test_numeric_debugger.py +++ b/test/quantization/pt2e/test_numeric_debugger.py @@ -6,6 +6,7 @@ from collections import Counter from typing import Dict import torch +from torch._environment import is_fbcode from torch._export import capture_pre_autograd_graph from torch.ao.quantization import ( compare_results, @@ -40,10 +41,6 @@ def _extract_debug_handles(model) -> Dict[str, int]: return debug_handle_map -def is_fbcode(): - return not hasattr(torch.version, "git_version") - - @unittest.skipIf(IS_WINDOWS, "Windows not yet supported for torch.compile") class TestNumericDebugger(TestCase): def test_simple(self): diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index 6e0375fa0fa..d3b4becf36e 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -9,10 +9,7 @@ from os.path import abspath, dirname from typing import Any, Callable, Dict, Optional, Set, Type, TYPE_CHECKING, Union import torch - - -def is_fbcode(): - return not hasattr(torch.version, "git_version") +from torch._environment import is_fbcode # to configure logging for dynamo, aot, and inductor diff --git a/torch/_environment.py b/torch/_environment.py new file mode 100644 index 00000000000..65cbd5d35ad --- /dev/null +++ b/torch/_environment.py @@ -0,0 +1,2 @@ +def is_fbcode() -> bool: + return False diff --git a/torch/_inductor/config.py b/torch/_inductor/config.py index aaa8d9bf6b1..455d050e907 100644 --- a/torch/_inductor/config.py +++ b/torch/_inductor/config.py @@ -3,10 +3,7 @@ import sys from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING, Union import torch - - -def is_fbcode() -> bool: - return not hasattr(torch.version, "git_version") +from torch._environment import is_fbcode def _get_tristate_env(name: str) -> Optional[bool]: