mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Revert "[user-streams] Fix stream graph output semantics (#164819)"
This reverts commit f5cb9a4c68.
Reverted https://github.com/pytorch/pytorch/pull/164819 on behalf of https://github.com/atalman due to breaks CI ([comment](https://github.com/pytorch/pytorch/pull/164819#issuecomment-3469018283))
This commit is contained in:
parent
e83be7042e
commit
0a3ac47c0a
|
|
@ -537,7 +537,7 @@ class CtxManagerTests(torch._dynamo.test_case.TestCaseWithNestedGraphBreaks):
|
|||
with torch.cuda.stream(new_stream):
|
||||
x = torch.add(x, 4)
|
||||
|
||||
new_event = torch.Event()
|
||||
new_event = torch.cuda.Event()
|
||||
new_event.record(new_stream)
|
||||
|
||||
new_event.wait(cur_stream)
|
||||
|
|
|
|||
|
|
@ -2851,13 +2851,5 @@
|
|||
"Move the Placement usage outside the compiled region"
|
||||
]
|
||||
}
|
||||
],
|
||||
"GB0283": [
|
||||
{
|
||||
"Gb_type": "Failed to make weakref to graph-created external object",
|
||||
"Context": "user_object: {example_value}",
|
||||
"Explanation": "Object does not allow us to make a weakref to it",
|
||||
"Hints": []
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,9 @@
|
|||
import weakref
|
||||
from typing import Any, Callable
|
||||
from typing import Any
|
||||
|
||||
from torch._dynamo.source import Source
|
||||
|
||||
|
||||
PyCodegen = Any
|
||||
|
||||
# This file is to handle types that we don't want to support
|
||||
# as explicit FX graph inputs. This uses a sidetable which
|
||||
# we populate in bytecode and is loaded during graph execution
|
||||
|
|
@ -13,70 +11,44 @@ PyCodegen = Any
|
|||
# We use a dynamo-generated index as a level of indirection
|
||||
# this allows us to register objects externally in pre-graph bytecode that we want
|
||||
# to pass to the graph, but not support their types as graph inputs
|
||||
index_to_bytecode_constructor: dict[int, Callable[[PyCodegen], None]] = {}
|
||||
index_to_source: dict[int, Source] = {}
|
||||
|
||||
index_to_external_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
|
||||
|
||||
keep_alive: list[Any] = []
|
||||
index_to_user_object_weakref: dict[int, weakref.ReferenceType[Any]] = {}
|
||||
|
||||
|
||||
def has_user_objects() -> bool:
|
||||
return bool(index_to_bytecode_constructor)
|
||||
return bool(index_to_source)
|
||||
|
||||
|
||||
def get_external_object_by_index(index: int) -> Any:
|
||||
assert index in index_to_external_object_weakref, (
|
||||
def get_user_object_by_index(index: int) -> Any:
|
||||
assert index in index_to_user_object_weakref, (
|
||||
"Index not registered in index_to_user_object_weakref"
|
||||
)
|
||||
obj = index_to_external_object_weakref[index]()
|
||||
obj = index_to_user_object_weakref[index]()
|
||||
assert obj is not None, "User object is no longer alive"
|
||||
return index_to_external_object_weakref[index]()
|
||||
return index_to_user_object_weakref[index]()
|
||||
|
||||
|
||||
def store_user_object_weakrefs(*args: Any) -> None:
|
||||
global index_to_external_object_weakref
|
||||
index_to_external_object_weakref.clear()
|
||||
index_to_external_object_weakref.update(
|
||||
global index_to_user_object_weakref
|
||||
index_to_user_object_weakref.clear()
|
||||
index_to_user_object_weakref.update(
|
||||
{i: weakref.ref(arg) for i, arg in enumerate(args)}
|
||||
)
|
||||
|
||||
|
||||
def reset_user_object_tracking() -> None:
|
||||
index_to_bytecode_constructor.clear()
|
||||
index_to_external_object_weakref.clear()
|
||||
keep_alive.clear()
|
||||
|
||||
|
||||
def register_graph_created_object(
|
||||
example_value: Any, construct_fn: Callable[[int, PyCodegen], None]
|
||||
) -> int:
|
||||
global index_to_bytecode_constructor
|
||||
global keep_alive
|
||||
keep_alive.append(example_value)
|
||||
index = len(index_to_bytecode_constructor)
|
||||
index_to_bytecode_constructor[index] = lambda cg: construct_fn(index, cg)
|
||||
try:
|
||||
index_to_external_object_weakref[index] = weakref.ref(example_value)
|
||||
except TypeError as e:
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
unimplemented_v2(
|
||||
gb_type="Failed to make weakref to graph-created external object",
|
||||
context=f"user_object: {example_value}",
|
||||
explanation="Object does not allow us to make a weakref to it",
|
||||
hints=[],
|
||||
from_exc=e,
|
||||
)
|
||||
return index
|
||||
index_to_source.clear()
|
||||
index_to_user_object_weakref.clear()
|
||||
|
||||
|
||||
# Register a user object to be used in the graph
|
||||
def register_user_object(value: Any, source: Source) -> int:
|
||||
global index_to_bytecode_constructor
|
||||
index = len(index_to_bytecode_constructor)
|
||||
index_to_bytecode_constructor[index] = lambda cg: cg(source)
|
||||
global index_to_source
|
||||
index = len(index_to_source)
|
||||
index_to_source[index] = source
|
||||
try:
|
||||
index_to_external_object_weakref[index] = weakref.ref(value)
|
||||
index_to_user_object_weakref[index] = weakref.ref(value)
|
||||
except TypeError as e:
|
||||
from .exc import unimplemented_v2
|
||||
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ from .exc import (
|
|||
unimplemented_v2,
|
||||
unimplemented_v2_with_warning,
|
||||
)
|
||||
from .graph_bytecode_inputs import has_user_objects, index_to_bytecode_constructor
|
||||
from .graph_bytecode_inputs import has_user_objects, index_to_source
|
||||
from .graph_deduplication import apply_graph_deduplication
|
||||
from .graph_region_tracker import GraphRegionTracker
|
||||
from .guards import GuardBuilder, install_guard
|
||||
|
|
@ -1539,19 +1539,9 @@ class OutputGraph(OutputGraphCommon):
|
|||
"store_user_object_weakrefs",
|
||||
)
|
||||
)
|
||||
tmp_vars = []
|
||||
for constructor in reversed(index_to_bytecode_constructor.values()):
|
||||
constructor(codegen)
|
||||
var_name = (
|
||||
self.new_var()
|
||||
) # keep alive any temp objects for the rest of the frame
|
||||
codegen.store(var_name)
|
||||
tmp_vars.append(var_name)
|
||||
|
||||
for var_name in tmp_vars:
|
||||
codegen.append_output(codegen.create_load(var_name))
|
||||
|
||||
codegen.call_function(len(index_to_bytecode_constructor), False)
|
||||
for source in reversed(index_to_source.values()):
|
||||
codegen(source)
|
||||
codegen.call_function(len(index_to_source), False)
|
||||
codegen.pop_top()
|
||||
self.add_output_instructions(codegen.get_instructions())
|
||||
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ import torch
|
|||
from torch import SymInt
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.graph_bytecode_inputs import (
|
||||
get_external_object_by_index,
|
||||
get_user_object_by_index,
|
||||
register_user_object,
|
||||
)
|
||||
from torch._dynamo.utils import (
|
||||
|
|
@ -1057,7 +1057,7 @@ class VariableBuilder:
|
|||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
index = register_user_object(value, self.source)
|
||||
stream_proxy = self.tx.output.create_proxy(
|
||||
"call_function", get_external_object_by_index, (index,), {}
|
||||
"call_function", get_user_object_by_index, (index,), {}
|
||||
)
|
||||
set_example_value(stream_proxy.node, value)
|
||||
var = StreamVariable(
|
||||
|
|
@ -1078,7 +1078,7 @@ class VariableBuilder:
|
|||
index = register_user_object(value, self.source)
|
||||
event_proxy = self.tx.output.create_proxy(
|
||||
"call_function",
|
||||
get_external_object_by_index,
|
||||
get_user_object_by_index,
|
||||
(index,),
|
||||
{},
|
||||
)
|
||||
|
|
@ -3006,8 +3006,8 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
|||
set_example_value(proxy.node, example_value)
|
||||
return SymNodeVariable(proxy, example_value, **options)
|
||||
elif (
|
||||
isinstance(example_value, torch.Stream)
|
||||
and proxy.node.target == get_external_object_by_index
|
||||
inspect.isclass(proxy.node.target)
|
||||
and issubclass(proxy.node.target, torch.Stream)
|
||||
) or proxy.node.target in [
|
||||
device_interface.current_stream
|
||||
for _, device_interface in get_registered_device_interfaces()
|
||||
|
|
|
|||
|
|
@ -4,10 +4,8 @@ import torch
|
|||
from torch.fx import Proxy
|
||||
|
||||
from .. import graph_break_hints
|
||||
from ..bytecode_transformation import create_call_function
|
||||
from ..device_interface import get_interface_for_device
|
||||
from ..exc import TYPE_CHECKING, unimplemented_v2
|
||||
from ..source import AttrSource, CallFunctionNoArgsSource, TorchSource
|
||||
from .base import VariableTracker
|
||||
from .constant import ConstantVariable
|
||||
from .ctx_manager import ContextWrappingVariable
|
||||
|
|
@ -174,9 +172,6 @@ class StreamVariable(StreamContextVariable):
|
|||
device: torch.device,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Index into the user object table
|
||||
# used to pass arbitrary objects to the graph
|
||||
user_object_index = kwargs.pop("user_obj_index", None)
|
||||
if proxy is not None and "example_value" in proxy.node.meta:
|
||||
assert proxy.node.meta["example_value"] == value
|
||||
assert value.device.type == device.type, (
|
||||
|
|
@ -188,8 +183,6 @@ class StreamVariable(StreamContextVariable):
|
|||
# pyrefly: ignore [read-only]
|
||||
self.device = device
|
||||
|
||||
self.user_object_index = user_object_index
|
||||
|
||||
def python_type(self) -> type:
|
||||
return torch.Stream
|
||||
|
||||
|
|
@ -268,27 +261,15 @@ class StreamVariable(StreamContextVariable):
|
|||
# If we got here, this stream is fully subsumed by the graph - this means it is
|
||||
# not an input or global
|
||||
assert not self.source
|
||||
if self.user_object_index is not None:
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.load_import_from(
|
||||
torch._dynamo.graph_bytecode_inputs.__name__,
|
||||
"get_external_object_by_index",
|
||||
)
|
||||
)
|
||||
codegen.append_output(codegen.create_load_const(self.user_object_index))
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
else:
|
||||
# TODO mlazos: evaluate if we still need this
|
||||
prefix = f"_stream_{self.device}"
|
||||
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
|
||||
@staticmethod
|
||||
def construct_in_graph_stream(index: int, codegen: "PyCodegen") -> None:
|
||||
# Use source to create the right bytecode, this
|
||||
# isn't an actual input
|
||||
source = CallFunctionNoArgsSource(AttrSource(TorchSource(), "Stream"))
|
||||
codegen(source)
|
||||
# Since we just proved that - for other such structures, like lists and dicts, reconstruction
|
||||
# is fine and sound according to dynamo principles of treating collectives. However,
|
||||
# streams are special in that we want to preserve the identity of the stream as the same as in the graph
|
||||
# Normally, we would do this via codegen for the proxy mapping to an output - we cannot do this yet, as we do not
|
||||
# yet have a plan for how we want to handle the case where the stream is used as an input or an output. Pending
|
||||
# design, to unblock current work, we lift the stream into a global and then codegen bytecode to load it from there.
|
||||
prefix = f"_stream_{self.device}"
|
||||
name = codegen.tx.output.install_global_by_id(prefix, self.value)
|
||||
codegen.append_output(codegen.create_load_global(name, add=True))
|
||||
|
||||
def _get_target_values(self) -> list["StreamVariable"]:
|
||||
return [self]
|
||||
|
|
|
|||
|
|
@ -58,7 +58,6 @@ from ..exc import (
|
|||
raise_observed_exception,
|
||||
unimplemented_v2,
|
||||
)
|
||||
from ..graph_bytecode_inputs import get_external_object_by_index
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
from ..source import (
|
||||
AttrSource,
|
||||
|
|
@ -810,31 +809,14 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
)
|
||||
args = [stacked]
|
||||
|
||||
if issubclass(self.value, torch.Stream):
|
||||
# Register newly created stream for reconstruction
|
||||
stream = self.value()
|
||||
from ..graph_bytecode_inputs import register_graph_created_object
|
||||
from .streams import StreamVariable
|
||||
|
||||
ind = register_graph_created_object(
|
||||
stream, StreamVariable.construct_in_graph_stream
|
||||
)
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function", get_external_object_by_index, (ind,), {}
|
||||
),
|
||||
user_obj_index=ind,
|
||||
)
|
||||
else:
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
*proxy_args_kwargs(args, kwargs),
|
||||
),
|
||||
)
|
||||
tensor_variable = wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
*proxy_args_kwargs(args, kwargs),
|
||||
),
|
||||
)
|
||||
|
||||
return tensor_variable
|
||||
elif self.value is random.Random:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user