Update on "Turn on strict dtype checking for test_torch.py"

Partially addresses #20376

I do this by overriding assertEqual in classes that opt into
this.  This means I have to fix #33821.  The fix is a little
unsatisfactory as idiomatic Python 2 super() calls don't work
(since the class is no longer in scope); hopefully this will just
work when we go to Python 3.

General approach taken:
- A lot of dtype mismatches are because we specified tensor constants
  that infer to some dtype, but the actual dtype needed is something else.
  Those are easy, just annotate the tensor() constructor (often a legacy
  Tensor/FloatTensor call) with dtype
- There are a few cases where the promotion rules are nontrivial.  Some of them
  I just typed out the expected promotion rules manually (based on trial
  and error)
- There are some more complex cases; if it gets too hairy I just
  set exact_dtype=False and nope the fuck out

I don't have time to do it for all the other classes.  But the setup
should work if people just incrementally add the overrides to classes,
and then eventually flip the default.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Differential Revision: [D20125791](https://our.internmc.facebook.com/intern/diff/D20125791)

[ghstack-poisoned]
This commit is contained in:
Edward Z. Yang 2020-02-27 12:13:21 -08:00
commit d5f7429756

View File

@ -14435,8 +14435,8 @@ class TestDevicePrecision(TestCase):
torch.uint8)
def test_from_sequence(self, device, dtype):
seq = [list(range(i * 4, i * 4 + 4)) for i in range(5)]
reference = torch.arange(0, 20, dtype=dtype).resize_(5, 4)
self.assertEqual(torch.tensor(seq, dtype=dtype, device=device), reference)
reference = torch.arange(0, 20).resize_(5, 4)
self.assertEqual(torch.tensor(seq, dtype=dtype, device=device), reference, exact_dtype=False)
def test_cat(self, device):
SIZE = 10
@ -15343,9 +15343,9 @@ def generate_test_function(cls,
# Compares CPU and device inputs and outputs
precision = dtype2precision.get(dtype, float_precision)
self.assertEqual(cpu_tensor, device_tensor, prec=precision)
self.assertEqual(cpu_args, device_args, prec=precision)
self.assertEqual(cpu_result, device_result, prec=precision)
self.assertEqual(cpu_tensor, device_tensor, prec=precision, exact_dtype=False)
self.assertEqual(cpu_args, device_args, prec=precision, exact_dtype=False)
self.assertEqual(cpu_result, device_result, prec=precision, exact_dtype=False)
test_name = "test_" + op_str + subtest_str
assert not hasattr(cls, test_name), "{0} already in TestDevicePrecision".format(test_name)