mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Static Runtime] Support aten::to.prim_dtype overload (#64928)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/64928 Added support this overload of `aten::to`: ``` aten::to.prim_dtype(Tensor(a) self, int? dtype, bool non_blocking=False, bool copy=False) -> Tensor(a|b) ``` Test Plan: `buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- IndividualOps_to` Reviewed By: hlu1 Differential Revision: D30901398 fbshipit-source-id: 38ce807c30185e92dd472b404b362f22ac7e4efb
This commit is contained in:
parent
a8c0b362ce
commit
d5f64afc38
|
|
@ -270,36 +270,32 @@ const auto pow_script_sca_ten = R"JIT(
|
|||
return torch.pow(input, exponent).clone()
|
||||
)JIT";
|
||||
|
||||
// to.dtype
|
||||
const auto to_script_0 = R"JIT(
|
||||
const auto to_script_dtype = R"JIT(
|
||||
def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
|
||||
a = input + input
|
||||
return torch.to(a, dtype, non_blocking, copy, memory_format).clone()
|
||||
)JIT";
|
||||
|
||||
// to.dtype, strided
|
||||
const auto to_script_1 = R"JIT(
|
||||
const auto to_script_dtype_strided = R"JIT(
|
||||
def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
|
||||
b = input.permute(0, 2, 3, 1)
|
||||
return torch.to(b, dtype, non_blocking, copy, memory_format).clone()
|
||||
)JIT";
|
||||
|
||||
// to.prim_dtype
|
||||
const auto to_script_2 = R"JIT(
|
||||
def forward(self, input:Tensor, dtype: int, non_blocking: bool, copy: bool):
|
||||
const auto to_script_prim_dtype = R"JIT(
|
||||
def forward(self, input:Tensor, dtype: Optional[int], non_blocking: bool, copy: bool):
|
||||
a = input + input
|
||||
return torch.to(a, dtype, non_blocking, copy).clone()
|
||||
)JIT";
|
||||
|
||||
// to.other
|
||||
const auto to_script_3 = R"JIT(
|
||||
const auto to_script_other = R"JIT(
|
||||
def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int):
|
||||
a = input + input
|
||||
return torch.to(a, other, non_blocking, copy, memory_format).clone()
|
||||
)JIT";
|
||||
|
||||
// if input is float tensor, b could be alias of a
|
||||
const auto to_script_4 = R"JIT(
|
||||
const auto to_script_alias = R"JIT(
|
||||
def forward(self, input:Tensor):
|
||||
a = input + input
|
||||
b = a.float()
|
||||
|
|
|
|||
|
|
@ -624,19 +624,26 @@ TEST(StaticRuntime, IndividualOps_to) {
|
|||
std::vector<IValue> args0{a, b, c, d, e};
|
||||
std::vector<IValue> args1{a, b, c, d};
|
||||
std::vector<IValue> args2{a, other, c, d, e};
|
||||
std::vector<IValue> args3{a, c10::nullopt, c, d};
|
||||
|
||||
testStaticRuntime(to_script_0, args0); // to.dtype
|
||||
testStaticRuntime(to_script_1, args0); // to.dtype, strided
|
||||
testStaticRuntime(to_script_2, args1); // to.prim_dtype
|
||||
testStaticRuntime(to_script_3, args2); // to.other
|
||||
testStaticRuntime(to_script_4, {a}); // alias
|
||||
testStaticRuntime(to_script_dtype, args0);
|
||||
testStaticRuntime(to_script_dtype_strided, args0);
|
||||
testStaticRuntime(to_script_prim_dtype, args1);
|
||||
if (!d) {
|
||||
testStaticRuntime(to_script_prim_dtype, args3);
|
||||
}
|
||||
testStaticRuntime(to_script_other, args2);
|
||||
testStaticRuntime(to_script_alias, {a});
|
||||
|
||||
// dynamic shapes
|
||||
testStaticRuntime(to_script_0, args0, {a2, b, c, d, e}); // to.dtype
|
||||
testStaticRuntime(to_script_1, args0, {a2, b, c, d, e}); // to.dtype
|
||||
testStaticRuntime(to_script_2, args1, {a2, b, c, d}); // to.prim_dtype
|
||||
testStaticRuntime(to_script_3, args2, {a2, a2_other, c, d, e}); // to.other
|
||||
testStaticRuntime(to_script_4, {a}, {a2});
|
||||
testStaticRuntime(to_script_dtype, args0, {a2, b, c, d, e});
|
||||
testStaticRuntime(to_script_dtype_strided, args0, {a2, b, c, d, e});
|
||||
testStaticRuntime(to_script_prim_dtype, args1, {a2, b, c, d});
|
||||
if (!d) {
|
||||
testStaticRuntime(to_script_prim_dtype, args3, {a2, c10::nullopt, c, d});
|
||||
}
|
||||
testStaticRuntime(to_script_other, args2, {a2, a2_other, c, d, e});
|
||||
testStaticRuntime(to_script_alias, {a}, {a2});
|
||||
};
|
||||
for (const bool non_blocking : {false, true}) {
|
||||
for (const bool copy : {false, true}) {
|
||||
|
|
|
|||
|
|
@ -340,35 +340,50 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROpe
|
|||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
|
||||
if (!n->matches(torch::schema(
|
||||
"aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)")) &&
|
||||
!n->matches(torch::schema(
|
||||
"aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
}
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& in0_t = p_node->Input(0).toTensor();
|
||||
const auto in2_i = p_node->Input(2).toBool();
|
||||
const auto in3_i = p_node->Input(3).toBool();
|
||||
const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
|
||||
if (p_node->Input(1).isTensor()) {
|
||||
// to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool
|
||||
// copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
|
||||
const auto in1_t = p_node->Input(1).toTensor();
|
||||
if (n->matches(torch::schema(
|
||||
"aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& in0_t = p_node->Input(0).toTensor();
|
||||
const auto& in1_t = p_node->Input(1).toTensor();
|
||||
const auto in2_i = p_node->Input(2).toBool();
|
||||
const auto in3_i = p_node->Input(3).toBool();
|
||||
const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
|
||||
p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
|
||||
} else {
|
||||
// to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False,
|
||||
// bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
|
||||
};
|
||||
}
|
||||
if (n->matches(torch::schema(
|
||||
"aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& in0_t = p_node->Input(0).toTensor();
|
||||
const auto in1_i = p_node->Input(1).toScalarType();
|
||||
const auto in2_i = p_node->Input(2).toBool();
|
||||
const auto in3_i = p_node->Input(3).toBool();
|
||||
const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
|
||||
p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
|
||||
}
|
||||
// in case that Output(0) is an alias of in0_t, copy the tensor.
|
||||
if (p_node->Output(0).toTensor().unsafeGetTensorImpl() ==
|
||||
in0_t.unsafeGetTensorImpl()) {
|
||||
p_node->Output(0) = in0_t.clone();
|
||||
}
|
||||
};
|
||||
};
|
||||
}
|
||||
if (n->matches(torch::schema(
|
||||
"aten::to.prim_dtype(Tensor(a) self, int? dtype, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"))) {
|
||||
return [](ProcessedNode* p_node) {
|
||||
const auto& in0_t = p_node->Input(0).toTensor();
|
||||
const auto in1_i = p_node->Input(1).toOptional<at::ScalarType>();
|
||||
const auto in2_i = p_node->Input(2).toBool();
|
||||
const auto in3_i = p_node->Input(3).toBool();
|
||||
// To mimick the behavior of the JIT interpreter, if both dtype
|
||||
// and copy are not set, we return self. Otherwise, we assume
|
||||
// that dtype is set.
|
||||
if (!in1_i && !in3_i) {
|
||||
p_node->Output(0) = in0_t;
|
||||
} else {
|
||||
TORCH_CHECK(
|
||||
in1_i,
|
||||
"dytpe cannot be None when copy is True for aten::to.prim_dtype");
|
||||
p_node->Output(0) = at::native::to(in0_t, *in1_i, in2_i, in3_i);
|
||||
}
|
||||
};
|
||||
}
|
||||
LogAndDumpSchema(n);
|
||||
return nullptr;
|
||||
});
|
||||
|
||||
REGISTER_NATIVE_OPERATOR_FUNCTOR(
|
||||
|
|
|
|||
|
|
@ -1024,7 +1024,7 @@ REGISTER_OPERATOR_FUNCTOR(
|
|||
|
||||
if (p_node->Output(0).isNone()) {
|
||||
// handle dtype, layout, and device
|
||||
at::ScalarType dtype;
|
||||
c10::optional<at::ScalarType> dtype;
|
||||
c10::Layout layout = self.layout();
|
||||
c10::Device device = self.device();
|
||||
if (p_node->Input(1).isTensor()) {
|
||||
|
|
@ -1033,7 +1033,7 @@ REGISTER_OPERATOR_FUNCTOR(
|
|||
layout = other.layout();
|
||||
device = other.device();
|
||||
} else {
|
||||
dtype = p_node->Input(1).toScalarType();
|
||||
dtype = p_node->Input(1).toOptional<at::ScalarType>();
|
||||
}
|
||||
|
||||
if (memory_format == c10::MemoryFormat::Preserve) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user