pytorch/torch/csrc/utils/functional.h
Zachary DeVito 9ed2190bdb Add a tagged union type that replaces tensor in the interpreter. (#9368)
Summary:
IValue is short for interpreter value. It is used frequently so a short name is important.
This will allow us to implement more non-tensor types in an efficient way and remove
many hacks from the compiler.

This PR is limited. It only introduces IValue and changes interpreter to use it.
Follow up PRs will:
* Change the way aten_ops consume non-tensor types so that integer lists,
  are no longer represented as Tensors.
* Introduce TensorList as a fundamental type and remove all vararg handling in gen_jit_dispatch
* Change the compiler to implement math on primitive numbers rather than converting to tensors.

jamesr66a  apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9368

Reviewed By: ezyang

Differential Revision: D8817598

Pulled By: zdevito

fbshipit-source-id: 29dce80611ce5f6384234de9d12a67861d2b112f
2018-07-16 15:40:22 -07:00

64 lines
1.7 KiB
C++

#pragma once
#include <vector>
#include <ATen/ATen.h>
namespace torch {
// The passed in function must take T by value (T), or by
// const reference (const T&); taking T by non-const reference
// will result in an error like:
//
// error: no type named 'type' in 'class std::result_of<foobar::__lambda(T)>'
//
// No explicit template parameters are required.
// Overload for explicit function and ArrayRef
template<typename F, typename T>
inline auto fmap(const T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
std::vector<decltype(fn(*inputs.begin()))> r;
r.reserve(inputs.size());
for(const auto & input : inputs)
r.push_back(fn(input));
return r;
}
template<typename F, typename T>
inline auto fmap(T& inputs, const F& fn) -> std::vector<decltype(fn(*inputs.begin()))> {
std::vector<decltype(fn(*inputs.begin()))> r;
r.reserve(inputs.size());
for(auto & input : inputs)
r.push_back(fn(input));
return r;
}
// C++ forbids taking an address of a constructor, so here's a workaround...
// Overload for constructor (R) application
template<typename R, typename T>
inline std::vector<R> fmap(const T& inputs) {
std::vector<R> r;
r.reserve(inputs.size());
for(auto & input : inputs)
r.push_back(R(input));
return r;
}
template<typename F, typename T>
inline std::vector<T> filter(at::ArrayRef<T> inputs, const F& fn) {
std::vector<T> r;
r.reserve(inputs.size());
for(auto & input : inputs) {
if (fn(input)) {
r.push_back(input);
}
}
return r;
}
template<typename F, typename T>
inline std::vector<T> filter(const std::vector<T>& inputs, const F& fn) {
return filter<F, T>(static_cast<at::ArrayRef<T>>(inputs), fn);
}
}