[ao_migration] torch/nn/quantized: torch.quantization -> torch.ao.quantization (#65900)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/65900

This changes the imports in the `caffe2/torch/nn/quantized` to include the new import locations.

```
codemod -d torch/nn/quantized --extensions py 'torch.quantization' 'torch.ao.quantization'
```

Test Plan: `python test/run_test.py`

Reviewed By: jerryzh168

Differential Revision: D31301193

fbshipit-source-id: 58efb1ad51a8b441e2a3bd5b91af11eab6b9331f
This commit is contained in:
Zafar Takhirov 2021-10-08 16:16:01 -07:00 committed by Facebook GitHub Bot
parent f1f3bd8c36
commit b23709df03
5 changed files with 12 additions and 12 deletions

View File

@ -77,7 +77,7 @@ class Linear(nnq.Linear):
r"""Create a dynamic quantized module from a float module or qparams_dict
Args:
mod (Module): a float module, either produced by torch.quantization
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
float_modules = [torch.nn.Linear, torch.nn.modules.linear.NonDynamicallyQuantizableLinear,
@ -95,7 +95,7 @@ class Linear(nnq.Linear):
# We have the circular import issues if we import the qconfig in the beginning of this file:
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
# import until we need it.
from torch.quantization.qconfig import default_dynamic_qconfig
from torch.ao.quantization.qconfig import default_dynamic_qconfig
weight_observer = default_dynamic_qconfig.weight()
dtype = weight_observer.dtype
assert dtype in [torch.qint8, torch.float16], "The only supported dtypes for " \

View File

@ -214,7 +214,7 @@ class RNNBase(torch.nn.Module):
# We have the circular import issues if we import the qconfig in the beginning of this file:
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
# import until we need it.
from torch.quantization.qconfig import default_dynamic_qconfig
from torch.ao.quantization.qconfig import default_dynamic_qconfig
weight_observer_method = default_dynamic_qconfig.weight
dtype = weight_observer_method().dtype
@ -731,7 +731,7 @@ class RNNCellBase(torch.nn.Module):
# We have the circular import issues if we import the qconfig in the beginning of this file:
# https://github.com/pytorch/pytorch/pull/24231. The current workaround is to postpone the
# import until we need it.
from torch.quantization.qconfig import default_dynamic_qconfig
from torch.ao.quantization.qconfig import default_dynamic_qconfig
weight_observer_method = default_dynamic_qconfig.weight
dtype = weight_observer_method().dtype

View File

@ -333,7 +333,7 @@ class Conv1d(_ConvNd):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.quantization
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
return _ConvNd.from_float(cls, mod)
@ -430,7 +430,7 @@ class Conv2d(_ConvNd):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.quantization
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
return _ConvNd.from_float(cls, mod)
@ -528,7 +528,7 @@ class Conv3d(_ConvNd):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.quantization
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
return _ConvNd.from_float(cls, mod)
@ -564,7 +564,7 @@ class _ConvTransposeNd(_ConvNd):
def from_float(cls, mod):
r"""Creates a quantized module from a float module or qparams_dict.
Args:
mod (Module): a float module, either produced by torch.quantization
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
# derived classes override cls._FLOAT_MODULE attribute

View File

@ -135,13 +135,13 @@ class Embedding(torch.nn.Module):
r"""Create a quantized embedding module from a float module
Args:
mod (Module): a float module, either produced by torch.quantization
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by user
"""
assert type(mod) == nn.Embedding, 'nnq.' + cls.__name__ + '.from_float only works for ' + \
nn.Embedding.__name__
assert hasattr(mod, 'qconfig'), 'Embedding input float module must have qconfig defined'
from torch.quantization import float_qparams_weight_only_qconfig
from torch.ao.quantization import float_qparams_weight_only_qconfig
if mod.qconfig is not None and mod.qconfig.weight is not None:
weight_observer = mod.qconfig.weight()
else:
@ -220,7 +220,7 @@ class EmbeddingBag(Embedding):
r"""Create a quantized embedding_bag module from a float module
Args:
mod (Module): a float module, either produced by torch.quantization
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by user
"""
if hasattr(mod, 'weight_fake_quant'):

View File

@ -255,7 +255,7 @@ class Linear(torch.nn.Module):
r"""Create a quantized module from a float module or qparams_dict
Args:
mod (Module): a float module, either produced by torch.quantization
mod (Module): a float module, either produced by torch.ao.quantization
utilities or provided by the user
"""
if hasattr(mod, 'weight_fake_quant'):