mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[Static Runtime] Remove wrappers for aten::cat (#62067)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/62067
The wrapper for aten::cat is no longer needed after the variadic cat change in D29565344 (ae58a4c45d) .
Also added a simple test to test dynamic shapes, i.e., input tensors in args2 are larger than in args1.
Reviewed By: navahgar, mikeiovine
Differential Revision: D29864600
fbshipit-source-id: 44a712c2e776815c09e0bf5631412149b81274b2
This commit is contained in:
parent
7c09de8384
commit
78f7d8ccfa
|
|
@ -752,6 +752,8 @@ TEST(StaticRuntime, IndividualOps_VarCat) {
|
|||
// 3D tensors - cat dim = 2
|
||||
std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), 2};
|
||||
testStaticRuntime(var_cat_script, args3);
|
||||
|
||||
testStaticRuntime(var_cat_script, args1, args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, LongModel) {
|
||||
|
|
|
|||
|
|
@ -407,23 +407,6 @@ REGISTER_OPERATOR_FUNCTOR(aten::nan_to_num, aten_nan_to_num, [](Node* n) -> SROp
|
|||
at::native::nan_to_num_out(in0_t, in1_d, in2_d, in3_d, out_t);
|
||||
};
|
||||
});
|
||||
REGISTER_OPERATOR_FUNCTOR(aten::cat, aten_cat, [](Node* n) -> SROperator {
|
||||
if (!n->matches(
|
||||
torch::schema("aten::cat(Tensor[] tensors, int dim=0) -> Tensor"))) {
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
}
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto in0_tl = p_node->Input(0).toTensorVector();
|
||||
const auto in1_i = p_node->Input(1).toInt();
|
||||
if (p_node->Output(0).isNone()) {
|
||||
p_node->Output(0) = create_empty_from(in0_tl[0]);
|
||||
}
|
||||
auto& out_t = p_node->Output(0).toTensor();
|
||||
fastResizeToZero(out_t);
|
||||
at::native::_cat_out_cpu(in0_tl, in1_i, out_t);
|
||||
};
|
||||
});
|
||||
|
||||
// Split out into a function to appease MSVC's pre-processor
|
||||
SROperator aten_stack(Node* n) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user