pytorch/torch/_dynamo/variables/base.py
Xuehai Pan 5b1cedacde [BE] [2/3] Rewrite super() calls in functorch and torch (#94588)
Rewrite Python built-in class `super()` calls. Only non-semantic changes should be applied.

- #94587
- #94588
- #94592

Also, methods with only a `super()` call are removed:

```diff
class MyModule(nn.Module):
-   def __init__(self):
-       super().__init__()
-
    def forward(self, ...):
        ...
```

Some cases that change the semantics should be kept unchanged. E.g.:

f152a79be9/caffe2/python/net_printer.py (L184-L190)

f152a79be9/test/test_jit_fuser_te.py (L2628-L2635)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/94588
Approved by: https://github.com/ezyang, https://github.com/albanD
2023-02-10 21:16:33 +00:00

297 lines
9.0 KiB
Python

import collections
from typing import Any, Callable, Dict, List, Optional, Set
from .. import variables
from ..exc import unimplemented
from ..source import AttrSource, Source
from ..utils import dict_values, identity, istype, odict_values
class MutableLocal:
"""
Marker used to indicate this (list, iter, etc) was constructed in
local scope and can be mutated safely in analysis without leaking
state.
"""
def __hash__(self):
return id(self)
def __eq__(self, other):
return self is other
# metaclass to call post_init
class HasPostInit(type):
def __call__(cls, *args, **kwargs):
obj = type.__call__(cls, *args, **kwargs)
obj.__post_init__(*args, **kwargs)
return obj
class VariableTracker(metaclass=HasPostInit):
"""
Base class for tracked locals and stack values
VariableTracker instances are immutable and should be copied in
order to change them.
"""
# fields to leave unmodified in apply()
_nonvar_fields = ["value"]
@staticmethod
def propagate(*vars: List[List["VariableTracker"]]):
"""Combine the guards from many VariableTracker into **kwargs for a new instance"""
guards = set()
def visit(var):
if type(var) in (list, tuple, dict_values, odict_values):
for i in var:
visit(i)
else:
assert isinstance(var, VariableTracker), typestr(var)
guards.update(var.guards)
visit(vars)
return {
"guards": guards,
}
def clone(self, **kwargs):
"""Shallow copy with some (optional) changes"""
args = dict(self.__dict__)
args.update(kwargs)
return self.__class__(**args)
@classmethod
def copy(cls, value):
"""Deeper (but not full) copy, leaving FX and user objects alone"""
return cls.apply(identity, value)
@classmethod
def apply(
cls,
fn: Callable[["VariableTracker"], "VariableTracker"],
value,
cache=None,
skip_fn=lambda _: False, # Whether we should skip applying to this var
):
"""
Walk this object and call fn on all the VariableTracker
instances to produce a new VariableTracker with the results.
"""
if cache is None:
cache = dict()
idx = id(value)
if idx in cache:
return cache[idx][0]
if isinstance(value, VariableTracker):
if not skip_fn(value):
updated_dict = dict(value.__dict__)
for key in updated_dict.keys():
if key not in value._nonvar_fields:
updated_dict[key] = cls.apply(
fn, updated_dict[key], cache, skip_fn
)
result = fn(value.clone(**updated_dict))
else:
result = fn(value)
elif istype(value, list):
result = [cls.apply(fn, v, cache, skip_fn) for v in value]
elif istype(value, tuple):
result = tuple(cls.apply(fn, v, cache, skip_fn) for v in value)
elif istype(value, collections.OrderedDict):
result = collections.OrderedDict(
cls.apply(fn, v, cache, skip_fn) for v in value.items()
)
elif istype(value, dict):
result = {
k: cls.apply(fn, v, cache, skip_fn) for k, v in list(value.items())
}
else:
result = value
# save `value` to keep it alive and ensure id() isn't reused
cache[idx] = (result, value)
return result
def add_guard(self, guard):
return self.clone(guards=set.union(self.guards, {guard}))
def add_guards(self, guards):
if guards is None:
return self
assert isinstance(guards, set)
return self.clone(guards=set.union(self.guards, guards))
def add_options(self, options, *more):
if more:
return self.add_options(options).add_options(*more)
if isinstance(options, VariableTracker):
return self.add_guards(options.guards)
assert isinstance(options, dict)
return self.add_guards(options.get("guards", set()))
def __str__(self):
return f"{self.__class__.__name__}()"
def __repr__(self):
return str(self)
def python_type(self):
raise NotImplementedError(f"{self} has no type")
def as_python_constant(self):
"""For constants"""
raise NotImplementedError(f"{self} is not a constant")
def is_python_constant(self):
try:
self.as_python_constant()
return True
except NotImplementedError:
return False
def as_specialized(self, tx):
"""
For specialized variables, return itself,
For unspecialized variables, convert to constant variable and return.
"""
return self
def can_make_guard(self):
try:
self.make_guard(None)
return True
except NotImplementedError:
return False
def make_guard(self, fn):
if self.source:
return self.source.make_guard(fn)
raise NotImplementedError()
def replace_guards(self, guards, *fns):
name = self.source.name()
new_guards = {g for g in (guards or []) if g.name != name}
new_guards.update(self.source.make_guard(fn) for fn in fns)
return new_guards
def const_getattr(self, tx, name: str) -> Any:
"""getattr(self, name) returning a python constant"""
raise NotImplementedError()
def var_getattr(self, tx, name: str) -> "VariableTracker":
"""getattr(self, name) returning a new variable"""
options = VariableTracker.propagate(self)
value = self.const_getattr(tx, name)
if not variables.ConstantVariable.is_literal(value):
raise NotImplementedError()
if self.source:
options["source"] = AttrSource(self.source, name)
return variables.ConstantVariable(value, **options)
def is_proxy(self):
try:
self.as_proxy()
return True
except NotImplementedError:
return False
def as_proxy(self):
raise NotImplementedError(str(self))
def reconstruct(self, codegen):
raise NotImplementedError()
def unpack_var_sequence(self, tx):
raise NotImplementedError()
def has_unpack_var_sequence(self, tx):
try:
self.unpack_var_sequence(tx)
return True
except NotImplementedError:
return False
def num_parameters(self):
unimplemented(f"num_parameters: {self}")
def call_hasattr(self, tx, name: str) -> "VariableTracker":
unimplemented(f"hasattr: {repr(self)}")
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
unimplemented(f"call_function {self} {args} {kwargs}")
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__len__" and self.has_unpack_var_sequence(tx):
assert not (args or kwargs)
return variables.ConstantVariable(
len(self.unpack_var_sequence(tx)), **VariableTracker.propagate(self)
)
elif (
name == "__getattr__"
and len(args) == 1
and args[0].is_python_constant()
and not kwargs
):
return self.var_getattr(tx, args[0].as_python_constant()).add_options(
self, args[0]
)
raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
def __init__(
self,
guards: Optional[Set] = None,
source: Source = None,
mutable_local: MutableLocal = None,
recursively_contains: Optional[Set] = None,
):
super().__init__()
self.guards = guards or set()
self.source = source
self.mutable_local = mutable_local
self.recursively_contains = (
recursively_contains # provides hint to replace_all when replacing vars
)
def __post_init__(self, *args, **kwargs):
if self.recursively_contains is None:
self.recursively_contains = set()
def aggregate_mutables(var):
self.recursively_contains.update(var.recursively_contains)
if var.mutable_local is not None:
self.recursively_contains.add(var.mutable_local)
return var
VariableTracker.apply(
aggregate_mutables, self, skip_fn=lambda var: var is not self
)
assert None not in self.recursively_contains
def typestr(*objs):
if len(objs) == 1:
(obj,) = objs
if isinstance(obj, VariableTracker):
return str(obj)
else:
return type(obj).__name__
else:
return " ".join(map(typestr, objs))