mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Update on "Turn on strict dtype checking for test_torch.py"
Partially addresses #20376 I do this by overriding assertEqual in classes that opt into this. This means I have to fix #33821. The fix is a little unsatisfactory as idiomatic Python 2 super() calls don't work (since the class is no longer in scope); hopefully this will just work when we go to Python 3. General approach taken: - A lot of dtype mismatches are because we specified tensor constants that infer to some dtype, but the actual dtype needed is something else. Those are easy, just annotate the tensor() constructor (often a legacy Tensor/FloatTensor call) with dtype - There are a few cases where the promotion rules are nontrivial. Some of them I just typed out the expected promotion rules manually (based on trial and error) - There are some more complex cases; if it gets too hairy I just set exact_dtype=False and nope the fuck out I don't have time to do it for all the other classes. But the setup should work if people just incrementally add the overrides to classes, and then eventually flip the default. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Differential Revision: [D20125791](https://our.internmc.facebook.com/intern/diff/D20125791) [ghstack-poisoned]
This commit is contained in:
commit
d5f7429756
|
|
@ -14435,8 +14435,8 @@ class TestDevicePrecision(TestCase):
|
|||
torch.uint8)
|
||||
def test_from_sequence(self, device, dtype):
|
||||
seq = [list(range(i * 4, i * 4 + 4)) for i in range(5)]
|
||||
reference = torch.arange(0, 20, dtype=dtype).resize_(5, 4)
|
||||
self.assertEqual(torch.tensor(seq, dtype=dtype, device=device), reference)
|
||||
reference = torch.arange(0, 20).resize_(5, 4)
|
||||
self.assertEqual(torch.tensor(seq, dtype=dtype, device=device), reference, exact_dtype=False)
|
||||
|
||||
def test_cat(self, device):
|
||||
SIZE = 10
|
||||
|
|
@ -15343,9 +15343,9 @@ def generate_test_function(cls,
|
|||
# Compares CPU and device inputs and outputs
|
||||
precision = dtype2precision.get(dtype, float_precision)
|
||||
|
||||
self.assertEqual(cpu_tensor, device_tensor, prec=precision)
|
||||
self.assertEqual(cpu_args, device_args, prec=precision)
|
||||
self.assertEqual(cpu_result, device_result, prec=precision)
|
||||
self.assertEqual(cpu_tensor, device_tensor, prec=precision, exact_dtype=False)
|
||||
self.assertEqual(cpu_args, device_args, prec=precision, exact_dtype=False)
|
||||
self.assertEqual(cpu_result, device_result, prec=precision, exact_dtype=False)
|
||||
|
||||
test_name = "test_" + op_str + subtest_str
|
||||
assert not hasattr(cls, test_name), "{0} already in TestDevicePrecision".format(test_name)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user