pytorch/torch/quantization
Supriya Rao 1cf12b7e53 [quant] Fix histogram observer to work with QAT on GPU (#34232)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34232

By default `torch.zeros` creates the tensor on GPU. Need to specify the device argument to get it to work correctly on GPU during QAT.

Test Plan:
1. Tested by running QAT on GPU

2. python test/test_quantization.py

Imported from OSS

Differential Revision: D20286351

fbshipit-source-id: 745723c85d902870c56c1c7492f26cb027ae9dc6
2020-03-05 17:19:12 -08:00
..
__init__.py Ignore F401 in all __init__.py without putting noqa (#25823) 2019-10-23 15:28:13 -07:00
_quantize_script.py [quant][graphmode][refactor] Better API for fold_convbn (#32380) 2020-01-24 15:46:47 -08:00
default_mappings.py [quant] Add Quantized BatchNorm2d module (#33109) 2020-02-13 12:15:43 -08:00
fake_quantize.py Per channel quantization performance improvement (#33772) 2020-02-26 10:19:25 -08:00
fuse_modules.py Enable inplace relu fusion for training (#33105) 2020-02-14 12:15:58 -08:00
observer.py [quant] Fix histogram observer to work with QAT on GPU (#34232) 2020-03-05 17:19:12 -08:00
qconfig.py Updates to quantization documentation (#30288) 2019-11-23 09:29:30 -08:00
quantize.py [quantization] FP16 dynamic quantized Linear 2020-01-27 15:45:32 -08:00
stubs.py Factored out the default mappings 2019-10-03 11:52:21 -07:00