mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Static Runtime] Add native op for aten::expand_as (#64024)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64024 `aten::expand_as` creates a view of the input tensor. This change adds its native op implementation for the static runtime. Test Plan: - Added `StaticRuntime.IndividualOps_ExpandAs` Reviewed By: hlu1 Differential Revision: D30546851 fbshipit-source-id: e53483048af890bc41b6192a1ab0c5ba0ee2bdc0
This commit is contained in:
parent
95d0b3199b
commit
cbfec02007
|
|
@ -349,6 +349,12 @@ const std::string embedding_bag_max_last_offset = R"JIT(
|
|||
return torch.embedding_bag(a, b, c, False, 2, False, None, True)
|
||||
)JIT";
|
||||
|
||||
const auto expand_as_script = R"JIT(
|
||||
def forward(self, input: Tensor, other:Tensor):
|
||||
a = input.expand_as(other)
|
||||
return a.clone()
|
||||
)JIT";
|
||||
|
||||
const auto sign_tensor = R"JIT(
|
||||
def forward(self, input: Tensor):
|
||||
return torch.sign(input).clone()
|
||||
|
|
|
|||
|
|
@ -610,6 +610,17 @@ TEST(StaticRuntime, IndividualOps_Detach) {
|
|||
testStaticRuntime(detach_script_1, args, args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, IndividualOps_ExpandAs) {
|
||||
auto a = at::randn({3,1});
|
||||
auto b = at::randn({3,2});
|
||||
auto c = at::randn({4,1});
|
||||
auto d = at::randn({4,2});
|
||||
std::vector<IValue> args{a, b};
|
||||
std::vector<IValue> args2{c, d};
|
||||
testStaticRuntime(expand_as_script, args);
|
||||
testStaticRuntime(expand_as_script, args, args2);
|
||||
}
|
||||
|
||||
TEST(StaticRuntime, IndividualOps_Full) {
|
||||
auto dtype = at::ScalarType::Int;
|
||||
auto cpu = at::Device(DeviceType::CPU);
|
||||
|
|
|
|||
|
|
@ -370,6 +370,22 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
|||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
aten::expand_as,
|
||||
aten_expand_as,
|
||||
[](Node* n) -> SROperator {
|
||||
if (!n->matches(torch::schema(
|
||||
"aten::expand_as(Tensor(a) self, Tensor other) -> Tensor(a)"))) {
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
}
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& self = p_node->Input(0).toTensor();
|
||||
const auto& other = p_node->Input(1).toTensor();
|
||||
p_node->Output(0) = self.expand(other.sizes());
|
||||
};
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
prim::isinstance,
|
||||
prim_isinstance,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user