mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Removes always restore, assuming that a HOP will cleanup any leftover state from tracing fwd + bwd This required a minor change to the autograd fn variable higher order op. If we are tracing forward DON'T add the call_function node into the main graph, since we are only tracing it for the purposes of speculation. Instead return the result directly to be passed to the backward for speculation. This was the only observable side effect on the output graph that I found. Test plan: test_smoke_from_test_autograd in test_autograd_function.py Pull Request resolved: https://github.com/pytorch/pytorch/pull/115317 Approved by: https://github.com/voznesenskym, https://github.com/jansel
1163 lines
42 KiB
Python
1163 lines
42 KiB
Python
import collections
|
||
import dataclasses
|
||
import functools
|
||
import inspect
|
||
import itertools
|
||
import operator
|
||
import sys
|
||
import types
|
||
from typing import Dict, List
|
||
|
||
import torch._C
|
||
import torch._numpy as tnp
|
||
from .. import config, polyfill, variables
|
||
from ..bytecode_transformation import create_call_function, create_instruction
|
||
from ..exc import unimplemented
|
||
from ..guards import GuardBuilder, install_guard
|
||
from ..source import AttrSource, GetItemSource, ODictGetItemSource, TypeSource
|
||
from ..utils import (
|
||
check_constant_args,
|
||
identity,
|
||
is_tensor_base_attr_getter,
|
||
proxy_args_kwargs,
|
||
)
|
||
from .base import MutableLocal, VariableTracker
|
||
from .dicts import DefaultDictVariable
|
||
from .functions import (
|
||
NestedUserFunctionVariable,
|
||
UserFunctionVariable,
|
||
UserMethodVariable,
|
||
)
|
||
from .user_defined import UserDefinedObjectVariable
|
||
|
||
|
||
class SuperVariable(VariableTracker):
|
||
def __init__(self, typevar, objvar=None, specialized=False, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.typevar = typevar
|
||
self.objvar = objvar
|
||
self.specialized = specialized # directly get attr from self.typevar if true
|
||
|
||
def reconstruct(self, codegen):
|
||
codegen(variables.BuiltinVariable(super))
|
||
codegen(self.typevar)
|
||
if self.objvar is not None:
|
||
codegen(self.objvar)
|
||
return create_call_function(2, True)
|
||
else:
|
||
return create_call_function(1, True)
|
||
|
||
def _resolved_getattr_and_source(self, tx, name):
|
||
assert self.objvar, "1-arg super not implemented"
|
||
if self.specialized:
|
||
return getattr(self.typevar.as_python_constant(), name)
|
||
search_type = self.typevar.as_python_constant()
|
||
|
||
# We default to the python type of the object. However, if this is
|
||
# a `type` or subclass of `type`, then the original object represents
|
||
# the user defined type.
|
||
type_to_use = self.objvar.python_type()
|
||
type_to_use_source = (
|
||
TypeSource(self.objvar.source) if self.objvar.source else None
|
||
)
|
||
if issubclass(type_to_use, type):
|
||
type_to_use = self.objvar.value
|
||
type_to_use_source = self.objvar.source
|
||
|
||
source = None
|
||
if self.objvar.source is not None:
|
||
# Walk the mro tuple to find out the actual class where the
|
||
# attribute resides.
|
||
search_mro = type_to_use.__mro__
|
||
start_index = search_mro.index(search_type) + 1
|
||
for index in range(start_index, len(search_mro)):
|
||
if hasattr(search_mro[index], name):
|
||
# Equivalent of something like type(L['self']).__mro__[1].attr_name
|
||
source = AttrSource(
|
||
GetItemSource(AttrSource(type_to_use_source, "__mro__"), index),
|
||
name,
|
||
)
|
||
break
|
||
|
||
# TODO(jansel): there is a small chance this could trigger user code, prevent that
|
||
return getattr(super(search_type, type_to_use), name), source
|
||
|
||
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
||
# Check if getattr is a constant. If not, delay the actual work by
|
||
# wrapping the result in GetAttrVariable. Mostly super is called with a
|
||
# method, so most of the work is delayed to call_function.
|
||
#
|
||
# We could have just implemented a const_getattr. However, super is
|
||
# special when it comes to finding sources. Compared to other VTs, super
|
||
# requires the attr name to walk the mro and find the actual source (and
|
||
# not just AttrSource).
|
||
value, source = self._resolved_getattr_and_source(self, name)
|
||
if not variables.ConstantVariable.is_literal(value):
|
||
return GetAttrVariable(self, name)
|
||
if source:
|
||
install_guard(source.make_guard(GuardBuilder.CONSTANT_MATCH))
|
||
return variables.ConstantVariable.create(value, source=source)
|
||
return variables.ConstantVariable.create(value)
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
inner_fn, source = self._resolved_getattr_and_source(self, name)
|
||
|
||
if inner_fn is object.__init__:
|
||
return LambdaVariable(identity)
|
||
elif inner_fn is torch.nn.Module.__init__:
|
||
objvar = self.objvar
|
||
from ..side_effects import AttributeMutationNew
|
||
|
||
if (
|
||
isinstance(objvar, variables.UserDefinedObjectVariable)
|
||
and isinstance(objvar.mutable_local, AttributeMutationNew)
|
||
and not (args or kwargs)
|
||
):
|
||
tx.output.side_effects.store_attr(
|
||
objvar,
|
||
"__call_nn_module_init",
|
||
variables.ConstantVariable.create(True),
|
||
)
|
||
return variables.ConstantVariable.create(None)
|
||
else:
|
||
unimplemented("super() nn.Module.__init__")
|
||
elif isinstance(inner_fn, types.FunctionType):
|
||
return variables.UserFunctionVariable(
|
||
inner_fn, source=source
|
||
).call_function(tx, [self.objvar] + args, kwargs)
|
||
elif isinstance(inner_fn, types.MethodType):
|
||
return variables.UserMethodVariable(
|
||
inner_fn.__func__, self.objvar, source=source
|
||
).call_function(tx, args, kwargs)
|
||
elif (
|
||
inner_fn is collections.OrderedDict.__getitem__
|
||
and isinstance(self.objvar, variables.UserDefinedObjectVariable)
|
||
and self.objvar.source
|
||
and len(args) == 1
|
||
and len(kwargs) == 0
|
||
and args[0].is_python_constant()
|
||
):
|
||
from .builder import VariableBuilder
|
||
|
||
key = args[0].as_python_constant()
|
||
return VariableBuilder(tx, ODictGetItemSource(self.objvar.source, key))(
|
||
collections.OrderedDict.__getitem__(self.objvar.value, key)
|
||
)
|
||
elif (
|
||
inner_fn in (collections.OrderedDict.__setitem__, object.__setattr__)
|
||
and isinstance(self.objvar, variables.CustomizedDictVariable)
|
||
and args
|
||
and variables.ConstDictVariable.is_valid_key(args[0])
|
||
and self.objvar.mutable_local
|
||
):
|
||
assert not kwargs and len(args) == 2
|
||
k = variables.ConstDictVariable.get_key(args[0])
|
||
|
||
newval = dict(self.objvar.items)
|
||
newval[k] = args[1]
|
||
return tx.replace_all(
|
||
self.objvar,
|
||
self.objvar.modifed(newval),
|
||
)
|
||
else:
|
||
unimplemented(f"non-function or method super: {inner_fn}")
|
||
|
||
|
||
class UnknownVariable(VariableTracker):
|
||
"""
|
||
It could be anything!
|
||
"""
|
||
|
||
|
||
class DelayGraphBreakVariable(UnknownVariable):
|
||
"""
|
||
Used to insert a dummy variable in the stack to do the graph break at CALL_FUNCTION.
|
||
"""
|
||
|
||
|
||
class ComptimeVariable(VariableTracker):
|
||
"""
|
||
This variable is special, it lets you execute arbitrary code at
|
||
Dynamo compile time
|
||
"""
|
||
|
||
def reconstruct(self, codegen):
|
||
raise NotImplementedError("comptime is special form")
|
||
|
||
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
||
from ..comptime import comptime
|
||
|
||
# To support the comptime.print_graph convenience accessors
|
||
from .functions import UserFunctionVariable
|
||
|
||
return UserFunctionVariable(
|
||
getattr(comptime, name), source=AttrSource(self.source, name)
|
||
)
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
from ..comptime import ComptimeContext
|
||
|
||
# TODO: support an expression form as well
|
||
|
||
assert not kwargs
|
||
assert len(args) == 1
|
||
fn = args[0]
|
||
if isinstance(fn, UserFunctionVariable):
|
||
fn.get_function()(ComptimeContext(tx))
|
||
elif isinstance(fn, NestedUserFunctionVariable):
|
||
# We have to manually bind the freevars ourselves
|
||
code = fn.get_code()
|
||
assert not fn.closure, (
|
||
"comptime function must not have free variables, "
|
||
f"but these variables were free: {code.co_freevars}"
|
||
)
|
||
func = types.FunctionType(
|
||
code,
|
||
fn.f_globals,
|
||
fn.fn_name.as_python_constant(),
|
||
tuple(fn.defaults.items) if fn.defaults else None,
|
||
# We could automatically promote free variables into
|
||
# ComptimeVar but this is confusing if you access
|
||
# a free variable that we actually DO have the runtime
|
||
# value for
|
||
# tuple(make_cell(ComptimeVar(i)) for i in fn.closure.items)
|
||
tuple(),
|
||
)
|
||
func(ComptimeContext(tx))
|
||
else:
|
||
raise RuntimeError(f"unsupported argument to comptime: {type(fn)}")
|
||
|
||
return variables.ConstantVariable.create(None)
|
||
|
||
|
||
class ClosureVariable(UnknownVariable):
|
||
def __init__(self, name, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.name = name
|
||
|
||
def reconstruct(self, codegen):
|
||
return [codegen.create_load_closure(self.name)]
|
||
|
||
|
||
# closure variable created by an inlined function
|
||
class InlinedClosureVariable(UnknownVariable):
|
||
def __init__(self, name, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.name = name
|
||
|
||
def reconstruct(self, codegen):
|
||
return [codegen.create_load_closure(self.name)]
|
||
|
||
|
||
class NewCellVariable(VariableTracker):
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
|
||
|
||
class NewGlobalVariable(VariableTracker):
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
|
||
|
||
class InspectSignatureVariable(VariableTracker):
|
||
"""represents inspect.signature(...)"""
|
||
|
||
@staticmethod
|
||
def create(callable, **kwargs):
|
||
if kwargs:
|
||
unimplemented(f"inspect.signature with {kwargs}")
|
||
return InspectSignatureVariable(callable)
|
||
|
||
def __init__(self, inspected: VariableTracker, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.inspected = inspected
|
||
|
||
def var_getattr(self, tx, name: str) -> "VariableTracker":
|
||
if name == "parameters":
|
||
return variables.ConstDictVariable(
|
||
{
|
||
name: InspectParameterVariable()
|
||
for name in self.inspected.inspect_parameter_names()
|
||
},
|
||
user_cls=dict,
|
||
)
|
||
return super().var_getattr(tx, name)
|
||
|
||
|
||
class InspectParameterVariable(VariableTracker):
|
||
"""This is not implemented, if used will graph break."""
|
||
|
||
pass
|
||
|
||
|
||
def produce_trampoline_autograd_fwd(fn_cls):
|
||
def trampoline_autograd_fwd(*args, **kwargs):
|
||
return fn_cls.forward(*args, **kwargs)
|
||
|
||
trampoline_autograd_fwd._origin = produce_trampoline_autograd_fwd
|
||
return trampoline_autograd_fwd
|
||
|
||
|
||
def produce_trampoline_autograd_bwd(fn_cls):
|
||
def trampoline_autograd_bwd(*args, **kwargs):
|
||
return fn_cls.backward(*args, **kwargs)
|
||
|
||
trampoline_autograd_bwd._origin = produce_trampoline_autograd_bwd
|
||
return trampoline_autograd_bwd
|
||
|
||
|
||
def produce_trampoline_autograd_apply(fn_cls):
|
||
def trampoline_autograd_apply(*args, **kwargs):
|
||
return fn_cls.apply(*args, **kwargs)
|
||
|
||
trampoline_autograd_apply._origin = produce_trampoline_autograd_apply
|
||
return trampoline_autograd_apply
|
||
|
||
|
||
class AutogradFunctionVariable(VariableTracker):
|
||
"""represents a torch.autograd.Function subclass"""
|
||
|
||
def __init__(self, fn_cls, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.fn_cls = fn_cls
|
||
|
||
def call_apply(self, tx, args, kwargs):
|
||
requires_grad = False
|
||
|
||
def visit(node):
|
||
nonlocal requires_grad
|
||
if isinstance(node, variables.TensorVariable):
|
||
if node.requires_grad is not False:
|
||
requires_grad = True
|
||
if isinstance(node, variables.NNModuleVariable):
|
||
if node.is_training(tx):
|
||
requires_grad = True
|
||
return node
|
||
|
||
VariableTracker.apply(visit, (args, kwargs))
|
||
|
||
ctx = AutogradFunctionContextVariable.create(tx)
|
||
args = [ctx, *args]
|
||
|
||
if (
|
||
requires_grad
|
||
and torch.is_grad_enabled()
|
||
and config.capture_autograd_function
|
||
):
|
||
# Note - this is the same check used in autograd/function.py, except inverted.
|
||
# If we want to support functorch transforms here, we will need to enable this.
|
||
if (
|
||
self.fn_cls.setup_context
|
||
!= torch.autograd.function._SingleLevelFunction.setup_context
|
||
):
|
||
unimplemented(
|
||
"NYI - autograd.Function with custom setup_context method"
|
||
)
|
||
|
||
vjp_fn = self.fn_cls.vjp # type: ignore[attr-defined]
|
||
if vjp_fn is not torch.autograd.Function.vjp:
|
||
unimplemented("NYI - User defind vjp")
|
||
|
||
jvp_fn = self.fn_cls.jvp # type: ignore[attr-defined]
|
||
if jvp_fn is not torch.autograd.Function.jvp:
|
||
unimplemented("NYI - User defind jvp")
|
||
|
||
from .higher_order_ops import TorchHigherOrderOperatorVariable
|
||
|
||
trampoline_autograd_apply = produce_trampoline_autograd_apply(self.fn_cls)
|
||
trampoline_autograd_fwd = produce_trampoline_autograd_fwd(self.fn_cls)
|
||
trampoline_autograd_bwd = produce_trampoline_autograd_bwd(self.fn_cls)
|
||
|
||
# NOTE [On Tracing autograd.Function w/ grad]
|
||
# The complex system described here revolves around the soundness evaluation of an autograd.Function in
|
||
# PyTorch. The system follows a well-defined strategy for tracing, which involves three key steps: tracing
|
||
# forward, tracing backward, and if both are sound the potential recording of an "apply" operation into the
|
||
# graph.We trace forward, and evaluate soundness. Soundness, in this context, refers to the absence of side
|
||
# effects, the avoidance of lifting new arguments into the trace, the production of a single tensor output,
|
||
# and a limited input scope confined to contexts, tensors, and constants. If the forward trace is sound,
|
||
# we install any guards accumulated from tracing. If not, we graph break. We trace backward, and evaluate
|
||
# for soundness, same as forward, except with more strictness. We enable a strict mode on the tx, and
|
||
# reject certain ops when running under this strict mode. If both the forward and backward traces are sound,
|
||
# we write the autograd function’s apply into the graph.
|
||
|
||
# For tracing forward and backward, we use UserFunctionVariable. Although it does not directly contribute
|
||
# to soundness evaluation, it plus a GlobalSource makes sure we can produce valid guards,
|
||
# and that we can inline properly here. Inlining is required in order to be able to ensure that the
|
||
# soundness evaluation works as described above.
|
||
|
||
module_source = AttrSource(
|
||
tx.import_source(self.fn_cls.__module__), self.fn_cls.__name__
|
||
)
|
||
fwd_bwd_tracer = torch._dynamo.output_graph.SubgraphTracer(
|
||
tx.output,
|
||
parent=tx.output.current_tracer,
|
||
source_target="autograd.Function",
|
||
)
|
||
higher_order_autograd_fn = TorchHigherOrderOperatorVariable.make(
|
||
trampoline_autograd_fwd,
|
||
source=AttrSource(module_source, "forward"),
|
||
fwd_bwd_tracer=fwd_bwd_tracer,
|
||
)
|
||
speculated_fwd_result = higher_order_autograd_fn.call_function(
|
||
tx, args, kwargs
|
||
)
|
||
|
||
if isinstance(speculated_fwd_result, variables.TupleVariable):
|
||
bwd_args = [ctx, *speculated_fwd_result.items]
|
||
else:
|
||
bwd_args = [ctx, speculated_fwd_result]
|
||
|
||
TorchHigherOrderOperatorVariable.make(
|
||
trampoline_autograd_bwd,
|
||
source=AttrSource(module_source, "backward"),
|
||
fwd_bwd_tracer=fwd_bwd_tracer,
|
||
).call_function(tx, bwd_args, {})
|
||
|
||
# If fwd and backward are sound, we want apply in the graph.
|
||
# We don't want backward because we are tracing forwards.
|
||
args = args[1:]
|
||
return TorchHigherOrderOperatorVariable.make(
|
||
trampoline_autograd_apply,
|
||
fwd_bwd_tracer=None,
|
||
).call_function(tx, args, kwargs)
|
||
|
||
if self.source:
|
||
source = AttrSource(AttrSource(self.source, "__class__"), "forward")
|
||
else:
|
||
source = None
|
||
fn = self.fn_cls.forward
|
||
if isinstance(fn, types.FunctionType):
|
||
return variables.UserFunctionVariable(fn, source=source).call_function(
|
||
tx, args, kwargs
|
||
)
|
||
elif isinstance(fn, types.MethodType):
|
||
return variables.UserMethodVariable(
|
||
fn.__func__,
|
||
variables.UserDefinedClassVariable(self.fn_cls),
|
||
source=source,
|
||
).call_function(tx, args, kwargs)
|
||
else:
|
||
unimplemented(
|
||
f"non-function or method in subclass of torch.autograd.Function: {fn}"
|
||
)
|
||
|
||
def call_function(self, tx, args, kwargs):
|
||
return AutogradFunctionVariable(self.fn_cls)
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
):
|
||
from ..allowed_functions import is_user_defined_allowed
|
||
from .builder import wrap_fx_proxy
|
||
|
||
if name == "apply":
|
||
if is_user_defined_allowed(self.fn_cls):
|
||
trampoline_autograd_apply = produce_trampoline_autograd_apply(
|
||
self.fn_cls
|
||
)
|
||
return wrap_fx_proxy(
|
||
tx=tx,
|
||
proxy=tx.output.create_proxy(
|
||
"call_function",
|
||
trampoline_autograd_apply,
|
||
*proxy_args_kwargs(args, kwargs),
|
||
),
|
||
)
|
||
else:
|
||
return self.call_apply(tx, args, kwargs)
|
||
elif name == "backward":
|
||
with tx.strict_translation_mode():
|
||
if isinstance(self.fn_cls.backward, types.FunctionType):
|
||
backward = UserFunctionVariable(self.fn_cls.backward)
|
||
elif isinstance(self.fn_cls.backward, types.MethodType):
|
||
backward = UserMethodVariable(
|
||
self.fn_cls.backward.__func__,
|
||
variables.UserDefinedClassVariable(self.fn_cls),
|
||
)
|
||
args = [backward.obj] + args
|
||
else:
|
||
unimplemented(
|
||
f"backward is a non-function or method: {self.fn_cls.backward}"
|
||
)
|
||
|
||
return tx.inline_call(tx, backward, args, kwargs)
|
||
|
||
elif name == "forward":
|
||
if isinstance(self.fn_cls.forward, types.FunctionType):
|
||
forward = UserFunctionVariable(self.fn_cls.forward)
|
||
elif isinstance(self.fn_cls.forward, types.MethodType):
|
||
forward = UserMethodVariable(
|
||
self.fn_cls.forward.__func__,
|
||
variables.UserDefinedClassVariable(self.fn_cls),
|
||
)
|
||
args = [forward.obj] + args
|
||
else:
|
||
unimplemented(
|
||
f"forward is a non-function or method: {self.fn_cls.forward}"
|
||
)
|
||
|
||
return tx.inline_call(tx, forward, args, kwargs)
|
||
|
||
else:
|
||
unimplemented(f"Unsupported method: {name}")
|
||
|
||
|
||
@dataclasses.dataclass
|
||
class SavedTensorBox:
|
||
tensors: List[VariableTracker] = dataclasses.field(default_factory=list)
|
||
|
||
|
||
class AutogradFunctionContextVariable(UserDefinedObjectVariable):
|
||
"""
|
||
Tracks an autograd.Function() context using mutation tracking in side_effects.py
|
||
"""
|
||
|
||
_nonvar_fields = {
|
||
"proxy",
|
||
"inference",
|
||
*UserDefinedObjectVariable._nonvar_fields,
|
||
}
|
||
|
||
def __init__(
|
||
self,
|
||
value,
|
||
value_type=None,
|
||
inference=False,
|
||
proxy=None,
|
||
saved_tensors=None,
|
||
**kwargs,
|
||
):
|
||
super().__init__(value=value, value_type=value_type, **kwargs)
|
||
self.inference = inference
|
||
self.proxy = proxy
|
||
self.saved_tensors = saved_tensors
|
||
|
||
@staticmethod
|
||
def create(tx):
|
||
proxy = tx.output.create_proxy(
|
||
"call_function", torch.autograd.function.FunctionCtx, tuple(), {}
|
||
)
|
||
out = tx.output.side_effects.track_object_new(
|
||
None,
|
||
torch.autograd.function.FunctionCtx,
|
||
functools.partial(
|
||
AutogradFunctionContextVariable,
|
||
inference=True,
|
||
proxy=proxy,
|
||
saved_tensors=SavedTensorBox(),
|
||
),
|
||
{},
|
||
)
|
||
proxy.node.meta["example_value"] = out.value
|
||
return out
|
||
|
||
def as_proxy(self):
|
||
if self.proxy is None:
|
||
unimplemented("proxy not set")
|
||
return self.proxy
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
if name != "save_for_backward":
|
||
unimplemented(f"autograd.Function context method: {name}")
|
||
if self.saved_tensors is None:
|
||
unimplemented(
|
||
"save_for_backward only supported on a newly constructed FunctionCtx"
|
||
)
|
||
|
||
if not self.inference:
|
||
assert self.source and not kwargs
|
||
tx.output.side_effects.track_save_for_backward(self, args)
|
||
|
||
for arg in args:
|
||
self.saved_tensors.tensors.append(arg)
|
||
return variables.ConstantVariable.create(None)
|
||
|
||
def var_getattr(self, tx, name):
|
||
if name == "save_for_backward":
|
||
return LambdaVariable(
|
||
lambda *args, **kwargs: self.call_method(tx, name, args, kwargs)
|
||
)
|
||
if name == "saved_tensors" and self.saved_tensors is not None:
|
||
return variables.TupleVariable(list(self.saved_tensors.tensors))
|
||
return super().var_getattr(tx, name)
|
||
|
||
|
||
class LambdaVariable(VariableTracker):
|
||
def __init__(self, fn, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.fn = fn
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
return self.fn(*args, **kwargs)
|
||
|
||
|
||
class GetAttrVariable(VariableTracker):
|
||
def __init__(self, obj, name, **kwargs):
|
||
super().__init__(**kwargs)
|
||
assert isinstance(obj, VariableTracker)
|
||
assert isinstance(name, str)
|
||
self.obj = obj
|
||
self.name = name
|
||
|
||
def __str__(self):
|
||
return f"{self.__class__.__name__}({self.obj}, {self.name})"
|
||
|
||
@staticmethod
|
||
def create_getattr_proxy(base_proxy: torch.fx.Proxy, attr):
|
||
return getattr(base_proxy, attr)
|
||
|
||
def as_proxy(self):
|
||
return GetAttrVariable.create_getattr_proxy(self.obj.as_proxy(), self.name)
|
||
|
||
def const_getattr(self, tx, name):
|
||
if not isinstance(self.obj, variables.NNModuleVariable):
|
||
raise NotImplementedError()
|
||
step1 = tx.output.get_submodule(self.obj.module_key)
|
||
if self.name not in step1.__dict__:
|
||
raise NotImplementedError()
|
||
step2 = inspect.getattr_static(step1, self.name)
|
||
if name not in step2.__dict__:
|
||
raise NotImplementedError()
|
||
return inspect.getattr_static(step2, name)
|
||
|
||
def reconstruct(self, codegen):
|
||
codegen(self.obj)
|
||
return codegen.create_load_attrs(self.name)
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
return self.obj.call_method(tx, self.name, args, kwargs)
|
||
|
||
|
||
class MethodWrapperVariable(VariableTracker):
|
||
def __init__(self, method_wrapper, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.method_wrapper = method_wrapper
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
if is_tensor_base_attr_getter(self.method_wrapper) and isinstance(
|
||
args[0], variables.TensorVariable
|
||
):
|
||
assert len(args) == 1 and len(kwargs) == 0
|
||
|
||
return args[0].var_getattr(tx, self.method_wrapper.__self__.__name__)
|
||
|
||
super().call_function(tx, args, kwargs)
|
||
|
||
def is_python_constant(self):
|
||
return True
|
||
|
||
def as_python_constant(self):
|
||
return self.method_wrapper
|
||
|
||
|
||
class GetSetDescriptorVariable(VariableTracker):
|
||
def __init__(self, desc, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.desc = desc
|
||
|
||
def var_getattr(self, tx, name):
|
||
if name == "__get__" and self.source:
|
||
from .builder import VariableBuilder
|
||
|
||
return VariableBuilder(tx, AttrSource(self.source, "__get__"))(
|
||
self.desc.__get__
|
||
)
|
||
else:
|
||
return super().var_getattr(tx, name)
|
||
|
||
def is_python_constant(self):
|
||
return True
|
||
|
||
def as_python_constant(self):
|
||
return self.desc
|
||
|
||
|
||
class PythonModuleVariable(VariableTracker):
|
||
def __init__(self, value: types.ModuleType, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.value = value
|
||
|
||
def python_type(self):
|
||
return types.ModuleType
|
||
|
||
|
||
class SkipFilesVariable(VariableTracker):
|
||
def __init__(self, value, reason, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.value = value
|
||
self.reason = reason
|
||
|
||
def python_type(self):
|
||
return type(self.value)
|
||
|
||
def as_python_constant(self):
|
||
return self.value
|
||
|
||
@staticmethod
|
||
@functools.lru_cache(None)
|
||
def fold_through_function_to_wrapper():
|
||
return {
|
||
collections.namedtuple: variables.UserDefinedClassVariable,
|
||
}
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
from .builtin import BuiltinVariable
|
||
|
||
if inspect.getattr_static(self.value, "_torchdynamo_disable", False):
|
||
unimplemented(f"call torch._dynamo.disable() wrapped function {self.value}")
|
||
# Allowlist a few popular classes(e.g, collections.OrderedDict) calls in skip files.
|
||
elif self.value is collections.OrderedDict:
|
||
return BuiltinVariable.call_custom_dict(
|
||
tx, collections.OrderedDict, *args, **kwargs
|
||
)
|
||
elif (
|
||
self.value is collections.defaultdict
|
||
and len(args) <= 1
|
||
and DefaultDictVariable.is_supported_arg(args[0])
|
||
):
|
||
return DefaultDictVariable(
|
||
{},
|
||
collections.defaultdict,
|
||
args[0],
|
||
mutable_local=MutableLocal(),
|
||
)
|
||
# Fold through the functions(e.g, collections.namedtuple)
|
||
# that inputs & outputs are all python constants
|
||
elif (
|
||
self.value in self.fold_through_function_to_wrapper().keys()
|
||
and check_constant_args(args, kwargs)
|
||
):
|
||
value = self.value(
|
||
*[x.as_python_constant() for x in args],
|
||
**{k: v.as_python_constant() for k, v in kwargs.items()},
|
||
)
|
||
return self.fold_through_function_to_wrapper().get(self.value)(
|
||
value, mutable_local=MutableLocal()
|
||
)
|
||
elif (
|
||
self.value is itertools.product
|
||
and not kwargs
|
||
and all(arg.has_unpack_var_sequence(tx) for arg in args)
|
||
):
|
||
seqs = [arg.unpack_var_sequence(tx) for arg in args]
|
||
items = []
|
||
for item in itertools.product(*seqs):
|
||
items.append(variables.TupleVariable(list(item)))
|
||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
||
elif (
|
||
self.value is itertools.chain
|
||
and not kwargs
|
||
and all(arg.has_unpack_var_sequence(tx) for arg in args)
|
||
):
|
||
seqs = [arg.unpack_var_sequence(tx) for arg in args]
|
||
items = []
|
||
for item in itertools.chain(*seqs):
|
||
items.append(item)
|
||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
||
elif self.value is itertools.accumulate:
|
||
from .builtin import BuiltinVariable
|
||
|
||
if any(key not in ["initial", "func"] for key in kwargs.keys()):
|
||
unimplemented(
|
||
"Unsupported kwargs for itertools.accumulate: "
|
||
f"{','.join(set(kwargs.keys()) - {'initial', 'func'})}"
|
||
)
|
||
|
||
acc = kwargs.get("initial")
|
||
|
||
if len(args) in [1, 2] and args[0].has_unpack_var_sequence(tx):
|
||
seq = args[0].unpack_var_sequence(tx)
|
||
|
||
if "func" in kwargs and len(args) == 1:
|
||
func = kwargs["func"].call_function
|
||
elif len(args) == 2:
|
||
func = args[1].call_function
|
||
elif len(args) == 1:
|
||
# Default to operator.add
|
||
func = BuiltinVariable(operator.add).call_function
|
||
else:
|
||
unimplemented(
|
||
"itertools.accumulate can only accept one of: `func` kwarg, pos 2 arg"
|
||
)
|
||
else:
|
||
unimplemented("Unsupported arguments for itertools.accumulate")
|
||
|
||
items = []
|
||
if acc is not None:
|
||
items.append(acc)
|
||
for item in seq:
|
||
if acc is None:
|
||
acc = item
|
||
else:
|
||
try:
|
||
acc = func(tx, [acc, item], {})
|
||
except Exception:
|
||
raise unimplemented( # noqa: TRY200
|
||
f"Unexpected failure in invoking function during accumulate. Failed running func {func}({item}{acc})"
|
||
)
|
||
items.append(acc)
|
||
|
||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
||
elif (
|
||
self.value is itertools.combinations
|
||
and not kwargs
|
||
and len(args) == 2
|
||
and args[0].has_unpack_var_sequence(tx)
|
||
and args[1].is_python_constant()
|
||
):
|
||
iterable = args[0].unpack_var_sequence(tx)
|
||
r = args[1].as_python_constant()
|
||
|
||
items = []
|
||
for item in itertools.combinations(iterable, r):
|
||
items.append(variables.TupleVariable(list(item)))
|
||
return variables.ListIteratorVariable(items, mutable_local=MutableLocal())
|
||
elif self.value is itertools.groupby:
|
||
if any(kw != "key" for kw in kwargs.keys()):
|
||
unimplemented(
|
||
"Unsupported kwargs for itertools.groupby: "
|
||
f"{','.join(set(kwargs.keys()) - {'key'})}"
|
||
)
|
||
|
||
def retrieve_const_key(key):
|
||
if isinstance(key, variables.SymNodeVariable):
|
||
return key.evaluate_expr()
|
||
elif isinstance(key, variables.ConstantVariable):
|
||
return key.as_python_constant()
|
||
else:
|
||
raise unimplemented(
|
||
"Unsupported key type for itertools.groupby: " + str(type(key))
|
||
)
|
||
|
||
if len(args) == 1 and args[0].has_unpack_var_sequence(tx):
|
||
seq = args[0].unpack_var_sequence(tx)
|
||
keyfunc = (
|
||
(
|
||
lambda x: (
|
||
retrieve_const_key(
|
||
kwargs.get("key").call_function(tx, [x], {})
|
||
)
|
||
)
|
||
)
|
||
if "key" in kwargs
|
||
else None
|
||
)
|
||
else:
|
||
unimplemented("Unsupported arguments for itertools.groupby")
|
||
|
||
result = []
|
||
try:
|
||
for k, v in itertools.groupby(seq, key=keyfunc):
|
||
result.append(
|
||
variables.TupleVariable(
|
||
[
|
||
variables.ConstantVariable.create(k)
|
||
if variables.ConstantVariable.is_literal(k)
|
||
else k,
|
||
variables.ListIteratorVariable(
|
||
list(v), mutable_local=MutableLocal()
|
||
),
|
||
],
|
||
mutable_local=MutableLocal(),
|
||
)
|
||
)
|
||
except Exception:
|
||
raise unimplemented( # noqa: TRY200
|
||
"Unexpected failure when calling itertools.groupby"
|
||
)
|
||
return variables.ListIteratorVariable(result, mutable_local=MutableLocal())
|
||
elif (
|
||
self.value is functools.wraps
|
||
and not kwargs
|
||
and len(args) == 1
|
||
and (
|
||
args[0].source is not None or args[0].can_reconstruct(tx.output.root_tx)
|
||
)
|
||
):
|
||
|
||
def wraps(fn):
|
||
if isinstance(fn, variables.NestedUserFunctionVariable):
|
||
if args[0].source:
|
||
reconstructible = args[0].source
|
||
else:
|
||
reconstructible = args[0]
|
||
return fn.clone(wrapped_reconstructible=reconstructible)
|
||
unimplemented(f"functools.wraps({fn})")
|
||
|
||
return variables.LambdaVariable(wraps)
|
||
elif self.value is collections.deque and not kwargs:
|
||
if len(args) == 0:
|
||
items = []
|
||
elif len(args) == 1 and args[0].has_unpack_var_sequence(tx):
|
||
items = args[0].unpack_var_sequence(tx)
|
||
else:
|
||
unimplemented("deque() with more than 1 arg not supported")
|
||
return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
|
||
elif self.value is functools.partial:
|
||
if not args:
|
||
unimplemented("functools.partial malformed")
|
||
# The first arg, a callable (the ctor below will assert on types)
|
||
fn = args[0]
|
||
rest_args = args[1:]
|
||
# guards for the produced FunctoolsPartialVariable are installed in FunctoolsPartialVariable ctor from the
|
||
# args and keywords
|
||
return variables.functions.FunctoolsPartialVariable(
|
||
fn, args=rest_args, keywords=kwargs
|
||
)
|
||
elif self.value is itertools.repeat:
|
||
if len(args) < 2:
|
||
return variables.RepeatIteratorVariable(
|
||
*args, mutable_local=MutableLocal()
|
||
)
|
||
|
||
from .builder import SourcelessBuilder
|
||
|
||
return tx.inline_user_function_return(
|
||
SourcelessBuilder()(tx, polyfill.repeat), args, kwargs
|
||
)
|
||
elif self.value is itertools.count:
|
||
return variables.CountIteratorVariable(*args, mutable_local=MutableLocal())
|
||
elif self.value is itertools.cycle:
|
||
return variables.CycleIteratorVariable(*args, mutable_local=MutableLocal())
|
||
else:
|
||
try:
|
||
path = inspect.getfile(self.value)
|
||
except TypeError:
|
||
path = f"Builtin {self.value.__name__}"
|
||
unimplemented(
|
||
f"'call_function {self.value.__qualname__} in skip_files {path}, {self.reason}'"
|
||
)
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
if (
|
||
self.value in {collections.OrderedDict, collections.defaultdict}
|
||
and name == "fromkeys"
|
||
):
|
||
from .builtin import BuiltinVariable
|
||
|
||
return BuiltinVariable.call_custom_dict_fromkeys(
|
||
tx, self.value, *args, **kwargs
|
||
)
|
||
return super().call_method(tx, name, args, kwargs)
|
||
|
||
|
||
class TypingVariable(VariableTracker):
|
||
def __init__(self, value, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.value = value
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
if name == "__getitem__" and len(args) == 1:
|
||
return variables.ConstantVariable.create(
|
||
self.value[args[0].as_python_constant()],
|
||
)
|
||
unimplemented("typing")
|
||
|
||
def python_type(self):
|
||
return type(self.value)
|
||
|
||
def as_python_constant(self):
|
||
return self.value
|
||
|
||
|
||
@functools.lru_cache(maxsize=1)
|
||
def get_np_to_tnp_map():
|
||
from ..utils import NP_TO_TNP_MODULE
|
||
|
||
np_fn_to_tnp_fn = {}
|
||
|
||
for np_mod, tnp_mod in NP_TO_TNP_MODULE.items():
|
||
for fn_name, tnp_fn in tnp_mod.__dict__.items():
|
||
if callable(tnp_fn):
|
||
# some internal details do leak from tnp
|
||
# which are not part of numpy API.
|
||
if np_fn := getattr(np_mod, fn_name, None):
|
||
np_fn_to_tnp_fn[np_fn] = tnp_fn
|
||
|
||
return np_fn_to_tnp_fn
|
||
|
||
|
||
class NumpyVariable(VariableTracker):
|
||
"""
|
||
Wrapper around `numpy.*`. Currently, is able to trace a small subset of numpy functions as well as numpy dtypes.
|
||
"""
|
||
|
||
def __init__(self, value, **kwargs):
|
||
super().__init__(**kwargs)
|
||
self.value = value
|
||
|
||
def call_function(
|
||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||
) -> "VariableTracker":
|
||
if not config.trace_numpy:
|
||
unimplemented(f"numpy.{self.value}()")
|
||
|
||
from ..utils import numpy_to_tensor_wrapper
|
||
|
||
from .tensor import NumpyNdarrayVariable
|
||
|
||
# lookup method name in tnp. Things like np.dtype(float) are not supported yet.
|
||
if self.value.__name__ == "dtype":
|
||
unimplemented(
|
||
f"numpy dtype function is not supported yet. Got type {type(self.value)}."
|
||
)
|
||
else: # We are dealing with a callable.
|
||
func = get_np_to_tnp_map().get(self.value)
|
||
if func is None:
|
||
unimplemented(
|
||
f"Can't find numpy function {self.value} in torch._numpy. "
|
||
" Please file an issue to request support for this function."
|
||
)
|
||
|
||
if (
|
||
func.__module__ == "torch._numpy.random"
|
||
and config.use_numpy_random_stream
|
||
):
|
||
msg = f"delegate '{func.__qualname__}' to NumPy itself via "
|
||
msg += f"confg.use_numpy_random_stream={config.use_numpy_random_stream}"
|
||
unimplemented(msg)
|
||
|
||
# TODO(larryliu0820): currently assuming all numpy.* functions are returning a ndarray that can be
|
||
# wrapped by NumpyNdarrayVariable which is wrong!
|
||
proxy = tx.output.create_proxy(
|
||
"call_function",
|
||
numpy_to_tensor_wrapper(func),
|
||
*proxy_args_kwargs(args, kwargs),
|
||
)
|
||
return NumpyNdarrayVariable.create(tx, proxy)
|
||
|
||
def call_method(
|
||
self,
|
||
tx,
|
||
name,
|
||
args: "List[VariableTracker]",
|
||
kwargs: "Dict[str, VariableTracker]",
|
||
) -> "VariableTracker":
|
||
unimplemented("numpy")
|
||
|
||
def python_type(self):
|
||
return type(self.value)
|
||
|
||
def as_python_constant(self):
|
||
return self.value
|
||
|
||
def as_proxy(self):
|
||
# this handles numpy dtype attribute such as np.float32. TODO(larryliu0820): we should split NumpyVariable
|
||
# into NumpyVariable for instances/objects and NumpyVariable for types.
|
||
if config.trace_numpy and isinstance(self.value, type):
|
||
# retrieve attribute str. E.g., "float32" if given np.float32
|
||
|
||
attr = self.value.__name__
|
||
# get tnp equivalent
|
||
tnp_dtype = tnp.dtype(attr)
|
||
# returning a string here because we are assuming all `dtype` kwargs for numpy
|
||
# functions can take an equivalent string and the behavior of the function would
|
||
# be the same as taking a numpy dtype.
|
||
return tnp_dtype.name
|
||
|
||
return super().as_proxy()
|
||
|
||
|
||
# Used to keep track of NULLs pushed on the stack for Python 3.11 function calls
|
||
class NullVariable(VariableTracker):
|
||
def __init__(self, **kwargs):
|
||
super().__init__(**kwargs)
|
||
|
||
def __str__(self):
|
||
return "NullVariable"
|
||
|
||
def reconstruct(self, codegen):
|
||
if sys.version_info < (3, 11):
|
||
unimplemented("cannot reconstruct NullVariable in < Python 3.11")
|
||
return [create_instruction("PUSH_NULL")]
|
||
|
||
|
||
class DeletedVariable(VariableTracker):
|
||
"""Marker used to implement delattr()"""
|
||
|
||
|
||
class StringFormatVariable(VariableTracker):
|
||
"""
|
||
Represents a call to str.format(), we delay calling format until after the graph.
|
||
"""
|
||
|
||
_nonvar_fields = {"format_string", *VariableTracker._nonvar_fields}
|
||
|
||
@classmethod
|
||
def create(cls, format_string, sym_args, sym_kwargs):
|
||
if all(
|
||
x.is_python_constant()
|
||
for x in itertools.chain(sym_args, sym_kwargs.values())
|
||
):
|
||
return variables.ConstantVariable.create(
|
||
format_string.format(
|
||
*[v.as_python_constant() for v in sym_args],
|
||
**{k: v.as_python_constant() for k, v in sym_kwargs.items()},
|
||
)
|
||
)
|
||
return cls(format_string, list(sym_args), dict(sym_kwargs))
|
||
|
||
def __init__(self, format_string, sym_args, sym_kwargs, **kwargs):
|
||
super().__init__(**kwargs)
|
||
assert isinstance(format_string, str)
|
||
self.format_string = format_string
|
||
self.sym_args = sym_args
|
||
self.sym_kwargs = sym_kwargs
|
||
|
||
def __repr__(self):
|
||
return f"{self.__class__.__name__}({self.format_string!r}, {self.sym_args!r}, {self.sym_kwargs!r})"
|
||
|
||
def reconstruct(self, codegen):
|
||
if sys.version_info >= (3, 11):
|
||
codegen.append_output(create_instruction("PUSH_NULL"))
|
||
codegen.append_output(codegen.create_load_const(self.format_string))
|
||
codegen.append_output(codegen.create_load_attr("format"))
|
||
codegen.extend_output(
|
||
variables.TupleVariable(self.sym_args).reconstruct(codegen)
|
||
)
|
||
codegen.extend_output(
|
||
variables.ConstDictVariable(self.sym_kwargs, user_cls=dict).reconstruct(
|
||
codegen
|
||
)
|
||
)
|
||
codegen.append_output(create_instruction("CALL_FUNCTION_EX", arg=1))
|
||
return []
|