Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13002
Batch dim wasn't handled in the CPU impl (will fail for inputs with N > 1).
Fixing that here.
Differential Revision: D10515159
fbshipit-source-id: ee7e4f489d2d4de793f550b31db7c0e2ba3651e8