[user-streams] Have StreamVariable inherit from StreamContextVariable (#164344)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164344
Approved by: https://github.com/williamwen42
ghstack dependencies: #162903, #164343
This commit is contained in:
Michael Lazos 2025-10-28 15:34:07 -07:00 committed by PyTorch MergeBot
parent aab27b051a
commit e105a47575

View File

@ -153,7 +153,7 @@ class StreamContextVariable(ContextWrappingVariable):
return current_stream
class StreamVariable(VariableTracker):
class StreamVariable(StreamContextVariable):
"""Represents the device-agnostic torch.Stream class"""
def __init__(
@ -168,7 +168,9 @@ class StreamVariable(VariableTracker):
assert value.device.type == device.type, (
"stream value is not equal to the passed device"
)
super().__init__(**kwargs)
super().__init__(
target_values=[self], initial_values=None, device=device, **kwargs
)
self.proxy = proxy
self.value = value
# pyrefly: ignore [read-only]
@ -230,6 +232,16 @@ class StreamVariable(VariableTracker):
return super().call_method(tx, name, args, kwargs)
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
# NB: Set initial values and target values when we enter
# Don't do this at object creation, as we need to record the current stream
# at the time the context is entered.
self.initial_values = [
StreamContextVariable._get_current_stream(self.device, tx)
]
self.target_values = [self]
return super().enter(tx)
def as_proxy(self) -> Proxy:
return self.proxy