[PyTorch] Fix const correctness for resize native functions (#55351)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/55351

We incorrectly used `Tensor&` to mean "the underlying
TensorImpl cannot be changed", as explained in
https://github.com/zdevito/ATen/issues/27#issuecomment-330717839 .
This diff gets us on the path to fixing this problem: we have an
incremental way to fix individual native functions so that we can
apply any handwritten fixes a few at a time. It gets the migration
started with the `resize` family of native functions.
ghstack-source-id: 127092677

Test Plan: fitsships

Reviewed By: ezyang

Differential Revision: D27583983

fbshipit-source-id: 4eeeec85f5d268e9d0f1645eb9396914a9f9557f
This commit is contained in:
Scott Wolchok 2021-04-21 14:46:54 -07:00 committed by Facebook GitHub Bot
parent 5e695b1271
commit 1211bccc65
24 changed files with 160 additions and 65 deletions

View File

@ -125,7 +125,7 @@ static void assert_names_equal(DimnameList a, DimnameList b) {
". Please rename the out tensor's dims with `Tensor.rename`.");
}
Tensor& propagate_names_if_nonempty(Tensor& result,
const Tensor& propagate_names_if_nonempty(const Tensor& result,
DimnameList maybe_names,
bool validate_names) {
propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_names, validate_names);
@ -141,7 +141,7 @@ TensorImpl* propagate_names_if_nonempty(TensorImpl* result,
return propagate_names(result, maybe_names, validate_names);
}
Tensor& propagate_names(Tensor& result, DimnameList names, bool validate_names) {
const Tensor& propagate_names(const Tensor& result, DimnameList names, bool validate_names) {
propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
return result;
}
@ -162,7 +162,7 @@ TensorImpl* propagate_names(TensorImpl* result, DimnameList names, bool validate
return result;
}
void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) {
void propagate_names_except(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) {
if (!result.has_names() && !src.has_names()) {
return;
}
@ -190,7 +190,7 @@ void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef exclu
propagate_names(result, outnames);
}
void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) {
void propagate_names_for_reduction(const Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) {
if (keepdim) {
propagate_names(result, src);
return;
@ -202,7 +202,7 @@ void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRe
propagate_names_except(result, src, reduced_dims);
}
void propagate_names(Tensor& result, const Tensor& src) {
void propagate_names(const Tensor& result, const Tensor& src) {
propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
}
@ -409,7 +409,7 @@ void check_names_for_dot(
// rules for binary ops that expect the named dims to line up positionally
// from the right. i.e.,
// Tensor[H, W].expand(3, 3, 3, 3) -> Tensor[None, None, H, W]
void propagate_names_for_expand(Tensor& result, const Tensor& self) {
void propagate_names_for_expand(const Tensor& result, const Tensor& self) {
if (!self.has_names()) {
return;
}

View File

@ -75,28 +75,28 @@ namespace namedinference {
// `names` can be empty; see [NOTE] Writing name inference rules
// If `names` is not empty, `names.size()` should equal `result.dim()`.
// When in doubt, use this overload instead of the others.
TORCH_API Tensor& propagate_names_if_nonempty(
Tensor& result,
TORCH_API const Tensor& propagate_names_if_nonempty(
const Tensor& result,
DimnameList maybe_names,
bool validate_names = false);
// Propagates `names` to `result`. Only use this if we are certain that there are
// names to propagate (that names is not empty).
TORCH_API Tensor& propagate_names(
Tensor& result,
TORCH_API const Tensor& propagate_names(
const Tensor& result,
DimnameList names,
bool validate_names = false);
// Propagates all names from src to result.
TORCH_API void propagate_names(Tensor& result, const Tensor& src);
TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
// Propagates all names except for those at the excluded_idxs.
TORCH_API void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs);
TORCH_API void propagate_names_except(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs);
// Used for reduction ops that have a `keepdim` arg.
TORCH_API void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs, bool keepdim);
TORCH_API void propagate_names_for_reduction(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs, bool keepdim);
TORCH_API void propagate_names_for_expand(Tensor& result, const Tensor& self);
TORCH_API void propagate_names_for_expand(const Tensor& result, const Tensor& self);
TORCH_API std::vector<Dimname> compute_cat_outnames(TensorList tensors);

View File

@ -229,6 +229,33 @@ struct BoxedKernelWrapper<
}
};
//
// 3.5. In-process migration to make in-place ops take and return
// const references instead.
template <class... OtherArgs>
struct BoxedKernelWrapper<
const at::Tensor&(const at::Tensor&, OtherArgs...),
std::enable_if_t<can_box_all<OtherArgs...>::value, void>
> {
static const at::Tensor& call(
KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func,
OperatorKernel* functor,
const OperatorHandle& opHandle,
DispatchKeySet dispatchKeySet,
const at::Tensor& outArg, OtherArgs... otherArgs
) {
torch::jit::Stack stack = boxArgs(outArg, otherArgs...);
(*boxed_kernel_func)(functor, opHandle, dispatchKeySet, &stack);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
stack.size() == 1,
"Boxed kernel was expected to return a single value on the stack, ",
"but instead returned ", stack.size(), " values."
);
return outArg;
}
};
//
// 4. out of place ops that take a single non-const Tensor reference as their
// final argument, and also return it.

View File

@ -379,6 +379,19 @@ Currently ops have this field set to True should match `MANUAL_CATCHALL` in tool
(It can be a superset of `MANUAL_CATCHALL` but we don't have a use case for it).
This field should only be used rarely.
### `use_const_ref_for_mutable_tensors`
```
use_const_ref_for_mutable_tensors: True
```
With this flag set, we will generate arguments for Tensors whose underlying data may change as
`const Tensor&` (or similar), just like we would for other Tensors. Previously, we generated these
as `Tensor &`, which 1) allowed changing which `TensorImpl` the `Tensor` itself referred to and 2)
was not necessary to allow the underlying data to change. (This was like using `T * const` when we
wanted `const T*`.)
## Writing an implementation in C++
Implementations of native functions go in an appropriate C++ file in the

View File

@ -7,7 +7,7 @@
namespace at { namespace native {
// Returns true if resize is necessary
bool resize_output_check(Tensor& output, IntArrayRef shape) {
bool resize_output_check(const Tensor& output, IntArrayRef shape) {
// Tests for resizing of tensors with one more elements
if (output.sizes().equals(shape)) {
return false;
@ -25,7 +25,7 @@ bool resize_output_check(Tensor& output, IntArrayRef shape) {
return true;
}
bool resize_output(Tensor& output, IntArrayRef shape) {
bool resize_output(const Tensor& output, IntArrayRef shape) {
if (resize_output_check(output, shape)) {
// avoid a redispatch for cpu and cuda.
// TODO: when resize_cuda_ is re-written to be unified with resize_,
@ -44,7 +44,7 @@ bool resize_output(Tensor& output, IntArrayRef shape) {
// Call the sparse implementation in SparseTensor.cpp directly.
// A dynamic dispatch here is NOT necessary, so I didn't put
// this function in native_functions.yaml
Tensor& resize_as_sparse_(Tensor& self, const Tensor& src);
const Tensor& resize_as_sparse_(const Tensor& self, const Tensor& src);
// TODO(VitalyFedyunin): Move it to HTML docs.
//
@ -70,8 +70,8 @@ Tensor& resize_as_sparse_(Tensor& self, const Tensor& src);
//
// - Otherwise, output tensor will have contiguous memory layout.
//
Tensor& resize_as_(
Tensor& self,
const Tensor& resize_as_(
const Tensor& self,
const Tensor& the_template,
c10::optional<MemoryFormat> optional_memory_format) {
if (self.is_sparse() && the_template.is_sparse()) {
@ -81,7 +81,7 @@ Tensor& resize_as_(
optional_memory_format.value());
return at::native::resize_as_sparse_(self, the_template);
}
Tensor& result = self.resize_(the_template.sizes());
const Tensor& result = self.resize_(the_template.sizes());
if (optional_memory_format.has_value()) {
auto memory_format = optional_memory_format.value();
if (memory_format == MemoryFormat::Preserve) {
@ -93,8 +93,8 @@ Tensor& resize_as_(
return result;
}
Tensor& resize_(
Tensor& self,
const Tensor& resize_(
const Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
if (self.has_names()) {

View File

@ -17,7 +17,7 @@ namespace at { namespace native {
// needs resizing
// NOTE: In the future the warning will become an error
// Returns a bool saying whether or not the resize actually happened or not
TORCH_API bool resize_output(Tensor& output, IntArrayRef shape);
TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
// These functions are called by native::resize_ as well as (legacy) TH resize.
// They are not in TH/THTensor.cpp because the at namespace is easier

View File

@ -20,8 +20,8 @@ inline int64_t storage_size_for(IntArrayRef size, IntArrayRef stride) {
return storage_size;
}
inline Tensor& resize_named_tensor_(
Tensor& self,
inline const Tensor& resize_named_tensor_(
const Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
TORCH_INTERNAL_ASSERT(self.has_names());

View File

@ -7,8 +7,8 @@
namespace at {
namespace native {
Tensor& resize_cuda_(
Tensor& self,
const Tensor& resize_cuda_(
const Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
if (self.has_names()) {

View File

@ -1531,6 +1531,7 @@
QuantizedCPU, QuantizedCUDA: empty_per_channel_affine_quantized
- func: resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: method
device_guard: False
dispatch:
@ -4010,11 +4011,13 @@
QuantizedCPU, QuantizedCUDA: quantized_clone
- func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: function, method
dispatch:
CompositeExplicitAutograd: resize_as_
- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: function
dispatch:
SparseCPU, SparseCUDA: resize_as_sparse_
@ -4262,11 +4265,13 @@
SparseCPU, SparseCUDA: new_with_dims_and_tensor_sparse
- func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_resize_
- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
use_const_ref_for_mutable_tensors: True
variants: method
dispatch:
SparseCPU, SparseCUDA: sparse_resize_and_clear_

View File

@ -60,8 +60,8 @@ AT_FORALL_OPERATORS(DEFINE_COMPARATOR)
#undef AT_FORALL_OPERATORS
#undef DEFINE_COMPARATOR
Tensor& quantized_resize_cpu_(
Tensor& self,
const Tensor& quantized_resize_cpu_(
const Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
TORCH_CHECK(

View File

@ -141,8 +141,8 @@ bool _is_same_size_as_sparse_csr(
return self.sizes().equals(src.sizes());
}
SparseCsrTensor& resize_as_sparse_csr_(
SparseCsrTensor& self,
const SparseCsrTensor& resize_as_sparse_csr_(
const SparseCsrTensor& self,
const SparseCsrTensor& src) {
TORCH_CHECK(
src.is_sparse_csr() && self.is_sparse_csr(),

View File

@ -406,8 +406,8 @@ SparseTensor clone_sparse(
* reshaping methods
******************************************************************************/
SparseTensor& sparse_resize_(
SparseTensor& self,
const SparseTensor& sparse_resize_(
const SparseTensor& self,
ArrayRef<int64_t> size,
int64_t sparse_dim,
int64_t dense_dim) {
@ -415,8 +415,8 @@ SparseTensor& sparse_resize_(
return self;
}
SparseTensor& sparse_resize_and_clear_(
SparseTensor& self,
const SparseTensor& sparse_resize_and_clear_(
const SparseTensor& self,
ArrayRef<int64_t> size,
int64_t sparse_dim,
int64_t dense_dim) {
@ -434,7 +434,7 @@ bool _is_same_size_as_sparse(
} // namespace
// Invoked from native/Resize.cpp (no dynamic dispatch necessary)
SparseTensor& resize_as_sparse_(SparseTensor& self, const SparseTensor& src) {
const SparseTensor& resize_as_sparse_(const SparseTensor& self, const SparseTensor& src) {
if (!_is_same_size_as_sparse(self, src)) {
sparse_resize_(self, src.sizes(), src.sparse_dim(), src.dense_dim());
}

15
tools/autograd/context.py Normal file
View File

@ -0,0 +1,15 @@
from tools.codegen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
from tools.codegen.context import native_function_manager
from tools.codegen.utils import T
import functools
from typing import Callable
# Like tools.api.context.with_native_function, but for
# NativeFunctionWithDifferentiabilityInfo.
def with_native_function_with_differentiability_info(func: Callable[[NFWDI], T]) -> Callable[[NFWDI], T]:
@functools.wraps(func)
def wrapper(f: NFWDI) -> T:
with native_function_manager(f.func):
return func(f)
return wrapper

View File

@ -14,6 +14,7 @@ from tools.codegen.model import (
from typing import List, Optional, Sequence, Tuple
from tools.codegen.gen import FileManager
from tools.codegen.utils import mapMaybe
from .context import with_native_function_with_differentiability_info
from .gen_trace_type import (
MANUAL_AUTOGRAD, type_wrapper_name, tie_return_values, get_return_value
)
@ -321,6 +322,7 @@ def emit_view_body(fn: NativeFunctionWithDifferentiabilityInfo, var: str) -> Tup
def modifies_arguments(f: NativeFunction) -> bool:
return f.func.kind() in [SchemaKind.inplace, SchemaKind.out]
@with_native_function_with_differentiability_info
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
f = fn.func
inplace_view_body: List[str] = []
@ -373,6 +375,7 @@ def gen_formals(f: NativeFunction) -> str:
for a in f.func.schema_order_arguments()]
)
@with_native_function_with_differentiability_info
def inplace_or_view_method_definition(fn: NativeFunctionWithDifferentiabilityInfo) -> Optional[str]:
f = fn.func
if get_view_info(fn) is None and (not modifies_arguments(f) or is_foreach_op(str(f.func.name))):
@ -384,6 +387,7 @@ def inplace_or_view_method_definition(fn: NativeFunctionWithDifferentiabilityInf
type_definition_body=emit_inplace_or_view_body(fn),
)
@with_native_function_with_differentiability_info
def inplace_or_view_method_registration(fn: NativeFunctionWithDifferentiabilityInfo) -> Optional[str]:
f = fn.func
if get_view_info(fn) is None and (not modifies_arguments(f) or is_foreach_op(str(f.func.name))):

View File

@ -22,6 +22,7 @@
# which will in turn dispatch back to VariableType for its
# differentiable subcomponents.
#
from .context import with_native_function_with_differentiability_info
from .gen_trace_type import (
MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER, declare_returned_variables,
tie_return_values, get_return_value, type_wrapper_name,
@ -41,7 +42,7 @@ from tools.codegen.api.autograd import (
is_differentiable)
from tools.codegen.api import cpp
from tools.codegen.code_template import CodeTemplate
from tools.codegen.context import with_native_function
from tools.codegen.context import native_function_manager, with_native_function
from tools.codegen.gen import FileManager
from tools.codegen.utils import mapMaybe
from tools.codegen.model import (Argument, NativeFunction, SchemaKind,
@ -363,16 +364,17 @@ def gen_variable_type_shard(
filtered_fns_with_diff_infos = list(filter(use_derived, fns_with_diff_infos))
for fn in filtered_fns_with_diff_infos:
f = fn.func
name = cpp.name(f.func)
formals = gen_formals(f)
with native_function_manager(f):
name = cpp.name(f.func)
formals = gen_formals(f)
type_definitions.append(METHOD_DEFINITION.substitute(
return_type=cpp.returns_type(f.func.returns).cpp_type(),
type_wrapper_name=type_wrapper_name(f),
type_definition_body=emit_body(fn),
formals=formals,
))
wrapper_registrations.append(gen_wrapper_registration(f))
type_definitions.append(METHOD_DEFINITION.substitute(
return_type=cpp.returns_type(f.func.returns).cpp_type(),
type_wrapper_name=type_wrapper_name(f),
type_definition_body=emit_body(fn),
formals=formals,
))
wrapper_registrations.append(gen_wrapper_registration(f))
# See Note [Manual Backend kernels]
assert (name in MANUAL_BACKEND) == f.manual_kernel_registration
@ -392,6 +394,7 @@ def gen_variable_type_shard(
'wrapper_registrations': wrapper_registrations,
})
@with_native_function_with_differentiability_info
def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
assert dispatch_strategy(fn) == 'use_derived'
f = fn.func

View File

@ -7,6 +7,7 @@ from tools.codegen.api.types import (ArgName, BaseCType, Binding, ConstRefCType,
OptionalCType, TupleCType, SpecialArgName, boolT, scalarT,
tensorListT, dimnameListT, tensorT, voidT,
BaseTypeToCppMapping, intArrayRefT, tensorOptionsT)
from tools.codegen import local
from typing import Optional, Sequence, Union, List, Set
# This file describes the translation of JIT schema to the public C++
@ -68,7 +69,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
else:
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
@ -78,7 +79,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
raise AssertionError(f"base type should have been value type {t}")
elif isinstance(t, OptionalType):
if str(t.elem) == 'Tensor':
if mutable:
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(BaseCType(tensorT))) # TODO: fix this discrepancy
else:
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(tensorT))))
@ -121,7 +122,10 @@ def returntype_type(t: Type, *, mutable: bool) -> CType:
if isinstance(t, BaseType):
if t.name == BaseTy.Tensor:
if mutable:
return MutRefCType(BaseCType(tensorT))
if local.use_const_ref_for_mutable_tensors():
return ConstRefCType(BaseCType(tensorT))
else:
return MutRefCType(BaseCType(tensorT))
else:
return BaseCType(tensorT)
elif t.name == BaseTy.Scalar:

View File

@ -7,6 +7,7 @@ from tools.codegen.api.types import (ArgName, BaseCType, Binding,
OptionalCType, tensorT, scalarT, layoutT,
deviceT, boolT, scalarTypeT)
from tools.codegen.api import cpp
from tools.codegen import local
from typing import Union, Sequence, List, Optional
@ -30,7 +31,7 @@ def name(func: FunctionSchema) -> str:
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
if str(t) == 'Tensor?':
tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
if mutable:
if mutable and not local.use_const_ref_for_mutable_tensors():
return NamedCType(binds, MutRefCType(tensor_type))
else:
return NamedCType(binds, ConstRefCType(tensor_type))

View File

@ -21,7 +21,7 @@ def native_function_manager(g: Union[NativeFunctionsGroup, NativeFunction]) -> I
else:
f = g
with context(f'in {f.loc}:\n {f.func}'):
with local.parametrize():
with local.parametrize(use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors):
yield
# Given a function that operates on NativeFunction, wrap it into a new function

View File

@ -195,6 +195,10 @@ class ComputeFunction:
if Variant.function not in f.variants and not self.is_redispatching_fn:
return None
with native_function_manager(f):
return self.callImpl(f)
def callImpl(self, f: NativeFunction) -> str:
name = cpp.name(f.func)
sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding)

View File

@ -1,6 +1,6 @@
import threading
from contextlib import contextmanager
from typing import Iterator
from typing import Optional, Iterator
# Simple dynamic scoping implementation. The name "parametrize" comes
# from Racket.
@ -15,10 +15,21 @@ from typing import Iterator
# DON'T add a new entry here.
class Locals(threading.local):
pass
use_const_ref_for_mutable_tensors: Optional[bool] = None
_locals = Locals()
def use_const_ref_for_mutable_tensors() -> bool:
assert _locals.use_const_ref_for_mutable_tensors is not None, \
"need to initialize local.use_const_ref_for_mutable_tensors with " \
"local.parametrize"
return _locals.use_const_ref_for_mutable_tensors
@contextmanager
def parametrize() -> Iterator[None]:
yield
def parametrize(*, use_const_ref_for_mutable_tensors: bool) -> Iterator[None]:
old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
try:
_locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
yield
finally:
_locals.use_const_ref_for_mutable_tensors = old_use_const_ref_for_mutable_tensors

View File

@ -168,6 +168,10 @@ class NativeFunction:
# classes for expository clarity.)
func: 'FunctionSchema'
# Whether or not to generate mutable tensor arguments like regular
# ones
use_const_ref_for_mutable_tensors: bool
# Whether or not to omit automatic generation of a DeviceGuard
device_guard: bool
@ -263,6 +267,9 @@ class NativeFunction:
assert isinstance(cpp_no_default_args_list, list)
cpp_no_default_args = set(cpp_no_default_args_list)
use_const_ref_for_mutable_tensors = e.pop('use_const_ref_for_mutable_tensors', False)
assert isinstance(use_const_ref_for_mutable_tensors, bool)
variants_s = e.pop('variants', 'function')
assert isinstance(variants_s, str)
variants: Set[Variant] = set()
@ -340,6 +347,7 @@ class NativeFunction:
return NativeFunction(
func=func,
use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
variants=variants,
structured=structured,
structured_delegate=structured_delegate,
@ -741,7 +749,7 @@ class Annotation:
@staticmethod
def parse(ann: str) -> 'Annotation':
m = re.match(r'^([a-z])(!?)$', ann)
m = re.match(r'^([a-z])(!?)(!?)$', ann)
assert m is not None, f'unrecognized alias annotation {ann}'
alias_set = (m.group(1),)
is_write = m.group(2) == '!'

View File

@ -45,8 +45,8 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
return self;
}
Tensor& resize_(
Tensor& self,
const Tensor& resize_(
const Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
if (torch::jit::tracer::isTracing()) {
@ -62,8 +62,8 @@ Tensor& resize_(
return self;
}
Tensor& resize_as_(
Tensor& self,
const Tensor& resize_as_(
const Tensor& self,
const Tensor& the_template,
c10::optional<MemoryFormat> optional_memory_format) {
if (torch::jit::tracer::isTracing()) {

View File

@ -156,9 +156,9 @@ Tensor & copy_(c10::DispatchKeySet ks, Tensor & self, const Tensor & src, bool n
return self;
}
Tensor& resize_(
const Tensor& resize_(
c10::DispatchKeySet ks,
Tensor& self,
const Tensor& self,
IntArrayRef size,
c10::optional<MemoryFormat> optional_memory_format) {
auto& self_ = unpack(self, "self", 0);
@ -177,9 +177,9 @@ Tensor& resize_(
return self;
}
Tensor& resize_as_(
const Tensor& resize_as_(
c10::DispatchKeySet ks,
Tensor& self,
const Tensor& self,
const Tensor& the_template,
c10::optional<MemoryFormat> optional_memory_format) {
auto& self_ = unpack(self, "self", 0);

View File

@ -112,7 +112,7 @@ inline void rebase_history(std::vector<Variable>&& vars, std::shared_ptr<Node> g
}
}
inline void increment_version(Tensor & t) {
inline void increment_version(const Tensor & t) {
impl::bump_version(t);
}