mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[TensorExpr] AOT Compiler: support symbolic shape arguments. (#70374)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70374
Differential Revision:
D33303646
D33303646
Test Plan: Imported from OSS
Reviewed By: navahgar, priyaramani
Pulled By: ZolotukhinM
fbshipit-source-id: da81af93b27e632b34c6f35b0ff3c933cba74c19
(cherry picked from commit 4af5fb18a1)
This commit is contained in:
parent
a7cebda955
commit
64668e61b8
|
|
@ -36,19 +36,78 @@ std::vector<int64_t> getConstSizes(const BufPtr b) {
|
|||
return r;
|
||||
}
|
||||
|
||||
// Construct input-specs vector from the inputs of the original graph
|
||||
std::vector<mobile::nnc::InputSpec> toInputSpecs(
|
||||
const std::vector<std::vector<int64_t>>& inputSizes,
|
||||
const std::vector<at::ScalarType>& inputTypes) {
|
||||
const std::shared_ptr<Graph>& g) {
|
||||
std::vector<mobile::nnc::InputSpec> specs;
|
||||
for (int i = 0; i < inputSizes.size(); ++i) {
|
||||
for (auto v : g->inputs()) {
|
||||
const auto& t = v->type();
|
||||
mobile::nnc::InputSpec spec;
|
||||
spec.sizes_ = inputSizes[i];
|
||||
spec.dtype_ = inputTypes[i];
|
||||
TORCH_CHECK(t->kind() == TypeKind::TensorType, "Unsupported input type");
|
||||
const auto& tt = t->cast<TensorType>();
|
||||
spec.sizes_ = {};
|
||||
auto sizes_vec = *tt->sizes().sizes();
|
||||
for (auto s : sizes_vec) {
|
||||
spec.sizes_.push_back(s ? *s : 0);
|
||||
}
|
||||
spec.dtype_ = *tt->scalarType();
|
||||
specs.emplace_back(std::move(spec));
|
||||
}
|
||||
return specs;
|
||||
}
|
||||
|
||||
// Locate symbolic shapes in shapes of the inputs.
|
||||
//
|
||||
// For each symbolic shape we're trying to find the input from which it can be
|
||||
// extracted and the dimension index in that input.
|
||||
// For instance, if we have
|
||||
// graph(%x : Float(SS(-1), 10), %y : Long(20, SS(-2), %ss_1 : int, %ss_2 : int)
|
||||
// then we would need to find locations of two symbolic shapes: SS(-1) and
|
||||
// SS(-2). The first one corresponds to the first dimension of the first input,
|
||||
// the second one corresponds to the second dimension of the second input,
|
||||
// so we will return {{0, 0}, {1, 1}}.
|
||||
//
|
||||
// If a symbolic shape cannot be found among dimensions of inputs, we
|
||||
// will throw an error (this situation is possible when symbolic shape
|
||||
// corresponds to the size of an intermediate - we don't support this
|
||||
// case here yet).
|
||||
//
|
||||
// If a symbolic shape can be found in several different positions, we
|
||||
// return the first one we find (TODO: maybe we should return all and
|
||||
// verify that they all match at runtime).
|
||||
std::vector<SymbolicShapePosition> findSymbolicShapePositions(
|
||||
std::shared_ptr<tensorexpr::TensorExprKernel> kernel) {
|
||||
std::vector<SymbolicShapePosition> res;
|
||||
for (int64_t sym_idx : kernel->getSymbolicShapeInputs()) {
|
||||
bool found = false;
|
||||
for (int64_t input_idx : c10::irange(kernel->graph()->inputs().size())) {
|
||||
auto input = kernel->graph()->inputs()[input_idx];
|
||||
|
||||
if (!input->type()->cast<TensorType>()) {
|
||||
continue;
|
||||
}
|
||||
auto tt = input->type()->expect<TensorType>();
|
||||
if (!tt->symbolic_sizes().sizes()) {
|
||||
continue;
|
||||
}
|
||||
std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
|
||||
for (int64_t dim_idx : c10::irange(shape_vec.size())) {
|
||||
if (shape_vec[dim_idx].value() == sym_idx) {
|
||||
res.emplace_back(input_idx, dim_idx);
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
TORCH_CHECK(
|
||||
found, "Could not locate a symbolic shape among input tensor shapes");
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
std::unique_ptr<Function> compileMethod(
|
||||
std::shared_ptr<tensorexpr::TensorExprKernel> kernel,
|
||||
const std::string& method_name,
|
||||
|
|
@ -56,7 +115,7 @@ std::unique_ptr<Function> compileMethod(
|
|||
const std::vector<at::ScalarType>& types) {
|
||||
auto func = std::make_unique<Function>();
|
||||
func->set_name(method_name);
|
||||
func->set_input_specs(toInputSpecs(sizes, types));
|
||||
func->set_input_specs(toInputSpecs(kernel->graph()));
|
||||
|
||||
auto params = c10::impl::GenericList(c10::AnyType::get());
|
||||
auto const_descriptors = kernel->getConstantDescriptors();
|
||||
|
|
@ -103,6 +162,7 @@ std::unique_ptr<Function> compileMethod(
|
|||
out_spec.push_back(output);
|
||||
}
|
||||
func->set_output_specs(out_spec);
|
||||
func->set_sym_shape_positions(findSymbolicShapePositions(kernel));
|
||||
|
||||
return func;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -165,6 +165,15 @@ c10::IValue Function::serialize() const {
|
|||
|
||||
// memory_plan_
|
||||
dict.insert("memory_plan", memory_plan_.serialize());
|
||||
|
||||
// sym_shape_positions_
|
||||
std::vector<c10::IValue> sym_shape_pos_vec;
|
||||
for (const auto& sym_shape_pos : sym_shape_positions_) {
|
||||
sym_shape_pos_vec.emplace_back(
|
||||
Tup({sym_shape_pos.input_idx_, sym_shape_pos.dim_idx_}));
|
||||
}
|
||||
dict.insert("sym_shape_pos", Tup(std::move(sym_shape_pos_vec)));
|
||||
|
||||
return dict;
|
||||
}
|
||||
|
||||
|
|
@ -224,18 +233,32 @@ c10::impl::GenericList Function::run(
|
|||
|
||||
// Fill in input tensors.
|
||||
TORCH_CHECK(
|
||||
input_specs_.size() == inputs.size(),
|
||||
input_specs_.size() == (inputs.size() + sym_shape_positions_.size()),
|
||||
"Input size doesn't match the spec, expect: ",
|
||||
input_specs_.size(),
|
||||
" actual: ",
|
||||
inputs.size());
|
||||
std::vector<int64_t> scalar_values;
|
||||
int offset = 0;
|
||||
for (const auto i : c10::irange(inputs.size())) {
|
||||
const c10::IValue& input = inputs[i];
|
||||
const auto& spec = input_specs_[i];
|
||||
const auto& input_tensor = input.toTensor();
|
||||
TORCH_CHECK(
|
||||
input_specs_[i].validate(input_tensor), "Invalid input at pos: ", i);
|
||||
args[i] = input_tensor.data_ptr();
|
||||
}
|
||||
offset += inputs.size();
|
||||
|
||||
scalar_values.reserve(sym_shape_positions_.size());
|
||||
for (const auto i : c10::irange(sym_shape_positions_.size())) {
|
||||
const auto& sym_shape_pos = sym_shape_positions_[i];
|
||||
const c10::IValue& input = inputs[sym_shape_pos.input_idx_];
|
||||
auto dim = input.toTensor().size(sym_shape_pos.dim_idx_);
|
||||
scalar_values.push_back(dim);
|
||||
args[i + offset] = &scalar_values[scalar_values.size() - 1];
|
||||
}
|
||||
offset += sym_shape_positions_.size();
|
||||
|
||||
// Preallocate and fill in output tensors.
|
||||
c10::List<at::Tensor> outputs;
|
||||
|
|
@ -243,7 +266,7 @@ c10::impl::GenericList Function::run(
|
|||
for (const auto i : c10::irange(output_specs_.size())) {
|
||||
at::Tensor output = output_specs_[i].allocate();
|
||||
outputs.emplace_back(output);
|
||||
args[inputs.size() + i] = output.data_ptr();
|
||||
args[i + offset] = output.data_ptr();
|
||||
}
|
||||
|
||||
// TODO: check consistency, e.g.: code version, input shape and compiled
|
||||
|
|
|
|||
|
|
@ -91,6 +91,16 @@ struct TORCH_API MemoryPlan {
|
|||
std::vector<int64_t> buffer_sizes_;
|
||||
};
|
||||
|
||||
// Location of a symbolic shape among dimensions of the inputs
|
||||
struct TORCH_API SymbolicShapePosition {
|
||||
SymbolicShapePosition() = default;
|
||||
SymbolicShapePosition(int64_t input_idx, int64_t dim_idx)
|
||||
: input_idx_(input_idx), dim_idx_(dim_idx) {}
|
||||
|
||||
int64_t input_idx_;
|
||||
int64_t dim_idx_;
|
||||
};
|
||||
|
||||
// Represents a compiled NNC function which has a 1-1 correspondence with a
|
||||
// `Method` (e.g. `forward`). It's similar as torch::jit::mobile::Function.
|
||||
class TORCH_API Function {
|
||||
|
|
@ -158,6 +168,15 @@ class TORCH_API Function {
|
|||
memory_plan_ = memory_plan;
|
||||
}
|
||||
|
||||
const std::vector<SymbolicShapePosition>& sym_shape_positions() const {
|
||||
return sym_shape_positions_;
|
||||
}
|
||||
|
||||
void set_sym_shape_positions(
|
||||
const std::vector<SymbolicShapePosition>& sym_shape_pos) {
|
||||
sym_shape_positions_ = sym_shape_pos;
|
||||
}
|
||||
|
||||
private:
|
||||
void init_execution_state() const;
|
||||
|
||||
|
|
@ -166,6 +185,7 @@ class TORCH_API Function {
|
|||
c10::impl::GenericList parameters_{at::AnyType::get()};
|
||||
std::vector<InputSpec> input_specs_;
|
||||
std::vector<OutputSpec> output_specs_;
|
||||
std::vector<SymbolicShapePosition> sym_shape_positions_;
|
||||
MemoryPlan memory_plan_;
|
||||
mutable std::unique_ptr<ExecutionState> execution_state_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -181,6 +181,10 @@ class TORCH_API TensorExprKernel {
|
|||
return codegen_->kernel_func_name();
|
||||
}
|
||||
|
||||
const std::vector<int64_t>& getSymbolicShapeInputs() const {
|
||||
return symbolic_shape_inputs_;
|
||||
}
|
||||
|
||||
private:
|
||||
enum BackendType {
|
||||
kUninitialized,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user