mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
This reverts commit f3d7a02716.
Reverted https://github.com/pytorch/pytorch/pull/132535 on behalf of https://github.com/clee2000 due to broke some internal builds D64479234 ([comment](https://github.com/pytorch/pytorch/pull/132535#issuecomment-2419983509))
61 lines
2.1 KiB
C++
61 lines
2.1 KiB
C++
#include <ATen/Context.h>
|
|
#include <torch/csrc/jit/mobile/module.h>
|
|
#include <torch/csrc/jit/mobile/quantization.h>
|
|
|
|
namespace torch::jit::mobile::quantization {
|
|
|
|
void PTQQuanizationHelper::quantize_dynamic(
|
|
torch::jit::mobile::Module& m,
|
|
const std::string& method_name) {
|
|
at::globalContext().setReleaseWeightsWhenPrepacking(false);
|
|
std::string reset_observers_method_name = "reset_observers_" + method_name;
|
|
std::string observe_method_name = "observe_" + method_name;
|
|
std::string quantize_method_name = "quantize_" + method_name;
|
|
std::string quantized_method_name = "quantized_" + method_name;
|
|
|
|
TORCH_CHECK(
|
|
m.find_method(reset_observers_method_name).has_value(),
|
|
"PTQ ready module must have",
|
|
reset_observers_method_name,
|
|
" method.");
|
|
TORCH_CHECK(
|
|
m.find_method(observe_method_name),
|
|
"PTQ ready module must have",
|
|
reset_observers_method_name,
|
|
" method.");
|
|
TORCH_CHECK(
|
|
m.find_method(quantize_method_name),
|
|
"PTQ ready module must have",
|
|
quantize_method_name,
|
|
" method.");
|
|
TORCH_CHECK(
|
|
m.find_method(quantized_method_name),
|
|
"PTQ ready module must have",
|
|
quantized_method_name,
|
|
" method.");
|
|
TORCH_CHECK(
|
|
m.find_method("get_all_bundled_inputs"),
|
|
"PTQ ready module must have get_all_bundled_inputs method.");
|
|
|
|
auto inputs = m.run_method("get_all_bundled_inputs")
|
|
.toList()
|
|
.get(0)
|
|
.toTupleRef()
|
|
.elements()
|
|
.vec();
|
|
m.get_method(reset_observers_method_name)({});
|
|
m.get_method(observe_method_name)(inputs);
|
|
m.get_method(quantize_method_name)(inputs);
|
|
|
|
m.compareMethodSchemas(method_name, quantized_method_name);
|
|
m.unsafeRemoveMethod(method_name);
|
|
const Function& to_be_copied =
|
|
m.find_method(quantized_method_name).value().function();
|
|
m.unsafeCopyMethod(method_name, to_be_copied);
|
|
m.unsafeRemoveMethod(quantized_method_name);
|
|
m.unsafeRemoveMethod(quantize_method_name);
|
|
m.unsafeRemoveMethod(observe_method_name);
|
|
m.unsafeRemoveMethod(reset_observers_method_name);
|
|
}
|
|
} // namespace torch::jit::mobile::quantization
|