mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This reverts commit f5cb9a4c68.
Reverted https://github.com/pytorch/pytorch/pull/164819 on behalf of https://github.com/atalman due to breaks CI ([comment](https://github.com/pytorch/pytorch/pull/164819#issuecomment-3469018283))
63 lines
1.9 KiB
Python
63 lines
1.9 KiB
Python
import weakref
|
|
from typing import Any
|
|
|
|
from torch._dynamo.source import Source
|
|
|
|
|
|
# This file is to handle types that we don't want to support
|
|
# as explicit FX graph inputs. This uses a sidetable which
|
|
# we populate in bytecode and is loaded during graph execution
|
|
|
|
# We use a dynamo-generated index as a level of indirection
|
|
# this allows us to register objects externally in pre-graph bytecode that we want
|
|
# to pass to the graph, but not support their types as graph inputs
|
|
index_to_source: dict[int, Source] = {}
|
|
|
|
index_to_user_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
|
|
|
|
|
|
def has_user_objects() -> bool:
|
|
return bool(index_to_source)
|
|
|
|
|
|
def get_user_object_by_index(index: int) -> Any:
|
|
assert index in index_to_user_object_weakref, (
|
|
"Index not registered in index_to_user_object_weakref"
|
|
)
|
|
obj = index_to_user_object_weakref[index]()
|
|
assert obj is not None, "User object is no longer alive"
|
|
return index_to_user_object_weakref[index]()
|
|
|
|
|
|
def store_user_object_weakrefs(*args: Any) -> None:
|
|
global index_to_user_object_weakref
|
|
index_to_user_object_weakref.clear()
|
|
index_to_user_object_weakref.update(
|
|
{i: weakref.ref(arg) for i, arg in enumerate(args)}
|
|
)
|
|
|
|
|
|
def reset_user_object_tracking() -> None:
|
|
index_to_source.clear()
|
|
index_to_user_object_weakref.clear()
|
|
|
|
|
|
# Register a user object to be used in the graph
|
|
def register_user_object(value: Any, source: Source) -> int:
|
|
global index_to_source
|
|
index = len(index_to_source)
|
|
index_to_source[index] = source
|
|
try:
|
|
index_to_user_object_weakref[index] = weakref.ref(value)
|
|
except TypeError as e:
|
|
from .exc import unimplemented_v2
|
|
|
|
unimplemented_v2(
|
|
gb_type="Failed to make weakref to User Object",
|
|
context=f"user_object: {value}",
|
|
explanation="Object does not allow us to make a weakref to it",
|
|
hints=[],
|
|
from_exc=e,
|
|
)
|
|
return index
|