pytorch/torch/csrc/autograd/cpp_hook.h
Peter Bell d701357d92 Factor out TensorBase that doesn't depend on native operators (#63612)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63612

This makes Tensor inherit from a new class TensorBase, that provides a subset of Tensor that doesn't
directly depend on native_functions.yaml. Code that only includes TensorBase.h with thus not need to
be rebuilt every time someone changes an operator signature.

Making `Tensor` inherit from this class means that `const TensorBase&` parameters will be callable
with an ordinary `Tensor`. I've also made `Tensor` constructible and assignable from `TensorBase` to
minimize friction in code mixing the two types.

To help enforce that `Tensor.h` and `Functions.h` aren't accidentally included, I've added an error
into `Operators.h` if `TORCH_ASSERT_NO_OPERATORS` is defined. We can either set this in the build
system for certain folders, or just define it at the top of any file.

I've also included an example of manually special-casing the commonly used `contiguous` operator.
The inline function's slow path defers to `TensorBase::__dispatch_contiguous` which is defined in
`Tensor.cpp`. I've made it so `OptionalTensorRef` is constructible from `TensorBase`, so I can
materialize a `Tensor` for use in dispatch without actually increasing its refcount.

Test Plan: Imported from OSS

Reviewed By: gchanan

Differential Revision: D30728580

Pulled By: ezyang

fbshipit-source-id: 2cbc8eee08043382ee6904ea8e743b1286921c03
2021-09-08 13:28:54 -07:00

18 lines
518 B
C++

#pragma once
#include <torch/csrc/autograd/function_hook.h>
#include <functional>
#include <memory>
namespace torch { namespace autograd {
using hooks_list = std::vector<std::function<at::TensorBase(const at::TensorBase&)>>;
struct CppFunctionPreHook : public FunctionPreHook {
CppFunctionPreHook(const std::shared_ptr<hooks_list> &hooks, int value_idx);
variable_list operator()(const variable_list& values) override;
std::shared_ptr<hooks_list> hooks_;
int value_idx_;
};
}} // namespace torch::autograd