mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Wow, I had to sweat so much to get this PR out lol. This PR enforces the invariant that whenever we allocate SymInts as part of fakeification, the SymInt is associated with a Source, and in fact we store the string source name on SymbolWithSourceName. We use 'sname' as the shorthand for source name, as 'name' is already used by sympy to name symbols. In order to store source names, we have to plumb source names from Dynamo to PyTorch. This made doing this PR a bit bone crushing, because there are many points in the Dynamo codebase where we are improperly converting intermediate tensors into fake tensors, where there is no source (and there cannot be, because it's a frickin' intermediate tensor). I've fixed all of the really awful cases in earlier PRs in the stack. This PR is just plumbing in source names from places where we do have it. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/90295 Approved by: https://github.com/voznesenskym
433 lines
15 KiB
Python
433 lines
15 KiB
Python
import abc
|
|
import enum
|
|
import functools
|
|
import inspect
|
|
import itertools
|
|
import types
|
|
from typing import Dict, List
|
|
|
|
from .. import variables
|
|
from ..bytecode_transformation import create_instruction
|
|
from ..exc import unimplemented
|
|
from ..source import AttrSource, GetItemSource
|
|
from ..utils import make_cell
|
|
from .base import typestr, VariableTracker
|
|
|
|
|
|
def wrap_bound_arg(val, options):
|
|
if isinstance(val, dict):
|
|
return variables.ConstDictVariable(
|
|
{k: wrap_bound_arg(v, options) for k, v in val.items()}, dict, **options
|
|
)
|
|
elif isinstance(val, (tuple, list)):
|
|
cls = variables.BaseListVariable.cls_for(type(val))
|
|
return cls([wrap_bound_arg(x, options) for x in val], **options)
|
|
elif variables.ConstantVariable.is_literal(val):
|
|
return variables.ConstantVariable(val, **options)
|
|
elif isinstance(val, types.FunctionType):
|
|
return variables.UserFunctionVariable(val, **options)
|
|
elif isinstance(val, enum.Enum):
|
|
return variables.EnumVariable(val, **options)
|
|
elif isinstance(val, (type, abc.ABCMeta)):
|
|
return variables.UserDefinedClassVariable(val, **options)
|
|
else:
|
|
assert isinstance(val, VariableTracker), typestr(val)
|
|
return val
|
|
|
|
|
|
def wrap_args_kwargs(result, options):
|
|
for k, v in list(result.items()):
|
|
if isinstance(v, (tuple, dict)):
|
|
# args/kwargs
|
|
result[k] = wrap_bound_arg(v, options)
|
|
|
|
|
|
def init_cellvars(parent, result, code):
|
|
closure_cells = dict()
|
|
side_effects = parent.output.side_effects
|
|
|
|
for name in code.co_cellvars:
|
|
closure_cells[name] = side_effects.track_cell_new()
|
|
if name in result:
|
|
side_effects.store_cell(closure_cells[name], result.pop(name))
|
|
|
|
return closure_cells
|
|
|
|
|
|
class BaseUserFunctionVariable(VariableTracker):
|
|
def get_filename(self):
|
|
return self.get_code().co_filename
|
|
|
|
def get_name(self):
|
|
return self.get_code().co_name
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
return tx.inline_user_function_return(
|
|
self, list(self.self_args()) + list(args), kwargs
|
|
)
|
|
|
|
def num_parameters(self):
|
|
return len(inspect.signature(self.get_function()).parameters)
|
|
|
|
def closure_vars(self, tx):
|
|
return {}
|
|
|
|
|
|
class UserFunctionVariable(BaseUserFunctionVariable):
|
|
"""Some unsupported user-defined global function"""
|
|
|
|
def __init__(self, fn, is_constant=False, **kwargs):
|
|
super(UserFunctionVariable, self).__init__(**kwargs)
|
|
if getattr(fn, "_dynamo_marked_constant", False):
|
|
# This method should be treated as a constant for the purposes of compilation
|
|
self.is_constant = True
|
|
else:
|
|
self.is_constant = False
|
|
|
|
assert isinstance(
|
|
fn, types.FunctionType
|
|
), f"expected FunctionType found {typestr(fn)} {fn}"
|
|
# unpack @torch._dynamo.optimize()(fn) wrapped function
|
|
fn = inspect.getattr_static(fn, "_torchdynamo_inline", fn)
|
|
# unpack torch.jit.script_if_tracing
|
|
if inspect.getattr_static(fn, "__script_if_tracing_wrapper", False):
|
|
fn = inspect.getattr_static(fn, "__original_fn", fn)
|
|
self.fn: types.FunctionType = fn
|
|
|
|
def self_args(self):
|
|
return []
|
|
|
|
def get_function(self):
|
|
return self.fn
|
|
|
|
def get_code(self):
|
|
return self.fn.__code__
|
|
|
|
def python_type(self):
|
|
return types.FunctionType
|
|
|
|
def has_self(self):
|
|
return getattr(self.fn, "__self__", None) is not None
|
|
|
|
def get_globals(self):
|
|
return self.fn.__globals__
|
|
|
|
def bind_args(self, parent, args, kwargs):
|
|
assert not self.is_constant
|
|
options = VariableTracker.propagate([self])
|
|
wrap = functools.partial(wrap_bound_arg, options=options)
|
|
|
|
tx = parent.output.root_tx
|
|
|
|
fn: types.FunctionType = self.fn
|
|
fake_func = types.FunctionType(
|
|
fn.__code__,
|
|
fn.__globals__,
|
|
fn.__name__,
|
|
tuple(map(wrap, fn.__defaults__ or [])),
|
|
fn.__closure__,
|
|
)
|
|
if fn.__kwdefaults__:
|
|
fake_func.__kwdefaults__ = {
|
|
k: wrap(v) for k, v in fn.__kwdefaults__.items()
|
|
}
|
|
|
|
bound = inspect.signature(fake_func).bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
result = dict(bound.arguments.items())
|
|
|
|
wrap_args_kwargs(result, options)
|
|
closure_cells = init_cellvars(parent, result, fn.__code__)
|
|
closure = self.fn.__closure__ or ()
|
|
assert len(closure) == len(self.fn.__code__.co_freevars)
|
|
for idx, name, cell in zip(
|
|
itertools.count(), self.fn.__code__.co_freevars, closure
|
|
):
|
|
if name == "__class__":
|
|
result[name] = variables.UserDefinedClassVariable(cell.cell_contents)
|
|
else:
|
|
var = tx.match_nested_cell(name, cell)
|
|
if var is not None:
|
|
# optimization for cleaner codegen
|
|
result[name] = var
|
|
elif self.source:
|
|
from .builder import VariableBuilder
|
|
|
|
side_effects = parent.output.side_effects
|
|
if cell in side_effects:
|
|
out = side_effects[cell]
|
|
else:
|
|
closure_cell = GetItemSource(
|
|
AttrSource(self.source, "__closure__"), idx
|
|
)
|
|
closure_cell_contents = AttrSource(
|
|
closure_cell, "cell_contents"
|
|
)
|
|
contents_var = VariableBuilder(parent, closure_cell_contents)(
|
|
cell.cell_contents
|
|
)
|
|
|
|
if (
|
|
closure_cell_contents.name()
|
|
not in tx.mutated_closure_cell_contents
|
|
):
|
|
# Optimistically don't allocate the cell, to
|
|
# reduce the number of side effects. This is
|
|
# important for cond, as without it, any accesses
|
|
# to closures create side effects and cond doesn't
|
|
# support side effects. If we're wrong and this
|
|
# closure cell gets written to, we will restart
|
|
# the analysis with this cell's name in the
|
|
# mutated list here
|
|
result[name] = contents_var
|
|
continue
|
|
|
|
# cells are written to with "cell_contents",
|
|
# so the source should just be the closure_cell, not its contents
|
|
out = side_effects.track_cell_existing(closure_cell, cell)
|
|
side_effects.store_cell(
|
|
out,
|
|
contents_var,
|
|
)
|
|
|
|
result[name] = out
|
|
|
|
else:
|
|
unimplemented("inline with __closure__")
|
|
|
|
return result, closure_cells
|
|
|
|
def export_freevars(self, parent, child):
|
|
pass
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
if self.is_constant:
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
return invoke_and_store_as_constant(
|
|
tx, self.fn, self.get_name(), options, args, kwargs
|
|
)
|
|
|
|
return super(UserFunctionVariable, self).call_function(tx, args, kwargs)
|
|
|
|
|
|
class UserMethodVariable(UserFunctionVariable):
|
|
"""Some unsupported user-defined method"""
|
|
|
|
def __init__(self, fn, obj, **kwargs):
|
|
super(UserMethodVariable, self).__init__(fn=fn, **kwargs)
|
|
self.obj = obj
|
|
|
|
def __str__(self):
|
|
return f"{self.__class__.__name__}({self.fn}, {self.obj})"
|
|
|
|
def self_args(self):
|
|
return [self.obj]
|
|
|
|
def python_type(self):
|
|
return types.MethodType
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
if (
|
|
isinstance(self.obj, variables.NNModuleVariable)
|
|
and getattr(self.fn, "__module__", "").startswith("torch.nn.")
|
|
or self.is_constant
|
|
):
|
|
return self.obj.call_method(
|
|
tx, self.fn.__name__, args, kwargs, constant=self.is_constant
|
|
).add_options(self)
|
|
return super().call_function(tx, args, kwargs)
|
|
|
|
def num_parameters(self):
|
|
return super(UserMethodVariable, self).num_parameters() - 1
|
|
|
|
|
|
class WrappedUserMethodVariable(UserMethodVariable):
|
|
def __init__(self, wrapped, context, **kwargs):
|
|
kwargs.pop("fn", None)
|
|
kwargs.pop("obj", None)
|
|
super(WrappedUserMethodVariable, self).__init__(
|
|
wrapped.fn, wrapped.obj, **kwargs
|
|
)
|
|
self.wrapped = wrapped
|
|
self.context = context
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
self.context.enter(tx)
|
|
result = super().call_function(tx, args, kwargs)
|
|
self.context.exit(tx)
|
|
return result
|
|
|
|
|
|
class WrappedUserFunctionVariable(UserFunctionVariable):
|
|
def __init__(self, wrapped, context, **kwargs):
|
|
kwargs.pop("fn", None)
|
|
kwargs.pop("obj", None)
|
|
super(WrappedUserFunctionVariable, self).__init__(wrapped.fn, **kwargs)
|
|
self.wrapped = wrapped
|
|
self.context = context
|
|
|
|
def call_function(
|
|
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
|
) -> "VariableTracker":
|
|
self.context.enter(tx)
|
|
result = super().call_function(tx, args, kwargs)
|
|
self.context.exit(tx)
|
|
return result
|
|
|
|
|
|
def invoke_and_store_as_constant(tx, fn, name, options, args, kwargs):
|
|
def convert(x):
|
|
if isinstance(x, variables.TensorVariable):
|
|
return x.get_real_value()
|
|
return x.as_python_constant()
|
|
|
|
args = [convert(x) for x in args]
|
|
kwargs = {k: convert(v) for k, v in kwargs.items()}
|
|
res = fn(*args, **kwargs)
|
|
return tx.output.register_attr_or_module(
|
|
res,
|
|
name,
|
|
source=None, # constant!
|
|
**options,
|
|
)
|
|
|
|
|
|
class NestedUserFunctionVariable(BaseUserFunctionVariable):
|
|
def __init__(
|
|
self,
|
|
fn_name,
|
|
code,
|
|
f_globals,
|
|
defaults,
|
|
kwdefaults,
|
|
annotations,
|
|
closure,
|
|
closure_scope,
|
|
**kwargs,
|
|
):
|
|
super(NestedUserFunctionVariable, self).__init__(**kwargs)
|
|
assert isinstance(fn_name.as_python_constant(), str)
|
|
assert isinstance(code.as_python_constant(), types.CodeType)
|
|
assert isinstance(f_globals, dict)
|
|
self.fn_name = fn_name
|
|
self.code = code
|
|
self.f_globals = f_globals
|
|
self.defaults = defaults
|
|
self.kwdefaults = kwdefaults
|
|
self.annotations = annotations
|
|
self.closure = closure
|
|
if closure is None:
|
|
closure_scope = None
|
|
self.closure_scope = closure_scope
|
|
|
|
def self_args(self):
|
|
return []
|
|
|
|
def get_code(self):
|
|
return self.code.as_python_constant()
|
|
|
|
def get_function(self):
|
|
if self.closure:
|
|
raise NotImplementedError()
|
|
func = types.FunctionType(
|
|
self.code.as_python_constant(),
|
|
self.f_globals,
|
|
self.fn_name.as_python_constant(),
|
|
)
|
|
if self.defaults:
|
|
func.__defaults__ = self.defaults.as_python_constant()
|
|
if self.kwdefaults:
|
|
func.__kwdefaults__ = self.kwdefaults.as_python_constant()
|
|
if self.annotations:
|
|
annotations = self.annotations.as_python_constant()
|
|
if isinstance(annotations, tuple):
|
|
from itertools import pairwise
|
|
|
|
annotations = dict(pairwise(annotations))
|
|
|
|
# TypeError: __annotations__ must be set to a dict object
|
|
assert isinstance(annotations, dict)
|
|
func.__annotations__ = annotations
|
|
return func
|
|
|
|
def has_closure(self):
|
|
return self.closure is not None
|
|
|
|
def has_self(self):
|
|
return False
|
|
|
|
def get_globals(self):
|
|
return self.f_globals
|
|
|
|
def bind_args(self, parent, args, kwargs):
|
|
code = self.get_code()
|
|
func = types.FunctionType(
|
|
code,
|
|
self.f_globals,
|
|
self.fn_name.as_python_constant(),
|
|
tuple(self.defaults.items) if self.defaults else None,
|
|
tuple(make_cell(None) for _ in range(len(self.get_code().co_freevars))),
|
|
)
|
|
if self.kwdefaults:
|
|
func.__kwdefaults__ = self.kwdefaults.items
|
|
|
|
bound = inspect.signature(func).bind(*args, **kwargs)
|
|
bound.apply_defaults()
|
|
result = dict(bound.arguments.items())
|
|
|
|
wrap_args_kwargs(result, VariableTracker.propagate(self))
|
|
closure_cells = init_cellvars(parent, result, code)
|
|
|
|
for idx, name in enumerate(code.co_freevars):
|
|
assert getattr(self.closure.items[idx], name, name) == name
|
|
assert name not in result
|
|
closure_cells[name] = self.closure.items[idx]
|
|
|
|
return result, closure_cells
|
|
|
|
def export_freevars(self, parent, child):
|
|
code = self.get_code()
|
|
for var in code.co_freevars:
|
|
if var in child.symbolic_locals:
|
|
parent.symbolic_locals[var] = child.symbolic_locals[var]
|
|
|
|
def reconstruct(self, codegen):
|
|
flags = 0x00
|
|
if self.defaults:
|
|
flags |= 0x01
|
|
codegen(self.defaults)
|
|
if self.kwdefaults:
|
|
flags |= 0x02
|
|
codegen(self.kwdefaults)
|
|
if isinstance(self.annotations, variables.ConstDictVariable) or isinstance(
|
|
self.annotations, variables.TupleVariable
|
|
):
|
|
flags |= 0x04
|
|
try:
|
|
if isinstance(self.annotations, variables.ConstDictVariable):
|
|
annotations = {
|
|
k: v.as_python_constant()
|
|
for k, v in self.annotations.items.items()
|
|
}
|
|
else:
|
|
annotations = tuple(
|
|
[v.as_python_constant() for v in self.annotations.items]
|
|
)
|
|
codegen.extend_output([codegen._create_load_const(annotations)])
|
|
except NotImplementedError:
|
|
codegen(self.annotations)
|
|
if self.closure:
|
|
flags |= 0x08
|
|
codegen(self.closure)
|
|
codegen(self.code)
|
|
codegen(self.fn_name)
|
|
return [create_instruction("MAKE_FUNCTION", flags)]
|