mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: X-link: https://github.com/pytorch/executorch/pull/12986 As part of better engineering week, we would like to improve out type support to improve dev experience in dynamo This PR adds strict typing support to a critical set of files for dynamo, `source.py` and the base `_guards.py` Running ``` mypy torch/_dynamo/source.py torch/_guards.py --linecount-report /tmp/coverage_log ``` | -------- | Lines Unannotated | Lines Total | % lines covered | Funcs Unannotated | Funcs Total | % funcs covered | | -------- | ------- | -------- | ------- | ------- | ------- | ------- | | Main | 1227 | 2208 | 55.57% | 207 | 362 | 57.18% | | This PR | 2217 | 2217 | 100.00% | 362 | 362 | 100.00% | | Delta | +990 | +9 | +44.43% | +155 | 0 | +42.82% | cc jgong5 mingfeima XiaobingSuper sanchitintel ashokei jingxu10 jerryzh168 voznesenskym penguinwu EikanWang Guobing-Chen zhuhaozhe blzheng wenzhe-nrv jiayisunx ipiszy chenyang78 kadeng muchulee8 amjames chauhang aakhundov coconutruben Test Plan: Imported from GitHub, without a `Test Plan:` line. Rollback Plan: Reviewed By: JacobSzwejbka, yangw-dev Differential Revision: D79199389 Pulled By: Lucaskabela Pull Request resolved: https://github.com/pytorch/pytorch/pull/159491 Approved by: https://github.com/anijain2305, https://github.com/yangw-dev
This commit is contained in:
parent
1293405c8d
commit
2b1ae29960
|
|
@ -1848,7 +1848,7 @@ def export(
|
||||||
ignore_fresh_unbacked = null_context()
|
ignore_fresh_unbacked = null_context()
|
||||||
assert ambient_fake_mode is not None
|
assert ambient_fake_mode is not None
|
||||||
if shape_env := ambient_fake_mode.shape_env:
|
if shape_env := ambient_fake_mode.shape_env:
|
||||||
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols()
|
ignore_fresh_unbacked = shape_env.ignore_fresh_unbacked_symbols() # type: ignore[assignment]
|
||||||
|
|
||||||
with (
|
with (
|
||||||
ambient_fake_mode,
|
ambient_fake_mode,
|
||||||
|
|
@ -1900,7 +1900,9 @@ def export(
|
||||||
fakify_with_ambient, graph_inputs
|
fakify_with_ambient, graph_inputs
|
||||||
)
|
)
|
||||||
graph_captured_result = torch.func.functional_call(
|
graph_captured_result = torch.func.functional_call(
|
||||||
graph, fake_params_buffers, fake_graph_inputs
|
graph,
|
||||||
|
fake_params_buffers, # type: ignore[arg-type]
|
||||||
|
fake_graph_inputs, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
return graph_captured_result
|
return graph_captured_result
|
||||||
|
|
|
||||||
|
|
@ -3685,7 +3685,7 @@ def strip_local_scope(s: str) -> str:
|
||||||
def get_guard_fail_reason_helper(
|
def get_guard_fail_reason_helper(
|
||||||
guard_manager: GuardFn,
|
guard_manager: GuardFn,
|
||||||
f_locals: dict[str, object],
|
f_locals: dict[str, object],
|
||||||
compile_id: CompileId,
|
compile_id: Optional[CompileId],
|
||||||
) -> str:
|
) -> str:
|
||||||
"""
|
"""
|
||||||
Return the reason why `guard_manager` failed.
|
Return the reason why `guard_manager` failed.
|
||||||
|
|
|
||||||
|
|
@ -809,6 +809,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape_env(self):
|
def shape_env(self):
|
||||||
|
assert self.tracing_context.fake_mode is not None
|
||||||
return self.tracing_context.fake_mode.shape_env
|
return self.tracing_context.fake_mode.shape_env
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -1691,6 +1692,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||||
)
|
)
|
||||||
self.call_cleanup_hooks()
|
self.call_cleanup_hooks()
|
||||||
old_fake_mode = self.tracing_context.fake_mode
|
old_fake_mode = self.tracing_context.fake_mode
|
||||||
|
assert old_fake_mode is not None
|
||||||
if not self.export:
|
if not self.export:
|
||||||
import torch._functorch.config as _config
|
import torch._functorch.config as _config
|
||||||
|
|
||||||
|
|
@ -1738,6 +1740,7 @@ class OutputGraph(OutputGraphGuardsState):
|
||||||
)
|
)
|
||||||
|
|
||||||
counters["stats"]["unique_graphs"] += 1
|
counters["stats"]["unique_graphs"] += 1
|
||||||
|
assert old_fake_mode.shape_env is not None
|
||||||
if specializations := old_fake_mode.shape_env.specializations:
|
if specializations := old_fake_mode.shape_env.specializations:
|
||||||
specialization_guards = []
|
specialization_guards = []
|
||||||
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
|
specialization_cache: dict[Specialization, Callable[[Any], Any]] = {}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,3 @@
|
||||||
# mypy: allow-untyped-defs
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
This module provides Source classes that track the origins of values in PyTorch Dynamo.
|
This module provides Source classes that track the origins of values in PyTorch Dynamo.
|
||||||
Sources represent where values come from (e.g. local variables, globals, attributes) and
|
Sources represent where values come from (e.g. local variables, globals, attributes) and
|
||||||
|
|
@ -22,9 +20,9 @@ the code needed to recreate values.
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import enum
|
import enum
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Optional, TYPE_CHECKING, Union
|
from typing import Any, Callable, Optional, TYPE_CHECKING, Union
|
||||||
|
|
||||||
from torch._guards import ChainedSource, GuardSource, Source
|
from torch._guards import ChainedSource, Guard, GuardSource, Source
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from .bytecode_transformation import create_call_function, create_instruction
|
from .bytecode_transformation import create_call_function, create_instruction
|
||||||
|
|
@ -96,7 +94,7 @@ _GUARD_SOURCE_FSDP_MODULE = {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def is_constant_source(source):
|
def is_constant_source(source: Source) -> bool:
|
||||||
if isinstance(source, ConstantSource):
|
if isinstance(source, ConstantSource):
|
||||||
return True
|
return True
|
||||||
try:
|
try:
|
||||||
|
|
@ -124,16 +122,16 @@ class LocalSource(Source):
|
||||||
# or `co_freevars`.
|
# or `co_freevars`.
|
||||||
is_derefed_cell_contents: bool = False
|
is_derefed_cell_contents: bool = False
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
if self.is_derefed_cell_contents:
|
if self.is_derefed_cell_contents:
|
||||||
codegen.load_deref(self.local_name)
|
codegen.load_deref(self.local_name)
|
||||||
else:
|
else:
|
||||||
codegen.append_output(codegen.create_load(self.local_name))
|
codegen.append_output(codegen.create_load(self.local_name))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.LOCAL
|
return GuardSource.LOCAL
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"L[{repr(self.local_name)}]"
|
return f"L[{repr(self.local_name)}]"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -141,13 +139,13 @@ class LocalSource(Source):
|
||||||
class SyntheticLocalSource(Source):
|
class SyntheticLocalSource(Source):
|
||||||
local_name: str
|
local_name: str
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.append_output(codegen.create_load(self.local_name))
|
codegen.append_output(codegen.create_load(self.local_name))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.SYNTHETIC_LOCAL
|
return GuardSource.SYNTHETIC_LOCAL
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
|
return f"SYNTHETIC_LOCAL[{self.local_name!r}]"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -155,15 +153,15 @@ class SyntheticLocalSource(Source):
|
||||||
class RandomValueSource(Source):
|
class RandomValueSource(Source):
|
||||||
random_call_index: int
|
random_call_index: int
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.RANDOM_VALUE
|
return GuardSource.RANDOM_VALUE
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
|
codegen.append_output(codegen.create_load(codegen.tx.output.random_values_var))
|
||||||
codegen.append_output(codegen.create_load_const(self.random_call_index))
|
codegen.append_output(codegen.create_load_const(self.random_call_index))
|
||||||
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"random_value_{self.random_call_index}"
|
return f"random_value_{self.random_call_index}"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -171,13 +169,13 @@ class RandomValueSource(Source):
|
||||||
class GlobalSource(Source):
|
class GlobalSource(Source):
|
||||||
global_name: str
|
global_name: str
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.append_output(codegen.create_load_global(self.global_name, add=True))
|
codegen.append_output(codegen.create_load_global(self.global_name, add=True))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.GLOBAL
|
return GuardSource.GLOBAL
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"G[{repr(self.global_name)}]"
|
return f"G[{repr(self.global_name)}]"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -185,7 +183,7 @@ class GlobalSource(Source):
|
||||||
class GlobalWeakRefSource(Source):
|
class GlobalWeakRefSource(Source):
|
||||||
global_name: str
|
global_name: str
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.add_push_null(
|
codegen.add_push_null(
|
||||||
lambda: codegen.append_output(
|
lambda: codegen.append_output(
|
||||||
codegen.create_load_global(self.global_name, add=True)
|
codegen.create_load_global(self.global_name, add=True)
|
||||||
|
|
@ -193,23 +191,23 @@ class GlobalWeakRefSource(Source):
|
||||||
)
|
)
|
||||||
codegen.extend_output(create_call_function(0, False))
|
codegen.extend_output(create_call_function(0, False))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.GLOBAL
|
return GuardSource.GLOBAL
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"G[{repr(self.global_name)}]()"
|
return f"G[{repr(self.global_name)}]()"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class WeakRefCallSource(ChainedSource):
|
class WeakRefCallSource(ChainedSource):
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.add_push_null(lambda: codegen(self.base))
|
codegen.add_push_null(lambda: codegen(self.base))
|
||||||
codegen.extend_output(create_call_function(0, False))
|
codegen.extend_output(create_call_function(0, False))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"{self.base.name()}()"
|
return f"{self.base.name()}()"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -222,7 +220,7 @@ class CallFunctionNoArgsSource(WeakRefCallSource):
|
||||||
class AttrSource(ChainedSource):
|
class AttrSource(ChainedSource):
|
||||||
member: str
|
member: str
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base, "Can't construct an AttrSource without a valid base source"
|
assert self.base, "Can't construct an AttrSource without a valid base source"
|
||||||
if "." in self.member:
|
if "." in self.member:
|
||||||
member_parts = self.member.split(".")
|
member_parts = self.member.split(".")
|
||||||
|
|
@ -231,14 +229,14 @@ class AttrSource(ChainedSource):
|
||||||
)
|
)
|
||||||
object.__setattr__(self, "member", member_parts[-1])
|
object.__setattr__(self, "member", member_parts[-1])
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
codegen.extend_output(codegen.create_load_attrs(self.member))
|
codegen.extend_output(codegen.create_load_attrs(self.member))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
if not self.member.isidentifier():
|
if not self.member.isidentifier():
|
||||||
return f"getattr({self.base.name()}, {self.member!r})"
|
return f"getattr({self.base.name()}, {self.member!r})"
|
||||||
return f"{self.base.name()}.{self.member}"
|
return f"{self.base.name()}.{self.member}"
|
||||||
|
|
@ -248,7 +246,7 @@ class AttrSource(ChainedSource):
|
||||||
class GenericAttrSource(ChainedSource):
|
class GenericAttrSource(ChainedSource):
|
||||||
member: str
|
member: str
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base, "Can't construct an AttrSource without a valid base source"
|
assert self.base, "Can't construct an AttrSource without a valid base source"
|
||||||
if "." in self.member:
|
if "." in self.member:
|
||||||
member_parts = self.member.split(".")
|
member_parts = self.member.split(".")
|
||||||
|
|
@ -257,14 +255,14 @@ class GenericAttrSource(ChainedSource):
|
||||||
)
|
)
|
||||||
object.__setattr__(self, "member", member_parts[-1])
|
object.__setattr__(self, "member", member_parts[-1])
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
codegen.extend_output(codegen.create_load_attrs(self.member))
|
codegen.extend_output(codegen.create_load_attrs(self.member))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"object.__getattribute__({self.base.name()}, {self.member!r})"
|
return f"object.__getattribute__({self.base.name()}, {self.member!r})"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -277,7 +275,7 @@ class LocalCellSource(Source):
|
||||||
|
|
||||||
local_name: str
|
local_name: str
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
|
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
|
||||||
# Dynamo's bytecode transformation differentiates them slightly, so we
|
# Dynamo's bytecode transformation differentiates them slightly, so we
|
||||||
# always emit `LOAD_CLOSURE` here.
|
# always emit `LOAD_CLOSURE` here.
|
||||||
|
|
@ -295,20 +293,20 @@ class LocalCellSource(Source):
|
||||||
class GradSource(ChainedSource):
|
class GradSource(ChainedSource):
|
||||||
member: str = "grad"
|
member: str = "grad"
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
codegen.extend_output(codegen.create_load_attrs(self.member))
|
codegen.extend_output(codegen.create_load_attrs(self.member))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"{self.base.name()}.{self.member}"
|
return f"{self.base.name()}.{self.member}"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class ParamBufferSource(AttrSource):
|
class ParamBufferSource(AttrSource):
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
|
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -331,16 +329,16 @@ class UnspecializedParamBufferSource(AttrSource):
|
||||||
class EphemeralSource(Source):
|
class EphemeralSource(Source):
|
||||||
desc: Optional[str] = None
|
desc: Optional[str] = None
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.EPHEMERAL
|
return GuardSource.EPHEMERAL
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
|
return f"<ephemeral{': ' + self.desc if self.desc is not None else ''}>"
|
||||||
|
|
||||||
def make_guard(self, fn):
|
def make_guard(self, fn: Callable[..., Any]) -> Guard:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def is_ephemeral(self):
|
def is_ephemeral(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -349,13 +347,15 @@ class TensorProperty(enum.Enum):
|
||||||
STRIDE = 1
|
STRIDE = 1
|
||||||
STORAGE_OFFSET = 2
|
STORAGE_OFFSET = 2
|
||||||
|
|
||||||
def method_name(self):
|
def method_name(self) -> str:
|
||||||
if self is TensorProperty.SIZE:
|
if self is TensorProperty.SIZE:
|
||||||
return "size"
|
return "size"
|
||||||
elif self is TensorProperty.STRIDE:
|
elif self is TensorProperty.STRIDE:
|
||||||
return "stride"
|
return "stride"
|
||||||
elif self is TensorProperty.STORAGE_OFFSET:
|
elif self is TensorProperty.STORAGE_OFFSET:
|
||||||
return "storage_offset"
|
return "storage_offset"
|
||||||
|
else:
|
||||||
|
raise AssertionError(f"unhandled {self}")
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
|
@ -363,14 +363,14 @@ class TensorPropertySource(ChainedSource):
|
||||||
prop: TensorProperty
|
prop: TensorProperty
|
||||||
idx: Optional[int] = None # None for STORAGE_OFFSET
|
idx: Optional[int] = None # None for STORAGE_OFFSET
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base is not None
|
assert self.base is not None
|
||||||
if self.prop is TensorProperty.STORAGE_OFFSET:
|
if self.prop is TensorProperty.STORAGE_OFFSET:
|
||||||
assert self.idx is None
|
assert self.idx is None
|
||||||
else:
|
else:
|
||||||
assert self.idx is not None
|
assert self.idx is not None
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.add_push_null(
|
codegen.add_push_null(
|
||||||
lambda: codegen.load_import_from(
|
lambda: codegen.load_import_from(
|
||||||
utils.__name__, f"call_{self.prop.method_name()}"
|
utils.__name__, f"call_{self.prop.method_name()}"
|
||||||
|
|
@ -384,10 +384,10 @@ class TensorPropertySource(ChainedSource):
|
||||||
create_call_function(2 if self.idx is not None else 1, False)
|
create_call_function(2 if self.idx is not None else 1, False)
|
||||||
)
|
)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
if self.prop is TensorProperty.SIZE:
|
if self.prop is TensorProperty.SIZE:
|
||||||
return f"{self.base.name()}.size()[{self.idx}]"
|
return f"{self.base.name()}.size()[{self.idx}]"
|
||||||
elif self.prop is TensorProperty.STRIDE:
|
elif self.prop is TensorProperty.STRIDE:
|
||||||
|
|
@ -403,88 +403,88 @@ class TensorPropertySource(ChainedSource):
|
||||||
class IndexedSource(ChainedSource):
|
class IndexedSource(ChainedSource):
|
||||||
idx: int
|
idx: int
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base is not None
|
assert self.base is not None
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"({self.idx}, {self.base.name()})"
|
return f"({self.idx}, {self.base.name()})"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class NegateSource(ChainedSource):
|
class NegateSource(ChainedSource):
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base is not None
|
assert self.base is not None
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
# NB: use method call so that function stripping regexes work
|
# NB: use method call so that function stripping regexes work
|
||||||
return f"{self.base.name()}.__neg__()"
|
return f"{self.base.name()}.__neg__()"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class ConvertIntSource(ChainedSource):
|
class ConvertIntSource(ChainedSource):
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base is not None
|
assert self.base is not None
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
return f"cast_symbool_to_symint_guardless({self.base.name()})"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class FlattenScriptObjectSource(ChainedSource):
|
class FlattenScriptObjectSource(ChainedSource):
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base is not None
|
assert self.base is not None
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"{self.base.name()}.__obj_flatten__()"
|
return f"{self.base.name()}.__obj_flatten__()"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class ScriptObjectQualifiedNameSource(ChainedSource):
|
class ScriptObjectQualifiedNameSource(ChainedSource):
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base is not None
|
assert self.base is not None
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"{self.base.name()}._type().qualified_name()"
|
return f"{self.base.name()}._type().qualified_name()"
|
||||||
|
|
||||||
|
|
||||||
class AttrProxySource(ChainedSource):
|
class AttrProxySource(ChainedSource):
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"{self.base.name()}.get_base()"
|
return f"{self.base.name()}.get_base()"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -495,7 +495,7 @@ class DefaultsSource(ChainedSource):
|
||||||
field: str = dataclasses.field(init=False, repr=False, compare=False)
|
field: str = dataclasses.field(init=False, repr=False, compare=False)
|
||||||
_name: str = dataclasses.field(init=False, repr=False, compare=False)
|
_name: str = dataclasses.field(init=False, repr=False, compare=False)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base, (
|
assert self.base, (
|
||||||
"Base must be a valid source in order to properly track and guard this Defaults to its origin."
|
"Base must be a valid source in order to properly track and guard this Defaults to its origin."
|
||||||
)
|
)
|
||||||
|
|
@ -512,16 +512,16 @@ class DefaultsSource(ChainedSource):
|
||||||
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
|
self, "_name", f"{self.base.name()}.{self.field}[{self.idx_key}]"
|
||||||
)
|
)
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
codegen.extend_output(codegen.create_load_attrs(self.field))
|
codegen.extend_output(codegen.create_load_attrs(self.field))
|
||||||
codegen.append_output(codegen.create_load_const(self.idx_key))
|
codegen.append_output(codegen.create_load_const(self.idx_key))
|
||||||
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -530,14 +530,14 @@ class GetItemSource(ChainedSource):
|
||||||
index: Any
|
index: Any
|
||||||
index_is_slice: bool = False
|
index_is_slice: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base is not None
|
assert self.base is not None
|
||||||
if isinstance(self.index, slice):
|
if isinstance(self.index, slice):
|
||||||
# store the hashable version of the slice so the whole GetItemSource is hashable
|
# store the hashable version of the slice so the whole GetItemSource is hashable
|
||||||
super().__setattr__("index", self.index.__reduce__())
|
super().__setattr__("index", self.index.__reduce__())
|
||||||
super().__setattr__("index_is_slice", True)
|
super().__setattr__("index_is_slice", True)
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
if self.index_is_slice:
|
if self.index_is_slice:
|
||||||
codegen.append_output(codegen.create_load_const(self.unpack_slice()))
|
codegen.append_output(codegen.create_load_const(self.unpack_slice()))
|
||||||
|
|
@ -545,15 +545,15 @@ class GetItemSource(ChainedSource):
|
||||||
codegen.append_output(codegen.create_load_const(self.index))
|
codegen.append_output(codegen.create_load_const(self.index))
|
||||||
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def unpack_slice(self):
|
def unpack_slice(self) -> slice:
|
||||||
assert self.index_is_slice
|
assert self.index_is_slice
|
||||||
slice_class, slice_args = self.index
|
slice_class, slice_args = self.index
|
||||||
return slice_class(*slice_args)
|
return slice_class(*slice_args)
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
# Index can be of following types
|
# Index can be of following types
|
||||||
# 1) index is a slice - example 1:4
|
# 1) index is a slice - example 1:4
|
||||||
# 2) index is a constant - example string, integer
|
# 2) index is a constant - example string, integer
|
||||||
|
|
@ -568,10 +568,10 @@ class GetItemSource(ChainedSource):
|
||||||
class ConstDictKeySource(ChainedSource):
|
class ConstDictKeySource(ChainedSource):
|
||||||
index: Any
|
index: Any
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.add_push_null(
|
codegen.add_push_null(
|
||||||
lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
|
lambda: codegen.load_import_from(utils.__name__, "dict_keys_getitem")
|
||||||
)
|
)
|
||||||
|
|
@ -579,11 +579,11 @@ class ConstDictKeySource(ChainedSource):
|
||||||
codegen.append_output(codegen.create_load_const(self.index))
|
codegen.append_output(codegen.create_load_const(self.index))
|
||||||
codegen.extend_output(create_call_function(2, False))
|
codegen.extend_output(create_call_function(2, False))
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
# The list creation will be CSE'd by PyExprCSEPass
|
# The list creation will be CSE'd by PyExprCSEPass
|
||||||
return f"list(dict.keys({self.base.name()}))[{self.index!r}]"
|
return f"list(dict.keys({self.base.name()}))[{self.index!r}]"
|
||||||
|
|
||||||
def is_dict_key(self):
|
def is_dict_key(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -591,15 +591,15 @@ class ConstDictKeySource(ChainedSource):
|
||||||
class NonSerializableSetGetItemSource(ChainedSource):
|
class NonSerializableSetGetItemSource(ChainedSource):
|
||||||
index: int
|
index: int
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
from .variables import ConstantVariable
|
from .variables import ConstantVariable
|
||||||
|
|
||||||
assert ConstantVariable.is_literal(self.index)
|
assert ConstantVariable.is_literal(self.index)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.add_push_null(
|
codegen.add_push_null(
|
||||||
lambda: codegen.load_import_from(utils.__name__, "set_getitem")
|
lambda: codegen.load_import_from(utils.__name__, "set_getitem")
|
||||||
)
|
)
|
||||||
|
|
@ -607,11 +607,11 @@ class NonSerializableSetGetItemSource(ChainedSource):
|
||||||
codegen.append_output(codegen.create_load_const(self.index))
|
codegen.append_output(codegen.create_load_const(self.index))
|
||||||
codegen.extend_output(create_call_function(2, False))
|
codegen.extend_output(create_call_function(2, False))
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
# set ordering might not be stable
|
# set ordering might not be stable
|
||||||
return f"list({self.base.name()})[{self.index!r}]"
|
return f"list({self.base.name()})[{self.index!r}]"
|
||||||
|
|
||||||
def is_dict_key(self):
|
def is_dict_key(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -623,17 +623,17 @@ class DictGetItemSource(ChainedSource):
|
||||||
# 2) constant - like string, integer
|
# 2) constant - like string, integer
|
||||||
index: Any
|
index: Any
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
from .variables import ConstantVariable
|
from .variables import ConstantVariable
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
self.index, ConstDictKeySource
|
self.index, ConstDictKeySource
|
||||||
) or ConstantVariable.is_literal(self.index)
|
) or ConstantVariable.is_literal(self.index)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
# Load dict
|
# Load dict
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
|
|
||||||
|
|
@ -644,7 +644,7 @@ class DictGetItemSource(ChainedSource):
|
||||||
codegen.append_output(codegen.create_load_const(self.index))
|
codegen.append_output(codegen.create_load_const(self.index))
|
||||||
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
codegen.append_output(create_instruction("BINARY_SUBSCR"))
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
if isinstance(self.index, ConstDictKeySource):
|
if isinstance(self.index, ConstDictKeySource):
|
||||||
return f"{self.base.name()}[{self.index.name()}]"
|
return f"{self.base.name()}[{self.index.name()}]"
|
||||||
else:
|
else:
|
||||||
|
|
@ -660,17 +660,17 @@ class DictSubclassGetItemSource(ChainedSource):
|
||||||
# 2) constant - like string, integer
|
# 2) constant - like string, integer
|
||||||
index: Any
|
index: Any
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
from .variables import ConstantVariable
|
from .variables import ConstantVariable
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
self.index, ConstDictKeySource
|
self.index, ConstDictKeySource
|
||||||
) or ConstantVariable.is_literal(self.index)
|
) or ConstantVariable.is_literal(self.index)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
# reconstruct dict.__getitem__(dct, key)
|
# reconstruct dict.__getitem__(dct, key)
|
||||||
|
|
||||||
# Load dict.__getitem__
|
# Load dict.__getitem__
|
||||||
|
|
@ -689,7 +689,7 @@ class DictSubclassGetItemSource(ChainedSource):
|
||||||
|
|
||||||
codegen.extend_output(create_call_function(2, False))
|
codegen.extend_output(create_call_function(2, False))
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
if isinstance(self.index, ConstDictKeySource):
|
if isinstance(self.index, ConstDictKeySource):
|
||||||
return f"dict.__getitem__({self.base.name()}, {self.index.name()})"
|
return f"dict.__getitem__({self.base.name()}, {self.index.name()})"
|
||||||
else:
|
else:
|
||||||
|
|
@ -702,7 +702,7 @@ class ListGetItemSource(GetItemSource):
|
||||||
Same as GetItemSource with reconstruct and name overridden to be list specific.
|
Same as GetItemSource with reconstruct and name overridden to be list specific.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
# Reconstruct list.__getitem__(lst, index) to avoid any side effects
|
# Reconstruct list.__getitem__(lst, index) to avoid any side effects
|
||||||
# from possibly overridden __getitem__.
|
# from possibly overridden __getitem__.
|
||||||
|
|
||||||
|
|
@ -724,7 +724,7 @@ class ListGetItemSource(GetItemSource):
|
||||||
|
|
||||||
codegen.extend_output(create_call_function(2, False))
|
codegen.extend_output(create_call_function(2, False))
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
# Index can be of following types
|
# Index can be of following types
|
||||||
# 1) index is a slice - example 1:4
|
# 1) index is a slice - example 1:4
|
||||||
# 2) index is a constant - example string, integer
|
# 2) index is a constant - example string, integer
|
||||||
|
|
@ -739,7 +739,7 @@ class ListGetItemSource(GetItemSource):
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class TupleIteratorGetItemSource(GetItemSource):
|
class TupleIteratorGetItemSource(GetItemSource):
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.add_push_null(
|
codegen.add_push_null(
|
||||||
lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
|
lambda: codegen.load_import_from(utils.__name__, "tuple_iterator_getitem")
|
||||||
)
|
)
|
||||||
|
|
@ -747,91 +747,91 @@ class TupleIteratorGetItemSource(GetItemSource):
|
||||||
codegen.append_output(codegen.create_load_const(self.index))
|
codegen.append_output(codegen.create_load_const(self.index))
|
||||||
codegen.extend_output(create_call_function(2, False))
|
codegen.extend_output(create_call_function(2, False))
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
return f"___tuple_iterator_getitem({self.base.name()}, {self.index!r})"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class DataclassFieldsSource(ChainedSource):
|
class DataclassFieldsSource(ChainedSource):
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.add_push_null(
|
codegen.add_push_null(
|
||||||
lambda: codegen.load_import_from(utils.__name__, "dataclass_fields")
|
lambda: codegen.load_import_from(utils.__name__, "dataclass_fields")
|
||||||
)
|
)
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
codegen.extend_output(create_call_function(1, False))
|
codegen.extend_output(create_call_function(1, False))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"___dataclass_fields({self.base.name()})"
|
return f"___dataclass_fields({self.base.name()})"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class TypeSource(ChainedSource):
|
class TypeSource(ChainedSource):
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.base is not None
|
assert self.base is not None
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
|
codegen.add_push_null(lambda: codegen.load_import_from("builtins", "type"))
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
codegen.extend_output(create_call_function(1, False))
|
codegen.extend_output(create_call_function(1, False))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"type({self.base.name()})"
|
return f"type({self.base.name()})"
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class OptimizerSource(ChainedSource):
|
class OptimizerSource(ChainedSource):
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return self.base.name()
|
return self.base.name()
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class NNModuleSource(ChainedSource):
|
class NNModuleSource(ChainedSource):
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
|
return _GUARD_SOURCE_SPECIALIZED_NN_MODULE[self.base.guard_source()]
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return self.base.name()
|
return self.base.name()
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class UnspecializedNNModuleSource(NNModuleSource):
|
class UnspecializedNNModuleSource(NNModuleSource):
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()]
|
return _GUARD_SOURCE_UNSPECIALIZED_NN_MODULE[self.base.guard_source()]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource):
|
class UnspecializedBuiltinNNModuleSource(UnspecializedNNModuleSource):
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()]
|
return _GUARD_SOURCE_UNSPECIALIZED_BUILTIN_NN_MODULE[self.base.guard_source()]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class FSDPNNModuleSource(NNModuleSource):
|
class FSDPNNModuleSource(NNModuleSource):
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
|
return _GUARD_SOURCE_FSDP_MODULE[self.base.guard_source()]
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class GlobalStateSource(Source):
|
class GlobalStateSource(Source):
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.GLOBAL
|
return GuardSource.GLOBAL
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -840,16 +840,16 @@ class TorchSource(Source):
|
||||||
"""Points to the actual `torch` module - used instead of GlobalSource
|
"""Points to the actual `torch` module - used instead of GlobalSource
|
||||||
in case the user has overridden `torch` in their local namespace"""
|
in case the user has overridden `torch` in their local namespace"""
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
from .guards import GuardBuilder, install_guard
|
from .guards import GuardBuilder, install_guard
|
||||||
|
|
||||||
install_guard(self.make_guard(GuardBuilder.ID_MATCH))
|
install_guard(self.make_guard(GuardBuilder.ID_MATCH))
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return "__import__('torch')"
|
return "__import__('torch')"
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.extend_output(
|
codegen.extend_output(
|
||||||
[
|
[
|
||||||
codegen.create_load_const(0), # level
|
codegen.create_load_const(0), # level
|
||||||
|
|
@ -858,7 +858,7 @@ class TorchSource(Source):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.GLOBAL
|
return GuardSource.GLOBAL
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -866,15 +866,15 @@ class TorchSource(Source):
|
||||||
class TorchFunctionModeStackSource(Source):
|
class TorchFunctionModeStackSource(Source):
|
||||||
ind: int
|
ind: int
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return f"___get_torch_function_mode_stack_at({self._get_index()})"
|
return f"___get_torch_function_mode_stack_at({self._get_index()})"
|
||||||
|
|
||||||
def _get_index(self):
|
def _get_index(self) -> int:
|
||||||
from .variables.torch_function import TorchFunctionModeStackVariable
|
from .variables.torch_function import TorchFunctionModeStackVariable
|
||||||
|
|
||||||
return TorchFunctionModeStackVariable.get_mode_index(self.ind)
|
return TorchFunctionModeStackVariable.get_mode_index(self.ind)
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.add_push_null(
|
codegen.add_push_null(
|
||||||
lambda: codegen.load_import_from(
|
lambda: codegen.load_import_from(
|
||||||
utils.__name__, "get_torch_function_mode_stack_at"
|
utils.__name__, "get_torch_function_mode_stack_at"
|
||||||
|
|
@ -883,7 +883,7 @@ class TorchFunctionModeStackSource(Source):
|
||||||
codegen.extend_output([codegen.create_load_const(self._get_index())])
|
codegen.extend_output([codegen.create_load_const(self._get_index())])
|
||||||
codegen.extend_output(create_call_function(1, False))
|
codegen.extend_output(create_call_function(1, False))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.GLOBAL
|
return GuardSource.GLOBAL
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -891,16 +891,16 @@ class TorchFunctionModeStackSource(Source):
|
||||||
class ConstantSource(Source):
|
class ConstantSource(Source):
|
||||||
source_name: str
|
source_name: str
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.append_output(codegen.create_load_global(self.source_name, add=False))
|
codegen.append_output(codegen.create_load_global(self.source_name, add=False))
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.CONSTANT
|
return GuardSource.CONSTANT
|
||||||
|
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return self.source_name
|
return self.source_name
|
||||||
|
|
||||||
def make_guard(self, fn):
|
def make_guard(self, fn: Any) -> Any:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -909,10 +909,10 @@ class NumpyTensorSource(ChainedSource):
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return f"___from_numpy({self.base.name()})"
|
return f"___from_numpy({self.base.name()})"
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
def reconstruct(self, codegen: "PyCodegen"):
|
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||||
codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
|
codegen.add_push_null(lambda: codegen.load_import_from("torch", "as_tensor"))
|
||||||
codegen(self.base)
|
codegen(self.base)
|
||||||
codegen.extend_output(create_call_function(1, False))
|
codegen.extend_output(create_call_function(1, False))
|
||||||
|
|
@ -923,7 +923,7 @@ class SubclassAttrListSource(ChainedSource):
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return f"{self.base.name()}.__tensor_flatten__()[0]"
|
return f"{self.base.name()}.__tensor_flatten__()[0]"
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -934,7 +934,7 @@ class FloatTensorSource(ChainedSource):
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return f"___as_tensor({self.base.name()})"
|
return f"___as_tensor({self.base.name()})"
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -943,7 +943,7 @@ class CallMethodItemSource(ChainedSource):
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return f"{self.base.name()}.item()"
|
return f"{self.base.name()}.item()"
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return self.base.guard_source()
|
return self.base.guard_source()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -952,23 +952,25 @@ class CallMethodItemSource(ChainedSource):
|
||||||
# guard contents from the ambient ShapeEnv
|
# guard contents from the ambient ShapeEnv
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class ShapeEnvSource(Source):
|
class ShapeEnvSource(Source):
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.SHAPE_ENV
|
return GuardSource.SHAPE_ENV
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class BackwardStateSource(Source):
|
class BackwardStateSource(Source):
|
||||||
def name(self):
|
def name(self) -> str:
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def guard_source(self):
|
def guard_source(self) -> GuardSource:
|
||||||
return GuardSource.BACKWARD_STATE
|
return GuardSource.BACKWARD_STATE
|
||||||
|
|
||||||
|
|
||||||
def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional[str]:
|
def get_local_source_name(
|
||||||
|
source: Source, *, only_allow_input: bool = False
|
||||||
|
) -> Optional[str]:
|
||||||
if isinstance(source, ChainedSource):
|
if isinstance(source, ChainedSource):
|
||||||
return get_local_source_name(source.base, only_allow_input=only_allow_input)
|
return get_local_source_name(source.base, only_allow_input=only_allow_input)
|
||||||
if not isinstance(source, LocalSource):
|
if not isinstance(source, LocalSource):
|
||||||
|
|
@ -978,7 +980,7 @@ def get_local_source_name(source: Source, *, only_allow_input=False) -> Optional
|
||||||
return source.local_name
|
return source.local_name
|
||||||
|
|
||||||
|
|
||||||
def is_from_local_source(source: Source, *, only_allow_input=False):
|
def is_from_local_source(source: Source, *, only_allow_input: bool = False) -> bool:
|
||||||
return get_local_source_name(source, only_allow_input=only_allow_input) is not None
|
return get_local_source_name(source, only_allow_input=only_allow_input) is not None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -994,7 +996,7 @@ def get_global_source_name(source: Source) -> Optional[str]:
|
||||||
return source.global_name
|
return source.global_name
|
||||||
|
|
||||||
|
|
||||||
def is_from_nonlocal_source(source: Source):
|
def is_from_nonlocal_source(source: Source) -> bool:
|
||||||
if isinstance(source, ChainedSource):
|
if isinstance(source, ChainedSource):
|
||||||
return is_from_nonlocal_source(source.base)
|
return is_from_nonlocal_source(source.base)
|
||||||
return (
|
return (
|
||||||
|
|
@ -1004,14 +1006,14 @@ def is_from_nonlocal_source(source: Source):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_from_source(source: Source, target: Source):
|
def is_from_source(source: Source, target: Source) -> bool:
|
||||||
if isinstance(source, ChainedSource):
|
if isinstance(source, ChainedSource):
|
||||||
return is_from_source(source.base, target)
|
return is_from_source(source.base, target)
|
||||||
return source == target
|
return source == target
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def is_from_unspecialized_nn_module_source(source: Source):
|
def is_from_unspecialized_nn_module_source(source: Source) -> bool:
|
||||||
if isinstance(source, UnspecializedNNModuleSource):
|
if isinstance(source, UnspecializedNNModuleSource):
|
||||||
return True
|
return True
|
||||||
if isinstance(source, ChainedSource):
|
if isinstance(source, ChainedSource):
|
||||||
|
|
@ -1020,7 +1022,7 @@ def is_from_unspecialized_nn_module_source(source: Source):
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def is_from_unspecialized_builtin_nn_module_source(source: Source):
|
def is_from_unspecialized_builtin_nn_module_source(source: Source) -> bool:
|
||||||
if isinstance(source, UnspecializedBuiltinNNModuleSource):
|
if isinstance(source, UnspecializedBuiltinNNModuleSource):
|
||||||
return True
|
return True
|
||||||
if isinstance(source, ChainedSource):
|
if isinstance(source, ChainedSource):
|
||||||
|
|
@ -1029,7 +1031,7 @@ def is_from_unspecialized_builtin_nn_module_source(source: Source):
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def is_from_unspecialized_param_buffer_source(source: Source):
|
def is_from_unspecialized_param_buffer_source(source: Source) -> bool:
|
||||||
if isinstance(source, UnspecializedParamBufferSource):
|
if isinstance(source, UnspecializedParamBufferSource):
|
||||||
return True
|
return True
|
||||||
if isinstance(source, ChainedSource):
|
if isinstance(source, ChainedSource):
|
||||||
|
|
@ -1038,7 +1040,7 @@ def is_from_unspecialized_param_buffer_source(source: Source):
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def is_from_flatten_script_object_source(source: Source):
|
def is_from_flatten_script_object_source(source: Source) -> bool:
|
||||||
if isinstance(source, FlattenScriptObjectSource):
|
if isinstance(source, FlattenScriptObjectSource):
|
||||||
return True
|
return True
|
||||||
elif isinstance(source, ChainedSource):
|
elif isinstance(source, ChainedSource):
|
||||||
|
|
@ -1047,7 +1049,7 @@ def is_from_flatten_script_object_source(source: Source):
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def is_from_optimizer_source(source: Source):
|
def is_from_optimizer_source(source: Source) -> bool:
|
||||||
if isinstance(source, OptimizerSource):
|
if isinstance(source, OptimizerSource):
|
||||||
return True
|
return True
|
||||||
if isinstance(source, ChainedSource):
|
if isinstance(source, ChainedSource):
|
||||||
|
|
@ -1058,7 +1060,7 @@ def is_from_optimizer_source(source: Source):
|
||||||
# TODO: can probably write a generic "test this on everything in the chain"
|
# TODO: can probably write a generic "test this on everything in the chain"
|
||||||
# helper
|
# helper
|
||||||
@functools.lru_cache
|
@functools.lru_cache
|
||||||
def is_from_defaults(source: Source):
|
def is_from_defaults(source: Source) -> bool:
|
||||||
if isinstance(source, DefaultsSource):
|
if isinstance(source, DefaultsSource):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2644,7 +2644,9 @@ def set_example_value(node, example_value):
|
||||||
# this to accurately reflect what the state of the value was at the time
|
# this to accurately reflect what the state of the value was at the time
|
||||||
# the program was traced).
|
# the program was traced).
|
||||||
node.meta["example_value"] = example_value
|
node.meta["example_value"] = example_value
|
||||||
shape_env = TracingContext.get().fake_mode.shape_env
|
fake_mode = TracingContext.get().fake_mode
|
||||||
|
assert fake_mode is not None
|
||||||
|
shape_env = fake_mode.shape_env
|
||||||
if (
|
if (
|
||||||
symbol_to_path
|
symbol_to_path
|
||||||
:= torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
|
:= torch.fx.experimental.symbolic_shapes.compute_unbacked_bindings(
|
||||||
|
|
@ -4765,7 +4767,7 @@ def record_pregraph_bytecode_exit(cm: AbstractContextManager[None]) -> None:
|
||||||
|
|
||||||
# Returns a set of code objects present traced in the current TracingContext, or None
|
# Returns a set of code objects present traced in the current TracingContext, or None
|
||||||
# if there is no current TracingContext.
|
# if there is no current TracingContext.
|
||||||
def get_traced_code() -> list[CodeType]:
|
def get_traced_code() -> Optional[list[CodeType]]:
|
||||||
from torch._guards import TracingContext
|
from torch._guards import TracingContext
|
||||||
|
|
||||||
return TracingContext.get_traced_code()
|
return TracingContext.get_traced_code()
|
||||||
|
|
|
||||||
|
|
@ -365,6 +365,7 @@ def make_fake_inputs(
|
||||||
# a toplevel TracingContext with a fake mode, so we do not want to
|
# a toplevel TracingContext with a fake mode, so we do not want to
|
||||||
# create another fake mode.
|
# create another fake mode.
|
||||||
fake_mode = context.fake_mode
|
fake_mode = context.fake_mode
|
||||||
|
assert fake_mode is not None
|
||||||
else:
|
else:
|
||||||
if isinstance(nn_module.forward, functools.partial):
|
if isinstance(nn_module.forward, functools.partial):
|
||||||
# functools handles nesting by itself, no need to recurse
|
# functools handles nesting by itself, no need to recurse
|
||||||
|
|
@ -852,7 +853,7 @@ def _fakify_script_objects(
|
||||||
mod: torch.nn.Module,
|
mod: torch.nn.Module,
|
||||||
args: Sequence[Any],
|
args: Sequence[Any],
|
||||||
kwargs: dict[Any, Any],
|
kwargs: dict[Any, Any],
|
||||||
fake_mode: torch._subclasses.fake_tensor.FakeTensorMode,
|
fake_mode: Optional[torch._subclasses.fake_tensor.FakeTensorMode],
|
||||||
):
|
):
|
||||||
# This context manager is used to fakify script objects into FakeScriptObject.
|
# This context manager is used to fakify script objects into FakeScriptObject.
|
||||||
# Inputs:
|
# Inputs:
|
||||||
|
|
|
||||||
|
|
@ -1129,7 +1129,7 @@ def remove_proxy_from_state_dict(state_dict: dict, in_place: bool) -> dict:
|
||||||
|
|
||||||
def _detect_fake_mode_from_gm(
|
def _detect_fake_mode_from_gm(
|
||||||
gm: torch.fx.GraphModule,
|
gm: torch.fx.GraphModule,
|
||||||
) -> torch._subclasses.fake_tensor.FakeTensorMode:
|
) -> Optional[torch._subclasses.fake_tensor.FakeTensorMode]:
|
||||||
"""
|
"""
|
||||||
For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs.
|
For a given graph module, we look at the "val" of placeholder nodes to find the fake inputs.
|
||||||
Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes.
|
Additionally, if gm doesn't have placeholders, we further look at the "example_value" or "val" of other nodes.
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ It does so by:
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import warnings
|
import warnings
|
||||||
from contextlib import contextmanager, ExitStack, nullcontext
|
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
from typing import Any, Callable, cast, Optional, TypeVar, Union
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
@ -441,9 +441,10 @@ def create_functionalized_rng_ops_wrapper(
|
||||||
# It goes from (primals, tangents) to (seed, offset, primals, tangents)
|
# It goes from (primals, tangents) to (seed, offset, primals, tangents)
|
||||||
# At runtime, we pass on the current seed and offset. This is hidden from
|
# At runtime, we pass on the current seed and offset. This is hidden from
|
||||||
# the user.
|
# the user.
|
||||||
fake_mode = detect_fake_mode()
|
fake_mode_det = detect_fake_mode()
|
||||||
if fake_mode is None:
|
fake_mode: AbstractContextManager[Any] = nullcontext()
|
||||||
fake_mode = nullcontext()
|
if fake_mode_det is not None:
|
||||||
|
fake_mode = fake_mode_det
|
||||||
|
|
||||||
def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"):
|
def override_get_rng_state(device: Union[int, str, torch.device] = "cuda"):
|
||||||
out = PhiloxStateTracker.get_state_as_tensor()
|
out = PhiloxStateTracker.get_state_as_tensor()
|
||||||
|
|
@ -1343,7 +1344,9 @@ def create_functional_call(
|
||||||
"ignore", "Anomaly Detection has been enabled."
|
"ignore", "Anomaly Detection has been enabled."
|
||||||
)
|
)
|
||||||
with torch.autograd.detect_anomaly(check_nan=False):
|
with torch.autograd.detect_anomaly(check_nan=False):
|
||||||
detect_fake_mode().epoch += 1
|
fake_mode = detect_fake_mode()
|
||||||
|
assert fake_mode is not None
|
||||||
|
fake_mode.epoch += 1
|
||||||
out = PropagateUnbackedSymInts(mod).run(
|
out = PropagateUnbackedSymInts(mod).run(
|
||||||
*args[params_len:], **kwargs
|
*args[params_len:], **kwargs
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -306,11 +306,12 @@ def compute_overlapping_inputs(aot_config, fwd_inputs, aliased_input_indices):
|
||||||
tracing_context = torch._guards.TracingContext.try_get()
|
tracing_context = torch._guards.TracingContext.try_get()
|
||||||
|
|
||||||
if tracing_context is not None:
|
if tracing_context is not None:
|
||||||
|
assert tracing_context.fake_mode is not None
|
||||||
shape_env = tracing_context.fake_mode.shape_env
|
shape_env = tracing_context.fake_mode.shape_env
|
||||||
|
|
||||||
# Check whether we can actually get the dynamo sources from within AOTAutograd.
|
# Check whether we can actually get the dynamo sources from within AOTAutograd.
|
||||||
if aot_config.aot_autograd_arg_pos_to_source and shape_env is not None:
|
if aot_config.aot_autograd_arg_pos_to_source and shape_env is not None:
|
||||||
maybe_suppress_guards = shape_env.suppress_guards
|
maybe_suppress_guards = shape_env.suppress_guards # type: ignore[assignment]
|
||||||
|
|
||||||
# Check whether there are any symbolic values being used.
|
# Check whether there are any symbolic values being used.
|
||||||
# We do this for 2 reasons:
|
# We do this for 2 reasons:
|
||||||
|
|
|
||||||
|
|
@ -495,6 +495,7 @@ class FunctionalizedRngRuntimeWrapper(InductorWrapper):
|
||||||
if config.functionalize_rng_ops:
|
if config.functionalize_rng_ops:
|
||||||
# Update example inputs for the fw_compiler
|
# Update example inputs for the fw_compiler
|
||||||
fake_mode = detect_fake_mode()
|
fake_mode = detect_fake_mode()
|
||||||
|
assert fake_mode is not None
|
||||||
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
|
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
|
||||||
flat_args.extend([seed, offset])
|
flat_args.extend([seed, offset])
|
||||||
# We are not clearing flat_args here because
|
# We are not clearing flat_args here because
|
||||||
|
|
|
||||||
218
torch/_guards.py
218
torch/_guards.py
|
|
@ -1,4 +1,3 @@
|
||||||
# mypy: allow-untyped-defs
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import contextlib
|
import contextlib
|
||||||
|
|
@ -37,10 +36,15 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
from collections.abc import Generator, Iterator
|
||||||
from types import CodeType
|
from types import CodeType
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
||||||
|
from torch._dynamo.codegen import PyCodegen
|
||||||
|
from torch._functorch._aot_autograd.schemas import ViewAndMutationMeta
|
||||||
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
torch._guards is the definitional source of truth for general purpose guard structures.
|
torch._guards is the definitional source of truth for general purpose guard structures.
|
||||||
|
|
@ -83,7 +87,7 @@ class CompileId:
|
||||||
# TODO: consider also tracking the recompilation count
|
# TODO: consider also tracking the recompilation count
|
||||||
# See Note: Updating CompileId
|
# See Note: Updating CompileId
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
# NOTE: Keep this in sync with both from_string and the tlparse repo
|
# NOTE: Keep this in sync with both from_string and the tlparse repo
|
||||||
if self.compiled_autograd_id is not None:
|
if self.compiled_autograd_id is not None:
|
||||||
assert (self.frame_id is None) == (self.frame_compile_id is None)
|
assert (self.frame_id is None) == (self.frame_compile_id is None)
|
||||||
|
|
@ -97,7 +101,7 @@ class CompileId:
|
||||||
return f"{self.frame_id}/{self.frame_compile_id}"
|
return f"{self.frame_id}/{self.frame_compile_id}"
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_string(cls, compile_id: Optional[str]):
|
def from_string(cls, compile_id: Optional[str]) -> Optional[CompileId]:
|
||||||
"""
|
"""
|
||||||
Factory method that creates a CompileId from its string representation.
|
Factory method that creates a CompileId from its string representation.
|
||||||
Keep this in sync with the __str__ method.
|
Keep this in sync with the __str__ method.
|
||||||
|
|
@ -125,7 +129,7 @@ class TraceId(NamedTuple):
|
||||||
# up by one
|
# up by one
|
||||||
attempt: int
|
attempt: int
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
# Keep this in sync with tlparse repo
|
# Keep this in sync with tlparse repo
|
||||||
if self.attempt == 0:
|
if self.attempt == 0:
|
||||||
return str(self.compile_id)
|
return str(self.compile_id)
|
||||||
|
|
@ -185,7 +189,7 @@ class GuardSource(enum.Enum):
|
||||||
GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
|
GuardSource.LOCAL_UNSPECIALIZED_BUILTIN_NN_MODULE,
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_local(self):
|
def is_local(self) -> bool:
|
||||||
return self in (
|
return self in (
|
||||||
GuardSource.LOCAL,
|
GuardSource.LOCAL,
|
||||||
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
GuardSource.LOCAL_SPECIALIZED_NN_MODULE,
|
||||||
|
|
@ -218,7 +222,7 @@ class SLoc:
|
||||||
framework_loc: Optional[Union[traceback.FrameSummary, str]]
|
framework_loc: Optional[Union[traceback.FrameSummary, str]]
|
||||||
maybe_user_loc: Optional[str]
|
maybe_user_loc: Optional[str]
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
floc = (
|
floc = (
|
||||||
self.framework_loc
|
self.framework_loc
|
||||||
if isinstance(self.framework_loc, str)
|
if isinstance(self.framework_loc, str)
|
||||||
|
|
@ -257,7 +261,7 @@ class Guard:
|
||||||
# it is meaningless. Example create_fns that are like this include
|
# it is meaningless. Example create_fns that are like this include
|
||||||
# GRAD_MODE and SHAPE_ENV.
|
# GRAD_MODE and SHAPE_ENV.
|
||||||
originating_source: Source
|
originating_source: Source
|
||||||
create_fn: Callable[[GuardBuilderBase, Guard], None]
|
create_fn: Callable[[GuardBuilderBase, Guard], Any]
|
||||||
|
|
||||||
# Export only. These values are written to at time of guard check_fn creation.
|
# Export only. These values are written to at time of guard check_fn creation.
|
||||||
guard_types: Optional[list[str]] = None
|
guard_types: Optional[list[str]] = None
|
||||||
|
|
@ -269,12 +273,12 @@ class Guard:
|
||||||
user_stack: Optional[traceback.StackSummary] = None
|
user_stack: Optional[traceback.StackSummary] = None
|
||||||
_hash: Optional[int] = None
|
_hash: Optional[int] = None
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self) -> int:
|
||||||
if self._hash is None:
|
if self._hash is None:
|
||||||
self._hash = hash((self.name, self.source, id(self.create_fn)))
|
self._hash = hash((self.name, self.source, id(self.create_fn)))
|
||||||
return self._hash
|
return self._hash
|
||||||
|
|
||||||
def sort_key(self):
|
def sort_key(self) -> tuple[bool, int, int, str, int]:
|
||||||
# Put the duplicate input guards at the end. The duplicate guards have
|
# Put the duplicate input guards at the end. The duplicate guards have
|
||||||
# two sources while guard.name only considers one source.
|
# two sources while guard.name only considers one source.
|
||||||
|
|
||||||
|
|
@ -290,10 +294,10 @@ class Guard:
|
||||||
self.inner_create_fn().__code__.co_firstlineno,
|
self.inner_create_fn().__code__.co_firstlineno,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __lt__(self, other):
|
def __lt__(self, other: Guard) -> bool:
|
||||||
return self.sort_key() < other.sort_key()
|
return self.sort_key() < other.sort_key()
|
||||||
|
|
||||||
def inner_create_fn(self):
|
def inner_create_fn(self) -> Callable[[GuardBuilderBase, Guard], Any]:
|
||||||
if isinstance(self.create_fn, functools.partial):
|
if isinstance(self.create_fn, functools.partial):
|
||||||
return self.create_fn.func
|
return self.create_fn.func
|
||||||
else:
|
else:
|
||||||
|
|
@ -308,7 +312,7 @@ class Guard:
|
||||||
return self.originating_source.guard_source()
|
return self.originating_source.guard_source()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def weakref_to_str(obj_weakref):
|
def weakref_to_str(obj_weakref: object) -> str:
|
||||||
"""
|
"""
|
||||||
This is a workaround of a Python weakref bug.
|
This is a workaround of a Python weakref bug.
|
||||||
|
|
||||||
|
|
@ -332,7 +336,7 @@ class Guard:
|
||||||
else:
|
else:
|
||||||
return str(obj_weakref)
|
return str(obj_weakref)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self) -> str:
|
||||||
s = f"""
|
s = f"""
|
||||||
{self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
|
{self.source.name.lower() if self.source else ""} {repr(self.name)} {self.inner_create_fn().__name__}
|
||||||
{{
|
{{
|
||||||
|
|
@ -344,7 +348,7 @@ class Guard:
|
||||||
"""
|
"""
|
||||||
return s
|
return s
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self) -> str:
|
||||||
output = f"Name: {repr(self.name)}\n"
|
output = f"Name: {repr(self.name)}\n"
|
||||||
source = self.source.name.lower() if self.source else ""
|
source = self.source.name.lower() if self.source else ""
|
||||||
output += f" Source: {source}\n"
|
output += f" Source: {source}\n"
|
||||||
|
|
@ -355,7 +359,7 @@ class Guard:
|
||||||
output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
|
output += f" Guarded Class Weakref: {self.guarded_class_weakref}\n"
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def create(self, builder: GuardBuilderBase):
|
def create(self, builder: GuardBuilderBase) -> Any:
|
||||||
try:
|
try:
|
||||||
return self.create_fn(builder, self)
|
return self.create_fn(builder, self)
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -364,16 +368,22 @@ class Guard:
|
||||||
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
|
log.error("Created at:\n%s", "".join(self.stack.format()[-4:]).rstrip())
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def is_specialized_nn_module(self):
|
def is_specialized_nn_module(self) -> bool:
|
||||||
return self.source.is_specialized_nn_module()
|
return self.source.is_specialized_nn_module()
|
||||||
|
|
||||||
def is_fsdp_module(self):
|
def is_fsdp_module(self) -> bool:
|
||||||
return self.source.is_fsdp_module()
|
return self.source.is_fsdp_module()
|
||||||
|
|
||||||
def is_local(self):
|
def is_local(self) -> bool:
|
||||||
return self.source.is_local()
|
return self.source.is_local()
|
||||||
|
|
||||||
def set_export_info(self, guard_type, guarded_class, code_list, obj_weakref):
|
def set_export_info(
|
||||||
|
self,
|
||||||
|
guard_type: str,
|
||||||
|
guarded_class: Optional[type],
|
||||||
|
code_list: list[str],
|
||||||
|
obj_weakref: object,
|
||||||
|
) -> None:
|
||||||
if not self.guard_types:
|
if not self.guard_types:
|
||||||
self.guard_types = []
|
self.guard_types = []
|
||||||
|
|
||||||
|
|
@ -428,7 +438,7 @@ class DuplicateInputs(GuardEnvExpr):
|
||||||
input_source_a: Source
|
input_source_a: Source
|
||||||
input_source_b: Source
|
input_source_b: Source
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self) -> None:
|
||||||
assert self.input_source_a != self.input_source_b
|
assert self.input_source_a != self.input_source_b
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -470,7 +480,7 @@ class Checkpointable(Generic[T]):
|
||||||
def copy_graphstate(self) -> T: ...
|
def copy_graphstate(self) -> T: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def restore_graphstate(self, state: T): ...
|
def restore_graphstate(self, state: T) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
class GuardsCheckpointState:
|
class GuardsCheckpointState:
|
||||||
|
|
@ -480,10 +490,10 @@ class GuardsCheckpointState:
|
||||||
|
|
||||||
dynamo_guards: set[Guard] = set()
|
dynamo_guards: set[Guard] = set()
|
||||||
|
|
||||||
def __init__(self, dynamo_guards):
|
def __init__(self, dynamo_guards: set[Guard]) -> None:
|
||||||
self.dynamo_guards = dynamo_guards
|
self.dynamo_guards = dynamo_guards
|
||||||
|
|
||||||
def diff(self, other):
|
def diff(self, other: GuardsCheckpointState) -> Optional[set[Guard]]:
|
||||||
"""
|
"""
|
||||||
Produces a delta against another GuardsCheckpointState.
|
Produces a delta against another GuardsCheckpointState.
|
||||||
|
|
||||||
|
|
@ -495,17 +505,19 @@ class GuardsCheckpointState:
|
||||||
return None
|
return None
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, GuardsCheckpointState):
|
||||||
|
return False
|
||||||
return self.diff(other) is None
|
return self.diff(other) is None
|
||||||
|
|
||||||
|
|
||||||
class ModuleContextCheckpointState:
|
class ModuleContextCheckpointState:
|
||||||
nn_modules: dict[str, torch.nn.Module] = {}
|
nn_modules: dict[str, torch.nn.Module] = {}
|
||||||
|
|
||||||
def __init__(self, nn_modules):
|
def __init__(self, nn_modules: dict[str, torch.nn.Module]) -> None:
|
||||||
self.nn_modules = nn_modules
|
self.nn_modules = nn_modules
|
||||||
|
|
||||||
def diff(self, other):
|
def diff(self, other: ModuleContextCheckpointState) -> Optional[set[str]]:
|
||||||
"""
|
"""
|
||||||
Produces a delta against another ModuleContextCheckpointState.
|
Produces a delta against another ModuleContextCheckpointState.
|
||||||
|
|
||||||
|
|
@ -517,7 +529,9 @@ class ModuleContextCheckpointState:
|
||||||
return None
|
return None
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, ModuleContextCheckpointState):
|
||||||
|
return False
|
||||||
return self.diff(other) is None
|
return self.diff(other) is None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -525,21 +539,21 @@ class ModuleContext(Checkpointable[ModuleContextCheckpointState]):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.nn_modules: dict[str, Any] = {}
|
self.nn_modules: dict[str, Any] = {}
|
||||||
|
|
||||||
def copy_graphstate(self):
|
def copy_graphstate(self) -> ModuleContextCheckpointState:
|
||||||
return ModuleContextCheckpointState(dict(self.nn_modules))
|
return ModuleContextCheckpointState(dict(self.nn_modules))
|
||||||
|
|
||||||
def restore_graphstate(self, state):
|
def restore_graphstate(self, state: ModuleContextCheckpointState) -> None:
|
||||||
assert isinstance(state, ModuleContextCheckpointState)
|
assert isinstance(state, ModuleContextCheckpointState)
|
||||||
self.nn_modules = state.nn_modules
|
self.nn_modules = state.nn_modules
|
||||||
|
|
||||||
|
|
||||||
class GlobalContextCheckpointState:
|
class GlobalContextCheckpointState:
|
||||||
global_state: dict[str, tuple[Callable, ...]] = {}
|
global_state: dict[str, tuple[Callable, Any]] = {}
|
||||||
|
|
||||||
def __init__(self, global_states):
|
def __init__(self, global_states: dict[str, tuple[Callable, Any]]) -> None:
|
||||||
self.global_state = global_states
|
self.global_state = global_states
|
||||||
|
|
||||||
def diff(self, other):
|
def diff(self, other: GlobalContextCheckpointState) -> Optional[set[str]]:
|
||||||
"""
|
"""
|
||||||
Produces a delta against another GlobalContextCheckpointState.
|
Produces a delta against another GlobalContextCheckpointState.
|
||||||
|
|
||||||
|
|
@ -551,7 +565,9 @@ class GlobalContextCheckpointState:
|
||||||
return None
|
return None
|
||||||
return r
|
return r
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other: object) -> bool:
|
||||||
|
if not isinstance(other, GlobalContextCheckpointState):
|
||||||
|
return False
|
||||||
return self.diff(other) is None
|
return self.diff(other) is None
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -571,12 +587,12 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self.global_state: dict[str, tuple[Callable, ...]] = {}
|
self.global_state: dict[str, tuple[Callable, Any]] = {}
|
||||||
|
|
||||||
def copy_graphstate(self):
|
def copy_graphstate(self) -> GlobalContextCheckpointState:
|
||||||
return GlobalContextCheckpointState(dict(self.global_state))
|
return GlobalContextCheckpointState(self.global_state)
|
||||||
|
|
||||||
def restore_graphstate(self, state):
|
def restore_graphstate(self, state: GlobalContextCheckpointState) -> None:
|
||||||
assert isinstance(state, GlobalContextCheckpointState)
|
assert isinstance(state, GlobalContextCheckpointState)
|
||||||
self.global_state = state.global_state
|
self.global_state = state.global_state
|
||||||
assert (
|
assert (
|
||||||
|
|
@ -590,26 +606,28 @@ class GlobalContext(Checkpointable[GlobalContextCheckpointState]):
|
||||||
# Like a Set[Guard] but will record the user stack on all guards at the
|
# Like a Set[Guard] but will record the user stack on all guards at the
|
||||||
# time they were installed at their destination
|
# time they were installed at their destination
|
||||||
class GuardsSet:
|
class GuardsSet:
|
||||||
def __init__(self, inner=None):
|
def __init__(self, inner: Optional[set[Guard]] = None) -> None:
|
||||||
if inner is None:
|
if inner is None:
|
||||||
inner = set()
|
inner = set()
|
||||||
self.inner = inner
|
self.inner = inner
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self) -> Iterator[Guard]:
|
||||||
return iter(self.inner)
|
return iter(self.inner)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self) -> int:
|
||||||
return len(self.inner)
|
return len(self.inner)
|
||||||
|
|
||||||
# Subtraction along with bool is typically used to determine the delta of
|
# Subtraction along with bool is typically used to determine the delta of
|
||||||
# added guards between checkpoints for higher order ops
|
# added guards between checkpoints for higher order ops
|
||||||
def __sub__(self, other):
|
def __sub__(self, other: GuardsSet) -> GuardsSet:
|
||||||
return GuardsSet(self.inner - other.inner)
|
return GuardsSet(self.inner - other.inner)
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self) -> bool:
|
||||||
return bool(self.inner)
|
return bool(self.inner)
|
||||||
|
|
||||||
def add(self, guard: Guard, *, collect_debug_stack=True, skip=0):
|
def add(
|
||||||
|
self, guard: Guard, *, collect_debug_stack: bool = True, skip: int = 0
|
||||||
|
) -> None:
|
||||||
if guard in self.inner:
|
if guard in self.inner:
|
||||||
return
|
return
|
||||||
if collect_debug_stack:
|
if collect_debug_stack:
|
||||||
|
|
@ -619,12 +637,12 @@ class GuardsSet:
|
||||||
guard.user_stack = TracingContext.extract_stack()
|
guard.user_stack = TracingContext.extract_stack()
|
||||||
self.inner.add(guard)
|
self.inner.add(guard)
|
||||||
|
|
||||||
def update(self, *others: set[Guard]):
|
def update(self, *others: set[Guard]) -> None:
|
||||||
for o in others:
|
for o in others:
|
||||||
for g in o:
|
for g in o:
|
||||||
self.add(g, skip=1)
|
self.add(g, skip=1)
|
||||||
|
|
||||||
def remove_guards_with_source(self, source):
|
def remove_guards_with_source(self, source: Source) -> None:
|
||||||
"""Delete all guards that contains a given source"""
|
"""Delete all guards that contains a given source"""
|
||||||
from ._dynamo.source import is_from_source
|
from ._dynamo.source import is_from_source
|
||||||
|
|
||||||
|
|
@ -646,10 +664,10 @@ class GuardsContext(Checkpointable[GuardsCheckpointState]):
|
||||||
self.dynamo_guards: GuardsSet = GuardsSet()
|
self.dynamo_guards: GuardsSet = GuardsSet()
|
||||||
self.aotautograd_guards: list[GuardEnvExpr] = []
|
self.aotautograd_guards: list[GuardEnvExpr] = []
|
||||||
|
|
||||||
def copy_graphstate(self):
|
def copy_graphstate(self) -> GuardsCheckpointState:
|
||||||
return GuardsCheckpointState(set(self.dynamo_guards.inner))
|
return GuardsCheckpointState(set(self.dynamo_guards.inner))
|
||||||
|
|
||||||
def restore_graphstate(self, state):
|
def restore_graphstate(self, state: GuardsCheckpointState) -> None:
|
||||||
# NB: "steals" the passed in state
|
# NB: "steals" the passed in state
|
||||||
assert isinstance(state, GuardsCheckpointState)
|
assert isinstance(state, GuardsCheckpointState)
|
||||||
self.dynamo_guards = GuardsSet(state.dynamo_guards)
|
self.dynamo_guards = GuardsSet(state.dynamo_guards)
|
||||||
|
|
@ -657,22 +675,22 @@ class GuardsContext(Checkpointable[GuardsCheckpointState]):
|
||||||
|
|
||||||
class HopSubgraphCache:
|
class HopSubgraphCache:
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str): ...
|
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ...
|
def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_autograd_key_entry(self, identifier: str, key: Callable): ...
|
def add_autograd_key_entry(self, identifier: str, key: Callable) -> None: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_autograd_key_entry(self, identifier: str): ...
|
def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_proxy_dispatch_entry(self, identifier: str, key: Callable): ...
|
def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_proxy_dispatch_entry(self, identifier: str): ...
|
def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def add_lazy_bwd_entry(
|
def add_lazy_bwd_entry(
|
||||||
|
|
@ -680,12 +698,12 @@ class HopSubgraphCache:
|
||||||
identifier: str,
|
identifier: str,
|
||||||
tangent_metadata: tuple[object],
|
tangent_metadata: tuple[object],
|
||||||
gmod: torch.fx.GraphModule,
|
gmod: torch.fx.GraphModule,
|
||||||
): ...
|
) -> int: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_lazy_bwd_entry(
|
def get_lazy_bwd_entry(
|
||||||
self, identifier: str, tangent_metadata: tuple[object]
|
self, identifier: str, tangent_metadata: tuple[object]
|
||||||
) -> int: ...
|
) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]: ...
|
||||||
|
|
||||||
|
|
||||||
class InvokeSubgraphCache(HopSubgraphCache):
|
class InvokeSubgraphCache(HopSubgraphCache):
|
||||||
|
|
@ -697,22 +715,22 @@ class InvokeSubgraphCache(HopSubgraphCache):
|
||||||
str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
|
str, dict[tuple[object], tuple[torch.fx.GraphModule, int]]
|
||||||
] = defaultdict(dict)
|
] = defaultdict(dict)
|
||||||
|
|
||||||
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str):
|
def add_dynamo_installed_submodule(self, fn_id: int, identifier: str) -> None:
|
||||||
self.dynamo_installed_submodules[fn_id].append(identifier)
|
self.dynamo_installed_submodules[fn_id].append(identifier)
|
||||||
|
|
||||||
def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]:
|
def get_dynamo_installed_submodules(self, fn_id: int) -> list[str]:
|
||||||
return self.dynamo_installed_submodules.get(fn_id, [])
|
return self.dynamo_installed_submodules.get(fn_id, [])
|
||||||
|
|
||||||
def add_autograd_key_entry(self, identifier: str, key: Callable):
|
def add_autograd_key_entry(self, identifier: str, key: Callable) -> None:
|
||||||
self.autograd_cache[identifier] = key
|
self.autograd_cache[identifier] = key
|
||||||
|
|
||||||
def get_autograd_key_entry(self, identifier: str):
|
def get_autograd_key_entry(self, identifier: str) -> Optional[Callable]:
|
||||||
return self.autograd_cache.get(identifier, None)
|
return self.autograd_cache.get(identifier, None)
|
||||||
|
|
||||||
def add_proxy_dispatch_entry(self, identifier: str, key: Callable):
|
def add_proxy_dispatch_entry(self, identifier: str, key: Callable) -> None:
|
||||||
self.proxy_dispatch_cache[identifier] = key
|
self.proxy_dispatch_cache[identifier] = key
|
||||||
|
|
||||||
def get_proxy_dispatch_entry(self, identifier: str):
|
def get_proxy_dispatch_entry(self, identifier: str) -> Optional[Callable]:
|
||||||
return self.proxy_dispatch_cache.get(identifier, None)
|
return self.proxy_dispatch_cache.get(identifier, None)
|
||||||
|
|
||||||
def add_lazy_bwd_entry(
|
def add_lazy_bwd_entry(
|
||||||
|
|
@ -720,13 +738,15 @@ class InvokeSubgraphCache(HopSubgraphCache):
|
||||||
identifier: str,
|
identifier: str,
|
||||||
tangent_metadata: tuple[object],
|
tangent_metadata: tuple[object],
|
||||||
gmod: torch.fx.GraphModule,
|
gmod: torch.fx.GraphModule,
|
||||||
):
|
) -> int:
|
||||||
# Save the number of existing graph modules in the dictionary to get the suffix
|
# Save the number of existing graph modules in the dictionary to get the suffix
|
||||||
num_gmods = len(self.lazy_bwd_cache[identifier])
|
num_gmods = len(self.lazy_bwd_cache[identifier])
|
||||||
self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods)
|
self.lazy_bwd_cache[identifier][tangent_metadata] = (gmod, num_gmods)
|
||||||
return num_gmods
|
return num_gmods
|
||||||
|
|
||||||
def get_lazy_bwd_entry(self, identifier: str, tangent_metadata: tuple[object]):
|
def get_lazy_bwd_entry(
|
||||||
|
self, identifier: str, tangent_metadata: tuple[object]
|
||||||
|
) -> tuple[Optional[torch.fx.GraphModule], Optional[int]]:
|
||||||
if identifier not in self.lazy_bwd_cache:
|
if identifier not in self.lazy_bwd_cache:
|
||||||
return (None, None)
|
return (None, None)
|
||||||
|
|
||||||
|
|
@ -779,7 +799,7 @@ class CompileContext:
|
||||||
def try_get() -> Optional[CompileContext]:
|
def try_get() -> Optional[CompileContext]:
|
||||||
return getattr(_TLS, "compile_context", None)
|
return getattr(_TLS, "compile_context", None)
|
||||||
|
|
||||||
def __init__(self, compile_id):
|
def __init__(self, compile_id: Optional[CompileId]) -> None:
|
||||||
assert compile_id is None or isinstance(compile_id, CompileId)
|
assert compile_id is None or isinstance(compile_id, CompileId)
|
||||||
self.compile_id: Optional[CompileId] = compile_id
|
self.compile_id: Optional[CompileId] = compile_id
|
||||||
self.attempt = 0
|
self.attempt = 0
|
||||||
|
|
@ -787,14 +807,14 @@ class CompileContext:
|
||||||
self.shape_env_guards: list[str] = []
|
self.shape_env_guards: list[str] = []
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def current_compile_id():
|
def current_compile_id() -> Optional[CompileId]:
|
||||||
self = CompileContext.try_get()
|
self = CompileContext.try_get()
|
||||||
if self is None:
|
if self is None:
|
||||||
return None
|
return None
|
||||||
return self.compile_id
|
return self.compile_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def current_trace_id():
|
def current_trace_id() -> Optional[TraceId]:
|
||||||
self = CompileContext.try_get()
|
self = CompileContext.try_get()
|
||||||
if self is None:
|
if self is None:
|
||||||
return None
|
return None
|
||||||
|
|
@ -823,28 +843,28 @@ class TracingContext:
|
||||||
"TracingContext.get() must be called within an ongoing trace."
|
"TracingContext.get() must be called within an ongoing trace."
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self, fake_mode):
|
def __init__(self, fake_mode: Optional[FakeTensorMode]) -> None:
|
||||||
self.guards_context = GuardsContext()
|
self.guards_context = GuardsContext()
|
||||||
self.module_context = ModuleContext()
|
self.module_context = ModuleContext()
|
||||||
self.global_context = GlobalContext()
|
self.global_context = GlobalContext()
|
||||||
self.previously_inlined_functions = dict()
|
self.previously_inlined_functions: dict[Any, Any] = dict()
|
||||||
self.previously_cleaned_instructions = dict()
|
self.previously_cleaned_instructions: dict[Any, Any] = dict()
|
||||||
self.fake_mode = fake_mode
|
self.fake_mode: Optional[FakeTensorMode] = fake_mode
|
||||||
self.frame_summary_stack = []
|
self.frame_summary_stack: list[traceback.FrameSummary] = []
|
||||||
# This is morally part of frame_summary_stack, but it is kept separate
|
# This is morally part of frame_summary_stack, but it is kept separate
|
||||||
# for clarity. As we process a frame, this variable gets updated
|
# for clarity. As we process a frame, this variable gets updated
|
||||||
# to keep track of what line we are in the function. We make a
|
# to keep track of what line we are in the function. We make a
|
||||||
# function call, this gets cleared and the frame location is pushed
|
# function call, this gets cleared and the frame location is pushed
|
||||||
# to frame_summary_stack (prepping this variable for the inner frame's
|
# to frame_summary_stack (prepping this variable for the inner frame's
|
||||||
# progress)
|
# progress)
|
||||||
self.loc_in_frame = None
|
self.loc_in_frame: Optional[tuple[str, int, str]] = None
|
||||||
# this is only set after aot_autograd
|
# this is only set after aot_autograd
|
||||||
self.fw_metadata = None
|
self.fw_metadata: Optional[ViewAndMutationMeta] = None
|
||||||
# this is only set after aot_autograd
|
# this is only set after aot_autograd
|
||||||
self.aot_graph_name = None
|
self.aot_graph_name: Optional[list[str]] = None
|
||||||
self.params_flat = None
|
self.params_flat: Optional[list[Any]] = None
|
||||||
self.params_flat_unwrap_subclasses = None
|
self.params_flat_unwrap_subclasses: Optional[list[Any]] = None
|
||||||
self.params_unwrapped_to_flat_index = None
|
self.params_unwrapped_to_flat_index: Optional[list[Any]] = None
|
||||||
# this is for extended return calling convention from backend
|
# this is for extended return calling convention from backend
|
||||||
# compiler to aot_autograd
|
# compiler to aot_autograd
|
||||||
# Per output, what the compiler specified stride of the output is,
|
# Per output, what the compiler specified stride of the output is,
|
||||||
|
|
@ -872,7 +892,7 @@ class TracingContext:
|
||||||
# list of code objects for inlined functions
|
# list of code objects for inlined functions
|
||||||
self.traced_code: list[CodeType] = []
|
self.traced_code: list[CodeType] = []
|
||||||
|
|
||||||
def clear(self):
|
def clear(self) -> None:
|
||||||
# Look at the note in output_graph.py in function `save_global_state`
|
# Look at the note in output_graph.py in function `save_global_state`
|
||||||
# for the context on clearing global context.
|
# for the context on clearing global context.
|
||||||
self.global_context.global_state = {}
|
self.global_context.global_state = {}
|
||||||
|
|
@ -881,7 +901,7 @@ class TracingContext:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def patch(**kwargs):
|
def patch(**kwargs: Any) -> Generator[None, None, None]:
|
||||||
prior = {}
|
prior = {}
|
||||||
ctx = TracingContext.get()
|
ctx = TracingContext.get()
|
||||||
|
|
||||||
|
|
@ -897,7 +917,7 @@ class TracingContext:
|
||||||
setattr(ctx, key, val)
|
setattr(ctx, key, val)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def extract_stack():
|
def extract_stack() -> traceback.StackSummary:
|
||||||
self = TracingContext.try_get()
|
self = TracingContext.try_get()
|
||||||
if self is None:
|
if self is None:
|
||||||
return traceback.StackSummary()
|
return traceback.StackSummary()
|
||||||
|
|
@ -906,7 +926,7 @@ class TracingContext:
|
||||||
stack = stack + [self._populate_loc_in_frame_summary()]
|
stack = stack + [self._populate_loc_in_frame_summary()]
|
||||||
return traceback.StackSummary.from_list(stack)
|
return traceback.StackSummary.from_list(stack)
|
||||||
|
|
||||||
def _populate_loc_in_frame_summary(self):
|
def _populate_loc_in_frame_summary(self) -> traceback.FrameSummary:
|
||||||
assert self.loc_in_frame is not None
|
assert self.loc_in_frame is not None
|
||||||
filename, lineno, frame_name = self.loc_in_frame
|
filename, lineno, frame_name = self.loc_in_frame
|
||||||
return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False)
|
return traceback.FrameSummary(filename, lineno, frame_name, lookup_line=False)
|
||||||
|
|
@ -915,7 +935,7 @@ class TracingContext:
|
||||||
# associated with the current frame state
|
# associated with the current frame state
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def clear_frame():
|
def clear_frame() -> Generator[None, None, None]:
|
||||||
tc = TracingContext.get()
|
tc = TracingContext.get()
|
||||||
with (
|
with (
|
||||||
unittest.mock.patch.object(tc, "frame_summary_stack", []),
|
unittest.mock.patch.object(tc, "frame_summary_stack", []),
|
||||||
|
|
@ -947,7 +967,9 @@ class TracingContext:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def current_frame(frame_summary):
|
def current_frame(
|
||||||
|
frame_summary: Optional[traceback.FrameSummary],
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
# frame_summary can be None to solely take advantage of real_stack
|
# frame_summary can be None to solely take advantage of real_stack
|
||||||
# attachment to thrown exceptions
|
# attachment to thrown exceptions
|
||||||
tc = TracingContext.get()
|
tc = TracingContext.get()
|
||||||
|
|
@ -968,7 +990,9 @@ class TracingContext:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
def report_output_strides():
|
def report_output_strides() -> Generator[
|
||||||
|
Optional[list[Optional[tuple[int, ...]]]], None, None
|
||||||
|
]:
|
||||||
tc = TracingContext.try_get()
|
tc = TracingContext.try_get()
|
||||||
if tc is None:
|
if tc is None:
|
||||||
yield None
|
yield None
|
||||||
|
|
@ -981,13 +1005,13 @@ class TracingContext:
|
||||||
tc.output_strides = old_output_strides
|
tc.output_strides = old_output_strides
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_current_loc(filename, lineno, frame_name):
|
def set_current_loc(filename: str, lineno: int, frame_name: str) -> None:
|
||||||
# Save the current location in the frame. Lazily generate the
|
# Save the current location in the frame. Lazily generate the
|
||||||
# framesummary.
|
# framesummary.
|
||||||
TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
|
TracingContext.get().loc_in_frame = (filename, lineno, frame_name)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_traced_code():
|
def get_traced_code() -> Optional[list[CodeType]]:
|
||||||
tc = TracingContext.try_get()
|
tc = TracingContext.try_get()
|
||||||
if tc is None:
|
if tc is None:
|
||||||
return None
|
return None
|
||||||
|
|
@ -995,7 +1019,9 @@ class TracingContext:
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def compile_context(context: Optional[CompileContext]):
|
def compile_context(
|
||||||
|
context: Optional[CompileContext],
|
||||||
|
) -> Generator[Optional[CompileContext], None, None]:
|
||||||
old_context = getattr(_TLS, "compile_context", None)
|
old_context = getattr(_TLS, "compile_context", None)
|
||||||
_TLS.compile_context = context
|
_TLS.compile_context = context
|
||||||
try:
|
try:
|
||||||
|
|
@ -1005,7 +1031,9 @@ def compile_context(context: Optional[CompileContext]):
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def tracing(context: Optional[TracingContext]):
|
def tracing(
|
||||||
|
context: Optional[TracingContext],
|
||||||
|
) -> Generator[Optional[TracingContext], None, None]:
|
||||||
"""
|
"""
|
||||||
This function installs the passed in tracing context as a dynamic scoped
|
This function installs the passed in tracing context as a dynamic scoped
|
||||||
global variable.
|
global variable.
|
||||||
|
|
@ -1035,13 +1063,13 @@ def tracing(context: Optional[TracingContext]):
|
||||||
# TODO(voz): Consider a toplevel torch/_source.py
|
# TODO(voz): Consider a toplevel torch/_source.py
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class Source:
|
class Source:
|
||||||
def is_dict_key(self):
|
def is_dict_key(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def is_ephemeral(self):
|
def is_ephemeral(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def reconstruct(self, codegen):
|
def reconstruct(self, codegen: PyCodegen) -> None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def guard_source(self) -> GuardSource:
|
def guard_source(self) -> GuardSource:
|
||||||
|
|
@ -1050,7 +1078,7 @@ class Source:
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def make_guard(self, fn) -> Guard:
|
def make_guard(self, fn: Callable[..., Any]) -> Guard:
|
||||||
if self.guard_source() is GuardSource.CONSTANT:
|
if self.guard_source() is GuardSource.CONSTANT:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return Guard(self, fn)
|
return Guard(self, fn)
|
||||||
|
|
@ -1058,7 +1086,7 @@ class Source:
|
||||||
def is_specialized_nn_module(self) -> bool:
|
def is_specialized_nn_module(self) -> bool:
|
||||||
return self.guard_source().is_specialized_nn_module()
|
return self.guard_source().is_specialized_nn_module()
|
||||||
|
|
||||||
def subguards_allowed(self):
|
def subguards_allowed(self) -> bool:
|
||||||
"""True if you can guard on attributes of this"""
|
"""True if you can guard on attributes of this"""
|
||||||
return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
|
return self.guard_source() != GuardSource.SYNTHETIC_LOCAL
|
||||||
|
|
||||||
|
|
@ -1068,11 +1096,11 @@ class Source:
|
||||||
class ChainedSource(Source):
|
class ChainedSource(Source):
|
||||||
base: Source
|
base: Source
|
||||||
|
|
||||||
def is_dict_key(self):
|
def is_dict_key(self) -> bool:
|
||||||
# Recurse until you either hit a ConstDictKey or a Source
|
# Recurse until you either hit a ConstDictKey or a Source
|
||||||
return self.base.is_dict_key()
|
return self.base.is_dict_key()
|
||||||
|
|
||||||
def is_ephemeral(self):
|
def is_ephemeral(self) -> bool:
|
||||||
return self.base.is_ephemeral()
|
return self.base.is_ephemeral()
|
||||||
|
|
||||||
def get_base(self) -> Source:
|
def get_base(self) -> Source:
|
||||||
|
|
@ -1082,7 +1110,7 @@ class ChainedSource(Source):
|
||||||
return current
|
return current
|
||||||
|
|
||||||
|
|
||||||
def detect_fake_mode(inputs: Any = None):
|
def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
|
||||||
"""
|
"""
|
||||||
Attempts to "detect" what the current fake mode is. If there is one ambiently
|
Attempts to "detect" what the current fake mode is. If there is one ambiently
|
||||||
available from TracingContext, we preferentially use that. Otherwise, we
|
available from TracingContext, we preferentially use that. Otherwise, we
|
||||||
|
|
@ -1126,7 +1154,7 @@ def detect_fake_mode(inputs: Any = None):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def active_fake_mode():
|
def active_fake_mode() -> Optional[FakeTensorMode]:
|
||||||
"""
|
"""
|
||||||
Inspects the dispatch mode stack for an active fake mode and returns it.
|
Inspects the dispatch mode stack for an active fake mode and returns it.
|
||||||
Returns None if no fake mode is active.
|
Returns None if no fake mode is active.
|
||||||
|
|
|
||||||
|
|
@ -465,6 +465,7 @@ class InvokeSubgraphAutogradOp(torch.autograd.Function):
|
||||||
from torch._subclasses.fake_tensor import extract_tensor_metadata
|
from torch._subclasses.fake_tensor import extract_tensor_metadata
|
||||||
|
|
||||||
fake_mode = detect_fake_mode(primals + filtered_grad_outs)
|
fake_mode = detect_fake_mode(primals + filtered_grad_outs)
|
||||||
|
assert fake_mode is not None, "fake_mode should be enabled for HOPs"
|
||||||
state = _CacheKeyState(fake_mode.shape_env)
|
state = _CacheKeyState(fake_mode.shape_env)
|
||||||
|
|
||||||
tangent_metadata: list[object] = []
|
tangent_metadata: list[object] = []
|
||||||
|
|
@ -607,6 +608,7 @@ def _(proxy_mode: ProxyTorchDispatchMode, subgraph, identifier, *operands):
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
|
|
||||||
fake_mode = detect_fake_mode(operands)
|
fake_mode = detect_fake_mode(operands)
|
||||||
|
assert fake_mode is not None and fake_mode.shape_env is not None
|
||||||
insert_deferred_runtime_asserts(
|
insert_deferred_runtime_asserts(
|
||||||
graph,
|
graph,
|
||||||
fake_mode.shape_env,
|
fake_mode.shape_env,
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import contextlib
|
import contextlib
|
||||||
import functools
|
import functools
|
||||||
from contextlib import contextmanager, ExitStack, nullcontext
|
from contextlib import AbstractContextManager, contextmanager, ExitStack, nullcontext
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, Optional, overload, TypeVar, Union
|
from typing import Any, Callable, Optional, overload, TypeVar, Union
|
||||||
|
|
||||||
|
|
@ -266,11 +266,12 @@ def _set_compilation_env():
|
||||||
|
|
||||||
# The invariant here is that we always trace the branch with fake tensor
|
# The invariant here is that we always trace the branch with fake tensor
|
||||||
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
||||||
fake_mode = detect_fake_mode(inputs)
|
fake_mode_det = detect_fake_mode(inputs)
|
||||||
tracing_mode = "real"
|
fake_mode: AbstractContextManager = nullcontext()
|
||||||
if fake_mode is None:
|
tracing_mode = "fake"
|
||||||
fake_mode = nullcontext()
|
if fake_mode_det is not None:
|
||||||
tracing_mode = "fake"
|
fake_mode = fake_mode_det
|
||||||
|
tracing_mode = "real"
|
||||||
|
|
||||||
# Note: we need to turn off proxy tensor mode to avoid tracing infra
|
# Note: we need to turn off proxy tensor mode to avoid tracing infra
|
||||||
# code that happens in make_fx e.g. we now call as_strided when wrapping tensor
|
# code that happens in make_fx e.g. we now call as_strided when wrapping tensor
|
||||||
|
|
@ -282,9 +283,12 @@ def _maybe_fake_tracing(fn, inputs: list[Any], pre_dispatch):
|
||||||
pre_dispatch=pre_dispatch,
|
pre_dispatch=pre_dispatch,
|
||||||
_error_on_data_dependent_ops=False,
|
_error_on_data_dependent_ops=False,
|
||||||
)(*inputs)
|
)(*inputs)
|
||||||
if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None:
|
if not isinstance(fake_mode, nullcontext) and fake_mode.shape_env is not None: # type: ignore[attr-defined]
|
||||||
insert_deferred_runtime_asserts(
|
insert_deferred_runtime_asserts(
|
||||||
gm, fake_mode.shape_env, "hoo_maybe_fake_tracing", export=True
|
gm,
|
||||||
|
fake_mode.shape_env, # type: ignore[attr-defined]
|
||||||
|
"hoo_maybe_fake_tracing",
|
||||||
|
export=True, # type: ignore[attr-defined]
|
||||||
)
|
)
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1065,7 +1065,7 @@ class GuardedCache(Generic[T]):
|
||||||
Helper to get the shape env from the tracing context.
|
Helper to get the shape env from the tracing context.
|
||||||
"""
|
"""
|
||||||
ctx = torch._guards.TracingContext.try_get()
|
ctx = torch._guards.TracingContext.try_get()
|
||||||
if not ctx:
|
if not ctx or not ctx.fake_mode:
|
||||||
return None
|
return None
|
||||||
return ctx.fake_mode.shape_env
|
return ctx.fake_mode.shape_env
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1942,7 +1942,7 @@ def fw_compiler_freezing(
|
||||||
idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node)
|
idx for idx, n in enumerate(model_outputs) if isinstance(n, torch.fx.Node)
|
||||||
]
|
]
|
||||||
|
|
||||||
static_input_idxs = []
|
static_input_idxs: list[Any] = []
|
||||||
# constant params will be real tensors, not fake
|
# constant params will be real tensors, not fake
|
||||||
tracing_context = torch._guards.TracingContext.try_get()
|
tracing_context = torch._guards.TracingContext.try_get()
|
||||||
unwrapped_args_offsets = [0]
|
unwrapped_args_offsets = [0]
|
||||||
|
|
@ -2461,6 +2461,7 @@ def compile_fx(
|
||||||
if node.op == "get_attr" and "val" not in node.meta:
|
if node.op == "get_attr" and "val" not in node.meta:
|
||||||
target = attrgetter(node.target)(gm)
|
target = attrgetter(node.target)(gm)
|
||||||
if isinstance(target, torch.Tensor):
|
if isinstance(target, torch.Tensor):
|
||||||
|
assert fake_mode is not None
|
||||||
node.meta["val"] = fake_mode.from_tensor(
|
node.meta["val"] = fake_mode.from_tensor(
|
||||||
target, static_shapes=True
|
target, static_shapes=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1429,7 +1429,9 @@ def register_replacement(
|
||||||
)
|
)
|
||||||
|
|
||||||
sym_args: list[torch.SymInt] = []
|
sym_args: list[torch.SymInt] = []
|
||||||
with torch._dynamo.utils.detect_fake_mode(args):
|
fake_mode = torch._dynamo.utils.detect_fake_mode(args)
|
||||||
|
assert fake_mode is not None
|
||||||
|
with fake_mode:
|
||||||
for i, grad in enumerate(requires_grad):
|
for i, grad in enumerate(requires_grad):
|
||||||
if isinstance(args[i], torch.Tensor):
|
if isinstance(args[i], torch.Tensor):
|
||||||
if grad and is_integer_dtype(args[i].dtype):
|
if grad and is_integer_dtype(args[i].dtype):
|
||||||
|
|
|
||||||
|
|
@ -203,6 +203,7 @@ def standalone_compile(
|
||||||
# Reuse fake_mode from the TracingContext.
|
# Reuse fake_mode from the TracingContext.
|
||||||
# NB: The TracingContext only exists if we're currently in a torch.compile backend.
|
# NB: The TracingContext only exists if we're currently in a torch.compile backend.
|
||||||
context = torch._guards.TracingContext.get()
|
context = torch._guards.TracingContext.get()
|
||||||
|
assert context.fake_mode is not None
|
||||||
fake_mode = context.fake_mode
|
fake_mode = context.fake_mode
|
||||||
elif dynamic_shapes == "from_graph":
|
elif dynamic_shapes == "from_graph":
|
||||||
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
fake_mode = FakeTensorMode(shape_env=ShapeEnv())
|
||||||
|
|
|
||||||
|
|
@ -2817,10 +2817,9 @@ def maybe_get_suppress_shape_guards_ctx() -> contextlib.AbstractContextManager[N
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
|
||||||
# In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
|
# In standalone inductor compile mode, we might not have a shape_env attached to the fake mode
|
||||||
shape_env = tracing_context.fake_mode.shape_env
|
if not tracing_context.fake_mode or not tracing_context.fake_mode.shape_env:
|
||||||
if not shape_env:
|
|
||||||
return contextlib.nullcontext()
|
return contextlib.nullcontext()
|
||||||
|
shape_env = tracing_context.fake_mode.shape_env
|
||||||
return shape_env.suppress_guards()
|
return shape_env.suppress_guards()
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -3343,12 +3342,13 @@ def tabulate_2d(elements: Sequence[Sequence[T]], headers: Sequence[T]) -> str:
|
||||||
for i, e in enumerate(row):
|
for i, e in enumerate(row):
|
||||||
widths[i] = max(widths[i], len(str(e)))
|
widths[i] = max(widths[i], len(str(e)))
|
||||||
lines = []
|
lines = []
|
||||||
lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths)))
|
# Need nested {} for string formatting; ignore SET_LINTER here
|
||||||
|
lines.append("|".join(f" {h:{w}} " for h, w in zip(headers, widths))) # noqa: set_linter
|
||||||
# widths whitespace horizontal separators
|
# widths whitespace horizontal separators
|
||||||
total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
|
total_width = sum(widths) + (len(widths) * 2) + (len(widths) - 1)
|
||||||
lines.append("-" * total_width)
|
lines.append("-" * total_width)
|
||||||
for row in elements:
|
for row in elements:
|
||||||
lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths)))
|
lines.append("|".join(f" {e:{w}} " for e, w in zip(row, widths))) # noqa: set_linter
|
||||||
return "\n".join(lines)
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ import operator
|
||||||
import typing
|
import typing
|
||||||
import warnings
|
import warnings
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from contextlib import nullcontext
|
from contextlib import AbstractContextManager, nullcontext
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
from typing import (
|
from typing import (
|
||||||
|
|
@ -2090,7 +2090,9 @@ def alert_not_deterministic(caller: str):
|
||||||
|
|
||||||
class CUDARngStateHelper:
|
class CUDARngStateHelper:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_torch_state_as_tuple(fake_mode=nullcontext()):
|
def get_torch_state_as_tuple(
|
||||||
|
fake_mode: AbstractContextManager[Any] = nullcontext(),
|
||||||
|
):
|
||||||
if not torch.cuda.is_available():
|
if not torch.cuda.is_available():
|
||||||
raise RuntimeError("CUDA not available")
|
raise RuntimeError("CUDA not available")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -573,8 +573,8 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||||
delattr(ep.graph_module, name)
|
delattr(ep.graph_module, name)
|
||||||
|
|
||||||
# TODO(zhxhchen17) Return the new graph_signature directly.
|
# TODO(zhxhchen17) Return the new graph_signature directly.
|
||||||
fake_mode = detect_fake_mode(fake_args)
|
fake_mode_det = detect_fake_mode(fake_args)
|
||||||
fake_mode = contextlib.nullcontext() if fake_mode is None else fake_mode # type: ignore[assignment]
|
fake_mode_ctx = contextlib.nullcontext() if fake_mode_det is None else fake_mode_det # type: ignore[assignment]
|
||||||
custom_triton_ops_decomposition_ctx = (
|
custom_triton_ops_decomposition_ctx = (
|
||||||
contextlib.nullcontext
|
contextlib.nullcontext
|
||||||
if decompose_custom_triton_ops
|
if decompose_custom_triton_ops
|
||||||
|
|
@ -582,7 +582,7 @@ def _decompose_and_get_gm_with_new_signature_constants(
|
||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
_ignore_backend_decomps(),
|
_ignore_backend_decomps(),
|
||||||
fake_mode,
|
fake_mode_ctx,
|
||||||
_override_composite_implicit_decomp(cia_to_decomp),
|
_override_composite_implicit_decomp(cia_to_decomp),
|
||||||
custom_triton_ops_decomposition_ctx(),
|
custom_triton_ops_decomposition_ctx(),
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -7872,7 +7872,9 @@ class PropagateUnbackedSymInts(torch.fx.Interpreter):
|
||||||
from torch._guards import detect_fake_mode
|
from torch._guards import detect_fake_mode
|
||||||
|
|
||||||
result = super().run_node(n)
|
result = super().run_node(n)
|
||||||
rebind_unbacked(detect_fake_mode().shape_env, n, result)
|
fake_mode = detect_fake_mode()
|
||||||
|
assert fake_mode is not None
|
||||||
|
rebind_unbacked(fake_mode.shape_env, n, result)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1028,7 +1028,7 @@ def bound_sympy(
|
||||||
|
|
||||||
# If there's a tracing context, augment available constrained ranges.
|
# If there's a tracing context, augment available constrained ranges.
|
||||||
context = torch._guards.TracingContext.try_get()
|
context = torch._guards.TracingContext.try_get()
|
||||||
if context and context.fake_mode.shape_env:
|
if context and context.fake_mode and context.fake_mode.shape_env:
|
||||||
if ranges:
|
if ranges:
|
||||||
ranges = {**context.fake_mode.shape_env.var_to_range, **ranges}
|
ranges = {**context.fake_mode.shape_env.var_to_range, **ranges}
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user