mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Support dropout(nested tensor) (#79318)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/79318 Approved by: https://github.com/jbschlosser
This commit is contained in:
parent
91c5fc323b
commit
1211ab679c
|
|
@ -240,8 +240,14 @@
|
||||||
- func: _shape_as_tensor(Tensor self) -> Tensor
|
- func: _shape_as_tensor(Tensor self) -> Tensor
|
||||||
|
|
||||||
- func: dropout(Tensor input, float p, bool train) -> Tensor
|
- func: dropout(Tensor input, float p, bool train) -> Tensor
|
||||||
|
dispatch:
|
||||||
|
CompositeImplicitAutograd: dropout
|
||||||
|
NestedTensorCPU, NestedTensorCUDA: dropout_nested
|
||||||
|
|
||||||
- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
|
- func: dropout_(Tensor(a!) self, float p, bool train) -> Tensor(a!)
|
||||||
|
dispatch:
|
||||||
|
CompositeImplicitAutograd: dropout_
|
||||||
|
NestedTensorCPU, NestedTensorCUDA: dropout_nested_
|
||||||
|
|
||||||
- func: feature_dropout(Tensor input, float p, bool train) -> Tensor
|
- func: feature_dropout(Tensor input, float p, bool train) -> Tensor
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -647,5 +647,19 @@ at::Tensor NestedTensor_get_nested_size_tensor(const at::Tensor& self){
|
||||||
return get_nested_size_tensor(self);
|
return get_nested_size_tensor(self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Tensor dropout_nested(const Tensor& input, double p, bool train) {
|
||||||
|
auto input_ptr = get_nested_tensor_impl(input);
|
||||||
|
const Tensor & input_buffer = input_ptr->get_buffer(),
|
||||||
|
sizemat = input_ptr->get_nested_size_tensor();
|
||||||
|
Tensor output_buffer = at::dropout(input_buffer, p, train);
|
||||||
|
return wrap_buffer(output_buffer, sizemat.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor& dropout_nested_(Tensor& input, double p, bool train) {
|
||||||
|
Tensor input_buffer = get_buffer(input);
|
||||||
|
at::dropout_(input_buffer, p, train);
|
||||||
|
return input;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace native
|
} // namespace native
|
||||||
} // namespace at
|
} // namespace at
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ from torch.testing._internal.common_device_type import (
|
||||||
instantiate_device_type_tests,
|
instantiate_device_type_tests,
|
||||||
skipMeta,
|
skipMeta,
|
||||||
)
|
)
|
||||||
from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests
|
from torch.testing._internal.common_utils import TestCase, IS_FBCODE, run_tests, freeze_rng_state
|
||||||
from torch import nested_tensor
|
from torch import nested_tensor
|
||||||
|
|
||||||
# Tests are ported from pytorch/nestedtensor.
|
# Tests are ported from pytorch/nestedtensor.
|
||||||
|
|
@ -482,6 +482,73 @@ class TestNestedTensorDeviceType(TestCase):
|
||||||
with self.assertRaisesRegex(RuntimeError, msg):
|
with self.assertRaisesRegex(RuntimeError, msg):
|
||||||
nt1.clone(memory_format=torch.channels_last)
|
nt1.clone(memory_format=torch.channels_last)
|
||||||
|
|
||||||
|
# cannot test torch.float16 because: RuntimeError: "bernoulli_scalar_cpu_" not implemented for 'Half'
|
||||||
|
@dtypes(torch.float, torch.double)
|
||||||
|
def test_dropout(self, device, dtype):
|
||||||
|
# edge case: empty nested tensor
|
||||||
|
nt0 = torch.nested_tensor([])
|
||||||
|
y = torch.nn.functional.dropout(nt0, 0.5)
|
||||||
|
self.nt_equal(nt0, y)
|
||||||
|
# normal nested tensor
|
||||||
|
ntensors = 4
|
||||||
|
nt = self.random_nt(device, dtype, ntensors, (4, 4))
|
||||||
|
# edge case: invalid dropout
|
||||||
|
self.assertRaises(ValueError, lambda: torch.nn.Dropout(-0.1))
|
||||||
|
self.assertRaises(ValueError, lambda: torch.nn.Dropout(1.1))
|
||||||
|
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, -0.1))
|
||||||
|
self.assertRaises(ValueError, lambda: torch.nn.functional.dropout(nt, 1.1))
|
||||||
|
# edge case: no dropout
|
||||||
|
dropouter = torch.nn.Dropout(0.0)
|
||||||
|
y0 = dropouter(nt)
|
||||||
|
y1 = torch.nn.functional.dropout(nt, 0.0)
|
||||||
|
self.nt_equal(nt, y0)
|
||||||
|
self.nt_equal(nt, y1)
|
||||||
|
# edge case: all dropout
|
||||||
|
dropouter = torch.nn.Dropout(1.0)
|
||||||
|
y0 = dropouter(nt)
|
||||||
|
y1 = torch.nn.functional.dropout(nt, 1.0)
|
||||||
|
nt0 = nt.clone()
|
||||||
|
for i in range(ntensors):
|
||||||
|
nt0[i].fill_(0.0)
|
||||||
|
self.nt_equal(nt0, y0)
|
||||||
|
self.nt_equal(nt0, y1)
|
||||||
|
# normal case: normal dropout
|
||||||
|
p = 0.2
|
||||||
|
y = torch.nn.functional.dropout(nt, p)
|
||||||
|
expect = nt.clone()
|
||||||
|
for i in range(ntensors):
|
||||||
|
actual_tensor = y[i].view(-1)
|
||||||
|
expect_tensor = expect[i].view(-1)
|
||||||
|
for j in range(actual_tensor.shape[0]):
|
||||||
|
if actual_tensor[j].item() == 0.0:
|
||||||
|
expect_tensor[j] = 0.0
|
||||||
|
else:
|
||||||
|
expect_tensor[j] /= 1.0 - p
|
||||||
|
self.nt_equal(y, expect)
|
||||||
|
with freeze_rng_state():
|
||||||
|
dropouter = torch.nn.Dropout(p)
|
||||||
|
y0 = dropouter(nt)
|
||||||
|
with freeze_rng_state():
|
||||||
|
y1 = torch.nn.functional.dropout(nt, p)
|
||||||
|
self.nt_equal(y0, y1)
|
||||||
|
# inplace
|
||||||
|
# in principle, since we have established the correctness of functional, we could simply compare inplace vs functional
|
||||||
|
# in practice, cuda functional has its own implementation to skip `bernoulli_`
|
||||||
|
# so cuda functional will differ from cuda inplace causing test failure
|
||||||
|
# in `test_dropout_cuda_float64 (__main__.TestNestedTensorDeviceTypeCUDA)`
|
||||||
|
# on `linux-xenial-cuda11.3-py3.7-gcc7 / test (default, 2, 4, linux.4xlarge.nvidia.gpu)`
|
||||||
|
expect = nt.clone()
|
||||||
|
torch.nn.functional.dropout(nt, p, inplace=True)
|
||||||
|
for i in range(ntensors):
|
||||||
|
actual_tensor = nt[i].view(-1)
|
||||||
|
expect_tensor = expect[i].view(-1)
|
||||||
|
for j in range(actual_tensor.shape[0]):
|
||||||
|
if actual_tensor[j].item() == 0.0:
|
||||||
|
expect_tensor[j] = 0.0
|
||||||
|
else:
|
||||||
|
expect_tensor[j] /= 1.0 - p
|
||||||
|
self.nt_equal(nt, expect)
|
||||||
|
|
||||||
class TestNestedTensorAutograd(TestCase):
|
class TestNestedTensorAutograd(TestCase):
|
||||||
def nt_equal(self, nt1, nt2):
|
def nt_equal(self, nt1, nt2):
|
||||||
self.assertEqual(nt1.dtype, nt2.dtype)
|
self.assertEqual(nt1.dtype, nt2.dtype)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user