mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add simple correctness check for native MHA (#72941)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72941
Simple test for MHA, use cos similarity as metric since scaling generate mismatch. Cuda is validated, CPU fix a following (We can land this with onlyCuda flag, and remove it once CPU is also done)
Test Plan:
For cuda:
buck build mode/opt -c fbcode.enable_gpu_sections=true caffe2/test:nn && buck-out/gen/caffe2/test/nn\#binary.par -r test_native_multihead_attention_cuda_float32 2>&1 | pastry
Reviewed By: swolchok
Differential Revision: D33906921
fbshipit-source-id: ad447401eb7002f22ed533d620a6b544524b3f58
(cherry picked from commit 45b778da27)
This commit is contained in:
parent
b1bd2268f8
commit
f41db99a56
|
|
@ -17533,7 +17533,6 @@ class TestNNDeviceType(NNTestCase):
|
|||
)
|
||||
self.assertEqual(output_non_contig, output_contig)
|
||||
|
||||
|
||||
@onlyCUDA
|
||||
@dtypes(*itertools.product((torch.int, torch.long), (torch.int, torch.long)))
|
||||
def test_embedding_bag_bfloat16(self, device, dtypes):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user