mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[functorch] Update neural_tangent_kernels.ipynb (pytorch/functorch#788)
Fix a small bug
```python3
if compute == 'full':
return result
if compute == 'trace':
return torch.einsum('NMKK->NM') # should be torch.einsum('NMKK->NM', result)
if compute == 'diagonal':
return torch.einsum('NMKK->NMK') # should be torch.einsum('NMKK->NMK', result)
```
This commit is contained in:
parent
126fd93c21
commit
ff9558a2ea
|
|
@ -288,9 +288,9 @@
|
|||
" if compute == 'full':\n",
|
||||
" return result\n",
|
||||
" if compute == 'trace':\n",
|
||||
" return torch.einsum('NMKK->NM')\n",
|
||||
" return torch.einsum('NMKK->NM', result)\n",
|
||||
" if compute == 'diagonal':\n",
|
||||
" return torch.einsum('NMKK->NMK')"
|
||||
" return torch.einsum('NMKK->NMK', result)"
|
||||
]
|
||||
},
|
||||
{
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user