[Static Runtime] Support aten::to.prim_dtype overload (#64928)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64928

Added support this overload of `aten::to`:
```
aten::to.prim_dtype(Tensor(a) self, int? dtype, bool non_blocking=False, bool copy=False) -> Tensor(a|b)
```

Test Plan: `buck test caffe2/benchmarks/static_runtime:static_runtime_cpptest -- IndividualOps_to`

Reviewed By: hlu1

Differential Revision: D30901398

fbshipit-source-id: 38ce807c30185e92dd472b404b362f22ac7e4efb
This commit is contained in:
Mike Iovine 2021-10-07 10:20:30 -07:00 committed by Facebook GitHub Bot
parent a8c0b362ce
commit d5f64afc38
4 changed files with 66 additions and 48 deletions

View File

@ -270,36 +270,32 @@ const auto pow_script_sca_ten = R"JIT(
return torch.pow(input, exponent).clone()
)JIT";
// to.dtype
const auto to_script_0 = R"JIT(
const auto to_script_dtype = R"JIT(
def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
a = input + input
return torch.to(a, dtype, non_blocking, copy, memory_format).clone()
)JIT";
// to.dtype, strided
const auto to_script_1 = R"JIT(
const auto to_script_dtype_strided = R"JIT(
def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
b = input.permute(0, 2, 3, 1)
return torch.to(b, dtype, non_blocking, copy, memory_format).clone()
)JIT";
// to.prim_dtype
const auto to_script_2 = R"JIT(
def forward(self, input:Tensor, dtype: int, non_blocking: bool, copy: bool):
const auto to_script_prim_dtype = R"JIT(
def forward(self, input:Tensor, dtype: Optional[int], non_blocking: bool, copy: bool):
a = input + input
return torch.to(a, dtype, non_blocking, copy).clone()
)JIT";
// to.other
const auto to_script_3 = R"JIT(
const auto to_script_other = R"JIT(
def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int):
a = input + input
return torch.to(a, other, non_blocking, copy, memory_format).clone()
)JIT";
// if input is float tensor, b could be alias of a
const auto to_script_4 = R"JIT(
const auto to_script_alias = R"JIT(
def forward(self, input:Tensor):
a = input + input
b = a.float()

View File

