mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Move exception to C10 (#12354)
Summary: There are still a few work to be done: - Move logging and unify AT_WARN with LOG(ERROR). - A few header files are still being plumbed through, need cleaning. - caffe2::EnforceNotMet aliasing is not done yet. - need to unify the macros. See c10/util/Exception.h This is mainly a codemod and not causing functional changes. If you find your job failing and trace back to this diff, usually it can be fixed by the following approaches: (1) add //caffe2/c10:c10 to your dependency (or transitive dependency). (2) change objects such as at::Error, at::Optional to the c10 namespace. (3) change functions to the c10 namespace. Especially, caffe2::MakeString is not overridden by the unified c10::str function. Nothing else changes. Please kindly consider not reverting this diff - it involves multiple rounds of rebasing and the fix is usually simple. Contact jiayq@ or AI Platform Dev for details. Pull Request resolved: https://github.com/pytorch/pytorch/pull/12354 Reviewed By: orionr Differential Revision: D10238910 Pulled By: Yangqing fbshipit-source-id: 7794d5bf2797ab0ca6ebaccaa2f7ebbd50ff8f32
This commit is contained in:
parent
aef8cadb9a
commit
713e706618
|
|
@ -325,7 +325,7 @@ Here are a few well known pitfalls and workarounds:
|
||||||
catch all of these problems: stay vigilant to the possibility that
|
catch all of these problems: stay vigilant to the possibility that
|
||||||
your crash is due to a real memory problem.
|
your crash is due to a real memory problem.
|
||||||
|
|
||||||
* (NVCC) `at::optional` does not work when used from device code. Don't use
|
* (NVCC) `c10::optional` does not work when used from device code. Don't use
|
||||||
it from kernels. Upstream issue: https://github.com/akrzemi1/Optional/issues/58
|
it from kernels. Upstream issue: https://github.com/akrzemi1/Optional/issues/58
|
||||||
and our local issue #10329.
|
and our local issue #10329.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/core/Device.h>
|
|
||||||
#include <ATen/core/ScalarType.h>
|
|
||||||
#include <ATen/Tensor.h>
|
#include <ATen/Tensor.h>
|
||||||
|
#include <ATen/core/Device.h>
|
||||||
#include <ATen/core/Error.h>
|
#include <ATen/core/Error.h>
|
||||||
#include <ATen/core/optional.h>
|
#include <ATen/core/ScalarType.h>
|
||||||
#include <ATen/detail/CUDAHooksInterface.h>
|
#include <ATen/detail/CUDAHooksInterface.h>
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
|
|
||||||
|
|
@ -29,7 +29,7 @@ struct DeviceGuard {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
explicit DeviceGuard(optional<Device> device_opt) {
|
explicit DeviceGuard(c10::optional<Device> device_opt) {
|
||||||
if (device_opt.has_value() && device_opt.value().is_cuda()) {
|
if (device_opt.has_value() && device_opt.value().is_cuda()) {
|
||||||
set_index(device_opt.value().index());
|
set_index(device_opt.value().index());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -12,7 +12,7 @@ namespace at {
|
||||||
static std::vector<int64_t> infer_size(IntList shape, int64_t numel) {
|
static std::vector<int64_t> infer_size(IntList shape, int64_t numel) {
|
||||||
auto res = shape.vec();
|
auto res = shape.vec();
|
||||||
int64_t newsize = 1;
|
int64_t newsize = 1;
|
||||||
auto infer_dim = at::optional<int64_t>();
|
auto infer_dim = c10::optional<int64_t>();
|
||||||
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
|
for (int64_t dim = 0, ndim = shape.size(); dim != ndim; dim++) {
|
||||||
if (shape[dim] == -1) {
|
if (shape[dim] == -1) {
|
||||||
if (infer_dim) {
|
if (infer_dim) {
|
||||||
|
|
|
||||||
|
|
@ -1,28 +1,2 @@
|
||||||
#pragma once
|
#include "c10/util/Backtrace.h"
|
||||||
|
#include "c10/util/Type.h"
|
||||||
#include <cstddef>
|
|
||||||
#include <string>
|
|
||||||
#include <typeinfo>
|
|
||||||
|
|
||||||
#include <ATen/core/Macros.h>
|
|
||||||
|
|
||||||
namespace at {
|
|
||||||
/// Utility to demangle a C++ symbol name.
|
|
||||||
CAFFE2_API std::string demangle(const char* name);
|
|
||||||
|
|
||||||
/// Returns the printable name of the type.
|
|
||||||
template <typename T>
|
|
||||||
inline const char* demangle_type() {
|
|
||||||
#ifdef __GXX_RTTI
|
|
||||||
static const std::string name = demangle(typeid(T).name());
|
|
||||||
return name.c_str();
|
|
||||||
#else // __GXX_RTTI
|
|
||||||
return "(RTTI disabled, cannot show name)";
|
|
||||||
#endif // __GXX_RTTI
|
|
||||||
}
|
|
||||||
|
|
||||||
CAFFE2_API std::string get_backtrace(
|
|
||||||
size_t frames_to_skip = 0,
|
|
||||||
size_t maximum_number_of_frames = 64,
|
|
||||||
bool skip_python_frames = true);
|
|
||||||
} // namespace at
|
|
||||||
|
|
|
||||||
|
|
@ -1,235 +1,2 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
#include "c10/util/Exception.h"
|
||||||
#include <ATen/core/Macros.h>
|
|
||||||
#include <ATen/core/optional.h>
|
|
||||||
|
|
||||||
#include <cstddef>
|
|
||||||
#include <exception>
|
|
||||||
#include <ostream>
|
|
||||||
#include <sstream>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#if defined(_MSC_VER) && _MSC_VER <= 1900
|
|
||||||
#define __func__ __FUNCTION__
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace at {
|
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
// Obtains the base name from a full path.
|
|
||||||
CAFFE2_API std::string StripBasename(const std::string& full_path);
|
|
||||||
|
|
||||||
inline std::ostream& _str(std::ostream& ss) {
|
|
||||||
return ss;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline std::ostream& _str(std::ostream& ss, const T& t) {
|
|
||||||
ss << t;
|
|
||||||
return ss;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename... Args>
|
|
||||||
inline std::ostream& _str(std::ostream& ss, const T& t, const Args&... args) {
|
|
||||||
return _str(_str(ss, t), args...);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
// Convert a list of string-like arguments into a single string.
|
|
||||||
template <typename... Args>
|
|
||||||
inline std::string str(const Args&... args) {
|
|
||||||
std::ostringstream ss;
|
|
||||||
detail::_str(ss, args...);
|
|
||||||
return ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Specializations for already-a-string types.
|
|
||||||
template <>
|
|
||||||
inline std::string str(const std::string& str) {
|
|
||||||
return str;
|
|
||||||
}
|
|
||||||
inline std::string str(const char* c_str) {
|
|
||||||
return c_str;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Represents a location in source code (for debugging).
|
|
||||||
struct CAFFE2_API SourceLocation {
|
|
||||||
const char* function;
|
|
||||||
const char* file;
|
|
||||||
uint32_t line;
|
|
||||||
};
|
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc);
|
|
||||||
|
|
||||||
/// The primary ATen error class.
|
|
||||||
/// Provides a complete error message with source location information via
|
|
||||||
/// `what()`, and a more concise message via `what_without_backtrace()`. Should
|
|
||||||
/// primarily be used with the `AT_ERROR` macro.
|
|
||||||
///
|
|
||||||
/// NB: at::Error is handled specially by the default torch to suppress the
|
|
||||||
/// backtrace, see torch/csrc/Exceptions.h
|
|
||||||
class CAFFE2_API Error : public std::exception {
|
|
||||||
std::vector<std::string> msg_stack_;
|
|
||||||
std::string backtrace_;
|
|
||||||
|
|
||||||
// These two are derived fields from msg_stack_ and backtrace_, but we need
|
|
||||||
// fields for the strings so that we can return a const char* (as the
|
|
||||||
// signature of std::exception requires).
|
|
||||||
std::string msg_;
|
|
||||||
std::string msg_without_backtrace_;
|
|
||||||
|
|
||||||
// This is a little debugging trick: you can stash a relevant pointer
|
|
||||||
// in caller, and then when you catch the exception, you can compare
|
|
||||||
// against pointers you have on hand to get more information about
|
|
||||||
// where the exception came from. In Caffe2, this is used to figure
|
|
||||||
// out which operator raised an exception.
|
|
||||||
const void* caller_;
|
|
||||||
|
|
||||||
public:
|
|
||||||
Error(
|
|
||||||
const std::string& msg,
|
|
||||||
const std::string& backtrace,
|
|
||||||
const void* caller = nullptr);
|
|
||||||
Error(SourceLocation source_location, const std::string& msg);
|
|
||||||
Error(
|
|
||||||
const char* file,
|
|
||||||
const int line,
|
|
||||||
const char* condition,
|
|
||||||
const std::string& msg,
|
|
||||||
const std::string& backtrace,
|
|
||||||
const void* caller = nullptr);
|
|
||||||
|
|
||||||
void AppendMessage(const std::string& msg);
|
|
||||||
|
|
||||||
// Compute the full message from msg_ and msg_without_backtrace_
|
|
||||||
// TODO: Maybe this should be private
|
|
||||||
std::string msg() const;
|
|
||||||
std::string msg_without_backtrace() const;
|
|
||||||
|
|
||||||
const std::vector<std::string>& msg_stack() const {
|
|
||||||
return msg_stack_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the complete error message, including the source location.
|
|
||||||
const char* what() const noexcept override {
|
|
||||||
return msg_.c_str();
|
|
||||||
}
|
|
||||||
|
|
||||||
const void* caller() const noexcept {
|
|
||||||
return caller_;
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns only the error message string, without source location.
|
|
||||||
const char* what_without_backtrace() const noexcept {
|
|
||||||
return msg_without_backtrace_.c_str();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
class CAFFE2_API Warning {
|
|
||||||
using handler_t =
|
|
||||||
void (*)(const SourceLocation& source_location, const char* msg);
|
|
||||||
|
|
||||||
public:
|
|
||||||
/// Issue a warning with a given message. Dispatched to the current
|
|
||||||
/// warning handler.
|
|
||||||
static void warn(SourceLocation source_location, std::string msg);
|
|
||||||
|
|
||||||
/// Sets the global warning handler. This is not thread-safe, so it should
|
|
||||||
/// generally be called once during initialization.
|
|
||||||
static void set_warning_handler(handler_t handler);
|
|
||||||
|
|
||||||
/// The default warning handler. Prints the message to stderr.
|
|
||||||
static void print_warning(
|
|
||||||
const SourceLocation& source_location,
|
|
||||||
const char* msg);
|
|
||||||
|
|
||||||
private:
|
|
||||||
static handler_t warning_handler_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// A utility function to return an exception std::string by prepending its
|
|
||||||
// exception type before its what() content
|
|
||||||
CAFFE2_API std::string GetExceptionString(const std::exception& e);
|
|
||||||
|
|
||||||
CAFFE2_API void ThrowEnforceNotMet(
|
|
||||||
const char* file,
|
|
||||||
const int line,
|
|
||||||
const char* condition,
|
|
||||||
const std::string& msg,
|
|
||||||
const void* caller);
|
|
||||||
|
|
||||||
} // namespace at
|
|
||||||
|
|
||||||
// TODO: variants that print the expression tested and thus don't require
|
|
||||||
// strings
|
|
||||||
|
|
||||||
#define AT_ENFORCE(condition, ...) \
|
|
||||||
do { \
|
|
||||||
if (!(condition)) { \
|
|
||||||
at::ThrowEnforceNotMet( \
|
|
||||||
__FILE__, __LINE__, #condition, at::str(__VA_ARGS__), nullptr); \
|
|
||||||
} \
|
|
||||||
} while (false)
|
|
||||||
|
|
||||||
#define AT_ENFORCE_WITH_CALLER(condition, ...) \
|
|
||||||
do { \
|
|
||||||
if (!(condition)) { \
|
|
||||||
at::ThrowEnforceNotMet( \
|
|
||||||
__FILE__, \
|
|
||||||
__LINE__, \
|
|
||||||
#condition, \
|
|
||||||
at::str(__VA_ARGS__), \
|
|
||||||
this); \
|
|
||||||
} \
|
|
||||||
} while (false)
|
|
||||||
|
|
||||||
#define AT_ENFORCE_GT(x, y, ...) \
|
|
||||||
AT_ENFORCE(x > y, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define AT_ENFORCE_GE(x, y, ...) \
|
|
||||||
AT_ENFORCE(x >= y, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define AT_ENFORCE_LT(x, y, ...) \
|
|
||||||
AT_ENFORCE(x < y, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define AT_ENFORCE_GE_WITH_CALLER(x, y, ...) \
|
|
||||||
AT_ENFORCE_WITH_CALLER(x >= y, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define AT_ENFORCE_EQ_WITH_CALLER(x, y, ...) \
|
|
||||||
AT_ENFORCE_WITH_CALLER(x == y, __VA_ARGS__)
|
|
||||||
|
|
||||||
#define AT_ERROR(...) \
|
|
||||||
throw at::Error({__func__, __FILE__, __LINE__}, at::str(__VA_ARGS__))
|
|
||||||
|
|
||||||
#define AT_WARN(...) \
|
|
||||||
at::Warning::warn({__func__, __FILE__, __LINE__}, at::str(__VA_ARGS__))
|
|
||||||
|
|
||||||
#define AT_ASSERT(cond) \
|
|
||||||
if (!(cond)) { \
|
|
||||||
AT_ERROR( \
|
|
||||||
#cond " ASSERT FAILED at ", \
|
|
||||||
__FILE__, \
|
|
||||||
":", \
|
|
||||||
__LINE__, \
|
|
||||||
", please report a bug to PyTorch."); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define AT_ASSERTM(cond, ...) \
|
|
||||||
if (!(cond)) { \
|
|
||||||
AT_ERROR(at::str( \
|
|
||||||
#cond, \
|
|
||||||
" ASSERT FAILED at ", \
|
|
||||||
__FILE__, \
|
|
||||||
":", \
|
|
||||||
__LINE__, \
|
|
||||||
", please report a bug to PyTorch. ", \
|
|
||||||
__VA_ARGS__)); \
|
|
||||||
}
|
|
||||||
|
|
||||||
#define AT_CHECK(cond, ...) \
|
|
||||||
if (!(cond)) { \
|
|
||||||
AT_ERROR(at::str(__VA_ARGS__)); \
|
|
||||||
}
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <functional>
|
|
||||||
#include <ATen/core/Macros.h>
|
#include <ATen/core/Macros.h>
|
||||||
|
#include <functional>
|
||||||
|
#include <utility>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
#include <ATen/core/OptionsGuard.h>
|
#include <ATen/core/OptionsGuard.h>
|
||||||
#include <ATen/core/optional.h>
|
|
||||||
#include <ATen/core/Layout.h>
|
#include <ATen/core/Layout.h>
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
||||||
|
|
@ -10,7 +10,7 @@ namespace at {
|
||||||
#if !AT_MOBILE && !defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
|
#if !AT_MOBILE && !defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY)
|
||||||
|
|
||||||
DefaultTensorOptions& mutateDefaultTensorOptions() {
|
DefaultTensorOptions& mutateDefaultTensorOptions() {
|
||||||
static thread_local at::optional<DefaultTensorOptions> options;
|
static thread_local c10::optional<DefaultTensorOptions> options;
|
||||||
/// This is an optional because of compiler bugs that mis-initialize static
|
/// This is an optional because of compiler bugs that mis-initialize static
|
||||||
/// thread local variables. The workaround is lazy initialization, i.e.
|
/// thread local variables. The workaround is lazy initialization, i.e.
|
||||||
/// `getDefaultTensorOptions()` will initialize the `options` to a proper
|
/// `getDefaultTensorOptions()` will initialize the `options` to a proper
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ATen/core/Device.h"
|
#include "ATen/core/Device.h"
|
||||||
|
#include "ATen/core/Error.h"
|
||||||
#include "ATen/core/Layout.h"
|
#include "ATen/core/Layout.h"
|
||||||
#include "ATen/core/Scalar.h"
|
#include "ATen/core/Scalar.h"
|
||||||
#include "ATen/core/ScalarType.h"
|
#include "ATen/core/ScalarType.h"
|
||||||
|
|
@ -8,9 +9,8 @@
|
||||||
#include "ATen/core/Storage.h"
|
#include "ATen/core/Storage.h"
|
||||||
#include "ATen/core/TensorAccessor.h"
|
#include "ATen/core/TensorAccessor.h"
|
||||||
#include "ATen/core/TensorImpl.h"
|
#include "ATen/core/TensorImpl.h"
|
||||||
#include "ATen/core/optional.h"
|
|
||||||
#include "ATen/core/UndefinedTensorImpl.h"
|
#include "ATen/core/UndefinedTensorImpl.h"
|
||||||
#include "ATen/core/Error.h"
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
struct Generator;
|
struct Generator;
|
||||||
|
|
@ -241,7 +241,7 @@ public:
|
||||||
|
|
||||||
/// Computes the gradient of current tensor w.r.t. graph leaves.
|
/// Computes the gradient of current tensor w.r.t. graph leaves.
|
||||||
void backward(
|
void backward(
|
||||||
at::optional<Tensor> gradient = at::nullopt,
|
c10::optional<Tensor> gradient = c10::nullopt,
|
||||||
bool keep_graph = false,
|
bool keep_graph = false,
|
||||||
bool create_graph = false);
|
bool create_graph = false);
|
||||||
|
|
||||||
|
|
@ -656,7 +656,7 @@ struct CAFFE2_API WeakTensor {
|
||||||
WeakTensor(const Tensor& t) : weak_impl_(t.impl_) {}
|
WeakTensor(const Tensor& t) : weak_impl_(t.impl_) {}
|
||||||
|
|
||||||
// XXX: this can return undefined tensors
|
// XXX: this can return undefined tensors
|
||||||
// Ideally it would be at::optional<Tensor>, but MSVC is too cool for that
|
// Ideally it would be c10::optional<Tensor>, but MSVC is too cool for that
|
||||||
Tensor lock() const {
|
Tensor lock() const {
|
||||||
return Tensor(weak_impl_.lock());
|
return Tensor(weak_impl_.lock());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
#include <ATen/core/TensorImpl.h>
|
#include <ATen/core/TensorImpl.h>
|
||||||
|
|
||||||
#include <ATen/core/optional.h>
|
|
||||||
#include <ATen/core/Backend.h>
|
#include <ATen/core/Backend.h>
|
||||||
#include <ATen/core/WrapDimMinimal.h>
|
|
||||||
#include <ATen/core/LegacyTypeDispatch.h>
|
#include <ATen/core/LegacyTypeDispatch.h>
|
||||||
|
#include <ATen/core/WrapDimMinimal.h>
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
#include <ATen/core/VariableHooksInterface.h>
|
#include <ATen/core/VariableHooksInterface.h>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@
|
||||||
#include <ATen/core/TensorTypeId.h>
|
#include <ATen/core/TensorTypeId.h>
|
||||||
#include <ATen/core/TensorTypeIdRegistration.h>
|
#include <ATen/core/TensorTypeIdRegistration.h>
|
||||||
#include <ATen/core/context_base.h>
|
#include <ATen/core/context_base.h>
|
||||||
#include <ATen/core/optional.h>
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
#include "c10/util/Flags.h"
|
#include "c10/util/Flags.h"
|
||||||
|
|
||||||
|
|
@ -65,7 +65,7 @@ inline int64_t size_from_dim_(int k, IntList dims) {
|
||||||
|
|
||||||
// Product of all dims up to k (not including dims[k])
|
// Product of all dims up to k (not including dims[k])
|
||||||
inline int64_t size_to_dim_(int k, IntList dims) {
|
inline int64_t size_to_dim_(int k, IntList dims) {
|
||||||
AT_ENFORCE((unsigned)k <= dims.size());
|
AT_ASSERT((unsigned)k <= dims.size());
|
||||||
int64_t r = 1;
|
int64_t r = 1;
|
||||||
for (int i = 0; i < k; ++i) {
|
for (int i = 0; i < k; ++i) {
|
||||||
r *= dims[i];
|
r *= dims[i];
|
||||||
|
|
@ -75,7 +75,7 @@ inline int64_t size_to_dim_(int k, IntList dims) {
|
||||||
|
|
||||||
// Product of all dims between k and l (not including dims[k] and dims[l])
|
// Product of all dims between k and l (not including dims[k] and dims[l])
|
||||||
inline int64_t size_between_dim_(int k, int l, IntList dims) {
|
inline int64_t size_between_dim_(int k, int l, IntList dims) {
|
||||||
AT_ENFORCE((unsigned)l < dims.size());
|
AT_ASSERT((unsigned)l < dims.size());
|
||||||
int64_t r = 1;
|
int64_t r = 1;
|
||||||
if (k < l) {
|
if (k < l) {
|
||||||
for (int i = k + 1; i < l; ++i) {
|
for (int i = k + 1; i < l; ++i) {
|
||||||
|
|
@ -91,8 +91,8 @@ inline int64_t size_between_dim_(int k, int l, IntList dims) {
|
||||||
|
|
||||||
// Wrap around axis_index if it is negative, s.t., -1 is the last dim
|
// Wrap around axis_index if it is negative, s.t., -1 is the last dim
|
||||||
inline int canonical_axis_index_(int axis_index, int ndims) {
|
inline int canonical_axis_index_(int axis_index, int ndims) {
|
||||||
AT_ENFORCE_GE(axis_index, -ndims);
|
AT_ASSERT(axis_index >= -ndims);
|
||||||
AT_ENFORCE_LT(axis_index, ndims);
|
AT_ASSERT(axis_index < ndims);
|
||||||
if (axis_index < 0) {
|
if (axis_index < 0) {
|
||||||
return axis_index + ndims;
|
return axis_index + ndims;
|
||||||
}
|
}
|
||||||
|
|
@ -264,12 +264,12 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline T * data() const {
|
inline T * data() const {
|
||||||
AT_ASSERT(!is_variable());
|
AT_ASSERT(!is_variable());
|
||||||
AT_ENFORCE_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
storage_.data() || numel_ == 0,
|
storage_.data() || numel_ == 0,
|
||||||
"The tensor has a non-zero number of elements, but its data is not allocated yet. "
|
"The tensor has a non-zero number of elements, but its data is not allocated yet. "
|
||||||
"Caffe2 uses a lazy allocation, so you will need to call "
|
"Caffe2 uses a lazy allocation, so you will need to call "
|
||||||
"mutable_data() or raw_mutable_data() to actually allocate memory.");
|
"mutable_data() or raw_mutable_data() to actually allocate memory.");
|
||||||
AT_ENFORCE_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
storage_.IsType<T>(),
|
storage_.IsType<T>(),
|
||||||
"Tensor type mismatch, caller expects elements to be ",
|
"Tensor type mismatch, caller expects elements to be ",
|
||||||
caffe2::TypeMeta::TypeName<T>(),
|
caffe2::TypeMeta::TypeName<T>(),
|
||||||
|
|
@ -282,7 +282,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
|
|
||||||
inline void* data() const {
|
inline void* data() const {
|
||||||
AT_ASSERT(!is_variable());
|
AT_ASSERT(!is_variable());
|
||||||
AT_ENFORCE_WITH_CALLER(storage_.data() || numel_ == 0);
|
AT_ASSERT(storage_.data() || numel_ == 0);
|
||||||
return static_cast<void*>(
|
return static_cast<void*>(
|
||||||
static_cast<char*>(storage_.data()) +
|
static_cast<char*>(storage_.data()) +
|
||||||
data_type_.itemsize() * storage_offset_);
|
data_type_.itemsize() * storage_offset_);
|
||||||
|
|
@ -421,7 +421,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (data_type_ != src.dtype()) {
|
if (data_type_ != src.dtype()) {
|
||||||
AT_ENFORCE_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
src.is_contiguous(),
|
src.is_contiguous(),
|
||||||
"Right now only copy of contiguous source Tensor is supported.");
|
"Right now only copy of contiguous source Tensor is supported.");
|
||||||
storage_ = at::Storage(GetDevice(), src.dtype());
|
storage_ = at::Storage(GetDevice(), src.dtype());
|
||||||
|
|
@ -439,10 +439,10 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
Resize(src.sizes());
|
Resize(src.sizes());
|
||||||
if (numel() > 0) {
|
if (numel() > 0) {
|
||||||
if (data_type_.copy()) {
|
if (data_type_.copy()) {
|
||||||
AT_ENFORCE(
|
AT_ASSERTM(
|
||||||
device_type() == ::at::DeviceType::CPU,
|
device_type() == ::at::DeviceType::CPU,
|
||||||
"In CopyFrom source and dest tensors must both be CPU for meta copy");
|
"In CopyFrom source and dest tensors must both be CPU for meta copy");
|
||||||
AT_ENFORCE(
|
AT_ASSERTM(
|
||||||
src.device_type() == ::at::DeviceType::CPU,
|
src.device_type() == ::at::DeviceType::CPU,
|
||||||
"In CopyFrom source and dest tensors must both be CPU for meta copy");
|
"In CopyFrom source and dest tensors must both be CPU for meta copy");
|
||||||
data_type_.copy()(src.data(), raw_mutable_data(data_type_), numel());
|
data_type_.copy()(src.data(), raw_mutable_data(data_type_), numel());
|
||||||
|
|
@ -459,7 +459,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
raw_mutable_data(data_type_),
|
raw_mutable_data(data_type_),
|
||||||
device_type());
|
device_type());
|
||||||
} else {
|
} else {
|
||||||
AT_ENFORCE(
|
AT_ASSERTM(
|
||||||
context->device_type() == src.device_type(),
|
context->device_type() == src.device_type(),
|
||||||
"Type for provided context does not match the type of source");
|
"Type for provided context does not match the type of source");
|
||||||
context->CopyBytesToDevice(
|
context->CopyBytesToDevice(
|
||||||
|
|
@ -489,10 +489,9 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
* complexity.
|
* complexity.
|
||||||
*/
|
*/
|
||||||
void Extend(int64_t num, float growthPct, at::BaseContext* context) {
|
void Extend(int64_t num, float growthPct, at::BaseContext* context) {
|
||||||
AT_ENFORCE_GE_WITH_CALLER(sizes_.size(), 1u);
|
AT_ASSERT(sizes_.size() >= 1u);
|
||||||
AT_ENFORCE_GE_WITH_CALLER(
|
AT_ASSERTM(num >= 0, "`num` must be non-negative for Extend");
|
||||||
num, 0, "`num` must be non-negative for Extend");
|
AT_ASSERTM(
|
||||||
AT_ENFORCE_WITH_CALLER(
|
|
||||||
is_contiguous_,
|
is_contiguous_,
|
||||||
"Right now Extend is only supported for contiguous Tensor.");
|
"Right now Extend is only supported for contiguous Tensor.");
|
||||||
auto newDims = sizes_;
|
auto newDims = sizes_;
|
||||||
|
|
@ -519,7 +518,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
auto oldDims = sizes_;
|
auto oldDims = sizes_;
|
||||||
Resize(newCapacity);
|
Resize(newCapacity);
|
||||||
auto* newData = raw_mutable_data(data_type_);
|
auto* newData = raw_mutable_data(data_type_);
|
||||||
AT_ENFORCE(
|
AT_ASSERTM(
|
||||||
context != nullptr, "Context must be provided to Extend the tensor");
|
context != nullptr, "Context must be provided to Extend the tensor");
|
||||||
context->CopyItemsSameDevice(
|
context->CopyItemsSameDevice(
|
||||||
data_type_, oldSize, oldData.get(), newData);
|
data_type_, oldSize, oldData.get(), newData);
|
||||||
|
|
@ -536,13 +535,12 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
*/
|
*/
|
||||||
template <class T>
|
template <class T>
|
||||||
void ReserveSpace(const T& outer_dim) {
|
void ReserveSpace(const T& outer_dim) {
|
||||||
AT_ENFORCE_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
is_contiguous_,
|
is_contiguous_,
|
||||||
"Right now ReserveSpace is only supported for contiguous Tensor.");
|
"Right now ReserveSpace is only supported for contiguous Tensor.");
|
||||||
AT_ENFORCE(
|
AT_ASSERTM(
|
||||||
numel_ != -1, "size should be initialized before calling ReserveSpace");
|
numel_ != -1, "size should be initialized before calling ReserveSpace");
|
||||||
AT_ENFORCE(
|
AT_ASSERTM(storage_.unique(), "Can't call ReserveSpace on shared storage.");
|
||||||
storage_.unique(), "Can't call ReserveSpace on shared storage.");
|
|
||||||
auto newCapacity = sizes_;
|
auto newCapacity = sizes_;
|
||||||
newCapacity[0] = outer_dim;
|
newCapacity[0] = outer_dim;
|
||||||
auto newNumel = std::accumulate(
|
auto newNumel = std::accumulate(
|
||||||
|
|
@ -614,15 +612,15 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
* This requires the total size of the tensor to remains constant.
|
* This requires the total size of the tensor to remains constant.
|
||||||
*/
|
*/
|
||||||
inline void Reshape(const std::vector<int64_t>& dims) {
|
inline void Reshape(const std::vector<int64_t>& dims) {
|
||||||
AT_ENFORCE_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
is_contiguous_,
|
is_contiguous_,
|
||||||
"Right now Reshape is only supported for contiguous Tensor.");
|
"Right now Reshape is only supported for contiguous Tensor.");
|
||||||
int64_t new_size = 1;
|
int64_t new_size = 1;
|
||||||
for (auto d : dims) {
|
for (auto d : dims) {
|
||||||
AT_ENFORCE_GE_WITH_CALLER(d, 0);
|
AT_ASSERT(d >= 0);
|
||||||
new_size *= d;
|
new_size *= d;
|
||||||
}
|
}
|
||||||
AT_ENFORCE_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
new_size == numel_,
|
new_size == numel_,
|
||||||
"New size and old size are not equal. You cannot use Reshape, "
|
"New size and old size are not equal. You cannot use Reshape, "
|
||||||
"but should use Resize."
|
"but should use Resize."
|
||||||
|
|
@ -662,14 +660,13 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
// Right now, we are assuming the device_type are the same, since it is
|
// Right now, we are assuming the device_type are the same, since it is
|
||||||
// inherently the same in the non-templatized code. We should probably add
|
// inherently the same in the non-templatized code. We should probably add
|
||||||
// an ENFORCE here which might affect perf a little bit.
|
// an ENFORCE here which might affect perf a little bit.
|
||||||
AT_ENFORCE_EQ_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
src.numel_,
|
src.numel_ == numel_,
|
||||||
numel_,
|
|
||||||
"Size mismatch - did you call reshape before sharing the data?");
|
"Size mismatch - did you call reshape before sharing the data?");
|
||||||
// It is possible that the source tensor hasn't called mutable_data() yet,
|
// It is possible that the source tensor hasn't called mutable_data() yet,
|
||||||
// in which case ShareData() doesn't make much sense since we don't really
|
// in which case ShareData() doesn't make much sense since we don't really
|
||||||
// know what to share yet.
|
// know what to share yet.
|
||||||
AT_ENFORCE_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
src.storage_.data() || src.numel_ == 0,
|
src.storage_.data() || src.numel_ == 0,
|
||||||
"Source tensor has no content and has size > 0");
|
"Source tensor has no content and has size > 0");
|
||||||
// Finally, do sharing.
|
// Finally, do sharing.
|
||||||
|
|
@ -685,7 +682,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
at::DataPtr&& data_ptr,
|
at::DataPtr&& data_ptr,
|
||||||
const caffe2::TypeMeta& data_type,
|
const caffe2::TypeMeta& data_type,
|
||||||
size_t capacity) {
|
size_t capacity) {
|
||||||
AT_ENFORCE_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
data_type.id() != caffe2::TypeIdentifier::uninitialized(),
|
data_type.id() != caffe2::TypeIdentifier::uninitialized(),
|
||||||
"To share with a raw external pointer you need to pass in an "
|
"To share with a raw external pointer you need to pass in an "
|
||||||
"initialized data_type(TypeMeta).");
|
"initialized data_type(TypeMeta).");
|
||||||
|
|
@ -693,7 +690,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
capacity = numel_ * data_type.itemsize();
|
capacity = numel_ * data_type.itemsize();
|
||||||
}
|
}
|
||||||
if (storage_.unique()) {
|
if (storage_.unique()) {
|
||||||
AT_ENFORCE_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
numel_ >= 0,
|
numel_ >= 0,
|
||||||
"To share data with a raw pointer, you need to set shape first.");
|
"To share data with a raw pointer, you need to set shape first.");
|
||||||
storage_.UniqueStorageShareExternalPointer(
|
storage_.UniqueStorageShareExternalPointer(
|
||||||
|
|
@ -725,7 +722,7 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
if (data_type_ == meta && (storage_.data() || numel_ == 0)) {
|
if (data_type_ == meta && (storage_.data() || numel_ == 0)) {
|
||||||
return static_cast<void*>(static_cast<char*>(storage_.data()) + storage_offset_ * meta.itemsize());
|
return static_cast<void*>(static_cast<char*>(storage_.data()) + storage_offset_ * meta.itemsize());
|
||||||
} else {
|
} else {
|
||||||
AT_ENFORCE_WITH_CALLER(
|
AT_ASSERTM(
|
||||||
numel_ >= 0,
|
numel_ >= 0,
|
||||||
"Tensor is not initialized. You probably need to call Resize() "
|
"Tensor is not initialized. You probably need to call Resize() "
|
||||||
"before calling mutable_data()");
|
"before calling mutable_data()");
|
||||||
|
|
@ -750,7 +747,8 @@ struct CAFFE2_API TensorImpl : public c10::intrusive_ptr_target {
|
||||||
return storage_.data();
|
return storage_.data();
|
||||||
}
|
}
|
||||||
const at::Allocator* allocator = storage_.allocator();
|
const at::Allocator* allocator = storage_.allocator();
|
||||||
AT_ENFORCE(
|
// TODO: Get rid of StaticContext
|
||||||
|
AT_ASSERTM(
|
||||||
allocator == nullptr,
|
allocator == nullptr,
|
||||||
"Allocator in storage_ is not used within Caffe2 functions. \
|
"Allocator in storage_ is not used within Caffe2 functions. \
|
||||||
we are using global function to get the allocator based on device \
|
we are using global function to get the allocator based on device \
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ inline TensorOptions Tensor::options() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void Tensor::backward(
|
inline void Tensor::backward(
|
||||||
at::optional<Tensor> gradient,
|
c10::optional<Tensor> gradient,
|
||||||
bool keep_graph,
|
bool keep_graph,
|
||||||
bool create_graph) {
|
bool create_graph) {
|
||||||
type().backward(*this, std::move(gradient), keep_graph, create_graph);
|
type().backward(*this, std::move(gradient), keep_graph, create_graph);
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,7 @@
|
||||||
#include <ATen/core/Layout.h>
|
#include <ATen/core/Layout.h>
|
||||||
#include <ATen/core/OptionsGuard.h>
|
#include <ATen/core/OptionsGuard.h>
|
||||||
#include <ATen/core/ScalarType.h>
|
#include <ATen/core/ScalarType.h>
|
||||||
#include <ATen/core/optional.h>
|
#include "c10/util/Optional.h"
|
||||||
#include <ATen/core/ScalarType.h>
|
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,6 +6,8 @@
|
||||||
#include <ATen/core/ScalarType.h>
|
#include <ATen/core/ScalarType.h>
|
||||||
#include <ATen/core/DefaultTensorOptions.h>
|
#include <ATen/core/DefaultTensorOptions.h>
|
||||||
|
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <iosfwd>
|
#include <iosfwd>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
@ -124,7 +126,7 @@ struct CAFFE2_API TensorOptions {
|
||||||
/// (This overload ensures that initializer lists for Device work
|
/// (This overload ensures that initializer lists for Device work
|
||||||
/// correctly.)
|
/// correctly.)
|
||||||
C10_NODISCARD TensorOptions device(Device d) const noexcept {
|
C10_NODISCARD TensorOptions device(Device d) const noexcept {
|
||||||
return device(make_optional(d));
|
return device(c10::make_optional(d));
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return a copy of `TensorOptions`, but with device set to CUDA, and the
|
/// Return a copy of `TensorOptions`, but with device set to CUDA, and the
|
||||||
|
|
@ -169,10 +171,10 @@ struct CAFFE2_API TensorOptions {
|
||||||
return has_device_ ? device_ : getDefaultTensorOptions().device();
|
return has_device_ ? device_ : getDefaultTensorOptions().device();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the device of the `TensorOptions`, or `nullopt` if
|
/// Returns the device of the `TensorOptions`, or `c10::nullopt` if
|
||||||
/// device is not specified.
|
/// device is not specified.
|
||||||
optional<Device> device_opt() const noexcept {
|
optional<Device> device_opt() const noexcept {
|
||||||
return has_device_ ? make_optional(device_) : nullopt;
|
return has_device_ ? c10::make_optional(device_) : c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the device index of the `TensorOptions`.
|
/// Returns the device index of the `TensorOptions`.
|
||||||
|
|
@ -185,10 +187,10 @@ struct CAFFE2_API TensorOptions {
|
||||||
return has_dtype_ ? dtype_ : getDefaultTensorOptions().dtype();
|
return has_dtype_ ? dtype_ : getDefaultTensorOptions().dtype();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the dtype of the `TensorOptions`, or `nullopt` if
|
/// Returns the dtype of the `TensorOptions`, or `c10::nullopt` if
|
||||||
/// device is not specified.
|
/// device is not specified.
|
||||||
optional<ScalarType> dtype_opt() const noexcept {
|
optional<ScalarType> dtype_opt() const noexcept {
|
||||||
return has_dtype_ ? make_optional(dtype_) : nullopt;
|
return has_dtype_ ? c10::make_optional(dtype_) : c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the layout of the `TensorOptions`.
|
/// Returns the layout of the `TensorOptions`.
|
||||||
|
|
@ -196,10 +198,10 @@ struct CAFFE2_API TensorOptions {
|
||||||
return has_layout_ ? layout_ : getDefaultTensorOptions().layout();
|
return has_layout_ ? layout_ : getDefaultTensorOptions().layout();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the layout of the `TensorOptions`, or `nullopt` if
|
/// Returns the layout of the `TensorOptions`, or `c10::nullopt` if
|
||||||
/// layout is not specified.
|
/// layout is not specified.
|
||||||
optional<Layout> layout_opt() const noexcept {
|
optional<Layout> layout_opt() const noexcept {
|
||||||
return has_layout_ ? make_optional(layout_) : nullopt;
|
return has_layout_ ? c10::make_optional(layout_) : c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the `requires_grad` property of the `TensorOptions`.
|
/// Returns the `requires_grad` property of the `TensorOptions`.
|
||||||
|
|
@ -207,10 +209,11 @@ struct CAFFE2_API TensorOptions {
|
||||||
return has_requires_grad_ ? requires_grad_ : getDefaultTensorOptions().requires_grad();
|
return has_requires_grad_ ? requires_grad_ : getDefaultTensorOptions().requires_grad();
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the `requires_grad` property of the `TensorOptions`, or `nullopt`
|
/// Returns the `requires_grad` property of the `TensorOptions`, or
|
||||||
/// if `requires_grad` is not specified.
|
/// `c10::nullopt` if `requires_grad` is not specified.
|
||||||
optional<bool> requires_grad_opt() const noexcept {
|
optional<bool> requires_grad_opt() const noexcept {
|
||||||
return has_requires_grad_ ? make_optional(requires_grad_) : nullopt;
|
return has_requires_grad_ ? c10::make_optional(requires_grad_)
|
||||||
|
: c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the `is_variable` property of the `TensorOptions`.
|
/// Returns the `is_variable` property of the `TensorOptions`.
|
||||||
|
|
@ -219,9 +222,9 @@ struct CAFFE2_API TensorOptions {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the `is_variable` property of the `TensorOptions`, or
|
/// Returns the `is_variable` property of the `TensorOptions`, or
|
||||||
/// `nullopt` if `is_variable` is not specified.
|
/// `c10::nullopt` if `is_variable` is not specified.
|
||||||
optional<bool> is_variable_opt() const noexcept {
|
optional<bool> is_variable_opt() const noexcept {
|
||||||
return has_is_variable_ ? make_optional(is_variable_) : nullopt;
|
return has_is_variable_ ? c10::make_optional(is_variable_) : c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolves the ATen backend specified by the current construction axes.
|
// Resolves the ATen backend specified by the current construction axes.
|
||||||
|
|
@ -302,7 +305,7 @@ struct CAFFE2_API TensorOptions {
|
||||||
// WARNING: If you edit TensorOptions to add more options, you
|
// WARNING: If you edit TensorOptions to add more options, you
|
||||||
// must adjust the implementation of Tensor::options
|
// must adjust the implementation of Tensor::options
|
||||||
|
|
||||||
// NB: We didn't use at::optional here, because then we can't pack
|
// NB: We didn't use c10::optional here, because then we can't pack
|
||||||
// the has_***_ boolean fields.
|
// the has_***_ boolean fields.
|
||||||
|
|
||||||
Device device_ = at::kCPU; // 64-bit (TODO: this should be 32-bit)
|
Device device_ = at::kCPU; // 64-bit (TODO: this should be 32-bit)
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <mutex>
|
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unordered_set>
|
|
||||||
#include "ATen/core/IdWrapper.h"
|
#include "ATen/core/IdWrapper.h"
|
||||||
#include "ATen/core/Macros.h"
|
#include "ATen/core/Macros.h"
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,7 @@
|
||||||
#include "ATen/core/TensorTypeId.h"
|
#include "ATen/core/TensorTypeId.h"
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <mutex>
|
||||||
#include <unordered_set>
|
#include <unordered_set>
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@
|
||||||
#include "ATen/core/Reduction.h"
|
#include "ATen/core/Reduction.h"
|
||||||
#include "ATen/core/TensorOptions.h"
|
#include "ATen/core/TensorOptions.h"
|
||||||
|
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
@ -132,7 +134,11 @@ struct CAFFE2_API Type {
|
||||||
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const = 0;
|
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const = 0;
|
||||||
virtual Tensor & _s_copy_from(const Tensor & self, Tensor & dst, bool non_blocking) const = 0;
|
virtual Tensor & _s_copy_from(const Tensor & self, Tensor & dst, bool non_blocking) const = 0;
|
||||||
|
|
||||||
virtual void backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const = 0;
|
virtual void backward(
|
||||||
|
Tensor& self,
|
||||||
|
c10::optional<Tensor> gradient,
|
||||||
|
bool keep_graph,
|
||||||
|
bool create_graph) const = 0;
|
||||||
virtual void set_data(Tensor & self, Tensor new_data) const = 0;
|
virtual void set_data(Tensor & self, Tensor new_data) const = 0;
|
||||||
|
|
||||||
virtual Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
|
virtual Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
|
#include "ATen/core/interned_strings.h"
|
||||||
#include <cstdint>
|
#include <cstdint>
|
||||||
|
#include <cstring>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <mutex>
|
#include <mutex>
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
@ -6,10 +8,8 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "ATen/core/Error.h"
|
#include "ATen/core/Error.h"
|
||||||
#include "ATen/core/interned_strings.h"
|
|
||||||
#include "ATen/core/interned_strings_class.h"
|
#include "ATen/core/interned_strings_class.h"
|
||||||
#include "ATen/core/optional.h"
|
#include "c10/util/Optional.h"
|
||||||
#include <cstring>
|
|
||||||
|
|
||||||
namespace torch { namespace jit {
|
namespace torch { namespace jit {
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
#include "ATen/core/Error.h"
|
#include "ATen/core/Error.h"
|
||||||
#include "ATen/core/interned_strings.h"
|
#include "ATen/core/interned_strings.h"
|
||||||
#include "ATen/core/optional.h"
|
|
||||||
#include <cstring>
|
#include <cstring>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
#include <ATen/core/optional.h>
|
|
||||||
|
|
@ -25,6 +25,8 @@
|
||||||
#include "ATen/core/IdWrapper.h"
|
#include "ATen/core/IdWrapper.h"
|
||||||
#include "ATen/core/Macros.h"
|
#include "ATen/core/Macros.h"
|
||||||
|
|
||||||
|
#include "c10/util/Type.h"
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* TypeIdentifier is a small type containing an id.
|
* TypeIdentifier is a small type containing an id.
|
||||||
* Types must be registered using CAFFE_KNOWN_TYPE() for them to have a type id.
|
* Types must be registered using CAFFE_KNOWN_TYPE() for them to have a type id.
|
||||||
|
|
@ -156,7 +158,7 @@ inline void _Ctor(void* ptr, size_t n) {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void _CtorNotDefault(void* /*ptr*/, size_t /*n*/) {
|
inline void _CtorNotDefault(void* /*ptr*/, size_t /*n*/) {
|
||||||
_ThrowRuntimeTypeLogicError(
|
_ThrowRuntimeTypeLogicError(
|
||||||
"Type " + std::string(at::demangle_type<T>()) +
|
"Type " + std::string(c10::demangle_type<T>()) +
|
||||||
" is not default-constructible.");
|
" is not default-constructible.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -206,7 +208,7 @@ inline void _Copy(const void* src, void* dst, size_t n) {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
inline void _CopyNotAllowed(const void* /*src*/, void* /*dst*/, size_t /*n*/) {
|
inline void _CopyNotAllowed(const void* /*src*/, void* /*dst*/, size_t /*n*/) {
|
||||||
_ThrowRuntimeTypeLogicError(
|
_ThrowRuntimeTypeLogicError(
|
||||||
"Type " + std::string(at::demangle_type<T>()) +
|
"Type " + std::string(c10::demangle_type<T>()) +
|
||||||
" does not allow assignment.");
|
" does not allow assignment.");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -273,7 +275,7 @@ const char* _TypeName() noexcept {
|
||||||
static const char* literal_name = __TypeName<T>();
|
static const char* literal_name = __TypeName<T>();
|
||||||
#ifdef __GXX_RTTI
|
#ifdef __GXX_RTTI
|
||||||
std::ignore = literal_name; // suppress unused warning
|
std::ignore = literal_name; // suppress unused warning
|
||||||
static const std::string name = at::demangle(typeid(T).name());
|
static const std::string name = c10::demangle(typeid(T).name());
|
||||||
return name.c_str();
|
return name.c_str();
|
||||||
#else
|
#else
|
||||||
return literal_name;
|
return literal_name;
|
||||||
|
|
|
||||||
|
|
@ -404,7 +404,10 @@ The behavior depends on the dimensionality of the Tensors as follows:
|
||||||
must be broadcastable). For example, if tensor1 is a (j x 1 x n x m) Tensor
|
must be broadcastable). For example, if tensor1 is a (j x 1 x n x m) Tensor
|
||||||
and tensor2 is a (k x m x p) Tensor, the returned tensor will be an (j x k x n x p) Tensor.
|
and tensor2 is a (k x m x p) Tensor, the returned tensor will be an (j x k x n x p) Tensor.
|
||||||
*/
|
*/
|
||||||
Tensor matmul(at::optional<Tensor> out_opt, const Tensor& tensor1, const Tensor& tensor2) {
|
Tensor matmul(
|
||||||
|
c10::optional<Tensor> out_opt,
|
||||||
|
const Tensor& tensor1,
|
||||||
|
const Tensor& tensor2) {
|
||||||
auto dim_tensor1 = tensor1.dim();
|
auto dim_tensor1 = tensor1.dim();
|
||||||
auto dim_tensor2 = tensor2.dim();
|
auto dim_tensor2 = tensor2.dim();
|
||||||
auto has_out = out_opt.has_value();
|
auto has_out = out_opt.has_value();
|
||||||
|
|
@ -486,15 +489,14 @@ Tensor matmul(at::optional<Tensor> out_opt, const Tensor& tensor1, const Tensor&
|
||||||
|
|
||||||
AT_ERROR("both arguments to matmul need to be at least 1D, but they are ",
|
AT_ERROR("both arguments to matmul need to be at least 1D, but they are ",
|
||||||
dim_tensor1, "D and ", dim_tensor2, "D");
|
dim_tensor1, "D and ", dim_tensor2, "D");
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
|
Tensor matmul(const Tensor & tensor1, const Tensor & tensor2) {
|
||||||
return at::native::matmul(at::nullopt, tensor1, tensor2);
|
return at::native::matmul(c10::nullopt, tensor1, tensor2);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& matmul_out(Tensor &result, const Tensor & tensor1, const Tensor & tensor2) {
|
Tensor& matmul_out(Tensor &result, const Tensor & tensor1, const Tensor & tensor2) {
|
||||||
at::native::matmul(at::optional<Tensor>(result), tensor1, tensor2);
|
at::native::matmul(c10::optional<Tensor>(result), tensor1, tensor2);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ Tensor cumsum(const Tensor& self, int64_t dim, ScalarType dtype) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor cumsum(const Tensor& self, int64_t dim) {
|
Tensor cumsum(const Tensor& self, int64_t dim) {
|
||||||
return at::native::cumsum(self, dim, nullopt);
|
return at::native::cumsum(self, dim, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
|
static inline Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
|
||||||
|
|
@ -56,7 +56,7 @@ Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType d
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim) {
|
Tensor& cumsum_out(Tensor& result, const Tensor& self, int64_t dim) {
|
||||||
return at::native::cumsum_out(result, self, dim, nullopt);
|
return at::native::cumsum_out(result, self, dim, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline Tensor cumprod(const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
|
static inline Tensor cumprod(const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
|
||||||
|
|
@ -68,7 +68,7 @@ Tensor cumprod(const Tensor& self, int64_t dim, ScalarType dtype) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor cumprod(const Tensor& self, int64_t dim) {
|
Tensor cumprod(const Tensor& self, int64_t dim) {
|
||||||
return at::native::cumprod(self, dim, nullopt);
|
return at::native::cumprod(self, dim, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
|
static inline Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim, optional<ScalarType> dtype) {
|
||||||
|
|
@ -88,7 +88,7 @@ Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim) {
|
Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim) {
|
||||||
return at::native::cumprod_out(result, self, dim, nullopt);
|
return at::native::cumprod_out(result, self, dim, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ALL REDUCE #################################################################
|
// ALL REDUCE #################################################################
|
||||||
|
|
@ -113,7 +113,7 @@ Tensor mean(const Tensor &self, ScalarType dtype) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor mean(const Tensor &self) {
|
Tensor mean(const Tensor &self) {
|
||||||
return at::native::mean(self, nullopt);
|
return at::native::mean(self, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
static inline Tensor sum(const Tensor &self, optional<ScalarType> dtype) {
|
static inline Tensor sum(const Tensor &self, optional<ScalarType> dtype) {
|
||||||
|
|
@ -125,13 +125,13 @@ Tensor sum(const Tensor &self, ScalarType dtype) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor sum(const Tensor &self) {
|
Tensor sum(const Tensor &self) {
|
||||||
return at::native::sum(self, nullopt);
|
return at::native::sum(self, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor _sum_cpu(const Tensor& self) {
|
Tensor _sum_cpu(const Tensor& self) {
|
||||||
if (self.is_contiguous()) {
|
if (self.is_contiguous()) {
|
||||||
Tensor result = at::empty({}, self.type());
|
Tensor result = at::empty({}, self.type());
|
||||||
sum_kernel(kCPU, result, self, at::nullopt);
|
sum_kernel(kCPU, result, self, c10::nullopt);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
return at::_sumall(self);
|
return at::_sumall(self);
|
||||||
|
|
@ -146,13 +146,13 @@ Tensor prod(const Tensor &self, ScalarType dtype) {
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor prod(const Tensor &self) {
|
Tensor prod(const Tensor &self) {
|
||||||
return at::native::prod(self, nullopt);
|
return at::native::prod(self, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor _prod_cpu(const Tensor &self) {
|
Tensor _prod_cpu(const Tensor &self) {
|
||||||
if (self.is_contiguous()) {
|
if (self.is_contiguous()) {
|
||||||
Tensor result = at::empty({}, self.type());
|
Tensor result = at::empty({}, self.type());
|
||||||
prod_kernel(kCPU, result, self, at::nullopt);
|
prod_kernel(kCPU, result, self, c10::nullopt);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
return at::_prodall(self);
|
return at::_prodall(self);
|
||||||
|
|
@ -185,10 +185,11 @@ static inline Tensor &mean_out(Tensor &result, const Tensor &self, int64_t dim,
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
|
Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
|
||||||
return at::native::mean_out(result, self, dim, keepdim, at::optional<ScalarType>(dtype));
|
return at::native::mean_out(
|
||||||
|
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
|
||||||
}
|
}
|
||||||
Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) {
|
Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) {
|
||||||
return at::native::mean_out(result, self, dim, keepdim, nullopt);
|
return at::native::mean_out(result, self, dim, keepdim, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) {
|
Tensor& mean_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) {
|
||||||
|
|
@ -209,10 +210,11 @@ static inline Tensor &sum_out(Tensor &result, const Tensor &self, IntList dim,
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
|
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
|
||||||
return at::native::sum_out(result, self, dim, keepdim, at::optional<ScalarType>(dtype));
|
return at::native::sum_out(
|
||||||
|
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
|
||||||
}
|
}
|
||||||
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim) {
|
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, bool keepdim) {
|
||||||
return at::native::sum_out(result, self, dim, keepdim, nullopt);
|
return at::native::sum_out(result, self, dim, keepdim, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, ScalarType dtype) {
|
Tensor& sum_out(Tensor& result, const Tensor& self, IntList dim, ScalarType dtype) {
|
||||||
|
|
@ -247,10 +249,11 @@ static inline Tensor &prod_out(Tensor &result, const Tensor &self, int64_t dim,
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
|
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
|
||||||
return at::native::prod_out(result, self, dim, keepdim, at::optional<ScalarType>(dtype));
|
return at::native::prod_out(
|
||||||
|
result, self, dim, keepdim, c10::optional<ScalarType>(dtype));
|
||||||
}
|
}
|
||||||
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) {
|
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, bool keepdim) {
|
||||||
return at::native::prod_out(result, self, dim, keepdim, nullopt);
|
return at::native::prod_out(result, self, dim, keepdim, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) {
|
Tensor& prod_out(Tensor& result, const Tensor& self, int64_t dim, ScalarType dtype) {
|
||||||
|
|
@ -292,11 +295,11 @@ static inline Tensor mean(const Tensor &self, int64_t dim, bool keepdim, optiona
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor mean(const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
|
Tensor mean(const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
|
||||||
return at::native::mean(self, dim, keepdim, at::optional<ScalarType>(dtype));
|
return at::native::mean(self, dim, keepdim, c10::optional<ScalarType>(dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor mean(const Tensor& self, int64_t dim, bool keepdim) {
|
Tensor mean(const Tensor& self, int64_t dim, bool keepdim) {
|
||||||
return at::native::mean(self, dim, keepdim, nullopt);
|
return at::native::mean(self, dim, keepdim, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor mean(const Tensor& self, int64_t dim, ScalarType dtype) {
|
Tensor mean(const Tensor& self, int64_t dim, ScalarType dtype) {
|
||||||
|
|
@ -308,11 +311,11 @@ static inline Tensor sum(const Tensor &self, IntList dim_, bool keepdim, optiona
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor sum(const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
|
Tensor sum(const Tensor& self, IntList dim, bool keepdim, ScalarType dtype) {
|
||||||
return at::native::sum(self, dim, keepdim, at::optional<ScalarType>(dtype));
|
return at::native::sum(self, dim, keepdim, c10::optional<ScalarType>(dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor sum(const Tensor& self, IntList dim, bool keepdim) {
|
Tensor sum(const Tensor& self, IntList dim, bool keepdim) {
|
||||||
return at::native::sum(self, dim, keepdim, nullopt);
|
return at::native::sum(self, dim, keepdim, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor sum(const Tensor& self, IntList dim, ScalarType dtype) {
|
Tensor sum(const Tensor& self, IntList dim, ScalarType dtype) {
|
||||||
|
|
@ -330,11 +333,11 @@ static inline Tensor prod(const Tensor &self, int64_t dim_, bool keepdim, option
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor prod(const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
|
Tensor prod(const Tensor& self, int64_t dim, bool keepdim, ScalarType dtype) {
|
||||||
return at::native::prod(self, dim, keepdim, at::optional<ScalarType>(dtype));
|
return at::native::prod(self, dim, keepdim, c10::optional<ScalarType>(dtype));
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor prod(const Tensor& self, int64_t dim, bool keepdim) {
|
Tensor prod(const Tensor& self, int64_t dim, bool keepdim) {
|
||||||
return at::native::prod(self, dim, keepdim, nullopt);
|
return at::native::prod(self, dim, keepdim, c10::nullopt);
|
||||||
}
|
}
|
||||||
|
|
||||||
Tensor prod(const Tensor& self, int64_t dim, ScalarType dtype) {
|
Tensor prod(const Tensor& self, int64_t dim, ScalarType dtype) {
|
||||||
|
|
@ -629,7 +632,7 @@ Tensor _norm(const Tensor &self, Scalar p) {
|
||||||
} else {
|
} else {
|
||||||
if (self.is_contiguous()) {
|
if (self.is_contiguous()) {
|
||||||
Tensor result = CPU(kFloat).scalarTensor(0).toType(self.type());
|
Tensor result = CPU(kFloat).scalarTensor(0).toType(self.type());
|
||||||
norm_kernel(kCPU, result, self, p, nullopt);
|
norm_kernel(kCPU, result, self, p, c10::nullopt);
|
||||||
return result;
|
return result;
|
||||||
} else {
|
} else {
|
||||||
return at::th_norm(self, p);
|
return at::th_norm(self, p);
|
||||||
|
|
|
||||||
|
|
@ -44,12 +44,13 @@ static bool _dimreduce_return_trivial_no_ident(Tensor &result, const Tensor &sel
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
static at::optional<Tensor> _allreduce_return_trivial(const Tensor &self, Scalar ident) {
|
static c10::optional<Tensor> _allreduce_return_trivial(
|
||||||
|
const Tensor& self,
|
||||||
|
Scalar ident) {
|
||||||
// Return identity
|
// Return identity
|
||||||
if (self.numel() == 0) {
|
if (self.numel() == 0) {
|
||||||
return self.type().scalarTensor(ident);
|
return self.type().scalarTensor(ident);
|
||||||
}
|
}
|
||||||
return at::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
}} // at::native
|
}} // at::native
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,8 @@
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/core/SmallVector.h>
|
#include <ATen/core/SmallVector.h>
|
||||||
#include <ATen/core/optional.h>
|
|
||||||
#include <ATen/detail/ScalarTypeConversions.h>
|
#include <ATen/detail/ScalarTypeConversions.h>
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
// TensorIterator is a helper class for element-wise operations, such as
|
// TensorIterator is a helper class for element-wise operations, such as
|
||||||
// arithmetic, comparisions, and trigonometric functions. It handles
|
// arithmetic, comparisions, and trigonometric functions. It handles
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,14 @@
|
||||||
|
#include <ATen/native/sparse/SparseUtils.h>
|
||||||
#include <TH/THTensor.hpp>
|
#include <TH/THTensor.hpp>
|
||||||
|
#include <algorithm>
|
||||||
|
#include <vector>
|
||||||
#include "ATen/ATen.h"
|
#include "ATen/ATen.h"
|
||||||
#include "ATen/ExpandUtils.h"
|
#include "ATen/ExpandUtils.h"
|
||||||
#include "ATen/InferSize.h"
|
#include "ATen/InferSize.h"
|
||||||
#include "ATen/NativeFunctions.h"
|
#include "ATen/NativeFunctions.h"
|
||||||
#include "ATen/WrapDimUtils.h"
|
#include "ATen/WrapDimUtils.h"
|
||||||
#include "ATen/core/Error.h"
|
#include "ATen/core/Error.h"
|
||||||
#include "ATen/core/optional.h"
|
#include "c10/util/Optional.h"
|
||||||
#include <ATen/native/sparse/SparseUtils.h>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
namespace native {
|
namespace native {
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@
|
||||||
|
|
||||||
#include "ATen/Dispatch.h"
|
#include "ATen/Dispatch.h"
|
||||||
#include "ATen/Parallel.h"
|
#include "ATen/Parallel.h"
|
||||||
#include "ATen/core/optional.h"
|
|
||||||
#include "ATen/cpu/vec256/vec256.h"
|
#include "ATen/cpu/vec256/vec256.h"
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
namespace at { namespace native { namespace {
|
namespace at { namespace native { namespace {
|
||||||
|
|
||||||
|
|
@ -46,7 +46,10 @@ struct Reduction {
|
||||||
using Reduce = Op<Vec>;
|
using Reduce = Op<Vec>;
|
||||||
using ReduceScalar = Op<scalar_t>;
|
using ReduceScalar = Op<scalar_t>;
|
||||||
|
|
||||||
static void apply(Tensor& res, const Tensor& self, at::optional<int64_t> dim) {
|
static void apply(
|
||||||
|
Tensor& res,
|
||||||
|
const Tensor& self,
|
||||||
|
c10::optional<int64_t> dim) {
|
||||||
auto out_ = res.data<scalar_t>();
|
auto out_ = res.data<scalar_t>();
|
||||||
auto data_ = self.data<scalar_t>();
|
auto data_ = self.data<scalar_t>();
|
||||||
auto numel = self.numel();
|
auto numel = self.numel();
|
||||||
|
|
@ -171,13 +174,19 @@ struct Reduction {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static void sum_kernel_impl(Tensor& result, const Tensor& self, at::optional<int64_t> dim) {
|
static void sum_kernel_impl(
|
||||||
|
Tensor& result,
|
||||||
|
const Tensor& self,
|
||||||
|
c10::optional<int64_t> dim) {
|
||||||
AT_DISPATCH_ALL_TYPES(self.type(), "sum", [&] {
|
AT_DISPATCH_ALL_TYPES(self.type(), "sum", [&] {
|
||||||
Reduction<scalar_t, std::plus, 0>::apply(result, self, dim);
|
Reduction<scalar_t, std::plus, 0>::apply(result, self, dim);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static void prod_kernel_impl(Tensor& result, const Tensor& self, at::optional<int64_t> dim) {
|
static void prod_kernel_impl(
|
||||||
|
Tensor& result,
|
||||||
|
const Tensor& self,
|
||||||
|
c10::optional<int64_t> dim) {
|
||||||
AT_DISPATCH_ALL_TYPES(self.type(), "prod", [&] {
|
AT_DISPATCH_ALL_TYPES(self.type(), "prod", [&] {
|
||||||
Reduction<scalar_t, std::multiplies, 1>::apply(result, self, dim);
|
Reduction<scalar_t, std::multiplies, 1>::apply(result, self, dim);
|
||||||
});
|
});
|
||||||
|
|
@ -189,7 +198,11 @@ struct NormReduction {
|
||||||
static constexpr int WIDTH = 128 / sizeof(scalar_t);
|
static constexpr int WIDTH = 128 / sizeof(scalar_t);
|
||||||
using Vec = Vec256<scalar_t>;
|
using Vec = Vec256<scalar_t>;
|
||||||
|
|
||||||
static void apply(Tensor& res, const Tensor& self, Scalar p, at::optional<int64_t> dim) {
|
static void apply(
|
||||||
|
Tensor& res,
|
||||||
|
const Tensor& self,
|
||||||
|
Scalar p,
|
||||||
|
c10::optional<int64_t> dim) {
|
||||||
auto out_ = res.data<scalar_t>();
|
auto out_ = res.data<scalar_t>();
|
||||||
auto data_ = self.data<scalar_t>();
|
auto data_ = self.data<scalar_t>();
|
||||||
auto numel = self.numel();
|
auto numel = self.numel();
|
||||||
|
|
@ -330,7 +343,11 @@ struct NormReduction {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static void norm_kernel_impl(Tensor& result, const Tensor& self, Scalar p, at::optional<int64_t> dim) {
|
static void norm_kernel_impl(
|
||||||
|
Tensor& result,
|
||||||
|
const Tensor& self,
|
||||||
|
Scalar p,
|
||||||
|
c10::optional<int64_t> dim) {
|
||||||
AT_DISPATCH_FLOATING_TYPES(self.type(), "norm", [&] {
|
AT_DISPATCH_FLOATING_TYPES(self.type(), "norm", [&] {
|
||||||
NormReduction<scalar_t>::apply(result, self, p, dim);
|
NormReduction<scalar_t>::apply(result, self, p, dim);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -1,17 +1,18 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/core/optional.h>
|
|
||||||
#include <ATen/native/DispatchStub.h>
|
#include <ATen/native/DispatchStub.h>
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
|
||||||
using reduce_fn = void(*)(Tensor &, const Tensor &, at::optional<int64_t>);
|
using reduce_fn = void (*)(Tensor&, const Tensor&, c10::optional<int64_t>);
|
||||||
|
|
||||||
DECLARE_DISPATCH(reduce_fn, sum_kernel);
|
DECLARE_DISPATCH(reduce_fn, sum_kernel);
|
||||||
DECLARE_DISPATCH(reduce_fn, prod_kernel);
|
DECLARE_DISPATCH(reduce_fn, prod_kernel);
|
||||||
|
|
||||||
using reduce_norm_fn = void(*)(Tensor &, const Tensor &, Scalar, at::optional<int64_t>);
|
using reduce_norm_fn =
|
||||||
|
void (*)(Tensor&, const Tensor&, Scalar, c10::optional<int64_t>);
|
||||||
DECLARE_DISPATCH(reduce_norm_fn, norm_kernel);
|
DECLARE_DISPATCH(reduce_norm_fn, norm_kernel);
|
||||||
|
|
||||||
}} // namespace at::native
|
}} // namespace at::native
|
||||||
|
|
|
||||||
|
|
@ -6,9 +6,9 @@
|
||||||
|
|
||||||
#include "ATen/Dispatch.h"
|
#include "ATen/Dispatch.h"
|
||||||
#include "ATen/Parallel.h"
|
#include "ATen/Parallel.h"
|
||||||
#include "ATen/core/optional.h"
|
|
||||||
#include "ATen/cpu/vec256/functional.h"
|
#include "ATen/cpu/vec256/functional.h"
|
||||||
#include "ATen/cpu/vec256/vec256.h"
|
#include "ATen/cpu/vec256/vec256.h"
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
// [Note AVX-SSE transitions] In general we avoid calls into cmath for code
|
// [Note AVX-SSE transitions] In general we avoid calls into cmath for code
|
||||||
// compiled with AVX/AVX2 This is because of SSE-AVX transitions and a bug in
|
// compiled with AVX/AVX2 This is because of SSE-AVX transitions and a bug in
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,12 @@ bool _isnan(double val) {
|
||||||
|
|
||||||
template <typename scalar_t, typename index_t>
|
template <typename scalar_t, typename index_t>
|
||||||
struct Reduction {
|
struct Reduction {
|
||||||
static void apply(Tensor& res, Tensor& res_indices, const Tensor& self, at::optional<int64_t> dim, bool greater) {
|
static void apply(
|
||||||
|
Tensor& res,
|
||||||
|
Tensor& res_indices,
|
||||||
|
const Tensor& self,
|
||||||
|
c10::optional<int64_t> dim,
|
||||||
|
bool greater) {
|
||||||
auto out_ = res.data<scalar_t>();
|
auto out_ = res.data<scalar_t>();
|
||||||
auto indices_ = res_indices.data<index_t>();
|
auto indices_ = res_indices.data<index_t>();
|
||||||
auto data_ = self.data<scalar_t>();
|
auto data_ = self.data<scalar_t>();
|
||||||
|
|
@ -87,13 +92,21 @@ struct Reduction {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
static void max_kernel_impl(Tensor& max, Tensor& max_indices, const Tensor& self, at::optional<int64_t> dim) {
|
static void max_kernel_impl(
|
||||||
|
Tensor& max,
|
||||||
|
Tensor& max_indices,
|
||||||
|
const Tensor& self,
|
||||||
|
c10::optional<int64_t> dim) {
|
||||||
AT_DISPATCH_ALL_TYPES(self.type(), "max", [&] {
|
AT_DISPATCH_ALL_TYPES(self.type(), "max", [&] {
|
||||||
Reduction<scalar_t, int64_t>::apply(max, max_indices, self, dim, true);
|
Reduction<scalar_t, int64_t>::apply(max, max_indices, self, dim, true);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
static void min_kernel_impl(Tensor& min, Tensor& min_indices, const Tensor& self, at::optional<int64_t> dim) {
|
static void min_kernel_impl(
|
||||||
|
Tensor& min,
|
||||||
|
Tensor& min_indices,
|
||||||
|
const Tensor& self,
|
||||||
|
c10::optional<int64_t> dim) {
|
||||||
AT_DISPATCH_ALL_TYPES(self.type(), "min", [&] {
|
AT_DISPATCH_ALL_TYPES(self.type(), "min", [&] {
|
||||||
Reduction<scalar_t, int64_t>::apply(min, min_indices, self, dim, false);
|
Reduction<scalar_t, int64_t>::apply(min, min_indices, self, dim, false);
|
||||||
});
|
});
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,8 @@
|
||||||
|
|
||||||
namespace at { namespace native {
|
namespace at { namespace native {
|
||||||
|
|
||||||
using reduce_fn = void(*)(Tensor &, Tensor &, const Tensor &, at::optional<int64_t>);
|
using reduce_fn =
|
||||||
|
void (*)(Tensor&, Tensor&, const Tensor&, c10::optional<int64_t>);
|
||||||
|
|
||||||
DECLARE_DISPATCH(reduce_fn, max_kernel);
|
DECLARE_DISPATCH(reduce_fn, max_kernel);
|
||||||
DECLARE_DISPATCH(reduce_fn, min_kernel);
|
DECLARE_DISPATCH(reduce_fn, min_kernel);
|
||||||
|
|
|
||||||
|
|
@ -146,7 +146,8 @@ public:
|
||||||
// See NOTE [ cuFFT Embedded Strides ].
|
// See NOTE [ cuFFT Embedded Strides ].
|
||||||
//
|
//
|
||||||
// TODO: Figure out why windows fails to compile
|
// TODO: Figure out why windows fails to compile
|
||||||
// at::optional<std::vector<long long int>> inembed_opt = at::nullopt;
|
// c10::optional<std::vector<long long int>> inembed_opt =
|
||||||
|
// c10::nullopt;
|
||||||
// Then move the following to a helper function.
|
// Then move the following to a helper function.
|
||||||
#ifdef __HIP_PLATFORM_HCC__
|
#ifdef __HIP_PLATFORM_HCC__
|
||||||
std::vector<int> inembed(signal_ndim);
|
std::vector<int> inembed(signal_ndim);
|
||||||
|
|
|
||||||
|
|
@ -1105,7 +1105,7 @@ struct DropoutState {
|
||||||
// for the first time. Note that in this case needed != used, as we don't need
|
// for the first time. Note that in this case needed != used, as we don't need
|
||||||
// a bufer to e.g. run RNNs in test mode.
|
// a bufer to e.g. run RNNs in test mode.
|
||||||
at::Tensor buffer;
|
at::Tensor buffer;
|
||||||
at::optional<cuda::CUDAEvent> event;
|
c10::optional<cuda::CUDAEvent> event;
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
|
|
||||||
// Every time we use a dropout state, we need to synchronize with its event,
|
// Every time we use a dropout state, we need to synchronize with its event,
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ def parse_default(s):
|
||||||
return '{}'
|
return '{}'
|
||||||
elif re.match(r'{.*}', s):
|
elif re.match(r'{.*}', s):
|
||||||
return s
|
return s
|
||||||
elif s == 'nullopt':
|
elif s == 'c10::nullopt':
|
||||||
return s
|
return s
|
||||||
try:
|
try:
|
||||||
return int(s)
|
return int(s)
|
||||||
|
|
|
||||||
|
|
@ -1 +1 @@
|
||||||
#include <ATen/core/optional.h>
|
#include "c10/util/Optional.h"
|
||||||
|
|
|
||||||
|
|
@ -11,11 +11,11 @@
|
||||||
#include "ATen/Allocator.h"
|
#include "ATen/Allocator.h"
|
||||||
#include "ATen/DeviceGuard.h"
|
#include "ATen/DeviceGuard.h"
|
||||||
#include "ATen/NativeFunctions.h"
|
#include "ATen/NativeFunctions.h"
|
||||||
#include "ATen/core/UndefinedTensorImpl.h"
|
|
||||||
#include "ATen/Utils.h"
|
#include "ATen/Utils.h"
|
||||||
#include "ATen/WrapDimUtils.h"
|
#include "ATen/WrapDimUtils.h"
|
||||||
#include "ATen/core/Half.h"
|
#include "ATen/core/Half.h"
|
||||||
#include "ATen/core/optional.h"
|
#include "ATen/core/UndefinedTensorImpl.h"
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "ATen/core/Device.h"
|
#include "ATen/core/Device.h"
|
||||||
|
#include "ATen/core/Error.h"
|
||||||
#include "ATen/core/Layout.h"
|
#include "ATen/core/Layout.h"
|
||||||
#include "ATen/core/Scalar.h"
|
#include "ATen/core/Scalar.h"
|
||||||
#include "ATen/core/ScalarType.h"
|
#include "ATen/core/ScalarType.h"
|
||||||
|
|
@ -8,9 +9,8 @@
|
||||||
#include "ATen/core/Storage.h"
|
#include "ATen/core/Storage.h"
|
||||||
#include "ATen/core/TensorAccessor.h"
|
#include "ATen/core/TensorAccessor.h"
|
||||||
#include "ATen/core/TensorImpl.h"
|
#include "ATen/core/TensorImpl.h"
|
||||||
#include "ATen/core/optional.h"
|
|
||||||
#include "ATen/core/UndefinedTensorImpl.h"
|
#include "ATen/core/UndefinedTensorImpl.h"
|
||||||
#include "ATen/core/Error.h"
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
namespace at {
|
namespace at {
|
||||||
struct Generator;
|
struct Generator;
|
||||||
|
|
@ -241,7 +241,7 @@ public:
|
||||||
|
|
||||||
/// Computes the gradient of current tensor w.r.t. graph leaves.
|
/// Computes the gradient of current tensor w.r.t. graph leaves.
|
||||||
void backward(
|
void backward(
|
||||||
at::optional<Tensor> gradient = at::nullopt,
|
c10::optional<Tensor> gradient = c10::nullopt,
|
||||||
bool keep_graph = false,
|
bool keep_graph = false,
|
||||||
bool create_graph = false);
|
bool create_graph = false);
|
||||||
|
|
||||||
|
|
@ -267,7 +267,7 @@ struct CAFFE2_API WeakTensor {
|
||||||
WeakTensor(const Tensor& t) : weak_impl_(t.impl_) {}
|
WeakTensor(const Tensor& t) : weak_impl_(t.impl_) {}
|
||||||
|
|
||||||
// XXX: this can return undefined tensors
|
// XXX: this can return undefined tensors
|
||||||
// Ideally it would be at::optional<Tensor>, but MSVC is too cool for that
|
// Ideally it would be c10::optional<Tensor>, but MSVC is too cool for that
|
||||||
Tensor lock() const {
|
Tensor lock() const {
|
||||||
return Tensor(weak_impl_.lock());
|
return Tensor(weak_impl_.lock());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ inline TensorOptions Tensor::options() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void Tensor::backward(
|
inline void Tensor::backward(
|
||||||
at::optional<Tensor> gradient,
|
c10::optional<Tensor> gradient,
|
||||||
bool keep_graph,
|
bool keep_graph,
|
||||||
bool create_graph) {
|
bool create_graph) {
|
||||||
type().backward(*this, std::move(gradient), keep_graph, create_graph);
|
type().backward(*this, std::move(gradient), keep_graph, create_graph);
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,8 @@
|
||||||
#include "ATen/core/Reduction.h"
|
#include "ATen/core/Reduction.h"
|
||||||
#include "ATen/core/TensorOptions.h"
|
#include "ATen/core/TensorOptions.h"
|
||||||
|
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
@ -103,7 +105,11 @@ struct CAFFE2_API Type {
|
||||||
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const = 0;
|
virtual Tensor & s_copy_(Tensor & self, const Tensor & src, bool non_blocking) const = 0;
|
||||||
virtual Tensor & _s_copy_from(const Tensor & self, Tensor & dst, bool non_blocking) const = 0;
|
virtual Tensor & _s_copy_from(const Tensor & self, Tensor & dst, bool non_blocking) const = 0;
|
||||||
|
|
||||||
virtual void backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const = 0;
|
virtual void backward(
|
||||||
|
Tensor& self,
|
||||||
|
c10::optional<Tensor> gradient,
|
||||||
|
bool keep_graph,
|
||||||
|
bool create_graph) const = 0;
|
||||||
virtual void set_data(Tensor & self, Tensor new_data) const = 0;
|
virtual void set_data(Tensor & self, Tensor new_data) const = 0;
|
||||||
|
|
||||||
virtual Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
|
virtual Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const = 0;
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,11 @@ Tensor TypeDefault::copy(const Tensor & src, bool non_blocking, optional<Device>
|
||||||
return r;
|
return r;
|
||||||
}
|
}
|
||||||
|
|
||||||
void TypeDefault::backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const {
|
void TypeDefault::backward(
|
||||||
|
Tensor& self,
|
||||||
|
c10::optional<Tensor> gradient,
|
||||||
|
bool keep_graph,
|
||||||
|
bool create_graph) const {
|
||||||
AT_ERROR("backward is not implemented for Tensor");
|
AT_ERROR("backward is not implemented for Tensor");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,11 @@ struct CAFFE2_API TypeDefault : public TypeExtendedInterface {
|
||||||
Tensor copy(const Tensor & src, bool non_blocking=false, optional<Device> to_device={}) const override;
|
Tensor copy(const Tensor & src, bool non_blocking=false, optional<Device> to_device={}) const override;
|
||||||
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const override;
|
Tensor & copy_(Tensor & self, const Tensor & src, bool non_blocking=false) const override;
|
||||||
|
|
||||||
void backward(Tensor & self, at::optional<Tensor> gradient, bool keep_graph, bool create_graph) const override;
|
void backward(
|
||||||
|
Tensor& self,
|
||||||
|
c10::optional<Tensor> gradient,
|
||||||
|
bool keep_graph,
|
||||||
|
bool create_graph) const override;
|
||||||
void set_data(Tensor & self, Tensor new_data) const override;
|
void set_data(Tensor & self, Tensor new_data) const override;
|
||||||
|
|
||||||
Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const override;
|
Tensor tensorFromBlob(void * data, IntList sizes, const std::function<void(void*)> & deleter=noop_deleter) const override;
|
||||||
|
|
|
||||||
|
|
@ -10,15 +10,15 @@
|
||||||
$th_headers
|
$th_headers
|
||||||
$storage_tensor_headers
|
$storage_tensor_headers
|
||||||
#include "ATen/${Generator}.h"
|
#include "ATen/${Generator}.h"
|
||||||
#include "ATen/core/TensorImpl.h"
|
|
||||||
#include "ATen/Allocator.h"
|
#include "ATen/Allocator.h"
|
||||||
#include "ATen/DeviceGuard.h"
|
#include "ATen/DeviceGuard.h"
|
||||||
#include "ATen/NativeFunctions.h"
|
#include "ATen/NativeFunctions.h"
|
||||||
#include "ATen/core/UndefinedTensorImpl.h"
|
|
||||||
#include "ATen/Utils.h"
|
#include "ATen/Utils.h"
|
||||||
#include "ATen/WrapDimUtils.h"
|
#include "ATen/WrapDimUtils.h"
|
||||||
#include "ATen/core/Half.h"
|
#include "ATen/core/Half.h"
|
||||||
#include "ATen/core/optional.h"
|
#include "ATen/core/TensorImpl.h"
|
||||||
|
#include "ATen/core/UndefinedTensorImpl.h"
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
|
||||||
|
|
@ -9,8 +9,8 @@ using namespace at;
|
||||||
|
|
||||||
// optional in cuda files
|
// optional in cuda files
|
||||||
TEST(OptionalTest, OptionalTestCUDA) {
|
TEST(OptionalTest, OptionalTestCUDA) {
|
||||||
at::optional<int64_t> trivially_destructible;
|
c10::optional<int64_t> trivially_destructible;
|
||||||
at::optional<std::vector<int64_t>> non_trivially_destructible;
|
c10::optional<std::vector<int64_t>> non_trivially_destructible;
|
||||||
ASSERT_FALSE(trivially_destructible.has_value());
|
ASSERT_FALSE(trivially_destructible.has_value());
|
||||||
ASSERT_FALSE(non_trivially_destructible.has_value());
|
ASSERT_FALSE(non_trivially_destructible.has_value());
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,4 +6,8 @@
|
||||||
|
|
||||||
SET(TH_FOUND 1)
|
SET(TH_FOUND 1)
|
||||||
SET(TH_INCLUDE_DIR "@TH_INCLUDE_DIR@")
|
SET(TH_INCLUDE_DIR "@TH_INCLUDE_DIR@")
|
||||||
SET(TH_LIBRARIES "@TH_LIBRARIES@")
|
# TODO: TH right now uses old-style cmake targets, and due to
|
||||||
|
# transitive dependency, now libraries such as libshm depend
|
||||||
|
# on C10 as well. As a result, we manually add that TH_LIBRARIES
|
||||||
|
# will contain C10 as well.
|
||||||
|
SET(TH_LIBRARIES "@TH_LIBRARIES@;@C10_LIBRARIES")
|
||||||
|
|
|
||||||
|
|
@ -137,8 +137,10 @@ void THTensor_resizeNd(THTensor *self, int nDimension, const int64_t *size, cons
|
||||||
// 2. newshape must be able to be separated into same number of chunks as oldshape was separated into,
|
// 2. newshape must be able to be separated into same number of chunks as oldshape was separated into,
|
||||||
// where each chunk of newshape has matching ``numel'', i.e., number of subspaces,
|
// where each chunk of newshape has matching ``numel'', i.e., number of subspaces,
|
||||||
// as the corresponding chunk of oldshape.
|
// as the corresponding chunk of oldshape.
|
||||||
at::optional<std::vector<int64_t>>
|
c10::optional<std::vector<int64_t>> THTensor_compute_stride(
|
||||||
THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride, at::IntList newshape) {
|
at::IntList oldshape,
|
||||||
|
at::IntList oldstride,
|
||||||
|
at::IntList newshape) {
|
||||||
if (oldshape.empty()) {
|
if (oldshape.empty()) {
|
||||||
return std::vector<int64_t>(newshape.size(), 1);
|
return std::vector<int64_t>(newshape.size(), 1);
|
||||||
}
|
}
|
||||||
|
|
@ -182,7 +184,7 @@ THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride, at::IntList
|
||||||
view_d--;
|
view_d--;
|
||||||
}
|
}
|
||||||
if (view_numel != tensor_numel) {
|
if (view_numel != tensor_numel) {
|
||||||
return at::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
if (tensor_d > 0) {
|
if (tensor_d > 0) {
|
||||||
chunk_base_stride = oldstride[tensor_d - 1];
|
chunk_base_stride = oldstride[tensor_d - 1];
|
||||||
|
|
@ -192,7 +194,7 @@ THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride, at::IntList
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (view_d != -1) {
|
if (view_d != -1) {
|
||||||
return at::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
return newstride;
|
return newstride;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -110,8 +110,10 @@ TH_API void THTensor_resizeNd(THTensor *self, int nDimension, const int64_t *siz
|
||||||
|
|
||||||
TH_CPP_API void THTensor_resize(THTensor *self, at::IntList size, at::IntList stride);
|
TH_CPP_API void THTensor_resize(THTensor *self, at::IntList size, at::IntList stride);
|
||||||
TH_CPP_API void THTensor_setStorage(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, at::IntList size_, at::IntList stride_);
|
TH_CPP_API void THTensor_setStorage(THTensor *self, THStorage *storage_, ptrdiff_t storageOffset_, at::IntList size_, at::IntList stride_);
|
||||||
TH_CPP_API at::optional<std::vector<int64_t>> THTensor_compute_stride(at::IntList oldshape, at::IntList oldstride,
|
TH_CPP_API c10::optional<std::vector<int64_t>> THTensor_compute_stride(
|
||||||
at::IntList newshape);
|
at::IntList oldshape,
|
||||||
|
at::IntList oldstride,
|
||||||
|
at::IntList newshape);
|
||||||
|
|
||||||
#include "generic/THTensor.hpp"
|
#include "generic/THTensor.hpp"
|
||||||
#include "THGenerateAllTypes.h"
|
#include "THGenerateAllTypes.h"
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@
|
||||||
// building, and the instruction is as follows: assuming that you are building
|
// building, and the instruction is as follows: assuming that you are building
|
||||||
// a library called libawesome.so. You should:
|
// a library called libawesome.so. You should:
|
||||||
// (1) for your cmake target (usually done by "add_library(awesome, ...)"),
|
// (1) for your cmake target (usually done by "add_library(awesome, ...)"),
|
||||||
// define a macro called AWESOME_BUILD_MAIN_DLL using
|
// define a macro called AWESOME_BUILD_MAIN_LIB using
|
||||||
// target_compile_options.
|
// target_compile_options.
|
||||||
// (2) define the AWESOME_API macro similar to the one below.
|
// (2) define the AWESOME_API macro similar to the one below.
|
||||||
// And in the source file of your awesome library, use AWESOME_API to
|
// And in the source file of your awesome library, use AWESOME_API to
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
#include <ATen/core/Backtrace.h>
|
#include "c10/util/Backtrace.h"
|
||||||
#include <ATen/core/optional.h>
|
#include "c10/util/Optional.h"
|
||||||
|
#include "c10/util/Type.h"
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
@ -7,62 +8,20 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#if defined(__ANDROID__)
|
#if (defined(__ANDROID__)) || \
|
||||||
#define AT_CORE_MOBILE 1
|
(defined(__APPLE__) && \
|
||||||
#elif ( \
|
(TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE)) || \
|
||||||
defined(__APPLE__) && \
|
defined(_WIN32) || defined(__EMSCRIPTEN__)
|
||||||
(TARGET_IPHONE_SIMULATOR || TARGET_OS_SIMULATOR || TARGET_OS_IPHONE))
|
// No backtrace on mobile, windows and emscripten platforms.
|
||||||
#define AT_CORE_MOBILE 1
|
|
||||||
#else
|
|
||||||
#define AT_CORE_MOBILE 0
|
|
||||||
#endif
|
|
||||||
|
|
||||||
#if !AT_CORE_MOBILE && !defined(_WIN32) && !defined(__EMSCRIPTEN__)
|
|
||||||
#define SUPPORTS_BACKTRACE 1
|
|
||||||
#else
|
|
||||||
#define SUPPORTS_BACKTRACE 0
|
#define SUPPORTS_BACKTRACE 0
|
||||||
#endif
|
#else
|
||||||
|
#define SUPPORTS_BACKTRACE 1
|
||||||
#if SUPPORTS_BACKTRACE
|
|
||||||
#include <cxxabi.h>
|
#include <cxxabi.h>
|
||||||
#include <execinfo.h>
|
#include <execinfo.h>
|
||||||
#endif // !defined(_WIN32)
|
|
||||||
|
|
||||||
namespace at {
|
|
||||||
|
|
||||||
#if SUPPORTS_BACKTRACE
|
|
||||||
std::string demangle(const char* name) {
|
|
||||||
int status = -1;
|
|
||||||
|
|
||||||
// This function will demangle the mangled function name into a more human
|
|
||||||
// readable format, e.g. _Z1gv -> g().
|
|
||||||
// More information:
|
|
||||||
// https://github.com/gcc-mirror/gcc/blob/master/libstdc%2B%2B-v3/libsupc%2B%2B/cxxabi.h
|
|
||||||
// NOTE: `__cxa_demangle` returns a malloc'd string that we have to free
|
|
||||||
// ourselves.
|
|
||||||
std::unique_ptr<char, std::function<void(char*)>> demangled(
|
|
||||||
abi::__cxa_demangle(
|
|
||||||
name,
|
|
||||||
/*__output_buffer=*/nullptr,
|
|
||||||
/*__length=*/0,
|
|
||||||
&status),
|
|
||||||
/*deleter=*/free);
|
|
||||||
|
|
||||||
// Demangling may fail, for example when the name does not follow the
|
|
||||||
// standard C++ (Itanium ABI) mangling scheme. This is the case for `main`
|
|
||||||
// or `clone` for example, so the mangled name is a fine default.
|
|
||||||
if (status == 0) {
|
|
||||||
return demangled.get();
|
|
||||||
} else {
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#else
|
|
||||||
std::string demangle(const char* name) {
|
|
||||||
return std::string(name);
|
|
||||||
}
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
|
||||||
// TODO: This backtrace retrieval can be implemented on Windows via the Windows
|
// TODO: This backtrace retrieval can be implemented on Windows via the Windows
|
||||||
// API using `CaptureStackBackTrace` and `SymFromAddr`.
|
// API using `CaptureStackBackTrace` and `SymFromAddr`.
|
||||||
// https://stackoverflow.com/questions/5693192/win32-backtrace-from-c-code
|
// https://stackoverflow.com/questions/5693192/win32-backtrace-from-c-code
|
||||||
|
|
@ -86,12 +45,11 @@ struct FrameInformation {
|
||||||
};
|
};
|
||||||
|
|
||||||
bool is_python_frame(const FrameInformation& frame) {
|
bool is_python_frame(const FrameInformation& frame) {
|
||||||
return frame.object_file == "python" ||
|
return frame.object_file == "python" || frame.object_file == "python3" ||
|
||||||
frame.object_file == "python3" ||
|
|
||||||
(frame.object_file.find("libpython") != std::string::npos);
|
(frame.object_file.find("libpython") != std::string::npos);
|
||||||
}
|
}
|
||||||
|
|
||||||
at::optional<FrameInformation> parse_frame_information(
|
c10::optional<FrameInformation> parse_frame_information(
|
||||||
const std::string& frame_string) {
|
const std::string& frame_string) {
|
||||||
FrameInformation frame;
|
FrameInformation frame;
|
||||||
|
|
||||||
|
|
@ -107,19 +65,19 @@ at::optional<FrameInformation> parse_frame_information(
|
||||||
|
|
||||||
auto function_name_start = frame_string.find("(");
|
auto function_name_start = frame_string.find("(");
|
||||||
if (function_name_start == std::string::npos) {
|
if (function_name_start == std::string::npos) {
|
||||||
return at::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
function_name_start += 1;
|
function_name_start += 1;
|
||||||
|
|
||||||
auto offset_start = frame_string.find('+', function_name_start);
|
auto offset_start = frame_string.find('+', function_name_start);
|
||||||
if (offset_start == std::string::npos) {
|
if (offset_start == std::string::npos) {
|
||||||
return at::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
offset_start += 1;
|
offset_start += 1;
|
||||||
|
|
||||||
const auto offset_end = frame_string.find(')', offset_start);
|
const auto offset_end = frame_string.find(')', offset_start);
|
||||||
if (offset_end == std::string::npos) {
|
if (offset_end == std::string::npos) {
|
||||||
return at::nullopt;
|
return c10::nullopt;
|
||||||
}
|
}
|
||||||
|
|
||||||
frame.object_file = frame_string.substr(0, function_name_start - 1);
|
frame.object_file = frame_string.substr(0, function_name_start - 1);
|
||||||
|
|
@ -143,7 +101,7 @@ at::optional<FrameInformation> parse_frame_information(
|
||||||
skip >> frame.offset_into_function;
|
skip >> frame.offset_into_function;
|
||||||
#else
|
#else
|
||||||
#warning Unknown standard library, backtraces may have incomplete debug information
|
#warning Unknown standard library, backtraces may have incomplete debug information
|
||||||
return at::nullopt;
|
return c10::nullopt;
|
||||||
#endif // defined(__GLIBCXX__)
|
#endif // defined(__GLIBCXX__)
|
||||||
|
|
||||||
// Some system-level functions don't have sufficient debug information, so
|
// Some system-level functions don't have sufficient debug information, so
|
||||||
|
|
@ -241,4 +199,4 @@ std::string get_backtrace(
|
||||||
#endif // SUPPORTS_BACKTRACE
|
#endif // SUPPORTS_BACKTRACE
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at
|
} // namespace c10
|
||||||
17
c10/util/Backtrace.h
Normal file
17
c10/util/Backtrace.h
Normal file
|
|
@ -0,0 +1,17 @@
|
||||||
|
#ifndef C10_UTIL_BACKTRACE_H_
|
||||||
|
#define C10_UTIL_BACKTRACE_H_
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <string>
|
||||||
|
#include <typeinfo>
|
||||||
|
|
||||||
|
#include "c10/macros/Macros.h"
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
C10_API std::string get_backtrace(
|
||||||
|
size_t frames_to_skip = 0,
|
||||||
|
size_t maximum_number_of_frames = 64,
|
||||||
|
bool skip_python_frames = true);
|
||||||
|
} // namespace c10
|
||||||
|
|
||||||
|
#endif // C10_UTIL_BACKTRACE_H_
|
||||||
|
|
@ -1,30 +1,12 @@
|
||||||
#include <ATen/core/Error.h>
|
#include "c10/util/Exception.h"
|
||||||
#include <ATen/core/Backtrace.h>
|
#include "c10/util/Backtrace.h"
|
||||||
|
#include "c10/util/Type.h"
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <numeric>
|
#include <numeric>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
namespace at {
|
namespace c10 {
|
||||||
|
|
||||||
namespace detail {
|
|
||||||
|
|
||||||
std::string StripBasename(const std::string& full_path) {
|
|
||||||
const char kSeparator = '/';
|
|
||||||
size_t pos = full_path.rfind(kSeparator);
|
|
||||||
if (pos != std::string::npos) {
|
|
||||||
return full_path.substr(pos + 1, std::string::npos);
|
|
||||||
} else {
|
|
||||||
return full_path;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace detail
|
|
||||||
|
|
||||||
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
|
|
||||||
out << loc.function << " at " << loc.file << ":" << loc.line;
|
|
||||||
return out;
|
|
||||||
}
|
|
||||||
|
|
||||||
Error::Error(
|
Error::Error(
|
||||||
const std::string& new_msg,
|
const std::string& new_msg,
|
||||||
|
|
@ -102,26 +84,10 @@ Warning::handler_t Warning::warning_handler_ = &Warning::print_warning;
|
||||||
|
|
||||||
std::string GetExceptionString(const std::exception& e) {
|
std::string GetExceptionString(const std::exception& e) {
|
||||||
#ifdef __GXX_RTTI
|
#ifdef __GXX_RTTI
|
||||||
return at::demangle(typeid(e).name()) + ": " + e.what();
|
return demangle(typeid(e).name()) + ": " + e.what();
|
||||||
#else
|
#else
|
||||||
return std::string("Exception (no RTTI available): ") + e.what();
|
return std::string("Exception (no RTTI available): ") + e.what();
|
||||||
#endif // __GXX_RTTI
|
#endif // __GXX_RTTI
|
||||||
}
|
}
|
||||||
|
|
||||||
std::function<std::string(void)>* GetFetchStackTrace() {
|
} // namespace c10
|
||||||
static std::function<std::string(void)> func = []() { return ""; };
|
|
||||||
return &func;
|
|
||||||
};
|
|
||||||
|
|
||||||
void ThrowEnforceNotMet(
|
|
||||||
const char* file,
|
|
||||||
const int line,
|
|
||||||
const char* condition,
|
|
||||||
const std::string& msg,
|
|
||||||
const void* caller)
|
|
||||||
{
|
|
||||||
at::Error e(file, line, condition, msg, (*GetFetchStackTrace())(), caller);
|
|
||||||
throw e;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace at
|
|
||||||
153
c10/util/Exception.h
Normal file
153
c10/util/Exception.h
Normal file
|
|
@ -0,0 +1,153 @@
|
||||||
|
#ifndef C10_UTIL_EXCEPTION_H_
|
||||||
|
#define C10_UTIL_EXCEPTION_H_
|
||||||
|
|
||||||
|
#include "c10/macros/Macros.h"
|
||||||
|
#include "c10/util/StringUtil.h"
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <exception>
|
||||||
|
#include <ostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#if defined(_MSC_VER) && _MSC_VER <= 1900
|
||||||
|
#define __func__ __FUNCTION__
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
|
||||||
|
/// The primary ATen error class.
|
||||||
|
/// Provides a complete error message with source location information via
|
||||||
|
/// `what()`, and a more concise message via `what_without_backtrace()`. Should
|
||||||
|
/// primarily be used with the `AT_ERROR` macro.
|
||||||
|
///
|
||||||
|
/// NB: c10::Error is handled specially by the default torch to suppress the
|
||||||
|
/// backtrace, see torch/csrc/Exceptions.h
|
||||||
|
class C10_API Error : public std::exception {
|
||||||
|
std::vector<std::string> msg_stack_;
|
||||||
|
std::string backtrace_;
|
||||||
|
|
||||||
|
// These two are derived fields from msg_stack_ and backtrace_, but we need
|
||||||
|
// fields for the strings so that we can return a const char* (as the
|
||||||
|
// signature of std::exception requires).
|
||||||
|
std::string msg_;
|
||||||
|
std::string msg_without_backtrace_;
|
||||||
|
|
||||||
|
// This is a little debugging trick: you can stash a relevant pointer
|
||||||
|
// in caller, and then when you catch the exception, you can compare
|
||||||
|
// against pointers you have on hand to get more information about
|
||||||
|
// where the exception came from. In Caffe2, this is used to figure
|
||||||
|
// out which operator raised an exception.
|
||||||
|
const void* caller_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
Error(
|
||||||
|
const std::string& msg,
|
||||||
|
const std::string& backtrace,
|
||||||
|
const void* caller = nullptr);
|
||||||
|
Error(SourceLocation source_location, const std::string& msg);
|
||||||
|
Error(
|
||||||
|
const char* file,
|
||||||
|
const int line,
|
||||||
|
const char* condition,
|
||||||
|
const std::string& msg,
|
||||||
|
const std::string& backtrace,
|
||||||
|
const void* caller = nullptr);
|
||||||
|
|
||||||
|
void AppendMessage(const std::string& msg);
|
||||||
|
|
||||||
|
// Compute the full message from msg_ and msg_without_backtrace_
|
||||||
|
// TODO: Maybe this should be private
|
||||||
|
std::string msg() const;
|
||||||
|
std::string msg_without_backtrace() const;
|
||||||
|
|
||||||
|
const std::vector<std::string>& msg_stack() const {
|
||||||
|
return msg_stack_;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns the complete error message, including the source location.
|
||||||
|
const char* what() const noexcept override {
|
||||||
|
return msg_.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
const void* caller() const noexcept {
|
||||||
|
return caller_;
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns only the error message string, without source location.
|
||||||
|
const char* what_without_backtrace() const noexcept {
|
||||||
|
return msg_without_backtrace_.c_str();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class C10_API Warning {
|
||||||
|
using handler_t =
|
||||||
|
void (*)(const SourceLocation& source_location, const char* msg);
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// Issue a warning with a given message. Dispatched to the current
|
||||||
|
/// warning handler.
|
||||||
|
static void warn(SourceLocation source_location, std::string msg);
|
||||||
|
/// Sets the global warning handler. This is not thread-safe, so it should
|
||||||
|
/// generally be called once during initialization.
|
||||||
|
static void set_warning_handler(handler_t handler);
|
||||||
|
/// The default warning handler. Prints the message to stderr.
|
||||||
|
static void print_warning(
|
||||||
|
const SourceLocation& source_location,
|
||||||
|
const char* msg);
|
||||||
|
|
||||||
|
private:
|
||||||
|
static handler_t warning_handler_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// A utility function to return an exception std::string by prepending its
|
||||||
|
// exception type before its what() content
|
||||||
|
C10_API std::string GetExceptionString(const std::exception& e);
|
||||||
|
|
||||||
|
} // namespace c10
|
||||||
|
|
||||||
|
// TODO: variants that print the expression tested and thus don't require
|
||||||
|
// strings
|
||||||
|
// TODO: CAFFE_ENFORCE_WITH_CALLER style macro
|
||||||
|
|
||||||
|
// TODO: move AT_ERROR to C10_ERROR
|
||||||
|
// TODO: consolidate the enforce and assert messages. Assert is a bit confusing
|
||||||
|
// as c++ assert quits, while this throws.
|
||||||
|
// TODO: merge AT_CHECK with AT_ASSERTM. CHECK in fbcode means strict failure if
|
||||||
|
// not met.
|
||||||
|
|
||||||
|
#define AT_ERROR(...) \
|
||||||
|
throw ::c10::Error({__func__, __FILE__, __LINE__}, ::c10::str(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define AT_WARN(...) \
|
||||||
|
::c10::Warning::warn({__func__, __FILE__, __LINE__}, ::c10::str(__VA_ARGS__))
|
||||||
|
|
||||||
|
#define AT_ASSERT(cond) \
|
||||||
|
if (!(cond)) { \
|
||||||
|
AT_ERROR( \
|
||||||
|
#cond " ASSERT FAILED at ", \
|
||||||
|
__FILE__, \
|
||||||
|
":", \
|
||||||
|
__LINE__, \
|
||||||
|
", please report a bug to PyTorch."); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define AT_ASSERTM(cond, ...) \
|
||||||
|
if (!(cond)) { \
|
||||||
|
AT_ERROR(::c10::str( \
|
||||||
|
#cond, \
|
||||||
|
" ASSERT FAILED at ", \
|
||||||
|
__FILE__, \
|
||||||
|
":", \
|
||||||
|
__LINE__, \
|
||||||
|
", please report a bug to PyTorch. ", \
|
||||||
|
__VA_ARGS__)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#define AT_CHECK(cond, ...) \
|
||||||
|
if (!(cond)) { \
|
||||||
|
AT_ERROR(::c10::str(__VA_ARGS__)); \
|
||||||
|
}
|
||||||
|
|
||||||
|
#endif // C10_UTIL_EXCEPTION_H_
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
#pragma once
|
#ifndef C10_UTIL_FLAGS_H_
|
||||||
|
#define C10_UTIL_FLAGS_H_
|
||||||
|
|
||||||
/* Commandline flags support for C10.
|
/* Commandline flags support for C10.
|
||||||
*
|
*
|
||||||
|
|
@ -210,3 +211,5 @@ C10_DECLARE_REGISTRY(C10FlagsRegistry, C10FlagParser, const std::string&);
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
#endif // C10_USE_GFLAGS
|
#endif // C10_USE_GFLAGS
|
||||||
|
|
||||||
|
#endif // C10_UTIL_FLAGS_H_
|
||||||
|
|
|
||||||
1
c10/util/Optional.cpp
Normal file
1
c10/util/Optional.cpp
Normal file
|
|
@ -0,0 +1 @@
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
@ -9,12 +9,13 @@
|
||||||
//
|
//
|
||||||
// From https://github.com/akrzemi1/Optional
|
// From https://github.com/akrzemi1/Optional
|
||||||
//
|
//
|
||||||
// ATen:
|
// C10
|
||||||
// - Move to `at` namespace.
|
// - Move to `c10` namespace.
|
||||||
// - Remove macro use in line 478 because the nvcc device compiler cannot handle
|
// - Remove macro use in line 478 because the nvcc device compiler cannot handle
|
||||||
// it.
|
// it.
|
||||||
|
|
||||||
#pragma once
|
#ifndef C10_UTIL_OPTIONAL_H_
|
||||||
|
#define C10_UTIL_OPTIONAL_H_
|
||||||
|
|
||||||
#include <cassert>
|
#include <cassert>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
@ -105,7 +106,7 @@
|
||||||
#define OPTIONAL_MUTABLE_CONSTEXPR constexpr
|
#define OPTIONAL_MUTABLE_CONSTEXPR constexpr
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace at {
|
namespace c10 {
|
||||||
|
|
||||||
// 20.5.4, optional for object types
|
// 20.5.4, optional for object types
|
||||||
template <class T>
|
template <class T>
|
||||||
|
|
@ -301,8 +302,7 @@ using OptionalBase = typename std::conditional<
|
||||||
|
|
||||||
template <class T>
|
template <class T>
|
||||||
class optional : private OptionalBase<T> {
|
class optional : private OptionalBase<T> {
|
||||||
|
template <class U> // re-declaration for nvcc on Windows.
|
||||||
template <class U> // re-declaration for nvcc on Windows.
|
|
||||||
using OptionalBase = typename std::conditional<
|
using OptionalBase = typename std::conditional<
|
||||||
std::is_trivially_destructible<U>::value, // if possible
|
std::is_trivially_destructible<U>::value, // if possible
|
||||||
constexpr_optional_base<typename std::remove_const<
|
constexpr_optional_base<typename std::remove_const<
|
||||||
|
|
@ -1007,13 +1007,13 @@ constexpr optional<X&> make_optional(std::reference_wrapper<X> v) {
|
||||||
return optional<X&>(v.get());
|
return optional<X&>(v.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace at
|
} // namespace c10
|
||||||
|
|
||||||
namespace std {
|
namespace std {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct hash<at::optional<T>> {
|
struct hash<c10::optional<T>> {
|
||||||
typedef typename hash<T>::result_type result_type;
|
typedef typename hash<T>::result_type result_type;
|
||||||
typedef at::optional<T> argument_type;
|
typedef c10::optional<T> argument_type;
|
||||||
|
|
||||||
constexpr result_type operator()(argument_type const& arg) const {
|
constexpr result_type operator()(argument_type const& arg) const {
|
||||||
return arg ? std::hash<T>{}(*arg) : result_type{};
|
return arg ? std::hash<T>{}(*arg) : result_type{};
|
||||||
|
|
@ -1021,9 +1021,9 @@ struct hash<at::optional<T>> {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct hash<at::optional<T&>> {
|
struct hash<c10::optional<T&>> {
|
||||||
typedef typename hash<T>::result_type result_type;
|
typedef typename hash<T>::result_type result_type;
|
||||||
typedef at::optional<T&> argument_type;
|
typedef c10::optional<T&> argument_type;
|
||||||
|
|
||||||
constexpr result_type operator()(argument_type const& arg) const {
|
constexpr result_type operator()(argument_type const& arg) const {
|
||||||
return arg ? std::hash<T>{}(*arg) : result_type{};
|
return arg ? std::hash<T>{}(*arg) : result_type{};
|
||||||
|
|
@ -1031,5 +1031,13 @@ struct hash<at::optional<T&>> {
|
||||||
};
|
};
|
||||||
} // namespace std
|
} // namespace std
|
||||||
|
|
||||||
|
// TODO: remove at::optional when done moving files
|
||||||
|
namespace at {
|
||||||
|
template <class T>
|
||||||
|
using optional = c10::optional<T>;
|
||||||
|
}
|
||||||
|
|
||||||
#undef TR2_OPTIONAL_REQUIRES
|
#undef TR2_OPTIONAL_REQUIRES
|
||||||
#undef TR2_OPTIONAL_ASSERTED_EXPRESSION
|
#undef TR2_OPTIONAL_ASSERTED_EXPRESSION
|
||||||
|
|
||||||
|
#endif // C10_UTIL_OPTIONAL_H_
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
#include "c10/macros/Macros.h"
|
||||||
#include "c10/util/Type.h"
|
#include "c10/util/Type.h"
|
||||||
|
|
||||||
namespace c10 {
|
namespace c10 {
|
||||||
|
|
|
||||||
43
c10/util/StringUtil.cpp
Normal file
43
c10/util/StringUtil.cpp
Normal file
|
|
@ -0,0 +1,43 @@
|
||||||
|
#include "c10/util/StringUtil.h"
|
||||||
|
#include "c10/util/Exception.h"
|
||||||
|
|
||||||
|
#include <cstring>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
std::string StripBasename(const std::string& full_path) {
|
||||||
|
const char kSeparator = '/';
|
||||||
|
size_t pos = full_path.rfind(kSeparator);
|
||||||
|
if (pos != std::string::npos) {
|
||||||
|
return full_path.substr(pos + 1, std::string::npos);
|
||||||
|
} else {
|
||||||
|
return full_path;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc) {
|
||||||
|
out << loc.function << " at " << loc.file << ":" << loc.line;
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t ReplaceAll(std::string& s, const char* from, const char* to) {
|
||||||
|
AT_CHECK(from && *from, "");
|
||||||
|
AT_CHECK(to, "");
|
||||||
|
|
||||||
|
size_t numReplaced = 0;
|
||||||
|
std::string::size_type lenFrom = std::strlen(from);
|
||||||
|
std::string::size_type lenTo = std::strlen(to);
|
||||||
|
for (auto pos = s.find(from); pos != std::string::npos;
|
||||||
|
pos = s.find(from, pos + lenTo)) {
|
||||||
|
s.replace(pos, lenFrom, to);
|
||||||
|
numReplaced++;
|
||||||
|
}
|
||||||
|
return numReplaced;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace c10
|
||||||
78
c10/util/StringUtil.h
Normal file
78
c10/util/StringUtil.h
Normal file
|
|
@ -0,0 +1,78 @@
|
||||||
|
#ifndef C10_UTIL_STRINGUTIL_H_
|
||||||
|
#define C10_UTIL_STRINGUTIL_H_
|
||||||
|
|
||||||
|
#include <c10/macros/Macros.h>
|
||||||
|
|
||||||
|
#include <cstddef>
|
||||||
|
#include <ostream>
|
||||||
|
#include <sstream>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace c10 {
|
||||||
|
|
||||||
|
namespace detail {
|
||||||
|
|
||||||
|
// Obtains the base name from a full path.
|
||||||
|
C10_API std::string StripBasename(const std::string& full_path);
|
||||||
|
|
||||||
|
inline std::ostream& _str(std::ostream& ss) {
|
||||||
|
return ss;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline std::ostream& _str(std::ostream& ss, const T& t) {
|
||||||
|
ss << t;
|
||||||
|
return ss;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T, typename... Args>
|
||||||
|
inline std::ostream& _str(std::ostream& ss, const T& t, const Args&... args) {
|
||||||
|
return _str(_str(ss, t), args...);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace detail
|
||||||
|
|
||||||
|
// Convert a list of string-like arguments into a single string.
|
||||||
|
template <typename... Args>
|
||||||
|
inline std::string str(const Args&... args) {
|
||||||
|
std::ostringstream ss;
|
||||||
|
detail::_str(ss, args...);
|
||||||
|
return ss.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Specializations for already-a-string types.
|
||||||
|
template <>
|
||||||
|
inline std::string str(const std::string& str) {
|
||||||
|
return str;
|
||||||
|
}
|
||||||
|
inline std::string str(const char* c_str) {
|
||||||
|
return c_str;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class Container>
|
||||||
|
inline std::string Join(const std::string& delimiter, const Container& v) {
|
||||||
|
std::stringstream s;
|
||||||
|
int cnt = static_cast<int64_t>(v.size()) - 1;
|
||||||
|
for (auto i = v.begin(); i != v.end(); ++i, --cnt) {
|
||||||
|
s << (*i) << (cnt ? delimiter : "");
|
||||||
|
}
|
||||||
|
return s.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace all occurrences of "from" substring to "to" string.
|
||||||
|
// Returns number of replacements
|
||||||
|
size_t C10_API ReplaceAll(std::string& s, const char* from, const char* to);
|
||||||
|
|
||||||
|
/// Represents a location in source code (for debugging).
|
||||||
|
struct C10_API SourceLocation {
|
||||||
|
const char* function;
|
||||||
|
const char* file;
|
||||||
|
uint32_t line;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::ostream& operator<<(std::ostream& out, const SourceLocation& loc);
|
||||||
|
|
||||||
|
} // namespace c10
|
||||||
|
|
||||||
|
#endif // C10_UTIL_STRINGUTIL_H_
|
||||||
|
|
@ -112,7 +112,7 @@ TensorRTOp::TensorRTOp(const OperatorDef& operator_def, Workspace* ws)
|
||||||
is_input_.push_back(is_input);
|
is_input_.push_back(is_input);
|
||||||
if (!is_input) {
|
if (!is_input) {
|
||||||
// For output, we try to get its output size hint
|
// For output, we try to get its output size hint
|
||||||
const std::string key = MakeString("output_size_hint_", output_idx);
|
const std::string key = c10::str("output_size_hint_", output_idx);
|
||||||
auto output_size_hint = OperatorBase::GetRepeatedArgument<int>(key);
|
auto output_size_hint = OperatorBase::GetRepeatedArgument<int>(key);
|
||||||
if (!output_size_hint.empty()) {
|
if (!output_size_hint.empty()) {
|
||||||
std::vector<int64_t> dims;
|
std::vector<int64_t> dims;
|
||||||
|
|
|
||||||
|
|
@ -188,7 +188,7 @@ void TensorRTTransformer::AddTrtOptions(
|
||||||
if (it != output_size_hints.end()) {
|
if (it != output_size_hints.end()) {
|
||||||
const auto& dims = it->second;
|
const auto& dims = it->second;
|
||||||
auto* output_size_hint_arg = op->add_arg();
|
auto* output_size_hint_arg = op->add_arg();
|
||||||
output_size_hint_arg->set_name(MakeString("output_size_hint_", i));
|
output_size_hint_arg->set_name(c10::str("output_size_hint_", i));
|
||||||
for (const auto& d : dims) {
|
for (const auto& d : dims) {
|
||||||
output_size_hint_arg->add_ints(d);
|
output_size_hint_arg->add_ints(d);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -114,7 +114,7 @@ void TensorSerializer::SerializeWithChunkSize(
|
||||||
this->Serialize(
|
this->Serialize(
|
||||||
tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
|
tensor, name, blob_proto.mutable_tensor(), chunkStart, chunk_size);
|
||||||
acceptor(
|
acceptor(
|
||||||
MakeString(name, kChunkIdSeparator, chunkStart / chunk_size),
|
c10::str(name, kChunkIdSeparator, chunkStart / chunk_size),
|
||||||
blob_proto.SerializeAsString());
|
blob_proto.SerializeAsString());
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -949,7 +949,7 @@ class DummyTypeSerializer : public BlobSerializerBase {
|
||||||
const auto& container = blob.template Get<DummyType>();
|
const auto& container = blob.template Get<DummyType>();
|
||||||
for (int k = 0; k < container.n_chunks; ++k) {
|
for (int k = 0; k < container.n_chunks; ++k) {
|
||||||
std::string serialized_chunk = container.serialize(name, k);
|
std::string serialized_chunk = container.serialize(name, k);
|
||||||
acceptor(MakeString(name, kChunkIdSeparator, k), serialized_chunk);
|
acceptor(c10::str(name, kChunkIdSeparator, k), serialized_chunk);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include "caffe2/core/dispatch/OpSchema.h"
|
#include "c10/util/Optional.h"
|
||||||
#include "caffe2/core/dispatch/Dispatcher.h"
|
#include "caffe2/core/dispatch/Dispatcher.h"
|
||||||
#include <ATen/core/optional.h>
|
#include "caffe2/core/dispatch/OpSchema.h"
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* To register your own kernel for an operator, do in one (!) cpp file:
|
* To register your own kernel for an operator, do in one (!) cpp file:
|
||||||
|
|
@ -89,14 +89,17 @@ private:
|
||||||
static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0;
|
static constexpr uint64_t KERNEL_PRESENT = 0x01 << 0;
|
||||||
static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1;
|
static constexpr uint64_t DISPATCH_KEY_PRESENT = 0x01 << 1;
|
||||||
|
|
||||||
at::optional<typename Schema::signature::func_type*> kernel_;
|
c10::optional<typename Schema::signature::func_type*> kernel_;
|
||||||
at::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_;
|
c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
constexpr KernelRegistrationBuilder(): KernelRegistrationBuilder(at::nullopt, at::nullopt) {}
|
constexpr KernelRegistrationBuilder()
|
||||||
|
: KernelRegistrationBuilder(c10::nullopt, c10::nullopt) {}
|
||||||
|
|
||||||
constexpr KernelRegistrationBuilder(at::optional<typename Schema::signature::func_type*> kernel, at::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key)
|
constexpr KernelRegistrationBuilder(
|
||||||
: kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {}
|
c10::optional<typename Schema::signature::func_type*> kernel,
|
||||||
|
c10::optional<typename Schema::dispatch::dispatch_key_type> dispatch_key)
|
||||||
|
: kernel_(std::move(kernel)), dispatch_key_(std::move(dispatch_key)) {}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Implicit coercion to KernelRegistrar<OpSchemaDef> that finalizes the builder and
|
* Implicit coercion to KernelRegistrar<OpSchemaDef> that finalizes the builder and
|
||||||
|
|
|
||||||
|
|
@ -21,21 +21,6 @@ namespace enforce_detail {
|
||||||
}
|
}
|
||||||
} // namespace enforce_detail
|
} // namespace enforce_detail
|
||||||
|
|
||||||
size_t ReplaceAll(string& s, const char* from, const char* to) {
|
|
||||||
CAFFE_ENFORCE(from && *from);
|
|
||||||
CAFFE_ENFORCE(to);
|
|
||||||
|
|
||||||
size_t numReplaced = 0;
|
|
||||||
string::size_type lenFrom = std::strlen(from);
|
|
||||||
string::size_type lenTo = std::strlen(to);
|
|
||||||
for (string::size_type pos = s.find(from); pos != string::npos;
|
|
||||||
pos = s.find(from, pos + lenTo)) {
|
|
||||||
s.replace(pos, lenFrom, to);
|
|
||||||
numReplaced++;
|
|
||||||
}
|
|
||||||
return numReplaced;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace {
|
namespace {
|
||||||
std::function<string(void)>* GetFetchStackTrace() {
|
std::function<string(void)>* GetFetchStackTrace() {
|
||||||
static std::function<string(void)> func = []() { return ""; };
|
static std::function<string(void)> func = []() { return ""; };
|
||||||
|
|
@ -53,7 +38,7 @@ void ThrowEnforceNotMet(
|
||||||
const char* condition,
|
const char* condition,
|
||||||
const std::string& msg,
|
const std::string& msg,
|
||||||
const void* caller) {
|
const void* caller) {
|
||||||
at::Error e(file, line, condition, msg, (*GetFetchStackTrace())(), caller);
|
c10::Error e(file, line, condition, msg, (*GetFetchStackTrace())(), caller);
|
||||||
if (c10::FLAGS_caffe2_use_fatal_for_enforce) {
|
if (c10::FLAGS_caffe2_use_fatal_for_enforce) {
|
||||||
LOG(FATAL) << e.msg_stack()[0];
|
LOG(FATAL) << e.msg_stack()[0];
|
||||||
}
|
}
|
||||||
|
|
@ -204,14 +189,16 @@ MessageLogger::MessageLogger(const char *file, int line, int severity)
|
||||||
std::chrono::duration_cast<std::chrono::nanoseconds>(
|
std::chrono::duration_cast<std::chrono::nanoseconds>(
|
||||||
std::chrono::high_resolution_clock::now().time_since_epoch());
|
std::chrono::high_resolution_clock::now().time_since_epoch());
|
||||||
*/
|
*/
|
||||||
stream_ << "[" << CAFFE2_SEVERITY_PREFIX[std::min(4, FATAL - severity_)]
|
stream_ << "["
|
||||||
|
<< CAFFE2_SEVERITY_PREFIX[std::min(4, FATAL - severity_)]
|
||||||
//<< (timeinfo->tm_mon + 1) * 100 + timeinfo->tm_mday
|
//<< (timeinfo->tm_mon + 1) * 100 + timeinfo->tm_mday
|
||||||
//<< std::setfill('0')
|
//<< std::setfill('0')
|
||||||
//<< " " << std::setw(2) << timeinfo->tm_hour
|
//<< " " << std::setw(2) << timeinfo->tm_hour
|
||||||
//<< ":" << std::setw(2) << timeinfo->tm_min
|
//<< ":" << std::setw(2) << timeinfo->tm_min
|
||||||
//<< ":" << std::setw(2) << timeinfo->tm_sec
|
//<< ":" << std::setw(2) << timeinfo->tm_sec
|
||||||
//<< "." << std::setw(9) << ns.count() % 1000000000
|
//<< "." << std::setw(9) << ns.count() % 1000000000
|
||||||
<< " " << at::detail::StripBasename(std::string(file)) << ":" << line << "] ";
|
<< " " << c10::detail::StripBasename(std::string(file)) << ":" << line
|
||||||
|
<< "] ";
|
||||||
}
|
}
|
||||||
|
|
||||||
// Output the contents of the stream to the proper channel on destruction.
|
// Output the contents of the stream to the proper channel on destruction.
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,7 @@
|
||||||
#include <sstream>
|
#include <sstream>
|
||||||
|
|
||||||
#include <ATen/core/Error.h>
|
#include <ATen/core/Error.h>
|
||||||
|
#include "c10/util/StringUtil.h"
|
||||||
#include "caffe2/core/common.h"
|
#include "caffe2/core/common.h"
|
||||||
#include "caffe2/core/flags.h"
|
#include "caffe2/core/flags.h"
|
||||||
|
|
||||||
|
|
@ -60,77 +61,28 @@ constexpr bool IsUsingGoogleLogging() {
|
||||||
*/
|
*/
|
||||||
CAFFE2_API void ShowLogInfoToStderr();
|
CAFFE2_API void ShowLogInfoToStderr();
|
||||||
|
|
||||||
inline void MakeStringInternal(std::stringstream& /*ss*/) {}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
inline void MakeStringInternal(std::stringstream& ss, const T& t) {
|
|
||||||
ss << t;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T, typename... Args>
|
|
||||||
inline void
|
|
||||||
MakeStringInternal(std::stringstream& ss, const T& t, const Args&... args) {
|
|
||||||
MakeStringInternal(ss, t);
|
|
||||||
MakeStringInternal(ss, args...);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename... Args>
|
|
||||||
string MakeString(const Args&... args) {
|
|
||||||
std::stringstream ss;
|
|
||||||
MakeStringInternal(ss, args...);
|
|
||||||
return string(ss.str());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Specializations for already-a-string types.
|
|
||||||
template <>
|
|
||||||
inline string MakeString(const string& str) {
|
|
||||||
return str;
|
|
||||||
}
|
|
||||||
inline string MakeString(const char* c_str) {
|
|
||||||
return string(c_str);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class Container>
|
|
||||||
inline string Join(const string& delimiter, const Container& v) {
|
|
||||||
std::stringstream s;
|
|
||||||
int cnt = static_cast<int64_t>(v.size()) - 1;
|
|
||||||
for (auto i = v.begin(); i != v.end(); ++i, --cnt) {
|
|
||||||
s << (*i) << (cnt ? delimiter : "");
|
|
||||||
}
|
|
||||||
return s.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace all occurrences of "from" substring to "to" string.
|
|
||||||
// Returns number of replacements
|
|
||||||
size_t ReplaceAll(string& s, const char* from, const char* to);
|
|
||||||
|
|
||||||
CAFFE2_API void SetStackTraceFetcher(std::function<string(void)> fetcher);
|
CAFFE2_API void SetStackTraceFetcher(std::function<string(void)> fetcher);
|
||||||
|
|
||||||
using EnforceNotMet = at::Error;
|
using EnforceNotMet = ::c10::Error;
|
||||||
|
|
||||||
#define CAFFE_ENFORCE(condition, ...) \
|
#define CAFFE_ENFORCE(condition, ...) \
|
||||||
do { \
|
do { \
|
||||||
if (!(condition)) { \
|
if (!(condition)) { \
|
||||||
::caffe2::ThrowEnforceNotMet( \
|
::caffe2::ThrowEnforceNotMet( \
|
||||||
__FILE__, __LINE__, #condition, ::caffe2::MakeString(__VA_ARGS__)); \
|
__FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__)); \
|
||||||
} \
|
} \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
#define CAFFE_ENFORCE_WITH_CALLER(condition, ...) \
|
#define CAFFE_ENFORCE_WITH_CALLER(condition, ...) \
|
||||||
do { \
|
do { \
|
||||||
if (!(condition)) { \
|
if (!(condition)) { \
|
||||||
::caffe2::ThrowEnforceNotMet( \
|
::caffe2::ThrowEnforceNotMet( \
|
||||||
__FILE__, \
|
__FILE__, __LINE__, #condition, ::c10::str(__VA_ARGS__), this); \
|
||||||
__LINE__, \
|
} \
|
||||||
#condition, \
|
|
||||||
::caffe2::MakeString(__VA_ARGS__), \
|
|
||||||
this); \
|
|
||||||
} \
|
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
#define CAFFE_THROW(...) \
|
#define CAFFE_THROW(...) \
|
||||||
::caffe2::ThrowEnforceNotMet( \
|
::caffe2::ThrowEnforceNotMet(__FILE__, __LINE__, "", ::c10::str(__VA_ARGS__))
|
||||||
__FILE__, __LINE__, "", ::caffe2::MakeString(__VA_ARGS__))
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Rich logging messages
|
* Rich logging messages
|
||||||
|
|
@ -148,7 +100,7 @@ using EnforceNotMet = at::Error;
|
||||||
* namespace caffe2 { namespace enforce_detail {
|
* namespace caffe2 { namespace enforce_detail {
|
||||||
* inline EnforceFailMessage IsVector(const vector<int64_t>& shape) {
|
* inline EnforceFailMessage IsVector(const vector<int64_t>& shape) {
|
||||||
* if (shape.size() == 1) { return EnforceOK(); }
|
* if (shape.size() == 1) { return EnforceOK(); }
|
||||||
* return MakeString("Shape ", shape, " is not a vector");
|
* return c10::str("Shape ", shape, " is not a vector");
|
||||||
* }
|
* }
|
||||||
* }}
|
* }}
|
||||||
*
|
*
|
||||||
|
|
@ -197,7 +149,7 @@ class CAFFE2_API EnforceFailMessage {
|
||||||
if (extra.empty()) {
|
if (extra.empty()) {
|
||||||
r = std::move(*msg_);
|
r = std::move(*msg_);
|
||||||
} else {
|
} else {
|
||||||
r = ::caffe2::MakeString(std::move(*msg_), ". ", std::move(extra));
|
r = ::c10::str(std::move(*msg_), ". ", std::move(extra));
|
||||||
}
|
}
|
||||||
delete msg_;
|
delete msg_;
|
||||||
return r;
|
return r;
|
||||||
|
|
@ -213,7 +165,7 @@ class CAFFE2_API EnforceFailMessage {
|
||||||
if (x op y) { \
|
if (x op y) { \
|
||||||
return EnforceOK(); \
|
return EnforceOK(); \
|
||||||
} \
|
} \
|
||||||
return MakeString(x, " vs ", y); \
|
return c10::str(x, " vs ", y); \
|
||||||
}
|
}
|
||||||
BINARY_COMP_HELPER(Equals, ==)
|
BINARY_COMP_HELPER(Equals, ==)
|
||||||
BINARY_COMP_HELPER(NotEquals, !=)
|
BINARY_COMP_HELPER(NotEquals, !=)
|
||||||
|
|
@ -233,7 +185,7 @@ BINARY_COMP_HELPER(LessEquals, <=)
|
||||||
__LINE__, \
|
__LINE__, \
|
||||||
expr, \
|
expr, \
|
||||||
CAFFE_ENFORCE_THAT_IMPL_r_.get_message_and_free( \
|
CAFFE_ENFORCE_THAT_IMPL_r_.get_message_and_free( \
|
||||||
::caffe2::MakeString(__VA_ARGS__))); \
|
::c10::str(__VA_ARGS__))); \
|
||||||
} \
|
} \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
||||||
|
|
@ -248,7 +200,7 @@ BINARY_COMP_HELPER(LessEquals, <=)
|
||||||
__LINE__, \
|
__LINE__, \
|
||||||
expr, \
|
expr, \
|
||||||
CAFFE_ENFORCE_THAT_IMPL_WITH_CALLER_r_.get_message_and_free( \
|
CAFFE_ENFORCE_THAT_IMPL_WITH_CALLER_r_.get_message_and_free( \
|
||||||
::caffe2::MakeString(__VA_ARGS__)), \
|
::c10::str(__VA_ARGS__)), \
|
||||||
this); \
|
this); \
|
||||||
} \
|
} \
|
||||||
} while (false)
|
} while (false)
|
||||||
|
|
|
||||||
|
|
@ -65,11 +65,11 @@ TEST(LoggingTest, EnforceShowcase) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(LoggingTest, Join) {
|
TEST(LoggingTest, Join) {
|
||||||
auto s = Join(", ", vector<int>({1, 2, 3}));
|
auto s = c10::Join(", ", vector<int>({1, 2, 3}));
|
||||||
EXPECT_EQ(s, "1, 2, 3");
|
EXPECT_EQ(s, "1, 2, 3");
|
||||||
s = Join(":", vector<string>());
|
s = c10::Join(":", vector<string>());
|
||||||
EXPECT_EQ(s, "");
|
EXPECT_EQ(s, "");
|
||||||
s = Join(", ", set<int>({3, 1, 2}));
|
s = c10::Join(", ", set<int>({3, 1, 2}));
|
||||||
EXPECT_EQ(s, "1, 2, 3");
|
EXPECT_EQ(s, "1, 2, 3");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -242,7 +242,7 @@ void DAGNetBase::WorkerFunction() {
|
||||||
operator_nodes_[idx].operator_->debug_def());
|
operator_nodes_[idx].operator_->debug_def());
|
||||||
}
|
}
|
||||||
} catch (std::exception& e) {
|
} catch (std::exception& e) {
|
||||||
std::string exception_str = at::GetExceptionString(e);
|
std::string exception_str = c10::GetExceptionString(e);
|
||||||
HandleException(idx, exception_str);
|
HandleException(idx, exception_str);
|
||||||
} catch (...) {
|
} catch (...) {
|
||||||
std::string exception_str = "Unknown exception";
|
std::string exception_str = "Unknown exception";
|
||||||
|
|
|
||||||
|
|
@ -927,9 +927,9 @@ class CAFFE2_API UnsupportedOperatorFeature : public std::exception {
|
||||||
// A helper macro that should ONLY be used in the operator constructor to check
|
// A helper macro that should ONLY be used in the operator constructor to check
|
||||||
// if needed features are met. If not, throws the UnsupportedOperatorFeature
|
// if needed features are met. If not, throws the UnsupportedOperatorFeature
|
||||||
// exception with the given message.
|
// exception with the given message.
|
||||||
#define OPERATOR_NEEDS_FEATURE(condition, ...) \
|
#define OPERATOR_NEEDS_FEATURE(condition, ...) \
|
||||||
if (!(condition)) { \
|
if (!(condition)) { \
|
||||||
throw UnsupportedOperatorFeature(::caffe2::MakeString(__VA_ARGS__)); \
|
throw UnsupportedOperatorFeature(::c10::str(__VA_ARGS__)); \
|
||||||
}
|
}
|
||||||
|
|
||||||
// Creates an operator with the given operator definition.
|
// Creates an operator with the given operator definition.
|
||||||
|
|
|
||||||
|
|
@ -418,7 +418,7 @@ bool ExecuteStepRecursive(ExecutionStepWrapper& stepWrapper) {
|
||||||
} catch (const std::exception& ex) {
|
} catch (const std::exception& ex) {
|
||||||
std::lock_guard<std::mutex> guard(exception_mutex);
|
std::lock_guard<std::mutex> guard(exception_mutex);
|
||||||
if (!first_exception.size()) {
|
if (!first_exception.size()) {
|
||||||
first_exception = at::GetExceptionString(ex);
|
first_exception = c10::GetExceptionString(ex);
|
||||||
LOG(ERROR) << "Parallel worker exception:\n" << first_exception;
|
LOG(ERROR) << "Parallel worker exception:\n" << first_exception;
|
||||||
}
|
}
|
||||||
compiledStep->gotFailure = true;
|
compiledStep->gotFailure = true;
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,7 @@ TEST(TypeMetaTest, Names) {
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
string(string_meta.name()) != typeid(string).name());
|
string(string_meta.name()) != typeid(string).name());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
string(string_meta.name()) == at::demangle(typeid(string).name()));
|
string(string_meta.name()) == c10::demangle(typeid(string).name()));
|
||||||
#endif // __GXX_RTTI
|
#endif // __GXX_RTTI
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@
|
||||||
#include <direct.h> // for _mkdir
|
#include <direct.h> // for _mkdir
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "c10/util/StringUtil.h"
|
||||||
|
|
||||||
#include "caffe2/utils/murmur_hash3.h"
|
#include "caffe2/utils/murmur_hash3.h"
|
||||||
|
|
||||||
namespace caffe2 {
|
namespace caffe2 {
|
||||||
|
|
@ -151,7 +153,8 @@ void FileStoreHandler::wait(
|
||||||
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||||||
std::chrono::steady_clock::now() - start);
|
std::chrono::steady_clock::now() - start);
|
||||||
if (timeout != kNoTimeout && elapsed > timeout) {
|
if (timeout != kNoTimeout && elapsed > timeout) {
|
||||||
STORE_HANDLER_TIMEOUT("Wait timeout for name(s): ", Join(" ", names));
|
STORE_HANDLER_TIMEOUT(
|
||||||
|
"Wait timeout for name(s): ", c10::Join(" ", names));
|
||||||
}
|
}
|
||||||
/* sleep override */
|
/* sleep override */
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||||
|
|
|
||||||
|
|
@ -107,7 +107,8 @@ void RedisStoreHandler::wait(
|
||||||
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
|
||||||
std::chrono::steady_clock::now() - start);
|
std::chrono::steady_clock::now() - start);
|
||||||
if (timeout != kNoTimeout && elapsed > timeout) {
|
if (timeout != kNoTimeout && elapsed > timeout) {
|
||||||
STORE_HANDLER_TIMEOUT("Wait timeout for name(s): ", Join(" ", names));
|
STORE_HANDLER_TIMEOUT(
|
||||||
|
"Wait timeout for name(s): ", c10::Join(" ", names));
|
||||||
}
|
}
|
||||||
/* sleep override */
|
/* sleep override */
|
||||||
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
std::this_thread::sleep_for(std::chrono::milliseconds(10));
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ struct CAFFE2_API StoreHandlerNotAvailableException
|
||||||
|
|
||||||
#define STORE_HANDLER_NOT_AVAILABLE(...) \
|
#define STORE_HANDLER_NOT_AVAILABLE(...) \
|
||||||
throw ::caffe2::StoreHandlerNotAvailableException( \
|
throw ::caffe2::StoreHandlerNotAvailableException( \
|
||||||
::caffe2::MakeString("[", __FILE__, ":", __LINE__, "] ", __VA_ARGS__));
|
::c10::str("[", __FILE__, ":", __LINE__, "] ", __VA_ARGS__));
|
||||||
|
|
||||||
/*
|
/*
|
||||||
* Timeout accessing the store.
|
* Timeout accessing the store.
|
||||||
|
|
@ -77,5 +77,5 @@ struct CAFFE2_API StoreHandlerTimeoutException : public std::runtime_error {
|
||||||
|
|
||||||
#define STORE_HANDLER_TIMEOUT(...) \
|
#define STORE_HANDLER_TIMEOUT(...) \
|
||||||
throw ::caffe2::StoreHandlerTimeoutException( \
|
throw ::caffe2::StoreHandlerTimeoutException( \
|
||||||
::caffe2::MakeString("[", __FILE__, ":", __LINE__, "] ", __VA_ARGS__));
|
::c10::str("[", __FILE__, ":", __LINE__, "] ", __VA_ARGS__));
|
||||||
} // namespace caffe2
|
} // namespace caffe2
|
||||||
|
|
|
||||||
|
|
@ -158,8 +158,8 @@ The convolution fusion operator consumes an input vector, a {dim}filter blob,
|
||||||
a bias blob and another input vector and computes the output. This operator
|
a bias blob and another input vector and computes the output. This operator
|
||||||
gives the chance to fuse the ReLU or element-wise Sum with a convolution
|
gives the chance to fuse the ReLU or element-wise Sum with a convolution
|
||||||
operator. {conv_fusion_doc})DOC";
|
operator. {conv_fusion_doc})DOC";
|
||||||
ReplaceAll(doc, "{dim}", dim);
|
c10::ReplaceAll(doc, "{dim}", dim);
|
||||||
ReplaceAll(doc, "{conv_fusion_doc}", kConvFusionDoc);
|
c10::ReplaceAll(doc, "{conv_fusion_doc}", kConvFusionDoc);
|
||||||
schema.SetDoc(doc);
|
schema.SetDoc(doc);
|
||||||
schema.Input(
|
schema.Input(
|
||||||
0,
|
0,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ namespace caffe2 { namespace onnx {
|
||||||
|
|
||||||
std::string DummyName::NewDummyName() {
|
std::string DummyName::NewDummyName() {
|
||||||
while (true) {
|
while (true) {
|
||||||
const std::string name = caffe2::MakeString("OC2_DUMMY_", counter_++);
|
const std::string name = c10::str("OC2_DUMMY_", counter_++);
|
||||||
auto ret = used_names_.insert(name);
|
auto ret = used_names_.insert(name);
|
||||||
if (ret.second) {
|
if (ret.second) {
|
||||||
return name;
|
return name;
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ TensorProto CreateOnnxShapeTensor(
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string SsaName(const std::string& n, int version) {
|
std::string SsaName(const std::string& n, int version) {
|
||||||
return MakeString(n, "_", version);
|
return c10::str(n, "_", version);
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
|
|
@ -283,8 +283,7 @@ void OnnxExporter::CopyCaffe2ArgToOnnxAttr(
|
||||||
attr->mutable_strings()->CopyFrom(arg.strings());
|
attr->mutable_strings()->CopyFrom(arg.strings());
|
||||||
attr->set_type(AttributeProto::STRINGS);
|
attr->set_type(AttributeProto::STRINGS);
|
||||||
} else {
|
} else {
|
||||||
CAFFE_THROW(
|
CAFFE_THROW(c10::str("Unsupported Caffe2 argument: ", arg.name()));
|
||||||
caffe2::MakeString("Unsupported Caffe2 argument: ", arg.name()));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -934,7 +933,7 @@ void OnnxExporter::InitOpToTensorProto(
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
CAFFE_THROW(
|
CAFFE_THROW(
|
||||||
MakeString("Cannot convert C2 op ", op.type(), "to ONNX TensorProto"));
|
c10::str("Cannot convert C2 op ", op.type(), "to ONNX TensorProto"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -97,8 +97,8 @@ std::function<void(OpSchema&)> ConvDocGenerator(const char* dim) {
|
||||||
string doc = R"DOC(
|
string doc = R"DOC(
|
||||||
The convolution operator consumes an input vector, a {dim}filter blob
|
The convolution operator consumes an input vector, a {dim}filter blob
|
||||||
and a bias blob and computes the output. {conv_doc})DOC";
|
and a bias blob and computes the output. {conv_doc})DOC";
|
||||||
ReplaceAll(doc, "{dim}", dim);
|
c10::ReplaceAll(doc, "{dim}", dim);
|
||||||
ReplaceAll(doc, "{conv_doc}", kConvDoc);
|
c10::ReplaceAll(doc, "{conv_doc}", kConvDoc);
|
||||||
schema.SetDoc(doc);
|
schema.SetDoc(doc);
|
||||||
schema.Input(
|
schema.Input(
|
||||||
0,
|
0,
|
||||||
|
|
|
||||||
|
|
@ -222,9 +222,9 @@ Performs element-wise binary {name} (with limited broadcast support).
|
||||||
|
|
||||||
{extra}
|
{extra}
|
||||||
)DOC";
|
)DOC";
|
||||||
ReplaceAll(doc, "{name}", name);
|
c10::ReplaceAll(doc, "{name}", name);
|
||||||
ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
|
c10::ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
|
||||||
ReplaceAll(doc, "{extra}", extra);
|
c10::ReplaceAll(doc, "{extra}", extra);
|
||||||
schema.SetDoc(doc);
|
schema.SetDoc(doc);
|
||||||
schema.Arg("broadcast", "*(type: int; default: 0)* Pass 1 to enable broadcasting");
|
schema.Arg("broadcast", "*(type: int; default: 0)* Pass 1 to enable broadcasting");
|
||||||
schema.Arg(
|
schema.Arg(
|
||||||
|
|
@ -598,10 +598,10 @@ Performs element-wise {desc} comparison **{name}** (with limited broadcast suppo
|
||||||
|
|
||||||
{extra}
|
{extra}
|
||||||
)DOC";
|
)DOC";
|
||||||
ReplaceAll(doc, "{name}", name);
|
c10::ReplaceAll(doc, "{name}", name);
|
||||||
ReplaceAll(doc, "{desc}", desc);
|
c10::ReplaceAll(doc, "{desc}", desc);
|
||||||
ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
|
c10::ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
|
||||||
ReplaceAll(doc, "{extra}", extra);
|
c10::ReplaceAll(doc, "{extra}", extra);
|
||||||
schema.SetDoc(doc);
|
schema.SetDoc(doc);
|
||||||
schema.Arg("broadcast", "*(type: int; default: 0)* Pass 1 to enable broadcasting.");
|
schema.Arg("broadcast", "*(type: int; default: 0)* Pass 1 to enable broadcasting.");
|
||||||
schema.Arg(
|
schema.Arg(
|
||||||
|
|
@ -804,9 +804,9 @@ Both input operands should be of type `bool`.
|
||||||
|
|
||||||
{extra}
|
{extra}
|
||||||
)DOC";
|
)DOC";
|
||||||
ReplaceAll(doc, "{name}", name);
|
c10::ReplaceAll(doc, "{name}", name);
|
||||||
ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
|
c10::ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
|
||||||
ReplaceAll(doc, "{extra}", extra);
|
c10::ReplaceAll(doc, "{extra}", extra);
|
||||||
schema.SetDoc(doc);
|
schema.SetDoc(doc);
|
||||||
schema.Arg("broadcast", "*(type: int; default: 0)* Pass 1 to enable broadcasting.");
|
schema.Arg("broadcast", "*(type: int; default: 0)* Pass 1 to enable broadcasting.");
|
||||||
schema.Arg(
|
schema.Arg(
|
||||||
|
|
@ -843,8 +843,8 @@ std::function<void(OpSchema&)> BitwiseDocGenerator(const char* name) {
|
||||||
Performs element-wise bitwise operation `{name}` (with limited broadcast support).
|
Performs element-wise bitwise operation `{name}` (with limited broadcast support).
|
||||||
Both input operands should be of type `bool`.
|
Both input operands should be of type `bool`.
|
||||||
{broadcast_doc})DOC";
|
{broadcast_doc})DOC";
|
||||||
ReplaceAll(doc, "{name}", name);
|
c10::ReplaceAll(doc, "{name}", name);
|
||||||
ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
|
c10::ReplaceAll(doc, "{broadcast_doc}", kBroadcastDoc);
|
||||||
schema.SetDoc(doc);
|
schema.SetDoc(doc);
|
||||||
schema.Arg("broadcast", "*(type: int; default: 0)* Pass 1 to enable broadcasting.");
|
schema.Arg("broadcast", "*(type: int; default: 0)* Pass 1 to enable broadcasting.");
|
||||||
schema.Arg(
|
schema.Arg(
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,7 @@ void fc_op_cpu_impl(
|
||||||
: W.size_from_dim(canonical_axis_w);
|
: W.size_from_dim(canonical_axis_w);
|
||||||
|
|
||||||
auto dimErrorString = [&]() {
|
auto dimErrorString = [&]() {
|
||||||
return caffe2::MakeString(
|
return c10::str(
|
||||||
"Dimension mismatch: ",
|
"Dimension mismatch: ",
|
||||||
"X: ",
|
"X: ",
|
||||||
X.dims(),
|
X.dims(),
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,7 @@ class FullyConnectedOp final : public Operator<Context> {
|
||||||
: W.size_from_dim(canonical_axis_w);
|
: W.size_from_dim(canonical_axis_w);
|
||||||
|
|
||||||
auto dimErrorString = [&]() {
|
auto dimErrorString = [&]() {
|
||||||
return MakeString(
|
return c10::str(
|
||||||
"Dimension mismatch: ",
|
"Dimension mismatch: ",
|
||||||
"X: ",
|
"X: ",
|
||||||
X.dims(),
|
X.dims(),
|
||||||
|
|
@ -187,7 +187,7 @@ class FullyConnectedGradientOp : public Operator<Context> {
|
||||||
: W.size_from_dim(canonical_axis_w);
|
: W.size_from_dim(canonical_axis_w);
|
||||||
|
|
||||||
auto dimErrorString = [&]() {
|
auto dimErrorString = [&]() {
|
||||||
return MakeString(
|
return c10::str(
|
||||||
"Dimension mismatch: ",
|
"Dimension mismatch: ",
|
||||||
"X: ",
|
"X: ",
|
||||||
X.dims(),
|
X.dims(),
|
||||||
|
|
|
||||||
|
|
@ -57,9 +57,9 @@ REGISTER_CPU_OPERATOR_STR(
|
||||||
template <typename Def>
|
template <typename Def>
|
||||||
string FormatDoc() {
|
string FormatDoc() {
|
||||||
string doc = Def::doc;
|
string doc = Def::doc;
|
||||||
ReplaceAll(doc, "{op}", Def::OpDef::name);
|
c10::ReplaceAll(doc, "{op}", Def::OpDef::name);
|
||||||
ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
|
c10::ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
|
||||||
auto replaced = ReplaceAll(doc, "{extra}", "");
|
auto replaced = c10::ReplaceAll(doc, "{extra}", "");
|
||||||
CAFFE_ENFORCE_EQ(replaced, 0);
|
CAFFE_ENFORCE_EQ(replaced, 0);
|
||||||
return doc;
|
return doc;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -26,8 +26,8 @@ std::function<void(OpSchema&)> LCDocGenerator(const char* dim) {
|
||||||
string doc = R"DOC(
|
string doc = R"DOC(
|
||||||
The locally connected operator consumes an input vector, a {dim}filter blob
|
The locally connected operator consumes an input vector, a {dim}filter blob
|
||||||
and a bias blob and computes the output. {lc_doc})DOC";
|
and a bias blob and computes the output. {lc_doc})DOC";
|
||||||
ReplaceAll(doc, "{dim}", dim);
|
c10::ReplaceAll(doc, "{dim}", dim);
|
||||||
ReplaceAll(doc, "{lc_doc}", kLCDoc);
|
c10::ReplaceAll(doc, "{lc_doc}", kLCDoc);
|
||||||
schema.SetDoc(doc);
|
schema.SetDoc(doc);
|
||||||
schema.Input(
|
schema.Input(
|
||||||
1,
|
1,
|
||||||
|
|
|
||||||
|
|
@ -50,7 +50,7 @@ class MatMulOp final : public Operator<Context> {
|
||||||
}
|
}
|
||||||
|
|
||||||
auto dimErrorString = [&]() {
|
auto dimErrorString = [&]() {
|
||||||
return MakeString(
|
return c10::str(
|
||||||
"Dimension mismatch: ",
|
"Dimension mismatch: ",
|
||||||
trans_a_ ? "trans(A): " : "A: ",
|
trans_a_ ? "trans(A): " : "A: ",
|
||||||
a_dim0,
|
a_dim0,
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,7 @@ class OnnxifiOp final : public Operator<Context> {
|
||||||
output_desc_.back().name = output.c_str();
|
output_desc_.back().name = output.c_str();
|
||||||
|
|
||||||
// For output, we try to get its output size hint
|
// For output, we try to get its output size hint
|
||||||
const std::string key = MakeString("output_size_hint_", output_idx);
|
const std::string key = c10::str("output_size_hint_", output_idx);
|
||||||
auto output_size_hint = this->template GetRepeatedArgument<int>(key);
|
auto output_size_hint = this->template GetRepeatedArgument<int>(key);
|
||||||
if (!output_size_hint.empty()) {
|
if (!output_size_hint.empty()) {
|
||||||
std::vector<int64_t> dims;
|
std::vector<int64_t> dims;
|
||||||
|
|
|
||||||
|
|
@ -869,8 +869,8 @@ Y:
|
||||||
std::function<void(OpSchema&)> AveragePoolDocGenerator(const char* dim) {
|
std::function<void(OpSchema&)> AveragePoolDocGenerator(const char* dim) {
|
||||||
return [=](OpSchema& schema) {
|
return [=](OpSchema& schema) {
|
||||||
string doc = "AveragePool{dim} {pool_doc}";
|
string doc = "AveragePool{dim} {pool_doc}";
|
||||||
ReplaceAll(doc, "{dim}", dim);
|
c10::ReplaceAll(doc, "{dim}", dim);
|
||||||
ReplaceAll(doc, "{pool_doc}", kAveragePoolDoc);
|
c10::ReplaceAll(doc, "{pool_doc}", kAveragePoolDoc);
|
||||||
schema.SetDoc(doc);
|
schema.SetDoc(doc);
|
||||||
schema.Input(
|
schema.Input(
|
||||||
0,
|
0,
|
||||||
|
|
@ -893,8 +893,8 @@ std::function<void(OpSchema&)> AveragePoolDocGenerator(const char* dim) {
|
||||||
std::function<void(OpSchema&)> MaxPoolDocGenerator(const char* dim) {
|
std::function<void(OpSchema&)> MaxPoolDocGenerator(const char* dim) {
|
||||||
return [=](OpSchema& schema) {
|
return [=](OpSchema& schema) {
|
||||||
string doc = "MaxPool{dim} {pool_doc}";
|
string doc = "MaxPool{dim} {pool_doc}";
|
||||||
ReplaceAll(doc, "{dim}", dim);
|
c10::ReplaceAll(doc, "{dim}", dim);
|
||||||
ReplaceAll(doc, "{pool_doc}", kMaxPoolDoc);
|
c10::ReplaceAll(doc, "{pool_doc}", kMaxPoolDoc);
|
||||||
schema.SetDoc(doc);
|
schema.SetDoc(doc);
|
||||||
schema.Input(
|
schema.Input(
|
||||||
0,
|
0,
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ struct GetRecurrentNetworkGradient : public GradientMakerBase {
|
||||||
gradientOutputs.push_back(GI(id));
|
gradientOutputs.push_back(GI(id));
|
||||||
}
|
}
|
||||||
|
|
||||||
VLOG(1) << "Gradient blobs: " << Join(", ", gradientOutputs);
|
VLOG(1) << "Gradient blobs: " << c10::Join(", ", gradientOutputs);
|
||||||
|
|
||||||
return SingleGradientDef(
|
return SingleGradientDef(
|
||||||
"RecurrentNetworkGradient", "", gradientInputs, gradientOutputs);
|
"RecurrentNetworkGradient", "", gradientInputs, gradientOutputs);
|
||||||
|
|
|
||||||
|
|
@ -336,18 +336,18 @@ OUTPUT:
|
||||||
template <typename Def>
|
template <typename Def>
|
||||||
string FormatDoc() {
|
string FormatDoc() {
|
||||||
string doc = Def::doc;
|
string doc = Def::doc;
|
||||||
ReplaceAll(doc, "{op}", Def::OpDef::name);
|
c10::ReplaceAll(doc, "{op}", Def::OpDef::name);
|
||||||
ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
|
c10::ReplaceAll(doc, "{op_doc}", Def::OpDef::doc);
|
||||||
if (strcmp(Def::OpDef::name, "Max") == 0) {
|
if (strcmp(Def::OpDef::name, "Max") == 0) {
|
||||||
ReplaceAll(doc, "{extra}", kLengthsMaxExtra);
|
c10::ReplaceAll(doc, "{extra}", kLengthsMaxExtra);
|
||||||
} else if (strcmp(Def::OpDef::name, "Mean") == 0) {
|
} else if (strcmp(Def::OpDef::name, "Mean") == 0) {
|
||||||
ReplaceAll(doc, "{extra}", kLengthsMeanExtra);
|
c10::ReplaceAll(doc, "{extra}", kLengthsMeanExtra);
|
||||||
} else if (strcmp(Def::OpDef::name, "Sum") == 0) {
|
} else if (strcmp(Def::OpDef::name, "Sum") == 0) {
|
||||||
ReplaceAll(doc, "{extra}", kLengthsSumExtra);
|
c10::ReplaceAll(doc, "{extra}", kLengthsSumExtra);
|
||||||
} else if (strcmp(Def::OpDef::name, "WeightedSum") == 0) {
|
} else if (strcmp(Def::OpDef::name, "WeightedSum") == 0) {
|
||||||
ReplaceAll(doc, "{extra}", kLengthsWeightedSumExtra);
|
c10::ReplaceAll(doc, "{extra}", kLengthsWeightedSumExtra);
|
||||||
} else {
|
} else {
|
||||||
ReplaceAll(doc, "{extra}", " ");
|
c10::ReplaceAll(doc, "{extra}", " ");
|
||||||
}
|
}
|
||||||
return doc;
|
return doc;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -40,12 +40,12 @@ struct VisitorContext {
|
||||||
std::string ShowNode(NodeRef node) {
|
std::string ShowNode(NodeRef node) {
|
||||||
if (nn::is<NeuralNetData>(node)) {
|
if (nn::is<NeuralNetData>(node)) {
|
||||||
const auto* nn_tensor = nn::get<NeuralNetData>(node);
|
const auto* nn_tensor = nn::get<NeuralNetData>(node);
|
||||||
return MakeString("Tensor: ", nn_tensor->getName());
|
return c10::str("Tensor: ", nn_tensor->getName());
|
||||||
} else if (nn::is<NeuralNetOperator>(node)) {
|
} else if (nn::is<NeuralNetOperator>(node)) {
|
||||||
const auto* nn_op = nn::get<NeuralNetOperator>(node);
|
const auto* nn_op = nn::get<NeuralNetOperator>(node);
|
||||||
const auto& op_def =
|
const auto& op_def =
|
||||||
dyn_cast<Caffe2Annotation>(nn_op->getAnnotation())->getOperatorDef();
|
dyn_cast<Caffe2Annotation>(nn_op->getAnnotation())->getOperatorDef();
|
||||||
return MakeString("Op: ", op_def.type());
|
return c10::str("Op: ", op_def.type());
|
||||||
} else {
|
} else {
|
||||||
CAFFE_THROW("Known node");
|
CAFFE_THROW("Known node");
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -135,7 +135,7 @@ OperatorDef OnnxifiTransformer::BuildOnnxifiOp(
|
||||||
if (it != output_size_hints.end()) {
|
if (it != output_size_hints.end()) {
|
||||||
const auto& dims = it->second;
|
const auto& dims = it->second;
|
||||||
auto* output_size_hint_arg = op.add_arg();
|
auto* output_size_hint_arg = op.add_arg();
|
||||||
output_size_hint_arg->set_name(MakeString("output_size_hint_", i));
|
output_size_hint_arg->set_name(c10::str("output_size_hint_", i));
|
||||||
for (const auto& d : dims) {
|
for (const auto& d : dims) {
|
||||||
output_size_hint_arg->add_ints(d);
|
output_size_hint_arg->add_ints(d);
|
||||||
}
|
}
|
||||||
|
|
@ -199,7 +199,7 @@ NetDef OnnxifiTransformer::SubnetToOnnxifiOp(
|
||||||
extra_weights.emplace_back(t.name());
|
extra_weights.emplace_back(t.name());
|
||||||
CAFFE_ENFORCE(
|
CAFFE_ENFORCE(
|
||||||
input_mapping_.emplace(t.name(), t.name()).second,
|
input_mapping_.emplace(t.name(), t.name()).second,
|
||||||
MakeString("Tensor ", t.name(), " already exists in the workspace"));
|
c10::str("Tensor ", t.name(), " already exists in the workspace"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,6 @@ set(Caffe2_HIP_TEST_SRCS ${Caffe2_HIP_TEST_SRCS}
|
||||||
|
|
||||||
set(LIB_SOURCES_CPU
|
set(LIB_SOURCES_CPU
|
||||||
Array.cpp
|
Array.cpp
|
||||||
Optional.cpp
|
|
||||||
Metaprogramming.cpp
|
Metaprogramming.cpp
|
||||||
TypeList.cpp
|
TypeList.cpp
|
||||||
TypeTraits.cpp
|
TypeTraits.cpp
|
||||||
|
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
#include "Optional.h"
|
|
||||||
|
|
@ -179,13 +179,11 @@ constexpr struct in_place_t{} in_place{};
|
||||||
|
|
||||||
|
|
||||||
// 20.5.7, Disengaged state indicator
|
// 20.5.7, Disengaged state indicator
|
||||||
struct nullopt_t
|
struct c10::nullopt_t {
|
||||||
{
|
|
||||||
struct init{};
|
struct init{};
|
||||||
constexpr explicit nullopt_t(init){}
|
constexpr explicit c10::nullopt_t(init) {}
|
||||||
};
|
};
|
||||||
constexpr nullopt_t nullopt{nullopt_t::init()};
|
constexpr c10::nullopt_t c10::nullopt{c10::nullopt_t::init()};
|
||||||
|
|
||||||
|
|
||||||
// 20.5.8, class bad_optional_access
|
// 20.5.8, class bad_optional_access
|
||||||
class bad_optional_access : public std::logic_error {
|
class bad_optional_access : public std::logic_error {
|
||||||
|
|
@ -282,7 +280,9 @@ using OptionalBase = typename std::conditional<
|
||||||
template <class T>
|
template <class T>
|
||||||
class optional : private OptionalBase<T>
|
class optional : private OptionalBase<T>
|
||||||
{
|
{
|
||||||
static_assert( !std::is_same<typename std::decay<T>::type, nullopt_t>::value, "bad T" );
|
static_assert(
|
||||||
|
!std::is_same<typename std::decay<T>::type, c10::nullopt_t>::value,
|
||||||
|
"bad T");
|
||||||
static_assert( !std::is_same<typename std::decay<T>::type, in_place_t>::value, "bad T" );
|
static_assert( !std::is_same<typename std::decay<T>::type, in_place_t>::value, "bad T" );
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -330,7 +330,7 @@ public:
|
||||||
|
|
||||||
// 20.5.5.1, constructors
|
// 20.5.5.1, constructors
|
||||||
constexpr optional() noexcept : OptionalBase<T>() {};
|
constexpr optional() noexcept : OptionalBase<T>() {};
|
||||||
constexpr optional(nullopt_t) noexcept : OptionalBase<T>() {};
|
constexpr optional(c10::nullopt_t) noexcept : OptionalBase<T>(){};
|
||||||
|
|
||||||
optional(const optional& rhs)
|
optional(const optional& rhs)
|
||||||
: OptionalBase<T>()
|
: OptionalBase<T>()
|
||||||
|
|
@ -366,8 +366,7 @@ public:
|
||||||
~optional() = default;
|
~optional() = default;
|
||||||
|
|
||||||
// 20.5.4.3, assignment
|
// 20.5.4.3, assignment
|
||||||
optional& operator=(nullopt_t) noexcept
|
optional& operator=(c10::nullopt_t) noexcept {
|
||||||
{
|
|
||||||
clear();
|
clear();
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
@ -538,7 +537,7 @@ public:
|
||||||
template <class T>
|
template <class T>
|
||||||
class optional<T&>
|
class optional<T&>
|
||||||
{
|
{
|
||||||
static_assert( !std::is_same<T, nullopt_t>::value, "bad T" );
|
static_assert(!std::is_same<T, c10::nullopt_t>::value, "bad T");
|
||||||
static_assert( !std::is_same<T, in_place_t>::value, "bad T" );
|
static_assert( !std::is_same<T, in_place_t>::value, "bad T" );
|
||||||
T* ref;
|
T* ref;
|
||||||
|
|
||||||
|
|
@ -547,7 +546,7 @@ public:
|
||||||
// 20.5.5.1, construction/destruction
|
// 20.5.5.1, construction/destruction
|
||||||
constexpr optional() noexcept : ref(nullptr) {}
|
constexpr optional() noexcept : ref(nullptr) {}
|
||||||
|
|
||||||
constexpr optional(nullopt_t) noexcept : ref(nullptr) {}
|
constexpr optional(c10::nullopt_t) noexcept : ref(nullptr) {}
|
||||||
|
|
||||||
constexpr optional(T& v) noexcept : ref(detail_::static_addressof(v)) {}
|
constexpr optional(T& v) noexcept : ref(detail_::static_addressof(v)) {}
|
||||||
|
|
||||||
|
|
@ -562,7 +561,7 @@ public:
|
||||||
~optional() = default;
|
~optional() = default;
|
||||||
|
|
||||||
// 20.5.5.2, mutation
|
// 20.5.5.2, mutation
|
||||||
optional& operator=(nullopt_t) noexcept {
|
optional& operator=(c10::nullopt_t) noexcept {
|
||||||
ref = nullptr;
|
ref = nullptr;
|
||||||
return *this;
|
return *this;
|
||||||
}
|
}
|
||||||
|
|
@ -680,70 +679,67 @@ template <class T> constexpr bool operator>=(const optional<T>& x, const optiona
|
||||||
return !(x < y);
|
return !(x < y);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 20.5.9, Comparison with c10::nullopt
|
||||||
// 20.5.9, Comparison with nullopt
|
template <class T>
|
||||||
template <class T> constexpr bool operator==(const optional<T>& x, nullopt_t) noexcept
|
constexpr bool operator==(const optional<T>& x, c10::nullopt_t) noexcept {
|
||||||
{
|
|
||||||
return (!x);
|
return (!x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator==(nullopt_t, const optional<T>& x) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator==(c10::nullopt_t, const optional<T>& x) noexcept {
|
||||||
return (!x);
|
return (!x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator!=(const optional<T>& x, nullopt_t) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator!=(const optional<T>& x, c10::nullopt_t) noexcept {
|
||||||
return bool(x);
|
return bool(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator!=(nullopt_t, const optional<T>& x) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator!=(c10::nullopt_t, const optional<T>& x) noexcept {
|
||||||
return bool(x);
|
return bool(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator<(const optional<T>&, nullopt_t) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator<(const optional<T>&, c10::nullopt_t) noexcept {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator<(nullopt_t, const optional<T>& x) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator<(c10::nullopt_t, const optional<T>& x) noexcept {
|
||||||
return bool(x);
|
return bool(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator<=(const optional<T>& x, nullopt_t) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator<=(const optional<T>& x, c10::nullopt_t) noexcept {
|
||||||
return (!x);
|
return (!x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator<=(nullopt_t, const optional<T>&) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator<=(c10::nullopt_t, const optional<T>&) noexcept {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator>(const optional<T>& x, nullopt_t) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator>(const optional<T>& x, c10::nullopt_t) noexcept {
|
||||||
return bool(x);
|
return bool(x);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator>(nullopt_t, const optional<T>&) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator>(c10::nullopt_t, const optional<T>&) noexcept {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator>=(const optional<T>&, nullopt_t) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator>=(const optional<T>&, c10::nullopt_t) noexcept {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class T> constexpr bool operator>=(nullopt_t, const optional<T>& x) noexcept
|
template <class T>
|
||||||
{
|
constexpr bool operator>=(c10::nullopt_t, const optional<T>& x) noexcept {
|
||||||
return (!x);
|
return (!x);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
// 20.5.10, Comparison with T
|
// 20.5.10, Comparison with T
|
||||||
template <class T> constexpr bool operator==(const optional<T>& x, const T& v)
|
template <class T> constexpr bool operator==(const optional<T>& x, const T& v)
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -150,7 +150,7 @@ C10_EXPORT bool ParseFromString(const string& spec, Message* proto) {
|
||||||
string bc_spec = spec;
|
string bc_spec = spec;
|
||||||
|
|
||||||
{
|
{
|
||||||
auto num_replaced = ReplaceAll(bc_spec, "cuda_gpu_id", "device_id");
|
auto num_replaced = c10::ReplaceAll(bc_spec, "cuda_gpu_id", "device_id");
|
||||||
if (num_replaced) {
|
if (num_replaced) {
|
||||||
LOG(ERROR) << "Your model was serialized in Protobuf TextFormat and "
|
LOG(ERROR) << "Your model was serialized in Protobuf TextFormat and "
|
||||||
<< "it has "
|
<< "it has "
|
||||||
|
|
|
||||||
9
setup.py
9
setup.py
|
|
@ -853,7 +853,9 @@ include_dirs += [
|
||||||
library_dirs.append(lib_path)
|
library_dirs.append(lib_path)
|
||||||
|
|
||||||
# we specify exact lib names to avoid conflict with lua-torch installs
|
# we specify exact lib names to avoid conflict with lua-torch installs
|
||||||
CAFFE2_LIBS = [os.path.join(lib_path, 'libcaffe2.so')]
|
CAFFE2_LIBS = [
|
||||||
|
os.path.join(lib_path, 'libcaffe2.so'),
|
||||||
|
os.path.join(lib_path, 'libc10.so')]
|
||||||
if USE_CUDA:
|
if USE_CUDA:
|
||||||
CAFFE2_LIBS.extend(['-Wl,--no-as-needed', os.path.join(lib_path, 'libcaffe2_gpu.so'), '-Wl,--as-needed'])
|
CAFFE2_LIBS.extend(['-Wl,--no-as-needed', os.path.join(lib_path, 'libcaffe2_gpu.so'), '-Wl,--as-needed'])
|
||||||
if USE_ROCM:
|
if USE_ROCM:
|
||||||
|
|
@ -877,7 +879,10 @@ if IS_DARWIN:
|
||||||
NCCL_LIB = os.path.join(lib_path, 'libnccl.2.dylib')
|
NCCL_LIB = os.path.join(lib_path, 'libnccl.2.dylib')
|
||||||
|
|
||||||
if IS_WINDOWS:
|
if IS_WINDOWS:
|
||||||
CAFFE2_LIBS = [os.path.join(lib_path, 'caffe2.lib')]
|
CAFFE2_LIBS = [
|
||||||
|
os.path.join(lib_path, 'caffe2.lib'),
|
||||||
|
os.path.join(lib_path, 'c10.lib')
|
||||||
|
]
|
||||||
if USE_CUDA:
|
if USE_CUDA:
|
||||||
CAFFE2_LIBS.append(os.path.join(lib_path, 'caffe2_gpu.lib'))
|
CAFFE2_LIBS.append(os.path.join(lib_path, 'caffe2_gpu.lib'))
|
||||||
if USE_ROCM:
|
if USE_ROCM:
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ struct TestValue {
|
||||||
explicit TestValue(const int& x) : lvalue_(x) {}
|
explicit TestValue(const int& x) : lvalue_(x) {}
|
||||||
explicit TestValue(int&& x) : rvalue_(x) {}
|
explicit TestValue(int&& x) : rvalue_(x) {}
|
||||||
|
|
||||||
at::optional<int> lvalue_;
|
c10::optional<int> lvalue_;
|
||||||
at::optional<int> rvalue_;
|
c10::optional<int> rvalue_;
|
||||||
};
|
};
|
||||||
|
|
||||||
TEST(MakeUniqueTest, ForwardRvaluesCorrectly) {
|
TEST(MakeUniqueTest, ForwardRvaluesCorrectly) {
|
||||||
|
|
|
||||||
|
|
@ -171,7 +171,7 @@ TEST_F(ModuleTest, CallingCloneOnModuleThatDoesNotOverrideCloneThrows) {
|
||||||
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
|
TEST_F(ModuleTest, CallingCloneOnModuleThatDoesOverrideCloneDoesNotThrow) {
|
||||||
struct Cloneable : Module {
|
struct Cloneable : Module {
|
||||||
std::shared_ptr<Module> clone(
|
std::shared_ptr<Module> clone(
|
||||||
at::optional<torch::Device> device = at::nullopt) const override {
|
c10::optional<torch::Device> device = c10::nullopt) const override {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -189,7 +189,7 @@ TEST_F(
|
||||||
auto output = parallel::data_parallel(
|
auto output = parallel::data_parallel(
|
||||||
m,
|
m,
|
||||||
input,
|
input,
|
||||||
/*devices=*/at::nullopt,
|
/*devices=*/c10::nullopt,
|
||||||
/*output_device=*/torch::Device(torch::kCUDA, 1));
|
/*output_device=*/torch::Device(torch::kCUDA, 1));
|
||||||
ASSERT_TRUE(output.defined());
|
ASSERT_TRUE(output.defined());
|
||||||
ASSERT_TRUE(output.device().is_cuda());
|
ASSERT_TRUE(output.device().is_cuda());
|
||||||
|
|
|
||||||
|
|
@ -10,11 +10,11 @@
|
||||||
#include "ATen/DeviceGuard.h"
|
#include "ATen/DeviceGuard.h"
|
||||||
#include "ATen/NativeFunctions.h"
|
#include "ATen/NativeFunctions.h"
|
||||||
#include "ATen/TensorImpl.h"
|
#include "ATen/TensorImpl.h"
|
||||||
#include "ATen/core/UndefinedTensorImpl.h"
|
|
||||||
#include "ATen/Utils.h"
|
#include "ATen/Utils.h"
|
||||||
#include "ATen/WrapDimUtils.h"
|
#include "ATen/WrapDimUtils.h"
|
||||||
#include "ATen/core/Half.h"
|
#include "ATen/core/Half.h"
|
||||||
#include "ATen/core/optional.h"
|
#include "ATen/core/UndefinedTensorImpl.h"
|
||||||
|
#include "c10/util/Optional.h"
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@ struct MatrixMultiplier {
|
||||||
at::Tensor tensor_;
|
at::Tensor tensor_;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool function_taking_optional(at::optional<at::Tensor> tensor) {
|
bool function_taking_optional(c10::optional<at::Tensor> tensor) {
|
||||||
return tensor.has_value();
|
return tensor.has_value();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -53,7 +53,7 @@ void test_argument_checking_for_serialized_modules(
|
||||||
try {
|
try {
|
||||||
module->forward({torch::jit::IValue(1), torch::jit::IValue(2)});
|
module->forward({torch::jit::IValue(1), torch::jit::IValue(2)});
|
||||||
assert(false);
|
assert(false);
|
||||||
} catch (const at::Error& error) {
|
} catch (const c10::Error& error) {
|
||||||
assert(
|
assert(
|
||||||
std::string(error.what_without_backtrace())
|
std::string(error.what_without_backtrace())
|
||||||
.find("Expected at most 1 argument(s) for operator 'forward', "
|
.find("Expected at most 1 argument(s) for operator 'forward', "
|
||||||
|
|
@ -63,7 +63,7 @@ void test_argument_checking_for_serialized_modules(
|
||||||
try {
|
try {
|
||||||
module->forward({torch::jit::IValue(5)});
|
module->forward({torch::jit::IValue(5)});
|
||||||
assert(false);
|
assert(false);
|
||||||
} catch (const at::Error& error) {
|
} catch (const c10::Error& error) {
|
||||||
assert(
|
assert(
|
||||||
std::string(error.what_without_backtrace())
|
std::string(error.what_without_backtrace())
|
||||||
.find("Expected value of type Dynamic for argument 'input' in "
|
.find("Expected value of type Dynamic for argument 'input' in "
|
||||||
|
|
@ -73,7 +73,7 @@ void test_argument_checking_for_serialized_modules(
|
||||||
try {
|
try {
|
||||||
module->forward({});
|
module->forward({});
|
||||||
assert(false);
|
assert(false);
|
||||||
} catch (const at::Error& error) {
|
} catch (const c10::Error& error) {
|
||||||
assert(
|
assert(
|
||||||
std::string(error.what_without_backtrace())
|
std::string(error.what_without_backtrace())
|
||||||
.find("forward() is missing value for argument 'input'") == 0);
|
.find("forward() is missing value for argument 'input'") == 0);
|
||||||
|
|
|
||||||
|
|
@ -331,7 +331,7 @@ def create_python_bindings(python_functions, has_self, is_module=False):
|
||||||
elif dispatch_type == 'Tensor &':
|
elif dispatch_type == 'Tensor &':
|
||||||
dispatch_type = 'Tensor'
|
dispatch_type = 'Tensor'
|
||||||
elif dispatch_type == 'const Device &':
|
elif dispatch_type == 'const Device &':
|
||||||
dispatch_type = 'at::optional<int32_t>'
|
dispatch_type = 'c10::optional<int32_t>'
|
||||||
formal = '{} {}'.format(dispatch_type, name)
|
formal = '{} {}'.format(dispatch_type, name)
|
||||||
return expr, formal
|
return expr, formal
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1387,7 +1387,7 @@ Tensor as_strided_backward(Tensor grad, TensorGeometry input_geometry, IntList s
|
||||||
auto storage = at::zeros({base_size}, grad.options());
|
auto storage = at::zeros({base_size}, grad.options());
|
||||||
|
|
||||||
// prepare indices tensor if we will do index_add_ later
|
// prepare indices tensor if we will do index_add_ later
|
||||||
at::optional<at::Tensor> flatten_full_indices;
|
c10::optional<at::Tensor> flatten_full_indices;
|
||||||
if (inp_maybe_overlap || out_maybe_overlap) {
|
if (inp_maybe_overlap || out_maybe_overlap) {
|
||||||
flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong));
|
flatten_full_indices = at::arange(0, base_size, grad.options().dtype(at::kLong));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user