Merge pull request #27749 from MykhailoTrushch:fc4

Add auto white balance DNN algorithm FC4
This commit is contained in:
Alexander Smorkalov 2025-09-29 16:16:23 +03:00 committed by GitHub
commit da8c313a90
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 481 additions and 0 deletions

BIN
samples/data/castle.jpg Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 162 KiB

View File

@ -0,0 +1,265 @@
// This file is part of OpenCV project.
// It is subject to the license terms in the LICENSE file found in the top-level
// directory of this distribution and at http://opencv.org/license.html.
/*
Auto white balance using FC4: https://github.com/yuanming-hu/fc4
Color constancy is a method to make colors of objects render correctly on a photo.
White balance aims to make white objects appear white on an image and not a shade of any
other color, independent of the actual light setting. White balance correction creates
a neutral looking coloring of the objects, and generally makes colors look more similar
to their 'true' colors under different light conditions.
Given an RGB image, the FC4 model predicts scene illuminant (R,G,B). We then apply
the illuminant to the image, applying the correction in the linear RGB space.
The transformation between linear and sRGB spaces is done as described in the sRGB standard,
which is a nonlinear Gamma correction with exponent 2.4 and extra handling of very small values.
This sample is written for 8bit images. The FC4 model accepts RGB images with applied Gamma scaling.
The training of the FC4 model was done on the Gehler-Shi dataset. The dataset includes
568 images and ground truth corrections, as well as ground truth illuminants. The linear
RGB images from the dataset were used with Gamma correction of 2.2 applied.
The model is a pretrained fold 0 of a training pipeline on the Gehler-Shi dataset, from the PyTorch
implementation of the FC4 algorithm by Mateo Rizzo. The model was converted from a .pth file to onnx
using torch.onnx.export. The model can be downloaded in the following link:
https://raw.githubusercontent.com/MykhailoTrushch/opencv/d6ab21353a87e4c527e38e464384c7ee78e96e22/samples/dnn/models/fc4_fold_0.onnx
Copyright (c) 2017 Yuanming Hu, Baoyuan Wang, Stephen Lin
Copyright (c) 2021 Matteo Rizzo
Licensed under the MIT license.
References:
Yuanming Hu, Baoyuan Wang, and Stephen Lin. FC: Fully Convolutional Color
Constancy with Confidence-Weighted Pooling. CVPR, 2017, pp. 40854094.
Implementations of FC4:
https://github.com/yuanming-hu/fc4/
https://github.com/matteo-rizzo/fc4-pytorch
Lilong Shi and Brian Funt, "Re-processed Version of the Gehler Color Constancy Dataset of 568 Images,"
accessed from http://www.cs.sfu.ca/~colour/data/
IEC 61966-2-1:1999 Multimedia Systems and Equipment Colour Measurement and Management Part 2-1: Colour Management Default RGB Colour Space sRGB. IEC Standard, 1999.
*/
#include <iostream>
#include <opencv2/dnn.hpp>
#include <opencv2/highgui.hpp>
#include <opencv2/imgproc.hpp>
#include "common.hpp"
using namespace cv;
using namespace cv::dnn;
using namespace std;
const string param_keys =
"{ help h | | Print help message }"
"{ @alias | fc4 | Model alias from models.yml "
"(optional) }"
"{ zoo | ../dnn/models.yml | Path to models.yml file "
"(optional) }"
"{ input i | castle.png | Path to input image }";
;
const string backend_keys =
format("{ backend | default | Choose one of computation backends: "
"default: automatically (by default), "
"openvino: Intel's Deep Learning Inference Engine "
"(https://software.intel.com/openvino-toolkit), "
"opencv: OpenCV implementation, "
"vkcom: VKCOM, "
"cuda: CUDA, "
"webnn: WebNN }");
const string target_keys =
format("{ target | cpu | Choose one of target computation devices: "
"cpu: CPU target (by default), "
"opencl: OpenCL, "
"opencl_fp16: OpenCL fp16 (half-float precision), "
"vpu: VPU, "
"vulkan: Vulkan, "
"cuda: CUDA, "
"cuda_fp16: CUDA fp16 (half-float preprocess) }");
// Normalization constant for 8bit values
const float NORMALIZE_FACTOR = 1.0f / 255.0f;
// sRGB to linear conversion constants (or vice versa):
// SRGB_THRESHOLD / LINEAR_THRESHOLD: breakpoints between linear and gamma regions
// SRGB_SLOPE: slope of the linear segment near black
// SRGB_ALPHA: offset to ensure continuity at the threshold
// SRGB_EXP: gamma exponent
const float SRGB_THRESHOLD = 0.04045f;
const float SRGB_ALPHA = 0.055f;
const float SRGB_SLOPE = 12.92f;
const float SRGB_EXP = 2.4f;
const float LINEAR_THRESHOLD = 0.0031308f;
const float EPS = 1e-10f;
static Mat srgbToLinear(const Mat &srgb32f) {
CV_Assert(srgb32f.type() == CV_32FC3);
const float a = SRGB_ALPHA;
Mat y = srgb32f;
Mat mask_low;
compare(y, SRGB_THRESHOLD, mask_low, CMP_LE);
Mat low = y / SRGB_SLOPE;
Mat t = (y + a) / (1.0f + a);
Mat high;
pow(t, SRGB_EXP, high);
Mat lin(y.size(), y.type(), Scalar(0, 0, 0));
low.copyTo(lin, mask_low);
Mat mask_high;
bitwise_not(mask_low, mask_high);
high.copyTo(lin, mask_high);
return lin;
}
static Mat linearToSrgb(const Mat &lin32f) {
CV_Assert(lin32f.type() == CV_32FC3);
const float a = SRGB_ALPHA;
Mat x = lin32f;
Mat mask_low;
compare(x, LINEAR_THRESHOLD, mask_low, CMP_LE);
Mat low = x * SRGB_SLOPE;
Mat powPart;
pow(x, 1.0 / SRGB_EXP, powPart);
Mat high = (1.0f + a) * powPart - a;
Mat srgb(x.size(), x.type(), Scalar(0, 0, 0));
low.copyTo(srgb, mask_low);
Mat mask_high;
bitwise_not(mask_low, mask_high);
high.copyTo(srgb, mask_high);
return srgb;
}
static Mat correct(const Mat &bgr8u, const Vec3f &illumRGB_linear) {
Mat f32;
bgr8u.convertTo(f32, CV_32F, NORMALIZE_FACTOR);
Mat lin = srgbToLinear(f32);
const float eR = std::max(illumRGB_linear[0], EPS);
const float eG = std::max(illumRGB_linear[1], EPS);
const float eB = std::max(illumRGB_linear[2], EPS);
float s3 = std::sqrt(3.0f);
Scalar corr(eB * s3 + EPS, eG * s3 + EPS, eR * s3 + EPS);
Mat corrected;
divide(lin, corr, corrected);
std::vector<Mat> ch;
split(corrected, ch);
double m0, m1, m2;
minMaxLoc(ch[0], nullptr, &m0);
minMaxLoc(ch[1], nullptr, &m1);
minMaxLoc(ch[2], nullptr, &m2);
float maxVal = static_cast<float>(std::max({m0, m1, m2})) + EPS;
corrected /= maxVal;
min(corrected, 1.0, corrected);
max(corrected, 0.0, corrected);
Mat srgb = linearToSrgb(corrected);
Mat out;
srgb.convertTo(out, CV_8U, 255.0);
return out;
}
static void annotate(Mat &img, const string &title) {
double fs = std::max(0.5, std::min(img.cols, img.rows) / 800.0);
int th = std::max(1, (int)std::round(fs * 2));
putText(img, title, Point(10, 30), FONT_HERSHEY_SIMPLEX, fs,
Scalar(0, 255, 0), th);
}
int main(int argc, char **argv) {
const string about = "FC4 Color Constancy (ONNX) sample.\n"
"Predicts scene illuminant and corrects the white "
"balance of the image.\n";
string keys = param_keys + backend_keys + target_keys;
CommandLineParser parser(argc, argv, keys);
if (parser.has("help")) {
cout << about << endl;
parser.printMessage();
return 0;
}
string modelName = parser.get<String>("@alias");
string zooFile = samples::findFile(parser.get<String>("zoo"));
keys += genPreprocArguments(modelName, zooFile);
parser = CommandLineParser(argc, argv, keys);
float scale = parser.get<float>("scale");
Scalar mean = parser.get<Scalar>("mean");
bool swapRB = parser.get<bool>("rgb");
String backend = parser.get<String>("backend");
String target = parser.get<String>("target");
String sha1 = parser.get<String>("sha1");
string model = findModel(parser.get<String>("model"), sha1);
string inputPath = findFile(parser.get<String>("input"));
if (model.empty()) {
cerr << "Model file not found\n";
return -1;
}
Net net;
try {
net = readNetFromONNX(model);
net.setPreferableBackend(getBackendID(backend));
net.setPreferableTarget(getTargetID(target));
} catch (const Exception &e) {
cerr << "Error loading model: " << e.what() << endl;
return -1;
}
Mat img = imread(inputPath, IMREAD_COLOR);
if (img.empty()) {
cerr << "Cannot load image: " << inputPath << endl;
return -1;
}
Mat blob;
blob = blobFromImage(img, scale, img.size(), mean, swapRB, /*crop=*/false,
/*type=*/CV_32F);
net.setInput(blob);
Mat out;
try {
out = net.forward();
} catch (const Exception &e) {
cerr << "Forward error: " << e.what() << endl;
return -1;
}
const float *p = out.ptr<float>(0);
CV_Assert(out.total() == 3);
Vec3f illum = Vec3f(p[0], p[1], p[2]);
Mat corrected = correct(img, illum);
Mat origVis = img.clone();
Mat corrVis = corrected.clone();
annotate(origVis, "Original");
annotate(corrVis, "FC4-corrected");
Mat stacked;
hconcat(origVis, corrVis, stacked);
imshow("Original and Corrected Images", stacked);
waitKey(0);
destroyAllWindows();
return 0;
}