@ -624,19 +624,26 @@ TEST(StaticRuntime, IndividualOps_to) {
std::vector<IValue> args0{a, b, c, d, e};
std::vector<IValue> args1{a, b, c, d};
std::vector<IValue> args2{a, other, c, d, e};
std::vector<IValue> args3{a, c10::nullopt, c, d};
testStaticRuntime(to_script_0, args0); // to.dtype
testStaticRuntime(to_script_1, args0); // to.dtype, strided
testStaticRuntime(to_script_2, args1); // to.prim_dtype
testStaticRuntime(to_script_3, args2); // to.other
testStaticRuntime(to_script_4, {a}); // alias
testStaticRuntime(to_script_dtype, args0);
testStaticRuntime(to_script_dtype_strided, args0);
testStaticRuntime(to_script_prim_dtype, args1);
if (!d) {
testStaticRuntime(to_script_prim_dtype, args3);
}
testStaticRuntime(to_script_other, args2);
testStaticRuntime(to_script_alias, {a});
// dynamic shapes
testStaticRuntime(to_script_0, args0, {a2, b, c, d, e}); // to.dtype
testStaticRuntime(to_script_1, args0, {a2, b, c, d, e}); // to.dtype
testStaticRuntime(to_script_2, args1, {a2, b, c, d}); // to.prim_dtype
testStaticRuntime(to_script_3, args2, {a2, a2_other, c, d, e}); // to.other
testStaticRuntime(to_script_4, {a}, {a2});
testStaticRuntime(to_script_dtype, args0, {a2, b, c, d, e});
testStaticRuntime(to_script_dtype_strided, args0, {a2, b, c, d, e});
testStaticRuntime(to_script_prim_dtype, args1, {a2, b, c, d});
if (!d) {
testStaticRuntime(to_script_prim_dtype, args3, {a2, c10::nullopt, c, d});
}
testStaticRuntime(to_script_other, args2, {a2, a2_other, c, d, e});
testStaticRuntime(to_script_alias, {a}, {a2});
};
for (const bool non_blocking : {false, true}) {
for (const bool copy : {false, true}) {

View File

@ -340,35 +340,50 @@ REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::narrow, aten_narrow, [](Node* n) -> SROpe
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(aten::to, aten_to, [](Node* n) -> SROperator {
if (!n->matches(torch::schema(
"aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)")) &&
!n->matches(torch::schema(
"aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
LogAndDumpSchema(n);
return nullptr;
}
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in2_i = p_node->Input(2).toBool();
const auto in3_i = p_node->Input(3).toBool();
const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
if (p_node->Input(1).isTensor()) {
// to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool
// copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
const auto in1_t = p_node->Input(1).toTensor();
if (n->matches(torch::schema(
"aten::to.other(Tensor(a) self, Tensor other, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto& in1_t = p_node->Input(1).toTensor();
const auto in2_i = p_node->Input(2).toBool();
const auto in3_i = p_node->Input(3).toBool();
const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
p_node->Output(0) = at::native::to(in0_t, in1_t, in2_i, in3_i, in4_o);
} else {
// to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False,
// bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)
};
}
if (n->matches(torch::schema(
"aten::to.dtype(Tensor(a) self, ScalarType dtype, bool non_blocking=False, bool copy=False, MemoryFormat? memory_format=None) -> Tensor(a)"))) {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_i = p_node->Input(1).toScalarType();
const auto in2_i = p_node->Input(2).toBool();
const auto in3_i = p_node->Input(3).toBool();
const auto in4_o = p_node->Input(4).toOptional<at::MemoryFormat>();
p_node->Output(0) = at::native::to(in0_t, in1_i, in2_i, in3_i, in4_o);
}
// in case that Output(0) is an alias of in0_t, copy the tensor.
if (p_node->Output(0).toTensor().unsafeGetTensorImpl() ==
in0_t.unsafeGetTensorImpl()) {
p_node->Output(0) = in0_t.clone();
}
};
};
}
if (n->matches(torch::schema(
"aten::to.prim_dtype(Tensor(a) self, int? dtype, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"))) {
return [](ProcessedNode* p_node) {
const auto& in0_t = p_node->Input(0).toTensor();
const auto in1_i = p_node->Input(1).toOptional<at::ScalarType>();
const auto in2_i = p_node->Input(2).toBool();
const auto in3_i = p_node->Input(3).toBool();
// To mimick the behavior of the JIT interpreter, if both dtype
// and copy are not set, we return self. Otherwise, we assume
// that dtype is set.
if (!in1_i && !in3_i) {
p_node->Output(0) = in0_t;
} else {
TORCH_CHECK(
in1_i,
"dytpe cannot be None when copy is True for aten::to.prim_dtype");
p_node->Output(0) = at::native::to(in0_t, *in1_i, in2_i, in3_i);
}
};
}
LogAndDumpSchema(n);
return nullptr;
});
REGISTER_NATIVE_OPERATOR_FUNCTOR(

View File

@ -1024,7 +1024,7 @@ REGISTER_OPERATOR_FUNCTOR(
if (p_node->Output(0).isNone()) {
// handle dtype, layout, and device
at::ScalarType dtype;
c10::optional<at::ScalarType> dtype;
c10::Layout layout = self.layout();
c10::Device device = self.device();
if (p_node->Input(1).isTensor()) {
@ -1033,7 +1033,7 @@ REGISTER_OPERATOR_FUNCTOR(
layout = other.layout();
device = other.device();
} else {
dtype = p_node->Input(1).toScalarType();
dtype = p_node->Input(1).toOptional<at::ScalarType>();
}
if (memory_format == c10::MemoryFormat::Preserve) {