mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add ScalarType argument to Type::options() (#19270)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19270 ghimport-source-id: a5ade6131f3260066c5750ea1fa9ed5c998bb791 Differential Revision: D14938707 Pulled By: li-roy fbshipit-source-id: 018fb3f01706531a06515d6d861e5683a455a705
This commit is contained in:
parent
a044ba1af5
commit
ab78449e8c
|
|
@ -176,6 +176,11 @@ CAFFE2_API TypeExtendedInterface& getType(const Tensor&);
|
|||
|
||||
CAFFE2_API Allocator* getCPUAllocator();
|
||||
|
||||
static inline DeprecatedTypeProperties& getNonVariableDeprecatedTypeProperties(Backend p, ScalarType s) {
|
||||
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
||||
p, s, /*is_variable*/false);
|
||||
}
|
||||
|
||||
static inline DeprecatedTypeProperties& CPU(ScalarType s) {
|
||||
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
||||
Backend::CPU, s, /*is_variable*/false);
|
||||
|
|
|
|||
|
|
@ -176,9 +176,8 @@ struct CAFFE2_API Type {
|
|||
return this != &other;
|
||||
}
|
||||
|
||||
/// Constructs the `TensorOptions` from a type and a `device_index`.
|
||||
TensorOptions options(int16_t device_index = -1) const {
|
||||
return TensorOptions().dtype(typeMeta())
|
||||
TensorOptions options(ScalarType s, int16_t device_index = -1) const {
|
||||
return TensorOptions().dtype(s)
|
||||
.device(device_type(), device_index)
|
||||
.layout(layout())
|
||||
.is_variable(is_variable());
|
||||
|
|
@ -186,20 +185,16 @@ struct CAFFE2_API Type {
|
|||
|
||||
/// Constructs the `TensorOptions` from a type and a Device. Asserts that
|
||||
/// the device type matches the device type of the type.
|
||||
TensorOptions options(c10::optional<Device> device_opt) const {
|
||||
TensorOptions options(ScalarType s, c10::optional<Device> device_opt) const {
|
||||
if (!device_opt.has_value()) {
|
||||
return options(-1);
|
||||
return options(s, -1);
|
||||
} else {
|
||||
Device device = device_opt.value();
|
||||
AT_ASSERT(device.type() == device_type());
|
||||
return options(device.index());
|
||||
return options(s, device.index());
|
||||
}
|
||||
}
|
||||
|
||||
operator TensorOptions() const {
|
||||
return options();
|
||||
}
|
||||
|
||||
// example
|
||||
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
|
||||
virtual Tensor abs(const Tensor & self) const = 0;
|
||||
|
|
|
|||
|
|
@ -1604,7 +1604,8 @@ def create_derived(backend_type_env, declarations):
|
|||
# e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the
|
||||
# ScalarType before constructing the scalar_tensor to avoid overflow checking.
|
||||
elif ret['type'] == 'accreal' or ret['type'] == 'real':
|
||||
return_scalar = 'return at::scalar_tensor(convert<${ScalarType}>(${call}), options());'
|
||||
return_scalar = ('return at::scalar_tensor(convert<${ScalarType}>(${call}), '
|
||||
'options(ScalarType::${ScalarName}));')
|
||||
case_body.append(CodeTemplate(return_scalar).substitute(case_env, call=call))
|
||||
else:
|
||||
# we using int64_t for long in the API, so correct it here...
|
||||
|
|
|
|||
|
|
@ -119,9 +119,8 @@ struct CAFFE2_API Type {
|
|||
return this != &other;
|
||||
}
|
||||
|
||||
/// Constructs the `TensorOptions` from a type and a `device_index`.
|
||||
TensorOptions options(int16_t device_index = -1) const {
|
||||
return TensorOptions().dtype(typeMeta())
|
||||
TensorOptions options(ScalarType s, int16_t device_index = -1) const {
|
||||
return TensorOptions().dtype(s)
|
||||
.device(device_type(), device_index)
|
||||
.layout(layout())
|
||||
.is_variable(is_variable());
|
||||
|
|
@ -129,20 +128,16 @@ struct CAFFE2_API Type {
|
|||
|
||||
/// Constructs the `TensorOptions` from a type and a Device. Asserts that
|
||||
/// the device type matches the device type of the type.
|
||||
TensorOptions options(c10::optional<Device> device_opt) const {
|
||||
TensorOptions options(ScalarType s, c10::optional<Device> device_opt) const {
|
||||
if (!device_opt.has_value()) {
|
||||
return options(-1);
|
||||
return options(s, -1);
|
||||
} else {
|
||||
Device device = device_opt.value();
|
||||
AT_ASSERT(device.type() == device_type());
|
||||
return options(device.index());
|
||||
return options(s, device.index());
|
||||
}
|
||||
}
|
||||
|
||||
operator TensorOptions() const {
|
||||
return options();
|
||||
}
|
||||
|
||||
// example
|
||||
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
|
||||
${pure_virtual_type_method_declarations}
|
||||
|
|
|
|||
|
|
@ -66,10 +66,10 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTypes) {
|
|||
options = TensorOptions(kInt);
|
||||
REQUIRE_OPTIONS(kCPU, -1, kInt, kStrided);
|
||||
|
||||
options = TensorOptions(getNonVariableType(Backend::SparseCPU, kFloat));
|
||||
options = TensorOptions(getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kFloat));
|
||||
REQUIRE_OPTIONS(kCPU, -1, kFloat, kSparse);
|
||||
|
||||
options = TensorOptions(getNonVariableType(Backend::SparseCPU, kByte));
|
||||
options = TensorOptions(getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kByte));
|
||||
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
|
||||
}
|
||||
|
||||
|
|
@ -77,7 +77,7 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTensors) {
|
|||
auto options = empty(5, kDouble).options();
|
||||
REQUIRE_OPTIONS(kCPU, -1, kDouble, kStrided);
|
||||
|
||||
options = empty(5, getNonVariableType(Backend::SparseCPU, kByte)).options();
|
||||
options = empty(5, getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kByte)).options();
|
||||
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,17 +42,17 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATypes_CUDA) {
|
|||
options = CUDA(kInt).options();
|
||||
REQUIRE_OPTIONS(kCUDA, -1, kInt, kStrided);
|
||||
|
||||
options = getNonVariableType(Backend::SparseCUDA, kFloat).options();
|
||||
options = getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options();
|
||||
REQUIRE_OPTIONS(kCUDA, -1, kFloat, kSparse);
|
||||
|
||||
options = getNonVariableType(Backend::SparseCUDA, kByte).options();
|
||||
options = getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kByte).options();
|
||||
REQUIRE_OPTIONS(kCUDA, -1, kByte, kSparse);
|
||||
|
||||
options = CUDA(kFloat).options(/*device=*/5);
|
||||
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kStrided);
|
||||
|
||||
options =
|
||||
getNonVariableType(Backend::SparseCUDA, kFloat).options(/*device=*/5);
|
||||
getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options(/*device=*/5);
|
||||
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse);
|
||||
}
|
||||
|
||||
|
|
@ -60,7 +60,7 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATensors_MultiCUDA) {
|
|||
auto options = empty(5, device(kCUDA).dtype(kDouble)).options();
|
||||
REQUIRE_OPTIONS(kCUDA, 0, kDouble, kStrided);
|
||||
|
||||
options = empty(5, getNonVariableType(Backend::SparseCUDA, kByte)).options();
|
||||
options = empty(5, getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kByte)).options();
|
||||
REQUIRE_OPTIONS(kCUDA, 0, kByte, kSparse);
|
||||
|
||||
if (torch::cuda::device_count() > 1) {
|
||||
|
|
|
|||
|
|
@ -35,13 +35,13 @@ struct TypeAndSize {
|
|||
/* implicit */
|
||||
TypeAndSize(const Tensor & t)
|
||||
: sizes(t.sizes().vec())
|
||||
, type(&t.dispatch_type()) {}
|
||||
, type(&t.type()) {}
|
||||
|
||||
Tensor zeros() { return at::zeros(sizes, *type); }
|
||||
|
||||
private:
|
||||
std::vector<int64_t> sizes;
|
||||
Type* type;
|
||||
at::DeprecatedTypeProperties* type;
|
||||
};
|
||||
|
||||
${autograd_function_declarations}
|
||||
|
|
|
|||
|
|
@ -334,7 +334,7 @@ static variable_list call_post_hooks(Function& fn, variable_list outputs, const
|
|||
return outputs;
|
||||
}
|
||||
|
||||
static bool is_compatible_type(const at::Type& expected, const at::Type& actual) {
|
||||
static bool is_compatible_type(const at::DeprecatedTypeProperties& expected, const at::DeprecatedTypeProperties& actual) {
|
||||
// Types are compatible if they exactly match or if the gradient is a sparse
|
||||
// version of the expected type.
|
||||
return expected == actual || (actual.is_sparse() &&
|
||||
|
|
@ -372,7 +372,7 @@ static void validate_outputs(const edge_list& edges, variable_list& grads, const
|
|||
}
|
||||
grads[i] = at::sum_to(std::move(grads[i]), metadata.shape());
|
||||
}
|
||||
if (!is_compatible_type(metadata.type(), grads[i].dispatch_type())) {
|
||||
if (!is_compatible_type(metadata.type(), grads[i].type())) {
|
||||
std::stringstream ss;
|
||||
ss << "invalid gradient at index " << i << " - expected type ";
|
||||
ss << metadata.type() << " but got " << grads[i].type();
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
|
|||
/// Adds the type and shape metadata for a new input. Returns the index of
|
||||
/// of the new input.
|
||||
uint32_t add_input_metadata(
|
||||
const at::Type& type
|
||||
const at::DeprecatedTypeProperties& type
|
||||
, at::IntArrayRef shape
|
||||
, at::Device device) noexcept {
|
||||
uint32_t input_nr = input_metadata_.size();
|
||||
|
|
|
|||
|
|
@ -12,17 +12,17 @@ namespace torch { namespace autograd {
|
|||
struct InputMetadata {
|
||||
InputMetadata() = default;
|
||||
|
||||
InputMetadata(const at::Type& type, at::IntArrayRef shape, at::Device device)
|
||||
InputMetadata(const at::DeprecatedTypeProperties& type, at::IntArrayRef shape, at::Device device)
|
||||
: type_{&type} , shape_{shape}, device_{device} { }
|
||||
|
||||
InputMetadata(const at::Tensor& t)
|
||||
: InputMetadata(t.dispatch_type(), t.sizes(), t.device()) { }
|
||||
: InputMetadata(t.type(), t.sizes(), t.device()) { }
|
||||
|
||||
bool is_valid() const {
|
||||
return type_ != nullptr;
|
||||
}
|
||||
|
||||
const at::Type& type() const {
|
||||
const at::DeprecatedTypeProperties& type() const {
|
||||
AT_ASSERT(type_);
|
||||
return *type_;
|
||||
}
|
||||
|
|
@ -40,7 +40,7 @@ struct InputMetadata {
|
|||
}
|
||||
|
||||
private:
|
||||
const at::Type* type_ = nullptr;
|
||||
const at::DeprecatedTypeProperties* type_ = nullptr;
|
||||
at::DimVector shape_;
|
||||
at::Device device_ = at::kCPU;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@ namespace torch { namespace autograd {
|
|||
VariableInfo::VariableInfo(const Variable& var)
|
||||
: type(&var.dispatch_type())
|
||||
, device(var.device())
|
||||
, scalar_type(var.scalar_type())
|
||||
, size(var.sizes().vec())
|
||||
, requires_grad(var.requires_grad()) {
|
||||
}
|
||||
|
|
@ -53,7 +54,7 @@ VariableInfo::VariableInfo(const Variable& var)
|
|||
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
|
||||
// NB: This will NOT work if we ever get mixed device gradients
|
||||
device_guard.reset_device(device);
|
||||
return at::zeros(size, type->options());
|
||||
return at::zeros(size, type->options(scalar_type));
|
||||
}
|
||||
|
||||
auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ struct VariableInfo {
|
|||
|
||||
at::Type* type;
|
||||
at::Device device = at::kCPU;
|
||||
at::ScalarType scalar_type = at::kFloat;
|
||||
std::vector<int64_t> size;
|
||||
bool requires_grad;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -46,7 +46,8 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject
|
|||
if (!data || data == Py_None) {
|
||||
// For legacy serialization code, create an empty tensor. This is also used
|
||||
// by nn.Parameter() with no arguments.
|
||||
auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options());
|
||||
auto scalar_type = torch::tensors::get_default_scalar_type();
|
||||
auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options(scalar_type));
|
||||
tensor = static_cast<Variable&>(var).data();
|
||||
} else if (THPVariable_Check(data)) {
|
||||
tensor = ((THPVariable*)data)->cdata.data();
|
||||
|
|
|
|||
|
|
@ -110,15 +110,15 @@ static Variable sequenceToVariable(const at::Type& type, PyObject* seq) {
|
|||
return torch::utils::indexing_tensor_from_data(idx_type, kLong, c10::nullopt, seq);
|
||||
}
|
||||
|
||||
static Variable valueToTensor(const at::Type & type, PyObject* value) {
|
||||
static Variable valueToTensor(const at::Type & type, const ScalarType scalar_type, PyObject* value) {
|
||||
if (THPVariable_Check(value)) {
|
||||
return reinterpret_cast<THPVariable*>(value)->cdata;
|
||||
}
|
||||
if (THPUtils_checkLong(value) || PyBool_Check(value)) {
|
||||
return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options());
|
||||
return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options(scalar_type));
|
||||
}
|
||||
if (PyFloat_Check(value)) {
|
||||
return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options());
|
||||
return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options(scalar_type));
|
||||
}
|
||||
throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString());
|
||||
}
|
||||
|
|
@ -334,7 +334,7 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
|
|||
|
||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||
OptionalDeviceGuard device_guard(device_of(self_));
|
||||
auto value = valueToTensor(self_.dispatch_type(), py_value);
|
||||
auto value = valueToTensor(self_.dispatch_type(), self_.scalar_type(), py_value);
|
||||
|
||||
// handle simple types: integers, slices, ellipsis, bool
|
||||
if (index == Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ const std::shared_ptr<Function>& Variable::grad_fn() const {
|
|||
fn->storage_offset = data().storage_offset();
|
||||
fn->set_next_edges(collect_next_edges(diff_view_meta->base_));
|
||||
fn->add_input_metadata(
|
||||
diff_view_meta->base_.dispatch_type()
|
||||
diff_view_meta->base_.type()
|
||||
, sizes() // Note: sizes(), not base_.sizes(), is intentional
|
||||
, diff_view_meta->base_.device());
|
||||
diff_view_meta->grad_fn_ = std::move(fn);
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
|
|||
tensors.push_back(tensor);
|
||||
for (auto device : devices.slice(1)) {
|
||||
_device_guard.set_index(device);
|
||||
tensors.push_back(at::empty(tensor.sizes(), type.options()));
|
||||
tensors.push_back(at::empty(tensor.sizes(), type.options(tensor.scalar_type())));
|
||||
}
|
||||
nccl::broadcast(tensors);
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -157,10 +157,9 @@ class ShapePropagator {
|
|||
return *iv;
|
||||
}
|
||||
if (CompleteTensorTypePtr type = type_->cast<CompleteTensorType>()) {
|
||||
auto backend =
|
||||
type->device().is_cpu() ? at::Backend::CPU : at::Backend::CUDA;
|
||||
auto attype = type->device().is_cpu() ?
|
||||
at::CPU(type->scalarType()) : at::CUDA(type->scalarType());
|
||||
at::DeviceGuard device_guard(type->device());
|
||||
auto& attype = at::getNonVariableType(backend, type->scalarType());
|
||||
auto t =
|
||||
at::empty_strided(type->sizes(), type->strides(), attype.options())
|
||||
.zero_();
|
||||
|
|
|
|||
|
|
@ -53,32 +53,32 @@ void maybe_initialize_cuda(const Device device) {
|
|||
}
|
||||
}
|
||||
|
||||
Tensor dispatch_zeros(const Type& type, optional<Device> device, IntArrayRef sizes) {
|
||||
Tensor dispatch_zeros(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) {
|
||||
maybe_initialize_cuda(type);
|
||||
AutoNoGIL no_gil;
|
||||
return torch::zeros(sizes, type.options(std::move(device)));
|
||||
return torch::zeros(sizes, type.options(scalar_type, std::move(device)));
|
||||
}
|
||||
|
||||
Tensor dispatch_ones(const Type& type, optional<Device> device, IntArrayRef sizes) {
|
||||
Tensor dispatch_ones(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) {
|
||||
maybe_initialize_cuda(type);
|
||||
AutoNoGIL no_gil;
|
||||
return torch::ones(sizes, type.options(std::move(device)));
|
||||
return torch::ones(sizes, type.options(scalar_type, std::move(device)));
|
||||
}
|
||||
|
||||
Tensor dispatch_full(const Type& type, Scalar fill_value, optional<Device> device, IntArrayRef sizes) {
|
||||
Tensor dispatch_full(const Type& type, const ScalarType scalar_type, Scalar fill_value, optional<Device> device, IntArrayRef sizes) {
|
||||
maybe_initialize_cuda(type);
|
||||
AutoNoGIL no_gil;
|
||||
return torch::full(sizes, fill_value, type.options(std::move(device)));
|
||||
return torch::full(sizes, fill_value, type.options(scalar_type, std::move(device)));
|
||||
}
|
||||
|
||||
Tensor new_with_sizes(const Type& type, optional<Device> device, IntArrayRef sizes) {
|
||||
Tensor new_with_sizes(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) {
|
||||
maybe_initialize_cuda(type);
|
||||
AutoNoGIL no_gil;
|
||||
return torch::empty(sizes, type.options(std::move(device)));
|
||||
return torch::empty(sizes, type.options(scalar_type, std::move(device)));
|
||||
}
|
||||
|
||||
Tensor new_with_storage(const Type& type, Storage storage) {
|
||||
auto tensor = at::empty({}, type.options());
|
||||
Tensor new_with_storage(const Type& type, const ScalarType scalar_type, Storage storage) {
|
||||
auto tensor = at::empty({}, type.options(scalar_type));
|
||||
tensor.set_(std::move(storage));
|
||||
return tensor;
|
||||
}
|
||||
|
|
@ -281,7 +281,7 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, ScalarType scalar_type, PyObj
|
|||
if (r.idx == 0) {
|
||||
auto deviceOptional = r.deviceOptional(0);
|
||||
check_legacy_ctor_device(type, deviceOptional);
|
||||
return at::empty({0}, type.options(r.deviceOptional(0)));
|
||||
return at::empty({0}, type.options(scalar_type, r.deviceOptional(0)));
|
||||
} else if (r.idx == 1) {
|
||||
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
||||
return type.unsafeTensorFromTH(cdata, true);
|
||||
|
|
@ -304,7 +304,7 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, ScalarType scalar_type, PyObj
|
|||
// unless the sequences is a torch.Size
|
||||
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
||||
}
|
||||
return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
|
||||
return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
|
||||
}
|
||||
throw std::runtime_error("new(): invalid arguments");
|
||||
}
|
||||
|
|
@ -323,7 +323,7 @@ Tensor legacy_sparse_tensor_new(const Type& type, ScalarType scalar_type, PyObje
|
|||
auto deviceOptional = r.deviceOptional(0);
|
||||
check_legacy_ctor_device(type, deviceOptional);
|
||||
at::OptionalDeviceGuard device_guard(deviceOptional);
|
||||
return at::empty({0}, type.options());
|
||||
return at::empty({0}, type.options(scalar_type));
|
||||
} else if (r.idx == 1) {
|
||||
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
||||
return type.unsafeTensorFromTH(cdata, true);
|
||||
|
|
@ -350,7 +350,7 @@ Tensor legacy_sparse_tensor_new(const Type& type, ScalarType scalar_type, PyObje
|
|||
// unless the sequences is a torch.Size
|
||||
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
||||
}
|
||||
return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
|
||||
return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
|
||||
}
|
||||
throw std::runtime_error("new(): invalid arguments");
|
||||
}
|
||||
|
|
@ -384,9 +384,9 @@ Tensor legacy_tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* ar
|
|||
auto deviceOptional = r.deviceOptional(0);
|
||||
check_legacy_ctor_device(type, deviceOptional);
|
||||
at::OptionalDeviceGuard device_guard(deviceOptional);
|
||||
return at::empty({0}, type.options());
|
||||
return at::empty({0}, type.options(scalar_type));
|
||||
} else if (r.idx == 1) {
|
||||
return new_with_storage(type, r.storage(0));
|
||||
return new_with_storage(type, scalar_type, r.storage(0));
|
||||
} else if (r.idx == 2) {
|
||||
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
||||
return type.unsafeTensorFromTH(cdata, true);
|
||||
|
|
@ -401,7 +401,7 @@ Tensor legacy_tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* ar
|
|||
// unless the sequences is a torch.Size
|
||||
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
||||
}
|
||||
return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
|
||||
return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
|
||||
} else if (r.idx == 5) {
|
||||
auto deviceOptional = r.deviceOptional(1);
|
||||
check_legacy_ctor_device(type, deviceOptional);
|
||||
|
|
@ -430,9 +430,9 @@ Tensor legacy_tensor_new(const Type& type, ScalarType scalar_type, PyObject* arg
|
|||
auto deviceOptional = r.deviceOptional(0);
|
||||
check_legacy_ctor_device(type, deviceOptional);
|
||||
at::OptionalDeviceGuard device_guard(deviceOptional);
|
||||
return at::empty({0}, type.options());
|
||||
return at::empty({0}, type.options(scalar_type));
|
||||
} else if (r.idx == 1) {
|
||||
return new_with_storage(type, r.storage(0));
|
||||
return new_with_storage(type, scalar_type, r.storage(0));
|
||||
} else if (r.idx == 2) {
|
||||
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
||||
return type.unsafeTensorFromTH(cdata, true);
|
||||
|
|
@ -447,7 +447,7 @@ Tensor legacy_tensor_new(const Type& type, ScalarType scalar_type, PyObject* arg
|
|||
// unless the sequences is a torch.Size
|
||||
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
||||
}
|
||||
return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
|
||||
return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
|
||||
} else if (r.idx == 5) {
|
||||
auto deviceOptional = r.deviceOptional(1);
|
||||
check_legacy_ctor_device(type, deviceOptional);
|
||||
|
|
@ -504,8 +504,9 @@ Tensor sparse_coo_tensor_ctor(const Type& default_type, ScalarType scalar_type,
|
|||
return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5));
|
||||
} else if (r.idx == 2) {
|
||||
const auto& type = typeWithDefault(r, 1, 2, default_type, scalar_type);
|
||||
const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
|
||||
at::OptionalDeviceGuard device_guard(r.deviceOptional(2));
|
||||
return at::sparse_coo_tensor(r.intlist(0), type.options().layout(at::kSparse)).set_requires_grad(r.toBool(3));
|
||||
return at::sparse_coo_tensor(r.intlist(0), type.options(actual_scalar_type).layout(at::kSparse)).set_requires_grad(r.toBool(3));
|
||||
}
|
||||
throw std::runtime_error("sparse_coo_tensor(): invalid arguments");
|
||||
}
|
||||
|
|
@ -603,7 +604,8 @@ Tensor new_empty(const Type& type, ScalarType scalar_type, PyObject* args, PyObj
|
|||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.idx == 0) {
|
||||
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
|
||||
return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(4));
|
||||
const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
|
||||
return new_with_sizes(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(4));
|
||||
}
|
||||
throw std::runtime_error("new_empty(): invalid arguments");
|
||||
}
|
||||
|
|
@ -617,7 +619,8 @@ Tensor new_full(const Type& type, ScalarType scalar_type, PyObject* args, PyObje
|
|||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.idx == 0) {
|
||||
const auto& actual_type = typeWithDefault(r, 2, 3, type, scalar_type);
|
||||
return dispatch_full(actual_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
|
||||
const auto actual_scalar_type = r.scalartypeWithDefault(2, scalar_type);
|
||||
return dispatch_full(actual_type, actual_scalar_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
|
||||
}
|
||||
throw std::runtime_error("new_full(): invalid arguments");
|
||||
}
|
||||
|
|
@ -631,7 +634,8 @@ Tensor new_ones(const Type& type, ScalarType scalar_type, PyObject* args, PyObje
|
|||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.idx == 0) {
|
||||
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
|
||||
return dispatch_ones(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
|
||||
const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
|
||||
return dispatch_ones(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
|
||||
}
|
||||
throw std::runtime_error("new_ones(): invalid arguments");
|
||||
}
|
||||
|
|
@ -645,7 +649,8 @@ Tensor new_zeros(const Type& type, ScalarType scalar_type, PyObject* args, PyObj
|
|||
auto r = parser.parse(args, kwargs, parsed_args);
|
||||
if (r.idx == 0) {
|
||||
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
|
||||
return dispatch_zeros(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
|
||||
const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
|
||||
return dispatch_zeros(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
|
||||
}
|
||||
throw std::runtime_error("new_zeros(): invalid arguments");
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user