pytorch/torch/_dynamo/variables/constant.py
lezcano a9dca53438 NumPy support in torch.compile (#106211)
RFC: https://github.com/pytorch/rfcs/pull/54
First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/

We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core.

In the next commits, I do a number of things in this order
- Fix a few small issues
- Make the tests that this PR adds pass
- Bend backwards until lintrunner passes
- Remove the optional dependency on `torch_np` and simply rely on the upstreamed code
- Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now.

Missing from this PR (but not blocking):
- Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate.
- https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge.

All the tests in `tests/torch_np` take about 75s to run.

This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211
Approved by: https://github.com/ezyang
2023-08-11 00:39:32 +00:00

186 lines
6.8 KiB
Python

import operator
from typing import Dict, List
import torch
from .. import variables
from ..exc import unimplemented
from ..utils import istype, np
from .base import typestr, VariableTracker
_type_to_assert_reason = {
# NB - We CAN have ConstantVariable(set) because of how sets interact with guards.
# A locally created set should always become a SetVariable, as the items in the set will already either be sourced
# from somewhere else, or unsourced. An input set would imply sources derived from set contents. For example, an
# input list's contents will have a source like some_list[0], some_list[1][1], etc. For a set, arbitrary access is
# not possible. This is a solvable problem, but one we have not taken on yet. As such, input sets are not allowed to
# become SetVariables. The solution here is to create a ConstantSetVariable that is more like a ConstantVariable.
# As this does not exist, we cannot add sets to this invariant.
list: "List types must use ListVariable.",
dict: "Dict types must use ConstDictVariable.",
torch.Tensor: "Tensor types must use TensorVariable.",
torch.SymInt: "SymInts must use SymNodeVariable. "
"If the underlying value is static, we will create a ConstantVariable and specialize.",
torch.SymFloat: "SymInts must use SymNodeVariable",
}
class ConstantVariable(VariableTracker):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
if not ConstantVariable.is_literal(value):
for disallowed_type, reason in _type_to_assert_reason.items():
assert not isinstance(value, disallowed_type), reason
if isinstance(value, np.number):
self.value = value.item()
else:
self.value = value
def as_proxy(self):
return self.value
def __str__(self):
# return f"ConstantVariable({self.value})"
return f"ConstantVariable({type(self.value).__name__})"
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
@property
def items(self):
"""
Need this when adding a BaseListVariable and a ConstantVariable together.
Happens in detectron2.
"""
return self.unpack_var_sequence(tx=None)
def getitem_const(self, arg: VariableTracker):
return ConstantVariable(
self.value[arg.as_python_constant()],
**VariableTracker.propagate([self, arg]),
)
@staticmethod
def is_literal(obj):
if type(obj) in (int, float, bool, type(None), str):
return True
if type(obj) in (list, tuple, set, frozenset):
return all(ConstantVariable.is_literal(x) for x in obj)
return False
def unpack_var_sequence(self, tx):
try:
options = VariableTracker.propagate([self])
return [ConstantVariable(x, **options) for x in self.as_python_constant()]
except TypeError as e:
raise NotImplementedError from e
def const_getattr(self, tx, name):
member = getattr(self.value, name)
if callable(member):
raise NotImplementedError()
return member
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
from .tensor import SymNodeVariable
options = VariableTracker.propagate(self, args, kwargs.values())
if istype(self.value, tuple):
# empty tuple constant etc
return variables.TupleVariable(
items=self.unpack_var_sequence(tx), source=self.source, **options
).call_method(tx, name, args, kwargs)
if istype(self.value, list):
return variables.ListVariable(
items=self.unpack_var_sequence(tx), source=self.source, **options
).call_method(tx, name, args, kwargs)
if any(isinstance(x, SymNodeVariable) for x in args):
# Promote to SymNodeVariable for operations involving dynamic shapes.
return variables.SymNodeVariable(self.as_proxy(), self.value).call_method(
tx, name, args, kwargs
)
try:
const_args = [a.as_python_constant() for a in args]
const_kwargs = {k: v.as_python_constant() for k, v in kwargs.items()}
except NotImplementedError:
return super().call_method(tx, name, args, kwargs)
def has_arith_binop(num_ty):
return (
isinstance(self.value, num_ty)
and hasattr(operator, name)
and len(args) == 1
and args[0].is_python_constant()
)
if isinstance(self.value, str) and name in str.__dict__.keys():
method = getattr(self.value, name)
return ConstantVariable(method(*const_args, **const_kwargs), **options)
elif has_arith_binop(int) or has_arith_binop(float):
op = getattr(operator, name)
add_target = const_args[0]
if isinstance(add_target, (torch.SymInt, torch.SymFloat)):
from .tensor import SymNodeVariable
# Addition between a non sym and sym makes a sym
# sym_num = tx.output.register_attr_or_module(
# add_target, f"sym_shape_{add_target}", source=None
# )
proxy = tx.output.create_proxy(
"call_function", op, (self.value, add_target), {}
)
return SymNodeVariable.create(tx, proxy, add_target, **options)
return ConstantVariable(op(self.value, add_target), **options)
elif name == "__len__" and not (args or kwargs):
return ConstantVariable(len(self.value), **options)
elif name == "__contains__" and len(args) == 1 and args[0].is_python_constant():
assert not kwargs
search = args[0].as_python_constant()
result = search in self.value
return ConstantVariable(result, **options)
unimplemented(f"const method call {typestr(self.value)}.{name}")
def call_hasattr(self, tx, name: str) -> "VariableTracker":
result = hasattr(self.value, name)
return variables.ConstantVariable(result).add_options(self)
class EnumVariable(VariableTracker):
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
def as_proxy(self):
return self.value
def __str__(self):
return f"EnumVariable({type(self.value)})"
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
def const_getattr(self, tx, name):
member = getattr(self.value, name)
if callable(member):
raise NotImplementedError()
return member