mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
partially address https://github.com/pytorch/pytorch/issues/118785 This diff fixes three things: 1. add get_function to FunctoolsPartialVariable note that it will be available only if all args constant otherwise, it would throw unimplemented in the call to asPythonConstant. 2. NamedTupleVariable takes args dispatched not as list ex: NamedTuple(a, b, c) vs NamedTuple([a, b, c]), hence fix that by specializing asProxy. 3. A call to create_arg from within create_proxy, changes a python NamedTuple to a function call node without associating an example value! Updated get_fake_values_from_nodes to handle such case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/119435 Approved by: https://github.com/jansel, https://github.com/anijain2305 ghstack dependencies: #119314
920 lines
33 KiB
Python
920 lines
33 KiB
Python
# mypy: ignore-errors
|
|
|
|
import collections
|
|
import functools
|
|
import inspect
|
|
import itertools
|
|
import types
|
|
from typing import Dict, List, Optional, TYPE_CHECKING, Union
|
|
|
|
import torch
|
|
|
|
from .. import variables
|
|
from ..bytecode_transformation import create_call_function, create_rot_n
|
|
from ..exc import unimplemented, Unsupported
|
|
from ..guards import GuardBuilder, install_guard
|
|
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
|
|
from ..utils import check_constant_args, get_first_attr, identity, istype, make_cell
|
|
from .base import MutableLocal, typestr, VariableTracker
|
|
from .constant import ConstantVariable
|
|
|
|
if TYPE_CHECKING:
|
|
from torch._guards import Source
|
|
|
|
|
|
def wrap_bound_arg(tx, val, source=None):
|
|
# Source propagation is best effort since not every object we encounter has a source to begin with.
|
|
if isinstance(val, VariableTracker):
|
|
return val
|
|
elif not source:
|
|
from torch._dynamo.variables.builder import SourcelessBuilder
|
|
|
|
return SourcelessBuilder()(tx, val)
|
|
else:
|
|
from torch._dynamo.variables.builder import VariableBuilder
|
|
|
|
return VariableBuilder(tx, source=source)(val)
|
|
|
|
|
|
def wrap_args_kwargs(tx, result):
|
|
for k, v in list(result.items()):
|
|
if isinstance(v, (tuple, dict)):
|
|
# args/kwargs
|
|
result[k] = wrap_bound_arg(tx, v)
|
|
|
|
|
|
def init_cellvars(parent, result, code):
|
|
closure_cells = dict()
|
|
side_effects = parent.output.side_effects
|
|
|
|
# for name in itertools.chain(code.co_cellvars, code.co_freevars):
|
|
for name in code.co_cellvars:
|
|
closure_cells[name] = side_effects.track_cell_new()
|
|
if name in result:
|
|
side_effects.store_cell(closure_cells[name], result.pop(name))
|
|
|
|
return closure_cells
|
|
|
|
|
|
def _create_nested_fn(
|
|
code, f_globals, name, defaults, closure, kwdefaults, annotations
|
|
):
|
|
from types import FunctionType
|
|
|
|
func = FunctionType(code, f_globals, name, defaults, closure)
|
|
func.__kwdefaults__ = kwdefaults
|
|
|
|
if isinstance(annotations, tuple):
|
|
from itertools import pairwise
|
|
|
|
annotations = dict(pairwise(annotations))
|
|
|
|
# TypeError: __annotations__ must be set to a dict object
|
|
assert annotations is None or isinstance(annotations, dict)
|
|
func.__annotations__ = annotations
|
|
|
|
return func
|
|
|
|
|
|
class BaseUserFunctionVariable(VariableTracker):
|
|
def get_filename(self):
|
|
return self.get_code().co_filename
|
|
|
|
def get_name(self):
|
|
return self.get_code().co_name
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
return tx.inline_user_function_return(
|
|
self, list(self.self_args()) + list(args), kwargs
|
|
)
|
|
|
|
def call_hasattr(self, tx, name: str) -> VariableTracker:
|
|
result = False
|
|
|
|
try:
|
|
result = hasattr(self.get_function(), name)
|
|
except NotImplementedError:
|
|
if name == "__name__" and isinstance(self, NestedUserFunctionVariable):
|
|
result = True
|
|
return variables.ConstantVariable.create(result)
|
|
|
|
def inspect_parameter_names(self):
|
|
return list(inspect.signature(self.get_function()).parameters)
|
|
|
|
def closure_vars(self, tx):
|
|
return {}
|
|
|
|
|
|
class UserFunctionVariable(BaseUserFunctionVariable):
|
|
"""Some unsupported user-defined global function"""
|
|
|
|
@classmethod
|
|
def create_with_source(cls, value, source):
|
|
install_guard(source.make_guard(GuardBuilder.CLOSURE_MATCH))
|
|
return cls(
|
|
value,
|
|
source=source,
|
|
)
|
|
|
|
def __init__(self, fn, is_constant=False, **kwargs):
|
|
super().__init__(**kwargs)
|
|
if getattr(fn, "_dynamo_marked_constant", False):
|
|
# This method should be treated as a constant for the purposes of compilation
|
|
self.is_constant = True
|
|
else:
|
|
self.is_constant = False
|
|
|
|
assert isinstance(
|
|
fn, (types.FunctionType, torch.jit.ScriptFunction)
|
|
), f"expected FunctionType found {typestr(fn)} {fn}"
|
|
# unpack @torch._dynamo.optimize()(fn) wrapped function
|
|
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
|
# unpack torch.jit.script_if_tracing
|
|
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
|
|
fn = inspect.getattr_static(fn, "__original_fn", fn)
|
|
self.fn: types.FunctionType = fn
|
|
|
|
def as_python_constant(self):
|
|
if istype(self, UserFunctionVariable):
|
|
return self.fn
|
|
# subclasses (such as methods) usually aren't a constant
|
|
return super().as_python_constant()
|
|
|
|
def self_args(self):
|
|
return []
|
|
|
|
def get_function(self):
|
|
return self.fn
|
|
|
|
def get_code(self):
|
|
return self.fn.__code__
|
|
|
|
def python_type(self):
|
|
return types.FunctionType
|
|
|
|
def has_self(self):
|
|
return getattr(self.fn, "__self__", None) is not None
|
|
|
|
def get_globals(self):
|
|
return self.fn.__globals__
|
|
|
|
def bind_args(self, parent, args, kwargs):
|
|
assert not self.is_constant
|
|
tx = parent.output.root_tx
|
|
wrap = functools.partial(wrap_bound_arg, tx=tx)
|
|
|
|
fn: types.FunctionType = self.fn
|
|
defaults = fn.__defaults__ or []
|
|
defaults_sources = [
|
|
None if self.source is None else DefaultsSource(self.source, idx)
|
|
for idx, _ in enumerate(defaults)
|
|
]
|
|
fake_func = types.FunctionType(
|
|
fn.__code__,
|
|
fn.__globals__,
|
|
fn.__name__,
|
|
tuple(
|
|
[
|
|
wrap(val=arg, source=source)
|
|
for arg, source in zip(defaults, defaults_sources)
|
|
]
|
|
),
|
|
fn.__closure__,
|
|
)
|
|
if fn.__kwdefaults__:
|
|
kwdefaults_sources = {
|
|
k: None
|
|
if self.source is None
|
|
else DefaultsSource(self.source, k, is_kw=True)
|
|
for k in fn.__kwdefaults__
|
|
}
|
|
fake_func.__kwdefaults__ = {
|
|
k: wrap(val=v, source=kwdefaults_sources[k])
|
|
for k, v in fn.__kwdefaults__.items()
|
|
}
|
|
|
|
bound = inspect.signature(fake_func).bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
result = dict(bound.arguments.items())
|
|
|
|
wrap_args_kwargs(tx, result)
|
|
closure_cells = init_cellvars(parent, result, fn.__code__)
|
|
closure = self.fn.__closure__ or ()
|
|
assert len(closure) == len(self.fn.__code__.co_freevars)
|
|
for idx, name, cell in zip(
|
|
itertools.count(), self.fn.__code__.co_freevars, closure
|
|
):
|
|
if name == "__class__":
|
|
source = AttrSource(self.source, "__class__") if self.source else None
|
|
result[name] = variables.UserDefinedClassVariable(
|
|
cell.cell_contents,
|
|
source=source,
|
|
)
|
|
else:
|
|
var = tx.match_nested_cell(name, cell)
|
|
if var is not None:
|
|
# optimization for cleaner codegen
|
|
result[name] = var
|
|
elif self.source:
|
|
from .builder import VariableBuilder
|
|
|
|
side_effects = parent.output.side_effects
|
|
if cell in side_effects:
|
|
out = side_effects[cell]
|
|
else:
|
|
closure_cell = GetItemSource(
|
|
AttrSource(self.source, "__closure__"), idx
|
|
)
|
|
closure_cell_contents = AttrSource(
|
|
closure_cell, "cell_contents"
|
|
)
|
|
try:
|
|
contents_var = VariableBuilder(
|
|
parent, closure_cell_contents
|
|
)(cell.cell_contents)
|
|
except ValueError:
|
|
# Cell has not yet been assigned
|
|
contents_var = variables.DeletedVariable()
|
|
|
|
if (
|
|
closure_cell_contents.name()
|
|
not in tx.mutated_closure_cell_contents
|
|
):
|
|
# Optimistically don't allocate the cell, to
|
|
# reduce the number of side effects. This is
|
|
# important for cond, as without it, any accesses
|
|
# to closures create side effects and cond doesn't
|
|
# support side effects. If we're wrong and this
|
|
# closure cell gets written to, we will restart
|
|
# the analysis with this cell's name in the
|
|
# mutated list here
|
|
result[name] = contents_var
|
|
continue
|
|
|
|
# cells are written to with "cell_contents",
|
|
# so the source should just be the closure_cell, not its contents
|
|
out = side_effects.track_cell_existing(closure_cell, cell)
|
|
side_effects.store_cell(
|
|
out,
|
|
contents_var,
|
|
)
|
|
|
|
result[name] = out
|
|
|
|
else:
|
|
from .builder import SourcelessBuilder
|
|
|
|
result[name] = SourcelessBuilder()(tx, cell.cell_contents)
|
|
|
|
return result, closure_cells
|
|
|
|
def export_freevars(self, parent, child):
|
|
pass
|
|
|
|
def call_hasattr(self, tx, name: str) -> VariableTracker:
|
|
result = hasattr(self.fn, name)
|
|
return variables.ConstantVariable.create(result)
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
if self.is_constant:
|
|
return invoke_and_store_as_constant(
|
|
tx, self.fn, self.get_name(), args, kwargs
|
|
)
|
|
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
|
|
class UserMethodVariable(UserFunctionVariable):
|
|
"""Some unsupported user-defined method"""
|
|
|
|
def __init__(self, fn, obj, **kwargs):
|
|
super().__init__(fn=fn, **kwargs)
|
|
self.obj = obj
|
|
|
|
def __str__(self):
|
|
return f"{self.__class__.__name__}({self.fn}, {self.obj})"
|
|
|
|
def self_args(self):
|
|
return [self.obj]
|
|
|
|
def python_type(self):
|
|
return types.MethodType
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
# For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
|
|
# rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
|
|
# since we ensure `forward` of allowed modules can be traced by AOT safely.
|
|
# Note this is not only for allowed modules, as user customized modules can extend from
|
|
# allowed modules but using parent's `forward` method, which is also covered by this branch.
|
|
|
|
# If we are tracing the higher order op, we want Dynamo to step inside
|
|
# the module call so that Dynamo can see the underlying parameters and
|
|
# buffers and raise them as inputs to the graph. The is_root_tracer
|
|
# check bypasses the if condition for non-root tracers and directly
|
|
# calls the super().call_function at the end, which is basically
|
|
# equivalent of inlining the method.
|
|
if tx.output.is_root_tracer() and isinstance(
|
|
self.obj, variables.NNModuleVariable
|
|
):
|
|
module_attr = getattr(self.fn, "__module__", "")
|
|
if (
|
|
module_attr is not None
|
|
and module_attr.startswith("torch.nn.")
|
|
or self.is_constant
|
|
):
|
|
return self.obj.call_method(
|
|
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
|
|
)
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
def inspect_parameter_names(self):
|
|
return super().inspect_parameter_names()[1:]
|
|
|
|
|
|
class WrappedUserMethodVariable(UserMethodVariable):
|
|
def __init__(self, wrapped, context, **kwargs):
|
|
kwargs.pop("fn", None)
|
|
kwargs.pop("obj", None)
|
|
super().__init__(wrapped.fn, wrapped.obj, **kwargs)
|
|
self.wrapped = wrapped
|
|
self.context = context
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
self.context.enter(tx)
|
|
result = super().call_function(tx, args, kwargs)
|
|
self.context.exit(tx)
|
|
return result
|
|
|
|
|
|
class WrappedUserFunctionVariable(UserFunctionVariable):
|
|
def __init__(self, wrapped, context, **kwargs):
|
|
kwargs.pop("fn", None)
|
|
kwargs.pop("obj", None)
|
|
super().__init__(wrapped.fn, **kwargs)
|
|
self.wrapped = wrapped
|
|
self.context = context
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
self.context.enter(tx)
|
|
result = super().call_function(tx, args, kwargs)
|
|
self.context.exit(tx)
|
|
return result
|
|
|
|
|
|
def invoke_and_store_as_constant(tx, fn, name, args, kwargs):
|
|
def convert(x):
|
|
if isinstance(x, variables.TensorVariable):
|
|
return x.get_real_value()
|
|
return x.as_python_constant()
|
|
|
|
args = [convert(x) for x in args]
|
|
kwargs = {k: convert(v) for k, v in kwargs.items()}
|
|
res = fn(*args, **kwargs)
|
|
return tx.output.register_attr_or_module(
|
|
res,
|
|
name,
|
|
source=ConstantSource(name),
|
|
)
|
|
|
|
|
|
class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|
_nonvar_fields = {
|
|
"closure_scope",
|
|
"f_globals",
|
|
*BaseUserFunctionVariable._nonvar_fields,
|
|
}
|
|
|
|
def __init__(
|
|
self,
|
|
fn_name,
|
|
code,
|
|
f_globals,
|
|
defaults,
|
|
kwdefaults,
|
|
annotations,
|
|
closure,
|
|
closure_scope,
|
|
wrapped_reconstructible=None,
|
|
**kwargs,
|
|
):
|
|
super().__init__(**kwargs)
|
|
assert isinstance(fn_name.as_python_constant(), str)
|
|
assert isinstance(code.as_python_constant(), types.CodeType)
|
|
assert isinstance(f_globals, dict)
|
|
self.fn_name = fn_name
|
|
self.code = code
|
|
self.f_globals = f_globals
|
|
self.defaults = defaults
|
|
self.kwdefaults = kwdefaults
|
|
self.annotations = annotations
|
|
self.closure = closure
|
|
if closure is None:
|
|
closure_scope = None
|
|
self.closure_scope = closure_scope
|
|
# Either a source or a VT with .can_reconstruct() == True
|
|
self.wrapped_reconstructible: Optional[
|
|
Union[Source, VariableTracker]
|
|
] = wrapped_reconstructible
|
|
|
|
def self_args(self):
|
|
return []
|
|
|
|
def get_code(self):
|
|
return self.code.as_python_constant()
|
|
|
|
def get_function(self):
|
|
if self.closure:
|
|
raise NotImplementedError()
|
|
func = types.FunctionType(
|
|
self.code.as_python_constant(),
|
|
self.f_globals,
|
|
self.fn_name.as_python_constant(),
|
|
)
|
|
if self.defaults:
|
|
func.__defaults__ = self.defaults.as_python_constant()
|
|
if self.kwdefaults:
|
|
func.__kwdefaults__ = self.kwdefaults.as_python_constant()
|
|
if self.annotations:
|
|
annotations = self.annotations.as_python_constant()
|
|
if isinstance(annotations, tuple):
|
|
from itertools import pairwise
|
|
|
|
annotations = dict(pairwise(annotations))
|
|
|
|
# TypeError: __annotations__ must be set to a dict object
|
|
assert isinstance(annotations, dict)
|
|
func.__annotations__ = annotations
|
|
return func
|
|
|
|
def has_closure(self):
|
|
return self.closure is not None
|
|
|
|
def has_self(self):
|
|
return False
|
|
|
|
def get_globals(self):
|
|
return self.f_globals
|
|
|
|
def bind_args(self, parent, args, kwargs):
|
|
from .misc import InlinedClosureVariable
|
|
|
|
code = self.get_code()
|
|
func = types.FunctionType(
|
|
code,
|
|
self.f_globals,
|
|
self.fn_name.as_python_constant(),
|
|
tuple(self.defaults.items) if self.defaults else None,
|
|
tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
|
|
)
|
|
if self.kwdefaults:
|
|
func.__kwdefaults__ = self.kwdefaults.keys_as_python_constant()
|
|
bound = inspect.signature(func).bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
result = dict(bound.arguments.items())
|
|
wrap_args_kwargs(parent.output.root_tx, result)
|
|
closure_cells = init_cellvars(parent, result, code)
|
|
|
|
for idx, name in enumerate(code.co_freevars):
|
|
cell = self.closure.items[idx]
|
|
assert getattr(cell, name, name) == name
|
|
assert name not in result
|
|
if isinstance(cell, InlinedClosureVariable):
|
|
# InlinedClosureVariable's are created from LOAD_CLOSURE's from
|
|
# InliningInstructionTranslators when the variable name is not found in closure_cells.
|
|
# They should remain outside of closure_cells, so that our callee (the
|
|
# InliningInstructionTranslator that traces `func`) handles
|
|
# the cell correctly - that is, the cell's contents are treated as if they
|
|
# are local variables, like in UserFunctionVariable's bind_args for freevars.
|
|
cand = parent
|
|
while cand and name not in cand.symbolic_locals:
|
|
cand = cand.parent
|
|
if cand is None:
|
|
raise RuntimeError(
|
|
f"Couldn't find {name} in the symbolic_locals of the inline interpreter stack"
|
|
)
|
|
result[name] = cand.symbolic_locals[name]
|
|
else:
|
|
closure_cells[name] = self.closure.items[idx]
|
|
|
|
return result, closure_cells
|
|
|
|
def export_freevars(self, parent, child):
|
|
code = self.get_code()
|
|
for var in code.co_freevars:
|
|
if var in child.symbolic_locals:
|
|
parent.symbolic_locals[var] = child.symbolic_locals[var]
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from(__name__, "_create_nested_fn")
|
|
codegen(self.code)
|
|
codegen.extend_output([codegen._create_load_const(self.f_globals)])
|
|
codegen(ConstantVariable.create(self.code.value.co_name))
|
|
|
|
if self.defaults:
|
|
codegen(self.defaults)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
if self.closure:
|
|
codegen(self.closure)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
if self.kwdefaults:
|
|
codegen(self.kwdefaults)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
if self.annotations:
|
|
try:
|
|
annotations = self.annotations.as_python_constant()
|
|
codegen.extend_output([codegen._create_load_const(annotations)])
|
|
except NotImplementedError:
|
|
codegen(self.annotations)
|
|
else:
|
|
codegen.extend_output([codegen.create_load_const(None)])
|
|
|
|
codegen.extend_output(create_call_function(7, push_null=True))
|
|
|
|
if self.wrapped_reconstructible:
|
|
codegen.load_import_from("functools", "wraps")
|
|
codegen(self.wrapped_reconstructible)
|
|
codegen.extend_output(create_call_function(1, True))
|
|
codegen.extend_output(create_rot_n(2))
|
|
codegen.extend_output(create_call_function(1, True))
|
|
|
|
return []
|
|
|
|
|
|
class SkipFunctionVariable(VariableTracker):
|
|
def __init__(self, value, reason=None, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.value = value
|
|
self.reason = reason
|
|
|
|
def python_type(self):
|
|
return type(self.value)
|
|
|
|
def as_python_constant(self):
|
|
return self.value
|
|
|
|
@classmethod
|
|
def create_with_source(cls, value, source):
|
|
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
|
|
return cls(
|
|
value,
|
|
source=source,
|
|
)
|
|
|
|
@staticmethod
|
|
@functools.lru_cache(None)
|
|
def fold_through_function_to_wrapper():
|
|
return {
|
|
collections.namedtuple: variables.UserDefinedClassVariable,
|
|
}
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
|
|
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
|
|
# Fold through the functions(e.g, collections.namedtuple)
|
|
# that inputs & outputs are all python constants
|
|
elif (
|
|
self.value in self.fold_through_function_to_wrapper().keys()
|
|
and check_constant_args(args, kwargs)
|
|
):
|
|
value = self.value(
|
|
*[x.as_python_constant() for x in args],
|
|
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
|
)
|
|
return self.fold_through_function_to_wrapper().get(self.value)(
|
|
value, mutable_local=MutableLocal()
|
|
)
|
|
elif (
|
|
self.value is functools.wraps
|
|
and not kwargs
|
|
and len(args) == 1
|
|
and (
|
|
args[0].source is not None or args[0].can_reconstruct(tx.output.root_tx)
|
|
)
|
|
):
|
|
|
|
def wraps(fn):
|
|
if isinstance(fn, variables.NestedUserFunctionVariable):
|
|
if args[0].source:
|
|
reconstructible = args[0].source
|
|
else:
|
|
reconstructible = args[0]
|
|
return fn.clone(wrapped_reconstructible=reconstructible)
|
|
unimplemented(f"functools.wraps({fn})")
|
|
|
|
return variables.LambdaVariable(wraps)
|
|
else:
|
|
try:
|
|
path = inspect.getfile(self.value)
|
|
except TypeError:
|
|
path = f"Builtin {self.value.__name__}"
|
|
msg = f"'skip function {self.value.__qualname__} in file {path}'"
|
|
msg += f"', {self.reason}'" if self.reason else ""
|
|
unimplemented(msg)
|
|
|
|
|
|
def _traceable_collective_remaps():
|
|
# We can't rely on importing from distributed, since it's not always built
|
|
if torch.distributed.is_available():
|
|
from torch.distributed._functional_collectives import (
|
|
traceable_collective_remaps,
|
|
)
|
|
|
|
return traceable_collective_remaps
|
|
return {}
|
|
|
|
|
|
def _traceable_collectives_source(tx, fn):
|
|
assert torch.distributed.is_available(), "Illegal invocation."
|
|
assert fn in _traceable_collective_remaps().values()
|
|
|
|
inner_name = fn.__name__
|
|
path_source = tx.import_source("torch.distributed._functional_collectives")
|
|
return AttrSource(path_source, inner_name)
|
|
|
|
|
|
class CollectiveFunctionRewriteVariable(UserFunctionVariable):
|
|
"""
|
|
Some of the torch.distributed.* collective APIs are possible to rewrite to 'traceable' collectives.
|
|
|
|
This class provides both a way to check if a function is remappable, and perform the remapping.
|
|
|
|
In the case that a function is 'remappable' but only for some combinations of call-time arguments,
|
|
we check the args at `call_function` time and fall back to graph-breaking if needed. This is no worse
|
|
than status-quo as we currently graph-break on all distributed.* collectives.
|
|
"""
|
|
|
|
def __init__(self, fn, *, replacement_var, **kwargs):
|
|
super().__init__(fn, **kwargs)
|
|
assert isinstance(replacement_var, UserFunctionVariable)
|
|
self.replacement_var = replacement_var
|
|
|
|
@staticmethod
|
|
def create(tx, old_fn, source, **options):
|
|
new_fn, new_source = CollectiveFunctionRewriteVariable.rewrite(tx, old_fn)
|
|
return CollectiveFunctionRewriteVariable(
|
|
old_fn,
|
|
replacement_var=UserFunctionVariable(new_fn, source=new_source, **options),
|
|
source=source,
|
|
**options,
|
|
)
|
|
|
|
@staticmethod
|
|
def can_rewrite(variable):
|
|
return (
|
|
inspect.isfunction(variable) and variable in _traceable_collective_remaps()
|
|
)
|
|
|
|
@staticmethod
|
|
def rewrite(tx, fn):
|
|
new_fn = _traceable_collective_remaps()[fn]
|
|
return new_fn, _traceable_collectives_source(tx, new_fn)
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
# call_function must check any unsupported arguments and graph-break.
|
|
# It's safe to assume args/kwargs from orig_fn map 1:1 to args/kwargs of remapped_fn,
|
|
# since that's the contract for putting a mapping in `traceable_collective_remaps`
|
|
if "async_op" in kwargs and kwargs["async_op"].as_python_constant():
|
|
unimplemented(
|
|
f"CollectiveFunctionRewriteVariable can't support async_op=True for {self.fn}"
|
|
)
|
|
return self.replacement_var.call_function(tx, args, kwargs)
|
|
|
|
|
|
class FunctoolsPartialVariable(VariableTracker):
|
|
def __init__(self, func: VariableTracker, args, keywords, **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.func = func
|
|
assert isinstance(args, list)
|
|
self.args = args
|
|
assert isinstance(keywords, dict)
|
|
self.keywords = keywords
|
|
|
|
def reconstruct(self, codegen):
|
|
codegen.load_import_from("functools", "partial")
|
|
codegen(self.func)
|
|
if self.args:
|
|
codegen.foreach(self.args)
|
|
if not self.keywords:
|
|
return create_call_function(len(self.args) + 1, True)
|
|
|
|
codegen.foreach(self.keywords.values())
|
|
keys = tuple(self.keywords.keys())
|
|
return codegen.create_call_function_kw(
|
|
len(keys) + len(self.args) + 1, keys, True
|
|
)
|
|
|
|
def get_function(self):
|
|
return self.as_python_constant()
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
merged_args = self.args + args
|
|
merged_kwargs = {**self.keywords, **kwargs}
|
|
return self.func.call_function(tx, merged_args, merged_kwargs)
|
|
|
|
def call_hasattr(self, tx, name: str) -> VariableTracker:
|
|
# functools.partial uses slots, so attributes are constant
|
|
return variables.ConstantVariable.create(
|
|
hasattr(functools.partial(identity), name)
|
|
)
|
|
|
|
def as_python_constant(self):
|
|
return functools.partial(
|
|
self.func.as_python_constant(),
|
|
*[arg.as_python_constant() for arg in self.args],
|
|
**{k: v.as_python_constant() for k, v in self.keywords.items()},
|
|
)
|
|
|
|
def guard_as_python_constant(self):
|
|
"""Similar to as_python_constant(), but add ID_MATCH guards to try to force things to become constants"""
|
|
return functools.partial(
|
|
self.func.guard_as_python_constant(),
|
|
*[v.guard_as_python_constant() for v in self.args],
|
|
**{k: v.guard_as_python_constant() for k, v in self.keywords.items()},
|
|
)
|
|
|
|
|
|
class TritonKernelVariable(VariableTracker):
|
|
def __init__(self, kernel, kernel_idx, grid, **kwargs):
|
|
from triton.runtime.autotuner import Autotuner
|
|
|
|
from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
|
|
|
|
super().__init__(**kwargs)
|
|
|
|
assert kernel is not None
|
|
|
|
self.kernel = kernel
|
|
self.kernel_idx = kernel_side_table.add_kernel(kernel)
|
|
|
|
assert kernel_idx is None or self.kernel_idx == kernel_idx
|
|
|
|
self.grid = grid
|
|
|
|
if isinstance(kernel, Autotuner):
|
|
# We only support configs and keys arguments of triton.autotune
|
|
# Make sure other arguments are defaulted
|
|
defaults = inspect.signature(Autotuner.__init__).parameters
|
|
|
|
# Newer version of triton change attribute name from warmup to num_warmup and rep to num_rep.
|
|
# The call to get_first_attr is to maintain backward-compatibility.
|
|
if (
|
|
(
|
|
"warmup" in defaults
|
|
and defaults["warmup"].default
|
|
!= get_first_attr(kernel, "num_warmups", "warmup")
|
|
)
|
|
or (
|
|
"rep" in defaults
|
|
and defaults["rep"].default
|
|
!= get_first_attr(kernel, "num_reps", "rep")
|
|
)
|
|
or (
|
|
"prune_configs_by" in defaults
|
|
and defaults["prune_configs_by"].default
|
|
!= kernel.early_config_prune
|
|
)
|
|
):
|
|
raise Unsupported(
|
|
"Only configs and keys are supported for triton.autotune"
|
|
)
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
from triton.runtime.autotuner import Autotuner
|
|
|
|
from .constant import ConstantVariable
|
|
from .dicts import ConstDictVariable
|
|
from .lists import BaseListVariable
|
|
|
|
if self.grid is None:
|
|
raise Unsupported("Triton kernels should always be called with a grid")
|
|
|
|
# Both for grid's meta as well as for the kernel, we need combined
|
|
# args and kwargs normalized
|
|
names = (
|
|
variables.ConstantVariable.create(name) for name in self.kernel.arg_names
|
|
)
|
|
kwargs = {variables.ConstantVariable.create(k): v for k, v in kwargs.items()}
|
|
normalized_args = {**dict(zip(names, args)), **kwargs}
|
|
|
|
configs = (
|
|
[config.kwargs for config in self.kernel.configs]
|
|
if isinstance(self.kernel, Autotuner)
|
|
else [{}]
|
|
)
|
|
grids = []
|
|
for config_args in configs:
|
|
# If the grid is a function, then lets execute it and convert it to
|
|
# a list
|
|
grid = self.grid
|
|
if isinstance(grid, (NestedUserFunctionVariable, UserFunctionVariable)):
|
|
# Populate the special "meta" argument to call the grid function
|
|
config_args = {
|
|
ConstantVariable.create(k): ConstantVariable.create(v)
|
|
for k, v in config_args.items()
|
|
}
|
|
meta = ConstDictVariable({**normalized_args, **config_args}, dict)
|
|
grid = grid.call_function(tx, [meta], {})
|
|
|
|
# Now, the grid must be a list either originally or through above
|
|
# modification
|
|
if isinstance(grid, BaseListVariable):
|
|
grids.append(grid.as_proxy())
|
|
else:
|
|
unimplemented(f"grid for the triton kernel is {type(grid)}")
|
|
|
|
for i in range(len(grids)):
|
|
if not isinstance(grids[i], tuple):
|
|
raise Unsupported("Only tuple grids are supported")
|
|
# inductor expects all grids to be 3-tuple so lets make it
|
|
if len(grids[i]) == 1:
|
|
grids[i] = (grids[i][0], 1, 1)
|
|
elif len(grids[i]) == 2:
|
|
grids[i] = (grids[i][0], grids[i][1], 1)
|
|
elif len(grids[i]) > 3:
|
|
raise Unsupported("Grid can have at most rank 3")
|
|
|
|
assert len(grids) != 0
|
|
if len(set(grids)) == 1:
|
|
# If there's only one unique grid, lets simplify
|
|
grids = [grids[0]]
|
|
|
|
from torch._higher_order_ops.triton_kernel_wrap import (
|
|
triton_kernel_wrapper_mutation,
|
|
)
|
|
|
|
# Combine args and kwargs and pass as a dict so that if user defined triton
|
|
# kernel uses variables as 'grid' or 'kernel', it does not conflict with
|
|
# parameters of the wrapper function
|
|
meta = ConstDictVariable(normalized_args, dict)
|
|
tx.output.create_proxy(
|
|
"call_function",
|
|
triton_kernel_wrapper_mutation,
|
|
(),
|
|
{
|
|
"kernel_idx": self.kernel_idx,
|
|
"grid": grids,
|
|
"kwargs": meta.as_proxy(),
|
|
},
|
|
)
|
|
|
|
return variables.ConstantVariable(
|
|
None,
|
|
)
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
if name == "__getitem__":
|
|
# __getitem__ should only be called if we don't already have a grid
|
|
# Only grid needs to be passed
|
|
if self.grid is not None or len(args) != 1:
|
|
raise Unsupported(
|
|
"Triton kernels should be called with only a single grid"
|
|
)
|
|
|
|
return TritonKernelVariable(
|
|
kernel=self.kernel,
|
|
kernel_idx=self.kernel_idx,
|
|
grid=args[0],
|
|
)
|
|
elif name == "run":
|
|
if "grid" not in kwargs:
|
|
raise Unsupported("Triton kernel requires to be called with a grid")
|
|
grid = kwargs.pop("grid")
|
|
kwargs.pop("warmup", None)
|
|
# rewrite kernel.run(*args, grid=grid) to kernel[grid](*args)
|
|
return TritonKernelVariable(
|
|
kernel=self.kernel, kernel_idx=self.kernel_idx, grid=grid
|
|
).call_function(tx, args, kwargs)
|
|
|
|
# Bail out to parent's implementation
|
|
return super().call_method(tx, name, args, kwargs)
|