mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
1. Removes calls to `replace_all` and `clone` and makes VTs mutable. 2. Properly handles Tuple Iterator mutation. Previously TupleIterator variables would only be properly reconstructed if they were advanced at least once in a frame. On calls to `next`, the source information would be lost (due to constructing a new iterator without using builder), which would ensure that during codegen the variable would be reconstructed from scratch. Now that VTs are mutated, the source is never lost, so we need to properly track mutation and handle it by replaying calls to `next` at the end of the modified bytecode. 3. Added test for checking iadd side effects, this was missing in our unit test coverage. 4. Fixed two incorrect sources, DelayGraphBreakVariable, and UserMethodVariable both relied on setting the source to AttrSource(parent, name) at the callsite of `var_getattr`. 5. Fixed a bug in inplace adding for lists, it would set the resulting VariableTracker's source to `None` which would utilize a different reconstruct path in codegen. Now this is handled explicitly by reconstructing vars when allow_cache=`False`, so that during side effect replay, the mutated var is correctly updated. In subsequent PRs: * Refactoring side effect tracking to be significantly simpler (I think we only need an `is_modified` flag) * Refactor `next_variables` iterator to match the signature of `next` * Remove all references to `options` in the code * Refactor VTs representing mutable collections to implement their own mutation update handling * Remove clone and/or make it specific to lists for creating slices * Add mutation tracking/replay for sets * Add mutation tracking/replay for iter.py * Removing setting source in builder (it's set at the top level after a var is returned) Pull Request resolved: https://github.com/pytorch/pytorch/pull/113725 Approved by: https://github.com/jansel
782 lines
28 KiB
Python
782 lines
28 KiB
Python
import functools
|
|
import inspect
|
|
import itertools
|
|
import types
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
|
|
from .. import variables
|
|
from ..bytecode_transformation import create_call_function, create_rot_n
|
|
from ..exc import unimplemented, Unsupported
|
|
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
|
from ..utils import make_cell
|
|
from .base import typestr, VariableTracker
|
|
|
|
|
|
def wrap_bound_arg(tx, val, source=None):
|
|
# Source propagation is best effort since not every object we encounter has a source to begin with.
|
|
if isinstance(val, VariableTracker):
|
|
return val
|
|
elif not source:
|
|
from torch._dynamo.variables.builder import SourcelessBuilder
|
|
|
|
return SourcelessBuilder()(tx, val)
|
|
else:
|
|
from torch._dynamo.variables.builder import VariableBuilder
|
|
|
|
return VariableBuilder(tx, source=source)(val)
|
|
|
|
|
|
def wrap_args_kwargs(tx, result):
|
|
for k, v in list(result.items()):
|
|
if isinstance(v, (tuple, dict)):
|
|
# args/kwargs
|
|
result[k] = wrap_bound_arg(tx, v)
|
|
|
|
|
|
def init_cellvars(parent, result, code):
|
|
closure_cells = dict()
|
|
side_effects = parent.output.side_effects
|
|
|
|
# for name in itertools.chain(code.co_cellvars, code.co_freevars):
|
|
for name in code.co_cellvars:
|
|
closure_cells[name] = side_effects.track_cell_new()
|
|
if name in result:
|
|
side_effects.store_cell(closure_cells[name], result.pop(name))
|
|
|
|
return closure_cells
|
|
|
|
|
|
def _create_nested_fn(
|
|
code, f_globals, name, defaults, closure, kwdefaults, annotations
|
|
):
|
|
from types import FunctionType
|
|
|
|
func = FunctionType(code, f_globals, name, defaults, closure)
|
|
func.__kwdefaults__ = kwdefaults
|
|
|
|
if isinstance(annotations, tuple):
|
|
from itertools import pairwise
|
|
|
|
annotations = dict(pairwise(annotations))
|
|
|
|
# TypeError: __annotations__ must be set to a dict object
|
|
assert annotations is None or isinstance(annotations, dict)
|
|
func.__annotations__ = annotations
|
|
|
|
return func
|
|
|
|
|
|
class BaseUserFunctionVariable(VariableTracker):
|
|
def get_filename(self):
|
|
return self.get_code().co_filename
|
|
|
|
def get_name(self):
|
|
return self.get_code().co_name
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
return tx.inline_user_function_return(
|
|
self, list(self.self_args()) + list(args), kwargs
|
|
)
|
|
|
|
def inspect_parameter_names(self):
|
|
return list(inspect.signature(self.get_function()).parameters)
|
|
|
|
def closure_vars(self, tx):
|
|
return {}
|
|
|
|
|
|
class UserFunctionVariable(BaseUserFunctionVariable):
|
|
"""Some unsupported user-defined global function"""
|
|
|
|
def __init__(self, fn, is_constant=False, **kwargs):
|
|
super().__init__(**kwargs)
|
|
if getattr(fn, "_dynamo_marked_constant", False):
|
|
# This method should be treated as a constant for the purposes of compilation
|
|
self.is_constant = True
|
|
else:
|
|
self.is_constant = False
|
|
|
|
assert isinstance(
|
|
fn, (types.FunctionType, torch.jit.ScriptFunction)
|
|
), f"expected FunctionType found {typestr(fn)} {fn}"
|
|
# unpack @torch._dynamo.optimize()(fn) wrapped function
|
|
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
|
# unpack torch.jit.script_if_tracing
|
|
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
|
|
fn = inspect.getattr_static(fn, "__original_fn", fn)
|
|
self.fn: types.FunctionType = fn
|
|
|
|
def self_args(self):
|
|
return []
|
|
|
|
def get_function(self):
|
|
return self.fn
|
|
|
|
def get_code(self):
|
|
return self.fn.__code__
|
|
|
|
def python_type(self):
|
|
return types.FunctionType
|
|
|
|
def has_self(self):
|
|
return getattr(self.fn, "__self__", None) is not None
|
|
|
|
def get_globals(self):
|
|
return self.fn.__globals__
|
|
|
|
def bind_args(self, parent, args, kwargs):
|
|
assert not self.is_constant
|
|
tx = parent.output.root_tx
|
|
wrap = functools.partial(wrap_bound_arg, tx=tx)
|
|
|
|
fn: types.FunctionType = self.fn
|
|
defaults = fn.__defaults__ or []
|
|
defaults_sources = [
|
|
None if self.source is None else DefaultsSource(self.source, idx)
|
|
for idx, _ in enumerate(defaults)
|
|
]
|
|
fake_func = types.FunctionType(
|
|
fn.__code__,
|
|
fn.__globals__,
|
|
fn.__name__,
|
|
tuple(
|
|
[
|
|
wrap(val=arg, source=source)
|
|
for arg, source in zip(defaults, defaults_sources)
|
|
]
|
|
),
|
|
fn.__closure__,
|
|
)
|
|
if fn.__kwdefaults__:
|
|
kwdefaults_sources = {
|
|
k: None
|
|
if self.source is None
|
|
else DefaultsSource(self.source, k, is_kw=True)
|
|
for k in fn.__kwdefaults__
|
|
}
|
|
fake_func.__kwdefaults__ = {
|
|
k: wrap(val=v, source=kwdefaults_sources[k])
|
|
for k, v in fn.__kwdefaults__.items()
|
|
}
|
|
|
|
bound = inspect.signature(fake_func).bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
result = dict(bound.arguments.items())
|
|
|
|
wrap_args_kwargs(tx, result)
|
|
closure_cells = init_cellvars(parent, result, fn.__code__)
|
|
closure = self.fn.__closure__ or ()
|
|
assert len(closure) == len(self.fn.__code__.co_freevars)
|
|
for idx, name, cell in zip(
|
|
itertools.count(), self.fn.__code__.co_freevars, closure
|
|
):
|
|
if name == "__class__":
|
|
source = AttrSource(self.source, "__class__") if self.source else None
|
|
result[name] = variables.UserDefinedClassVariable(
|
|
cell.cell_contents,
|
|
source=source,
|
|
)
|
|
else:
|
|
var = tx.match_nested_cell(name, cell)
|
|
if var is not None:
|
|
# optimization for cleaner codegen
|
|
result[name] = var
|
|
elif self.source:
|
|
from .builder import VariableBuilder
|
|
|
|
side_effects = parent.output.side_effects
|
|
if cell in side_effects:
|
|
out = side_effects[cell]
|
|
else:
|
|
closure_cell = GetItemSource(
|
|
AttrSource(self.source, "__closure__"), idx
|
|
)
|
|
closure_cell_contents = AttrSource(
|
|
closure_cell, "cell_contents"
|
|
)
|
|
contents_var = VariableBuilder(parent, closure_cell_contents)(
|
|
cell.cell_contents
|
|
)
|
|
|
|
if (
|
|
closure_cell_contents.name()
|
|
not in tx.mutated_closure_cell_contents
|
|
):
|
|
# Optimistically don't allocate the cell, to
|
|
# reduce the number of side effects. This is
|
|
# important for cond, as without it, any accesses
|
|
# to closures create side effects and cond doesn't
|
|
# support side effects. If we're wrong and this
|
|
# closure cell gets written to, we will restart
|
|
# the analysis with this cell's name in the
|
|
# mutated list here
|
|
result[name] = contents_var
|
|
continue
|
|
|
|
# cells are written to with "cell_contents",
|
|
# so the source should just be the closure_cell, not its contents
|
|
out = side_effects.track_cell_existing(closure_cell, cell)
|
|
side_effects.store_cell(
|
|
out,
|
|
contents_var,
|
|
)
|
|
|
|
result[name] = out
|
|
|
|
else:
|
|
from .builder import SourcelessBuilder
|
|
|
|
result[name] = SourcelessBuilder()(tx, cell.cell_contents)
|
|
|
|
return result, closure_cells
|
|
|
|
def export_freevars(self, parent, child):
|
|
pass
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
if self.is_constant:
|
|
return invoke_and_store_as_constant(
|
|
tx, self.fn, self.get_name(), args, kwargs
|
|
)
|
|
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
|
|
class UserMethodVariable(UserFunctionVariable):
|
|
"""Some unsupported user-defined method"""
|
|
|
|
def __init__(self, fn, obj, **kwargs):
|
|
super().__init__(fn=fn, **kwargs)
|
|
self.obj = obj
|
|
|
|
def __str__(self):
|
|
return f"{self.__class__.__name__}({self.fn}, {self.obj})"
|
|
|
|
def self_args(self):
|
|
return [self.obj]
|
|
|
|
def python_type(self):
|
|
return types.MethodType
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
# For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
|
|
# rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
|
|
# since we ensure `forward` of allowed modules can be traced by AOT safely.
|
|
# Note this is not only for allowed modules, as user customized modules can extend from
|
|
# allowed modules but using parent's `forward` method, which is also covered by this branch.
|
|
|
|
# If we are tracing the higher order op, we want Dynamo to step inside
|
|
# the module call so that Dynamo can see the underlying parameters and
|
|
# buffers and raise them as inputs to the graph. The is_root_tracer
|
|
# check bypasses the if condition for non-root tracers and directly
|
|
# calls the super().call_function at the end, which is basically
|
|
# equivalent of inlining the method.
|
|
if tx.output.is_root_tracer() and isinstance(
|
|
self.obj, variables.NNModuleVariable
|
|
):
|
|
module_attr = getattr(self.fn, "__module__", "")
|
|
if (
|
|
module_attr is not None
|
|
and module_attr.startswith("torch.nn.")
|
|
or self.is_constant
|
|
):
|
|
return self.obj.call_method(
|
|
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
|
|
)
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
def inspect_parameter_names(self):
|
|
return super().inspect_parameter_names()[1:]
|
|
|
|
|
|
class WrappedUserMethodVariable(UserMethodVariable):
|
|
def __init__(self, wrapped, context, **kwargs):
|
|
kwargs.pop("fn", None)
|
|
kwargs.pop("obj", None)
|
|
super().__init__(wrapped.fn, wrapped.obj, **kwargs)
|
|
self.wrapped = wrapped
|
|
self.context = context
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
self.context.enter(tx)
|
|
result = super().call_function(tx, args, kwargs)
|
|
self.context.exit(tx)
|
|
return result
|
|
|
|
|
|
class WrappedUserFunctionVariable(UserFunctionVariable):
|
|
def __init__(self, wrapped, context, **kwargs):
|
|
kwargs.pop("fn", None)
|
|
kwargs.pop("obj", None)
|
|
super().__init__(wrapped.fn, **kwargs)
|
|
self.wrapped = wrapped
|
|
self.context = context
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
self.context.enter(tx)
|
|
result = super().call_function(tx, args, kwargs)
|
|
self.context.exit(tx)
|
|
return result
|
|
|
|
|
|
def invoke_and_store_as_constant(tx, fn, name, args, kwargs):
|
|
def convert(x):
|
|
if isinstance(x, variables.TensorVariable):
|
|
return x.get_real_value()
|
|
return x.as_python_constant()
|
|
|
|
args = [convert(x) for x in args]
|
|
kwargs = {k: convert(v) for k, v in kwargs.items()}
|
|
res = fn(*args, **kwargs)
|
|
return tx.output.register_attr_or_module(
|
|
res,
|
|
name,
|
|
source=ConstantSource(name),
|
|
)
|
|
|
|
|
|
class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|
_nonvar_fields = {
|
|
"closure_scope",
|
|
"f_globals",
|
|
*BaseUserFunctionVariable._nonvar_fields,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
fn_name,
|
|
code,
|
|
f_globals,
|
|
defaults,
|
|
kwdefaults,
|
|
annotations,
|
|
closure,
|
|
closure_scope,
|
|
wrapped_reconstructible=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
assert isinstance(fn_name.as_python_constant(), str)
|
|
assert isinstance(code.as_python_constant(), types.CodeType)
|
|
assert isinstance(f_globals, dict)
|
|
self.fn_name = fn_name
|
|
self.code = code
|
|
self.f_globals = f_globals
|
|
self.defaults = defaults
|
|
self.kwdefaults = kwdefaults
|
|
self.annotations = annotations
|
|
self.closure = closure
|
|
if closure is None:
|
|
closure_scope = None
|
|
self.closure_scope = closure_scope
|
|
# Either a source or a VT with .can_reconstruct() == True
|
|
self.wrapped_reconstructible: Optional[
|
|
Union[Source, VariableTracker]
|
|
] = wrapped_reconstructible
|
|
|
|
def self_args(self):
|
|
return []
|
|
|
|
def get_code(self):
|
|
return self.code.as_python_constant()
|
|
|
|
def get_function(self):
|
|
if self.closure:
|
|
raise NotImplementedError()
|
|
func = types.FunctionType(
|
|
self.code.as_python_constant(),
|
|
self.f_globals,
|
|
self.fn_name.as_python_constant(),
|
|
)
|
|
if self.defaults:
|
|
func.__defaults__ = self.defaults.as_python_constant()
|
|
if self.kwdefaults:
|
|
func.__kwdefaults__ = self.kwdefaults.as_python_constant()
|
|
if self.annotations:
|
|
annotations = self.annotations.as_python_constant()
|
|
if isinstance(annotations, tuple):
|
|
from itertools import pairwise
|
|
|
|
annotations = dict(pairwise(annotations))
|
|
|
|
# TypeError: __annotations__ must be set to a dict object
|
|
assert isinstance(annotations, dict)
|
|
func.__annotations__ = annotations
|
|
return func
|
|
|
|
def has_closure(self):
|
|
return self.closure is not None
|
|
|
|
def has_self(self):
|
|
return False
|
|
|
|
def get_globals(self):
|
|
return self.f_globals
|
|
|
|
def bind_args(self, parent, args, kwargs):
|
|
from .misc import InlinedClosureVariable
|
|
|
|
code = self.get_code()
|
|
func = types.FunctionType(
|
|
code,
|
|
self.f_globals,
|
|
self.fn_name.as_python_constant(),
|
|
tuple(self.defaults.items) if self.defaults else None,
|
|
tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
|
|
)
|
|
if self.kwdefaults:
|
|
func.__kwdefaults__ = self.kwdefaults.items
|
|
bound = inspect.signature(func).bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
result = dict(bound.arguments.items())
|
|
wrap_args_kwargs(parent.output.root_tx, result)
|
|
closure_cells = init_cellvars(parent, result, code)
|
|
|
|
for idx, name in enumerate(code.co_freevars):
|
|
cell = self.closure.items[idx]
|
|
assert getattr(cell, name, name) == name
|
|
assert name not in result
|
|
if isinstance(cell, InlinedClosureVariable):
|
|
# InlinedClosureVariable's are created from LOAD_CLOSURE's from
|
|
# InliningInstructionTranslators when the variable name is not found in closure_cells.
|
|
# They should remain outside of closure_cells, so that our callee (the
|
|
# InliningInstructionTranslator that traces `func`) handles
|
|
# the cell correctly - that is, the cell's contents are treated as if they
|
|
# are local variables, like in UserFunctionVariable's bind_args for freevars.
|
|
cand = parent
|
|
while cand and name not in cand.symbolic_locals:
|
|
cand = cand.parent
|
|
if cand is None:
|
|
raise RuntimeError(
|
|
f"Couldn't find {name} in the symbolic_locals of the inline interpreter stack"
|
|
)
|
|
result[name] = cand.symbolic_locals[name]
|
|
else:
|
|
closure_cells[name] = self.closure.items[idx]
|
|
|
|
return result, closure_cells
|
|
|
|
def export_freevars(self, parent, child):
|
|
code = self.get_code()
|
|
for var in code.co_freevars:
|
|
if var in child.symbolic_locals:
|
|
parent.symbolic_locals[var] = child.symbolic_locals[var]
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from(__name__, "_create_nested_fn")
|
|
codegen(self.code)
|
|
codegen.extend_output([codegen._create_load_const(self.f_globals)])
|
|
codegen(self.fn_name)
|
|
|
|
if self.defaults:
|
|
codegen(self.defaults)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
if self.closure:
|
|
codegen(self.closure)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
if self.kwdefaults:
|
|
codegen(self.kwdefaults)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
if self.annotations:
|
|
try:
|
|
if isinstance(self.annotations, variables.ConstDictVariable):
|
|
annotations = {
|
|
k: v.as_python_constant()
|
|
for k, v in self.annotations.items.items()
|
|
}
|
|
else:
|
|
annotations = tuple(
|
|
[v.as_python_constant() for v in self.annotations.items]
|
|
)
|
|
codegen.extend_output([codegen._create_load_const(annotations)])
|
|
except NotImplementedError:
|
|
codegen(self.annotations)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
codegen.extend_output(create_call_function(7, push_null=True))
|
|
|
|
if self.wrapped_reconstructible:
|
|
codegen.load_import_from("functools", "wraps")
|
|
codegen(self.wrapped_reconstructible)
|
|
codegen.extend_output(create_call_function(1, True))
|
|
codegen.extend_output(create_rot_n(2))
|
|
codegen.extend_output(create_call_function(1, True))
|
|
|
|
return []
|
|
|
|
|
|
def _traceable_collective_remaps():
|
|
# We can't rely on importing from distributed, since it's not always built
|
|
if torch.distributed.is_available():
|
|
from torch.distributed._functional_collectives import (
|
|
traceable_collective_remaps,
|
|
)
|
|
|
|
return traceable_collective_remaps
|
|
return {}
|
|
|
|
|
|
def _traceable_collectives_source(tx, fn):
|
|
assert torch.distributed.is_available(), "Illegal invocation."
|
|
from torch.distributed._functional_collectives import (
|
|
all_gather_tensor_inplace,
|
|
reduce_scatter_tensor_inplace,
|
|
)
|
|
|
|
valid_values = {all_gather_tensor_inplace, reduce_scatter_tensor_inplace}
|
|
assert fn in valid_values
|
|
inner_name = fn.__name__
|
|
path_source = tx.import_source("torch.distributed._functional_collectives")
|
|
return AttrSource(path_source, inner_name)
|
|
|
|
|
|
class CollectiveFunctionRewriteVariable(UserFunctionVariable):
|
|
"""
|
|
Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
|
|
|
|
This class provides both a way to check if a function is remappable, and perform the remapping.
|
|
|
|
In the case that a function is 'remappable' but only for some combinations of call-time arguments,
|
|
we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse
|
|
than status-quo as we currently graph-break on all distributed.* collectives.
|
|
"""
|
|
|
|
def __init__(self, fn, *, replacement_var, **kwargs):
|
|
super().__init__(fn, **kwargs)
|
|
assert isinstance(replacement_var, UserFunctionVariable)
|
|
self.replacement_var = replacement_var
|
|
|
|
@staticmethod
|
|
def create(tx, old_fn, source, **options):
|
|
new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn)
|
|
return CollectiveFunctionRewriteVariable(
|
|
old_fn,
|
|
replacement_var=UserFunctionVariable(new_fn, source=new_source, **options),
|
|
source=source,
|
|
**options,
|
|
)
|
|
|
|
@staticmethod
|
|
def can_rewrite(variable):
|
|
return (
|
|
inspect.isfunction(variable) and variable in _traceable_collective_remaps()
|
|
)
|
|
|
|
@staticmethod
|
|
def rewrite(tx, fn):
|
|
new_fn = _traceable_collective_remaps()[fn]
|
|
return new_fn, _traceable_collectives_source(tx, new_fn)
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
# call_function must check any unsupported arguments and graph-break.
|
|
# It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
|
|
# since that's the contract for putting a mapping in `traceable_collective_remaps`
|
|
if kwargs.get("async_op", False):
|
|
unimplemented(
|
|
f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}"
|
|
)
|
|
return self.replacement_var.call_function(tx, args, kwargs)
|
|
|
|
|
|
class FunctoolsPartialVariable(VariableTracker):
|
|
def __init__(self, func, args, keywords, original=None, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.func = func
|
|
assert isinstance(args, list)
|
|
self.args = args
|
|
assert isinstance(keywords, dict)
|
|
self.keywords = keywords
|
|
self.original = original
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
merged_args = self.args + args
|
|
merged_kwargs = {**self.keywords, **kwargs}
|
|
return self.func.call_function(tx, merged_args, merged_kwargs)
|
|
|
|
def as_python_constant(self):
|
|
if self.original:
|
|
return self.original
|
|
else:
|
|
|
|
def get_val(v):
|
|
if isinstance(v, variables.UserDefinedObjectVariable):
|
|
return v.value
|
|
else:
|
|
return v.as_python_constant()
|
|
|
|
return functools.partial(
|
|
self.func.fn,
|
|
*[get_val(arg) for arg in self.args],
|
|
**{k: get_val(v) for k, v in self.keywords.items()},
|
|
)
|
|
|
|
|
|
class TritonKernelVariable(VariableTracker):
|
|
def __init__(self, kernel, kernel_idx, grid, **kwargs):
|
|
from triton.runtime.autotuner import Autotuner
|
|
|
|
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
assert kernel is not None
|
|
|
|
self.kernel = kernel
|
|
self.kernel_idx = kernel_side_table.add_kernel(kernel)
|
|
|
|
assert kernel_idx is None or self.kernel_idx == kernel_idx
|
|
|
|
self.grid = grid
|
|
|
|
if isinstance(kernel, Autotuner):
|
|
# We only support configs and keys arguments of triton.autotune
|
|
# Make sure other arguments are defaulted
|
|
defaults = inspect.signature(Autotuner).parameters
|
|
if (
|
|
("warmup" in defaults and defaults["warmup"].default != kernel.warmup)
|
|
or ("rep" in defaults and defaults["rep"].default != kernel.rep)
|
|
or (
|
|
"prune_configs_by" in defaults
|
|
and defaults["prune_configs_by"].default
|
|
!= kernel.early_config_prune
|
|
)
|
|
):
|
|
raise Unsupported(
|
|
"Only configs and keys are supported for triton.autotune"
|
|
)
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
from triton.runtime.autotuner import Autotuner
|
|
|
|
from .constant import ConstantVariable
|
|
from .dicts import ConstDictVariable
|
|
from .lists import BaseListVariable
|
|
|
|
if self.grid is None:
|
|
raise Unsupported("Triton kernels should always be called with a grid")
|
|
|
|
# Both for grid's meta as well as for the kernel, we need combined
|
|
# args and kwargs normalized
|
|
normalized_args = {**dict(zip(self.kernel.arg_names, args)), **kwargs}
|
|
|
|
configs = (
|
|
[config.kwargs for config in self.kernel.configs]
|
|
if isinstance(self.kernel, Autotuner)
|
|
else [{}]
|
|
)
|
|
grids = []
|
|
for config_args in configs:
|
|
# If the grid is a function, then lets execute it and convert it to
|
|
# a list
|
|
grid = self.grid
|
|
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
|
|
# Populate the special "meta" argument to call the grid function
|
|
config_args = {
|
|
k: ConstantVariable.create(v) for k, v in config_args.items()
|
|
}
|
|
meta = ConstDictVariable({**normalized_args, **config_args}, dict)
|
|
grid = grid.call_function(tx, [meta], {})
|
|
|
|
# Now, the grid must be a list either originally or through above
|
|
# modification
|
|
if isinstance(grid, BaseListVariable):
|
|
grids.append(grid.as_proxy())
|
|
else:
|
|
unimplemented(f"grid for the triton kernel is {type(grid)}")
|
|
|
|
for i in range(len(grids)):
|
|
if not isinstance(grids[i], tuple):
|
|
raise Unsupported("Only tuple grids are supported")
|
|
# inductor expects all grids to be 3-tuple so lets make it
|
|
if len(grids[i]) == 1:
|
|
grids[i] = (grids[i][0], 1, 1)
|
|
elif len(grids[i]) == 2:
|
|
grids[i] = (grids[i][0], grids[i][1], 1)
|
|
elif len(grids[i]) > 3:
|
|
raise Unsupported("Grid can have at most rank 3")
|
|
|
|
assert len(grids) != 0
|
|
if len(set(grids)) == 1:
|
|
# If there's only one unique grid, lets simplify
|
|
grids = [grids[0]]
|
|
|
|
from torch._higher_order_ops.triton_kernel_wrap import (
|
|
triton_kernel_wrapper_mutation,
|
|
)
|
|
|
|
# Combine args and kwargs and pass as a dict so that if user defined triton
|
|
# kernel uses variables as 'grid' or 'kernel', it does not conflict with
|
|
# parameters of the wrapper function
|
|
meta = ConstDictVariable(normalized_args, dict)
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
triton_kernel_wrapper_mutation,
|
|
(),
|
|
{
|
|
"kernel_idx": self.kernel_idx,
|
|
"grid": grids,
|
|
"kwargs": meta.as_proxy(),
|
|
},
|
|
)
|
|
|
|
return variables.ConstantVariable(
|
|
None,
|
|
)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
if name == "__getitem__":
|
|
# __getitem__ should only be called if we don't already have a grid
|
|
# Only grid needs to be passed
|
|
if self.grid is not None or len(args) != 1:
|
|
raise Unsupported(
|
|
"Triton kernels should be called with only a single grid"
|
|
)
|
|
|
|
return TritonKernelVariable(
|
|
kernel=self.kernel,
|
|
kernel_idx=self.kernel_idx,
|
|
grid=args[0],
|
|
)
|
|
elif name == "run":
|
|
if "grid" not in kwargs:
|
|
raise Unsupported("Triton kernel requires to be called with a grid")
|
|
grid = kwargs.pop("grid")
|
|
# rewrite kernel.run(*args, grid=grid) to kernel[grid](*args)
|
|
return TritonKernelVariable(
|
|
kernel=self.kernel, kernel_idx=self.kernel_idx, grid=grid
|
|
).call_function(tx, args, kwargs)
|
|
|
|
# Bail out to parent's implementation
|
|
return super().call_method(tx, name, args, kwargs)
|