[Dynamo][Better Engineering] Add typing annotations to guard and source (#158397)

As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo

This PR adds strict typing support to a critical set of files for dynamo, `source.py` and the base `_guards.py`

Running
```
mypy torch/_dynamo/source.py torch/_guards.py --linecount-report /tmp/coverage_log
```

| -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered |
| -------- | ------- | -------- | ------- | ------- | ------- | ------- |
| Main  |  1227 | 2208 | 55.57% | 207 | 362 | 57.18% |
| This PR | 2217 | 2217 | 100.00% | 362 | 362 | 100.00% |
| Delta    | +990 | +9 | +44.43% | +155 | 0 | +42.82% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/158397
Approved by: https://github.com/anijain2305
This commit is contained in:
Lucas Kabela 2025-07-24 15:55:18 +00:00 committed by PyTorch MergeBot
parent fd48681b6a
commit abcb24f4de
22 changed files with 335 additions and 278 deletions

View File

@ -1846,7 +1846,7 @@ def export(
ignore_fresh_unbacked = null_context() ignore_fresh_unbacked = null_context()
assert ambient_fake_mode is not None assert ambient_fake_mode is not None
if shape_env := ambient_fake_mode.shape_env: if shape_env := ambient_fake_mode.shape_env:
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() # type: ignore[assignment]
with ( with (
ambient_fake_mode, ambient_fake_mode,
@ -1898,7 +1898,9 @@ def export(
fakify_with_ambient, graph_inputs fakify_with_ambient, graph_inputs
) )
graph_captured_result = torch.func.functional_call( graph_captured_result = torch.func.functional_call(
graph, fake_params_buffers, fake_graph_inputs graph,
fake_params_buffers, # type: ignore[arg-type]
fake_graph_inputs, # type: ignore[arg-type]
) )
return graph_captured_result return graph_captured_result

View File

@ -3481,7 +3481,7 @@ def strip_local_scope(s: str) -> str:
def get_guard_fail_reason_helper( def get_guard_fail_reason_helper(
guard_manager: GuardFn, guard_manager: GuardFn,
f_locals: dict[str, object], f_locals: dict[str, object],
compile_id: CompileId, compile_id: Optional[CompileId],
) -> str: ) -> str:
""" """
Return the reason why `guard_manager` failed. Return the reason why `guard_manager` failed.

View File

@ -809,6 +809,7 @@ class OutputGraph(OutputGraphGuardsState):
@property @property
def shape_env(self): def shape_env(self):
assert self.tracing_context.fake_mode is not None
return self.tracing_context.fake_mode.shape_env return self.tracing_context.fake_mode.shape_env
@property @property
@ -1691,6 +1692,7 @@ class OutputGraph(OutputGraphGuardsState):
) )
self.call_cleanup_hooks() self.call_cleanup_hooks()
old_fake_mode = self.tracing_context.fake_mode old_fake_mode = self.tracing_context.fake_mode
assert old_fake_mode is not None
if not self.export: if not self.export:
import torch._functorch.config as _config import torch._functorch.config as _config
@ -1738,6 +1740,7 @@ class OutputGraph(OutputGraphGuardsState):
) )
counters["stats"]["unique_graphs"] += 1 counters["stats"]["unique_graphs"] += 1
assert old_fake_mode.shape_env is not None
if specializations := old_fake_mode.shape_env.specializations: if specializations := old_fake_mode.shape_env.specializations:
specialization_guards = [] specialization_guards = []
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {} specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}

View File

