[TensorExpr] AOT Compiler: support symbolic shape arguments. (#70374)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/70374

Differential Revision:
D33303646
D33303646

Test Plan: Imported from OSS

Reviewed By: navahgar, priyaramani

Pulled By: ZolotukhinM

fbshipit-source-id: da81af93b27e632b34c6f35b0ff3c933cba74c19
(cherry picked from commit 4af5fb18a1)
This commit is contained in:
Mikhail Zolotukhin 2022-02-01 18:25:44 -08:00 committed by PyTorch MergeBot
parent a7cebda955
commit 64668e61b8
4 changed files with 115 additions and 8 deletions

View File

@ -36,19 +36,78 @@ std::vector<int64_t> getConstSizes(const BufPtr b) {
return r;
}
// Construct input-specs vector from the inputs of the original graph
std::vector<mobile::nnc::InputSpec> toInputSpecs(
const std::vector<std::vector<int64_t>>& inputSizes,
const std::vector<at::ScalarType>& inputTypes) {
const std::shared_ptr<Graph>& g) {
std::vector<mobile::nnc::InputSpec> specs;
for (int i = 0; i < inputSizes.size(); ++i) {
for (auto v : g->inputs()) {
const auto& t = v->type();
mobile::nnc::InputSpec spec;
spec.sizes_ = inputSizes[i];
spec.dtype_ = inputTypes[i];
TORCH_CHECK(t->kind() == TypeKind::TensorType, "Unsupported input type");
const auto& tt = t->cast<TensorType>();
spec.sizes_ = {};
auto sizes_vec = *tt->sizes().sizes();
for (auto s : sizes_vec) {
spec.sizes_.push_back(s ? *s : 0);
}
spec.dtype_ = *tt->scalarType();
specs.emplace_back(std::move(spec));
}
return specs;
}
// Locate symbolic shapes in shapes of the inputs.
//
// For each symbolic shape we're trying to find the input from which it can be
// extracted and the dimension index in that input.
// For instance, if we have
// graph(%x : Float(SS(-1), 10), %y : Long(20, SS(-2), %ss_1 : int, %ss_2 : int)
// then we would need to find locations of two symbolic shapes: SS(-1) and
// SS(-2). The first one corresponds to the first dimension of the first input,
// the second one corresponds to the second dimension of the second input,
// so we will return {{0, 0}, {1, 1}}.
//
// If a symbolic shape cannot be found among dimensions of inputs, we
// will throw an error (this situation is possible when symbolic shape
// corresponds to the size of an intermediate - we don't support this
// case here yet).
//
// If a symbolic shape can be found in several different positions, we
// return the first one we find (TODO: maybe we should return all and
// verify that they all match at runtime).
std::vector<SymbolicShapePosition> findSymbolicShapePositions(
std::shared_ptr<tensorexpr::TensorExprKernel> kernel) {
std::vector<SymbolicShapePosition> res;
for (int64_t sym_idx : kernel->getSymbolicShapeInputs()) {
bool found = false;
for (int64_t input_idx : c10::irange(kernel->graph()->inputs().size())) {
auto input = kernel->graph()->inputs()[input_idx];
if (!input->type()->cast<TensorType>()) {
continue;
}
auto tt = input->type()->expect<TensorType>();
if (!tt->symbolic_sizes().sizes()) {
continue;
}
std::vector<at::ShapeSymbol> shape_vec = *tt->symbolic_sizes().sizes();
for (int64_t dim_idx : c10::irange(shape_vec.size())) {
if (shape_vec[dim_idx].value() == sym_idx) {
res.emplace_back(input_idx, dim_idx);
found = true;
break;
}
}
if (found) {
break;
}
}
TORCH_CHECK(
found, "Could not locate a symbolic shape among input tensor shapes");
}
return res;
}
std::unique_ptr<Function> compileMethod(
std::shared_ptr<tensorexpr::TensorExprKernel> kernel,
const std::string& method_name,
@ -56,7 +115,7 @@ std::unique_ptr<Function> compileMethod(
const std::vector<at::ScalarType>& types) {
auto func = std::make_unique<Function>();
func->set_name(method_name);
func->set_input_specs(toInputSpecs(sizes, types));
func->set_input_specs(toInputSpecs(kernel->graph()));
auto params = c10::impl::GenericList(c10::AnyType::get());
auto const_descriptors = kernel->getConstantDescriptors();
@ -103,6 +162,7 @@ std::unique_ptr<Function> compileMethod(
out_spec.push_back(output);
}
func->set_output_specs(out_spec);
func->set_sym_shape_positions(findSymbolicShapePositions(kernel));
return func;
}

View File

@ -165,6 +165,15 @@ c10::IValue Function::serialize() const {
// memory_plan_
dict.insert("memory_plan", memory_plan_.serialize());
// sym_shape_positions_
std::vector<c10::IValue> sym_shape_pos_vec;
for (const auto& sym_shape_pos : sym_shape_positions_) {
sym_shape_pos_vec.emplace_back(
Tup({sym_shape_pos.input_idx_, sym_shape_pos.dim_idx_}));
}
dict.insert("sym_shape_pos", Tup(std::move(sym_shape_pos_vec)));
return dict;
}
@ -224,18 +233,32 @@ c10::impl::GenericList Function::run(
// Fill in input tensors.
TORCH_CHECK(
input_specs_.size() == inputs.size(),
input_specs_.size() == (inputs.size() + sym_shape_positions_.size()),
"Input size doesn't match the spec, expect: ",
input_specs_.size(),
" actual: ",
inputs.size());
std::vector<int64_t> scalar_values;
int offset = 0;
for (const auto i : c10::irange(inputs.size())) {
const c10::IValue& input = inputs[i];
const auto& spec = input_specs_[i];
const auto& input_tensor = input.toTensor();
TORCH_CHECK(
input_specs_[i].validate(input_tensor), "Invalid input at pos: ", i);
args[i] = input_tensor.data_ptr();
}
offset += inputs.size();
scalar_values.reserve(sym_shape_positions_.size());
for (const auto i : c10::irange(sym_shape_positions_.size())) {
const auto& sym_shape_pos = sym_shape_positions_[i];
const c10::IValue& input = inputs[sym_shape_pos.input_idx_];
auto dim = input.toTensor().size(sym_shape_pos.dim_idx_);
scalar_values.push_back(dim);
args[i + offset] = &scalar_values[scalar_values.size() - 1];
}
offset += sym_shape_positions_.size();
// Preallocate and fill in output tensors.
c10::List<at::Tensor> outputs;
@ -243,7 +266,7 @@ c10::impl::GenericList Function::run(
for (const auto i : c10::irange(output_specs_.size())) {
at::Tensor output = output_specs_[i].allocate();
outputs.emplace_back(output);
args[inputs.size() + i] = output.data_ptr();
args[i + offset] = output.data_ptr();
}
// TODO: check consistency, e.g.: code version, input shape and compiled

View File

@ -91,6 +91,16 @@ struct TORCH_API MemoryPlan {
std::vector<int64_t> buffer_sizes_;
};
// Location of a symbolic shape among dimensions of the inputs
struct TORCH_API SymbolicShapePosition {
SymbolicShapePosition() = default;
SymbolicShapePosition(int64_t input_idx, int64_t dim_idx)
: input_idx_(input_idx), dim_idx_(dim_idx) {}
int64_t input_idx_;
int64_t dim_idx_;
};
// Represents a compiled NNC function which has a 1-1 correspondence with a
// `Method` (e.g. `forward`). It's similar as torch::jit::mobile::Function.
class TORCH_API Function {
@ -158,6 +168,15 @@ class TORCH_API Function {
memory_plan_ = memory_plan;
}
const std::vector<SymbolicShapePosition>& sym_shape_positions() const {
return sym_shape_positions_;
}
void set_sym_shape_positions(
const std::vector<SymbolicShapePosition>& sym_shape_pos) {
sym_shape_positions_ = sym_shape_pos;
}
private:
void init_execution_state() const;
@ -166,6 +185,7 @@ class TORCH_API Function {
c10::impl::GenericList parameters_{at::AnyType::get()};
std::vector<InputSpec> input_specs_;
std::vector<OutputSpec> output_specs_;
std::vector<SymbolicShapePosition> sym_shape_positions_;
MemoryPlan memory_plan_;
mutable std::unique_ptr<ExecutionState> execution_state_;
};

View File

@ -181,6 +181,10 @@ class TORCH_API TensorExprKernel {
return codegen_->kernel_func_name();
}
const std::vector<int64_t>& getSymbolicShapeInputs() const {
return symbolic_shape_inputs_;
}
private:
enum BackendType {
kUninitialized,