Convert InstructionTranslatorGraphState and OutputGraphState to NamedTuple (#90186)

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90186
Approved by: https://github.com/voznesenskym
This commit is contained in:
Edward Z. Yang 2022-12-07 12:01:35 -08:00 committed by PyTorch MergeBot
parent 1119aac485
commit ca5f69ef19
2 changed files with 32 additions and 19 deletions

View File

@ -7,7 +7,18 @@ import operator
import re
import traceback
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Set, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
NamedTuple,
Optional,
OrderedDict,
Set,
Tuple,
Union,
)
import sympy
from typing_extensions import Protocol
@ -56,9 +67,12 @@ class CompiledFn(Protocol):
CompilerFn = Callable[[fx.GraphModule, List[torch.Tensor]], CompiledFn]
OutputGraphState = Tuple[
List[GraphArg], Set[Guard], Optional[Dict[str, torch.nn.Module]], SideEffects, int
]
class OutputGraphState(NamedTuple):
graphargs: List[GraphArg]
guards: Set[Guard]
nn_modules: Optional[Dict[str, torch.nn.Module]]
side_effects: SideEffects
timestamp: int
@functools.lru_cache(None)
@ -207,7 +221,7 @@ class OutputGraph(fx.Tracer):
def copy_graphstate(self) -> OutputGraphState:
"""Create a checkpoint of the current state by copying everything"""
assert self.nn_modules is not None
state = (
state = OutputGraphState(
list(self.graphargs),
set(self.guards),
dict(self.nn_modules),
@ -217,7 +231,7 @@ class OutputGraph(fx.Tracer):
self.timestamp += 1
return state
def restore_graphstate(self, state):
def restore_graphstate(self, state: OutputGraphState):
"""Restore a checkpoint created by self.copy_graphstate()"""
(
self.graphargs,

View File

@ -13,7 +13,7 @@ import types
import typing
import weakref
from collections.abc import Sized
from typing import Any, Callable, Dict, List, Optional, Tuple
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Tuple
from unittest.mock import patch
import torch
@ -106,16 +106,15 @@ class BlockStackEntry:
return self.with_context.exit(tx)
InstructionTranslatorGraphState = Tuple[
OutputGraphState,
Dict[str, VariableTracker],
List[VariableTracker],
List[BlockStackEntry],
Optional[int],
Instruction,
Optional[Instruction],
int,
]
class InstructionTranslatorGraphState(NamedTuple):
output: OutputGraphState
symbolic_locals: Dict[str, VariableTracker]
stack: List[VariableTracker]
block_stack: List[BlockStackEntry]
instruction_pointer: Optional[int]
current_instruction: Instruction
next_instruction: Optional[Instruction]
lineno: int
def stack_op(fn: typing.Callable[..., object]):
@ -1441,7 +1440,7 @@ class InstructionTranslatorBase(object):
def copy_graphstate(self) -> InstructionTranslatorGraphState:
"""Create a checkpoint of the current state by copying everything"""
return (
return InstructionTranslatorGraphState(
self.output.copy_graphstate(),
collections.OrderedDict(self.symbolic_locals),
list(self.stack),
@ -1452,7 +1451,7 @@ class InstructionTranslatorBase(object):
self.lineno,
)
def restore_graphstate(self, state):
def restore_graphstate(self, state: InstructionTranslatorGraphState):
"""Restore a checkpoint created by self.copy_graphstate()"""
(
output_state,