Support dropout(nested tensor) (#79318)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/79318
Approved by: https://github.com/jbschlosser
This commit is contained in:
Yifan Shen 2022-06-17 00:46:05 +00:00 committed by PyTorch MergeBot
parent 91c5fc323b
commit 1211ab679c
3 changed files with 88 additions and 1 deletions

View File

@ -240,8 +240,14 @@
- func: _shape_as_tensor(Tensor self) -> Tensor - func: _shape_as_tensor(Tensor self) -> Tensor
- func: dropout(Tensor input, float p, bool train) -> Tensor - func: dropout(Tensor input, float p, bool train) -> Tensor
dispatch:
CompositeImplicitAutograd: dropout
NestedTensorCPU, NestedTensorCUDA: dropout_nested
- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!) - func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
dispatch:
CompositeImplicitAutograd: dropout_
NestedTensorCPU, NestedTensorCUDA: dropout_nested_
- func: feature_dropout(Tensor input, float p, bool train) -> Tensor - func: feature_dropout(Tensor input, float p, bool train) -> Tensor

View File

@ -647,5 +647,19 @@ at::Tensor NestedTensor_get_nested_size_tensor(const at::Tensor& self){
return get_nested_size_tensor(self); return get_nested_size_tensor(self);
} }
Tensor dropout_nested(const Tensor& input, double p, bool train) {
auto input_ptr = get_nested_tensor_impl(input);
const Tensor & input_buffer = input_ptr->get_buffer(),
sizemat = input_ptr->get_nested_size_tensor();
Tensor output_buffer = at::dropout(input_buffer, p, train);
return wrap_buffer(output_buffer, sizemat.clone());
}
Tensor& dropout_nested_(Tensor& input, double p, bool train) {
Tensor input_buffer = get_buffer(input);
at::dropout_(input_buffer, p, train);
return input;
}
} // namespace native } // namespace native
} // namespace at } // namespace at

View File

@ -9,7 +9,7 @@ from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, instantiate_device_type_tests,
skipMeta, skipMeta,
) )
from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state
from torch import nested_tensor from torch import nested_tensor
# Tests are ported from pytorch/nestedtensor. # Tests are ported from pytorch/nestedtensor.
@ -482,6 +482,73 @@ class TestNestedTensorDeviceType(TestCase):
with self.assertRaisesRegex(RuntimeError, msg): with self.assertRaisesRegex(RuntimeError, msg):
nt1.clone(memory_format=torch.channels_last) nt1.clone(memory_format=torch.channels_last)
# cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
@dtypes(torch.float, torch.double)
def test_dropout(self, device, dtype):
# edge case: empty nested tensor
nt0 = torch.nested_tensor([])
y = torch.nn.functional.dropout(nt0, 0.5)
self.nt_equal(nt0, y)
# normal nested tensor
ntensors = 4
nt = self.random_nt(device, dtype, ntensors, (4, 4))
# edge case: invalid dropout
self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
# edge case: no dropout
dropouter = torch.nn.Dropout(0.0)
y0 = dropouter(nt)
y1 = torch.nn.functional.dropout(nt, 0.0)
self.nt_equal(nt, y0)
self.nt_equal(nt, y1)
# edge case: all dropout
dropouter = torch.nn.Dropout(1.0)
y0 = dropouter(nt)
y1 = torch.nn.functional.dropout(nt, 1.0)
nt0 = nt.clone()
for i in range(ntensors):
nt0[i].fill_(0.0)
self.nt_equal(nt0, y0)
self.nt_equal(nt0, y1)
# normal case: normal dropout
p = 0.2
y = torch.nn.functional.dropout(nt, p)
expect = nt.clone()
for i in range(ntensors):
actual_tensor = y[i].view(-1)
expect_tensor = expect[i].view(-1)
for j in range(actual_tensor.shape[0]):
if actual_tensor[j].item() == 0.0:
expect_tensor[j] = 0.0
else:
expect_tensor[j] /= 1.0 - p
self.nt_equal(y, expect)
with freeze_rng_state():
dropouter = torch.nn.Dropout(p)
y0 = dropouter(nt)
with freeze_rng_state():
y1 = torch.nn.functional.dropout(nt, p)
self.nt_equal(y0, y1)
# inplace
# in principle, since we have established the correctness of functional, we could simply compare inplace vs functional
# in practice, cuda functional has its own implementation to skip `bernoulli_`
# so cuda functional will differ from cuda inplace causing test failure
# in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)`
# on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)`
expect = nt.clone()
torch.nn.functional.dropout(nt, p, inplace=True)
for i in range(ntensors):
actual_tensor = nt[i].view(-1)
expect_tensor = expect[i].view(-1)
for j in range(actual_tensor.shape[0]):
if actual_tensor[j].item() == 0.0:
expect_tensor[j] = 0.0
else:
expect_tensor[j] /= 1.0 - p
self.nt_equal(nt, expect)
class TestNestedTensorAutograd(TestCase): class TestNestedTensorAutograd(TestCase):
def nt_equal(self, nt1, nt2): def nt_equal(self, nt1, nt2):
self.assertEqual(nt1.dtype, nt2.dtype) self.assertEqual(nt1.dtype, nt2.dtype)