mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implement MM fusion (MM with add reduction tree) A tree where leaves are matrix multiplies and inner vertices are adds can be computed as a single mm. Such subgraph often appear in backward if a single weight is reused multiple times (e.g. in RNNs). NOTE: this seems to be slightly slower on the GPU than the naive implementation, but it's a huge win on the CPU (think 100x lower overhead)
321 lines
11 KiB
C++
321 lines
11 KiB
C++
#include "python_compiled_function.h"
|
|
|
|
#include "torch/csrc/jit/pybind.h"
|
|
#include "torch/csrc/autograd/grad_mode.h"
|
|
#include "torch/csrc/autograd/variable.h"
|
|
#include "torch/csrc/jit/tracer.h"
|
|
#include "torch/csrc/jit/passes/dead_code_elimination.h"
|
|
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
|
|
#include "torch/csrc/jit/passes/peephole.h"
|
|
#include "torch/csrc/jit/passes/graph_fuser.h"
|
|
#include "torch/csrc/jit/passes/inplace_check.h"
|
|
#include "torch/csrc/jit/passes/batch_mm.h"
|
|
#include "torch/csrc/jit/python_arg_flatten.h"
|
|
#include "torch/csrc/jit/interpreter.h"
|
|
#include "torch/csrc/jit/interpreter_autograd_function.h"
|
|
|
|
#include <algorithm>
|
|
#include <functional>
|
|
#include <atomic>
|
|
|
|
namespace torch { namespace jit { namespace python {
|
|
|
|
using namespace torch::autograd;
|
|
using namespace torch::jit::tracer;
|
|
|
|
namespace {
|
|
|
|
// pybind casts are really verobse...
|
|
py::object steal(py::handle x) {
|
|
return py::reinterpret_steal<py::object>(x);
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
// Lifecycle of a CompiledFunction:
|
|
//
|
|
// - It is given an underlying function, which knows how to actually
|
|
// execute the code that we want to compile.
|
|
// - When we encounter an input configuration for which we don't
|
|
// have an optimized trace, we run the underlying function, tracing its
|
|
// result. The trace is not done yet, so we save it into our set of pending
|
|
// traces for that configuration.
|
|
// - When we encounter an input configuration whose trace is "ready"
|
|
// (that is, we've seen all of the passes, so the trace contains
|
|
// forwards/backwards/etc), we compile it, and then register this
|
|
// as the compiled trace.
|
|
// - When we encounter an input configuration whose trace is compiled,
|
|
// we just directly run the compiled trace.
|
|
struct CompiledFunction {
|
|
|
|
struct TraceForKey {
|
|
TraceForKey(CompiledFunction& fn, bool grad_enabled)
|
|
: fn_(fn)
|
|
, grad_enabled_(grad_enabled) {}
|
|
|
|
bool ready() {
|
|
if (is_ready_) return true;
|
|
|
|
// Remove expired traces
|
|
traces_.erase(std::remove_if(traces_.begin(),
|
|
traces_.end(),
|
|
[](const std::shared_ptr<TracingState>& state) {
|
|
return state->is_expired();
|
|
}),
|
|
traces_.end());
|
|
|
|
// Check if any trace is complete
|
|
auto complete_it = std::find_if(traces_.begin(),
|
|
traces_.end(),
|
|
[](const std::shared_ptr<TracingState>& state) {
|
|
return state->is_complete();
|
|
});
|
|
if (complete_it == traces_.end())
|
|
return false;
|
|
|
|
auto complete_trace = *complete_it; // NOTE: copy, because we clear right after
|
|
traces_.clear();
|
|
|
|
// Now, we have a complete trace. Compile it.
|
|
EliminateDeadCode(complete_trace->graph);
|
|
CheckInplace(complete_trace->graph);
|
|
if (fn_.optimize_) {
|
|
PeepholeOptimize(complete_trace->graph);
|
|
BatchMM(complete_trace->graph);
|
|
FuseGraph(complete_trace->graph);
|
|
EliminateCommonSubexpression(complete_trace->graph);
|
|
}
|
|
factory_ = std::make_shared<InterpreterFunctionFactory>(complete_trace.get());
|
|
graph_ = complete_trace->graph;
|
|
is_ready_ = true;
|
|
return true;
|
|
}
|
|
|
|
variable_list run(variable_list inputs) {
|
|
JIT_ASSERT(is_ready_);
|
|
AutoNoGIL _gil_guard;
|
|
auto fn = factory_->construct();
|
|
fn->willReleaseVariables(); // forward pass is never reused, so it is safe to release anything it can
|
|
return fn->apply(inputs);
|
|
}
|
|
|
|
PyObject* add_trace(PyObject *args, ParsedArgs input_info) {
|
|
JIT_ASSERT(!is_ready_);
|
|
// Start tracing
|
|
AutoGradMode grad_mode(grad_enabled_);
|
|
auto num_stages = grad_enabled_ ? fn_.nderivs_ + 1 : 1;
|
|
auto enter_info = tracer::enter(fmap<TraceInput>(input_info.vars), num_stages);
|
|
auto & trace = enter_info.first;
|
|
auto & new_vars = enter_info.second;
|
|
|
|
// Enter returns us a new list of Variables to use as inputs, so handle that.
|
|
std::size_t num_all_inputs = input_info.vars.size();
|
|
std::size_t num_captured = fn_.captured_vars_.size();
|
|
// Check that no captured Variables were replaced by enter. It's hard to handle that.
|
|
for (std::size_t i = num_all_inputs - num_captured; i < num_all_inputs; ++i) {
|
|
TORCH_EXPECTM(input_info.vars[i].get() == new_vars[i].get(),
|
|
"Some of the Variables captured by the JIT are repeated");
|
|
}
|
|
// Now only arguments to this function could have changed. Slice their vars out, and
|
|
// re-create the structure of args, but using new Variables.
|
|
variable_list new_inputs(new_vars.begin(),
|
|
new_vars.end() - num_captured);
|
|
THPObjectPtr new_args { unflatten(new_inputs, input_info.desc) };
|
|
|
|
// Call back into Python function
|
|
auto out = PyObject_CallObject(fn_.function_.get(), new_args.get());
|
|
if (!out) throw py::error_already_set();
|
|
|
|
// Flatten outputs and update fields
|
|
auto out_info = flatten(out);
|
|
if (out_desc_.structure.empty()) {
|
|
out_desc_ = std::move(out_info.desc);
|
|
} else {
|
|
// TODO: assert matches but only in debug mode
|
|
}
|
|
|
|
// Finish tracing and save the current trace
|
|
tracer::exit(out_info.vars);
|
|
traces_.emplace_back(std::move(trace));
|
|
return out;
|
|
}
|
|
|
|
CompiledFunction& fn_;
|
|
IODescriptor out_desc_;
|
|
std::vector<std::shared_ptr<TracingState>> traces_;
|
|
bool grad_enabled_ = false;
|
|
bool is_ready_ = false;
|
|
|
|
std::shared_ptr<InterpreterFunctionFactory> factory_;
|
|
std::shared_ptr<jit::Graph> graph_;
|
|
};
|
|
|
|
TraceForKey& getTrace(ParsedArgs& args) {
|
|
auto it = ktraces_.find(args.desc);
|
|
if (it == ktraces_.end()) {
|
|
bool grad_enabled = args.desc.grad_enabled;
|
|
std::tie(it, std::ignore) = ktraces_.emplace(args.desc,
|
|
TraceForKey(*this, grad_enabled));
|
|
}
|
|
return it->second;
|
|
}
|
|
|
|
ParsedArgs flattenArgs(py::handle pyargs) {
|
|
auto args = flatten(pyargs);
|
|
// We need to take captured_var types into account when choosing the trace
|
|
args.extend(captured_vars_);
|
|
return args;
|
|
}
|
|
|
|
py::object fallback(py::handle pyargs) {
|
|
return steal(PyObject_CallObject(function_.get(), pyargs.ptr()));
|
|
}
|
|
|
|
py::object call(py::handle pyargs) {
|
|
if (!enabled_) {
|
|
return fallback(pyargs);
|
|
}
|
|
auto args = flattenArgs(pyargs);
|
|
|
|
if(isTracingVar(args.vars)) {
|
|
// Some outer compiled function has called another compiled function.
|
|
// In this case we just fall back to the original python function,
|
|
// allowing the inner trace to be inlined into the outer.
|
|
// This scenario occurs when blocking an lstm loop.
|
|
return fallback(pyargs);
|
|
}
|
|
|
|
auto& ktrace = getTrace(args);
|
|
|
|
variable_list out_vars;
|
|
if (ktrace.ready()) {
|
|
hits_++;
|
|
return steal(unflatten(ktrace.run(std::move(args.vars)), ktrace.out_desc_));
|
|
} else {
|
|
misses_++;
|
|
return steal(ktrace.add_trace(pyargs.ptr(), std::move(args)));
|
|
}
|
|
}
|
|
|
|
void clearCache() {
|
|
ktraces_.clear();
|
|
}
|
|
|
|
CompiledFunction(int nderivs, bool optimize, bool enabled, py::object function,
|
|
std::string name)
|
|
: nderivs_(nderivs)
|
|
, optimize_(optimize)
|
|
, enabled_(enabled)
|
|
, hits_(0)
|
|
, misses_(0)
|
|
, function_(function.release().ptr())
|
|
, name_(std::move(name))
|
|
, captured_vars_()
|
|
, ktraces_() {}
|
|
|
|
int nderivs_;
|
|
bool optimize_;
|
|
bool enabled_;
|
|
std::atomic<uint64_t> hits_;
|
|
std::atomic<uint64_t> misses_;
|
|
THPObjectPtr function_;
|
|
std::string name_;
|
|
variable_list captured_vars_;
|
|
std::unordered_map<IODescriptor, TraceForKey, torch::hash<IODescriptor>> ktraces_;
|
|
};
|
|
|
|
|
|
std::ostream& operator<<(std::ostream& out, const CompiledFunction::TraceForKey & trace) {
|
|
if(!const_cast<CompiledFunction::TraceForKey&>(trace).ready()) {
|
|
out << "<trace has been started but has not been completed>";
|
|
return out;
|
|
}
|
|
out << *trace.graph_ << "\n";
|
|
return out;
|
|
}
|
|
|
|
std::ostream& operator<<(std::ostream& out, const CompiledFunction & cf) {
|
|
out << "CompiledFunction: " << cf.name_ << "(nderivs=" << cf.nderivs_ << ", optimized=" << cf.optimize_ << ", enabled=" << cf.enabled_ << "):\n";
|
|
out << "trace cache hits: " << cf.hits_ << "\n";
|
|
out << "trace cache misses: " << cf.misses_ << "\n";
|
|
std::vector<std::string> trace_info;
|
|
for(auto & v : cf.ktraces_) {
|
|
std::stringstream ss;
|
|
ss << v.first << v.second << "\n\n";
|
|
trace_info.push_back(ss.str());
|
|
}
|
|
// unordered map, so sort to make this deterministic, the IODescriptors will
|
|
// be different so comparison won't read most of the string.
|
|
std::sort(trace_info.begin(), trace_info.end());
|
|
out << trace_info.size() << " traces found.\n";
|
|
|
|
for(size_t i = 0; i < trace_info.size(); ++i) {
|
|
out << "Trace " << i << " for input with layout " << trace_info[i];
|
|
}
|
|
return out;
|
|
}
|
|
|
|
|
|
namespace {
|
|
|
|
CompiledFunction::TraceForKey* getTraceFor(CompiledFunction& fn,
|
|
py::handle pyargs) {
|
|
auto args = fn.flattenArgs(pyargs);
|
|
auto it = fn.ktraces_.find(args.desc);
|
|
if (it == fn.ktraces_.end())
|
|
return nullptr;
|
|
return it->second.ready() ? &it->second : nullptr;
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
static py::tuple tuple_tail(const py::tuple & tup) {
|
|
py::tuple r(tup.size() - 1);
|
|
for(std::size_t i = 1; i < tup.size(); i++) {
|
|
r[i-1] = tup[i];
|
|
}
|
|
return r;
|
|
}
|
|
|
|
void initCompilerMixin(PyObject *module) {
|
|
auto m = py::handle(module).cast<py::module>();
|
|
py::class_<CompiledFunction>(m, "CompiledFunction", py::dynamic_attr())
|
|
.def(py::init<int, bool, bool, py::object, std::string>())
|
|
.def("__call__", [](py::args args_) -> py::object {
|
|
auto fn = py::cast<CompiledFunction*>(args_[0]);
|
|
auto args = tuple_tail(args_);
|
|
return fn->call(args);
|
|
})
|
|
.def("has_trace_for", [](py::args args_) -> bool {
|
|
auto fn = py::cast<CompiledFunction*>(args_[0]);
|
|
auto args = tuple_tail(args_);
|
|
return getTraceFor(*fn, args) != nullptr;
|
|
})
|
|
.def("graph_for", [](py::args args_) -> py::object {
|
|
auto fn = py::cast<CompiledFunction*>(args_[0]);
|
|
auto args = tuple_tail(args_);
|
|
auto trace = getTraceFor(*fn, args);
|
|
return trace ? py::cast(trace->graph_) : py::none();
|
|
})
|
|
.def("clear_cache", [](CompiledFunction& fn) {
|
|
fn.clearCache();
|
|
})
|
|
.def("set_captured_vars", [](CompiledFunction& fn, variable_list vars) {
|
|
fn.captured_vars_ = std::move(vars);
|
|
})
|
|
.def("jit_debug_info", [](const CompiledFunction& s) -> std::string {
|
|
std::ostringstream ss;
|
|
ss << s;
|
|
return ss.str();
|
|
})
|
|
.def_property_readonly("hits", [](CompiledFunction& fn) {
|
|
return fn.hits_.load();
|
|
})
|
|
.def_property_readonly("misses", [](CompiledFunction& fn) {
|
|
return fn.misses_.load();
|
|
})
|
|
.def_readwrite("enabled", &CompiledFunction::enabled_);
|
|
}
|
|
|
|
}}} // namespace torch::jit::python
|