@ -1,5 +1,3 @@
# mypy: allow-untyped-defs
""" """
This module provides Source classes that track the origins of values in PyTorch Dynamo. This module provides Source classes that track the origins of values in PyTorch Dynamo.
Sources represent where values come from (e.g. local variables, globals, attributes) and Sources represent where values come from (e.g. local variables, globals, attributes) and
@ -22,9 +20,9 @@ the code needed to recreate values.
import dataclasses import dataclasses
import enum import enum
import functools import functools
from typing import Any, Optional, TYPE_CHECKING, Union from typing import Any, Callable, Optional, TYPE_CHECKING, Union
from torch._guards import ChainedSource, GuardSource, Source from torch._guards import ChainedSource, Guard, GuardSource, Source
from . import utils from . import utils
from .bytecode_transformation import create_call_function, create_instruction from .bytecode_transformation import create_call_function, create_instruction
@ -96,7 +94,7 @@ _GUARD_SOURCE_FSDP_MODULE = {
} }
def is_constant_source(source): def is_constant_source(source: Source) -> bool:
if isinstance(source, ConstantSource): if isinstance(source, ConstantSource):
return True return True
try: try:
@ -124,16 +122,16 @@ class LocalSource(Source):
# or `co_freevars`. # or `co_freevars`.
is_derefed_cell_contents: bool = False is_derefed_cell_contents: bool = False
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
if self.is_derefed_cell_contents: if self.is_derefed_cell_contents:
codegen.load_deref(self.local_name) codegen.load_deref(self.local_name)
else: else:
codegen.append_output(codegen.create_load(self.local_name)) codegen.append_output(codegen.create_load(self.local_name))
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.LOCAL return GuardSource.LOCAL
def name(self): def name(self) -> str:
return f"L[{repr(self.local_name)}]" return f"L[{repr(self.local_name)}]"
@ -141,13 +139,13 @@ class LocalSource(Source):
class SyntheticLocalSource(Source): class SyntheticLocalSource(Source):
local_name: str local_name: str
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.append_output(codegen.create_load(self.local_name)) codegen.append_output(codegen.create_load(self.local_name))
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.SYNTHETIC_LOCAL return GuardSource.SYNTHETIC_LOCAL
def name(self): def name(self) -> str:
return f"SYNTHETIC_LOCAL[{self.local_name!r}]" return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
@ -155,15 +153,15 @@ class SyntheticLocalSource(Source):
class RandomValueSource(Source): class RandomValueSource(Source):
random_call_index: int random_call_index: int
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.RANDOM_VALUE return GuardSource.RANDOM_VALUE
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var)) codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
codegen.append_output(codegen.create_load_const(self.random_call_index)) codegen.append_output(codegen.create_load_const(self.random_call_index))
codegen.append_output(create_instruction("BINARY_SUBSCR")) codegen.append_output(create_instruction("BINARY_SUBSCR"))
def name(self): def name(self) -> str:
return f"random_value_{self.random_call_index}" return f"random_value_{self.random_call_index}"
@ -171,13 +169,13 @@ class RandomValueSource(Source):
class GlobalSource(Source): class GlobalSource(Source):
global_name: str global_name: str
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.append_output(codegen.create_load_global(self.global_name, add=True)) codegen.append_output(codegen.create_load_global(self.global_name, add=True))
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL return GuardSource.GLOBAL
def name(self): def name(self) -> str:
return f"G[{repr(self.global_name)}]" return f"G[{repr(self.global_name)}]"
@ -185,7 +183,7 @@ class GlobalSource(Source):
class GlobalWeakRefSource(Source): class GlobalWeakRefSource(Source):
global_name: str global_name: str
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.append_output( lambda: codegen.append_output(
codegen.create_load_global(self.global_name, add=True) codegen.create_load_global(self.global_name, add=True)
@ -193,23 +191,23 @@ class GlobalWeakRefSource(Source):
) )
codegen.extend_output(create_call_function(0, False)) codegen.extend_output(create_call_function(0, False))
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL return GuardSource.GLOBAL
def name(self): def name(self) -> str:
return f"G[{repr(self.global_name)}]()" return f"G[{repr(self.global_name)}]()"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class WeakRefCallSource(ChainedSource): class WeakRefCallSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(lambda: codegen(self.base)) codegen.add_push_null(lambda: codegen(self.base))
codegen.extend_output(create_call_function(0, False)) codegen.extend_output(create_call_function(0, False))
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return f"{self.base.name()}()" return f"{self.base.name()}()"
@ -222,7 +220,7 @@ class CallFunctionNoArgsSource(WeakRefCallSource):
class AttrSource(ChainedSource): class AttrSource(ChainedSource):
member: str member: str
def __post_init__(self): def __post_init__(self) -> None:
assert self.base, "Can't construct an AttrSource without a valid base source" assert self.base, "Can't construct an AttrSource without a valid base source"
if "." in self.member: if "." in self.member:
member_parts = self.member.split(".") member_parts = self.member.split(".")
@ -231,14 +229,14 @@ class AttrSource(ChainedSource):
) )
object.__setattr__(self, "member", member_parts[-1]) object.__setattr__(self, "member", member_parts[-1])
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.member)) codegen.extend_output(codegen.create_load_attrs(self.member))
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
if not self.member.isidentifier(): if not self.member.isidentifier():
return f"getattr({self.base.name()}, {self.member!r})" return f"getattr({self.base.name()}, {self.member!r})"
return f"{self.base.name()}.{self.member}" return f"{self.base.name()}.{self.member}"
@ -248,7 +246,7 @@ class AttrSource(ChainedSource):
class GenericAttrSource(ChainedSource): class GenericAttrSource(ChainedSource):
member: str member: str
def __post_init__(self): def __post_init__(self) -> None:
assert self.base, "Can't construct an AttrSource without a valid base source" assert self.base, "Can't construct an AttrSource without a valid base source"
if "." in self.member: if "." in self.member:
member_parts = self.member.split(".") member_parts = self.member.split(".")
@ -257,14 +255,14 @@ class GenericAttrSource(ChainedSource):
) )
object.__setattr__(self, "member", member_parts[-1]) object.__setattr__(self, "member", member_parts[-1])
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.member)) codegen.extend_output(codegen.create_load_attrs(self.member))
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return f"object.__getattribute__({self.base.name()}, {self.member!r})" return f"object.__getattribute__({self.base.name()}, {self.member!r})"
@ -277,7 +275,7 @@ class LocalCellSource(Source):
local_name: str local_name: str
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics, # Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
# Dynamo's bytecode transformation differentiates them slightly, so we # Dynamo's bytecode transformation differentiates them slightly, so we
# always emit `LOAD_CLOSURE` here. # always emit `LOAD_CLOSURE` here.
@ -295,20 +293,20 @@ class LocalCellSource(Source):
class GradSource(ChainedSource): class GradSource(ChainedSource):
member: str = "grad" member: str = "grad"
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.member)) codegen.extend_output(codegen.create_load_attrs(self.member))
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return f"{self.base.name()}.{self.member}" return f"{self.base.name()}.{self.member}"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ParamBufferSource(AttrSource): class ParamBufferSource(AttrSource):
def guard_source(self): def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
@ -331,16 +329,16 @@ class UnspecializedParamBufferSource(AttrSource):
class EphemeralSource(Source): class EphemeralSource(Source):
desc: Optional[str] = None desc: Optional[str] = None
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.EPHEMERAL return GuardSource.EPHEMERAL
def name(self): def name(self) -> str:
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>" return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
def make_guard(self, fn): def make_guard(self, fn: Callable[..., Any]) -> Guard:
raise NotImplementedError raise NotImplementedError
def is_ephemeral(self): def is_ephemeral(self) -> bool:
return True return True
@ -349,13 +347,15 @@ class TensorProperty(enum.Enum):
STRIDE = 1 STRIDE = 1
STORAGE_OFFSET = 2 STORAGE_OFFSET = 2
def method_name(self): def method_name(self) -> str:
if self is TensorProperty.SIZE: if self is TensorProperty.SIZE:
return "size" return "size"
elif self is TensorProperty.STRIDE: elif self is TensorProperty.STRIDE:
return "stride" return "stride"
elif self is TensorProperty.STORAGE_OFFSET: elif self is TensorProperty.STORAGE_OFFSET:
return "storage_offset" return "storage_offset"
else:
raise AssertionError(f"unhandled {self}")
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
@ -363,14 +363,14 @@ class TensorPropertySource(ChainedSource):
prop: TensorProperty prop: TensorProperty
idx: Optional[int] = None # None for STORAGE_OFFSET idx: Optional[int] = None # None for STORAGE_OFFSET
def __post_init__(self): def __post_init__(self) -> None:
assert self.base is not None assert self.base is not None
if self.prop is TensorProperty.STORAGE_OFFSET: if self.prop is TensorProperty.STORAGE_OFFSET:
assert self.idx is None assert self.idx is None
else: else:
assert self.idx is not None assert self.idx is not None
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.load_import_from( lambda: codegen.load_import_from(
utils.__name__, f"call_{self.prop.method_name()}" utils.__name__, f"call_{self.prop.method_name()}"
@ -384,10 +384,10 @@ class TensorPropertySource(ChainedSource):
create_call_function(2 if self.idx is not None else 1, False) create_call_function(2 if self.idx is not None else 1, False)
) )
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
if self.prop is TensorProperty.SIZE: if self.prop is TensorProperty.SIZE:
return f"{self.base.name()}.size()[{self.idx}]" return f"{self.base.name()}.size()[{self.idx}]"
elif self.prop is TensorProperty.STRIDE: elif self.prop is TensorProperty.STRIDE:
@ -403,88 +403,88 @@ class TensorPropertySource(ChainedSource):
class IndexedSource(ChainedSource): class IndexedSource(ChainedSource):
idx: int idx: int
def __post_init__(self): def __post_init__(self) -> None:
assert self.base is not None assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
raise NotImplementedError raise NotImplementedError
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return f"({self.idx}, {self.base.name()})" return f"({self.idx}, {self.base.name()})"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class NegateSource(ChainedSource): class NegateSource(ChainedSource):
def __post_init__(self): def __post_init__(self) -> None:
assert self.base is not None assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
raise NotImplementedError raise NotImplementedError
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
# NB: use method call so that function stripping regexes work # NB: use method call so that function stripping regexes work
return f"{self.base.name()}.__neg__()" return f"{self.base.name()}.__neg__()"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ConvertIntSource(ChainedSource): class ConvertIntSource(ChainedSource):
def __post_init__(self): def __post_init__(self) -> None:
assert self.base is not None assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return f"cast_symbool_to_symint_guardless({self.base.name()})" return f"cast_symbool_to_symint_guardless({self.base.name()})"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class FlattenScriptObjectSource(ChainedSource): class FlattenScriptObjectSource(ChainedSource):
def __post_init__(self): def __post_init__(self) -> None:
assert self.base is not None assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return f"{self.base.name()}.__obj_flatten__()" return f"{self.base.name()}.__obj_flatten__()"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ScriptObjectQualifiedNameSource(ChainedSource): class ScriptObjectQualifiedNameSource(ChainedSource):
def __post_init__(self): def __post_init__(self) -> None:
assert self.base is not None assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return f"{self.base.name()}._type().qualified_name()" return f"{self.base.name()}._type().qualified_name()"
class AttrProxySource(ChainedSource): class AttrProxySource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return f"{self.base.name()}.get_base()" return f"{self.base.name()}.get_base()"
@ -495,7 +495,7 @@ class DefaultsSource(ChainedSource):
field: str = dataclasses.field(init=False, repr=False, compare=False) field: str = dataclasses.field(init=False, repr=False, compare=False)
_name: str = dataclasses.field(init=False, repr=False, compare=False) _name: str = dataclasses.field(init=False, repr=False, compare=False)
def __post_init__(self): def __post_init__(self) -> None:
assert self.base, ( assert self.base, (
"Base must be a valid source in order to properly track and guard this Defaults to its origin." "Base must be a valid source in order to properly track and guard this Defaults to its origin."
) )
@ -512,16 +512,16 @@ class DefaultsSource(ChainedSource):
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]" self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
) )
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
codegen.extend_output(codegen.create_load_attrs(self.field)) codegen.extend_output(codegen.create_load_attrs(self.field))
codegen.append_output(codegen.create_load_const(self.idx_key)) codegen.append_output(codegen.create_load_const(self.idx_key))
codegen.append_output(create_instruction("BINARY_SUBSCR")) codegen.append_output(create_instruction("BINARY_SUBSCR"))
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return self._name return self._name
@ -530,14 +530,14 @@ class GetItemSource(ChainedSource):
index: Any index: Any
index_is_slice: bool = False index_is_slice: bool = False
def __post_init__(self): def __post_init__(self) -> None:
assert self.base is not None assert self.base is not None
if isinstance(self.index, slice): if isinstance(self.index, slice):
# store the hashable version of the slice so the whole GetItemSource is hashable # store the hashable version of the slice so the whole GetItemSource is hashable
super().__setattr__("index", self.index.__reduce__()) super().__setattr__("index", self.index.__reduce__())
super().__setattr__("index_is_slice", True) super().__setattr__("index_is_slice", True)
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
if self.index_is_slice: if self.index_is_slice:
codegen.append_output(codegen.create_load_const(self.unpack_slice())) codegen.append_output(codegen.create_load_const(self.unpack_slice()))
@ -545,15 +545,15 @@ class GetItemSource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index)) codegen.append_output(codegen.create_load_const(self.index))
codegen.append_output(create_instruction("BINARY_SUBSCR")) codegen.append_output(create_instruction("BINARY_SUBSCR"))
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def unpack_slice(self): def unpack_slice(self) -> slice:
assert self.index_is_slice assert self.index_is_slice
slice_class, slice_args = self.index slice_class, slice_args = self.index
return slice_class(*slice_args) return slice_class(*slice_args)
def name(self): def name(self) -> str:
# Index can be of following types # Index can be of following types
# 1) index is a slice - example 1:4 # 1) index is a slice - example 1:4
# 2) index is a constant - example string, integer # 2) index is a constant - example string, integer
@ -568,10 +568,10 @@ class GetItemSource(ChainedSource):
class ConstDictKeySource(ChainedSource): class ConstDictKeySource(ChainedSource):
index: Any index: Any
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem") lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
) )
@ -579,11 +579,11 @@ class ConstDictKeySource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index)) codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, False)) codegen.extend_output(create_call_function(2, False))
def name(self): def name(self) -> str:
# The list creation will be CSE'd by PyExprCSEPass # The list creation will be CSE'd by PyExprCSEPass
return f"list(dict.keys({self.base.name()}))[{self.index!r}]" return f"list(dict.keys({self.base.name()}))[{self.index!r}]"
def is_dict_key(self): def is_dict_key(self) -> bool:
return True return True
@ -591,15 +591,15 @@ class ConstDictKeySource(ChainedSource):
class NonSerializableSetGetItemSource(ChainedSource): class NonSerializableSetGetItemSource(ChainedSource):
index: int index: int
def __post_init__(self): def __post_init__(self) -> None:
from .variables import ConstantVariable from .variables import ConstantVariable
assert ConstantVariable.is_literal(self.index) assert ConstantVariable.is_literal(self.index)
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "set_getitem") lambda: codegen.load_import_from(utils.__name__, "set_getitem")
) )
@ -607,11 +607,11 @@ class NonSerializableSetGetItemSource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index)) codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, False)) codegen.extend_output(create_call_function(2, False))
def name(self): def name(self) -> str:
# set ordering might not be stable # set ordering might not be stable
return f"list({self.base.name()})[{self.index!r}]" return f"list({self.base.name()})[{self.index!r}]"
def is_dict_key(self): def is_dict_key(self) -> bool:
return False return False
@ -623,17 +623,17 @@ class DictGetItemSource(ChainedSource):
# 2) constant - like string, integer # 2) constant - like string, integer
index: Any index: Any
def __post_init__(self): def __post_init__(self) -> None:
from .variables import ConstantVariable from .variables import ConstantVariable
assert isinstance( assert isinstance(
self.index, ConstDictKeySource self.index, ConstDictKeySource
) or ConstantVariable.is_literal(self.index) ) or ConstantVariable.is_literal(self.index)
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
# Load dict # Load dict
codegen(self.base) codegen(self.base)
@ -644,7 +644,7 @@ class DictGetItemSource(ChainedSource):
codegen.append_output(codegen.create_load_const(self.index)) codegen.append_output(codegen.create_load_const(self.index))
codegen.append_output(create_instruction("BINARY_SUBSCR")) codegen.append_output(create_instruction("BINARY_SUBSCR"))
def name(self): def name(self) -> str:
if isinstance(self.index, ConstDictKeySource): if isinstance(self.index, ConstDictKeySource):
return f"{self.base.name()}[{self.index.name()}]" return f"{self.base.name()}[{self.index.name()}]"
else: else:
@ -660,17 +660,17 @@ class DictSubclassGetItemSource(ChainedSource):
# 2) constant - like string, integer # 2) constant - like string, integer
index: Any index: Any
def __post_init__(self): def __post_init__(self) -> None:
from .variables import ConstantVariable from .variables import ConstantVariable
assert isinstance( assert isinstance(
self.index, ConstDictKeySource self.index, ConstDictKeySource
) or ConstantVariable.is_literal(self.index) ) or ConstantVariable.is_literal(self.index)
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
# reconstruct dict.__getitem__(dct, key) # reconstruct dict.__getitem__(dct, key)
# Load dict.__getitem__ # Load dict.__getitem__
@ -689,7 +689,7 @@ class DictSubclassGetItemSource(ChainedSource):
codegen.extend_output(create_call_function(2, False)) codegen.extend_output(create_call_function(2, False))
def name(self): def name(self) -> str:
if isinstance(self.index, ConstDictKeySource): if isinstance(self.index, ConstDictKeySource):
return f"dict.__getitem__({self.base.name()}, {self.index.name()})" return f"dict.__getitem__({self.base.name()}, {self.index.name()})"
else: else:
@ -702,7 +702,7 @@ class ListGetItemSource(GetItemSource):
Same as GetItemSource with reconstruct and name overridden to be list specific. Same as GetItemSource with reconstruct and name overridden to be list specific.
""" """
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
# Reconstruct list.__getitem__(lst, index) to avoid any side effects # Reconstruct list.__getitem__(lst, index) to avoid any side effects
# from possibly overridden __getitem__. # from possibly overridden __getitem__.
@ -724,7 +724,7 @@ class ListGetItemSource(GetItemSource):
codegen.extend_output(create_call_function(2, False)) codegen.extend_output(create_call_function(2, False))
def name(self): def name(self) -> str:
# Index can be of following types # Index can be of following types
# 1) index is a slice - example 1:4 # 1) index is a slice - example 1:4
# 2) index is a constant - example string, integer # 2) index is a constant - example string, integer
@ -739,7 +739,7 @@ class ListGetItemSource(GetItemSource):
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TupleIteratorGetItemSource(GetItemSource): class TupleIteratorGetItemSource(GetItemSource):
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem") lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
) )
@ -747,91 +747,91 @@ class TupleIteratorGetItemSource(GetItemSource):
codegen.append_output(codegen.create_load_const(self.index)) codegen.append_output(codegen.create_load_const(self.index))
codegen.extend_output(create_call_function(2, False)) codegen.extend_output(create_call_function(2, False))
def name(self): def name(self) -> str:
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})" return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class DataclassFieldsSource(ChainedSource): class DataclassFieldsSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.load_import_from(utils.__name__, "dataclass_fields") lambda: codegen.load_import_from(utils.__name__, "dataclass_fields")
) )
codegen(self.base) codegen(self.base)
codegen.extend_output(create_call_function(1, False)) codegen.extend_output(create_call_function(1, False))
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return f"___dataclass_fields({self.base.name()})" return f"___dataclass_fields({self.base.name()})"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class TypeSource(ChainedSource): class TypeSource(ChainedSource):
def __post_init__(self): def __post_init__(self) -> None:
assert self.base is not None assert self.base is not None
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type")) codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
codegen(self.base) codegen(self.base)
codegen.extend_output(create_call_function(1, False)) codegen.extend_output(create_call_function(1, False))
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return f"type({self.base.name()})" return f"type({self.base.name()})"
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class OptimizerSource(ChainedSource): class OptimizerSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def name(self): def name(self) -> str:
return self.base.name() return self.base.name()
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class NNModuleSource(ChainedSource): class NNModuleSource(ChainedSource):
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen(self.base) codegen(self.base)
def guard_source(self): def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()] return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
def name(self): def name(self) -> str:
return self.base.name() return self.base.name()
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class UnspecializedNNModuleSource(NNModuleSource): class UnspecializedNNModuleSource(NNModuleSource):
def guard_source(self): def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()] return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource): class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource):
def guard_source(self): def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()] return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class FSDPNNModuleSource(NNModuleSource): class FSDPNNModuleSource(NNModuleSource):
def guard_source(self): def guard_source(self) -> GuardSource:
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()] return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class GlobalStateSource(Source): class GlobalStateSource(Source):
def name(self): def name(self) -> str:
return "" return ""
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL return GuardSource.GLOBAL
@ -840,16 +840,16 @@ class TorchSource(Source):
"""Points to the actual `torch` module - used instead of GlobalSource """Points to the actual `torch` module - used instead of GlobalSource
in case the user has overridden `torch` in their local namespace""" in case the user has overridden `torch` in their local namespace"""
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
from .guards import GuardBuilder, install_guard from .guards import GuardBuilder, install_guard
install_guard(self.make_guard(GuardBuilder.ID_MATCH)) install_guard(self.make_guard(GuardBuilder.ID_MATCH))
def name(self): def name(self) -> str:
return "__import__('torch')" return "__import__('torch')"
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.extend_output( codegen.extend_output(
[ [
codegen.create_load_const(0), # level codegen.create_load_const(0), # level
@ -858,7 +858,7 @@ class TorchSource(Source):
] ]
) )
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL return GuardSource.GLOBAL
@ -866,15 +866,15 @@ class TorchSource(Source):
class TorchFunctionModeStackSource(Source): class TorchFunctionModeStackSource(Source):
ind: int ind: int
def name(self): def name(self) -> str:
return f"___get_torch_function_mode_stack_at({self._get_index()})" return f"___get_torch_function_mode_stack_at({self._get_index()})"
def _get_index(self): def _get_index(self) -> int:
from .variables.torch_function import TorchFunctionModeStackVariable from .variables.torch_function import TorchFunctionModeStackVariable
return TorchFunctionModeStackVariable.get_mode_index(self.ind) return TorchFunctionModeStackVariable.get_mode_index(self.ind)
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null( codegen.add_push_null(
lambda: codegen.load_import_from( lambda: codegen.load_import_from(
utils.__name__, "get_torch_function_mode_stack_at" utils.__name__, "get_torch_function_mode_stack_at"
@ -883,7 +883,7 @@ class TorchFunctionModeStackSource(Source):
codegen.extend_output([codegen.create_load_const(self._get_index())]) codegen.extend_output([codegen.create_load_const(self._get_index())])
codegen.extend_output(create_call_function(1, False)) codegen.extend_output(create_call_function(1, False))
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.GLOBAL return GuardSource.GLOBAL
@ -891,16 +891,16 @@ class TorchFunctionModeStackSource(Source):
class ConstantSource(Source): class ConstantSource(Source):
source_name: str source_name: str
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.append_output(codegen.create_load_global(self.source_name, add=False)) codegen.append_output(codegen.create_load_global(self.source_name, add=False))
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.CONSTANT return GuardSource.CONSTANT
def name(self): def name(self) -> str:
return self.source_name return self.source_name
def make_guard(self, fn): def make_guard(self, fn: Any) -> Any:
raise NotImplementedError raise NotImplementedError
@ -909,10 +909,10 @@ class NumpyTensorSource(ChainedSource):
def name(self) -> str: def name(self) -> str:
return f"___from_numpy({self.base.name()})" return f"___from_numpy({self.base.name()})"
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
def reconstruct(self, codegen: "PyCodegen"): def reconstruct(self, codegen: "PyCodegen") -> None:
codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor")) codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
codegen(self.base) codegen(self.base)
codegen.extend_output(create_call_function(1, False)) codegen.extend_output(create_call_function(1, False))
@ -923,7 +923,7 @@ class SubclassAttrListSource(ChainedSource):
def name(self) -> str: def name(self) -> str:
return f"{self.base.name()}.__tensor_flatten__()[0]" return f"{self.base.name()}.__tensor_flatten__()[0]"
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
@ -934,7 +934,7 @@ class FloatTensorSource(ChainedSource):
def name(self) -> str: def name(self) -> str:
return f"___as_tensor({self.base.name()})" return f"___as_tensor({self.base.name()})"
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
@ -943,7 +943,7 @@ class CallMethodItemSource(ChainedSource):
def name(self) -> str: def name(self) -> str:
return f"{self.base.name()}.item()" return f"{self.base.name()}.item()"
def guard_source(self): def guard_source(self) -> GuardSource:
return self.base.guard_source() return self.base.guard_source()
@ -952,23 +952,25 @@ class CallMethodItemSource(ChainedSource):
# guard contents from the ambient ShapeEnv # guard contents from the ambient ShapeEnv
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class ShapeEnvSource(Source): class ShapeEnvSource(Source):
def name(self): def name(self) -> str:
return "" return ""
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.SHAPE_ENV return GuardSource.SHAPE_ENV
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class BackwardStateSource(Source): class BackwardStateSource(Source):
def name(self): def name(self) -> str:
return "" return ""
def guard_source(self): def guard_source(self) -> GuardSource:
return GuardSource.BACKWARD_STATE return GuardSource.BACKWARD_STATE
def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional[str]: def get_local_source_name(
source: Source, *, only_allow_input: bool = False
) -> Optional[str]:
if isinstance(source, ChainedSource): if isinstance(source, ChainedSource):
return get_local_source_name(source.base, only_allow_input=only_allow_input) return get_local_source_name(source.base, only_allow_input=only_allow_input)
if not isinstance(source, LocalSource): if not isinstance(source, LocalSource):
@ -978,7 +980,7 @@ def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional
return source.local_name return source.local_name
def is_from_local_source(source: Source, *, only_allow_input=False): def is_from_local_source(source: Source, *, only_allow_input: bool = False) -> bool:
return get_local_source_name(source, only_allow_input=only_allow_input) is not None return get_local_source_name(source, only_allow_input=only_allow_input) is not None
@ -994,7 +996,7 @@ def get_global_source_name(source: Source) -> Optional[str]:
return source.global_name return source.global_name
def is_from_nonlocal_source(source: Source): def is_from_nonlocal_source(source: Source) -> bool:
if isinstance(source, ChainedSource): if isinstance(source, ChainedSource):
return is_from_nonlocal_source(source.base) return is_from_nonlocal_source(source.base)
return ( return (
@ -1004,14 +1006,14 @@ def is_from_nonlocal_source(source: Source):
) )
def is_from_source(source: Source, target: Source): def is_from_source(source: Source, target: Source) -> bool:
if isinstance(source, ChainedSource): if isinstance(source, ChainedSource):
return is_from_source(source.base, target) return is_from_source(source.base, target)
return source == target return source == target
@functools.lru_cache @functools.lru_cache
def is_from_unspecialized_nn_module_source(source: Source): def is_from_unspecialized_nn_module_source(source: Source) -> bool:
if isinstance(source, UnspecializedNNModuleSource): if isinstance(source, UnspecializedNNModuleSource):
return True return True
if isinstance(source, ChainedSource): if isinstance(source, ChainedSource):
@ -1020,7 +1022,7 @@ def is_from_unspecialized_nn_module_source(source: Source):
@functools.lru_cache @functools.lru_cache
def is_from_unspecialized_builtin_nn_module_source(source: Source): def is_from_unspecialized_builtin_nn_module_source(source: Source) -> bool:
if isinstance(source, UnspecializedBuiltinNNModuleSource): if isinstance(source, UnspecializedBuiltinNNModuleSource):
return True return True
if isinstance(source, ChainedSource): if isinstance(source, ChainedSource):
@ -1029,7 +1031,7 @@ def is_from_unspecialized_builtin_nn_module_source(source: Source):
@functools.lru_cache @functools.lru_cache
def is_from_unspecialized_param_buffer_source(source: Source): def is_from_unspecialized_param_buffer_source(source: Source) -> bool:
if isinstance(source, UnspecializedParamBufferSource): if isinstance(source, UnspecializedParamBufferSource):
return True return True
if isinstance(source, ChainedSource): if isinstance(source, ChainedSource):
@ -1038,7 +1040,7 @@ def is_from_unspecialized_param_buffer_source(source: Source):
@functools.lru_cache @functools.lru_cache
def is_from_flatten_script_object_source(source: Source): def is_from_flatten_script_object_source(source: Source) -> bool:
if isinstance(source, FlattenScriptObjectSource): if isinstance(source, FlattenScriptObjectSource):
return True return True
elif isinstance(source, ChainedSource): elif isinstance(source, ChainedSource):
@ -1047,7 +1049,7 @@ def is_from_flatten_script_object_source(source: Source):
@functools.lru_cache @functools.lru_cache
def is_from_optimizer_source(source: Source): def is_from_optimizer_source(source: Source) -> bool:
if isinstance(source, OptimizerSource): if isinstance(source, OptimizerSource):
return True return True
if isinstance(source, ChainedSource): if isinstance(source, ChainedSource):
@ -1058,7 +1060,7 @@ def is_from_optimizer_source(source: Source):
# TODO: can probably write a generic "test this on everything in the chain" # TODO: can probably write a generic "test this on everything in the chain"
# helper # helper
@functools.lru_cache @functools.lru_cache
def is_from_defaults(source: Source): def is_from_defaults(source: Source) -> bool:
if isinstance(source, DefaultsSource): if isinstance(source, DefaultsSource):
return True return True

View File

@ -2653,7 +2653,9 @@ def set_example_value(node, example_value):
# this to accurately reflect what the state of the value was at the time # this to accurately reflect what the state of the value was at the time
# the program was traced). # the program was traced).
node.meta["example_value"] = example_value node.meta["example_value"] = example_value
shape_env = TracingContext.get().fake_mode.shape_env fake_mode = TracingContext.get().fake_mode
assert fake_mode is not None
shape_env = fake_mode.shape_env
if ( if (
symbol_to_path symbol_to_path
:= torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings( := torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
@ -4774,7 +4776,7 @@ def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
# Returns a set of code objects present traced in the current TracingContext, or None # Returns a set of code objects present traced in the current TracingContext, or None
# if there is no current TracingContext. # if there is no current TracingContext.
def get_traced_code() -> list[CodeType]: def get_traced_code() -> Optional[list[CodeType]]:
from torch._guards import TracingContext from torch._guards import TracingContext
return TracingContext.get_traced_code() return TracingContext.get_traced_code()

View File

@ -365,6 +365,7 @@ def make_fake_inputs(
# a toplevel TracingContext with a fake mode, so we do not want to # a toplevel TracingContext with a fake mode, so we do not want to
# create another fake mode. # create another fake mode.
fake_mode = context.fake_mode fake_mode = context.fake_mode
assert fake_mode is not None
else: else:
if isinstance(nn_module.forward, functools.partial): if isinstance(nn_module.forward, functools.partial):
# functools handles nesting by itself, no need to recurse # functools handles nesting by itself, no need to recurse
@ -852,7 +853,7 @@ def _fakify_script_objects(
mod: torch.nn.Module, mod: torch.nn.Module,
args: Sequence[Any], args: Sequence[Any],
kwargs: dict[Any, Any], kwargs: dict[Any, Any],
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode, fake_mode: Optional[torch._subclasses.fake_tensor.FakeTensorMode],
): ):
# This context manager is used to fakify script objects into FakeScriptObject. # This context manager is used to fakify script objects into FakeScriptObject.
# Inputs: # Inputs:

View File

@ -1129,7 +1129,7 @@ def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict:
def _detect_fake_mode_from_gm( def _detect_fake_mode_from_gm(
gm: torch.fx.GraphModule, gm: torch.fx.GraphModule,
) -> torch._subclasses.fake_tensor.FakeTensorMode: ) -> Optional[torch._subclasses.fake_tensor.FakeTensorMode]:
""" """
For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs. For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs.
Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes. Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes.

View File

@ -12,7 +12,7 @@ It does so by:
""" """
import warnings import warnings
from contextlib import contextmanager, ExitStack, nullcontext from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from functools import wraps from functools import wraps
from typing import Any, Callable, Optional, TypeVar, Union from typing import Any, Callable, Optional, TypeVar, Union
@ -337,9 +337,10 @@ def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True) -> Any:
# It goes from (primals, tangents) to (seed, offset, primals, tangents) # It goes from (primals, tangents) to (seed, offset, primals, tangents)
# At runtime, we pass on the current seed and offset. This is hidden from # At runtime, we pass on the current seed and offset. This is hidden from
# the user. # the user.
fake_mode = detect_fake_mode() fake_mode_det = detect_fake_mode()
if fake_mode is None: fake_mode: AbstractContextManager[Any] = nullcontext()
fake_mode = nullcontext() if fake_mode_det is not None:
fake_mode = fake_mode_det
def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"): def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"):
out = PhiloxStateTracker.get_state_as_tensor() out = PhiloxStateTracker.get_state_as_tensor()
@ -1130,7 +1131,9 @@ def create_functional_call(mod, params_spec, params_len, store_orig_mod=False):
"ignore", "Anomaly Detection has been enabled." "ignore", "Anomaly Detection has been enabled."
) )
with torch.autograd.detect_anomaly(check_nan=False): with torch.autograd.detect_anomaly(check_nan=False):
detect_fake_mode().epoch += 1 fake_mode = detect_fake_mode()
assert fake_mode is not None
fake_mode.epoch += 1
out = PropagateUnbackedSymInts(mod).run( out = PropagateUnbackedSymInts(mod).run(
*args[params_len:], **kwargs *args[params_len:], **kwargs
) )

View File

@ -280,11 +280,12 @@ def compute_overlapping_inputs(aot_config, fwd_inputs, aliased_input_indices):
tracing_context = torch._guards.TracingContext.try_get() tracing_context = torch._guards.TracingContext.try_get()
if tracing_context is not None: if tracing_context is not None:
assert tracing_context.fake_mode is not None
shape_env = tracing_context.fake_mode.shape_env shape_env = tracing_context.fake_mode.shape_env
# Check whether we can actually get the dynamo sources from within AOTAutograd. # Check whether we can actually get the dynamo sources from within AOTAutograd.
if aot_config.aot_autograd_arg_pos_to_source and shape_env is not None: if aot_config.aot_autograd_arg_pos_to_source and shape_env is not None:
maybe_suppress_guards = shape_env.suppress_guards maybe_suppress_guards = shape_env.suppress_guards # type: ignore[assignment]
# Check whether there are any symbolic values being used. # Check whether there are any symbolic values being used.
# We do this for 2 reasons: # We do this for 2 reasons:

View File

@ -480,6 +480,7 @@ class FunctionalizedRngRuntimeWrapper(InductorWrapper):
if config.functionalize_rng_ops: if config.functionalize_rng_ops:
# Update example inputs for the fw_compiler # Update example inputs for the fw_compiler
fake_mode = detect_fake_mode() fake_mode = detect_fake_mode()
assert fake_mode is not None
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode) seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
flat_args.extend([seed, offset]) flat_args.extend([seed, offset])
# We are not clearing flat_args here because # We are not clearing flat_args here because

View File

@ -1,4 +1,3 @@
# mypy: allow-untyped-defs
from __future__ import annotations from __future__ import annotations
import contextlib import contextlib
@ -37,10 +36,15 @@ log = logging.getLogger(__name__)
if TYPE_CHECKING: if TYPE_CHECKING:
from collections.abc import Generator, Iterator
from types import CodeType from types import CodeType
import sympy import sympy
from torch._dynamo.codegen import PyCodegen
from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
from torch._subclasses.fake_tensor import FakeTensorMode
""" """
torch._guards is the definitional source of truth for general purpose guard structures. torch._guards is the definitional source of truth for general purpose guard structures.
@ -83,7 +87,7 @@ class CompileId:
# TODO: consider also tracking the recompilation count # TODO: consider also tracking the recompilation count
# See Note: Updating CompileId # See Note: Updating CompileId
def __str__(self): def __str__(self) -> str:
# NOTE: Keep this in sync with both from_string and the tlparse repo # NOTE: Keep this in sync with both from_string and the tlparse repo
if self.compiled_autograd_id is not None: if self.compiled_autograd_id is not None:
assert (self.frame_id is None) == (self.frame_compile_id is None) assert (self.frame_id is None) == (self.frame_compile_id is None)
@ -97,7 +101,7 @@ class CompileId:
return f"{self.frame_id}/{self.frame_compile_id}" return f"{self.frame_id}/{self.frame_compile_id}"
@classmethod @classmethod
def from_string(cls, compile_id: Optional[str]): def from_string(cls, compile_id: Optional[str]) -> Optional[CompileId]:
""" """
Factory method that creates a CompileId from its string representation. Factory method that creates a CompileId from its string representation.
Keep this in sync with the __str__ method. Keep this in sync with the __str__ method.
@ -125,7 +129,7 @@ class TraceId(NamedTuple):
# up by one # up by one
attempt: int attempt: int
def __str__(self): def __str__(self) -> str:
# Keep this in sync with tlparse repo # Keep this in sync with tlparse repo
if self.attempt == 0: if self.attempt == 0:
return str(self.compile_id) return str(self.compile_id)
@ -174,7 +178,7 @@ class GuardSource(enum.Enum):
GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE, GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
) )
def is_local(self): def is_local(self) -> bool:
return self in ( return self in (
GuardSource.LOCAL, GuardSource.LOCAL,
GuardSource.LOCAL_SPECIALIZED_NN_MODULE, GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
@ -207,7 +211,7 @@ class SLoc:
framework_loc: Optional[Union[traceback.FrameSummary, str]] framework_loc: Optional[Union[traceback.FrameSummary, str]]
maybe_user_loc: Optional[str] maybe_user_loc: Optional[str]
def __str__(self): def __str__(self) -> str:
floc = ( floc = (
self.framework_loc self.framework_loc
if isinstance(self.framework_loc, str) if isinstance(self.framework_loc, str)
@ -246,7 +250,7 @@ class Guard:
# it is meaningless. Example create_fns that are like this include # it is meaningless. Example create_fns that are like this include
# GRAD_MODE and SHAPE_ENV. # GRAD_MODE and SHAPE_ENV.
originating_source: Source originating_source: Source
create_fn: Callable[[GuardBuilderBase, Guard], None] create_fn: Callable[[GuardBuilderBase, Guard], Any]
# Export only. These values are written to at time of guard check_fn creation. # Export only. These values are written to at time of guard check_fn creation.
guard_types: Optional[list[str]] = None guard_types: Optional[list[str]] = None
@ -258,12 +262,12 @@ class Guard:
user_stack: Optional[traceback.StackSummary] = None user_stack: Optional[traceback.StackSummary] = None
_hash: Optional[int] = None _hash: Optional[int] = None
def __hash__(self): def __hash__(self) -> int:
if self._hash is None: if self._hash is None:
self._hash = hash((self.name, self.source, id(self.create_fn))) self._hash = hash((self.name, self.source, id(self.create_fn)))
return self._hash return self._hash
def sort_key(self): def sort_key(self) -> tuple[bool, int, int, str, int]:
# Put the duplicate input guards at the end. The duplicate guards have # Put the duplicate input guards at the end. The duplicate guards have
# two sources while guard.name only considers one source. # two sources while guard.name only considers one source.
@ -279,10 +283,10 @@ class Guard:
self.inner_create_fn().__code__.co_firstlineno, self.inner_create_fn().__code__.co_firstlineno,
) )
def __lt__(self, other): def __lt__(self, other: Guard) -> bool:
return self.sort_key() < other.sort_key() return self.sort_key() < other.sort_key()
def inner_create_fn(self): def inner_create_fn(self) -> Callable[[GuardBuilderBase, Guard], Any]:
if isinstance(self.create_fn, functools.partial): if isinstance(self.create_fn, functools.partial):
return self.create_fn.func return self.create_fn.func
else: else:
@ -297,7 +301,7 @@ class Guard:
return self.originating_source.guard_source() return self.originating_source.guard_source()
@staticmethod @staticmethod
def weakref_to_str(obj_weakref): def weakref_to_str(obj_weakref: object) -> str:
""" """
This is a workaround of a Python weakref bug. This is a workaround of a Python weakref bug.
@ -321,7 +325,7 @@ class Guard:
else: else:
return str(obj_weakref) return str(obj_weakref)
def __repr__(self): def __repr__(self) -> str:
s = f""" s = f"""
{self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__} {self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
{{ {{
@ -333,7 +337,7 @@ class Guard:
""" """
return s return s
def __str__(self): def __str__(self) -> str:
output = f"Name: {repr(self.name)}\n" output = f"Name: {repr(self.name)}\n"
source = self.source.name.lower() if self.source else "" source = self.source.name.lower() if self.source else ""
output += f" Source: {source}\n" output += f" Source: {source}\n"
@ -344,7 +348,7 @@ class Guard:
output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n" output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
return output return output
def create(self, builder: GuardBuilderBase): def create(self, builder: GuardBuilderBase) -> Any:
try: try:
return self.create_fn(builder, self) return self.create_fn(builder, self)
except Exception: except Exception:
@ -353,16 +357,22 @@ class Guard:
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip()) log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
raise raise
def is_specialized_nn_module(self): def is_specialized_nn_module(self) -> bool:
return self.source.is_specialized_nn_module() return self.source.is_specialized_nn_module()
def is_fsdp_module(self): def is_fsdp_module(self) -> bool:
return self.source.is_fsdp_module() return self.source.is_fsdp_module()
def is_local(self): def is_local(self) -> bool:
return self.source.is_local() return self.source.is_local()
def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref): def set_export_info(
self,
guard_type: str,
guarded_class: Optional[type],
code_list: list[str],
obj_weakref: object,
) -> None:
if not self.guard_types: if not self.guard_types:
self.guard_types = [] self.guard_types = []
@ -417,7 +427,7 @@ class DuplicateInputs(GuardEnvExpr):
input_source_a: Source input_source_a: Source
input_source_b: Source input_source_b: Source
def __post_init__(self): def __post_init__(self) -> None:
assert self.input_source_a != self.input_source_b assert self.input_source_a != self.input_source_b
@ -459,7 +469,7 @@ class Checkpointable(Generic[T]):
def copy_graphstate(self) -> T: ... def copy_graphstate(self) -> T: ...
@abstractmethod @abstractmethod
def restore_graphstate(self, state: T): ... def restore_graphstate(self, state: T) -> None: ...
class GuardsCheckpointState: class GuardsCheckpointState:
@ -469,10 +479,10 @@ class GuardsCheckpointState:
dynamo_guards: set[Guard] = set() dynamo_guards: set[Guard] = set()
def __init__(self, dynamo_guards): def __init__(self, dynamo_guards: set[Guard]) -> None:
self.dynamo_guards = dynamo_guards self.dynamo_guards = dynamo_guards
def diff(self, other): def diff(self, other: GuardsCheckpointState) -> Optional[set[Guard]]:
""" """
Produces a delta against another GuardsCheckpointState. Produces a delta against another GuardsCheckpointState.
@ -484,17 +494,19 @@ class GuardsCheckpointState:
return None return None
return r return r
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, GuardsCheckpointState):
return False
return self.diff(other) is None return self.diff(other) is None
class ModuleContextCheckpointState: class ModuleContextCheckpointState:
nn_modules: dict[str, torch.nn.Module] = {} nn_modules: dict[str, torch.nn.Module] = {}
def __init__(self, nn_modules): def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None:
self.nn_modules = nn_modules self.nn_modules = nn_modules
def diff(self, other): def diff(self, other: ModuleContextCheckpointState) -> Optional[set[str]]:
""" """
Produces a delta against another ModuleContextCheckpointState. Produces a delta against another ModuleContextCheckpointState.
@ -506,7 +518,9 @@ class ModuleContextCheckpointState:
return None return None
return r return r
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, ModuleContextCheckpointState):
return False
return self.diff(other) is None return self.diff(other) is None
@ -514,21 +528,21 @@ class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
def __init__(self) -> None: def __init__(self) -> None:
self.nn_modules: dict[str, Any] = {} self.nn_modules: dict[str, Any] = {}
def copy_graphstate(self): def copy_graphstate(self) -> ModuleContextCheckpointState:
return ModuleContextCheckpointState(dict(self.nn_modules)) return ModuleContextCheckpointState(dict(self.nn_modules))
def restore_graphstate(self, state): def restore_graphstate(self, state: ModuleContextCheckpointState) -> None:
assert isinstance(state, ModuleContextCheckpointState) assert isinstance(state, ModuleContextCheckpointState)
self.nn_modules = state.nn_modules self.nn_modules = state.nn_modules
class GlobalContextCheckpointState: class GlobalContextCheckpointState:
global_state: dict[str, tuple[Callable, ...]] = {} global_state: dict[str, tuple[Callable, Any]] = {}
def __init__(self, global_states): def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None:
self.global_state = global_states self.global_state = global_states
def diff(self, other): def diff(self, other: GlobalContextCheckpointState) -> Optional[set[str]]:
""" """
Produces a delta against another GlobalContextCheckpointState. Produces a delta against another GlobalContextCheckpointState.
@ -540,7 +554,9 @@ class GlobalContextCheckpointState:
return None return None
return r return r
def __eq__(self, other): def __eq__(self, other: object) -> bool:
if not isinstance(other, GlobalContextCheckpointState):
return False
return self.diff(other) is None return self.diff(other) is None
@ -560,12 +576,12 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
} }
def __init__(self) -> None: def __init__(self) -> None:
self.global_state: dict[str, tuple[Callable, ...]] = {} self.global_state: dict[str, tuple[Callable, Any]] = {}
def copy_graphstate(self): def copy_graphstate(self) -> GlobalContextCheckpointState:
return GlobalContextCheckpointState(dict(self.global_state)) return GlobalContextCheckpointState(self.global_state)
def restore_graphstate(self, state): def restore_graphstate(self, state: GlobalContextCheckpointState) -> None:
assert isinstance(state, GlobalContextCheckpointState) assert isinstance(state, GlobalContextCheckpointState)
self.global_state = state.global_state self.global_state = state.global_state
assert ( assert (
@ -579,26 +595,28 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
# Like a Set[Guard] but will record the user stack on all guards at the # Like a Set[Guard] but will record the user stack on all guards at the
# time they were installed at their destination # time they were installed at their destination
class GuardsSet: class GuardsSet:
def __init__(self, inner=None): def __init__(self, inner: Optional[set[Guard]] = None) -> None:
if inner is None: if inner is None:
inner = set() inner = set()
self.inner = inner self.inner = inner
def __iter__(self): def __iter__(self) -> Iterator[Guard]:
return iter(self.inner) return iter(self.inner)
def __len__(self): def __len__(self) -> int:
return len(self.inner) return len(self.inner)
# Subtraction along with bool is typically used to determine the delta of # Subtraction along with bool is typically used to determine the delta of
# added guards between checkpoints for higher order ops # added guards between checkpoints for higher order ops
def __sub__(self, other): def __sub__(self, other: GuardsSet) -> GuardsSet:
return GuardsSet(self.inner - other.inner) return GuardsSet(self.inner - other.inner)
def __bool__(self): def __bool__(self) -> bool:
return bool(self.inner) return bool(self.inner)
def add(self, guard: Guard, *, collect_debug_stack=True, skip=0): def add(
self, guard: Guard, *, collect_debug_stack: bool = True, skip: int = 0
) -> None:
if guard in self.inner: if guard in self.inner:
return return
if collect_debug_stack: if collect_debug_stack:
@ -608,12 +626,12 @@ class GuardsSet:
guard.user_stack = TracingContext.extract_stack() guard.user_stack = TracingContext.extract_stack()
self.inner.add(guard) self.inner.add(guard)
def update(self, *others: set[Guard]): def update(self, *others: set[Guard]) -> None:
for o in others: for o in others:
for g in o: for g in o:
self.add(g, skip=1) self.add(g, skip=1)
def remove_guards_with_source(self, source): def remove_guards_with_source(self, source: Source) -> None:
"""Delete all guards that contains a given source""" """Delete all guards that contains a given source"""
from ._dynamo.source import is_from_source from ._dynamo.source import is_from_source
@ -635,10 +653,10 @@ class GuardsContext(Checkpointable[GuardsCheckpointState]):
self.dynamo_guards: GuardsSet = GuardsSet() self.dynamo_guards: GuardsSet = GuardsSet()
self.aotautograd_guards: list[GuardEnvExpr] = [] self.aotautograd_guards: list[GuardEnvExpr] = []
def copy_graphstate(self): def copy_graphstate(self) -> GuardsCheckpointState:
return GuardsCheckpointState(set(self.dynamo_guards.inner)) return GuardsCheckpointState(set(self.dynamo_guards.inner))
def restore_graphstate(self, state): def restore_graphstate(self, state: GuardsCheckpointState) -> None:
# NB: "steals" the passed in state # NB: "steals" the passed in state
assert isinstance(state, GuardsCheckpointState) assert isinstance(state, GuardsCheckpointState)
self.dynamo_guards = GuardsSet(state.dynamo_guards) self.dynamo_guards = GuardsSet(state.dynamo_guards)
@ -646,22 +664,22 @@ class GuardsContext(Checkpointable[GuardsCheckpointState]):
class HopSubgraphCache: class HopSubgraphCache:
@abstractmethod @abstractmethod
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str): ... def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: ...
@abstractmethod @abstractmethod
def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ... def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ...
@abstractmethod @abstractmethod
def add_autograd_key_entry(self, identifier: str, key: Callable): ... def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ...
@abstractmethod @abstractmethod
def get_autograd_key_entry(self, identifier: str): ... def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: ...
@abstractmethod @abstractmethod
def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ... def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ...
@abstractmethod @abstractmethod
def get_proxy_dispatch_entry(self, identifier: str): ... def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: ...
@abstractmethod @abstractmethod
def add_lazy_bwd_entry( def add_lazy_bwd_entry(
@ -669,12 +687,12 @@ class HopSubgraphCache:
identifier: str, identifier: str,
tangent_metadata: tuple[object], tangent_metadata: tuple[object],
gmod: torch.fx.GraphModule, gmod: torch.fx.GraphModule,
): ... ) -> int: ...
@abstractmethod @abstractmethod
def get_lazy_bwd_entry( def get_lazy_bwd_entry(
self, identifier: str, tangent_metadata: tuple[object] self, identifier: str, tangent_metadata: tuple[object]
) -> int: ... ) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: ...
class InvokeSubgraphCache(HopSubgraphCache): class InvokeSubgraphCache(HopSubgraphCache):
@ -686,22 +704,22 @@ class InvokeSubgraphCache(HopSubgraphCache):
str, dict[tuple[object], tuple[torch.fx.GraphModule, int]] str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
] = defaultdict(dict) ] = defaultdict(dict)
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str): def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
self.dynamo_installed_submodules[fn_id].append(identifier) self.dynamo_installed_submodules[fn_id].append(identifier)
def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]:
return self.dynamo_installed_submodules.get(fn_id, []) return self.dynamo_installed_submodules.get(fn_id, [])
def add_autograd_key_entry(self, identifier: str, key: Callable): def add_autograd_key_entry(self, identifier: str, key: Callable) -> None:
self.autograd_cache[identifier] = key self.autograd_cache[identifier] = key
def get_autograd_key_entry(self, identifier: str): def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]:
return self.autograd_cache.get(identifier, None) return self.autograd_cache.get(identifier, None)
def add_proxy_dispatch_entry(self, identifier: str, key: Callable): def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None:
self.proxy_dispatch_cache[identifier] = key self.proxy_dispatch_cache[identifier] = key
def get_proxy_dispatch_entry(self, identifier: str): def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]:
return self.proxy_dispatch_cache.get(identifier, None) return self.proxy_dispatch_cache.get(identifier, None)
def add_lazy_bwd_entry( def add_lazy_bwd_entry(
@ -709,13 +727,15 @@ class InvokeSubgraphCache(HopSubgraphCache):
identifier: str, identifier: str,
tangent_metadata: tuple[object], tangent_metadata: tuple[object],
gmod: torch.fx.GraphModule, gmod: torch.fx.GraphModule,
): ) -> int:
# Save the number of existing graph modules in the dictionary to get the suffix # Save the number of existing graph modules in the dictionary to get the suffix
num_gmods = len(self.lazy_bwd_cache[identifier]) num_gmods = len(self.lazy_bwd_cache[identifier])
self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods) self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods)
return num_gmods return num_gmods
def get_lazy_bwd_entry(self, identifier: str, tangent_metadata: tuple[object]): def get_lazy_bwd_entry(
self, identifier: str, tangent_metadata: tuple[object]
) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]:
if identifier not in self.lazy_bwd_cache: if identifier not in self.lazy_bwd_cache:
return (None, None) return (None, None)
@ -768,7 +788,7 @@ class CompileContext:
def try_get() -> Optional[CompileContext]: def try_get() -> Optional[CompileContext]:
return getattr(_TLS, "compile_context", None) return getattr(_TLS, "compile_context", None)
def __init__(self, compile_id): def __init__(self, compile_id: Optional[CompileId]) -> None:
assert compile_id is None or isinstance(compile_id, CompileId) assert compile_id is None or isinstance(compile_id, CompileId)
self.compile_id: Optional[CompileId] = compile_id self.compile_id: Optional[CompileId] = compile_id
self.attempt = 0 self.attempt = 0
@ -776,14 +796,14 @@ class CompileContext:
self.shape_env_guards: list[str] = [] self.shape_env_guards: list[str] = []
@staticmethod @staticmethod
def current_compile_id(): def current_compile_id() -> Optional[CompileId]:
self = CompileContext.try_get() self = CompileContext.try_get()
if self is None: if self is None:
return None return None
return self.compile_id return self.compile_id
@staticmethod @staticmethod
def current_trace_id(): def current_trace_id() -> Optional[TraceId]:
self = CompileContext.try_get() self = CompileContext.try_get()
if self is None: if self is None:
return None return None
@ -812,28 +832,28 @@ class TracingContext:
"TracingContext.get() must be called within an ongoing trace." "TracingContext.get() must be called within an ongoing trace."
) )
def __init__(self, fake_mode): def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None:
self.guards_context = GuardsContext() self.guards_context = GuardsContext()
self.module_context = ModuleContext() self.module_context = ModuleContext()
self.global_context = GlobalContext() self.global_context = GlobalContext()
self.previously_inlined_functions = dict() self.previously_inlined_functions: dict[Any, Any] = dict()
self.previously_cleaned_instructions = dict() self.previously_cleaned_instructions: dict[Any, Any] = dict()
self.fake_mode = fake_mode self.fake_mode: Optional[FakeTensorMode] = fake_mode
self.frame_summary_stack = [] self.frame_summary_stack: list[traceback.FrameSummary] = []
# This is morally part of frame_summary_stack, but it is kept separate # This is morally part of frame_summary_stack, but it is kept separate
# for clarity. As we process a frame, this variable gets updated # for clarity. As we process a frame, this variable gets updated
# to keep track of what line we are in the function. We make a # to keep track of what line we are in the function. We make a
# function call, this gets cleared and the frame location is pushed # function call, this gets cleared and the frame location is pushed
# to frame_summary_stack (prepping this variable for the inner frame's # to frame_summary_stack (prepping this variable for the inner frame's
# progress) # progress)
self.loc_in_frame = None self.loc_in_frame: Optional[tuple[str, int, str]] = None
# this is only set after aot_autograd # this is only set after aot_autograd
self.fw_metadata = None self.fw_metadata: Optional[ViewAndMutationMeta] = None
# this is only set after aot_autograd # this is only set after aot_autograd
self.aot_graph_name = None self.aot_graph_name: Optional[list[str]] = None
self.params_flat = None self.params_flat: Optional[list[Any]] = None
self.params_flat_unwrap_subclasses = None self.params_flat_unwrap_subclasses: Optional[list[Any]] = None
self.params_unwrapped_to_flat_index = None self.params_unwrapped_to_flat_index: Optional[list[Any]] = None
# this is for extended return calling convention from backend # this is for extended return calling convention from backend
# compiler to aot_autograd # compiler to aot_autograd
# Per output, what the compiler specified stride of the output is, # Per output, what the compiler specified stride of the output is,
@ -861,7 +881,7 @@ class TracingContext:
# list of code objects for inlined functions # list of code objects for inlined functions
self.traced_code: list[CodeType] = [] self.traced_code: list[CodeType] = []
def clear(self): def clear(self) -> None:
# Look at the note in output_graph.py in function `save_global_state` # Look at the note in output_graph.py in function `save_global_state`
# for the context on clearing global context. # for the context on clearing global context.
self.global_context.global_state = {} self.global_context.global_state = {}
@ -870,7 +890,7 @@ class TracingContext:
@staticmethod @staticmethod
@contextmanager @contextmanager
def patch(**kwargs): def patch(**kwargs: Any) -> Generator[None, None, None]:
prior = {} prior = {}
ctx = TracingContext.get() ctx = TracingContext.get()
@ -886,7 +906,7 @@ class TracingContext:
setattr(ctx, key, val) setattr(ctx, key, val)
@staticmethod @staticmethod
def extract_stack(): def extract_stack() -> traceback.StackSummary:
self = TracingContext.try_get() self = TracingContext.try_get()
if self is None: if self is None:
return traceback.StackSummary() return traceback.StackSummary()
@ -895,7 +915,7 @@ class TracingContext:
stack = stack + [self._populate_loc_in_frame_summary()] stack = stack + [self._populate_loc_in_frame_summary()]
return traceback.StackSummary.from_list(stack) return traceback.StackSummary.from_list(stack)
def _populate_loc_in_frame_summary(self): def _populate_loc_in_frame_summary(self) -> traceback.FrameSummary:
assert self.loc_in_frame is not None assert self.loc_in_frame is not None
filename, lineno, frame_name = self.loc_in_frame filename, lineno, frame_name = self.loc_in_frame
return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False) return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False)
@ -904,7 +924,7 @@ class TracingContext:
# associated with the current frame state # associated with the current frame state
@staticmethod @staticmethod
@contextlib.contextmanager @contextlib.contextmanager
def clear_frame(): def clear_frame() -> Generator[None, None, None]:
tc = TracingContext.get() tc = TracingContext.get()
with ( with (
unittest.mock.patch.object(tc, "frame_summary_stack", []), unittest.mock.patch.object(tc, "frame_summary_stack", []),
@ -936,7 +956,9 @@ class TracingContext:
@staticmethod @staticmethod
@contextlib.contextmanager @contextlib.contextmanager
def current_frame(frame_summary): def current_frame(
frame_summary: Optional[traceback.FrameSummary],
) -> Generator[None, None, None]:
# frame_summary can be None to solely take advantage of real_stack # frame_summary can be None to solely take advantage of real_stack
# attachment to thrown exceptions # attachment to thrown exceptions
tc = TracingContext.get() tc = TracingContext.get()
@ -957,7 +979,9 @@ class TracingContext:
@staticmethod @staticmethod
@contextlib.contextmanager @contextlib.contextmanager
def report_output_strides(): def report_output_strides() -> Generator[
Optional[list[Optional[tuple[int, ...]]]], None, None
]:
tc = TracingContext.try_get() tc = TracingContext.try_get()
if tc is None: if tc is None:
yield None yield None
@ -970,13 +994,13 @@ class TracingContext:
tc.output_strides = old_output_strides tc.output_strides = old_output_strides
@staticmethod @staticmethod
def set_current_loc(filename, lineno, frame_name): def set_current_loc(filename: str, lineno: int, frame_name: str) -> None:
# Save the current location in the frame. Lazily generate the # Save the current location in the frame. Lazily generate the
# framesummary. # framesummary.
TracingContext.get().loc_in_frame = (filename, lineno, frame_name) TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
@staticmethod @staticmethod
def get_traced_code(): def get_traced_code() -> Optional[list[CodeType]]:
tc = TracingContext.try_get() tc = TracingContext.try_get()
if tc is None: if tc is None:
return None return None
@ -984,7 +1008,9 @@ class TracingContext:
@contextmanager @contextmanager
def compile_context(context: Optional[CompileContext]): def compile_context(
context: Optional[CompileContext],
) -> Generator[Optional[CompileContext], None, None]:
old_context = getattr(_TLS, "compile_context", None) old_context = getattr(_TLS, "compile_context", None)
_TLS.compile_context = context _TLS.compile_context = context
try: try:
@ -994,7 +1020,9 @@ def compile_context(context: Optional[CompileContext]):
@contextmanager @contextmanager
def tracing(context: Optional[TracingContext]): def tracing(
context: Optional[TracingContext],
) -> Generator[Optional[TracingContext], None, None]:
""" """
This function installs the passed in tracing context as a dynamic scoped This function installs the passed in tracing context as a dynamic scoped
global variable. global variable.
@ -1024,13 +1052,13 @@ def tracing(context: Optional[TracingContext]):
# TODO(voz): Consider a toplevel torch/_source.py # TODO(voz): Consider a toplevel torch/_source.py
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class Source: class Source:
def is_dict_key(self): def is_dict_key(self) -> bool:
return False return False
def is_ephemeral(self): def is_ephemeral(self) -> bool:
return False return False
def reconstruct(self, codegen): def reconstruct(self, codegen: PyCodegen) -> None:
raise NotImplementedError raise NotImplementedError
def guard_source(self) -> GuardSource: def guard_source(self) -> GuardSource:
@ -1039,7 +1067,7 @@ class Source:
def name(self) -> str: def name(self) -> str:
raise NotImplementedError raise NotImplementedError
def make_guard(self, fn) -> Guard: def make_guard(self, fn: Callable[..., Any]) -> Guard:
if self.guard_source() is GuardSource.CONSTANT: if self.guard_source() is GuardSource.CONSTANT:
raise NotImplementedError raise NotImplementedError
return Guard(self, fn) return Guard(self, fn)
@ -1047,7 +1075,7 @@ class Source:
def is_specialized_nn_module(self) -> bool: def is_specialized_nn_module(self) -> bool:
return self.guard_source().is_specialized_nn_module() return self.guard_source().is_specialized_nn_module()
def subguards_allowed(self): def subguards_allowed(self) -> bool:
"""True if you can guard on attributes of this""" """True if you can guard on attributes of this"""
return self.guard_source() != GuardSource.SYNTHETIC_LOCAL return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
@ -1057,11 +1085,11 @@ class Source:
class ChainedSource(Source): class ChainedSource(Source):
base: Source base: Source
def is_dict_key(self): def is_dict_key(self) -> bool:
# Recurse until you either hit a ConstDictKey or a Source # Recurse until you either hit a ConstDictKey or a Source
return self.base.is_dict_key() return self.base.is_dict_key()
def is_ephemeral(self): def is_ephemeral(self) -> bool:
return self.base.is_ephemeral() return self.base.is_ephemeral()
def get_base(self) -> Source: def get_base(self) -> Source:
@ -1071,7 +1099,7 @@ class ChainedSource(Source):
return current return current
def detect_fake_mode(inputs: Any = None): def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
""" """
Attempts to "detect" what the current fake mode is. If there is one ambiently Attempts to "detect" what the current fake mode is. If there is one ambiently
available from TracingContext, we preferentially use that. Otherwise, we available from TracingContext, we preferentially use that. Otherwise, we
@ -1115,7 +1143,7 @@ def detect_fake_mode(inputs: Any = None):
return None return None
def active_fake_mode(): def active_fake_mode() -> Optional[FakeTensorMode]:
""" """
Inspects the dispatch mode stack for an active fake mode and returns it. Inspects the dispatch mode stack for an active fake mode and returns it.
Returns None if no fake mode is active. Returns None if no fake mode is active.

View File

@ -465,6 +465,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
from torch._subclasses.fake_tensor import extract_tensor_metadata from torch._subclasses.fake_tensor import extract_tensor_metadata
fake_mode = detect_fake_mode(primals + filtered_grad_outs) fake_mode = detect_fake_mode(primals + filtered_grad_outs)
assert fake_mode is not None, "fake_mode should be enabled for HOPs"
state = _CacheKeyState(fake_mode.shape_env) state = _CacheKeyState(fake_mode.shape_env)
tangent_metadata: list[object] = [] tangent_metadata: list[object] = []
@ -607,6 +608,7 @@ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands):
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
fake_mode = detect_fake_mode(operands) fake_mode = detect_fake_mode(operands)
assert fake_mode is not None and fake_mode.shape_env is not None
insert_deferred_runtime_asserts( insert_deferred_runtime_asserts(
graph, graph,
fake_mode.shape_env, fake_mode.shape_env,

View File

@ -1,7 +1,7 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
import contextlib import contextlib
import functools import functools
from contextlib import contextmanager, ExitStack, nullcontext from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Optional, overload, TypeVar, Union from typing import Any, Callable, Optional, overload, TypeVar, Union
@ -266,11 +266,12 @@ def _set_compilation_env():
# The invariant here is that we always trace the branch with fake tensor # The invariant here is that we always trace the branch with fake tensor
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch): def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
fake_mode = detect_fake_mode(inputs) fake_mode_det = detect_fake_mode(inputs)
tracing_mode = "real" fake_mode: AbstractContextManager = nullcontext()
if fake_mode is None: tracing_mode = "fake"
fake_mode = nullcontext() if fake_mode_det is not None:
tracing_mode = "fake" fake_mode = fake_mode_det
tracing_mode = "real"
# Note: we need to turn off proxy tensor mode to avoid tracing infra # Note: we need to turn off proxy tensor mode to avoid tracing infra
# code that happens in make_fx e.g. we now call as_strided when wrapping tensor # code that happens in make_fx e.g. we now call as_strided when wrapping tensor
@ -282,9 +283,12 @@ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
pre_dispatch=pre_dispatch, pre_dispatch=pre_dispatch,
_error_on_data_dependent_ops=False, _error_on_data_dependent_ops=False,
)(*inputs) )(*inputs)
if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: # type: ignore[attr-defined]
insert_deferred_runtime_asserts( insert_deferred_runtime_asserts(
gm, fake_mode.shape_env, "hoo_maybe_fake_tracing", export=True gm,
fake_mode.shape_env, # type: ignore[attr-defined]
"hoo_maybe_fake_tracing",
export=True, # type: ignore[attr-defined]
) )
return gm return gm

