pytorch/torch/csrc/autograd/python_function.h
Peter Goldsborough 372d1d6735
Create ATen tensors via TensorOptions (#7869)
* Created TensorOptions

Storing the type in TensorOptions to solve the Variable problem

Created convenience creation functions for TensorOptions and added tests

Converted zeros to TensorOptions

Converted rand to TensorOptions

Fix codegen for TensorOptions and multiple arguments

Put TensorOptions convenience functions into torch namespace too

All factory functions except *_like support TensorOptions

Integrated with recent JIT changes

Support *_like functions

Fix in place modification

Some cleanups and fixes

Support sparse_coo_tensor

Fix bug in Type.cpp

Fix .empty calls in C++ API

Fix bug in Type.cpp

Trying to fix device placement

Make AutoGPU CPU compatible

Remove some auto_gpu.h uses

Fixing some headers

Fix some remaining CUDA/AutoGPU issues

Fix some AutoGPU uses

Fixes to dispatch_tensor_conversion

Reset version of new variables to zero

Implemented parsing device strings

Random fixes to tests

Self review cleanups

flake8

Undo changes to variable.{h,cpp} because they fail on gcc7.2

Add [cuda] tag to tensor_options_cuda.cpp

Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks

Fix linker error in AutoGPU.cpp

Fix bad merge conflict in native_functions.yaml

Fixed caffe2/contrib/aten

Fix new window functions added to TensorFactories.cpp

* Removed torch::TensorOptions

Added code to generate wrapper functions for factory methods

Add implicit constructor from Backend to TensorOptions

Remove Var() from C++ API and use torch:: functions

Use torch:: functions more subtly in C++ API

Make AutoGPU::set_device more exception safe

Check status directly in DynamicCUDAHooksInterface

Rename AutoGPU to DeviceGuard

Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad

remove python_default_init: self.type()

Add back original factory functions, but with deprecation warnings

Disable DeviceGuard for a couple functions in ATen

Remove print statement

Fix DeviceGuard construction from undefined tensor

Fixing CUDA device compiler issues

Moved as many methods as possible into header files

Dont generate python functions for deprecated factories

Remove merge conflict artefact

Fix tensor_options_cuda.cpp

Fix set_requires_grad not being checked

Fix tensor_new.h

TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac

Fix bug in DeviceGuard.h

Missing includes

TEMPORARILY moving a few more methods into .cpp to see if it fixes windows

Fixing linker errors

* Fix up SummaryOps to use new factories

Undo device agnostic behavior of DeviceGuard

Use -1 instead of optional for default device index

Also move DeviceGuard methods into header

Fixes around device index after optional -> int32_t switch

Fix use of DeviceGuard in new_with_tensor_copy

Fix tensor_options.cpp

* Fix Type::copy(

* Remove test_non_float_params from ONNX tests

* Set requires_grad=False in ONNX tests that use ints

* Put layout/dtype/device on Tensor

* Post merge fixes

* Change behavior of DeviceGuard to match AutoGPU

* Fix C++ API integration tests

* Fix flip functions
2018-06-16 00:40:35 -07:00

110 lines
3.2 KiB
C++

#pragma once
#include "torch/csrc/python_headers.h"
#include "torch/csrc/Exceptions.h"
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
#include "torch/csrc/autograd/saved_variable.h"
#include "torch/csrc/utils/object_ptr.h"
#include <ATen/optional.h>
#include <vector>
#include <utility>
#include <memory>
namespace at {
struct DeviceGuard;
} // namespace at
namespace torch { namespace jit { struct Graph; }}
namespace torch { namespace autograd {
struct VariableInfo {
explicit VariableInfo(const Variable& var);
Variable zeros(at::DeviceGuard& device_guard) const;
at::Type* type;
int32_t device = -1;
std::vector<int64_t> size;
bool requires_grad;
};
// A Function which is implemented by a Python object (i.e., a THPFunction).
// Calls to 'apply' are forwarded to the Python method implementation.
struct PyFunction : public Function {
PyFunction(PyObject* obj) : obj(obj) {}
virtual variable_list apply(const variable_list& inputs) override;
variable_list legacy_apply(const variable_list& inputs);
virtual void release_variables() override;
virtual std::string name() const override;
virtual std::shared_ptr<Function> get_shared_ptr() override;
virtual bool is_traceable() override;
// THPFunction this Function is wrapping.
PyObject* obj;
};
/**
* Cast an object into a tuple, if it is not a tuple already. Returns true
* if the original object was not a tuple.
*/
inline bool ensure_tuple(THPObjectPtr& obj) {
if (PyTuple_Check(obj.get()))
return false;
PyObject *tuple = PyTuple_New(1);
if (!tuple) throw python_error();
PyTuple_SET_ITEM(tuple, 0, obj.release());
obj = tuple;
return true;
}
}} // namespace torch::autograd
struct THPFunction {
PyObject_HEAD
PyObject *needs_input_grad;
// Python tuple of tensors whose variables we should save. Set
// by Python with 'save_for_backward'. If nullptr, no tensors were
// saved.
PyObject *to_save;
// Python tuple of tensors which are not differentiable. Set by
// Python with 'mark_non_differentiable'. If nullptr, no tensors were
// non-differentiable.
PyObject *non_differentiable;
// Python tuple of tensors which had inplace updates in the forward()
// pass. Set by Python with 'mark_dirty'. If nullptr, no tensors were
// modified inplace.
PyObject *dirty_tensors;
std::vector<torch::autograd::VariableInfo> output_info;
std::vector<torch::autograd::VariableInfo> input_info;
std::vector<torch::autograd::SavedVariable> saved_variables;
// For each input, true if the input is a THPVariable
std::vector<bool> is_variable_input;
char has_freed_buffers;
char is_traced;
// The C++ wrapper for this Python function.
// See a comment in THPFunction_asFunction for details about this field.
torch::autograd::PyFunction cdata;
};
bool THPFunction_initModule(PyObject *module);
extern PyTypeObject THPFunctionType;
extern PyObject *THPFunctionClass;
// XXX: this function requires the GIL (it can have side effects).
std::shared_ptr<torch::autograd::PyFunction> THPFunction_asFunction(THPFunction* self);
inline bool THPFunction_Check(PyObject* obj) {
return PyObject_IsInstance(obj, (PyObject*)&THPFunctionType);
}