pytorch/torch/_dynamo/variables/functions.py
Yanbo Liang d855b6aed6 [Dynamo] Add unit test for explicitly calling __call__ (#100146)
@wconstab As we discussed last Friday, I added the unit test for explicitly calling __call__ and added comment to explain why we redirecting ```UserMethodVariable.call_function``` to ```NNModuleVariable.call_method``` for a certain case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/100146
Approved by: https://github.com/wconstab
2023-04-27 15:47:11 +00:00

503 lines
17 KiB
Python

import abc
import enum
import functools
import inspect
import itertools
import sys
import types
from typing import Dict, List
import torch
from .. import variables
from ..allowed_functions import is_allowed, is_builtin_callable
from ..bytecode_transformation import (
create_call_function,
create_instruction,
create_rot_n,
)
from ..exc import unimplemented
from ..source import AttrSource, ConstantSource, DefaultsSource, GetItemSource
from ..utils import istensor, istype, make_cell
from .base import typestr, VariableTracker
def wrap_bound_arg(tx, val, options, source=None):
# Source propagation is best effort since not every object we encounter has a source to begin with.
assert (
"source" not in options
), "Source needs to be separate from options due to recursive calls for lists/dicts"
if isinstance(val, dict):
return variables.ConstDictVariable(
{
k: wrap_bound_arg(tx, v, options, source=getattr(v, "source", None))
for k, v in val.items()
},
dict,
**options,
)
elif isinstance(val, (tuple, list)):
cls = variables.BaseListVariable.cls_for(type(val))
return cls(
[
wrap_bound_arg(tx, x, options, source=getattr(x, "source", None))
for x in val
],
**options,
)
if variables.ConstantVariable.is_literal(val) or istype(
val, (torch.Size, torch.device, torch.dtype)
):
return variables.ConstantVariable(val, **options)
elif is_builtin_callable(val):
return variables.BuiltinVariable(val, source=source, **options)
elif is_allowed(val):
return variables.TorchVariable(val, source=source, **options)
elif isinstance(val, types.FunctionType):
return variables.UserFunctionVariable(val, source=source, **options)
elif isinstance(val, enum.Enum):
return variables.EnumVariable(val, source=source, **options)
elif isinstance(val, (type, abc.ABCMeta)):
return variables.UserDefinedClassVariable(val, source=source, **options)
elif istensor(val):
from torch._dynamo.variables.builder import VariableBuilder
return VariableBuilder(tx, source=source, **options)(val)
else:
assert isinstance(val, VariableTracker), typestr(val)
return val
def wrap_args_kwargs(tx, result, options):
for k, v in list(result.items()):
if isinstance(v, (tuple, dict)):
# args/kwargs
result[k] = wrap_bound_arg(tx, v, options)
def init_cellvars(parent, result, code):
closure_cells = dict()
side_effects = parent.output.side_effects
for name in code.co_cellvars:
closure_cells[name] = side_effects.track_cell_new()
if name in result:
side_effects.store_cell(closure_cells[name], result.pop(name))
return closure_cells
class BaseUserFunctionVariable(VariableTracker):
def get_filename(self):
return self.get_code().co_filename
def get_name(self):
return self.get_code().co_name
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
return tx.inline_user_function_return(
self, list(self.self_args()) + list(args), kwargs
)
def num_parameters(self):
return len(inspect.signature(self.get_function()).parameters)
def closure_vars(self, tx):
return {}
class UserFunctionVariable(BaseUserFunctionVariable):
"""Some unsupported user-defined global function"""
def __init__(self, fn, is_constant=False, **kwargs):
super().__init__(**kwargs)
if getattr(fn, "_dynamo_marked_constant", False):
# This method should be treated as a constant for the purposes of compilation
self.is_constant = True
else:
self.is_constant = False
assert isinstance(
fn, (types.FunctionType, torch.jit.ScriptFunction)
), f"expected FunctionType found {typestr(fn)} {fn}"
# unpack @torch._dynamo.optimize()(fn) wrapped function
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
# unpack torch.jit.script_if_tracing
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
fn = inspect.getattr_static(fn, "__original_fn", fn)
self.fn: types.FunctionType = fn
def self_args(self):
return []
def get_function(self):
return self.fn
def get_code(self):
return self.fn.__code__
def python_type(self):
return types.FunctionType
def has_self(self):
return getattr(self.fn, "__self__", None) is not None
def get_globals(self):
return self.fn.__globals__
def bind_args(self, parent, args, kwargs):
assert not self.is_constant
options = VariableTracker.propagate([self])
tx = parent.output.root_tx
wrap = functools.partial(wrap_bound_arg, tx=tx, options=options)
fn: types.FunctionType = self.fn
defaults = fn.__defaults__ or []
defaults_sources = [
None if self.source is None else DefaultsSource(self.source, idx)
for idx, _ in enumerate(defaults)
]
fake_func = types.FunctionType(
fn.__code__,
fn.__globals__,
fn.__name__,
tuple(
[
wrap(val=arg, source=source)
for arg, source in zip(defaults, defaults_sources)
]
),
fn.__closure__,
)
if fn.__kwdefaults__:
kwdefaults_sources = {
k: None
if self.source is None
else DefaultsSource(self.source, k, is_kw=True)
for k in fn.__kwdefaults__
}
fake_func.__kwdefaults__ = {
k: wrap(val=v, source=kwdefaults_sources[k])
for k, v in fn.__kwdefaults__.items()
}
bound = inspect.signature(fake_func).bind(*args, **kwargs)
bound.apply_defaults()
result = dict(bound.arguments.items())
wrap_args_kwargs(tx, result, options)
closure_cells = init_cellvars(parent, result, fn.__code__)
closure = self.fn.__closure__ or ()
assert len(closure) == len(self.fn.__code__.co_freevars)
for idx, name, cell in zip(
itertools.count(), self.fn.__code__.co_freevars, closure
):
if name == "__class__":
source = AttrSource(self.source, "__class__") if self.source else None
result[name] = variables.UserDefinedClassVariable(
cell.cell_contents,
source=source,
)
else:
var = tx.match_nested_cell(name, cell)
if var is not None:
# optimization for cleaner codegen
result[name] = var
elif self.source:
from .builder import VariableBuilder
side_effects = parent.output.side_effects
if cell in side_effects:
out = side_effects[cell]
else:
closure_cell = GetItemSource(
AttrSource(self.source, "__closure__"), idx
)
closure_cell_contents = AttrSource(
closure_cell, "cell_contents"
)
contents_var = VariableBuilder(parent, closure_cell_contents)(
cell.cell_contents
)
if (
closure_cell_contents.name()
not in tx.mutated_closure_cell_contents
):
# Optimistically don't allocate the cell, to
# reduce the number of side effects. This is
# important for cond, as without it, any accesses
# to closures create side effects and cond doesn't
# support side effects. If we're wrong and this
# closure cell gets written to, we will restart
# the analysis with this cell's name in the
# mutated list here
result[name] = contents_var
continue
# cells are written to with "cell_contents",
# so the source should just be the closure_cell, not its contents
out = side_effects.track_cell_existing(closure_cell, cell)
side_effects.store_cell(
out,
contents_var,
)
result[name] = out
else:
unimplemented("inline with __closure__")
return result, closure_cells
def export_freevars(self, parent, child):
pass
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
if self.is_constant:
options = VariableTracker.propagate(self, args, kwargs.values())
return invoke_and_store_as_constant(
tx, self.fn, self.get_name(), options, args, kwargs
)
return super().call_function(tx, args, kwargs)
class UserMethodVariable(UserFunctionVariable):
"""Some unsupported user-defined method"""
def __init__(self, fn, obj, **kwargs):
super().__init__(fn=fn, **kwargs)
self.obj = obj
def __str__(self):
return f"{self.__class__.__name__}({self.fn}, {self.obj})"
def self_args(self):
return [self.obj]
def python_type(self):
return types.MethodType
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
# For nn.Module methods, redirecting to NNModuleVariable.call_method for optimized solution
# rather than simple inlining. E.g, putting `call_method` op in FX graph for `forward` method
# since we ensure `forward` of allowed modules can be traced by AOT safely.
# Note this is not only for allowed modules, as user customized modules can extend from
# allowed modules but using parent's `forward` method, which is also covered by this branch.
if isinstance(self.obj, variables.NNModuleVariable):
module_attr = getattr(self.fn, "__module__", "")
if (
module_attr is not None
and module_attr.startswith("torch.nn.")
or self.is_constant
):
return self.obj.call_method(
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
).add_options(self)
return super().call_function(tx, args, kwargs)
def num_parameters(self):
return super().num_parameters() - 1
class WrappedUserMethodVariable(UserMethodVariable):
def __init__(self, wrapped, context, **kwargs):
kwargs.pop("fn", None)
kwargs.pop("obj", None)
super().__init__(wrapped.fn, wrapped.obj, **kwargs)
self.wrapped = wrapped
self.context = context
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
self.context.enter(tx)
result = super().call_function(tx, args, kwargs)
self.context.exit(tx)
return result
class WrappedUserFunctionVariable(UserFunctionVariable):
def __init__(self, wrapped, context, **kwargs):
kwargs.pop("fn", None)
kwargs.pop("obj", None)
super().__init__(wrapped.fn, **kwargs)
self.wrapped = wrapped
self.context = context
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
self.context.enter(tx)
result = super().call_function(tx, args, kwargs)
self.context.exit(tx)
return result
def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs):
def convert(x):
if isinstance(x, variables.TensorVariable):
return x.get_real_value()
return x.as_python_constant()
args = [convert(x) for x in args]
kwargs = {k: convert(v) for k, v in kwargs.items()}
res = fn(*args, **kwargs)
return tx.output.register_attr_or_module(
res,
name,
source=ConstantSource(name),
**options,
)
class NestedUserFunctionVariable(BaseUserFunctionVariable):
def __init__(
self,
fn_name,
code,
f_globals,
defaults,
kwdefaults,
annotations,
closure,
closure_scope,
wraps_source=None,
**kwargs,
):
super().__init__(**kwargs)
assert isinstance(fn_name.as_python_constant(), str)
assert isinstance(code.as_python_constant(), types.CodeType)
assert isinstance(f_globals, dict)
self.fn_name = fn_name
self.code = code
self.f_globals = f_globals
self.defaults = defaults
self.kwdefaults = kwdefaults
self.annotations = annotations
self.closure = closure
if closure is None:
closure_scope = None
self.closure_scope = closure_scope
self.wraps_source = wraps_source
def self_args(self):
return []
def get_code(self):
return self.code.as_python_constant()
def get_function(self):
if self.closure:
raise NotImplementedError()
func = types.FunctionType(
self.code.as_python_constant(),
self.f_globals,
self.fn_name.as_python_constant(),
)
if self.defaults:
func.__defaults__ = self.defaults.as_python_constant()
if self.kwdefaults:
func.__kwdefaults__ = self.kwdefaults.as_python_constant()
if self.annotations:
annotations = self.annotations.as_python_constant()
if isinstance(annotations, tuple):
from itertools import pairwise
annotations = dict(pairwise(annotations))
# TypeError: __annotations__ must be set to a dict object
assert isinstance(annotations, dict)
func.__annotations__ = annotations
return func
def has_closure(self):
return self.closure is not None
def has_self(self):
return False
def get_globals(self):
return self.f_globals
def bind_args(self, parent, args, kwargs):
code = self.get_code()
func = types.FunctionType(
code,
self.f_globals,
self.fn_name.as_python_constant(),
tuple(self.defaults.items) if self.defaults else None,
tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
)
if self.kwdefaults:
func.__kwdefaults__ = self.kwdefaults.items
bound = inspect.signature(func).bind(*args, **kwargs)
bound.apply_defaults()
result = dict(bound.arguments.items())
wrap_args_kwargs(parent.output.root_tx, result, VariableTracker.propagate(self))
closure_cells = init_cellvars(parent, result, code)
for idx, name in enumerate(code.co_freevars):
assert getattr(self.closure.items[idx], name, name) == name
assert name not in result
closure_cells[name] = self.closure.items[idx]
return result, closure_cells
def export_freevars(self, parent, child):
code = self.get_code()
for var in code.co_freevars:
if var in child.symbolic_locals:
parent.symbolic_locals[var] = child.symbolic_locals[var]
def reconstruct(self, codegen):
flags = 0x00
if self.defaults:
flags |= 0x01
codegen(self.defaults)
if self.kwdefaults:
flags |= 0x02
codegen(self.kwdefaults)
if isinstance(
self.annotations, (variables.ConstDictVariable, variables.TupleVariable)
):
flags |= 0x04
try:
if isinstance(self.annotations, variables.ConstDictVariable):
annotations = {
k: v.as_python_constant()
for k, v in self.annotations.items.items()
}
else:
annotations = tuple(
[v.as_python_constant() for v in self.annotations.items]
)
codegen.extend_output([codegen._create_load_const(annotations)])
except NotImplementedError:
codegen(self.annotations)
if self.closure:
flags |= 0x08
codegen(self.closure)
codegen(self.code)
if sys.version_info < (3, 11):
codegen(self.fn_name)
codegen.extend_output([create_instruction("MAKE_FUNCTION", arg=flags)])
if self.wraps_source:
codegen.load_import_from("functools", "wraps")
codegen(self.wraps_source)
codegen.extend_output(create_call_function(1, True))
codegen.extend_output(create_rot_n(2))
codegen.extend_output(create_call_function(1, True))
return []