mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[hoo] Add with_effects to handle side effectful ops (#120296)
Proposal: https://docs.google.com/document/d/179QyhicGzTXJ5jvTAoAosP_Nzgf3PpgZwU_E3VV9PlM/edit#heading=h.bnm38nu3yfno Implementation discussion: https://docs.google.com/document/d/179QyhicGzTXJ5jvTAoAosP_Nzgf3PpgZwU_E3VV9PlM/edit#heading=h.bj61609o1buq Result with print: ``` graph(): %arg0_1 : [num_users=1] = placeholder[target=arg0_1] %arg1_1 : [num_users=1] = placeholder[target=arg1_1] %with_effects : [num_users=1] = call_function[target=torch._higher_order_ops.effects.with_effects](args = (%arg0_1, aten.print.default, moo), kwargs = {}) %getitem : [num_users=1] = call_function[target=operator.getitem](args = (%with_effects, 0), kwargs = {}) %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg1_1, %arg1_1), kwargs = {}) return [getitem, add] ``` Follow ups: * Add handling to auto_functionalize * Add support for tokens on the export side * Add support for tokens on the inductor side Pull Request resolved: https://github.com/pytorch/pytorch/pull/120296 Approved by: https://github.com/zou3519
This commit is contained in:
parent
29976519a1
commit
a7e93c341f
208
test/higher_order_ops/test_with_effects.py
Normal file
208
test/higher_order_ops/test_with_effects.py
Normal file
|
|
@ -0,0 +1,208 @@
|
|||
# Owner(s): ["module: functorch"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._inductor
|
||||
import torch._inductor.decomposition
|
||||
from torch._functorch.aot_autograd import aot_export_module
|
||||
from torch._higher_order_ops.effects import with_effects
|
||||
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing import FileCheck
|
||||
from torch.testing._internal.common_utils import (
|
||||
find_library_location,
|
||||
IS_FBCODE,
|
||||
IS_MACOS,
|
||||
IS_SANDCASTLE,
|
||||
IS_WINDOWS,
|
||||
run_tests,
|
||||
skipIfTorchDynamo,
|
||||
TestCase,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestWithEffects(TestCase):
|
||||
def setUp(self):
|
||||
if IS_MACOS:
|
||||
raise unittest.SkipTest("non-portable load_library call used in test")
|
||||
elif IS_SANDCASTLE or IS_FBCODE:
|
||||
torch.ops.load_library(
|
||||
"//caffe2/test/cpp/jit:test_custom_class_registrations"
|
||||
)
|
||||
elif IS_WINDOWS:
|
||||
lib_file_path = find_library_location("torchbind_test.dll")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
else:
|
||||
lib_file_path = find_library_location("libtorchbind_test.so")
|
||||
torch.ops.load_library(str(lib_file_path))
|
||||
|
||||
def test_print(self):
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
torch.ops.aten._print("moo")
|
||||
res = x + x
|
||||
torch.ops.aten._print("moo")
|
||||
return (res,)
|
||||
|
||||
inputs = (torch.randn(3),)
|
||||
|
||||
# Without functionalization, print should just appear in the graph directly
|
||||
gm = make_fx(M())(*inputs)
|
||||
FileCheck().check_count("torch.ops.aten._print.default", 2, exactly=True).run(
|
||||
gm.code
|
||||
)
|
||||
|
||||
# With functionalization, it should appear wrapped with with_effects()
|
||||
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1):
|
||||
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
|
||||
getitem = with_effects[0]; with_effects = None
|
||||
add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
|
||||
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
|
||||
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||
return (getitem_2, add)""",
|
||||
)
|
||||
self.assertEqual(len(gs.input_tokens), 1)
|
||||
self.assertEqual(len(gs.output_tokens), 1)
|
||||
|
||||
@unittest.expectedFailure # Will enable this once we enable tokens in export
|
||||
def test_torchbind_custom_op(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
||||
|
||||
def forward(self, x):
|
||||
return (x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x),)
|
||||
|
||||
with enable_torchbind_tracing():
|
||||
gm, gs = aot_export_module(M(), (torch.ones(2, 3),), trace_joint=False)
|
||||
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1):
|
||||
_tensor_constant0 = self._tensor_constant0
|
||||
takes_foo = torch.ops._TorchScriptTesting.takes_foo.default(_tensor_constant0, arg0_1); _tensor_constant0 = None
|
||||
add = torch.ops.aten.add.Tensor(arg0_1, takes_foo); arg0_1 = takes_foo = None
|
||||
return (add,)""", # noqa: B950
|
||||
)
|
||||
self.assertEqual(len(gs.input_tokens), 1)
|
||||
self.assertEqual(len(gs.output_tokens), 1)
|
||||
|
||||
def test_print_with_buffer_mutations(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.register_buffer("buf", torch.ones(3))
|
||||
|
||||
def forward(self, x):
|
||||
torch.ops.aten._print("moo")
|
||||
res = x + x
|
||||
self.buf.add_(res)
|
||||
res = self.buf + x
|
||||
torch.ops.aten._print("moo")
|
||||
return (res,)
|
||||
|
||||
inputs = (torch.randn(3),)
|
||||
|
||||
# With functionalization, it should appear wrapped with with_effects()
|
||||
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||
self.assertExpectedInline(
|
||||
str(gm.code).strip(),
|
||||
"""\
|
||||
def forward(self, arg0_1, arg1_1, arg2_1):
|
||||
with_effects = torch._higher_order_ops.effects.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
|
||||
getitem = with_effects[0]; with_effects = None
|
||||
add = torch.ops.aten.add.Tensor(arg2_1, arg2_1)
|
||||
add_1 = torch.ops.aten.add.Tensor(arg1_1, add); arg1_1 = add = None
|
||||
add_2 = torch.ops.aten.add.Tensor(add_1, arg2_1); arg2_1 = None
|
||||
with_effects_1 = torch._higher_order_ops.effects.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
|
||||
getitem_2 = with_effects_1[0]; with_effects_1 = None
|
||||
return (getitem_2, add_1, add_2)""",
|
||||
)
|
||||
self.assertEqual(len(gs.input_tokens), 1)
|
||||
self.assertEqual(len(gs.output_tokens), 1)
|
||||
self.assertEqual(len(gs.buffers_to_mutate), 1)
|
||||
|
||||
def test_print_with_input_mutations(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x):
|
||||
torch.ops.aten._print("moo")
|
||||
res = x + x
|
||||
x.add_(res)
|
||||
res = x + x
|
||||
torch.ops.aten._print("moo")
|
||||
return (res,)
|
||||
|
||||
inputs = (torch.randn(3),)
|
||||
|
||||
# With functionalization, it should appear wrapped with with_effects()
|
||||
gm, gs = aot_export_module(M(), inputs, trace_joint=False)
|
||||
self.assertEqual(len(gs.input_tokens), 1)
|
||||
self.assertEqual(len(gs.output_tokens), 1)
|
||||
self.assertEqual(len(gs.user_inputs_to_mutate), 1)
|
||||
|
||||
def test_alias_op(self):
|
||||
def f(token, x):
|
||||
token, out = with_effects(token, torch.ops.aten.absolute_.default, x)
|
||||
return token, out
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError, r"Ops with aliasing is not supported"
|
||||
):
|
||||
make_fx(f)(torch.tensor([]), torch.tensor(4))
|
||||
|
||||
def test_compile_aot_eager(self):
|
||||
def f(x):
|
||||
torch.ops.aten._print("moo")
|
||||
res = x + x
|
||||
torch.ops.aten._print("moo")
|
||||
return res
|
||||
|
||||
inputs = (torch.randn(2, 3),)
|
||||
|
||||
res = torch.compile(f, backend="aot_eager")(*inputs)
|
||||
self.assertTrue(torch.allclose(res, f(*inputs)))
|
||||
|
||||
@skipIfTorchDynamo(
|
||||
"We're testing if the test works with inductor, which it currently"
|
||||
"doesn't, so we expectedFailure-d the test, but the Dynamo tests"
|
||||
"override the backend, causing an unexpected success"
|
||||
)
|
||||
@unittest.expectedFailure # NYI: AssertionError: with_effects is not an OpOverload
|
||||
def test_compile_inductor(self):
|
||||
def f(x):
|
||||
torch.ops.aten._print("moo")
|
||||
res = x + x
|
||||
torch.ops.aten._print("moo")
|
||||
return res
|
||||
|
||||
inputs = (torch.randn(2, 3),)
|
||||
|
||||
res = torch.compile(f, backend="inductor")(*inputs)
|
||||
self.assertTrue(torch.allclose(res, f(*inputs)))
|
||||
|
||||
def test_compile_aot_eager_requires_grad(self):
|
||||
def f(x):
|
||||
torch.ops.aten._print("moo")
|
||||
res = x + x
|
||||
torch.ops.aten._print("moo")
|
||||
return res
|
||||
|
||||
inputs = (torch.randn(2, 3, requires_grad=True),)
|
||||
|
||||
res = torch.compile(f, backend="aot_eager")(*inputs)
|
||||
self.assertTrue(torch.allclose(res, f(*inputs)))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
@ -105,7 +105,8 @@ def run_functionalized_fw_and_collect_metadata(
|
|||
|
||||
# It doesn't matter if we run this under predispatch or not because it is
|
||||
# only for figuring out metadata
|
||||
with disable_above, FunctionalTensorMode():
|
||||
mode = FunctionalTensorMode(_allow_token_discovery=True)
|
||||
with disable_above, mode:
|
||||
# precondition: The passed in function already handles unflattening inputs + flattening outputs
|
||||
flat_f_args = pytree.tree_map(_to_fun, flat_args)
|
||||
flat_f_outs = f(*flat_f_args)
|
||||
|
|
@ -618,6 +619,7 @@ from a multi-output view call"
|
|||
subclass_tangent_meta=create_subclass_meta(traced_tangents),
|
||||
is_train=is_train,
|
||||
grad_enabled_mutation=grad_enabled_mutation,
|
||||
tokens=mode._tokens,
|
||||
)
|
||||
return metadata
|
||||
|
||||
|
|
|
|||
|
|
@ -374,9 +374,10 @@ def create_graph_signature(
|
|||
graph_output_names = _graph_output_names(fx_g)
|
||||
|
||||
num_params_buffers = len(param_names) + len(buffer_names)
|
||||
num_tokens = len(fw_metadata.tokens)
|
||||
# We have enough restrictions on the graph (no de-duping, synthetic bases, etc),
|
||||
# Such that # graph inps = # user inps + # params + # buffers
|
||||
num_user_args = len(graph_input_names) - num_params_buffers
|
||||
num_user_args = len(graph_input_names) - num_params_buffers - num_tokens
|
||||
|
||||
if trace_joint:
|
||||
assert num_user_fw_outs is not None
|
||||
|
|
@ -411,7 +412,9 @@ def create_graph_signature(
|
|||
else:
|
||||
backward_signature = None
|
||||
num_user_fw_outs = (
|
||||
len(graph_output_names) - fw_metadata.num_mutated_inp_runtime_indices
|
||||
len(graph_output_names)
|
||||
- fw_metadata.num_mutated_inp_runtime_indices
|
||||
- num_tokens
|
||||
)
|
||||
|
||||
return GraphSignature.from_tracing_metadata(
|
||||
|
|
|
|||
|
|
@ -222,6 +222,9 @@ def aot_dispatch_autograd(
|
|||
+ inner_meta.num_outputs
|
||||
+ inner_meta.num_intermediate_bases
|
||||
+ inner_meta.num_outputs_rng_offset
|
||||
+ len(
|
||||
fw_metadata.tokens
|
||||
) # See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
)
|
||||
fw_module, bw_module = aot_config.partition_fn(
|
||||
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
|
||||
|
|
@ -493,7 +496,7 @@ def aot_dispatch_autograd(
|
|||
args = (*args, seed, offset)
|
||||
# There is a pretty complicated calling convention around what the compiled fw returns.
|
||||
# The full list of outputs and their relative order is:
|
||||
# (*mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
|
||||
# (*tokens, *mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
|
||||
# - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version
|
||||
# of the original view, and not the synthetic base
|
||||
|
||||
|
|
@ -514,6 +517,7 @@ def aot_dispatch_autograd(
|
|||
num_mutated_runtime_inps = (
|
||||
CompiledFunction.metadata.num_mutated_inp_runtime_indices
|
||||
)
|
||||
num_tokens = len(CompiledFunction.metadata.tokens)
|
||||
num_forward_returns = CompiledFunction.metadata.num_forward_returns
|
||||
num_forward = CompiledFunction.metadata.num_forward
|
||||
|
||||
|
|
@ -538,7 +542,7 @@ def aot_dispatch_autograd(
|
|||
), str([type(x) for x in symint_outs])
|
||||
ctx.symints = symint_outs
|
||||
|
||||
raw_returns = fw_outs[0:num_forward_returns]
|
||||
raw_returns = fw_outs[0 : num_forward_returns + num_tokens]
|
||||
|
||||
# Wrap all autograd.Function.forward() outputs that are aliases
|
||||
# so that autograd.Function doesn't treat them as tensors
|
||||
|
|
|
|||
|
|
@ -69,10 +69,15 @@ def create_runtime_wrapper(
|
|||
keep_input_mutations: bool,
|
||||
disable_amp: bool,
|
||||
):
|
||||
num_tokens = len(runtime_metadata.tokens)
|
||||
|
||||
if not hasattr(compiled_fn, "_boxed_call"):
|
||||
compiled_fn = make_boxed_func(compiled_fn)
|
||||
|
||||
def runtime_wrapper(*args):
|
||||
# Pass in effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
|
||||
args = (*[torch.tensor([])] * num_tokens, *args)
|
||||
|
||||
if trace_joint:
|
||||
args_ = list(args)
|
||||
# See Note [Detaching inputs that never need gradients]
|
||||
|
|
@ -120,8 +125,12 @@ def create_runtime_wrapper(
|
|||
== num_mutated_runtime_inps
|
||||
+ runtime_metadata.num_outputs
|
||||
+ num_intermediate_bases
|
||||
+ num_tokens
|
||||
)
|
||||
|
||||
# Toss out the effect tokens (See Note [Side-Effectful Tokens in AOTAutograd])
|
||||
all_outs = all_outs[num_tokens:]
|
||||
|
||||
# Step 3: After running the compiled fw, apply updates to mutated inputs
|
||||
num_mutations_to_apply = runtime_metadata.num_mutated_inp_runtime_indices
|
||||
if num_mutations_to_apply > 0:
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ input/output types, metadata, config, function signatures etc.
|
|||
|
||||
import collections
|
||||
import functools
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, Dict, List, NewType, Optional, Set, Tuple, Union
|
||||
|
||||
|
|
@ -267,6 +267,11 @@ class ViewAndMutationMeta:
|
|||
# raised
|
||||
deterministic: Optional[bool] = None
|
||||
|
||||
# Map of effect type (ex. _EffectType.ORDERED) to token. If there are
|
||||
# side-effectful operators, FunctionalTensorMode will populate this
|
||||
# dictionary telling us how many tokens we will need during tracing.
|
||||
tokens: Dict[Any, torch.Tensor] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
# pre-compute the indices of the inputs that are mutated.
|
||||
# When keep_input_mutations is set, we don't need to worry about our epilogue
|
||||
|
|
@ -549,6 +554,9 @@ class GraphSignature:
|
|||
|
||||
backward_signature: Optional[BackwardSignature]
|
||||
|
||||
input_tokens: List[GraphInputName]
|
||||
output_tokens: List[GraphOutputName]
|
||||
|
||||
@classmethod
|
||||
def from_tracing_metadata(
|
||||
cls,
|
||||
|
|
@ -569,35 +577,54 @@ class GraphSignature:
|
|||
graph_outputs = graph_output_names
|
||||
parameters = list(named_parameters)
|
||||
buffers = list(named_buffers)
|
||||
num_tokens = len(view_mutation_metadata.tokens)
|
||||
|
||||
# Calling convention assumptions:
|
||||
# (1) graph inputs = (params, buffers, user_inputs)
|
||||
# (2) graph outputs = (mutated_inputs, user_outs, param_gradients)
|
||||
# (1) graph inputs = (input_tokens, params, buffers, user_inputs)
|
||||
# (2) graph outputs = (output_tokens, mutated_inputs, user_outs, param_gradients)
|
||||
# (If we are capturing an inference graph, this convention is identical
|
||||
# except that param_gradients is empty)
|
||||
user_inputs = graph_inputs[len(parameters) + len(buffers) :]
|
||||
assert num_user_inputs == len(user_inputs)
|
||||
assert len(graph_inputs) == (len(parameters) + len(buffers) + len(user_inputs))
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd] for information on tokens
|
||||
|
||||
inputs_to_parameters = dict(zip(graph_inputs[: len(parameters)], parameters))
|
||||
# Address input calling conventions:
|
||||
start, stop = 0, num_tokens
|
||||
input_tokens = graph_inputs[start:stop]
|
||||
|
||||
start, stop = stop, stop + len(parameters)
|
||||
inputs_to_parameters = dict(zip(graph_inputs[start:stop], parameters))
|
||||
|
||||
start, stop = stop, stop + len(buffers)
|
||||
inputs_to_buffers = dict(
|
||||
zip(
|
||||
graph_inputs[len(parameters) : len(parameters) + len(buffers)],
|
||||
graph_inputs[start:stop],
|
||||
buffers,
|
||||
)
|
||||
)
|
||||
|
||||
names = [*parameters, *buffers, *user_inputs]
|
||||
start, stop = stop, stop + num_user_inputs
|
||||
user_inputs = graph_inputs[start:stop]
|
||||
|
||||
# We should've gone through all the inputs now
|
||||
assert len(graph_inputs) - stop == 0
|
||||
|
||||
# Address output calling conventions:
|
||||
start, stop = 0, num_tokens
|
||||
output_tokens = graph_outputs[start:stop]
|
||||
|
||||
names = [*input_tokens, *parameters, *buffers, *user_inputs]
|
||||
mutations = []
|
||||
for idx, input_info in enumerate(view_mutation_metadata.input_info):
|
||||
if input_info.mutates_data:
|
||||
# Only buffers can be mutated, not parameters
|
||||
assert idx >= len(parameters)
|
||||
mutations.append(names[idx])
|
||||
mutations.append(names[idx + num_tokens])
|
||||
|
||||
assert len(mutations) == view_mutation_metadata.num_mutated_inp_runtime_indices
|
||||
|
||||
start, stop = 0, view_mutation_metadata.num_mutated_inp_runtime_indices
|
||||
start, stop = (
|
||||
stop,
|
||||
stop + view_mutation_metadata.num_mutated_inp_runtime_indices,
|
||||
)
|
||||
outputs_to_mutations = dict(zip(graph_outputs[start:stop], mutations))
|
||||
|
||||
user_inputs_to_mutate = {}
|
||||
|
|
@ -631,6 +658,8 @@ class GraphSignature:
|
|||
in_spec=in_spec,
|
||||
out_spec=out_spec,
|
||||
backward_signature=backward_signature,
|
||||
input_tokens=input_tokens, # type: ignore[arg-type]
|
||||
output_tokens=output_tokens, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ from torch import Tensor
|
|||
from torch._decomp.decompositions_for_rng import PhiloxStateTracker
|
||||
from torch._guards import detect_fake_mode
|
||||
from torch._prims_common import CUDARngStateHelper
|
||||
from torch._subclasses.functional_tensor import FunctionalTensorMode
|
||||
from torch.fx.experimental.symbolic_shapes import definitely_false, sym_eq
|
||||
from torch.nn.utils import stateless
|
||||
|
||||
|
|
@ -350,12 +349,43 @@ def create_functionalized_fn(
|
|||
disable_above = torch._C._ExcludeDispatchKeyGuard(
|
||||
torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize)
|
||||
)
|
||||
with disable_above, FunctionalTensorMode(aot_config.pre_dispatch):
|
||||
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
if trace_joint:
|
||||
assert (
|
||||
isinstance(args, tuple)
|
||||
and len(args) == 2
|
||||
and isinstance(args[0], (list, tuple))
|
||||
)
|
||||
tokens = args[0][: len(meta.tokens)]
|
||||
actual_args = args[0][len(meta.tokens) :]
|
||||
args = (actual_args, args[1])
|
||||
else:
|
||||
tokens = args[: len(meta.tokens)]
|
||||
args = args[len(meta.tokens) :]
|
||||
assert all(token.numel() == 0 for token in tokens)
|
||||
|
||||
with disable_above:
|
||||
# Wrap inputs into functional wrappers
|
||||
f_args = pytree.tree_map(to_fun, args)
|
||||
f_tokens = pytree.tree_map(to_fun, tokens)
|
||||
|
||||
# Populate the current FunctionalTensorMode with the tokens per
|
||||
# operator. See Note [FunctionalTensorMode is Stateful]
|
||||
functional_tensor_mode = (
|
||||
torch.utils._python_dispatch._detect_functional_mode()
|
||||
)
|
||||
assert functional_tensor_mode is not None
|
||||
for i, k in enumerate(meta.tokens.keys()):
|
||||
functional_tensor_mode._tokens[k] = f_tokens[i]
|
||||
|
||||
# Run the joint
|
||||
f_outs = fn(*f_args)
|
||||
|
||||
# Return both the tokens and the outputs
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
f_outs = (*functional_tensor_mode._tokens.values(), *f_outs)
|
||||
|
||||
if trace_joint:
|
||||
# We support a limited amount of mutation of graph inputs during the backward pass.
|
||||
# (This is used e.g. by Float8, which needs to update buffers during the backward pass)
|
||||
|
|
@ -470,6 +500,14 @@ def create_functionalized_fn(
|
|||
# Setup the wrapper for functionalization of rng ops
|
||||
helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint)
|
||||
|
||||
# Additionally pass in tokens as inputs
|
||||
# See Note [Side-Effectful Tokens in AOTAutograd]
|
||||
additional_token_inputs = [torch.tensor([])] * len(meta.tokens)
|
||||
if trace_joint:
|
||||
args = ([*additional_token_inputs, *args[0]], *args[1:])
|
||||
else:
|
||||
args = [*additional_token_inputs, *args]
|
||||
|
||||
return helper, args
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -375,6 +375,31 @@ AOT_COUNTER = itertools.count()
|
|||
# To work around this, we view every forward output when creating out tangent
|
||||
# tensors so that tangents can never be the same as forward inputs even if
|
||||
# forward inputs alias forward outputs.
|
||||
|
||||
# Note [Side-Effectful Tokens in AOTAutograd]
|
||||
#
|
||||
# We allow some some side-effectful operators in
|
||||
# the post-AOTAutograd (functional) graph, such as prints and torchbind operations.
|
||||
# To ensure that these side-effects are compatible to future graph passes that
|
||||
# assume that the graph is functional, we will thread "effect tokens" to show
|
||||
# data dependence between these side-effectful operators. Practically speaking,
|
||||
# effect tokens are just dummy values (torch.tensor([])). The graph would look
|
||||
# like the following:
|
||||
#
|
||||
# def gm(self, token0, reader):
|
||||
# token1, frame = with_token(ordered_effect_op, (reader,), token0)
|
||||
# frame = frame * 2
|
||||
# token2, frame2 = with_token(ordered_effect_op, (reader,), token1)
|
||||
# frame2 = frame2 * 2
|
||||
# return token2, frame, frame2
|
||||
#
|
||||
# We will pass the token as an input to the graph, thread it through
|
||||
# side-effectful operators using the `with_effects` high order operator, and then
|
||||
# return the updated token as an output.
|
||||
# So the signature of the graph input would look something like
|
||||
# (*tokens, *params_buffers, *user_inputs), and the signature of the graph
|
||||
# output would look something like (*tokens, *outputs).
|
||||
|
||||
#
|
||||
#
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
|
|||
206
torch/_higher_order_ops/effects.py
Normal file
206
torch/_higher_order_ops/effects.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._C import DispatchKey
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
|
||||
|
||||
class _EffectType(Enum):
|
||||
ORDERED = "Ordered"
|
||||
|
||||
|
||||
SIDE_EFFECTS: Dict[torch._ops.OpOverload, _EffectType] = {
|
||||
torch.ops.aten._print.default: _EffectType.ORDERED,
|
||||
}
|
||||
|
||||
|
||||
class WithEffects(HigherOrderOperator):
|
||||
"""
|
||||
with_effects(token, op, args, kwargs) -> (new_token, op_results)
|
||||
|
||||
This HOP helps ensure ordering between side effectful ops like prints or ops
|
||||
using torchbind objects. This is needed to ensure a traced graph from
|
||||
AOTAutograd is functional so that future optimization passes do not reorder
|
||||
these operators. This is done through threading "effect tokens" through the
|
||||
graph to enforce data dependence between side effectful ops.
|
||||
|
||||
The tokens are basically dummy values (torch.tensor([])). We create a token
|
||||
per "effect type", which are enumerated in the _EffectType enum.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("with_effects")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
token,
|
||||
op: torch._ops.OpOverload,
|
||||
*args: Tuple[Any, ...],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[Any, ...]:
|
||||
assert isinstance(op, torch._ops.OpOverload)
|
||||
assert not has_aliasing(op), "Ops with aliasing is not supported"
|
||||
assert has_effects(op, args, kwargs)
|
||||
assert isinstance(kwargs, dict)
|
||||
return super().__call__(token, op, *args, **kwargs)
|
||||
|
||||
|
||||
with_effects = WithEffects()
|
||||
|
||||
|
||||
def has_aliasing(op: torch._ops.OpOverload):
|
||||
for arg in op._schema.arguments:
|
||||
if arg.alias_info is not None:
|
||||
return True
|
||||
for arg in op._schema.returns:
|
||||
if arg.alias_info is not None:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def has_effects(op, args, kwargs) -> bool:
|
||||
return (
|
||||
isinstance(op, torch._ops.OpOverload)
|
||||
and not has_aliasing(op)
|
||||
and get_effect_key(op, args, kwargs) is not None
|
||||
)
|
||||
|
||||
|
||||
def get_effect_key(op, args, kwargs) -> Optional[_EffectType]:
|
||||
if op in SIDE_EFFECTS:
|
||||
return SIDE_EFFECTS[op]
|
||||
|
||||
# TODO(angelayi): Enable this when enabling tokens with export -- this will
|
||||
# break some existing export tests right now
|
||||
# for arg in args:
|
||||
# if isinstance(arg, torch.ScriptObject):
|
||||
# return _EffectType.ORDERED
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@with_effects.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def with_effects_dense(
|
||||
token: torch.Tensor,
|
||||
op: torch._ops.OpOverload,
|
||||
*args: Tuple[Any, ...],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
out = op(*args, **kwargs)
|
||||
new_token = torch.tensor([])
|
||||
if isinstance(out, tuple):
|
||||
return (new_token, *out)
|
||||
return (new_token, out)
|
||||
|
||||
|
||||
@with_effects.py_impl(FakeTensorMode)
|
||||
def with_effects_fake(
|
||||
mode,
|
||||
token: torch.Tensor,
|
||||
op: torch._ops.OpOverload,
|
||||
*args: Tuple[Any, ...],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
with mode:
|
||||
result = with_effects_dense(token, op, *args, **kwargs)
|
||||
return result
|
||||
|
||||
|
||||
@with_effects.py_impl(ProxyTorchDispatchMode)
|
||||
def with_effects_proxy(
|
||||
mode,
|
||||
token: torch.Tensor,
|
||||
op: torch._ops.OpOverload,
|
||||
*args: Tuple[Any, ...],
|
||||
**kwargs: Dict[str, Any],
|
||||
) -> Tuple[torch.Tensor, ...]:
|
||||
if not mode.enable_tracing:
|
||||
return with_effects(token, op, *args, **kwargs)
|
||||
|
||||
with disable_proxy_modes_tracing():
|
||||
out = with_effects(token, op, *args, **kwargs)
|
||||
|
||||
proxy_token = mode.tracer.unwrap_proxy(token)
|
||||
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
|
||||
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
|
||||
|
||||
out_proxy = mode.tracer.create_proxy(
|
||||
"call_function",
|
||||
with_effects,
|
||||
(proxy_token, op, *proxy_args),
|
||||
proxy_kwargs,
|
||||
)
|
||||
result = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
|
||||
return result
|
||||
|
||||
|
||||
with_effects.fallthrough(DispatchKey.AutogradCPU)
|
||||
with_effects.fallthrough(DispatchKey.AutogradCUDA)
|
||||
|
||||
|
||||
def handle_effects(
|
||||
allow_token_discovery: bool,
|
||||
tokens: Dict[_EffectType, torch.Tensor],
|
||||
op: torch._ops.OpOverload,
|
||||
args: Tuple[Any, ...],
|
||||
kwargs: Dict[str, Any],
|
||||
) -> Any:
|
||||
"""
|
||||
Args:
|
||||
allow_token_discovery: Whether or not we are discovering tokens. If this
|
||||
is true, we will create a token for every side effect type seen that
|
||||
does not have a token assigned yet. If this is false, the tokens
|
||||
should've all been created ahead of time, so we will error if there is
|
||||
no token mapping to every effect type.
|
||||
|
||||
tokens: Map of effect type to tokens. This is to chain operators of the
|
||||
same effects together so that they do not get reordered in later
|
||||
optimization passes.
|
||||
"""
|
||||
|
||||
# Get a token. We can't do `tokens.get(op, torch.tensor([]))` because
|
||||
# this will create an empty tensor during proxy mode tracing if the token
|
||||
# doesn't exist. But the tokens should always exist during proxy mode tracing.
|
||||
key = get_effect_key(op, args, kwargs)
|
||||
assert key is not None
|
||||
if key not in tokens:
|
||||
assert allow_token_discovery, f"Could not find a token for effect {key}"
|
||||
tokens[key] = torch.tensor([])
|
||||
token = tokens[key]
|
||||
|
||||
from torch._subclasses.functional_tensor import PythonFunctionalizeAPI
|
||||
|
||||
ctx = PythonFunctionalizeAPI()
|
||||
|
||||
unwrapped_token = ctx.unwrap_tensors([token])[0] # type: ignore[arg-type]
|
||||
unwrapped_args = ctx.unwrap_tensors(args) # type: ignore[arg-type]
|
||||
unwrapped_kwargs = ctx.unwrap_tensors(kwargs) # type: ignore[arg-type]
|
||||
with ctx.redispatch_to_next():
|
||||
(new_token, *unwrapped_outs) = with_effects(
|
||||
unwrapped_token, op, *unwrapped_args, **unwrapped_kwargs # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if len(op._schema.returns) == 0:
|
||||
assert unwrapped_outs[0] is None
|
||||
unwrapped_outs = None # type: ignore[assignment]
|
||||
elif len(op._schema.returns) == 1:
|
||||
assert len(unwrapped_outs) == 1
|
||||
unwrapped_outs = unwrapped_outs[0]
|
||||
else:
|
||||
assert len(unwrapped_outs) == len(op._schema.returns)
|
||||
|
||||
# Add the newly created token into the tokens map for a following call to
|
||||
# use this token.
|
||||
wrapped_token = ctx.wrap_tensors(new_token)
|
||||
assert isinstance(wrapped_token, torch.Tensor)
|
||||
tokens[key] = wrapped_token
|
||||
|
||||
return ctx.wrap_tensors(unwrapped_outs) # type: ignore[arg-type]
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import contextlib
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, ContextManager, Optional, Tuple
|
||||
from typing import Any, Callable, ContextManager, Dict, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
|
|
@ -215,7 +215,7 @@ class FunctionalTensor(torch.Tensor):
|
|||
|
||||
|
||||
class FunctionalTensorMode(TorchDispatchMode):
|
||||
def __init__(self, pre_dispatch=False, export=False):
|
||||
def __init__(self, pre_dispatch=False, export=False, _allow_token_discovery=False):
|
||||
self.export = export
|
||||
self.is_on_stack = False
|
||||
self.enter_stack = []
|
||||
|
|
@ -225,6 +225,18 @@ class FunctionalTensorMode(TorchDispatchMode):
|
|||
self.pre_dispatch = pre_dispatch
|
||||
# This will be turned off later for pre-dispatch functionalization
|
||||
self._dispatch_key = torch._C.DispatchKey.PreDispatch if pre_dispatch else None # type: ignore[attr-defined]
|
||||
# Map of effect type (ex. _EffectType.ORDERED) to a token. The tokens help keep
|
||||
# track of the ordering between side effectful operations.
|
||||
self._tokens: Dict[Any, torch.Tensor] = {}
|
||||
|
||||
# Functionalization runs twice in AOTAutograd, once in
|
||||
# `run_functionalized_fw_and_collect_metadata` to collect metadata to
|
||||
# see which tensors need to be functionalized and discover how many
|
||||
# tokens we need, and another time in `make_fx` which does the actual
|
||||
# tracing to replace ops with their functional variants and handling
|
||||
# side-effectful ops. In the second stage there should be no token
|
||||
# discovery. This flag distinguishes between the two stages.
|
||||
self._allow_token_discovery = _allow_token_discovery
|
||||
|
||||
# No-op if FunctionalTensorMode is already in use
|
||||
def __enter__(self):
|
||||
|
|
@ -338,6 +350,16 @@ class FunctionalTensorMode(TorchDispatchMode):
|
|||
)
|
||||
return do_auto_functionalize(func, args, kwargs)
|
||||
|
||||
from torch._higher_order_ops.effects import handle_effects, has_effects
|
||||
|
||||
if has_effects(func, args, kwargs):
|
||||
assert not torch._C._dispatch_has_kernel_for_dispatch_key(
|
||||
func.name(), torch._C.DispatchKey.Functionalize
|
||||
)
|
||||
return handle_effects(
|
||||
self._allow_token_discovery, self._tokens, func, args, kwargs
|
||||
)
|
||||
|
||||
args_unwrapped, kwargs_unwrapped = pytree.tree_map_only(
|
||||
FunctionalTensor, unwrap, (args, kwargs)
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user