mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
* Created TensorOptions
Storing the type in TensorOptions to solve the Variable problem
Created convenience creation functions for TensorOptions and added tests
Converted zeros to TensorOptions
Converted rand to TensorOptions
Fix codegen for TensorOptions and multiple arguments
Put TensorOptions convenience functions into torch namespace too
All factory functions except *_like support TensorOptions
Integrated with recent JIT changes
Support *_like functions
Fix in place modification
Some cleanups and fixes
Support sparse_coo_tensor
Fix bug in Type.cpp
Fix .empty calls in C++ API
Fix bug in Type.cpp
Trying to fix device placement
Make AutoGPU CPU compatible
Remove some auto_gpu.h uses
Fixing some headers
Fix some remaining CUDA/AutoGPU issues
Fix some AutoGPU uses
Fixes to dispatch_tensor_conversion
Reset version of new variables to zero
Implemented parsing device strings
Random fixes to tests
Self review cleanups
flake8
Undo changes to variable.{h,cpp} because they fail on gcc7.2
Add [cuda] tag to tensor_options_cuda.cpp
Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks
Fix linker error in AutoGPU.cpp
Fix bad merge conflict in native_functions.yaml
Fixed caffe2/contrib/aten
Fix new window functions added to TensorFactories.cpp
* Removed torch::TensorOptions
Added code to generate wrapper functions for factory methods
Add implicit constructor from Backend to TensorOptions
Remove Var() from C++ API and use torch:: functions
Use torch:: functions more subtly in C++ API
Make AutoGPU::set_device more exception safe
Check status directly in DynamicCUDAHooksInterface
Rename AutoGPU to DeviceGuard
Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad
remove python_default_init: self.type()
Add back original factory functions, but with deprecation warnings
Disable DeviceGuard for a couple functions in ATen
Remove print statement
Fix DeviceGuard construction from undefined tensor
Fixing CUDA device compiler issues
Moved as many methods as possible into header files
Dont generate python functions for deprecated factories
Remove merge conflict artefact
Fix tensor_options_cuda.cpp
Fix set_requires_grad not being checked
Fix tensor_new.h
TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac
Fix bug in DeviceGuard.h
Missing includes
TEMPORARILY moving a few more methods into .cpp to see if it fixes windows
Fixing linker errors
* Fix up SummaryOps to use new factories
Undo device agnostic behavior of DeviceGuard
Use -1 instead of optional for default device index
Also move DeviceGuard methods into header
Fixes around device index after optional -> int32_t switch
Fix use of DeviceGuard in new_with_tensor_copy
Fix tensor_options.cpp
* Fix Type::copy(
* Remove test_non_float_params from ONNX tests
* Set requires_grad=False in ONNX tests that use ints
* Put layout/dtype/device on Tensor
* Post merge fixes
* Change behavior of DeviceGuard to match AutoGPU
* Fix C++ API integration tests
* Fix flip functions
430 lines
14 KiB
C++
430 lines
14 KiB
C++
#pragma once
|
|
|
|
// Parse arguments to Python functions implemented in C++
|
|
// This is similar to PyArg_ParseTupleAndKeywords(), but specifically handles
|
|
// the types relevant to PyTorch and distinguishes between overloaded function
|
|
// signatures.
|
|
//
|
|
// Example:
|
|
//
|
|
// static PythonArgParser parser({
|
|
// "norm(Scalar p, int64_t dim, bool keepdim=False)",
|
|
// "norm(Scalar p=2)",
|
|
// });
|
|
// ParsedArgs<3> parsed_args;
|
|
// auto r = parser.parse(args, kwargs, parsed_args);
|
|
// if (r.idx == 0) {
|
|
// norm(r.scalar(0), r.int64(1), r.bool(0));
|
|
// } else {
|
|
// norm(r.scalar(0));
|
|
// }
|
|
//
|
|
// We auto-generate most uses of PythonArgParser; the generated files
|
|
// are torch/csrc/autograd/generated/python_*.cpp
|
|
//
|
|
// Some gotchas that you should watch out for:
|
|
//
|
|
// - Note [Order of overloads matters]
|
|
// Order of overloads matters. A set of input arguments may
|
|
// bind to multiple argument specs; we will always pick the
|
|
// first one in PythonArgParser. However, when you are writing
|
|
// overloads in, e.g., native_functions.yaml, you don't have to
|
|
// worry about what order you write them, because the code
|
|
// generation logic always gives the overloads a canonical
|
|
// order, where Tensor overloads come first, before Scalar overloads.
|
|
// This logic is in sort_declarations in
|
|
// tools/autograd/gen_python_functions.py
|
|
//
|
|
// - Zero-dim tensors (e.g., torch.tensor(2)) bind to both
|
|
// Scalar and Tensor, UNLESS they require grad (in which case
|
|
// they only bind to Tensor).
|
|
|
|
|
|
#include "torch/csrc/python_headers.h"
|
|
|
|
#include "torch/csrc/Device.h"
|
|
#include "torch/csrc/Dtype.h"
|
|
#include "torch/csrc/DynamicTypes.h"
|
|
#include "torch/csrc/Exceptions.h"
|
|
#include "torch/csrc/Generator.h"
|
|
#include "torch/csrc/autograd/generated/VariableType.h"
|
|
#include "torch/csrc/autograd/python_variable.h"
|
|
#include "torch/csrc/jit/tracer.h"
|
|
#include "torch/csrc/tensor/python_tensor.h"
|
|
#include "torch/csrc/utils/numpy_stub.h"
|
|
#include "torch/csrc/utils/object_ptr.h"
|
|
#include "torch/csrc/utils/python_numbers.h"
|
|
#include "torch/csrc/utils/python_strings.h"
|
|
|
|
#include <ATen/ATen.h>
|
|
|
|
#include <array>
|
|
#include <cstddef>
|
|
#include <memory>
|
|
#include <sstream>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace torch {
|
|
|
|
enum class ParameterType {
|
|
TENSOR, SCALAR, INT64, DOUBLE, TENSOR_LIST, INT_LIST, GENERATOR,
|
|
BOOL, STORAGE, PYOBJECT, SCALARTYPE, LAYOUT, DEVICE, STRING
|
|
};
|
|
|
|
struct FunctionParameter;
|
|
struct FunctionSignature;
|
|
struct PythonArgs;
|
|
|
|
// Contains bound Python arguments in declaration order
|
|
template<int N>
|
|
struct ParsedArgs {
|
|
PyObject* args[N];
|
|
};
|
|
|
|
struct PythonArgParser {
|
|
explicit PythonArgParser(std::vector<std::string> fmts, bool traceable=false);
|
|
|
|
template<int N>
|
|
inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs<N>& dst);
|
|
|
|
private:
|
|
[[noreturn]]
|
|
void print_error(PyObject* args, PyObject* kwargs, PyObject* dst[]);
|
|
PythonArgs raw_parse(PyObject* args, PyObject* kwargs, PyObject* dst[]);
|
|
|
|
std::vector<FunctionSignature> signatures_;
|
|
std::string function_name;
|
|
ssize_t max_args;
|
|
bool traceable;
|
|
};
|
|
|
|
struct PythonArgs {
|
|
PythonArgs(int idx, bool traceable, const FunctionSignature& signature, PyObject** args)
|
|
: idx(idx)
|
|
, traceable(traceable)
|
|
, signature(signature)
|
|
, args(args) {}
|
|
|
|
int idx;
|
|
bool traceable;
|
|
const FunctionSignature& signature;
|
|
PyObject** args;
|
|
|
|
inline at::Tensor tensor(int i);
|
|
inline at::Scalar scalar(int i);
|
|
inline at::Scalar scalarWithDefault(int i, at::Scalar default_scalar);
|
|
inline std::vector<at::Tensor> tensorlist(int i);
|
|
template<int N>
|
|
inline std::array<at::Tensor, N> tensorlist_n(int i);
|
|
inline std::vector<int64_t> intlist(int i);
|
|
inline std::vector<int64_t> intlistWithDefault(int i, std::vector<int64_t> default_intlist);
|
|
inline at::Generator* generator(int i);
|
|
inline std::unique_ptr<at::Storage> storage(int i);
|
|
inline at::ScalarType scalartype(int i);
|
|
inline at::ScalarType scalartypeWithDefault(int i, at::ScalarType default_scalartype);
|
|
inline at::optional<at::ScalarType> scalartypeOptional(int i);
|
|
inline const THPLayout& layout(int i);
|
|
inline const THPLayout& layoutWithDefault(int i, const THPLayout& default_layout);
|
|
inline at::Device device(int i);
|
|
inline at::Device deviceWithDefault(int i, const at::Device& default_device);
|
|
inline at::optional<at::Device> deviceOptional(int i);
|
|
inline std::string string(int i);
|
|
inline PyObject* pyobject(int i);
|
|
inline int64_t toInt64(int i);
|
|
inline int64_t toInt64WithDefault(int i, int64_t default_int);
|
|
inline double toDouble(int i);
|
|
inline double toDoubleWithDefault(int i, double default_double);
|
|
inline bool toBool(int i);
|
|
inline bool toBoolWithDefault(int i, bool default_bool);
|
|
inline bool isNone(int i);
|
|
};
|
|
|
|
struct FunctionSignature {
|
|
explicit FunctionSignature(const std::string& fmt);
|
|
|
|
bool parse(PyObject* args, PyObject* kwargs, PyObject* dst[], bool raise_exception);
|
|
std::string toString() const;
|
|
|
|
std::string name;
|
|
std::vector<FunctionParameter> params;
|
|
ssize_t min_args;
|
|
ssize_t max_args;
|
|
ssize_t max_pos_args;
|
|
bool hidden;
|
|
bool deprecated;
|
|
};
|
|
|
|
struct FunctionParameter {
|
|
FunctionParameter(const std::string& fmt, bool keyword_only);
|
|
|
|
bool check(PyObject* obj);
|
|
void set_default_str(const std::string& str);
|
|
std::string type_name() const;
|
|
|
|
ParameterType type_;
|
|
bool optional;
|
|
bool allow_none;
|
|
bool keyword_only;
|
|
int size;
|
|
std::string name;
|
|
// having this as a raw PyObject * will presumably leak it, but these are only held by static objects
|
|
// anyway, and Py_Finalize can already be called when this is destructed.
|
|
PyObject *python_name;
|
|
at::Scalar default_scalar;
|
|
std::vector<int64_t> default_intlist;
|
|
union {
|
|
bool default_bool;
|
|
int64_t default_int;
|
|
double default_double;
|
|
at::ScalarType default_scalartype;
|
|
THPLayout* default_layout;
|
|
};
|
|
};
|
|
|
|
template<int N>
|
|
inline PythonArgs PythonArgParser::parse(PyObject* args, PyObject* kwargs, ParsedArgs<N>& dst) {
|
|
if (N < max_args) {
|
|
throw ValueError("PythonArgParser: dst ParsedArgs buffer does not have enough capacity, expected %d (got %d)",
|
|
(int)max_args, N);
|
|
}
|
|
return raw_parse(args, kwargs, dst.args);
|
|
}
|
|
|
|
inline at::Tensor PythonArgs::tensor(int i) {
|
|
if (!args[i]) return at::Tensor();
|
|
if (!THPVariable_Check(args[i])) {
|
|
// NB: Are you here because you passed None to a Variable method,
|
|
// and you expected an undefined tensor to be returned? Don't add
|
|
// a test for Py_None here; instead, you need to mark the argument
|
|
// as *allowing none*; you can do this by writing 'Tensor?' instead
|
|
// of 'Tensor' in the ATen metadata.
|
|
throw TypeError("expected Tensor as argument %d, but got %s", i,
|
|
Py_TYPE(args[i])->tp_name);
|
|
}
|
|
return reinterpret_cast<THPVariable*>(args[i])->cdata;
|
|
}
|
|
|
|
inline at::Scalar PythonArgs::scalar(int i) {
|
|
return scalarWithDefault(i, signature.params[i].default_scalar);
|
|
}
|
|
|
|
inline at::Scalar PythonArgs::scalarWithDefault(int i, at::Scalar default_scalar) {
|
|
if (!args[i]) return default_scalar;
|
|
// Zero-dim tensors are converted to Scalars as-is. Note this doesn't currently
|
|
// handle most NumPy scalar types except np.float64.
|
|
if (THPVariable_Check(args[i])) {
|
|
return at::Scalar(((THPVariable*)args[i])->cdata);
|
|
}
|
|
if (THPUtils_checkLong(args[i])) {
|
|
return at::Scalar(static_cast<int64_t>(THPUtils_unpackLong(args[i])));
|
|
}
|
|
return at::Scalar(THPUtils_unpackDouble(args[i]));
|
|
}
|
|
|
|
inline std::vector<at::Tensor> PythonArgs::tensorlist(int i) {
|
|
if (!args[i]) return std::vector<at::Tensor>();
|
|
PyObject* arg = args[i];
|
|
auto tuple = PyTuple_Check(arg);
|
|
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
|
std::vector<at::Tensor> res(size);
|
|
for (int idx = 0; idx < size; idx++) {
|
|
PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
|
|
if (!THPVariable_Check(obj)) {
|
|
throw TypeError("expected Tensor as element %d in argument %d, but got %s",
|
|
idx, i, Py_TYPE(args[i])->tp_name);
|
|
}
|
|
res[idx] = reinterpret_cast<THPVariable*>(obj)->cdata;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
template<int N>
|
|
inline std::array<at::Tensor, N> PythonArgs::tensorlist_n(int i) {
|
|
auto res = std::array<at::Tensor, N>();
|
|
PyObject* arg = args[i];
|
|
if (!arg) return res;
|
|
auto tuple = PyTuple_Check(arg);
|
|
auto size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
|
if (size != N) {
|
|
throw TypeError("expected tuple of %d elements but got %d", N, (int)size);
|
|
}
|
|
for (int idx = 0; idx < size; idx++) {
|
|
PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
|
|
if (!THPVariable_Check(obj)) {
|
|
throw TypeError("expected Tensor as element %d in argument %d, but got %s",
|
|
idx, i, Py_TYPE(args[i])->tp_name);
|
|
}
|
|
res[idx] = reinterpret_cast<THPVariable*>(obj)->cdata;
|
|
}
|
|
return res;
|
|
}
|
|
|
|
inline std::vector<int64_t> PythonArgs::intlist(int i) {
|
|
return intlistWithDefault(i, signature.params[i].default_intlist);
|
|
}
|
|
|
|
inline std::vector<int64_t> PythonArgs::intlistWithDefault(int i, std::vector<int64_t> default_intlist) {
|
|
if (!args[i]) return default_intlist;
|
|
PyObject* arg = args[i];
|
|
auto size = signature.params[i].size;
|
|
if (size > 0 && THPUtils_checkLong(arg)) {
|
|
return std::vector<int64_t>(size, THPUtils_unpackIndex(arg));
|
|
}
|
|
auto tuple = PyTuple_Check(arg);
|
|
size = tuple ? PyTuple_GET_SIZE(arg) : PyList_GET_SIZE(arg);
|
|
std::vector<int64_t> res(size);
|
|
for (int idx = 0; idx < size; idx++) {
|
|
PyObject* obj = tuple ? PyTuple_GET_ITEM(arg, idx) : PyList_GET_ITEM(arg, idx);
|
|
try {
|
|
// Elements of torch.Size are tensors during tracing, and we need to record extra
|
|
// information before they are turned into an IntList
|
|
if (traceable && THPVariable_Check(obj)) {
|
|
auto & var = THPVariable_Unpack(obj);
|
|
jit::tracer::ArgumentStash::stashIntListElem(
|
|
signature.params[i].name, size, idx, var);
|
|
res[idx] = var.toCLong();
|
|
continue;
|
|
} else {
|
|
res[idx] = THPUtils_unpackIndex(obj);
|
|
}
|
|
} catch (std::runtime_error &e) {
|
|
throw TypeError("%s(): argument '%s' must be %s, but found element of type %s at pos %d",
|
|
signature.name.c_str(), signature.params[i].name.c_str(),
|
|
signature.params[i].type_name().c_str(), Py_TYPE(obj)->tp_name, idx + 1);
|
|
}
|
|
}
|
|
return res;
|
|
}
|
|
|
|
inline at::ScalarType PythonArgs::scalartypeWithDefault(int i, at::ScalarType default_scalartype) {
|
|
if (!args[i]) return default_scalartype;
|
|
return scalartype(i);
|
|
}
|
|
|
|
inline at::ScalarType PythonArgs::scalartype(int i) {
|
|
if (!args[i]) {
|
|
auto scalartype = signature.params[i].default_scalartype;
|
|
return (scalartype == at::ScalarType::Undefined) ?
|
|
torch::tensor::get_default_tensor_type().scalarType() : scalartype;
|
|
}
|
|
return reinterpret_cast<THPDtype*>(args[i])->scalar_type;
|
|
}
|
|
|
|
inline at::optional<at::ScalarType> PythonArgs::scalartypeOptional(int i) {
|
|
if (!args[i]) return at::nullopt;
|
|
return scalartype(i);
|
|
}
|
|
|
|
inline const THPLayout& PythonArgs::layout(int i) {
|
|
if (!args[i]) return *signature.params[i].default_layout;
|
|
return *reinterpret_cast<THPLayout*>(args[i]);
|
|
}
|
|
|
|
inline const THPLayout& PythonArgs::layoutWithDefault(int i, const THPLayout& default_layout) {
|
|
if (!args[i]) return default_layout;
|
|
return layout(i);
|
|
}
|
|
|
|
static std::string cuda_str = "cuda";
|
|
static std::string cpu_str = "cpu";
|
|
static std::string cuda_prefix = "cuda:";
|
|
static std::string cpu_prefix = "cpu:";
|
|
|
|
inline at::Device PythonArgs::device(int i) {
|
|
if (!args[i]) {
|
|
const auto& default_tensor_type = torch::tensor::get_default_tensor_type();
|
|
return at::Device(default_tensor_type.backend());
|
|
}
|
|
if (THPDevice_Check(args[i])) {
|
|
const auto device = reinterpret_cast<THPDevice*>(args[i]);
|
|
return device->device;
|
|
}
|
|
if (THPUtils_checkLong(args[i])) {
|
|
const auto device_index = THPUtils_unpackLong(args[i]);
|
|
AT_CHECK(device_index >= 0, "Device index must not be negative");
|
|
return at::Device(at::kCUDA, device_index);
|
|
}
|
|
const std::string device_str = THPUtils_unpackString(args[i]);
|
|
if (device_str == cpu_str) {
|
|
return at::Device(at::kCPU);
|
|
} else if (device_str == cuda_str) {
|
|
return at::Device(at::kCUDA);
|
|
} else if (device_str.compare(0, cpu_prefix.length(), cpu_prefix) == 0) {
|
|
const auto device_index = std::stoi(device_str.substr(cpu_prefix.length()));
|
|
AT_CHECK(device_index >= 0, "Device index must not be negative");
|
|
return at::Device(at::kCPU, device_index);
|
|
} else if (device_str.compare(0, cuda_prefix.length(), cuda_prefix) == 0) {
|
|
const auto device_index = std::stoi(device_str.substr(cuda_prefix.length()));
|
|
AT_CHECK(device_index >= 0, "Device index must not be negative");
|
|
return at::Device(at::kCUDA, device_index);
|
|
}
|
|
throw torch::TypeError("only \"cuda\" and \"cpu\" are valid device types, got %s", device_str.c_str());
|
|
}
|
|
|
|
inline at::Device PythonArgs::deviceWithDefault(int i, const at::Device& default_device) {
|
|
if (!args[i]) return default_device;
|
|
return device(i);
|
|
}
|
|
|
|
inline at::optional<at::Device> PythonArgs::deviceOptional(int i) {
|
|
if (!args[i]) return at::nullopt;
|
|
return device(i);
|
|
}
|
|
|
|
inline std::string PythonArgs::string(int i) {
|
|
if (!args[i]) return "";
|
|
return THPUtils_unpackString(args[i]);
|
|
}
|
|
|
|
inline int64_t PythonArgs::toInt64(int i) {
|
|
if (!args[i]) return signature.params[i].default_int;
|
|
return THPUtils_unpackLong(args[i]);
|
|
}
|
|
|
|
inline int64_t PythonArgs::toInt64WithDefault(int i, int64_t default_int) {
|
|
if (!args[i]) return default_int;
|
|
return toInt64(i);
|
|
}
|
|
|
|
inline double PythonArgs::toDouble(int i) {
|
|
if (!args[i]) return signature.params[i].default_double;
|
|
return THPUtils_unpackDouble(args[i]);
|
|
}
|
|
|
|
inline double PythonArgs::toDoubleWithDefault(int i, double default_double) {
|
|
if (!args[i]) return default_double;
|
|
return toDouble(i);
|
|
}
|
|
|
|
inline bool PythonArgs::toBool(int i) {
|
|
if (!args[i]) return signature.params[i].default_bool;
|
|
return args[i] == Py_True;
|
|
}
|
|
|
|
inline bool PythonArgs::toBoolWithDefault(int i, bool default_bool) {
|
|
if (!args[i]) return default_bool;
|
|
return toBool(i);
|
|
}
|
|
|
|
inline bool PythonArgs::isNone(int i) {
|
|
return args[i] == nullptr;
|
|
}
|
|
|
|
inline at::Generator* PythonArgs::generator(int i) {
|
|
if (!args[i]) return nullptr;
|
|
return reinterpret_cast<THPGenerator*>(args[i])->cdata;
|
|
}
|
|
|
|
inline std::unique_ptr<at::Storage> PythonArgs::storage(int i) {
|
|
if (!args[i]) return nullptr;
|
|
return createStorage(args[i]);
|
|
}
|
|
|
|
inline PyObject* PythonArgs::pyobject(int i) {
|
|
if (!args[i]) return Py_None;
|
|
return args[i];
|
|
}
|
|
|
|
} // namespace torch
|