#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #ifdef USE_RPC #include using torch::distributed::autograd::DistAutogradContainer; #endif #include #include #include #include #include #include #include #include #include #include #include namespace torch { namespace jit { using CodeImpl = interpreter::CodeImpl; // Before we translate to intepreter instructions, we do // some preprocessing of the graph to turn it into a form that is closer // to what the instructions will look like. // In particular we: // * Computes whether a input to a node is the last use, so we can issue MOVE // rather than LOAD instructions. // * Drop nodes are inserted for any node that is unused to create a dummy use // that will cause the interpreter to free the node. // A drop node just pops its input off the stack to ensure the interpreter // releases references to nodes that are never used. Drop nodes are also // inserted when the last use of a node is in some conditionally run control // flow (e.g. one side of an If) and the interpreter must free the node only // after the control flow has reconverged // Outputs are: // * graph - the post processed copy of g // * move_flags[n] - a list of booleans, one for each input, // indicating whether this is the last use of the value. The interpreter // should generate a move rather than a copy in this case. TensorTypePtr tensorTypeInCurrentExecutionContext(const at::Tensor& t) { if (!t.defined()) { return TensorType::get()->withUndefined(); } auto r = TensorType::create(t); if (!at::GradMode::is_enabled()) { return r->withRequiresGrad(false); } return r; } namespace { inline int64_t getDistAutogradContextId() { #ifdef USE_RPC return DistAutogradContainer::currentContextId(); #else return 0; #endif } } // namespace // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables) thread_local InterpreterStateImpl* tls_int_state_ptr_ = nullptr; struct TLSCurrentInterpreterGuard { TLSCurrentInterpreterGuard(InterpreterStateImpl* state) { prev_state_ = tls_int_state_ptr_; tls_int_state_ptr_ = state; } ~TLSCurrentInterpreterGuard() { tls_int_state_ptr_ = prev_state_; } private: InterpreterStateImpl* prev_state_; }; // InterpreterState state that and used to compute a Code struct InterpreterStateImpl : c10::intrusive_ptr_target { InterpreterStateImpl(const Code& code, TaskLauncher taskLauncher) : taskLauncher_(std::move(taskLauncher)) { enterFrame(code, 0); } private: using Frame = torch::jit::interpreter::Frame; struct WarnedNodes { public: // Inserts idx into warned_nodes_, returns a boolean indicates whether // insertion actually happened (idx wasn't originally in the set). bool insert(int32_t idx) { std::unique_lock lock(mutex_); return warned_nodes_.insert(idx).second; } private: std::mutex mutex_; std::unordered_set warned_nodes_; }; WarnedNodes warned_nodes_; // if we need to suspend, where do we reset the stack? // answer: to where it was when we were called, not // including any inputs to this function int64_t stack_start_ = -1; c10::intrusive_ptr future_; TaskLauncher taskLauncher_; // this holds all the tensors for this interpreter run // we don't bother minimizing the size of this vector, since the extra // memory used by the pointers in this will be small // instead we are very aggresive about releasing tensors when they become dead // to make sure memory management happens efficiently. // We optimize for the case where derivatives are run with retain_graph=False // in the case where it is true, then the interpreter and this array get // copied if this every becomes a bottleneck then we _should_ consider // minimizing the total number or register std::vector registers; // A stack of objects that have been __enter__'d. std::vector entered_objects; std::vector frames; c10::intrusive_ptr intrusive_from_this() { c10::raw::intrusive_ptr::incref(this); return c10::intrusive_ptr::reclaim(this); } void enterFrame(const Code& code, size_t base_pointer) { frames.emplace_back(Frame{code.pImpl, 0, base_pointer, c10::nullopt}); registers.resize(registers.size() + code.pImpl->register_size_); } void leaveFrame() { registers.resize(registers.size() - frames.back().function->register_size_); frames.pop_back(); } // relative to the end of the register list so that when we call // functions we are referring to the registers of the currenly executing // function. IValue& reg(size_t reg) { return *(registers.end() - reg); } void dump(std::ostream& out, const Stack& stack) const { out << "Stack:\n"; for (const auto& val : stack) { out << val; out << "\n"; } } void runBuiltinFunction(Stack& stack, Function* fn) { // BuiltinOpFunction directly invokes a void(Stack&) to implement // custom C++ classes. Call run() here with the stack, and we will // get the results from that C++ method back in the stack. Advance // the PC by 1 without adding any new frame. fn->run(stack); ++frames.back().pc; } void runGraphFunction(Stack& stack, Function* fn) { const Code& code = // consider passing // `frames.back().function->remaining_bailout_depth_` into // `get_executor().getPlanFor()` to propagate caller's depth // restrictions onto children while this strategy has a // potential to reduce the number of compilations for too // dynamic callers we might miss opportunities where a caller is // dynamic but a callee gets stable arguments fn->get_executor() .getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()) .code; ++frames.back().pc; enterFrame(code, stack.size() - code.num_inputs()); checkAndStartRecordFunction(frames.back(), stack); } bool runImpl(Stack& stack) { // if we have never run before, then we might have to return the // stack when we suspend, record where it starts so we return the right // stack if (stack_start_ == -1) { TORCH_INTERNAL_ASSERT(stack.size() >= frames.back().function->n_inputs); stack_start_ = stack.size() - frames.back().function->n_inputs; } else { // during restarts, all of the stack is always our own, so we leave // nothing stack_start_ = 0; } TLSCurrentInterpreterGuard g(this); if (frames.back().pc == 0 && stack_start_ == 0) { checkAndStartRecordFunction(frames.back(), stack); } try { while (true) { Frame& frame = frames.back(); // std::cout << "RUNNING "; // frames.back().function->dump(std::cout, frame.pc); Instruction inst = frame.function->instructions_[frame.pc]; profiling::InstructionSpan instSpan{ *frame.function->instructions_source()[frame.pc]}; switch (inst.op) { case ENTER: { const auto& obj = peek(stack, 0, 1); TORCH_INTERNAL_ASSERT(obj.isObject()); entered_objects.push_back(obj); ++frame.pc; } break; case EXIT: { auto obj = entered_objects.back().toObject(); auto& f = obj->type()->getMethod("__exit__"); push(stack, std::move(obj)); entered_objects.pop_back(); push(stack, IValue()); push(stack, IValue()); push(stack, IValue()); runGraphFunction(stack, &f); } break; case OP: frame.function->operator_table_[inst.X](&stack); ++frame.pc; break; case OPN: stack.push_back(inst.N); frame.function->operator_table_[inst.X](&stack); ++frame.pc; break; case LOAD: stack.emplace_back(reg(inst.X)); ++frame.pc; break; case MOVE: stack.emplace_back(std::move(reg(inst.X))); ++frame.pc; break; case STORE: reg(inst.X) = pop(stack); ++frame.pc; break; case STOREN: for (size_t i = inst.N; i > 0; --i) { reg(inst.X + i - 1) = pop(stack); } ++frame.pc; break; case DROP: pop(stack); ++frame.pc; break; case DROPR: reg(inst.X) = IValue(); ++frame.pc; break; case LOADC: stack.emplace_back(frame.function->constant_table_[inst.X]); ++frame.pc; break; case GET_ATTR: { auto userObj = pop(stack).toObject(); auto value = userObj->getSlot(inst.X); push(stack, std::move(value)); ++frame.pc; } break; case SET_ATTR: { auto v = pop(stack); auto userObj = pop(stack).toObject(); userObj->setSlot(inst.X, std::move(v)); ++frame.pc; } break; case JF: frame.pc += (pop(stack).toBool()) ? 1 : inst.X; break; case JMP: frame.pc += inst.X; break; case LOOP: { // stack: iteration_count, max_iter, cond, loop_carried_deps... auto fr = stack.end() - (inst.N + 1); int64_t trip_count = fr[0].toInt(); int64_t max_trip_count = fr[1].toInt(); bool cond = fr[2].toBool(); if (trip_count < max_trip_count && cond) { fr[2] = trip_count; fr[0] = trip_count + 1; ++frame.pc; } else { size_t n_loop_carried = inst.N - 2; for (size_t i = 0; i < n_loop_carried; ++i) { fr[i] = std::move(fr[i + 3]); } drop(stack, 3); // iteration_count, max_iter, cond frame.pc += inst.X; } } break; case CALL: { Function* fn = frame.function->function_table_[inst.X]; if (!fn->isGraphFunction()) { runBuiltinFunction(stack, fn); } else { runGraphFunction(stack, fn); } } break; case INTERFACE_CALL: { // note the hash table lookup to find the function // this can be more optimized if necessary, caching parts // of the hashing computation or storing the offset when // the object is turned into an interface // consider passing // `frames.back().function->remaining_bailout_depth_` into // `get_executor().getPlanFor()` to propagate caller's depth // restrictions onto children while this strategy has a potential to // reduce the number of compilations for too dynamic callers we // might miss opportunities where a caller is dynamic but a callee // gets stable arguments Function& function = peek(stack, 0, inst.N) .toObject() ->type() ->getMethod( frame.function->constant_table_[inst.X].toStringRef()); if (!function.isGraphFunction()) { runBuiltinFunction(stack, &function); } else { runGraphFunction(stack, &function); } } break; case RET: if (frames.size() > 1) { leaveFrame(); break; } if (future_) { auto num_outputs = frames.back().function->n_outputs; if (num_outputs == 1) { future_->markCompleted(stack.back()); } else { future_->markCompleted(c10::ivalue::Tuple::create( jit::last(stack, num_outputs).vec())); } } // destroy the last frame and call RecordFunction's end callbacks leaveFrame(); return false; case WAIT: { auto future = stack.back().toFuture(); if (!future->completed()) { getOrCreateFuture(); // callback needs to be a struct rather than a lambda so that // we can move the stack to the other thread struct Callback { Callback( c10::intrusive_ptr state, Stack stack) : stateImpl_(std::move(state)), state_(stateImpl_), stack_(std::move(stack)) { dist_autograd_context_id_ = getDistAutogradContextId(); state_ = InterpreterState(stateImpl_); } void operator()(c10::ivalue::Future& /* unused */) { stateImpl_->taskLauncher_(InterpreterContinuation( state_, std::move(stack_), dist_autograd_context_id_, std::move(tls_state_))); } private: c10::intrusive_ptr stateImpl_; InterpreterState state_; Stack stack_; int64_t dist_autograd_context_id_; // preserve the original ThreadLocalState at::ThreadLocalState tls_state_; }; // we are suspending, so we need to reset the stack to where we // started if it started empty, except for the inputs we can avoid // a true copy by swapping, which leaves the original stack empty. Stack copied; if (stack_start_ == 0) { copied.swap(stack); } else { copied.insert( copied.begin(), std::make_move_iterator(stack.begin() + stack_start_), std::make_move_iterator(stack.end())); stack.resize(stack_start_); } // save pc into the frame so we continue here when restored future->addCallback( Callback(intrusive_from_this(), std::move(copied))); return true; } stack.pop_back(); stack.emplace_back(future->value()); ++frame.pc; } break; case PROFILE_OP: { auto& frame_id_ref = frame.id; if (!frame_id_ref.has_value()) { frame_id_ref = Frame::genId(); } const auto& callback = frame.function->profile_function_table_[inst.X]; push(stack, c10::IValue{static_cast(*frame_id_ref)}); callback(stack); ++frame.pc; break; } case FAIL_GUARD: { // patch FAIL_GUARD back to GUARD GRAPH_DEBUG( "Bailout ", inst.X, " triggered via bailout_requests_!"); frame.function->instructions_[frame.pc].op = GUARD; push(stack, false); ++frame.pc; break; } case TYPECHECK: { int num_inputs = inst.N, i = 0; // NOLINTNEXTLINE(clang-diagnostic-sign-compare) TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs && num_inputs > 0); // Check every input's shape against profiled (expected) shape. for (i = 0; i < num_inputs; i++) { auto& input = peek(stack, i, num_inputs); auto& t = input.toTensor(); const TypePtr& expected = frame.function->type_table_[inst.X + i]; auto* expected_type = expected->castRaw(); if (t.defined() && !expected_type->matchTensor(t)) { push(stack, false); break; } } if (i == num_inputs) { push(stack, true); } ++frame.pc; break; } case GUARD: { if (!stack.back().isTensor()) { // stack.back() is an Uninitialized IValue and this is a guard // on a block output. Uninitialized IValues are never used // so it's safe to pass this guard check push(stack, true); } else { auto& t = stack.back().toTensor(); const TypePtr& expected = frame.function->type_table_[inst.X]; auto* expected_type = expected->castRaw(); if (t.defined() && !frames.back().symbols2dims.bindSymbolicShapes( t.sizes(), expected_type->symbolic_sizes())) { push(stack, false); } else { push(stack, expected_type->matchTensor(t)); } } ++frame.pc; } break; case TAIL_CALL: { GRAPH_DEBUG("running TAIL_CALL for ", inst.X); frame.function->function_table_[inst.X]->ensure_defined(); size_t remaining_bailout_depth = frame.function->remaining_bailout_depth_ > 0 ? frame.function->remaining_bailout_depth_ - 1 : 0; const Code& code = frame.function->function_table_[inst.X] ->get_executor() .getPlanFor(stack, remaining_bailout_depth) .code; size_t num_inputs = code.num_inputs(); size_t base_pointer = frame.base_pointer; TORCH_INTERNAL_ASSERT(stack.size() >= num_inputs); size_t inputs_start = stack.size() - num_inputs; for (size_t i = 0; i < num_inputs; ++i) { stack.at(base_pointer + i) = std::move(stack.at(inputs_start + i)); } stack.resize(base_pointer + num_inputs); leaveFrame(); enterFrame(code, base_pointer); checkAndStartRecordFunction(frames.back(), stack); } break; case LIST_UNPACK: { listUnpack(stack, inst.X); ++frame.pc; } break; case TUPLE_CONSTRUCT: { tupleConstruct(stack, inst.X); ++frame.pc; } break; case TUPLE_SLICE: { tupleSlice(stack, inst.X, inst.X + inst.N); ++frame.pc; } break; case NAMED_TUPLE_CONSTRUCT: { namedTupleConstruct( stack, frame.function->type_table_[inst.X]->expect(), inst.N); ++frame.pc; } break; case LIST_CONSTRUCT: { const auto& type = frame.function->type_table_[inst.X]->expectRef(); listConstruct(stack, type, inst.N); ++frame.pc; } break; case DICT_CONSTRUCT: { const auto& type = frame.function->type_table_[inst.X]->expectRef(); dictConstruct(stack, type, inst.N); ++frame.pc; } break; case CREATE_OBJECT: { auto type = frame.function->type_table_[inst.X]->expect(); createObject(stack, type); ++frame.pc; } break; case ISINSTANCE: { at::ArrayRef types( &(frame.function->type_table_[inst.X]), &(frame.function->type_table_[inst.X + inst.N])); isinstance(stack, types); ++frame.pc; } break; case FORK: { // Move inputs to a separate stack Function* forked_fn = frame.function->function_table_[inst.X]; InterpreterState forked_interpreter( forked_fn->get_executor() .getPlanFor(stack, GraphExecutor::getDefaultNumBailOuts()) .code, taskLauncher_); InterpreterContinuation continuation( forked_interpreter, Stack(stack.end() - inst.N, stack.end()), getDistAutogradContextId()); drop(stack, inst.N); push(stack, forked_interpreter.getFuture()); taskLauncher_(std::move(continuation)); ++frame.pc; } break; case WARN: { // Keeps track of which WARN instruction has been executed before, // we only want to execute each WARN once to match default Python // warning behavior. bool need_warn = true; if (inst.X != -1) { need_warn = warned_nodes_.insert(inst.X); } Node* node = frames.back().function->instructions_source_.at(frame.pc); auto range = node->sourceRange().source(); if (range->filename()) { drop(stack, 1); const auto& msg = stack.back().toStringRef(); if (need_warn) { auto line = range->starting_line_no() + range->lineno_for_offset(node->sourceRange().start()); c10::SourceLocation location{ "", range->filename()->c_str(), uint32_t(line)}; // Sends the warning to the warning handler with the // "verbatim" flag. This flag ensures the warning handler // will print the exception as configured. c10::Warning::warn(location, msg, /*verbatim=*/true); } stack.pop_back(); } else { const auto& msg = stack.back().toStringRef(); if (need_warn) { TORCH_WARN(msg); } stack.pop_back(); } ++frame.pc; } break; } } } catch (std::exception& e) { for (auto it = entered_objects.rbegin(), end = entered_objects.rend(); it != end; ++it) { auto& f = it->toObject()->type()->getMethod("__exit__"); Stack stack; push(stack, *it); push(stack, IValue()); push(stack, IValue()); push(stack, IValue()); try { f.run(stack); } catch (std::exception& e) { std::ostringstream ss; ss << "The following operation failed in the TorchScript interpreter.\n"; formatStackTrace(ss); ss << "RuntimeError: " << ExceptionMessage(e) << "\n"; } } bool is_jit_exception = dynamic_cast(&e); // Janky af. See https://github.com/pytorch/pytorch/issues/54612 auto* not_implemented_error = dynamic_cast(&e); handleError(ExceptionMessage(e), is_jit_exception, not_implemented_error); return false; } } void formatStackTrace(std::ostream& out) { format_stack_trace(out, callstack()); } void handleError( const ExceptionMessage& msg, bool is_jit_exception, c10::NotImplementedError* not_implemented_error) { std::ostringstream ss; ss << "The following operation failed in the TorchScript interpreter.\n"; formatStackTrace(ss); ss << "RuntimeError: " << msg << "\n"; if (future_) { future_->setError(std::make_exception_ptr(Future::FutureError(ss.str()))); } else if (is_jit_exception) { throw JITException(ss.str()); } else if (not_implemented_error) { throw c10::NotImplementedError( ss.str(), not_implemented_error->backtrace(), not_implemented_error->caller()); } else { throw std::runtime_error(ss.str()); } } static void checkAndStartRecordFunction(Frame& frame, Stack& stack) { bool pre_sampled = false; if (!frame.record_function && at::hasCallbacks() && at::shouldRunRecordFunction(&pre_sampled)) { auto rec_fn = std::make_unique( at::RecordScope::TORCHSCRIPT_FUNCTION, pre_sampled); if (rec_fn->isActive()) { if (rec_fn->needsInputs()) { rec_fn->before( frame.function->function_name_, last(stack, frame.function->n_inputs)); } else { rec_fn->before(frame.function->function_name_); } frame.record_function = std::move(rec_fn); } } } public: std::vector callstack() const { std::vector entries; for (size_t i = 0; i < frames.size(); ++i) { const Frame& frame = frames[i]; std::string previous_fn_name = frame.function->function_name_; size_t pc = frame.pc; // CALL nodes have already advanced the pc, so // undo that to report the call node if (i + 1 < frames.size()) { --pc; } Node* node = frame.function->instructions_source_[pc]; if (node->callstack()) { for (const auto& p : (*node->callstack())->vec()) { entries.emplace_back(StackEntry{previous_fn_name, std::get<1>(p)}); previous_fn_name = std::get<0>(p)->name(); } } entries.emplace_back(StackEntry{previous_fn_name, node->sourceRange()}); } return entries; } c10::intrusive_ptr getOrCreateFuture() { if (!future_) { future_ = c10::make_intrusive(frames.front().function->return_type_); } return future_; } c10::intrusive_ptr runAsync(Stack& stack) { getOrCreateFuture(); runImpl(stack); return future_; } void run(Stack& stack) { if (runImpl(stack)) { future_->wait(); auto num_outputs = frames.front().function->n_outputs; if (num_outputs == 1) { push(stack, future_->value()); } else { auto tuple = future_->value().toTuple(); for (const IValue& value : tuple->elements()) { push(stack, value); } } } } }; std::vector currentCallstack() { if (tls_int_state_ptr_) { auto cs = tls_int_state_ptr_->callstack(); std::reverse(cs.begin(), cs.end()); return cs; } return std::vector(); } std::ostream& operator<<(std::ostream& out, const Code& code) { out << *code.pImpl->graph_ << "\n"; code.pImpl->dump(out); return out; } Code::Code( const std::shared_ptr& graph, std::string function_name, size_t remaining_bailout_depth) : pImpl(new CodeImpl( graph, std::move(function_name), remaining_bailout_depth)) {} Code::Code(CodeImpl* codeImpl) : pImpl(codeImpl) {} Code::~Code() = default; MobileCode::MobileCode( const std::shared_ptr& graph, std::string function_name, bool emit_default_input_instructions, size_t remaining_bailout_depth) : Code(new interpreter::MobileCodeImpl( graph, std::move(function_name), emit_default_input_instructions, remaining_bailout_depth)) {} MobileCode::~MobileCode() = default; const std::vector& Code::grad_executors() { return pImpl->grad_executors(); } const std::vector& Code::diff_graph_op_executors() { return pImpl->diff_graph_op_executors(); } size_t Code::num_bailouts() const { return pImpl->type_table_.size(); } void Code::request_bailout(size_t index) { pImpl->request_bailout(index); } size_t Code::num_inputs() const { return pImpl->n_inputs; } size_t Code::num_outputs() const { return pImpl->n_outputs; } const std::vector& Code::constant_table() const { return pImpl->constant_table(); } const std::vector& Code::instructions() const { return pImpl->instructions(); } const std::unordered_map& Code::op_to_num_specified_args() const { return pImpl->op_to_num_specified_args(); } const std::vector& Code::instructions_source() const { return pImpl->instructions_source(); } const std::vector& Code::type_table() const { return pImpl->type_table_; } size_t Code::register_size() const { return pImpl->register_size_; } InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher) : pImpl(c10::make_intrusive( code, std::move(taskLauncher))) {} InterpreterState::~InterpreterState() = default; void InterpreterState::run(Stack& stack) { static_cast(pImpl.get())->run(stack); } c10::intrusive_ptr InterpreterState::runAsync(Stack& stack) { return static_cast(pImpl.get())->runAsync(stack); } c10::intrusive_ptr InterpreterState::getFuture() { return static_cast(pImpl.get())->getOrCreateFuture(); } InterpreterState::InterpreterState( c10::intrusive_ptr pImpl_) : pImpl(std::move(pImpl_)) {} void InterpreterContinuation::operator()() { #ifdef USE_RPC auto prev_dist_id = DistAutogradContainer::currentContextId(); DistAutogradContainer::forceCurrentContextId(dist_autograd_context_id_); #endif if (tls_state_ != c10::nullopt) { at::ThreadLocalStateGuard g(*tls_state_); state.runAsync(stack); } else { state.runAsync(stack); } #ifdef USE_RPC DistAutogradContainer::forceCurrentContextId(prev_dist_id); #endif } } // namespace jit } // namespace torch