mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
When exporting a training model for Executorch (which requires all ops to be core aten) with cross entropy loss (`torch.nn.CrossEntropyLoss`), we ran into the following error from the fx verifier in `to_edge`: ``` torch._export.verifier.SpecViolationError: Operator torch._ops.aten.nll_loss2d_forward.default is not Aten Canonical. ``` The aten [implementation](https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/LossNLL.cpp#L624) of `torch.nn.CrossEntropyLoss` uses `nll_loss2d_forward` for inference and `nll_loss2d_backward` for training, so we need to add the decompositions for both (which already exist) to the list of core aten decompositions. Pull Request resolved: https://github.com/pytorch/pytorch/pull/133534 Approved by: https://github.com/JacobSzwejbka |
||
|---|---|---|
| .. | ||
| __init__.py | ||
| decompositions_for_jvp.py | ||
| decompositions_for_rng.py | ||
| decompositions.py | ||