[Static Runtime] Remove wrappers for aten::cat (#62067)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62067

The wrapper for aten::cat is no longer needed after the variadic cat change in D29565344 (ae58a4c45d) .
Also added a simple test to test dynamic shapes, i.e., input tensors in args2 are larger than in args1.

Reviewed By: navahgar, mikeiovine

Differential Revision: D29864600

fbshipit-source-id: 44a712c2e776815c09e0bf5631412149b81274b2
This commit is contained in:
Hao Lu 2021-07-23 20:31:58 -07:00 committed by Facebook GitHub Bot
parent 7c09de8384
commit 78f7d8ccfa
2 changed files with 2 additions and 17 deletions

View File

@ -752,6 +752,8 @@ TEST(StaticRuntime, IndividualOps_VarCat) {
// 3D tensors - cat dim = 2
std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), 2};
testStaticRuntime(var_cat_script, args3);
testStaticRuntime(var_cat_script, args1, args2);
}
TEST(StaticRuntime, LongModel) {

View File

@ -407,23 +407,6 @@ REGISTER_OPERATOR_FUNCTOR(aten::nan_to_num, aten_nan_to_num, [](Node* n) -> SROp
at::native::nan_to_num_out(in0_t, in1_d, in2_d, in3_d, out_t);
};
});
REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator {
if (!n->matches(
torch::schema("aten::cat(Tensor[] tensors, int dim=0) -> Tensor"))) {
LogAndDumpSchema(n);
return nullptr;
}
return [](ProcessedNode* p_node) {
const auto in0_tl = p_node->Input(0).toTensorVector();
const auto in1_i = p_node->Input(1).toInt();
if (p_node->Output(0).isNone()) {
p_node->Output(0) = create_empty_from(in0_tl[0]);
}
auto& out_t = p_node->Output(0).toTensor();
fastResizeToZero(out_t);
at::native::_cat_out_cpu(in0_tl, in1_i, out_t);
};
});
// Split out into a function to appease MSVC's pre-processor
SROperator aten_stack(Node* n) {