mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support dropout(nested tensor) (#79318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79318 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
91c5fc323b
commit
1211ab679c
|
|
@ -240,8 +240,14 @@
|
|||
- func: _shape_as_tensor(Tensor self) -> Tensor
|
||||
|
||||
- func: dropout(Tensor input, float p, bool train) -> Tensor
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: dropout
|
||||
NestedTensorCPU, NestedTensorCUDA: dropout_nested
|
||||
|
||||
- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
|
||||
dispatch:
|
||||
CompositeImplicitAutograd: dropout_
|
||||
NestedTensorCPU, NestedTensorCUDA: dropout_nested_
|
||||
|
||||
- func: feature_dropout(Tensor input, float p, bool train) -> Tensor
|
||||
|
||||
|
|
|
|||
|
|
@ -647,5 +647,19 @@ at::Tensor NestedTensor_get_nested_size_tensor(const at::Tensor& self){
|
|||
return get_nested_size_tensor(self);
|
||||
}
|
||||
|
||||
Tensor dropout_nested(const Tensor& input, double p, bool train) {
|
||||
auto input_ptr = get_nested_tensor_impl(input);
|
||||
const Tensor & input_buffer = input_ptr->get_buffer(),
|
||||
sizemat = input_ptr->get_nested_size_tensor();
|
||||
Tensor output_buffer = at::dropout(input_buffer, p, train);
|
||||
return wrap_buffer(output_buffer, sizemat.clone());
|
||||
}
|
||||
|
||||
Tensor& dropout_nested_(Tensor& input, double p, bool train) {
|
||||
Tensor input_buffer = get_buffer(input);
|
||||
at::dropout_(input_buffer, p, train);
|
||||
return input;
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from torch.testing._internal.common_device_type import (
|
|||
instantiate_device_type_tests,
|
||||
skipMeta,
|
||||
)
|
||||
from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests
|
||||
from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state
|
||||
from torch import nested_tensor
|
||||
|
||||
# Tests are ported from pytorch/nestedtensor.
|
||||
|
|
@ -482,6 +482,73 @@ class TestNestedTensorDeviceType(TestCase):
|
|||
with self.assertRaisesRegex(RuntimeError, msg):
|
||||
nt1.clone(memory_format=torch.channels_last)
|
||||
|
||||
# cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
|
||||
@dtypes(torch.float, torch.double)
|
||||
def test_dropout(self, device, dtype):
|
||||
# edge case: empty nested tensor
|
||||
nt0 = torch.nested_tensor([])
|
||||
y = torch.nn.functional.dropout(nt0, 0.5)
|
||||
self.nt_equal(nt0, y)
|
||||
# normal nested tensor
|
||||
ntensors = 4
|
||||
nt = self.random_nt(device, dtype, ntensors, (4, 4))
|
||||
# edge case: invalid dropout
|
||||
self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
|
||||
self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
|
||||
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
|
||||
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
|
||||
# edge case: no dropout
|
||||
dropouter = torch.nn.Dropout(0.0)
|
||||
y0 = dropouter(nt)
|
||||
y1 = torch.nn.functional.dropout(nt, 0.0)
|
||||
self.nt_equal(nt, y0)
|
||||
self.nt_equal(nt, y1)
|
||||
# edge case: all dropout
|
||||
dropouter = torch.nn.Dropout(1.0)
|
||||
y0 = dropouter(nt)
|
||||
y1 = torch.nn.functional.dropout(nt, 1.0)
|
||||
nt0 = nt.clone()
|
||||
for i in range(ntensors):
|
||||
nt0[i].fill_(0.0)
|
||||
self.nt_equal(nt0, y0)
|
||||
self.nt_equal(nt0, y1)
|
||||
# normal case: normal dropout
|
||||
p = 0.2
|
||||
y = torch.nn.functional.dropout(nt, p)
|
||||
expect = nt.clone()
|
||||
for i in range(ntensors):
|
||||
actual_tensor = y[i].view(-1)
|
||||
expect_tensor = expect[i].view(-1)
|
||||
for j in range(actual_tensor.shape[0]):
|
||||
if actual_tensor[j].item() == 0.0:
|
||||
expect_tensor[j] = 0.0
|
||||
else:
|
||||
expect_tensor[j] /= 1.0 - p
|
||||
self.nt_equal(y, expect)
|
||||
with freeze_rng_state():
|
||||
dropouter = torch.nn.Dropout(p)
|
||||
y0 = dropouter(nt)
|
||||
with freeze_rng_state():
|
||||
y1 = torch.nn.functional.dropout(nt, p)
|
||||
self.nt_equal(y0, y1)
|
||||
# inplace
|
||||
# in principle, since we have established the correctness of functional, we could simply compare inplace vs functional
|
||||
# in practice, cuda functional has its own implementation to skip `bernoulli_`
|
||||
# so cuda functional will differ from cuda inplace causing test failure
|
||||
# in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)`
|
||||
# on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)`
|
||||
expect = nt.clone()
|
||||
torch.nn.functional.dropout(nt, p, inplace=True)
|
||||
for i in range(ntensors):
|
||||
actual_tensor = nt[i].view(-1)
|
||||
expect_tensor = expect[i].view(-1)
|
||||
for j in range(actual_tensor.shape[0]):
|
||||
if actual_tensor[j].item() == 0.0:
|
||||
expect_tensor[j] = 0.0
|
||||
else:
|
||||
expect_tensor[j] /= 1.0 - p
|
||||
self.nt_equal(nt, expect)
|
||||
|
||||
class TestNestedTensorAutograd(TestCase):
|
||||
def nt_equal(self, nt1, nt2):
|
||||
self.assertEqual(nt1.dtype, nt2.dtype)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user