mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[SobolEngine] Fix edge case of dtype of first sample (#51578)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/51578 https://github.com/pytorch/pytorch/pull/49710 introduced an edge case in which drawing a single sample resulted in ignoring the `dtype` arg to `draw`. This fixes this and adds a unit test to cover this behavior. Test Plan: Unit tests Reviewed By: danielrjiang Differential Revision: D26204393 fbshipit-source-id: 441a44dc035002e7bbe6b662bf6d1af0e2cd88f4
This commit is contained in:
parent
4746b3d1fb
commit
a990ff7001
|
|
@ -1566,6 +1566,18 @@ class AbstractTestCases:
|
|||
def test_sobolengine_draw_scrambled(self):
|
||||
self.test_sobolengine_draw(scramble=True)
|
||||
|
||||
def test_sobolengine_first_point(self):
|
||||
for dtype in (torch.float, torch.double):
|
||||
engine = torch.quasirandom.SobolEngine(2, scramble=False)
|
||||
sample = engine.draw(1, dtype=dtype)
|
||||
self.assertTrue(torch.all(sample == 0))
|
||||
self.assertEqual(sample.dtype, dtype)
|
||||
for dtype in (torch.float, torch.double):
|
||||
engine = torch.quasirandom.SobolEngine(2, scramble=True, seed=123456)
|
||||
sample = engine.draw(1, dtype=dtype)
|
||||
self.assertTrue(torch.all(sample != 0))
|
||||
self.assertEqual(sample.dtype, dtype)
|
||||
|
||||
def test_sobolengine_continuing(self, scramble: bool = False):
|
||||
ref_sample = self._sobol_reference_samples(scramble=scramble)
|
||||
engine = torch.quasirandom.SobolEngine(2, scramble=scramble, seed=123456)
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class SobolEngine(object):
|
|||
"""
|
||||
if self.num_generated == 0:
|
||||
if n == 1:
|
||||
result = self._first_point
|
||||
result = self._first_point.to(dtype)
|
||||
else:
|
||||
result, self.quasi = torch._sobol_engine_draw(
|
||||
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated, dtype=dtype,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user