[SobolEngine] Fix edge case of dtype of first sample (#51578)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/51578

https://github.com/pytorch/pytorch/pull/49710 introduced an edge case in which
drawing a single sample resulted in ignoring the `dtype` arg to `draw`. This
fixes this and adds a unit test to cover this behavior.

Test Plan: Unit tests

Reviewed By: danielrjiang

Differential Revision: D26204393

fbshipit-source-id: 441a44dc035002e7bbe6b662bf6d1af0e2cd88f4
This commit is contained in:
Max Balandat 2021-02-02 14:22:35 -08:00 committed by Facebook GitHub Bot
parent 4746b3d1fb
commit a990ff7001
2 changed files with 13 additions and 1 deletions

View File

@ -1566,6 +1566,18 @@ class AbstractTestCases:
def test_sobolengine_draw_scrambled(self):
self.test_sobolengine_draw(scramble=True)
def test_sobolengine_first_point(self):
for dtype in (torch.float, torch.double):
engine = torch.quasirandom.SobolEngine(2, scramble=False)
sample = engine.draw(1, dtype=dtype)
self.assertTrue(torch.all(sample == 0))
self.assertEqual(sample.dtype, dtype)
for dtype in (torch.float, torch.double):
engine = torch.quasirandom.SobolEngine(2, scramble=True, seed=123456)
sample = engine.draw(1, dtype=dtype)
self.assertTrue(torch.all(sample != 0))
self.assertEqual(sample.dtype, dtype)
def test_sobolengine_continuing(self, scramble: bool = False):
ref_sample = self._sobol_reference_samples(scramble=scramble)
engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)

View File

@ -83,7 +83,7 @@ class SobolEngine(object):
"""
if self.num_generated == 0:
if n == 1:
result = self._first_point
result = self._first_point.to(dtype)
else:
result, self.quasi = torch._sobol_engine_draw(
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,