pytorch/torch/csrc/variable_tensor_functions.h
Gregory Chanan 0947712e5d Move Factory functions from Type to TypeExtendedInterface. (#12025)
Summary:
This makes a few changes wrt Type, with the ultimate goal of removing Type from the public Methods/Functions.  In particular:
1) Removes factory functions from Type, into TypeExtendedInterface.
2) sparse_coo_tensor is now a first class at:: namespace function, with TensorOptions overloads.
3) We move from Type-based sparse_coo_tensor dispatch to function-based.

Note we still require a number of changes to get rid of tType in the public interface, in particular TensorOptions needs to support CUDA vs non-CUDA dispatch.  That is coming in a future patch.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12025

Reviewed By: ezyang

Differential Revision: D10017205

Pulled By: gchanan

fbshipit-source-id: 00807a37b09ed33f0656aaa165bb925abb026320
2018-09-25 09:40:17 -07:00

38 lines
1.6 KiB
C++

#pragma once
#include <ATen/ATen.h>
#include <torch/csrc/THP_export.h>
namespace torch {
// NOTE: This API is currently highly experimental and may change drastically
// in the near future.
// These functions provide a small wrapper around aten ensuring
// that we create tensors with type Variable rather than raw tensors
// when we create new tensors. We also provide a few accessors like requires_grad
// that make it easier to get to varible information when we have a at::Tensor
/// Returns a `TypeExtendedInterface` object for the given backend (e.g. `at::kCPU`) and
/// `ScalarType` (e.g. `at::kDouble`).
/// TODO: Eliminate this function as much as possible
THP_CLASS at::TypeExtendedInterface& getVariableType(at::Backend backend, at::ScalarType type);
/// Returns a `TypeExtendedInterface` object for the CPU backend and the given `ScalarType`
/// (e.g. `at::kDouble`). Equivalent to `getVariableType(kCPU, type)`.
/// TODO: Eliminate this function as much as possible
THP_CLASS at::TypeExtendedInterface& CPU(at::ScalarType type);
/// Returns a `TypeExtendedInterface` object for the CUDA backend and the given `ScalarType`
/// (e.g. `at::kDouble`). Equivalent to `getVariableType(kCUDA, type)`.
/// TODO: Eliminate this function as much as possible
THP_CLASS at::TypeExtendedInterface& CUDA(at::ScalarType type);
/// Sets the `requires_grad` property of the given `Tensor`.
THP_CLASS void set_requires_grad(at::Tensor& tensor, bool requires_grad) noexcept;
/// Returns the `requires_grad` of the given `Tensor`.
THP_CLASS bool requires_grad(const at::Tensor& tensor) noexcept;
} // namespace torch