[functorch] Update neural_tangent_kernels.ipynb (pytorch/functorch#788)

Fix a small bug
```python3
    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM')        # should be torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK')        # should be torch.einsum('NMKK->NMK', result)
```
This commit is contained in:
Local State 2022-05-09 23:25:55 -04:00 committed by Jon Janzen
parent 126fd93c21
commit ff9558a2ea

View File

@ -288,9 +288,9 @@
" if compute == 'full':\n",
" return result\n",
" if compute == 'trace':\n",
" return torch.einsum('NMKK->NM')\n",
" return torch.einsum('NMKK->NM', result)\n",
" if compute == 'diagonal':\n",
" return torch.einsum('NMKK->NMK')"
" return torch.einsum('NMKK->NMK', result)"
]
},
{