pytorch/torch/_dynamo/variables/constant.py
Ram Rachum 351d73b97f Fix exception causes all over the codebase (#90271)
This is the continuation to #90134 and hopefully the final PR in this series.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271
Approved by: https://github.com/kit1980
2022-12-07 04:29:00 +00:00

159 lines
5.5 KiB
Python

import operator
from typing import Dict, List
import torch
from .. import variables
from ..exc import unimplemented
from ..utils import istype
from .base import typestr, VariableTracker
class ConstantVariable(VariableTracker):
def __init__(self, value, **kwargs):
super(ConstantVariable, self).__init__(**kwargs)
assert not isinstance(value, torch.Tensor)
assert not isinstance(value, torch.SymInt)
assert not isinstance(value, torch.SymFloat)
self.value = value
def as_proxy(self):
return self.value
def __str__(self):
# return f"ConstantVariable({self.value})"
return f"ConstantVariable({type(self.value).__name__})"
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
@property
def items(self):
"""
Need this when adding a BaseListVariable and a ConstantVariable together.
Happens in detectron2.
"""
return self.unpack_var_sequence(tx=None)
def getitem_const(self, arg: VariableTracker):
return ConstantVariable(
self.value[arg.as_python_constant()],
**VariableTracker.propagate([self, arg]),
)
@staticmethod
def is_literal(obj):
if type(obj) in (int, float, bool, type(None), str):
return True
if type(obj) in (list, tuple, set, frozenset):
return all(ConstantVariable.is_literal(x) for x in obj)
return False
def unpack_var_sequence(self, tx):
try:
options = VariableTracker.propagate([self])
return [ConstantVariable(x, **options) for x in self.as_python_constant()]
except TypeError as e:
raise NotImplementedError from e
def const_getattr(self, tx, name):
member = getattr(self.value, name)
if callable(member):
raise NotImplementedError()
return member
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from .tensor import DynamicShapeVariable
options = VariableTracker.propagate(self, args, kwargs.values())
if istype(self.value, tuple):
# empty tuple constant etc
return variables.TupleVariable(
items=self.unpack_var_sequence(tx), source=self.source, **options
).call_method(tx, name, args, kwargs)
if any([isinstance(x, DynamicShapeVariable) for x in args]):
# NOTE! DANGER! THIS ONLY WORKS FOR COMMUTATIVE OPS
# we are relying on add to have arg[0] be a DynamicShapeVariable
# because we are in ConstantVariable land
# This transforms
# constant + dynamic
# into
# dynamic + constant
# Which already has infra built for writing to the graph
if name == "__add__":
assert len(args) == 1
return args[0].call_method(tx, name, [self], {})
# Unfortunate constant
return super(ConstantVariable, self).call_method(tx, name, args, kwargs)
try:
const_args = [a.as_python_constant() for a in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
except NotImplementedError:
return super(ConstantVariable, self).call_method(tx, name, args, kwargs)
def has_arith_binop(num_ty):
return (
isinstance(self.value, num_ty)
and hasattr(operator, name)
and len(args) == 1
and args[0].is_python_constant()
)
if isinstance(self.value, str) and name in str.__dict__.keys():
assert not kwargs
method = getattr(self.value, name)
return ConstantVariable(method(*const_args, **const_kwargs), **options)
elif has_arith_binop(int) or has_arith_binop(float):
op = getattr(operator, name)
add_target = const_args[0]
if isinstance(add_target, (torch.SymInt, torch.SymFloat)):
from .tensor import DynamicShapeVariable
# Addition between a non sym and sym makes a sym
# dyn_shape = tx.output.register_attr_or_module(
# add_target, f"sym_shape_{add_target}", source=None
# )
proxy = tx.output.create_proxy(
"call_function", op, (self.value, add_target), {}
)
return DynamicShapeVariable.create(tx, proxy, add_target, **options)
return ConstantVariable(op(self.value, add_target), **options)
elif name == "__len__" and not (args or kwargs):
return ConstantVariable(len(self.value), **options)
elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
assert not kwargs
search = args[0].as_python_constant()
result = search in self.value
return ConstantVariable(result, **options)
unimplemented(f"const method call {typestr(self.value)}.{name}")
class EnumVariable(VariableTracker):
def __init__(self, value, **kwargs):
super(EnumVariable, self).__init__(**kwargs)
self.value = value
def as_proxy(self):
return self.value
def __str__(self):
return f"EnumVariable({type(self.value)})"
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value