pytorch/torch/_dynamo/variables/streams.py
Michael Lazos f5cb9a4c68 [user-streams] Fix stream graph output semantics (#164819)
Preivously, we would stash a single stream value we constructed at trace time in a global and return the same value from repeated calls to the graph.

With this PR, we construct the stream value in advance, reference the constructed value in the graph via the lookup table, and if that value is returned as an output, read the value from the lookup table and return it (in bytecode, not as a graph output, since we don't support arbitrary stream outputs).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164819
Approved by: https://github.com/anijain2305
ghstack dependencies: #164304, #164522
2025-10-30 04:58:46 +00:00

352 lines
12 KiB
Python

from typing import Any, Optional
import torch
from torch.fx import Proxy
from .. import graph_break_hints
from ..bytecode_transformation import create_call_function
from ..device_interface import get_interface_for_device
from ..exc import TYPE_CHECKING, unimplemented_v2
from ..source import AttrSource, CallFunctionNoArgsSource, TorchSource
from .base import VariableTracker
from .constant import ConstantVariable
from .ctx_manager import ContextWrappingVariable
from .misc import GetAttrVariable
if TYPE_CHECKING:
from torch._dynamo.symbolic_convert import InstructionTranslator
from ..codegen import PyCodegen
from torch._library.custom_ops import custom_op
Tensor = torch.Tensor
@custom_op("streams::fork", mutates_args=())
def fork_stream(
from_index: int,
from_device: torch.device,
to_index: int,
to_device: torch.device,
) -> None:
pass
@fork_stream.register_fake
def _(
from_index: int,
from_device: torch.device,
to_index: int,
to_device: torch.device,
) -> None:
pass
@custom_op("streams::join", mutates_args=())
def join_stream(
from_index: int,
from_device: torch.device,
to_index: int,
to_device: torch.device,
) -> None:
pass
@join_stream.register_fake
def _(
from_index: int,
from_device: torch.device,
to_index: int,
to_device: torch.device,
) -> None:
pass
class StreamContextVariable(ContextWrappingVariable):
"""This represents torch.cuda.StreamContext"""
@staticmethod
def create(
tx: "InstructionTranslator",
target_value: "StreamVariable",
**kwargs: dict[str, Any],
) -> "StreamContextVariable":
return StreamContextVariable(
target_values=[target_value],
initial_values=[
StreamContextVariable._get_current_stream(target_value.device, tx)
],
device=target_value.device,
**kwargs,
)
def __init__(
self,
target_values: list["StreamVariable"],
device: torch.device,
initial_values: Optional[list["StreamVariable"]] = None,
**kwargs: dict[str, Any],
) -> None:
super().__init__(
target_values=target_values, initial_values=initial_values, **kwargs
)
self.device = device
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
# to stream, from stream is the order of the arguments
# we are entering the target, and leaving the initial stream
tx.output.create_proxy(
"call_function",
torch.ops.streams.fork.default,
self._target_stream_proxies() + self._initial_stream_proxies(),
{},
)
return ConstantVariable.create(None)
def exit(self, tx: "InstructionTranslator", *args: tuple[Any]) -> "VariableTracker":
# to stream, from stream is the order of the arguments
# we are leaving the target, and entering the initial stream
tx.output.create_proxy(
"call_function",
torch.ops.streams.join.default,
self._initial_stream_proxies() + self._target_stream_proxies(),
{},
)
return ConstantVariable.create(None)
def _initial_stream_proxies(self) -> tuple[Proxy, Proxy]:
assert self.initial_values, "No initial stream to move from"
return StreamContextVariable._extract_stream_properties(
self.initial_values[0].as_proxy()
)
def _target_stream_proxies(self) -> tuple[Proxy, Proxy]:
return StreamContextVariable._extract_stream_properties(
self._get_target_values()[0].as_proxy()
)
@staticmethod
def _extract_stream_properties(stream_proxy: Proxy) -> tuple[Proxy, Proxy]:
stream_index = GetAttrVariable.create_getattr_proxy(stream_proxy, "stream_id")
stream_device = GetAttrVariable.create_getattr_proxy(stream_proxy, "device")
return stream_index, stream_device
@staticmethod
def _get_current_stream(
device: torch.device, tx: "InstructionTranslator"
) -> "StreamVariable":
from .builder import wrap_fx_proxy_cls
current_stream_method = get_interface_for_device(device).current_stream
current_stream = wrap_fx_proxy_cls(
StreamVariable,
tx,
tx.output.create_proxy(
"call_function",
current_stream_method,
(None,),
{},
),
)
return current_stream
def _get_target_values(self) -> list["StreamVariable"]:
# We need this to be overridable, since StreamVariable does
# not store target values (it does not require any arguments)
# and captures the current stream at the time of entering the context
return self.target_values
def supports_graph_breaks(self) -> bool:
return True
class StreamVariable(StreamContextVariable):
"""Represents the device-agnostic torch.Stream class"""
def __init__(
self,
proxy: Proxy,
value: torch.Stream,
device: torch.device,
**kwargs: Any,
) -> None:
# Index into the user object table
# used to pass arbitrary objects to the graph
user_object_index = kwargs.pop("user_obj_index", None)
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
assert value.device.type == device.type, (
"stream value is not equal to the passed device"
)
super().__init__(target_values=[], initial_values=None, device=device, **kwargs)
self.proxy = proxy
self.value = value
# pyrefly: ignore [read-only]
self.device = device
self.user_object_index = user_object_index
def python_type(self) -> type:
return torch.Stream
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> "VariableTracker":
assert hasattr(self.value, name), f"no stream method found named {name}"
from ..utils import cmp_name_to_op_mapping, proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait_stream", "synchronize", "wait_event"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name == "record_event":
return wrap_fx_proxy_cls(
target_cls=EventVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
elif name in cmp_name_to_op_mapping and len(args) == 1 and not kwargs:
from ..guards import GuardBuilder, install_guard
if self.source:
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
# NB : Checking for mutation is necessary because we compare
# constant values
other = args[0]
if not isinstance(other, StreamVariable):
return ConstantVariable.create(NotImplemented)
if other.source:
install_guard(self.source.make_guard(GuardBuilder.EQUALS_MATCH))
return ConstantVariable.create(
cmp_name_to_op_mapping[name](self.value, other.value) # type: ignore[arg-type]
)
return super().call_method(tx, name, args, kwargs)
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
# NB: Set initial values when we enter
# Don't do this at object creation, as we need to record the current stream
# at the time the context is entered.
self.initial_values = [
StreamContextVariable._get_current_stream(self.device, tx)
]
return super().enter(tx)
def as_proxy(self) -> Proxy:
return self.proxy
def module_name(self) -> str:
return "torch._C"
def fn_name(self) -> str:
return "Stream"
def reconstruct(self, codegen: "PyCodegen") -> None:
# If we got here, this stream is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
if self.user_object_index is not None:
codegen.add_push_null(
lambda: codegen.load_import_from(
torch._dynamo.graph_bytecode_inputs.__name__,
"get_external_object_by_index",
)
)
codegen.append_output(codegen.create_load_const(self.user_object_index))
codegen.extend_output(create_call_function(1, False))
else:
# TODO mlazos: evaluate if we still need this
prefix = f"_stream_{self.device}"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))
@staticmethod
def construct_in_graph_stream(index: int, codegen: "PyCodegen") -> None:
# Use source to create the right bytecode, this
# isn't an actual input
source = CallFunctionNoArgsSource(AttrSource(TorchSource(), "Stream"))
codegen(source)
def _get_target_values(self) -> list["StreamVariable"]:
return [self]
class EventVariable(VariableTracker):
def __init__(self, proxy: Proxy, value: torch.Event, **kwargs: Any) -> None:
if proxy is not None and "example_value" in proxy.node.meta:
assert proxy.node.meta["example_value"] == value
super().__init__(**kwargs)
self.proxy = proxy
self.value = value
def call_method(
self,
tx: "InstructionTranslator",
name: str,
args: list[VariableTracker],
kwargs: dict[str, VariableTracker],
) -> VariableTracker:
from ..utils import proxy_args_kwargs
from .builder import wrap_fx_proxy_cls
if name in ("wait", "record", "synchronize"):
tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
)
return ConstantVariable(None)
elif name == "query":
return wrap_fx_proxy_cls(
target_cls=ConstantVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_method", name, *proxy_args_kwargs([self] + args, kwargs)
),
)
else:
method_name = (
f"{type(self.value).__module__}.{type(self.value).__qualname__}.{name}"
)
unimplemented_v2(
gb_type="Unsupported event method",
context=str(name),
explanation=f"Dynamo doesn't support tracing the {method_name} method. "
f"We currently support wait, record, synchronize, and query.",
hints=[
*graph_break_hints.SUPPORTABLE,
],
)
def as_proxy(self) -> Proxy:
return self.proxy
def reconstruct(self, codegen: "PyCodegen") -> None:
# If we got here, this event is fully subsumed by the graph - this means it is
# not an input or global
assert not self.source
# Similar to stream handling, we lift the event into a global and then codegen bytecode to load it from there.
prefix = "_event"
name = codegen.tx.output.install_global_by_id(prefix, self.value)
codegen.append_output(codegen.create_load_global(name, add=True))