mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-08 07:39:33 +01:00
Summary: Since caffe2 and torch have been consolidated, CAFFE2_API should be merged with TORCH_API. Addresses a TODO. Manually edited some references of the removed `CAFFE2_API`: * `CONTRIBUTING.md` * `caffe2/proto/CMakeLists.txt` * `cmake/ProtoBuf.cmake` * `c10/macros/Export.h` * `torch/csrc/WindowsTorchApiMacro.h` Pull Request resolved: https://github.com/pytorch/pytorch/pull/49496 Reviewed By: malfet, samestep Differential Revision: D25600726 Pulled By: janeyx99 fbshipit-source-id: 7e068d959e397ac183c097d7e9a9afeca5ddd782
94 lines
2.9 KiB
C++
94 lines
2.9 KiB
C++
#pragma once
|
|
|
|
#include "caffe2/opt/backend_transformer_base.h"
|
|
|
|
#include <unordered_set>
|
|
|
|
namespace caffe2 {
|
|
|
|
struct TvmTransformOptions final : public BackendTransformOptions {
|
|
explicit TvmTransformOptions() : BackendTransformOptions() {}
|
|
|
|
// Whether to enable profiling based jit
|
|
bool profiling_based_jit{false};
|
|
};
|
|
|
|
class TORCH_API TvmTransformer final : public BackendTransformerBase {
|
|
public:
|
|
explicit TvmTransformer(const TvmTransformOptions& opts)
|
|
: BackendTransformerBase(), opts_(opts) {}
|
|
|
|
~TvmTransformer() override {}
|
|
|
|
// Given workspace and predict net, cluster continuous parts that can be run
|
|
// by TVM and create one TVMJit op for each clustered subgraph.
|
|
// \param ws c2 workspace
|
|
// \param pred_net c2 predict net
|
|
// \param weight_names list of the names of the constant weights
|
|
// \param shape_hints User provided shape info, usually for primary inputs so
|
|
// that bound shape inference can have something to start
|
|
// \param blocklisted_ops a set of ops that we don't want to lower to TVM in
|
|
// terms of their net positions. This is very useful for debugging but for
|
|
// normal runs it should be empty
|
|
void transform(
|
|
Workspace* ws,
|
|
NetDef* pred_net,
|
|
const std::vector<std::string>& weight_names,
|
|
const ShapeInfoMap& shape_hints,
|
|
const std::unordered_set<int>& blocklisted_ops) override;
|
|
|
|
static const std::unordered_set<std::string>& getSupportedOps();
|
|
|
|
static bool canConvertFullGraph(
|
|
const caffe2::NetDef& net,
|
|
const std::unordered_set<int>& blocklisted_ops);
|
|
|
|
private:
|
|
// Given TVM runnable subnets, contract them into one TVMJitOp
|
|
NetDef buildTvmOp(
|
|
const caffe2::NetDef& net,
|
|
const std::unordered_set<std::string>& weights,
|
|
const ShapeInfoMap& shape_hints);
|
|
|
|
// Apply transform to cluster connected TVM runnable ops into one TVMJitOp
|
|
NetDef applyTvmTransform(
|
|
NetDef* pred_net,
|
|
const std::unordered_set<std::string>& weights,
|
|
const std::unordered_set<int>& blocklisted_ops,
|
|
const ShapeInfoMap& shape_hints);
|
|
|
|
// Options
|
|
TvmTransformOptions opts_;
|
|
|
|
// Track number of TVMJitOp we created
|
|
int tvm_op_id_{0};
|
|
|
|
// Model id
|
|
std::string model_id_;
|
|
};
|
|
|
|
// Helper function to clean up a net and run tvm transform.
|
|
TORCH_API void tvmTransform(
|
|
NetDef* net,
|
|
Workspace* ws,
|
|
const std::vector<std::string>& input_names,
|
|
const std::vector<std::string>& output_names,
|
|
const std::vector<std::string>& weight_names,
|
|
const ShapeInfoMap& shape_hints,
|
|
const std::unordered_set<int>& blocklisted_ops,
|
|
int32_t max_batch_size,
|
|
int32_t max_seq_size,
|
|
int32_t num_embeddings,
|
|
int32_t embedding_size,
|
|
int32_t tvm_min_ops,
|
|
bool tvm_profiling_based_jit,
|
|
bool debug);
|
|
|
|
TORCH_API void cleanUpPredictNet(
|
|
NetDef* net,
|
|
const std::vector<std::string>& input_names,
|
|
const std::vector<std::string>& output_names,
|
|
const std::vector<std::string>& weight_names);
|
|
|
|
} // namespace caffe2
|