mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Add ScalarType argument to Type::options() (#19270)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/19270 ghimport-source-id: a5ade6131f3260066c5750ea1fa9ed5c998bb791 Differential Revision: D14938707 Pulled By: li-roy fbshipit-source-id: 018fb3f01706531a06515d6d861e5683a455a705
This commit is contained in:
parent
a044ba1af5
commit
ab78449e8c
|
|
@ -176,6 +176,11 @@ CAFFE2_API TypeExtendedInterface& getType(const Tensor&);
|
||||||
|
|
||||||
CAFFE2_API Allocator* getCPUAllocator();
|
CAFFE2_API Allocator* getCPUAllocator();
|
||||||
|
|
||||||
|
static inline DeprecatedTypeProperties& getNonVariableDeprecatedTypeProperties(Backend p, ScalarType s) {
|
||||||
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
||||||
|
p, s, /*is_variable*/false);
|
||||||
|
}
|
||||||
|
|
||||||
static inline DeprecatedTypeProperties& CPU(ScalarType s) {
|
static inline DeprecatedTypeProperties& CPU(ScalarType s) {
|
||||||
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties(
|
||||||
Backend::CPU, s, /*is_variable*/false);
|
Backend::CPU, s, /*is_variable*/false);
|
||||||
|
|
|
||||||
|
|
@ -176,9 +176,8 @@ struct CAFFE2_API Type {
|
||||||
return this != &other;
|
return this != &other;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Constructs the `TensorOptions` from a type and a `device_index`.
|
TensorOptions options(ScalarType s, int16_t device_index = -1) const {
|
||||||
TensorOptions options(int16_t device_index = -1) const {
|
return TensorOptions().dtype(s)
|
||||||
return TensorOptions().dtype(typeMeta())
|
|
||||||
.device(device_type(), device_index)
|
.device(device_type(), device_index)
|
||||||
.layout(layout())
|
.layout(layout())
|
||||||
.is_variable(is_variable());
|
.is_variable(is_variable());
|
||||||
|
|
@ -186,20 +185,16 @@ struct CAFFE2_API Type {
|
||||||
|
|
||||||
/// Constructs the `TensorOptions` from a type and a Device. Asserts that
|
/// Constructs the `TensorOptions` from a type and a Device. Asserts that
|
||||||
/// the device type matches the device type of the type.
|
/// the device type matches the device type of the type.
|
||||||
TensorOptions options(c10::optional<Device> device_opt) const {
|
TensorOptions options(ScalarType s, c10::optional<Device> device_opt) const {
|
||||||
if (!device_opt.has_value()) {
|
if (!device_opt.has_value()) {
|
||||||
return options(-1);
|
return options(s, -1);
|
||||||
} else {
|
} else {
|
||||||
Device device = device_opt.value();
|
Device device = device_opt.value();
|
||||||
AT_ASSERT(device.type() == device_type());
|
AT_ASSERT(device.type() == device_type());
|
||||||
return options(device.index());
|
return options(s, device.index());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
operator TensorOptions() const {
|
|
||||||
return options();
|
|
||||||
}
|
|
||||||
|
|
||||||
// example
|
// example
|
||||||
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
|
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
|
||||||
virtual Tensor abs(const Tensor & self) const = 0;
|
virtual Tensor abs(const Tensor & self) const = 0;
|
||||||
|
|
|
||||||
|
|
@ -1604,7 +1604,8 @@ def create_derived(backend_type_env, declarations):
|
||||||
# e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the
|
# e.g. x.sum(0) and x.sum() return the same type. We explicitly cast to the
|
||||||
# ScalarType before constructing the scalar_tensor to avoid overflow checking.
|
# ScalarType before constructing the scalar_tensor to avoid overflow checking.
|
||||||
elif ret['type'] == 'accreal' or ret['type'] == 'real':
|
elif ret['type'] == 'accreal' or ret['type'] == 'real':
|
||||||
return_scalar = 'return at::scalar_tensor(convert<${ScalarType}>(${call}), options());'
|
return_scalar = ('return at::scalar_tensor(convert<${ScalarType}>(${call}), '
|
||||||
|
'options(ScalarType::${ScalarName}));')
|
||||||
case_body.append(CodeTemplate(return_scalar).substitute(case_env, call=call))
|
case_body.append(CodeTemplate(return_scalar).substitute(case_env, call=call))
|
||||||
else:
|
else:
|
||||||
# we using int64_t for long in the API, so correct it here...
|
# we using int64_t for long in the API, so correct it here...
|
||||||
|
|
|
||||||
|
|
@ -119,9 +119,8 @@ struct CAFFE2_API Type {
|
||||||
return this != &other;
|
return this != &other;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Constructs the `TensorOptions` from a type and a `device_index`.
|
TensorOptions options(ScalarType s, int16_t device_index = -1) const {
|
||||||
TensorOptions options(int16_t device_index = -1) const {
|
return TensorOptions().dtype(s)
|
||||||
return TensorOptions().dtype(typeMeta())
|
|
||||||
.device(device_type(), device_index)
|
.device(device_type(), device_index)
|
||||||
.layout(layout())
|
.layout(layout())
|
||||||
.is_variable(is_variable());
|
.is_variable(is_variable());
|
||||||
|
|
@ -129,20 +128,16 @@ struct CAFFE2_API Type {
|
||||||
|
|
||||||
/// Constructs the `TensorOptions` from a type and a Device. Asserts that
|
/// Constructs the `TensorOptions` from a type and a Device. Asserts that
|
||||||
/// the device type matches the device type of the type.
|
/// the device type matches the device type of the type.
|
||||||
TensorOptions options(c10::optional<Device> device_opt) const {
|
TensorOptions options(ScalarType s, c10::optional<Device> device_opt) const {
|
||||||
if (!device_opt.has_value()) {
|
if (!device_opt.has_value()) {
|
||||||
return options(-1);
|
return options(s, -1);
|
||||||
} else {
|
} else {
|
||||||
Device device = device_opt.value();
|
Device device = device_opt.value();
|
||||||
AT_ASSERT(device.type() == device_type());
|
AT_ASSERT(device.type() == device_type());
|
||||||
return options(device.index());
|
return options(s, device.index());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
operator TensorOptions() const {
|
|
||||||
return options();
|
|
||||||
}
|
|
||||||
|
|
||||||
// example
|
// example
|
||||||
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
|
// virtual Tensor * add(Tensor & a, Tensor & b) = 0;
|
||||||
${pure_virtual_type_method_declarations}
|
${pure_virtual_type_method_declarations}
|
||||||
|
|
|
||||||
|
|
@ -66,10 +66,10 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTypes) {
|
||||||
options = TensorOptions(kInt);
|
options = TensorOptions(kInt);
|
||||||
REQUIRE_OPTIONS(kCPU, -1, kInt, kStrided);
|
REQUIRE_OPTIONS(kCPU, -1, kInt, kStrided);
|
||||||
|
|
||||||
options = TensorOptions(getNonVariableType(Backend::SparseCPU, kFloat));
|
options = TensorOptions(getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kFloat));
|
||||||
REQUIRE_OPTIONS(kCPU, -1, kFloat, kSparse);
|
REQUIRE_OPTIONS(kCPU, -1, kFloat, kSparse);
|
||||||
|
|
||||||
options = TensorOptions(getNonVariableType(Backend::SparseCPU, kByte));
|
options = TensorOptions(getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kByte));
|
||||||
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
|
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -77,7 +77,7 @@ TEST(TensorOptionsTest, ConstructsWellFromCPUTensors) {
|
||||||
auto options = empty(5, kDouble).options();
|
auto options = empty(5, kDouble).options();
|
||||||
REQUIRE_OPTIONS(kCPU, -1, kDouble, kStrided);
|
REQUIRE_OPTIONS(kCPU, -1, kDouble, kStrided);
|
||||||
|
|
||||||
options = empty(5, getNonVariableType(Backend::SparseCPU, kByte)).options();
|
options = empty(5, getNonVariableDeprecatedTypeProperties(Backend::SparseCPU, kByte)).options();
|
||||||
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
|
REQUIRE_OPTIONS(kCPU, -1, kByte, kSparse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -42,17 +42,17 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATypes_CUDA) {
|
||||||
options = CUDA(kInt).options();
|
options = CUDA(kInt).options();
|
||||||
REQUIRE_OPTIONS(kCUDA, -1, kInt, kStrided);
|
REQUIRE_OPTIONS(kCUDA, -1, kInt, kStrided);
|
||||||
|
|
||||||
options = getNonVariableType(Backend::SparseCUDA, kFloat).options();
|
options = getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options();
|
||||||
REQUIRE_OPTIONS(kCUDA, -1, kFloat, kSparse);
|
REQUIRE_OPTIONS(kCUDA, -1, kFloat, kSparse);
|
||||||
|
|
||||||
options = getNonVariableType(Backend::SparseCUDA, kByte).options();
|
options = getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kByte).options();
|
||||||
REQUIRE_OPTIONS(kCUDA, -1, kByte, kSparse);
|
REQUIRE_OPTIONS(kCUDA, -1, kByte, kSparse);
|
||||||
|
|
||||||
options = CUDA(kFloat).options(/*device=*/5);
|
options = CUDA(kFloat).options(/*device=*/5);
|
||||||
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kStrided);
|
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kStrided);
|
||||||
|
|
||||||
options =
|
options =
|
||||||
getNonVariableType(Backend::SparseCUDA, kFloat).options(/*device=*/5);
|
getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kFloat).options(/*device=*/5);
|
||||||
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse);
|
REQUIRE_OPTIONS(kCUDA, 5, kFloat, kSparse);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -60,7 +60,7 @@ TEST(TensorOptionsTest, ConstructsWellFromCUDATensors_MultiCUDA) {
|
||||||
auto options = empty(5, device(kCUDA).dtype(kDouble)).options();
|
auto options = empty(5, device(kCUDA).dtype(kDouble)).options();
|
||||||
REQUIRE_OPTIONS(kCUDA, 0, kDouble, kStrided);
|
REQUIRE_OPTIONS(kCUDA, 0, kDouble, kStrided);
|
||||||
|
|
||||||
options = empty(5, getNonVariableType(Backend::SparseCUDA, kByte)).options();
|
options = empty(5, getNonVariableDeprecatedTypeProperties(Backend::SparseCUDA, kByte)).options();
|
||||||
REQUIRE_OPTIONS(kCUDA, 0, kByte, kSparse);
|
REQUIRE_OPTIONS(kCUDA, 0, kByte, kSparse);
|
||||||
|
|
||||||
if (torch::cuda::device_count() > 1) {
|
if (torch::cuda::device_count() > 1) {
|
||||||
|
|
|
||||||
|
|
@ -35,13 +35,13 @@ struct TypeAndSize {
|
||||||
/* implicit */
|
/* implicit */
|
||||||
TypeAndSize(const Tensor & t)
|
TypeAndSize(const Tensor & t)
|
||||||
: sizes(t.sizes().vec())
|
: sizes(t.sizes().vec())
|
||||||
, type(&t.dispatch_type()) {}
|
, type(&t.type()) {}
|
||||||
|
|
||||||
Tensor zeros() { return at::zeros(sizes, *type); }
|
Tensor zeros() { return at::zeros(sizes, *type); }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<int64_t> sizes;
|
std::vector<int64_t> sizes;
|
||||||
Type* type;
|
at::DeprecatedTypeProperties* type;
|
||||||
};
|
};
|
||||||
|
|
||||||
${autograd_function_declarations}
|
${autograd_function_declarations}
|
||||||
|
|
|
||||||
|
|
@ -334,7 +334,7 @@ static variable_list call_post_hooks(Function& fn, variable_list outputs, const
|
||||||
return outputs;
|
return outputs;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool is_compatible_type(const at::Type& expected, const at::Type& actual) {
|
static bool is_compatible_type(const at::DeprecatedTypeProperties& expected, const at::DeprecatedTypeProperties& actual) {
|
||||||
// Types are compatible if they exactly match or if the gradient is a sparse
|
// Types are compatible if they exactly match or if the gradient is a sparse
|
||||||
// version of the expected type.
|
// version of the expected type.
|
||||||
return expected == actual || (actual.is_sparse() &&
|
return expected == actual || (actual.is_sparse() &&
|
||||||
|
|
@ -372,7 +372,7 @@ static void validate_outputs(const edge_list& edges, variable_list& grads, const
|
||||||
}
|
}
|
||||||
grads[i] = at::sum_to(std::move(grads[i]), metadata.shape());
|
grads[i] = at::sum_to(std::move(grads[i]), metadata.shape());
|
||||||
}
|
}
|
||||||
if (!is_compatible_type(metadata.type(), grads[i].dispatch_type())) {
|
if (!is_compatible_type(metadata.type(), grads[i].type())) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << "invalid gradient at index " << i << " - expected type ";
|
ss << "invalid gradient at index " << i << " - expected type ";
|
||||||
ss << metadata.type() << " but got " << grads[i].type();
|
ss << metadata.type() << " but got " << grads[i].type();
|
||||||
|
|
|
||||||
|
|
@ -130,7 +130,7 @@ struct TORCH_API Function : std::enable_shared_from_this<Function> {
|
||||||
/// Adds the type and shape metadata for a new input. Returns the index of
|
/// Adds the type and shape metadata for a new input. Returns the index of
|
||||||
/// of the new input.
|
/// of the new input.
|
||||||
uint32_t add_input_metadata(
|
uint32_t add_input_metadata(
|
||||||
const at::Type& type
|
const at::DeprecatedTypeProperties& type
|
||||||
, at::IntArrayRef shape
|
, at::IntArrayRef shape
|
||||||
, at::Device device) noexcept {
|
, at::Device device) noexcept {
|
||||||
uint32_t input_nr = input_metadata_.size();
|
uint32_t input_nr = input_metadata_.size();
|
||||||
|
|
|
||||||
|
|
@ -12,17 +12,17 @@ namespace torch { namespace autograd {
|
||||||
struct InputMetadata {
|
struct InputMetadata {
|
||||||
InputMetadata() = default;
|
InputMetadata() = default;
|
||||||
|
|
||||||
InputMetadata(const at::Type& type, at::IntArrayRef shape, at::Device device)
|
InputMetadata(const at::DeprecatedTypeProperties& type, at::IntArrayRef shape, at::Device device)
|
||||||
: type_{&type} , shape_{shape}, device_{device} { }
|
: type_{&type} , shape_{shape}, device_{device} { }
|
||||||
|
|
||||||
InputMetadata(const at::Tensor& t)
|
InputMetadata(const at::Tensor& t)
|
||||||
: InputMetadata(t.dispatch_type(), t.sizes(), t.device()) { }
|
: InputMetadata(t.type(), t.sizes(), t.device()) { }
|
||||||
|
|
||||||
bool is_valid() const {
|
bool is_valid() const {
|
||||||
return type_ != nullptr;
|
return type_ != nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const at::Type& type() const {
|
const at::DeprecatedTypeProperties& type() const {
|
||||||
AT_ASSERT(type_);
|
AT_ASSERT(type_);
|
||||||
return *type_;
|
return *type_;
|
||||||
}
|
}
|
||||||
|
|
@ -40,7 +40,7 @@ struct InputMetadata {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const at::Type* type_ = nullptr;
|
const at::DeprecatedTypeProperties* type_ = nullptr;
|
||||||
at::DimVector shape_;
|
at::DimVector shape_;
|
||||||
at::Device device_ = at::kCPU;
|
at::Device device_ = at::kCPU;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@ namespace torch { namespace autograd {
|
||||||
VariableInfo::VariableInfo(const Variable& var)
|
VariableInfo::VariableInfo(const Variable& var)
|
||||||
: type(&var.dispatch_type())
|
: type(&var.dispatch_type())
|
||||||
, device(var.device())
|
, device(var.device())
|
||||||
|
, scalar_type(var.scalar_type())
|
||||||
, size(var.sizes().vec())
|
, size(var.sizes().vec())
|
||||||
, requires_grad(var.requires_grad()) {
|
, requires_grad(var.requires_grad()) {
|
||||||
}
|
}
|
||||||
|
|
@ -53,7 +54,7 @@ VariableInfo::VariableInfo(const Variable& var)
|
||||||
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
|
Variable VariableInfo::zeros(at::OptionalDeviceGuard& device_guard) const {
|
||||||
// NB: This will NOT work if we ever get mixed device gradients
|
// NB: This will NOT work if we ever get mixed device gradients
|
||||||
device_guard.reset_device(device);
|
device_guard.reset_device(device);
|
||||||
return at::zeros(size, type->options());
|
return at::zeros(size, type->options(scalar_type));
|
||||||
}
|
}
|
||||||
|
|
||||||
auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
|
auto PyFunction::legacy_apply(const variable_list& inputs) -> variable_list {
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ struct VariableInfo {
|
||||||
|
|
||||||
at::Type* type;
|
at::Type* type;
|
||||||
at::Device device = at::kCPU;
|
at::Device device = at::kCPU;
|
||||||
|
at::ScalarType scalar_type = at::kFloat;
|
||||||
std::vector<int64_t> size;
|
std::vector<int64_t> size;
|
||||||
bool requires_grad;
|
bool requires_grad;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,8 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject
|
||||||
if (!data || data == Py_None) {
|
if (!data || data == Py_None) {
|
||||||
// For legacy serialization code, create an empty tensor. This is also used
|
// For legacy serialization code, create an empty tensor. This is also used
|
||||||
// by nn.Parameter() with no arguments.
|
// by nn.Parameter() with no arguments.
|
||||||
auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options());
|
auto scalar_type = torch::tensors::get_default_scalar_type();
|
||||||
|
auto var = at::empty({0}, torch::tensors::get_default_tensor_type().options(scalar_type));
|
||||||
tensor = static_cast<Variable&>(var).data();
|
tensor = static_cast<Variable&>(var).data();
|
||||||
} else if (THPVariable_Check(data)) {
|
} else if (THPVariable_Check(data)) {
|
||||||
tensor = ((THPVariable*)data)->cdata.data();
|
tensor = ((THPVariable*)data)->cdata.data();
|
||||||
|
|
|
||||||
|
|
@ -110,15 +110,15 @@ static Variable sequenceToVariable(const at::Type& type, PyObject* seq) {
|
||||||
return torch::utils::indexing_tensor_from_data(idx_type, kLong, c10::nullopt, seq);
|
return torch::utils::indexing_tensor_from_data(idx_type, kLong, c10::nullopt, seq);
|
||||||
}
|
}
|
||||||
|
|
||||||
static Variable valueToTensor(const at::Type & type, PyObject* value) {
|
static Variable valueToTensor(const at::Type & type, const ScalarType scalar_type, PyObject* value) {
|
||||||
if (THPVariable_Check(value)) {
|
if (THPVariable_Check(value)) {
|
||||||
return reinterpret_cast<THPVariable*>(value)->cdata;
|
return reinterpret_cast<THPVariable*>(value)->cdata;
|
||||||
}
|
}
|
||||||
if (THPUtils_checkLong(value) || PyBool_Check(value)) {
|
if (THPUtils_checkLong(value) || PyBool_Check(value)) {
|
||||||
return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options());
|
return at::scalar_tensor(Scalar(THPUtils_unpackLong(value)), type.options(scalar_type));
|
||||||
}
|
}
|
||||||
if (PyFloat_Check(value)) {
|
if (PyFloat_Check(value)) {
|
||||||
return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options());
|
return at::scalar_tensor(Scalar(THPUtils_unpackDouble(value)), type.options(scalar_type));
|
||||||
}
|
}
|
||||||
throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString());
|
throw TypeError("can't assign a %s to a %s", Py_TYPE(value)->tp_name, type.toString());
|
||||||
}
|
}
|
||||||
|
|
@ -334,7 +334,7 @@ int THPVariable_setitem(PyObject* self, PyObject* index, PyObject* py_value) {
|
||||||
|
|
||||||
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
auto& self_ = reinterpret_cast<THPVariable*>(self)->cdata;
|
||||||
OptionalDeviceGuard device_guard(device_of(self_));
|
OptionalDeviceGuard device_guard(device_of(self_));
|
||||||
auto value = valueToTensor(self_.dispatch_type(), py_value);
|
auto value = valueToTensor(self_.dispatch_type(), self_.scalar_type(), py_value);
|
||||||
|
|
||||||
// handle simple types: integers, slices, ellipsis, bool
|
// handle simple types: integers, slices, ellipsis, bool
|
||||||
if (index == Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
if (index == Py_False) { // NOLINT(cppcoreguidelines-pro-type-cstyle-cast)
|
||||||
|
|
|
||||||
|
|
@ -214,7 +214,7 @@ const std::shared_ptr<Function>& Variable::grad_fn() const {
|
||||||
fn->storage_offset = data().storage_offset();
|
fn->storage_offset = data().storage_offset();
|
||||||
fn->set_next_edges(collect_next_edges(diff_view_meta->base_));
|
fn->set_next_edges(collect_next_edges(diff_view_meta->base_));
|
||||||
fn->add_input_metadata(
|
fn->add_input_metadata(
|
||||||
diff_view_meta->base_.dispatch_type()
|
diff_view_meta->base_.type()
|
||||||
, sizes() // Note: sizes(), not base_.sizes(), is intentional
|
, sizes() // Note: sizes(), not base_.sizes(), is intentional
|
||||||
, diff_view_meta->base_.device());
|
, diff_view_meta->base_.device());
|
||||||
diff_view_meta->grad_fn_ = std::move(fn);
|
diff_view_meta->grad_fn_ = std::move(fn);
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ std::vector<Tensor> broadcast(const Tensor& tensor, IntArrayRef devices) {
|
||||||
tensors.push_back(tensor);
|
tensors.push_back(tensor);
|
||||||
for (auto device : devices.slice(1)) {
|
for (auto device : devices.slice(1)) {
|
||||||
_device_guard.set_index(device);
|
_device_guard.set_index(device);
|
||||||
tensors.push_back(at::empty(tensor.sizes(), type.options()));
|
tensors.push_back(at::empty(tensor.sizes(), type.options(tensor.scalar_type())));
|
||||||
}
|
}
|
||||||
nccl::broadcast(tensors);
|
nccl::broadcast(tensors);
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -157,10 +157,9 @@ class ShapePropagator {
|
||||||
return *iv;
|
return *iv;
|
||||||
}
|
}
|
||||||
if (CompleteTensorTypePtr type = type_->cast<CompleteTensorType>()) {
|
if (CompleteTensorTypePtr type = type_->cast<CompleteTensorType>()) {
|
||||||
auto backend =
|
auto attype = type->device().is_cpu() ?
|
||||||
type->device().is_cpu() ? at::Backend::CPU : at::Backend::CUDA;
|
at::CPU(type->scalarType()) : at::CUDA(type->scalarType());
|
||||||
at::DeviceGuard device_guard(type->device());
|
at::DeviceGuard device_guard(type->device());
|
||||||
auto& attype = at::getNonVariableType(backend, type->scalarType());
|
|
||||||
auto t =
|
auto t =
|
||||||
at::empty_strided(type->sizes(), type->strides(), attype.options())
|
at::empty_strided(type->sizes(), type->strides(), attype.options())
|
||||||
.zero_();
|
.zero_();
|
||||||
|
|
|
||||||
|
|
@ -53,32 +53,32 @@ void maybe_initialize_cuda(const Device device) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor dispatch_zeros(const Type& type, optional<Device> device, IntArrayRef sizes) {
|
Tensor dispatch_zeros(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) {
|
||||||
maybe_initialize_cuda(type);
|
maybe_initialize_cuda(type);
|
||||||
AutoNoGIL no_gil;
|
AutoNoGIL no_gil;
|
||||||
return torch::zeros(sizes, type.options(std::move(device)));
|
return torch::zeros(sizes, type.options(scalar_type, std::move(device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor dispatch_ones(const Type& type, optional<Device> device, IntArrayRef sizes) {
|
Tensor dispatch_ones(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) {
|
||||||
maybe_initialize_cuda(type);
|
maybe_initialize_cuda(type);
|
||||||
AutoNoGIL no_gil;
|
AutoNoGIL no_gil;
|
||||||
return torch::ones(sizes, type.options(std::move(device)));
|
return torch::ones(sizes, type.options(scalar_type, std::move(device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor dispatch_full(const Type& type, Scalar fill_value, optional<Device> device, IntArrayRef sizes) {
|
Tensor dispatch_full(const Type& type, const ScalarType scalar_type, Scalar fill_value, optional<Device> device, IntArrayRef sizes) {
|
||||||
maybe_initialize_cuda(type);
|
maybe_initialize_cuda(type);
|
||||||
AutoNoGIL no_gil;
|
AutoNoGIL no_gil;
|
||||||
return torch::full(sizes, fill_value, type.options(std::move(device)));
|
return torch::full(sizes, fill_value, type.options(scalar_type, std::move(device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor new_with_sizes(const Type& type, optional<Device> device, IntArrayRef sizes) {
|
Tensor new_with_sizes(const Type& type, const ScalarType scalar_type, optional<Device> device, IntArrayRef sizes) {
|
||||||
maybe_initialize_cuda(type);
|
maybe_initialize_cuda(type);
|
||||||
AutoNoGIL no_gil;
|
AutoNoGIL no_gil;
|
||||||
return torch::empty(sizes, type.options(std::move(device)));
|
return torch::empty(sizes, type.options(scalar_type, std::move(device)));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor new_with_storage(const Type& type, Storage storage) {
|
Tensor new_with_storage(const Type& type, const ScalarType scalar_type, Storage storage) {
|
||||||
auto tensor = at::empty({}, type.options());
|
auto tensor = at::empty({}, type.options(scalar_type));
|
||||||
tensor.set_(std::move(storage));
|
tensor.set_(std::move(storage));
|
||||||
return tensor;
|
return tensor;
|
||||||
}
|
}
|
||||||
|
|
@ -281,7 +281,7 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, ScalarType scalar_type, PyObj
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
auto deviceOptional = r.deviceOptional(0);
|
auto deviceOptional = r.deviceOptional(0);
|
||||||
check_legacy_ctor_device(type, deviceOptional);
|
check_legacy_ctor_device(type, deviceOptional);
|
||||||
return at::empty({0}, type.options(r.deviceOptional(0)));
|
return at::empty({0}, type.options(scalar_type, r.deviceOptional(0)));
|
||||||
} else if (r.idx == 1) {
|
} else if (r.idx == 1) {
|
||||||
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
||||||
return type.unsafeTensorFromTH(cdata, true);
|
return type.unsafeTensorFromTH(cdata, true);
|
||||||
|
|
@ -304,7 +304,7 @@ Tensor legacy_sparse_tensor_ctor(const Type& type, ScalarType scalar_type, PyObj
|
||||||
// unless the sequences is a torch.Size
|
// unless the sequences is a torch.Size
|
||||||
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
||||||
}
|
}
|
||||||
return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
|
return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new(): invalid arguments");
|
throw std::runtime_error("new(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
@ -323,7 +323,7 @@ Tensor legacy_sparse_tensor_new(const Type& type, ScalarType scalar_type, PyObje
|
||||||
auto deviceOptional = r.deviceOptional(0);
|
auto deviceOptional = r.deviceOptional(0);
|
||||||
check_legacy_ctor_device(type, deviceOptional);
|
check_legacy_ctor_device(type, deviceOptional);
|
||||||
at::OptionalDeviceGuard device_guard(deviceOptional);
|
at::OptionalDeviceGuard device_guard(deviceOptional);
|
||||||
return at::empty({0}, type.options());
|
return at::empty({0}, type.options(scalar_type));
|
||||||
} else if (r.idx == 1) {
|
} else if (r.idx == 1) {
|
||||||
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
||||||
return type.unsafeTensorFromTH(cdata, true);
|
return type.unsafeTensorFromTH(cdata, true);
|
||||||
|
|
@ -350,7 +350,7 @@ Tensor legacy_sparse_tensor_new(const Type& type, ScalarType scalar_type, PyObje
|
||||||
// unless the sequences is a torch.Size
|
// unless the sequences is a torch.Size
|
||||||
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
||||||
}
|
}
|
||||||
return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
|
return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new(): invalid arguments");
|
throw std::runtime_error("new(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
@ -384,9 +384,9 @@ Tensor legacy_tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* ar
|
||||||
auto deviceOptional = r.deviceOptional(0);
|
auto deviceOptional = r.deviceOptional(0);
|
||||||
check_legacy_ctor_device(type, deviceOptional);
|
check_legacy_ctor_device(type, deviceOptional);
|
||||||
at::OptionalDeviceGuard device_guard(deviceOptional);
|
at::OptionalDeviceGuard device_guard(deviceOptional);
|
||||||
return at::empty({0}, type.options());
|
return at::empty({0}, type.options(scalar_type));
|
||||||
} else if (r.idx == 1) {
|
} else if (r.idx == 1) {
|
||||||
return new_with_storage(type, r.storage(0));
|
return new_with_storage(type, scalar_type, r.storage(0));
|
||||||
} else if (r.idx == 2) {
|
} else if (r.idx == 2) {
|
||||||
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
||||||
return type.unsafeTensorFromTH(cdata, true);
|
return type.unsafeTensorFromTH(cdata, true);
|
||||||
|
|
@ -401,7 +401,7 @@ Tensor legacy_tensor_ctor(const Type& type, ScalarType scalar_type, PyObject* ar
|
||||||
// unless the sequences is a torch.Size
|
// unless the sequences is a torch.Size
|
||||||
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
||||||
}
|
}
|
||||||
return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
|
return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
|
||||||
} else if (r.idx == 5) {
|
} else if (r.idx == 5) {
|
||||||
auto deviceOptional = r.deviceOptional(1);
|
auto deviceOptional = r.deviceOptional(1);
|
||||||
check_legacy_ctor_device(type, deviceOptional);
|
check_legacy_ctor_device(type, deviceOptional);
|
||||||
|
|
@ -430,9 +430,9 @@ Tensor legacy_tensor_new(const Type& type, ScalarType scalar_type, PyObject* arg
|
||||||
auto deviceOptional = r.deviceOptional(0);
|
auto deviceOptional = r.deviceOptional(0);
|
||||||
check_legacy_ctor_device(type, deviceOptional);
|
check_legacy_ctor_device(type, deviceOptional);
|
||||||
at::OptionalDeviceGuard device_guard(deviceOptional);
|
at::OptionalDeviceGuard device_guard(deviceOptional);
|
||||||
return at::empty({0}, type.options());
|
return at::empty({0}, type.options(scalar_type));
|
||||||
} else if (r.idx == 1) {
|
} else if (r.idx == 1) {
|
||||||
return new_with_storage(type, r.storage(0));
|
return new_with_storage(type, scalar_type, r.storage(0));
|
||||||
} else if (r.idx == 2) {
|
} else if (r.idx == 2) {
|
||||||
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
auto cdata = reinterpret_cast<void*>(r.toInt64(0));
|
||||||
return type.unsafeTensorFromTH(cdata, true);
|
return type.unsafeTensorFromTH(cdata, true);
|
||||||
|
|
@ -447,7 +447,7 @@ Tensor legacy_tensor_new(const Type& type, ScalarType scalar_type, PyObject* arg
|
||||||
// unless the sequences is a torch.Size
|
// unless the sequences is a torch.Size
|
||||||
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
return legacy_new_from_sequence(type, scalar_type, deviceOptional, r.pyobject(0));
|
||||||
}
|
}
|
||||||
return new_with_sizes(type, r.deviceOptional(1), r.intlist(0));
|
return new_with_sizes(type, scalar_type, r.deviceOptional(1), r.intlist(0));
|
||||||
} else if (r.idx == 5) {
|
} else if (r.idx == 5) {
|
||||||
auto deviceOptional = r.deviceOptional(1);
|
auto deviceOptional = r.deviceOptional(1);
|
||||||
check_legacy_ctor_device(type, deviceOptional);
|
check_legacy_ctor_device(type, deviceOptional);
|
||||||
|
|
@ -504,8 +504,9 @@ Tensor sparse_coo_tensor_ctor(const Type& default_type, ScalarType scalar_type,
|
||||||
return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5));
|
return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5));
|
||||||
} else if (r.idx == 2) {
|
} else if (r.idx == 2) {
|
||||||
const auto& type = typeWithDefault(r, 1, 2, default_type, scalar_type);
|
const auto& type = typeWithDefault(r, 1, 2, default_type, scalar_type);
|
||||||
|
const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
|
||||||
at::OptionalDeviceGuard device_guard(r.deviceOptional(2));
|
at::OptionalDeviceGuard device_guard(r.deviceOptional(2));
|
||||||
return at::sparse_coo_tensor(r.intlist(0), type.options().layout(at::kSparse)).set_requires_grad(r.toBool(3));
|
return at::sparse_coo_tensor(r.intlist(0), type.options(actual_scalar_type).layout(at::kSparse)).set_requires_grad(r.toBool(3));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("sparse_coo_tensor(): invalid arguments");
|
throw std::runtime_error("sparse_coo_tensor(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
@ -603,7 +604,8 @@ Tensor new_empty(const Type& type, ScalarType scalar_type, PyObject* args, PyObj
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
|
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
|
||||||
return new_with_sizes(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(4));
|
const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
|
||||||
|
return new_with_sizes(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(4));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new_empty(): invalid arguments");
|
throw std::runtime_error("new_empty(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
@ -617,7 +619,8 @@ Tensor new_full(const Type& type, ScalarType scalar_type, PyObject* args, PyObje
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
const auto& actual_type = typeWithDefault(r, 2, 3, type, scalar_type);
|
const auto& actual_type = typeWithDefault(r, 2, 3, type, scalar_type);
|
||||||
return dispatch_full(actual_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
|
const auto actual_scalar_type = r.scalartypeWithDefault(2, scalar_type);
|
||||||
|
return dispatch_full(actual_type, actual_scalar_type, r.scalar(1), r.deviceOptional(3), r.intlist(0)).set_requires_grad(r.toBool(4));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new_full(): invalid arguments");
|
throw std::runtime_error("new_full(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
@ -631,7 +634,8 @@ Tensor new_ones(const Type& type, ScalarType scalar_type, PyObject* args, PyObje
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
|
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
|
||||||
return dispatch_ones(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
|
const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
|
||||||
|
return dispatch_ones(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new_ones(): invalid arguments");
|
throw std::runtime_error("new_ones(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
@ -645,7 +649,8 @@ Tensor new_zeros(const Type& type, ScalarType scalar_type, PyObject* args, PyObj
|
||||||
auto r = parser.parse(args, kwargs, parsed_args);
|
auto r = parser.parse(args, kwargs, parsed_args);
|
||||||
if (r.idx == 0) {
|
if (r.idx == 0) {
|
||||||
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
|
const auto& actual_type = typeWithDefault(r, 1, 2, type, scalar_type);
|
||||||
return dispatch_zeros(actual_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
|
const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type);
|
||||||
|
return dispatch_zeros(actual_type, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3));
|
||||||
}
|
}
|
||||||
throw std::runtime_error("new_zeros(): invalid arguments");
|
throw std::runtime_error("new_zeros(): invalid arguments");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user