mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/28605 This was added because std::to_string isn't available in libstc++ on Android. Use it in more places to get the PyTorch Android build working with libstdc++. Test Plan: Internal android build. Reviewed By: jerryzh168 Differential Revision: D18099520 fbshipit-source-id: 17a2b617c2d21deadd0fdac1db849823637981fc
81 lines
2.3 KiB
C++
81 lines
2.3 KiB
C++
#include <torch/csrc/jit/function.h>
|
|
#include <torch/csrc/jit/passes/inliner.h>
|
|
|
|
#include <torch/csrc/jit/script/error_report.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace {
|
|
FunctionSchema defaultSchemaFor(const Function& function) {
|
|
std::vector<Argument> args;
|
|
std::vector<Argument> returns;
|
|
Graph& g = *function.graph();
|
|
size_t num_inputs = function.num_inputs();
|
|
for (size_t i = 0; i < num_inputs; ++i) {
|
|
const Value* v = g.inputs().at(i);
|
|
std::string name = v->hasDebugName() ? v->debugNameBase()
|
|
: ("argument_" + c10::to_string(i));
|
|
args.emplace_back(std::move(name), unshapedType(g.inputs()[i]->type()));
|
|
}
|
|
for (size_t i = 0; i < g.outputs().size(); ++i) {
|
|
returns.emplace_back("", unshapedType(g.outputs()[i]->type()));
|
|
}
|
|
return {function.name(), "", std::move(args), std::move(returns)};
|
|
}
|
|
} // namespace
|
|
|
|
struct RecursiveMethodCallError : public std::exception {};
|
|
void placeholderCreator(Function&) {
|
|
throw RecursiveMethodCallError();
|
|
}
|
|
|
|
void Function::run(Stack& stack) {
|
|
get_executor().run(stack);
|
|
}
|
|
|
|
void Function::run(Stack&& stack) {
|
|
run(stack);
|
|
}
|
|
|
|
IValue Function::operator()(
|
|
std::vector<IValue> stack,
|
|
const Kwargs& kwargs) {
|
|
getSchema().checkAndNormalizeInputs(stack, kwargs);
|
|
run(stack);
|
|
return stack.front();
|
|
}
|
|
|
|
void Function::ensure_defined() {
|
|
try {
|
|
if (function_creator_) {
|
|
auto creator = function_creator_;
|
|
function_creator_ = placeholderCreator;
|
|
creator(*this);
|
|
function_creator_ = nullptr;
|
|
}
|
|
} catch (RecursiveMethodCallError&) {
|
|
throw script::ErrorReport() // TODO: once lower_first_class methods is
|
|
// removed re-establish callsite info for
|
|
// debugging
|
|
<< " method '" << name() << "' is called recursively. "
|
|
<< "Recursive calls are not supported";
|
|
}
|
|
check_single_output();
|
|
}
|
|
|
|
const FunctionSchema& Function::getSchema() const {
|
|
if (schema_ == nullptr) {
|
|
schema_ = make_unique<FunctionSchema>(defaultSchemaFor(*this));
|
|
}
|
|
return *schema_;
|
|
}
|
|
|
|
void preoptimizeGraph(std::shared_ptr<Graph>& graph) {
|
|
// TODO: Invoke cleanup passes before and after inlining to reduce amount of
|
|
// code we're copying.
|
|
Inline(*graph);
|
|
}
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|