View File

@ -0,0 +1,202 @@
#!/usr/bin/env python3
# This file is part of OpenCV project.
# It is subject to the license terms in the LICENSE file found in the top-level
# directory of this distribution and at http://opencv.org/license.html.
'''
Auto white balance using FC4: https://github.com/yuanming-hu/fc4
Color constancy is a method to make colors of objects render correctly on a photo.
White balance aims to make white objects appear white on an image and not a shade of any
other color, independent of the actual light setting. White balance correction creates
a neutral looking coloring of the objects, and generally makes colors look more similar
to their 'true' colors under different light conditions.
Given an RGB image, the FC4 model predicts scene illuminant (R,G,B). We then apply
the illuminant to the image, applying the correction in the linear RGB space.
The transformation between linear and sRGB spaces is done as described in the sRGB standard,
which is a nonlinear Gamma correction with exponent 2.4 and extra handling of very small values.
This sample is written for 8bit images. The FC4 model accepts RGB images with applied Gamma scaling.
The training of the FC4 model was done on the Gehler-Shi dataset. The dataset includes
568 images and ground truth corrections, as well as ground truth illuminants. The linear
RGB images from the dataset were used with Gamma correction of 2.2 applied.
The model is a pretrained fold 0 of a training pipeline on the Gehler-Shi dataset, from the PyTorch
implementation of the FC4 algorithm by Mateo Rizzo. The model was converted from a .pth file to onnx
using torch.onnx.export. The model can be downloaded in the following link:
https://raw.githubusercontent.com/MykhailoTrushch/opencv/d6ab21353a87e4c527e38e464384c7ee78e96e22/samples/dnn/models/fc4_fold_0.onnx
Copyright (c) 2017 Yuanming Hu, Baoyuan Wang, Stephen Lin
Copyright (c) 2021 Matteo Rizzo
Licensed under the MIT license.
References:
Yuanming Hu, Baoyuan Wang, and Stephen Lin. FC⁴: Fully Convolutional Color
Constancy with Confidence-Weighted Pooling. CVPR, 2017, pp. 40854094.
Implementations of FC4:
https://github.com/yuanming-hu/fc4/
https://github.com/matteo-rizzo/fc4-pytorch
Lilong Shi and Brian Funt, "Re-processed Version of the Gehler Color
Constancy Dataset of 568 Images," accessed from http://www.cs.sfu.ca/~colour/data/
IEC 61966-2-1:1999 Multimedia Systems and Equipment Colour Measurement and Management
Part 2-1: Colour Management Default RGB Colour Space sRGB. IEC Standard, 1999.
'''
import argparse
import sys
import numpy as np
import cv2 as cv
from common import *
# Normalization constant for 8bit values
NORMALIZE_FACTOR = 1.0 / 255.0
# sRGB to linear conversion constants (or vice versa):
# SRGB_THRESHOLD / LINEAR_THRESHOLD: breakpoints between linear and gamma regions
# SRGB_SLOPE: slope of the linear segment near black
# SRGB_ALPHA: offset to ensure continuity at the threshold
# SRGB_EXP: gamma exponent
SRGB_THRESHOLD = 0.04045
SRGB_ALPHA = 0.055
SRGB_SLOPE = 12.92
SRGB_EXP = 2.4
LINEAR_THRESHOLD = 0.0031308
EPS = 1e-10
def srgb_to_linear(rgb: np.ndarray) -> np.ndarray:
low = rgb / SRGB_SLOPE
high = np.power((rgb + SRGB_ALPHA) / (1.0 + SRGB_ALPHA), SRGB_EXP, dtype=np.float32)
return np.where(rgb <= SRGB_THRESHOLD, low, high).astype(np.float32)
def linear_to_srgb(lin: np.ndarray) -> np.ndarray:
low = lin * SRGB_SLOPE
high = (1.0 + SRGB_ALPHA) * np.power(lin, 1.0 / SRGB_EXP, dtype=np.float32) - SRGB_ALPHA
return np.where(lin <= LINEAR_THRESHOLD, low, high).astype(np.float32)
def correct(bgr8u: np.ndarray, illum_rgb_linear: np.ndarray) -> np.ndarray:
assert bgr8u.dtype == np.uint8 and bgr8u.ndim == 3 and bgr8u.shape[2] == 3
bgr = bgr8u.astype(np.float32) * NORMALIZE_FACTOR
lin = srgb_to_linear(bgr)
e_r = max(float(illum_rgb_linear[0]), EPS)
e_g = max(float(illum_rgb_linear[1]), EPS)
e_b = max(float(illum_rgb_linear[2]), EPS)
s3 = np.float32(np.sqrt(3.0))
corr_bgr = np.array([e_b * s3 + EPS,
e_g * s3 + EPS,
e_r * s3 + EPS],
dtype=np.float32)
corrected = lin / corr_bgr.reshape(1, 1, 3)
max_val = float(corrected.max()) + EPS
corrected /= max_val
corrected = np.clip(corrected, 0.0, 1.0)
srgb = linear_to_srgb(corrected)
out_bgr8 = (srgb * 255.0 + 0.5).astype(np.uint8)
return out_bgr8
def annotate(img_bgr: np.ndarray, title: str) -> None:
fs = max(0.5, min(img_bgr.shape[1], img_bgr.shape[0]) / 800.0)
th = max(1, int(round(fs * 2)))
cv.putText(img_bgr, title, (10, 30), cv.FONT_HERSHEY_SIMPLEX, fs, (0,255,0), th)
def get_args_parser(func_args):
backends = ("default", "openvino", "opencv", "vkcom", "cuda", "webnn")
targets = ("cpu", "opencl", "opencl_fp16", "ncs2_vpu", "hddl_vpu", "vulkan",
"cuda", "cuda_fp16")
p = argparse.ArgumentParser(add_help=False)
p.add_argument('--zoo',
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models.yml'),
help='An optional path to file with preprocessing parameters.')
p.add_argument("--input", help="Path to input image", default="castle.png")
p.add_argument('--backend', default="default", type=str, choices=backends,
help="Choose one of computation backends: "
"default: automatically (by default), "
"openvino: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
"opencv: OpenCV implementation, "
"vkcom: VKCOM, "
"cuda: CUDA, "
"webnn: WebNN")
p.add_argument('--target', default="cpu", type=str, choices=targets,
help="Choose one of target computation devices: "
"cpu: CPU target (by default), "
"opencl: OpenCL, "
"opencl_fp16: OpenCL fp16 (half-float precision), "
"ncs2_vpu: NCS2 VPU, "
"hddl_vpu: HDDL VPU, "
"vulkan: Vulkan, "
"cuda: CUDA, "
"cuda_fp16: CUDA fp16 (half-float preprocess)")
args, _ = p.parse_known_args()
add_preproc_args(args.zoo, p, 'auto_white_balance', prefix="", alias="fc4")
p = argparse.ArgumentParser(
parents=[p],
description="FC4 Color Constancy (ONNX): " \
"predicts illuminant and applies white balance.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter
)
return p.parse_args(func_args)
def main(func_args=None):
args = get_args_parser(func_args)
args.model = findModel(args.model, args.sha1)
try:
net = cv.dnn.readNetFromONNX(args.model)
net.setPreferableBackend(get_backend_id(args.backend))
net.setPreferableTarget(get_target_id(args.target))
except cv.error as e:
print(f"Error loading model: {e}", file=sys.stderr)
sys.exit(1)
img = cv.imread(findFile(args.input), cv.IMREAD_COLOR)
if img is None:
print(f"Cannot load image: {args.input}", file=sys.stderr)
sys.exit(1)
blob = cv.dnn.blobFromImage(
img, scalefactor=args.scale, size=(img.shape[1], img.shape[0]),
mean=args.mean, swapRB=args.rgb, crop=False, ddepth=cv.CV_32F
)
net.setInput(blob)
try:
out = net.forward()
except cv.error as e:
print(f"Forward error: {e}", file=sys.stderr)
sys.exit(1)
illum = out.astype(np.float32).reshape(-1)
if out.size != 3:
print("Error: model output of size not equal to 3 (should output 3 illuminants in RGB order)")
sys.exit(-1)
corrected = correct(img, illum)
orig_vis = img.copy()
corr_vis = corrected.copy()
annotate(orig_vis, "Original")
annotate(corr_vis, "FC4-corrected")
stacked = np.hstack([orig_vis, corr_vis])
cv.imshow("Original and Corrected Images", stacked)
cv.waitKey(0)
cv.destroyAllWindows()
if __name__ == "__main__":
main()

View File

@ -538,3 +538,17 @@ seemoredetails:
height: 512 height: 512
sample: "super_resolution" sample: "super_resolution"
input: true input: true
################################################################################
# Auto white balance models.
################################################################################
fc4:
load_info:
url: "https://raw.githubusercontent.com/MykhailoTrushch/fc4-models/main/fc4_fold_0.onnx"
sha1: "e8a9a65ec0baaae3e4c97b34274a620eb362e905"
model: "fc4_fold_0.onnx"
sample: "auto_white_balance"
scale: 0.00392156862
rgb: true
mean: 0