mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Fix]: TSConverter handles call ops with multiple outputs (#129294)
#### Issue
* Current call ops does not handle IR with multiple outputs. If an op has multiple outputs, we add an implicit unpack to map output. E.g.,
```
%5 : Tensor, %6 : Tensor = aten::max(%x.1, %3, %4), scope: export.test_converter.M:: # /data/users/jiashenc/pytorch/test/export/test_converter.py:774:20
```
* There are some cases that `prim::If` sub-blocks do not return any outputs. E.g.,
```
%9 : bool = aten::gt(%8, %3), scope: export.test_converter.M::/torch.nn.modules.pooling.AdaptiveMaxPool2d::pool # <string>:5:9
= prim::If(%9), scope: export.test_converter.M::/torch.nn.modules.pooling.AdaptiveMaxPool2d::pool # <string>:5:2
block0():
-> ()
block1():
= prim::RaiseException(%5, %4), scope: export.test_converter.M::/torch.nn.modules.pooling.AdaptiveMaxPool2d::pool # <string>:5:2
-> ()
```
#### Test Plan
We did an exhaustive search of all torch APIs that can return multiple outputs. We sample some of common ones and add new test cases based on those.
* `pytest test/export/test_converter.py -s -k test_ts2ep_multi_outputs_on_call_ops`
#### Appendix
* aten ops that return multiple outputs.
```
aten._batch_norm_impl_index
aten._batch_norm_no_update
aten._batch_norm_with_update
aten._batch_norm_with_update_functional
aten._cudnn_rnn
aten._efficient_attention_backward
aten._efficient_attention_forward
aten._embedding_bag
aten._embedding_bag_forward_only
aten._flash_attention_backward
aten._flash_attention_forward
aten._fused_adam
aten._fused_dropout
aten._fused_moving_avg_obs_fq_helper
aten._linalg_det
aten._linalg_eigh
aten._linalg_slogdet
aten._linalg_solve_ex
aten._linalg_svd
aten._native_batch_norm_legit
aten._native_batch_norm_legit_functional
aten._native_batch_norm_legit_no_training
aten._pack_padded_sequence
aten._prelu_kernel_backward
aten._scaled_dot_product_efficient_attention
aten._scaled_dot_product_efficient_attention_backward
aten._scaled_dot_product_flash_attention
aten._scaled_dot_product_flash_attention_backward
aten._scaled_dot_product_flash_attention_for_cpu
aten._scaled_dot_product_flash_attention_for_cpu_backward
aten._thnn_fused_lstm_cell
aten._thnn_fused_lstm_cell_backward_impl
aten._unique2
aten._weight_norm_interface
aten.adaptive_max_pool2d
aten.adaptive_max_pool3d
aten.aminmax
aten.batch_norm_backward
aten.convolution_backward
aten.cudnn_batch_norm
aten.cudnn_batch_norm_backward
aten.cummax
aten.cummin
aten.fractional_max_pool2d
aten.frexp
aten.grid_sampler_2d_backward
aten.grid_sampler_3d_backward
aten.gru
aten.linalg_cholesky_ex
aten.linalg_eig
aten.linalg_inv_ex
aten.linalg_ldl_factor_ex
aten.linalg_lu
aten.linalg_lu_factor_ex
aten.linalg_qr
aten.linear_backward
aten.log_sigmoid_forward
aten.lstm
aten.lu_unpack
aten.max
aten.max_pool2d_with_indices
aten.max_pool3d_with_indices
aten.median
aten.min
aten.miopen_batch_norm
aten.miopen_batch_norm_backward
aten.mkldnn_rnn_layer
aten.mkldnn_rnn_layer_backward
aten.mode
aten.multilabel_margin_loss_forward
aten.nanmedian
aten.native_batch_norm
aten.native_batch_norm_backward
aten.native_dropout
aten.native_group_norm
aten.native_group_norm_backward
aten.native_layer_norm
aten.native_layer_norm_backward
aten.nll_loss2d_forward
aten.nll_loss_forward
aten.quantized_gru
aten.quantized_lstm
aten.rnn_relu
aten.rnn_tanh
aten.sort
aten.std_mean
aten.topk
aten.triangular_solve
aten.unique_dim
aten.var_mean
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129294
Approved by: https://github.com/angelayi
This commit is contained in:
parent
7f1cda1533
commit
686b7f046a
|
|
@ -1027,6 +1027,23 @@ class TestConverter(TestCase):
|
|||
# Cannot script variable length inputs.
|
||||
self._check_equal_ts_ep_converter(func2, tuple(values), ["trace"])
|
||||
|
||||
def test_ts2ep_multi_outputs_on_call_ops(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.pool = torch.nn.AdaptiveMaxPool2d((2, 2), return_indices=True)
|
||||
|
||||
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
||||
return (
|
||||
torch.max(x, dim=0),
|
||||
torch.topk(x, 3),
|
||||
torch.sort(x, dim=0),
|
||||
self.pool(y),
|
||||
)
|
||||
|
||||
inp = (torch.randn([4, 4]), torch.randn([1, 1, 10, 10]))
|
||||
self._check_equal_ts_ep_converter(M(), inp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -510,12 +510,16 @@ class TS2FXGraphConverter:
|
|||
# TODO: covnert sourceRange() into stack_trace
|
||||
# fx_node.meta["stack_trace"] = node.sourceRange()
|
||||
|
||||
outs = tuple(node.outputs())
|
||||
if len(outs) == 1:
|
||||
if node.outputsSize() == 1:
|
||||
output_name = node.output().debugName()
|
||||
self.name_to_node[output_name] = fx_node
|
||||
elif len(outs) > 1:
|
||||
raise RuntimeError("Number of outputs > 1 is not supported yet")
|
||||
else:
|
||||
for i, outp in enumerate(node.outputs()):
|
||||
output_name = outp.debugName()
|
||||
next_fx_node = self.fx_graph.call_function(
|
||||
operator.getitem, (fx_node, i)
|
||||
)
|
||||
self.name_to_node[output_name] = next_fx_node
|
||||
|
||||
def convert_prim_TupleConstruct(self, node: torch._C.Node):
|
||||
self._convert_prim_iterator(node)
|
||||
|
|
@ -747,12 +751,10 @@ class TS2FXGraphConverter:
|
|||
|
||||
cond_node = self.fx_graph.call_function(torch.cond, args, {})
|
||||
|
||||
outs = tuple(node.outputs())
|
||||
if len(outs) == 1:
|
||||
# prim::If can also have zero output.
|
||||
if node.outputsSize() == 1:
|
||||
output_name = node.output().debugName()
|
||||
self.name_to_node[output_name] = cond_node
|
||||
elif len(outs) > 1:
|
||||
raise RuntimeError("Number of outputs > 1 is not supported yet")
|
||||
|
||||
def convert_aten_Bool(self, node: torch._C.Node):
|
||||
self._convert_as_noop(node)
|
||||
|
|
@ -856,10 +858,14 @@ class TS2FXGraphConverter:
|
|||
)
|
||||
else:
|
||||
raise ValueError(f"Output {output_name} not found")
|
||||
if args:
|
||||
|
||||
if len(args) == 1:
|
||||
self.fx_graph.output(
|
||||
args[0]
|
||||
) # Get rid of an extra list wrapped around final output.
|
||||
else:
|
||||
# Sub-block of prim::If can have zero output.
|
||||
self.fx_graph.output([])
|
||||
|
||||
|
||||
class ExplainTS2FXGraphConverter(TS2FXGraphConverter):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user