pytorch/torch/_dynamo/graph_bytecode_inputs.py

63 lines
1.9 KiB
Python

import weakref
from typing import Any
from torch._dynamo.source import Source
# This file is to handle types that we don't want to support
# as explicit FX graph inputs. This uses a sidetable which
# we populate in bytecode and is loaded during graph execution
# We use a dynamo-generated index as a level of indirection
# this allows us to register objects externally in pre-graph bytecode that we want
# to pass to the graph, but not support their types as graph inputs
index_to_source: dict[int, Source] = {}
index_to_user_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
def has_user_objects() -> bool:
return bool(index_to_source)
def get_user_object_by_index(index: int) -> Any:
assert index in index_to_user_object_weakref, (
"Index not registered in index_to_user_object_weakref"
)
obj = index_to_user_object_weakref[index]()
assert obj is not None, "User object is no longer alive"
return index_to_user_object_weakref[index]()
def store_user_object_weakrefs(*args: Any) -> None:
global index_to_user_object_weakref
index_to_user_object_weakref.clear()
index_to_user_object_weakref.update(
{i: weakref.ref(arg) for i, arg in enumerate(args)}
)
def reset_user_object_tracking() -> None:
index_to_source.clear()
index_to_user_object_weakref.clear()
# Register a user object to be used in the graph
def register_user_object(value: Any, source: Source) -> int:
global index_to_source
index = len(index_to_source)
index_to_source[index] = source
try:
index_to_user_object_weakref[index] = weakref.ref(value)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to User Object",
context=f"user_object: {value}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
return index