[Static Runtime] Add native op for aten::expand_as (#64024)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64024

`aten::expand_as` creates a view of the input tensor. This change adds its native op implementation for the static runtime.

Test Plan: - Added `StaticRuntime.IndividualOps_ExpandAs`

Reviewed By: hlu1

Differential Revision: D30546851

fbshipit-source-id: e53483048af890bc41b6192a1ab0c5ba0ee2bdc0
This commit is contained in:
Don Jang 2021-08-26 12:58:05 -07:00 committed by Facebook GitHub Bot
parent 95d0b3199b
commit cbfec02007
3 changed files with 33 additions and 0 deletions

View File

@ -349,6 +349,12 @@ const std::string embedding_bag_max_last_offset = R"JIT(
return torch.embedding_bag(a, b, c, False, 2, False, None, True)
)JIT";
const auto expand_as_script = R"JIT(
def forward(self, input: Tensor, other:Tensor):
a = input.expand_as(other)
return a.clone()
)JIT";
const auto sign_tensor = R"JIT(
def forward(self, input: Tensor):
return torch.sign(input).clone()

View File

@ -610,6 +610,17 @@ TEST(StaticRuntime, IndividualOps_Detach) {
testStaticRuntime(detach_script_1, args, args2);
}
TEST(StaticRuntime, IndividualOps_ExpandAs) {
auto a = at::randn({3,1});
auto b = at::randn({3,2});
auto c = at::randn({4,1});
auto d = at::randn({4,2});
std::vector<IValue> args{a, b};
std::vector<IValue> args2{c, d};
testStaticRuntime(expand_as_script, args);
testStaticRuntime(expand_as_script, args, args2);
}
TEST(StaticRuntime, IndividualOps_Full) {
auto dtype = at::ScalarType::Int;
auto cpu = at::Device(DeviceType::CPU);

View File

@ -370,6 +370,22 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
};
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(
aten::expand_as,
aten_expand_as,
[](Node* n) -> SROperator {
if (!n->matches(torch::schema(
"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
LogAndDumpSchema(n);
return nullptr;
}
return [](ProcessedNode* p_node) {
const auto& self = p_node->Input(0).toTensor();
const auto& other = p_node->Input(1).toTensor();
p_node->Output(0) = self.expand(other.sizes());
};
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(
prim::isinstance,
prim_isinstance,