mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Add inductor standalone_compile API (#150670)"
This reverts commit c9aef50898.
Reverted https://github.com/pytorch/pytorch/pull/150670 on behalf of https://github.com/Camyll due to breaking internal builds with torch module not found error ([comment](https://github.com/pytorch/pytorch/pull/150670#issuecomment-2806975267))
This commit is contained in:
parent
c0a0761871
commit
74f6bc28a7
|
|
@ -3,8 +3,6 @@ import functools
|
|||
import os
|
||||
import pickle
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import unittest
|
||||
from typing import Optional, Union
|
||||
|
|
@ -13,7 +11,6 @@ from unittest import mock
|
|||
import torch
|
||||
from torch._dynamo import reset
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch._aot_autograd.autograd_cache import AOTAutogradCache
|
||||
from torch._inductor import config, metrics
|
||||
from torch._inductor.codecache import (
|
||||
|
|
@ -1379,131 +1376,6 @@ class TestFxGraphCache(TestCase):
|
|||
)
|
||||
|
||||
|
||||
@instantiate_parametrized_tests
|
||||
class TestStandaloneCompile(TestCase):
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
counters.clear()
|
||||
PatchCaches.setUp()
|
||||
CacheArtifactManager.clear()
|
||||
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
PatchCaches.tearDown()
|
||||
|
||||
def reset(self):
|
||||
AOTAutogradCache.clear()
|
||||
PyCodeCache.cache_clear(purge=True)
|
||||
torch._dynamo.reset()
|
||||
clear_inductor_caches()
|
||||
|
||||
def capture(self, fn):
|
||||
def inner(*args):
|
||||
gm = None
|
||||
actual_args = None
|
||||
kwargs = None
|
||||
|
||||
def backend(gm_, args_, **kwargs_):
|
||||
nonlocal gm
|
||||
nonlocal actual_args
|
||||
nonlocal kwargs
|
||||
gm = gm_
|
||||
actual_args = args_
|
||||
kwargs = kwargs_
|
||||
return gm
|
||||
|
||||
_ = torch.compile(fn, fullgraph=True, backend=backend)(*args)
|
||||
return gm, actual_args, kwargs
|
||||
|
||||
return inner
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
@parametrize("format", ("binary", "unpacked"))
|
||||
@parametrize("dynamic", (False, True))
|
||||
def test_basic(self, format: str, dynamic: bool) -> None:
|
||||
mod = torch.nn.Linear(1, 3)
|
||||
x = torch.randn(4, 1)
|
||||
if dynamic:
|
||||
torch._dynamo.mark_dynamic(x, 0)
|
||||
|
||||
def f(x):
|
||||
with torch.no_grad():
|
||||
return mod(x)
|
||||
|
||||
eager_out = f(x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
path = (
|
||||
temp_dir
|
||||
if format == "unpacked"
|
||||
else os.path.join(temp_dir, "compiled_artifact.bin")
|
||||
)
|
||||
with fresh_inductor_cache():
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
||||
compiled_artifact = torch._inductor.standalone_compile(gm, args)
|
||||
compiled_artifact.save(path=path, format=format)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 0)
|
||||
|
||||
with fresh_inductor_cache():
|
||||
loaded = torch._inductor.CompiledArtifact.load(path=path, format=format)
|
||||
compiled_out = loaded(*args)
|
||||
self.assertEqual(eager_out, compiled_out)
|
||||
|
||||
self.assertEqual(counters["inductor"]["fxgraph_cache_hit"], 1)
|
||||
|
||||
@config.patch({"fx_graph_cache": True})
|
||||
@config.patch({"fx_graph_remote_cache": False})
|
||||
@functorch_config.patch({"enable_autograd_cache": True})
|
||||
def test_different_process(self):
|
||||
x = torch.ones(4, 1)
|
||||
|
||||
def f(x):
|
||||
return x.sin() * 2
|
||||
|
||||
gm, args, kwargs = self.capture(f)(x)
|
||||
assert not kwargs
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
path = os.path.join(temp_dir, "compiled_artifact.bin")
|
||||
|
||||
with fresh_inductor_cache():
|
||||
compiled_artifact = torch._inductor.standalone_compile(gm, args)
|
||||
compiled_artifact.save(path=path)
|
||||
|
||||
script = f"""
|
||||
import torch
|
||||
from torch._inductor.utils import fresh_inductor_cache
|
||||
|
||||
arg = torch.ones(4, 1)
|
||||
with fresh_inductor_cache():
|
||||
loaded = torch._inductor.CompiledArtifact.load(path="{path}")
|
||||
compiled_result = loaded(arg)
|
||||
|
||||
eager_result = arg.sin() * 2
|
||||
|
||||
if not torch.allclose(eager_result, compiled_result, atol=0.1, rtol=0.01):
|
||||
raise RuntimeError("tensors do not match")
|
||||
"""
|
||||
try:
|
||||
subprocess.check_output(
|
||||
[sys.executable, "-c", script],
|
||||
stderr=subprocess.STDOUT,
|
||||
cwd=os.path.dirname(os.path.realpath(__file__)),
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
self.fail(
|
||||
msg=(
|
||||
"Subprocess exception while attempting to run test: "
|
||||
+ e.output.decode("utf-8")
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class TestFxGraphCacheHashing(TestCase):
|
||||
def test_parameter_constants(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -488,9 +488,6 @@ class AOTAutogradCacheEntry:
|
|||
forward_time_taken_ns: int
|
||||
backward_time_taken_ns: int
|
||||
|
||||
# Used by standalone_compile
|
||||
sanitized_aot_config: AOTConfig
|
||||
|
||||
# Turn cache entry into the original callable
|
||||
def wrap_post_compile(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -142,27 +142,6 @@ def aot_dispatch_export(
|
|||
return compiled_fn, fw_metadata
|
||||
|
||||
|
||||
def sanitize_aot_config(input: AOTConfig) -> AOTConfig:
|
||||
return AOTConfig(
|
||||
fw_compiler=None, # type: ignore[arg-type]
|
||||
bw_compiler=None, # type: ignore[arg-type]
|
||||
partition_fn=None, # type: ignore[arg-type]
|
||||
decompositions={},
|
||||
inference_compiler=None,
|
||||
num_params_buffers=input.num_params_buffers,
|
||||
aot_id=input.aot_id,
|
||||
keep_inference_input_mutations=input.keep_inference_input_mutations,
|
||||
is_export=input.is_export,
|
||||
no_tangents=input.no_tangents,
|
||||
aot_autograd_arg_pos_to_source=input.aot_autograd_arg_pos_to_source,
|
||||
dynamic_shapes=input.dynamic_shapes,
|
||||
enable_log=input.enable_log,
|
||||
static_input_indices=input.static_input_indices,
|
||||
pre_dispatch=input.pre_dispatch,
|
||||
cache_info=None,
|
||||
)
|
||||
|
||||
|
||||
def aot_dispatch_base(
|
||||
flat_fn,
|
||||
flat_args: list[Any],
|
||||
|
|
@ -272,7 +251,6 @@ def aot_dispatch_base(
|
|||
indices_of_inps_to_detach=[],
|
||||
forward_time_taken_ns=time_taken_ns,
|
||||
backward_time_taken_ns=0,
|
||||
sanitized_aot_config=sanitize_aot_config(aot_config),
|
||||
)
|
||||
AOTAutogradCache.save(
|
||||
cache_info.cache_key, entry, remote=should_use_remote_autograd_cache()
|
||||
|
|
@ -1327,7 +1305,6 @@ def aot_dispatch_autograd(
|
|||
_indices_of_inps_to_detach,
|
||||
forward_time_taken_ns,
|
||||
backward_time_taken_ns,
|
||||
sanitized_aot_config=sanitize_aot_config(aot_config),
|
||||
)
|
||||
remote = should_use_remote_autograd_cache()
|
||||
AOTAutogradCache.save(cache_info.cache_key, entry, remote)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ treat_parameters_as_free_to_save = True
|
|||
# Applies CSE to the graph before partitioning
|
||||
cse = True
|
||||
|
||||
from torch._environment import is_fbcode
|
||||
from torch._inductor.config import is_fbcode
|
||||
|
||||
|
||||
enable_autograd_cache: bool = Config(
|
||||
|
|
|
|||
|
|
@ -9,8 +9,6 @@ from typing import Any, IO, Optional, TYPE_CHECKING, Union
|
|||
import torch._inductor.config
|
||||
import torch.fx
|
||||
|
||||
from .standalone_compile import CompiledArtifact # noqa: TC001
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.utils import InputType
|
||||
|
|
@ -22,7 +20,6 @@ __all__ = [
|
|||
"list_mode_options",
|
||||
"list_options",
|
||||
"cudagraph_mark_step_begin",
|
||||
"standalone_compile",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -361,34 +358,3 @@ def cudagraph_mark_step_begin():
|
|||
from .cudagraph_trees import mark_step_begin
|
||||
|
||||
mark_step_begin()
|
||||
|
||||
|
||||
def standalone_compile(
|
||||
gm: torch.fx.GraphModule,
|
||||
example_inputs: list[InputType],
|
||||
options: Optional[dict[str, Any]] = None,
|
||||
) -> CompiledArtifact:
|
||||
"""
|
||||
Precompilation API for inductor.
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
compiled_artifact = torch._inductor.standalone_compile(gm, args)
|
||||
compiled_artifact.save(path=path, format="binary")
|
||||
|
||||
# Later on a new process
|
||||
loaded = torch._inductor.CompiledArtifact.load(path=path, format="binary")
|
||||
compiled_out = loaded(*args)
|
||||
|
||||
Args:
|
||||
gm: Graph Module
|
||||
example_inputs: Inputs for the graph module
|
||||
options: Inductor compilation options
|
||||
|
||||
Returns:
|
||||
CompiledArtifact that can be saved to disk or invoked directly.
|
||||
"""
|
||||
from .standalone_compile import standalone_compile
|
||||
|
||||
options = options if options else {}
|
||||
return standalone_compile(gm, example_inputs, **options)
|
||||
|
|
|
|||
|
|
@ -1046,20 +1046,16 @@ class FxGraphCache:
|
|||
# If there's not a cache hit, we don't want the evaluation to
|
||||
# affect the current env, e.g., cause the creation of new guards,
|
||||
# so we evaluate with the hints instead of the symbols.
|
||||
if config.unsafe_skip_cache_dynamic_shape_guards:
|
||||
hit = True
|
||||
else:
|
||||
hit = bool(
|
||||
shape_env.evaluate_guards_expression(candidate.guards_expr, hints)
|
||||
)
|
||||
log.debug(
|
||||
"fx graph cache key %s evaluating guards [%s] with values %s => hit=%s",
|
||||
key,
|
||||
candidate.guards_expr,
|
||||
hints,
|
||||
hit,
|
||||
)
|
||||
|
||||
hit = bool(
|
||||
shape_env.evaluate_guards_expression(candidate.guards_expr, hints)
|
||||
)
|
||||
log.debug(
|
||||
"fx graph cache key %s evaluating guards [%s] with values %s => hit=%s",
|
||||
key,
|
||||
candidate.guards_expr,
|
||||
hints,
|
||||
hit,
|
||||
)
|
||||
if hit:
|
||||
graph = candidate
|
||||
break
|
||||
|
|
@ -1107,7 +1103,7 @@ class FxGraphCache:
|
|||
AutotuneCacheBundler.begin_compile(inductor_meta, code=code)
|
||||
|
||||
# Now re-evaluate with the symints to add any guards to the current env.
|
||||
if not config.unsafe_skip_cache_dynamic_shape_guards and graph.guards_expr:
|
||||
if graph.guards_expr:
|
||||
check = bool(
|
||||
shape_env.evaluate_guards_expression(graph.guards_expr, symints)
|
||||
)
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ from torch._inductor.output_code import (
|
|||
index_expanded_dims,
|
||||
OutputCode,
|
||||
)
|
||||
from torch._inductor.runtime.cache_dir_utils import cache_dir
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.utils import (
|
||||
BoxedBool,
|
||||
count_tangents,
|
||||
|
|
|
|||
|
|
@ -117,9 +117,6 @@ bundled_autotune_remote_cache: Optional[bool] = bundled_autotune_remote_cache_de
|
|||
# Force disabled all inductor level caching -- This will override any other caching flag
|
||||
force_disable_caches: bool = os.environ.get("TORCHINDUCTOR_FORCE_DISABLE_CACHES") == "1"
|
||||
|
||||
# Unsafe way to skip dynamic shape guards to get faster cache load
|
||||
unsafe_skip_cache_dynamic_shape_guards: bool = False
|
||||
|
||||
# sleep in inductor for testing
|
||||
sleep_sec_TESTING_ONLY: Optional[int] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import getpass
|
|||
import os
|
||||
import re
|
||||
import tempfile
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
# Factoring out to file without torch dependencies
|
||||
|
|
@ -33,16 +31,3 @@ def triton_cache_dir(device: int) -> str:
|
|||
"triton",
|
||||
str(device),
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_cache_dir(directory: str) -> Generator[None, None, None]:
|
||||
original = os.environ.get("TORCHINDUCTOR_CACHE_DIR")
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = directory
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if original is None:
|
||||
del os.environ["TORCHINDUCTOR_CACHE_DIR"]
|
||||
else:
|
||||
os.environ["TORCHINDUCTOR_CACHE_DIR"] = original
|
||||
|
|
|
|||
|
|
@ -1,184 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Any, Callable, Literal, Optional, TYPE_CHECKING
|
||||
|
||||
import torch.fx
|
||||
from torch._dynamo.utils import dynamo_timed
|
||||
from torch._inductor.cudagraph_utils import BoxedDeviceIndex
|
||||
from torch._inductor.runtime.cache_dir_utils import temporary_cache_dir
|
||||
from torch._inductor.utils import BoxedBool, InputType
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
|
||||
from . import config
|
||||
from .utils import shape_env_from_inputs
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from torch.compiler._cache import CacheInfo
|
||||
from torch.fx import GraphModule
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompiledArtifact:
|
||||
"""
|
||||
CompiledArtifact class represents the precompiled inductor artifact that
|
||||
can be invoked in order to avoid repeated compilation.
|
||||
|
||||
CompiledArtifact can be obtained by calling standalone_compile(gm, example_inputs)
|
||||
to create a fresh CompiledArtifact from a GraphModule and example inputs.
|
||||
|
||||
Later this CompiledArtifact can be saved to disk, either as a binary or unpacked
|
||||
into the provided folder via the CompiledArtifact.save function.
|
||||
|
||||
CompiledArtifact.load provides a way to create a CompiledArtifact from the
|
||||
binary or unpacked data.
|
||||
|
||||
Finally, the CompiledArtifact can be invoked via the __call__ method
|
||||
to execute the precompiled artifact.
|
||||
"""
|
||||
|
||||
_compiled_fn: Callable[..., Any]
|
||||
_artifacts: Optional[tuple[bytes, CacheInfo]]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compiled_fn: Callable[..., Any],
|
||||
artifacts: Optional[tuple[bytes, CacheInfo]],
|
||||
):
|
||||
self._compiled_fn = compiled_fn
|
||||
self._artifacts = artifacts
|
||||
|
||||
def __call__(self, *args: Any) -> Any:
|
||||
return self._compiled_fn(*args)[0]
|
||||
|
||||
def save(
|
||||
self, *, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> None:
|
||||
with dynamo_timed("CompiledArtifact.save"):
|
||||
if self._artifacts is None:
|
||||
raise RuntimeError(
|
||||
"CompiledArtifact.save failed to save since there's no artifact to save"
|
||||
)
|
||||
artifact_bytes, cache_info = self._artifacts
|
||||
assert len(cache_info.aot_autograd_artifacts) == 1
|
||||
key = cache_info.aot_autograd_artifacts[0]
|
||||
|
||||
if format == "binary":
|
||||
# cant assert that it is a file since it might not exist yet
|
||||
assert not os.path.isdir(path)
|
||||
|
||||
from torch.utils._appending_byte_serializer import BytesWriter
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
writer = BytesWriter(0)
|
||||
writer.write_bytes(torch_key())
|
||||
writer.write_str(key)
|
||||
writer.write_bytes(artifact_bytes)
|
||||
with open(path, "wb") as file:
|
||||
file.write(writer.to_bytes())
|
||||
else:
|
||||
assert format == "unpacked"
|
||||
assert os.path.isdir(path)
|
||||
shutil.rmtree(path, ignore_errors=True)
|
||||
|
||||
with temporary_cache_dir(path):
|
||||
# This function unpacks the cache artifacts to disk
|
||||
torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
*, path: str, format: Literal["binary", "unpacked"] = "binary"
|
||||
) -> CompiledArtifact:
|
||||
with dynamo_timed("CompiledArtifact.load"):
|
||||
if format == "binary":
|
||||
# cant assert that it is a file since it might not exist yet
|
||||
assert not os.path.isdir(path)
|
||||
with open(path, "rb") as file:
|
||||
artifacts = file.read()
|
||||
from torch.utils._appending_byte_serializer import BytesReader
|
||||
|
||||
from .codecache import torch_key
|
||||
|
||||
reader = BytesReader(artifacts)
|
||||
assert reader.read_bytes() == torch_key()
|
||||
key = reader.read_str()
|
||||
artifact_bytes = reader.read_bytes()
|
||||
assert reader.is_finished()
|
||||
|
||||
torch.compiler.load_cache_artifacts(artifact_bytes)
|
||||
|
||||
cache_dir_ctx: AbstractContextManager[None] = nullcontext()
|
||||
else:
|
||||
assert format == "unpacked"
|
||||
assert os.path.isdir(path)
|
||||
autograd_cache_dir = os.path.join(path, "aotautograd")
|
||||
assert os.path.isdir(autograd_cache_dir)
|
||||
files = list(os.listdir(autograd_cache_dir))
|
||||
assert len(files) == 1
|
||||
key = files[0]
|
||||
cache_dir_ctx = temporary_cache_dir(path)
|
||||
|
||||
with cache_dir_ctx:
|
||||
with torch._functorch.config.patch(strict_autograd_cache=True):
|
||||
from torch._functorch._aot_autograd.autograd_cache import (
|
||||
AOTAutogradCache,
|
||||
)
|
||||
|
||||
entry = AOTAutogradCache._lookup(key, local=True, remote=False)
|
||||
|
||||
assert entry is not None
|
||||
|
||||
from .compile_fx import _CompileFxKwargs
|
||||
|
||||
fx_config = _CompileFxKwargs(
|
||||
cudagraphs=BoxedBool(False),
|
||||
boxed_forward_device_index=BoxedDeviceIndex(0),
|
||||
)
|
||||
|
||||
context = torch._guards.TracingContext(
|
||||
FakeTensorMode(shape_env=ShapeEnv())
|
||||
)
|
||||
with (
|
||||
torch._guards.tracing(context),
|
||||
config.patch(unsafe_skip_cache_dynamic_shape_guards=True),
|
||||
):
|
||||
compiled_fn = entry.wrap_post_compile(
|
||||
[], entry.sanitized_aot_config, fx_config
|
||||
)
|
||||
return CompiledArtifact(lambda *args: compiled_fn(list(args)), None)
|
||||
|
||||
|
||||
def standalone_compile(
|
||||
gm: GraphModule, example_inputs: Sequence[InputType], **kwargs: Any
|
||||
) -> CompiledArtifact:
|
||||
from torch.compiler._cache import CacheArtifactManager
|
||||
|
||||
from .compile_fx import compile_fx
|
||||
|
||||
shape_env = shape_env_from_inputs(example_inputs, default=True)
|
||||
assert shape_env is not None
|
||||
|
||||
context = torch._guards.TracingContext(FakeTensorMode(shape_env=shape_env))
|
||||
with torch._guards.tracing(context):
|
||||
with CacheArtifactManager.with_fresh_cache():
|
||||
compiled_fn = compile_fx(gm, example_inputs, **kwargs)
|
||||
assert callable(compiled_fn)
|
||||
|
||||
artifacts = torch.compiler.save_cache_artifacts()
|
||||
if artifacts is None:
|
||||
log.warning(
|
||||
"standalone_compile artifact generation failed, cannot save. "
|
||||
"Run with TORCH_LOGS=+torch._inductor.codecache to identify the problem"
|
||||
)
|
||||
|
||||
return CompiledArtifact(compiled_fn, artifacts)
|
||||
|
|
@ -50,7 +50,6 @@ import sympy
|
|||
|
||||
import torch
|
||||
from torch._inductor.runtime.hints import DeviceProperties
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
|
@ -61,6 +60,7 @@ if TYPE_CHECKING:
|
|||
from torch import SymBool, SymFloat, SymInt
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
||||
from torch.fx.node import Node
|
||||
|
||||
from .codegen.common import WorkspaceArg
|
||||
|
|
@ -2412,9 +2412,7 @@ def run_and_get_cpp_code(
|
|||
return result, s
|
||||
|
||||
|
||||
def shape_env_from_inputs(
|
||||
inputs: Sequence[InputType], default: bool = False
|
||||
) -> Optional[ShapeEnv]:
|
||||
def shape_env_from_inputs(inputs: Sequence[InputType]) -> Optional[ShapeEnv]:
|
||||
fake_mode = detect_fake_mode(inputs)
|
||||
|
||||
# TODO(voz): It would be nice to enable this assert, but there are lots of tests that
|
||||
|
|
@ -2430,9 +2428,6 @@ def shape_env_from_inputs(
|
|||
if isinstance(input, torch.SymInt):
|
||||
return input.node.shape_env
|
||||
|
||||
if default:
|
||||
return ShapeEnv()
|
||||
|
||||
# TODO(voz): Should we always have one anyway?
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,6 @@ import copy
|
|||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
||||
|
|
@ -123,26 +121,6 @@ class CacheArtifactManager:
|
|||
cls._serializer.clear()
|
||||
cls._cache_info.clear()
|
||||
|
||||
@classmethod
|
||||
@contextmanager
|
||||
def with_fresh_cache(cls) -> Generator[None, None, None]:
|
||||
original_new_cache_artifacts = cls._new_cache_artifacts
|
||||
original_seen_artifacts = cls._seen_artifacts
|
||||
original_serializer = cls._serializer
|
||||
original_cache_info = cls._cache_info
|
||||
|
||||
cls._new_cache_artifacts = []
|
||||
cls._seen_artifacts = OrderedSet()
|
||||
cls._serializer = AppendingByteSerializer(serialize_fn=CacheArtifact.serialize)
|
||||
cls._cache_info = CacheInfo()
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
cls._new_cache_artifacts = original_new_cache_artifacts
|
||||
cls._seen_artifacts = original_seen_artifacts
|
||||
cls._serializer = original_serializer
|
||||
cls._cache_info = original_cache_info
|
||||
|
||||
@classmethod
|
||||
def record_artifact(
|
||||
cls,
|
||||
|
|
|
|||
|
|
@ -1404,10 +1404,10 @@ class ProxyTorchDispatchMode(TorchDispatchMode):
|
|||
# ProxyTorchDispatchMode state was (if there was any).
|
||||
# This lets us properly reset the state on exit.
|
||||
self.enter_stack: list[Optional[ProxyTorchDispatchMode]] = []
|
||||
self.decomp_layers: int = 0
|
||||
self.decomp_layers = 0
|
||||
from torch._inductor import config
|
||||
|
||||
self.emulate_precision_casts: bool = config.emulate_precision_casts
|
||||
self.emulate_precision_casts = config.emulate_precision_casts
|
||||
|
||||
@count
|
||||
def __torch_dispatch__(
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user