[BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767)

See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129767
Approved by: https://github.com/anijain2305
This commit is contained in:
Xuehai Pan 2024-07-31 21:58:25 +08:00 committed by PyTorch MergeBot
parent ad9826208c
commit e74ba1b34a
69 changed files with 251 additions and 191 deletions

View File

@ -48,7 +48,6 @@ ISORT_SKIPLIST = re.compile(
# torch/**
# torch/_[a-c]*/**
# torch/_d*/**
"torch/_d*/**",
# torch/_[e-h]*/**
# torch/_i*/**
# torch/_[j-z]*/**

View File

@ -12,6 +12,7 @@ from torch._ops import HigherOrderOperator, OpOverload, OpOverloadPacket
from torch._prims_common import CustomOutParamAnnotation
from torch.utils import _pytree as pytree
__all__ = [
"decomposition_table",
"pre_autograd_decomposition_table",

View File

@ -34,6 +34,7 @@ from torch._prims_common.wrappers import (
from torch.utils import _pytree as pytree
from torch.utils._pytree import tree_map
DispatchKey = torch._C.DispatchKey # type: ignore[attr-defined]
# None of these functions are publicly accessible; get at them

View File

@ -8,6 +8,7 @@ import torch._decomp
from torch import Tensor
from torch._prims_common.wrappers import _maybe_remove_out_wrapper
decomposition_table = torch._decomp.decomposition_table
decomposition_table_for_jvp: Dict[torch._ops.OperatorBase, Callable] = {}
register_decomposition = torch._decomp.register_decomposition

View File

@ -9,6 +9,7 @@ import torch._decomp as decomp
from torch._decomp import get_decompositions
from torch._ops import OpOverload
aten = torch.ops.aten
rng_decompositions: Dict[str, Dict[OpOverload, Callable]] = defaultdict(dict)

View File

@ -10,6 +10,7 @@ import torch._ops
import torch.utils._python_dispatch
import torch.utils._pytree as pytree
__all__ = ["enable_python_dispatcher", "no_python_dispatcher", "enable_pre_dispatch"]
no_python_dispatcher = torch._C._DisablePythonDispatcher

View File

@ -1,4 +1,5 @@
import torch
from . import convert_frame, eval_frame, resume_execution
from .backends.registry import list_backends, lookup_backend, register_backend
from .callback import callback_handler, on_compile_end, on_compile_start
@ -32,6 +33,7 @@ from .external_utils import is_compiling
from .mutation_guard import GenerationTracker
from .utils import graph_break_reasons, guard_failures, orig_code_map, reset_frame_count
__all__ = [
"allow_in_graph",
"assume_constant_result",

View File

@ -2,11 +2,9 @@
import torch
from torch._C import DispatchKey
from torch._higher_order_ops.utils import autograd_not_implemented
from torch._ops import HigherOrderOperator
from torch._subclasses import FakeTensorMode
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.proxy_tensor import ProxyTorchDispatchMode, track_tensor_tree
from torch.utils._python_dispatch import _get_current_dispatch_mode
from torch.utils._pytree import tree_map_only

View File

@ -11,6 +11,7 @@ from torch._dynamo.utils import counters, defake, flatten_graph_inputs
from torch._functorch.aot_autograd import aot_module_simplified
from torch.utils._python_dispatch import _disable_current_modes
log = logging.getLogger(__name__)
@ -53,7 +54,6 @@ class AotAutograd:
)
from functorch.compile import nop
from torch._inductor.debug import enable_aot_logging
# debug asserts slow down compile time noticeably,

View File

@ -23,8 +23,8 @@ from torch._inductor.utils import (
num_fw_fixed_arguments,
output_node,
)
from torch.multiprocessing.reductions import StorageWeakRef
from .registry import register_backend

View File

@ -6,14 +6,15 @@ from importlib import import_module
from typing import Any, List, Optional
import torch
from functorch.compile import min_cut_rematerialization_partition
from torch import _guards
from torch._functorch import config as functorch_config
from torch._functorch.compilers import ts_compile
from .common import aot_autograd
from .registry import register_debug_backend as register_backend
"""
This file contains TorchDynamo backends intended for debugging uses.
"""

View File

@ -13,6 +13,7 @@ from torch._dynamo.utils import deepcopy_to_fake_tensor, detect_fake_mode
from torch._logging import trace_structured
from torch.fx.node import Node
# Regular log messages should go through 'log'.
# ddp_graph_log is a separate artifact logger reserved for dumping graphs.
# See docs/source/logging.rst for more info.

View File

@ -8,6 +8,7 @@ from torch.onnx._internal.onnxruntime import (
is_onnxrt_backend_supported,
torch_compile_backend,
)
from .registry import register_backend

View File

@ -7,6 +7,7 @@ from functorch.compile import make_boxed_func
from ..backends.common import aot_autograd
from .registry import register_backend, register_experimental_backend
log = logging.getLogger(__name__)

View File

@ -10,10 +10,11 @@ from types import MappingProxyType
from typing import Optional
import torch
from .common import device_from_inputs, fake_tensor_unsupported
from .common import device_from_inputs, fake_tensor_unsupported
from .registry import register_backend
log = logging.getLogger(__name__)

View File

@ -5,6 +5,7 @@ import dis
import sys
from typing import Any, Set, Union
TERMINAL_OPCODES = {
dis.opmap["RETURN_VALUE"],
dis.opmap["JUMP_FORWARD"],

View File

@ -9,6 +9,7 @@ from torch._guards import CompileId
from . import config
log = logging.getLogger(__name__)
"""
[Note on cache size limit]

View File

@ -7,8 +7,8 @@ import types
from typing import Counter, Dict, List, Optional
import torch.nn
from . import utils
from . import utils
from .bytecode_transformation import (
add_push_null,
create_call_function,

View File

@ -29,9 +29,11 @@ from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
from torch.fx.traceback import preserve_node_meta, set_stack_trace
from torch.utils._traceback import CapturedTraceback
if TYPE_CHECKING:
from torch.fx.proxy import Proxy
compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")

View File

@ -472,4 +472,5 @@ if TYPE_CHECKING:
from torch.utils._config_module import install_config_module
install_config_module(sys.modules[__name__])

View File

@ -23,23 +23,17 @@ from typing import Any, Callable, Dict, List, Optional, Set, TypeVar, Union
from typing_extensions import ParamSpec
from weakref import ReferenceType
from torch._utils_internal import maybe_upload_prof_stats_to_manifold
from torch.fx._lazy_graph_module import _use_lazy_graph_module
from torch.utils._traceback import CapturedTraceback
np: Optional[ModuleType]
try:
import numpy as np
except ModuleNotFoundError:
np = None
import torch
import torch._logging
from torch._dynamo.distributed import get_compile_pg
from torch._guards import compile_context, CompileContext, CompileId, tracing
from torch._logging import structured
from torch._utils_internal import compile_time_strobelight_meta, signpost_event
from torch._utils_internal import (
compile_time_strobelight_meta,
maybe_upload_prof_stats_to_manifold,
signpost_event,
)
from torch.fx._lazy_graph_module import _use_lazy_graph_module
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
GuardOnDataDependentSymNode,
@ -47,7 +41,7 @@ from torch.fx.experimental.symbolic_shapes import (
from torch.fx.graph_module import _forward_from_src as original_forward_from_src
from torch.nn.parallel.distributed import DistributedDataParallel
from torch.utils._python_dispatch import _disable_current_modes
from torch.utils._traceback import format_traceback_short
from torch.utils._traceback import CapturedTraceback, format_traceback_short
from . import config, exc, trace_rules
from .bytecode_analysis import remove_dead_code, remove_pointless_jumps
@ -109,12 +103,21 @@ from .utils import (
write_record_to_file,
)
np: Optional[ModuleType]
try:
import numpy as np
except ModuleNotFoundError:
np = None
if typing.TYPE_CHECKING:
from .backends.registry import CompilerFn
from .repro.after_dynamo import WrapBackendDebug
from .types import BytecodeHook, CacheEntry
from .variables.builder import FrameStateSizeEntry
log = logging.getLogger(__name__)
bytecode_log = torch._logging.getArtifactLogger(__name__, "bytecode")
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
@ -536,6 +539,7 @@ from collections import OrderedDict
from torch.utils.hooks import RemovableHandle
if typing.TYPE_CHECKING:
from .output_graph import OutputGraph

View File

@ -4,6 +4,7 @@ from contextlib import contextmanager
import torch
doc = """
This is used when dynamo traces torch.nn.Parameter, which normally would not trace properly
with AOTAutograd. We instead create a placeholder torch.nn.Parameter before the graph, which

View File

@ -2,6 +2,7 @@
import contextlib
import threading
# Global variable to identify which SubgraphTracer we are in.
# It is sometimes difficult to find an InstructionTranslator to use.
_current_scope_id = threading.local()

View File

@ -20,7 +20,6 @@ import torch
import torch._prims_common as utils
import torch._subclasses.meta_utils
from torch import Tensor
from torch._dynamo.testing import rand_strided
from torch._prims_common import is_float_dtype
from torch.multiprocessing.reductions import StorageWeakRef
@ -29,6 +28,7 @@ from torch.utils._content_store import ContentStoreReader, ContentStoreWriter
from . import config
from .utils import clone_inputs, get_debug_dir
log = logging.getLogger(__name__)
T = TypeVar("T")

View File

@ -5,12 +5,14 @@ from typing import TYPE_CHECKING
import torch
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from . import trace_rules, variables
from .comptime import comptime
from .eval_frame import DisableContext, innermost_fn, RunOnlyContext
from .exc import IncorrectUsage
from .external_utils import is_compiling
if TYPE_CHECKING:
from torch._C._dynamo.eval_frame import ( # noqa: F401
reset_code,

View File

@ -5,6 +5,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union
import torch
from torch._streambase import _EventBase, _StreamBase
get_cuda_stream: Optional[Callable[[int], int]]
if torch.cuda._is_compiled():
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream

View File

@ -1,8 +1,10 @@
from typing import Optional
import torch.distributed as dist
from . import config
_COMPILE_PG: Optional[dist.ProcessGroup] = None

View File

@ -37,13 +37,17 @@ from typing import (
)
from unittest.mock import patch
import sympy
import torch
import torch.fx
import torch.utils._pytree as pytree
import torch.utils.checkpoint
from torch import _guards
from torch._dispatch.python import enable_python_dispatcher
from torch._utils_internal import justknobs_check, log_export_usage
from torch.export.dynamic_shapes import _process_dynamic_shapes
from torch.fx import GraphModule
from torch.fx.experimental.proxy_tensor import make_fx, maybe_disable_fake_tensor_mode
from torch.fx.experimental.symbolic_shapes import (
ConstraintViolationError,
@ -53,10 +57,20 @@ from torch.fx.experimental.symbolic_shapes import (
)
from torch.fx.graph import _PyTreeCodeGen, _PyTreeInfo
from ..fx import GraphModule
from . import config, convert_frame, external_utils, trace_rules, utils
from .backends.registry import CompilerFn, lookup_backend
from .code_context import code_context
from .exc import CondOpArgsMismatchError, UserError, UserErrorType
from .hooks import Hooks
from .mutation_guard import install_generation_tagging_init
from .utils import common_constant_types, compile_times
if TYPE_CHECKING:
from torch._subclasses import fake_tensor
from .types import CacheEntry, DynamoCallback
# see discussion at https://github.com/pytorch/pytorch/issues/120699
reset_code = torch._C._dynamo.eval_frame.reset_code # noqa: F401
@ -65,27 +79,14 @@ set_guard_error_hook = torch._C._dynamo.eval_frame.set_guard_error_hook # noqa:
skip_code = torch._C._dynamo.eval_frame.skip_code # noqa: F401
unsupported = torch._C._dynamo.eval_frame.unsupported # noqa: F401
from . import config, convert_frame, external_utils, trace_rules, utils
from .code_context import code_context
from .exc import CondOpArgsMismatchError, UserError, UserErrorType
from .mutation_guard import install_generation_tagging_init
from .utils import common_constant_types, compile_times
log = logging.getLogger(__name__)
from torch._dispatch.python import enable_python_dispatcher
always_optimize_code_objects = utils.ExactWeakKeyDictionary()
null_context = contextlib.nullcontext
import sympy
if TYPE_CHECKING:
from torch._subclasses import fake_tensor
from .types import CacheEntry, DynamoCallback
# See https://github.com/python/typing/pull/240
class Unset(Enum):
token = 0
@ -1627,7 +1628,7 @@ class TorchPatcher:
)
torch.distributions.Distribution.set_default_validate_args(False)
from ..optim import (
from torch.optim import (
adadelta,
adagrad,
adam,

View File

@ -8,9 +8,9 @@ from typing import Any, cast, NoReturn, Optional, Tuple, TYPE_CHECKING
import torch._guards
from . import config
from .utils import counters
if TYPE_CHECKING:
from torch._guards import CompileId
@ -25,6 +25,7 @@ def exportdb_error_message(case_name):
import logging
log = logging.getLogger(__name__)
graph_breaks_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")

View File

@ -7,6 +7,7 @@ from typing import List
import torch
import torch.utils._pytree as pytree
try:
import numpy as np
except ModuleNotFoundError:

View File

@ -1,7 +1,7 @@
import tokenize
from typing import Dict, List, Optional
cache: Dict[str, Dict[int, str]] = {}

View File

@ -35,12 +35,6 @@ from typing import (
)
from weakref import ReferenceType
try:
import numpy as np
except ModuleNotFoundError:
np = None # type: ignore[assignment]
import torch
import torch.utils._device
from torch._dynamo.source import (
@ -60,7 +54,6 @@ from torch._guards import (
GuardSource,
Source,
)
from torch._logging import structured
from torch.fx.experimental.symbolic_shapes import (
EqualityConstraint,
@ -72,7 +65,6 @@ from torch.utils.weak import TensorWeakRef
from . import config, convert_frame, exc, mutation_guard
from .eval_frame import set_guard_error_hook
from .source import (
AttrSource,
ChainedSource,
@ -115,9 +107,17 @@ from .utils import (
verify_guard_fn_signature,
)
try:
import numpy as np
except ModuleNotFoundError:
np = None # type: ignore[assignment]
if TYPE_CHECKING:
from sympy import Symbol
log = logging.getLogger(__name__)
guards_log = torch._logging.getArtifactLogger(__name__, "guards")
recompiles_log = torch._logging.getArtifactLogger(__name__, "recompiles")
@ -310,7 +310,6 @@ if sys.version_info[:2] <= (3, 8):
HAS_UNPARSE_FUNCTIONS = True
except ImportError:
HAS_UNPARSE_FUNCTIONS = False
pass
else:
HAS_UNPARSE_FUNCTIONS = True

View File

@ -1,8 +1,8 @@
import dataclasses
from typing import Callable, Optional
from torch._guards import GuardsSet
from .types import GuardFail

View File

@ -4,6 +4,7 @@ import logging
from torch.hub import _Faketqdm, tqdm
# Disable progress bar by default, not in dynamo config because otherwise get a circular import
disable_progress = True

View File

@ -6,8 +6,8 @@ import weakref
import torch.nn
from torch.nn import Module
from . import config
from . import config
from .utils import ExactWeakKeyDictionary, is_lazy_module, nn_module_has_global_hooks

View File

@ -18,10 +18,8 @@ from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TYPE_CHECKIN
import sympy
import torch._guards
import torch._logging
import torch.distributed as dist
import torch.nn
import torch.utils._pytree as pytree
from torch import fx
@ -104,12 +102,13 @@ from .variables.tensor import (
TensorVariable,
UnspecializedPythonVariable,
)
from .variables.torch_function import TensorWithTFOverrideVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslatorBase
log = logging.getLogger(__name__)
graph_tabular_log = torch._logging.getArtifactLogger(__name__, "graph")
graph_code_log = torch._logging.getArtifactLogger(__name__, "graph_code")

View File

@ -6,6 +6,7 @@ from typing import Any, Dict
from torch.utils._import_utils import import_dill
dill = import_dill()

View File

@ -44,6 +44,7 @@ from torch.hub import tqdm
from .. import config
log = logging.getLogger(__name__)

View File

@ -12,7 +12,6 @@ from typing import Union
import torch
import torch.fx as fx
from torch._dynamo.debug_utils import (
AccuracyError,
backend_accuracy_fails,
@ -36,6 +35,7 @@ from .. import config
from ..backends.registry import lookup_backend, register_debug_backend
from ..debug_utils import clone_inputs_retaining_gradness
log = logging.getLogger(__name__)

View File

@ -19,6 +19,7 @@ from .bytecode_transformation import (
)
from .utils import ExactWeakKeyDictionary
# taken from code.h in cpython
CO_OPTIMIZED = 0x0001
CO_NEWLOCALS = 0x0002

View File

@ -10,6 +10,7 @@ from . import utils
from .bytecode_transformation import create_call_function, create_instruction
from .utils import enum_repr
# It shouldn't be supported to construct an NNModuleVariable inside an FSDP module,
# so those cases are omitted intentionally
_GUARD_SOURCE_NN_MODULE = {

View File

@ -110,6 +110,7 @@ from .variables.user_defined import (
UserDefinedObjectVariable,
)
log = logging.getLogger(__name__)
graph_break_log = torch._logging.getArtifactLogger(__name__, "graph_breaks")
trace_call_log = torch._logging.getArtifactLogger(__name__, "trace_call")

View File

@ -4,6 +4,7 @@ from torch._prims import _make_prim, RETURN_TYPE
from torch._subclasses import FakeTensorMode
from torch._subclasses.functional_tensor import FunctionalTensorMode
_tensor_version = _make_prim(
schema="_tensor_version(Tensor self) -> SymInt",
return_type=RETURN_TYPE.NEW,

View File

@ -14,6 +14,7 @@ from torch.testing._internal.common_utils import ( # type: ignore[attr-defined]
from . import config, reset, utils
log = logging.getLogger(__name__)

View File

@ -12,12 +12,6 @@ import unittest
from typing import List, Optional, Sequence, Union
from unittest.mock import patch
np: Optional[types.ModuleType] = None
try:
import numpy as np
except ModuleNotFoundError:
np = None
import torch
from torch import fx
from torch._dynamo.output_graph import OutputGraph
@ -32,6 +26,14 @@ from .bytecode_transformation import (
from .guards import CheckFunctionManager, CompileId, GuardedCode
from .utils import same
np: Optional[types.ModuleType] = None
try:
import numpy as np
except ModuleNotFoundError:
np = None
unsupported = eval_frame.unsupported
three = 3

View File

@ -35,20 +35,14 @@ import weakref
from collections import defaultdict
from typing import Any, Callable, cast, Dict, List, Optional, Set, Union
np: Optional[types.ModuleType] = None
try:
import numpy as np
except ModuleNotFoundError:
pass
import torch
import torch._inductor.test_operators
import torch.distributed
import torch.utils._content_store
from ..utils import _config_module
from torch.utils import _config_module
from .resume_execution import TORCH_DYNAMO_RESUME_IN_PREFIX
from .utils import getfile, hashable, NP_SUPPORTED_MODULES, unwrap_if_wrapper
from .variables import (
BuiltinVariable,
FunctionalCallVariable,
@ -61,6 +55,13 @@ from .variables import (
)
np: Optional[types.ModuleType] = None
try:
import numpy as np
except ModuleNotFoundError:
pass
if typing.TYPE_CHECKING:
from .variables.base import VariableTracker

View File

@ -4,6 +4,9 @@ import types
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Protocol, Union
from typing_extensions import TypeAlias
import torch
from torch._guards import CompileId
if sys.version_info >= (3, 11):
from torch._C._dynamo import eval_frame
@ -12,8 +15,6 @@ if sys.version_info >= (3, 11):
else:
DynamoFrameType: TypeAlias = types.FrameType
import torch
from torch._guards import CompileId
# This class has a `check_fn` field for the guard,
# and a `code` field for the code object.

View File

@ -9,6 +9,7 @@ import dis
import enum
import functools
import gc
import importlib
import inspect
import itertools
import linecache
@ -52,10 +53,21 @@ from typing import (
)
from typing_extensions import Literal, ParamSpec, TypeGuard
from ..utils.hooks import RemovableHandle
import torch
import torch._functorch.config
import torch._inductor.config as inductor_config
import torch.fx.experimental.symbolic_shapes
import torch.utils._pytree as pytree
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import log_compilation_event
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._triton import has_triton, has_triton_package
from torch.utils.hooks import RemovableHandle
T = TypeVar("T")
_P = ParamSpec("_P")
try:
import numpy as np
@ -67,6 +79,7 @@ try:
import torch._numpy as tnp
from torch._guards import detect_fake_mode # noqa: F401n
from torch._logging import LazyString
from . import config
# NOTE: Make sure `NP_SUPPORTED_MODULES` and `NP_TO_TNP_MODULE` are in sync.
@ -92,23 +105,9 @@ try:
except ImportError:
pass
import importlib
import torch
import torch._functorch.config
import torch._inductor.config as inductor_config
import torch.fx.experimental.symbolic_shapes
import torch.utils._pytree as pytree
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._guards import TracingContext
from torch._subclasses.meta_utils import is_sparse_compressed
from torch._utils_internal import log_compilation_event
from torch.fx._utils import _format_graph_code, lazy_format_graph_code
from torch.nn.modules.lazy import LazyModuleMixin
from torch.utils._triton import has_triton, has_triton_package
T = TypeVar("T")
_P = ParamSpec("_P")
unpatched_nn_module_getattr = torch.nn.Module.__getattr__
@ -1850,6 +1849,7 @@ def get_fake_value(node, tx, allow_non_graph_fake=False):
by further wrapping them as this graph's fakes.
"""
from torch.utils._sympy.value_ranges import ValueRangeError
from .exc import (
TorchRuntimeError,
unimplemented,
@ -2403,6 +2403,7 @@ def is_utils_checkpoint(obj):
def build_checkpoint_variable(**options):
import torch._higher_order_ops.wrap as higher_order_ops
from .variables.higher_order_ops import TorchHigherOrderOperatorVariable
# TODO - This is a temporary situation where we have two versions of

View File

@ -82,7 +82,6 @@ from .misc import (
UnknownVariable,
)
from .nn_module import NNModuleVariable, UnspecializedNNModuleVariable
from .optimizer import OptimizerVariable
from .sdpa import SDPAParamsVariable
from .tensor import (
@ -101,6 +100,7 @@ from .user_defined import (
WeakRefVariable,
)
__all__ = [
"AutogradFunctionContextVariable",
"AutogradFunctionVariable",

View File

@ -4,9 +4,6 @@ import collections
from enum import Enum
from typing import Any, Callable, Dict, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .. import variables
from ..current_scope_id import current_scope_id
from ..exc import unimplemented
@ -14,6 +11,10 @@ from ..source import AttrSource, Source
from ..utils import istype
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
class MutableLocalSource(Enum):
"""
If the VariableTracker.mutable_local represents a Variable that:

View File

@ -17,20 +17,7 @@ import types
import weakref
from typing import Any, List, NamedTuple, Optional, TYPE_CHECKING, Union
from torch._utils_internal import justknobs_check
from torch.utils._sympy.value_ranges import ValueRanges
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
try:
import numpy as np
except ModuleNotFoundError:
np = None
import torch
from torch import SymInt
from torch._guards import GuardSource, TracingContext
from torch._higher_order_ops.torchbind import call_torchbind
@ -38,6 +25,7 @@ from torch._ops import HigherOrderOperator
from torch._streambase import _EventBase, _StreamBase
from torch._subclasses.fake_tensor import FakeTensor, is_fake, maybe_get_fake_mode
from torch._subclasses.meta_utils import is_sparse_any
from torch._utils_internal import justknobs_check
from torch.fx.experimental._backward_state import BackwardState
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
@ -49,10 +37,10 @@ from torch.fx.experimental.symbolic_shapes import (
)
from torch.fx.immutable_collections import immutable_dict, immutable_list
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from torch.utils._sympy.value_ranges import ValueRanges
from torch.utils.weak import TensorWeakRef
from .. import config, mutation_guard, replay_record, trace_rules
from ..device_interface import get_registered_device_interfaces
from ..exc import InternalTorchDynamoError, unimplemented
from ..guards import GuardBuilder, install_guard, make_dupe_guard
@ -109,7 +97,6 @@ from ..utils import (
unwrap_with_attr_name_if_wrapper,
wrap_fake_exception,
)
from .base import MutableLocal, typestr, VariableTracker, VariableTrackerMeta
from .constant import ConstantVariable, EnumVariable
from .ctx_manager import (
@ -182,7 +169,6 @@ from .misc import (
from .nn_module import FSDPManagedNNModuleVariable, UnspecializedNNModuleVariable
from .optimizer import OptimizerVariable
from .script_object import TorchScriptObjectVariable
from .sdpa import SDPAParamsVariable
from .tensor import (
NumpyNdarrayVariable,
@ -202,6 +188,16 @@ from .user_defined import (
)
try:
import numpy as np
except ModuleNotFoundError:
np = None
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
log = logging.getLogger(__name__)

View File

@ -14,9 +14,6 @@ from typing import Dict, List, TYPE_CHECKING
import torch
from torch import sym_float, sym_int
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config, polyfill, variables
@ -70,6 +67,11 @@ from .tensor import (
)
from .user_defined import UserDefinedObjectVariable, UserDefinedVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
log = logging.getLogger(__name__)

View File

@ -6,15 +6,17 @@ from typing import Dict, List, TYPE_CHECKING
import torch
from torch._dynamo.source import GetItemSource
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .. import variables
from ..exc import unimplemented, UserError, UserErrorType
from ..guards import GuardBuilder, install_guard
from ..utils import common_constant_types, istype, np
from .base import typestr, VariableTracker
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
_type_to_assert_reason = {
# NB - We CAN have ConstantVariable.create(set) because of how sets interact with guards.
# A locally created set should always become a SetVariable, as the items in the set will already either be sourced

View File

@ -6,9 +6,6 @@ import warnings
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union
import torch._C
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._guards import Guard
from .. import variables
@ -32,6 +29,10 @@ from .functions import (
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
@dataclasses.dataclass
class ContextMangerState:
"""

View File

@ -7,9 +7,6 @@ import inspect
import sys
from typing import Dict, List, Optional, TYPE_CHECKING
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._subclasses.fake_tensor import is_fake
from .. import polyfill, variables
@ -22,6 +19,11 @@ from ..utils import dict_keys, dict_values, istype, specialize_symnode
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
# [Adding a new supported class within the keys of ConstDictVarialble]
# - Add its tracker type to is_hashable
# - (perhaps) Define how it is compared in _HashableTracker._eq_impl

View File

@ -4,10 +4,8 @@ import inspect
from typing import Dict, List, TYPE_CHECKING
import torch
from torch.fx.experimental._backward_state import BackwardState
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from ...fx.experimental._backward_state import BackwardState
from .. import compiled_autograd, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
from ..exc import unimplemented
@ -19,6 +17,10 @@ from .base import VariableTracker
from .constant import ConstantVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
class DistributedVariable(VariableTracker):
"""
The base distributed variable that encapsulates common methods

View File

@ -9,9 +9,6 @@ from typing import Dict, List, Optional, TYPE_CHECKING, Union
import torch
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .. import polyfill, variables
from ..bytecode_transformation import create_call_function, create_rot_n
from ..exc import unimplemented, Unsupported
@ -27,8 +24,6 @@ from ..utils import (
from .base import MutableLocal, typestr, VariableTracker
from .constant import ConstantVariable
if TYPE_CHECKING:
from torch._guards import Source
try:
from torch.distributed._composable.fsdp import _fsdp_param_group
@ -36,6 +31,11 @@ except ModuleNotFoundError:
_fsdp_param_group = None
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._guards import Source
def wrap_bound_arg(tx: "InstructionTranslator", val, source=None):
# Source propagation is best effort since not every object we encounter has a source to begin with.
if isinstance(val, VariableTracker):

View File

@ -6,16 +6,12 @@ import inspect
import itertools
import logging
import types
from typing import Dict, List, Optional, TYPE_CHECKING
import torch._C
import torch.fx
import torch.nn
import torch.onnx.operators
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._dynamo.utils import get_fake_value
from torch._dynamo.variables import ConstantVariable
from torch._dynamo.variables.base import VariableTracker
@ -26,8 +22,8 @@ from torch._guards import Source
from torch._ops import HigherOrderOperator
from torch.fx.passes.shape_prop import _extract_tensor_metadata
from torch.utils import _pytree as pytree
from .. import variables
from .. import variables
from ..exc import UncapturedHigherOrderOpError, unimplemented, Unsupported
from ..source import AttrSource
from ..utils import proxy_args_kwargs
@ -36,6 +32,10 @@ from .lazy import LazyVariableTracker
from .lists import ListVariable, TupleVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
log = logging.getLogger(__name__)
@ -1467,6 +1467,7 @@ class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
) -> VariableTracker:
from torch._higher_order_ops.wrap import TagActivationCheckpoint
from torch.utils.checkpoint import noop_context_fn
from .builder import wrap_fx_proxy
context_fn = None
@ -1609,6 +1610,7 @@ class FlexAttentionHigherOrderVariable(TorchHigherOrderOperatorVariable):
fn_name: str,
):
from torch._higher_order_ops.flex_attention import TransformGetItemToIndex
from .builder import SourcelessBuilder
tx: InstructionTranslator = tx

View File

@ -1,16 +1,10 @@
# mypy: ignore-errors
MAX_CYCLE = 3000
import itertools
import operator
import sys
from typing import Dict, List, Optional, TYPE_CHECKING, Union
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .. import polyfill, variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..exc import (
@ -20,11 +14,17 @@ from ..exc import (
unimplemented,
UserError,
)
from .base import MutableLocal, VariableTracker
from .constant import ConstantVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
MAX_CYCLE = 3000
class ItertoolsVariable(VariableTracker):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)

View File

@ -9,11 +9,7 @@ from typing import Dict, List, Optional, TYPE_CHECKING
import torch
import torch.fx
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from ..._guards import Source
from torch._guards import Source
from .. import polyfill, variables
from ..bytecode_transformation import create_call_function, create_instruction
@ -36,6 +32,10 @@ from .functions import UserFunctionVariable, UserMethodVariable
from .iter import IteratorVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
class BaseListVariable(VariableTracker):
@staticmethod
def cls_for_instance(obj):

View File

@ -13,8 +13,6 @@ import torch._C
import torch._numpy as tnp
import torch.utils._pytree as pytree
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .. import config, variables
from ..bytecode_transformation import (
add_push_null_call_function_ex,
@ -38,6 +36,10 @@ from .functions import NestedUserFunctionVariable, UserFunctionVariable
from .user_defined import is_standard_setattr, UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
class SuperVariable(VariableTracker):
_nonvar_fields = {
"specialized",

View File

@ -9,9 +9,6 @@ from typing import Any, Dict, List, TYPE_CHECKING
import torch.nn
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .. import trace_rules, variables
from ..exc import (
ObservedException,
@ -49,6 +46,10 @@ from .lists import SliceVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
def initialize_lazy_module(tx: "InstructionTranslator", mod, args, kwargs):
"""
Fairly coupled helper used by NNModuleVariable and UnspecializedNNModuleVariable.

View File

@ -4,9 +4,6 @@ import weakref
from typing import Dict, List, TYPE_CHECKING
import torch
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch.utils._pytree import tree_map_only
from ..guards import GuardBuilder, install_guard
@ -18,14 +15,16 @@ from ..source import (
GradSource,
)
from ..utils import GLOBAL_KEY_PREFIX
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .base import VariableTracker

View File

@ -4,8 +4,8 @@ import functools
from typing import Dict
import torch
from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
from ..exc import unimplemented, UnsafeScriptObjectError, Unsupported
from .base import VariableTracker
from .user_defined import UserDefinedObjectVariable
@ -49,6 +49,7 @@ class TorchScriptObjectVariable(UserDefinedObjectVariable):
)
def var_getattr(self, tx, name: str) -> VariableTracker:
from torch._higher_order_ops.torchbind import call_torchbind
from ..source import AttrSource
from .higher_order_ops import TorchHigherOrderOperatorVariable

View File

@ -1,17 +1,17 @@
# mypy: ignore-errors
from inspect import getattr_static
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from ..bytecode_transformation import create_call_function
from ..exc import Unsupported
from .base import VariableTracker
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
class SDPAParamsVariable(VariableTracker):
"""Represents the c++ params struct for scaled dot product attention.
This is a read-only container."""
@ -19,6 +19,7 @@ class SDPAParamsVariable(VariableTracker):
@staticmethod
def create(tx: "InstructionTranslator", value, source):
from torch.backends.cuda import SDPAParams
from ..source import AttrSource
from .builder import VariableBuilder
from .torch import TorchInGraphFunctionVariable
@ -64,6 +65,7 @@ class SDPAParamsVariable(VariableTracker):
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
import torch._C
from ..source import AttrSource
from .builder import wrap_fx_proxy
from .misc import GetAttrVariable

View File

@ -15,9 +15,6 @@ import torch._numpy as tnp
import torch.fx
import torch.random
from torch._dynamo import compiled_autograd
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._subclasses.meta_utils import is_sparse_any
from torch.fx.experimental.symbolic_shapes import (
guard_scalar,
@ -27,6 +24,7 @@ from torch.fx.experimental.symbolic_shapes import (
SymTypes,
)
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config, variables
from .._trace_wrapped_higher_order_op import trace_wrapped
from ..exc import unimplemented, UserError, UserErrorType
@ -49,11 +47,17 @@ from .base import VariableTracker
from .constant import ConstantVariable
from .lists import SizeVariable
try:
import numpy as np
except ModuleNotFoundError:
np = None
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
log = logging.getLogger(__name__)
# Ops that allow tensor <op> tensor

View File

@ -3,7 +3,6 @@
import functools
import inspect
import logging
import math
import re
from typing import Dict, List, TYPE_CHECKING
@ -13,13 +12,10 @@ import torch._refs
import torch.fx
import torch.nn
import torch.onnx.operators
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch._guards import TracingContext
from torch._logging import warning_once
from torch._streambase import _StreamBase
from ..._guards import TracingContext
from .. import config, polyfill, variables
from ..codegen import PyCodegen
from ..create_parameter_op import (
@ -50,6 +46,7 @@ from .distributed import DistributedVariable, ProcessGroupVariable
from .lists import ListVariable, TupleVariable
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
try:
import numpy as np
except ModuleNotFoundError:
@ -60,6 +57,11 @@ try:
except ModuleNotFoundError:
_fsdp_param_group = None # type: ignore[assignment]
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
log = logging.getLogger(__name__)
supported_ctx_manager_classes = dict.fromkeys(
@ -353,6 +355,7 @@ class TorchInGraphFunctionVariable(BaseTorchVariable):
return _register
from torch.backends.cuda import SDPAParams
from . import (
ConstantVariable,
DeterministicAlgorithmsVariable,

View File

@ -4,11 +4,8 @@ import inspect
from typing import Dict, List, TYPE_CHECKING
import torch.utils._pytree as pytree
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from torch.overrides import _get_overloaded_args, get_default_nowrap_functions
from ..exc import unimplemented
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, GlobalSource, TypeSource
@ -18,7 +15,10 @@ from .lists import TupleVariable
from .tensor import TensorSubclassVariable, TensorVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from .base import VariableTracker
@ -223,6 +223,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
# [Note: __torch_function__] We currently only support attributes that are defined on
# base tensors, custom attribute accesses will graph break.
import torch
from .builder import SourcelessBuilder
if name in banned_attrs:
@ -277,6 +278,7 @@ class TensorWithTFOverrideVariable(TensorVariable):
# of `call_method`.
if tx.output.torch_function_enabled:
import torch
from .builder import SourcelessBuilder, VariableBuilder
if _is_attr_overidden(tx, self, name):

View File

@ -11,30 +11,14 @@ import sys
import threading
import types
import warnings
from typing import Dict, Generic, List, TYPE_CHECKING
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from ..bytecode_transformation import create_call_function
try:
import numpy as np
except ModuleNotFoundError:
np = None
try:
from torch.utils._cxx_pytree import PyTreeSpec
except ImportError:
PyTreeSpec = type(None)
import torch._dynamo.config
import torch.nn
from torch._guards import TracingContext
from .. import variables
from ..bytecode_transformation import create_call_function
from ..create_parameter_op import do_not_convert_to_tracable_parameter
from ..exc import ObservedException, unimplemented
from ..guards import GuardBuilder, install_guard
@ -64,6 +48,21 @@ from .base import MutableLocal, VariableTracker
from .dicts import DefaultDictVariable
try:
import numpy as np
except ModuleNotFoundError:
np = None
try:
from torch.utils._cxx_pytree import PyTreeSpec
except ImportError:
PyTreeSpec = type(None)
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
def is_standard_setattr(val):
return val in (object.__setattr__,)