mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63612 This makes Tensor inherit from a new class TensorBase, that provides a subset of Tensor that doesn't directly depend on native_functions.yaml. Code that only includes TensorBase.h with thus not need to be rebuilt every time someone changes an operator signature. Making `Tensor` inherit from this class means that `const TensorBase&` parameters will be callable with an ordinary `Tensor`. I've also made `Tensor` constructible and assignable from `TensorBase` to minimize friction in code mixing the two types. To help enforce that `Tensor.h` and `Functions.h` aren't accidentally included, I've added an error into `Operators.h` if `TORCH_ASSERT_NO_OPERATORS` is defined. We can either set this in the build system for certain folders, or just define it at the top of any file. I've also included an example of manually special-casing the commonly used `contiguous` operator. The inline function's slow path defers to `TensorBase::__dispatch_contiguous` which is defined in `Tensor.cpp`. I've made it so `OptionalTensorRef` is constructible from `TensorBase`, so I can materialize a `Tensor` for use in dispatch without actually increasing its refcount. Test Plan: Imported from OSS Reviewed By: gchanan Differential Revision: D30728580 Pulled By: ezyang fbshipit-source-id: 2cbc8eee08043382ee6904ea8e743b1286921c03
18 lines
518 B
C++
18 lines
518 B
C++
#pragma once
|
|
#include <torch/csrc/autograd/function_hook.h>
|
|
#include <functional>
|
|
#include <memory>
|
|
|
|
namespace torch { namespace autograd {
|
|
|
|
using hooks_list = std::vector<std::function<at::TensorBase(const at::TensorBase&)>>;
|
|
|
|
struct CppFunctionPreHook : public FunctionPreHook {
|
|
CppFunctionPreHook(const std::shared_ptr<hooks_list> &hooks, int value_idx);
|
|
variable_list operator()(const variable_list& values) override;
|
|
|
|
std::shared_ptr<hooks_list> hooks_;
|
|
int value_idx_;
|
|
};
|
|
}} // namespace torch::autograd
|