[TorchArrow][efficiency][3/n] variadic versions of op fused /unfused inference_wrapper_run_flat (#81133)

Summary:
Added `variadic` version (just an optimization) of the registered fused and unfused ops.

Reviewed By: tenpercent

Differential Revision: D37456033

Pull Request resolved: https://github.com/pytorch/pytorch/pull/81133
Approved by: https://github.com/tenpercent, https://github.com/qxy11
This commit is contained in:
Mandar Deshpande 2022-07-13 10:30:32 +00:00 committed by PyTorch MergeBot
parent e423354b91
commit 937ca69f15
2 changed files with 7 additions and 0 deletions

View File

@ -167,6 +167,10 @@ void OptimizeGraph(
graph,
fromQualString("fb::sigrid_transforms_torch_bind"),
fromQualString("fb::variadic_sigrid_transforms_torch_bind"));
UseVariadicOp(
graph,
fromQualString("torcharrow::inference_wrapper_run_flat"),
fromQualString("torcharrow::variadic_inference_wrapper_run_flat"));
// These fused ops only have out variants - we can't do the fusion when
// out variants are disabled.
FuseSignLog1P(graph);

View File

@ -872,6 +872,9 @@ void FuseListUnpack(std::shared_ptr<torch::jit::Graph>& graph) {
OP_PAIR(
"torcharrow::inference_wrapper_run_flat",
"static_runtime::fused_inference_wrapper_run_flat"),
OP_PAIR(
"torcharrow::variadic_inference_wrapper_run_flat",
"static_runtime::fused_variadic_inference_wrapper_run_flat"),
OP_PAIR("fb::equally_split", "static_runtime::fused_equally_split"),
OP_PAIR(
"fb::sigrid_transforms", "static_runtime::fused_sigrid_transforms"),