mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[user-streams] Have StreamVariable inherit from StreamContextVariable (#164344)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/164344 Approved by: https://github.com/williamwen42 ghstack dependencies: #162903, #164343
This commit is contained in:
parent
aab27b051a
commit
e105a47575
|
|
@ -153,7 +153,7 @@ class StreamContextVariable(ContextWrappingVariable):
|
|||
return current_stream
|
||||
|
||||
|
||||
class StreamVariable(VariableTracker):
|
||||
class StreamVariable(StreamContextVariable):
|
||||
"""Represents the device-agnostic torch.Stream class"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -168,7 +168,9 @@ class StreamVariable(VariableTracker):
|
|||
assert value.device.type == device.type, (
|
||||
"stream value is not equal to the passed device"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(
|
||||
target_values=[self], initial_values=None, device=device, **kwargs
|
||||
)
|
||||
self.proxy = proxy
|
||||
self.value = value
|
||||
# pyrefly: ignore [read-only]
|
||||
|
|
@ -230,6 +232,16 @@ class StreamVariable(VariableTracker):
|
|||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def enter(self, tx: "InstructionTranslator") -> "VariableTracker":
|
||||
# NB: Set initial values and target values when we enter
|
||||
# Don't do this at object creation, as we need to record the current stream
|
||||
# at the time the context is entered.
|
||||
self.initial_values = [
|
||||
StreamContextVariable._get_current_stream(self.device, tx)
|
||||
]
|
||||
self.target_values = [self]
|
||||
return super().enter(tx)
|
||||
|
||||
def as_proxy(self) -> Proxy:
|
||||
return self.proxy
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user