View File

@ -1065,7 +1065,7 @@ class GuardedCache(Generic[T]):
Helper to get the shape env from the tracing context. Helper to get the shape env from the tracing context.
""" """
ctx = torch._guards.TracingContext.try_get() ctx = torch._guards.TracingContext.try_get()
if not ctx: if not ctx or not ctx.fake_mode:
return None return None
return ctx.fake_mode.shape_env return ctx.fake_mode.shape_env

View File

@ -1942,7 +1942,7 @@ def fw_compiler_freezing(
idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node) idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node)
] ]
static_input_idxs = [] static_input_idxs: list[Any] = []
# constant params will be real tensors, not fake # constant params will be real tensors, not fake
tracing_context = torch._guards.TracingContext.try_get() tracing_context = torch._guards.TracingContext.try_get()
unwrapped_args_offsets = [0] unwrapped_args_offsets = [0]
@ -2461,6 +2461,7 @@ def compile_fx(
if node.op == "get_attr" and "val" not in node.meta: if node.op == "get_attr" and "val" not in node.meta:
target = attrgetter(node.target)(gm) target = attrgetter(node.target)(gm)
if isinstance(target, torch.Tensor): if isinstance(target, torch.Tensor):
assert fake_mode is not None
node.meta["val"] = fake_mode.from_tensor( node.meta["val"] = fake_mode.from_tensor(
target, static_shapes=True target, static_shapes=True
) )

View File

@ -1429,7 +1429,9 @@ def register_replacement(
) )
sym_args: list[torch.SymInt] = [] sym_args: list[torch.SymInt] = []
with torch._dynamo.utils.detect_fake_mode(args): fake_mode = torch._dynamo.utils.detect_fake_mode(args)
assert fake_mode is not None
with fake_mode:
for i, grad in enumerate(requires_grad): for i, grad in enumerate(requires_grad):
if isinstance(args[i], torch.Tensor): if isinstance(args[i], torch.Tensor):
if grad and is_integer_dtype(args[i].dtype): if grad and is_integer_dtype(args[i].dtype):

