mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129767 Approved by: https://github.com/anijain2305
This commit is contained in:
parent
ad9826208c
commit
e74ba1b34a
|
|
@ -48,7 +48,6 @@ ISORT_SKIPLIST = re.compile(
|
|||
# torch/**
|
||||
# torch/_[a-c]*/**
|
||||
# torch/_d*/**
|
||||
"torch/_d*/**",
|
||||
# torch/_[e-h]*/**
|
||||
# torch/_i*/**
|
||||
# torch/_[j-z]*/**
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
|
|||
from torch._prims_common import CustomOutParamAnnotation
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
|
||||
__all__ = [
|
||||
"decomposition_table",
|
||||
"pre_autograd_decomposition_table",
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ from torch._prims_common.wrappers import (
|
|||
from torch.utils import _pytree as pytree
|
||||
from torch.utils._pytree import tree_map
|
||||
|
||||
|
||||
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
|
||||
|
||||
# None of these functions are publicly accessible; get at them
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import torch._decomp
|
|||
from torch import Tensor
|
||||
from torch._prims_common.wrappers import _maybe_remove_out_wrapper
|
||||
|
||||
|
||||
decomposition_table = torch._decomp.decomposition_table
|
||||
decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
|
||||
register_decomposition = torch._decomp.register_decomposition
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import torch._decomp as decomp
|
|||
from torch._decomp import get_decompositions
|
||||
from torch._ops import OpOverload
|
||||
|
||||
|
||||
aten = torch.ops.aten
|
||||
|
||||
rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import torch._ops
|
|||
import torch.utils._python_dispatch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
|
||||
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
|
||||
|
||||
no_python_dispatcher = torch._C._DisablePythonDispatcher
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import torch
|
||||
|
||||
from . import convert_frame, eval_frame, resume_execution
|
||||
from .backends.registry import list_backends, lookup_backend, register_backend
|
||||
from .callback import callback_handler, on_compile_end, on_compile_start
|
||||
|
|
@ -32,6 +33,7 @@ from .external_utils import is_compiling
|
|||
from .mutation_guard import GenerationTracker
|
||||
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
|
||||
|
||||
|
||||
__all__ = [
|
||||
"allow_in_graph",
|
||||
"assume_constant_result",
|
||||
|
|
|
|||
|
|
@ -2,11 +2,9 @@
|
|||
import torch
|
||||
from torch._C import DispatchKey
|
||||
from torch._higher_order_ops.utils import autograd_not_implemented
|
||||
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses import FakeTensorMode
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
|
||||
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
|
||||
from torch.utils._python_dispatch import _get_current_dispatch_mode
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from torch._dynamo.utils import counters, defake, flatten_graph_inputs
|
|||
from torch._functorch.aot_autograd import aot_module_simplified
|
||||
from torch.utils._python_dispatch import _disable_current_modes
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -53,7 +54,6 @@ class AotAutograd:
|
|||
)
|
||||
|
||||
from functorch.compile import nop
|
||||
|
||||
from torch._inductor.debug import enable_aot_logging
|
||||
|
||||
# debug asserts slow down compile time noticeably,
|
||||
|
|
|
|||
|
|
@ -23,8 +23,8 @@ from torch._inductor.utils import (
|
|||
num_fw_fixed_arguments,
|
||||
output_node,
|
||||
)
|
||||
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
||||
from .registry import register_backend
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,14 +6,15 @@ from importlib import import_module
|
|||
from typing import Any, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from functorch.compile import min_cut_rematerialization_partition
|
||||
from torch import _guards
|
||||
from torch._functorch import config as functorch_config
|
||||
from torch._functorch.compilers import ts_compile
|
||||
|
||||
from .common import aot_autograd
|
||||
from .registry import register_debug_backend as register_backend
|
||||
|
||||
|
||||
"""
|
||||
This file contains TorchDynamo backends intended for debugging uses.
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
|
|||
from torch._logging import trace_structured
|
||||
from torch.fx.node import Node
|
||||
|
||||
|
||||
# Regular log messages should go through 'log'.
|
||||
# ddp_graph_log is a separate artifact logger reserved for dumping graphs.
|
||||
# See docs/source/logging.rst for more info.
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from torch.onnx._internal.onnxruntime import (
|
|||
is_onnxrt_backend_supported,
|
||||
torch_compile_backend,
|
||||
)
|
||||
|
||||
from .registry import register_backend
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from functorch.compile import make_boxed_func
|
|||
from ..backends.common import aot_autograd
|
||||
from .registry import register_backend, register_experimental_backend
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,10 +10,11 @@ from types import MappingProxyType
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from .common import device_from_inputs, fake_tensor_unsupported
|
||||
|
||||
from .common import device_from_inputs, fake_tensor_unsupported
|
||||
from .registry import register_backend
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import dis
|
|||
import sys
|
||||
from typing import Any, Set, Union
|
||||
|
||||
|
||||
TERMINAL_OPCODES = {
|
||||
dis.opmap["RETURN_VALUE"],
|
||||
dis.opmap["JUMP_FORWARD"],
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from torch._guards import CompileId
|
|||
|
||||
from . import config
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
"""
|
||||
[Note on cache size limit]
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@ import types
|
|||
from typing import Counter, Dict, List, Optional
|
||||
|
||||
import torch.nn
|
||||
from . import utils
|
||||
|
||||
from . import utils
|
||||
from .bytecode_transformation import (
|
||||
add_push_null,
|
||||
create_call_function,
|
||||
|
|
|
|||
|
|
@ -29,9 +29,11 @@ from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
|
|||
from torch.fx.traceback import preserve_node_meta, set_stack_trace
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch.fx.proxy import Proxy
|
||||
|
||||
|
||||
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
|
||||
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
|
||||
|
||||
|
|
|
|||
|
|
@ -472,4 +472,5 @@ if TYPE_CHECKING:
|
|||
|
||||
from torch.utils._config_module import install_config_module
|
||||
|
||||
|
||||
install_config_module(sys.modules[__name__])
|
||||
|
|
|
|||
|
|
@ -23,23 +23,17 @@ from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union
|
|||
from typing_extensions import ParamSpec
|
||||
from weakref import ReferenceType
|
||||
|
||||
from torch._utils_internal import maybe_upload_prof_stats_to_manifold
|
||||
|
||||
from torch.fx._lazy_graph_module import _use_lazy_graph_module
|
||||
from torch.utils._traceback import CapturedTraceback
|
||||
|
||||
np: Optional[ModuleType]
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None
|
||||
|
||||
import torch
|
||||
import torch._logging
|
||||
from torch._dynamo.distributed import get_compile_pg
|
||||
from torch._guards import compile_context, CompileContext, CompileId, tracing
|
||||
from torch._logging import structured
|
||||
from torch._utils_internal import compile_time_strobelight_meta, signpost_event
|
||||
from torch._utils_internal import (
|
||||
compile_time_strobelight_meta,
|
||||
maybe_upload_prof_stats_to_manifold,
|
||||
signpost_event,
|
||||
)
|
||||
from torch.fx._lazy_graph_module import _use_lazy_graph_module
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
GuardOnDataDependentSymNode,
|
||||
|
|
@ -47,7 +41,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
|
||||
from torch.nn.parallel.distributed import DistributedDataParallel
|
||||
from torch.utils._python_dispatch import _disable_current_modes
|
||||
from torch.utils._traceback import format_traceback_short
|
||||
from torch.utils._traceback import CapturedTraceback, format_traceback_short
|
||||
|
||||
from . import config, exc, trace_rules
|
||||
from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
|
||||
|
|
@ -109,12 +103,21 @@ from .utils import (
|
|||
write_record_to_file,
|
||||
)
|
||||
|
||||
|
||||
np: Optional[ModuleType]
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .backends.registry import CompilerFn
|
||||
from .repro.after_dynamo import WrapBackendDebug
|
||||
from .types import BytecodeHook, CacheEntry
|
||||
from .variables.builder import FrameStateSizeEntry
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
|
||||
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
||||
|
|
@ -536,6 +539,7 @@ from collections import OrderedDict
|
|||
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .output_graph import OutputGraph
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from contextlib import contextmanager
|
|||
|
||||
import torch
|
||||
|
||||
|
||||
doc = """
|
||||
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
|
||||
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
import contextlib
|
||||
import threading
|
||||
|
||||
|
||||
# Global variable to identify which SubgraphTracer we are in.
|
||||
# It is sometimes difficult to find an InstructionTranslator to use.
|
||||
_current_scope_id = threading.local()
|
||||
|
|
|
|||
|
|
@ -20,7 +20,6 @@ import torch
|
|||
import torch._prims_common as utils
|
||||
import torch._subclasses.meta_utils
|
||||
from torch import Tensor
|
||||
|
||||
from torch._dynamo.testing import rand_strided
|
||||
from torch._prims_common import is_float_dtype
|
||||
from torch.multiprocessing.reductions import StorageWeakRef
|
||||
|
|
@ -29,6 +28,7 @@ from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
|
|||
from . import config
|
||||
from .utils import clone_inputs, get_debug_dir
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
|
|
|||
|
|
@ -5,12 +5,14 @@ from typing import TYPE_CHECKING
|
|||
|
||||
import torch
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from . import trace_rules, variables
|
||||
from .comptime import comptime
|
||||
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
|
||||
from .exc import IncorrectUsage
|
||||
from .external_utils import is_compiling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._C._dynamo.eval_frame import ( # noqa: F401
|
||||
reset_code,
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
|
|||
import torch
|
||||
from torch._streambase import _EventBase, _StreamBase
|
||||
|
||||
|
||||
get_cuda_stream: Optional[Callable[[int], int]]
|
||||
if torch.cuda._is_compiled():
|
||||
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
from . import config
|
||||
|
||||
|
||||
_COMPILE_PG: Optional[dist.ProcessGroup] = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -37,13 +37,17 @@ from typing import (
|
|||
)
|
||||
from unittest.mock import patch
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
import torch.fx
|
||||
import torch.utils._pytree as pytree
|
||||
import torch.utils.checkpoint
|
||||
from torch import _guards
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._utils_internal import justknobs_check, log_export_usage
|
||||
from torch.export.dynamic_shapes import _process_dynamic_shapes
|
||||
from torch.fx import GraphModule
|
||||
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
ConstraintViolationError,
|
||||
|
|
@ -53,10 +57,20 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
)
|
||||
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
|
||||
|
||||
from ..fx import GraphModule
|
||||
from . import config, convert_frame, external_utils, trace_rules, utils
|
||||
from .backends.registry import CompilerFn, lookup_backend
|
||||
|
||||
from .code_context import code_context
|
||||
from .exc import CondOpArgsMismatchError, UserError, UserErrorType
|
||||
from .hooks import Hooks
|
||||
from .mutation_guard import install_generation_tagging_init
|
||||
from .utils import common_constant_types, compile_times
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses import fake_tensor
|
||||
|
||||
from .types import CacheEntry, DynamoCallback
|
||||
|
||||
|
||||
# see discussion at https://github.com/pytorch/pytorch/issues/120699
|
||||
reset_code = torch._C._dynamo.eval_frame.reset_code # noqa: F401
|
||||
|
|
@ -65,27 +79,14 @@ set_guard_error_hook = torch._C._dynamo.eval_frame.set_guard_error_hook # noqa:
|
|||
skip_code = torch._C._dynamo.eval_frame.skip_code # noqa: F401
|
||||
unsupported = torch._C._dynamo.eval_frame.unsupported # noqa: F401
|
||||
|
||||
from . import config, convert_frame, external_utils, trace_rules, utils
|
||||
from .code_context import code_context
|
||||
from .exc import CondOpArgsMismatchError, UserError, UserErrorType
|
||||
from .mutation_guard import install_generation_tagging_init
|
||||
from .utils import common_constant_types, compile_times
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
|
||||
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
|
||||
null_context = contextlib.nullcontext
|
||||
|
||||
|
||||
import sympy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._subclasses import fake_tensor
|
||||
from .types import CacheEntry, DynamoCallback
|
||||
|
||||
|
||||
# See https://github.com/python/typing/pull/240
|
||||
class Unset(Enum):
|
||||
token = 0
|
||||
|
|
@ -1627,7 +1628,7 @@ class TorchPatcher:
|
|||
)
|
||||
torch.distributions.Distribution.set_default_validate_args(False)
|
||||
|
||||
from ..optim import (
|
||||
from torch.optim import (
|
||||
adadelta,
|
||||
adagrad,
|
||||
adam,
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@ from typing import Any, cast, NoReturn, Optional, Tuple, TYPE_CHECKING
|
|||
import torch._guards
|
||||
|
||||
from . import config
|
||||
|
||||
from .utils import counters
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._guards import CompileId
|
||||
|
||||
|
|
@ -25,6 +25,7 @@ def exportdb_error_message(case_name):
|
|||
|
||||
import logging
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from typing import List
|
|||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import tokenize
|
||||
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
|
||||
cache: Dict[str, Dict[int, str]] = {}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -35,12 +35,6 @@ from typing import (
|
|||
)
|
||||
from weakref import ReferenceType
|
||||
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
import torch
|
||||
import torch.utils._device
|
||||
from torch._dynamo.source import (
|
||||
|
|
@ -60,7 +54,6 @@ from torch._guards import (
|
|||
GuardSource,
|
||||
Source,
|
||||
)
|
||||
|
||||
from torch._logging import structured
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
EqualityConstraint,
|
||||
|
|
@ -72,7 +65,6 @@ from torch.utils.weak import TensorWeakRef
|
|||
|
||||
from . import config, convert_frame, exc, mutation_guard
|
||||
from .eval_frame import set_guard_error_hook
|
||||
|
||||
from .source import (
|
||||
AttrSource,
|
||||
ChainedSource,
|
||||
|
|
@ -115,9 +107,17 @@ from .utils import (
|
|||
verify_guard_fn_signature,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sympy import Symbol
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
guards_log = torch._logging.getArtifactLogger(__name__, "guards")
|
||||
recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles")
|
||||
|
|
@ -310,7 +310,6 @@ if sys.version_info[:2] <= (3, 8):
|
|||
HAS_UNPARSE_FUNCTIONS = True
|
||||
except ImportError:
|
||||
HAS_UNPARSE_FUNCTIONS = False
|
||||
pass
|
||||
else:
|
||||
HAS_UNPARSE_FUNCTIONS = True
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import dataclasses
|
||||
|
||||
from typing import Callable, Optional
|
||||
|
||||
from torch._guards import GuardsSet
|
||||
|
||||
from .types import GuardFail
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import logging
|
|||
|
||||
from torch.hub import _Faketqdm, tqdm
|
||||
|
||||
|
||||
# Disable progress bar by default, not in dynamo config because otherwise get a circular import
|
||||
disable_progress = True
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import weakref
|
|||
|
||||
import torch.nn
|
||||
from torch.nn import Module
|
||||
from . import config
|
||||
|
||||
from . import config
|
||||
from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -18,10 +18,8 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKIN
|
|||
import sympy
|
||||
|
||||
import torch._guards
|
||||
|
||||
import torch._logging
|
||||
import torch.distributed as dist
|
||||
|
||||
import torch.nn
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import fx
|
||||
|
|
@ -104,12 +102,13 @@ from .variables.tensor import (
|
|||
TensorVariable,
|
||||
UnspecializedPythonVariable,
|
||||
)
|
||||
|
||||
from .variables.torch_function import TensorWithTFOverrideVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
|
||||
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from typing import Any, Dict
|
|||
|
||||
from torch.utils._import_utils import import_dill
|
||||
|
||||
|
||||
dill = import_dill()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from torch.hub import tqdm
|
|||
|
||||
from .. import config
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from typing import Union
|
|||
|
||||
import torch
|
||||
import torch.fx as fx
|
||||
|
||||
from torch._dynamo.debug_utils import (
|
||||
AccuracyError,
|
||||
backend_accuracy_fails,
|
||||
|
|
@ -36,6 +35,7 @@ from .. import config
|
|||
from ..backends.registry import lookup_backend, register_debug_backend
|
||||
from ..debug_utils import clone_inputs_retaining_gradness
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ from .bytecode_transformation import (
|
|||
)
|
||||
from .utils import ExactWeakKeyDictionary
|
||||
|
||||
|
||||
# taken from code.h in cpython
|
||||
CO_OPTIMIZED = 0x0001
|
||||
CO_NEWLOCALS = 0x0002
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from . import utils
|
|||
from .bytecode_transformation import create_call_function, create_instruction
|
||||
from .utils import enum_repr
|
||||
|
||||
|
||||
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
|
||||
# so those cases are omitted intentionally
|
||||
_GUARD_SOURCE_NN_MODULE = {
|
||||
|
|
|
|||
|
|
@ -110,6 +110,7 @@ from .variables.user_defined import (
|
|||
UserDefinedObjectVariable,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
|
||||
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ from torch._prims import _make_prim, RETURN_TYPE
|
|||
from torch._subclasses import FakeTensorMode
|
||||
from torch._subclasses.functional_tensor import FunctionalTensorMode
|
||||
|
||||
|
||||
_tensor_version = _make_prim(
|
||||
schema="_tensor_version(Tensor self) -> SymInt",
|
||||
return_type=RETURN_TYPE.NEW,
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
|
|||
|
||||
from . import config, reset, utils
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,12 +12,6 @@ import unittest
|
|||
from typing import List, Optional, Sequence, Union
|
||||
from unittest.mock import patch
|
||||
|
||||
np: Optional[types.ModuleType] = None
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None
|
||||
|
||||
import torch
|
||||
from torch import fx
|
||||
from torch._dynamo.output_graph import OutputGraph
|
||||
|
|
@ -32,6 +26,14 @@ from .bytecode_transformation import (
|
|||
from .guards import CheckFunctionManager, CompileId, GuardedCode
|
||||
from .utils import same
|
||||
|
||||
|
||||
np: Optional[types.ModuleType] = None
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None
|
||||
|
||||
|
||||
unsupported = eval_frame.unsupported
|
||||
three = 3
|
||||
|
||||
|
|
|
|||
|
|
@ -35,20 +35,14 @@ import weakref
|
|||
from collections import defaultdict
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Union
|
||||
|
||||
np: Optional[types.ModuleType] = None
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
import torch
|
||||
import torch._inductor.test_operators
|
||||
import torch.distributed
|
||||
import torch.utils._content_store
|
||||
from ..utils import _config_module
|
||||
from torch.utils import _config_module
|
||||
|
||||
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
|
||||
from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
|
||||
|
||||
from .variables import (
|
||||
BuiltinVariable,
|
||||
FunctionalCallVariable,
|
||||
|
|
@ -61,6 +55,13 @@ from .variables import (
|
|||
)
|
||||
|
||||
|
||||
np: Optional[types.ModuleType] = None
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
pass
|
||||
|
||||
|
||||
if typing.TYPE_CHECKING:
|
||||
from .variables.base import VariableTracker
|
||||
|
||||
|
|
|
|||
|
|
@ -4,6 +4,9 @@ import types
|
|||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import torch
|
||||
from torch._guards import CompileId
|
||||
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from torch._C._dynamo import eval_frame
|
||||
|
|
@ -12,8 +15,6 @@ if sys.version_info >= (3, 11):
|
|||
else:
|
||||
DynamoFrameType: TypeAlias = types.FrameType
|
||||
|
||||
import torch
|
||||
from torch._guards import CompileId
|
||||
|
||||
# This class has a `check_fn` field for the guard,
|
||||
# and a `code` field for the code object.
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import dis
|
|||
import enum
|
||||
import functools
|
||||
import gc
|
||||
import importlib
|
||||
import inspect
|
||||
import itertools
|
||||
import linecache
|
||||
|
|
@ -52,10 +53,21 @@ from typing import (
|
|||
)
|
||||
from typing_extensions import Literal, ParamSpec, TypeGuard
|
||||
|
||||
from ..utils.hooks import RemovableHandle
|
||||
import torch
|
||||
import torch._functorch.config
|
||||
import torch._inductor.config as inductor_config
|
||||
import torch.fx.experimental.symbolic_shapes
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._guards import TracingContext
|
||||
from torch._subclasses.meta_utils import is_sparse_compressed
|
||||
from torch._utils_internal import log_compilation_event
|
||||
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
|
||||
from torch.nn.modules.lazy import LazyModuleMixin
|
||||
from torch.utils._triton import has_triton, has_triton_package
|
||||
from torch.utils.hooks import RemovableHandle
|
||||
|
||||
T = TypeVar("T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
|
@ -67,6 +79,7 @@ try:
|
|||
import torch._numpy as tnp
|
||||
from torch._guards import detect_fake_mode # noqa: F401n
|
||||
from torch._logging import LazyString
|
||||
|
||||
from . import config
|
||||
|
||||
# NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync.
|
||||
|
|
@ -92,23 +105,9 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
|
||||
import importlib
|
||||
|
||||
import torch
|
||||
import torch._functorch.config
|
||||
import torch._inductor.config as inductor_config
|
||||
import torch.fx.experimental.symbolic_shapes
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import fx
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._guards import TracingContext
|
||||
from torch._subclasses.meta_utils import is_sparse_compressed
|
||||
from torch._utils_internal import log_compilation_event
|
||||
|
||||
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
|
||||
from torch.nn.modules.lazy import LazyModuleMixin
|
||||
from torch.utils._triton import has_triton, has_triton_package
|
||||
|
||||
T = TypeVar("T")
|
||||
_P = ParamSpec("_P")
|
||||
|
||||
unpatched_nn_module_getattr = torch.nn.Module.__getattr__
|
||||
|
||||
|
|
@ -1850,6 +1849,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
|
|||
by further wrapping them as this graph's fakes.
|
||||
"""
|
||||
from torch.utils._sympy.value_ranges import ValueRangeError
|
||||
|
||||
from .exc import (
|
||||
TorchRuntimeError,
|
||||
unimplemented,
|
||||
|
|
@ -2403,6 +2403,7 @@ def is_utils_checkpoint(obj):
|
|||
|
||||
def build_checkpoint_variable(**options):
|
||||
import torch._higher_order_ops.wrap as higher_order_ops
|
||||
|
||||
from .variables.higher_order_ops import TorchHigherOrderOperatorVariable
|
||||
|
||||
# TODO - This is a temporary situation where we have two versions of
|
||||
|
|
|
|||
|
|
@ -82,7 +82,6 @@ from .misc import (
|
|||
UnknownVariable,
|
||||
)
|
||||
from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
|
||||
|
||||
from .optimizer import OptimizerVariable
|
||||
from .sdpa import SDPAParamsVariable
|
||||
from .tensor import (
|
||||
|
|
@ -101,6 +100,7 @@ from .user_defined import (
|
|||
WeakRefVariable,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AutogradFunctionContextVariable",
|
||||
"AutogradFunctionVariable",
|
||||
|
|
|
|||
|
|
@ -4,9 +4,6 @@ import collections
|
|||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from .. import variables
|
||||
from ..current_scope_id import current_scope_id
|
||||
from ..exc import unimplemented
|
||||
|
|
@ -14,6 +11,10 @@ from ..source import AttrSource, Source
|
|||
from ..utils import istype
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
class MutableLocalSource(Enum):
|
||||
"""
|
||||
If the VariableTracker.mutable_local represents a Variable that:
|
||||
|
|
|
|||
|
|
@ -17,20 +17,7 @@ import types
|
|||
import weakref
|
||||
from typing import Any, List, NamedTuple, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from torch._utils_internal import justknobs_check
|
||||
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None
|
||||
|
||||
import torch
|
||||
|
||||
from torch import SymInt
|
||||
from torch._guards import GuardSource, TracingContext
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
|
|
@ -38,6 +25,7 @@ from torch._ops import HigherOrderOperator
|
|||
from torch._streambase import _EventBase, _StreamBase
|
||||
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch._utils_internal import justknobs_check
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_constrain_range_for_size,
|
||||
|
|
@ -49,10 +37,10 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
)
|
||||
from torch.fx.immutable_collections import immutable_dict, immutable_list
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
from torch.utils._sympy.value_ranges import ValueRanges
|
||||
from torch.utils.weak import TensorWeakRef
|
||||
|
||||
from .. import config, mutation_guard, replay_record, trace_rules
|
||||
|
||||
from ..device_interface import get_registered_device_interfaces
|
||||
from ..exc import InternalTorchDynamoError, unimplemented
|
||||
from ..guards import GuardBuilder, install_guard, make_dupe_guard
|
||||
|
|
@ -109,7 +97,6 @@ from ..utils import (
|
|||
unwrap_with_attr_name_if_wrapper,
|
||||
wrap_fake_exception,
|
||||
)
|
||||
|
||||
from .base import MutableLocal, typestr, VariableTracker, VariableTrackerMeta
|
||||
from .constant import ConstantVariable, EnumVariable
|
||||
from .ctx_manager import (
|
||||
|
|
@ -182,7 +169,6 @@ from .misc import (
|
|||
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
|
||||
from .optimizer import OptimizerVariable
|
||||
from .script_object import TorchScriptObjectVariable
|
||||
|
||||
from .sdpa import SDPAParamsVariable
|
||||
from .tensor import (
|
||||
NumpyNdarrayVariable,
|
||||
|
|
@ -202,6 +188,16 @@ from .user_defined import (
|
|||
)
|
||||
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -14,9 +14,6 @@ from typing import Dict, List, TYPE_CHECKING
|
|||
|
||||
import torch
|
||||
from torch import sym_float, sym_int
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from .. import config, polyfill, variables
|
||||
|
|
@ -70,6 +67,11 @@ from .tensor import (
|
|||
)
|
||||
from .user_defined import UserDefinedObjectVariable, UserDefinedVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,15 +6,17 @@ from typing import Dict, List, TYPE_CHECKING
|
|||
import torch
|
||||
from torch._dynamo.source import GetItemSource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from .. import variables
|
||||
from ..exc import unimplemented, UserError, UserErrorType
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..utils import common_constant_types, istype, np
|
||||
from .base import typestr, VariableTracker
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
_type_to_assert_reason = {
|
||||
# NB - We CAN have ConstantVariable.create(set) because of how sets interact with guards.
|
||||
# A locally created set should always become a SetVariable, as the items in the set will already either be sourced
|
||||
|
|
|
|||
|
|
@ -6,9 +6,6 @@ import warnings
|
|||
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import torch._C
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._guards import Guard
|
||||
|
||||
from .. import variables
|
||||
|
|
@ -32,6 +29,10 @@ from .functions import (
|
|||
from .user_defined import UserDefinedObjectVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ContextMangerState:
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -7,9 +7,6 @@ import inspect
|
|||
import sys
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from torch._subclasses.fake_tensor import is_fake
|
||||
|
||||
from .. import polyfill, variables
|
||||
|
|
@ -22,6 +19,11 @@ from ..utils import dict_keys, dict_values, istype, specialize_symnode
|
|||
from .base import MutableLocal, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
# [Adding a new supported class within the keys of ConstDictVarialble]
|
||||
# - Add its tracker type to is_hashable
|
||||
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl
|
||||
|
|
|
|||
|
|
@ -4,10 +4,8 @@ import inspect
|
|||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from ...fx.experimental._backward_state import BackwardState
|
||||
from .. import compiled_autograd, variables
|
||||
from .._trace_wrapped_higher_order_op import trace_wrapped
|
||||
from ..exc import unimplemented
|
||||
|
|
@ -19,6 +17,10 @@ from .base import VariableTracker
|
|||
from .constant import ConstantVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
class DistributedVariable(VariableTracker):
|
||||
"""
|
||||
The base distributed variable that encapsulates common methods
|
||||
|
|
|
|||
|
|
@ -9,9 +9,6 @@ from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
|||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from .. import polyfill, variables
|
||||
from ..bytecode_transformation import create_call_function, create_rot_n
|
||||
from ..exc import unimplemented, Unsupported
|
||||
|
|
@ -27,8 +24,6 @@ from ..utils import (
|
|||
from .base import MutableLocal, typestr, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._guards import Source
|
||||
|
||||
try:
|
||||
from torch.distributed._composable.fsdp import _fsdp_param_group
|
||||
|
|
@ -36,6 +31,11 @@ except ModuleNotFoundError:
|
|||
_fsdp_param_group = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._guards import Source
|
||||
|
||||
|
||||
def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
|
||||
# Source propagation is best effort since not every object we encounter has a source to begin with.
|
||||
if isinstance(val, VariableTracker):
|
||||
|
|
|
|||
|
|
@ -6,16 +6,12 @@ import inspect
|
|||
import itertools
|
||||
import logging
|
||||
import types
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING
|
||||
|
||||
import torch._C
|
||||
import torch.fx
|
||||
import torch.nn
|
||||
import torch.onnx.operators
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._dynamo.utils import get_fake_value
|
||||
from torch._dynamo.variables import ConstantVariable
|
||||
from torch._dynamo.variables.base import VariableTracker
|
||||
|
|
@ -26,8 +22,8 @@ from torch._guards import Source
|
|||
from torch._ops import HigherOrderOperator
|
||||
from torch.fx.passes.shape_prop import _extract_tensor_metadata
|
||||
from torch.utils import _pytree as pytree
|
||||
from .. import variables
|
||||
|
||||
from .. import variables
|
||||
from ..exc import UncapturedHigherOrderOpError, unimplemented, Unsupported
|
||||
from ..source import AttrSource
|
||||
from ..utils import proxy_args_kwargs
|
||||
|
|
@ -36,6 +32,10 @@ from .lazy import LazyVariableTracker
|
|||
from .lists import ListVariable, TupleVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
|
@ -1467,6 +1467,7 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
|
|||
) -> VariableTracker:
|
||||
from torch._higher_order_ops.wrap import TagActivationCheckpoint
|
||||
from torch.utils.checkpoint import noop_context_fn
|
||||
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
context_fn = None
|
||||
|
|
@ -1609,6 +1610,7 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
fn_name: str,
|
||||
):
|
||||
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
|
||||
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
tx: InstructionTranslator = tx
|
||||
|
|
|
|||
|
|
@ -1,16 +1,10 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
MAX_CYCLE = 3000
|
||||
|
||||
import itertools
|
||||
import operator
|
||||
import sys
|
||||
|
||||
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from .. import polyfill, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..exc import (
|
||||
|
|
@ -20,11 +14,17 @@ from ..exc import (
|
|||
unimplemented,
|
||||
UserError,
|
||||
)
|
||||
|
||||
from .base import MutableLocal, VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
MAX_CYCLE = 3000
|
||||
|
||||
|
||||
class ItertoolsVariable(VariableTracker):
|
||||
def __init__(self, value, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
|
|||
|
|
@ -9,11 +9,7 @@ from typing import Dict, List, Optional, TYPE_CHECKING
|
|||
|
||||
import torch
|
||||
import torch.fx
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from ..._guards import Source
|
||||
from torch._guards import Source
|
||||
|
||||
from .. import polyfill, variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
|
|
@ -36,6 +32,10 @@ from .functions import UserFunctionVariable, UserMethodVariable
|
|||
from .iter import IteratorVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
class BaseListVariable(VariableTracker):
|
||||
@staticmethod
|
||||
def cls_for_instance(obj):
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@ import torch._C
|
|||
import torch._numpy as tnp
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from .. import config, variables
|
||||
from ..bytecode_transformation import (
|
||||
add_push_null_call_function_ex,
|
||||
|
|
@ -38,6 +36,10 @@ from .functions import NestedUserFunctionVariable, UserFunctionVariable
|
|||
from .user_defined import is_standard_setattr, UserDefinedObjectVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
class SuperVariable(VariableTracker):
|
||||
_nonvar_fields = {
|
||||
"specialized",
|
||||
|
|
|
|||
|
|
@ -9,9 +9,6 @@ from typing import Any, Dict, List, TYPE_CHECKING
|
|||
|
||||
import torch.nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from .. import trace_rules, variables
|
||||
from ..exc import (
|
||||
ObservedException,
|
||||
|
|
@ -49,6 +46,10 @@ from .lists import SliceVariable
|
|||
from .user_defined import UserDefinedObjectVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
|
||||
"""
|
||||
Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable.
|
||||
|
|
|
|||
|
|
@ -4,9 +4,6 @@ import weakref
|
|||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
|
|
@ -18,14 +15,16 @@ from ..source import (
|
|||
GradSource,
|
||||
)
|
||||
from ..utils import GLOBAL_KEY_PREFIX
|
||||
|
||||
from .constant import ConstantVariable
|
||||
from .dicts import ConstDictVariable
|
||||
from .lists import ListVariable
|
||||
from .misc import GetAttrVariable
|
||||
from .user_defined import UserDefinedObjectVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from .base import VariableTracker
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import functools
|
|||
from typing import Dict
|
||||
|
||||
import torch
|
||||
from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
|
||||
|
||||
from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
|
||||
from .base import VariableTracker
|
||||
from .user_defined import UserDefinedObjectVariable
|
||||
|
||||
|
|
@ -49,6 +49,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
|
|||
)
|
||||
def var_getattr(self, tx, name: str) -> VariableTracker:
|
||||
from torch._higher_order_ops.torchbind import call_torchbind
|
||||
|
||||
from ..source import AttrSource
|
||||
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,17 @@
|
|||
# mypy: ignore-errors
|
||||
|
||||
from inspect import getattr_static
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..exc import Unsupported
|
||||
from .base import VariableTracker
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
class SDPAParamsVariable(VariableTracker):
|
||||
"""Represents the c++ params struct for scaled dot product attention.
|
||||
This is a read-only container."""
|
||||
|
|
@ -19,6 +19,7 @@ class SDPAParamsVariable(VariableTracker):
|
|||
@staticmethod
|
||||
def create(tx: "InstructionTranslator", value, source):
|
||||
from torch.backends.cuda import SDPAParams
|
||||
|
||||
from ..source import AttrSource
|
||||
from .builder import VariableBuilder
|
||||
from .torch import TorchInGraphFunctionVariable
|
||||
|
|
@ -64,6 +65,7 @@ class SDPAParamsVariable(VariableTracker):
|
|||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
|
||||
import torch._C
|
||||
|
||||
from ..source import AttrSource
|
||||
from .builder import wrap_fx_proxy
|
||||
from .misc import GetAttrVariable
|
||||
|
|
|
|||
|
|
@ -15,9 +15,6 @@ import torch._numpy as tnp
|
|||
import torch.fx
|
||||
import torch.random
|
||||
from torch._dynamo import compiled_autograd
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._subclasses.meta_utils import is_sparse_any
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
guard_scalar,
|
||||
|
|
@ -27,6 +24,7 @@ from torch.fx.experimental.symbolic_shapes import (
|
|||
SymTypes,
|
||||
)
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from .. import config, variables
|
||||
from .._trace_wrapped_higher_order_op import trace_wrapped
|
||||
from ..exc import unimplemented, UserError, UserErrorType
|
||||
|
|
@ -49,11 +47,17 @@ from .base import VariableTracker
|
|||
from .constant import ConstantVariable
|
||||
from .lists import SizeVariable
|
||||
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Ops that allow tensor <op> tensor
|
||||
|
|
|
|||
|
|
@ -3,7 +3,6 @@
|
|||
import functools
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
import math
|
||||
import re
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
|
@ -13,13 +12,10 @@ import torch._refs
|
|||
import torch.fx
|
||||
import torch.nn
|
||||
import torch.onnx.operators
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
from torch._guards import TracingContext
|
||||
from torch._logging import warning_once
|
||||
|
||||
from torch._streambase import _StreamBase
|
||||
from ..._guards import TracingContext
|
||||
|
||||
from .. import config, polyfill, variables
|
||||
from ..codegen import PyCodegen
|
||||
from ..create_parameter_op import (
|
||||
|
|
@ -50,6 +46,7 @@ from .distributed import DistributedVariable, ProcessGroupVariable
|
|||
from .lists import ListVariable, TupleVariable
|
||||
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
|
||||
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
|
|
@ -60,6 +57,11 @@ try:
|
|||
except ModuleNotFoundError:
|
||||
_fsdp_param_group = None # type: ignore[assignment]
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
supported_ctx_manager_classes = dict.fromkeys(
|
||||
|
|
@ -353,6 +355,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
|
|||
return _register
|
||||
|
||||
from torch.backends.cuda import SDPAParams
|
||||
|
||||
from . import (
|
||||
ConstantVariable,
|
||||
DeterministicAlgorithmsVariable,
|
||||
|
|
|
|||
|
|
@ -4,11 +4,8 @@ import inspect
|
|||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
|
||||
|
||||
from ..exc import unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import AttrSource, GlobalSource, TypeSource
|
||||
|
|
@ -18,7 +15,10 @@ from .lists import TupleVariable
|
|||
from .tensor import TensorSubclassVariable, TensorVariable
|
||||
from .user_defined import UserDefinedObjectVariable
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from .base import VariableTracker
|
||||
|
||||
|
||||
|
|
@ -223,6 +223,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||
# [Note: __torch_function__] We currently only support attributes that are defined on
|
||||
# base tensors, custom attribute accesses will graph break.
|
||||
import torch
|
||||
|
||||
from .builder import SourcelessBuilder
|
||||
|
||||
if name in banned_attrs:
|
||||
|
|
@ -277,6 +278,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
|
|||
# of `call_method`.
|
||||
if tx.output.torch_function_enabled:
|
||||
import torch
|
||||
|
||||
from .builder import SourcelessBuilder, VariableBuilder
|
||||
|
||||
if _is_attr_overidden(tx, self, name):
|
||||
|
|
|
|||
|
|
@ -11,30 +11,14 @@ import sys
|
|||
import threading
|
||||
import types
|
||||
import warnings
|
||||
|
||||
from typing import Dict, Generic, List, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
from ..bytecode_transformation import create_call_function
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None
|
||||
|
||||
try:
|
||||
from torch.utils._cxx_pytree import PyTreeSpec
|
||||
except ImportError:
|
||||
PyTreeSpec = type(None)
|
||||
|
||||
import torch._dynamo.config
|
||||
|
||||
import torch.nn
|
||||
from torch._guards import TracingContext
|
||||
|
||||
from .. import variables
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..create_parameter_op import do_not_convert_to_tracable_parameter
|
||||
from ..exc import ObservedException, unimplemented
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
|
|
@ -64,6 +48,21 @@ from .base import MutableLocal, VariableTracker
|
|||
from .dicts import DefaultDictVariable
|
||||
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
except ModuleNotFoundError:
|
||||
np = None
|
||||
|
||||
try:
|
||||
from torch.utils._cxx_pytree import PyTreeSpec
|
||||
except ImportError:
|
||||
PyTreeSpec = type(None)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._dynamo.symbolic_convert import InstructionTranslator
|
||||
|
||||
|
||||
def is_standard_setattr(val):
|
||||
return val in (object.__setattr__,)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user