Commit Graph

251 Commits

Author SHA1 Message Date
Jerry Zhang
ace645a017 Add support for prototype affine quantization in pt2e flow (#141421)
Summary:
duplicated affine quantization functionality including
observer (https://github.com/pytorch/ao/blob/main/torchao/quantization/observer.py)
and some quant_primitive ops (7c3c51fd0d/torchao/quantization/quant_primitives.py (L26-L30))
to allow for per group quantization min max observer in pt2e flow

Next: We can follow up to add moving average min max observer

Test Plan:
python test/test_quantization.py -k test_channel_group_quantization

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141421
Approved by: https://github.com/cccclai
2024-12-24 04:22:18 +00:00
Nikhil Gupta
94737e8a2a [ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)
Description:
1. Quantize Linear Layer Weights to 4-bits:
Quantize the weights of the Linear layer to 4 bits, using symmetric quantization.
Pack two 4-bit weights into one uint8 container.
Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32.

2. Prepare Quantized Weights, Scales, and Optional Bias:
After quantizing, obtain the quantized_weights, scales, and groupsize.
If the original Linear layer has a bias, prepare it as well.

3. Pack the Weights Efficiently:
Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias.
```python
packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features)
```
Input parameters should include:
in_features and out_features (the same as the Linear layer’s corresponding parameters).

4. Perform Dynamic Quantized Matrix Multiplication:
Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights.
```python
output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights,  groupsize, in_features, out_features)
```
Inputs required include:
The input tensor, packed_weights , groupsize, and the in_features and out_features.

API Usage: https://github.com/pytorch/pytorch/issues/143289

Model Perf :
7B Transformer model:
Prefill : 340 t/s
Decode  : 40  t/s
2B Transformer model
Prefill : 747 t/s
Decode  : 80  t/s

Tests:
python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight
Ran 1 test in 0.016s

OK

python test/test_linalg.py -k test__dyn_quant_matmul_4bit
Ran 8 tests in 0.077s

OK

python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit
Ran 8 tests in 11.454s

Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124
Approved by: https://github.com/digantdesai, https://github.com/malfet
2024-12-20 19:32:03 +00:00
PyTorch MergeBot
8136daff5a Revert "[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)"
This reverts commit 4b82251011.

Reverted https://github.com/pytorch/pytorch/pull/134124 on behalf of https://github.com/huydhn due to Sorry for reverting your change but it breaks lots of internal build ([comment](https://github.com/pytorch/pytorch/pull/134124#issuecomment-2555953189))
2024-12-19 23:33:17 +00:00
Nikhil Gupta
4b82251011 [ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)
Description:
1. Quantize Linear Layer Weights to 4-bits:
Quantize the weights of the Linear layer to 4 bits, using symmetric quantization.
Pack two 4-bit weights into one uint8 container.
Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32.

2. Prepare Quantized Weights, Scales, and Optional Bias:
After quantizing, obtain the quantized_weights, scales, and groupsize.
If the original Linear layer has a bias, prepare it as well.

3. Pack the Weights Efficiently:
Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias.
```python
packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features)
```
Input parameters should include:
in_features and out_features (the same as the Linear layer’s corresponding parameters).

4. Perform Dynamic Quantized Matrix Multiplication:
Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights.
```python
output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights,  groupsize, in_features, out_features)
```
Inputs required include:
The input tensor, packed_weights , groupsize, and the in_features and out_features.

API Usage: https://github.com/pytorch/pytorch/issues/143289

Model Perf :
7B Transformer model:
Prefill : 340 t/s
Decode  : 40  t/s
2B Transformer model
Prefill : 747 t/s
Decode  : 80  t/s

Tests:
python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight
Ran 1 test in 0.016s

OK

python test/test_linalg.py -k test__dyn_quant_matmul_4bit
Ran 8 tests in 0.077s

OK

python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit
Ran 8 tests in 11.454s

Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124
Approved by: https://github.com/digantdesai, https://github.com/malfet
2024-12-19 18:51:26 +00:00
PyTorch MergeBot
14fe1f7190 Revert "[ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)"
This reverts commit d3ff2d42c2.

Reverted https://github.com/pytorch/pytorch/pull/134124 on behalf of https://github.com/malfet due to This broke S390 builds, includes cpuinfo unconditionally ([comment](https://github.com/pytorch/pytorch/pull/134124#issuecomment-2552560208))
2024-12-19 01:05:11 +00:00
Nikhil Gupta
d3ff2d42c2 [ARM][feat]: Add 4 bit dynamic quantization matmuls & KleidiAI Backend (#134124)
Description:
1. Quantize Linear Layer Weights to 4-bits:
Quantize the weights of the Linear layer to 4 bits, using symmetric quantization.
Pack two 4-bit weights into one uint8 container.
Choose a quantization scheme (channel-wise or group-wise), with the group size being a multiple of 32.

2. Prepare Quantized Weights, Scales, and Optional Bias:
After quantizing, obtain the quantized_weights, scales, and groupsize.
If the original Linear layer has a bias, prepare it as well.

3. Pack the Weights Efficiently:
Use torch.ops.aten._dyn_quant_pack_4bit_weight to optimally pack the weights, scales, and optional bias.
```python
packed_weights = torch.ops.aten._dyn_quant_pack_4bit_weight(weight, scales_and_zeros, bias, groupsize, in_features, out_features)
```
Input parameters should include:
in_features and out_features (the same as the Linear layer’s corresponding parameters).

4. Perform Dynamic Quantized Matrix Multiplication:
Use torch.ops.aten._dyn_quant_matmul_4bit to perform matrix multiplication with quantized weights.
```python
output = torch.ops.aten._dyn_quant_matmul_4bit(input, packed_weights,  groupsize, in_features, out_features)
```
Inputs required include:
The input tensor, packed_weights , groupsize, and the in_features and out_features.

API Usage: https://github.com/pytorch/pytorch/issues/143289

Model Perf :
7B Transformer model:
Prefill : 340 t/s
Decode  : 40  t/s
2B Transformer model
Prefill : 747 t/s
Decode  : 80  t/s

Tests:
python test/test_linalg.py -k test__dyn_quant_pack_4bit_weight
Ran 1 test in 0.016s

OK

python test/test_linalg.py -k test__dyn_quant_matmul_4bit
Ran 8 tests in 0.077s

OK

python test/test_linalg.py -k test_compile_dyn_quant_matmul_4bit
Ran 8 tests in 11.454s

Change-Id: Ia1672bad5e6ec94e64d8bb1971395d60f4b3a452

Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134124
Approved by: https://github.com/digantdesai, https://github.com/malfet
2024-12-18 22:30:07 +00:00
gasoonjia
91261107e0 debug handler maintain through decomposition (#141612)
Add checks in the ao numberic debugger to guard the debug handle consistency between aten op decomposition

Differential Revision: [D66517480](https://our.internmc.facebook.com/intern/diff/D66517480/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141612
Approved by: https://github.com/jerryzh168
2024-12-12 12:26:45 +00:00
gasoonjia
ff059587c6 support condition branch in ao debug handler (#141516)
This diff introduced the supportive of condition statement into ao debug handler generation.

Most of code borrowed from ExecuTorch to avoid circle dependency issue.

Differential Revision: [D66270691](https://our.internmc.facebook.com/intern/diff/D66270691/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/141516
Approved by: https://github.com/jerryzh168
2024-12-10 14:05:12 +00:00
PyTorch MergeBot
09ce760fef Revert "Add missing data types at torch export serialization (#138561)"
This reverts commit 1ef1b3b391.

Reverted https://github.com/pytorch/pytorch/pull/138561 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/138561#issuecomment-2513343401))
2024-12-03 01:32:50 +00:00
yintong-lu
1ef1b3b391 Add missing data types at torch export serialization (#138561)
Related to #131654

Added missing FP8 data types at torch export serialization.
Added test cases of FP8 data types.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138561
Approved by: https://github.com/jerryzh168, https://github.com/jgong5
2024-11-28 08:35:03 +00:00
ZhiweiYan-96
c418a9ac75 [Intel GPU] XPUInductorQuantizer for XPU int8 recipe customization (#139578)
# Motivation
This PR add `XPUInductorQuantizer`, which would defined the recipe of int8 quantization at XPU backend.

# Detailed
The `XPUInductorQuantizer` is class derived from `X86InductorQuantizer` as both quantizer would take the advantage of highly optimized operators in oneDNN library(qconv, qlinear, qconv/qlinear fusion).

We share the same recipe as `X86InductorQuantizer`, so we would have same `annotate_xxxx` methods.  So, in ideal situation, the `XPUInductorQuantizer` would have no class body as all implementation can inherit from base class.

In this PR, we override the `annotate_xxx` method for operators that has NOT be implemented. All operators XPU backend does  not implement would be fallbacked to fp32 implementation as the node in graph is a `dq-op-q` pairs. This would help provide good OOB usability for XPU backend.   On the other hand, the implemented operators would uses `annotate_op` implemented in base class and could be lowered successfully.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139578
Approved by: https://github.com/EikanWang, https://github.com/leslie-fang-intel, https://github.com/CuiYifeng, https://github.com/jerryzh168
ghstack dependencies: #133080
2024-11-26 09:44:14 +00:00
Aaron Gokaslan
12e95aa4ee [BE]: Apply PERF401 autofixes from ruff (#140980)
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables.
* list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize.
* Manually went back and made mypy happy after the change.
* Also fixed style lints in files covered by flake8 but not by pyfmt

Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980
Approved by: https://github.com/justinchuby, https://github.com/malfet
2024-11-20 17:52:07 +00:00
Jiang, Yanbing
f77eb07662 Split int4wo weight packing (#139611)
Fixes https://github.com/pytorch/ao/issues/1117.

This PR is to seperate int4wo weight packing between CPU and other devices, to help implement `INT4CPULayout` in torchao based on https://github.com/pytorch/ao/issues/1117#issuecomment-2451252756.

Now, for CPU, the input `weight` of `_convert_weight_to_int4pack_for_cpu` is [n, k] int32, output is [n, k / 2] uint8. The input packed weight of `_weight_int4pack_mm_for_cpu` is [n, k / 2] uint8.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/139611
Approved by: https://github.com/jerryzh168
2024-11-12 10:12:50 +00:00
Tom Ritchford
c0582fd0f8 Remove unused Python variables in torch/[b-z]* (#136963)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/136963
Approved by: https://github.com/ezyang
2024-10-19 16:45:22 +00:00
Shangdi Yu
c83178d894 Change to export_for_training in XNNPACK tests (#137238)
Summary: as title

Test Plan: CI

Differential Revision: D63344674

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137238
Approved by: https://github.com/tugsbayasgalan
2024-10-03 21:28:05 +00:00
Shangdi Yu
590a3e9f8a [export][training ir migration] quantized_decomposed.quantize_per_tensor decomposition (#134525)
Summary:
In graph of  TestXNNPACKQuantizer.test_dynamic_linear_with_con test, some quantized_decomposed.quantize_per_tensor.default ops are becoming quantized_decomposed.dequantize_per_tensor.tensor ops when using the new training ir.

This is because we lift params/buffers before calling make_fx. So previously, for the graph that’s passed to make_fx,`graph.L__self___linear1.weight` is a tensor
now in training ir, graph.L__self___linear1.weight is a FakeTensor. This caused the node overload to be different.

Test Plan:
```
buck2 run 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- -r test_dynamic_linear_with_conv
```

Differential Revision: D61364547

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134525
Approved by: https://github.com/tugsbayasgalan, https://github.com/jerryzh168
2024-09-06 07:06:06 +00:00
Mikayla Gawarecki
d9576c9440 Fix failures when default is flipped for weights_only (#127627)
Tests on XLA shard not fixed yet but there is an issue here https://github.com/pytorch/xla/issues/7799

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127627
Approved by: https://github.com/albanD
ghstack dependencies: #132349
2024-08-16 00:22:43 +00:00
Oguz Ulgen
72d2dba992 Add None return type to init (#132335)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132335
Approved by: https://github.com/albanD
2024-08-01 15:26:45 +00:00
kausik
4f60a2e39c Set correct output dtype for dequantize op during convert_pt2e in decomposed mode (#128953)
Earlier the signature of dequantize ops for decomposed quantized Tensor was changed for wider use-cases where the output dtype can be different from torch.float and needs to be passed during dequantization.
Please refer: https://github.com/pytorch/pytorch/pull/121450

However, setting of correct output dtype for dequantize ops was still missing in convert_pt2e flow.

This change enables the users to use PT2E quantization flow with non torch.float unquantized dtype, such as torch.bfloat16.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128953
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
2024-07-19 04:58:02 +00:00
leslie-fang-intel
2a1f22e57f Change BN to eval before QAT Convert phase (#130598)
**Summary**
In the QAT convert phase, we fold bn into conv and do DCE to this BN node. We should change `torch.ops.aten._native_batch_norm_legit.default` to `torch.ops.aten._native_batch_norm_legit_no_training.default`  for a safe DCE.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130598
Approved by: https://github.com/jgong5, https://github.com/yushangdi
2024-07-12 16:03:56 +00:00
Jiang, Yanbing
6f662e9575 update the input weight of _convert_weight_to_int4pack to [n][k / 2] uint8 (#129940)
This PR is to update the input `weight` of `_convert_weight_to_int4pack` from `[n][k] int32` to `[n][k / 2] uint8`, both for CPU, CUDA and MPS, which can help decouple int4 model checkpoint with different ISAs and different platforms in `gpt-fast`. The advantage is int4 model checkpoint can be shared in different test machines, without re-generating in one certain platform. Meanwhile, the size of input `weight` can be reduced to `1 / 8`.

Before this PR, packed weight stored in CUDA specific layout: `[n/8][k/(InnerKTiles*16)][32][InnerKTiles/2]`, dtype int32, where InnerKTiles = 2, 4, 8. CPU packed weight viewed as the SAME shape but stored in different layout: `[n/64][k][32]`, dtype uint8. Weight is strongly coupled with platforms (CPU/CUDA) and ISAs (AVX512/AVX2/scalar). And users cannot use a generated weight in another different ISA or platform, because when loading weight into devices, the compute format is different.
![image](https://github.com/pytorch/pytorch/assets/61222868/64971c4b-29b9-42cf-9aeb-ffa01cea93dd)

Now, we use common serialized layout (`[n][k/2] uint8`) for different devices or ISAs as input `weight` of `_convert_weight_to_int4pack`, and each back chooses how to interpret as compute layout.
![image](https://github.com/pytorch/pytorch/assets/61222868/c7990761-c723-417b-aca2-7c60db7785c7)

### Performance
Intel (R) Xeon (R) CPU Max 9480, single socket (56 cores)
There is no obvious regression of this PR.
![image](https://github.com/pytorch/pytorch/assets/61222868/6046dcf4-920b-4c63-9ca3-1c8c3cafebde)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129940
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/mingfeima
2024-07-11 15:26:48 +00:00
PyTorch MergeBot
637cc8d27f Revert "update the input weight of _convert_weight_to_int4pack to [n][k / 2] uint8 (#129940)"
This reverts commit 6367f02a0e.

Reverted https://github.com/pytorch/pytorch/pull/129940 on behalf of https://github.com/albanD due to Broke rocm tests on main 6367f02a0e ([comment](https://github.com/pytorch/pytorch/pull/129940#issuecomment-2220554681))
2024-07-10 13:48:32 +00:00
Jiang, Yanbing
6367f02a0e update the input weight of _convert_weight_to_int4pack to [n][k / 2] uint8 (#129940)
This PR is to update the input `weight` of `_convert_weight_to_int4pack` from `[n][k] int32` to `[n][k / 2] uint8`, both for CPU, CUDA and MPS, which can help decouple int4 model checkpoint with different ISAs and different platforms in `gpt-fast`. The advantage is int4 model checkpoint can be shared in different test machines, without re-generating in one certain platform. Meanwhile, the size of input `weight` can be reduced to `1 / 8`.

Before this PR, packed weight stored in CUDA specific layout: `[n/8][k/(InnerKTiles*16)][32][InnerKTiles/2]`, dtype int32, where InnerKTiles = 2, 4, 8. CPU packed weight viewed as the SAME shape but stored in different layout: `[n/64][k][32]`, dtype uint8. Weight is strongly coupled with platforms (CPU/CUDA) and ISAs (AVX512/AVX2/scalar). And users cannot use a generated weight in another different ISA or platform, because when loading weight into devices, the compute format is different.
![image](https://github.com/pytorch/pytorch/assets/61222868/64971c4b-29b9-42cf-9aeb-ffa01cea93dd)

Now, we use common serialized layout (`[n][k/2] uint8`) for different devices or ISAs as input `weight` of `_convert_weight_to_int4pack`, and each back chooses how to interpret as compute layout.
![image](https://github.com/pytorch/pytorch/assets/61222868/c7990761-c723-417b-aca2-7c60db7785c7)

### Performance
Intel (R) Xeon (R) CPU Max 9480, single socket (56 cores)
There is no obvious regression of this PR.
![image](https://github.com/pytorch/pytorch/assets/61222868/6046dcf4-920b-4c63-9ca3-1c8c3cafebde)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129940
Approved by: https://github.com/jgong5, https://github.com/lezcano, https://github.com/mingfeima
2024-07-10 07:38:42 +00:00
leslie-fang-intel
35a197defa [Inductor][CPP] Enable Quantized Linear GEMM Template with FP32 output (#128825)
**Summary**
Support int8 GEMM Template with refer MicroInt8GEMM kernel for case:

- Activation dtype: uint8
- Weight dtype: int8
- Output dtype: float32/bfloat16
- Post Op Fusion: without unary post operator fusion

**Test Plan**
```
clear && python -u -m pytest -s -v test/inductor/test_cpu_select_algorithm.py -k test_quantized_linear_with_pointwise
```

**Next Step**
- [ ] Unary post op fusion
- [ ] Int8 output
- [ ] Binary Fusion
- [ ] AMX int8 MicroGEMM Kernel

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128825
Approved by: https://github.com/jgong5, https://github.com/jansel
2024-06-30 09:45:43 +00:00
cyy
163847b1bb [1/N] [Caffe2] Remove caffe2_aten_fallback code (#128675)
Fixes #ISSUE_NUMBER

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128675
Approved by: https://github.com/r-barnes
2024-06-17 21:25:59 +00:00
Nikita Shulga
4ff9113e3d [MPS] Add _weight_int8pack_mm tests (#127041)
As well as extend the test to cover MV cases (where A matrix is 1xM) Limit int8 op testing to 32x32 matrix sizes for now

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127041
Approved by: https://github.com/larryliu0820, https://github.com/manuelcandales
2024-05-24 16:08:06 +00:00
Nikita Shulga
30610251ec [MPS] And naive quantized intmm and .gputrace capture hooks (#125163)
- Implement a very straightforward Metal copy of CPU int4mm kernel
- Implement int8mm kernel by constructing a graph consisting of upcast, transpose and mm
- Add `isCapturing`, `isCaptureEnabled`, `startCapture` and `stopCapture` methods to `MPSProfile` which can be used to help one debug/profile Metal kernels by wrapping the calls with the following
  ```cpp
   if (getMPSProfiler().profiler.isCaptureEnabled()) {
     getMPSProfiler().startCapture(__func__, mpsStream);
   }
   ...
   if (getMPSProfiler().isCapturing()) {
     getMPSProfiler().stopCapture(mpsStream);
   }
  ```
  that, if invoked with `MTL_CAPTURE_ENABLED` environment variable set to one, will produce .gputrace files, in the current working directory, which can later be loaded and used to debug or profiler the kernel
<img width="1093" alt="image" src="https://github.com/pytorch/pytorch/assets/2453524/a2bf27e8-df8a-442c-a525-1df67b8a376a">

- Added `test_int4mm` to TestLinalgMPS, which is mostly copy-n-paste of the test from `test_linalg`

TODOs:
 - Add weight pack
 - Perf-tune both kernels
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125163
Approved by: https://github.com/mikekgfb
2024-05-03 15:20:39 +00:00
Aaron Gokaslan
5a1216bb2e [BE]: Update ruff to 0.4.1 (#124549)
Update ruff to 0.4.1 .
This version fixes a lot false negatives/false positives, is 20-40% faster, and has various other bug fixes.

Below is a before and after table showing the execution time of ruff lint and ruff format in milliseconds courtesy of https://astral.sh/blog/ruff-v0.4.0

| Repository                                         | Linter (v0.3) | Linter (v0.4) | Formatter (v0.3) | Formatter (v0.4) |
|----------------------------------------------------|---------------|---------------|------------------|------------------|
| [pytorch/pytorch](https://github.com/pytorch/pytorch) | 328.7         | 251.8         | 351.1            | 274.9            |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/124549
Approved by: https://github.com/ezyang
2024-04-21 14:06:23 +00:00
andrewor14
ea8e0c75c7 [quant][pt2] Fix create FQ with FixedQParamsQSpec (#122104)
Summary: Before we just returned a _PartialWrapper object when
using FixedQParamsQuantizationSpec in QAT. This is wrong and
we should return a FQ object instead.

Differential Revision: [D55021106](https://our.internmc.facebook.com/intern/diff/D55021106)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122104
Approved by: https://github.com/jerryzh168
2024-03-22 14:23:05 +00:00
Jerry Zhang
901ba2be86 [quant][pt2e] Add support for conv transpose + bn + {relu} weights fusion in PTQ (#122046)
Summary:

also added some utils in xnnpack_quantizer_utils.py
* annotate_conv_tranpsose_bn_relu and annotate_conv_transpose_bn -> this is for QAT
* annotate_conv_transpose_relu

conv_transpose + bn weights fusion is performed automatically and can not be disabled currently
we can add support to allow disable this fusion later if needed

Test Plan:
python test/test_quantization.py -k test_conv_transpose_bn_fusion

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/122046
Approved by: https://github.com/andrewor14
2024-03-19 21:00:57 +00:00
Avik Chaudhuri
f351a71dbb remove constraints from capture_pre_autograd_graph (#120981)
Differential Revision: D54407296

Pull Request resolved: https://github.com/pytorch/pytorch/pull/120981
Approved by: https://github.com/zhxchen17
2024-03-02 07:00:51 +00:00
gs-olive
e0f6fa6a7c Windows Dynamo Error Removal CI Check (#115969)
Rebase of #111313 onto `main`, for CI validation

Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115969
Approved by: https://github.com/PaliC, https://github.com/thiagocrepaldi
2024-02-14 21:14:36 +00:00
PyTorch MergeBot
4a5b2cd6cb Revert "Windows Dynamo Error Removal CI Check (#115969)"
This reverts commit 45e7af5818.

Reverted https://github.com/pytorch/pytorch/pull/115969 on behalf of https://github.com/PaliC due to this pr ended up breaking some of our periodic tests ([comment](https://github.com/pytorch/pytorch/pull/115969#issuecomment-1942934386))
2024-02-14 01:11:46 +00:00
Jerry Zhang
7082e24ce8 [quant][pt2e][bc-breaking] Set fold_quantize to True in convert_pt2e (#119425)
Summary: This is a follow up to https://github.com/pytorch/pytorch/pull/118605 to set `fold_quantize` flag to True in `convert_pt2e`

Test Plan: CI

Differential Revision: D53550237

Pull Request resolved: https://github.com/pytorch/pytorch/pull/119425
Approved by: https://github.com/andrewor14
2024-02-09 18:13:43 +00:00
gs-olive
45e7af5818 Windows Dynamo Error Removal CI Check (#115969)
Rebase of #111313 onto `main`, for CI validation

Co-authored-by: Stella Laurenzo <stellaraccident@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115969
Approved by: https://github.com/ezyang
2024-02-08 21:23:45 +00:00
PyTorch MergeBot
81abc2b249 Revert "[quant][pt2e][bc-breaking] Remove fold_quantize flag (#118701)"
This reverts commit 482d952e88.

Reverted https://github.com/pytorch/pytorch/pull/118701 on behalf of https://github.com/facebook-github-bot due to Diff reverted internally ([comment](https://github.com/pytorch/pytorch/pull/118701#issuecomment-1932866964))
2024-02-07 20:56:16 +00:00
Jerry Zhang
482d952e88 [quant][pt2e][bc-breaking] Remove fold_quantize flag (#118701)
Summary:
This is a follow up to https://github.com/pytorch/pytorch/pull/118605 to remove `fold_quantize` flag from
`convert_pt2e`

Test Plan: CI

Differential Revision: D53247301

BC Breaking Note:

flag `fold_quantize` set to True `convert_pt2e` and now we'll fold the quantize op in the weight by default, so users will see model size reduction by default after pt2e quantization.
2.2
```
folded_model = convert_pt2e(model, fold_quantize=True)

non_folded_model = convert_pt2e(model)
```

2.3
```
folded_model = convert_pt2e(model)

non_folded_model = convert_pt2e(model, fold_quantize=False)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118701
Approved by: https://github.com/andrewor14, https://github.com/leslie-fang-intel
2024-02-07 19:10:51 +00:00
Edward Z. Yang
9bce208dfb Replace follow_imports = silent with normal (#118414)
This is a lot of files changed! Don't panic! Here's how it works:

* Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file.
* When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded.
* The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors.
* Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list.
* Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves.
* torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state.
* There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many.

In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file.

The codemod was done with this script authored by GPT-4:

```
import glob

exclude_patterns = [
    ...
]

for pattern in exclude_patterns:
    for filepath in glob.glob(pattern, recursive=True):
        if filepath.endswith('.py'):
            with open(filepath, 'r+') as f:
                content = f.read()
                f.seek(0, 0)
                f.write('# mypy: ignore-errors\n\n' + content)
```

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414
Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
2024-01-27 02:44:11 +00:00
leslie-fang-intel
263cc12fab Add Dynamo Reset in PT2E Quantization testing (#117200)
**Summary**
Fix https://github.com/pytorch/pytorch/issues/117012 by adding `torch._dynamo.reset()` in `PT2EQuantizationTestCase._quantize`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/117200
Approved by: https://github.com/jerryzh168
2024-01-11 05:53:55 +00:00
Max Ren
d2033a0639 [quant][pt2e][xnnpack_quantizer] add support for linear_relu (#117052)
Add support for linear_relu annotation for XNNPACKQuantizer, this allows the input to linear and the output to relu to share the same quantization parameter.s

Differential Revision: [D52574086](https://our.internmc.facebook.com/intern/diff/D52574086/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117052
Approved by: https://github.com/jerryzh168, https://github.com/digantdesai
2024-01-09 23:19:52 +00:00
Jerry Zhang
28e2e12b2a [quant][be] enable xnnpack_quantizer tests to run in internal CI (#116911)
Summary: fixed an import problem for test_xnnpack_quantizer so that it can run in CI

Test Plan:
internal CI
sanity check: buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/quantization:test_quantization -- --exact 'caffe2/test/quantization:test_quantization - test_conv2d (caffe2.test.quantization.pt2e.test_xnnpack_quantizer.TestXNNPACKQuantizer)'

Differential Revision: D52576449

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116911
Approved by: https://github.com/mcr229
2024-01-08 23:43:47 +00:00
Aaron Gokaslan
bd10fea79a [BE]: Enable F821 and fix bugs (#116579)
Fixes #112371

I tried to fix as many of the bugs as I could, a few I could not figure out what the proper fix for them was though and so I left them with noqas.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116579
Approved by: https://github.com/ezyang
2024-01-01 08:40:46 +00:00
Jerry Zhang
a93b9ee9d8 [quant][be] Add a test for per channel quant for groupwise conv (#115224)
Summary:
just making sure this works

Test Plan:
python test/test_quantization.py -k test_groupwise_per_channel_quant

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115224
Approved by: https://github.com/andrewor14
2023-12-07 04:46:20 +00:00
Jerry Zhang
64fd706b21 [quant][pt2e] Add generate_numeric_debug_handle pass (#114315)
Summary:
This is a util for numeric suite in pt2 export so that we can build
a more streamlined UX for numerical debugging in quant + executorch stack

Test Plan:
python test/test_quantization.py TestGenerateNumericDebugHandle

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/114315
Approved by: https://github.com/zhxchen17
2023-12-01 03:38:17 +00:00
leslie-fang-intel
65e99357ae [Inductor] [Quant] Enable QConv2d Unary int8-mixed-bf16 Lowering (#112550)
**Summary**
- PR 5 for enabling Int8-Mixed-BF16 PT2E PTQ Quantization with Inductor https://github.com/pytorch/pytorch/issues/111640.
- Enable the QConv2d Unary int8-mixed-bf16 weight prepack and post grad lowering inside inductor.

**TestPlan**
```
python -m pytest test_mkldnn_pattern_matcher.py -k test_qconv2d
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112550
Approved by: https://github.com/jgong5, https://github.com/jerryzh168
2023-11-10 08:59:40 +00:00
Jerry Zhang
501d118255 [quant][pt2e] Add transform_for_annotation method in Quantizer (#113115)
Summary:
Adding the method so that people can do some transformations before annotation to make the graph easier to annotate

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_transform_for_annotation

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D51141080](https://our.internmc.facebook.com/intern/diff/D51141080)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113115
Approved by: https://github.com/kimishpatel
2023-11-09 20:23:29 +00:00
Jerry Zhang
12c257cc00 [qunat][pt2e] Support allow_implicit_sharing flag (#112929)
Summary:
For a Node: node1 and edge: (node1, node2), since they are observing the same
Tensor, we may want to implicitly share observers, this flag allows people to
turn off this behavior for the output of the node

See the test_allow_implicit_sharing test for use case

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_allow_implicit_sharing

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112929
Approved by: https://github.com/kimishpatel
2023-11-08 23:47:17 +00:00
Jerry Zhang
43c211facb [quant][pt2e] Actually support transitive sharing for SharedQuantizationSpec (#111172)
Summary:
Previously we actually did not really support this, this PR added the support.

Next
* clean up insert observer logic
* add allow_transitive_sharing boolean flag to allow people to turn this op for certain edges

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_shared_qspec_transitivity

Reviewers:

Subscribers:

Tasks:

Tags:

Differential Revision: [D50250789](https://our.internmc.facebook.com/intern/diff/D50250789)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/111172
Approved by: https://github.com/kimishpatel
2023-10-20 23:25:17 +00:00
Jerry Zhang
c9b8e06060 [quant] Enable quantization for wav2letter (#109830)
Summary:
Also added annotation support for conv1d_relu and conv1d in XNNPACKQuantizer, the quantized results still
matches fx quant path (didn't quantize conv1d) so tests are not disabled

Test Plan: with-proxy buck2 run executorch/examples/quantization:example -- -m=w2l --verify

Differential Revision: D49479546

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109830
Approved by: https://github.com/kimishpatel
2023-09-29 00:47:34 +00:00
Jerry Zhang
3de42995e4 [quant][pt2e] Add quant API re-entrant test (#110125)
Summary:
Add the test to make sure we can call the quantize API multiple times

Test Plan:
python test/test_quantization.py TestQuantizePT2E.test_reentrant

Reviewers:

Subscribers:

Tasks:

Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110125
Approved by: https://github.com/kimishpatel
ghstack dependencies: #110097
2023-09-28 22:41:59 +00:00