mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Make empty c10-full (#46092)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/46092 Make empty c10-full without using hacky-wrapper, i.e. port the kernel to the new style signature. This PR also changes the signature of some helpers called by empty to the new style. ghstack-source-id: 116544203 (Note: this ignores all push blocking failures!) Test Plan: vs prev diff (outdated, before c10::optional fix): https://www.internalfb.com/intern/fblearner/details/224735103/ after c10::optional fix: https://www.internalfb.com/intern/fblearner/details/231391773/ Also, after the c10::optional fix, the instruction counting benchmark shows a 2% regression for calling empty from Python. We decided this is acceptable and decided against landing D24425836 which would fix the regression. Reviewed By: ezyang Differential Revision: D24219944 fbshipit-source-id: e554096e90ce438c75b679131c3151ff8e5c5d50
This commit is contained in:
parent
3649a2c170
commit
edf751ca2f
|
|
@ -869,10 +869,13 @@ Tensor new_zeros_batching_rule(
|
|||
Tensor new_empty_batching_rule(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
const TensorOptions& options) {
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
auto physical_view = MultiBatchVmapTransform::logicalToPhysical(self);
|
||||
auto physical_size = physical_view.getPhysicalShape(size);
|
||||
auto result = physical_view.tensor().new_empty(physical_size, options);
|
||||
auto result = physical_view.tensor().new_empty(physical_size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory));
|
||||
return physical_view.newLogicalFromPhysical(result);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -29,10 +29,11 @@ Tensor& scalar_fill(Tensor& self, Scalar value) {
|
|||
return self;
|
||||
}
|
||||
|
||||
Tensor scalar_tensor_static(Scalar s, const TensorOptions& options) {
|
||||
Tensor scalar_tensor_static(Scalar s, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
at::tracer::impl::NoTracerDispatchMode tracer_guard;
|
||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||
auto result = at::detail::empty_cpu({}, options);
|
||||
auto result = at::detail::empty_cpu({}, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
scalar_fill(result, s);
|
||||
return result;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,9 @@ namespace detail {
|
|||
// but we also want to skip compute_types which in not avoidable
|
||||
// in TensorIterator for now.
|
||||
Tensor& scalar_fill(Tensor& self, Scalar value);
|
||||
TORCH_API Tensor scalar_tensor_static(Scalar s, const TensorOptions& options);
|
||||
TORCH_API Tensor scalar_tensor_static(Scalar s, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt,
|
||||
c10::optional<c10::MemoryFormat> memory_format_opt);
|
||||
} // namespace detail
|
||||
} // namespace at
|
||||
|
||||
|
|
@ -25,12 +27,12 @@ inline at::Tensor scalar_to_tensor(Scalar s, const Device device = at::kCPU) {
|
|||
// This is the fast track we have for CPU scalar tensors.
|
||||
if (device == at::kCPU && !s.isComplex()) {
|
||||
if (s.isFloatingPoint()) {
|
||||
return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kDouble));
|
||||
return at::detail::scalar_tensor_static(s, at::kDouble, c10::nullopt, at::kCPU, c10::nullopt, c10::nullopt);
|
||||
} else if (s.isBoolean()) {
|
||||
return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kBool));
|
||||
return at::detail::scalar_tensor_static(s, at::kBool, c10::nullopt, at::kCPU, c10::nullopt, c10::nullopt);
|
||||
} else {
|
||||
AT_ASSERT(s.isIntegral(false));
|
||||
return at::detail::scalar_tensor_static(s, at::device(at::kCPU).dtype(at::kLong));
|
||||
return at::detail::scalar_tensor_static(s, at::kLong, c10::nullopt, at::kCPU, c10::nullopt, c10::nullopt);
|
||||
}
|
||||
}
|
||||
if (s.isFloatingPoint()) {
|
||||
|
|
|
|||
|
|
@ -16,32 +16,24 @@ int _crash_if_asan(int arg) {
|
|||
|
||||
namespace detail {
|
||||
// empty_cpu is used in ScalarOps.h, which can be referenced by other ATen files. Since we want to decouple direct referencing native symbols and only access native symbols through dispatching, we move its implementation here.
|
||||
Tensor empty_cpu(
|
||||
IntArrayRef size,
|
||||
const TensorOptions& options,
|
||||
c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
!(options.has_memory_format() && optional_memory_format.has_value()),
|
||||
"Cannot set memory_format both in TensorOptions and explicit argument; please delete "
|
||||
"the redundant setter.");
|
||||
const MemoryFormat memory_format =
|
||||
optional_memory_format.value_or(
|
||||
options.memory_format_opt().value_or(
|
||||
MemoryFormat::Contiguous));
|
||||
Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
Device device = device_or_default(device_opt);
|
||||
|
||||
AT_ASSERT(options.device().type() == DeviceType::CPU);
|
||||
TORCH_CHECK(device.type() == DeviceType::CPU);
|
||||
check_size_nonnegative(size);
|
||||
|
||||
bool pin_memory = pinned_memory_or_default(pin_memory_opt);
|
||||
c10::Allocator* allocator;
|
||||
if (options.pinned_memory()) {
|
||||
if (pin_memory) {
|
||||
allocator = detail::getCUDAHooks().getPinnedMemoryAllocator();
|
||||
} else {
|
||||
allocator = at::getCPUAllocator();
|
||||
}
|
||||
|
||||
int64_t nelements = prod_intlist(size);
|
||||
const caffe2::TypeMeta dtype = options.dtype();
|
||||
const int64_t size_bytes = nelements * dtype.itemsize();
|
||||
caffe2::TypeMeta dtype = scalarTypeToTypeMeta(dtype_or_default(dtype_opt));
|
||||
int64_t size_bytes = nelements * dtype.itemsize();
|
||||
auto storage_impl = c10::make_intrusive<StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size_bytes,
|
||||
|
|
@ -56,6 +48,7 @@ Tensor empty_cpu(
|
|||
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
|
||||
}
|
||||
|
||||
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
|
||||
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
|
||||
|
||||
return tensor;
|
||||
|
|
|
|||
|
|
@ -136,10 +136,8 @@ inline void check_size_nonnegative(IntArrayRef size) {
|
|||
|
||||
namespace detail {
|
||||
CAFFE2_API
|
||||
Tensor empty_cpu(
|
||||
IntArrayRef size,
|
||||
const TensorOptions& options = {},
|
||||
c10::optional<MemoryFormat> memory_format = c10::nullopt);
|
||||
Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt);
|
||||
} // namespace detail
|
||||
|
||||
} // at
|
||||
|
|
|
|||
|
|
@ -7,34 +7,29 @@ namespace native {
|
|||
// Will be promoted to a public API later, but not now
|
||||
Tensor empty_meta(
|
||||
IntArrayRef size,
|
||||
const TensorOptions& options_,
|
||||
c10::optional<c10::MemoryFormat> optional_memory_format
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory,
|
||||
c10::optional<c10::MemoryFormat> memory_format
|
||||
) {
|
||||
TORCH_CHECK(
|
||||
!(options_.has_memory_format() && optional_memory_format.has_value()),
|
||||
"Cannot set memory_format both in TensorOptions and explicit argument; please delete "
|
||||
"the redundant setter.");
|
||||
TensorOptions options = options_.merge_memory_format(optional_memory_format);
|
||||
|
||||
// TODO: deduplicate this logic with empty_cpu
|
||||
|
||||
auto dtype = options.dtype();
|
||||
auto device = options.device();
|
||||
auto tensor = detail::make_tensor<TensorImpl>(
|
||||
// NB: We include the computed dispatch key, not because it will actually
|
||||
// participate in dispatch, but so that tests like is_sparse/is_cuda
|
||||
// give the correct result (a CUDA meta tensor "is cuda"). If we don't
|
||||
// like this, remove the computeDispatchKey line
|
||||
DispatchKeySet{DispatchKey::Meta, options.computeDispatchKey()},
|
||||
dtype,
|
||||
DispatchKeySet{DispatchKey::Meta, computeDispatchKey(dtype, layout, device)},
|
||||
scalarTypeToTypeMeta(dtype_or_default(dtype)),
|
||||
device
|
||||
);
|
||||
if (size.size() != 1 || size[0] != 0) {
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
|
||||
}
|
||||
|
||||
auto memory_format = options.memory_format_opt().value_or(MemoryFormat::Contiguous);
|
||||
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
|
||||
auto memory_format_ = memory_format.value_or(MemoryFormat::Contiguous);
|
||||
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format_);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -165,8 +165,9 @@ Tensor polar(const Tensor& abs, const Tensor& angle) {
|
|||
}
|
||||
|
||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ empty ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
Tensor empty_cpu(IntArrayRef size, const TensorOptions& options_, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
return at::detail::empty_cpu(size, options_, optional_memory_format);
|
||||
Tensor empty_cpu(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
return at::detail::empty_cpu(size, dtype_opt, layout_opt, device_opt, pin_memory_opt, memory_format_opt);
|
||||
}
|
||||
|
||||
Tensor empty(
|
||||
|
|
@ -186,9 +187,10 @@ Tensor empty(
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, const TensorOptions& options) {
|
||||
Tensor empty_strided_cpu(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
check_size_nonnegative(size);
|
||||
auto t = at::native::empty_cpu({0}, options);
|
||||
auto t = at::native::empty_cpu({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
at::native::resize_impl_cpu_(t.unsafeGetTensorImpl(), size, stride);
|
||||
return t;
|
||||
}
|
||||
|
|
@ -336,9 +338,16 @@ Tensor empty_like(
|
|||
Tensor new_empty(
|
||||
const Tensor& self,
|
||||
IntArrayRef size,
|
||||
const TensorOptions& options
|
||||
c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt,
|
||||
c10::optional<Device> device_opt,
|
||||
c10::optional<bool> pin_memory_opt
|
||||
) {
|
||||
return at::empty(size, self.options().merge_in(options));
|
||||
auto dtype = dtype_opt.has_value() ? dtype_opt : optTypeMetaToScalarType(self.options().dtype_opt());
|
||||
auto layout = layout_opt.has_value() ? layout_opt : self.options().layout_opt();
|
||||
auto device = device_opt.has_value() ? device_opt : self.options().device_opt();
|
||||
auto pin_memory = pin_memory_opt.has_value() ? pin_memory_opt : self.options().pinned_memory_opt();
|
||||
return at::empty(size, dtype, layout, device, pin_memory, c10::nullopt);
|
||||
}
|
||||
|
||||
Tensor new_empty_strided(
|
||||
|
|
@ -507,7 +516,7 @@ Tensor scalar_tensor(Scalar s, const TensorOptions& options) {
|
|||
// auto result = at::empty({}, options);
|
||||
at::tracer::impl::NoTracerDispatchMode tracer_guard;
|
||||
at::AutoNonVariableTypeMode non_var_type_mode(true);
|
||||
auto result = empty_cpu({}, options);
|
||||
auto result = empty_cpu({}, optTypeMetaToScalarType(options.dtype_opt()), options.layout_opt(), options.device_opt(), options.pinned_memory_opt());
|
||||
at::native::fill_(result, s);
|
||||
return result;
|
||||
}
|
||||
|
|
@ -735,13 +744,14 @@ Tensor range(
|
|||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ triangle ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
||||
Tensor tril_indices_cpu(
|
||||
int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
|
||||
check_args(row, col, options);
|
||||
int64_t row, int64_t col, int64_t offset, c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
check_args(row, col, layout_opt);
|
||||
|
||||
auto tril_size = get_tril_size(row, col, offset);
|
||||
|
||||
// create an empty Tensor with correct size
|
||||
auto result = at::empty({2, tril_size}, options);
|
||||
auto result = at::native::empty_cpu({2, tril_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
|
||||
// The following three approaches result in very little performance
|
||||
// differences. Hence, the 2nd option is taken for simpler code, and to return
|
||||
|
|
@ -780,13 +790,14 @@ Tensor tril_indices_cpu(
|
|||
}
|
||||
|
||||
Tensor triu_indices_cpu(
|
||||
int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
|
||||
check_args(row, col, options);
|
||||
int64_t row, int64_t col, int64_t offset, c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
check_args(row, col, layout_opt);
|
||||
|
||||
auto triu_size = row * col - get_tril_size(row, col, offset - 1);
|
||||
|
||||
// create an empty Tensor with correct size
|
||||
auto result = at::empty({2, triu_size}, options);
|
||||
auto result = at::native::empty_cpu({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
|
||||
AT_DISPATCH_ALL_TYPES(result.scalar_type(), "triu_indices", [&]() -> void {
|
||||
// fill the Tensor with correct values
|
||||
|
|
|
|||
|
|
@ -50,14 +50,14 @@ inline int64_t get_tril_size(int64_t row, int64_t col, int64_t offset) {
|
|||
}
|
||||
|
||||
inline void check_args(
|
||||
int64_t row, int64_t col, const TensorOptions& options) {
|
||||
int64_t row, int64_t col, c10::optional<Layout> layout_opt) {
|
||||
TORCH_CHECK(row >= 0, "row must be non-negative, got", row);
|
||||
TORCH_CHECK(col >= 0, "col must be non-negative, got", col);
|
||||
if (options.has_layout()) {
|
||||
if (layout_opt.has_value()) {
|
||||
TORCH_CHECK(
|
||||
options.layout() == at::kStrided,
|
||||
*layout_opt == at::kStrided,
|
||||
"only support layout=torch.strided, got",
|
||||
options.layout())
|
||||
*layout_opt)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -882,7 +882,8 @@ void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
|
|||
//However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced.
|
||||
bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous();
|
||||
at::Tensor out_temp = need_to_copy ?
|
||||
at::native::empty_cuda({self.dim(), num_nonzeros_h}, out.options()) :
|
||||
at::native::empty_cuda({self.dim(), num_nonzeros_h}, optTypeMetaToScalarType(out.options().dtype_opt()),
|
||||
out.options().layout_opt(), out.options().device_opt(), out.options().pinned_memory_opt()) :
|
||||
out.resize_({self.dim(), num_nonzeros_h});
|
||||
//Scalars are expected to produce output of size (1,0), so we can't write to it
|
||||
if (self.dim() > 0) {
|
||||
|
|
@ -931,7 +932,7 @@ Tensor& nonzero_out_cuda(Tensor& out, const Tensor& self){
|
|||
}
|
||||
|
||||
Tensor nonzero_cuda(const Tensor& self){
|
||||
Tensor out = at::native::empty_cuda({0}, self.options().dtype(kLong));
|
||||
Tensor out = at::native::empty_cuda({0}, kLong, self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt());
|
||||
return nonzero_out_cuda(out, self);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -322,7 +322,9 @@ void multinomial_kernel_impl(Tensor& result, const Tensor& self, const int64_t n
|
|||
// To exploit greater parallelism for the sampling, generate the
|
||||
// Uniform random samples in a separate kernel launch, into
|
||||
// temporarily allocated memory. The device RNG is thread-limited
|
||||
Tensor sampled = native::empty_cuda({numDist, n_sample}, self_v.options());
|
||||
Tensor sampled = native::empty_cuda({numDist, n_sample}, optTypeMetaToScalarType(self_v.options().dtype_opt()),
|
||||
self_v.options().layout_opt(), self_v.options().device_opt(),
|
||||
self_v.options().pinned_memory_opt());
|
||||
at::native::uniform_(sampled, 0.0, 1.0, generator);
|
||||
|
||||
dim3 block(numCategories < maxThreads ? numCategories : maxThreads);
|
||||
|
|
|
|||
|
|
@ -497,7 +497,8 @@ void scan_cub(const Tensor& self, Tensor& result, scalar_t init, BinaryFunction
|
|||
at::cuda::getCurrentCUDAStream()));
|
||||
auto temp_storage = at::native::empty_cuda(
|
||||
{static_cast<int64_t>(temp_storage_bytes)},
|
||||
self.options().dtype(kByte));
|
||||
kByte, self.options().layout_opt(), self.options().device_opt(),
|
||||
self.options().pinned_memory_opt());
|
||||
AT_CUDA_CHECK(cub::DeviceScan::InclusiveScan(
|
||||
temp_storage.data_ptr(),
|
||||
temp_storage_bytes,
|
||||
|
|
|
|||
|
|
@ -41,15 +41,16 @@ Tensor& eye_out_cuda(Tensor& result, int64_t n, int64_t m) {
|
|||
return result;
|
||||
}
|
||||
|
||||
Tensor empty_cuda(IntArrayRef size, const TensorOptions& options, c10::optional<MemoryFormat> optional_memory_format) {
|
||||
AT_ASSERT(options.device().type() == at::DeviceType::CUDA);
|
||||
TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned");
|
||||
Tensor empty_cuda(IntArrayRef size, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt, c10::optional<c10::MemoryFormat> memory_format_opt) {
|
||||
AT_ASSERT(device_or_default(device_opt).type() == at::DeviceType::CUDA);
|
||||
TORCH_CHECK(!pin_memory_opt.has_value() || !*pin_memory_opt, "Only dense CPU tensors can be pinned");
|
||||
check_size_nonnegative(size);
|
||||
|
||||
auto* allocator = at::cuda::getCUDADeviceAllocator();
|
||||
int64_t nelements = prod_intlist(size);
|
||||
auto dtype = options.dtype();
|
||||
int64_t size_bytes = nelements * dtype.itemsize();
|
||||
auto dtype = dtype_or_default(dtype_opt);
|
||||
auto dtype_meta = scalarTypeToTypeMeta(dtype);
|
||||
int64_t size_bytes = nelements * dtype_meta.itemsize();
|
||||
auto storage_impl = c10::make_intrusive<StorageImpl>(
|
||||
c10::StorageImpl::use_byte_size_t(),
|
||||
size_bytes,
|
||||
|
|
@ -58,23 +59,19 @@ Tensor empty_cuda(IntArrayRef size, const TensorOptions& options, c10::optional<
|
|||
/*resizeable=*/true);
|
||||
|
||||
auto tensor =
|
||||
detail::make_tensor<TensorImpl>(storage_impl, DispatchKey::CUDA, dtype);
|
||||
detail::make_tensor<TensorImpl>(storage_impl, DispatchKey::CUDA, dtype_meta);
|
||||
// Default TensorImpl has size [0]
|
||||
if (size.size() != 1 || size[0] != 0) {
|
||||
tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size);
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
!(options.has_memory_format() && optional_memory_format.has_value()),
|
||||
"Cannot set memory_format both in TensorOptions and explicit argument; please delete "
|
||||
"the redundant setter.");
|
||||
auto memory_format = options.memory_format_opt().value_or(optional_memory_format.value_or(MemoryFormat::Contiguous));
|
||||
auto memory_format = memory_format_opt.value_or(MemoryFormat::Contiguous);
|
||||
tensor.unsafeGetTensorImpl()->empty_tensor_restride(memory_format);
|
||||
return tensor;
|
||||
}
|
||||
|
||||
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, const TensorOptions& options) {
|
||||
auto t = at::native::empty_cuda({0}, options);
|
||||
Tensor empty_strided_cuda(IntArrayRef size, IntArrayRef stride, c10::optional<ScalarType> dtype_opt, c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
auto t = at::native::empty_cuda({0}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
at::native::resize_impl_cuda_(t.unsafeGetTensorImpl(), size, stride);
|
||||
return t;
|
||||
}
|
||||
|
|
@ -325,11 +322,12 @@ void tril_indices_kernel(scalar_t * tensor,
|
|||
// implementation, please enable them in test/test_cuda.py and make sure they
|
||||
// pass on your local server.
|
||||
Tensor tril_indices_cuda(
|
||||
int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
|
||||
check_args(row, col, options);
|
||||
int64_t row, int64_t col, int64_t offset, c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
check_args(row, col, layout_opt);
|
||||
|
||||
auto tril_size = get_tril_size(row, col, offset);
|
||||
auto tensor = empty_cuda({2, tril_size}, options);
|
||||
auto tensor = empty_cuda({2, tril_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
|
||||
if (tril_size > 0) {
|
||||
auto m_first_row = offset > 0 ?
|
||||
|
|
@ -399,11 +397,12 @@ void triu_indices_kernel(scalar_t * tensor,
|
|||
// implementation, please enable them in test/test_cuda.py and make sure they
|
||||
// pass on your local server.
|
||||
Tensor triu_indices_cuda(
|
||||
int64_t row, int64_t col, int64_t offset, const TensorOptions& options) {
|
||||
check_args(row, col, options);
|
||||
int64_t row, int64_t col, int64_t offset, c10::optional<ScalarType> dtype_opt,
|
||||
c10::optional<Layout> layout_opt, c10::optional<Device> device_opt, c10::optional<bool> pin_memory_opt) {
|
||||
check_args(row, col, layout_opt);
|
||||
|
||||
auto triu_size = row * col - get_tril_size(row, col, offset - 1);
|
||||
auto tensor = empty_cuda({2, triu_size}, options);
|
||||
auto tensor = empty_cuda({2, triu_size}, dtype_opt, layout_opt, device_opt, pin_memory_opt);
|
||||
|
||||
if (triu_size > 0) {
|
||||
// # of triu elements in the first row
|
||||
|
|
|
|||
|
|
@ -70,17 +70,20 @@ at::Tensor& metal_copy_impl_(at::Tensor& dst, const at::Tensor& src) {
|
|||
|
||||
Tensor empty(
|
||||
IntArrayRef size,
|
||||
const TensorOptions& options,
|
||||
optional<ScalarType> dtype,
|
||||
optional<Layout> layout,
|
||||
optional<Device> device,
|
||||
optional<bool> pin_memory,
|
||||
c10::optional<MemoryFormat> memory_format) {
|
||||
TORCH_CHECK(
|
||||
!options.has_pinned_memory(),
|
||||
!pin_memory.has_value(),
|
||||
"'pin_memory' argument is incompatible with Metal tensor");
|
||||
TORCH_CHECK(
|
||||
!options.has_memory_format() && !memory_format,
|
||||
!memory_format.has_value(),
|
||||
"'memory_format' argument is incompatible with Metal tensor");
|
||||
MetalTensor mt{size.vec()};
|
||||
return MetalTensor::toTensor(
|
||||
std::move(mt), at::device(at::kMetal).dtype(options.dtype()));
|
||||
std::move(mt), at::device(at::kMetal).dtype(dtype));
|
||||
};
|
||||
|
||||
at::Tensor empty_strided(
|
||||
|
|
@ -249,7 +252,7 @@ TORCH_LIBRARY_IMPL(aten, Metal, m) {
|
|||
m.impl("add.Tensor", TORCH_FN(add_Tensor));
|
||||
m.impl("add_.Tensor", TORCH_FN(add__Tensor));
|
||||
m.impl("addmm", TORCH_FN(addmm));
|
||||
m.impl_UNBOXED("empty.memory_format", empty);
|
||||
m.impl("empty.memory_format", empty);
|
||||
m.impl("empty_strided", TORCH_FN(empty_strided));
|
||||
m.impl("log_softmax.int", TORCH_FN(log_softmax_int));
|
||||
m.impl("max_pool2d", TORCH_FN(max_pool2d));
|
||||
|
|
|
|||
|
|
@ -68,7 +68,8 @@ Tensor mkldnn_add(const Tensor& self, const Tensor& other, Scalar alpha) {
|
|||
const std::vector<float> scales{1.0, alpha.to<float>()};
|
||||
ideep::sum::compute(scales, {x, y}, z);
|
||||
|
||||
return new_with_itensor_mkldnn(std::move(z), self.options());
|
||||
return new_with_itensor_mkldnn(std::move(z), optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().device_opt());
|
||||
}
|
||||
|
||||
Tensor& mkldnn_add_(Tensor& self, const Tensor& other, Scalar alpha) {
|
||||
|
|
@ -99,7 +100,9 @@ Tensor& mkldnn_mul_out(Tensor& result, const Tensor& self, const Tensor& other)
|
|||
}
|
||||
|
||||
Tensor mkldnn_mul(const Tensor& self, const Tensor& other) {
|
||||
Tensor result = empty_mkldnn(self.sizes(), self.options());
|
||||
Tensor result = empty_mkldnn(self.sizes(), optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().layout_opt(), self.options().device_opt(),
|
||||
self.options().pinned_memory_opt());
|
||||
return native::mkldnn_mul_out(result, self, other);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -123,10 +123,12 @@ Tensor mkldnn_convolution(
|
|||
groups);
|
||||
|
||||
if (input.is_mkldnn()) {
|
||||
return new_with_itensor_mkldnn(std::move(mkldnn_output), input.options());
|
||||
return new_with_itensor_mkldnn(std::move(mkldnn_output), optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||
input.options().device_opt());
|
||||
} else {
|
||||
return mkldnn_to_dense(
|
||||
new_with_itensor_mkldnn(std::move(mkldnn_output), input.options()));
|
||||
new_with_itensor_mkldnn(std::move(mkldnn_output), optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||
input.options().device_opt()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -150,7 +152,8 @@ Tensor mkldnn_convolution_backward_input(
|
|||
groups);
|
||||
|
||||
return mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_input),
|
||||
grad_output.options()));
|
||||
optTypeMetaToScalarType(grad_output.options().dtype_opt()),
|
||||
grad_output.options().device_opt()));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
|
||||
|
|
@ -188,9 +191,11 @@ std::tuple<Tensor, Tensor> mkldnn_convolution_backward_weights(
|
|||
|
||||
return std::make_tuple(
|
||||
mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_weight),
|
||||
grad_output.options())),
|
||||
optTypeMetaToScalarType(grad_output.options().dtype_opt()),
|
||||
grad_output.options().device_opt())),
|
||||
mkldnn_to_dense(new_with_itensor_mkldnn(std::move(mkldnn_grad_bias),
|
||||
grad_output.options())));
|
||||
optTypeMetaToScalarType(grad_output.options().dtype_opt()),
|
||||
grad_output.options().device_opt())));
|
||||
}
|
||||
|
||||
std::tuple<Tensor, Tensor, Tensor> mkldnn_convolution_backward(
|
||||
|
|
|
|||
|
|
@ -54,9 +54,11 @@ Tensor mkldnn_linear(
|
|||
output_size.push_back(weight.size(0));
|
||||
|
||||
if (self.dim() > 2) {
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options()).reshape(output_size);
|
||||
return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().device_opt()).reshape(output_size);
|
||||
}
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options());
|
||||
return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().device_opt());
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
|
|
|||
|
|
@ -40,14 +40,16 @@ using IDeepTensorWrapperPtr = c10::intrusive_ptr<IDeepTensorWrapper>;
|
|||
using MKLDNNTensorImpl = OpaqueTensorImpl<IDeepTensorWrapperPtr>;
|
||||
using MKLDNNTensor = Tensor;
|
||||
|
||||
Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options) {
|
||||
Tensor new_with_itensor_mkldnn(ideep::tensor&& it, c10::optional<ScalarType> dtype, c10::optional<Device> device) {
|
||||
// NOTE: int32_t dims from ideep::tensor but sizes needs int64_t
|
||||
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
|
||||
auto dims = it.get_dims();
|
||||
IDeepTensorWrapperPtr handle = c10::make_intrusive<IDeepTensorWrapper>(std::move(it));
|
||||
caffe2::TypeMeta dtype_ = scalarTypeToTypeMeta(dtype_or_default(dtype));
|
||||
Device device_ = device_or_default(device);
|
||||
return detail::make_tensor<MKLDNNTensorImpl>(
|
||||
DispatchKeySet(DispatchKey::MkldnnCPU),
|
||||
options.dtype(), options.device(), handle,
|
||||
dtype_, device_, handle,
|
||||
std::vector<int64_t>(dims.begin(), dims.end()));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
namespace at { namespace native {
|
||||
|
||||
// Construct aten MKL-DNN tensor given an ideep tensor
|
||||
Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options);
|
||||
Tensor new_with_itensor_mkldnn(ideep::tensor&& it, c10::optional<ScalarType> dtype, c10::optional<Device> device);
|
||||
|
||||
// Retrieve `ideep::tensor` from MKL-DNN tensor
|
||||
ideep::tensor& itensor_from_mkldnn(const Tensor& mkldnn_tensor);
|
||||
|
|
|
|||
|
|
@ -32,7 +32,9 @@ Tensor dense_to_mkldnn(const Tensor& cpu_tensor) {
|
|||
"Can't convert cpu tensor with the number of dimensions > 5");
|
||||
// TODO: consider to convert non-contiguous tensor to `ideep::tensor` directly.
|
||||
auto cpu_tensor_cont = cpu_tensor.contiguous();
|
||||
Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), cpu_tensor_cont.options());
|
||||
Tensor mkldnn_tensor = empty_mkldnn(cpu_tensor_cont.sizes(), optTypeMetaToScalarType(cpu_tensor_cont.options().dtype_opt()),
|
||||
cpu_tensor_cont.options().layout_opt(), cpu_tensor_cont.options().device_opt(),
|
||||
cpu_tensor_cont.options().pinned_memory_opt());
|
||||
ideep::tensor& dtensor = itensor_from_mkldnn(mkldnn_tensor);
|
||||
dtensor.feed_from(dtensor.get_dims(),
|
||||
ideep::tensor::data_type::f32,
|
||||
|
|
@ -79,7 +81,8 @@ Tensor mkldnn_reorder_conv2d_weight(
|
|||
result.init(desc);
|
||||
result.feed_from(w);
|
||||
|
||||
return new_with_itensor_mkldnn(std::move(result), self.options());
|
||||
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().device_opt());
|
||||
}
|
||||
|
||||
Tensor mkldnn_reorder_conv3d_weight(
|
||||
|
|
@ -105,7 +108,7 @@ Tensor mkldnn_reorder_conv3d_weight(
|
|||
result.init(desc);
|
||||
result.feed_from(w);
|
||||
|
||||
return new_with_itensor_mkldnn(std::move(result), self.options());
|
||||
return new_with_itensor_mkldnn(std::move(result), optTypeMetaToScalarType(self.options().dtype_opt()), self.options().device_opt());
|
||||
}
|
||||
|
||||
#else
|
||||
|
|
|
|||
|
|
@ -56,18 +56,24 @@ std::tuple<Tensor, Tensor, Tensor> mkldnn_batch_norm(
|
|||
// ideep::batch_normalization_forward_training::compute<AllocForMKLDNN>(
|
||||
// x, w, b, y, saved_mean, saved_var, m, v, momentum, eps);
|
||||
// return std::make_tuple(
|
||||
// new_with_itensor_mkldnn(std::move(y), input.options()),
|
||||
// new_with_itensor_mkldnn(std::move(saved_mean), input.options()),
|
||||
// new_with_itensor_mkldnn(std::move(saved_var), input.options()));
|
||||
// new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||
// input.options().device_opt()),
|
||||
// new_with_itensor_mkldnn(std::move(saved_mean), optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||
// input.options().device_opt()),
|
||||
// new_with_itensor_mkldnn(std::move(saved_var), optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||
// input.options().device_opt()));
|
||||
} else {
|
||||
TORCH_CHECK(input.dim() == 4 || input.dim() == 5,
|
||||
"mkldnn_batch_norm: currently mkldnn only support 2d and 3d batchnorm");
|
||||
ideep::batch_normalization_forward_inference::compute(
|
||||
x, m, v, w, b, y, eps);
|
||||
return std::make_tuple(
|
||||
new_with_itensor_mkldnn(std::move(y), input.options()),
|
||||
new_with_itensor_mkldnn(ideep::tensor{}, input.options()),
|
||||
new_with_itensor_mkldnn(ideep::tensor{}, input.options()));
|
||||
new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||
input.options().device_opt()),
|
||||
new_with_itensor_mkldnn(ideep::tensor{}, optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||
input.options().device_opt()),
|
||||
new_with_itensor_mkldnn(ideep::tensor{}, optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||
input.options().device_opt()));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ static Tensor _mkldnn_pooling(
|
|||
algo,
|
||||
ideep::prop_kind::forward);
|
||||
|
||||
return new_with_itensor_mkldnn(std::move(y), input.options());
|
||||
return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()), input.options().device_opt());
|
||||
}
|
||||
|
||||
Tensor mkldnn_max_pool2d(
|
||||
|
|
|
|||
|
|
@ -28,7 +28,8 @@ Tensor mkldnn_relu(const Tensor& input) {
|
|||
ideep::tensor y;
|
||||
ideep::eltwise_forward::compute(
|
||||
x, y, ideep::algorithm::eltwise_relu, ideep::prop_kind::forward_training, /*alpha*/ 0.0);
|
||||
return new_with_itensor_mkldnn(std::move(y), input.options());
|
||||
return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(input.options().dtype_opt()),
|
||||
input.options().device_opt());
|
||||
}
|
||||
|
||||
Tensor& mkldnn_relu_(Tensor& input) {
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ Tensor mkldnn_softmax(
|
|||
ideep::tensor& x = itensor_from_mkldnn(self);
|
||||
ideep::tensor y;
|
||||
ideep::softmax_forward::compute(x, y, wrapped_dim);
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options());
|
||||
return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().device_opt());
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
|
|
|
|||
|
|
@ -4,10 +4,7 @@ namespace at { namespace native {
|
|||
|
||||
#if AT_MKLDNN_ENABLED()
|
||||
|
||||
Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
!options.has_memory_format(),
|
||||
"'memory_format' argument is incompatible with mkldnn tensor");
|
||||
Tensor empty_mkldnn(IntArrayRef sizes, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(
|
||||
!optional_memory_format.has_value(),
|
||||
"'memory_format' argument is incompatible with mkldnn tensor");
|
||||
|
|
@ -15,12 +12,12 @@ Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::option
|
|||
// TODO: support int64_t dims in ideep::tensor to avoid extra conversion
|
||||
ideep::tensor::dims dst_dims (sizes.begin(), sizes.end());
|
||||
ideep::tensor it {dst_dims, ideep::tensor::data_type::f32};
|
||||
return new_with_itensor_mkldnn(std::move(it), options);
|
||||
return new_with_itensor_mkldnn(std::move(it), dtype, device);
|
||||
}
|
||||
|
||||
#else
|
||||
|
||||
Tensor empty_mkldnn(IntArrayRef sizes, const TensorOptions& options, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
Tensor empty_mkldnn(IntArrayRef sizes, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(false, "empty_mkldnn: MKL-DNN build is disabled");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,8 @@ Tensor mkldnn_reshape(const Tensor& self, IntArrayRef size) {
|
|||
const ideep::tensor& x = itensor_from_mkldnn(self);
|
||||
ideep::tensor y{x};
|
||||
y.reshape(inferred_size);
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options());
|
||||
return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().device_opt());
|
||||
}
|
||||
|
||||
Tensor mkldnn_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
|
|
@ -62,7 +63,8 @@ Tensor mkldnn_clone(const Tensor& self, c10::optional<c10::MemoryFormat> optiona
|
|||
ideep::tensor& src = itensor_from_mkldnn(self);
|
||||
ideep::tensor dst;
|
||||
ideep::direct_copy::compute(src, dst);
|
||||
return new_with_itensor_mkldnn(std::move(dst), self.options());
|
||||
return new_with_itensor_mkldnn(std::move(dst), optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().device_opt());
|
||||
}
|
||||
|
||||
Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
|
||||
|
|
@ -72,7 +74,8 @@ Tensor mkldnn_transpose(const Tensor& self, int64_t dim0, int64_t dim1) {
|
|||
std::iota(axes.begin(), axes.end(), 0);
|
||||
std::swap(axes[dim0], axes[dim1]);
|
||||
y.transpose_from(x, axes);
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options());
|
||||
return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().device_opt());
|
||||
}
|
||||
|
||||
Tensor& mkldnn_transpose_(Tensor& self, int64_t dim0, int64_t dim1) {
|
||||
|
|
|
|||
|
|
@ -30,7 +30,8 @@ Tensor mkldnn_sigmoid(const Tensor& self) {
|
|||
ideep::tensor y;
|
||||
ideep::eltwise_forward::compute(
|
||||
x, y, ideep::algorithm::eltwise_logistic, ideep::prop_kind::forward);
|
||||
return new_with_itensor_mkldnn(std::move(y), self.options());
|
||||
return new_with_itensor_mkldnn(std::move(y), optTypeMetaToScalarType(self.options().dtype_opt()),
|
||||
self.options().device_opt());
|
||||
}
|
||||
|
||||
Tensor& mkldnn_sigmoid_(Tensor& self) {
|
||||
|
|
|
|||
|
|
@ -1612,13 +1612,13 @@
|
|||
CUDA: _embedding_bag_per_sample_weights_backward_cuda
|
||||
|
||||
- func: empty_meta(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
#use_c10_dispatcher: full
|
||||
use_c10_dispatcher: full
|
||||
|
||||
- func: empty.names(int[] size, *, Dimname[]? names, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
device_guard: False
|
||||
|
||||
- func: empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
|
||||
#use_c10_dispatcher: full
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CPU: empty_cpu
|
||||
CUDA: empty_cuda
|
||||
|
|
@ -1626,7 +1626,7 @@
|
|||
SparseCPU, SparseCUDA: empty_sparse
|
||||
|
||||
- func: new_empty(Tensor self, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
#use_c10_dispatcher: full
|
||||
use_c10_dispatcher: full
|
||||
variants: method
|
||||
|
||||
- func: new_empty_strided(Tensor self, int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
|
|
@ -1679,7 +1679,7 @@
|
|||
device_guard: False
|
||||
|
||||
- func: empty_strided(int[] size, int[] stride, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CPU: empty_strided_cpu
|
||||
CUDA: empty_strided_cuda
|
||||
|
|
@ -4595,12 +4595,12 @@
|
|||
use_c10_dispatcher: full
|
||||
|
||||
- func: _sparse_coo_tensor_with_dims(int sparse_dim, int dense_dim, int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: new_with_dims_sparse
|
||||
|
||||
- func: _sparse_coo_tensor_with_dims_and_tensors(int sparse_dim, int dense_dim, int[] size, Tensor indices, Tensor values, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=False) -> Tensor
|
||||
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
SparseCPU, SparseCUDA: new_with_dims_and_tensor_sparse
|
||||
|
||||
|
|
@ -5680,13 +5680,13 @@
|
|||
DefaultBackend: tril
|
||||
|
||||
- func: tril_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CPU: tril_indices_cpu
|
||||
CUDA: tril_indices_cuda
|
||||
|
||||
- func: triu_indices(int row, int col, int offset=0, *, ScalarType? dtype=long, Layout? layout=None, Device? device=None, bool? pin_memory=None) -> Tensor
|
||||
use_c10_dispatcher: hacky_wrapper_for_legacy_signatures
|
||||
use_c10_dispatcher: full
|
||||
dispatch:
|
||||
CPU: triu_indices_cpu
|
||||
CUDA: triu_indices_cuda
|
||||
|
|
|
|||
|
|
@ -70,22 +70,23 @@ Tensor values_sparse(const Tensor& self) {
|
|||
|
||||
/*** Helper methods ***/
|
||||
|
||||
SparseTensor new_sparse(const TensorOptions& options) {
|
||||
AT_ASSERT(options.layout() == kSparse);
|
||||
SparseTensor new_sparse(c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory) {
|
||||
AT_ASSERT(layout.has_value() && *layout == kSparse);
|
||||
DispatchKey dispatch_key;
|
||||
if (options.device().is_cuda()) {
|
||||
if (device_or_default(device).is_cuda()) {
|
||||
dispatch_key = DispatchKey::SparseCUDA;
|
||||
} else {
|
||||
dispatch_key = DispatchKey::SparseCPU;
|
||||
}
|
||||
return detail::make_tensor<SparseTensorImpl>(
|
||||
DispatchKeySet(dispatch_key), options.dtype());
|
||||
DispatchKeySet(dispatch_key), scalarTypeToTypeMeta(dtype_or_default(dtype)));
|
||||
}
|
||||
|
||||
/** Actual dispatched creation methods ***/
|
||||
|
||||
SparseTensor new_with_dims_sparse(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size, const TensorOptions& options) {
|
||||
SparseTensor self = new_sparse(options);
|
||||
SparseTensor new_with_dims_sparse(int64_t sparse_dim, int64_t dense_dim, ArrayRef<int64_t> size, c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory) {
|
||||
SparseTensor self = new_sparse(dtype, layout, device, pin_memory);
|
||||
get_sparse_impl(self)->resize_and_clear_(sparse_dim, dense_dim, size);
|
||||
return self;
|
||||
}
|
||||
|
|
@ -96,8 +97,11 @@ SparseTensor new_with_dims_and_tensor_sparse(
|
|||
ArrayRef<int64_t> size,
|
||||
const LongTensor& indices,
|
||||
const Tensor& values,
|
||||
const TensorOptions& options) {
|
||||
SparseTensor self = new_sparse(options);
|
||||
c10::optional<ScalarType> dtype,
|
||||
c10::optional<Layout> layout,
|
||||
c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
SparseTensor self = new_sparse(dtype, layout, device, pin_memory);
|
||||
get_sparse_impl(self)->resize_(sparse_dim, dense_dim, size);
|
||||
// NOTE: There is no guarantee that `indices` and `values` don't contain AutogradMeta. However,
|
||||
// we want to maintain the invariant that `indices_` and `values_` of a sparse tensor don't
|
||||
|
|
@ -115,9 +119,9 @@ SparseTensor new_with_dims_and_tensor_sparse(
|
|||
/** Public creation API that dispatch to methods above **/
|
||||
|
||||
/** Empty init **/
|
||||
Tensor empty_sparse(IntArrayRef size, const TensorOptions& options, c10::optional<MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(!options.pinned_memory(), "Only dense CPU tensors can be pinned");
|
||||
return new_with_dims_sparse(size.size(), 0, size, options);
|
||||
Tensor empty_sparse(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<MemoryFormat> optional_memory_format) {
|
||||
TORCH_CHECK(!pin_memory.has_value() || !*pin_memory, "Only dense CPU tensors can be pinned");
|
||||
return new_with_dims_sparse(size.size(), 0, size, dtype, layout, device, pin_memory);
|
||||
}
|
||||
|
||||
/* Shape init */
|
||||
|
|
@ -260,7 +264,9 @@ SparseTensor clone_sparse(const SparseTensor& self, c10::optional<c10::MemoryFor
|
|||
!optional_memory_format.has_value(),
|
||||
"unsupported memory format option ",
|
||||
optional_memory_format.value());
|
||||
SparseTensor other = new_with_dims_sparse(self.sparse_dim(), self.dense_dim(), self.sizes(), self.options());
|
||||
SparseTensor other = new_with_dims_sparse(self.sparse_dim(), self.dense_dim(), self.sizes(),
|
||||
optTypeMetaToScalarType(self.options().dtype_opt()), self.options().layout_opt(),
|
||||
self.options().device_opt(), self.options().pinned_memory_opt());
|
||||
copy_into_sparse(other, self._indices(), self._values(), true);
|
||||
return other._coalesced_(self.is_coalesced());
|
||||
}
|
||||
|
|
@ -309,7 +315,7 @@ SparseTensor dense_to_sparse(const Tensor& self, int64_t sparse_dim){
|
|||
|
||||
Tensor nz = self.nonzero().transpose(0, 1);
|
||||
if (nz.size(1) == 0) {
|
||||
return new_with_dims_sparse(sparse_dim, dims - sparse_dim, sizes, sparse_options);
|
||||
return new_with_dims_sparse(sparse_dim, dims - sparse_dim, sizes, optTypeMetaToScalarType(sparse_options.dtype_opt()), sparse_options.layout_opt(), sparse_options.device_opt(), sparse_options.pinned_memory_opt());
|
||||
}
|
||||
LongTensor indices;
|
||||
if (sparse_dim == dims) {
|
||||
|
|
@ -376,7 +382,7 @@ SparseTensor coalesce_sparse_cpu(const SparseTensor& self) {
|
|||
|
||||
LongTensor indices_scalar = flatten_indices(indices, self.sizes());
|
||||
|
||||
SparseTensor dst = new_sparse(self.options());
|
||||
SparseTensor dst = new_sparse(optTypeMetaToScalarType(self.options().dtype_opt()), self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt());
|
||||
get_sparse_impl(dst)->resize_(sparse_dim, dense_dim, self.sizes());
|
||||
// TODO: is there a more idiomatic way to do this?
|
||||
LongTensor newIndices = at::empty(indices.sizes(), indices.options());
|
||||
|
|
|
|||
|
|
@ -55,17 +55,20 @@ VulkanTensor& vtensor_from_vulkan(Tensor& tensor) {
|
|||
|
||||
Tensor empty(
|
||||
IntArrayRef size,
|
||||
const TensorOptions& options,
|
||||
optional<ScalarType> dtype,
|
||||
optional<Layout> layout,
|
||||
optional<Device> device,
|
||||
optional<bool> pin_memory,
|
||||
const optional<MemoryFormat> memory_format) {
|
||||
TORCH_CHECK(
|
||||
!options.pinned_memory(),
|
||||
!pin_memory.has_value(),
|
||||
"'pin_memory' argument is incompatible with Vulkan tensor");
|
||||
TORCH_CHECK(
|
||||
!options.has_memory_format() && !memory_format,
|
||||
!memory_format.has_value(),
|
||||
"'memory_format' argument is incompatible with Vulkan tensor");
|
||||
VulkanTensor vt{size.vec()};
|
||||
return new_with_vtensor_vulkan(
|
||||
std::move(vt), at::device(at::kVulkan).dtype(options.dtype()));
|
||||
std::move(vt), at::device(at::kVulkan).dtype(dtype));
|
||||
}
|
||||
|
||||
Tensor empty_strided(
|
||||
|
|
@ -76,7 +79,7 @@ Tensor empty_strided(
|
|||
optional<Device> device,
|
||||
optional<bool> pin_memory) {
|
||||
return vulkan::aten::empty(
|
||||
size, TensorOptions().dtype(dtype).layout(layout).device(device).pinned_memory(pin_memory), c10::nullopt);
|
||||
size, dtype, layout, device, pin_memory, c10::nullopt);
|
||||
}
|
||||
|
||||
Tensor upsample_nearest2d(
|
||||
|
|
@ -548,7 +551,7 @@ TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
|
|||
m.impl_UNBOXED("transpose_", at::native::vulkan::aten::transpose_);
|
||||
m.impl("view", TORCH_FN(at::native::vulkan::aten::view));
|
||||
m.impl("unsqueeze", TORCH_FN(at::native::vulkan::aten::unsqueeze));
|
||||
m.impl_UNBOXED("empty.memory_format", at::native::vulkan::aten::empty);
|
||||
m.impl("empty.memory_format", at::native::vulkan::aten::empty);
|
||||
m.impl("empty_strided", TORCH_FN(at::native::vulkan::aten::empty_strided));
|
||||
m.impl("add.Tensor", TORCH_FN(at::native::vulkan::aten::add));
|
||||
m.impl("clamp", TORCH_FN(at::native::vulkan::aten::clamp));
|
||||
|
|
|
|||
|
|
@ -370,14 +370,8 @@ TEST(BasicTest, FactoryMethodsTest) {
|
|||
ASSERT_FALSE(tensor0.is_pinned());
|
||||
|
||||
// Test setting requires_grad to true.
|
||||
tensor0 = at::empty({4}, at::TensorOptions().requires_grad(true));
|
||||
ASSERT_EQ(tensor0.dtype(), at::kFloat);
|
||||
ASSERT_EQ(tensor0.layout(), at::kStrided);
|
||||
ASSERT_EQ(tensor0.device(), at::kCPU);
|
||||
// This is a bug. Requires_grad was set to TRUE but this is being ignored.
|
||||
// Issue https://github.com/pytorch/pytorch/issues/30405
|
||||
ASSERT_FALSE(tensor0.requires_grad());
|
||||
ASSERT_FALSE(tensor0.is_pinned());
|
||||
// This is a bug. Requires_grad was set to TRUE but this is not implemented.
|
||||
EXPECT_ANY_THROW(at::empty({4}, at::TensorOptions().requires_grad(true)));
|
||||
|
||||
// Test setting dtype
|
||||
at::Tensor tensor1 = at::empty({4}, at::TensorOptions().dtype(at::kHalf));
|
||||
|
|
|
|||
|
|
@ -10,7 +10,8 @@ using namespace at;
|
|||
|
||||
static int test_int;
|
||||
|
||||
Tensor empty_override(IntArrayRef size, const TensorOptions & options, c10::optional<MemoryFormat> optional_memory_format) {
|
||||
Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout,
|
||||
c10::optional<Device> device, c10::optional<bool> pin_memory, c10::optional<MemoryFormat> optional_memory_format) {
|
||||
test_int = 1;
|
||||
auto tensor_impl = c10::make_intrusive<TensorImpl, UndefinedTensorImpl>(
|
||||
Storage(
|
||||
|
|
@ -37,13 +38,13 @@ Tensor empty_strided_override(
|
|||
c10::optional<c10::Device> device,
|
||||
c10::optional<bool> pin_memory) {
|
||||
|
||||
return empty_override(size, at::kMSNPU, c10::nullopt);
|
||||
return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt);
|
||||
}
|
||||
|
||||
TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
|
||||
m.impl_UNBOXED("aten::empty.memory_format", empty_override);
|
||||
m.impl_UNBOXED("aten::empty_strided", empty_strided_override);
|
||||
m.impl_UNBOXED("aten::add.Tensor", add_override);
|
||||
m.impl("aten::empty.memory_format", empty_override);
|
||||
m.impl("aten::empty_strided", empty_strided_override);
|
||||
m.impl("aten::add.Tensor", add_override);
|
||||
}
|
||||
|
||||
TEST(BackendExtensionTest, TestRegisterOp) {
|
||||
|
|
|
|||
|
|
@ -22,23 +22,23 @@ namespace c10 {
|
|||
DispatchKey computeDispatchKey(c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device);
|
||||
|
||||
inline ScalarType dtype_or_default(c10::optional<ScalarType> dtype) {
|
||||
return dtype.has_value() ? *dtype : get_default_dtype_as_scalartype();
|
||||
return value_or_else(dtype, [] {return get_default_dtype_as_scalartype();});
|
||||
}
|
||||
|
||||
inline caffe2::TypeMeta dtype_or_default(c10::optional<caffe2::TypeMeta> dtype) {
|
||||
return dtype.has_value() ? *dtype : get_default_dtype();
|
||||
return value_or_else(dtype, [] {return get_default_dtype();});
|
||||
}
|
||||
|
||||
inline Layout layout_or_default(c10::optional<Layout> layout) {
|
||||
return layout.has_value() ? *layout : kStrided;
|
||||
return layout.value_or(kStrided);
|
||||
}
|
||||
|
||||
inline Device device_or_default(c10::optional<Device> device) {
|
||||
return device.has_value() ? *device : Device(kCPU);
|
||||
return value_or_else(device, [] {return Device(kCPU);});
|
||||
}
|
||||
|
||||
inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
|
||||
return pinned_memory.has_value() ? *pinned_memory : false;
|
||||
return pinned_memory.value_or(false);
|
||||
}
|
||||
|
||||
/// A class to encapsulate construction axes of an Tensor. TensorOptions was
|
||||
|
|
@ -121,6 +121,8 @@ inline bool pinned_memory_or_default(c10::optional<bool> pinned_memory) {
|
|||
/// To get around this, we templatize the `Device` constructor. Since overload
|
||||
/// resolution is done before template resolution, our problem is solved.
|
||||
|
||||
DispatchKey computeDispatchKey(optional<ScalarType> dtype, optional<Layout> layout, optional<Device> device);
|
||||
|
||||
|
||||
struct C10_API TensorOptions {
|
||||
TensorOptions()
|
||||
|
|
@ -402,7 +404,7 @@ struct C10_API TensorOptions {
|
|||
return DispatchKeySet(computeDispatchKey());
|
||||
}
|
||||
|
||||
inline DispatchKey computeDispatchKey() const {
|
||||
DispatchKey computeDispatchKey() const {
|
||||
return c10::computeDispatchKey(optTypeMetaToScalarType(dtype_opt()), layout_opt(), device_opt());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@
|
|||
#include <type_traits>
|
||||
#include <utility>
|
||||
|
||||
#include <c10/util/Metaprogramming.h>
|
||||
|
||||
#define TR2_OPTIONAL_REQUIRES(...) \
|
||||
typename std::enable_if<__VA_ARGS__::value, bool>::type = false
|
||||
|
||||
|
|
@ -643,6 +645,22 @@ class optional : private OptionalBase<T> {
|
|||
}
|
||||
};
|
||||
|
||||
template <class T, class F>
|
||||
constexpr T value_or_else(const optional<T>& v, F&& func) {
|
||||
static_assert(std::is_convertible<typename guts::infer_function_traits_t<F>::return_type, T>::value,
|
||||
"func parameters must be a callable that returns a type convertible to the value stored in the optional");
|
||||
return v.has_value() ? *v : detail_::convert<T>(std::forward<F>(func)());
|
||||
}
|
||||
|
||||
template <class T, class F>
|
||||
constexpr T value_or_else(optional<T>&& v, F&& func) {
|
||||
static_assert(std::is_convertible<typename guts::infer_function_traits_t<F>::return_type, T>::value,
|
||||
"func parameters must be a callable that returns a type convertible to the value stored in the optional");
|
||||
return v.has_value()
|
||||
? constexpr_move(std::move(v).contained_val())
|
||||
: detail_::convert<T>(std::forward<F>(func)());
|
||||
}
|
||||
|
||||
|
||||
// XXX: please refrain from using optional<T&>, since it is being against with
|
||||
// the optional standard in c++ 17, see the debate and the details here:
|
||||
|
|
|
|||
|
|
@ -20,9 +20,10 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) {
|
|||
return Tensor(std::move(tensor_impl));
|
||||
}
|
||||
|
||||
Tensor empty_override(IntArrayRef size, const TensorOptions& options, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::optional<Layout> layout, c10::optional<Device> device,
|
||||
c10::optional<bool> pin_memory, c10::optional<c10::MemoryFormat> optional_memory_format) {
|
||||
test_int = 0;
|
||||
return get_tensor(options.dtype(), size);
|
||||
return get_tensor(scalarTypeToTypeMeta(dtype_or_default(dtype)), size);
|
||||
}
|
||||
|
||||
Tensor add_override(const Tensor & a, const Tensor & b , Scalar c) {
|
||||
|
|
|
|||
|
|
@ -399,6 +399,7 @@ def gen_variable_type_shard(out, aten_declarations, template_path, suffix, heade
|
|||
strategy = dispatch_strategy(declaration)
|
||||
if declaration['name'] not in MANUAL_AUTOGRAD and strategy == 'use_derived':
|
||||
body = emit_body(declaration)
|
||||
|
||||
type_definitions.append(METHOD_DEFINITION.substitute(
|
||||
declaration, type_definition_body=body, formals=formals))
|
||||
if declaration['use_c10_dispatcher'] in ['full', 'hacky_wrapper_for_legacy_signatures']:
|
||||
|
|
|
|||
|
|
@ -175,17 +175,18 @@ at::Tensor inferAndAlloc(
|
|||
}
|
||||
|
||||
auto at_type = data_type_to_aten(tv->getDataType().value());
|
||||
auto tensor_options =
|
||||
at::TensorOptions().dtype(at_type).device(options.device);
|
||||
|
||||
if (zero_init) {
|
||||
auto tensor_options =
|
||||
at::TensorOptions().dtype(at_type).device(options.device);
|
||||
c10::IntArrayRef isizes(sizes);
|
||||
return at::zeros(isizes, tensor_options);
|
||||
} else {
|
||||
c10::IntArrayRef isizes(sizes);
|
||||
// Non Variable type guard for empty_cuda call
|
||||
at::AutoNonVariableTypeMode non_variable_type_mode;
|
||||
return at::native::empty_cuda(isizes, tensor_options);
|
||||
return at::native::empty_cuda(
|
||||
isizes, at_type, c10::nullopt, options.device, c10::nullopt);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -411,18 +412,20 @@ std::vector<at::Tensor> FusionExecutor::runFusion(
|
|||
// take the short-cut for launch if we see a recorded input set again;
|
||||
launch_params = executor_entry->launch_params;
|
||||
for (size_t i = 0; i < executor_entry->output_sizes.size(); i++) {
|
||||
auto tensor_options = at::TensorOptions()
|
||||
.dtype(executor_entry->output_types[i])
|
||||
.device(options_.device);
|
||||
alloced_outputs.push_back(at::native::empty_cuda(
|
||||
executor_entry->output_sizes[i], tensor_options));
|
||||
executor_entry->output_sizes[i],
|
||||
executor_entry->output_types[i],
|
||||
c10::nullopt,
|
||||
options_.device,
|
||||
c10::nullopt));
|
||||
}
|
||||
for (size_t i = 0; i < executor_entry->empty_buffer_sizes.size(); i++) {
|
||||
auto tensor_options = at::TensorOptions()
|
||||
.dtype(executor_entry->empty_buffer_types[i])
|
||||
.device(options_.device);
|
||||
global_buffers.empty_buffers.push_back(at::native::empty_cuda(
|
||||
executor_entry->empty_buffer_sizes[i], tensor_options));
|
||||
executor_entry->empty_buffer_sizes[i],
|
||||
executor_entry->empty_buffer_types[i],
|
||||
c10::nullopt,
|
||||
options_.device,
|
||||
c10::nullopt));
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < executor_entry->zero_buffer_sizes.size(); i++) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user