mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Part of #123062 - #123062 Pull Request resolved: https://github.com/pytorch/pytorch/pull/128861 Approved by: https://github.com/ezyang
31 lines
806 B
Python
31 lines
806 B
Python
# mypy: allow-untyped-defs
|
|
import torch
|
|
|
|
|
|
__all__ = ["Dropout"]
|
|
|
|
|
|
class Dropout(torch.nn.Dropout):
|
|
r"""This is the quantized equivalent of :class:`~torch.nn.Dropout`.
|
|
And this is a placeholder to enable models where fp32 tensors
|
|
had dropout to work with quantized tensors in train and eval mode.
|
|
|
|
Args:
|
|
p: probability of an element to be zeroed
|
|
inplace: can optionally do the operation in-place. Default: ``False``
|
|
"""
|
|
|
|
def forward(self, input):
|
|
return input
|
|
|
|
def _get_name(self):
|
|
return "QuantizedDropout"
|
|
|
|
@classmethod
|
|
def from_float(cls, mod, use_precomputed_fake_quant=False):
|
|
return cls(mod.p, mod.inplace)
|
|
|
|
@classmethod
|
|
def from_reference(cls, mod, scale, zero_point):
|
|
return cls(mod.p, mod.inplace)
|