[ONNX] Fix index_put_ usage (#161263)

Summary:
It's hard to understand how it's working in most of our models, but in general it looks like `aten::copy_` is replaced incorrectly.
There are two schemas for `aten::copy_`:
1. `aten::copy_.Tensor(Tensor(a!) self, Tensor other) -> Tensor(a!)`
2. `aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)`

According to the logic in the comments we don't need one of the parameters for `aten::index_put_`.

It seems logic has been inferred from ordinary `aten::copy` where there could be a third parameter which is `non_blocking` flag.

Depending on the execution environment the sliced copying can be replaced either by first schema or by second schema with explicitly setting default parameter to `False`.

If first schema is selected it will lead to the crash (which is easily to catch in our prod env). In case of the second schema selection, there is no crash, but the third parameter is treated as `accumulate` parameter of the `index_put_` function which doesn't make sense.

So, in any case usage of the third parameter must be removed from the `aten::copy_` replacement.

For more details and check this post:
https://fb.workplace.com/groups/1405155842844877/permalink/25337687649165028/

Test Plan:

The test fails in production envirounment only.
In the test env `non_blocking` flag is mapped as `False` to the `acumulate` flag, which doesn't cause test to fail, but has no sense in terms of flags mapping.

The export works without errors, before the fix it was failing with accessing by index out of bounds vector, like this:
```
   1095     _C._jit_onnx_log("Torch IR graph at exception: ", graph)
File ~/.bento/kernels/bento_kernel_gaia_ml/1578/bento_kernel_gaia_ml_binary-inplace#link-tree/torch/onnx/utils.py:636, in _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop, fixed_batch_size, params_dict, dynamic_axes, input_names, module)
    629 _C._jit_pass_lower_all_tuples(graph)
    630 # in _jit_pass_onnx, symbolic functions are called for each node for conversion.
    631 # However, there are nodes that cannot be converted without additional context.
    632 # For example, the number of outputs from split (and whether it is static or dynamic) is unknown
    633 # until the point where it is unpacked by listUnpack node.
    634 # This pass does a preprocess, and prepares the nodes such that enough context can be received
    635 # by the symbolic function.
--> 636 _C._jit_pass_onnx_remove_inplace_ops_for_onnx(graph, module)
    637 _C._jit_pass_onnx_preprocess(graph)
    639 # onnx does not support tuples, so try to remove them
RuntimeError: vector::_M_range_check: __n (which is 2) >= this->size() (which is 2)
```

The test script:
```
import torch as th
import tempfile

class CopyTest(th.nn.Module):
    def forward(
        self,
        input_th: th.Tensor
    ):
        to_fill = th.ones((3, 3))
        to_fill[:, 0] = input_th[:, 0]
        return to_fill

m = CopyTest()

test_tensor = th.zeros((3, 3))

with tempfile.NamedTemporaryFile() as f:
    th.onnx.export(
            m,
            (test_tensor,),
            f,
            export_params=True,
            opset_version=17,
            do_constant_folding=True,
            input_names=["input"],
            output_names=["features"],
            dynamo=False,
        )
```

The exported model test:
```
import torch
import onnx
import onnxruntime

model_name = '/home/ironsided/test_model.onnx'
onnx_model = onnx.load(model_name)
onnx.checker.check_model(onnx_model)

example_inputs = (torch.zeros(3, 3),)

onnx_inputs = [tensor.numpy(force=True) for tensor in example_inputs]
print(f"Input length: {len(onnx_inputs)}")
print(f"Sample input: {onnx_inputs}")

ort_session = onnxruntime.InferenceSession(
    model_name, providers=["CPUExecutionProvider"]
)

onnxruntime_input = {input_arg.name: input_value for input_arg, input_value in zip(ort_session.get_inputs(), onnx_inputs)}

# ONNX Runtime returns a list of outputs
onnxruntime_outputs = ort_session.run(None, onnxruntime_input)[0]

print(onnxruntime_outputs)
```

The produced result is correct:
```
Input length: 1
Sample input: [array([[0., 0., 0.],
       [0., 0., 0.],
       [0., 0., 0.]], dtype=float32)]
[[0. 1. 1.]
 [0. 1. 1.]
 [0. 1. 1.]]
```

Rollback Plan:

Differential Revision: D80797028

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161263
Approved by: https://github.com/justinchuby, https://github.com/jermenkoo
This commit is contained in:
Roman Bobniev 2025-08-27 18:53:10 +00:00 committed by PyTorch MergeBot
parent 1750cc8037
commit 47ecd2042f

View File

@ -191,8 +191,7 @@ std::pair<Value*, Value*> PrepareCopyForONNX(Node* node) {
expanded_value->node()->copyMetadata(node);
auto index_put = graph->insert(
aten::index_put_,
{node->input(0), dummy_list, expanded_value, node->input(2)});
aten::index_put_, {node->input(0), dummy_list, expanded_value});
index_put->node()->copyMetadata(node);
index_put->copyMetadata(node->output());
node->output()->replaceAllUsesWith(index_put);