View File

@ -203,6 +203,7 @@ def standalone_compile(
# Reuse fake_mode from the TracingContext. # Reuse fake_mode from the TracingContext.
# NB: The TracingContext only exists if we're currently in a torch.compile backend. # NB: The TracingContext only exists if we're currently in a torch.compile backend.
context = torch._guards.TracingContext.get() context = torch._guards.TracingContext.get()
assert context.fake_mode is not None
fake_mode = context.fake_mode fake_mode = context.fake_mode
elif dynamic_shapes == "from_graph": elif dynamic_shapes == "from_graph":
fake_mode = FakeTensorMode(shape_env=ShapeEnv()) fake_mode = FakeTensorMode(shape_env=ShapeEnv())

View File

@ -2719,10 +2719,9 @@ def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[N
return contextlib.nullcontext() return contextlib.nullcontext()
# In standalone inductor compile mode, we might not have a shape_env attached to the fake mode # In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
shape_env = tracing_context.fake_mode.shape_env if not tracing_context.fake_mode or not tracing_context.fake_mode.shape_env:
if not shape_env:
return contextlib.nullcontext() return contextlib.nullcontext()
shape_env = tracing_context.fake_mode.shape_env
return shape_env.suppress_guards() return shape_env.suppress_guards()
@ -3245,12 +3244,13 @@ def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
for i, e in enumerate(row): for i, e in enumerate(row):
widths[i] = max(widths[i], len(str(e))) widths[i] = max(widths[i], len(str(e)))
lines = [] lines = []
lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) # Need nested {} for string formatting; ignore SET_LINTER here
lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) # noqa: set_linter
# widths whitespace horizontal separators # widths whitespace horizontal separators
total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1) total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
lines.append("-" * total_width) lines.append("-" * total_width)
for row in elements: for row in elements:
lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) # noqa: set_linter
return "\n".join(lines) return "\n".join(lines)

