mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[ao_migration] torch/nn/quantized: torch.quantization -> torch.ao.quantization (#65900)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/65900 This changes the imports in the `caffe2/torch/nn/quantized` to include the new import locations. ``` codemod -d torch/nn/quantized --extensions py 'torch.quantization' 'torch.ao.quantization' ``` Test Plan: `python test/run_test.py` Reviewed By: jerryzh168 Differential Revision: D31301193 fbshipit-source-id: 58efb1ad51a8b441e2a3bd5b91af11eab6b9331f
This commit is contained in:
parent
f1f3bd8c36
commit
b23709df03
|
|
@ -77,7 +77,7 @@ class Linear(nnq.Linear):
|
|||
r"""Create a dynamic quantized module from a float module or qparams_dict
|
||||
|
||||
Args:
|
||||
mod (Module): a float module, either produced by torch.quantization
|
||||
mod (Module): a float module, either produced by torch.ao.quantization
|
||||
utilities or provided by the user
|
||||
"""
|
||||
float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
|
||||
|
|
@ -95,7 +95,7 @@ class Linear(nnq.Linear):
|
|||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
||||
# import until we need it.
|
||||
from torch.quantization.qconfig import default_dynamic_qconfig
|
||||
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
||||
weight_observer = default_dynamic_qconfig.weight()
|
||||
dtype = weight_observer.dtype
|
||||
assert dtype in [torch.qint8, torch.float16], "The only supported dtypes for " \
|
||||
|
|
|
|||
|
|
@ -214,7 +214,7 @@ class RNNBase(torch.nn.Module):
|
|||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
||||
# import until we need it.
|
||||
from torch.quantization.qconfig import default_dynamic_qconfig
|
||||
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
||||
weight_observer_method = default_dynamic_qconfig.weight
|
||||
|
||||
dtype = weight_observer_method().dtype
|
||||
|
|
@ -731,7 +731,7 @@ class RNNCellBase(torch.nn.Module):
|
|||
# We have the circular import issues if we import the qconfig in the beginning of this file:
|
||||
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
|
||||
# import until we need it.
|
||||
from torch.quantization.qconfig import default_dynamic_qconfig
|
||||
from torch.ao.quantization.qconfig import default_dynamic_qconfig
|
||||
weight_observer_method = default_dynamic_qconfig.weight
|
||||
|
||||
dtype = weight_observer_method().dtype
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ class Conv1d(_ConvNd):
|
|||
r"""Creates a quantized module from a float module or qparams_dict.
|
||||
|
||||
Args:
|
||||
mod (Module): a float module, either produced by torch.quantization
|
||||
mod (Module): a float module, either produced by torch.ao.quantization
|
||||
utilities or provided by the user
|
||||
"""
|
||||
return _ConvNd.from_float(cls, mod)
|
||||
|
|
@ -430,7 +430,7 @@ class Conv2d(_ConvNd):
|
|||
r"""Creates a quantized module from a float module or qparams_dict.
|
||||
|
||||
Args:
|
||||
mod (Module): a float module, either produced by torch.quantization
|
||||
mod (Module): a float module, either produced by torch.ao.quantization
|
||||
utilities or provided by the user
|
||||
"""
|
||||
return _ConvNd.from_float(cls, mod)
|
||||
|
|
@ -528,7 +528,7 @@ class Conv3d(_ConvNd):
|
|||
r"""Creates a quantized module from a float module or qparams_dict.
|
||||
|
||||
Args:
|
||||
mod (Module): a float module, either produced by torch.quantization
|
||||
mod (Module): a float module, either produced by torch.ao.quantization
|
||||
utilities or provided by the user
|
||||
"""
|
||||
return _ConvNd.from_float(cls, mod)
|
||||
|
|
@ -564,7 +564,7 @@ class _ConvTransposeNd(_ConvNd):
|
|||
def from_float(cls, mod):
|
||||
r"""Creates a quantized module from a float module or qparams_dict.
|
||||
Args:
|
||||
mod (Module): a float module, either produced by torch.quantization
|
||||
mod (Module): a float module, either produced by torch.ao.quantization
|
||||
utilities or provided by the user
|
||||
"""
|
||||
# derived classes override cls._FLOAT_MODULE attribute
|
||||
|
|
|
|||
|
|
@ -135,13 +135,13 @@ class Embedding(torch.nn.Module):
|
|||
r"""Create a quantized embedding module from a float module
|
||||
|
||||
Args:
|
||||
mod (Module): a float module, either produced by torch.quantization
|
||||
mod (Module): a float module, either produced by torch.ao.quantization
|
||||
utilities or provided by user
|
||||
"""
|
||||
assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
|
||||
nn.Embedding.__name__
|
||||
assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined'
|
||||
from torch.quantization import float_qparams_weight_only_qconfig
|
||||
from torch.ao.quantization import float_qparams_weight_only_qconfig
|
||||
if mod.qconfig is not None and mod.qconfig.weight is not None:
|
||||
weight_observer = mod.qconfig.weight()
|
||||
else:
|
||||
|
|
@ -220,7 +220,7 @@ class EmbeddingBag(Embedding):
|
|||
r"""Create a quantized embedding_bag module from a float module
|
||||
|
||||
Args:
|
||||
mod (Module): a float module, either produced by torch.quantization
|
||||
mod (Module): a float module, either produced by torch.ao.quantization
|
||||
utilities or provided by user
|
||||
"""
|
||||
if hasattr(mod, 'weight_fake_quant'):
|
||||
|
|
|
|||
|
|
@ -255,7 +255,7 @@ class Linear(torch.nn.Module):
|
|||
r"""Create a quantized module from a float module or qparams_dict
|
||||
|
||||
Args:
|
||||
mod (Module): a float module, either produced by torch.quantization
|
||||
mod (Module): a float module, either produced by torch.ao.quantization
|
||||
utilities or provided by the user
|
||||
"""
|
||||
if hasattr(mod, 'weight_fake_quant'):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user