mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: I got some tensor->variable conversion exceptions from `torch/csrc/autograd/variable.h`, which used the `TORCH_ASSERTM` macros instead of `AT_CHECK`, so they didn't have backtraces. This was such a substantial loss for debugability that I decided to update the whole codebase to use the backtrace-enabled ATen macros instead of `TORCH_ASSERT` and `JIT_ASSERT`, the latter having been an alias of the former. ezyang apaszke zdevito Pull Request resolved: https://github.com/pytorch/pytorch/pull/9575 Differential Revision: D8924566 Pulled By: goldsborough fbshipit-source-id: 7a4013b13eec9dbf024cef94cf49fca72f61d441
51 lines
1.3 KiB
C++
51 lines
1.3 KiB
C++
#pragma once
|
|
|
|
#include "torch/csrc/utils/functional.h"
|
|
|
|
#include <ATen/ATen.h>
|
|
#include <utility>
|
|
|
|
namespace torch { namespace utils {
|
|
|
|
inline at::Tensor flatten_dense_tensors(at::TensorList tensors) {
|
|
static auto flatten = [](const at::Tensor &t) { return t.contiguous().view({-1}); };
|
|
if (tensors.size() == 1)
|
|
return flatten(tensors[0]);
|
|
return at::cat(fmap(tensors, flatten));
|
|
}
|
|
|
|
inline std::vector<at::Tensor> unflatten_dense_tensors(const at::Tensor& flat, at::TensorList tensors) {
|
|
std::vector<at::Tensor> outputs;
|
|
outputs.reserve(tensors.size());
|
|
size_t offset = 0;
|
|
for (const auto & tensor : tensors) {
|
|
auto numel = tensor.numel();
|
|
outputs.push_back(flat.narrow(0, offset, numel).view(tensor.sizes()));
|
|
offset += numel;
|
|
}
|
|
return outputs;
|
|
}
|
|
|
|
|
|
struct TensorGroup {
|
|
std::vector<at::Tensor> tensors;
|
|
size_t size = 0;
|
|
|
|
at::Type& type() {
|
|
AT_ASSERT(!tensors.empty());
|
|
return tensors[0].type();
|
|
}
|
|
};
|
|
|
|
std::vector<TensorGroup> take_tensors(at::TensorList tensors, size_t size_limit);
|
|
void reorder_tensors_like(std::vector<at::Tensor>& tensors, at::TensorList order);
|
|
|
|
std::pair<at::Tensor, at::Tensor> flatten_sparse_tensors(at::TensorList tensors);
|
|
|
|
std::vector<at::Tensor> unflatten_sparse_tensors(
|
|
const at::Tensor& flat_indices,
|
|
const at::Tensor& flat_values,
|
|
at::TensorList tensors);
|
|
|
|
}}
|