View File

@ -5,7 +5,7 @@ import operator
import typing import typing
import warnings import warnings
from collections.abc import Sequence from collections.abc import Sequence
from contextlib import nullcontext from contextlib import AbstractContextManager, nullcontext
from enum import Enum from enum import Enum
from functools import reduce from functools import reduce
from typing import ( from typing import (
@ -2087,7 +2087,9 @@ def alert_not_deterministic(caller: str):
class CUDARngStateHelper: class CUDARngStateHelper:
@staticmethod @staticmethod
def get_torch_state_as_tuple(fake_mode=nullcontext()): def get_torch_state_as_tuple(
fake_mode: AbstractContextManager[Any] = nullcontext(),
):
if not torch.cuda.is_available(): if not torch.cuda.is_available():
raise RuntimeError("CUDA not available") raise RuntimeError("CUDA not available")

View File

@ -573,8 +573,8 @@ def _decompose_and_get_gm_with_new_signature_constants(
delattr(ep.graph_module, name) delattr(ep.graph_module, name)
# TODO(zhxhchen17) Return the new graph_signature directly. # TODO(zhxhchen17) Return the new graph_signature directly.
fake_mode = detect_fake_mode(fake_args) fake_mode_det = detect_fake_mode(fake_args)
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode # type: ignore[assignment] fake_mode_ctx = contextlib.nullcontext() if fake_mode_det is None else fake_mode_det # type: ignore[assignment]
custom_triton_ops_decomposition_ctx = ( custom_triton_ops_decomposition_ctx = (
contextlib.nullcontext contextlib.nullcontext
if decompose_custom_triton_ops if decompose_custom_triton_ops
@ -582,7 +582,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
) )
with ( with (
_ignore_backend_decomps(), _ignore_backend_decomps(),
fake_mode, fake_mode_ctx,
_override_composite_implicit_decomp(cia_to_decomp), _override_composite_implicit_decomp(cia_to_decomp),
custom_triton_ops_decomposition_ctx(), custom_triton_ops_decomposition_ctx(),
): ):

View File

@ -7870,7 +7870,9 @@ class PropagateUnbackedSymInts(torch.fx.Interpreter):
from torch._guards import detect_fake_mode from torch._guards import detect_fake_mode
result = super().run_node(n) result = super().run_node(n)
rebind_unbacked(detect_fake_mode().shape_env, n, result) fake_mode = detect_fake_mode()
assert fake_mode is not None
rebind_unbacked(fake_mode.shape_env, n, result)
return result return result

View File

@ -1028,7 +1028,7 @@ def bound_sympy(
# If there's a tracing context, augment available constrained ranges. # If there's a tracing context, augment available constrained ranges.
context = torch._guards.TracingContext.try_get() context = torch._guards.TracingContext.try_get()
if context and context.fake_mode.shape_env: if context and context.fake_mode and context.fake_mode.shape_env:
if ranges: if ranges:
ranges = {**context.fake_mode.shape_env.var_to_range, **ranges} ranges = {**context.fake_mode.shape_env.var_to_range, **ranges}
else: else: