mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Another code-mod for clang-tidy: Conversion operators should be marked explicit so that they don't cause unwanted implicit conversions. This is especially important for `operator bool()`, see https://stackoverflow.com/questions/39995573/when-can-i-use-explicit-operator-bool-without-a-cast ezyang apaszke Pull Request resolved: https://github.com/pytorch/pytorch/pull/9336 Reviewed By: apaszke Differential Revision: D8807065 Pulled By: goldsborough fbshipit-source-id: 0e9f4ebd0048a2a510c0d05fa410695d7e977eb1
59 lines
1.6 KiB
C++
59 lines
1.6 KiB
C++
#pragma once
|
|
#include <memory>
|
|
#include <vector>
|
|
#include "ATen/optional.h"
|
|
|
|
namespace at {
|
|
struct Tensor;
|
|
}
|
|
namespace torch { namespace jit {
|
|
|
|
// The interpreter run Graphs with Tensor inputs and Tensor outputs
|
|
// a separate component in the autograd handles unwrapping and wrapping
|
|
// variable objects for use in the interpreter.
|
|
|
|
struct Node;
|
|
struct GraphExecutor;
|
|
struct CodeImpl;
|
|
struct InterpreterStateImpl;
|
|
struct Graph;
|
|
struct Node;
|
|
struct TensorType;
|
|
|
|
struct Code {
|
|
Code()
|
|
: pImpl(nullptr) {}
|
|
Code(std::shared_ptr<Graph>& graph);
|
|
~Code();
|
|
|
|
// Returns pointers to GraphExecutors created to run GraphExecutor nodes in the given graph.
|
|
const std::vector<GraphExecutor*>& executors();
|
|
|
|
explicit operator bool() const {
|
|
return pImpl != nullptr;
|
|
}
|
|
|
|
private:
|
|
std::shared_ptr<CodeImpl> pImpl;
|
|
friend struct InterpreterStateImpl;
|
|
friend std::ostream & operator<<(std::ostream & out, const Code & code);
|
|
};
|
|
|
|
struct InterpreterState {
|
|
InterpreterState(const Code & code);
|
|
// advance the interpreter state by running one stage. Returning the
|
|
// outputs for that stage, suspending the computation.
|
|
// Call this function again continues computation where it left off.
|
|
void runOneStage(std::vector<at::Tensor> & stack);
|
|
const TensorType & tensorTypeForInput(size_t i) const;
|
|
~InterpreterState();
|
|
// create a copy of InterpreterState with its current state
|
|
// used when retain_graph=True so that stages can be re-run
|
|
InterpreterState clone() const;
|
|
private:
|
|
InterpreterState(InterpreterStateImpl * pImpl);
|
|
std::shared_ptr<InterpreterStateImpl> pImpl;
|
|
};
|
|
|
|
}}
|