mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[PyTorch] Fix const correctness for resize native functions (#55351)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/55351 We incorrectly used `Tensor&` to mean "the underlying TensorImpl cannot be changed", as explained in https://github.com/zdevito/ATen/issues/27#issuecomment-330717839 . This diff gets us on the path to fixing this problem: we have an incremental way to fix individual native functions so that we can apply any handwritten fixes a few at a time. It gets the migration started with the `resize` family of native functions. ghstack-source-id: 127092677 Test Plan: fitsships Reviewed By: ezyang Differential Revision: D27583983 fbshipit-source-id: 4eeeec85f5d268e9d0f1645eb9396914a9f9557f
This commit is contained in:
parent
5e695b1271
commit
1211bccc65
|
|
@ -125,7 +125,7 @@ static void assert_names_equal(DimnameList a, DimnameList b) {
|
|||
". Please rename the out tensor's dims with `Tensor.rename`.");
|
||||
}
|
||||
|
||||
Tensor& propagate_names_if_nonempty(Tensor& result,
|
||||
const Tensor& propagate_names_if_nonempty(const Tensor& result,
|
||||
DimnameList maybe_names,
|
||||
bool validate_names) {
|
||||
propagate_names_if_nonempty(result.unsafeGetTensorImpl(), maybe_names, validate_names);
|
||||
|
|
@ -141,7 +141,7 @@ TensorImpl* propagate_names_if_nonempty(TensorImpl* result,
|
|||
return propagate_names(result, maybe_names, validate_names);
|
||||
}
|
||||
|
||||
Tensor& propagate_names(Tensor& result, DimnameList names, bool validate_names) {
|
||||
const Tensor& propagate_names(const Tensor& result, DimnameList names, bool validate_names) {
|
||||
propagate_names(result.unsafeGetTensorImpl(), names, validate_names);
|
||||
return result;
|
||||
}
|
||||
|
|
@ -162,7 +162,7 @@ TensorImpl* propagate_names(TensorImpl* result, DimnameList names, bool validate
|
|||
return result;
|
||||
}
|
||||
|
||||
void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) {
|
||||
void propagate_names_except(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs) {
|
||||
if (!result.has_names() && !src.has_names()) {
|
||||
return;
|
||||
}
|
||||
|
|
@ -190,7 +190,7 @@ void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef exclu
|
|||
propagate_names(result, outnames);
|
||||
}
|
||||
|
||||
void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) {
|
||||
void propagate_names_for_reduction(const Tensor& result, const Tensor& src, IntArrayRef reduced_dims, bool keepdim) {
|
||||
if (keepdim) {
|
||||
propagate_names(result, src);
|
||||
return;
|
||||
|
|
@ -202,7 +202,7 @@ void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRe
|
|||
propagate_names_except(result, src, reduced_dims);
|
||||
}
|
||||
|
||||
void propagate_names(Tensor& result, const Tensor& src) {
|
||||
void propagate_names(const Tensor& result, const Tensor& src) {
|
||||
propagate_names(result.unsafeGetTensorImpl(), src.unsafeGetTensorImpl());
|
||||
}
|
||||
|
||||
|
|
@ -409,7 +409,7 @@ void check_names_for_dot(
|
|||
// rules for binary ops that expect the named dims to line up positionally
|
||||
// from the right. i.e.,
|
||||
// Tensor[H, W].expand(3, 3, 3, 3) -> Tensor[None, None, H, W]
|
||||
void propagate_names_for_expand(Tensor& result, const Tensor& self) {
|
||||
void propagate_names_for_expand(const Tensor& result, const Tensor& self) {
|
||||
if (!self.has_names()) {
|
||||
return;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -75,28 +75,28 @@ namespace namedinference {
|
|||
// `names` can be empty; see [NOTE] Writing name inference rules
|
||||
// If `names` is not empty, `names.size()` should equal `result.dim()`.
|
||||
// When in doubt, use this overload instead of the others.
|
||||
TORCH_API Tensor& propagate_names_if_nonempty(
|
||||
Tensor& result,
|
||||
TORCH_API const Tensor& propagate_names_if_nonempty(
|
||||
const Tensor& result,
|
||||
DimnameList maybe_names,
|
||||
bool validate_names = false);
|
||||
|
||||
// Propagates `names` to `result`. Only use this if we are certain that there are
|
||||
// names to propagate (that names is not empty).
|
||||
TORCH_API Tensor& propagate_names(
|
||||
Tensor& result,
|
||||
TORCH_API const Tensor& propagate_names(
|
||||
const Tensor& result,
|
||||
DimnameList names,
|
||||
bool validate_names = false);
|
||||
|
||||
// Propagates all names from src to result.
|
||||
TORCH_API void propagate_names(Tensor& result, const Tensor& src);
|
||||
TORCH_API void propagate_names(const Tensor& result, const Tensor& src);
|
||||
|
||||
// Propagates all names except for those at the excluded_idxs.
|
||||
TORCH_API void propagate_names_except(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs);
|
||||
TORCH_API void propagate_names_except(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs);
|
||||
|
||||
// Used for reduction ops that have a `keepdim` arg.
|
||||
TORCH_API void propagate_names_for_reduction(Tensor& result, const Tensor& src, IntArrayRef excluded_idxs, bool keepdim);
|
||||
TORCH_API void propagate_names_for_reduction(const Tensor& result, const Tensor& src, IntArrayRef excluded_idxs, bool keepdim);
|
||||
|
||||
TORCH_API void propagate_names_for_expand(Tensor& result, const Tensor& self);
|
||||
TORCH_API void propagate_names_for_expand(const Tensor& result, const Tensor& self);
|
||||
|
||||
TORCH_API std::vector<Dimname> compute_cat_outnames(TensorList tensors);
|
||||
|
||||
|
|
|
|||
|
|
@ -229,6 +229,33 @@ struct BoxedKernelWrapper<
|
|||
}
|
||||
};
|
||||
|
||||
//
|
||||
// 3.5. In-process migration to make in-place ops take and return
|
||||
// const references instead.
|
||||
template <class... OtherArgs>
|
||||
struct BoxedKernelWrapper<
|
||||
const at::Tensor&(const at::Tensor&, OtherArgs...),
|
||||
std::enable_if_t<can_box_all<OtherArgs...>::value, void>
|
||||
> {
|
||||
static const at::Tensor& call(
|
||||
KernelFunction::InternalBoxedKernelFunction* boxed_kernel_func,
|
||||
OperatorKernel* functor,
|
||||
const OperatorHandle& opHandle,
|
||||
DispatchKeySet dispatchKeySet,
|
||||
const at::Tensor& outArg, OtherArgs... otherArgs
|
||||
) {
|
||||
torch::jit::Stack stack = boxArgs(outArg, otherArgs...);
|
||||
(*boxed_kernel_func)(functor, opHandle, dispatchKeySet, &stack);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(
|
||||
stack.size() == 1,
|
||||
"Boxed kernel was expected to return a single value on the stack, ",
|
||||
"but instead returned ", stack.size(), " values."
|
||||
);
|
||||
|
||||
return outArg;
|
||||
}
|
||||
};
|
||||
|
||||
//
|
||||
// 4. out of place ops that take a single non-const Tensor reference as their
|
||||
// final argument, and also return it.
|
||||
|
|
|
|||
|
|
@ -379,6 +379,19 @@ Currently ops have this field set to True should match `MANUAL_CATCHALL` in tool
|
|||
(It can be a superset of `MANUAL_CATCHALL` but we don't have a use case for it).
|
||||
This field should only be used rarely.
|
||||
|
||||
### `use_const_ref_for_mutable_tensors`
|
||||
|
||||
```
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
```
|
||||
|
||||
With this flag set, we will generate arguments for Tensors whose underlying data may change as
|
||||
`const Tensor&` (or similar), just like we would for other Tensors. Previously, we generated these
|
||||
as `Tensor &`, which 1) allowed changing which `TensorImpl` the `Tensor` itself referred to and 2)
|
||||
was not necessary to allow the underlying data to change. (This was like using `T * const` when we
|
||||
wanted `const T*`.)
|
||||
|
||||
|
||||
## Writing an implementation in C++
|
||||
|
||||
Implementations of native functions go in an appropriate C++ file in the
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
namespace at { namespace native {
|
||||
|
||||
// Returns true if resize is necessary
|
||||
bool resize_output_check(Tensor& output, IntArrayRef shape) {
|
||||
bool resize_output_check(const Tensor& output, IntArrayRef shape) {
|
||||
// Tests for resizing of tensors with one more elements
|
||||
if (output.sizes().equals(shape)) {
|
||||
return false;
|
||||
|
|
@ -25,7 +25,7 @@ bool resize_output_check(Tensor& output, IntArrayRef shape) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool resize_output(Tensor& output, IntArrayRef shape) {
|
||||
bool resize_output(const Tensor& output, IntArrayRef shape) {
|
||||
if (resize_output_check(output, shape)) {
|
||||
// avoid a redispatch for cpu and cuda.
|
||||
// TODO: when resize_cuda_ is re-written to be unified with resize_,
|
||||
|
|
@ -44,7 +44,7 @@ bool resize_output(Tensor& output, IntArrayRef shape) {
|
|||
// Call the sparse implementation in SparseTensor.cpp directly.
|
||||
// A dynamic dispatch here is NOT necessary, so I didn't put
|
||||
// this function in native_functions.yaml
|
||||
Tensor& resize_as_sparse_(Tensor& self, const Tensor& src);
|
||||
const Tensor& resize_as_sparse_(const Tensor& self, const Tensor& src);
|
||||
|
||||
// TODO(VitalyFedyunin): Move it to HTML docs.
|
||||
//
|
||||
|
|
@ -70,8 +70,8 @@ Tensor& resize_as_sparse_(Tensor& self, const Tensor& src);
|
|||
//
|
||||
// - Otherwise, output tensor will have contiguous memory layout.
|
||||
//
|
||||
Tensor& resize_as_(
|
||||
Tensor& self,
|
||||
const Tensor& resize_as_(
|
||||
const Tensor& self,
|
||||
const Tensor& the_template,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
if (self.is_sparse() && the_template.is_sparse()) {
|
||||
|
|
@ -81,7 +81,7 @@ Tensor& resize_as_(
|
|||
optional_memory_format.value());
|
||||
return at::native::resize_as_sparse_(self, the_template);
|
||||
}
|
||||
Tensor& result = self.resize_(the_template.sizes());
|
||||
const Tensor& result = self.resize_(the_template.sizes());
|
||||
if (optional_memory_format.has_value()) {
|
||||
auto memory_format = optional_memory_format.value();
|
||||
if (memory_format == MemoryFormat::Preserve) {
|
||||
|
|
@ -93,8 +93,8 @@ Tensor& resize_as_(
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor& resize_(
|
||||
Tensor& self,
|
||||
const Tensor& resize_(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
if (self.has_names()) {
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ namespace at { namespace native {
|
|||
// needs resizing
|
||||
// NOTE: In the future the warning will become an error
|
||||
// Returns a bool saying whether or not the resize actually happened or not
|
||||
TORCH_API bool resize_output(Tensor& output, IntArrayRef shape);
|
||||
TORCH_API bool resize_output(const Tensor& output, IntArrayRef shape);
|
||||
|
||||
// These functions are called by native::resize_ as well as (legacy) TH resize.
|
||||
// They are not in TH/THTensor.cpp because the at namespace is easier
|
||||
|
|
|
|||
|
|
@ -20,8 +20,8 @@ inline int64_t storage_size_for(IntArrayRef size, IntArrayRef stride) {
|
|||
return storage_size;
|
||||
}
|
||||
|
||||
inline Tensor& resize_named_tensor_(
|
||||
Tensor& self,
|
||||
inline const Tensor& resize_named_tensor_(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
TORCH_INTERNAL_ASSERT(self.has_names());
|
||||
|
|
|
|||
|
|
@ -7,8 +7,8 @@
|
|||
namespace at {
|
||||
namespace native {
|
||||
|
||||
Tensor& resize_cuda_(
|
||||
Tensor& self,
|
||||
const Tensor& resize_cuda_(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
if (self.has_names()) {
|
||||
|
|
|
|||
|
|
@ -1531,6 +1531,7 @@
|
|||
QuantizedCPU, QuantizedCUDA: empty_per_channel_affine_quantized
|
||||
|
||||
- func: resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
variants: method
|
||||
device_guard: False
|
||||
dispatch:
|
||||
|
|
@ -4010,11 +4011,13 @@
|
|||
QuantizedCPU, QuantizedCUDA: quantized_clone
|
||||
|
||||
- func: resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
variants: function, method
|
||||
dispatch:
|
||||
CompositeExplicitAutograd: resize_as_
|
||||
|
||||
- func: resize_as_sparse_(Tensor(a!) self, Tensor the_template) -> Tensor(a!)
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
variants: function
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: resize_as_sparse_
|
||||
|
|
@ -4262,11 +4265,13 @@
|
|||
SparseCPU, SparseCUDA: new_with_dims_and_tensor_sparse
|
||||
|
||||
- func: sparse_resize_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_resize_
|
||||
|
||||
- func: sparse_resize_and_clear_(Tensor(a!) self, int[] size, int sparse_dim, int dense_dim) -> Tensor(a!)
|
||||
use_const_ref_for_mutable_tensors: True
|
||||
variants: method
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: sparse_resize_and_clear_
|
||||
|
|
|
|||
|
|
@ -60,8 +60,8 @@ AT_FORALL_OPERATORS(DEFINE_COMPARATOR)
|
|||
#undef AT_FORALL_OPERATORS
|
||||
#undef DEFINE_COMPARATOR
|
||||
|
||||
Tensor& quantized_resize_cpu_(
|
||||
Tensor& self,
|
||||
const Tensor& quantized_resize_cpu_(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
|
|
|
|||
|
|
@ -141,8 +141,8 @@ bool _is_same_size_as_sparse_csr(
|
|||
return self.sizes().equals(src.sizes());
|
||||
}
|
||||
|
||||
SparseCsrTensor& resize_as_sparse_csr_(
|
||||
SparseCsrTensor& self,
|
||||
const SparseCsrTensor& resize_as_sparse_csr_(
|
||||
const SparseCsrTensor& self,
|
||||
const SparseCsrTensor& src) {
|
||||
TORCH_CHECK(
|
||||
src.is_sparse_csr() && self.is_sparse_csr(),
|
||||
|
|
|
|||
|
|
@ -406,8 +406,8 @@ SparseTensor clone_sparse(
|
|||
* reshaping methods
|
||||
******************************************************************************/
|
||||
|
||||
SparseTensor& sparse_resize_(
|
||||
SparseTensor& self,
|
||||
const SparseTensor& sparse_resize_(
|
||||
const SparseTensor& self,
|
||||
ArrayRef<int64_t> size,
|
||||
int64_t sparse_dim,
|
||||
int64_t dense_dim) {
|
||||
|
|
@ -415,8 +415,8 @@ SparseTensor& sparse_resize_(
|
|||
return self;
|
||||
}
|
||||
|
||||
SparseTensor& sparse_resize_and_clear_(
|
||||
SparseTensor& self,
|
||||
const SparseTensor& sparse_resize_and_clear_(
|
||||
const SparseTensor& self,
|
||||
ArrayRef<int64_t> size,
|
||||
int64_t sparse_dim,
|
||||
int64_t dense_dim) {
|
||||
|
|
@ -434,7 +434,7 @@ bool _is_same_size_as_sparse(
|
|||
} // namespace
|
||||
|
||||
// Invoked from native/Resize.cpp (no dynamic dispatch necessary)
|
||||
SparseTensor& resize_as_sparse_(SparseTensor& self, const SparseTensor& src) {
|
||||
const SparseTensor& resize_as_sparse_(const SparseTensor& self, const SparseTensor& src) {
|
||||
if (!_is_same_size_as_sparse(self, src)) {
|
||||
sparse_resize_(self, src.sizes(), src.sparse_dim(), src.dense_dim());
|
||||
}
|
||||
|
|
|
|||
15
tools/autograd/context.py
Normal file
15
tools/autograd/context.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from tools.codegen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI
|
||||
from tools.codegen.context import native_function_manager
|
||||
from tools.codegen.utils import T
|
||||
|
||||
import functools
|
||||
from typing import Callable
|
||||
|
||||
# Like tools.api.context.with_native_function, but for
|
||||
# NativeFunctionWithDifferentiabilityInfo.
|
||||
def with_native_function_with_differentiability_info(func: Callable[[NFWDI], T]) -> Callable[[NFWDI], T]:
|
||||
@functools.wraps(func)
|
||||
def wrapper(f: NFWDI) -> T:
|
||||
with native_function_manager(f.func):
|
||||
return func(f)
|
||||
return wrapper
|
||||
|
|
@ -14,6 +14,7 @@ from tools.codegen.model import (
|
|||
from typing import List, Optional, Sequence, Tuple
|
||||
from tools.codegen.gen import FileManager
|
||||
from tools.codegen.utils import mapMaybe
|
||||
from .context import with_native_function_with_differentiability_info
|
||||
from .gen_trace_type import (
|
||||
MANUAL_AUTOGRAD, type_wrapper_name, tie_return_values, get_return_value
|
||||
)
|
||||
|
|
@ -321,6 +322,7 @@ def emit_view_body(fn: NativeFunctionWithDifferentiabilityInfo, var: str) -> Tup
|
|||
def modifies_arguments(f: NativeFunction) -> bool:
|
||||
return f.func.kind() in [SchemaKind.inplace, SchemaKind.out]
|
||||
|
||||
@with_native_function_with_differentiability_info
|
||||
def emit_inplace_or_view_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
|
||||
f = fn.func
|
||||
inplace_view_body: List[str] = []
|
||||
|
|
@ -373,6 +375,7 @@ def gen_formals(f: NativeFunction) -> str:
|
|||
for a in f.func.schema_order_arguments()]
|
||||
)
|
||||
|
||||
@with_native_function_with_differentiability_info
|
||||
def inplace_or_view_method_definition(fn: NativeFunctionWithDifferentiabilityInfo) -> Optional[str]:
|
||||
f = fn.func
|
||||
if get_view_info(fn) is None and (not modifies_arguments(f) or is_foreach_op(str(f.func.name))):
|
||||
|
|
@ -384,6 +387,7 @@ def inplace_or_view_method_definition(fn: NativeFunctionWithDifferentiabilityInf
|
|||
type_definition_body=emit_inplace_or_view_body(fn),
|
||||
)
|
||||
|
||||
@with_native_function_with_differentiability_info
|
||||
def inplace_or_view_method_registration(fn: NativeFunctionWithDifferentiabilityInfo) -> Optional[str]:
|
||||
f = fn.func
|
||||
if get_view_info(fn) is None and (not modifies_arguments(f) or is_foreach_op(str(f.func.name))):
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
# which will in turn dispatch back to VariableType for its
|
||||
# differentiable subcomponents.
|
||||
#
|
||||
from .context import with_native_function_with_differentiability_info
|
||||
from .gen_trace_type import (
|
||||
MANUAL_BACKEND, MANUAL_AUTOGRAD_AND_TRACER, declare_returned_variables,
|
||||
tie_return_values, get_return_value, type_wrapper_name,
|
||||
|
|
@ -41,7 +42,7 @@ from tools.codegen.api.autograd import (
|
|||
is_differentiable)
|
||||
from tools.codegen.api import cpp
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
from tools.codegen.context import with_native_function
|
||||
from tools.codegen.context import native_function_manager, with_native_function
|
||||
from tools.codegen.gen import FileManager
|
||||
from tools.codegen.utils import mapMaybe
|
||||
from tools.codegen.model import (Argument, NativeFunction, SchemaKind,
|
||||
|
|
@ -363,16 +364,17 @@ def gen_variable_type_shard(
|
|||
filtered_fns_with_diff_infos = list(filter(use_derived, fns_with_diff_infos))
|
||||
for fn in filtered_fns_with_diff_infos:
|
||||
f = fn.func
|
||||
name = cpp.name(f.func)
|
||||
formals = gen_formals(f)
|
||||
with native_function_manager(f):
|
||||
name = cpp.name(f.func)
|
||||
formals = gen_formals(f)
|
||||
|
||||
type_definitions.append(METHOD_DEFINITION.substitute(
|
||||
return_type=cpp.returns_type(f.func.returns).cpp_type(),
|
||||
type_wrapper_name=type_wrapper_name(f),
|
||||
type_definition_body=emit_body(fn),
|
||||
formals=formals,
|
||||
))
|
||||
wrapper_registrations.append(gen_wrapper_registration(f))
|
||||
type_definitions.append(METHOD_DEFINITION.substitute(
|
||||
return_type=cpp.returns_type(f.func.returns).cpp_type(),
|
||||
type_wrapper_name=type_wrapper_name(f),
|
||||
type_definition_body=emit_body(fn),
|
||||
formals=formals,
|
||||
))
|
||||
wrapper_registrations.append(gen_wrapper_registration(f))
|
||||
|
||||
# See Note [Manual Backend kernels]
|
||||
assert (name in MANUAL_BACKEND) == f.manual_kernel_registration
|
||||
|
|
@ -392,6 +394,7 @@ def gen_variable_type_shard(
|
|||
'wrapper_registrations': wrapper_registrations,
|
||||
})
|
||||
|
||||
@with_native_function_with_differentiability_info
|
||||
def emit_body(fn: NativeFunctionWithDifferentiabilityInfo) -> List[str]:
|
||||
assert dispatch_strategy(fn) == 'use_derived'
|
||||
f = fn.func
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from tools.codegen.api.types import (ArgName, BaseCType, Binding, ConstRefCType,
|
|||
OptionalCType, TupleCType, SpecialArgName, boolT, scalarT,
|
||||
tensorListT, dimnameListT, tensorT, voidT,
|
||||
BaseTypeToCppMapping, intArrayRefT, tensorOptionsT)
|
||||
from tools.codegen import local
|
||||
from typing import Optional, Sequence, Union, List, Set
|
||||
|
||||
# This file describes the translation of JIT schema to the public C++
|
||||
|
|
@ -68,7 +69,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
|||
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor:
|
||||
if mutable:
|
||||
if mutable and not local.use_const_ref_for_mutable_tensors():
|
||||
return NamedCType(binds, MutRefCType(BaseCType(tensorT)))
|
||||
else:
|
||||
return NamedCType(binds, ConstRefCType(BaseCType(tensorT)))
|
||||
|
|
@ -78,7 +79,7 @@ def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
|||
raise AssertionError(f"base type should have been value type {t}")
|
||||
elif isinstance(t, OptionalType):
|
||||
if str(t.elem) == 'Tensor':
|
||||
if mutable:
|
||||
if mutable and not local.use_const_ref_for_mutable_tensors():
|
||||
return NamedCType(binds, MutRefCType(BaseCType(tensorT))) # TODO: fix this discrepancy
|
||||
else:
|
||||
return NamedCType(binds, ConstRefCType(OptionalCType(BaseCType(tensorT))))
|
||||
|
|
@ -121,7 +122,10 @@ def returntype_type(t: Type, *, mutable: bool) -> CType:
|
|||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor:
|
||||
if mutable:
|
||||
return MutRefCType(BaseCType(tensorT))
|
||||
if local.use_const_ref_for_mutable_tensors():
|
||||
return ConstRefCType(BaseCType(tensorT))
|
||||
else:
|
||||
return MutRefCType(BaseCType(tensorT))
|
||||
else:
|
||||
return BaseCType(tensorT)
|
||||
elif t.name == BaseTy.Scalar:
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ from tools.codegen.api.types import (ArgName, BaseCType, Binding,
|
|||
OptionalCType, tensorT, scalarT, layoutT,
|
||||
deviceT, boolT, scalarTypeT)
|
||||
from tools.codegen.api import cpp
|
||||
from tools.codegen import local
|
||||
|
||||
from typing import Union, Sequence, List, Optional
|
||||
|
||||
|
|
@ -30,7 +31,7 @@ def name(func: FunctionSchema) -> str:
|
|||
def argumenttype_type(t: Type, *, mutable: bool, binds: ArgName) -> NamedCType:
|
||||
if str(t) == 'Tensor?':
|
||||
tensor_type: OptionalCType = OptionalCType(BaseCType(tensorT))
|
||||
if mutable:
|
||||
if mutable and not local.use_const_ref_for_mutable_tensors():
|
||||
return NamedCType(binds, MutRefCType(tensor_type))
|
||||
else:
|
||||
return NamedCType(binds, ConstRefCType(tensor_type))
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ def native_function_manager(g: Union[NativeFunctionsGroup, NativeFunction]) -> I
|
|||
else:
|
||||
f = g
|
||||
with context(f'in {f.loc}:\n {f.func}'):
|
||||
with local.parametrize():
|
||||
with local.parametrize(use_const_ref_for_mutable_tensors=f.use_const_ref_for_mutable_tensors):
|
||||
yield
|
||||
|
||||
# Given a function that operates on NativeFunction, wrap it into a new function
|
||||
|
|
|
|||
|
|
@ -195,6 +195,10 @@ class ComputeFunction:
|
|||
if Variant.function not in f.variants and not self.is_redispatching_fn:
|
||||
return None
|
||||
|
||||
with native_function_manager(f):
|
||||
return self.callImpl(f)
|
||||
|
||||
def callImpl(self, f: NativeFunction) -> str:
|
||||
name = cpp.name(f.func)
|
||||
|
||||
sig_group = CppSignatureGroup.from_native_function(f, method=False, fallback_binding=f.manual_cpp_binding)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Iterator
|
||||
from typing import Optional, Iterator
|
||||
|
||||
# Simple dynamic scoping implementation. The name "parametrize" comes
|
||||
# from Racket.
|
||||
|
|
@ -15,10 +15,21 @@ from typing import Iterator
|
|||
# DON'T add a new entry here.
|
||||
|
||||
class Locals(threading.local):
|
||||
pass
|
||||
use_const_ref_for_mutable_tensors: Optional[bool] = None
|
||||
|
||||
_locals = Locals()
|
||||
|
||||
def use_const_ref_for_mutable_tensors() -> bool:
|
||||
assert _locals.use_const_ref_for_mutable_tensors is not None, \
|
||||
"need to initialize local.use_const_ref_for_mutable_tensors with " \
|
||||
"local.parametrize"
|
||||
return _locals.use_const_ref_for_mutable_tensors
|
||||
|
||||
@contextmanager
|
||||
def parametrize() -> Iterator[None]:
|
||||
yield
|
||||
def parametrize(*, use_const_ref_for_mutable_tensors: bool) -> Iterator[None]:
|
||||
old_use_const_ref_for_mutable_tensors = _locals.use_const_ref_for_mutable_tensors
|
||||
try:
|
||||
_locals.use_const_ref_for_mutable_tensors = use_const_ref_for_mutable_tensors
|
||||
yield
|
||||
finally:
|
||||
_locals.use_const_ref_for_mutable_tensors = old_use_const_ref_for_mutable_tensors
|
||||
|
|
|
|||
|
|
@ -168,6 +168,10 @@ class NativeFunction:
|
|||
# classes for expository clarity.)
|
||||
func: 'FunctionSchema'
|
||||
|
||||
# Whether or not to generate mutable tensor arguments like regular
|
||||
# ones
|
||||
use_const_ref_for_mutable_tensors: bool
|
||||
|
||||
# Whether or not to omit automatic generation of a DeviceGuard
|
||||
device_guard: bool
|
||||
|
||||
|
|
@ -263,6 +267,9 @@ class NativeFunction:
|
|||
assert isinstance(cpp_no_default_args_list, list)
|
||||
cpp_no_default_args = set(cpp_no_default_args_list)
|
||||
|
||||
use_const_ref_for_mutable_tensors = e.pop('use_const_ref_for_mutable_tensors', False)
|
||||
assert isinstance(use_const_ref_for_mutable_tensors, bool)
|
||||
|
||||
variants_s = e.pop('variants', 'function')
|
||||
assert isinstance(variants_s, str)
|
||||
variants: Set[Variant] = set()
|
||||
|
|
@ -340,6 +347,7 @@ class NativeFunction:
|
|||
|
||||
return NativeFunction(
|
||||
func=func,
|
||||
use_const_ref_for_mutable_tensors=use_const_ref_for_mutable_tensors,
|
||||
variants=variants,
|
||||
structured=structured,
|
||||
structured_delegate=structured_delegate,
|
||||
|
|
@ -741,7 +749,7 @@ class Annotation:
|
|||
|
||||
@staticmethod
|
||||
def parse(ann: str) -> 'Annotation':
|
||||
m = re.match(r'^([a-z])(!?)$', ann)
|
||||
m = re.match(r'^([a-z])(!?)(!?)$', ann)
|
||||
assert m is not None, f'unrecognized alias annotation {ann}'
|
||||
alias_set = (m.group(1),)
|
||||
is_write = m.group(2) == '!'
|
||||
|
|
|
|||
|
|
@ -45,8 +45,8 @@ Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking) {
|
|||
return self;
|
||||
}
|
||||
|
||||
Tensor& resize_(
|
||||
Tensor& self,
|
||||
const Tensor& resize_(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
if (torch::jit::tracer::isTracing()) {
|
||||
|
|
@ -62,8 +62,8 @@ Tensor& resize_(
|
|||
return self;
|
||||
}
|
||||
|
||||
Tensor& resize_as_(
|
||||
Tensor& self,
|
||||
const Tensor& resize_as_(
|
||||
const Tensor& self,
|
||||
const Tensor& the_template,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
if (torch::jit::tracer::isTracing()) {
|
||||
|
|
|
|||
|
|
@ -156,9 +156,9 @@ Tensor & copy_(c10::DispatchKeySet ks, Tensor & self, const Tensor & src, bool n
|
|||
return self;
|
||||
}
|
||||
|
||||
Tensor& resize_(
|
||||
const Tensor& resize_(
|
||||
c10::DispatchKeySet ks,
|
||||
Tensor& self,
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
auto& self_ = unpack(self, "self", 0);
|
||||
|
|
@ -177,9 +177,9 @@ Tensor& resize_(
|
|||
return self;
|
||||
}
|
||||
|
||||
Tensor& resize_as_(
|
||||
const Tensor& resize_as_(
|
||||
c10::DispatchKeySet ks,
|
||||
Tensor& self,
|
||||
const Tensor& self,
|
||||
const Tensor& the_template,
|
||||
c10::optional<MemoryFormat> optional_memory_format) {
|
||||
auto& self_ = unpack(self, "self", 0);
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ inline void rebase_history(std::vector<Variable>&& vars, std::shared_ptr<Node> g
|
|||
}
|
||||
}
|
||||
|
||||
inline void increment_version(Tensor & t) {
|
||||
inline void increment_version(const Tensor & t) {
|
||||
impl::bump_version(t);
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user