pytorch/tools/codegen/api/native.py
Scott Wolchok 1211bccc65 [PyTorch] Fix const correctness for resize native functions (#55351)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55351

We incorrectly used `Tensor&` to mean "the underlying
TensorImpl cannot be changed", as explained in
https://github.com/zdevito/ATen/issues/27#issuecomment-330717839 .
This diff gets us on the path to fixing this problem: we have an
incremental way to fix individual native functions so that we can
apply any handwritten fixes a few at a time. It gets the migration
started with the `resize` family of native functions.
ghstack-source-id: 127092677

Test Plan: fitsships

Reviewed By: ezyang

Differential Revision: D27583983

fbshipit-source-id: 4eeeec85f5d268e9d0f1645eb9396914a9f9557f
2021-04-21 14:51:41 -07:00

112 lines
4.7 KiB
Python

from tools.codegen.model import (Argument, FunctionSchema, Return,
SelfArgument, TensorOptionsArguments, Type,
assert_never)
from tools.codegen.api.types import (ArgName, BaseCType, Binding,
ConstRefCType, NamedCType, CType, MutRefCType, ListCType,
OptionalCType, tensorT, scalarT, layoutT,
deviceT, boolT, scalarTypeT)
from tools.codegen.api import cpp
from tools.codegen import local
from typing import Union, Sequence, List, Optional
# This file describes the translation of JIT schema to the native functions API.
# This looks a lot like the C++ API (which makes historical sense, because the
# idea was you wrote native functions to implement functions in the C++ API),
# but over time we have evolved the C++ API without actually changing our
# native:: kernels. The intention is to make native API and dispatcher API
# line up as closely as possible, since this results in the least overhead
# (no translation is needed from dispatcher API to native API).
def name(func: FunctionSchema) -> str:
name = str(func.name.name)
# TODO: delete this!
if func.is_out_fn():
name += '_out'
if func.name.overload_name:
name += f'_{func.name.overload_name}'
return name
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
if str(t) == 'Tensor?':
tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(tensor_type))
else:
return NamedCType(binds, ConstRefCType(tensor_type))
elif str(t) == 'Tensor?[]':
return NamedCType(binds, ConstRefCType(ListCType(OptionalCType(BaseCType(tensorT)))))
elif str(t) == 'Scalar':
return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
elif str(t) == 'Scalar?':
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(scalarT))))
return cpp.argumenttype_type(t, mutable=mutable, binds=binds)
def returns_type(rs: Sequence[Return]) -> CType:
return cpp.returns_type(rs)
def argument_type(a: Argument, *, binds: ArgName) -> NamedCType:
return argumenttype_type(a.type, mutable=a.is_write, binds=binds)
def argument(a: Union[Argument, SelfArgument, TensorOptionsArguments], *, is_out: bool) -> List[Binding]:
# Ideally, we NEVER default native functions. However, there are a number
# of functions that call native:: directly and rely on the defaulting
# existing. So for BC, we generate defaults for non-out variants (but not
# for out variants, where it is impossible to generate an appropriate
# default)
should_default = not is_out
if isinstance(a, Argument):
default: Optional[str] = None
if should_default and a.default is not None:
default = cpp.default_expr(a.default, a.type)
return [Binding(
nctype=argument_type(a, binds=a.name),
name=a.name,
default=default,
argument=a,
)]
elif isinstance(a, SelfArgument):
# Erase SelfArgument from the distinction
return argument(a.argument, is_out=is_out)
elif isinstance(a, TensorOptionsArguments):
default = None
if should_default:
default = '{}'
# TODO: Not sure why the arguments assigned here are for
# TensorOptionsArguments and not the constituent pieces. It seems
# to matter
return [
Binding(
nctype=NamedCType('dtype', OptionalCType(BaseCType(scalarTypeT))),
name='dtype',
default=default,
argument=a,
),
Binding(
nctype=NamedCType('layout', OptionalCType(BaseCType(layoutT))),
name='layout',
default=default,
argument=a,
),
Binding(
nctype=NamedCType('device', OptionalCType(BaseCType(deviceT))),
name='device',
default=default,
argument=a,
),
Binding(
nctype=NamedCType('pin_memory', OptionalCType(BaseCType(boolT))),
name='pin_memory',
default=default,
argument=a,
)]
else:
assert_never(a)
def arguments(func: FunctionSchema) -> List[Binding]:
args: List[Union[Argument, TensorOptionsArguments, SelfArgument]] = []
args.extend(func.arguments.non_out)
args.extend(func.arguments.out)
return [r for arg in args for r in argument(arg, is_out=func.is_out_fn())]