pytorch/torch/_dynamo/variables/base.py
Jason Ansel c7c09722ad Move TorchDynamo into PyTorch core (#86461)
Context:
https://github.com/pytorch/torchdynamo/issues/1588

This PR moves [TorchDynamo](https://github.com/pytorch/torchdynamo) and TorchInductor into PyTorch core.
- `torchdynamo` becomes `torch._dynamo`
- `torchinductor` becomes `torch._inductor`

This PR was generated by running `copy_to_core.sh` in https://github.com/pytorch/torchdynamo/pull/1538

Pull Request resolved: https://github.com/pytorch/pytorch/pull/86461
Approved by: https://github.com/voznesenskym
2022-10-13 23:18:06 +00:00

263 lines
8.0 KiB
Python

import collections
from typing import Any, Callable, Dict, List, Optional, Set
from .. import variables
from ..exc import unimplemented
from ..source import AttrSource, Source
from ..utils import dict_values, identity, istype, odict_values
class MutableLocal:
"""
Marker used to indicate this (list, iter, etc) was constructed in
local scope and can be mutated safely in analysis without leaking
state.
"""
def __hash__(self):
return id(self)
def __eq__(self, other):
return self is other
class VariableTracker:
"""
Base class for tracked locals and stack values
VariableTracker instances are immutable and should be copied in
order to change them.
"""
# fields to leave unmodified in apply()
_nonvar_fields = ["value"]
@staticmethod
def propagate(*vars: List[List["VariableTracker"]]):
"""Combine the guards from many VariableTracker into **kwargs for a new instance"""
guards = set()
def visit(var):
if type(var) in (list, tuple, dict_values, odict_values):
for i in var:
visit(i)
elif isinstance(var, variables.BaseListVariable):
guards.update(var.guards)
for i in var.items:
visit(i)
elif isinstance(var, variables.ConstDictVariable):
guards.update(var.guards)
visit(var.items.values())
else:
assert isinstance(var, VariableTracker), typestr(var)
guards.update(var.guards)
visit(vars)
return {
"guards": guards,
}
def clone(self, **kwargs):
"""Shallow copy with some (optional) changes"""
args = dict(self.__dict__)
args.update(kwargs)
return self.__class__(**args)
@classmethod
def copy(cls, value):
"""Deeper (but not full) copy, leaving FX and user objects alone"""
return cls.apply(identity, value)
@classmethod
def apply(
cls, fn: Callable[["VariableTracker"], "VariableTracker"], value, cache=None
):
"""
Walk this object and call fn on all the VariableTracker
instances to produce a new VariableTracker with the results.
"""
if cache is None:
cache = dict()
idx = id(value)
if idx in cache:
return cache[idx][0]
if isinstance(value, VariableTracker):
updated_dict = dict(value.__dict__)
for key in updated_dict.keys():
if key not in value._nonvar_fields:
updated_dict[key] = cls.apply(fn, updated_dict[key], cache)
result = fn(value.clone(**updated_dict))
elif istype(value, list):
result = [cls.apply(fn, v, cache) for v in value]
elif istype(value, tuple):
result = tuple(cls.apply(fn, v, cache) for v in value)
elif istype(value, collections.OrderedDict):
result = collections.OrderedDict(
cls.apply(fn, v, cache) for v in value.items()
)
elif istype(value, dict):
result = {k: cls.apply(fn, v, cache) for k, v in list(value.items())}
else:
result = value
# save `value` to keep it alive and ensure id() isn't reused
cache[idx] = (result, value)
return result
def add_guard(self, guard):
return self.clone(guards=set.union(self.guards, {guard}))
def add_guards(self, guards):
if guards is None:
return self
assert isinstance(guards, set)
return self.clone(guards=set.union(self.guards, guards))
def add_options(self, options, *more):
if more:
return self.add_options(options).add_options(*more)
if isinstance(options, VariableTracker):
return self.add_guards(options.guards)
assert isinstance(options, dict)
return self.add_guards(options.get("guards", set()))
def __str__(self):
return f"{self.__class__.__name__}()"
def __repr__(self):
return str(self)
def python_type(self):
raise NotImplementedError(f"{self} has no type")
def as_python_constant(self):
"""For constants"""
raise NotImplementedError(f"{self} is not a constant")
def is_python_constant(self):
try:
self.as_python_constant()
return True
except NotImplementedError:
return False
def as_specialized(self, tx):
"""
For specialized variables, return itself,
For unspecialized variables, convert to constant variable and return.
"""
return self
def can_make_guard(self):
try:
self.make_guard(None)
return True
except NotImplementedError:
return False
def make_guard(self, fn):
if self.source:
return self.source.make_guard(fn)
raise NotImplementedError()
def replace_guards(self, guards, *fns):
name = self.source.name()
new_guards = {g for g in (guards or []) if g.name != name}
new_guards.update(self.source.make_guard(fn) for fn in fns)
return new_guards
def const_getattr(self, tx, name: str) -> Any:
"""getattr(self, name) returning a python constant"""
raise NotImplementedError()
def var_getattr(self, tx, name: str) -> "VariableTracker":
"""getattr(self, name) returning a new variable"""
options = VariableTracker.propagate(self)
value = self.const_getattr(tx, name)
if not variables.ConstantVariable.is_literal(value):
raise NotImplementedError()
if self.source:
options["source"] = AttrSource(self.source, name)
return variables.ConstantVariable(value, **options)
def is_proxy(self):
try:
self.as_proxy()
return True
except NotImplementedError:
return False
def as_proxy(self):
raise NotImplementedError(str(self))
def reconstruct(self, codegen):
raise NotImplementedError()
def unpack_var_sequence(self, tx):
raise NotImplementedError()
def has_unpack_var_sequence(self, tx):
try:
self.unpack_var_sequence(tx)
return True
except NotImplementedError:
return False
def num_parameters(self):
unimplemented(f"num_parameters: {self}")
def call_hasattr(self, tx, name: str) -> "VariableTracker":
unimplemented(f"hasattr: {self}")
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
unimplemented(f"call_function {self} {args} {kwargs}")
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "__len__" and self.has_unpack_var_sequence(tx):
assert not (args or kwargs)
return variables.ConstantVariable(
len(self.unpack_var_sequence(tx)), **VariableTracker.propagate(self)
)
elif (
name == "__getattr__"
and len(args) == 1
and args[0].is_python_constant()
and not kwargs
):
return self.var_getattr(tx, args[0].as_python_constant()).add_options(
self, args[0]
)
raise unimplemented(f"call_method {self} {name} {args} {kwargs}")
def __init__(
self,
guards: Optional[Set] = None,
source: Source = None,
mutable_local: MutableLocal = None,
):
super(VariableTracker, self).__init__()
self.guards = guards or set()
self.source = source
self.mutable_local = mutable_local
def typestr(*objs):
if len(objs) == 1:
(obj,) = objs
if isinstance(obj, VariableTracker):
return str(obj)
else:
return type(obj).__name__
else:
return " ".join(map(typestr, objs))