pytorch/torch/_dynamo/variables/functions.py
Guilherme Leobas 80cf0ce153 Enhance torch.vmap support from inside torch.compile (#116050)
This work rewrites vmap support in torch.compile by inlining most of
the frames into the existing FX graph. It also unlocks to PyTorch to
support features that were previously missing, such as keyword args.

Fixes: https://github.com/pytorch/pytorch/issues/114306

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116050
Approved by: https://github.com/zou3519
2024-01-22 17:53:45 +00:00

823 lines
30 KiB
Python

import functools
import inspect
import itertools
import types
from typing import Dict, List, Optional, TYPE_CHECKING, Union
import torch
from .. import variables
from ..bytecode_transformation import create_call_function, create_rot_n
from ..exc import unimplemented, Unsupported
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import get_first_attr, make_cell
from .base import typestr, VariableTracker
if TYPE_CHECKING:
from torch._guards import Source
def wrap_bound_arg(tx, val, source=None):
# Source propagation is best effort since not every object we encounter has a source to begin with.
if isinstance(val, VariableTracker):
return val
elif not source:
from torch._dynamo.variables.builder import SourcelessBuilder
return SourcelessBuilder()(tx, val)
else:
from torch._dynamo.variables.builder import VariableBuilder
return VariableBuilder(tx, source=source)(val)
def wrap_args_kwargs(tx, result):
for k, v in list(result.items()):
if isinstance(v, (tuple, dict)):
# args/kwargs
result[k] = wrap_bound_arg(tx, v)
def init_cellvars(parent, result, code):
closure_cells = dict()
side_effects = parent.output.side_effects
# for name in itertools.chain(code.co_cellvars, code.co_freevars):
for name in code.co_cellvars:
closure_cells[name] = side_effects.track_cell_new()
if name in result:
side_effects.store_cell(closure_cells[name], result.pop(name))
return closure_cells
def _create_nested_fn(
code, f_globals, name, defaults, closure, kwdefaults, annotations
):
from types import FunctionType
func = FunctionType(code, f_globals, name, defaults, closure)
func.__kwdefaults__ = kwdefaults
if isinstance(annotations, tuple):
from itertools import pairwise
annotations = dict(pairwise(annotations))
# TypeError: __annotations__ must be set to a dict object
assert annotations is None or isinstance(annotations, dict)
func.__annotations__ = annotations
return func
class BaseUserFunctionVariable(VariableTracker):
def get_filename(self):
return self.get_code().co_filename
def get_name(self):
return self.get_code().co_name
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return tx.inline_user_function_return(
self, list(self.self_args()) + list(args), kwargs
)
def call_hasattr(self, tx, name: str) -> VariableTracker:
result = False
try:
result = hasattr(self.get_function(), name)
except NotImplementedError:
if name == "__name__" and isinstance(self, NestedUserFunctionVariable):
result = True
return variables.ConstantVariable.create(result)
def inspect_parameter_names(self):
return list(inspect.signature(self.get_function()).parameters)
def closure_vars(self, tx):
return {}
class UserFunctionVariable(BaseUserFunctionVariable):
"""Some unsupported user-defined global function"""
@classmethod
def create_with_source(cls, value, source):
return cls(
value,
source=source,
)
def __init__(self, fn, is_constant=False, **kwargs):
super().__init__(**kwargs)
if getattr(fn, "_dynamo_marked_constant", False):
# This method should be treated as a constant for the purposes of compilation
self.is_constant = True
else:
self.is_constant = False
assert isinstance(
fn, (types.FunctionType, torch.jit.ScriptFunction)
), f"expected FunctionType found {typestr(fn)} {fn}"
# unpack @torch._dynamo.optimize()(fn) wrapped function
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
# unpack torch.jit.script_if_tracing
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
fn = inspect.getattr_static(fn, "__original_fn", fn)
self.fn: types.FunctionType = fn
def self_args(self):
return []
def get_function(self):
return self.fn
def get_code(self):
return self.fn.__code__
def python_type(self):
return types.FunctionType
def has_self(self):
return getattr(self.fn, "__self__", None) is not None
def get_globals(self):
return self.fn.__globals__
def bind_args(self, parent, args, kwargs):
assert not self.is_constant
tx = parent.output.root_tx
wrap = functools.partial(wrap_bound_arg, tx=tx)
fn: types.FunctionType = self.fn
defaults = fn.__defaults__ or []
defaults_sources = [
None if self.source is None else DefaultsSource(self.source, idx)
for idx, _ in enumerate(defaults)
]
fake_func = types.FunctionType(
fn.__code__,
fn.__globals__,
fn.__name__,
tuple(
[
wrap(val=arg, source=source)
for arg, source in zip(defaults, defaults_sources)
]
),
fn.__closure__,
)
if fn.__kwdefaults__:
kwdefaults_sources = {
k: None
if self.source is None
else DefaultsSource(self.source, k, is_kw=True)
for k in fn.__kwdefaults__
}
fake_func.__kwdefaults__ = {
k: wrap(val=v, source=kwdefaults_sources[k])
for k, v in fn.__kwdefaults__.items()
}
bound = inspect.signature(fake_func).bind(*args, **kwargs)
bound.apply_defaults()
result = dict(bound.arguments.items())
wrap_args_kwargs(tx, result)
closure_cells = init_cellvars(parent, result, fn.__code__)
closure = self.fn.__closure__ or ()
assert len(closure) == len(self.fn.__code__.co_freevars)
for idx, name, cell in zip(
itertools.count(), self.fn.__code__.co_freevars, closure
):
if name == "__class__":
source = AttrSource(self.source, "__class__") if self.source else None
result[name] = variables.UserDefinedClassVariable(
cell.cell_contents,
source=source,
)
else:
var = tx.match_nested_cell(name, cell)
if var is not None:
# optimization for cleaner codegen
result[name] = var
elif self.source:
from .builder import VariableBuilder
side_effects = parent.output.side_effects
if cell in side_effects:
out = side_effects[cell]
else:
closure_cell = GetItemSource(
AttrSource(self.source, "__closure__"), idx
)
closure_cell_contents = AttrSource(
closure_cell, "cell_contents"
)
try:
contents_var = VariableBuilder(
parent, closure_cell_contents
)(cell.cell_contents)
except ValueError:
# Cell has not yet been assigned
contents_var = variables.DeletedVariable()
if (
closure_cell_contents.name()
not in tx.mutated_closure_cell_contents
):
# Optimistically don't allocate the cell, to
# reduce the number of side effects. This is
# important for cond, as without it, any accesses
# to closures create side effects and cond doesn't
# support side effects. If we're wrong and this
# closure cell gets written to, we will restart
# the analysis with this cell's name in the
# mutated list here
result[name] = contents_var
continue
# cells are written to with "cell_contents",
# so the source should just be the closure_cell, not its contents
out = side_effects.track_cell_existing(closure_cell, cell)
side_effects.store_cell(
out,
contents_var,
)
result[name] = out
else:
from .builder import SourcelessBuilder
result[name] = SourcelessBuilder()(tx, cell.cell_contents)
return result, closure_cells
def export_freevars(self, parent, child):
pass
def call_hasattr(self, tx, name: str) -> VariableTracker:
result = hasattr(self.fn, name)
return variables.ConstantVariable.create(result)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if self.is_constant:
return invoke_and_store_as_constant(
tx, self.fn, self.get_name(), args, kwargs
)
return super().call_function(tx, args, kwargs)
class UserMethodVariable(UserFunctionVariable):
"""Some unsupported user-defined method"""
def __init__(self, fn, obj, **kwargs):
super().__init__(fn=fn, **kwargs)
self.obj = obj
def __str__(self):
return f"{self.__class__.__name__}({self.fn}, {self.obj})"
def self_args(self):
return [self.obj]
def python_type(self):
return types.MethodType
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
# For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
# rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
# since we ensure `forward` of allowed modules can be traced by AOT safely.
# Note this is not only for allowed modules, as user customized modules can extend from
# allowed modules but using parent's `forward` method, which is also covered by this branch.
# If we are tracing the higher order op, we want Dynamo to step inside
# the module call so that Dynamo can see the underlying parameters and
# buffers and raise them as inputs to the graph. The is_root_tracer
# check bypasses the if condition for non-root tracers and directly
# calls the super().call_function at the end, which is basically
# equivalent of inlining the method.
if tx.output.is_root_tracer() and isinstance(
self.obj, variables.NNModuleVariable
):
module_attr = getattr(self.fn, "__module__", "")
if (
module_attr is not None
and module_attr.startswith("torch.nn.")
or self.is_constant
):
return self.obj.call_method(
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
)
return super().call_function(tx, args, kwargs)
def inspect_parameter_names(self):
return super().inspect_parameter_names()[1:]
class WrappedUserMethodVariable(UserMethodVariable):
def __init__(self, wrapped, context, **kwargs):
kwargs.pop("fn", None)
kwargs.pop("obj", None)
super().__init__(wrapped.fn, wrapped.obj, **kwargs)
self.wrapped = wrapped
self.context = context
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
self.context.enter(tx)
result = super().call_function(tx, args, kwargs)
self.context.exit(tx)
return result
class WrappedUserFunctionVariable(UserFunctionVariable):
def __init__(self, wrapped, context, **kwargs):
kwargs.pop("fn", None)
kwargs.pop("obj", None)
super().__init__(wrapped.fn, **kwargs)
self.wrapped = wrapped
self.context = context
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
self.context.enter(tx)
result = super().call_function(tx, args, kwargs)
self.context.exit(tx)
return result
def invoke_and_store_as_constant(tx, fn, name, args, kwargs):
def convert(x):
if isinstance(x, variables.TensorVariable):
return x.get_real_value()
return x.as_python_constant()
args = [convert(x) for x in args]
kwargs = {k: convert(v) for k, v in kwargs.items()}
res = fn(*args, **kwargs)
return tx.output.register_attr_or_module(
res,
name,
source=ConstantSource(name),
)
class NestedUserFunctionVariable(BaseUserFunctionVariable):
_nonvar_fields = {
"closure_scope",
"f_globals",
*BaseUserFunctionVariable._nonvar_fields,
}
def __init__(
self,
fn_name,
code,
f_globals,
defaults,
kwdefaults,
annotations,
closure,
closure_scope,
wrapped_reconstructible=None,
**kwargs,
):
super().__init__(**kwargs)
assert isinstance(fn_name.as_python_constant(), str)
assert isinstance(code.as_python_constant(), types.CodeType)
assert isinstance(f_globals, dict)
self.fn_name = fn_name
self.code = code
self.f_globals = f_globals
self.defaults = defaults
self.kwdefaults = kwdefaults
self.annotations = annotations
self.closure = closure
if closure is None:
closure_scope = None
self.closure_scope = closure_scope
# Either a source or a VT with .can_reconstruct() == True
self.wrapped_reconstructible: Optional[
Union[Source, VariableTracker]
] = wrapped_reconstructible
def self_args(self):
return []
def get_code(self):
return self.code.as_python_constant()
def get_function(self):
if self.closure:
raise NotImplementedError()
func = types.FunctionType(
self.code.as_python_constant(),
self.f_globals,
self.fn_name.as_python_constant(),
)
if self.defaults:
func.__defaults__ = self.defaults.as_python_constant()
if self.kwdefaults:
func.__kwdefaults__ = self.kwdefaults.as_python_constant()
if self.annotations:
annotations = self.annotations.as_python_constant()
if isinstance(annotations, tuple):
from itertools import pairwise
annotations = dict(pairwise(annotations))
# TypeError: __annotations__ must be set to a dict object
assert isinstance(annotations, dict)
func.__annotations__ = annotations
return func
def has_closure(self):
return self.closure is not None
def has_self(self):
return False
def get_globals(self):
return self.f_globals
def bind_args(self, parent, args, kwargs):
from .misc import InlinedClosureVariable
code = self.get_code()
func = types.FunctionType(
code,
self.f_globals,
self.fn_name.as_python_constant(),
tuple(self.defaults.items) if self.defaults else None,
tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
)
if self.kwdefaults:
func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant()
bound = inspect.signature(func).bind(*args, **kwargs)
bound.apply_defaults()
result = dict(bound.arguments.items())
wrap_args_kwargs(parent.output.root_tx, result)
closure_cells = init_cellvars(parent, result, code)
for idx, name in enumerate(code.co_freevars):
cell = self.closure.items[idx]
assert getattr(cell, name, name) == name
assert name not in result
if isinstance(cell, InlinedClosureVariable):
# InlinedClosureVariable's are created from LOAD_CLOSURE's from
# InliningInstructionTranslators when the variable name is not found in closure_cells.
# They should remain outside of closure_cells, so that our callee (the
# InliningInstructionTranslator that traces `func`) handles
# the cell correctly - that is, the cell's contents are treated as if they
# are local variables, like in UserFunctionVariable's bind_args for freevars.
cand = parent
while cand and name not in cand.symbolic_locals:
cand = cand.parent
if cand is None:
raise RuntimeError(
f"Couldn't find {name} in the symbolic_locals of the inline interpreter stack"
)
result[name] = cand.symbolic_locals[name]
else:
closure_cells[name] = self.closure.items[idx]
return result, closure_cells
def export_freevars(self, parent, child):
code = self.get_code()
for var in code.co_freevars:
if var in child.symbolic_locals:
parent.symbolic_locals[var] = child.symbolic_locals[var]
def reconstruct(self, codegen):
codegen.load_import_from(__name__, "_create_nested_fn")
codegen(self.code)
codegen.extend_output([codegen._create_load_const(self.f_globals)])
codegen(self.fn_name)
if self.defaults:
codegen(self.defaults)
else:
codegen.extend_output([codegen.create_load_const(None)])
if self.closure:
codegen(self.closure)
else:
codegen.extend_output([codegen.create_load_const(None)])
if self.kwdefaults:
codegen(self.kwdefaults)
else:
codegen.extend_output([codegen.create_load_const(None)])
if self.annotations:
try:
annotations = self.annotations.as_python_constant()
codegen.extend_output([codegen._create_load_const(annotations)])
except NotImplementedError:
codegen(self.annotations)
else:
codegen.extend_output([codegen.create_load_const(None)])
codegen.extend_output(create_call_function(7, push_null=True))
if self.wrapped_reconstructible:
codegen.load_import_from("functools", "wraps")
codegen(self.wrapped_reconstructible)
codegen.extend_output(create_call_function(1, True))
codegen.extend_output(create_rot_n(2))
codegen.extend_output(create_call_function(1, True))
return []
def _traceable_collective_remaps():
# We can't rely on importing from distributed, since it's not always built
if torch.distributed.is_available():
from torch.distributed._functional_collectives import (
traceable_collective_remaps,
)
return traceable_collective_remaps
return {}
def _traceable_collectives_source(tx, fn):
assert torch.distributed.is_available(), "Illegal invocation."
assert fn in _traceable_collective_remaps().values()
inner_name = fn.__name__
path_source = tx.import_source("torch.distributed._functional_collectives")
return AttrSource(path_source, inner_name)
class CollectiveFunctionRewriteVariable(UserFunctionVariable):
"""
Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
This class provides both a way to check if a function is remappable, and perform the remapping.
In the case that a function is 'remappable' but only for some combinations of call-time arguments,
we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse
than status-quo as we currently graph-break on all distributed.* collectives.
"""
def __init__(self, fn, *, replacement_var, **kwargs):
super().__init__(fn, **kwargs)
assert isinstance(replacement_var, UserFunctionVariable)
self.replacement_var = replacement_var
@staticmethod
def create(tx, old_fn, source, **options):
new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn)
return CollectiveFunctionRewriteVariable(
old_fn,
replacement_var=UserFunctionVariable(new_fn, source=new_source, **options),
source=source,
**options,
)
@staticmethod
def can_rewrite(variable):
return (
inspect.isfunction(variable) and variable in _traceable_collective_remaps()
)
@staticmethod
def rewrite(tx, fn):
new_fn = _traceable_collective_remaps()[fn]
return new_fn, _traceable_collectives_source(tx, new_fn)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
# call_function must check any unsupported arguments and graph-break.
# It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
# since that's the contract for putting a mapping in `traceable_collective_remaps`
if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
unimplemented(
f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}"
)
return self.replacement_var.call_function(tx, args, kwargs)
class FunctoolsPartialVariable(VariableTracker):
def __init__(self, func: VariableTracker, args, keywords, original=None, **kwargs):
super().__init__(**kwargs)
self.func = func
assert isinstance(args, list)
self.args = args
assert isinstance(keywords, dict)
self.keywords = keywords
self.original = original
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
merged_args = self.args + args
merged_kwargs = {**self.keywords, **kwargs}
return self.func.call_function(tx, merged_args, merged_kwargs)
def call_hasattr(self, tx, name: str) -> VariableTracker:
from .constant import ConstantVariable
# reconstruct the partial without the keyword arguments
# This works as PyTorch does not allow mutating the partial variable
p = functools.partial(self.func.get_function())
r = hasattr(p, name)
return ConstantVariable.create(r)
def as_python_constant(self):
if self.original:
return self.original
else:
def get_val(v):
if isinstance(v, variables.UserDefinedObjectVariable):
return v.value
else:
return v.as_python_constant()
return functools.partial(
self.func.get_function(),
*[get_val(arg) for arg in self.args],
**{k: get_val(v) for k, v in self.keywords.items()},
)
class TritonKernelVariable(VariableTracker):
def __init__(self, kernel, kernel_idx, grid, **kwargs):
from triton.runtime.autotuner import Autotuner
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
super().__init__(**kwargs)
assert kernel is not None
self.kernel = kernel
self.kernel_idx = kernel_side_table.add_kernel(kernel)
assert kernel_idx is None or self.kernel_idx == kernel_idx
self.grid = grid
if isinstance(kernel, Autotuner):
# We only support configs and keys arguments of triton.autotune
# Make sure other arguments are defaulted
defaults = inspect.signature(Autotuner.__init__).parameters
# Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
# The call to get_first_attr is to maintain backward-compatibility.
if (
(
"warmup" in defaults
and defaults["warmup"].default
!= get_first_attr(kernel, "num_warmups", "warmup")
)
or (
"rep" in defaults
and defaults["rep"].default
!= get_first_attr(kernel, "num_reps", "rep")
)
or (
"prune_configs_by" in defaults
and defaults["prune_configs_by"].default
!= kernel.early_config_prune
)
):
raise Unsupported(
"Only configs and keys are supported for triton.autotune"
)
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from triton.runtime.autotuner import Autotuner
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import BaseListVariable
if self.grid is None:
raise Unsupported("Triton kernels should always be called with a grid")
# Both for grid's meta as well as for the kernel, we need combined
# args and kwargs normalized
names = (
variables.ConstantVariable.create(name) for name in self.kernel.arg_names
)
kwargs = {variables.ConstantVariable.create(k): v for k, v in kwargs.items()}
normalized_args = {**dict(zip(names, args)), **kwargs}
configs = (
[config.kwargs for config in self.kernel.configs]
if isinstance(self.kernel, Autotuner)
else [{}]
)
grids = []
for config_args in configs:
# If the grid is a function, then lets execute it and convert it to
# a list
grid = self.grid
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
# Populate the special "meta" argument to call the grid function
config_args = {
ConstantVariable.create(k): ConstantVariable.create(v)
for k, v in config_args.items()
}
meta = ConstDictVariable({**normalized_args, **config_args}, dict)
grid = grid.call_function(tx, [meta], {})
# Now, the grid must be a list either originally or through above
# modification
if isinstance(grid, BaseListVariable):
grids.append(grid.as_proxy())
else:
unimplemented(f"grid for the triton kernel is {type(grid)}")
for i in range(len(grids)):
if not isinstance(grids[i], tuple):
raise Unsupported("Only tuple grids are supported")
# inductor expects all grids to be 3-tuple so lets make it
if len(grids[i]) == 1:
grids[i] = (grids[i][0], 1, 1)
elif len(grids[i]) == 2:
grids[i] = (grids[i][0], grids[i][1], 1)
elif len(grids[i]) > 3:
raise Unsupported("Grid can have at most rank 3")
assert len(grids) != 0
if len(set(grids)) == 1:
# If there's only one unique grid, lets simplify
grids = [grids[0]]
from torch._higher_order_ops.triton_kernel_wrap import (
triton_kernel_wrapper_mutation,
)
# Combine args and kwargs and pass as a dict so that if user defined triton
# kernel uses variables as 'grid' or 'kernel', it does not conflict with
# parameters of the wrapper function
meta = ConstDictVariable(normalized_args, dict)
tx.output.create_proxy(
"call_function",
triton_kernel_wrapper_mutation,
(),
{
"kernel_idx": self.kernel_idx,
"grid": grids,
"kwargs": meta.as_proxy(),
},
)
return variables.ConstantVariable(
None,
)
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__getitem__":
# __getitem__ should only be called if we don't already have a grid
# Only grid needs to be passed
if self.grid is not None or len(args) != 1:
raise Unsupported(
"Triton kernels should be called with only a single grid"
)
return TritonKernelVariable(
kernel=self.kernel,
kernel_idx=self.kernel_idx,
grid=args[0],
)
elif name == "run":
if "grid" not in kwargs:
raise Unsupported("Triton kernel requires to be called with a grid")
grid = kwargs.pop("grid")
kwargs.pop("warmup", None)
# rewrite kernel.run(*args, grid=grid) to kernel[grid](*args)
return TritonKernelVariable(
kernel=self.kernel, kernel_idx=self.kernel_idx, grid=grid
).call_function(tx, args, kwargs)
# Bail out to parent's implementation
return super().call_method(tx, name, args, kwargs)