#include "torch/csrc/jit/passes/shape_analysis.h" #include "torch/csrc/jit/ir.h" #include "torch/csrc/jit/constants.h" #include "torch/csrc/jit/argument_spec.h" #include "torch/csrc/jit/operator.h" #include "torch/csrc/jit/assertions.h" #include "torch/csrc/autograd/variable.h" #include #include #include #include #include #include #include namespace torch { namespace jit { struct propagation_error : std::exception {}; #define SHAPE_ASSERT(cond) if (!(cond)) throw propagation_error() namespace { void setUnshapedType(Node * node) { for(auto o : node->outputs()) { o->setType(unshapedType(o->type())); } } int64_t wrapDim(int64_t dim, at::IntList sizes) { if (dim < 0) { dim += sizes.size(); } return dim; } IValue representativeValue(Value* v) { TypePtr type_ = v->type(); // if the value is actually constant, just use it! if(auto iv = toIValue(v)) { return *iv; } if (CompleteTensorTypePtr type = type_->cast()) { auto backend = type->device() == -1 ? at::Backend::CPU : at::Backend::CUDA; at::DeviceGuard device_guard(type->device()); auto& attype = at::getNonVariableType(backend, type->scalarType()); auto t = at::empty_strided(type->sizes(), type->strides(), attype.options()).zero_(); return autograd::make_variable(t, /*requires_grad=*/false); } else if (type_->isSubtypeOf(FloatType::get())) { return 0.f; } // we should not get here because isValidArgumentForRunning should have // prevented it std::stringstream ss; ss << "unable to create representative value for: " << type_->str() << ". File a bug report."; throw std::runtime_error(ss.str()); } void PropagateShapeOnBlock(Block * block, bool insert_expands=true); // for each node in the schema with type Tensor, extract the T type // returns c10::nullopt if any Tensor in the schema does not have a known shape // ignores non-tensor in the list of inputs template c10::optional>> gatherTensorTypes(Node* node) { std::vector> tensor_types; auto & schema = node->schema(); auto & args = schema.arguments(); // can't handle varargs primitives because we don't know what should be a Tensor if (schema.is_vararg()) { return c10::nullopt; } for (size_t i = 0; i < args.size(); ++i) { if (args[i].type()->isSubtypeOf(ListType::ofTensors())) { return c10::nullopt; } else if (args[i].type()->isSubtypeOf(DynamicType::get())) { if (auto type = node->input(i)->type()->cast()) { tensor_types.push_back(type); } else { return c10::nullopt; } } else /* non-tensor type */ { continue; } } return tensor_types; } bool mergeTypes(ArrayRef lhs, ArrayRef rhs, ArrayRef outputs) { JIT_ASSERT(lhs.size() == rhs.size() && rhs.size() == outputs.size()); bool changed = false; for(size_t i = 0; i < lhs.size(); ++i) { auto old_output_type = outputs[i]->type(); auto new_type = unifyTypes(lhs[i]->type(), rhs[i]->type()); JIT_ASSERT(new_type); outputs[i]->setType(*new_type); if(*old_output_type != *outputs[i]->type()) changed = true; } return changed; } void PropagateShapeOnNode(Node * node, bool insert_expands=true); void broadcastBinary(Node *node, std::vector& types, size_t idx1, size_t idx2) { auto expected_size = at::infer_size(types[idx1]->sizes(), types[idx2]->sizes()); auto broadcast = [&](size_t input_idx) { CompleteTensorTypePtr input_type = types.at(input_idx); if (input_type->sizes() == expected_size) return; auto graph = node->owningGraph(); WithInsertPoint point_guard { node }; Node *expand = graph->create(aten::expand, {node->inputs().at(input_idx), graph->insertConstant(expected_size), graph->insertConstant(false)}) ->insertBefore(node); PropagateShapeOnNode(expand); node->replaceInput(input_idx, expand->output()); }; broadcast(idx1); broadcast(idx2); types[0] = node->inputs().at(idx1)->type()->expect(); types[1] = node->inputs().at(idx2)->type()->expect(); } bool isValidArgumentForRunning(Value* v) { // allow constants if(toIValue(v)) return true; if(CompleteTensorTypePtr tt = v->type()->cast()) { return !at::isIntegralType(tt->scalarType()); } return v->type()->isSubtypeOf(FloatType::get()); } bool isValidReturnForRunning(Value* v) { return v->type()->isSubtypeOf(DynamicType::get()) || v->type()->isSubtypeOf(NumberType::get()); } OperatorSet cannot_propagate_shape_by_running_it = { "aten::gesv(Tensor self, Tensor A) -> (Tensor, Tensor)", "aten::inverse(Tensor self) -> Tensor", }; bool canPropagateShapeByRunningIt(Node* node) { if(cannot_propagate_shape_by_running_it.find(node)) { return false; } bool valid_args = std::all_of( node->inputs().begin(), node->inputs().end(), isValidArgumentForRunning); if (!valid_args) return false; bool valid_returns = std::all_of( node->outputs().begin(), node->outputs().end(), isValidReturnForRunning); if (!valid_returns) return false; return true; } bool PropagateShapeOnNodeByRunningIt(Node* node) { if (!canPropagateShapeByRunningIt(node)) return false; auto op = getOperation(node); Stack stack; for (auto input : node->inputs()) { stack.push_back(representativeValue(input)); } // XXX: we're not catching any exceptions from the op for now. This // is to uncover any mistakes we could make when editing this code, // and eventually it shouldn't matter, because this phase should be // preceded by schema checking. op(stack); JIT_ASSERT(stack.size() == node->outputs().size()); for (size_t i = 0; i < stack.size(); ++i) { // some ops may have mixed tensor/primitive outputs // for primitives, we don't need to change the type because it is already // its most constrained form. if(stack[i].isTensor()) node->outputs()[i]->inferTypeFrom(stack[i].toTensor()); } return true; } // is it ok to try to run the op // If an input is a constant, then we assume that the input is valid // and we can try to run it. // Otherwise: // Integral typed _inputs_ are often an indicator that we're indexing into // a tensor, so we should special-case these ops in the shape propagation. // Additionally, passing in a zero representative tensor into an integer // division op causes divide-by-zero errors // _Outputs_ must be tensors or primtives // We will call inferTypeFrom on the tensors, and ignore the primitives. // However, we allow primitive returns because we want to support mixed // primitive/tensor outputs. bool PropagateTensorShapeOnNode(Node * node, bool insert_expands); bool PropagateCompleteShapeOnNode( Node * node, bool insert_expands, std::vector types); void PropagateCatShape(Node * cat_node) { static const auto propagate_complete = [](Node * node, at::ArrayRef tensors) -> bool { auto input_types = fmap(tensors, [](Value *v) { return v->type()->cast(); }); if (!std::all_of(input_types.begin(), input_types.end(), [](const CompleteTensorTypePtr& tp) { return tp != nullptr; })) { return false; } if (!node->is_constant(attr::dim)) return false; std::vector sizes = input_types[0]->sizes(); const int64_t dim = wrapDim(node->get(attr::dim).value(), sizes); const int64_t ndim = sizes.size(); if (dim < 0 || dim >= ndim) return false; sizes[dim] = 0; for (auto & tp : input_types) { auto & tp_sizes = tp->sizes(); if (sizes.size() != tp_sizes.size()) return false; for (int64_t i = 0; i < ndim; ++i) { if (sizes[i] != tp_sizes[i] && i != dim) { return false; } } sizes[dim] += tp_sizes[dim]; } node->output()->setType(input_types[0]->withSizes(sizes)); return true; }; static const auto propagate = [](Node * node, at::ArrayRef tensors) -> bool { for (Value * v : tensors) { if (auto type = v->type()->cast()) { node->output()->setType(type); return true; } } return false; }; auto list_node = cat_node->namedInput(attr::tensors)->node(); if (list_node->kind() == prim::ListConstruct) { auto tensors = list_node->inputs(); if (!tensors.empty()) { if (propagate_complete(cat_node, tensors)) { return; } else if (propagate(cat_node, tensors)) { return; } } } setUnshapedType(cat_node); } void PropagateShapeOnNode(Node * node, bool insert_expands) { // These don't require the types, and have complicated schema. Return early after we process them. switch(node->kind()) { case prim::If: { auto then_block = node->blocks().at(0); auto else_block = node->blocks().at(1); PropagateShapeOnBlock(then_block); PropagateShapeOnBlock(else_block); mergeTypes(then_block->outputs(), else_block->outputs(), node->outputs()); return; } case prim::Loop: { auto body_block = node->blocks().at(0); // propagate counter type body_block->inputs().at(0)->setType(node->inputs().at(0)->type()); // propagate loop-carried input types to block inputs auto loop_carried_inputs = node->inputs().slice(2); // skip max, cond auto loop_carried_block = body_block->inputs().slice(1); // skip trip for(size_t i = 0; i < loop_carried_inputs.size(); ++i) { loop_carried_block[i]->setType(loop_carried_inputs[i]->type()); } auto loop_carried_outputs = body_block->outputs().slice(1); // skip cond do { PropagateShapeOnBlock(body_block, /*insert_expands=*/false); // note: inserting expands is unsafe at this point, we don't know // if the types are stable yet, so the arguments to expand may change } while(mergeTypes(loop_carried_block, loop_carried_outputs, loop_carried_block)); // now that the types are stable, we can insert the expands PropagateShapeOnBlock(body_block, /*insert_expands=*/true); for(size_t i = 0; i < loop_carried_inputs.size(); ++i) { node->outputs()[i]->setType(loop_carried_block[i]->type()); } return; } case prim::ImplicitTensorToNum: case prim::TensorToNum: return; // correct num type is already set case prim::NumToTensor: { if (node->input()->type()->isSubtypeOf(IntType::get())) { node->output()->setType(TensorType::create(at::kLong, -1, 0)); } else { JIT_ASSERT(node->input()->type()->isSubtypeOf(FloatType::get())); node->output()->setType(TensorType::create(at::kDouble, -1, 0)); } return; } case prim::TupleConstruct: { // We refresh the tuple type, because the input types could have been refined. node->output()->setType(TupleType::create(fmap(node->inputs(), [](Value *v) { return v->type(); }))); return; } case prim::TupleUnpack: { auto tuple_type = node->input()->type()->cast(); JIT_ASSERT(tuple_type && tuple_type->elements().size() == node->outputs().size()); auto elems = tuple_type->elements(); for (size_t i = 0; i < node->outputs().size(); ++i) { node->output(i)->setType(elems[i]); } return; } case prim::Constant: { if(node->output()->type()->isSubtypeOf(DynamicType::get())) { node->output()->inferTypeFrom(node->t(attr::value)); } return; } case prim::ConstantChunk: { Value *tensor = node->input(); if (auto type = tensor->type()->cast()) { for (Value * output : node->outputs()) { output->setType(type); } } else { setUnshapedType(node); } return; } case prim::PythonOp: case prim::Print: case prim::RaiseException: case prim::Undefined: { setUnshapedType(node); return; } default: break; // fall-through } if (node->matches("aten::cat(Tensor[] tensors, int dim) -> Tensor")) { return PropagateCatShape(node); } if (auto maybe_complete_types = gatherTensorTypes(node)) { if (PropagateCompleteShapeOnNode(node, insert_expands, std::move(*maybe_complete_types))) { return; } } if (PropagateTensorShapeOnNode(node, insert_expands)) { return; } if (PropagateShapeOnNodeByRunningIt(node)) { return; } return setUnshapedType(node); } static c10::optional determineListSize(Value* list) { JIT_ASSERT(list->type()->cast()); if (auto shape = constant_as>(list)) { return shape->size(); } auto input_node = list->node(); if (input_node->kind() == prim::ListConstruct) { return input_node->inputs().size(); } return c10::nullopt; } bool PropagateTensorShapeOnNode(Node * node, bool insert_expands) { static const auto broadcast = [](std::vector& tensor_types) -> TensorTypePtr { if (tensor_types.size() == 1) { return tensor_types[0]; } JIT_ASSERT(!tensor_types.empty()); auto any_type = tensor_types[0]; auto max_dims = any_type->dim(); for (auto & type : tensor_types) { max_dims = std::max(max_dims, type->dim()); } return TensorType::create(any_type->scalarType(), any_type->device(), max_dims); }; using type_vec_t = std::vector; // Formula is expected to return a vector of length equal to the number of tensor // outputs of the node, or an empty vector which implies that it failed to propagate. using formula_t = std::function; static std::mutex shape_formulas_mutex; static std::vector> shape_formulas; struct register_formula_for { register_formula_for(OperatorSet operators, formula_t formula) { std::unique_lock lock {shape_formulas_mutex}; shape_formulas.emplace_back(std::move(operators), std::move(formula)); } }; // Requirements: // dims : preserved // scalar type : preserved // device : preserved // tensor inputs : 1 // tensor outputs : 1 // Additionally: // - First input should be the only tensor input static const register_formula_for simple_unary_ops {{ "aten::abs(Tensor self) -> Tensor", "aten::acos(Tensor self) -> Tensor", "aten::neg(Tensor self) -> Tensor", "aten::t(Tensor self) -> Tensor", "aten::sigmoid(Tensor self) -> Tensor", "aten::tanh(Tensor self) -> Tensor", "aten::exp(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", "aten::asin(Tensor self) -> Tensor", "aten::atan(Tensor self) -> Tensor", "aten::ceil(Tensor self) -> Tensor", "aten::clone(Tensor self) -> Tensor", "aten::contiguous(Tensor self) -> Tensor", "aten::bernoulli(Tensor self, *, Generator generator) -> Tensor", "aten::celu(Tensor self, Scalar alpha) -> Tensor", "aten::clamp(Tensor self, Scalar? min, Scalar? max) -> Tensor", "aten::clamp_max(Tensor self, Scalar max) -> Tensor", "aten::clamp_min(Tensor self, Scalar min) -> Tensor", "aten::alpha_dropout(Tensor input, float p, bool train) -> Tensor", "aten::bernoulli(Tensor self, float p, *, Generator generator) -> Tensor", "aten::cos(Tensor self) -> Tensor", "aten::cosh(Tensor self) -> Tensor", "aten::digamma(Tensor self) -> Tensor", "aten::dropout(Tensor input, float p, bool train) -> Tensor", "aten::elu(Tensor self, Scalar alpha, Scalar scale, Scalar input_scale) -> Tensor", "aten::erf(Tensor self) -> Tensor", "aten::erfc(Tensor self) -> Tensor", "aten::erfinv(Tensor self) -> Tensor", "aten::exp(Tensor self) -> Tensor", "aten::expm1(Tensor self) -> Tensor", "aten::log(Tensor self) -> Tensor", "aten::log10(Tensor self) -> Tensor", "aten::log1p(Tensor self) -> Tensor", "aten::log2(Tensor self) -> Tensor", "aten::log_sigmoid(Tensor self) -> Tensor", "aten::log_softmax(Tensor self, int dim) -> Tensor", "aten::floor(Tensor self) -> Tensor", "aten::frac(Tensor self) -> Tensor", "aten::flip(Tensor self, int[] dims) -> Tensor", "aten::feature_alpha_dropout(Tensor input, float p, bool train) -> Tensor", "aten::feature_dropout(Tensor input, float p, bool train) -> Tensor", "aten::hardshrink(Tensor self, Scalar lambd) -> Tensor", "aten::hardtanh(Tensor self, Scalar min_val, Scalar max_val) -> Tensor", "aten::glu(Tensor self, int dim) -> Tensor", "aten::inverse(Tensor self) -> Tensor", "aten::leaky_relu(Tensor self, Scalar negative_slope) -> Tensor", "aten::lgamma(Tensor self) -> Tensor", "aten::mvlgamma(Tensor self, int p) -> Tensor", "aten::normal(float mean, Tensor std, *, Generator generator) -> Tensor", "aten::normal(Tensor mean, float std, *, Generator generator) -> Tensor", "aten::permute(Tensor self, int[] dims) -> Tensor", "aten::pin_memory(Tensor self) -> Tensor", "aten::pinverse(Tensor self, float rcond) -> Tensor", "aten::reciprocal(Tensor self) -> Tensor", "aten::relu(Tensor self) -> Tensor", "aten::round(Tensor self) -> Tensor", "aten::rrelu(Tensor self, Scalar lower, Scalar upper, bool training, Generator generator) -> Tensor", "aten::rsqrt(Tensor self) -> Tensor", "aten::selu(Tensor self) -> Tensor", "aten::sigmoid(Tensor self) -> Tensor", "aten::sign(Tensor self) -> Tensor", "aten::sin(Tensor self) -> Tensor", "aten::sinh(Tensor self) -> Tensor", "aten::softmax(Tensor self, int dim) -> Tensor", "aten::softplus(Tensor self, Scalar beta, Scalar threshold) -> Tensor", "aten::softshrink(Tensor self, Scalar lambd) -> Tensor", "aten::sqrt(Tensor self) -> Tensor", "aten::tan(Tensor self) -> Tensor", "aten::tanh(Tensor self) -> Tensor", "aten::threshold(Tensor self, Scalar threshold, Scalar value) -> Tensor", "aten::transpose(Tensor self, int dim0, int dim1) -> Tensor", "aten::tril(Tensor self, int diagonal) -> Tensor", "aten::triu(Tensor self, int diagonal) -> Tensor", "aten::trunc(Tensor self) -> Tensor", "aten::rot90(Tensor self, int k, int[] dims) -> Tensor", "aten::narrow(Tensor self, int dim, int start, int length) -> Tensor", "aten::slice(Tensor self, int dim, int start, int end, int step) -> Tensor", "aten::alias(Tensor self) -> Tensor", "aten::detach(Tensor self) -> Tensor", "aten::cumprod(Tensor self, int dim) -> Tensor", "aten::cumsum(Tensor self, int dim) -> Tensor", "aten::empty_like(Tensor self) -> Tensor", "aten::full_like(Tensor self, Scalar fill_value) -> Tensor", "aten::ones_like(Tensor self) -> Tensor", "aten::rand_like(Tensor self) -> Tensor", "aten::randint_like(Tensor self, int high) -> Tensor", "aten::randint_like(Tensor self, int low, int high) -> Tensor", "aten::randn_like(Tensor self) -> Tensor", "aten::zeros_like(Tensor self) -> Tensor", }, [](Node * node) -> type_vec_t { auto input_type = node->input(0)->type()->cast(); return input_type ? type_vec_t{input_type} : type_vec_t{}; }}; // Requirements: // dims : broadcast all tensor args // scalar type : always matching and preserved // device : always matching and preserved // tensor inputs : * // tensor outputs : 1 static const register_formula_for broadcasting_ops {{ // Tensor-Tensor operators "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", "aten::mul(Tensor self, Tensor other) -> Tensor", "aten::div(Tensor self, Tensor other) -> Tensor", "aten::pow(Tensor self, Tensor exponent) -> Tensor", "aten::min(Tensor self, Tensor other) -> Tensor", "aten::max(Tensor self, Tensor other) -> Tensor", "aten::fmod(Tensor self, Tensor other) -> Tensor", "aten::remainder(Tensor self, Tensor other) -> Tensor", "aten::lerp(Tensor self, Tensor end, Scalar weight) -> Tensor", "aten::max(Tensor self, Tensor other) -> Tensor", "aten::min(Tensor self, Tensor other) -> Tensor", "aten::__and__(Tensor self, Tensor other) -> Tensor", "aten::__or__(Tensor self, Tensor other) -> Tensor", "aten::__xor__(Tensor self, Tensor other) -> Tensor", "aten::__lshift__(Tensor self, Tensor other) -> Tensor", "aten::__rshift__(Tensor self, Tensor other) -> Tensor", "aten::__iand__(Tensor self, Tensor other) -> Tensor", "aten::__ior__(Tensor self, Tensor other) -> Tensor", "aten::__ixor__(Tensor self, Tensor other) -> Tensor", "aten::__ilshift__(Tensor self, Tensor other) -> Tensor", "aten::__irshift__(Tensor self, Tensor other) -> Tensor", // Tensor-Scalar operators "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", "aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor", "aten::mul(Tensor self, Scalar other) -> Tensor", "aten::div(Tensor self, Scalar other) -> Tensor", "aten::pow(Tensor self, Scalar exponent) -> Tensor", "aten::fmod(Tensor self, Scalar other) -> Tensor", "aten::remainder(Tensor self, Scalar other) -> Tensor", "aten::pow(Scalar self, Tensor exponent) -> Tensor", "aten::__and__(Tensor self, Scalar other) -> Tensor", "aten::__or__(Tensor self, Scalar other) -> Tensor", "aten::__xor__(Tensor self, Scalar other) -> Tensor", "aten::__lshift__(Tensor self, Scalar other) -> Tensor", "aten::__rshift__(Tensor self, Scalar other) -> Tensor", "aten::__iand__(Tensor self, Scalar other) -> Tensor", "aten::__ior__(Tensor self, Scalar other) -> Tensor", "aten::__ixor__(Tensor self, Scalar other) -> Tensor", "aten::__ilshift__(Tensor self, Scalar other) -> Tensor", "aten::__irshift__(Tensor self, Scalar other) -> Tensor", // Ops with Tensor-Tensor overloads only "aten::atan2(Tensor self, Tensor other) -> Tensor", // Non-binary ops "aten::where(Tensor condition, Tensor self, Tensor other) -> Tensor", "aten::addcdiv(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor", "aten::addcmul(Tensor self, Tensor tensor1, Tensor tensor2, *, Scalar value) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto maybe_tensor_types = gatherTensorTypes(node)) { return {broadcast(*maybe_tensor_types)}; } return {}; }}; static const auto any_tensor_type = [](Node * node) -> TensorTypePtr { for (Value * input : node->inputs()) { if (auto type = input->type()->cast()) { return type; } } return nullptr; }; // Requirements: // dims : always matching and preserved // scalar type : always matching and preserved // device : always matching and preserved // tensor inputs : 2 // tensor outputs : 1 static const register_formula_for binary_ops_strict_match {{ "aten::normal(Tensor mean, Tensor std, *, Generator generator) -> Tensor", "aten::mm(Tensor self, Tensor mat2) -> Tensor", "aten::bmm(Tensor self, Tensor mat2) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto type = any_tensor_type(node)) { return {type}; } return {}; }}; // Requirements: // dims : all tensor args are broadcast // scalar type : byte/uint8 // device : always matching and preserved // tensor inputs : * // tensor outputs : 1 static const register_formula_for comparison_ops {{ "aten::lt(Tensor self, Tensor other) -> Tensor", "aten::le(Tensor self, Tensor other) -> Tensor", "aten::gt(Tensor self, Tensor other) -> Tensor", "aten::ge(Tensor self, Tensor other) -> Tensor", "aten::eq(Tensor self, Tensor other) -> Tensor", "aten::ne(Tensor self, Tensor other) -> Tensor", "aten::lt(Tensor self, Scalar other) -> Tensor", "aten::le(Tensor self, Scalar other) -> Tensor", "aten::gt(Tensor self, Scalar other) -> Tensor", "aten::ge(Tensor self, Scalar other) -> Tensor", "aten::eq(Tensor self, Scalar other) -> Tensor", "aten::ne(Tensor self, Scalar other) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto maybe_tensor_types = gatherTensorTypes(node)) { return {broadcast(*maybe_tensor_types)->toScalarType(at::kByte)}; } return {}; }}; // Requirements: // dims : preserved from the first argument // scalar type : preserved from the first argument (doesn't have to match other arguments) // device : always matching and preserved // tensor inputs : * // tensor outputs : 1 // NB: those ops (with slight adjustments) are good candidates for restarts. // Knowing the type and device of weights or biases is usually enough to // infer the output type. static const register_formula_for nn_ops_first_input_preserving {{ "aten::batch_norm(Tensor input, Tensor weight, Tensor bias, Tensor running_mean, Tensor running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "aten::conv1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor", "aten::conv2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor", "aten::conv3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, int groups) -> Tensor", "aten::conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad) -> Tensor", "aten::conv_transpose1d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", "aten::conv_transpose2d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", "aten::conv_transpose3d(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] output_padding, int groups, int[] dilation) -> Tensor", "aten::convolution(Tensor input, Tensor weight, Tensor bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor", "aten::adaptive_avg_pool1d(Tensor self, int[] output_size) -> Tensor", "aten::adaptive_avg_pool2d(Tensor self, int[] output_size) -> Tensor", "aten::adaptive_avg_pool3d(Tensor self, int[] output_size) -> Tensor", "aten::avg_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", "aten::avg_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", "aten::avg_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, bool ceil_mode, bool count_include_pad) -> Tensor", "aten::max_pool1d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor", "aten::max_pool2d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor", "aten::max_pool3d(Tensor self, int[] kernel_size, int[] stride, int[] padding, int[] dilation, bool ceil_mode) -> Tensor", "aten::max_unpool2d(Tensor self, Tensor indices, int[] output_size) -> Tensor", "aten::max_unpool3d(Tensor self, Tensor indices, int[] output_size, int[] stride, int[] padding) -> Tensor", "aten::reflection_pad1d(Tensor self, int[] padding) -> Tensor", "aten::reflection_pad2d(Tensor self, int[] padding) -> Tensor", "aten::replication_pad1d(Tensor self, int[] padding) -> Tensor", "aten::replication_pad2d(Tensor self, int[] padding) -> Tensor", "aten::replication_pad3d(Tensor self, int[] padding) -> Tensor", "aten::upsample_bilinear2d(Tensor self, int[] output_size, bool align_corners) -> Tensor", "aten::upsample_linear1d(Tensor self, int[] output_size, bool align_corners) -> Tensor", "aten::upsample_nearest1d(Tensor self, int[] output_size) -> Tensor", "aten::upsample_nearest2d(Tensor self, int[] output_size) -> Tensor", "aten::upsample_nearest3d(Tensor self, int[] output_size) -> Tensor", "aten::upsample_trilinear3d(Tensor self, int[] output_size, bool align_corners) -> Tensor", "aten::prelu(Tensor self, Tensor weight) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto type = node->input(0)->type()->cast()) { return {type}; } return {}; }}; // Requirements: // dims : 0 // scalar type : preserved // device : preserved // tensor inputs : 1 // tensor outputs : 1 // Additionally: // - First input should be the only tensor input static const register_formula_for all_reduce_ops {{ "aten::argmax(Tensor self) -> Tensor", "aten::argmin(Tensor self) -> Tensor", "aten::det(Tensor self) -> Tensor", "aten::logdet(Tensor self) -> Tensor", "aten::max(Tensor self) -> Tensor", "aten::min(Tensor self) -> Tensor", "aten::mean(Tensor self) -> Tensor", "aten::median(Tensor self) -> Tensor", "aten::norm(Tensor self, Scalar p) -> Tensor", "aten::std(Tensor self, bool unbiased) -> Tensor", "aten::sum(Tensor self) -> Tensor", "aten::trace(Tensor self) -> Tensor", "aten::var(Tensor self, bool unbiased) -> Tensor", "aten::all(Tensor self) -> Tensor", "aten::any(Tensor self) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto type = node->input(0)->type()->cast()) { return {type->withDim(0)}; } return {}; }}; // Requirements: // dims : 0 // scalar type : preserved if floating point, otherwise long/int64 // device : preserved // tensor inputs : 1 // tensor outputs : 1 // Additionally: // - First input should be the only tensor input static const register_formula_for all_reduce_ops_with_integer_upcast {{ "aten::sum(Tensor self) -> Tensor", "aten::prod(Tensor self) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto type = node->input(0)->type()->cast()) { return {at::isFloatingType(type->scalarType()) ? type->withDim(0) : type->withDim(0)->toScalarType(at::kLong)}; } return {}; }}; static const auto multidim_reduce_with_postprocess = [](Node * node, size_t num_reduced_dim, bool upcast_integer) -> type_vec_t { auto maybe_keepdim = node->get(attr::keepdim); if (!maybe_keepdim) return {}; if (auto type = node->input(0)->type()->cast()) { if (upcast_integer && !at::isFloatingType(type->scalarType())) { type = type->toScalarType(at::kLong); } if (*maybe_keepdim) { return {type}; } else if (type->dim() > num_reduced_dim) { return {type->withDim(type->dim() - num_reduced_dim)}; } } return {}; }; // Requirements: // dims : preserved if keepdim == false, 1 smaller otherwise // scalar type : preserved for first output, byte/uint8 for second output if exists // device : preserved // tensor inputs : 1 // tensor outputs : 1 or 2 // Additionally: // - First input should be the only tensor input // - Has a bool keepdim argument static const register_formula_for dim_reduce_ops {{ "aten::argmax(Tensor self, int dim, bool keepdim) -> Tensor", "aten::argmin(Tensor self, int dim, bool keepdim) -> Tensor", "aten::max_values(Tensor self, int dim, bool keepdim) -> Tensor", "aten::min_values(Tensor self, int dim, bool keepdim) -> Tensor", "aten::mean(Tensor self, int dim, bool keepdim) -> Tensor", "aten::norm(Tensor self, Scalar p, int dim, bool keepdim) -> Tensor", "aten::std(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor", "aten::var(Tensor self, int dim, bool unbiased, bool keepdim) -> Tensor", "aten::logsumexp(Tensor self, int dim, bool keepdim) -> Tensor", "aten::all(Tensor self, int dim, bool keepdim) -> Tensor", "aten::any(Tensor self, int dim, bool keepdim) -> Tensor", // Ops returning indices as second output "aten::kthvalue(Tensor self, int k, int dim, bool keepdim) -> (Tensor, Tensor)", "aten::max(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", "aten::min(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", "aten::median(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", "aten::mode(Tensor self, int dim, bool keepdim) -> (Tensor, Tensor)", }, [](Node * node) -> type_vec_t { // NB: Note that while this function is generally meant to be used with ops that // have a single output, we will fix up its return right below. auto output_types = multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/1, /*integer_upcast=*/false); if (!output_types.empty() && node->outputs().size() == 2) { output_types.push_back(output_types.back()->toScalarType(at::kLong)); } return output_types; }}; // Requirements: // dims : preserved if keepdim == false, 1 smaller otherwise // scalar type : preserved if floating point, otherwise long/int64 // device : preserved // tensor inputs : 1 // tensor outputs : 1 // Additionally: // - First input should be the only tensor input // - has a bool keepdim argument static const register_formula_for dim_reduce_ops_with_integer_upcast {{ "aten::prod(Tensor self, int dim, bool keepdim) -> Tensor", }, [](Node * node) -> type_vec_t { return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/1, /*integer_upcast=*/true); }}; // Requirements: // dims : preserved if keepdim == false, 1 smaller otherwise // scalar type : preserved if floating point, otherwise long/int64 // device : preserved // tensor inputs : 1 // tensor outputs : 1 // Additionally: // - has bool keepdim and int[] dim arguments static const register_formula_for multidim_reduce_ops_with_integer_upcast {{ "aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto dim = node->get>(attr::dim)) { // TODO: can dim contain duplicates? return multidim_reduce_with_postprocess(node, /*num_reduce_dim=*/dim->size(), /*integer_upcast=*/true); } return {}; }}; static const auto get_device_int = [](c10::optional dev) -> c10::optional { if (!dev) return {}; if (dev->is_cpu()) { return {-1}; } return dev->has_index() ? c10::optional{dev->index()} : c10::nullopt; }; static const auto factory_with_ndim = [](Node * node, int dim) -> type_vec_t{ auto maybe_layout = node->get(attr::layout); if (!maybe_layout || maybe_layout != at::kStrided) return {}; auto maybe_device = get_device_int(node->get(attr::device)); if (!maybe_device) return {}; auto maybe_scalar_type = node->get(attr::dtype); if (!maybe_scalar_type) return {}; return {TensorType::create(*maybe_scalar_type, *maybe_device, dim)}; }; // Requirements: // dims : preserved // scalar type : equal to value of dtype // device : equal to value of device // tensor inputs : 1 // tensor outputs : 1 // Additionally: // - has ScalarType dtype, Layeout layout and Device device arguments static const register_formula_for like_factories_with_options {{ "aten::empty_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor", "aten::full_like(Tensor self, Scalar fill_value, *, int dtype, int layout, int[] device) -> Tensor", "aten::ones_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor", "aten::rand_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor", "aten::randint_like(Tensor self, int high, *, int dtype, int layout, int[] device) -> Tensor", "aten::randint_like(Tensor self, int low, int high, *, int dtype, int layout, int[] device) -> Tensor", "aten::randn_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor", "aten::zeros_like(Tensor self, *, int dtype, int layout, int[] device) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto type = node->namedInput(attr::self)->type()->cast()) { return factory_with_ndim(node, type->dim()); } return {}; }}; // Requirements: // dims : equal to number of elements in size // scalar type : equal to value of dtype // device : equal to value of device // tensor inputs : 1 // tensor outputs : 1 // Additionally: // - has int[] size, ScalarType dtype, Layeout layout and Device device arguments static const register_formula_for size_factories_with_options {{ "aten::empty(int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::full(int[] size, Scalar fill_value, *, int dtype, int layout, int[] device) -> Tensor", "aten::ones(int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::rand(int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::randn(int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::zeros(int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::randint(int high, int[] size, *, int dtype, int layout, int[] device) -> Tensor", "aten::randint(int low, int high, int[] size, *, int dtype, int layout, int[] device) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto maybe_size = node->get>(attr::size)) { return factory_with_ndim(node, maybe_size->size()); } return {}; }}; static const auto get_cast_scalar_type = [](Node *node) -> at::ScalarType { switch (node->kind()) { case aten::_cast_Byte: return at::kByte; case aten::_cast_Char: return at::kChar; case aten::_cast_Double: return at::kDouble; case aten::_cast_Float: return at::kFloat; case aten::_cast_Half: return at::kHalf; case aten::_cast_Int: return at::kInt; case aten::_cast_Long: return at::kLong; case aten::_cast_Short: return at::kShort; default: AT_ASSERTM(false, "unknown node kind in get_cast_scalar_type: ", node->kind().toQualString()); } }; static const register_formula_for cast_ops {{ "aten::_cast_Byte(Tensor self, bool non_blocking) -> Tensor", "aten::_cast_Char(Tensor self, bool non_blocking) -> Tensor", "aten::_cast_Double(Tensor self, bool non_blocking) -> Tensor", "aten::_cast_Float(Tensor self, bool non_blocking) -> Tensor", "aten::_cast_Half(Tensor self, bool non_blocking) -> Tensor", "aten::_cast_Int(Tensor self, bool non_blocking) -> Tensor", "aten::_cast_Long(Tensor self, bool non_blocking) -> Tensor", "aten::_cast_Short(Tensor self, bool non_blocking) -> Tensor", }, [](Node * node) -> type_vec_t { if (auto type = node->namedInput(attr::self)->type()->cast()) { return {type->toScalarType(get_cast_scalar_type(node))}; } return {}; }}; // First, try to match one of the registered formulas to their operator sets. for (auto & entry : shape_formulas) { if (entry.first.find(node)) { auto types = entry.second(node); if (types.empty()) { return false; } else { auto outputs = node->outputs(); JIT_ASSERT(types.size() == outputs.size()); for (size_t i = 0; i < types.size(); ++i) { JIT_ASSERT(outputs[i]->type()->isSubtypeOf(DynamicType::get())); outputs[i]->setType(types[i]); } return true; } } } // This section implements shape prop for an assorted set of nodes that only // need partial information about their input types. const auto input_type = [node](size_t index) { return node->input(index)->type()->cast(); }; if (node->matches("aten::masked_select(Tensor self, Tensor mask) -> Tensor")) { auto type = input_type(0); auto mask_type = input_type(1); if (type && mask_type) { if (type->dim() == 0 && mask_type->dim() == 0) { node->output()->setType(type->withDim(0)); } else { node->output()->setType(type->withDim(1)); } return true; } if (auto type = input_type(0)) { node->output()->setType(type->withDim(1)); return true; } } else if (node->matches("aten::dot(Tensor self, Tensor tensor) -> Tensor")) { if (auto type = any_tensor_type(node)) { node->output()->setType(type->withDim(0)); return true; } } else if (node->matches("aten::mv(Tensor self, Tensor vec) -> Tensor") || node->matches("aten::addmv(Tensor self, Tensor mat, Tensor vec, *, Scalar beta, Scalar alpha) -> Tensor")) { if (auto type = any_tensor_type(node)) { node->output()->setType(type->withDim(1)); return true; } } else if (node->matches("aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta, Scalar alpha) -> Tensor") || node->matches("aten::addbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor") || node->matches("aten::addr(Tensor self, Tensor vec1, Tensor vec2, *, Scalar beta, Scalar alpha) -> Tensor")) { if (auto type = any_tensor_type(node)) { node->output()->setType(type->withDim(2)); return true; } } else if (node->matches("aten::baddbmm(Tensor self, Tensor batch1, Tensor batch2, *, Scalar beta, Scalar alpha) -> Tensor")) { if (auto type = any_tensor_type(node)) { node->output()->setType(type->withDim(3)); return true; } } else if (node->matches("aten::index_select(Tensor self, int dim, Tensor index) -> Tensor")) { auto type = input_type(0); auto index_type = input_type(1); // index_select behaves very weirdly when self.dim() == 0. It allows both 0D and 1D // indices, and returns a value that has as many dimensions as index. if (type && index_type) { if (type->dim() == 0) { node->output()->setType(type->withDim(index_type->dim())); } else { node->output()->setType(type); } return true; } } else if (node->matches("aten::gather(Tensor self, int dim, Tensor index) -> Tensor")) { auto type = input_type(0); auto index_type = input_type(1); // Gather has this annoying edge case where index always needs to match the number of // dims of self, **except** when self is 1D and index is 0D in which case we return // a 0D output. if (type && index_type) { if (index_type->dim() == 0) { node->output()->setType(type->withDim(0)); } else { node->output()->setType(type); } return true; } } else if (node->matches("aten::embedding(Tensor weight, Tensor indices, int padding_idx, bool scale_grad_by_freq, bool sparse) -> Tensor")) { auto weight_type = input_type(0); auto indices_type = input_type(1); if (weight_type && indices_type) { node->output()->setType(weight_type->withDim(indices_type->dim() + 1)); return true; } } else if (node->matches("aten::bilinear(Tensor input1, Tensor input2, Tensor weight, Tensor bias) -> Tensor")) { if (auto type = input_type(0)) { node->output()->setType(type); return true; } if (auto type = input_type(1)) { node->output()->setType(type); return true; } } else if (node->matches("aten::dist(Tensor self, Tensor other, Scalar p) -> Tensor")) { if (auto type = any_tensor_type(node)) { node->output()->setType(type->withDim(0)); return true; } } // The code below implements formulas that need type information for all their // tensor inputs, and have exactly one output. std::vector tensor_types; static const auto reshape_prop = [](Node * node, Symbol shape_input, const std::vector& tensor_types) -> TensorTypePtr { if (auto list_size = determineListSize(node->namedInput(shape_input))) { return tensor_types.at(0)->withDim(*list_size); } return nullptr; }; const auto getSingleOutputType = [&]() -> TypePtr { if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { return tensor_types.at(0)->toScalarType(tensor_types.at(1)->scalarType()); } else if (node->matches("aten::view_as(Tensor self, Tensor other) -> Tensor") || node->matches("aten::expand_as(Tensor self, Tensor other) -> Tensor") || node->matches("aten::reshape_as(Tensor self, Tensor other) -> Tensor")) { return tensor_types.at(0)->withDim(tensor_types.at(1)->dim()); } else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor") || node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor") || node->matches("aten::as_strided(Tensor self, int[] size, int[] stride) -> Tensor") || node->matches("aten::as_strided(Tensor self, int[] size, int[] stride, int storage_offset) -> Tensor")) { return reshape_prop(node, attr::size, tensor_types); } else if (node->matches("aten::reshape(Tensor self, int[] shape) -> Tensor")) { return reshape_prop(node, attr::shape, tensor_types); } else if (node->matches("aten::repeat(Tensor self, int[] repeats) -> Tensor")) { return reshape_prop(node, attr::repeats, tensor_types); } else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor")) { auto & t = tensor_types.at(0); return t->withDim(t->dim() + 1); } else if (node->matches("aten::select(Tensor self, int dim, int index) -> Tensor") || node->matches("aten::diagonal(Tensor self, int offset, int dim1, int dim2) -> Tensor")) { auto & t = tensor_types.at(0); return t->dim() > 0 ? t->withDim(t->dim() - 1) : nullptr; } else if (node->matches("aten::matmul(Tensor self, Tensor other) -> Tensor")) { int dim1 = tensor_types.at(0)->dim(); int dim2 = tensor_types.at(1)->dim(); if (dim1 == 1 && dim2 == 1) { // Dot product return tensor_types.at(0)->withDim(0); } else if (dim1 == 2 && dim2 == 2) { // Matrix multiply return tensor_types.at(0); } else if (dim1 == 1 && dim2 == 2) { // Unsqueeze + matrix multiply + squeeze return tensor_types.at(0); } else if (dim1 == 2 && dim2 == 1) { // Matrix vector multiply return tensor_types.at(1); } else { // Batched matrix multiply (possibly with squeeze + unsqueeze if one argument is 1D) auto type = broadcast(tensor_types); if (tensor_types.at(0)->dim() == 1 || tensor_types.at(1)->dim() == 1) { type = type->withDim(type->dim() - 1); } return type; } } else if (node->matches("aten::nonzero(Tensor self) -> Tensor")) { return tensor_types.at(0)->toScalarType(at::kLong); } else if (node->matches("aten::take(Tensor self, Tensor index) -> Tensor")) { return tensor_types.at(1)->toScalarType(tensor_types.at(0)->scalarType()); } else if (node->matches("aten::diagflat(Tensor self, int offset) -> Tensor")) { return tensor_types.at(0)->withDim(2); } else if (node->matches("aten::diag(Tensor self, int diagonal) -> Tensor")) { auto & t = tensor_types.at(0); if (t->dim() == 1) { return t->withDim(2); } else if (t->dim() == 2) { return t->withDim(1); } else { return nullptr; } } else if (node->matches("aten::unfold(Tensor self, int dimension, int size, int step) -> Tensor")) { auto & t = tensor_types.at(0); return t->dim() == 0 ? t : t->withDim(t->dim() + 1); } else if (node->matches("aten::polygamma(int n, Tensor self) -> Tensor")) { return tensor_types.at(0); } return nullptr; }; if (auto maybe_tensor_types = gatherTensorTypes(node)) { tensor_types = std::move(*maybe_tensor_types); } else { return false; } if (node->outputs().size() == 1) { if (auto type = getSingleOutputType()) { node->output()->setType(type); return true; } } return false; } bool PropagateCompleteShapeOnNode(Node * node, bool insert_expands, std::vector tensor_types) { // For expensive ops we can directly encode their shape propagation // here, otherwise we fallback to running a fake version of the op // to get a quick and dirty propagation. if (node->matches("aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") || node->matches("aten::sub(Tensor self, Tensor other, *, Scalar alpha) -> Tensor") || node->matches("aten::mul(Tensor self, Tensor other) -> Tensor")) { // These nodes and "div" handle tensors of different shapes internally, // so there's no need to insert explicit expand nodes. Note that "div" is // handled by the fallthrough because it's not always safe to run it due // to integer divide-by-zero. return PropagateShapeOnNodeByRunningIt(node); } else if (node->matches("aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor") || node->matches("aten::sub(Tensor self, Scalar other, Scalar alpha) -> Tensor") || node->matches("aten::mul(Tensor self, Scalar other) -> Tensor") || node->matches("aten::pow(Tensor self, Scalar exponent) -> Tensor")) { node->output()->setType(tensor_types.at(0)); return true; } else if (insert_expands && ( node->matches("aten::pow(Tensor self, Tensor exponent) -> Tensor") || node->matches("aten::min(Tensor self, Tensor other) -> Tensor") || node->matches("aten::max(Tensor self, Tensor other) -> Tensor") || node->matches("aten::lt(Tensor self, Tensor other) -> Tensor") || node->matches("aten::le(Tensor self, Tensor other) -> Tensor") || node->matches("aten::gt(Tensor self, Tensor other) -> Tensor") || node->matches("aten::ge(Tensor self, Tensor other) -> Tensor") || node->matches("aten::eq(Tensor self, Tensor other) -> Tensor") || node->matches("aten::ne(Tensor self, Tensor other) -> Tensor"))) { // Binary broadcasting ops // NB: we don't handle the nodes in any other way (note the lack of return!), // because the type casting logic in scalar cases is non-trivial. // It's better to just run them. broadcastBinary(node, tensor_types, 0, 1); return PropagateShapeOnNodeByRunningIt(node); } else if (node->matches("aten::neg(Tensor self) -> Tensor") || node->matches("aten::sigmoid(Tensor self) -> Tensor") || node->matches("aten::tanh(Tensor self) -> Tensor")) { node->output()->setType(tensor_types.at(0)->contiguous()); return true; } else if (node->matches("aten::mm(Tensor self, Tensor mat2) -> Tensor")) { auto lhs_type = tensor_types.at(0); auto rhs_type = tensor_types.at(1); SHAPE_ASSERT(lhs_type->sizes().size() == 2 && rhs_type->sizes().size() == 2); node->output()->setType(CompleteTensorType::create( lhs_type->scalarType(), lhs_type->device(), at::IntList{lhs_type->sizes().at(0), rhs_type->sizes().at(1)})); return true; } else if (node->matches("aten::t(Tensor self) -> Tensor")) { auto tp = tensor_types.at(0); auto sizes = tp->sizes(); auto strides = tp->strides(); SHAPE_ASSERT(sizes.size() == 2); std::swap(sizes.at(0), sizes.at(1)); std::swap(strides.at(0), strides.at(1)); node->output()->setType(tp->withSizesStrides(sizes, strides)); return true; } else if (node->matches("aten::narrow(Tensor self, int dim, int start, int length) -> Tensor", /*const_inputs=*/{attr::dim, attr::length})) { auto tp = tensor_types.at(0); auto sizes = tp->sizes(); int64_t dim = node->get(attr::dim).value(); int64_t length = node->get(attr::length).value(); SHAPE_ASSERT(dim >= 0 && static_cast(dim) < sizes.size()); sizes.at(dim) = length; node->output()->setType(tp->withSizesStrides(sizes, tp->strides())); return true; } else if (node->matches("aten::sum(Tensor self) -> Tensor")) { node->output()->setType(tensor_types.at(0)->withSizes({})); return true; } else if (node->matches("aten::sum(Tensor self, int[] dim, bool keepdim) -> Tensor", /*const_inputs=*/{attr::dim, attr::keepdim})) { auto & tp = tensor_types.at(0); auto sizes = tp->sizes(); auto dims = node->get>(attr::dim).value(); bool keepdim = node->get(attr::keepdim).value(); std::reverse(dims.begin(), dims.end()); for (int64_t dim : dims) { SHAPE_ASSERT(dim >= 0 && static_cast(dim) < sizes.size()); if (keepdim) { sizes.at(dim) = 1; } else { sizes.erase(sizes.begin() + dim); } } node->output()->setType(tp->withSizes(sizes)); return true; } else if (node->matches("aten::squeeze(Tensor self, int dim) -> Tensor", /*const_inputs=*/attr::dim)) { auto & tp = tensor_types.at(0); auto sizes = tp->sizes(); auto strides = tp->strides(); int64_t dim = wrapDim(node->get(attr::dim).value(), sizes); SHAPE_ASSERT(dim >= 0 && static_cast(dim) < sizes.size()); if (sizes.at(dim) == 1) { sizes.erase(sizes.begin() + dim); strides.erase(strides.begin() + dim); } node->output()->setType(tp->withSizesStrides(sizes, strides)); return true; } else if (node->matches("aten::unsqueeze(Tensor self, int dim) -> Tensor", /*const_inputs=*/attr::dim)) { auto & tp = tensor_types.at(0); auto sizes = tp->sizes(); auto strides = tp->strides(); int64_t dim = wrapDim(node->get(attr::dim).value(), sizes); SHAPE_ASSERT(dim >= 0 && static_cast(dim) <= sizes.size()); int64_t new_stride = dim >= static_cast(sizes.size()) ? 1 : sizes.at(dim) * strides.at(dim); sizes.insert(sizes.begin() + dim, 1); strides.insert(strides.begin() + dim, new_stride); node->output()->setType(tp->withSizesStrides(sizes, strides)); return true; } else if (node->matches("aten::view(Tensor self, int[] size) -> Tensor", /*const_inputs=*/attr::size)) { auto sizes = node->get>(attr::size).value(); bool inferred = false; size_t inferred_idx; int64_t size_product = 1; for (size_t i = 0; i < sizes.size(); ++i) { if (sizes[i] == -1) { if (inferred) throw propagation_error(); inferred = true; inferred_idx = i; } else { size_product *= sizes[i]; } } if (inferred) { SHAPE_ASSERT(size_product != 0); size_t numel = 1; for (int64_t s : tensor_types.at(0)->sizes()) numel *= s; int64_t inferred_size = numel / size_product; sizes[inferred_idx] = inferred_size; } node->output()->setType(tensor_types.at(0)->withSizes(sizes)); return true; } else if (node->matches("aten::type_as(Tensor self, Tensor other) -> Tensor")) { if (tensor_types.at(0)->scalarType() == tensor_types.at(1)->scalarType()) { node->output()->setType(node->namedInput(attr::self)->type()); } else { // This will be a copy, so the result will be contiguous node->output()->setType(tensor_types.at(1)->withSizes(tensor_types.at(0)->sizes())); } return true; } else if (node->matches("aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor", /*const_inputs=*/attr::size)) { auto tp = tensor_types.at(0); std::vector sizes, strides; std::tie(sizes, strides) = at::inferExpandGeometry( tp->sizes(), tp->strides(), node->get>(attr::size).value()); node->output()->setType(tp->withSizesStrides(sizes, strides)); return true; } else if (node->matches("aten::index_select(Tensor self, int dim, Tensor index) -> Tensor", /*const_inputs=*/attr::dim)) { auto ten = tensor_types.at(0); auto index = tensor_types.at(1); int64_t dim = node->get(attr::dim).value(); SHAPE_ASSERT(index->sizes().size() == 1); SHAPE_ASSERT(dim >= 0 && static_cast(dim) < ten->sizes().size()); std::vector sizes = ten->sizes(); sizes[dim] = index->sizes()[0]; node->output()->setType(ten->withSizes(sizes)); return true; } else if (node->matches("aten::chunk(Tensor self, int chunks, int dim) -> Tensor[]", /*const_inputs=*/{attr::chunks, attr::dim})) { auto input_type = tensor_types.at(0); auto sizes = input_type->sizes(); const auto & strides = input_type->strides(); int64_t dim = node->get(attr::dim).value(); int64_t chunks = node->get(attr::chunks).value(); sizes[dim] /= chunks; for (Value * output : node->outputs()) { output->setType(input_type->withSizesStrides(sizes, strides)); } if (input_type->sizes().at(dim) % chunks != 0) { sizes[dim] = input_type->sizes().at(dim) % chunks; node->outputs().back()->setType(input_type->withSizesStrides(sizes, strides)); } return true; } else if (node->kind() == onnx::Shape) { SHAPE_ASSERT(node->inputs().size() == 1 && node->outputs().size() == 1); std::vector dim_vec = {(int64_t)tensor_types.at(0)->sizes().size()}; at::IntList dims(dim_vec); node->output()->setType( CompleteTensorType::create(at::kLong, -1, dims)); return true; } else if (node->kind() == onnx::Reshape) { setUnshapedType(node); return true; } setUnshapedType(node); return false; } void PropagateShapeOnBlock(Block * block, bool insert_expands) { for (Node * node : block->nodes()) { try { PropagateShapeOnNode(node, insert_expands); } catch(propagation_error& e) { setUnshapedType(node); } catch(std::exception & e) { if(auto sl = node->getSourceLocation()) { sl->wrapAndRethrowException(e, "operation failed shape propagation"); } else { throw; } } } } } // anonymous namespace void PropagateInputShapes(Graph & graph) { PropagateShapeOnBlock(graph.block()); } namespace { void EraseShapeInformation(at::ArrayRef vals) { for (Value * v : vals) { v->setType(unshapedType(v->type())); } } void EraseShapeInformation(Block * b) { EraseShapeInformation(b->inputs()); EraseShapeInformation(b->outputs()); for (Node * n : b->nodes()) { EraseShapeInformation(n->outputs()); for (Block *sb : n->blocks()) { EraseShapeInformation(sb); } } } } // anonymous namespace void EraseShapeInformation(Graph & graph) { EraseShapeInformation(graph.block()); } }}