mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Convert InstructionTranslatorGraphState and OutputGraphState to NamedTuple (#90186)
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/90186 Approved by: https://github.com/voznesenskym
This commit is contained in:
parent
1119aac485
commit
ca5f69ef19
|
|
@ -7,7 +7,18 @@ import operator
|
|||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Set, Tuple, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
OrderedDict,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import sympy
|
||||
from typing_extensions import Protocol
|
||||
|
|
@ -56,9 +67,12 @@ class CompiledFn(Protocol):
|
|||
CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
|
||||
|
||||
|
||||
OutputGraphState = Tuple[
|
||||
List[GraphArg], Set[Guard], Optional[Dict[str, torch.nn.Module]], SideEffects, int
|
||||
]
|
||||
class OutputGraphState(NamedTuple):
|
||||
graphargs: List[GraphArg]
|
||||
guards: Set[Guard]
|
||||
nn_modules: Optional[Dict[str, torch.nn.Module]]
|
||||
side_effects: SideEffects
|
||||
timestamp: int
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
|
@ -207,7 +221,7 @@ class OutputGraph(fx.Tracer):
|
|||
def copy_graphstate(self) -> OutputGraphState:
|
||||
"""Create a checkpoint of the current state by copying everything"""
|
||||
assert self.nn_modules is not None
|
||||
state = (
|
||||
state = OutputGraphState(
|
||||
list(self.graphargs),
|
||||
set(self.guards),
|
||||
dict(self.nn_modules),
|
||||
|
|
@ -217,7 +231,7 @@ class OutputGraph(fx.Tracer):
|
|||
self.timestamp += 1
|
||||
return state
|
||||
|
||||
def restore_graphstate(self, state):
|
||||
def restore_graphstate(self, state: OutputGraphState):
|
||||
"""Restore a checkpoint created by self.copy_graphstate()"""
|
||||
(
|
||||
self.graphargs,
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import types
|
|||
import typing
|
||||
import weakref
|
||||
from collections.abc import Sized
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -106,16 +106,15 @@ class BlockStackEntry:
|
|||
return self.with_context.exit(tx)
|
||||
|
||||
|
||||
InstructionTranslatorGraphState = Tuple[
|
||||
OutputGraphState,
|
||||
Dict[str, VariableTracker],
|
||||
List[VariableTracker],
|
||||
List[BlockStackEntry],
|
||||
Optional[int],
|
||||
Instruction,
|
||||
Optional[Instruction],
|
||||
int,
|
||||
]
|
||||
class InstructionTranslatorGraphState(NamedTuple):
|
||||
output: OutputGraphState
|
||||
symbolic_locals: Dict[str, VariableTracker]
|
||||
stack: List[VariableTracker]
|
||||
block_stack: List[BlockStackEntry]
|
||||
instruction_pointer: Optional[int]
|
||||
current_instruction: Instruction
|
||||
next_instruction: Optional[Instruction]
|
||||
lineno: int
|
||||
|
||||
|
||||
def stack_op(fn: typing.Callable[..., object]):
|
||||
|
|
@ -1441,7 +1440,7 @@ class InstructionTranslatorBase(object):
|
|||
|
||||
def copy_graphstate(self) -> InstructionTranslatorGraphState:
|
||||
"""Create a checkpoint of the current state by copying everything"""
|
||||
return (
|
||||
return InstructionTranslatorGraphState(
|
||||
self.output.copy_graphstate(),
|
||||
collections.OrderedDict(self.symbolic_locals),
|
||||
list(self.stack),
|
||||
|
|
@ -1452,7 +1451,7 @@ class InstructionTranslatorBase(object):
|
|||
self.lineno,
|
||||
)
|
||||
|
||||
def restore_graphstate(self, state):
|
||||
def restore_graphstate(self, state: InstructionTranslatorGraphState):
|
||||
"""Restore a checkpoint created by self.copy_graphstate()"""
|
||||
(
|
||||
output_state,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user