pytorch/torch/_dynamo/graph_bytecode_inputs.py
Michael Lazos f5cb9a4c68 [user-streams] Fix stream graph output semantics (#164819)
Preivously, we would stash a single stream value we constructed at trace time in a global and return the same value from repeated calls to the graph.

With this PR, we construct the stream value in advance, reference the constructed value in the graph via the lookup table, and if that value is returned as an output, read the value from the lookup table and return it (in bytecode, not as a graph output, since we don't support arbitrary stream outputs).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164819
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304, #164522
2025-10-30 04:58:46 +00:00

91 lines
2.9 KiB
Python

import weakref
from typing import Any, Callable
from torch._dynamo.source import Source
PyCodegen = Any
# This file is to handle types that we don't want to support
# as explicit FX graph inputs. This uses a sidetable which
# we populate in bytecode and is loaded during graph execution
# We use a dynamo-generated index as a level of indirection
# this allows us to register objects externally in pre-graph bytecode that we want
# to pass to the graph, but not support their types as graph inputs
index_to_bytecode_constructor: dict[int, Callable[[PyCodegen], None]] = {}
index_to_external_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
keep_alive: list[Any] = []
def has_user_objects() -> bool:
return bool(index_to_bytecode_constructor)
def get_external_object_by_index(index: int) -> Any:
assert index in index_to_external_object_weakref, (
"Index not registered in index_to_user_object_weakref"
)
obj = index_to_external_object_weakref[index]()
assert obj is not None, "User object is no longer alive"
return index_to_external_object_weakref[index]()
def store_user_object_weakrefs(*args: Any) -> None:
global index_to_external_object_weakref
index_to_external_object_weakref.clear()
index_to_external_object_weakref.update(
{i: weakref.ref(arg) for i, arg in enumerate(args)}
)
def reset_user_object_tracking() -> None:
index_to_bytecode_constructor.clear()
index_to_external_object_weakref.clear()
keep_alive.clear()
def register_graph_created_object(
example_value: Any, construct_fn: Callable[[int, PyCodegen], None]
) -> int:
global index_to_bytecode_constructor
global keep_alive
keep_alive.append(example_value)
index = len(index_to_bytecode_constructor)
index_to_bytecode_constructor[index] = lambda cg: construct_fn(index, cg)
try:
index_to_external_object_weakref[index] = weakref.ref(example_value)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to graph-created external object",
context=f"user_object: {example_value}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
return index
# Register a user object to be used in the graph
def register_user_object(value: Any, source: Source) -> int:
global index_to_bytecode_constructor
index = len(index_to_bytecode_constructor)
index_to_bytecode_constructor[index] = lambda cg: cg(source)
try:
index_to_external_object_weakref[index] = weakref.ref(value)
except TypeError as e:
from .exc import unimplemented_v2
unimplemented_v2(
gb_type="Failed to make weakref to User Object",
context=f"user_object: {value}",
explanation="Object does not allow us to make a weakref to it",
hints=[],
from_exc=e,
)
return index