mirror of
https://github.com/zebrajr/opencv.git
synced 2025-12-06 00:19:46 +01:00
Merge pull request #27749 from MykhailoTrushch:fc4
Add auto white balance DNN algorithm FC4
This commit is contained in:
commit
da8c313a90
BIN
samples/data/castle.jpg
Normal file
BIN
samples/data/castle.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 162 KiB |
265
samples/dnn/auto_white_balance.cpp
Normal file
265
samples/dnn/auto_white_balance.cpp
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level
|
||||
// directory of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
/*
|
||||
Auto white balance using FC4: https://github.com/yuanming-hu/fc4
|
||||
|
||||
Color constancy is a method to make colors of objects render correctly on a photo.
|
||||
White balance aims to make white objects appear white on an image and not a shade of any
|
||||
other color, independent of the actual light setting. White balance correction creates
|
||||
a neutral looking coloring of the objects, and generally makes colors look more similar
|
||||
to their 'true' colors under different light conditions.
|
||||
|
||||
Given an RGB image, the FC4 model predicts scene illuminant (R,G,B). We then apply
|
||||
the illuminant to the image, applying the correction in the linear RGB space.
|
||||
The transformation between linear and sRGB spaces is done as described in the sRGB standard,
|
||||
which is a nonlinear Gamma correction with exponent 2.4 and extra handling of very small values.
|
||||
This sample is written for 8bit images. The FC4 model accepts RGB images with applied Gamma scaling.
|
||||
|
||||
The training of the FC4 model was done on the Gehler-Shi dataset. The dataset includes
|
||||
568 images and ground truth corrections, as well as ground truth illuminants. The linear
|
||||
RGB images from the dataset were used with Gamma correction of 2.2 applied.
|
||||
|
||||
The model is a pretrained fold 0 of a training pipeline on the Gehler-Shi dataset, from the PyTorch
|
||||
implementation of the FC4 algorithm by Mateo Rizzo. The model was converted from a .pth file to onnx
|
||||
using torch.onnx.export. The model can be downloaded in the following link:
|
||||
https://raw.githubusercontent.com/MykhailoTrushch/opencv/d6ab21353a87e4c527e38e464384c7ee78e96e22/samples/dnn/models/fc4_fold_0.onnx
|
||||
|
||||
Copyright (c) 2017 Yuanming Hu, Baoyuan Wang, Stephen Lin
|
||||
Copyright (c) 2021 Matteo Rizzo
|
||||
|
||||
Licensed under the MIT license.
|
||||
|
||||
References:
|
||||
|
||||
Yuanming Hu, Baoyuan Wang, and Stephen Lin. “FC⁴: Fully Convolutional Color
|
||||
Constancy with Confidence-Weighted Pooling.” CVPR, 2017, pp. 4085–4094.
|
||||
|
||||
Implementations of FC4:
|
||||
https://github.com/yuanming-hu/fc4/
|
||||
https://github.com/matteo-rizzo/fc4-pytorch
|
||||
|
||||
Lilong Shi and Brian Funt, "Re-processed Version of the Gehler Color Constancy Dataset of 568 Images,"
|
||||
accessed from http://www.cs.sfu.ca/~colour/data/
|
||||
|
||||
“IEC 61966-2-1:1999 – Multimedia Systems and Equipment – Colour Measurement and Management – Part 2-1: Colour Management – Default RGB Colour Space – sRGB.” IEC Standard, 1999.
|
||||
*/
|
||||
|
||||
#include <iostream>
|
||||
#include <opencv2/dnn.hpp>
|
||||
#include <opencv2/highgui.hpp>
|
||||
#include <opencv2/imgproc.hpp>
|
||||
|
||||
#include "common.hpp"
|
||||
|
||||
using namespace cv;
|
||||
using namespace cv::dnn;
|
||||
using namespace std;
|
||||
|
||||
const string param_keys =
|
||||
"{ help h | | Print help message }"
|
||||
"{ @alias | fc4 | Model alias from models.yml "
|
||||
"(optional) }"
|
||||
"{ zoo | ../dnn/models.yml | Path to models.yml file "
|
||||
"(optional) }"
|
||||
"{ input i | castle.png | Path to input image }";
|
||||
;
|
||||
|
||||
const string backend_keys =
|
||||
format("{ backend | default | Choose one of computation backends: "
|
||||
"default: automatically (by default), "
|
||||
"openvino: Intel's Deep Learning Inference Engine "
|
||||
"(https://software.intel.com/openvino-toolkit), "
|
||||
"opencv: OpenCV implementation, "
|
||||
"vkcom: VKCOM, "
|
||||
"cuda: CUDA, "
|
||||
"webnn: WebNN }");
|
||||
|
||||
const string target_keys =
|
||||
format("{ target | cpu | Choose one of target computation devices: "
|
||||
"cpu: CPU target (by default), "
|
||||
"opencl: OpenCL, "
|
||||
"opencl_fp16: OpenCL fp16 (half-float precision), "
|
||||
"vpu: VPU, "
|
||||
"vulkan: Vulkan, "
|
||||
"cuda: CUDA, "
|
||||
"cuda_fp16: CUDA fp16 (half-float preprocess) }");
|
||||
|
||||
// Normalization constant for 8bit values
|
||||
const float NORMALIZE_FACTOR = 1.0f / 255.0f;
|
||||
|
||||
// sRGB to linear conversion constants (or vice versa):
|
||||
// SRGB_THRESHOLD / LINEAR_THRESHOLD: breakpoints between linear and gamma regions
|
||||
// SRGB_SLOPE: slope of the linear segment near black
|
||||
// SRGB_ALPHA: offset to ensure continuity at the threshold
|
||||
// SRGB_EXP: gamma exponent
|
||||
const float SRGB_THRESHOLD = 0.04045f;
|
||||
const float SRGB_ALPHA = 0.055f;
|
||||
const float SRGB_SLOPE = 12.92f;
|
||||
const float SRGB_EXP = 2.4f;
|
||||
const float LINEAR_THRESHOLD = 0.0031308f;
|
||||
const float EPS = 1e-10f;
|
||||
|
||||
static Mat srgbToLinear(const Mat &srgb32f) {
|
||||
CV_Assert(srgb32f.type() == CV_32FC3);
|
||||
const float a = SRGB_ALPHA;
|
||||
|
||||
Mat y = srgb32f;
|
||||
Mat mask_low;
|
||||
compare(y, SRGB_THRESHOLD, mask_low, CMP_LE);
|
||||
|
||||
Mat low = y / SRGB_SLOPE;
|
||||
|
||||
Mat t = (y + a) / (1.0f + a);
|
||||
Mat high;
|
||||
pow(t, SRGB_EXP, high);
|
||||
|
||||
Mat lin(y.size(), y.type(), Scalar(0, 0, 0));
|
||||
low.copyTo(lin, mask_low);
|
||||
Mat mask_high;
|
||||
bitwise_not(mask_low, mask_high);
|
||||
high.copyTo(lin, mask_high);
|
||||
|
||||
return lin;
|
||||
}
|
||||
|
||||
static Mat linearToSrgb(const Mat &lin32f) {
|
||||
CV_Assert(lin32f.type() == CV_32FC3);
|
||||
const float a = SRGB_ALPHA;
|
||||
|
||||
Mat x = lin32f;
|
||||
Mat mask_low;
|
||||
compare(x, LINEAR_THRESHOLD, mask_low, CMP_LE);
|
||||
Mat low = x * SRGB_SLOPE;
|
||||
Mat powPart;
|
||||
pow(x, 1.0 / SRGB_EXP, powPart);
|
||||
Mat high = (1.0f + a) * powPart - a;
|
||||
|
||||
Mat srgb(x.size(), x.type(), Scalar(0, 0, 0));
|
||||
low.copyTo(srgb, mask_low);
|
||||
Mat mask_high;
|
||||
bitwise_not(mask_low, mask_high);
|
||||
high.copyTo(srgb, mask_high);
|
||||
|
||||
return srgb;
|
||||
}
|
||||
|
||||
static Mat correct(const Mat &bgr8u, const Vec3f &illumRGB_linear) {
|
||||
Mat f32;
|
||||
bgr8u.convertTo(f32, CV_32F, NORMALIZE_FACTOR);
|
||||
|
||||
Mat lin = srgbToLinear(f32);
|
||||
|
||||
const float eR = std::max(illumRGB_linear[0], EPS);
|
||||
const float eG = std::max(illumRGB_linear[1], EPS);
|
||||
const float eB = std::max(illumRGB_linear[2], EPS);
|
||||
|
||||
float s3 = std::sqrt(3.0f);
|
||||
Scalar corr(eB * s3 + EPS, eG * s3 + EPS, eR * s3 + EPS);
|
||||
|
||||
Mat corrected;
|
||||
divide(lin, corr, corrected);
|
||||
|
||||
std::vector<Mat> ch;
|
||||
split(corrected, ch);
|
||||
double m0, m1, m2;
|
||||
minMaxLoc(ch[0], nullptr, &m0);
|
||||
minMaxLoc(ch[1], nullptr, &m1);
|
||||
minMaxLoc(ch[2], nullptr, &m2);
|
||||
float maxVal = static_cast<float>(std::max({m0, m1, m2})) + EPS;
|
||||
corrected /= maxVal;
|
||||
min(corrected, 1.0, corrected);
|
||||
max(corrected, 0.0, corrected);
|
||||
|
||||
Mat srgb = linearToSrgb(corrected);
|
||||
|
||||
Mat out;
|
||||
srgb.convertTo(out, CV_8U, 255.0);
|
||||
return out;
|
||||
}
|
||||
|
||||
static void annotate(Mat &img, const string &title) {
|
||||
double fs = std::max(0.5, std::min(img.cols, img.rows) / 800.0);
|
||||
int th = std::max(1, (int)std::round(fs * 2));
|
||||
putText(img, title, Point(10, 30), FONT_HERSHEY_SIMPLEX, fs,
|
||||
Scalar(0, 255, 0), th);
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
const string about = "FC4 Color Constancy (ONNX) sample.\n"
|
||||
"Predicts scene illuminant and corrects the white "
|
||||
"balance of the image.\n";
|
||||
|
||||
string keys = param_keys + backend_keys + target_keys;
|
||||
|
||||
CommandLineParser parser(argc, argv, keys);
|
||||
if (parser.has("help")) {
|
||||
cout << about << endl;
|
||||
parser.printMessage();
|
||||
return 0;
|
||||
}
|
||||
|
||||
string modelName = parser.get<String>("@alias");
|
||||
string zooFile = samples::findFile(parser.get<String>("zoo"));
|
||||
keys += genPreprocArguments(modelName, zooFile);
|
||||
parser = CommandLineParser(argc, argv, keys);
|
||||
|
||||
float scale = parser.get<float>("scale");
|
||||
Scalar mean = parser.get<Scalar>("mean");
|
||||
bool swapRB = parser.get<bool>("rgb");
|
||||
String backend = parser.get<String>("backend");
|
||||
String target = parser.get<String>("target");
|
||||
String sha1 = parser.get<String>("sha1");
|
||||
string model = findModel(parser.get<String>("model"), sha1);
|
||||
string inputPath = findFile(parser.get<String>("input"));
|
||||
|
||||
if (model.empty()) {
|
||||
cerr << "Model file not found\n";
|
||||
return -1;
|
||||
}
|
||||
|
||||
Net net;
|
||||
try {
|
||||
net = readNetFromONNX(model);
|
||||
net.setPreferableBackend(getBackendID(backend));
|
||||
net.setPreferableTarget(getTargetID(target));
|
||||
} catch (const Exception &e) {
|
||||
cerr << "Error loading model: " << e.what() << endl;
|
||||
return -1;
|
||||
}
|
||||
Mat img = imread(inputPath, IMREAD_COLOR);
|
||||
if (img.empty()) {
|
||||
cerr << "Cannot load image: " << inputPath << endl;
|
||||
return -1;
|
||||
}
|
||||
Mat blob;
|
||||
blob = blobFromImage(img, scale, img.size(), mean, swapRB, /*crop=*/false,
|
||||
/*type=*/CV_32F);
|
||||
net.setInput(blob);
|
||||
|
||||
Mat out;
|
||||
try {
|
||||
out = net.forward();
|
||||
} catch (const Exception &e) {
|
||||
cerr << "Forward error: " << e.what() << endl;
|
||||
return -1;
|
||||
}
|
||||
|
||||
const float *p = out.ptr<float>(0);
|
||||
CV_Assert(out.total() == 3);
|
||||
Vec3f illum = Vec3f(p[0], p[1], p[2]);
|
||||
|
||||
Mat corrected = correct(img, illum);
|
||||
|
||||
Mat origVis = img.clone();
|
||||
Mat corrVis = corrected.clone();
|
||||
annotate(origVis, "Original");
|
||||
annotate(corrVis, "FC4-corrected");
|
||||
Mat stacked;
|
||||
hconcat(origVis, corrVis, stacked);
|
||||
imshow("Original and Corrected Images", stacked);
|
||||
waitKey(0);
|
||||
destroyAllWindows();
|
||||
return 0;
|
||||
}
|
||||
202
samples/dnn/auto_white_balance.py
Normal file
202
samples/dnn/auto_white_balance.py
Normal file
|
|
@ -0,0 +1,202 @@
|
|||
#!/usr/bin/env python3
|
||||
# This file is part of OpenCV project.
|
||||
# It is subject to the license terms in the LICENSE file found in the top-level
|
||||
# directory of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
'''
|
||||
Auto white balance using FC4: https://github.com/yuanming-hu/fc4
|
||||
|
||||
Color constancy is a method to make colors of objects render correctly on a photo.
|
||||
White balance aims to make white objects appear white on an image and not a shade of any
|
||||
other color, independent of the actual light setting. White balance correction creates
|
||||
a neutral looking coloring of the objects, and generally makes colors look more similar
|
||||
to their 'true' colors under different light conditions.
|
||||
|
||||
Given an RGB image, the FC4 model predicts scene illuminant (R,G,B). We then apply
|
||||
the illuminant to the image, applying the correction in the linear RGB space.
|
||||
The transformation between linear and sRGB spaces is done as described in the sRGB standard,
|
||||
which is a nonlinear Gamma correction with exponent 2.4 and extra handling of very small values.
|
||||
This sample is written for 8bit images. The FC4 model accepts RGB images with applied Gamma scaling.
|
||||
|
||||
The training of the FC4 model was done on the Gehler-Shi dataset. The dataset includes
|
||||
568 images and ground truth corrections, as well as ground truth illuminants. The linear
|
||||
RGB images from the dataset were used with Gamma correction of 2.2 applied.
|
||||
|
||||
The model is a pretrained fold 0 of a training pipeline on the Gehler-Shi dataset, from the PyTorch
|
||||
implementation of the FC4 algorithm by Mateo Rizzo. The model was converted from a .pth file to onnx
|
||||
using torch.onnx.export. The model can be downloaded in the following link:
|
||||
https://raw.githubusercontent.com/MykhailoTrushch/opencv/d6ab21353a87e4c527e38e464384c7ee78e96e22/samples/dnn/models/fc4_fold_0.onnx
|
||||
|
||||
Copyright (c) 2017 Yuanming Hu, Baoyuan Wang, Stephen Lin
|
||||
Copyright (c) 2021 Matteo Rizzo
|
||||
|
||||
Licensed under the MIT license.
|
||||
|
||||
References:
|
||||
|
||||
Yuanming Hu, Baoyuan Wang, and Stephen Lin. “FC⁴: Fully Convolutional Color
|
||||
Constancy with Confidence-Weighted Pooling.” CVPR, 2017, pp. 4085–4094.
|
||||
|
||||
Implementations of FC4:
|
||||
https://github.com/yuanming-hu/fc4/
|
||||
https://github.com/matteo-rizzo/fc4-pytorch
|
||||
|
||||
Lilong Shi and Brian Funt, "Re-processed Version of the Gehler Color
|
||||
Constancy Dataset of 568 Images," accessed from http://www.cs.sfu.ca/~colour/data/
|
||||
|
||||
“IEC 61966-2-1:1999 – Multimedia Systems and Equipment – Colour Measurement and Management –
|
||||
Part 2-1: Colour Management – Default RGB Colour Space – sRGB.” IEC Standard, 1999.
|
||||
'''
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import numpy as np
|
||||
import cv2 as cv
|
||||
|
||||
from common import *
|
||||
|
||||
|
||||
# Normalization constant for 8bit values
|
||||
NORMALIZE_FACTOR = 1.0 / 255.0
|
||||
|
||||
# sRGB to linear conversion constants (or vice versa):
|
||||
# SRGB_THRESHOLD / LINEAR_THRESHOLD: breakpoints between linear and gamma regions
|
||||
# SRGB_SLOPE: slope of the linear segment near black
|
||||
# SRGB_ALPHA: offset to ensure continuity at the threshold
|
||||
# SRGB_EXP: gamma exponent
|
||||
SRGB_THRESHOLD = 0.04045
|
||||
SRGB_ALPHA = 0.055
|
||||
SRGB_SLOPE = 12.92
|
||||
SRGB_EXP = 2.4
|
||||
LINEAR_THRESHOLD = 0.0031308
|
||||
EPS = 1e-10
|
||||
|
||||
def srgb_to_linear(rgb: np.ndarray) -> np.ndarray:
|
||||
low = rgb / SRGB_SLOPE
|
||||
high = np.power((rgb + SRGB_ALPHA) / (1.0 + SRGB_ALPHA), SRGB_EXP, dtype=np.float32)
|
||||
return np.where(rgb <= SRGB_THRESHOLD, low, high).astype(np.float32)
|
||||
|
||||
def linear_to_srgb(lin: np.ndarray) -> np.ndarray:
|
||||
low = lin * SRGB_SLOPE
|
||||
high = (1.0 + SRGB_ALPHA) * np.power(lin, 1.0 / SRGB_EXP, dtype=np.float32) - SRGB_ALPHA
|
||||
return np.where(lin <= LINEAR_THRESHOLD, low, high).astype(np.float32)
|
||||
|
||||
def correct(bgr8u: np.ndarray, illum_rgb_linear: np.ndarray) -> np.ndarray:
|
||||
assert bgr8u.dtype == np.uint8 and bgr8u.ndim == 3 and bgr8u.shape[2] == 3
|
||||
|
||||
bgr = bgr8u.astype(np.float32) * NORMALIZE_FACTOR
|
||||
lin = srgb_to_linear(bgr)
|
||||
e_r = max(float(illum_rgb_linear[0]), EPS)
|
||||
e_g = max(float(illum_rgb_linear[1]), EPS)
|
||||
e_b = max(float(illum_rgb_linear[2]), EPS)
|
||||
s3 = np.float32(np.sqrt(3.0))
|
||||
corr_bgr = np.array([e_b * s3 + EPS,
|
||||
e_g * s3 + EPS,
|
||||
e_r * s3 + EPS],
|
||||
dtype=np.float32)
|
||||
|
||||
corrected = lin / corr_bgr.reshape(1, 1, 3)
|
||||
|
||||
max_val = float(corrected.max()) + EPS
|
||||
corrected /= max_val
|
||||
corrected = np.clip(corrected, 0.0, 1.0)
|
||||
|
||||
srgb = linear_to_srgb(corrected)
|
||||
|
||||
out_bgr8 = (srgb * 255.0 + 0.5).astype(np.uint8)
|
||||
return out_bgr8
|
||||
|
||||
def annotate(img_bgr: np.ndarray, title: str) -> None:
|
||||
fs = max(0.5, min(img_bgr.shape[1], img_bgr.shape[0]) / 800.0)
|
||||
th = max(1, int(round(fs * 2)))
|
||||
cv.putText(img_bgr, title, (10, 30), cv.FONT_HERSHEY_SIMPLEX, fs, (0,255,0), th)
|
||||
|
||||
def get_args_parser(func_args):
|
||||
backends = ("default", "openvino", "opencv", "vkcom", "cuda", "webnn")
|
||||
targets = ("cpu", "opencl", "opencl_fp16", "ncs2_vpu", "hddl_vpu", "vulkan",
|
||||
"cuda", "cuda_fp16")
|
||||
|
||||
p = argparse.ArgumentParser(add_help=False)
|
||||
p.add_argument('--zoo',
|
||||
default=os.path.join(os.path.dirname(os.path.abspath(__file__)), 'models.yml'),
|
||||
help='An optional path to file with preprocessing parameters.')
|
||||
p.add_argument("--input", help="Path to input image", default="castle.png")
|
||||
p.add_argument('--backend', default="default", type=str, choices=backends,
|
||||
help="Choose one of computation backends: "
|
||||
"default: automatically (by default), "
|
||||
"openvino: Intel's Deep Learning Inference Engine (https://software.intel.com/openvino-toolkit), "
|
||||
"opencv: OpenCV implementation, "
|
||||
"vkcom: VKCOM, "
|
||||
"cuda: CUDA, "
|
||||
"webnn: WebNN")
|
||||
p.add_argument('--target', default="cpu", type=str, choices=targets,
|
||||
help="Choose one of target computation devices: "
|
||||
"cpu: CPU target (by default), "
|
||||
"opencl: OpenCL, "
|
||||
"opencl_fp16: OpenCL fp16 (half-float precision), "
|
||||
"ncs2_vpu: NCS2 VPU, "
|
||||
"hddl_vpu: HDDL VPU, "
|
||||
"vulkan: Vulkan, "
|
||||
"cuda: CUDA, "
|
||||
"cuda_fp16: CUDA fp16 (half-float preprocess)")
|
||||
|
||||
args, _ = p.parse_known_args()
|
||||
add_preproc_args(args.zoo, p, 'auto_white_balance', prefix="", alias="fc4")
|
||||
p = argparse.ArgumentParser(
|
||||
parents=[p],
|
||||
description="FC4 Color Constancy (ONNX): " \
|
||||
"predicts illuminant and applies white balance.",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
return p.parse_args(func_args)
|
||||
|
||||
|
||||
|
||||
def main(func_args=None):
|
||||
args = get_args_parser(func_args)
|
||||
args.model = findModel(args.model, args.sha1)
|
||||
|
||||
try:
|
||||
net = cv.dnn.readNetFromONNX(args.model)
|
||||
net.setPreferableBackend(get_backend_id(args.backend))
|
||||
net.setPreferableTarget(get_target_id(args.target))
|
||||
except cv.error as e:
|
||||
print(f"Error loading model: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
img = cv.imread(findFile(args.input), cv.IMREAD_COLOR)
|
||||
if img is None:
|
||||
print(f"Cannot load image: {args.input}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
blob = cv.dnn.blobFromImage(
|
||||
img, scalefactor=args.scale, size=(img.shape[1], img.shape[0]),
|
||||
mean=args.mean, swapRB=args.rgb, crop=False, ddepth=cv.CV_32F
|
||||
)
|
||||
net.setInput(blob)
|
||||
|
||||
try:
|
||||
out = net.forward()
|
||||
except cv.error as e:
|
||||
print(f"Forward error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
illum = out.astype(np.float32).reshape(-1)
|
||||
if out.size != 3:
|
||||
print("Error: model output of size not equal to 3 (should output 3 illuminants in RGB order)")
|
||||
sys.exit(-1)
|
||||
|
||||
corrected = correct(img, illum)
|
||||
|
||||
orig_vis = img.copy()
|
||||
corr_vis = corrected.copy()
|
||||
annotate(orig_vis, "Original")
|
||||
annotate(corr_vis, "FC4-corrected")
|
||||
stacked = np.hstack([orig_vis, corr_vis])
|
||||
cv.imshow("Original and Corrected Images", stacked)
|
||||
cv.waitKey(0)
|
||||
cv.destroyAllWindows()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -538,3 +538,17 @@ seemoredetails:
|
|||
height: 512
|
||||
sample: "super_resolution"
|
||||
input: true
|
||||
|
||||
################################################################################
|
||||
# Auto white balance models.
|
||||
################################################################################
|
||||
|
||||
fc4:
|
||||
load_info:
|
||||
url: "https://raw.githubusercontent.com/MykhailoTrushch/fc4-models/main/fc4_fold_0.onnx"
|
||||
sha1: "e8a9a65ec0baaae3e4c97b34274a620eb362e905"
|
||||
model: "fc4_fold_0.onnx"
|
||||
sample: "auto_white_balance"
|
||||
scale: 0.00392156862
|
||||
rgb: true
|
||||
mean: 0
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user