mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: Need to bring in all signatures https://www.internalfb.com/code/fbsource/[36035b9e4e41813e215ffd5f4377d65b7259237e]/fbcode/caffe2/aten/src/ATen/core/function.h?lines=91-101 Test Plan: ``` Action Failed for fbcode//accelerators/pytorch/lib/cuda:ngram_repeat_block_cuda (ovr_config//platform/linux:x86_64-fbcode-platform010-clang-6dbc4bb1b9a32829)#5: cxx_compile ngram_repeat_block_cuda_kernel.cu (pic) failed with non-zero exit code 1 debug information: action_digest=988629a726bc4eabcaf334db2317a969958d5fd2:94 stdout: stderr: fbcode/caffe2/torch/csrc/jit/api/function_impl.h(11): warning: overloaded virtual function "torch::jit::Function::call" is only partially overridden in class "torch::jit::GraphFunction" fbcode/caffe2/torch/csrc/jit/api/function_impl.h(11): warning: overloaded virtual function "torch::jit::Function::call" is only partially overridden in class "torch::jit::GraphFunction" fbcode/caffe2/torch/csrc/jit/frontend/tree_views.h: In instantiation of 'static torch::jit::Maybe<T> torch::jit::Maybe<T>::create(const torch::jit::SourceRange&, const T&) [with T = torch::jit::List<torch::jit::Property>]': fbcode/caffe2/torch/csrc/jit/frontend/tree_views.h:505:117: required from here fbcode/caffe2/torch/csrc/jit/frontend/tree_views.h:220:33: error: cannot convert 'const torch::jit::List<torch::jit::Property>' to 'torch::jit::TreeList&&' {aka 'c10::SmallVector<c10::intrusive_ptr<torch::jit::Tree>, 4>&&'} 220 | return Maybe<T>(Compound::create(TK_OPTION, range, {value})); | ~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~ fbcode/caffe2/torch/csrc/jit/frontend/tree.h:144:1: note: initializing argument 3 of 'static torch::jit::TreeRef torch::jit::Compound::create(int, const torch::jit::SourceRange&, torch::jit::TreeList&&)' 143 | const SourceRange& range_, | ~~~~~~~~~~~~~~~~~~~~~~~~ 144 | TreeList&& trees_) { | ^ fbcode/caffe2/torch/csrc/jit/frontend/tree_views.h: In instantiation of 'static torch::jit::Maybe<T> torch::jit::Maybe<T>::create(const torch::jit::SourceRange&, const T&) [with T = torch::jit::List<torch::jit::Assign>]': fbcode/caffe2/torch/csrc/jit/frontend/tree_views.h:505:171: required from here fbcode/caffe2/torch/csrc/jit/frontend/tree_views.h:220:33: error: cannot convert 'const torch::jit::List<torch::jit::Assign>' to 'torch::jit::TreeList&&' {aka 'c10::SmallVector<c10::intrusive_ptr<torch::jit::Tree>, 4>&&'} 220 | return Maybe<T>(Compound::create(TK_OPTION, range, {value})); | ~~~~~~~~~~~~~~~~~^~~~~~~~~~~~~~~~~~~~~~~~~~ fbcode/caffe2/torch/csrc/jit/frontend/tree.h:144:1: note: initializing argument 3 of 'static torch::jit::TreeRef torch::jit::Compound::create(int, const torch::jit::SourceRange&, torch::jit::TreeList&&)' 143 | const SourceRange& range_, | ~~~~~~~~~~~~~~~~~~~~~~~~ 144 | TreeList&& trees_) { | ^ cc1plus: note: unrecognized command-line option '-Wno-ignored-optimization-argument' may have been intended to silence earlier diagnostics cc1plus: note: unrecognized command-line option '-Wno-ambiguous-reversed-operator' may have been intended to silence earlier diagnostics cc1plus: note: unrecognized command-line option '-Wno-ignored-optimization-argument' may have been intended to silence earlier diagnostics cc1plus: note: unrecognized command-line option '-Wno-ambiguous-reversed-operator' may have been intended to silence earlier diagnostics command: buck-out/v2/gen/fbcode/999b02f9444004c1/tools/build/__wrap_nvcc.py__/wrap_nvcc.py -_NVCC_BIN_ fbcode ...<omitted>... ors/pytorch/lib/cuda/__ngram_repeat_block_cuda__/__objects__/ngram_repeat_block_cuda_kernel.cu.pic.o (rerun with -v to view the untruncated command) ``` Differential Revision: D33579670 fbshipit-source-id: 9acb443732feb3e921ce0fa5f38f21ed44f64114
169 lines
4.8 KiB
C++
169 lines
4.8 KiB
C++
#pragma once
|
|
|
|
#include <ATen/core/function.h>
|
|
#include <torch/csrc/jit/ir/ir.h>
|
|
#include <torch/csrc/jit/runtime/graph_executor.h>
|
|
#include <torch/csrc/utils/memory.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
|
|
struct TORCH_API GraphFunction : public Function {
|
|
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
GraphFunction(
|
|
c10::QualifiedName name,
|
|
std::shared_ptr<Graph> graph,
|
|
std::function<void(GraphFunction&)> function_creator)
|
|
: name_(std::move(name)),
|
|
graph_(std::move(graph)),
|
|
function_creator_(std::move(function_creator)) {}
|
|
|
|
bool isGraphFunction() const override {
|
|
return true;
|
|
}
|
|
|
|
void run(Stack& stack) override;
|
|
|
|
std::function<void(GraphFunction&)> function_creator() const {
|
|
return function_creator_;
|
|
}
|
|
|
|
c10::intrusive_ptr<c10::ivalue::Future> runAsync(
|
|
Stack& stack,
|
|
TaskLauncher taskLauncher = at::launch) override;
|
|
|
|
std::shared_ptr<Graph> graph() const {
|
|
return graph_;
|
|
}
|
|
|
|
std::shared_ptr<Graph> optimized_graph() const {
|
|
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
|
|
auto& optimized_graph = optimized_graphs_[currentSpecialization()];
|
|
if (optimized_graph) {
|
|
return *optimized_graph;
|
|
}
|
|
optimized_graph = graph_->copy();
|
|
if (getGraphExecutorOptimize()) {
|
|
preoptimizeGraph(*optimized_graph);
|
|
}
|
|
return *optimized_graph;
|
|
}
|
|
|
|
const c10::QualifiedName& qualname() const override {
|
|
return name_;
|
|
}
|
|
|
|
// if this isn't yet defined, run its method_creator function
|
|
void ensure_defined() override;
|
|
|
|
size_t num_inputs() const override {
|
|
return graph()->inputs().size();
|
|
}
|
|
|
|
Function& setSchema(FunctionSchema schema) override {
|
|
schema_ = make_unique<FunctionSchema>(std::move(schema));
|
|
return *this;
|
|
}
|
|
|
|
const FunctionSchema& getSchema() const override;
|
|
|
|
GraphExecutorState getDebugState() {
|
|
return get_executor().getDebugState();
|
|
}
|
|
|
|
bool is_optimized() const {
|
|
TORCH_WARN(
|
|
"GraphFunction::is_optimized() is deprecated and always returns true. "
|
|
"Please use getGraphExecutorOptimize()");
|
|
return true;
|
|
}
|
|
|
|
void check_single_output() {
|
|
TORCH_CHECK(
|
|
graph()->outputs().size() == 1,
|
|
"Method (but not graphs in general) require a single output. Use None/Tuple for 0 or 2+ outputs");
|
|
}
|
|
|
|
GraphExecutor& get_executor() {
|
|
ensure_defined();
|
|
std::lock_guard<std::recursive_mutex> lock(compile_mutex);
|
|
auto& executor = executors_[currentSpecialization()];
|
|
if (executor) {
|
|
return *executor;
|
|
}
|
|
check_single_output();
|
|
executor = GraphExecutor(optimized_graph(), name_.name());
|
|
return *executor;
|
|
}
|
|
|
|
using Function::call;
|
|
bool call(
|
|
Stack& stack,
|
|
size_t bailOut,
|
|
c10::function_ref<void(const Code&)> f) override {
|
|
f(get_executor().getPlanFor(stack, bailOut).code);
|
|
return true;
|
|
}
|
|
|
|
void clear_optimized_graphs() {
|
|
optimized_graphs_.fill(c10::nullopt);
|
|
}
|
|
|
|
private:
|
|
enum SpecializationKey {
|
|
AutocastOff,
|
|
CpuAutocastOn,
|
|
GpuAutocastOn,
|
|
CpuGpuAutocastOn,
|
|
|
|
// This provides the number of specializations
|
|
// (Must be last entry)
|
|
TotalCount
|
|
};
|
|
|
|
SpecializationKey currentSpecialization() const;
|
|
|
|
private:
|
|
c10::QualifiedName name_;
|
|
// The original, non-optimized graph
|
|
std::shared_ptr<Graph> graph_; // for debugging and for inlining
|
|
|
|
// Optimized graph, computed lazily. Used for inlining.
|
|
mutable std::array<
|
|
c10::optional<std::shared_ptr<Graph>>,
|
|
SpecializationKey::TotalCount>
|
|
optimized_graphs_;
|
|
|
|
// GraphFunctions are invokable from multiple threads, so this lock needs to
|
|
// be held when we're initializing graph executor for the first time or
|
|
// computing the optimized graph. We're using reentrant mutex so that we don't
|
|
// need to worry about causing a deadlock by calling one method from another
|
|
// (e.g. optimized_graph() from get_executor()).
|
|
mutable std::recursive_mutex compile_mutex;
|
|
|
|
// executor_[0] - autocast off
|
|
// executor_[1] - autocast cpu on
|
|
// executor_[2] - autocast gpu on
|
|
// executor_[3] - autocast cpu & gpu on
|
|
std::array<c10::optional<GraphExecutor>, SpecializationKey::TotalCount>
|
|
executors_;
|
|
|
|
// an optional function that actually creates the method when
|
|
// ensure_defined() is called. This is used by the compiler so
|
|
// that it can construct methods out of order
|
|
std::function<void(GraphFunction&)> function_creator_;
|
|
|
|
// if absent, then we generate a default schema based on the graph
|
|
// mutable because getSchema caches the default schema if one is requested
|
|
// before a call to setSchema
|
|
mutable std::unique_ptr<FunctionSchema> schema_;
|
|
};
|
|
|
|
// Short hands for dynamic_cast<GraphFunction*>.
|
|
TORCH_API GraphFunction* tryToGraphFunction(Function&) noexcept;
|
|
TORCH_API GraphFunction& toGraphFunction(Function&);
|
|
TORCH_API const GraphFunction& toGraphFunction(const Function&);
|
|
|
|
} // namespace jit
|
|
} // namespace torch
|