mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary:
This is a first step towards adding exceptions. We need minimal support in order to begin converting the torch library to weak script mode (which is the main goal here).
Some limitations (that are documented in the tests & compiler):
1. Cannot assign exceptions to variables
2. Any name after raise is being treated as a valid Exception
3. No control flow analysis yet. Below a will be undefined:
if True:
a = 1
else:
raise Exception("Hi")
return a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12789
Differential Revision: D12848936
Pulled By: eellison
fbshipit-source-id: 1f60ceef2381040486123ec797e97d65b074862d
1366 lines
58 KiB
C++
1366 lines
58 KiB
C++
#include "torch/csrc/jit/passes/shape_analysis.h"
|
|
|
|
#include "torch/csrc/jit/ir.h"
|
|
#include "torch/csrc/jit/constants.h"
|
|
#include "torch/csrc/jit/argument_spec.h"
|
|
#include "torch/csrc/jit/operator.h"
|
|
#include "torch/csrc/jit/assertions.h"
|
|
|
|
#include "torch/csrc/autograd/variable.h"
|
|
|
|
#include <ATen/DeviceGuard.h>
|
|
#include <ATen/ExpandUtils.h>
|
|
|
|
#include <exception>
|
|
#include <iostream>
|
|
#include <memory>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace torch { namespace jit {
|
|
|
|
struct propagation_error : std::exception {};
|
|
|
|
#define SHAPE_ASSERT(cond) if (!(cond)) throw propagation_error()
|
|
|
|
namespace {
|
|
|
|
void setUnshapedType(Node * node) {
|
|
for(auto o : node->outputs()) {
|
|
o->setType(unshapedType(o->type()));
|
|
}
|
|
}
|
|
|
|
int64_t wrapDim(int64_t dim, at::IntList sizes) {
|
|
if (dim < 0) {
|
|
dim += sizes.size();
|
|
}
|
|
return dim;
|
|
}
|
|
|
|
IValue representativeValue(Value* v) {
|
|
TypePtr type_ = v->type();
|
|
// if the value is actually constant, just use it!
|
|
if(auto iv = toIValue(v)) {
|
|
return *iv;
|
|
}
|
|
if (CompleteTensorTypePtr type = type_->cast<CompleteTensorType>()) {
|
|
auto backend = type->device() == -1 ? at::Backend::CPU : at::Backend::CUDA;
|
|
at::DeviceGuard device_guard(type->device());
|
|
auto& attype = at::getNonVariableType(backend, type->scalarType());
|
|
auto t = at::empty_strided(type->sizes(), type->strides(), attype.options()).zero_();
|
|
return autograd::make_variable(t, /*requires_grad=*/false);
|
|
} else if (type_->isSubtypeOf(FloatType::get())) {
|
|
return 0.f;
|
|
}
|
|
// we should not get here because isValidArgumentForRunning should have
|
|
// prevented it
|
|
std::stringstream ss;
|
|
ss << "unable to create representative value for: " << type_->str()
|
|
<< ". File a bug report.";
|
|
throw std::runtime_error(ss.str());
|
|
}
|
|
|
|
void PropagateShapeOnBlock(Block * block, bool insert_expands=true);
|
|
|
|
// for each node in the schema with type Tensor, extract the T type
|
|
// returns c10::nullopt if any Tensor in the schema does not have a known shape
|
|
// ignores non-tensor in the list of inputs
|
|
template <typename T>
|
|
c10::optional<std::vector<std::shared_ptr<T>>> gatherTensorTypes(Node* node) {
|
|
std::vector<std::shared_ptr<T>> tensor_types;
|
|
|
|
auto & schema = node->schema();
|
|
auto & args = schema.arguments();
|
|
// can't handle varargs primitives because we don't know what should be a Tensor
|
|
if (schema.is_vararg()) {
|
|
return c10::nullopt;
|
|
}
|
|
for (size_t i = 0; i < args.size(); ++i) {
|
|
if (args[i].type()->isSubtypeOf(ListType::ofTensors())) {
|
|
return c10::nullopt;
|
|
} else if (args[i].type()->isSubtypeOf(DynamicType::get())) {
|
|
if (auto type = node->input(i)->type()->cast<T>()) {
|
|
tensor_types.push_back(type);
|
|
} else {
|
|
return c10::nullopt;
|
|
}
|
|
} else /* non-tensor type */ {
|
|
continue;
|
|
}
|
|
}
|
|
|
|
return tensor_types;
|
|
}
|
|
|
|
bool mergeTypes(ArrayRef<Value*> lhs, ArrayRef<Value*> rhs, ArrayRef<Value*> outputs) {
|
|
JIT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size());
|
|
bool changed = false;
|
|
for(size_t i = 0; i < lhs.size(); ++i) {
|
|
auto old_output_type = outputs[i]->type();
|
|
auto new_type = unifyTypes(lhs[i]->type(), rhs[i]->type());
|
|
JIT_ASSERT(new_type);
|
|
outputs[i]->setType(*new_type);
|
|
if(*old_output_type != *outputs[i]->type())
|
|
changed = true;
|
|
}
|
|
return changed;
|
|
}
|
|
|
|
void PropagateShapeOnNode(Node * node, bool insert_expands=true);
|
|
|
|
void broadcastBinary(Node *node, std::vector<CompleteTensorTypePtr>& types, size_t idx1, size_t idx2) {
|
|
auto expected_size = at::infer_size(types[idx1]->sizes(), types[idx2]->sizes());
|
|
auto broadcast = [&](size_t input_idx) {
|
|
CompleteTensorTypePtr input_type = types.at(input_idx);
|
|
if (input_type->sizes() == expected_size)
|
|
return;
|
|
auto graph = node->owningGraph();
|
|
WithInsertPoint point_guard { node };
|
|
Node *expand = graph->create(aten::expand,
|
|
{node->inputs().at(input_idx),
|
|
graph->insertConstant(expected_size),
|
|
graph->insertConstant(false)})
|
|
->insertBefore(node);
|
|
PropagateShapeOnNode(expand);
|
|
node->replaceInput(input_idx, expand->output());
|
|
};
|
|
broadcast(idx1);
|
|
broadcast(idx2);
|
|
types[0] = node->inputs().at(idx1)->type()->expect<CompleteTensorType>();
|
|
types[1] = node->inputs().at(idx2)->type()->expect<CompleteTensorType>();
|
|
}
|
|
|
|
bool isValidArgumentForRunning(Value* v) {
|
|
// allow constants
|
|
if(toIValue(v))
|
|
return true;
|
|
if(CompleteTensorTypePtr tt = v->type()->cast<CompleteTensorType>()) {
|
|
return !at::isIntegralType(tt->scalarType());
|
|
}
|
|
return v->type()->isSubtypeOf(FloatType::get());
|
|
}
|
|
bool isValidReturnForRunning(Value* v) {
|
|
return v->type()->isSubtypeOf(DynamicType::get()) ||
|
|
v->type()->isSubtypeOf(NumberType::get());
|
|
}
|
|
|
|
OperatorSet cannot_propagate_shape_by_running_it = {
|
|
"aten::gesv(Tensor self, Tensor A) -> (Tensor, Tensor)",
|
|
"aten::inverse(Tensor self) -> Tensor",
|
|
};
|
|
|
|
bool canPropagateShapeByRunningIt(Node* node) {
|
|
if(cannot_propagate_shape_by_running_it.find(node)) {
|
|
return false;
|
|
}
|
|
|
|
bool valid_args = std::all_of(
|
|
node->inputs().begin(), node->inputs().end(), isValidArgumentForRunning);
|
|
if (!valid_args)
|
|
return false;
|
|
|
|
bool valid_returns = std::all_of(
|
|
node->outputs().begin(), node->outputs().end(), isValidReturnForRunning);
|
|
if (!valid_returns)
|
|
return false;
|
|
|
|
return true;
|
|
}
|
|
|
|
bool PropagateShapeOnNodeByRunningIt(Node* node) {
|
|
if (!canPropagateShapeByRunningIt(node))
|
|
return false;
|
|
auto op = getOperation(node);
|
|
Stack stack;
|
|
|
|
for (auto input : node->inputs()) {
|
|
stack.push_back(representativeValue(input));
|
|
}
|
|
|
|
// XXX: we're not catching any exceptions from the op for now. This
|
|
// is to uncover any mistakes we could make when editing this code,
|
|
// and eventually it shouldn't matter, because this phase should be
|
|
// preceded by schema checking.
|
|
op(stack);
|
|
|
|
JIT_ASSERT(stack.size() == node->outputs().size());
|
|
for (size_t i = 0; i < stack.size(); ++i) {
|
|
// some ops may have mixed tensor/primitive outputs
|
|
// for primitives, we don't need to change the type because it is already
|
|
// its most constrained form.
|
|
if(stack[i].isTensor())
|
|
node->outputs()[i]->inferTypeFrom(stack[i].toTensor());
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// is it ok to try to run the op
|
|
// If an input is a constant, then we assume that the input is valid
|
|
// and we can try to run it.
|
|
// Otherwise:
|
|
// Integral typed _inputs_ are often an indicator that we're indexing into
|
|
// a tensor, so we should special-case these ops in the shape propagation.
|
|
// Additionally, passing in a zero representative tensor into an integer
|
|
// division op causes divide-by-zero errors
|
|
// _Outputs_ must be tensors or primtives
|
|
// We will call inferTypeFrom on the tensors, and ignore the primitives.
|
|
// However, we allow primitive returns because we want to support mixed
|
|
// primitive/tensor outputs.
|
|
|
|
bool PropagateTensorShapeOnNode(Node * node, bool insert_expands);
|
|
bool PropagateCompleteShapeOnNode(
|
|
Node * node, bool insert_expands, std::vector<CompleteTensorTypePtr> types);
|
|
|
|
void PropagateCatShape(Node * cat_node) {
|
|
static const auto propagate_complete = [](Node * node, at::ArrayRef<Value*> tensors) -> bool {
|
|
auto input_types = fmap(tensors, [](Value *v) { return v->type()->cast<CompleteTensorType>(); });
|
|
if (!std::all_of(input_types.begin(), input_types.end(),
|
|
[](const CompleteTensorTypePtr& tp) { return tp != nullptr; })) {
|
|
return false;
|
|
}
|
|
if (!node->is_constant(attr::dim)) return false;
|
|
std::vector<int64_t> sizes = input_types[0]->sizes();
|
|
const int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
|
|
const int64_t ndim = sizes.size();
|
|
|
|
if (dim < 0 || dim >= ndim)
|
|
return false;
|
|
|
|
sizes[dim] = 0;
|
|
for (auto & tp : input_types) {
|
|
auto & tp_sizes = tp->sizes();
|
|
if (sizes.size() != tp_sizes.size())
|
|
return false;
|
|
for (int64_t i = 0; i < ndim; ++i) {
|
|
if (sizes[i] != tp_sizes[i] && i != dim) {
|
|
return false;
|
|
}
|
|
}
|
|
sizes[dim] += tp_sizes[dim];
|
|
}
|
|
node->output()->setType(input_types[0]->withSizes(sizes));
|
|
return true;
|
|
};
|
|
static const auto propagate = [](Node * node, at::ArrayRef<Value*> tensors) -> bool {
|
|
for (Value * v : tensors) {
|
|
if (auto type = v->type()->cast<TensorType>()) {
|
|
node->output()->setType(type);
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
};
|
|
auto list_node = cat_node->namedInput(attr::tensors)->node();
|
|
if (list_node->kind() == prim::ListConstruct) {
|
|
auto tensors = list_node->inputs();
|
|
if (!tensors.empty()) {
|
|
if (propagate_complete(cat_node, tensors)) {
|
|
return;
|
|
} else if (propagate(cat_node, tensors)) {
|
|
return;
|
|
}
|
|
}
|
|
}
|
|
setUnshapedType(cat_node);
|
|
}
|
|
|
|
void PropagateShapeOnNode(Node * node, bool insert_expands) {
|
|
// These don't require the types, and have complicated schema. Return early after we process them.
|
|
switch(node->kind()) {
|
|
case prim::If: {
|
|
auto then_block = node->blocks().at(0);
|
|
auto else_block = node->blocks().at(1);
|
|
PropagateShapeOnBlock(then_block);
|
|
PropagateShapeOnBlock(else_block);
|
|
mergeTypes(then_block->outputs(), else_block->outputs(), node->outputs());
|
|
return;
|
|
}
|
|
case prim::Loop: {
|
|
auto body_block = node->blocks().at(0);
|
|
// propagate counter type
|
|
body_block->inputs().at(0)->setType(node->inputs().at(0)->type());
|
|
// propagate loop-carried input types to block inputs
|
|
auto loop_carried_inputs = node->inputs().slice(2); // skip max, cond
|
|
auto loop_carried_block = body_block->inputs().slice(1); // skip trip
|
|
for(size_t i = 0; i < loop_carried_inputs.size(); ++i) {
|
|
loop_carried_block[i]->setType(loop_carried_inputs[i]->type());
|
|
}
|
|
auto loop_carried_outputs = body_block->outputs().slice(1); // skip cond
|
|
|
|
do {
|
|
PropagateShapeOnBlock(body_block, /*insert_expands=*/false);
|
|
// note: inserting expands is unsafe at this point, we don't know
|
|
// if the types are stable yet, so the arguments to expand may change
|
|
} while(mergeTypes(loop_carried_block, loop_carried_outputs, loop_carried_block));
|
|
|
|
// now that the types are stable, we can insert the expands
|
|
PropagateShapeOnBlock(body_block, /*insert_expands=*/true);
|
|
|
|
|
|
for(size_t i = 0; i < loop_carried_inputs.size(); ++i) {
|
|
node->outputs()[i]->setType(loop_carried_block[i]->type());
|
|
}
|
|
return;
|
|
}
|
|
case prim::ImplicitTensorToNum:
|
|
case prim::TensorToNum:
|
|
return; // correct num type is already set
|
|
case prim::NumToTensor: {
|
|
if (node->input()->type()->isSubtypeOf(IntType::get())) {
|
|
node->output()->setType(TensorType::create(at::kLong, -1, 0));
|
|
} else {
|
|
JIT_ASSERT(node->input()->type()->isSubtypeOf(FloatType::get()));
|
|
node->output()->setType(TensorType::create(at::kDouble, -1, 0));
|
|
}
|
|
return;
|
|
}
|
|
case prim::TupleConstruct: {
|
|
// We refresh the tuple type, because the input types could have been refined.
|
|
node->output()->setType(TupleType::create(fmap(node->inputs(), [](Value *v) { return v->type(); })));
|
|
return;
|
|
}
|
|
case prim::TupleUnpack: {
|
|
auto tuple_type = node->input()->type()->cast<TupleType>();
|
|
JIT_ASSERT(tuple_type && tuple_type->elements().size() == node->outputs().size());
|
|
auto elems = tuple_type->elements();
|
|
for (size_t i = 0; i < node->outputs().size(); ++i) {
|
|
node->output(i)->setType(elems[i]);
|
|
}
|
|
return;
|
|
}
|
|
case prim::Constant: {
|
|
if(node->output()->type()->isSubtypeOf(DynamicType::get())) {
|
|
node->output()->inferTypeFrom(node->t(attr::value));
|
|
}
|
|
return;
|
|
}
|
|
case prim::ConstantChunk: {
|
|
Value *tensor = node->input();
|
|
if (auto type = tensor->type()->cast<TensorType>()) {
|
|
for (Value * output : node->outputs()) {
|
|
output->setType(type);
|
|
}
|
|
} else {
|
|
setUnshapedType(node);
|
|
}
|
|
return;
|
|
}
|
|
case prim::PythonOp:
|
|
case prim::Print:
|
|
case prim::RaiseException:
|
|
case prim::Undefined: {
|
|
setUnshapedType(node);
|
|
return;
|
|
}
|
|
default:
|
|
break; // fall-through
|
|
}
|
|
if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor")) {
|
|
return PropagateCatShape(node);
|
|
}
|
|
|
|
if (auto maybe_complete_types = gatherTensorTypes<CompleteTensorType>(node)) {
|
|
if (PropagateCompleteShapeOnNode(node, insert_expands, std::move(*maybe_complete_types))) {
|
|
return;
|
|
}
|
|
}
|
|
|
|
if (PropagateTensorShapeOnNode(node, insert_expands)) {
|
|
return;
|
|
}
|
|
|
|
if (PropagateShapeOnNodeByRunningIt(node)) {
|
|
return;
|
|
}
|
|
return setUnshapedType(node);
|
|
}
|
|
|
|
static c10::optional<size_t> determineListSize(Value* list) {
|
|
JIT_ASSERT(list->type()->cast<ListType>());
|
|
if (auto shape = constant_as<std::vector<int64_t>>(list)) {
|
|
return shape->size();
|
|
}
|
|
auto input_node = list->node();
|
|
if (input_node->kind() == prim::ListConstruct) {
|
|
return input_node->inputs().size();
|
|
}
|
|
return c10::nullopt;
|
|
}
|
|
|
|
bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) {
|
|
static const auto broadcast = [](std::vector<TensorTypePtr>& tensor_types) -> TensorTypePtr {
|
|
if (tensor_types.size() == 1) {
|
|
return tensor_types[0];
|
|
}
|
|
JIT_ASSERT(!tensor_types.empty());
|
|
auto any_type = tensor_types[0];
|
|
auto max_dims = any_type->dim();
|
|
for (auto & type : tensor_types) {
|
|
max_dims = std::max(max_dims, type->dim());
|
|
}
|
|
return TensorType::create(any_type->scalarType(), any_type->device(), max_dims);
|
|
};
|
|
|
|
using type_vec_t = std::vector<TensorTypePtr>;
|
|
// Formula is expected to return a vector of length equal to the number of tensor
|
|
// outputs of the node, or an empty vector which implies that it failed to propagate.
|
|
using formula_t = std::function<type_vec_t(Node *)>;
|
|
static std::mutex shape_formulas_mutex;
|
|
static std::vector<std::pair<OperatorSet, formula_t>> shape_formulas;
|
|
struct register_formula_for {
|
|
register_formula_for(OperatorSet operators, formula_t formula) {
|
|
std::unique_lock<std::mutex> lock {shape_formulas_mutex};
|
|
shape_formulas.emplace_back(std::move(operators), std::move(formula));
|
|
}
|
|
};
|
|
|
|
// Requirements:
|
|
// dims : preserved
|
|
// scalar type : preserved
|
|
// device : preserved
|
|
// tensor inputs : 1
|
|
// tensor outputs : 1
|
|
// Additionally:
|
|
// - First input should be the only tensor input
|
|
static const register_formula_for simple_unary_ops {{
|
|
"aten::abs(Tensor self) -> Tensor",
|
|
"aten::acos(Tensor self) -> Tensor",
|
|
"aten::neg(Tensor self) -> Tensor",
|
|
"aten::t(Tensor self) -> Tensor",
|
|
"aten::sigmoid(Tensor self) -> Tensor",
|
|
"aten::tanh(Tensor self) -> Tensor",
|
|
"aten::exp(Tensor self) -> Tensor",
|
|
"aten::relu(Tensor self) -> Tensor",
|
|
"aten::asin(Tensor self) -> Tensor",
|
|
"aten::atan(Tensor self) -> Tensor",
|
|
"aten::ceil(Tensor self) -> Tensor",
|
|
"aten::clone(Tensor self) -> Tensor",
|
|
"aten::contiguous(Tensor self) -> Tensor",
|
|
"aten::bernoulli(Tensor self, *, Generator generator) -> Tensor",
|
|
"aten::celu(Tensor self, Scalar alpha) -> Tensor",
|
|
"aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor",
|
|
"aten::clamp_max(Tensor self, Scalar max) -> Tensor",
|
|
"aten::clamp_min(Tensor self, Scalar min) -> Tensor",
|
|
"aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor",
|
|
"aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor",
|
|
"aten::cos(Tensor self) -> Tensor",
|
|
"aten::cosh(Tensor self) -> Tensor",
|
|
"aten::digamma(Tensor self) -> Tensor",
|
|
"aten::dropout(Tensor input, float p, bool train) -> Tensor",
|
|
"aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor",
|
|
"aten::erf(Tensor self) -> Tensor",
|
|
"aten::erfc(Tensor self) -> Tensor",
|
|
"aten::erfinv(Tensor self) -> Tensor",
|
|
"aten::exp(Tensor self) -> Tensor",
|
|
"aten::expm1(Tensor self) -> Tensor",
|
|
"aten::log(Tensor self) -> Tensor",
|
|
"aten::log10(Tensor self) -> Tensor",
|
|
"aten::log1p(Tensor self) -> Tensor",
|
|
"aten::log2(Tensor self) -> Tensor",
|
|
"aten::log_sigmoid(Tensor self) -> Tensor",
|
|
"aten::log_softmax(Tensor self, int dim) -> Tensor",
|
|
"aten::floor(Tensor self) -> Tensor",
|
|
"aten::frac(Tensor self) -> Tensor",
|
|
"aten::flip(Tensor self, int[] dims) -> Tensor",
|
|
"aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor",
|
|
"aten::feature_dropout(Tensor input, float p, bool train) -> Tensor",
|
|
"aten::hardshrink(Tensor self, Scalar lambd) -> Tensor",
|
|
"aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor",
|
|
"aten::glu(Tensor self, int dim) -> Tensor",
|
|
"aten::inverse(Tensor self) -> Tensor",
|
|
"aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor",
|
|
"aten::lgamma(Tensor self) -> Tensor",
|
|
"aten::mvlgamma(Tensor self, int p) -> Tensor",
|
|
"aten::normal(float mean, Tensor std, *, Generator generator) -> Tensor",
|
|
"aten::normal(Tensor mean, float std, *, Generator generator) -> Tensor",
|
|
"aten::permute(Tensor self, int[] dims) -> Tensor",
|
|
"aten::pin_memory(Tensor self) -> Tensor",
|
|
"aten::pinverse(Tensor self, float rcond) -> Tensor",
|
|
"aten::reciprocal(Tensor self) -> Tensor",
|
|
"aten::relu(Tensor self) -> Tensor",
|
|
"aten::round(Tensor self) -> Tensor",
|
|
"aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor",
|
|
"aten::rsqrt(Tensor self) -> Tensor",
|
|
"aten::selu(Tensor self) -> Tensor",
|
|
"aten::sigmoid(Tensor self) -> Tensor",
|
|
"aten::sign(Tensor self) -> Tensor",
|
|
"aten::sin(Tensor self) -> Tensor",
|
|
"aten::sinh(Tensor self) -> Tensor",
|
|
"aten::softmax(Tensor self, int dim) -> Tensor",
|
|
"aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor",
|
|
"aten::softshrink(Tensor self, Scalar lambd) -> Tensor",
|
|
"aten::sqrt(Tensor self) -> Tensor",
|
|
"aten::tan(Tensor self) -> Tensor",
|
|
"aten::tanh(Tensor self) -> Tensor",
|
|
"aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor",
|
|
"aten::transpose(Tensor self, int dim0, int dim1) -> Tensor",
|
|
"aten::tril(Tensor self, int diagonal) -> Tensor",
|
|
"aten::triu(Tensor self, int diagonal) -> Tensor",
|
|
"aten::trunc(Tensor self) -> Tensor",
|
|
"aten::rot90(Tensor self, int k, int[] dims) -> Tensor",
|
|
"aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
|
|
"aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor",
|
|
"aten::alias(Tensor self) -> Tensor",
|
|
"aten::detach(Tensor self) -> Tensor",
|
|
"aten::cumprod(Tensor self, int dim) -> Tensor",
|
|
"aten::cumsum(Tensor self, int dim) -> Tensor",
|
|
|
|
"aten::empty_like(Tensor self) -> Tensor",
|
|
"aten::full_like(Tensor self, Scalar fill_value) -> Tensor",
|
|
"aten::ones_like(Tensor self) -> Tensor",
|
|
"aten::rand_like(Tensor self) -> Tensor",
|
|
"aten::randint_like(Tensor self, int high) -> Tensor",
|
|
"aten::randint_like(Tensor self, int low, int high) -> Tensor",
|
|
"aten::randn_like(Tensor self) -> Tensor",
|
|
"aten::zeros_like(Tensor self) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
auto input_type = node->input(0)->type()->cast<TensorType>();
|
|
return input_type ? type_vec_t{input_type} : type_vec_t{};
|
|
}};
|
|
|
|
// Requirements:
|
|
// dims : broadcast all tensor args
|
|
// scalar type : always matching and preserved
|
|
// device : always matching and preserved
|
|
// tensor inputs : *
|
|
// tensor outputs : 1
|
|
static const register_formula_for broadcasting_ops {{
|
|
// Tensor-Tensor operators
|
|
"aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
|
|
"aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor",
|
|
"aten::mul(Tensor self, Tensor other) -> Tensor",
|
|
"aten::div(Tensor self, Tensor other) -> Tensor",
|
|
"aten::pow(Tensor self, Tensor exponent) -> Tensor",
|
|
"aten::min(Tensor self, Tensor other) -> Tensor",
|
|
"aten::max(Tensor self, Tensor other) -> Tensor",
|
|
"aten::fmod(Tensor self, Tensor other) -> Tensor",
|
|
"aten::remainder(Tensor self, Tensor other) -> Tensor",
|
|
"aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor",
|
|
"aten::max(Tensor self, Tensor other) -> Tensor",
|
|
"aten::min(Tensor self, Tensor other) -> Tensor",
|
|
"aten::__and__(Tensor self, Tensor other) -> Tensor",
|
|
"aten::__or__(Tensor self, Tensor other) -> Tensor",
|
|
"aten::__xor__(Tensor self, Tensor other) -> Tensor",
|
|
"aten::__lshift__(Tensor self, Tensor other) -> Tensor",
|
|
"aten::__rshift__(Tensor self, Tensor other) -> Tensor",
|
|
"aten::__iand__(Tensor self, Tensor other) -> Tensor",
|
|
"aten::__ior__(Tensor self, Tensor other) -> Tensor",
|
|
"aten::__ixor__(Tensor self, Tensor other) -> Tensor",
|
|
"aten::__ilshift__(Tensor self, Tensor other) -> Tensor",
|
|
"aten::__irshift__(Tensor self, Tensor other) -> Tensor",
|
|
|
|
// Tensor-Scalar operators
|
|
"aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor",
|
|
"aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor",
|
|
"aten::mul(Tensor self, Scalar other) -> Tensor",
|
|
"aten::div(Tensor self, Scalar other) -> Tensor",
|
|
"aten::pow(Tensor self, Scalar exponent) -> Tensor",
|
|
"aten::fmod(Tensor self, Scalar other) -> Tensor",
|
|
"aten::remainder(Tensor self, Scalar other) -> Tensor",
|
|
"aten::pow(Scalar self, Tensor exponent) -> Tensor",
|
|
"aten::__and__(Tensor self, Scalar other) -> Tensor",
|
|
"aten::__or__(Tensor self, Scalar other) -> Tensor",
|
|
"aten::__xor__(Tensor self, Scalar other) -> Tensor",
|
|
"aten::__lshift__(Tensor self, Scalar other) -> Tensor",
|
|
"aten::__rshift__(Tensor self, Scalar other) -> Tensor",
|
|
"aten::__iand__(Tensor self, Scalar other) -> Tensor",
|
|
"aten::__ior__(Tensor self, Scalar other) -> Tensor",
|
|
"aten::__ixor__(Tensor self, Scalar other) -> Tensor",
|
|
"aten::__ilshift__(Tensor self, Scalar other) -> Tensor",
|
|
"aten::__irshift__(Tensor self, Scalar other) -> Tensor",
|
|
|
|
// Ops with Tensor-Tensor overloads only
|
|
"aten::atan2(Tensor self, Tensor other) -> Tensor",
|
|
|
|
// Non-binary ops
|
|
"aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor",
|
|
"aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
|
|
"aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
if (auto maybe_tensor_types = gatherTensorTypes<TensorType>(node)) {
|
|
return {broadcast(*maybe_tensor_types)};
|
|
}
|
|
return {};
|
|
}};
|
|
|
|
static const auto any_tensor_type = [](Node * node) -> TensorTypePtr {
|
|
for (Value * input : node->inputs()) {
|
|
if (auto type = input->type()->cast<TensorType>()) {
|
|
return type;
|
|
}
|
|
}
|
|
return nullptr;
|
|
};
|
|
|
|
// Requirements:
|
|
// dims : always matching and preserved
|
|
// scalar type : always matching and preserved
|
|
// device : always matching and preserved
|
|
// tensor inputs : 2
|
|
// tensor outputs : 1
|
|
static const register_formula_for binary_ops_strict_match {{
|
|
"aten::normal(Tensor mean, Tensor std, *, Generator generator) -> Tensor",
|
|
"aten::mm(Tensor self, Tensor mat2) -> Tensor",
|
|
"aten::bmm(Tensor self, Tensor mat2) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
if (auto type = any_tensor_type(node)) {
|
|
return {type};
|
|
}
|
|
return {};
|
|
}};
|
|
|
|
// Requirements:
|
|
// dims : all tensor args are broadcast
|
|
// scalar type : byte/uint8
|
|
// device : always matching and preserved
|
|
// tensor inputs : *
|
|
// tensor outputs : 1
|
|
static const register_formula_for comparison_ops {{
|
|
"aten::lt(Tensor self, Tensor other) -> Tensor",
|
|
"aten::le(Tensor self, Tensor other) -> Tensor",
|
|
"aten::gt(Tensor self, Tensor other) -> Tensor",
|
|
"aten::ge(Tensor self, Tensor other) -> Tensor",
|
|
"aten::eq(Tensor self, Tensor other) -> Tensor",
|
|
"aten::ne(Tensor self, Tensor other) -> Tensor",
|
|
"aten::lt(Tensor self, Scalar other) -> Tensor",
|
|
"aten::le(Tensor self, Scalar other) -> Tensor",
|
|
"aten::gt(Tensor self, Scalar other) -> Tensor",
|
|
"aten::ge(Tensor self, Scalar other) -> Tensor",
|
|
"aten::eq(Tensor self, Scalar other) -> Tensor",
|
|
"aten::ne(Tensor self, Scalar other) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
if (auto maybe_tensor_types = gatherTensorTypes<TensorType>(node)) {
|
|
return {broadcast(*maybe_tensor_types)->toScalarType(at::kByte)};
|
|
}
|
|
return {};
|
|
}};
|
|
|
|
// Requirements:
|
|
// dims : preserved from the first argument
|
|
// scalar type : preserved from the first argument (doesn't have to match other arguments)
|
|
// device : always matching and preserved
|
|
// tensor inputs : *
|
|
// tensor outputs : 1
|
|
// NB: those ops (with slight adjustments) are good candidates for restarts.
|
|
// Knowing the type and device of weights or biases is usually enough to
|
|
// infer the output type.
|
|
static const register_formula_for nn_ops_first_input_preserving {{
|
|
"aten::batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor",
|
|
"aten::conv1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
|
|
"aten::conv2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
|
|
"aten::conv3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor",
|
|
"aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad) -> Tensor",
|
|
"aten::conv_transpose1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
|
|
"aten::conv_transpose2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
|
|
"aten::conv_transpose3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor",
|
|
"aten::convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor",
|
|
"aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor",
|
|
"aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor",
|
|
"aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor",
|
|
"aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
|
|
"aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
|
|
"aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor",
|
|
"aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
|
|
"aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
|
|
"aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor",
|
|
"aten::max_unpool2d(Tensor self, Tensor indices, int[] output_size) -> Tensor",
|
|
"aten::max_unpool3d(Tensor self, Tensor indices, int[] output_size, int[] stride, int[] padding) -> Tensor",
|
|
"aten::reflection_pad1d(Tensor self, int[] padding) -> Tensor",
|
|
"aten::reflection_pad2d(Tensor self, int[] padding) -> Tensor",
|
|
"aten::replication_pad1d(Tensor self, int[] padding) -> Tensor",
|
|
"aten::replication_pad2d(Tensor self, int[] padding) -> Tensor",
|
|
"aten::replication_pad3d(Tensor self, int[] padding) -> Tensor",
|
|
"aten::upsample_bilinear2d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
|
|
"aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
|
|
"aten::upsample_nearest1d(Tensor self, int[] output_size) -> Tensor",
|
|
"aten::upsample_nearest2d(Tensor self, int[] output_size) -> Tensor",
|
|
"aten::upsample_nearest3d(Tensor self, int[] output_size) -> Tensor",
|
|
"aten::upsample_trilinear3d(Tensor self, int[] output_size, bool align_corners) -> Tensor",
|
|
"aten::prelu(Tensor self, Tensor weight) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
if (auto type = node->input(0)->type()->cast<TensorType>()) {
|
|
return {type};
|
|
}
|
|
return {};
|
|
}};
|
|
|
|
// Requirements:
|
|
// dims : 0
|
|
// scalar type : preserved
|
|
// device : preserved
|
|
// tensor inputs : 1
|
|
// tensor outputs : 1
|
|
// Additionally:
|
|
// - First input should be the only tensor input
|
|
static const register_formula_for all_reduce_ops {{
|
|
"aten::argmax(Tensor self) -> Tensor",
|
|
"aten::argmin(Tensor self) -> Tensor",
|
|
"aten::det(Tensor self) -> Tensor",
|
|
"aten::logdet(Tensor self) -> Tensor",
|
|
"aten::max(Tensor self) -> Tensor",
|
|
"aten::min(Tensor self) -> Tensor",
|
|
"aten::mean(Tensor self) -> Tensor",
|
|
"aten::median(Tensor self) -> Tensor",
|
|
"aten::norm(Tensor self, Scalar p) -> Tensor",
|
|
"aten::std(Tensor self, bool unbiased) -> Tensor",
|
|
"aten::sum(Tensor self) -> Tensor",
|
|
"aten::trace(Tensor self) -> Tensor",
|
|
"aten::var(Tensor self, bool unbiased) -> Tensor",
|
|
"aten::all(Tensor self) -> Tensor",
|
|
"aten::any(Tensor self) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
if (auto type = node->input(0)->type()->cast<TensorType>()) {
|
|
return {type->withDim(0)};
|
|
}
|
|
return {};
|
|
}};
|
|
|
|
// Requirements:
|
|
// dims : 0
|
|
// scalar type : preserved if floating point, otherwise long/int64
|
|
// device : preserved
|
|
// tensor inputs : 1
|
|
// tensor outputs : 1
|
|
// Additionally:
|
|
// - First input should be the only tensor input
|
|
static const register_formula_for all_reduce_ops_with_integer_upcast {{
|
|
"aten::sum(Tensor self) -> Tensor",
|
|
"aten::prod(Tensor self) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
if (auto type = node->input(0)->type()->cast<TensorType>()) {
|
|
return {at::isFloatingType(type->scalarType()) ? type->withDim(0) : type->withDim(0)->toScalarType(at::kLong)};
|
|
}
|
|
return {};
|
|
}};
|
|
|
|
static const auto multidim_reduce_with_postprocess =
|
|
[](Node * node, size_t num_reduced_dim, bool upcast_integer) -> type_vec_t {
|
|
auto maybe_keepdim = node->get<bool>(attr::keepdim);
|
|
if (!maybe_keepdim) return {};
|
|
if (auto type = node->input(0)->type()->cast<TensorType>()) {
|
|
if (upcast_integer && !at::isFloatingType(type->scalarType())) {
|
|
type = type->toScalarType(at::kLong);
|
|
}
|
|
if (*maybe_keepdim) {
|
|
return {type};
|
|
} else if (type->dim() > num_reduced_dim) {
|
|
return {type->withDim(type->dim() - num_reduced_dim)};
|
|
}
|
|
}
|
|
return {};
|
|
};
|
|
|
|
// Requirements:
|
|
// dims : preserved if keepdim == false, 1 smaller otherwise
|
|
// scalar type : preserved for first output, byte/uint8 for second output if exists
|
|
// device : preserved
|
|
// tensor inputs : 1
|
|
// tensor outputs : 1 or 2
|
|
// Additionally:
|
|
// - First input should be the only tensor input
|
|
// - Has a bool keepdim argument
|
|
static const register_formula_for dim_reduce_ops {{
|
|
"aten::argmax(Tensor self, int dim, bool keepdim) -> Tensor",
|
|
"aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor",
|
|
"aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor",
|
|
"aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor",
|
|
"aten::mean(Tensor self, int dim, bool keepdim) -> Tensor",
|
|
"aten::norm(Tensor self, Scalar p, int dim, bool keepdim) -> Tensor",
|
|
"aten::std(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
|
|
"aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor",
|
|
"aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor",
|
|
"aten::all(Tensor self, int dim, bool keepdim) -> Tensor",
|
|
"aten::any(Tensor self, int dim, bool keepdim) -> Tensor",
|
|
|
|
// Ops returning indices as second output
|
|
"aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)",
|
|
"aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
|
|
"aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
|
|
"aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
|
|
"aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)",
|
|
}, [](Node * node) -> type_vec_t {
|
|
// NB: Note that while this function is generally meant to be used with ops that
|
|
// have a single output, we will fix up its return right below.
|
|
auto output_types = multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/1, /*integer_upcast=*/false);
|
|
if (!output_types.empty() && node->outputs().size() == 2) {
|
|
output_types.push_back(output_types.back()->toScalarType(at::kLong));
|
|
}
|
|
return output_types;
|
|
}};
|
|
|
|
// Requirements:
|
|
// dims : preserved if keepdim == false, 1 smaller otherwise
|
|
// scalar type : preserved if floating point, otherwise long/int64
|
|
// device : preserved
|
|
// tensor inputs : 1
|
|
// tensor outputs : 1
|
|
// Additionally:
|
|
// - First input should be the only tensor input
|
|
// - has a bool keepdim argument
|
|
static const register_formula_for dim_reduce_ops_with_integer_upcast {{
|
|
"aten::prod(Tensor self, int dim, bool keepdim) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/1, /*integer_upcast=*/true);
|
|
}};
|
|
|
|
// Requirements:
|
|
// dims : preserved if keepdim == false, 1 smaller otherwise
|
|
// scalar type : preserved if floating point, otherwise long/int64
|
|
// device : preserved
|
|
// tensor inputs : 1
|
|
// tensor outputs : 1
|
|
// Additionally:
|
|
// - has bool keepdim and int[] dim arguments
|
|
static const register_formula_for multidim_reduce_ops_with_integer_upcast {{
|
|
"aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
if (auto dim = node->get<std::vector<int64_t>>(attr::dim)) {
|
|
// TODO: can dim contain duplicates?
|
|
return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/dim->size(), /*integer_upcast=*/true);
|
|
}
|
|
return {};
|
|
}};
|
|
|
|
static const auto get_device_int =
|
|
[](c10::optional<at::Device> dev) -> c10::optional<int> {
|
|
if (!dev)
|
|
return {};
|
|
if (dev->is_cpu()) {
|
|
return {-1};
|
|
}
|
|
return dev->has_index() ? c10::optional<int>{dev->index()} : c10::nullopt;
|
|
};
|
|
static const auto factory_with_ndim = [](Node * node, int dim) -> type_vec_t{
|
|
auto maybe_layout = node->get<at::Layout>(attr::layout);
|
|
if (!maybe_layout || maybe_layout != at::kStrided) return {};
|
|
auto maybe_device = get_device_int(node->get<at::Device>(attr::device));
|
|
if (!maybe_device) return {};
|
|
auto maybe_scalar_type = node->get<at::ScalarType>(attr::dtype);
|
|
if (!maybe_scalar_type) return {};
|
|
return {TensorType::create(*maybe_scalar_type, *maybe_device, dim)};
|
|
};
|
|
|
|
// Requirements:
|
|
// dims : preserved
|
|
// scalar type : equal to value of dtype
|
|
// device : equal to value of device
|
|
// tensor inputs : 1
|
|
// tensor outputs : 1
|
|
// Additionally:
|
|
// - has ScalarType dtype, Layeout layout and Device device arguments
|
|
static const register_formula_for like_factories_with_options {{
|
|
"aten::empty_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::ones_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::rand_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::randint_like(Tensor self, int high, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::randn_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::zeros_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
if (auto type = node->namedInput(attr::self)->type()->cast<TensorType>()) {
|
|
return factory_with_ndim(node, type->dim());
|
|
}
|
|
return {};
|
|
}};
|
|
|
|
// Requirements:
|
|
// dims : equal to number of elements in size
|
|
// scalar type : equal to value of dtype
|
|
// device : equal to value of device
|
|
// tensor inputs : 1
|
|
// tensor outputs : 1
|
|
// Additionally:
|
|
// - has int[] size, ScalarType dtype, Layeout layout and Device device arguments
|
|
static const register_formula_for size_factories_with_options {{
|
|
"aten::empty(int[] size, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::full(int[] size, Scalar fill_value, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::ones(int[] size, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::rand(int[] size, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::randn(int[] size, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::zeros(int[] size, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::randint(int high, int[] size, *, int dtype, int layout, int[] device) -> Tensor",
|
|
"aten::randint(int low, int high, int[] size, *, int dtype, int layout, int[] device) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
if (auto maybe_size = node->get<std::vector<int64_t>>(attr::size)) {
|
|
return factory_with_ndim(node, maybe_size->size());
|
|
}
|
|
return {};
|
|
}};
|
|
|
|
static const auto get_cast_scalar_type = [](Node *node) -> at::ScalarType {
|
|
switch (node->kind()) {
|
|
case aten::_cast_Byte: return at::kByte;
|
|
case aten::_cast_Char: return at::kChar;
|
|
case aten::_cast_Double: return at::kDouble;
|
|
case aten::_cast_Float: return at::kFloat;
|
|
case aten::_cast_Half: return at::kHalf;
|
|
case aten::_cast_Int: return at::kInt;
|
|
case aten::_cast_Long: return at::kLong;
|
|
case aten::_cast_Short: return at::kShort;
|
|
default: AT_ASSERTM(false, "unknown node kind in get_cast_scalar_type: ", node->kind().toQualString());
|
|
}
|
|
};
|
|
static const register_formula_for cast_ops {{
|
|
"aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor",
|
|
"aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor",
|
|
"aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor",
|
|
"aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor",
|
|
"aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor",
|
|
"aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor",
|
|
"aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor",
|
|
"aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor",
|
|
}, [](Node * node) -> type_vec_t {
|
|
if (auto type = node->namedInput(attr::self)->type()->cast<TensorType>()) {
|
|
return {type->toScalarType(get_cast_scalar_type(node))};
|
|
}
|
|
return {};
|
|
}};
|
|
|
|
// First, try to match one of the registered formulas to their operator sets.
|
|
for (auto & entry : shape_formulas) {
|
|
if (entry.first.find(node)) {
|
|
auto types = entry.second(node);
|
|
if (types.empty()) {
|
|
return false;
|
|
} else {
|
|
auto outputs = node->outputs();
|
|
JIT_ASSERT(types.size() == outputs.size());
|
|
for (size_t i = 0; i < types.size(); ++i) {
|
|
JIT_ASSERT(outputs[i]->type()->isSubtypeOf(DynamicType::get()));
|
|
outputs[i]->setType(types[i]);
|
|
}
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// This section implements shape prop for an assorted set of nodes that only
|
|
// need partial information about their input types.
|
|
const auto input_type = [node](size_t index) {
|
|
return node->input(index)->type()->cast<TensorType>();
|
|
};
|
|
if (node->matches("aten::masked_select(Tensor self, Tensor mask) -> Tensor")) {
|
|
auto type = input_type(0);
|
|
auto mask_type = input_type(1);
|
|
if (type && mask_type) {
|
|
if (type->dim() == 0 && mask_type->dim() == 0) {
|
|
node->output()->setType(type->withDim(0));
|
|
} else {
|
|
node->output()->setType(type->withDim(1));
|
|
}
|
|
return true;
|
|
}
|
|
if (auto type = input_type(0)) {
|
|
node->output()->setType(type->withDim(1));
|
|
return true;
|
|
}
|
|
} else if (node->matches("aten::dot(Tensor self, Tensor tensor) -> Tensor")) {
|
|
if (auto type = any_tensor_type(node)) {
|
|
node->output()->setType(type->withDim(0));
|
|
return true;
|
|
}
|
|
} else if (node->matches("aten::mv(Tensor self, Tensor vec) -> Tensor") ||
|
|
node->matches("aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha) -> Tensor")) {
|
|
if (auto type = any_tensor_type(node)) {
|
|
node->output()->setType(type->withDim(1));
|
|
return true;
|
|
}
|
|
} else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor") ||
|
|
node->matches("aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor") ||
|
|
node->matches("aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha) -> Tensor")) {
|
|
if (auto type = any_tensor_type(node)) {
|
|
node->output()->setType(type->withDim(2));
|
|
return true;
|
|
}
|
|
} else if (node->matches("aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor")) {
|
|
if (auto type = any_tensor_type(node)) {
|
|
node->output()->setType(type->withDim(3));
|
|
return true;
|
|
}
|
|
} else if (node->matches("aten::index_select(Tensor self, int dim, Tensor index) -> Tensor")) {
|
|
auto type = input_type(0);
|
|
auto index_type = input_type(1);
|
|
// index_select behaves very weirdly when self.dim() == 0. It allows both 0D and 1D
|
|
// indices, and returns a value that has as many dimensions as index.
|
|
if (type && index_type) {
|
|
if (type->dim() == 0) {
|
|
node->output()->setType(type->withDim(index_type->dim()));
|
|
} else {
|
|
node->output()->setType(type);
|
|
}
|
|
return true;
|
|
}
|
|
} else if (node->matches("aten::gather(Tensor self, int dim, Tensor index) -> Tensor")) {
|
|
auto type = input_type(0);
|
|
auto index_type = input_type(1);
|
|
// Gather has this annoying edge case where index always needs to match the number of
|
|
// dims of self, **except** when self is 1D and index is 0D in which case we return
|
|
// a 0D output.
|
|
if (type && index_type) {
|
|
if (index_type->dim() == 0) {
|
|
node->output()->setType(type->withDim(0));
|
|
} else {
|
|
node->output()->setType(type);
|
|
}
|
|
return true;
|
|
}
|
|
} else if (node->matches("aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) {
|
|
auto weight_type = input_type(0);
|
|
auto indices_type = input_type(1);
|
|
if (weight_type && indices_type) {
|
|
node->output()->setType(weight_type->withDim(indices_type->dim() + 1));
|
|
return true;
|
|
}
|
|
} else if (node->matches("aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor bias) -> Tensor")) {
|
|
if (auto type = input_type(0)) {
|
|
node->output()->setType(type);
|
|
return true;
|
|
}
|
|
if (auto type = input_type(1)) {
|
|
node->output()->setType(type);
|
|
return true;
|
|
}
|
|
} else if (node->matches("aten::dist(Tensor self, Tensor other, Scalar p) -> Tensor")) {
|
|
if (auto type = any_tensor_type(node)) {
|
|
node->output()->setType(type->withDim(0));
|
|
return true;
|
|
}
|
|
}
|
|
|
|
// The code below implements formulas that need type information for all their
|
|
// tensor inputs, and have exactly one output.
|
|
std::vector<TensorTypePtr> tensor_types;
|
|
static const auto reshape_prop =
|
|
[](Node * node, Symbol shape_input, const std::vector<TensorTypePtr>& tensor_types) -> TensorTypePtr {
|
|
if (auto list_size = determineListSize(node->namedInput(shape_input))) {
|
|
return tensor_types.at(0)->withDim(*list_size);
|
|
}
|
|
return nullptr;
|
|
};
|
|
const auto getSingleOutputType = [&]() -> TypePtr {
|
|
if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
|
|
return tensor_types.at(0)->toScalarType(tensor_types.at(1)->scalarType());
|
|
} else if (node->matches("aten::view_as(Tensor self, Tensor other) -> Tensor") ||
|
|
node->matches("aten::expand_as(Tensor self, Tensor other) -> Tensor") ||
|
|
node->matches("aten::reshape_as(Tensor self, Tensor other) -> Tensor")) {
|
|
return tensor_types.at(0)->withDim(tensor_types.at(1)->dim());
|
|
} else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") ||
|
|
node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") ||
|
|
node->matches("aten::as_strided(Tensor self, int[] size, int[] stride) -> Tensor") ||
|
|
node->matches("aten::as_strided(Tensor self, int[] size, int[] stride, int storage_offset) -> Tensor")) {
|
|
return reshape_prop(node, attr::size, tensor_types);
|
|
} else if (node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) {
|
|
return reshape_prop(node, attr::shape, tensor_types);
|
|
} else if (node->matches("aten::repeat(Tensor self, int[] repeats) -> Tensor")) {
|
|
return reshape_prop(node, attr::repeats, tensor_types);
|
|
} else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) {
|
|
auto & t = tensor_types.at(0);
|
|
return t->withDim(t->dim() + 1);
|
|
} else if (node->matches("aten::select(Tensor self, int dim, int index) -> Tensor") ||
|
|
node->matches("aten::diagonal(Tensor self, int offset, int dim1, int dim2) -> Tensor")) {
|
|
auto & t = tensor_types.at(0);
|
|
return t->dim() > 0 ? t->withDim(t->dim() - 1) : nullptr;
|
|
} else if (node->matches("aten::matmul(Tensor self, Tensor other) -> Tensor")) {
|
|
int dim1 = tensor_types.at(0)->dim();
|
|
int dim2 = tensor_types.at(1)->dim();
|
|
if (dim1 == 1 && dim2 == 1) {
|
|
// Dot product
|
|
return tensor_types.at(0)->withDim(0);
|
|
} else if (dim1 == 2 && dim2 == 2) {
|
|
// Matrix multiply
|
|
return tensor_types.at(0);
|
|
} else if (dim1 == 1 && dim2 == 2) {
|
|
// Unsqueeze + matrix multiply + squeeze
|
|
return tensor_types.at(0);
|
|
} else if (dim1 == 2 && dim2 == 1) {
|
|
// Matrix vector multiply
|
|
return tensor_types.at(1);
|
|
} else {
|
|
// Batched matrix multiply (possibly with squeeze + unsqueeze if one argument is 1D)
|
|
auto type = broadcast(tensor_types);
|
|
if (tensor_types.at(0)->dim() == 1 || tensor_types.at(1)->dim() == 1) {
|
|
type = type->withDim(type->dim() - 1);
|
|
}
|
|
return type;
|
|
}
|
|
} else if (node->matches("aten::nonzero(Tensor self) -> Tensor")) {
|
|
return tensor_types.at(0)->toScalarType(at::kLong);
|
|
} else if (node->matches("aten::take(Tensor self, Tensor index) -> Tensor")) {
|
|
return tensor_types.at(1)->toScalarType(tensor_types.at(0)->scalarType());
|
|
} else if (node->matches("aten::diagflat(Tensor self, int offset) -> Tensor")) {
|
|
return tensor_types.at(0)->withDim(2);
|
|
} else if (node->matches("aten::diag(Tensor self, int diagonal) -> Tensor")) {
|
|
auto & t = tensor_types.at(0);
|
|
if (t->dim() == 1) {
|
|
return t->withDim(2);
|
|
} else if (t->dim() == 2) {
|
|
return t->withDim(1);
|
|
} else {
|
|
return nullptr;
|
|
}
|
|
} else if (node->matches("aten::unfold(Tensor self, int dimension, int size, int step) -> Tensor")) {
|
|
auto & t = tensor_types.at(0);
|
|
return t->dim() == 0 ? t : t->withDim(t->dim() + 1);
|
|
} else if (node->matches("aten::polygamma(int n, Tensor self) -> Tensor")) {
|
|
return tensor_types.at(0);
|
|
}
|
|
return nullptr;
|
|
};
|
|
if (auto maybe_tensor_types = gatherTensorTypes<TensorType>(node)) {
|
|
tensor_types = std::move(*maybe_tensor_types);
|
|
} else {
|
|
return false;
|
|
}
|
|
if (node->outputs().size() == 1) {
|
|
if (auto type = getSingleOutputType()) {
|
|
node->output()->setType(type);
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
bool PropagateCompleteShapeOnNode(Node * node, bool insert_expands,
|
|
std::vector<CompleteTensorTypePtr> tensor_types) {
|
|
// For expensive ops we can directly encode their shape propagation
|
|
// here, otherwise we fallback to running a fake version of the op
|
|
// to get a quick and dirty propagation.
|
|
if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
|
|
node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") ||
|
|
node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) {
|
|
// These nodes and "div" handle tensors of different shapes internally,
|
|
// so there's no need to insert explicit expand nodes. Note that "div" is
|
|
// handled by the fallthrough because it's not always safe to run it due
|
|
// to integer divide-by-zero.
|
|
return PropagateShapeOnNodeByRunningIt(node);
|
|
} else if (node->matches("aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
|
|
node->matches("aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") ||
|
|
node->matches("aten::mul(Tensor self, Scalar other) -> Tensor") ||
|
|
node->matches("aten::pow(Tensor self, Scalar exponent) -> Tensor")) {
|
|
node->output()->setType(tensor_types.at(0));
|
|
return true;
|
|
} else if (insert_expands && (
|
|
node->matches("aten::pow(Tensor self, Tensor exponent) -> Tensor") ||
|
|
node->matches("aten::min(Tensor self, Tensor other) -> Tensor") ||
|
|
node->matches("aten::max(Tensor self, Tensor other) -> Tensor") ||
|
|
node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") ||
|
|
node->matches("aten::le(Tensor self, Tensor other) -> Tensor") ||
|
|
node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") ||
|
|
node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") ||
|
|
node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") ||
|
|
node->matches("aten::ne(Tensor self, Tensor other) -> Tensor"))) {
|
|
// Binary broadcasting ops
|
|
// NB: we don't handle the nodes in any other way (note the lack of return!),
|
|
// because the type casting logic in scalar cases is non-trivial.
|
|
// It's better to just run them.
|
|
broadcastBinary(node, tensor_types, 0, 1);
|
|
return PropagateShapeOnNodeByRunningIt(node);
|
|
} else if (node->matches("aten::neg(Tensor self) -> Tensor") ||
|
|
node->matches("aten::sigmoid(Tensor self) -> Tensor") ||
|
|
node->matches("aten::tanh(Tensor self) -> Tensor")) {
|
|
node->output()->setType(tensor_types.at(0)->contiguous());
|
|
return true;
|
|
} else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) {
|
|
auto lhs_type = tensor_types.at(0);
|
|
auto rhs_type = tensor_types.at(1);
|
|
SHAPE_ASSERT(lhs_type->sizes().size() == 2 && rhs_type->sizes().size() == 2);
|
|
node->output()->setType(CompleteTensorType::create(
|
|
lhs_type->scalarType(), lhs_type->device(),
|
|
at::IntList{lhs_type->sizes().at(0), rhs_type->sizes().at(1)}));
|
|
return true;
|
|
} else if (node->matches("aten::t(Tensor self) -> Tensor")) {
|
|
auto tp = tensor_types.at(0);
|
|
auto sizes = tp->sizes();
|
|
auto strides = tp->strides();
|
|
SHAPE_ASSERT(sizes.size() == 2);
|
|
std::swap(sizes.at(0), sizes.at(1));
|
|
std::swap(strides.at(0), strides.at(1));
|
|
node->output()->setType(tp->withSizesStrides(sizes, strides));
|
|
return true;
|
|
} else if (node->matches("aten::narrow(Tensor self, int dim, int start, int length) -> Tensor",
|
|
/*const_inputs=*/{attr::dim, attr::length})) {
|
|
auto tp = tensor_types.at(0);
|
|
auto sizes = tp->sizes();
|
|
int64_t dim = node->get<int64_t>(attr::dim).value();
|
|
int64_t length = node->get<int64_t>(attr::length).value();
|
|
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
|
|
sizes.at(dim) = length;
|
|
node->output()->setType(tp->withSizesStrides(sizes, tp->strides()));
|
|
return true;
|
|
} else if (node->matches("aten::sum(Tensor self) -> Tensor")) {
|
|
node->output()->setType(tensor_types.at(0)->withSizes({}));
|
|
return true;
|
|
} else if (node->matches("aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor",
|
|
/*const_inputs=*/{attr::dim, attr::keepdim})) {
|
|
auto & tp = tensor_types.at(0);
|
|
auto sizes = tp->sizes();
|
|
auto dims = node->get<std::vector<int64_t>>(attr::dim).value();
|
|
bool keepdim = node->get<bool>(attr::keepdim).value();
|
|
std::reverse(dims.begin(), dims.end());
|
|
for (int64_t dim : dims) {
|
|
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
|
|
if (keepdim) {
|
|
sizes.at(dim) = 1;
|
|
} else {
|
|
sizes.erase(sizes.begin() + dim);
|
|
}
|
|
}
|
|
node->output()->setType(tp->withSizes(sizes));
|
|
return true;
|
|
} else if (node->matches("aten::squeeze(Tensor self, int dim) -> Tensor", /*const_inputs=*/attr::dim)) {
|
|
auto & tp = tensor_types.at(0);
|
|
auto sizes = tp->sizes();
|
|
auto strides = tp->strides();
|
|
int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
|
|
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < sizes.size());
|
|
if (sizes.at(dim) == 1) {
|
|
sizes.erase(sizes.begin() + dim);
|
|
strides.erase(strides.begin() + dim);
|
|
}
|
|
node->output()->setType(tp->withSizesStrides(sizes, strides));
|
|
return true;
|
|
} else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor", /*const_inputs=*/attr::dim)) {
|
|
auto & tp = tensor_types.at(0);
|
|
auto sizes = tp->sizes();
|
|
auto strides = tp->strides();
|
|
int64_t dim = wrapDim(node->get<int64_t>(attr::dim).value(), sizes);
|
|
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) <= sizes.size());
|
|
int64_t new_stride = dim >= static_cast<int64_t>(sizes.size()) ? 1 : sizes.at(dim) * strides.at(dim);
|
|
sizes.insert(sizes.begin() + dim, 1);
|
|
strides.insert(strides.begin() + dim, new_stride);
|
|
node->output()->setType(tp->withSizesStrides(sizes, strides));
|
|
return true;
|
|
} else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor", /*const_inputs=*/attr::size)) {
|
|
auto sizes = node->get<std::vector<int64_t>>(attr::size).value();
|
|
bool inferred = false;
|
|
size_t inferred_idx;
|
|
int64_t size_product = 1;
|
|
for (size_t i = 0; i < sizes.size(); ++i) {
|
|
if (sizes[i] == -1) {
|
|
if (inferred) throw propagation_error();
|
|
inferred = true;
|
|
inferred_idx = i;
|
|
} else {
|
|
size_product *= sizes[i];
|
|
}
|
|
}
|
|
|
|
if (inferred) {
|
|
SHAPE_ASSERT(size_product != 0);
|
|
size_t numel = 1;
|
|
for (int64_t s : tensor_types.at(0)->sizes())
|
|
numel *= s;
|
|
int64_t inferred_size = numel / size_product;
|
|
sizes[inferred_idx] = inferred_size;
|
|
}
|
|
node->output()->setType(tensor_types.at(0)->withSizes(sizes));
|
|
return true;
|
|
} else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) {
|
|
if (tensor_types.at(0)->scalarType() == tensor_types.at(1)->scalarType()) {
|
|
node->output()->setType(node->namedInput(attr::self)->type());
|
|
} else {
|
|
// This will be a copy, so the result will be contiguous
|
|
node->output()->setType(tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes()));
|
|
}
|
|
return true;
|
|
} else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor",
|
|
/*const_inputs=*/attr::size)) {
|
|
auto tp = tensor_types.at(0);
|
|
std::vector<int64_t> sizes, strides;
|
|
std::tie(sizes, strides) = at::inferExpandGeometry(
|
|
tp->sizes(), tp->strides(), node->get<std::vector<int64_t>>(attr::size).value());
|
|
node->output()->setType(tp->withSizesStrides(sizes, strides));
|
|
return true;
|
|
} else if (node->matches("aten::index_select(Tensor self, int dim, Tensor index) -> Tensor",
|
|
/*const_inputs=*/attr::dim)) {
|
|
auto ten = tensor_types.at(0);
|
|
auto index = tensor_types.at(1);
|
|
int64_t dim = node->get<int64_t>(attr::dim).value();
|
|
SHAPE_ASSERT(index->sizes().size() == 1);
|
|
SHAPE_ASSERT(dim >= 0 && static_cast<size_t>(dim) < ten->sizes().size());
|
|
std::vector<int64_t> sizes = ten->sizes();
|
|
sizes[dim] = index->sizes()[0];
|
|
node->output()->setType(ten->withSizes(sizes));
|
|
return true;
|
|
} else if (node->matches("aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]",
|
|
/*const_inputs=*/{attr::chunks, attr::dim})) {
|
|
auto input_type = tensor_types.at(0);
|
|
auto sizes = input_type->sizes();
|
|
const auto & strides = input_type->strides();
|
|
int64_t dim = node->get<int64_t>(attr::dim).value();
|
|
int64_t chunks = node->get<int64_t>(attr::chunks).value();
|
|
sizes[dim] /= chunks;
|
|
for (Value * output : node->outputs()) {
|
|
output->setType(input_type->withSizesStrides(sizes, strides));
|
|
}
|
|
if (input_type->sizes().at(dim) % chunks != 0) {
|
|
sizes[dim] = input_type->sizes().at(dim) % chunks;
|
|
node->outputs().back()->setType(input_type->withSizesStrides(sizes, strides));
|
|
}
|
|
return true;
|
|
} else if (node->kind() == onnx::Shape) {
|
|
SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1);
|
|
std::vector<int64_t> dim_vec = {(int64_t)tensor_types.at(0)->sizes().size()};
|
|
at::IntList dims(dim_vec);
|
|
node->output()->setType(
|
|
CompleteTensorType::create(at::kLong, -1, dims));
|
|
return true;
|
|
} else if (node->kind() == onnx::Reshape) {
|
|
setUnshapedType(node);
|
|
return true;
|
|
}
|
|
setUnshapedType(node);
|
|
return false;
|
|
}
|
|
|
|
void PropagateShapeOnBlock(Block * block, bool insert_expands) {
|
|
for (Node * node : block->nodes()) {
|
|
try {
|
|
PropagateShapeOnNode(node, insert_expands);
|
|
} catch(propagation_error& e) {
|
|
setUnshapedType(node);
|
|
} catch(std::exception & e) {
|
|
if(auto sl = node->getSourceLocation()) {
|
|
sl->wrapAndRethrowException(e, "operation failed shape propagation");
|
|
} else {
|
|
throw;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
void PropagateInputShapes(Graph & graph) {
|
|
PropagateShapeOnBlock(graph.block());
|
|
}
|
|
|
|
namespace {
|
|
|
|
void EraseShapeInformation(at::ArrayRef<Value*> vals) {
|
|
for (Value * v : vals) {
|
|
v->setType(unshapedType(v->type()));
|
|
}
|
|
}
|
|
|
|
void EraseShapeInformation(Block * b) {
|
|
EraseShapeInformation(b->inputs());
|
|
EraseShapeInformation(b->outputs());
|
|
for (Node * n : b->nodes()) {
|
|
EraseShapeInformation(n->outputs());
|
|
for (Block *sb : n->blocks()) {
|
|
EraseShapeInformation(sb);
|
|
}
|
|
}
|
|
}
|
|
|
|
} // anonymous namespace
|
|
|
|
void EraseShapeInformation(Graph & graph) {
|
|
EraseShapeInformation(graph.block());
|
|
}
|
|
|
|
}}
|