mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
The main thrust of the initial effort here was to capture `register_hook` calls on tensors in compile regions. The first part of this was done in https://github.com/pytorch/pytorch/pull/108903 wherein we added support for register_hook input tensors. The distinction between input and intermediary is due to implementation differences. There are 2 kinds of hooks: 1) Hooks on objects with sources (inputs, params) 2) Hooks on objects w/o sources (intermediaries, and outputs). Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced). **The plan:** For tensors w/ a source: (The PR above) We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call register_hook. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke register_hook on. As long as we guard on the identity of the lifted function, this is sound to do. For tensors w/o a source: (This PR) Ostensibly, the most correct and complete solution would be to smuggle hooks into a runtime wrapper in aot_autograd, where all the items the hooks close over are lifted to inputs as necessary and passed alongside the user provided function. This is necessary so that we can properly trace out and capture all the mutations within the user defined hook at backwards time. This is too complicated - so, we limited the scope of this initial PR to a simple subset of hooks: - Hooks must have a source (be known to us already, not a lambda or intermediary defined function) - We must be tracing under compiled autograd **The flow**: We use the HOP added in https://github.com/pytorch/pytorch/pull/109690/files, referred to as the HOP below. 1) We intercept register_hook calls and wrap the user defined fn in the HOP 2) We write a `_register_hook_trampoline` to the graph that is a local no-arg function that is invoked as a call_function in the dynamo graph 3) aot_autograd inlines through it during its trace, and sees the HOP 4) the HOP preserves itself in the graph - it does not get traced into 5) During backwards, compiled_autograd installs the HOP under a hook call 6) When compiled_autograd enters compilation over its generated graph, dynamo traces the contents of the hook Pull Request resolved: https://github.com/pytorch/pytorch/pull/109537 Approved by: https://github.com/ezyang
742 lines
26 KiB
Python
742 lines
26 KiB
Python
import functools
|
|
import inspect
|
|
import itertools
|
|
import types
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
|
|
from .. import variables
|
|
from ..bytecode_transformation import create_call_function, create_rot_n
|
|
from ..exc import unimplemented, Unsupported
|
|
from ..source import (
|
|
AttrSource,
|
|
ConstantSource,
|
|
DefaultsSource,
|
|
GetItemSource,
|
|
GlobalSource,
|
|
)
|
|
from ..utils import make_cell
|
|
from .base import typestr, VariableTracker
|
|
|
|
|
|
def wrap_bound_arg(tx, val, options, source=None):
|
|
# Source propagation is best effort since not every object we encounter has a source to begin with.
|
|
assert (
|
|
"source" not in options
|
|
), "Source needs to be separate from options due to recursive calls for lists/dicts"
|
|
if isinstance(val, VariableTracker):
|
|
return val
|
|
elif not source:
|
|
from torch._dynamo.variables.builder import SourcelessBuilder
|
|
|
|
return SourcelessBuilder()(tx, val).add_options(options)
|
|
else:
|
|
from torch._dynamo.variables.builder import VariableBuilder
|
|
|
|
return VariableBuilder(tx, source=source)(val).add_options(options)
|
|
|
|
|
|
def wrap_args_kwargs(tx, result, options):
|
|
for k, v in list(result.items()):
|
|
if isinstance(v, (tuple, dict)):
|
|
# args/kwargs
|
|
result[k] = wrap_bound_arg(tx, v, options)
|
|
|
|
|
|
def init_cellvars(parent, result, code):
|
|
closure_cells = dict()
|
|
side_effects = parent.output.side_effects
|
|
|
|
# for name in itertools.chain(code.co_cellvars, code.co_freevars):
|
|
for name in code.co_cellvars:
|
|
closure_cells[name] = side_effects.track_cell_new()
|
|
if name in result:
|
|
side_effects.store_cell(closure_cells[name], result.pop(name))
|
|
|
|
return closure_cells
|
|
|
|
|
|
def _create_nested_fn(
|
|
code, f_globals, name, defaults, closure, kwdefaults, annotations
|
|
):
|
|
from types import FunctionType
|
|
|
|
func = FunctionType(code, f_globals, name, defaults, closure)
|
|
func.__kwdefaults__ = kwdefaults
|
|
|
|
if isinstance(annotations, tuple):
|
|
from itertools import pairwise
|
|
|
|
annotations = dict(pairwise(annotations))
|
|
|
|
# TypeError: __annotations__ must be set to a dict object
|
|
assert annotations is None or isinstance(annotations, dict)
|
|
func.__annotations__ = annotations
|
|
|
|
return func
|
|
|
|
|
|
class BaseUserFunctionVariable(VariableTracker):
|
|
def get_filename(self):
|
|
return self.get_code().co_filename
|
|
|
|
def get_name(self):
|
|
return self.get_code().co_name
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
return tx.inline_user_function_return(
|
|
self, list(self.self_args()) + list(args), kwargs
|
|
)
|
|
|
|
def num_parameters(self):
|
|
return len(inspect.signature(self.get_function()).parameters)
|
|
|
|
def closure_vars(self, tx):
|
|
return {}
|
|
|
|
|
|
class UserFunctionVariable(BaseUserFunctionVariable):
|
|
"""Some unsupported user-defined global function"""
|
|
|
|
def __init__(self, fn, is_constant=False, **kwargs):
|
|
super().__init__(**kwargs)
|
|
if getattr(fn, "_dynamo_marked_constant", False):
|
|
# This method should be treated as a constant for the purposes of compilation
|
|
self.is_constant = True
|
|
else:
|
|
self.is_constant = False
|
|
|
|
assert isinstance(
|
|
fn, (types.FunctionType, torch.jit.ScriptFunction)
|
|
), f"expected FunctionType found {typestr(fn)} {fn}"
|
|
# unpack @torch._dynamo.optimize()(fn) wrapped function
|
|
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
|
# unpack torch.jit.script_if_tracing
|
|
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
|
|
fn = inspect.getattr_static(fn, "__original_fn", fn)
|
|
self.fn: types.FunctionType = fn
|
|
|
|
def self_args(self):
|
|
return []
|
|
|
|
def get_function(self):
|
|
return self.fn
|
|
|
|
def get_code(self):
|
|
return self.fn.__code__
|
|
|
|
def python_type(self):
|
|
return types.FunctionType
|
|
|
|
def has_self(self):
|
|
return getattr(self.fn, "__self__", None) is not None
|
|
|
|
def get_globals(self):
|
|
return self.fn.__globals__
|
|
|
|
def bind_args(self, parent, args, kwargs):
|
|
assert not self.is_constant
|
|
options = VariableTracker.propagate([self])
|
|
tx = parent.output.root_tx
|
|
wrap = functools.partial(wrap_bound_arg, tx=tx, options=options)
|
|
|
|
fn: types.FunctionType = self.fn
|
|
defaults = fn.__defaults__ or []
|
|
defaults_sources = [
|
|
None if self.source is None else DefaultsSource(self.source, idx)
|
|
for idx, _ in enumerate(defaults)
|
|
]
|
|
fake_func = types.FunctionType(
|
|
fn.__code__,
|
|
fn.__globals__,
|
|
fn.__name__,
|
|
tuple(
|
|
[
|
|
wrap(val=arg, source=source)
|
|
for arg, source in zip(defaults, defaults_sources)
|
|
]
|
|
),
|
|
fn.__closure__,
|
|
)
|
|
if fn.__kwdefaults__:
|
|
kwdefaults_sources = {
|
|
k: None
|
|
if self.source is None
|
|
else DefaultsSource(self.source, k, is_kw=True)
|
|
for k in fn.__kwdefaults__
|
|
}
|
|
fake_func.__kwdefaults__ = {
|
|
k: wrap(val=v, source=kwdefaults_sources[k])
|
|
for k, v in fn.__kwdefaults__.items()
|
|
}
|
|
|
|
bound = inspect.signature(fake_func).bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
result = dict(bound.arguments.items())
|
|
|
|
wrap_args_kwargs(tx, result, options)
|
|
closure_cells = init_cellvars(parent, result, fn.__code__)
|
|
closure = self.fn.__closure__ or ()
|
|
assert len(closure) == len(self.fn.__code__.co_freevars)
|
|
for idx, name, cell in zip(
|
|
itertools.count(), self.fn.__code__.co_freevars, closure
|
|
):
|
|
if name == "__class__":
|
|
source = AttrSource(self.source, "__class__") if self.source else None
|
|
result[name] = variables.UserDefinedClassVariable(
|
|
cell.cell_contents,
|
|
source=source,
|
|
)
|
|
else:
|
|
var = tx.match_nested_cell(name, cell)
|
|
if var is not None:
|
|
# optimization for cleaner codegen
|
|
result[name] = var
|
|
elif self.source:
|
|
from .builder import VariableBuilder
|
|
|
|
side_effects = parent.output.side_effects
|
|
if cell in side_effects:
|
|
out = side_effects[cell]
|
|
else:
|
|
closure_cell = GetItemSource(
|
|
AttrSource(self.source, "__closure__"), idx
|
|
)
|
|
closure_cell_contents = AttrSource(
|
|
closure_cell, "cell_contents"
|
|
)
|
|
contents_var = VariableBuilder(parent, closure_cell_contents)(
|
|
cell.cell_contents
|
|
)
|
|
|
|
if (
|
|
closure_cell_contents.name()
|
|
not in tx.mutated_closure_cell_contents
|
|
):
|
|
# Optimistically don't allocate the cell, to
|
|
# reduce the number of side effects. This is
|
|
# important for cond, as without it, any accesses
|
|
# to closures create side effects and cond doesn't
|
|
# support side effects. If we're wrong and this
|
|
# closure cell gets written to, we will restart
|
|
# the analysis with this cell's name in the
|
|
# mutated list here
|
|
result[name] = contents_var
|
|
continue
|
|
|
|
# cells are written to with "cell_contents",
|
|
# so the source should just be the closure_cell, not its contents
|
|
out = side_effects.track_cell_existing(closure_cell, cell)
|
|
side_effects.store_cell(
|
|
out,
|
|
contents_var,
|
|
)
|
|
|
|
result[name] = out
|
|
|
|
else:
|
|
from .builder import SourcelessBuilder
|
|
|
|
result[name] = SourcelessBuilder()(
|
|
tx, cell.cell_contents
|
|
).add_options(options)
|
|
|
|
return result, closure_cells
|
|
|
|
def export_freevars(self, parent, child):
|
|
pass
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
if self.is_constant:
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
return invoke_and_store_as_constant(
|
|
tx, self.fn, self.get_name(), options, args, kwargs
|
|
)
|
|
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
|
|
class UserMethodVariable(UserFunctionVariable):
|
|
"""Some unsupported user-defined method"""
|
|
|
|
def __init__(self, fn, obj, **kwargs):
|
|
super().__init__(fn=fn, **kwargs)
|
|
self.obj = obj
|
|
|
|
def __str__(self):
|
|
return f"{self.__class__.__name__}({self.fn}, {self.obj})"
|
|
|
|
def self_args(self):
|
|
return [self.obj]
|
|
|
|
def python_type(self):
|
|
return types.MethodType
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
# For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
|
|
# rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
|
|
# since we ensure `forward` of allowed modules can be traced by AOT safely.
|
|
# Note this is not only for allowed modules, as user customized modules can extend from
|
|
# allowed modules but using parent's `forward` method, which is also covered by this branch.
|
|
|
|
# If we are tracing the higher order op, we want Dynamo to step inside
|
|
# the module call so that Dynamo can see the underlying parameters and
|
|
# buffers and raise them as inputs to the graph. The is_root_tracer
|
|
# check bypasses the if condition for non-root tracers and directly
|
|
# calls the super().call_function at the end, which is basically
|
|
# equivalent of inlining the method.
|
|
if tx.output.is_root_tracer() and isinstance(
|
|
self.obj, variables.NNModuleVariable
|
|
):
|
|
module_attr = getattr(self.fn, "__module__", "")
|
|
if (
|
|
module_attr is not None
|
|
and module_attr.startswith("torch.nn.")
|
|
or self.is_constant
|
|
):
|
|
return self.obj.call_method(
|
|
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
|
|
).add_options(self)
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
def num_parameters(self):
|
|
return super().num_parameters() - 1
|
|
|
|
|
|
class WrappedUserMethodVariable(UserMethodVariable):
|
|
def __init__(self, wrapped, context, **kwargs):
|
|
kwargs.pop("fn", None)
|
|
kwargs.pop("obj", None)
|
|
super().__init__(wrapped.fn, wrapped.obj, **kwargs)
|
|
self.wrapped = wrapped
|
|
self.context = context
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
self.context.enter(tx)
|
|
result = super().call_function(tx, args, kwargs)
|
|
self.context.exit(tx)
|
|
return result
|
|
|
|
|
|
class WrappedUserFunctionVariable(UserFunctionVariable):
|
|
def __init__(self, wrapped, context, **kwargs):
|
|
kwargs.pop("fn", None)
|
|
kwargs.pop("obj", None)
|
|
super().__init__(wrapped.fn, **kwargs)
|
|
self.wrapped = wrapped
|
|
self.context = context
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
self.context.enter(tx)
|
|
result = super().call_function(tx, args, kwargs)
|
|
self.context.exit(tx)
|
|
return result
|
|
|
|
|
|
def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs):
|
|
def convert(x):
|
|
if isinstance(x, variables.TensorVariable):
|
|
return x.get_real_value()
|
|
return x.as_python_constant()
|
|
|
|
args = [convert(x) for x in args]
|
|
kwargs = {k: convert(v) for k, v in kwargs.items()}
|
|
res = fn(*args, **kwargs)
|
|
return tx.output.register_attr_or_module(
|
|
res,
|
|
name,
|
|
source=ConstantSource(name),
|
|
**options,
|
|
)
|
|
|
|
|
|
class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|
def __init__(
|
|
self,
|
|
fn_name,
|
|
code,
|
|
f_globals,
|
|
defaults,
|
|
kwdefaults,
|
|
annotations,
|
|
closure,
|
|
closure_scope,
|
|
wraps_source=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
assert isinstance(fn_name.as_python_constant(), str)
|
|
assert isinstance(code.as_python_constant(), types.CodeType)
|
|
assert isinstance(f_globals, dict)
|
|
self.fn_name = fn_name
|
|
self.code = code
|
|
self.f_globals = f_globals
|
|
self.defaults = defaults
|
|
self.kwdefaults = kwdefaults
|
|
self.annotations = annotations
|
|
self.closure = closure
|
|
if closure is None:
|
|
closure_scope = None
|
|
self.closure_scope = closure_scope
|
|
self.wraps_source = wraps_source
|
|
|
|
def self_args(self):
|
|
return []
|
|
|
|
def get_code(self):
|
|
return self.code.as_python_constant()
|
|
|
|
def get_function(self):
|
|
if self.closure:
|
|
raise NotImplementedError()
|
|
func = types.FunctionType(
|
|
self.code.as_python_constant(),
|
|
self.f_globals,
|
|
self.fn_name.as_python_constant(),
|
|
)
|
|
if self.defaults:
|
|
func.__defaults__ = self.defaults.as_python_constant()
|
|
if self.kwdefaults:
|
|
func.__kwdefaults__ = self.kwdefaults.as_python_constant()
|
|
if self.annotations:
|
|
annotations = self.annotations.as_python_constant()
|
|
if isinstance(annotations, tuple):
|
|
from itertools import pairwise
|
|
|
|
annotations = dict(pairwise(annotations))
|
|
|
|
# TypeError: __annotations__ must be set to a dict object
|
|
assert isinstance(annotations, dict)
|
|
func.__annotations__ = annotations
|
|
return func
|
|
|
|
def has_closure(self):
|
|
return self.closure is not None
|
|
|
|
def has_self(self):
|
|
return False
|
|
|
|
def get_globals(self):
|
|
return self.f_globals
|
|
|
|
def bind_args(self, parent, args, kwargs):
|
|
from .misc import InlinedClosureVariable
|
|
|
|
code = self.get_code()
|
|
func = types.FunctionType(
|
|
code,
|
|
self.f_globals,
|
|
self.fn_name.as_python_constant(),
|
|
tuple(self.defaults.items) if self.defaults else None,
|
|
tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
|
|
)
|
|
if self.kwdefaults:
|
|
func.__kwdefaults__ = self.kwdefaults.items
|
|
bound = inspect.signature(func).bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
result = dict(bound.arguments.items())
|
|
wrap_args_kwargs(parent.output.root_tx, result, VariableTracker.propagate(self))
|
|
closure_cells = init_cellvars(parent, result, code)
|
|
|
|
for idx, name in enumerate(code.co_freevars):
|
|
cell = self.closure.items[idx]
|
|
assert getattr(cell, name, name) == name
|
|
assert name not in result
|
|
if isinstance(cell, InlinedClosureVariable):
|
|
# InlinedClosureVariable's are created from LOAD_CLOSURE's from
|
|
# InliningInstructionTranslators when the variable name is not found in closure_cells.
|
|
# They should remain outside of closure_cells, so that our callee (the
|
|
# InliningInstructionTranslator that traces `func`) handles
|
|
# the cell correctly - that is, the cell's contents are treated as if they
|
|
# are local variables, like in UserFunctionVariable's bind_args for freevars.
|
|
cand = parent
|
|
while cand and name not in cand.symbolic_locals:
|
|
cand = cand.parent
|
|
if cand is None:
|
|
raise RuntimeError(
|
|
f"Couldn't find {name} in the symbolic_locals of the inline interpreter stack"
|
|
)
|
|
result[name] = cand.symbolic_locals[name]
|
|
else:
|
|
closure_cells[name] = self.closure.items[idx]
|
|
|
|
return result, closure_cells
|
|
|
|
def export_freevars(self, parent, child):
|
|
code = self.get_code()
|
|
for var in code.co_freevars:
|
|
if var in child.symbolic_locals:
|
|
parent.symbolic_locals[var] = child.symbolic_locals[var]
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from(__name__, "_create_nested_fn")
|
|
codegen(self.code)
|
|
codegen.extend_output([codegen._create_load_const(self.f_globals)])
|
|
codegen(self.fn_name)
|
|
|
|
if self.defaults:
|
|
codegen(self.defaults)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
if self.closure:
|
|
codegen(self.closure)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
if self.kwdefaults:
|
|
codegen(self.kwdefaults)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
if self.annotations:
|
|
try:
|
|
if isinstance(self.annotations, variables.ConstDictVariable):
|
|
annotations = {
|
|
k: v.as_python_constant()
|
|
for k, v in self.annotations.items.items()
|
|
}
|
|
else:
|
|
annotations = tuple(
|
|
[v.as_python_constant() for v in self.annotations.items]
|
|
)
|
|
codegen.extend_output([codegen._create_load_const(annotations)])
|
|
except NotImplementedError:
|
|
codegen(self.annotations)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
codegen.extend_output(create_call_function(7, push_null=True))
|
|
|
|
if self.wraps_source:
|
|
codegen.load_import_from("functools", "wraps")
|
|
codegen(self.wraps_source)
|
|
codegen.extend_output(create_call_function(1, True))
|
|
codegen.extend_output(create_rot_n(2))
|
|
codegen.extend_output(create_call_function(1, True))
|
|
|
|
return []
|
|
|
|
|
|
def _traceable_collective_remaps():
|
|
# We can't rely on importing from distributed, since its not always built
|
|
if torch.distributed.is_available():
|
|
from torch.distributed._functional_collectives import (
|
|
traceable_collective_remaps,
|
|
)
|
|
|
|
return traceable_collective_remaps
|
|
return {}
|
|
|
|
|
|
def _traceable_collectives_source(fn):
|
|
assert torch.distributed.is_available(), "Illegal invocation."
|
|
from torch.distributed._functional_collectives import (
|
|
all_gather_tensor_inplace,
|
|
reduce_scatter_tensor_inplace,
|
|
)
|
|
|
|
valid_values = {all_gather_tensor_inplace, reduce_scatter_tensor_inplace}
|
|
assert fn in valid_values
|
|
inner_name = fn.__name__
|
|
path_source = AttrSource(
|
|
base=AttrSource(base=GlobalSource(global_name="torch"), member="distributed"),
|
|
member="_functional_collectives",
|
|
)
|
|
return AttrSource(path_source, inner_name)
|
|
|
|
|
|
class CollectiveFunctionRewriteVariable(UserFunctionVariable):
|
|
"""
|
|
Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
|
|
|
|
This class provides both a way to check if a function is remappable, and perform the remapping.
|
|
|
|
In the case that a function is 'remappable' but only for some combinations of call-time arguments,
|
|
we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse
|
|
than status-quo as we currently graph-break on all distributed.* collectives.
|
|
"""
|
|
|
|
def __init__(self, fn, *, orig_fn, orig_source, **kwargs):
|
|
# orig_fn lets us implement any fn-specific args/kwargs restrictions inside call_function
|
|
self.orig_fn = orig_fn
|
|
self.orig_source = orig_source
|
|
|
|
# remapped_fn gets stuffed in self.fn and used in super().call_function
|
|
super().__init__(fn, **kwargs)
|
|
|
|
@staticmethod
|
|
def can_rewrite(variable):
|
|
return (
|
|
inspect.isfunction(variable) and variable in _traceable_collective_remaps()
|
|
)
|
|
|
|
@staticmethod
|
|
def rewrite(fn):
|
|
new_fn = _traceable_collective_remaps()[fn]
|
|
return new_fn, _traceable_collectives_source(new_fn)
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
# call_function must check any unsupported arguments and graph-break.
|
|
# It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
|
|
# since that's the contract for putting a mapping in `traceable_collective_remaps`
|
|
if kwargs.get("async_op", False):
|
|
# Put the old source back, this function will always graph break, but this ensures
|
|
# we produce the correct guards.
|
|
self.source = self.orig_source
|
|
unimplemented(
|
|
f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.orig_fn}"
|
|
)
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
|
|
class FunctoolsPartialVariable(VariableTracker):
|
|
def __init__(self, func, args, keywords, original=None, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.func = func
|
|
assert isinstance(args, list)
|
|
self.args = args
|
|
assert isinstance(keywords, dict)
|
|
self.keywords = keywords
|
|
self.original = original
|
|
|
|
self.guards.update(VariableTracker.propagate(func)["guards"])
|
|
for arg in args:
|
|
self.guards.update(VariableTracker.propagate(arg)["guards"])
|
|
for val in keywords.values():
|
|
self.guards.update(VariableTracker.propagate(val)["guards"])
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
options = VariableTracker.propagate([self])
|
|
merged_args = self.args + args
|
|
merged_kwargs = {**self.keywords, **kwargs}
|
|
|
|
return self.func.call_function(tx, merged_args, merged_kwargs).add_options(
|
|
options
|
|
)
|
|
|
|
def as_python_constant(self):
|
|
if self.original:
|
|
return self.original
|
|
else:
|
|
|
|
def get_val(v):
|
|
if isinstance(v, variables.UserDefinedObjectVariable):
|
|
return v.value
|
|
else:
|
|
return v.as_python_constant()
|
|
|
|
return functools.partial(
|
|
self.func.fn,
|
|
*[get_val(arg) for arg in self.args],
|
|
**{k: get_val(v) for k, v in self.keywords.items()},
|
|
)
|
|
|
|
|
|
class TritonKernelVariable(VariableTracker):
|
|
def __init__(self, kernel, kernel_idx, grid, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
|
|
|
assert kernel is not None
|
|
|
|
self.kernel = kernel
|
|
self.kernel_idx = kernel_side_table.add_kernel(kernel)
|
|
|
|
assert kernel_idx is None or self.kernel_idx == kernel_idx
|
|
|
|
self.grid = grid
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
from .dicts import ConstDictVariable
|
|
from .lists import BaseListVariable
|
|
|
|
grid = self.grid
|
|
|
|
if grid is None:
|
|
raise Unsupported("Triton kernels should always be called with a grid")
|
|
|
|
# Both for grid's meta as well as for the kernel, we need combined
|
|
# args and kwargs normalized
|
|
normalized_args = {**dict(zip(self.kernel.arg_names, args)), **kwargs}
|
|
meta = ConstDictVariable(normalized_args, dict)
|
|
|
|
# If the grid is a function, then lets execute it and convert it to
|
|
# a list
|
|
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
|
|
# Populate the special "meta" argument to call the grid function
|
|
grid = grid.call_function(tx, [meta], {})
|
|
|
|
# Now, the grid must be a list either originally or through above
|
|
# modification
|
|
if isinstance(grid, BaseListVariable):
|
|
grid = grid.as_proxy()
|
|
else:
|
|
unimplemented(f"grid for the triton kernel is {type(grid)}")
|
|
|
|
from torch._higher_order_ops.triton_kernel_wrap import (
|
|
triton_kernel_wrapper_mutation,
|
|
)
|
|
|
|
# Combine args and kwargs and pass as a dict so that if user defined triton
|
|
# kernel uses variables as 'grid' or 'kernel', it does not conflict with
|
|
# parameters of the wrapper function
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
triton_kernel_wrapper_mutation,
|
|
(),
|
|
{
|
|
"kernel_idx": self.kernel_idx,
|
|
"grid": grid,
|
|
"kwargs": meta.as_proxy(),
|
|
},
|
|
)
|
|
|
|
return variables.ConstantVariable(
|
|
None,
|
|
**VariableTracker.propagate(self, args, kwargs.values()),
|
|
)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
if name == "__getitem__":
|
|
# __getitem__ should only be called if we don't already have a grid
|
|
# Only grid needs to be passed
|
|
if self.grid is not None or len(args) != 1:
|
|
raise Unsupported(
|
|
"Triton kernels should be called with only a single grid"
|
|
)
|
|
|
|
return TritonKernelVariable(
|
|
kernel=self.kernel,
|
|
kernel_idx=self.kernel_idx,
|
|
grid=args[0],
|
|
**VariableTracker.propagate(self),
|
|
)
|
|
|
|
# Bail out to parent's implementation
|
|
return super().call_method(tx, name, args, kwargs)
|