mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This is the continuation to #90134 and hopefully the final PR in this series. Pull Request resolved: https://github.com/pytorch/pytorch/pull/90271 Approved by: https://github.com/kit1980
159 lines
5.5 KiB
Python
159 lines
5.5 KiB
Python
import operator
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
|
|
from .. import variables
|
|
from ..exc import unimplemented
|
|
from ..utils import istype
|
|
from .base import typestr, VariableTracker
|
|
|
|
|
|
class ConstantVariable(VariableTracker):
|
|
def __init__(self, value, **kwargs):
|
|
super(ConstantVariable, self).__init__(**kwargs)
|
|
assert not isinstance(value, torch.Tensor)
|
|
assert not isinstance(value, torch.SymInt)
|
|
assert not isinstance(value, torch.SymFloat)
|
|
self.value = value
|
|
|
|
def as_proxy(self):
|
|
return self.value
|
|
|
|
def __str__(self):
|
|
# return f"ConstantVariable({self.value})"
|
|
return f"ConstantVariable({type(self.value).__name__})"
|
|
|
|
def python_type(self):
|
|
return type(self.value)
|
|
|
|
def as_python_constant(self):
|
|
return self.value
|
|
|
|
@property
|
|
def items(self):
|
|
"""
|
|
Need this when adding a BaseListVariable and a ConstantVariable together.
|
|
Happens in detectron2.
|
|
"""
|
|
return self.unpack_var_sequence(tx=None)
|
|
|
|
def getitem_const(self, arg: VariableTracker):
|
|
return ConstantVariable(
|
|
self.value[arg.as_python_constant()],
|
|
**VariableTracker.propagate([self, arg]),
|
|
)
|
|
|
|
@staticmethod
|
|
def is_literal(obj):
|
|
if type(obj) in (int, float, bool, type(None), str):
|
|
return True
|
|
if type(obj) in (list, tuple, set, frozenset):
|
|
return all(ConstantVariable.is_literal(x) for x in obj)
|
|
return False
|
|
|
|
def unpack_var_sequence(self, tx):
|
|
try:
|
|
options = VariableTracker.propagate([self])
|
|
return [ConstantVariable(x, **options) for x in self.as_python_constant()]
|
|
except TypeError as e:
|
|
raise NotImplementedError from e
|
|
|
|
def const_getattr(self, tx, name):
|
|
member = getattr(self.value, name)
|
|
if callable(member):
|
|
raise NotImplementedError()
|
|
return member
|
|
|
|
def call_method(
|
|
self,
|
|
tx,
|
|
name,
|
|
args: "List[VariableTracker]",
|
|
kwargs: "Dict[str, VariableTracker]",
|
|
) -> "VariableTracker":
|
|
from .tensor import DynamicShapeVariable
|
|
|
|
options = VariableTracker.propagate(self, args, kwargs.values())
|
|
|
|
if istype(self.value, tuple):
|
|
# empty tuple constant etc
|
|
return variables.TupleVariable(
|
|
items=self.unpack_var_sequence(tx), source=self.source, **options
|
|
).call_method(tx, name, args, kwargs)
|
|
|
|
if any([isinstance(x, DynamicShapeVariable) for x in args]):
|
|
# NOTE! DANGER! THIS ONLY WORKS FOR COMMUTATIVE OPS
|
|
# we are relying on add to have arg[0] be a DynamicShapeVariable
|
|
# because we are in ConstantVariable land
|
|
# This transforms
|
|
# constant + dynamic
|
|
# into
|
|
# dynamic + constant
|
|
# Which already has infra built for writing to the graph
|
|
if name == "__add__":
|
|
assert len(args) == 1
|
|
return args[0].call_method(tx, name, [self], {})
|
|
# Unfortunate constant
|
|
return super(ConstantVariable, self).call_method(tx, name, args, kwargs)
|
|
try:
|
|
const_args = [a.as_python_constant() for a in args]
|
|
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
|
|
except NotImplementedError:
|
|
return super(ConstantVariable, self).call_method(tx, name, args, kwargs)
|
|
|
|
def has_arith_binop(num_ty):
|
|
return (
|
|
isinstance(self.value, num_ty)
|
|
and hasattr(operator, name)
|
|
and len(args) == 1
|
|
and args[0].is_python_constant()
|
|
)
|
|
|
|
if isinstance(self.value, str) and name in str.__dict__.keys():
|
|
assert not kwargs
|
|
method = getattr(self.value, name)
|
|
return ConstantVariable(method(*const_args, **const_kwargs), **options)
|
|
elif has_arith_binop(int) or has_arith_binop(float):
|
|
op = getattr(operator, name)
|
|
add_target = const_args[0]
|
|
if isinstance(add_target, (torch.SymInt, torch.SymFloat)):
|
|
from .tensor import DynamicShapeVariable
|
|
|
|
# Addition between a non sym and sym makes a sym
|
|
# dyn_shape = tx.output.register_attr_or_module(
|
|
# add_target, f"sym_shape_{add_target}", source=None
|
|
# )
|
|
proxy = tx.output.create_proxy(
|
|
"call_function", op, (self.value, add_target), {}
|
|
)
|
|
return DynamicShapeVariable.create(tx, proxy, add_target, **options)
|
|
return ConstantVariable(op(self.value, add_target), **options)
|
|
elif name == "__len__" and not (args or kwargs):
|
|
return ConstantVariable(len(self.value), **options)
|
|
elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
|
|
assert not kwargs
|
|
search = args[0].as_python_constant()
|
|
result = search in self.value
|
|
return ConstantVariable(result, **options)
|
|
|
|
unimplemented(f"const method call {typestr(self.value)}.{name}")
|
|
|
|
|
|
class EnumVariable(VariableTracker):
|
|
def __init__(self, value, **kwargs):
|
|
super(EnumVariable, self).__init__(**kwargs)
|
|
self.value = value
|
|
|
|
def as_proxy(self):
|
|
return self.value
|
|
|
|
def __str__(self):
|
|
return f"EnumVariable({type(self.value)})"
|
|
|
|
def python_type(self):
|
|
return type(self.value)
|
|
|
|
def as_python_constant(self):
|
|
return self.value
|