pytorch/torch/_functorch/aot_autograd.py
Aryan Gupta 8cee0a25bd fix: Flake8-BugBear code B-026 for PyTorch (#111362)
Fixes #106571

I have fixed the B-026 error codes for Flake8 tests on the codebase. Please review and tell me anything else to do.
Thanks and excited for this first contribution to PyTorch.

Also I refer this issue which introduced [B-026](https://github.com/PyCQA/flake8-bugbear/issues/286) in `pytest-bugbear` and discuss the error code.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111362
Approved by: https://github.com/Skylion007
2023-11-07 21:38:18 +00:00

5165 lines
243 KiB
Python

import collections
import dataclasses
import itertools
import logging
import warnings
import pprint
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
from enum import Enum
from functools import partial, wraps
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, NewType
from unittest.mock import patch
from functorch import make_fx
import torch
import torch.fx.traceback as fx_traceback
import torch.nn as nn
import torch.utils._pytree as pytree
import torch.utils.dlpack
from torch import Tensor
from torch._subclasses.meta_utils import safe_is_leaf
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo import compiled_autograd
from torch._dynamo.utils import dynamo_timed, lazy_format_graph_code, preserve_rng_state
from torch._guards import detect_fake_mode, tracing
from torch._prims_common import CUDARngStateHelper
from torch._logging import getArtifactLogger
from torch._subclasses import FakeTensor, FakeTensorMode
from torch._subclasses.fake_tensor import is_fake
from torch._subclasses.functional_tensor import FunctionalTensor, FunctionalTensorMode
from torch.fx import immutable_collections, Interpreter
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
from torch.fx.experimental.symbolic_shapes import ShapeEnv, is_concrete_int, fx_placeholder_vals
from torch.multiprocessing.reductions import StorageWeakRef
from torch.nn.utils import stateless
from torch.utils._python_dispatch import is_traceable_wrapper_subclass, transform_subclass
from torch._decomp.decompositions_for_rng import PhiloxStateTracker, rng_decompositions
from . import config
from .partitioners import default_partition
from torch._guards import TracingContext, DuplicateInputs, Source
original_zip = zip
def strict_zip(*iterables, strict=True, **kwargs):
if not strict:
return original_zip(*iterables, **kwargs)
shortest_length = min(len(it) for it in iterables)
for iterable in iterables:
if len(iterable) != shortest_length:
raise ValueError("The iterables have different lengths and strict mode is enabled.")
return original_zip(*iterables, **kwargs)
zip = strict_zip
log = logging.getLogger(__name__)
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
MutationType = Enum(
"MutationType", ("none", "metadata_only", "data", "data_and_metadata")
)
OutputType = Enum(
"OutputType", (
# output is not an alias
"non_alias",
# output aliases an input
"alias_of_input",
# output **is** an input tensor
"is_input",
# output has a ._base tensor, which is a graph intermediate.
# We need to return its ._base as a graph output,
# so its requires_grad info is populated correctly.
# Instructs the runtime code to regenerate the current output
# from a base tensor, graph_intermediates[base_idx]
"alias_of_intermediate_save_as_output",
# Same as above; but we don't need to explicitly add its ._base
# as a graph output, because it already **is** a graph output.
"alias_of_intermediate",
# Same as above; but the output's ._base is **already** a user output.
# Instructs the runtime code to regenerate the current output from
# a base tensor, user_outputs[base_idx]
"alias_of_intermediate_base_is_user_output",
# See Note [Intermediate Bases Optimization]
"unsafe_view_alias",
# output is an alias, but has a custom autograd.Function backward.
# In this case, we don't want to do view-replay, since we won't be able to replay the custom function.
# Instead, we'll treat this output "normally", and trace its backward into the graph.
"custom_function_view",
)
)
pytree._register_pytree_node(
immutable_collections.immutable_list,
lambda x: (list(x), None),
lambda x, c: immutable_collections.immutable_list(x),
)
pytree._register_pytree_node(
immutable_collections.immutable_dict,
lambda x: (list(x.values()), list(x.keys())),
lambda x, c: immutable_collections.immutable_dict(
dict(zip(c, x))
),
)
def partial_asdict(obj: Any) -> Any:
if dataclasses.is_dataclass(obj):
return {field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)}
elif isinstance(obj, (list, tuple)):
return obj.__class__([partial_asdict(item) for item in obj])
elif isinstance(obj, dict):
return {k: partial_asdict(v) for k, v in obj.items()}
else:
return obj
aten = torch.ops.aten
# This global counter increments every time we compile a graph with
# AOTAutograd. You can use this to correlate runtime error messages
# with compile time (e.g., if you get an error at runtime saying
# compiled graph 3 failed, you can set a breakpoint at compile time
# for this graph number to investigate further at compile time.)
#
# NB: this is different from get_aot_compilation_context, which tracks
# each underlying graph that is compiled. In contrast, AOT_COUNTER
# corresponds to top-level invocations of aot_module/aot_function;
# one counter is allocated per entire compiled block (but this block
# may involve compiling multiple subgraphs; e.g., for forwards/backwards)
AOT_COUNTER = itertools.count()
KNOWN_TYPES = tuple(
[torch.Tensor, int, str, float, bool, type(None)] + list(py_sym_types)
)
# Set up hooks so that during backward the fx's stack_trace is properly set
callback_set = False
def setup_stacktrace_preservation_hooks(roots: List):
def iter_graph(roots):
if not roots:
return
seen = set()
q = collections.deque()
for node in roots:
if node is not None:
seen.add(node)
q.append(node)
while q:
node = q.popleft()
for fn, _idx in node.next_functions:
if fn in seen or fn is None:
continue
seen.add(fn)
q.append(fn)
yield node
def get_callback(saved_stack_):
def callback():
global callback_set
fx_traceback.set_stack_trace(saved_stack_)
callback_set = False
return callback
def get_prehook(stack_, seq_nr):
def prehook(grad_output):
global callback_set
if not callback_set:
torch.autograd.variable.Variable._execution_engine.queue_callback(
get_callback(fx_traceback.format_stack())
)
callback_set = True
fx_traceback.set_stack_trace(stack_)
fx_traceback.set_grad_fn_seq_nr(seq_nr)
return prehook
def get_posthook(special_stack_, seq_nr):
def posthook(grad_input, grad_output):
fx_traceback.set_stack_trace(special_stack_)
fx_traceback.reset_grad_fn_seq_nr()
return posthook
for node in iter_graph(roots):
forward_node_stack = node.metadata.get("traceback_", [])
node.register_prehook(get_prehook(forward_node_stack,
node._sequence_nr()))
special_stack = forward_node_stack.copy()
special_stack.append(
"Gradient addition node due to multiple use of tensor around:"
)
node.register_hook(get_posthook(special_stack, node._sequence_nr()))
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# AOT Autograd contains a pretty non-trivial amount of logic to handle edge cases around aliasing and mutation
# that are external to the graph (they show up as side effects in some way when you run the graph).
#
# Take a look at `test_aotdispatch.py TestAOTAutograd.test_input_mutation*` tests for some examples functions
# and what they're compiled graphs looks like.
# Below is a very long comment detailing several edge cases, and showing how AOT Autograd handles them.
#
# Note [AOT Autograd: input data mutations]
#
# If we compile a function that mutates inputs, then those input mutations are real side effects
# that a user expects to see after running the compiled graph.
# However, the graph that we want to send to a backend needs to be *entirely* functional.
# The way we reconcile this difference is that we remove the mutations completely from the graph that we compile
# but we update the graph to return (updated_inputs, user_outputs).
# In the epilogue that runs after the compiled graph is executed, we copy the updated inputs back to the originals.
#
# Example: original user code:
# def f(x):
# x.mul_(2)
# out = x.mul(3)
# return out
#
# After AOT Autograd compiles, we end up with a:
# (a) compiled graph
# (b) autograd.Function.forward() method, that executes the compiled graph
# (c) wrapper function, that calls the autograd.Function.forward() and performs the epilogue
#
# The output of (a, b, c) are all written below.
#
# def compiled_forward_graph(x):
# x_updated = x.mul(2)
# out = x_updated.mul(3)
# return x_updated, out
#
# # x_updated gets a gradient in the compiled backward
# def compiled_backward_graph(grad_x_updated, grad_out):
# grad_x = ...
# return grad_x
#
# def autograd.Function.forward(x):
# x_updated, out = compiled_forward_graph(x)
# return x_updated, out
#
# def compiled_wrapper(x):
# x_updated, out = autograd.Function.apply(x)
# x.copy_(x_updated)
# return out
#
# Another important thing to note is that updated inputs (due to data mutations) *do* participate
# in the compiled backward graph! Since the compiled forward graph gets N extra outputs
# (due to updated inputs showing up as graph outputs),
# The compiled backward gets an additional N inputs.
# That way, during the x.copy_(x_updated) bit in the epilogue, gradients will flow from the updated input
# back to the original input.
# Note [AOT Autograd: input metadata mutations]
#
# For the same reason as input mutations, we also don't put input metadata mutations in the graph.
# Instead, we return the updated version of the input (a view), and mutate the input's metadata outside of the graph
#
# Example: original user code:
# def f(x):
# x.t_()
# out = x.mul(3)
# return out
#
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
# def compiled_forward_graph(x):
# x_updated = x.t()
# out = x_updated.mul(3)
# return x_updated, out
#
# # x_updated does *not* get a gradient in the compiled backward
# def compiled_backward_graph(grad_out):
# grad_x = ...
# return grad_x
#
# def autograd.Function.forward(x):
# x_updated, out = compiled_forward_graph(x)
# return x_updated, out
#
# def compiled_wrapper(x):
# x_updated, out = autograd.Function.apply(x)
# x.as_strided_(x_updated)
# return out
# Note [AOT Autograd: outputs aliasing inputs or intermediates!]
#
# AOT Autograd needs special handling for outputs that alias graph inputs or intermediates!
# Why?
# (1) autograd.Function.forward() has a limitation, where views that returned in the forward cannot later be mutated.
# (2) views don't need to be compiled in the graph anyway - it's cheap to generate them outside of the compiled graph,
# in an epilogue.
# For outputs that alias inputs, we do the following:
# (a) *still* return the aliased output as a graph output
# (b) In the AOT Autograd wrapper/epilogue, we don't return that aliased output. Instead, we use it to regenerate the output.
#
# For outputs that alias *intermediates*, we do the following:
# (a) Return the output in the compiled forward, **and** return it's ._base (a graph intermediates) as an output in the forward
# (b) Use (output, graph_intermediate) to regenerate the alias, and return that to the user (instead of the compiled fw output).
# You might wonder why we return the aliased output directly in the graph (and making the graph compute it),
# only to not return it and instead generate a fresh alias off of the intermediate,
# instead of (say) just storing metadata about the size/stride of the output somewhere to generate the alias. There are two reasons:
# (1) Getting the actual alias tensor allows us to use view-replay to generate the alias, instead of an as_strided() call
# (2) Inductor (and other backends) are free to change the memory format of graph outputs, if it results in better performance.
# This can result in problems if a user later tries to .view() that output expecting it to have one set of strides,
# when it has a different set of strides.
# By including the view op directly in the graph, inductor takes that into account when deciding what memory format
# the graph intermediate should be.
#
# Another important thing to note is how our traced backward() graph handles aliases.
# (this applies to outputs aliasing inputs, outputs aliasing intermediates,
# *and* updated inputs returned in the compiled forward due to metadata-only mutations).
# Any outputs that alias (either inputs or intermediates) do NOT participate in the compiled backward graph
# It would be wasteful to include them in the compiled backward(), because we regenerate them eagerly
# at the end of the forward.
#
# Example: original user code:
# def f(x):
# out1 = x.t()
# intermediate = x.mul(2)
# out2 = intermediate.view(-1)
# return out1, out2
#
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
# def compiled_forward_graph(x):
# out1 = x.t()
# intermediate = x.mul(2)
# out2 = intermediate.view(-1)
# # the compiled graph also returns the intermediate
# return out1, out2, intermediate
#
# # intermediate gets a gradient in the compiled backward.
# # both output aliases (out1 and out2) do not.
# def compiled_backward_graph(grad_intermediate):
# grad_x = ...
# return grad_x
#
# def autograd.Function.forward(x):
# out1, out2, intermediate = compiled_forward_graph(x)
# return out1, out2, intermediate
#
# def compiled_wrapper(x):
# out1, out2, intermediate = autograd.Function.apply(x)
# # regenerate out1 from the input
# out1_regenerated = out1._view_func(x)
# # regenerate out1 from the intermediate
# out2_regenerated = out2._view_func(intermediate)
# return out1_regenerated, out2_regenerated
# Note [AOT Autograd: mutations to inputs that alias other inputs]
#
# Another edge case that is (only partially) handled today is when an input is mutated, but itself aliases another input.
# AOT Autograd needs to **ensure** that functionalization knows that the two inputs are aliased to each other.
# That way, when the aliased input is accessed later in the graph, functionalization knows to "update" the alias
# given the mutation that occurred.
#
# This is handled by updating the calling convention: we create a "synthetic base" that becomes a new input
# in the compiled function, and we regenerate the original (aliased) inputs directly off of the base
# inside of the compiled function.
#
# This logic is fully encapsulated in aot_wrapper_synthetic_base()
#
# Example: original user code:
# def f(x, x_view):
# x.mul_(2)
# out = x * x_view
# return out
# f(x, x.view(-1))
#
# AOT Autograd output (compiled graph, autograd.Function.forward(), wrapper function):
# def compiled_forward_graph(base)
# x = generate_x(base)
# x_view = generate_x_view(base)
# x_updated = x.mul(2)
# x_view_updated = x_updated.view(-1)
# out = x_updated * x_view_updated
# return x_updated, out
#
# # The calling convention change from (aliases) -> (base) happens
# # *outside* of the autograd.Function.forward().
# # That means the forward() only has 1 input (base),
# # and the backward() only has 1 output (grad_base)
# def compiled_backward_graph(grad_out):
# grad_base = ...
# return grad_base
#
# def autograd.Function.forward(base):
# x_updated, out = compiled_forward_graph(base)
# return x_updated, out
#
# # The compiled wrapper is where we create synthetic bases.
# # The info on which inputs are mutated is also tracked *before* synthetic base creation.
# def compiled_wrapper(x, x_view):
# base = merge_view_inputs(x, x_view)
# x_updated, out = autograd.Function.apply(base)
# # x and x_view are aliased in eager mode, so this mutation to x will automatically affect x_view.
# x.copy_(x_updated)
# return out
# Note [AOT Autograd: Views to avoid tangents aliasing inputs]
#
# We view every forward output when creating out tangent tensors to handle the problematic
# case in which a subclass does extra aliasing between graph outputs/inputs in a way that
# is not visible above the sublass.
#
# Ordinarily, when constructing the joint function that we want to trace in AOTAutograd,
# we're guaranteed that the tangent tensors that we pass
# into the joint are distinct tensors from the primals. This is because when
# decide which forward outputs to create tangents for, we only create tangents
# for forward outputs that are not aliases of inputs (See Note
# [AOT Autograd: outputs aliasing inputs or intermediates!]).
#
# However, when wrapper tensor subclasses enter the picture, it is possible
# to have an output of the forward that is a subclass that is not an
# input / alias of an input, but one of its inner tensors is an alias!
# NestedTensor is an example: Performing an out-of-place pointwise op on a
# NestedTensor constructs a fresh NestedTensor that holds onto the input's
# offsets tensor directly.
#
# Having tangent tensors that are the same as the (primal) forward inputs,
# can cause problems during tracing as make_fx() will specialize on our
# duplicate inputs: If we passed in the same tensor for primals_1 and
# tangents_1 during tracing, make_fx() will happily sub out all usages of
# tangents_1 with primals_1 in the graph, which is not what we want.
#
# To work around this, we view every forward output when creating out tangent
# tensors so that tangents can never be the same as forward inputs even if
# forward inputs alias forward outputs.
#
#
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# This class stores info about every user output.
@dataclass(frozen=True)
class OutputAliasInfo:
# Tells us if this output is:
# (1) a regular (non-aliased) output
# (2) an alias of a forward input
# (3) **is** a forward input (special case of "alias_of_input")
# (4) an alias of an intermediate (aka an alias of an output of the inner traced forward)
# (5) an alias of an intermediate, that explicitly requires returning the intermediate
# as a graph output
# (6) an alias of an intermediate, where that intermediate is also a user output
output_type: OutputType
# The raw type of the output (torch.Tensor, SymInt, etc)
raw_type: type
# If (1) above, then
# - base_idx is None
# If (2) or (3) above, then
# - Tells us that the base of this alias is user_fwd_input[base_idx]
# (This is an index into the inputs *before* we make synthetic bases)
# If (4) or (5) above, then
# - Tells us that the base of this alias is output_graph_intermediates[base_idx]
# here, this refers to the index of the *direct* traced
# If (6) above, then:
# - Tells us that the base of this alias is output_user_fwds[base_idx]
# here, this refers to the index of the *direct* traced
base_idx: Optional[int]
# If it is a Tensor, what the dynamic dims are (otherwise is None)
dynamic_dims: Optional[Set[int]]
# requires_grad
requires_grad: bool
# This class tells us info about user inputs.
@dataclass(frozen=True)
class InputAliasInfo:
is_leaf: bool
mutates_data: bool
mutates_metadata: bool
mutations_hidden_from_autograd: bool
requires_grad: bool
@dataclasses.dataclass
class SubclassCreationMeta:
"""
Used for AOTDispatch.
This dataclass gives us the information we need to reconstruct a tensor subclass
from our flat inputs.
Why is this important? The graph that we'd like to trace out contains flat tensor inputs,
But the user's original model may have subclass inputs and outputs.
So we need to wrap/unwrap subclasses as necessary to translate between the user's
view (subclass inps/outs), and the backend compiler's view (graph with no subclass args).
Complications arise mostly from the fact that a subclass can hold more than one inner tensor;
So for a given subclass input/output, we need to carefully track which indices map
to the subclass tensor in the corresponding "dense-tensor-only" graph.
"""
# In the inner graph that only takes in dense tensor inputs,
# this maps to the first index of "tensors that should go in this subclass wrapper"
flat_tensor_start_idx: int
# The number of tensors that live in this subclass wrapper
arg_count: int
# Stores the original subclass itself.
# This is needed because we need the autograd metadata on the original subclass
# (this is guaranteed to be a wrapper subclass that holds a fake tensor,
# so holding onto this at runtime shouldn't leak memory)
original_subclass: torch.Tensor
# meta and inner_keys are produced by the subclass's __tensor_flatten__.
# We need to keep them around to plumb them into __tensor_unflatten__.
meta: Any
inner_keys: List[any]
def creation_fn(self, all_args, *, is_runtime: bool):
curr_args = all_args[self.flat_tensor_start_idx:self.flat_tensor_start_idx + self.arg_count]
assert len(curr_args) == len(self.inner_keys), f'inner_keys: {str(self.inner_keys)}. len(curr_args): {len(curr_args)}'
out = type(self.original_subclass).__tensor_unflatten__(dict(zip(self.inner_keys, curr_args)), self.meta)
if not is_runtime:
# After wrapping up the inner dense tensors into a subclass, we need to make sure that our new wrapper
# has correct autograd metadata, since we'll be tracing through the autograd engine with the subclass.
# We don't trace through the autograd engine at runtime though, so no need
# to compute this extra metadata then!
torch._mirror_autograd_meta_to(self.original_subclass, out)
return out
def __post_init__(self):
# sanity assert to make sure we don't leak memory
assert is_fake(self.original_subclass)
# This class encapsulates all aliasing + mutation info we need about the forward graph
# See a more detailed overview of the edge case handling at
# https://docs.google.com/document/d/19UoIh_SVrMy_b2Sx5ZaeOJttm6P0Qmyss2rdBuyfoic/edit
@dataclass(eq=False)
class ViewAndMutationMeta:
# length = # user inputs
# This gives us info about every input, and what sort of mutation happened to it (if any)
input_info: List[InputAliasInfo]
# length = # user outputs
# This gives us info about every output (mostly around whether it aliases other tensors)
output_info: List[OutputAliasInfo]
# length = the number of intermediate bases appended as outputs to the end of the forward graph.
# Note: this is not necessarily the same thing as:
# len([x for x in output_info if x.output_type == OutputType.alias_of_intermediate])
# Because outputs might share a ._base, or an output's ._base might itself be
# another user output (in both cases, we won't redundantly append bases to the end of the graph)
num_intermediate_bases: int
# For inference only: instructs us to keep data-only input mutations directly in the graph
keep_input_mutations: int
# length = (# inputs w data mutations) + (# user outputs that are non_aliasing tensors)
# + (# intermediate bases)
# These are the FakeTensor (or potential SymInt) outputs that we traced from our
# metadata pass of the user's forward function.
# Their only use today is to pass them as a best-guess for tangents when tracing the joint.
# Stashing them as part of our "metadata" makes it simpler if we want to run our analysis
# pass once, and re-use the output throughout AOTAutograd
traced_tangents: List[Any]
# Each of these is a list telling us about subclasses for the inputs/outputs/grad_outs
# They are used throughout AOTDispatch to tell us how to generate a list of subclass tensors,
# Given a (potentially larger) list of plain torch tensors.
# Taking subclass_inp_meta as an example:
# subclass_inp_meta[i] = j (an int) tells us:
# "The i'th user input is not a subclass, and corresponds to inputs[j] of the plain-tensor graph."
# subclass_inp_meta[i] = SubclassCreationMeta(flat_tensor_start_idx=3, arg_count=2)
# "The i'th user input is subclass holding two inner tensors, which are
# inputs[3] and inputs[4] of the plain-tensor graph".
# length = # user inputs
subclass_inp_meta: List[Union[int, SubclassCreationMeta]]
# So, the full set of outputs to the forward graph looks something like:
# (*mutated_inps, *user_outs, *intermediate_bases, *saved_for_bw_tensors)
# where the first 3 of those 4 can be subclasses
# (but not saved_for_bw tensors, since these are internal to the compiler
# and not user visible, so there's no point in wrapping/unwrapping them at runtime).
# This list contains subclass information on all of the fw graph outputs
# except for saved_for_bw_tensors.
subclass_fw_graph_out_meta: List[Union[int, SubclassCreationMeta]]
# length = # backward graph inputs
subclass_tangent_meta: List[Union[int, SubclassCreationMeta]]
# TODO: we should kill this
# (need to default it to not break internal)
is_train: bool = False
num_symints_saved_for_bw: Optional[int] = None
# The grad_enabled mutation that will be emitted in the runtime_wrapper epilogue
# NOTE: AOTAutograd will assume that the ambient `is_grad_enabled` is the grad mode
# that is intended to be in effect prior to running the graph, in keeping with
# equivalence to eager mode. It is the responsibility of upstream graph acquisition
# to reset the grad mode to its pre-graph value prior to calling aot_autograd.
grad_enabled_mutation: Optional[bool] = None
def __post_init__(self):
mutated_inp_indices = [
i for i, m in enumerate(self.input_info) if m.mutates_metadata or m.mutates_data
]
# pre-compute the indices of the inputs that are mutated.
# When keep_input_mutations is set, we don't need to worry about our epilogue
# handling data-only mutations, because we keep them directly in the graph.
mutated_inp_runtime_indices = [
i for i, m in enumerate(self.input_info) if m.mutates_metadata or (not self.keep_input_mutations and m.mutates_data)
]
aliased_out_indices = [
i
for i, m in enumerate(self.output_info)
if m.output_type not in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]
]
unsafe_view_out_indices = [
i for i, m in enumerate(self.output_info) if m.output_type is OutputType.unsafe_view_alias
]
self.mutated_inp_indices = mutated_inp_indices
# This is pre-computed in post_init for perf.
# It contains the index of every element
# of input_info that corresponds to a mutation (data or metadata or both)
self.mutated_inp_runtime_indices = mutated_inp_runtime_indices
# This is pre-computed for perf.
# It contains the index of every element
# of output_info that corresponds to an alias (either of an input or intermediate)
self.aliased_out_indices = aliased_out_indices
self.unsafe_view_out_indices = unsafe_view_out_indices
self.num_outputs = len(self.output_info)
self.num_outputs_non_aliased = len(
[x for x in self.output_info
if x.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]]
)
self.num_outputs_aliased_to_inputs = len(
[
x
for x in self.output_info
if x.output_type in [
OutputType.alias_of_input,
OutputType.is_input,
]
]
)
self.num_unsafe_view_outputs = len(self.unsafe_view_out_indices)
self.num_outputs_aliased_to_intermediates = len(
[
x
for x in self.output_info
if x.output_type in [
OutputType.alias_of_intermediate,
OutputType.alias_of_intermediate_save_as_output,
OutputType.alias_of_intermediate_base_is_user_output,
]
]
)
self.num_outputs_aliased = (
self.num_outputs_aliased_to_inputs + self.num_outputs_aliased_to_intermediates
)
self.num_mutated_data_inputs = len(
[x for x in self.input_info if x.mutates_data]
)
self.num_mutated_metadata_inputs = len(
[
x
for x in self.input_info
if x.mutates_metadata
]
)
self.num_mutated_metadata_only_inputs = len(
[
x
for x in self.input_info
if not x.mutates_data and x.mutates_metadata
]
)
self.num_mutated_inputs = self.num_mutated_data_inputs + self.num_mutated_metadata_only_inputs
self.dynamic_outputs = any(
o.dynamic_dims for o in self.output_info
)
# See Note: [AOTAutograd Backward Guards]
# This is pre-computed for fast asserts on the types of our grad_outputs in the backward.
# Eventually, we should kill this and replace with real backward guards.
# (we want to precompute the "runtime" types, so replace FakeTensor with torch.Tensor)
self.output_types = [torch.Tensor if isinstance(x, FakeTensor) else type(x) for x in self.traced_tangents]
self.is_rng_op_functionalized = config.functionalize_rng_ops
# All of the above metadata is collected by tracing the fw function.
# However, extra outputs for rng offsets behave differently. Both fwd
# and bwd graphs have their own outputs for the total consumed offsets.
# Unlike mutated inputs, we don't have to worry about sending the right
# set of tensors between fwd and bwd. Fwd and bwd offsets are
# independent and simpler to handle. Therefore, we track them
# separately.
self.num_outputs_rng_offset = 1 if self.is_rng_op_functionalized else 0
# Our forward() returns both (mutated_inputs, outputs, output_intermediate_bases, saved_tensors, saved_symints)
self.num_forward_returns = self.num_mutated_inputs + self.num_outputs + self.num_intermediate_bases
# In case of functionalization of rng ops, the fw_module returns one
# additional output for rng offset. This rng offset is used right
# away to advance the rng state, and is not passed on to the raw
# outputs. However, we need to know the exact boundary to identify
# which tensors to be saved for the bwd graph. num_forward captures
# this information.
self.num_forward = self.num_forward_returns + self.num_outputs_rng_offset
@property
def tensors_saved_for_backwards_slice(self):
assert self.num_symints_saved_for_bw is not None
if self.num_symints_saved_for_bw > 0:
return slice(self.num_forward, -self.num_symints_saved_for_bw)
else:
return slice(self.num_forward, None)
@property
def symints_saved_for_backwards_slice(self):
assert self.num_symints_saved_for_bw is not None
if self.num_symints_saved_for_bw > 0:
return slice(-self.num_symints_saved_for_bw, None)
else:
return slice(0, 0) # empty slice
def __eq__(self, other):
if not isinstance(other, ViewAndMutationMeta):
return NotImplemented
return (self.input_info == other.input_info and
self.output_info == other.output_info and
self.num_intermediate_bases == other.num_intermediate_bases and
self.keep_input_mutations == other.keep_input_mutations and
self.is_rng_op_functionalized == other.is_rng_op_functionalized and
self.num_outputs_rng_offset == other.num_outputs_rng_offset and
len(self.traced_tangents) == len(other.traced_tangents) and
all(x.shape == y.shape and x.dtype == y.dtype for x, y, in zip(self.traced_tangents, other.traced_tangents)))
@dataclass(eq=False)
class SubclassMeta:
# A copy of all forward metadata, but computed on the *dense* tensor forward (after desugaring subclasses)
# So for example, if the user had a model containing two `TwoTensor` inputs,
# Then `SubclassMeta.fw_metadata.input_infos` would have length 4 here.
fw_metadata: ViewAndMutationMeta
# Note: [Computing Subclass Metadata about grad_inputs]
# Given a list of flattened, plain tensor grad_inputs, this tells us how to reconstruct the grad_input subclasses
#
# You might think: why not just assume that all grad_inputs will have the same subclass-ness as the original inputs?
# (AOTAutograd generally assumes other properties, e.g. that grad_outputs are contiguous)
#
# This doesn't really work though. take this example:
#
# def f(DoubleTensor, DenseTensor):
# return DoubleTensor * DenseTensor
#
# In the above example, the .grad field of *both* DoubleTensor and DenseTensor will be a DoubleTensor.
# When we trace out a joint fw-bw graph, we'll end up returning two subclasses for the two grad_inputs.
# This means that our backward graph will return 4 outputs (two dense tensors for each DoubleTensor grad_input)
# and we need to properly store the metadata that tells us how to turn these 4 outputs back into DoubleTensors.
#
# Note that this info **cannot** easily be figured out from ViewAndMutationMeta.
# We can only compute this info by tracing the entire joint and examining the grad_inputs that we computed.
#
# See Note: [AOTAutograd Backward Guards]
# This will also eventually require us to install backward guards,
# in case we made incorrect assumptions about the subclass-ness of our grad_outputs
#
# Optional field because we don't compute for inference graphs
grad_input_metas: Optional[List[Union[int, SubclassCreationMeta]]]
def __init__(self):
# The fields in this class get set after its construction.
pass
# This class exists because:
# - the autograd.Function.forward() in aot autograd returns outputs that might alias inputs
# - we only care about the metadata on those aliases, so we can regenerate them.
# We do not want them to participate in the autograd.Function.
# We do that by wrapping them in an opaque class, so the autograd.Function
# does not know to treat them as tensors.
@dataclass(frozen=True)
class TensorAlias:
alias: torch.Tensor
def has_same_metadata(t1, t2):
return (
t1.size() == t2.size()
and t1.stride() == t2.stride()
and t1.storage_offset() == t2.storage_offset()
and t1.storage_offset() == t2.storage_offset()
and t1.is_conj() == t2.is_conj()
and t1.is_neg() == t2.is_neg()
)
def gen_alias_from_base(aliased_base_tensor, target_meta_tensor, target_requires_grad):
# Try to do view-replay if possible.
# fall back to .as_strided() if we can't.
if target_meta_tensor._base is not None:
# The base that we want to replay our view off of might have a different shape than the view's original base.
b = target_meta_tensor._base
abt = aliased_base_tensor
# Don't unnecessarily call as_strided if nothing changed; as_strided's
# backward is poorly implemented and slow
if abt is not b and (
abt.size() != b.size() or
abt.stride() != b.stride() or
abt.storage_offset() != b.storage_offset()
):
reshaped_base_tensor = aliased_base_tensor.as_strided(
b.size(), b.stride(), b.storage_offset()
)
else:
reshaped_base_tensor = aliased_base_tensor
out = target_meta_tensor._view_func(reshaped_base_tensor)
# This shape mismatch can happen due to a bug in inplace/view handling in autograd.
# Try putting a breakpoint here and running
# `test/functorch/test_aotdispatch TestAOTAutograd.test_output_all_alias_types`
# Also, https://github.com/pytorch/pytorch/issues/49825
#
# As a stopgap, we'll fall back to as_strided.
if out is not None and out.shape == target_meta_tensor.shape:
if aliased_base_tensor.requires_grad and not target_requires_grad:
out = out.detach()
elif not aliased_base_tensor.requires_grad and target_requires_grad:
out.requires_grad_(True)
return out
size = target_meta_tensor.size()
stride = target_meta_tensor.stride()
storage_offset = target_meta_tensor.storage_offset()
if aliased_base_tensor.is_complex() and not target_meta_tensor.is_complex():
aliased_out = torch.view_as_real(aliased_base_tensor).as_strided(
size, stride, storage_offset
)
elif not aliased_base_tensor.is_complex() and target_meta_tensor.is_complex():
aliased_out = torch.view_as_complex(aliased_base_tensor).as_strided(
size, stride, storage_offset
)
else:
aliased_out = aliased_base_tensor.as_strided(size, stride, storage_offset)
# For outputs aliasing inputs, we need to check if the requires-gradness has changed.
if aliased_base_tensor.requires_grad and not target_requires_grad:
aliased_out = aliased_out.detach()
elif not aliased_base_tensor.requires_grad and target_requires_grad:
aliased_out.requires_grad_(True)
return aliased_out
def to_fun(t):
if isinstance(t, Tensor):
if is_traceable_wrapper_subclass(t):
# See Note [Functionalization always runs last]
# This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper
# goes at the bottom.
# recurse here, so we can support nested wrapper subclasses
out = transform_subclass(t, lambda _, inner_t: to_fun(inner_t))
torch._mirror_autograd_meta_to(t, out)
return out
else:
return FunctionalTensor.to_functional(t)
else:
return t
def sync_functional_tensor(t):
if is_traceable_wrapper_subclass(t):
attrs, ctx = t.__tensor_flatten__()
for attr in attrs:
sync_functional_tensor(getattr(t, attr))
else:
torch._sync(t)
# When subclasses are involved, t here will usually look something like:
# SubclassA(SubclassB(FunctionalTensor(_to_functional_tensor(FakeTensor))))
def from_fun(t):
if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t):
# See Note [Functionalization always runs last]
# This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper
# goes at the bottom.
# recurse here, so we can support nested wrapper subclasses
out = transform_subclass(t, lambda _, inner_t: from_fun(inner_t))
torch._mirror_autograd_meta_to(t, out)
return out
if not isinstance(t, FunctionalTensor):
# quick sanity assert
if isinstance(t, torch.Tensor):
assert not torch._is_functional_tensor(t)
return t
sync_functional_tensor(t)
return torch._from_functional_tensor(t.elem)
def is_fun(t):
if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t):
# See Note [Functionalization always runs last]
# This means that if we want to "functionalize" a subclass, we need to ensure that the functional wrapper
# goes at the bottom.
# recurse here, so we can support nested wrapper subclasses
t_attrs, _ = t.__tensor_flatten__()
t_inners = [getattr(t, attr) for attr in t_attrs]
any_fun = any(is_fun(x) for x in t_inners)
all_fun = all(is_fun(x) for x in t_inners)
assert any_fun == all_fun
return any_fun
return isinstance(t, FunctionalTensor)
# t here is either
# (1) A FunctionalTensor(_to_functional_tensor(FakeTensor))
# (2) A traceable tensor subclass that holds a FunctionalTensor
def has_metadata_mutation(t):
if is_traceable_wrapper_subclass(t):
attrs, _ = t.__tensor_flatten__()
# A tensor subclass was updated if any of its inner elements were updated
return any(has_metadata_mutation(getattr(t, attr)) for attr in attrs)
else:
assert isinstance(t, FunctionalTensor)
return torch._functionalize_has_metadata_mutation(t.elem)
def are_all_mutations_hidden_from_autograd(t):
if is_traceable_wrapper_subclass(t):
attrs, _ = t.__tensor_flatten__()
# If all inner elements are mutations hidden from autograd, then it is a mutation hidden from autograd.
return all(are_all_mutations_hidden_from_autograd(getattr(t, attr)) for attr in attrs)
else:
assert isinstance(t, FunctionalTensor)
return torch._functionalize_are_all_mutations_hidden_from_autograd(t.elem)
# new_arg and arg here are either:
# (1) both a FakeTensor
# (2) both a traceable tensor subclass that holds a FakeTensor
# Pre-condition: the two args are the "old" and "new" inputs from running functionalization.
# When we run functionalization and wrap our inputs into FunctionalTensors,
# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed
#
# Normally it would be enough just to check if arg is new_arg, which is normally enough for functionalization
# to confirm that inputs were not mutated when running the user's model with functionalization on.
# But when we have subclass inputs, we can't rely on that:
# `from_fun(to_fun(x)) is x` will return False, because the call to `from_fun` constructs
# a brand new subclass instance: we are calling __tensor_unflatten__, and going
# from Subclass(FakeTensor) to Subclass(FunctionalTensor(FakeTensor))
def was_updated(arg, new_arg):
if is_traceable_wrapper_subclass(arg):
assert is_traceable_wrapper_subclass(new_arg)
attrs, _ = arg.__tensor_flatten__()
new_attrs, _ = new_arg.__tensor_flatten__()
assert attrs == new_attrs
# A tensor subclass was updated if any of its inner elements were updated
return any(was_updated(getattr(arg, attr), getattr(new_arg, attr)) for attr in attrs)
else:
return arg is not new_arg
# new_arg and arg here are either:
# (1) both a FakeTensor
# (2) both a traceable tensor subclass that holds a FakeTensor
# Pre-condition: the two args are the "old" and "new" inputs from running functionalization.
# When we run functionalization and wrap our inputs into FunctionalTensors,
# we can detect whether or not an input was mutated by checking to see if the inner tensor has changed,
# but shares storage with the old input
def was_metadata_updated(arg, new_arg):
if is_traceable_wrapper_subclass(arg):
assert is_traceable_wrapper_subclass(new_arg)
attrs, _ = arg.__tensor_flatten__()
new_attrs, _ = new_arg.__tensor_flatten__()
assert attrs == new_attrs
# A tensor subclass was updated if any of its inner elements were updated
return any(was_metadata_updated(getattr(arg, attr), getattr(new_arg, attr)) for attr in attrs)
else:
return arg is not new_arg and StorageWeakRef(arg.untyped_storage()) == StorageWeakRef(new_arg.untyped_storage())
def _get_hints(exprs):
"""
Get the hints of a list/tuple of int/SymInt.
"""
if isinstance(exprs, (list, tuple)):
return type(exprs)(_get_hints(e) for e in exprs)
elif isinstance(exprs, torch.SymInt):
return exprs.node.shape_env.size_hint(exprs.node.expr)
else:
return exprs
def requires_subclass_dispatch(args, fw_metadata: ViewAndMutationMeta) -> bool:
args_flattened = pytree.arg_tree_leaves(*args)
any_subclass_args = any(is_traceable_wrapper_subclass(x) for x in args_flattened if isinstance(x, Tensor))
any_subclass_outputs = any(is_traceable_wrapper_subclass(x) for x in fw_metadata.traced_tangents if isinstance(x, Tensor))
# This tells us whether or not we need to perform any unwrapping/wrapping of tensor subclasses at runtime.
return any_subclass_args or any_subclass_outputs
# Given a flat list of arguments, some of which may be tensor subclasses,
# computes metadata about "how to reconstruct the current list of subclasses,
# if we were given their flattened dense tensors instead"
def create_subclass_meta(curr_args: List[Any]) -> List[Union[int, SubclassCreationMeta]]:
idx = 0
infos = []
for a in curr_args:
if isinstance(a, torch.Tensor) and is_traceable_wrapper_subclass(a):
attrs, meta = a.__tensor_flatten__()
start_idx = idx
cnt = len(attrs)
curr_cnt = cnt
infos.append(SubclassCreationMeta(
flat_tensor_start_idx=start_idx,
arg_count=curr_cnt,
original_subclass=a,
meta=meta,
inner_keys=attrs,
))
else:
infos.append(idx)
cnt = 1
idx += cnt
return infos
def _get_autocast_states():
return [
torch.is_autocast_enabled(),
torch.is_autocast_cpu_enabled(),
torch.get_autocast_gpu_dtype(),
torch.get_autocast_cpu_dtype(),
torch.is_autocast_cache_enabled(),
]
# This is a version of functionalization that is specifically designed
# for the AOTAutograd use case.
#
# Unlike functorch's variant, this doesn't use the functorch level system,
# instead it directly uses PyTorch's conventional dispatcher to hit the
# functionalization key. In particular, this means that FunctionalTensorWrapper
# can have autograd data stored directly on it.
#
# In typical AOTAutograd usage, the dispatch key order will look like:
#
# Autograd - Functionalization ~~~~> Proxy Mode - Fake Tensor
# outer tensor inner tensor
#
# Returns:
# - ViewAndMutationMeta, telling us metadata about the inputs and outputs, and
# The list of outputs from the forward, but **only** the outputs that we need
# to pass in as tangents into the backward.
# Specifically, aliased outputs from the forward get regenerated, and don't participate
# in the compiled backward function.
def run_functionalized_fw_and_collect_metadata(
f,
*,
keep_input_mutations: bool,
# TODO: refactor to kill this flag
is_train: bool = False,
) -> ViewAndMutationMeta:
memo = {}
def _to_fun(t):
if isinstance(t, Tensor):
if t in memo:
return memo[t]
r = to_fun(t)
memo[t] = r
return r
else:
return t
@wraps(f)
def inner(*flat_args):
# This function is meant to be run with the forward, which expects a flat list of tensor/symint/other args.
assert all(isinstance(a, KNOWN_TYPES) for a in flat_args)
input_info: List[InputAliasInfo] = []
output_info: List[OutputAliasInfo] = []
flat_f_args = pytree.tree_map(_to_fun, flat_args)
prior_grad_enabled = torch.is_grad_enabled()
prior_autocast_states = _get_autocast_states()
# See Note [Disabling Functionalize TLS Above Python Functionalization]
disable_above = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize))
with disable_above, FunctionalTensorMode():
# precondition: The passed in function already handles unflattening inputs + flattening outputs
flat_f_outs = f(*flat_f_args)
if prior_autocast_states != _get_autocast_states():
raise RuntimeError(
"AOTAutograd does not support tracing graphs that mutate the autocast state. "
"Dynamo will only insert autocast context managers (e.g. with torch.autocast(..)) into the graph, "
"which will unwind all of their mutations to autocast state before the graph exits. "
"If you encounter this error while using torch.compile, please file a bug."
)
# Inspect the state of the input tensor functional wrapper to detect input mutation info
# If inp[i] has a metadata-only mutation, then maybe_inputs_with_mutated_metadata[i] contains the updated version
for (i, (arg, f_arg)) in enumerate(zip(flat_args, flat_f_args)):
if not isinstance(arg, Tensor):
new_arg = arg
else:
new_arg = from_fun(f_arg)
if was_updated(arg, new_arg):
if was_metadata_updated(arg, new_arg):
mutates_data = False
mutates_metadata = True
else:
mutates_data = True
mutates_metadata = has_metadata_mutation(f_arg)
mutations_hidden_from_autograd = are_all_mutations_hidden_from_autograd(f_arg)
else:
mutates_data = False
mutates_metadata = False
mutations_hidden_from_autograd = False
input_info.append(InputAliasInfo(
is_leaf=isinstance(arg, torch.Tensor) and safe_is_leaf(arg),
mutates_data=mutates_data,
mutates_metadata=mutates_metadata,
mutations_hidden_from_autograd=mutations_hidden_from_autograd,
requires_grad=isinstance(f_arg, torch.Tensor) and f_arg.requires_grad
))
# If a function involves creating a tensor, and returning a view of it, such that its _base is the intermediate,
# We need to make sure our graph returns the _base as a graph output, and we manually recreate the view
# to return to the user. Why? The backend compiler is free to (incorrectly) not set requires_grad
# on the base tensor, but we are obligated to properly set requires-gradness on the real output.
inp_storage_refs = {
StorageWeakRef(inpt.untyped_storage()): idx
for idx, inpt in enumerate(flat_f_args)
if isinstance(inpt, torch.Tensor)
}
# We need inp tensor id's to be able to tell if an outputs **are** inputs.
inp_tensor_ids = {
id(inpt) for inpt in flat_f_args if isinstance(inpt, torch.Tensor)
}
# We need output tensor id's to tell if any output._base` attributes **are** other outputs.
# (This is also a dict because we need to know that output's index, so we can regenerate
# the alias from it).
out_tensor_ids = {id(o): i for i, o in enumerate(flat_f_outs)}
# Keep track of which outputs alias other outputs
out_tensor_alias_counts = collections.defaultdict(int)
# This tells us, for a given group of outputs that alias each other,
# whether they e.g. all came from an unbind call
num_aliased_tensors_that_are_multi_output_views = collections.defaultdict(int)
out_storage_to_tensors = collections.defaultdict(set)
for o in flat_f_outs:
if isinstance(o, torch.Tensor):
curr_storage = StorageWeakRef(o.untyped_storage())
out_tensor_alias_counts[curr_storage] += 1
# Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
# This is an optimization on top of the "alias of intermediates" logic,
# which you can read more about under Note [AOT Autograd: outputs aliasing inputs or intermediates!]
#
# Before describing the optimization: this is important for AOTAutograd to have good
# perf around, multi-output views. HOWEVER:
# - There is a more generic change to AOTAutograd that we'd like to make, that subsumes this case,
# around using pre-dispatch tracing to partition out a graph so we can faithfully replay all
# views without having to regenerate them at runtime.
# - It's loosely described in this doc (more details will be added soon):
# https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit
# - Once that change lands, we should just rip out this "optimization", since:
# (1) It will be fully unnecessary
# (2) Although it is only a few lines of code, it is a bit difficult to reason about
# its correctness with the autograd engine in all cases.
#
#
# What is this optimization? Consider the below case:
# def f(x):
# intermediate = x.mul(2)
# # x and intermediate here require grad
# o1, o2, ... o10 = intermediate.unbind(-1)
# return intermediate, o1, o2, ... o10
# Now, the "intermediate base" handling in AOTAutograd implies that we must do the following:
# (1) return "intermediate as an extra output of the compiled graph
# (2) regenerate each aliased output off of "intermediate", **outside** of the autograd.Function.
# The reason AOTAutograd ordinarily does this is for safety: the autograd engine needs to know
# that o1 through o10 are all aliased, and if we blindly return o1 through o10 from the autograd.Function,
# this information will be hidden.
# In particular, mutating one alias might require autograd to update autograd metadata on the other aliases
# (like their grad_fn, for example, when the autograd engine needs to do view-replay).
#
# However, intermediate_base logic can be bad for backward performance (we sometimes generate
# as_strided calls during the intermediate base logic, which can have a slow backward formula).
# Is it possible to find a set of conditions where it is **safe** to hide the output aliasing from autograd?
#
# For a set of outputs of the graph that alias each other, o_1...o_k, consider:
# (1) They came from the same multi-output view op, e.g. o_1, ..., o_k = intermediate.unbind(0)
# (2) If there are any other aliases of o_1 through o_k (in the example above, intermediate),
# **at most** 1 can escape from the graph (e.g. there is not some other graph input/output
# o_other, that aliases these outputs)
# (3) o_1...o_k all require_grad, they all share the same ._base, and their ._base requires grad.
# This condition is important because it's what causes slowness in the intermediate_base
# codepath of aot_autograd. Ordinarily, o_1...o_k would all get a grad_fn, and
# aot_autograd's view-replay might give each output an AsStridedBackward as its grad_fn.
# "K" AsStridedBackward calls will be *much* slower than a single UnbindBackward.
# In this setup, is it possible to mutate one of the outputs o_i in a way that would affect the autograd meta
# of the other aliases?
#
# Claim: No! Consider a few example (which I'm pretty sure cover all cases of mutation w.r.t. autograd):
# (a) What happens if we mutate any of o_1 through o_k directly?
# Autograd raises an error:
# "RuntimeError: Output 0 of UnbindBackward0 is a view and is being modified inplace. This view is
# the output of a function that returns multiple views. Such functions do not allow the output
# views to be modified inplace. You should replace the inplace operation by an out-of-place one."
# (b) What if we take a view of o_k and mutate it, o_k.view(o_k.shape).mul_(2)?
# Autograd raises the same error- the "multi-output-view"ness of an alias propagates to future views.
# (c) What if we mutate o_k under no_grad?
# Autograd raises the same error
# (d) What if we detach and mutate, e.g. o_k.detach().mul_(2)?
# Autograd allows this, *but* autograd updates all alias's grad_fn's to be error functions when accessed.
# Autograd raises the same error
# (e) What if we try to mutate another alias of o_1...o_k, that was **not** created from a multi-output view?
# We promised that there is at most **one** such alias, e.g. intermediate in the example above.
# You can mutate intermediate, but in eager mode this will change the grad_fn of o_1...o_k
# to be error fn's.
# Since intermediate was the *only* non-multi-output-alias, there are no other aliases
# of `intermediate` around that were produced by the compiled fn and have a valid grad_fn.
#
# Coming back to this optimization:
# Given that it is not possible for mutating one of these aliases to affect the autograd metadata of another alias
# without causing an error in eager mode, we will simple hide the aliasing from autograd during torch.compile
# if all of the above conditions are met.
# This has the slight downside that it's possible to write some "bad" code that autograd will raise an error on
# in eager but fail to during torch.compile, but it has the benefit that this code has much better performance.
# NOTE: if and when we eventually update AOTAutograd to do the "view graph slicing" defined here:
# https://docs.google.com/document/d/1DlfFq8TKbuAn2zyJxLfoW-X1qkkm5PLdHFtySo03QAk/edit,
# then this optimization will probably matter less and might be ok to remove.
is_cur_tensor_multi_out_view = isinstance(o, FunctionalTensor) \
and torch._functionalize_is_multi_output_view(o.elem)
if is_cur_tensor_multi_out_view:
num_aliased_tensors_that_are_multi_output_views[curr_storage] += 1
out_storage_to_tensors[curr_storage].add(o)
# maps the id of an intermediate base to its index in the output of the compiled forward
intermediate_base_tensor_id_to_output_idx: Dict[int, int] = {}
intermediate_bases: List[torch.Tensor] = []
for o in flat_f_outs:
curr_storage = None if not isinstance(o, torch.Tensor) else StorageWeakRef(o.untyped_storage())
outs_with_identical_metadata_that_require_grad = [] if not isinstance(o, torch.Tensor) else [
curr for curr in out_storage_to_tensors[curr_storage]
if has_same_metadata(o, curr) and curr.requires_grad and o is not curr
]
is_result_of_custom_autograd_fn = False
if isinstance(o, torch.Tensor):
# Need to check for both custom cpp (CppFunction) and python (BackwardCFunction) autograd fns
if type(o.grad_fn).__name__ == "CppFunction":
is_result_of_custom_autograd_fn = True
if isinstance(o.grad_fn, torch.autograd.function.BackwardCFunction):
is_result_of_custom_autograd_fn = True
if not isinstance(o, torch.Tensor):
output_type = OutputType.non_alias
base_idx = None
elif curr_storage in inp_storage_refs and o.grad_fn is not None \
and is_result_of_custom_autograd_fn:
output_type = OutputType.custom_function_view
base_idx = None
elif curr_storage in inp_storage_refs:
base_idx = inp_storage_refs[curr_storage]
is_input_tensor = id(o) in inp_tensor_ids
num_aliased_outs = out_tensor_alias_counts[curr_storage]
num_multi_output_view_outs = num_aliased_tensors_that_are_multi_output_views[curr_storage]
num_aliased_outs_that_are_not_multi_output_views = num_aliased_outs - num_multi_output_view_outs
if o.grad_fn is not None and num_aliased_outs_that_are_not_multi_output_views == 0:
# See Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
# In particular, given:
# def f(x):
# return list(x.unbind(0))
# The main reason we ordinarily try to regenerate these output aliases outside of the
# compiled autograd.Function is because if any of the outputs are later mutated,
# autograd needs to perform view-replay to regenerate them.
# However, autograd does not allow users to mutate multi-output views
# in any way that can change the autograd metadata of other aliases.
# So we hide this aliasing from autograd here.
aot_graphs_log.info("Encountered AOTAutograd case: differentiable outputs that \
alias each other from a multi-output view call")
output_type = OutputType.non_alias
elif is_input_tensor:
output_type = OutputType.is_input
else:
output_type = OutputType.alias_of_input
# We only need to handle the intermediate base case when both
# the intermediate base and the output require gradients.
# See Note [AOT Autograd: outputs aliasing inputs or intermediates!]
elif (
o._base is not None
and o.requires_grad
and o._base.requires_grad
):
num_aliased_outs = out_tensor_alias_counts[curr_storage]
num_multi_output_view_outs = num_aliased_tensors_that_are_multi_output_views[curr_storage]
num_aliased_outs_that_are_not_multi_output_views = num_aliased_outs - num_multi_output_view_outs
# Note: [AOTAutograd: differentiable outputs that alias each other from a multi-output view call]
if out_tensor_alias_counts[curr_storage] == 1 or num_aliased_outs_that_are_not_multi_output_views <= 1:
# Note [Intermediate Bases Optimization]
# Normally if we have an output that aliases an intermediate,
# we need to add the extra "intermediate base" logic further down
# to prevent autograd from yelling at us if the user later tries to
# mutate that output.
# However, the common case here is if we have an output that aliases an intermediate,
# but doesn't alias any other outputs.
# In that case, autograd shouldn't have to worry about the aliasing at all
# (if that output is mutated, there are no other live aliases for autograd to worry about).
# The "intermediate bases" can hurt inductor perf by forcing more variables to become outputs.
# So as an optimization, we won't do intermediate base handling in this case.
# Instead, we'll hide the aliasing from autograd using aten._unsafe_view().
if out_tensor_alias_counts[curr_storage] != 1 and num_aliased_outs_that_are_not_multi_output_views <= 1:
aot_graphs_log.info("Encountered AOTAutograd case: differentiable outputs that alias each other \
from a multi-output view call")
output_type = OutputType.unsafe_view_alias
base_idx = None
else:
# First, check if o's ._base is an existing output
maybe_existing_out_idx = out_tensor_ids.get(id(o._base), None)
if maybe_existing_out_idx is not None:
# Special case where the output is an alias of a graph intermediate, but that intermediate
# is itself also a user output.
output_type = OutputType.alias_of_intermediate_base_is_user_output
base_idx = maybe_existing_out_idx
else:
# Next, check if o's ._base is an intermediate base that we already returned
maybe_existing_base_output_idx = intermediate_base_tensor_id_to_output_idx.get(
id(o._base), None
)
if maybe_existing_base_output_idx is not None:
output_type = OutputType.alias_of_intermediate
base_idx = maybe_existing_base_output_idx
else:
# Otherwise, take o._base and explicitly return it as an output in the compiled graph
new_out_idx = len(intermediate_bases)
base_idx = new_out_idx
# Indicate to the logic later on (when we trace the joint)
# that this particular output should get it's ._base appended to the forward graph outputs
output_type = OutputType.alias_of_intermediate_save_as_output
intermediate_base_tensor_id_to_output_idx[id(o._base)] = new_out_idx
intermediate_bases.append(o._base)
elif (
# See https://github.com/pytorch/pytorch/issues/100348 for this case.
# This protects against the specific case where a user fn returns (output, output.detach())
out_tensor_alias_counts[curr_storage] > 1
and len(outs_with_identical_metadata_that_require_grad) > 0
and not o.requires_grad
):
assert len(outs_with_identical_metadata_that_require_grad) > 0
# In theory we could use any of these tensors to regenerate the aliased outputs from,
# since they all alias each other and have identical metatadata
out_alias = outs_with_identical_metadata_that_require_grad[0]
existing_out_idx = out_tensor_ids[id(out_alias)]
output_type = OutputType.alias_of_intermediate_base_is_user_output
base_idx = existing_out_idx
else:
output_type = OutputType.non_alias
base_idx = None
if isinstance(o, torch.Tensor):
dynamic_dims = {i for i, s in enumerate(o.shape) if not is_concrete_int(s)}
else:
dynamic_dims = None
out_info = OutputAliasInfo(
output_type=output_type,
raw_type=type(o),
base_idx=base_idx,
dynamic_dims=dynamic_dims,
requires_grad=isinstance(o, torch.Tensor) and o.requires_grad
)
output_info.append(out_info)
# See Note [AOT Autograd: Views to avoid tangents aliasing inputs]
def view_avoid_dupes_with_primals(t):
if isinstance(t, Tensor) and is_traceable_wrapper_subclass(t):
return transform_subclass(t, lambda _, inner_t: view_avoid_dupes_with_primals(inner_t))
if isinstance(t, Tensor):
return t.view(t.shape)
return t
# This analysis function returns *only* the outputs that are meant to be tangents to the backwards.
# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
# are *regenerated* later, and not used directly in the autograd graph
f_input_tangents = [
inp
for inp, info in zip(flat_f_args, input_info)
if info.mutates_data and info.requires_grad
]
f_output_tangents = [
o
for o, info in zip(flat_f_outs, output_info)
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]
and issubclass(info.raw_type, torch.Tensor)
and info.requires_grad
]
# intermediate bases are also included in the backward graph
f_tangents = f_input_tangents + f_output_tangents + intermediate_bases
traced_tangents = pytree.tree_map(from_fun, f_tangents)
traced_tangents = pytree.tree_map(view_avoid_dupes_with_primals, traced_tangents)
user_outs = pytree.tree_map(from_fun, f_output_tangents)
f_mutated_inputs = [
inp
for inp, info in zip(flat_f_args, input_info)
if info.mutates_data or info.mutates_metadata
]
f_metadata_mutated_inputs = [
inp
for inp, info in zip(flat_f_args, input_info)
if info.mutates_metadata
]
# This logic (annoyingly) re-figures out exactly what the outputs to the compiled fw graph will be.
# When handling subclasses, we need info about **all** outputs of compiled forward graph,
# so we know precisely which graph outputs to wrap back into tensor subclasses
# Ideally we would refactor this so not have an is_train flag, and have the separate
# inference and training paths decide which inputs/output to ask for subclass info on.
# However, we currently stash indexing information on each SubclassMeta about its order
# in the graph outputs list.
f_fw_graph_outs = list(flat_f_outs)
if is_train or not keep_input_mutations:
f_fw_graph_outs = f_mutated_inputs + f_fw_graph_outs
else:
# even when "keep_input_mutations" is True,
# we never keep metadata-only mutations in the fw graph
f_fw_graph_outs = f_metadata_mutated_inputs + f_fw_graph_outs
if is_train:
f_fw_graph_outs = f_fw_graph_outs + intermediate_bases
fw_graph_outs = pytree.tree_map(from_fun, f_fw_graph_outs)
grad_enabled_mutation = None
if torch.is_grad_enabled() != prior_grad_enabled:
grad_enabled_mutation = torch.is_grad_enabled()
torch.set_grad_enabled(prior_grad_enabled) # Restore the prior state after tracing it
aot_graphs_log.info(
("grad_mode mutation encountered in graph. "
"Will emit mutation epilogue, to set grad_mode=%s"),
grad_enabled_mutation
)
metadata = ViewAndMutationMeta(
input_info=input_info,
output_info=output_info,
num_intermediate_bases=len(intermediate_bases),
keep_input_mutations=keep_input_mutations,
traced_tangents=traced_tangents,
subclass_inp_meta=create_subclass_meta(flat_args),
subclass_fw_graph_out_meta=create_subclass_meta(fw_graph_outs),
subclass_tangent_meta=create_subclass_meta(traced_tangents),
is_train=is_train,
grad_enabled_mutation=grad_enabled_mutation,
)
return metadata
return inner
@dataclass
class BackwardSignature:
"""
Provides information about the backward section of an exported
joint forward-backward graph.
For a particular fx GraphModule, this class contains information on:
(1) A mapping from each gradient (backwards output) to the parameter
it corresponds to (forward input)
(2) A mapping from each gradient (backwards output) to the user input
it corresponds to (forward input)
(3) Which of the forward outputs corresponds to the loss, that we backprop on.
Each string name is the `node.name` of the corresponding node in the fx graph.
"""
gradients_to_parameters: Dict[str, str]
gradients_to_user_inputs: Dict[str, str]
loss_output: str
GraphOutputName = NewType('GraphOutputName', str)
GraphInputName = NewType('GraphInputName', str)
FQN = NewType('FQN', str)
@dataclass
class GraphSignature:
"""
Provides information about an exported module.
For a particular fx GraphModule, this class contains information on:
(1) Which graph inputs are parameters, buffers, or user inputs
(2) (for params/buffers) a mapping from the name of each graph argument
to its parameter/buffer FQN in the original nn.Module.
(3) If there are input mutations, these are represented as extra outputs
in the fx GraphModule. We provide a mapping from these
extra output names to the names of the actual inputs.
(4) The pytree metadata on how to flatten/unflatten inputs and outputs.
The corresponding FX GraphModule only accepts and returns
pytree-flattened inputs/outputs.
(5) (Optionally) if the FX is a joint forward-backward graph, we provide
a signature on the backward section of the joint graph.
"""
parameters: List[FQN]
buffers: List[FQN]
user_inputs: List[GraphInputName]
user_outputs: List[GraphOutputName]
inputs_to_parameters: Dict[GraphInputName, FQN]
inputs_to_buffers: Dict[GraphInputName, FQN]
# If the user's module mutates a buffer,
# it's represented in the graph as an extra graph output.
# This dict is a mapping from
# "graph outputs that correspond to updated buffers"
# to the FQN names of those mutated buffers.
buffers_to_mutate: Dict[GraphOutputName, FQN]
in_spec: pytree.TreeSpec
out_spec: pytree.TreeSpec
backward_signature: Optional[BackwardSignature]
@classmethod
def from_tracing_metadata(
cls,
*,
in_spec: pytree.TreeSpec,
out_spec: pytree.TreeSpec,
graph_input_names: List[str],
graph_output_names: List[str],
view_mutation_metadata: ViewAndMutationMeta,
named_parameters: List[str],
named_buffers: List[str],
num_user_inputs: int,
num_user_outputs: int,
loss_index: Optional[int],
backward_signature: Optional[BackwardSignature],
) -> "GraphSignature":
graph_inputs = graph_input_names
graph_outputs = graph_output_names
parameters = list(named_parameters)
buffers = list(named_buffers)
# Calling convention assumptions:
# (1) graph inputs = (params, buffers, user_inputs)
# (2) graph outputs = (mutated_inputs, user_outs, param_gradients)
# (If we are capturing an inference graph, this convention is identical
# except that param_gradients is empty)
user_inputs = graph_inputs[len(parameters) + len(buffers) :]
assert num_user_inputs == len(user_inputs)
assert len(graph_inputs) == (len(parameters) + len(buffers) + len(user_inputs))
inputs_to_parameters = dict(zip(graph_inputs[: len(parameters)], parameters))
inputs_to_buffers = dict(zip(
graph_inputs[len(parameters) : len(parameters) + len(buffers)],
buffers,
))
state_names = [*parameters, *buffers]
mutated_buffers = []
for idx, input_info in enumerate(view_mutation_metadata.input_info):
if input_info.mutates_data:
# Only buffers can be mutated, not parameters
assert idx >= len(parameters)
buffer_name = state_names[idx]
mutated_buffers.append(buffer_name)
assert len(mutated_buffers) == view_mutation_metadata.num_mutated_inputs
start, stop = 0, view_mutation_metadata.num_mutated_inputs
buffers_to_mutate = dict(zip(graph_outputs[start:stop], mutated_buffers))
start, stop = stop, stop + num_user_outputs
user_outputs = graph_outputs[start:stop]
unused_outputs = len(graph_outputs) - stop
if backward_signature is not None:
unused_outputs -= len(backward_signature.gradients_to_parameters) + len(
backward_signature.gradients_to_user_inputs
)
assert unused_outputs == 0
return GraphSignature(
parameters=parameters,
buffers=buffers,
user_inputs=user_inputs,
user_outputs=user_outputs,
inputs_to_buffers=inputs_to_buffers,
inputs_to_parameters=inputs_to_parameters,
buffers_to_mutate=buffers_to_mutate,
in_spec=in_spec,
out_spec=out_spec,
backward_signature=backward_signature,
)
@dataclasses.dataclass
class AOTConfig:
"""
Configuration for AOTDispatcher
"""
fw_compiler: Callable
bw_compiler: Callable
partition_fn: Callable
decompositions: Dict[Callable, Callable]
num_params_buffers: int
aot_id: int
keep_inference_input_mutations: bool
is_export: bool = False
no_tangents: bool = False
dynamic_shapes: bool = False
aot_autograd_arg_pos_to_source : Optional[List[Source]] = None
inference_compiler: Optional[Callable] = None
enable_log: bool = True
# This function takes in a tensor t, and returns one of t, t.view(), or t.clone().
# When tracing the joint forward + backward, for any inputs in the graph that are mutated,
# we need to clone them first (and similarly for metadata-only mutations, we need to view them first).
# The idea is that when we trace the backward, we need to pass in the *original* primals
# to autograd.grad(), before they were mutated.
# Note: when we have synthetic base inputs, we need to clone them *before* creating views off of them.
# This means that "idx" here represents the index of the (potentially) synthetic base.
# What we need to do is:
# (1) map the current (post-synthetic-base calling convention) input argument index
# to int index pre-synthetic-base-calling-convention.
# (2) There could be multiple, if this index corresponds to a synthetic base
# that has multiple input aliases.
# (3) If any of those corresponding inputs get metadata mutations, then we clone the base.
def maybe_to_fresh_input(idx, t, meta):
if not isinstance(t, Tensor):
return t
if idx in meta.mutated_inp_indices:
# We only need to bother cloning mutated inputs that participate in autograd.
mutated_inp_idx = meta.mutated_inp_indices.index(idx)
if meta.input_info[idx].requires_grad and meta.input_info[idx].mutates_data:
# Make sure the primal we pass to autograd.grad()
# sees the tensor before the mutation
return t.clone()
if meta.input_info[idx] and meta.input_info[idx].mutates_metadata:
# Make sure the primal we pass to autograd.grad()
# sees the tensor before the metadata mutation
return t.view(t.shape)
return t
# This function returns a new function that returns mutated inputs as outputs.
# if keep_data_input_mutations is set, then we assume that data-only mutations
# will be left in the graph, and we only return metadata-mutated inputs as outputs.
def fn_input_mutations_to_outputs(
fn: Callable,
meta: ViewAndMutationMeta,
keep_data_input_mutations: bool,
) -> Any:
def inner_fn(*args):
outs = fn(*args)
assert len(meta.output_info) == len(outs)
# The compiled fw will return mutated input tensors, *including* metadata-only mutation.
# However, if keep_data_input_mutations is set, the compiled fw only needs to return metadata-mutated inputs.
# (because data-only input mutations are handled directly in the compiled graph)
mutated_inputs_to_return = [
x
for (i, x) in enumerate(args)
if meta.input_info[i].mutates_metadata or (meta.input_info[i].mutates_data and not keep_data_input_mutations)
]
return *mutated_inputs_to_return, *outs
return inner_fn
# This function takes in a fn with external aliasing and mutation,
# and returns a new fn with no external aliasing and mutation,
# as needed for autograd.
# The main transformations are:
# - Return mutated inputs as extra outputs
# - Clone mutated inputs that require gradients,
# because autograd will require us to pass the pre-mutated inputs into autograd.grad
# - Return intermediate bases of outputs as additional outputs,
# needed to appease autograd.Function
# The new function returns:
# (1) The updated outputs
# (2) A boolean mask of len(new_fn_outputs),
# that can be used to tell autograd.grad which outputs should get tangents
# if we trace the backward.
def fn_prepped_for_autograd(
fn: Callable,
meta: ViewAndMutationMeta,
) -> Any:
def inner_fn(*args):
args_maybe_cloned = [
maybe_to_fresh_input(i, t, meta) for i, t in enumerate(args)
]
outs = fn(*args_maybe_cloned)
assert isinstance(outs, (tuple, list))
outs = list(outs)
assert len(meta.output_info) == len(outs)
mutated_inputs_to_return = [
x
for (i, x) in enumerate(args_maybe_cloned)
if meta.input_info[i].mutates_metadata or meta.input_info[i].mutates_data
]
intermediate_bases = []
for i, (o, info) in enumerate(zip(outs, meta.output_info)):
if info.output_type == OutputType.alias_of_intermediate_save_as_output:
intermediate_bases.append(o._base)
assert meta.num_intermediate_bases == len(intermediate_bases)
# the compiled forward should return (mutated_inputs, user_outs, intermediate_bases)
fw_outs_to_return = *mutated_inputs_to_return, *outs, *intermediate_bases
# Also return a boolean mask specifying which outputs to this function will be used as tangents
mutated_inputs_grad_mask = [
meta.input_info[meta.mutated_inp_indices[i]].mutates_data and meta.input_info[meta.mutated_inp_indices[i]].requires_grad
for (i, x) in enumerate(mutated_inputs_to_return)
]
# Pass any (non-aliased) outputs in as tangents, since they'll be returned as outputs in the fw
# For outputs that are aliases of intermediates, we will have returned the output's _base as an output in the graph instead,
# which we *should* send to grad()
output_grad_mask = [
meta.output_info[i].output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]
# Also, only tensor outputs should participate in the backward
# (in particular, Symint outputs in the forward graph shouldn't get tangents)
and issubclass(meta.output_info[i].raw_type, torch.Tensor)
and meta.output_info[i].requires_grad
for (i, x) in enumerate(outs)
]
intermediate_base_grad_mask = [True for _ in range(len(intermediate_bases))]
out_grad_mask = mutated_inputs_grad_mask + output_grad_mask + intermediate_base_grad_mask
assert len(out_grad_mask) == len(fw_outs_to_return)
# Take care to grab and sync the updated inputs from primals_after_cloning (the inputs we actually mutate!)
# and not primals (the preserved inputs, pre-mutation, that we pass to grad())
# This is annoying: our joint function needs to be aware of functionalization
# (syncing mutated inputs before calling autograd.grad())
# In theory, we could make the autograd engine do this automatically, although that probably isn't any cleaner.
for arg in args_maybe_cloned:
if not isinstance(arg, Tensor):
continue
sync_functional_tensor(arg)
return fw_outs_to_return, out_grad_mask
return inner_fn
# Given a fn, computes the joint.
# NOTE: fn is expects the following behavior:
# (1) fn() needs to return a tuple of (outs, mask),
# where `mask` tells us which outputs are meant to have tangents.
# we don't know this info automatically, because we don't actually want to blindly
# compute tangents for every output that requires grad.
# Specifically, outputs that alias inputs won't participate in the backward and get tangents.
# (2) fn() cannot mutate any inputs that require gradient.
# otherwise, when we compute autograd.grad(), we will not take those input mutations into account
# (the way this is handled is that we ensure any inputs that normally get mutated are cloned first)
def create_joint(
fn: Callable, *, aot_config: AOTConfig
) -> Any:
def inner_fn(primals: List[Any], tangents: List[Any]):
outs, tangent_mask = fn(*primals)
assert len(tangent_mask) == len(outs)
outs_to_grad = [o for needs_tangent, o in zip(tangent_mask, outs) if needs_tangent]
assert len(outs_to_grad) == len(tangents)
# Get the inputs that need gradients
grad_primals = []
inputs_needs_grads = []
# Note that we're not using primals here,
# being carefully not to pass any mutated inputs into autograd.grad()
for p in primals:
is_grad_tensor = isinstance(p, Tensor) and p.requires_grad
inputs_needs_grads.append(is_grad_tensor)
if is_grad_tensor:
grad_primals.append(p)
# Get the outputs that need gradients
needed_outs = []
needed_tangents = []
for out, tangent in zip(outs_to_grad, tangents):
if isinstance(out, Tensor) and out.requires_grad:
# A bit sketchy, but fixes e.g. test_aot_autograd_exhaustive_matmul_cpu_float32
# The issue is that we are sensitive to decomps that don't accurately maintain
# their output's _base.shape compared to eager mode, and this helps mitigate a bit.
needed_outs.append(
out if out.shape == tangent.shape else out.view(tangent.shape)
)
needed_tangents.append(tangent)
setup_stacktrace_preservation_hooks([out.grad_fn for out in needed_outs])
if config.functionalize_rng_ops:
PhiloxStateTracker.mark_beginning_of_backward()
backward_out = []
# Call the backwards pass
if grad_primals:
with fx_traceback.preserve_node_meta():
# for full graph export, we always export a joint graph where we assume no tangents are needed.
if aot_config.no_tangents:
assert len(needed_tangents) == 1 and needed_tangents[0].numel() == 1
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
allow_unused=True,
)
else:
backward_out = torch.autograd.grad(
needed_outs,
grad_primals,
grad_outputs=needed_tangents,
allow_unused=True,
)
backward_out_iter = iter(backward_out)
return outs, [
next(backward_out_iter) if i else None for i in inputs_needs_grads
]
def inner_fn_with_anomaly(*args):
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Anomaly Detection has been enabled."
)
with torch.autograd.detect_anomaly(check_nan=False):
return inner_fn(*args)
return inner_fn_with_anomaly
# This creates the final function that we want to trace using make_fx(),
# in both aot_dispatch_autograd and aot_dispatch_base.
# Preconditions:
# - fn corresponds to the user's fw function
# - fn arguments have been flattened, duplicate arguments have been handled
# - In the returned function, the "primals" arguments *includes* synthetic bases.
# This function does the work of functionalizing the input function,
# and performing copy_() calls at the end of the function if `keep_input_mutations` is set.
# The function returned has signature that is either:
# (1) "traced_fn(primals: List[Any])" if trace_joint is False
# (2) "traced_fn(primals: List[Any], tangents: List[Any])" if trace_joint is True
# Returns a new (functionalized) function, and updated arguments to call it with.
def create_functionalized_fn(
fn,
args,
*,
meta: ViewAndMutationMeta,
aot_config: AOTConfig,
trace_joint: bool,
) -> Tuple[Callable, List[Any]]:
def functionalized_f_helper(*args):
# Wrap inputs into functional wrappers
f_args = pytree.tree_map(to_fun, args)
# See Note [Disabling Functionalize TLS Above Python Functionalization]
disable_above = torch._C._ExcludeDispatchKeyGuard(torch._C.DispatchKeySet(torch._C.DispatchKey.Functionalize))
with disable_above, FunctionalTensorMode():
# Run the joint
f_outs = fn(*f_args)
if aot_config.keep_inference_input_mutations and not trace_joint:
# Note: This is a bit annoying. There's a layering issue here, where:
# (1) functionalization needs to operate on **synthetic base** inputs, before unpacking them into the "real" inputs.
# (2) For keep_input_mutations, we support tracing a call to copy_() directly on mutated inputs.
# However, we **only** want to support this for inputs that have data-only (and no metadata) mutations,
# because inductor (and backends in generally) would prefer not to see these (e.g. as_strided_(), resize_()).
# This makes it pretty difficult for this logic to operate on synthetic bases.
# (3) In addition, there are cases where it's significantly cheaper to perform the copy on the individual
# (unpacked) input aliases, instead of the synthetic base.
# Example case where (3) could be important:
#
# def f(x, y):
# x.mul_(2)
# y.mul_(3)
# return x, y
# a = torch.ones(1'000'000)
# x, y = out(a[0:9], a[1:10])
#
# It would be much better to add copy_() calls into the graph for the two tiny slices, instead of materializing
# a giant "updated synthetic base" and copying into a's entire storage.
#
# For now, we are pessimistically not performing the optimization from (3);
# we will materialize an "updated" synthetic base, and copy it back to the synthetic input base.
# This allows us to factor aot autograd much more nicely, since only one area of the code needs to worry
# about synthetic bases.
for i, (inpt_old, inpt_f) in enumerate(zip(args, f_args)):
if not isinstance(inpt_f, torch.Tensor):
continue
assert is_fun(inpt_f)
inpt_new = from_fun(inpt_f)
if meta.input_info[i].mutates_data and not meta.input_info[i].mutates_metadata:
# We found an input that had a (data-only) mutation.
# Since keep_input_mutations is set, we need to faithfully apply a copy_()
# so the compiler will see the input mutation in the graph.
assert inpt_new is not inpt_old
if meta.input_info[i].mutations_hidden_from_autograd:
with torch.no_grad(), torch.autograd._unsafe_preserve_version_counter(inpt_old):
inpt_old.copy_(inpt_new)
else:
inpt_old.copy_(inpt_new)
return pytree.tree_map(from_fun, f_outs)
# Kinda annoying, but needed to make sure that the fx graph we trace out has "primals"
# and "tangents" as its input names (which are special-cased by the partitioner)
def joint_helper(primals, tangents):
return functionalized_f_helper(primals, tangents)
def fwd_helper(*args):
return functionalized_f_helper(*args)
helper = joint_helper if trace_joint else fwd_helper
if config.functionalize_rng_ops:
# Setup the wrapper for functionalization of rng ops
helper, args = create_functionalized_rng_ops_wrapper(helper, args, trace_joint)
return helper, args
def create_graph(f, args, *, aot_config: AOTConfig) -> torch.fx.GraphModule:
with enable_python_dispatcher():
fx_g = make_fx(f, decomposition_table=aot_config.decompositions)(*args)
return fx_g
def normalize_as_list(x):
if isinstance(x, tuple):
return list(x)
elif isinstance(x, list):
return x
return [x]
aot_autograd_decompositions = {}
# This is a list since looking forward, we can have this arbitrarily nested.
graph_being_compiled: List[str] = []
# TODO: It would be nice to reset the numbering every time aot_id goes
# up, but this is annoying to do right now (because we don't know if
# an aot_id will come back from the dead), so right now this also happens
# to be a globally unique number too (at the cost of wobbling if you change
# how the graphs compile)
nth_graph: int = 0
model_name: str = "model"
def set_model_name(name):
global model_name
model_name = name
def get_aot_compilation_context() -> Tuple[List[str], str, int]:
return list(graph_being_compiled), model_name, nth_graph
def get_aot_graph_name() -> str:
"""
Returns the name of the graph being compiled.
"""
global model_name, graph_being_compiled, nth_graph
return f"{model_name}__{'_'.join(graph_being_compiled)}_{nth_graph}"
get_graph_being_compiled = get_aot_graph_name
@contextmanager
def track_graph_compiling(aot_config, graph_name):
global graph_being_compiled
# TODO: Don't shove the aot_id in here; set it in the context
graph_being_compiled = [f"{aot_config.aot_id}_{graph_name}"]
try:
yield
finally:
global nth_graph
nth_graph += 1
graph_being_compiled = []
def make_boxed_func(f):
def g(args):
return f(*args)
g._boxed_call = True
return g
def make_boxed_compiler(compiler):
@wraps(compiler)
def f(fx_g, inps):
out_f = compiler(fx_g, inps)
fx_g = make_boxed_func(out_f)
return fx_g
return f
def call_func_with_args(f, args, steal_args=False, disable_amp=False):
if not steal_args:
args = list(args)
assert isinstance(args, list)
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context():
if hasattr(f, "_boxed_call"):
out = normalize_as_list(f(args))
else:
# TODO: Please remove soon
# https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
warnings.warn(
"Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
"Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
"See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
)
out = normalize_as_list(f(*args))
return out
def aot_dispatch_base_graph(
flat_fn,
flat_args: List[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta
) -> Tuple[Callable, List[Any], Optional[SubclassMeta]]:
# aot_dispatch_base requires functionalization, but doesn't need to handle as many cases as the autograd case.
# The cases that aot_dispatch_base doesn't need to handle include:
# - outputs that are aliases of graph intermediates
# - outputs that are aliases of graph inputs
# While cases that it does need to handle include:
# - input mutations (including when inputs are aliases of each other)
# - input metadata mutations
fn_to_trace = fn_input_mutations_to_outputs(
flat_fn,
fw_metadata,
keep_data_input_mutations=aot_config.keep_inference_input_mutations,
)
fn_to_trace, updated_flat_args = create_functionalized_fn(
fn_to_trace, flat_args, meta=fw_metadata, aot_config=aot_config, trace_joint=False)
fn_to_trace, updated_flat_args_subclasses_desugared, maybe_subclass_meta = aot_dispatch_subclass(
fn_to_trace, updated_flat_args, is_joint_structure=False, meta=fw_metadata, fw_only=flat_fn)
fw_module = create_graph(
fn_to_trace,
updated_flat_args_subclasses_desugared,
aot_config=aot_config,
)
# As long as we opted to remove input mutations, then
# there should be *NO* mutating ops in the graph at this point.
copy_count = assert_functional_graph(fw_module.graph, allow_input_mutations=aot_config.keep_inference_input_mutations)
fw_module.graph.eliminate_dead_code()
fw_module.recompile()
copy_count2 = assert_functional_graph(fw_module.graph, allow_input_mutations=aot_config.keep_inference_input_mutations)
assert copy_count == copy_count2
if aot_config.enable_log:
aot_graphs_log.info("%s", lazy_format_graph_code("Forward graph", fw_module, aot_config.aot_id))
# TODO: should factor this into a separate function for export that always only returns just the graph.
if aot_config.is_export:
assert maybe_subclass_meta is None, "aot_export_module does not support tensor subclass inputs for now."
return fw_module
return fw_module, list(updated_flat_args_subclasses_desugared), maybe_subclass_meta
def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph(
flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
disable_amp = torch._C._is_any_autocast_enabled()
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context(), track_graph_compiling(aot_config, "inference"):
compiler = aot_config.inference_compiler if aot_config.inference_compiler is not None else aot_config.fw_compiler
if config.functionalize_rng_ops:
# Add the seed and offset as example inputs to pass to the compiler
fake_mode = detect_fake_mode()
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
updated_flat_args.extend([seed, offset])
if torch._guards.TracingContext.get():
torch._guards.TracingContext.get().fw_metadata = fw_metadata \
if maybe_subclass_meta is None else maybe_subclass_meta.fw_metadata
compiled_fw = compiler(fw_module, updated_flat_args)
# This boxed_call handling happens inside create_runtime_wrapper as well.
# However, create_runtime_wrapper does not expect the rng offsets in the
# output. So, we have to create another wrapper and take out the offset. As
# a result, we have to account for not boxed_call compilers as well.
if not hasattr(compiled_fw, "_boxed_call"):
compiled_fw = make_boxed_func(compiled_fw)
# Create a wrapper to set up the rng functionalize bits
@wraps(compiled_fw)
def rng_functionalization_wrapper(args):
# args is a list because compiled_fw is boxed_call
if fw_metadata.is_rng_op_functionalized:
# Add the seed and offset to args
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple()
args.extend([seed, offset])
out = compiled_fw(args)
out = functionalized_rng_runtime_epilogue(fw_metadata, out)
return out
else:
return compiled_fw(args)
if maybe_subclass_meta is not None:
compiled_fw_func = aot_dispatch_subclass_wrapper(
rng_functionalization_wrapper, subclass_metas=fw_metadata.subclass_fw_graph_out_meta, num_fw_outs_saved_for_bw=None)
else:
compiled_fw_func = rng_functionalization_wrapper
if not hasattr(compiled_fw_func, "_boxed_call"):
compiled_fw_func = make_boxed_func(compiled_fw_func)
compiled_fn = create_runtime_wrapper(
compiled_fw_func,
runtime_metadata=fw_metadata,
indices_of_inps_to_detach=[],
trace_joint=False,
keep_input_mutations=aot_config.keep_inference_input_mutations,
disable_amp=disable_amp
)
return compiled_fn
# Returns the number of detected copy_
def assert_functional_graph(fx_g: torch.fx.Graph, *, allow_input_mutations: bool = False) -> int:
placeholders = set()
copy_count = 0
# NB: It would also be nice to verify that the mutations all happen at the
# end, but we also do some administrative views after mutations so this
# isn't actually true. (TODO: Could this cause problems for Inductor?)
for n in fx_g.nodes:
if n.op == "placeholder":
placeholders.add(n)
if isinstance(n.target, torch._ops.OpOverload):
if n.target is aten.copy_.default and allow_input_mutations:
suffix = True
# Can only copy_ into an input, and can only do so once
assert n.args[0] in placeholders
placeholders.remove(n.args[0])
copy_count += 1
else:
assert not n.target._schema.is_mutable, \
f'aot_autograd expected to have an entirely functional graph, but found {n.format_node()}'
return copy_count
def are_differentiable_views(view1, view2):
if view1 is view2:
return True
if view1._base is None and view2._base is None:
return False
if view1._base is view2._base or view1._base is view2 or view1 is view2._base:
return True
return False
def same_dtype_views(view1, view2):
if view1.dtype != view2.dtype:
return False
if view1._base is not None and view1.dtype != view1._base.dtype:
return False
if view2._base is not None and view2.dtype != view2._base.dtype:
return False
return True
# Assumption: x and y are known to share a storage, and we are trying to determine
# if their memory is actually completely disjoint, based on sizes/strides/storage_offset
def tensors_definitely_do_not_overlap(x, y):
if x is y:
return False
if x.numel() == 0 or y.numel() == 0:
return True
# Make x always on the left
if x.storage_offset() > y.storage_offset():
x, y = y, x
# Short-circuit in the "obvious" overlapping case: both tensors are contiguous
if x.is_contiguous() and y.is_contiguous():
if x.storage_offset() + x.numel() > y.storage_offset():
# definitely overlap
return False
else:
# definitely no overlap
return True
if x.dim() == 2 and y.dim() == 2 and x.stride(1) == 1 and y.stride(1) == 1:
# This cases is needed for the shampoo optimizer.
# All tensors are 2d (non-contiguous), have the same outer stride, and have an inner stride of 1
# (so rows are contiguous)
if x.stride(0) == y.stride(0):
offset_delta = y.storage_offset() - x.storage_offset()
if offset_delta < x.size(1):
# definitely overlaps (row 0 of y overlaps with row 0 of x)
# Example:
# base = torch.arange(32).reshape(4, 8)
# x = base.narrow(1, 0, 4)
# x: size=(4, 4), stride=(8, 1), offset=0
# y = base.narrow(1, 3, 4)
# y: size=(4, 4), stride=(8, 1), offset=3
return False
x_total_elems_covered = x.stride(0) * (x.size(0) - 1) + x.size(1)
if x_total_elems_covered <= offset_delta:
# definitely does not overlap (last byte of x is before start of y)
# Example:
# x: size=(4, 4), stride=(8, 1), offset=0 (last byte is 27)
# y: size=(4, 4), stride=(8, 1), offset=28 (start byte is 28)
return True
# At this point, we want to check if the 0th row of y
# overlaps with **some** row of x.
# We can check this by shifting y backward by the shared stride, repeatedly,
# until the first row of y is before the first row of x.
# Then we can check if these rows overlap.
# We can accomplish this by modding our offset by the stride.
offset_delta_mod = offset_delta % x.stride(0)
# Example:
# 0 1 2 3
# 9 10 11 12
# 18 19 20 21
# 27 28 29 30
# x: size=(4, 4), stride=(9, 1), offset=0
# y: size=(4, 4), stride=(9, 1), offset=22 (this would not overlap)
# y: size=(4, 4), stride=(9, 1), offset=23 (this would not overlap)
# y: size=(4, 4), stride=(9, 1), offset=24 (this would overlap)
# y: size=(4, 4), stride=(9, 1), offset=25 (this would overlap)
# If the interval [modded_offset, modded_offset + x_size] falls entirely
# without
if offset_delta_mod + y.size(1) <= x.stride(0):
return True
else:
return False
return False
def compute_overlapping_inputs(fwd_inputs, aliased_input_indices):
actual_aliased_indices = set()
for j in range(len(aliased_input_indices)):
for i in range(j):
i_ = aliased_input_indices[i]
j_ = aliased_input_indices[j]
if not tensors_definitely_do_not_overlap(fwd_inputs[i_], fwd_inputs[j_]):
actual_aliased_indices.add(i_)
actual_aliased_indices.add(j_)
return actual_aliased_indices
# Note [Handling mutations on an input that aliases other inputs]
# The easiest example to show-case this edge case is here:
#
# def f(a, b):
# a.mul_(2)
# out = a + b
# return out
# b = torch.ones(...)
# a = b.view(-1)
# f(a, b)
#
# In this situation, if a and b happened to be aliased, we need to trace something different!
# Suppose we had b = a.view(-1)
# (In this case, that means that `a._base is b`)
#
# We need to ensure that the aliasing relationship between a and b is preserved.
# We do that detecting the specific situation above (mutate an input that aliases another input),
# and when we do that, we create a synthetic base argument. Then inside of the traced forward,
# we regenerate a and b off of that base.
# The complete example of the transformed function looks like this:
#
# // The traced forward takes in a synthetic base, and regenerates the aliased inputs as views
# // We could consider getting view-replay support here to minimize as_strided_scatter ops in the graph
# def traced_forward(base):
# a = base.as_strided(...)
# b = base.as_strided(...)
# a_updated = a.mul(2)
# base_updated = torch.as_strided_scatter(base, a_updated, ...)
# b_updated = base_updated.as_strided(...)
# out = a_updated + b_updated
# return a_updated, out
#
# def compiled_fn(a, b):
# // we detect that a is the "differentiable base" here
# base = a
# // In other situations, we might do either:
# // (1) a and b are both views off of some larger differentiable base
# // assert a._base is b._base and a._base is not None
# // base = a._base
# // (2) a and b both don't require gradients. Create a base from the storage
# // assert a._base is None and b._base is None
# // base = torch.Tensor(a.storage())
# a_updated, out = traced_forward(base)
# a.copy_(a_updated)
# return out
#
# This function:
# (1) Merges input views into a synthetic base argument, when any of those input views are mutated
# (2) Returns metadata telling the autograd.Function how to modify their arguments properly,
# to respect the new calling convention.
#
# The calling convention is as follows.
# Any inputs that were originally views of one another get yanked, and replaced with a synthetic base.
# The argument list ordering goes [base1, ..., baseN], [arg1, ..., argN],
# Where the ordering of the bases is determined from the ordering of the original view args.
# baseA will come before baseB if the earliest original argument coming from baseA
# showed up earlier in the argument list than the earliest original argument coming from baseB.
#
# Example, given some tensors a, b, c, d
# call site:
# f(a, c.view(-1), b.view(-1), b, c, d)
# Modified argument list:
# c_base comes first because the first c view came earlier in arg list than the first b view
# a and d still show up in the modified arg list, but b and c don't- they're regenerated from their bases
# b_base = torch.Tensor(b.storage())
# c_base = torch.Tensor(c.storage())
# f(c_base, b_base, a, d)
def merge_view_inputs(
fwd_inputs: List[Any], mutated_input_info: List[InputAliasInfo],
*,
# The autograd case currently has more restrictions than the inference case.
is_inference: bool,
) -> Tuple[List[Any], Optional[List[Union[int, Tuple[int, torch.Tensor]]]]]:
assert len(fwd_inputs) == len(mutated_input_info)
storage_ref_to_idx: Dict[StorageWeakRef, List[int]] = collections.defaultdict(list)
base_args = []
other_args = []
for i, inpt in enumerate(fwd_inputs):
if isinstance(inpt, Tensor):
storage_ref = StorageWeakRef(inpt.untyped_storage())
storage_ref_to_idx[storage_ref].append(i)
else:
other_args.append(inpt)
# Note [Synthetic Base Info Metadata]
# This list contains metadata that tells you what the i'th argument in the inner calling convention should be.
# It's either:
# - another int (corresponding to the index in the argument list of the element from the outer calling convention)
# - idx, view_tensor, where we can generate the new output with view_tensor._view_func(old_args[idx])
# idx corresponds to which synthetic base from the outer calling context to view
inner_calling_convention_meta: Dict[int, Union[int, Tuple[int, torch.Tensor]]] = {}
for aliased_input_indices in storage_ref_to_idx.values():
if len(aliased_input_indices) <= 1 or not any(
# We only care about mutations that affect all aliases,
# so metadata mutations on an input doesn't require us to do synthetic base handling.
mutated_input_info[inpt_idx].mutates_data
for inpt_idx in aliased_input_indices
):
for curr_idx in aliased_input_indices:
other_args.append(fwd_inputs[curr_idx])
continue
# Here, we attempt to do a more complicated check to detect false aliasing
# (e.g. if all the tensors have the same storage, but don't actually overlap)
# In theory, we could have a large group of tensors that all share storages, where only *some* of them
# have overlapping memory.
# I don't bother with that case for now: here, we only bail out earlier if we detect that **every** pair
# of tensors in the current group that shares a storage is non-overlapping.
aliased_input_indices_no_false_sharing = compute_overlapping_inputs(fwd_inputs, aliased_input_indices)
if len(aliased_input_indices_no_false_sharing) <= 1:
for curr_idx in aliased_input_indices:
other_args.append(fwd_inputs[curr_idx])
continue
# We detected an input that was mutated, AND aliases with another input.
# we need to replace this set of aliased inputs with a single synthetic base.
# For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases
# and error out. We can fix them later.
# These checks are transitive, so we don't need to check every pair.
for idx1, idx2 in zip(aliased_input_indices, aliased_input_indices[1:], strict=False):
view1 = fwd_inputs[idx1]
view2 = fwd_inputs[idx2]
# The "inputs that are aliased but have different differentiable bases" case
# is more complicated and hopefully pretty rare. Not currently handled.
if not is_inference:
assert are_differentiable_views(
view1, view2
), "aot_autograd() does not yet handle non-differentiable view input mutations."
# Regenerating views when reinterpreting complex / real tensors seems non-trivial,
# not handling for now
assert same_dtype_views(
view1, view2
), "aot_autograd() does not yet handle input mutations on views with different dtypes."
non_none_bases = [
fwd_inputs[i]._base
for i in aliased_input_indices
if fwd_inputs[i]._base is not None
]
aliases_with_none_bases = [
fwd_inputs[i] for i in aliased_input_indices if fwd_inputs[i]._base is None
]
if len(non_none_bases) == 0:
# Case where none of the aliases have a ._base
# we generate a synthetic base without gradients, and generate views off of it
# We hit this case when we have input tensors to the graph that share a storage,
# but do not have a ._base field.
# Wondering when we hit this case?
# The _base field simply says that autograd knows about the aliasing relationship,
# but sometimes we create tensors which are aliased out of the same storage but guaranteed
# to be disjoint. In these cases, we will skip setting up the _base relationship
# for performance reasons (because the fact that the tensors share the same storage
# is unobservable unless you (1) do naughty things with resize_/as_strided
# or (2) look at the storage--as we are doing here.)
# One particular example of this is optimizer steps on the LSTM module:
# LSTM parameters are packed into a contiguous storage for efficiency reasons when
# calling cuDNN kernels, so when these parameters get passed to the optimizer we will
# find they share the same storage, but do not have _base set since they are all disjoint.
#
# NOTE: There is one case where this is unsafe:
# torch.Tensor(storage) will ALWAYS create a 1D tensor, which is not necessarily
# the same shape as the "actual" base that the tensor came from.
# For the most part this is fine, because we always use as_strided()
# to generate the original aliased inputs again.
# If we were to use view-replay though, this could cause the aliased views
# to have incorrect sizes.
example_idx = aliased_input_indices[0]
example_alias = fwd_inputs[example_idx]
# Note that this function is re-used at both trace time and runtime.
# At trace time, we're under a FakeMode so synthetic_base becomes a FakeTensor.
synthetic_base = torch.empty((0,), dtype=example_alias.dtype, device=example_alias.device)
# We don't actually have a convenient way of going from storage -> tensor,
# So using set_() here (we suffer some minor overhead, but this case is rare).
synthetic_base.set_(example_alias.untyped_storage())
else:
# Case where all of the aliases require gradients, and have the same _base.
synthetic_base = non_none_bases[0]
for other_base in non_none_bases[1:]:
assert (
other_base is synthetic_base
), "aot_autograd() does not yet handle non-differentiable view input mutations."
for alias in aliases_with_none_bases:
assert (
alias is synthetic_base
), "aot_autograd() does not yet handle non-differentiable view input mutations."
base_args.append(synthetic_base)
for curr_view_idx in aliased_input_indices:
curr_view = fwd_inputs[curr_view_idx]
base_idx = len(base_args) - 1
# We store just enough info here so that we can regenerate the view later.
# Regeneration: curr_view._view_func(args[base_idx])
inner_calling_convention_meta[curr_view_idx] = (base_idx, curr_view)
if len(base_args) == 0:
assert len(other_args) == len(fwd_inputs)
# If no synthetic bases are necessary, just return the original inputs.
return fwd_inputs, None
else:
# Otherwise, return:
# (1) The new args according to the updated calling convention: (synthetic_bases, other_args)
# (2) Metadata telling functionalization how to generate the inner argument list given the outer calling convention.
# We post-process it into a list, where meta[i] tells you info about the i'th argument in the inner calling convention.
args_to_functionalization = base_args + other_args
arg_to_old_idx_map = {arg: i for (i, arg) in enumerate(fwd_inputs)}
for i, other_arg in enumerate(other_args):
new_idx = len(base_args) + i
old_idx = arg_to_old_idx_map[other_arg]
inner_calling_convention_meta[old_idx] = new_idx
# post process into a list
post_processed_calling_convention_meta: List[Union[int, Callable]] = [
-1 for _ in range(len(inner_calling_convention_meta))
]
for k, v in inner_calling_convention_meta.items():
post_processed_calling_convention_meta[k] = v
# Quick assert: every argument in the inner calling convention should be accounted for.
for x in post_processed_calling_convention_meta:
assert x != -1
return args_to_functionalization, post_processed_calling_convention_meta
def format_guard_bug_msg(aot_config, expected):
return (
f"At compilation time, graph {aot_config.aot_id} was compiled under the "
f"assumption that {expected}, but at runtime this was not the case. "
"This indicates a guard bug in AOTAutograd or Dynamo, please file a bug to PyTorch."
)
def remove_dupe_metadata(
m: ViewAndMutationMeta,
keep_arg_mask: List[bool],
add_dupe_map: List[int],
) -> ViewAndMutationMeta:
assert len(m.input_info) == len(keep_arg_mask)
# Easy invariant: the first argument should never be a dupe (it will be kept)
assert len(keep_arg_mask) > 0 and keep_arg_mask[0]
# Filter dupe'd mutated inputs out of traced_tangents
num_data_mutations = len([x for x in m.input_info if x.mutates_data])
other_traced_tangents = m.traced_tangents[num_data_mutations:]
inp_traced_tangents = m.traced_tangents[:num_data_mutations]
filtered_inp_traced_tangents = [x for i, x in enumerate(inp_traced_tangents) if keep_arg_mask[m.mutated_inp_indices[i]]]
traced_tangents = filtered_inp_traced_tangents + other_traced_tangents
return ViewAndMutationMeta(
input_info=[x for i, x in enumerate(m.input_info) if keep_arg_mask[i]],
# For outputs that are views of inputs, we store the index of the input that the output
# was generated from. Need to update that index to account for removed dupes.
output_info=[
OutputAliasInfo(
output_type=o.output_type,
raw_type=o.raw_type,
dynamic_dims=o.dynamic_dims,
base_idx=None if o.base_idx is None else add_dupe_map[o.base_idx],
requires_grad=o.requires_grad
)
for o in m.output_info
],
num_intermediate_bases=m.num_intermediate_bases,
keep_input_mutations=m.keep_input_mutations,
traced_tangents=traced_tangents,
# We are guaranteed not to get here, since dupes are not supported today with subclass inputs.
subclass_inp_meta=None,
subclass_fw_graph_out_meta=None,
subclass_tangent_meta=None,
is_train=m.is_train,
)
# Given our ViewAndMutation metadata, this fn constructs a new set of metadata,
# after adding synthetic base arguments to the function.
# Most of the work in this fn is slogging through all of the metadata corresponding to inputs,
# and updating it with our synthetic base calling convention.
#
# When config.debug_assert is set, we automatically regenerate the metadata
# and compare it to this output for sanity.
#
# In addition to the updated metadata, also return the list of input indices
# that will need to be updated in the synthetic base epilogue
def create_synthetic_base_metadata(
m: ViewAndMutationMeta,
# Maps each outer argument idx to its inner idx (or, if this outer arg is generated from a
# synthetic base, you get a tuple of (i, TensorMeta), telling you the base tensor idx, and view metadata)
synthetic_base_info: List[Union[int, Tuple[int, torch.Tensor]]],
outer_args: List[Any],
inner_args: List[Any],
) -> Tuple[ViewAndMutationMeta, List[int]]:
S_Outer = NewType('S_Outer', int)
S_Inner = NewType('S_Inner', int)
synthetic_base_to_indices: Dict[S_Inner, List[S_Outer]] = {}
for inner_idx in range(len(inner_args)):
outer_aliased_indices_of_current_base_arg = [
outer_idx for outer_idx, inner_idx_or_tuple in enumerate(synthetic_base_info)
if (isinstance(inner_idx_or_tuple, int) and inner_idx_or_tuple == inner_idx)
or (isinstance(inner_idx_or_tuple, tuple) and inner_idx_or_tuple[0] == inner_idx)
]
synthetic_base_to_indices[inner_idx] = outer_aliased_indices_of_current_base_arg
# given the requires_grad info on mutated inputs,
# generate the requires_grad info on those same mutated inputs, but after constructing synthetic bases.
input_infos = []
for outer_indices in synthetic_base_to_indices.values():
# leaf-ness should be all-or-nothing for aliased tensor.
# (aka if "a" and "b" are views, then a.is_leaf == b.is_leaf)
any_leaf = any(m.input_info[x].is_leaf for x in outer_indices)
all_leaf = all(m.input_info[x].is_leaf for x in outer_indices)
assert any_leaf == all_leaf
inpt_info = InputAliasInfo(
# If len(outer_indices) > 1, then this input is a synthetic base.
# The invariant is that to the rest of aot autograd, synthetic bases only show up if
# one of their aliases gets a data mutation. And if any of their aliases get metadata
# mutations, they will be hidden from the rest of aot autograd.
mutates_data=True if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_data,
mutates_metadata=False if len(outer_indices) > 1 else m.input_info[outer_indices[0]].mutates_metadata,
mutations_hidden_from_autograd=all(m.input_info[x].mutations_hidden_from_autograd for x in outer_indices),
is_leaf=any_leaf,
requires_grad=any(m.input_info[x].requires_grad for x in outer_indices)
)
input_infos.append(inpt_info)
# Find any inputs that fulfill the following criteria:
# (1) They are part of a synthetic base (because they alias another input,
# and at least one input experiences a data mutation)
# (2) They experience a metadata mutation
outer_aliased_arg_idx_with_metadata_mutations = [
outer_idx for outer_idx, inpt_info in enumerate(m.input_info)
if inpt_info.mutates_metadata and not isinstance(synthetic_base_info[outer_idx], int)
]
# grab the original requires grad info on the outputs, except the ones from the mutated inputs
input_metadata_output_info = [
OutputAliasInfo(
output_type=OutputType.alias_of_input,
raw_type=FunctionalTensor,
dynamic_dims={i for i, s in enumerate(outer_args[outer_idx].shape) if not is_concrete_int(s)},
base_idx=synthetic_base_info[outer_idx][0],
requires_grad=outer_args[outer_idx].requires_grad
) for outer_idx in outer_aliased_arg_idx_with_metadata_mutations]
existing_output_infos = [
OutputAliasInfo(
output_type=o.output_type,
raw_type=o.raw_type,
dynamic_dims=o.dynamic_dims,
# Map the input idx pre-synthetic-bases to the new idx post-synthetic-bases
base_idx=None if o.base_idx is None
else synthetic_base_info[o.base_idx]
if isinstance(synthetic_base_info[o.base_idx], int)
else synthetic_base_info[o.base_idx][0],
requires_grad=o.requires_grad
)
for o in m.output_info]
inner_mutated_tangents = [
x
for inner_idx, x in enumerate(inner_args)
if input_infos[inner_idx].mutates_data and input_infos[inner_idx].requires_grad
]
output_info = existing_output_infos + input_metadata_output_info
# Regenerate traced tangents to include mutated inputs including synthetic bases
traced_tangents = inner_mutated_tangents + m.traced_tangents[len(inner_mutated_tangents):]
return ViewAndMutationMeta(
input_info=input_infos,
output_info=output_info,
num_intermediate_bases=m.num_intermediate_bases,
keep_input_mutations=m.keep_input_mutations,
traced_tangents=traced_tangents,
# We are guaranteed not to get here, since synthetic_base codepaths are not supported today with subclass inputs.
subclass_inp_meta=None,
subclass_fw_graph_out_meta=None,
subclass_tangent_meta=None,
is_train=m.is_train,
), outer_aliased_arg_idx_with_metadata_mutations
# MOTIVATION:
#
# When tracing functions for future execution, one must be careful not to pass
# in the same input tensor multiple times (e.g., f(x, x), as this can result
# in graphs that are ONLY valid if you later pass a new tensor in exactly the
# same way (e.g., f(y, y)). (NB: we really mean duplicate; two distinct
# tensors that alias each other is a different situation that is covered by
# aot_dispatch_deduplicated_autograd). Here are two examples:
#
# (1) Suppose you have a function:
#
# def f(x, y):
# return x + y
#
# If you make_fx(f)(x, x), you will trace out:
#
# def f(x, y):
# return y + y
#
# Oops!
#
# (2) For most tensors x and y, you can compute f's gradient with respect to
# these to inputs by saying torch.autograd.grad(f(x, y), (x, y)). However,
# if x is y, you will trace out a program that gets incorrect gradients:
#
# >>> x = torch.randn(1, requires_grad=True)
# >>> torch.autograd.grad(x + x, (x, x))
# (tensor([2.]), tensor([2.]))
#
# In other words, the gradient is double-counted. Deduplicating the arguments
# gives you an appropriate gradient:
#
# >>> y = torch.randn(1, requires_grad=True)
# >>> torch.autograd.grad(x + y, (x, y))
# (tensor([1.]), tensor([1.]))
#
# HOW TO DEDUPLICATE:
#
# There are a few strategies, in order of preference:
#
# 1. For every duplicate argument to the function, detach it into
# a separate leaf tensor, so that it is no longer duplicated.
#
# PRO: The resulting compiled graph works for any configuration
# of duplicated arguments.
#
# CON: It does not (naively) work if you mutate the metadata of inputs:
#
# def f(x, y):
# x.transpose_(0, 1)
# y.transpose_(0, 2)
#
# x = torch.randn(2, 3, 4)
# f(x, x)
#
# The ordering of the transposes inside f dictates whether or not
# you get [4, 2, 3] or [3, 4, 2]. This means that you cannot precompute
# what metadata mutations should get applied to each input; you need to
# assume they aren't duplicates (what we do today) or preserve
# the original metadata mutations exactly in order, so that they work
# for any duplicate configuration.
#
# CON: It does not (naively) work if you mutate the data of inputs.
# In particular, leaf tensors that require grad cannot be mutated,
# this makes it impossible to differentiate with respect to the original
# base.
#
# 2. For every duplicate argument to the function, remove it, so it is
# no longer part of the "true" signature:
#
# PRO: Implemented naively, it still works for metadata/data mutation.
#
# CON: The resulting compiled graph is duplicate-specialized: it only
# works if future calls duplicate arguments in exactly the same way.
# Horribly, Dynamo doesn't guard on this at the moment. But even if
# it did, you could still end up recompiling a bunch of each duplicate.
#
# Our strategy is to do (1) if we can, and do (2) otherwise, erroring if
# Dynamo's guards are not enough. In practice, this seems to cover
# everything.
#
def aot_wrapper_dedupe(
flat_fn,
flat_args: List[Tensor],
aot_config: AOTConfig,
*,
compiler_fn,
fw_metadata,
):
# Use information about whether or not flat_fn mutates its arguments
# or not to handle dupe args
# Strategy 1: For any input that is not mutated, we can leafify it if we
# need to remove a duplicate.
leaf_flat_args = []
args_set = set()
ok = True
for i, a in enumerate(flat_args):
if not isinstance(a, torch.Tensor):
leaf_flat_args.append(a)
elif a not in args_set:
args_set.add(a)
leaf_flat_args.append(a)
elif not fw_metadata.input_info[i].mutates_data and not fw_metadata.input_info[i].mutates_metadata:
leaf_flat_args.append(a.detach().requires_grad_(a.requires_grad))
else:
ok = False
break
if ok:
return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
if requires_subclass_dispatch(leaf_flat_args, fw_metadata):
raise RuntimeError("""\
Encountered duplicate inputs that are mutated in the graph, but at least one input/output
to the graph is a tensor subclass. This is not supported today. You can try to
remove the aliasing yourself as a workaround, or otherwise file an issue on github.""")
# export path: ban duplicate inputs for now, add later if requested.
if aot_config.is_export:
raise RuntimeError(f"""\
Encountered duplicated inputs that are mutated in the graph you are trying to export.
This functionality is currently not supported. If needed, please file a github issue.
fw_metadata={str(fw_metadata)}
""")
# Strategy 2: Duplicate specialize.
#
# In Haskell types, suppose you have:
#
# add_dupe_args :: DedupedArgs -> Args
# remove_dupe_args :: Args -> DedupedArgs
#
# compiler_fn
# :: (DedupedArgs -> R) -> DedupedArgs -> AOTConfig -> (DedupedArgs -> R)
# deped_compiler_fn
# :: (Args -> R) -> Args -> AOTConfig -> (Args -> R)
#
# Then the code below can be written in point-free style as:
#
# deduped_compiler_fn f a c =
# compiler_fn (f . add_dupe_args) (remove_dupe_args a) c . remove_dupe_args
#
# Suppose you have:
#
# [a, b, a, c]
#
# We want:
#
# remove_dupe_args([a, b, a, c]) == [a, b, c]
# add_dupe_args([a, b, c]) == [a, b, a, c]
#
# This is done via (respectively):
#
# seen_args = {a: 0, b: 1, c: 2}
# enumerate(add_dupe_map) = [ # how to get args from the deduped list
# (0, 0),
# (1, 1),
# (2, 0),
# (3, 2),
# ]
# keep_arg_mask = [True, True, False, True]
seen_args = {}
keep_arg_mask = []
# Implicitly map duped arg position (list index) to de-duped arg position
add_dupe_map: List[int] = []
duped_arg_len = len(flat_args)
j = 0 # index into deduped_flat_args
for t in flat_args:
if isinstance(t, torch.Tensor):
if t in seen_args:
keep_arg_mask.append(False)
add_dupe_map.append(seen_args[t])
continue
seen_args[t] = j
keep_arg_mask.append(True)
add_dupe_map.append(j)
j += 1
assert len(add_dupe_map) == duped_arg_len, (
f"Expects add_dupe_map to have length {duped_arg_len} but got {len(add_dupe_map)}"
)
# NB: Hot path, avoid set lookups here
# TODO: Can avoid the zip here too, probably
def remove_dupe_args(args):
return [t for t, keep in zip(args, keep_arg_mask) if keep]
def add_dupe_args(args):
return [args[add_dupe_map[i]] for i in range(duped_arg_len)]
deduped_flat_args = remove_dupe_args(flat_args)
# Update our input metadata to remove duped input metadata.
updated_fw_metadata = remove_dupe_metadata(fw_metadata, keep_arg_mask, add_dupe_map)
tracing_context = TracingContext.get()
if tracing_context and aot_config.aot_autograd_arg_pos_to_source:
# TODO(voz): This structure is 1:1, we could consider an alternate structure like
# kept_pos:[dupe_arg_pos], however, add_dupe_map is 1:1 so we would need a new structure there,
# which feels like needless complexity for a tiny bit of efficiency at this point.
for dupe_arg_pos, (kept_pos, keep_arg) in enumerate(zip(add_dupe_map, keep_arg_mask)):
if not keep_arg:
dupe_arg_source = aot_config.aot_autograd_arg_pos_to_source[dupe_arg_pos]
kept_arg_source = aot_config.aot_autograd_arg_pos_to_source[kept_pos]
tracing_context.guards_context.aotautograd_guards.append(DuplicateInputs(kept_arg_source, dupe_arg_source))
@wraps(flat_fn)
def wrapped_flat_fn(*args):
return flat_fn(*add_dupe_args(args))
if config.debug_assert:
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
wrapped_flat_fn,
keep_input_mutations=fw_metadata.keep_input_mutations,
is_train=fw_metadata.is_train,
)(*deduped_flat_args)
assert ref_fw_metadata == updated_fw_metadata, \
f'ref_metadata={str(ref_fw_metadata)}, actual_metadata={str(updated_fw_metadata)}'
compiled_fn = compiler_fn(wrapped_flat_fn, deduped_flat_args, aot_config, fw_metadata=updated_fw_metadata)
if not hasattr(compiled_fn, "_boxed_call"):
compiled_fn = make_boxed_func(compiled_fn)
@wraps(compiled_fn)
def wrapped_compiled_fn(args):
deduped_args = remove_dupe_args(args)
args.clear()
return compiled_fn(deduped_args)
wrapped_compiled_fn._boxed_call = True
# This can be uncommented when we properly guard for duplicates,
# but right now we must not do it.
# if not config.debug_assert:
# return wrapped_compiled_fn
@wraps(wrapped_compiled_fn)
def debugged_compiled_fn(args):
# Test that the computed remove/add arg functions are an inverse
new_args = add_dupe_args(remove_dupe_args(args))
seen = {}
for i, (x, y) in enumerate(zip(new_args, args)):
seen[y] = None
assert x is y, format_guard_bug_msg(
aot_config,
f"{describe_input(i, aot_config)} would be a duplicate of "
f"{describe_input(add_dupe_map[i], aot_config)}",
)
# This is only an error if there is metadata mutation on both of
# the duped arguments; in this case, we need to know what order
# the metadata mutation applies in. You'll get the correct result
# otherwise, because a graph that assumes distinct inputs works if
# you dupe the inputs (the gradient contributions from each input
# will get summed up appropriately.)
#
# TODO: work out how to setup this assert correctly
"""
assert len(seen) == unique_args, format_guard_bug_msg(aot_config,
f"there would be {unique_args} distinct arguments"
)
"""
return wrapped_compiled_fn(args)
debugged_compiled_fn._boxed_call = True
return debugged_compiled_fn
# This layer handles the situation where you have two inputs that alias each other,
# and one of the inputs is mutated.
# We need to take special care to ensure that the mutation is applied to the other aliases in the graph.
#
# pre-condition: aot_wrapper_dedup has already run.
# (This function will in theory work if there are duplicate args.
# However, the synthetic base code path is a bit sub-optimal, and running with dupe'd inputs
# would cause us to hit that path more frequently).
def aot_wrapper_synthetic_base(
flat_fn,
flat_args: List[Tensor],
aot_config: AOTConfig,
*,
fw_metadata: ViewAndMutationMeta,
# Currently, the only reason we need to plumb this bool is because
# the synthetic base code prohibits more cases in the autograd case than the inference case.
needs_autograd: bool,
compiler_fn,
):
is_inference = not needs_autograd
flat_args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
flat_args, fw_metadata.input_info, is_inference=is_inference,
)
# Happy path: we don't need synthetic bases
if synthetic_base_info is None:
return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
# export path: ban synthetic bases for now, add later if requested.
if requires_subclass_dispatch(flat_args, fw_metadata):
raise RuntimeError("""\
Encountered aliased inputs that are mutated in the graph, but at least one input/output
to the graph is a tensor subclass. This is not supported today. You can try to
remove the aliasing yourself as a workaround, or otherwise file an issue on github.""")
if aot_config.is_export:
raise RuntimeError(f"""\
Encountered aliased inputs that are mutated in the graph you are trying to export.
This functionality is currently not supported. If needed, please file a github issue.
synthetic_base_info={str(synthetic_base_info)}
fw_metadata={str(fw_metadata)}
""")
assert len(fw_metadata.input_info) == len(synthetic_base_info)
# Update our forward metadata to take synthetic bases into account
fw_metadata_updated, aliased_arg_idx_with_metadata_mutations = \
create_synthetic_base_metadata(fw_metadata, synthetic_base_info, flat_args, flat_args_with_synthetic_bases)
num_aliased_args_with_metadata_mutations = len(aliased_arg_idx_with_metadata_mutations)
def unpack_synthetic_bases(primals: List[Any]) -> List[Any]:
f_args_inner = []
for inner_idx_or_tuple in synthetic_base_info:
if isinstance(inner_idx_or_tuple, int):
f_args_inner.append(primals[inner_idx_or_tuple])
else:
inner_base_idx, view_tensor = inner_idx_or_tuple
base = primals[inner_base_idx]
view_arg = gen_alias_from_base(
base, view_tensor, view_tensor.requires_grad
)
f_args_inner.append(view_arg)
return f_args_inner
@wraps(flat_fn)
def wrapped_flat_fn(*args):
unpacked_args = unpack_synthetic_bases(args)
# This is a bit subtle. The goal of this entire function (aot_dispatch_synthetic_bases)
# is to relieve the downstream logic from having to reason about mutations on inputs that alias
# each other, by replacing aliased inputs with a synthetic base.
# One area where this breaks down a bit however is if one of those aliased inputs
# experienced a metadata mutation.
# We are now obligated to reapply the metadata mutation directly to the user's input;
# it isn't enough to apply mutations back to the synthetic base in the downstream logic.
#
# The way we handle this is by pretending that those aliased inputs that experience metadata mutations
# are additional outputs in the user's forward function.
# The downstream logic will just treat these as "user outputs that alias inputs".
# However, we will manually grab them at runtime here, use them to reapply the metadata mutation
# to the user inputs, and not return them to the user.
aliased_args_with_metadata_mutations = [
x for i, x in enumerate(unpacked_args) if i in aliased_arg_idx_with_metadata_mutations]
if len(aliased_args_with_metadata_mutations) > 0:
return *(flat_fn(*unpacked_args)), *aliased_args_with_metadata_mutations
else:
return flat_fn(*unpacked_args)
if config.debug_assert:
ref_fw_metadata = run_functionalized_fw_and_collect_metadata(
wrapped_flat_fn,
keep_input_mutations=fw_metadata.keep_input_mutations,
is_train=fw_metadata.is_train,
)(*flat_args_with_synthetic_bases)
assert ref_fw_metadata == fw_metadata_updated, (
f'ref_metadata={pprint.pformat(partial_asdict(ref_fw_metadata))}, '
f'\nactual_metadata={pprint.pformat(partial_asdict(fw_metadata_updated))}'
)
compiled_fn = compiler_fn(wrapped_flat_fn, flat_args_with_synthetic_bases, aot_config, fw_metadata=fw_metadata_updated)
if not hasattr(compiled_fn, "_boxed_call"):
compiled_fn = make_boxed_func(compiled_fn)
@wraps(compiled_fn)
def wrapped_compiled_fn(args):
args_with_synthetic_bases, synthetic_base_info = merge_view_inputs(
args, fw_metadata.input_info, is_inference=is_inference
)
assert synthetic_base_info is not None
aliased_args_w_metadata_mutations = [args[i] for i in aliased_arg_idx_with_metadata_mutations]
args.clear()
outs = compiled_fn(args_with_synthetic_bases)
if num_aliased_args_with_metadata_mutations > 0:
# This code does not handle **all** input metadata mutations.
# Instead, it only handles metadata mutations on inputs that were converted into synthetic bases
# (which only happens if at least one aliased input experienced a data mutation).
# e.g:
# def f(a, b):
# a.mul_(2)
# b.t_(1, 0)
# f(x.view(2, 2), x.view(2, 2))
mutated_metadata_inps = outs[-num_aliased_args_with_metadata_mutations:]
user_outs = outs[:-num_aliased_args_with_metadata_mutations]
for inp, mutated_inp in zip(aliased_args_w_metadata_mutations, mutated_metadata_inps):
inp.as_strided_(mutated_inp.size(), mutated_inp.stride(), mutated_inp.storage_offset())
return user_outs
return outs
return wrapped_compiled_fn
def describe_input(i, aot_config):
if i < aot_config.num_params_buffers:
return f"parameter/buffer {i}"
else:
return f"input {i - aot_config.num_params_buffers}"
# The wrapper created by this function handles all of the runtime aliasing and mutation "epilogue" logic
# that needs to run after the compiled function.
#
# This function accepts a trace_joint flag, indicating whether or not we're generating the runtime
# epilogue for a forward-only inference graph, or for an autograd.Function.apply function.
# This is because there are some minor differences in how we treat these cases at runtime:
# - resize_() is currently handled in the inference case, but not fully handled in the autograd case.
# - the autograd cases inserts TensorAlias wrapper objects for outputs that alias inputs
def create_runtime_wrapper(
compiled_fn,
*,
runtime_metadata: ViewAndMutationMeta,
indices_of_inps_to_detach: List[int],
trace_joint: bool,
keep_input_mutations: bool,
disable_amp: bool
):
if not hasattr(compiled_fn, "_boxed_call"):
compiled_fn = make_boxed_func(compiled_fn)
def runtime_wrapper(*args):
if trace_joint:
args_ = list(args)
# See Note [Detaching inputs that never need gradients]
for idx in indices_of_inps_to_detach:
if isinstance(args_[idx], torch.Tensor):
args_[idx] = args_[idx].detach()
with torch.autograd._force_original_view_tracking(True):
all_outs = call_func_with_args(
compiled_fn,
args_,
disable_amp=disable_amp,
)
else:
# When we have an inference graph, we run with torch.no_grad.
# It's possible to get an inference graph with inputs that require grad,
# in which case we want to make sure autograd is disabled
# (since e.g., inductor will generate aten.addmm.out calls which autograd will complain on)
with torch.no_grad():
all_outs = call_func_with_args(
compiled_fn,
args,
disable_amp=disable_amp,
)
num_mutated_inps = runtime_metadata.num_mutated_inputs
num_metadata_mutated_inps = runtime_metadata.num_mutated_metadata_inputs
num_intermediate_bases = runtime_metadata.num_intermediate_bases
if keep_input_mutations:
assert (
len(all_outs)
== num_metadata_mutated_inps + runtime_metadata.num_outputs + num_intermediate_bases
)
assert (
len(runtime_metadata.mutated_inp_runtime_indices) == num_metadata_mutated_inps
)
else:
assert (
len(all_outs)
== num_mutated_inps + runtime_metadata.num_outputs + num_intermediate_bases
)
assert (
len(runtime_metadata.mutated_inp_runtime_indices) == num_mutated_inps
)
# Step 3: After running the compiled fw, apply updates to mutated inputs
num_mutations_to_apply = len(runtime_metadata.mutated_inp_runtime_indices)
if num_mutations_to_apply > 0:
updated_inputs = all_outs[: num_mutations_to_apply]
fw_outs = all_outs[num_mutations_to_apply :]
for i, inpt_idx in enumerate(
runtime_metadata.mutated_inp_runtime_indices
):
meta = runtime_metadata.input_info[inpt_idx]
if not meta.mutates_data and not meta.mutates_metadata:
continue
original_inpt = args[inpt_idx]
updated_inpt = updated_inputs[i]
if meta.mutates_metadata and not meta.mutates_data:
if trace_joint:
assert isinstance(updated_inpt, TensorAlias)
updated_inpt = updated_inpt.alias
# We need to grab the size/stride/storage_offset from the compiled forward,
# and use that to mutate the metadata of the input
original_inpt.as_strided_(
updated_inpt.size(),
updated_inpt.stride(),
updated_inpt.storage_offset(),
)
else:
if meta.mutates_data and meta.mutates_metadata:
original_inpt.as_strided_(
updated_inpt.size(),
updated_inpt.stride(),
updated_inpt.storage_offset(),
)
else:
assert meta.mutates_data
if meta.is_leaf and original_inpt.requires_grad:
# We can hit this situation in this case:
# def f(x):
# x.detach().mul_(2)
# return x + 1
# AOTAutograd will see a mutation in the above case, and try to
# apply a copy_() here, in the epilogue.
# But if x required gradients, and is a leaf, then autograd
# will yell at us for trying to mutate it.
# However, it's only possible to end up in this scenario (like the above)
# if all of the mutations to the leaf input were non-autograd-tracking mutations
# (aka mutations under no_grad(), or on detached views).
# In that case, we fully want to hide the mutation from autograd, so detaching is ok.
original_inpt.detach().copy_(updated_inpt)
else:
original_inpt.copy_(updated_inpt)
else:
fw_outs = all_outs
# Step 4: Manually regenerate any outputs that are aliased to inputs, instead of
# compiling them.
if runtime_metadata.num_outputs_aliased > 0:
# The compiled forward also returned intermediate bases. We don't want to return them to the user.
if runtime_metadata.num_intermediate_bases > 0:
fw_outs_no_intermediate_bases = fw_outs[
: -runtime_metadata.num_intermediate_bases
]
intermediate_bases = fw_outs[-runtime_metadata.num_intermediate_bases:]
else:
fw_outs_no_intermediate_bases = fw_outs
intermediate_bases = []
assert len(fw_outs_no_intermediate_bases) == len(runtime_metadata.output_info)
fw_outs_including_aliases = []
for i, (o, info) in enumerate(zip(
fw_outs_no_intermediate_bases, runtime_metadata.output_info
)):
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]:
fw_outs_including_aliases.append(o)
continue
if trace_joint:
assert isinstance(o, TensorAlias)
o_ = o.alias
else:
o_ = o
o_grad = runtime_metadata.output_info[i].requires_grad
if info.output_type == OutputType.alias_of_input:
aliased_base_tensor = args[info.base_idx]
regenerated_out = gen_alias_from_base(aliased_base_tensor, o_, o_grad)
fw_outs_including_aliases.append(regenerated_out)
continue
elif info.output_type == OutputType.is_input:
aliased_base_tensor = args[info.base_idx]
regenerated_out = aliased_base_tensor
fw_outs_including_aliases.append(regenerated_out)
continue
elif info.output_type == OutputType.alias_of_intermediate:
base_tensor_list = intermediate_bases
elif info.output_type == OutputType.alias_of_intermediate_save_as_output:
base_tensor_list = intermediate_bases
else:
assert info.output_type == OutputType.alias_of_intermediate_base_is_user_output
base_tensor_list = fw_outs_no_intermediate_bases
aliased_base_tensor = base_tensor_list[info.base_idx]
# TODO: handle the custom autograd function case here.
# We need a way to check whether a tensor came from a custom autograd fn from python,
# AND a way to replay that custom view fn.
regenerated_out = gen_alias_from_base(aliased_base_tensor, o_, o_grad)
fw_outs_including_aliases.append(regenerated_out)
ret_outs = fw_outs_including_aliases
else:
ret_outs = fw_outs
if runtime_metadata.dynamic_outputs:
for t, o in zip(ret_outs, runtime_metadata.output_info):
if o.dynamic_dims is None:
continue
if hasattr(t, '_dynamo_weak_dynamic_indices'):
t._dynamo_weak_dynamic_indices |= o.dynamic_dims
else:
t._dynamo_weak_dynamic_indices = o.dynamic_dims.copy()
if runtime_metadata.grad_enabled_mutation is not None:
torch.set_grad_enabled(runtime_metadata.grad_enabled_mutation)
return ret_outs
return runtime_wrapper
# Calling convention: If we are running functionalized RNG, then outs consists
# of (user_outs, rng_offset)
def functionalized_rng_runtime_epilogue(metadata, outs, return_new_outs=True):
if metadata.is_rng_op_functionalized:
assert metadata.num_outputs_rng_offset == 1
new_rng_offset = outs[-1]
CUDARngStateHelper.set_new_offset(new_rng_offset)
if return_new_outs:
user_outs = outs[:-1]
return user_outs
else:
return None
return outs
def create_functionalized_rng_ops_wrapper(func, args, trace_joint=True):
# Functionalization of rng ops changes the calling convention of the joint graph.
# It goes from (primals, tangents) to (seed, offset, primals, tangents)
# At runtime, we pass on the current seed and offset. This is hidden from
# the user.
fake_mode = detect_fake_mode()
if fake_mode is None:
fake_mode = nullcontext()
def override_get_rng_state(device: Union[int, str, torch.device] = 'cuda'):
out = PhiloxStateTracker.get_state_as_tensor()
return out
def override_set_rng_state(x, device: Union[int, str, torch.device] = 'cuda'):
PhiloxStateTracker.set_state_from_tensor(x)
def append_rng_offsets(args):
if trace_joint:
# args signature before: Tuple(fwd_outputs), Tuple(bwd_outputs)
# args signature after: Tuple(fwd_outputs, new_fwd_rng_offset), Tuple(bwd_offset, new_bwd_rng_offset)
return ((*args[0], PhiloxStateTracker.get_updated_fwd_offset()),
(*args[1], PhiloxStateTracker.get_updated_bwd_offset()))
else:
# args signature before: Tuple(fwd_outputs)
# args signature after: Tuple(fwd_outputs, new_fwd_rng_offset)
return (*args, PhiloxStateTracker.get_updated_fwd_offset())
def traced_joint(primals, tangents, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset):
with patch("torch.cuda.get_rng_state", override_get_rng_state), patch("torch.cuda.set_rng_state", override_set_rng_state):
return append_rng_offsets(func(primals, tangents))
def traced_forward(*primals_fwd_seed_fwd_base_offset):
# The signature is (*primals, seed, offset)
with patch("torch.cuda.get_rng_state", override_get_rng_state), patch("torch.cuda.set_rng_state", override_set_rng_state):
return append_rng_offsets(func(*primals_fwd_seed_fwd_base_offset[:-2]))
if trace_joint:
# Get the current seed and offset to setup tracing.
fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
bwd_seed, bwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward")
PhiloxStateTracker.record_state(bwd_seed, bwd_base_offset, "backward")
return traced_joint, (*args, fwd_seed, fwd_base_offset, bwd_seed, bwd_base_offset)
else:
# Get the current seed and offset to setup tracing.
fwd_seed, fwd_base_offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
PhiloxStateTracker.record_state(fwd_seed, fwd_base_offset, "forward")
return traced_forward, (*args, fwd_seed, fwd_base_offset)
# Output structure:
# - List[Tensor] if tracing an inference graph
# - Tuple[List[Tensor], List[Tensor]] if tracing a joint graph.
# This function effectively concats each inner list of subclass tensors
# into a (potentially longer) list of inner tensors.
#
# This function takes in a pytree of arguments and unwraps any tensor subclasses.
# Annoyingly, we can't use pytrees to perform the unwrapping, because unwrapping returns
# a list of tensors that we would then need to concat together.
# Instead, we specialize the logic for the inference vs. joint graph case.
# NOTE: this function is hot, since we unwrap tensor subclass inputs at runtime
def unwrap_tensor_subclasses(wrapped_args, *, is_joint_structure: bool):
def concat_inner_tensors_from_subclasses(xs):
xs_inner = []
for x in xs:
if isinstance(x, torch.Tensor) and is_traceable_wrapper_subclass(x):
attrs, _ = x.__tensor_flatten__()
xs_inner += [getattr(x, attr) for attr in attrs]
else:
xs_inner += [x]
return xs_inner
if is_joint_structure:
assert isinstance(wrapped_args, tuple) and len(wrapped_args) == 2
assert isinstance(wrapped_args[0], (tuple, list)) and isinstance(wrapped_args[1], (tuple, list))
unwrapped_args_fw = concat_inner_tensors_from_subclasses(wrapped_args[0])
unwrapped_args_tangents = concat_inner_tensors_from_subclasses(wrapped_args[1])
unwrapped_args = (unwrapped_args_fw, unwrapped_args_tangents)
else:
assert isinstance(wrapped_args, (list, tuple))
unwrapped_args_fw = concat_inner_tensors_from_subclasses(wrapped_args)
unwrapped_args = unwrapped_args_fw
return unwrapped_args
# Turns a flattened list of tensor arguments into (maybe) subclass tensors.
# This function is used both at trace time and runtime, so we have an is_runtime flag telling us which context we're in.
def wrap_tensor_subclasses(
unwrapped_args: List[Any],
*,
subclass_metas: List[Union[int, SubclassCreationMeta]],
num_fw_outs_saved_for_bw: Optional[int] = None,
is_runtime: bool = False,
) -> List[Any]:
wrapped_args = []
num_args_tallied = 0
for subclass_meta in subclass_metas:
if isinstance(subclass_meta, int):
wrapped_args.append(unwrapped_args[subclass_meta])
num_args_tallied += 1
else:
assert isinstance(subclass_meta, SubclassCreationMeta)
wrapped_args.append(subclass_meta.creation_fn(unwrapped_args, is_runtime=is_runtime))
num_args_tallied += subclass_meta.arg_count
# Note: [Partitioner handling for Subclasses, Part 2]
# At the beginning of AOTAutograd, we collect metadata on the inputs and outputs of the user fw,
# to figure out which inputs/outputs are subclasses, and how to reconstruct the subclasses after flattening them.
#
# When this function is called at runtime in the forward,
# we have been passed a list of (flattened) dense-tensor fw-outs, and need to reconstruct any subclass fw outs.
#
# One reasonable question that you should ask: when should the dense_tensor -> subclass_tensor wrapping happen?
# Answer: we do it **inside of our compiled autograd.Function**.
# This seems like morally the right place: autograd happens above subclass desugaring,
# so autograd should see actual tensor subclasses at runtime, and not flattened dense tensors.
#
# This causes a tricky interaction though: when we run the min-cut partitioner to divvy up the joint graph
# into a forward and backward graph, we end up with some activations that show up as extra outputs
# in the compiled forward graph, that are **not** user outputs.
# These activations are not visible to the user, and so there's no need for us to wrap them back into subclasses.
#
# On top of that, when we first computed subclass metadata (in `run_functionalized_fw_and_collect_metadata`),
# we computed subclass metadata on every forward output, but this did **not** include activations
# created by the partitioner.
# as a result, `unwrapped_args` here will correspond to (*unwrapped_user_fw_outs, *activations),
# but `subclass_metas` will only correspond to subclass metatadata on `user_fw_outs`.
# We then need to make sure that we return (*wrapped_user_fw_outs, *activations).
if num_fw_outs_saved_for_bw is not None:
assert len(unwrapped_args) == num_args_tallied + num_fw_outs_saved_for_bw
activations = unwrapped_args[num_args_tallied:]
if isinstance(wrapped_args, tuple) and isinstance(activations, tuple):
return wrapped_args + activations
return tuple(list(wrapped_args) + list(activations))
else:
assert len(unwrapped_args) == num_args_tallied
return tuple(wrapped_args)
# Given a bunch of "dense" tensor arguments, this function (potentially) wraps them into tensor subclasses.
# This function carefully handles the inference vs. joint cases:
# - when is_joint_structure is True, args is (primals, tangents)
# - when is_joint_structure is False, args is [*primals]
def wrap_tensor_subclasses_maybe_joint(unwrapped_args, *, is_joint_structure: bool, meta: ViewAndMutationMeta) -> List[Any]:
# Since this function is re-used for both inference and joint graphs,
if is_joint_structure:
assert isinstance(unwrapped_args, tuple) and len(unwrapped_args) == 2
assert isinstance(unwrapped_args[0], (tuple, list)) and isinstance(unwrapped_args[1], (tuple, list))
primals, tangents = unwrapped_args[0], unwrapped_args[1]
wrapped_primals = wrap_tensor_subclasses(primals, subclass_metas=meta.subclass_inp_meta)
wrapped_tangents = wrap_tensor_subclasses(tangents, subclass_metas=meta.subclass_tangent_meta)
return (wrapped_primals, wrapped_tangents)
else:
wrapped_args = wrap_tensor_subclasses(unwrapped_args, subclass_metas=meta.subclass_inp_meta)
return wrapped_args
# This wrapper handles the AOTDispatch runtime logic for tensor subclasses.
# At runtime, we have a compiled function that knows how to operate on the domain of DenseTensor -> DenseTensor,
# But the user might have passed us some tensor subclass inputs (or expect some subclass tensor outputs).
# This function handles the wrapping and unwrapping of tensor subclasses at runtime.
def aot_dispatch_subclass_wrapper(
runtime_fn: Callable,
*,
subclass_metas: List[Union[int, SubclassCreationMeta]],
num_fw_outs_saved_for_bw: Optional[int],
) -> Callable:
def inner_fn(args):
unwrapped_args = unwrap_tensor_subclasses(args, is_joint_structure=False)
# expectation: runtime_fn is a boxed fn
unwrapped_outs = runtime_fn(unwrapped_args)
wrapped_outs = wrap_tensor_subclasses(
unwrapped_outs, subclass_metas=subclass_metas, num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw, is_runtime=True)
return wrapped_outs
# box it
inner_fn._boxed_call = True
return inner_fn
def create_metadata_for_subclass(meta: ViewAndMutationMeta) -> ViewAndMutationMeta:
# input infos
input_info = []
for inp, subclass_meta in zip(meta.input_info, meta.subclass_inp_meta):
num_inps = 1 if isinstance(subclass_meta, int) else subclass_meta.arg_count
for _ in range(num_inps):
input_info.append(inp)
# output infos
output_info = []
subclass_out_meta_user_outs_only = meta.subclass_fw_graph_out_meta[meta.num_mutated_data_inputs:]
if meta.num_intermediate_bases > 0:
subclass_out_meta_user_outs_only = subclass_out_meta_user_outs_only[:-meta.num_intermediate_bases]
# sanity assert
assert len(meta.output_info) == len(subclass_out_meta_user_outs_only)
# Assume that the information on the output is shared by all of its inner tensors.
for out, subclass_meta in zip(meta.output_info, subclass_out_meta_user_outs_only):
num_outs = 1 if isinstance(subclass_meta, int) else subclass_meta.arg_count
for _ in range(num_outs):
output_info.append(out)
# A bit hacky, but we don't actually care about all of the metadata here.
# This metadata is used **underneath** both autograd and subclass de-sugaring,
# So all we really care about is stuff like:
# - num inputs/outputs (needed by the partitioner)
# - input mutations (**not** used today, since we don't handle input mutations inside the subclass,
# although we should handle this eventually)
# TODO: add a test case to assert we error when this happens, instead of getting silent correctness
num_intermediate_bases = None
keep_input_mutations = meta.keep_input_mutations
traced_tangents = None
subclass_inp_meta = None
subclass_fw_graph_out_meta = None
subclass_tangent_meta = None
metadata = ViewAndMutationMeta(
input_info=input_info,
output_info=output_info,
num_intermediate_bases=num_intermediate_bases,
keep_input_mutations=keep_input_mutations,
traced_tangents=traced_tangents,
subclass_inp_meta=subclass_inp_meta,
subclass_fw_graph_out_meta=subclass_fw_graph_out_meta,
subclass_tangent_meta=subclass_tangent_meta,
)
return metadata
SubclassTracingInfo = collections.namedtuple("SubclassTracingInfo", [
"plain_tensor_trace_fn", "plain_tensor_args", "maybe_subclass_meta"
])
# Given a function operating on Subclass -> Subclass, returns an function that operates on Tensor -> Tensor
# Also returns:
# - the new set of arguments to pass into this function (now that tensor subclasses have been eliminated)
# - the updated ViewAndMutationMeta for this dense -> dense function.
# The other important arguments are:
# - flat_fn_maybe_joint: when is_joint_structure=True, this is the joint fw-bw function.
# when is_joint_structure=False, this is just the forward function.
# - fw_only: this is *always* the forward-only function.
# Why do we need this? We need to collect updated ViewAndMutationMeta on our new dense -> dense functions.
# In particular, we need this to tell the partitioner how many dense forward outputs there are.
def aot_dispatch_subclass(
flat_fn_maybe_joint,
args: List[Any],
*,
is_joint_structure: bool,
meta: ViewAndMutationMeta,
fw_only: Callable,
) -> "SubclassTracingInfo":
# Skip logic if we don't need to trace through any subclasses
req_subclass_dispatch = requires_subclass_dispatch(args, meta)
if not req_subclass_dispatch:
return SubclassTracingInfo(
plain_tensor_trace_fn=flat_fn_maybe_joint,
plain_tensor_args=args,
maybe_subclass_meta=None,
)
# TODO: add subclass guards (later PR).
# What's going on here? We need to compute subclass metadata about the outputs of the joint (grad_inputs).
# Annoying: we don't know the grad input metas until we're in the middle of tracing the joint,
# so we set it later, while we're tracing the joint (see inner_fn() below).
# Another option would be to run our run_functionalized_fw_and_collect_metadata() function
# directly on the joint, but this would hurt compile time (adding yet another pass through the joint).
subclass_meta = SubclassMeta()
def inner_fn(fn, args, *, use_trace_joint: bool):
# Step 1: wrap tensor inputs into subclasses if necessary
all_args = wrap_tensor_subclasses_maybe_joint(args, is_joint_structure=use_trace_joint, meta=meta)
# Step 2: call the inner function, with our (maybe subclass) inputs
wrapped_outs = fn(*all_args)
if use_trace_joint:
# See Note: [Computing Subclass Metadata about grad_inputs]
# We also stash subclass info on our grad_inputs, if we're tracing the joint.
nonlocal subclass_meta
assert isinstance(wrapped_outs, tuple) and len(wrapped_outs) == 2
# Don't need fw outs since we already have subclass metadata on them
grad_inputs = wrapped_outs[1]
subclass_meta.grad_input_metas = create_subclass_meta(grad_inputs)
# Step 3: Unwrap any subclass outputs back into dense tensors
unwrapped_outs = unwrap_tensor_subclasses(wrapped_outs, is_joint_structure=use_trace_joint)
return unwrapped_outs
def joint_fn(primals, tangents):
return inner_fn(flat_fn_maybe_joint, (primals, tangents), use_trace_joint=True)
def fw_fn(*primals):
return inner_fn(flat_fn_maybe_joint, primals, use_trace_joint=False)
def metadata_fn(*primals):
return inner_fn(fw_only, primals, use_trace_joint=False)
args_unwrapped = unwrap_tensor_subclasses(args, is_joint_structure=is_joint_structure)
if is_joint_structure:
primals_unwrapped = args_unwrapped[0]
fn_to_trace = joint_fn
else:
primals_unwrapped = args_unwrapped
fn_to_trace = fw_fn
# Note: [Partitioner handling for Subclasses, Part 1]
# The way the partitioner works is that:
# (1) we pass is a single graph containing the joint fw/bw,
# where the # of graph outputs corresponds to # fw_outputs + # grad_inputs
# (2) The partitioner accepts an arguments, num_fwd_outputs,
# and assumes that the first "num_fwd_outputs" graph outputs correspond
# to outputs of the forward graph.
# How do tensor subclasses enter the picture?
# the num_fwd_outputs in the final graph is actually non-trivial to compute,
# because it can be influenced by input mutations and intermediate bases.
# So we compute it by inspecting the current ViewAndMutationMeta object.
# However, the original ViewAndMutationMeta that we computed was created
# on the subclass -> subclass graph,
# which can have a different number of outputs than the dense -> dense graph.
# That's why we createa a fresh metadata object on the dense -> dense function here,
# and plumb it back up to the partitioner.
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
meta_updated = run_functionalized_fw_and_collect_metadata(
metadata_fn,
keep_input_mutations=meta.keep_input_mutations,
is_train=meta.is_train,
)(*primals_unwrapped)
subclass_meta.fw_metadata = meta_updated
return SubclassTracingInfo(
plain_tensor_trace_fn=fn_to_trace,
plain_tensor_args=args_unwrapped,
maybe_subclass_meta=subclass_meta,
)
# Has the precondition that there
# are no duplicate arguments in flat_args (e.g., the same Tensor
# object never shows up twice. However, two tensor inputs MAY alias
# the same storage, so long as they have separate TensorImpls.)
def aot_dispatch_autograd_graph(flat_fn, flat_args: List[Any], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
# traced_tangents corresponds to the set of outputs in the traced forward that should get grad_outputs in the traced backward.
# It includes outputs of the original forward, *and* any updated inputs due to input mutations.
# However, it does *not* include any outputs that are aliases of inputs or intermediates, or any metadata-only input mutations.
traced_tangents = pytree.tree_map(
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
fw_metadata.traced_tangents,
)
joint_inputs = (flat_args, traced_tangents)
fn_prepared_for_autograd = fn_prepped_for_autograd(
flat_fn,
fw_metadata,
)
joint_fn_to_trace = create_joint(fn_prepared_for_autograd, aot_config=aot_config)
joint_fn_to_trace, updated_joint_inputs = create_functionalized_fn(
joint_fn_to_trace,
joint_inputs,
meta=fw_metadata,
aot_config=aot_config,
trace_joint=True,
)
subclass_tracing_info = aot_dispatch_subclass(
joint_fn_to_trace, updated_joint_inputs, is_joint_structure=True, meta=fw_metadata, fw_only=flat_fn)
joint_fn_to_trace = subclass_tracing_info.plain_tensor_trace_fn
updated_joint_inputs = subclass_tracing_info.plain_tensor_args
maybe_subclass_meta = subclass_tracing_info.maybe_subclass_meta
fx_g = create_graph(joint_fn_to_trace, updated_joint_inputs, aot_config=aot_config)
# There should be *NO* mutating ops in the graph at this point.
assert_functional_graph(fx_g.graph)
# Redundant with the check above, but worth having in case tracing introduced
# a fake tensor. Unlikely.
# See Note: [Fake Modules and AOTAutograd]
torch._dynamo.utils.assert_no_fake_params_or_buffers(fx_g)
fx_g.graph.eliminate_dead_code()
fx_g.recompile()
# TODO: in AOTAutograd, we create metadata like _indices_of_inps_to_detach to detect
# when we need to manually detach() some inputs in the forward.
# Higher order ops might eventually need to do the same.
if aot_config.is_export:
assert maybe_subclass_meta is None, "aot_export_module does not support tensor subclass inputs for now."
return fx_g
return fx_g, updated_joint_inputs, maybe_subclass_meta
def aot_dispatch_autograd(flat_fn, flat_args: List[Any], aot_config: AOTConfig, *, fw_metadata: ViewAndMutationMeta):
fx_g, joint_inputs, maybe_subclass_meta = aot_dispatch_autograd_graph(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
# Copied from aot_dispatch_autograd_graph.
traced_tangents = pytree.tree_map(
lambda x: x.detach().contiguous() if isinstance(x, Tensor) else x,
fw_metadata.traced_tangents,
)
disable_amp = torch._C._is_any_autocast_enabled()
if aot_config.enable_log:
aot_joint_log.info("%s", lazy_format_graph_code("Joint graph", fx_g, aot_config.aot_id))
with torch.no_grad():
inner_meta = fw_metadata if maybe_subclass_meta is None else maybe_subclass_meta.fw_metadata
with track_graph_compiling(aot_config, "joint"):
# See Note: [Partitioner handling for Subclasses, Part 1]
num_inner_fwd_outputs = (
inner_meta.num_mutated_inputs
+ inner_meta.num_outputs
+ inner_meta.num_intermediate_bases
+ inner_meta.num_outputs_rng_offset
)
fw_module, bw_module = aot_config.partition_fn(
fx_g, joint_inputs, num_fwd_outputs=num_inner_fwd_outputs
)
fw_outs = next(n for n in fw_module.graph.nodes if n.op == "output").args[0]
# we only need to bookkeep the symints that are saved for bw, not any symints
# the user forward might have returned in its own output
fw_outs_saved_for_bw = fw_outs[num_inner_fwd_outputs:]
num_fw_outs_saved_for_bw = len(fw_outs_saved_for_bw)
symint_outs_saved_for_bw = [
n for n in fw_outs_saved_for_bw if is_sym_node(n)
]
fw_metadata.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
inner_meta.num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
_num_symints_saved_for_bw = len(symint_outs_saved_for_bw)
# Note [Detaching inputs that never need gradients]
# See https://github.com/pytorch/pytorch/issues/97745
# Suppose we have a function like this that we want to compile:
#
# def f(x, y):
# return torch.mul(x, y.detach())
#
# What gradients should we compute for x and y?
# By default, AOTAutograd will compute a gradient for **every** input that requires gradients,
# and so we'll compute:
# x_grad_input = y
# y_grad_input = None
# Does this preserve the semantics of eager mode?
# Unfortunately, no.
# Doing the above will cause autograd to **continue** to backprop the autograd tape
# that was generated from constructing y.
#
# This is **different** from what would have happened in eager mode.
# In eager mode, if we backprop through the output of this function, autograd will only traverse
# the bit of the autograd tape corresponding to "x".
# In particular, if a user had previously backpropped through y's autograd tape,
# And then they try to backprop through the output of the above function,
# then we'll hit the dreaded "Trying to backward through the graph a second time" error.
#
# You might think: If autograd sees that a gradient is None, shouldn't it stop early,
# instead of continuing the backprop through the ancestors of that node in the graph?
#
# Autograd has two passes:
# (1) a first pass that traverses the autograd graph and figures out which nodes need to be executed
# (2) a second pass that actually goes ahead and executes each node when it becomes ready,
# propagating gradients
# By the time we're executing a node and we see that it produces a None, the set of nodes to execute
# is already locked-in.
#
# The fix: instead, we can recognize statically that the graph we're compiling will never contribute
# gradients to y, and prevent autograd from trying to traverse y's autograd tape at all.
# We can do this by manually detach'ing y before sending it through the `CompiledFunction`.
#
# Note that this solution is not bulletproof.
# It's possible to construct a case where eager may or may not have have tried to autograd through y,
# depending on the actual grad_outputs that were passed in during the backward.
# There is no easy fix for this: the simplest fix would be to run with `retain_graph=True`,
# allowing autograd to re-use the graph.
#
# An example of this case is:
# def f(x):
# return x.detach() * 2, x * 3
# If we were to only backprop through outs[0], in eager, we would stop
# If we backward only on the first output, we shouldn't send a grad through x.
# But the custom autograd function doesn't know that: it will materialize zero grads for x * 3
# and we will end up with a zero grad at x.
# If we later backprop through the second output, this will also require backprop'ing through x.
# Meaning we'll need to use `retain_graph=True` to be able to backprop through x the second time.
_indices_of_inps_to_detach = []
bw_outs = next(n for n in bw_module.graph.nodes if n.op == "output").args[0]
# TODO: we should apply the below "detach inputs if their gradients are statically known to be None"
# optimization even if we have subclass inputs/outputs (we do not handle this today).
# Computing which our our inputs get None gradients is a bit more complicated,
# if any of our inputs are subclasses. Why?
# (a) we need to make sure that we call .detach() on the input subclasses, since autograd sees subclasses.
# (b) The grad_outputs that we AOT computed in our backward graph are the desugared tensor tensors,
# so we need to figure out which subclass fw inputs they map to.
if maybe_subclass_meta is None:
assert len(bw_outs) == len(fw_metadata.input_info) + inner_meta.num_outputs_rng_offset
for i, (bw_out) in enumerate(bw_outs):
if bw_out is None:
_indices_of_inps_to_detach.append(i)
if aot_config.enable_log:
aot_graphs_log.info("%s", lazy_format_graph_code("Forward graph", fw_module, aot_config.aot_id))
aot_graphs_log.info("%s", lazy_format_graph_code("Backward graph", bw_module, aot_config.aot_id))
with track_graph_compiling(aot_config, "forward"):
# flat_args at this point might still be subclasses-
# make sure to pass the unwrapped fake tensors into the compiler!
adjusted_flat_args = joint_inputs[0]
if config.functionalize_rng_ops:
# Update example inputs for the fw_compiler
fake_mode = detect_fake_mode()
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple(fake_mode)
adjusted_flat_args.extend([seed, offset])
# We are not clearing flat_args here because
# 1) There is a check in the debug compiler at the end
# 2) It does not matter as these are fake tensors
if torch._guards.TracingContext.get():
torch._guards.TracingContext.get().fw_metadata = inner_meta
with TracingContext.report_output_strides() as fwd_output_strides:
compiled_fw_func = aot_config.fw_compiler(
fw_module, adjusted_flat_args
)
if not hasattr(compiled_fw_func, "_boxed_call"):
compiled_fw_func = make_boxed_func(compiled_fw_func)
if maybe_subclass_meta is not None:
# Why do we need to pass in num_fw_outs_saved_for_bw?
# See Note: [Partitioner handling for Subclasses, Part 2]
compiled_fw_func = aot_dispatch_subclass_wrapper(
compiled_fw_func,
subclass_metas=fw_metadata.subclass_fw_graph_out_meta,
num_fw_outs_saved_for_bw=num_fw_outs_saved_for_bw
)
if not hasattr(compiled_fw_func, "_boxed_call"):
compiled_fw_func = make_boxed_func(compiled_fw_func)
# NB: It's important to compile backwards ahead of time, as this may
# add extra guards which we need to apply to the Dynamo cache at
# forwards
with track_graph_compiling(aot_config, "backward"):
placeholder_list = fx_placeholder_vals(bw_module)
forward_saved_for_backwards_strides = None
if fwd_output_strides is not None:
forward_saved_for_backwards_strides = fwd_output_strides[inner_meta.tensors_saved_for_backwards_slice]
# saved activations can have different stride to eager if
# the compiler does layout optimization. We should restride the
# tensor passed in for compiling the backward graph using the
# saved tensor's stride.
for i in range(len(placeholder_list)):
ph_arg = placeholder_list[i]
if not isinstance(ph_arg, torch.Tensor):
continue
if forward_saved_for_backwards_strides is None:
continue
real_stride = None
# Per all_args calling convention
j = i - len(symint_outs_saved_for_bw)
if 0 <= j < len(forward_saved_for_backwards_strides):
real_stride = forward_saved_for_backwards_strides[j]
if real_stride is None:
continue
# Comparing ph_arg.stride() with real_stride directly may
# cause dynamic dimensions in ph_arg being specialized to static
# value. Using the hints to avoid that.
if _get_hints(ph_arg.stride()) != real_stride:
# Note that here we use the stride of the real tensor to
# restride a FakeTensor. This does not cause trouble
# for dynamic shape since this code path only get
# executed if layout optimization is enabled. And we
# disable layout optimization for dynamic shape right
# now.
#
# A solution that decide stride order based on real
# tensor's stride and then apply that stride order to
# the FakeTensor does not work smoothly since some
# tensor's layout is not 'dense'. E.g. mixnet_l has a
# tensor with size [8, 64, 112, 112] and strides
# (2408448, 1, 21504, 192). The solution mentioned will
# decide a stride of (802816, 1, 7168, 64) for this
# tensor which is wrong.
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
compiled_bw_func = None
if len(symint_outs_saved_for_bw):
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context():
try:
compiled_bw_func = aot_config.bw_compiler(
bw_module, placeholder_list
)
except Exception:
log.warning(
"failed to eagerly compile backwards for dynamic, suppressing in case backwards not needed",
exc_info=True
)
saved_context = TracingContext.get()
class CompiledFunction(torch.autograd.Function):
compiled_fw = compiled_fw_func
compiled_bw = compiled_bw_func
metadata = fw_metadata
maybe_subclass_metadata: Optional[SubclassMeta] = maybe_subclass_meta
num_symints_saved_for_bw = _num_symints_saved_for_bw
@staticmethod
def _compiled_autograd_key(ctx):
return (aot_config.aot_id, *ctx.symints)
@staticmethod
def forward(ctx, *deduped_flat_tensor_args):
args = deduped_flat_tensor_args
if CompiledFunction.metadata.is_rng_op_functionalized:
# Add the seed and offset to args
seed, offset = CUDARngStateHelper.get_torch_state_as_tuple()
args = (*args, seed, offset)
# There is a pretty complicated calling convention around what the compiled fw returns.
# The full list of outputs and their relative order is:
# (*mutated_inputs, *fw_outs, *fw_intermediate_bases, *saved_tensors, *saved_symints)
# - Note that in the synthetic bases case, mutated_inputs will correspond to an updated version
# of the original view, and not the synthetic base
fw_outs = call_func_with_args(
CompiledFunction.compiled_fw,
args,
disable_amp=disable_amp,
)
num_outputs = CompiledFunction.metadata.num_outputs
num_outputs_aliased = CompiledFunction.metadata.num_outputs_aliased
num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases
num_symints_saved_for_bw = CompiledFunction.num_symints_saved_for_bw
num_mutated_inputs = CompiledFunction.metadata.num_mutated_inputs
num_mutated_metadata_only_inputs = (
CompiledFunction.metadata.num_mutated_metadata_only_inputs
)
num_forward_returns = CompiledFunction.metadata.num_forward_returns
num_forward = CompiledFunction.metadata.num_forward
# Partitioners must put symint arguments at the end separate from tensor arguments
tensors_saved_for_backwards = fw_outs[
CompiledFunction.metadata.tensors_saved_for_backwards_slice
]
assert all(
isinstance(x, torch.Tensor) for x in tensors_saved_for_backwards
)
# See Note [Detaching saved tensors in AOTAutograd]
ctx.save_for_backward(*(x.detach() if x._is_view() else x for x in tensors_saved_for_backwards))
symint_outs = fw_outs[CompiledFunction.metadata.symints_saved_for_backwards_slice]
assert all(
isinstance(x, (int, float, torch.SymInt, torch.SymFloat))
for x in symint_outs
), str([type(x) for x in symint_outs])
ctx.symints = symint_outs
raw_returns = fw_outs[0:num_forward_returns]
# Wrap all autograd.Function.forward() outputs that are aliases
# so that autograd.Function doesn't treat them as tensors
if num_mutated_metadata_only_inputs > 0:
for i, idx in enumerate(
CompiledFunction.metadata.mutated_inp_indices
):
# We could make this faster by only looping over inputs with metadata-only mutations
# (instead of looping over inputs with either data or metadata mutations), but there shouldn't be many.
info = CompiledFunction.metadata.input_info[idx]
if info.mutates_metadata and not info.mutates_data:
raw_returns[i] = TensorAlias(raw_returns[i])
if config.debug_assert:
user_mutated_inputs_raw = raw_returns[0:num_mutated_inputs]
mut_inp_infos = [
x for x in CompiledFunction.metadata.input_info if x.mutates_data or x.mutates_metadata
]
assert len(user_mutated_inputs_raw) == len(mut_inp_infos)
if CompiledFunction.metadata.num_unsafe_view_outputs > 0:
for idx in CompiledFunction.metadata.unsafe_view_out_indices:
raw_return_idx = num_mutated_inputs + idx
o = raw_returns[raw_return_idx]
raw_returns[raw_return_idx] = torch.ops.aten._unsafe_view(o, o.shape)
if num_outputs_aliased > 0:
for idx in CompiledFunction.metadata.aliased_out_indices:
raw_return_idx = num_mutated_inputs + idx
raw_returns[raw_return_idx] = TensorAlias(raw_returns[raw_return_idx])
if config.debug_assert:
intermediates_raw = raw_returns[num_mutated_inputs + num_outputs:]
assert not any(isinstance(x, TensorAlias) for x in intermediates_raw)
# invariant: intermediate bases always require gradients, so we don't have to
# consider marking them as non-differentiable.
raw_returns_not_including_intermediate_bases = raw_returns[:num_mutated_inputs + num_outputs]
raw_returns_meta = (
[x for x in CompiledFunction.metadata.input_info if x.mutates_data or x.mutates_metadata]
+ CompiledFunction.metadata.output_info
)
fw_outs_not_requiring_grad = [
x
for (i, x) in enumerate(raw_returns_not_including_intermediate_bases)
if isinstance(x, torch.Tensor)
and not raw_returns_meta[i].requires_grad
]
ctx.mark_non_differentiable(*fw_outs_not_requiring_grad)
ctx._materialize_non_diff_grads = False
functionalized_rng_runtime_epilogue(
CompiledFunction.metadata,
fw_outs[num_forward_returns:num_forward],
return_new_outs=False
)
return tuple(raw_returns)
@staticmethod
def backward(ctx, *flat_args):
# Calling convention: we expect a grad_out passed to the backward:
# - for every output of the fw that does *not* alias an input or graph intermediate
# - for every updated_input generated by the fw that does *not* alias an input (aka only data-mutations)
# - for every graph intermediate that we need to use to generate an output later.
# The other outputs in the autograd.Function.forward that do *not* show up in the backward include:
# - outputs that alias inputs or graph intermediates
# - updated inputs due to metadata-only mutations.
# We need to return them in the forward, but ensure that they all do not get gradients in the backward,
# and we filter them out here before passing the remaining grad_outputs into the compiled backward.
num_mutated_inps = CompiledFunction.metadata.num_mutated_inputs
num_intermediate_bases = CompiledFunction.metadata.num_intermediate_bases
expected_grad_outs = (
CompiledFunction.metadata.num_outputs + num_mutated_inps + num_intermediate_bases
)
assert len(flat_args) == expected_grad_outs
out_info = CompiledFunction.metadata.output_info
inp_tangents, out_tangents, intermediate_base_tangents = (
flat_args[0:num_mutated_inps],
flat_args[num_mutated_inps:num_mutated_inps + CompiledFunction.metadata.num_outputs],
flat_args[num_mutated_inps + CompiledFunction.metadata.num_outputs:],
)
# input_info contains info on *every* input,
# But in the backward(), we are only given grad outputs for every mutated input
# We then need to filter out the grad outputs that correspond to metadata-only mutations or don't require grad
mutated_inp_indices = CompiledFunction.metadata.mutated_inp_indices
input_info = CompiledFunction.metadata.input_info
assert len(inp_tangents) == len(mutated_inp_indices)
inp_tangents_filtered = [
x
for x, info_idx in zip(inp_tangents, mutated_inp_indices)
if input_info[info_idx].mutates_data and input_info[info_idx].requires_grad
]
# We also need to filter out grad outputs that correspond to outputs aliasing inputs/intermediates
out_tangents_filtered = [
x
for x, info in zip(out_tangents, out_info)
if info.output_type in [OutputType.non_alias, OutputType.unsafe_view_alias, OutputType.custom_function_view]
and issubclass(info.raw_type, torch.Tensor)
and info.requires_grad
]
# intermediate bases always require gradients, and always participate in the backward graph.
flat_bw_args_with_grads = [*inp_tangents_filtered, *out_tangents_filtered, *intermediate_base_tangents]
num_flat_bw_args_with_grads = len(flat_bw_args_with_grads)
# sanity asserts
# metadata_only_inps = [
# x for x, info_idx in zip(inp_tangents, mutated_inp_indices)
# if not input_info[info_idx].mutates_data
# ]
# aliased_outputs = [
# x for x, info in zip(out_tangents, out_info) if info.output_type != OutputType.non_alias]
# assert all(x is None for x in metadata_only_inps)
# assert all(x is None for x in aliased_outputs)
rng_args = []
if CompiledFunction.metadata.is_rng_op_functionalized:
# Add the seed and offset to args
rng_args = CUDARngStateHelper.get_torch_state_as_tuple()
all_args = [
*ctx.symints,
*ctx.saved_tensors,
*flat_bw_args_with_grads,
*rng_args
]
del flat_bw_args_with_grads
tangents_start_idx = len(all_args) - num_flat_bw_args_with_grads - len(rng_args)
tangents_end_idx = len(all_args) - len(rng_args)
# Note: [AOTAutograd Backward Guards]
# During AOTDispatch, we eagerly create and trace out a joint fw-bw graph.
# Doing so requires us to "guess" about some of the metadata of our grad_outputs.
#
# In particular: if an output to the forward is a plain tensor or a subclass,
# its corresponding grad_output in the backward **may or may not** be
# a plain tensor or a subclass. The main cases are:
# (1) If an output is a plain tensor, its grad_out will also be a plain tensor,
# *unless* the output is used in some subclass compute later in the forward graph,
# which will cause its grad_output to become a subclass
# (2) If an output is a subclass, its grad_out will also be a subclass,
# *unless* the output of the forward did not actually participate in the gradient computation,
# in which case autograd will insert a plain tensor of zeros for the grad_output.
# We could avoid this case with `torch.autograd.Function.set_materialize_grads`,
# although this is not turned on today in AOTAutgrad and would require more work.
#
# Today, we make a guess on subclass-ness based on the above examples,
# and hard-error in the backward if we guessed wrong.
#
# In the future, we should add backward guards that would allow us to
# properly handle this case instead of erroring: we would need to retrace the backward graph,
# since we might produce an entirely different trace if our grad_outputs are subclass or not.
assert len(CompiledFunction.metadata.output_types) == num_flat_bw_args_with_grads
grad_output_types = [type(x) for x in all_args[-num_flat_bw_args_with_grads:]]
# In general, we can add more asserts/guards here for when we partitioned
# with incorrect assumptions about the grad_outputs.
# Normalize FakeTensor -> torch.Tensor
# - during tracing our types are FakeTensor
# - at runtime in the backward our types are torch.Tensor...
# - unless we're running compiled backward, in which case they are also FakeTensor
grad_output_types_ = [torch.Tensor if x is FakeTensor else x for x in grad_output_types]
assert grad_output_types_ == CompiledFunction.metadata.output_types, f"""\
We incorrectly attempted to compile the backward with incorrect subclass metadata.
If you run into this error, please file an issue.
Expected grad_output types: {str(CompiledFunction.metadata.output_types)}
Got grad_output types: {str(grad_output_types)}"""
# TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
if CompiledFunction.maybe_subclass_metadata is not None:
# Get the number of tangents after unwrapping
len_tangents = len(unwrap_tensor_subclasses(
all_args[tangents_start_idx: tangents_end_idx], is_joint_structure=False
))
all_args = unwrap_tensor_subclasses(all_args, is_joint_structure=False)
tangents_start_idx = len(all_args) - len_tangents - len(rng_args)
tangents_end_idx = tangents_start_idx + len_tangents
# Make the tangents contiguous. Note that we must do this after subclass desugaring
# because inputs to inductor have to be contiguous
all_args = [
t.contiguous() if tangents_start_idx <= i < tangents_end_idx else t
for i, t in enumerate(all_args)
]
def call_compiled_backward():
if ctx._is_compiled_autograd_tracing():
# For compiled autograd, run raw FX graph so that it can be inlined into the larger graph
symints = ctx._get_compiled_autograd_symints()
assert len(symints) == len(ctx.symints)
all_args[:len(symints)] = symints
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context():
out = normalize_as_list(bw_module(*all_args))
out = functionalized_rng_runtime_epilogue(CompiledFunction.metadata, out)
return tuple(out)
ctx.maybe_clear_saved_tensors()
if CompiledFunction.compiled_bw is None:
context = torch._C._DisableAutocast if disable_amp else nullcontext
with tracing(saved_context), context(), track_graph_compiling(aot_config, "backward"):
CompiledFunction.compiled_bw = aot_config.bw_compiler(
bw_module, placeholder_list
)
out = call_func_with_args(
CompiledFunction.compiled_bw,
all_args,
steal_args=True,
disable_amp=disable_amp,
)
out = functionalized_rng_runtime_epilogue(CompiledFunction.metadata, out)
return tuple(out)
if torch.is_grad_enabled() and any(t.requires_grad for t in all_args if isinstance(t, torch.Tensor)):
# Ensure that the graph is connected, and error if double backward is performed.
# See comment for why once_differentiable is not sufficient:
# https://github.com/pytorch/pytorch/pull/92348/files#r1072962107
class CompiledFunctionBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, *unused_args):
outs = call_compiled_backward()
# TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
if CompiledFunction.maybe_subclass_metadata is not None:
outs_wrapped = wrap_tensor_subclasses(
outs, subclass_metas=CompiledFunction.maybe_subclass_metadata.grad_input_metas)
return outs_wrapped
return outs
@staticmethod
def backward(ctx, *args):
raise RuntimeError("torch.compile with aot_autograd does not currently support double backward")
CompiledFunctionBackward._compiled_autograd_key = CompiledFunction._compiled_autograd_key
# Pass args even though they're unused, so that the graph is built
out = CompiledFunctionBackward.apply(*all_args)
else:
out = call_compiled_backward()
# TODO: figure out how to refactor the backward properly so I can use aot_dispatch_subclass_wrapper() here.
if CompiledFunction.maybe_subclass_metadata is not None:
outs_wrapped = wrap_tensor_subclasses(
out, subclass_metas=CompiledFunction.maybe_subclass_metadata.grad_input_metas)
return outs_wrapped
return out
compiled_function = create_runtime_wrapper(
CompiledFunction.apply,
runtime_metadata=fw_metadata,
indices_of_inps_to_detach=_indices_of_inps_to_detach,
trace_joint=True,
keep_input_mutations=False,
disable_amp=disable_amp
)
if not config.debug_assert:
return compiled_function
flat_requires_grad = [
a.requires_grad if isinstance(a, Tensor) else None for a in flat_args
]
@wraps(compiled_function)
def debug_compiled_function(*args):
# TODO: Check aliasing relationships
# TODO: Check strides for metadata mutation
# (NB: ideally, this logic is factored out of this function and
# you move these debug checks there)
# Check requires grad. Bad case is when we compiled with
# requires_grad = False, but input requires_grad = True
# (vice versa is OK; we compute a gradient and then throw
# it away when it hits the input.)
for i, a in enumerate(args):
can_require_grad = flat_requires_grad[i]
if can_require_grad is None:
assert not isinstance(a, Tensor)
elif not can_require_grad:
assert not a.requires_grad, format_guard_bug_msg(
aot_config,
f"{describe_input(i, aot_config)} would not require grad",
)
return compiled_function(*args)
return debug_compiled_function
@dynamo_timed
def create_aot_dispatcher_function(
flat_fn, flat_args: List[Any], aot_config: AOTConfig
):
"""
Traces the forward and backward graphs of the attr:`flat_fn` to generate a
joint graph. The joint graph is an Fx graph with Aten ops. Please refer to
the tracing mechanism to understand the graph capturing details.
The joint graph is then passed through attr:`partition_fn` to isolate the
forward and backward portions, which are then respectively compiled via the
provided attr:`fw_compiler` and attr:`bw_compiler`.
The resulting compiled forward and backward graphs are then wrapped up in a
``torch.autograd.Function`` object.
The calling convention here is that the first aot_config.num_params_buffers
inputs in flat_args are parameters and buffers, and the rest are inputs.
We use this to assume that parameters/buffer's shapes don't change.
Note: this function is used both by aot_function and aot_export (controlled by aot_config.is_export)
When aot_config.is_export is True, we return an FX graph + metadata
When aot_config.is_export is False, we return an ordinary runtime function
"""
# This is the main entry point.
# TODO: Chillee argues that dynamo itself should pass in fake tensors to
# the list of arguments when compiling; at the moment we do not do this
if aot_config.decompositions is None:
aot_config.decompositions = {}
aot_config.decompositions = {
**aot_autograd_decompositions,
**aot_config.decompositions,
}
if config.functionalize_rng_ops:
# Update the decompositions with functionalized random decompositions
aot_config.decompositions = {
**rng_decompositions,
**aot_config.decompositions,
}
# Check flat_args to see if they're already fake. If so, use that fake
# mode instead.
fake_mode = detect_fake_mode(flat_args)
if fake_mode is None:
shape_env = ShapeEnv() if aot_config.dynamic_shapes else None
fake_mode = FakeTensorMode(shape_env=shape_env)
else:
shape_env = fake_mode.shape_env
python_dispatcher_mode = (
enable_python_dispatcher() if shape_env is not None else nullcontext()
)
with torch.autograd.set_multithreading_enabled(
False
), preserve_rng_state(), fake_mode, python_dispatcher_mode, PhiloxStateTracker():
def process_inputs(flat_args):
def convert(idx, x):
if shape_env is not None:
from torch._dynamo.source import ConstantSource
if isinstance(x, int):
source = ConstantSource(f"sym_{idx}")
return shape_env.create_symintnode(
shape_env.create_symbol(x, source),
hint=x,
source=source
)
if not isinstance(x, torch.Tensor):
return x
if isinstance(x, FakeTensor):
assert x.fake_mode is fake_mode
return x
if is_traceable_wrapper_subclass(x):
attrs, _ = x.__tensor_flatten__()
if all(isinstance(getattr(x, attr), FakeTensor) for attr in attrs):
assert all(getattr(x, attr).fake_mode is fake_mode for attr in attrs)
return x
# TODO: Ensure that this codepath is never exercised from
# Dynamo
if (
idx < aot_config.num_params_buffers
and config.static_weight_shapes
):
return fake_mode.from_tensor(x, static_shapes=True)
return fake_mode.from_tensor(x, static_shapes=False)
return [convert(idx, x) for idx, x in enumerate(flat_args)]
fake_flat_args = process_inputs(flat_args)
needs_autograd = (
any(x.requires_grad for x in fake_flat_args if isinstance(x, Tensor))
and torch.is_grad_enabled()
)
with enable_python_dispatcher():
# Patch set_rng_state as set_rng_state with fake tensors is
# nonsensical. This does not affect the collection of metadata.
with patch("torch.cuda.set_rng_state", lambda *args: None):
fw_metadata = run_functionalized_fw_and_collect_metadata(
flat_fn,
keep_input_mutations=aot_config.keep_inference_input_mutations and not needs_autograd,
is_train=needs_autograd,
)(*fake_flat_args)
req_subclass_dispatch = requires_subclass_dispatch(fake_flat_args, fw_metadata)
if needs_autograd and not any(x.requires_grad for x in fw_metadata.output_info):
# We realized that none of the outputs require grad,
# so we actually have an inference graph.
needs_autograd = False
# A bit silly: right now in the subclass codepath, our ViewAndMutationMeta
# changes depending on whether we pass in is_train / keep_input_mutations,
# so we're forced to recompute the metadata.
# TODO: refactor the subclass path of run_functionalized_fw_and_collect_metadata
# so that this is unnecessary.
if req_subclass_dispatch:
fw_metadata = run_functionalized_fw_and_collect_metadata(
flat_fn,
keep_input_mutations=aot_config.keep_inference_input_mutations and not needs_autograd,
is_train=needs_autograd,
)(*fake_flat_args)
else:
fw_metadata = ViewAndMutationMeta(
input_info=fw_metadata.input_info,
output_info=fw_metadata.output_info,
num_intermediate_bases=fw_metadata.num_intermediate_bases,
keep_input_mutations=aot_config.keep_inference_input_mutations and not needs_autograd,
traced_tangents=fw_metadata.traced_tangents,
subclass_inp_meta=fw_metadata.subclass_inp_meta,
subclass_fw_graph_out_meta=fw_metadata.subclass_fw_graph_out_meta,
subclass_tangent_meta=fw_metadata.subclass_tangent_meta,
is_train=needs_autograd,
)
if fw_metadata.num_intermediate_bases > 0:
assert not req_subclass_dispatch, f"""\
torch.compile is currently being used with tensor subclass inputs:
{','.join([str(type(x)) for x in fake_flat_args])}. We are attempting to a compile a graph with two graph outputs
that alias one another, which is currently unsupported in the subclass use case. If you run into this,
please file a github issue"""
if aot_config.is_export:
# aot_export: ban input metadata mutations for now to keep shared code paths simpler.
# Keeping .resize_() in the graph will require some work
# Allowing it but keeping the graph functional will require some calling convention changes.
if len([x for x in fw_metadata.input_info if x.mutates_metadata]) != 0:
raise RuntimeError(f"""\
Found an input that received a metadata mutation, through e.g. a call to `.resize_()` or `.transpose_()`.
This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
fw_metadata={str(fw_metadata)}""")
# In export, banning data mutations on inputs that require grad for now.
# This should be rare, and is tricky to get right. When we trace the backward,
# we currently trace with autograd.grad instead of .backward(), which makes it difficult
# to ensure that we run autograd all the way through the input **before** it saw the mutation.
if len([x for x in fw_metadata.input_info if x.requires_grad and x.mutates_data]) != 0:
raise RuntimeError(f"""\
Found a graph input that requires gradients, and received a mutation.
This is currently banned in the aot_export workflow. If you need this functionality, please file a github issue.
fw_metadata={str(fw_metadata)}""")
if req_subclass_dispatch:
raise RuntimeError("""\
aot_export is not currently supported with traceable tensor subclass.
If you need this feature, please comment on <CREATE_ISSUE_LINK>""")
# Need to decide on a strategy for functionalized RNG: toggling via global config seems bad,
# and turning it on will require a non-trivial calling convention change for any export runtime.
if config.functionalize_rng_ops:
raise RuntimeError("""\
Functionalized RNG is not currently supported in the aot_export workflow. Please file a github issue,
or otherwise set torch._functorch.config.functionalize_rng_ops = False.""")
# crappy version of dispatcher
# TODO: Do this properly
if needs_autograd:
# For now, aot_dispatch_autograd knows to explicitly return a graph
# when run with export, and an opaque callable otherwise.
# In theory we could factor these out, but I wanted to let the dust
# settle on how functionalized rng fits into export first.
compiler_fn = aot_dispatch_autograd_graph if aot_config.is_export else aot_dispatch_autograd
else:
# aot_dispatch_base_graph contains only the "graph bits", while aot_dispatch_base
# includes some extra work around handling a runtime epilogue.
compiler_fn = aot_dispatch_base_graph if aot_config.is_export else aot_dispatch_base
compiler_fn = partial(aot_wrapper_synthetic_base, compiler_fn=compiler_fn, needs_autograd=needs_autograd)
compiler_fn = partial(aot_wrapper_dedupe, compiler_fn=compiler_fn)
# You can put more passes here
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
if aot_config.is_export:
mutated_user_inp_locs = [
idx - aot_config.num_params_buffers
for idx in fw_metadata.mutated_inp_indices
if idx >= aot_config.num_params_buffers
]
if len(mutated_user_inp_locs) > 0:
raise RuntimeError(f"""
Found following user inputs located at {mutated_user_inp_locs} are mutated. This is currently banned in the aot_export workflow.
If you need this functionality, please file a github issue.
fw_metadata={str(fw_metadata)}""")
# During export, we don't get back a callable - we get back the raw fx graph
# (either a joint or an inference-only graph)
assert isinstance(compiled_fn, torch.fx.GraphModule)
return compiled_fn, fw_metadata
if not hasattr(compiled_fn, "_boxed_call"):
compiled_fn = make_boxed_func(compiled_fn)
return compiled_fn
# Inspired by autodidax (thanks!)
class PytreeThunk:
spec = None
# These are some kinda dumb microoptimizations that save about 3-4 us of overhead.
is_simple = (
None # if the output spec is a tuple/list, we won't bother unflattening it.
)
is_really_simple = None # if the output spec is a LeafSpec
def set(self, spec):
assert self.spec is None or self.spec == spec
self.spec = spec
if type(self.spec) in [tuple, list] and all(
isinstance(i, pytree.LeafSpec) for i in spec.children_specs
):
self.is_simple = True
if isinstance(self.spec, pytree.LeafSpec):
self.is_really_simple = True
def unflatten(self, x):
if self.is_really_simple:
return x[0]
if self.is_simple:
return x
return pytree.tree_unflatten(x, self.spec)
def create_functional_call(mod, params_spec, params_len):
# Redundant with dynamo, but worth having in case this gets invoked elsewhere.
# https://github.com/pytorch/pytorch/issues/103569
def functional_call(*args, **kwargs):
with stateless._reparametrize_module(
mod, pytree.tree_unflatten(args[:params_len], params_spec)
):
if isinstance(mod, torch.fx.GraphModule):
with fx_traceback.preserve_node_meta(), warnings.catch_warnings():
warnings.filterwarnings(
"ignore", "Anomaly Detection has been enabled."
)
with torch.autograd.detect_anomaly(check_nan=False):
out = Interpreter(mod).run(*args[params_len:], **kwargs)
else:
out = mod(*args[params_len:], **kwargs)
if not isinstance(out, (tuple, list)):
raise RuntimeError(
"Graph output must be a tuple(). This is so that we can avoid "
"pytree processing of the outputs. Please change the module to "
"have tuple outputs or use aot_module instead."
)
return out
return functional_call
# Creates a function that returns flattened inputs and outputs
# Also returns the output tree spec, which is needed to recover the "unflattened"
# output tree structure later.
def create_tree_flattened_fn(fn, args, kwargs=None) -> Tuple[Callable, PytreeThunk]:
if kwargs is None:
kwargs = {}
# Save the args_spec for flat_tensor_args to unflatten while tracing
_, tensor_args_spec = pytree.tree_flatten((args, kwargs))
out_spec = PytreeThunk()
def flat_fn(*flat_args):
# The input are flattened tensor args. Prepare the args in the
# order that original function expects. Add static args as well.
# They will appear as tensor constants in the traced graph.
nonlocal out_spec
args, kwargs = pytree.tree_unflatten(flat_args, tensor_args_spec)
tree_out = fn(*args, **kwargs)
flat_out, spec = pytree.tree_flatten(tree_out)
for i in flat_out:
is_known_type = False
for j in KNOWN_TYPES:
if isinstance(i, j):
is_known_type = True
break
if not is_known_type:
raise RuntimeError(
f"Found {type(i)} in output, which is not a known type. "
"If this type holds tensors, you need to register a pytree for it. "
"See https://github.com/pytorch/functorch/issues/475 for a brief "
"explanation why. If you don't need to register a pytree, please "
"leave a comment explaining your use case and we'll make this more "
"ergonomic to deal with"
)
out_spec.set(spec)
return flat_out
return flat_fn, out_spec
def _graph_input_names(gm):
return [node.name for node in gm.graph.nodes if node.op == "placeholder"]
def _graph_output_names(gm):
output_node = next(iter(reversed(gm.graph.nodes)))
assert output_node.op == "output" and len(output_node.args) == 1
return_args = output_node.args[0]
return [getattr(return_arg, "name", None) for return_arg in return_args]
def create_graph_signature(
fx_g: torch.fx.GraphModule,
fw_metadata: ViewAndMutationMeta,
in_spec: pytree.TreeSpec,
out_spec: pytree.TreeSpec,
*,
user_args_flat: List[torch.Tensor],
params_and_buffers_flat: List[torch.Tensor],
param_names: List[str],
buffer_names: List[str],
trace_joint: bool,
num_user_fw_outs: Optional[int],
loss_index: Optional[int],
) -> GraphSignature:
# Retrieve graph input names
graph_input_names = _graph_input_names(fx_g)
# Retrieve graph output names
graph_output_names = _graph_output_names(fx_g)
num_params_buffers = len(param_names) + len(buffer_names)
# We have enough restrictions on the graph (no de-duping, synthetic bases, etc),
# Such that # graph inps = # user inps + # params + # buffers
num_user_args = len(graph_input_names) - num_params_buffers
if trace_joint:
assert num_user_fw_outs is not None
num_fw_outs = num_user_fw_outs + fw_metadata.num_mutated_inputs
backward_output_names = graph_output_names[num_fw_outs:]
grad_index = itertools.count(0)
gradients_to_parameters = {
backward_output_names[next(grad_index)]: param_names[i]
for i, param in enumerate(params_and_buffers_flat)
if param.requires_grad
}
gradients_to_user_inputs = {
backward_output_names[next(grad_index)]: graph_input_names[i + len(params_and_buffers_flat)]
for i, user_input in enumerate(user_args_flat)
if user_input.requires_grad
}
assert len(gradients_to_parameters) + len(gradients_to_user_inputs) == len(
backward_output_names
)
# Check that we have fully accounted for all graph outputs
backward_signature = BackwardSignature(
gradients_to_parameters,
gradients_to_user_inputs,
graph_output_names[loss_index],
)
else:
backward_signature = None
num_user_fw_outs = len(graph_output_names) - fw_metadata.num_mutated_inputs
return GraphSignature.from_tracing_metadata(
in_spec=in_spec,
out_spec=out_spec,
graph_input_names=graph_input_names,
graph_output_names=graph_output_names,
view_mutation_metadata=fw_metadata,
named_parameters=param_names,
named_buffers=buffer_names,
num_user_inputs=num_user_args,
num_user_outputs=num_user_fw_outs,
loss_index=loss_index,
backward_signature=backward_signature,
)
def aot_function(
fn: Callable,
fw_compiler: Callable,
bw_compiler: Optional[Callable] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
num_params_buffers: int = 0,
keep_inference_input_mutations: bool = False,
inference_compiler: Optional[Callable] = None,
*,
# Whether or not to trace with dynamic shapes
dynamic=False,
enable_log=True,
) -> Callable:
"""
Traces the forward and backward graph of :attr:`fn` using torch dispatch
mechanism, and then compiles the generated forward and backward graphs
through :attr:`fw_compiler` and :attr:`bw_compiler`.
:func:`aot_function` traces the forward and backward graph ahead of time,
and generates a joint forward and backward graph. :attr:`partition_fn` is
then used to separate out forward and backward graphs. The partitioner
function can be used to perform optimizations such as recomputation. One can
set `decompositions` dictionary to decompose the operators into a sequence
of core or simpler operators supported by the backend compilers.
.. warning::
This API is experimental and likely to change.
Args:
fn (Callable): A Python function that takes one ore more arguments. Must
return one or more Tensors.
fw_compiler (Callable): A Python function that accepts an Fx graph with
Aten ops and input args, and returns a Callable that semantically is
equivalent to the input Fx graph.
bw_compiler (Optional[Callable]): A Python function that accepts an
Fx graph with Aten ops and input args, and returns a Callable that
semantically is equivalent to the input Fx graph. Default: None
(when None, it defaults to the :attr:`fw_compiler`)
partition_fn (Callable): A Python function that takes a joint forward
and backward graph, and partitions it into separate forward and
backward graphs.
decompositions (Dict): A dictionary to define the decomposition of
larger Aten ops into simpler or core Aten ops.
inference_compiler (Optional[Callable]): A Python function that accepts an
Fx graph with Aten ops and input args, and returns a Callable that
semantically is equivalent to the input Fx graph. inference_compiler is invoked
if no autograd is needed. Default: None
(when None, it defaults to the :attr:`fw_compiler`)
Returns:
Returns a ``Callable`` that retains the eager behavior of the original
:attr:`fn`, but with forward and backward graph compiled via
:attr:`fw_compile` and :attr:`bw_compile`.
A simple example usage of :func:`aot_function` is as follows. This example
will print the forward and backward graphs of the function ``fn``
>>> fn = lambda x : x.sin().cos()
>>> def print_compile_fn(fx_module, args):
>>> print(fx_module)
>>> return fx_module
>>> aot_fn = aot_function(fn, print_compile_fn)
>>> x = torch.randn(4, 5, requires_grad=True)
>>> aot_fn(x)
"""
if bw_compiler is None:
bw_compiler = fw_compiler
if inference_compiler is None:
inference_compiler = fw_compiler
aot_config = AOTConfig(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
inference_compiler=inference_compiler,
partition_fn=partition_fn,
decompositions=decompositions,
num_params_buffers=num_params_buffers,
aot_id=next(AOT_COUNTER),
keep_inference_input_mutations=keep_inference_input_mutations,
dynamic_shapes=dynamic,
aot_autograd_arg_pos_to_source=None,
is_export=False,
no_tangents=False,
enable_log=enable_log,
)
cached_res = None
@wraps(fn)
def returned_function(*args, **kwargs):
nonlocal cached_res
# Now flatten the tensor args
flat_args = pytree.arg_tree_leaves(*args, **kwargs)
# Compile the function and save it in the cache
if cached_res is None:
flat_fn, out_spec = create_tree_flattened_fn(fn, args, kwargs)
compiled_fn = create_aot_dispatcher_function(
flat_fn,
flat_args,
aot_config,
)
cached_res = (compiled_fn, out_spec)
cached_fn, out_spec = cached_res
out = cached_fn(flat_args)
return out_spec.unflatten(out)
return returned_function
def aot_module(mod: nn.Module, *args, **kwargs) -> nn.Module:
"""
Traces the forward and backward graph of :attr:`mod` using torch dispatch
tracing mechanism. It is wrapper function, that underneath uses
:func:`aot_function` to perform tracing and compilation.
:func:`aot_module` lifts the parameters and buffers of ``nn.Module`` as inputs
to a new callable which is then compiled through :func:`aot_function`.
.. warning::
This API is experimental and likely to change.
Args:
mod (Callable): A ``nn.Module`` module.
args : args to be passed to :func:`aot_function`
kwargs : kwargs to be passed to :func:`aot_function`
Returns:
Returns a ``nn.Module`` that retains the eager behavior of the original
:attr:`mod`, but with forward and backward graph compiled.
"""
# See Note: [Fake Modules and AOTAutograd]
torch._dynamo.utils.assert_no_fake_params_or_buffers(mod)
def functional_call(named_params, named_buffers, *args, **kwargs):
params_and_buffers = {**named_params, **named_buffers}
return torch.func.functional_call(mod, params_and_buffers, args, kwargs)
named_params = dict(mod.named_parameters(remove_duplicate=False))
named_buffers = dict(mod.named_buffers(remove_duplicate=False))
num_params_buffers = len(named_params) + len(named_buffers)
compiled_f = aot_function(
functional_call, *args, num_params_buffers=num_params_buffers, **kwargs
)
class AOTModule(nn.Module):
def __init__(self):
super().__init__()
self.orig_module = mod
def forward(self, *args, **kwargs):
return compiled_f(
named_params,
named_buffers,
*args,
**kwargs,
)
return AOTModule()
def aot_module_simplified(
mod: nn.Module,
args,
fw_compiler: Callable,
bw_compiler: Optional[Callable] = None,
partition_fn: Callable = default_partition,
decompositions: Optional[Dict] = None,
keep_inference_input_mutations=False,
inference_compiler: Optional[Callable] = None,
) -> nn.Module:
"""
This is the simplified or low overhead version of aot_module. For frontends
like TorchDynamo, the input functions/modules to AOT are static and have
unpacked inputs/outputs. This gives us an opportunity to remove the
(1) pytree overhead to parse inputs/outputs,
(2) AOT Autograd cache,
(3) Reading of params/buffers in every forward call
:func:`aot_module_simplified` removes these overheads.
"""
params = {
**dict(mod.named_parameters(remove_duplicate=False)),
**dict(mod.named_buffers(remove_duplicate=False)),
}
params_flat, params_spec = pytree.tree_flatten(params)
params_flat = list(params_flat)
params_len = len(params_flat)
functional_call = create_functional_call(mod, params_spec, params_len)
if bw_compiler is None:
bw_compiler = fw_compiler
if inference_compiler is None:
inference_compiler = fw_compiler
seen_sources = set()
full_args = []
# First, the params
full_args.extend(params_flat)
if torch._guards.TracingContext.get():
torch._guards.TracingContext.get().params_flat = params_flat
aot_autograd_arg_pos_to_source = None
# Then, the params 1:1 mapped sources, if relevant.
if hasattr(mod, "_param_name_to_source"):
aot_autograd_arg_pos_to_source = []
# We now know this came from dynamo, and (1) we care about guards,
# so setting up aot_autograd_arg_pos_to_source for downstream dedup guards
# can now be done safely. (2) Dynamo logic protects the 1:1 sizing below.
for name in params.keys():
assert name in mod._param_name_to_source, f"{name} not found."
source = mod._param_name_to_source[name]
assert source not in seen_sources, source
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)
# Next, the input args
full_args.extend(args)
if hasattr(mod, "graph"):
# Non dynamo entrypoints can get to here...
for i, node in enumerate(mod.graph.nodes):
if node.op == "placeholder":
if hasattr(node, "_dynamo_source"):
# ... but not here!
if aot_autograd_arg_pos_to_source is None:
aot_autograd_arg_pos_to_source = []
source = node._dynamo_source
assert source not in seen_sources, source
seen_sources.add(source)
aot_autograd_arg_pos_to_source.append(source)
if aot_autograd_arg_pos_to_source is not None:
assert len(full_args) == len(aot_autograd_arg_pos_to_source)
dynamic_shapes = False
for x in full_args:
if isinstance(x, FakeTensor):
dynamic_shapes = x.fake_mode.shape_env is not None
break
aot_config = AOTConfig(
fw_compiler=fw_compiler,
bw_compiler=bw_compiler,
inference_compiler=inference_compiler,
partition_fn=partition_fn,
decompositions=decompositions,
num_params_buffers=params_len,
aot_id=next(AOT_COUNTER),
keep_inference_input_mutations=keep_inference_input_mutations,
dynamic_shapes=dynamic_shapes,
aot_autograd_arg_pos_to_source=aot_autograd_arg_pos_to_source,
is_export=False,
no_tangents=False,
)
with compiled_autograd.disable():
compiled_fn = create_aot_dispatcher_function(
functional_call,
full_args,
aot_config,
)
# TODO: There is something deeply wrong here; compiled_fn running with
# the boxed calling convention, but aot_module_simplified somehow
# historically returned a function that was not the boxed calling
# convention. This should get fixed...
def forward(*runtime_args):
full_args = []
full_args.extend(params_flat)
full_args.extend(runtime_args)
return compiled_fn(full_args)
# Just for convenience
forward.zero_grad = mod.zero_grad
forward.named_parameters = mod.named_parameters
forward.named_buffers = mod.named_buffers
return forward
def aot_export_module(
mod: nn.Module,
args,
*,
decompositions: Optional[Dict] = None,
# If true, we'll return a joint forward-backward graph,
# As well as metadata on the loss + gradients in the backward.
trace_joint: bool,
# If trace_joint is True, we expect your module to return a scalar loss.
# Your module can return multiple outputs, so you must specify which output the loss is.
output_loss_index: Optional[int] = None,
) -> Tuple[torch.fx.GraphModule, GraphSignature]:
"""
This function takes in a module, and returns:
(1) an FX graph that can be exported
(2) some metadata about the graph
If `trace_joint=True` we will return a joint graph of the forward + backward.
The traced FX graph will have the following properties compared to the original module:
(1) Inputs and outputs to the module will be pytree-flattened
(2) Parameters and buffers on the module will be lifted into graph inputs,
graph_inputs = (*parameters, *buffers, *user_inputs)
(3) The graph will be fully functionalized
(4) Any input mutations will be converted into additional outputs in the graph,
meaning whoever calls this graph is responsible for applying the mutations
back to the original inputs.
(5) If is_joint is provided the graph will return parameter gradients in addition to user outputs.
The graph output will look like:
graph_outputs = (*updated_inputs, *user_outputs, *param_gradients)
There are also several restrictions on what modules can use this API. In particular:
(1) If trace_joint is specified, we expect the loss function to be **fused**
into the module forward. One of the outputs to the forward must be a scalar loss,
which is specified with `output_loss_index`.
All other outputs to the forward are presumed to not require gradients.
(2) This API cannot capture optimizers (although in theory we could build an API for this).
(3) Metadata mutations on params/buffers/inputs are banned.
(4) Data mutations on anything that requires gradients are banned (parameters)
(5) If an input is mutated, it is not allowed to alias any other inputs.
(6) Parameters must not be duplicated.
"""
named_parameters = dict(mod.named_parameters(remove_duplicate=False))
named_buffers = dict(mod.named_buffers(remove_duplicate=False))
params_and_buffers = {
**dict(named_parameters),
**dict(named_buffers),
}
params_and_buffers_flat, params_spec = pytree.tree_flatten(params_and_buffers)
params_and_buffers_flat = tuple(params_and_buffers_flat)
params_len = len(params_and_buffers_flat)
functional_call = create_functional_call(mod, params_spec, params_len)
num_fw_outs = None
if trace_joint:
# This helper effectively just adds some extra asserts about what the backward will look like:
# Outputs must include a scalar loss, that we compute gradients w.r.t.
# We don't compute gradients w.r.t. anything else: so just in case we detach()
# and other output tensors.
def fn_to_trace(*args):
nonlocal num_fw_outs
out = functional_call(*args)
if output_loss_index is None:
raise RuntimeError("""\
If trace_joint=Trueit is required that one of your forward outputs must be a scalar loss.
You must specify the which (index) output is the loss with output_loss_index.""")
if isinstance(out, (torch.Tensor)):
out = (out,)
if not isinstance(out, (tuple, list)):
raise RuntimeError(f"Expected forward output to be either a tensor or a list/tuple of tensors. found {type(out)}")
for i, o in enumerate(out):
# We only want to create a backward graph w.r.t. the loss that the user passed in.
# This implies that every other output should not require gradients.
# Instead of making this an error (and forcing the user to detach all other outputs
# of their forward),
# we'll automatically detach them here.
if o.requires_grad and i != output_loss_index:
raise RuntimeError(f"""\
Found an output of the forward that requires gradients, that was not the scalar loss.
We require all outputs to the forward that are not the scalar loss to not require gradient,
because we will only compute a backward graph against the scalar loss.
You can fix this by calling .detach() on each of your forward outputs that is not the loss.
You specified that output index {output_loss_index} is the loss, but we found that
the output at index {i} requires gradients.""")
out_loss = out[output_loss_index]
num_fw_outs = len(out)
if not out_loss.requires_grad:
raise RuntimeError(f"""\
The output at index {output_loss_index} was marked as the loss, but it does not require gradients""")
if out_loss.numel() != 1:
raise RuntimeError(f"""\
We require the output marked as the loss (at index {output_loss_index}) to be a scalar, but it has shape {out_loss.shape}""")
return out
ctx = nullcontext
else:
# Run under no_grad, so our tracing machinery only traces an inference graph.
ctx = torch.no_grad
fn_to_trace = functional_call
full_args = []
# First, the params
# NB: It is REQUIRED that parameters come first, Inductor infers "fixed"
# parameters by looking at the difference in parameter count outside
# and inside AOTAutograd, and assumes the prefix of arguments are fixed
# arguments
full_args.extend(params_and_buffers_flat)
# Next, the input args
full_args.extend(args)
with ctx():
fx_g, metadata, in_spec, out_spec = _aot_export_function(
fn_to_trace,
full_args,
decompositions=decompositions,
num_params_buffers=len(params_and_buffers_flat),
no_tangents=True,
)
if trace_joint:
def flattened_joint(*args):
# The idea here is that the joint graph that AOTAutograd creates has some strict properties:
# (1) It accepts two arguments (primals, tangents), and pytree_flattens them
# (2) It returns a tuple of (fw_outs, gradients)
# This is a very useful convention for anyone who wants to partition the joint graph
# into a separate forward and backward graph.
# However,
# (1) for people exporting a single joint graph, it would be preferable not to have
# any pytrees in the graph.
# (2) We are guaranteed in the aot_export_module case that the forward outputs a loss,
# and there are therefore no tangents that are needed to run the joint graph.
# (3) AOTAutograd creates a grad_input for every input in the forward,
# including None's for inputs that are not grad-requiring tensors.
# we don't want these in our export graph.
# and there are therefore no tangents that are needed to run the joint graph.
# This function "fixes" both of the above by removing any tangent inputs,
# and removing pytrees from the original FX graph.
fake_tangents = [None for _ in range(metadata.num_outputs + metadata.num_mutated_inputs)]
fw_outs, gradients = fx_g(args, fake_tangents)
assert len(gradients) == len(args)
output_gradients = []
for i, (a, grad) in enumerate(zip(args, gradients)):
if isinstance(a, torch.Tensor) and a.requires_grad:
assert grad is not None, """\
Found a parameter that did not receive a gradient.
"This is most likely a bug, but if this needs to be supported please comment on this Github issue:
https://github.com/pytorch/pytorch/issues/101192
"""
output_gradients.append(grad)
else:
assert grad is None
return *fw_outs, *output_gradients
fx_g = make_fx(flattened_joint)(*full_args)
user_args_flat = pytree.arg_tree_leaves(*args)
return fx_g, create_graph_signature(
fx_g,
metadata,
in_spec,
out_spec,
user_args_flat=user_args_flat,
params_and_buffers_flat=params_and_buffers_flat,
param_names=list(named_parameters.keys()),
buffer_names=list(named_buffers.keys()),
trace_joint=trace_joint,
num_user_fw_outs=num_fw_outs,
loss_index=output_loss_index,
)
def aot_export_joint_simple(
func: Callable,
args,
*,
trace_joint: bool,
# It looks like the main consequence of this API is that for dynamic shapes,
# it will assume that parms/buffers are static.
# With the new inferred dynamic shapes API, maybe this doesn't matter?
num_params_buffers: int = 0,
decompositions: Optional[Dict] = None,
) -> torch.fx.GraphModule:
"""
A simplified version of export. Used by higher order operators.
This function makes a high-level "no calling convention changes" guarantee:
- If no inputs require grad (so we export an inference graph),
there are *no* calling convention change between the exported graph, and "func".
- If at least one input requires grad (so we trace out and export a joint fw-bw graph),
Then if you were partition the graph into a separate forward and backward graph,
The forward graph will have no calling convention changes compared to "func".
The above also relies on some strong restrictions around which functions this API accepts:
(1) `args` cannot contain any pytrees (they must have been pytree_flattened already)
(2) `func` cannot mutate any inputs
(3) The outputs of `func` cannot alias any inputs.
Note: this function is only lightly tested today. It will probably be tested more heavily by higher order ops.
"""
if trace_joint:
ctx = nullcontext
else:
# Run under no_grad, so our tracing machinery only traces an inference graph.
ctx = torch.no_grad
with ctx():
fx_g, metadata, in_spec, out_spec = _aot_export_function(
func,
args,
decompositions=decompositions,
)
# At this point, we can just directly return the (joint or inference graph) that we traced.
# First though: a bunch of assertions to make sure that our graph doesn't require
# any calling convention changes compared to the original function.
# These restrictions are *in addition to* the general restrictions on export.
# No input mutations
if len([x for x in metadata.input_info if x.mutates_data or x.mutates_metadata]) != 0:
raise RuntimeError(f"aot_export_joint_simple does not support input mutations. {str(metadata)}")
# No output aliasing
if len([x for x in metadata.output_info if x.output_type != OutputType.non_alias]) != 0:
raise RuntimeError(f"aot_export_joint_simple does not support outputs that alias inputs. {str(metadata)}")
# No pytrees
if type(in_spec) == pytree.LeafSpec:
raise RuntimeError(f"aot_export_joint_simple requires inputs to be a single list/tuple. in_spec={str(in_spec)}")
if len([x for x in in_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
raise RuntimeError(f"aot_export_joint_simple requires individual inputs not to be pytrees. in_spec={str(in_spec)}")
if type(out_spec) == pytree.LeafSpec:
raise RuntimeError(f"aot_export_joint_simple requires outputs to be a single list/tuple. out_spec={str(out_spec)}")
if len([x for x in out_spec.children_specs if type(x) != pytree.LeafSpec]) != 0:
raise RuntimeError(f"aot_export_joint_simple requires individual outputs not to be pytrees. out_spec={str(out_spec)}")
# TODO: we might have to temporarily patch config.functionalize_rng
# so that it doesn't run when we're exporting a higher order op.
if config.debug_assert:
# Smoke test that after partitioning, we can run the forward without any calling convention changes.
fw_module, bw_module = aot_config.default_partition(
fx_g, args, num_fwd_outputs=len(fw_metadata.output_infos)
)
# Attempt to run the fw_module with the original user inputs
fake_mode = detect_fake_mode(args)
if fake_mode is None:
fake_mode = FakeTensorMode()
with fake_mode:
fw_module(*args)
return fx_g
# Private for now because we aren't providing a contract on what to return
# for joint graphs (we could when there's a clearer use case)
# In the future, we may need to add more export API's that provide their own strong guarantees.
# This is meant as a general helper function for handling various export-y use cases.
def _aot_export_function(
func: Callable,
args,
*,
num_params_buffers: int = 0,
decompositions: Optional[Dict] = None,
# If we're exporting a joint graph and we don't want any tangent inputs in the graph
# (because we are backpropping through a scalar 1 loss),
# we need to explicitly specify not to include tangents in the graph.
# It's not enough just to check that our tangent is a scalar, since we also
# need to know if it is a 1 (no need to make it a graph input), or something else
# (requiring it to be a graph input).
# We don't know this info at trace time though, so we need to make it an explicit config.
no_tangents: bool = False,
) -> Tuple[torch.fx.GraphModule, ViewAndMutationMeta, pytree.TreeSpec, pytree.TreeSpec]:
dynamic_shapes = False
for x in args:
if isinstance(x, FakeTensor):
dynamic_shapes = x.fake_mode.shape_env is not None
break
flat_fn, out_spec = create_tree_flattened_fn(func, args)
flat_args, in_spec = pytree.tree_flatten(args)
# The export use case doesn't care about several bits of AOTConfig
# (1) compilers (we just export the graph)
# (2) partitioners (export is only full graph, user can partition themselves)
aot_config = AOTConfig(
fw_compiler=None,
bw_compiler=None,
inference_compiler=None,
partition_fn=None,
decompositions=decompositions,
num_params_buffers=num_params_buffers,
aot_id=next(AOT_COUNTER),
# For now there's no use case involving keeping input mutations in the graph
# (which we can only do in the inference case anyway).
# We can add this later if we need to.
keep_inference_input_mutations=False,
dynamic_shapes=dynamic_shapes,
aot_autograd_arg_pos_to_source=None,
is_export=True,
no_tangents=no_tangents,
)
fx_g, meta = create_aot_dispatcher_function(
flat_fn,
flat_args,
aot_config,
)
return fx_g, meta, in_spec, out_spec.spec
compiled_function = aot_function
compiled_module = aot_module