mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/56830 Opt into formatting on GitHub and format everything. This is a trial run before turning on formatting for more and eventually all of the codebase. Test Plan: CI Reviewed By: zertosh Differential Revision: D27979080 fbshipit-source-id: a80f0c48691c08ae8ca0af06377b87e6a2351151
45 lines
1.0 KiB
C++
45 lines
1.0 KiB
C++
#pragma once
|
|
|
|
#include <c10/core/Backend.h>
|
|
#include <c10/core/Device.h>
|
|
#include <c10/core/Layout.h>
|
|
#include <c10/core/ScalarType.h>
|
|
|
|
namespace c10 {
|
|
|
|
struct TensorOptions;
|
|
|
|
/// Like TensorOptions, but all fields are guaranteed to be filled.
|
|
struct DefaultTensorOptions {
|
|
DefaultTensorOptions() = default;
|
|
|
|
caffe2::TypeMeta dtype() const noexcept {
|
|
return dtype_;
|
|
}
|
|
Device device() const noexcept {
|
|
return device_;
|
|
}
|
|
Layout layout() const noexcept {
|
|
return layout_;
|
|
}
|
|
bool requires_grad() const noexcept {
|
|
return requires_grad_;
|
|
}
|
|
|
|
// Defined in TensorOptions.h
|
|
inline DefaultTensorOptions& merge(const TensorOptions& options);
|
|
|
|
private:
|
|
caffe2::TypeMeta dtype_ = caffe2::TypeMeta::Make<float>(); // 64-bit
|
|
Device device_ = at::kCPU; // 32-bit
|
|
Layout layout_ = at::kStrided; // 8-bit
|
|
bool requires_grad_ = false; // 8-bit
|
|
};
|
|
|
|
inline const DefaultTensorOptions& getDefaultTensorOptions() {
|
|
static const auto options = DefaultTensorOptions();
|
|
return options;
|
|
}
|
|
|
|
} // namespace c10
|