mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: This makes a few changes wrt Type, with the ultimate goal of removing Type from the public Methods/Functions. In particular: 1) Removes factory functions from Type, into TypeExtendedInterface. 2) sparse_coo_tensor is now a first class at:: namespace function, with TensorOptions overloads. 3) We move from Type-based sparse_coo_tensor dispatch to function-based. Note we still require a number of changes to get rid of tType in the public interface, in particular TensorOptions needs to support CUDA vs non-CUDA dispatch. That is coming in a future patch. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12025 Reviewed By: ezyang Differential Revision: D10017205 Pulled By: gchanan fbshipit-source-id: 00807a37b09ed33f0656aaa165bb925abb026320
38 lines
1.6 KiB
C++
38 lines
1.6 KiB
C++
#pragma once
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <torch/csrc/THP_export.h>
|
|
|
|
namespace torch {
|
|
|
|
// NOTE: This API is currently highly experimental and may change drastically
|
|
// in the near future.
|
|
|
|
// These functions provide a small wrapper around aten ensuring
|
|
// that we create tensors with type Variable rather than raw tensors
|
|
// when we create new tensors. We also provide a few accessors like requires_grad
|
|
// that make it easier to get to varible information when we have a at::Tensor
|
|
|
|
/// Returns a `TypeExtendedInterface` object for the given backend (e.g. `at::kCPU`) and
|
|
/// `ScalarType` (e.g. `at::kDouble`).
|
|
/// TODO: Eliminate this function as much as possible
|
|
THP_CLASS at::TypeExtendedInterface& getVariableType(at::Backend backend, at::ScalarType type);
|
|
|
|
/// Returns a `TypeExtendedInterface` object for the CPU backend and the given `ScalarType`
|
|
/// (e.g. `at::kDouble`). Equivalent to `getVariableType(kCPU, type)`.
|
|
/// TODO: Eliminate this function as much as possible
|
|
THP_CLASS at::TypeExtendedInterface& CPU(at::ScalarType type);
|
|
|
|
/// Returns a `TypeExtendedInterface` object for the CUDA backend and the given `ScalarType`
|
|
/// (e.g. `at::kDouble`). Equivalent to `getVariableType(kCUDA, type)`.
|
|
/// TODO: Eliminate this function as much as possible
|
|
THP_CLASS at::TypeExtendedInterface& CUDA(at::ScalarType type);
|
|
|
|
/// Sets the `requires_grad` property of the given `Tensor`.
|
|
THP_CLASS void set_requires_grad(at::Tensor& tensor, bool requires_grad) noexcept;
|
|
|
|
/// Returns the `requires_grad` of the given `Tensor`.
|
|
THP_CLASS bool requires_grad(const at::Tensor& tensor) noexcept;
|
|
|
|
} // namespace torch
|