mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: IValue is short for interpreter value. It is used frequently so a short name is important. This will allow us to implement more non-tensor types in an efficient way and remove many hacks from the compiler. This PR is limited. It only introduces IValue and changes interpreter to use it. Follow up PRs will: * Change the way aten_ops consume non-tensor types so that integer lists, are no longer represented as Tensors. * Introduce TensorList as a fundamental type and remove all vararg handling in gen_jit_dispatch * Change the compiler to implement math on primitive numbers rather than converting to tensors. jamesr66a apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9368 Reviewed By: ezyang Differential Revision: D8817598 Pulled By: zdevito fbshipit-source-id: 29dce80611ce5f6384234de9d12a67861d2b112f
64 lines
1.7 KiB
C++
64 lines
1.7 KiB
C++
#pragma once
|
|
|
|
#include <vector>
|
|
#include <ATen/ATen.h>
|
|
|
|
namespace torch {
|
|
|
|
// The passed in function must take T by value (T), or by
|
|
// const reference (const T&); taking T by non-const reference
|
|
// will result in an error like:
|
|
//
|
|
// error: no type named 'type' in 'class std::result_of<foobar::__lambda(T)>'
|
|
//
|
|
// No explicit template parameters are required.
|
|
|
|
// Overload for explicit function and ArrayRef
|
|
template<typename F, typename T>
|
|
inline auto fmap(const T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
|
|
std::vector<decltype(fn(*inputs.begin()))> r;
|
|
r.reserve(inputs.size());
|
|
for(const auto & input : inputs)
|
|
r.push_back(fn(input));
|
|
return r;
|
|
}
|
|
|
|
template<typename F, typename T>
|
|
inline auto fmap(T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
|
|
std::vector<decltype(fn(*inputs.begin()))> r;
|
|
r.reserve(inputs.size());
|
|
for(auto & input : inputs)
|
|
r.push_back(fn(input));
|
|
return r;
|
|
}
|
|
|
|
// C++ forbids taking an address of a constructor, so here's a workaround...
|
|
// Overload for constructor (R) application
|
|
template<typename R, typename T>
|
|
inline std::vector<R> fmap(const T& inputs) {
|
|
std::vector<R> r;
|
|
r.reserve(inputs.size());
|
|
for(auto & input : inputs)
|
|
r.push_back(R(input));
|
|
return r;
|
|
}
|
|
|
|
template<typename F, typename T>
|
|
inline std::vector<T> filter(at::ArrayRef<T> inputs, const F& fn) {
|
|
std::vector<T> r;
|
|
r.reserve(inputs.size());
|
|
for(auto & input : inputs) {
|
|
if (fn(input)) {
|
|
r.push_back(input);
|
|
}
|
|
}
|
|
return r;
|
|
}
|
|
|
|
template<typename F, typename T>
|
|
inline std::vector<T> filter(const std::vector<T>& inputs, const F& fn) {
|
|
return filter<F, T>(static_cast<at::ArrayRef<T>>(inputs), fn);
|
|
}
|
|
|
|
}
|