pytorch/torch/_decomp
Jack Zhang 64d9afd8a7 Register nll_loss2d decompositions for core aten (#133534)
When exporting a training model for Executorch (which requires all ops to be core aten) with cross entropy loss (`torch.nn.CrossEntropyLoss`), we ran into the following error from the fx verifier in `to_edge`:

```
torch._export.verifier.SpecViolationError: Operator torch._ops.aten.nll_loss2d_forward.default is not Aten Canonical.
```
The aten [implementation](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossNLL.cpp#L624) of `torch.nn.CrossEntropyLoss` uses `nll_loss2d_forward` for inference and `nll_loss2d_backward` for training, so we need to add the decompositions for both (which already exist) to the list of core aten decompositions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133534
Approved by: https://github.com/JacobSzwejbka
2024-08-19 18:26:48 +00:00
..
__init__.py Register nll_loss2d decompositions for core aten (#133534) 2024-08-19 18:26:48 +00:00
decompositions_for_jvp.py [BE][Easy][15/19] enforce style for empty lines in import segments in torch/_d*/ (#129767) 2024-07-31 21:18:11 +00:00
decompositions_for_rng.py Add None return type to init (#132335) 2024-08-01 15:26:45 +00:00
decompositions.py Decompose _unsafe_index_put into index_put (#133365) 2024-08-19 18:07:23 +00:00