pytorch/torch/csrc/autograd/functions/basic_ops.h
Sam Gross 1290e586fb Use at::Tensor based autograd Variable (#2676)
Variable is now a subclass of at::Tensor backed by a VariableImpl* pImpl. The implementation of the ATen functions is defined in the auto-generated VariableType.h/cpp file.

Currently, only functions which fall through to the base type, such as sizes() and isCuda() are implemented. Differentiable ops like add() and mul() will be added in a subsequent PR.
2017-09-12 11:36:01 -04:00

77 lines
1.7 KiB
C++

#pragma once
#include <Python.h>
#include <memory>
#include <string>
#include "torch/csrc/autograd/function.h"
#include "torch/csrc/autograd/variable.h"
namespace torch { namespace autograd {
struct Error : public Function {
Error(std::string msg, FunctionFlags&& flags)
: Function(std::move(flags))
, msg(std::move(msg)) {}
Error(std::string msg)
: msg(std::move(msg)) {}
virtual variable_list apply(const variable_list& inputs) override;
std::string msg;
};
// Identity in forward, Error in backward. Used to implement @once_differentiable
struct DelayedError : public Function {
DelayedError(std::string msg)
: msg(std::move(msg)) {};
virtual variable_list apply(const variable_list& inputs) override;
std::string msg;
};
struct GraphRoot : public Function {
GraphRoot(function_list functions, variable_list inputs)
: outputs(std::move(inputs)) {
next_functions = std::move(functions);
is_executable = true;
};
virtual variable_list apply(const variable_list& inputs) {
return outputs;
}
variable_list outputs;
};
struct Add : public ForwardFunction<> {
Add() {}
virtual variable_list apply(const variable_list& inputs) override;
};
struct AddBackward : public Function {
AddBackward(FunctionFlags&& flags)
: Function(std::move(flags)) {}
virtual variable_list apply(const variable_list& gradOutputs) override;
};
struct Mul : public ForwardFunction<> {
Mul() {}
virtual variable_list apply(const variable_list& inputs) override;
};
struct MulBackward : public Function {
MulBackward(FunctionFlags&& flags)
: Function(std::move(flags)) {}
virtual variable_list apply(const variable_list& gradOutputs) override;
};
}}