Merge pull request #747 from sbokov:color_constancy

pull/762/head
Maksim Shabunin 9 years ago
commit 191d5c1de9
  1. 231
      modules/xphoto/include/opencv2/xphoto/white_balance.hpp
  2. 4
      modules/xphoto/perf/perf_grayworld.cpp
  3. 6
      modules/xphoto/perf/perf_learning_based_color_balance.cpp
  4. 68
      modules/xphoto/samples/color_balance.cpp
  5. 48
      modules/xphoto/samples/color_balance_benchmark.py
  6. 62
      modules/xphoto/samples/grayworld_color_balance.cpp
  7. 94
      modules/xphoto/samples/learn_color_balance.py
  8. 51
      modules/xphoto/samples/learning_based_color_balance.cpp
  9. 60
      modules/xphoto/samples/simple_color_balance.cpp
  10. 85
      modules/xphoto/src/grayworld_white_balance.cpp
  11. 250
      modules/xphoto/src/learning_based_color_balance.cpp
  12. 10
      modules/xphoto/src/learning_based_color_balance_model.hpp
  13. 133
      modules/xphoto/src/simple_color_balance.cpp
  14. 3
      modules/xphoto/test/simple_color_balance.cpp
  15. 6
      modules/xphoto/test/test_grayworld.cpp
  16. 11
      modules/xphoto/test/test_learning_based_color_balance.cpp
  17. 42
      modules/xphoto/tutorials/training_white_balance.markdown

@ -58,129 +58,172 @@ namespace xphoto
//! @addtogroup xphoto //! @addtogroup xphoto
//! @{ //! @{
//! various white balance algorithms /** @brief The base class for auto white balance algorithms.
enum WhitebalanceTypes
{
/** perform smart histogram adjustments (ignoring 4% pixels with minimal and maximal
values) for each channel */
WHITE_BALANCE_SIMPLE = 0,
WHITE_BALANCE_GRAYWORLD = 1
};
/** @brief The function implements different algorithm of automatic white balance,
i.e. it tries to map image's white color to perceptual white (this can be violated due to
specific illumination or camera settings).
@param src
@param dst
@param algorithmType see xphoto::WhitebalanceTypes
@param inputMin minimum value in the input image
@param inputMax maximum value in the input image
@param outputMin minimum value in the output image
@param outputMax maximum value in the output image
@sa cvtColor, equalizeHist
*/ */
CV_EXPORTS_W void balanceWhite(const Mat &src, Mat &dst, const int algorithmType, class CV_EXPORTS_W WhiteBalancer : public Algorithm
const float inputMin = 0.0f, const float inputMax = 255.0f, {
const float outputMin = 0.0f, const float outputMax = 255.0f); public:
/** @brief Applies white balancing to the input image
/** @brief Implements a simple grayworld white balance algorithm. @param src Input image
@param dst White balancing result
@sa cvtColor, equalizeHist
*/
CV_WRAP virtual void balanceWhite(InputArray src, OutputArray dst) = 0;
};
The function autowbGrayworld scales the values of pixels based on a /** @brief A simple white balance algorithm that works by independently stretching
gray-world assumption which states that the average of all channels each of the input image channels to the specified range. For increased robustness
should result in a gray image. it ignores the top and bottom \f$p\%\f$ of pixel values.
*/
class CV_EXPORTS_W SimpleWB : public WhiteBalancer
{
public:
/** @brief Input image range minimum value
@see setInputMin */
CV_WRAP virtual float getInputMin() const = 0;
/** @copybrief getInputMin @see getInputMin */
CV_WRAP virtual void setInputMin(float val) = 0;
/** @brief Input image range maximum value
@see setInputMax */
CV_WRAP virtual float getInputMax() const = 0;
/** @copybrief getInputMax @see getInputMax */
CV_WRAP virtual void setInputMax(float val) = 0;
/** @brief Output image range minimum value
@see setOutputMin */
CV_WRAP virtual float getOutputMin() const = 0;
/** @copybrief getOutputMin @see getOutputMin */
CV_WRAP virtual void setOutputMin(float val) = 0;
/** @brief Output image range maximum value
@see setOutputMax */
CV_WRAP virtual float getOutputMax() const = 0;
/** @copybrief getOutputMax @see getOutputMax */
CV_WRAP virtual void setOutputMax(float val) = 0;
/** @brief Percent of top/bottom values to ignore
@see setP */
CV_WRAP virtual float getP() const = 0;
/** @copybrief getP @see getP */
CV_WRAP virtual void setP(float val) = 0;
};
/** @brief Creates an instance of SimpleWB
*/
CV_EXPORTS_W Ptr<SimpleWB> createSimpleWB();
This function adds a modification which thresholds pixels based on their /** @brief Gray-world white balance algorithm
saturation value and only uses pixels below the provided threshold in
finding average pixel values.
Saturation is calculated using the following for a 3-channel RGB image per This algorithm scales the values of pixels based on a
pixel I and is in the range [0, 1]: gray-world assumption which states that the average of all channels
should result in a gray image.
\f[ \texttt{Saturation} [I] = \frac{\textrm{max}(R,G,B) - \textrm{min}(R,G,B) It adds a modification which thresholds pixels based on their
}{\textrm{max}(R,G,B)} \f] saturation value and only uses pixels below the provided threshold in
finding average pixel values.
A threshold of 1 means that all pixels are used to white-balance, while a Saturation is calculated using the following for a 3-channel RGB image per
threshold of 0 means no pixels are used. Lower thresholds are useful in pixel I and is in the range [0, 1]:
white-balancing saturated images.
Currently only works on images of type @ref CV_8UC3 and @ref CV_16UC3. \f[ \texttt{Saturation} [I] = \frac{\textrm{max}(R,G,B) - \textrm{min}(R,G,B)
}{\textrm{max}(R,G,B)} \f]
@param src Input array. A threshold of 1 means that all pixels are used to white-balance, while a
@param dst Output array of the same size and type as src. threshold of 0 means no pixels are used. Lower thresholds are useful in
@param thresh Maximum saturation for a pixel to be included in the white-balancing saturated images.
gray-world assumption.
@sa balanceWhite Currently supports images of type @ref CV_8UC3 and @ref CV_16UC3.
*/ */
CV_EXPORTS_W void autowbGrayworld(InputArray src, OutputArray dst, class CV_EXPORTS_W GrayworldWB : public WhiteBalancer
float thresh = 0.5f); {
public:
/** @brief Implements a more sophisticated learning-based automatic color balance algorithm. /** @brief Maximum saturation for a pixel to be included in the
gray-world assumption
As autowbGrayworld, this function works by applying different gains to the input @see setSaturationThreshold */
image channels, but their computation is a bit more involved compared to the CV_WRAP virtual float getSaturationThreshold() const = 0;
simple grayworld assumption. More details about the algorithm can be found in /** @copybrief getSaturationThreshold @see getSaturationThreshold */
@cite Cheng2015 . CV_WRAP virtual void setSaturationThreshold(float val) = 0;
};
/** @brief Creates an instance of GrayworldWB
*/
CV_EXPORTS_W Ptr<GrayworldWB> createGrayworldWB();
To mask out saturated pixels this function uses only pixels that satisfy the /** @brief More sophisticated learning-based automatic white balance algorithm.
following condition:
\f[ \frac{\textrm{max}(R,G,B)}{\texttt{range_max_val}} < \texttt{saturation_thresh} \f] As @ref GrayworldWB, this algorithm works by applying different gains to the input
image channels, but their computation is a bit more involved compared to the
simple gray-world assumption. More details about the algorithm can be found in
@cite Cheng2015 .
Currently supports images of type @ref CV_8UC3 and @ref CV_16UC3. To mask out saturated pixels this function uses only pixels that satisfy the
following condition:
@param src Input three-channel image in the BGR color space. \f[ \frac{\textrm{max}(R,G,B)}{\texttt{range_max_val}} < \texttt{saturation_thresh} \f]
@param dst Output image of the same size and type as src.
@param range_max_val Maximum possible value of the input image (e.g. 255 for 8 bit images, 4095 for 12 bit images)
@param saturation_thresh Threshold that is used to determine saturated pixels
@param hist_bin_num Defines the size of one dimension of a three-dimensional RGB histogram that is used internally by
the algorithm. It often makes sense to increase the number of bins for images with higher bit depth (e.g. 256 bins
for a 12 bit image)
@sa autowbGrayworld Currently supports images of type @ref CV_8UC3 and @ref CV_16UC3.
*/ */
CV_EXPORTS_W void autowbLearningBased(InputArray src, OutputArray dst, int range_max_val = 255, class CV_EXPORTS_W LearningBasedWB : public WhiteBalancer
float saturation_thresh = 0.98f, int hist_bin_num = 64); {
public:
/** @brief Implements the feature extraction part of the learning-based color balance algorithm. /** @brief Implements the feature extraction part of the algorithm.
In accordance with @cite Cheng2015 , computes the following features for the input image: In accordance with @cite Cheng2015 , computes the following features for the input image:
1. Chromaticity of an average (R,G,B) tuple 1. Chromaticity of an average (R,G,B) tuple
2. Chromaticity of the brightest (R,G,B) tuple (while ignoring saturated pixels) 2. Chromaticity of the brightest (R,G,B) tuple (while ignoring saturated pixels)
3. Chromaticity of the dominant (R,G,B) tuple (the one that has the highest value in the RGB histogram) 3. Chromaticity of the dominant (R,G,B) tuple (the one that has the highest value in the RGB histogram)
4. Mode of the chromaticity pallete, that is constructed by taking 300 most common colors according to 4. Mode of the chromaticity palette, that is constructed by taking 300 most common colors according to
the RGB histogram and projecting them on the chromaticity plane. Mode is the most high-density point the RGB histogram and projecting them on the chromaticity plane. Mode is the most high-density point
of the pallete, which is computed by a straightforward fixed-bandwidth kernel density estimator with of the palette, which is computed by a straightforward fixed-bandwidth kernel density estimator with
a Epanechnikov kernel function. a Epanechnikov kernel function.
@param src Input three-channel image in the BGR color space. @param src Input three-channel image (BGR color space is assumed).
@param dst An array of four (r,g) chromaticity tuples corresponding to the features listed above. @param dst An array of four (r,g) chromaticity tuples corresponding to the features listed above.
@param range_max_val Maximum possible value of the input image (e.g. 255 for 8 bit images, 4095 for 12 bit images)
@param saturation_thresh Threshold that is used to determine saturated pixels
@param hist_bin_num Defines the size of one dimension of a three-dimensional RGB histogram that is used internally by
the algorithm. It often makes sense to increase the number of bins for images with higher bit depth (e.g. 256 bins
for a 12 bit image)
@sa autowbLearningBased
*/ */
CV_EXPORTS_W void extractSimpleFeatures(InputArray src, OutputArray dst, int range_max_val = 255, CV_WRAP virtual void extractSimpleFeatures(InputArray src, OutputArray dst) = 0;
float saturation_thresh = 0.98f, int hist_bin_num = 64);
/** @brief Maximum possible value of the input image (e.g. 255 for 8 bit images,
/** @brief Implements an efficient fixed-point approximation for applying channel gains. 4095 for 12 bit images)
@see setRangeMaxVal */
@param src Input three-channel image in the BGR color space (either CV_8UC3 or CV_16UC3) CV_WRAP virtual int getRangeMaxVal() const = 0;
@param dst Output image of the same size and type as src. /** @copybrief getRangeMaxVal @see getRangeMaxVal */
@param gainB gain for the B channel CV_WRAP virtual void setRangeMaxVal(int val) = 0;
@param gainG gain for the G channel
@param gainR gain for the R channel /** @brief Threshold that is used to determine saturated pixels, i.e. pixels where at least one of the
channels exceeds \f$\texttt{saturation_threshold}\times\texttt{range_max_val}\f$ are ignored.
@sa autowbGrayworld, autowbLearningBased @see setSaturationThreshold */
CV_WRAP virtual float getSaturationThreshold() const = 0;
/** @copybrief getSaturationThreshold @see getSaturationThreshold */
CV_WRAP virtual void setSaturationThreshold(float val) = 0;
/** @brief Defines the size of one dimension of a three-dimensional RGB histogram that is used internally
by the algorithm. It often makes sense to increase the number of bins for images with higher bit depth
(e.g. 256 bins for a 12 bit image).
@see setHistBinNum */
CV_WRAP virtual int getHistBinNum() const = 0;
/** @copybrief getHistBinNum @see getHistBinNum */
CV_WRAP virtual void setHistBinNum(int val) = 0;
};
/** @brief Creates an instance of LearningBasedWB
@param path_to_model Path to a .yml file with the model. If not specified, the default model is used
*/ */
CV_EXPORTS_W void applyChannelGains(InputArray src, OutputArray dst, float gainB, float gainG, float gainR); CV_EXPORTS_W Ptr<LearningBasedWB> createLearningBasedWB(const String& path_to_model = String());
//! @}
/** @brief Implements an efficient fixed-point approximation for applying channel gains, which is
the last step of multiple white balance algorithms.
@param src Input three-channel image in the BGR color space (either CV_8UC3 or CV_16UC3)
@param dst Output image of the same size and type as src.
@param gainB gain for the B channel
@param gainG gain for the G channel
@param gainR gain for the R channel
*/
CV_EXPORTS_W void applyChannelGains(InputArray src, OutputArray dst, float gainB, float gainG, float gainR);
//! @}
} }
} }

@ -21,8 +21,10 @@ PERF_TEST_P( Size_WBThresh, autowbGrayworld,
Mat dst(size, CV_8UC3); Mat dst(size, CV_8UC3);
declare.in(src, WARMUP_RNG).out(dst); declare.in(src, WARMUP_RNG).out(dst);
Ptr<xphoto::GrayworldWB> wb = xphoto::createGrayworldWB();
wb->setSaturationThreshold(wb_thresh);
TEST_CYCLE() xphoto::autowbGrayworld(src, dst, wb_thresh); TEST_CYCLE() wb->balanceWhite(src, dst);
SANITY_CHECK(dst); SANITY_CHECK(dst);
} }

@ -65,8 +65,12 @@ PERF_TEST_P(learningBasedWBPerfTest, perf, Combine(SZ_ALL_HD, Values(CV_8UC3, CV
RNG rng(1234); RNG rng(1234);
rng.fill(src_dscl, RNG::UNIFORM, 0, range_max_val); rng.fill(src_dscl, RNG::UNIFORM, 0, range_max_val);
resize(src_dscl, src, src.size()); resize(src_dscl, src, src.size());
Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB();
wb->setRangeMaxVal(range_max_val);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(hist_bin_num);
TEST_CYCLE() xphoto::autowbLearningBased(src, dst, range_max_val, 0.98f, hist_bin_num); TEST_CYCLE() wb->balanceWhite(src, dst);
SANITY_CHECK_NOTHING(); SANITY_CHECK_NOTHING();
} }

@ -0,0 +1,68 @@
#include "opencv2/xphoto.hpp"
#include "opencv2/highgui.hpp"
using namespace cv;
using namespace std;
const char *keys = { "{help h usage ? | | print this message}"
"{i | | input image name }"
"{o | | output image name }"
"{a |grayworld| color balance algorithm (simple, grayworld or learning_based)}"
"{m | | path to the model for the learning-based algorithm (optional) }" };
int main(int argc, const char **argv)
{
CommandLineParser parser(argc, argv, keys);
parser.about("OpenCV color balance demonstration sample");
if (parser.has("help") || argc < 2)
{
parser.printMessage();
return 0;
}
string inFilename = parser.get<string>("i");
string outFilename = parser.get<string>("o");
string algorithm = parser.get<string>("a");
string modelFilename = parser.get<string>("m");
if (!parser.check())
{
parser.printErrors();
return -1;
}
Mat src = imread(inFilename, 1);
if (src.empty())
{
printf("Cannot read image file: %s\n", inFilename.c_str());
return -1;
}
Mat res;
Ptr<xphoto::WhiteBalancer> wb;
if (algorithm == "simple")
wb = xphoto::createSimpleWB();
else if (algorithm == "grayworld")
wb = xphoto::createGrayworldWB();
else if (algorithm == "learning_based")
wb = xphoto::createLearningBasedWB(modelFilename);
else
{
printf("Unsupported algorithm: %s\n", algorithm.c_str());
return -1;
}
wb->balanceWhite(src, res);
if (outFilename == "")
{
namedWindow("after white balance", 1);
imshow("after white balance", res);
waitKey(0);
}
else
imwrite(outFilename, res);
return 0;
}

@ -5,6 +5,7 @@ import numpy as np
import scipy.io import scipy.io
import cv2 import cv2
import timeit import timeit
from learn_color_balance import load_ground_truth
def load_json(path): def load_json(path):
@ -39,15 +40,24 @@ def stretch_to_8bit(arr, clip_percentile = 2.5):
return arr.astype(np.uint8) return arr.astype(np.uint8)
def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder): def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder, model_folder):
new_im = None new_im = None
start_time = timeit.default_timer() start_time = timeit.default_timer()
if algo=="grayworld": if algo=="grayworld":
new_im = cv2.xphoto.autowbGrayworld(im, 0.95) inst = cv2.xphoto.createGrayworldWB()
inst.setSaturationThreshold(0.95)
new_im = inst.balanceWhite(im)
elif algo=="nothing": elif algo=="nothing":
new_im = im new_im = im
elif algo=="learning_based": elif algo.split(":")[0]=="learning_based":
new_im = cv2.xphoto.autowbLearningBased(im, None, range_thresh, 0.98, bin_num) model_path = ""
if len(algo.split(":"))>1:
model_path = os.path.join(model_folder, algo.split(":")[1])
inst = cv2.xphoto.createLearningBasedWB(model_path)
inst.setRangeMaxVal(range_thresh)
inst.setSaturationThreshold(0.98)
inst.setHistBinNum(bin_num)
new_im = inst.balanceWhite(im)
elif algo=="GT": elif algo=="GT":
gains = gt_illuminant / min(gt_illuminant) gains = gt_illuminant / min(gt_illuminant)
g1 = float(1.0 / gains[2]) g1 = float(1.0 / gains[2])
@ -59,7 +69,7 @@ def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder):
if len(dst_folder)>0: if len(dst_folder)>0:
if not os.path.exists(dst_folder): if not os.path.exists(dst_folder):
os.makedirs(dst_folder) os.makedirs(dst_folder)
im_name = ("%04d_" % i) + algo + ".jpg" im_name = ("%04d_" % i) + algo.replace(":","_") + ".jpg"
cv2.imwrite(os.path.join(dst_folder, im_name), stretch_to_8bit(new_im)) cv2.imwrite(os.path.join(dst_folder, im_name), stretch_to_8bit(new_im))
#recover the illuminant from the color balancing result, assuming the standard model: #recover the illuminant from the color balancing result, assuming the standard model:
@ -140,7 +150,9 @@ if __name__ == '__main__':
metavar="ALGORITHMS", metavar="ALGORITHMS",
default="", default="",
help=("Comma-separated list of color balance algorithms to evaluate. " help=("Comma-separated list of color balance algorithms to evaluate. "
"Currently available: GT,learning_based,grayworld,nothing.")) "Currently available: GT,learning_based,grayworld,nothing. "
"Use a colon to set a specific model for the learning-based "
"algorithm, e.g. learning_based:model1.yml,learning_based:model2.yml"))
parser.add_argument( parser.add_argument(
"-i", "-i",
"--input_folder", "--input_folder",
@ -196,6 +208,12 @@ if __name__ == '__main__':
default="0,0", default="0,0",
help=("Comma-separated range of images from the dataset to evaluate on (for instance: 0,568). " help=("Comma-separated range of images from the dataset to evaluate on (for instance: 0,568). "
"All available images are used by default.")) "All available images are used by default."))
parser.add_argument(
"-m",
"--model_folder",
metavar="MODEL_FOLDER",
default="",
help=("Path to the folder containing models for the learning-based color balance algorithm (optional)"))
args, other_args = parser.parse_known_args() args, other_args = parser.parse_known_args()
if not os.path.exists(args.input_folder): if not os.path.exists(args.input_folder):
@ -218,22 +236,8 @@ if __name__ == '__main__':
print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>") print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>")
sys.exit(1) sys.exit(1)
gt = scipy.io.loadmat(args.ground_truth)
img_files = sorted(os.listdir(args.input_folder)) img_files = sorted(os.listdir(args.input_folder))
(gt_illuminants,black_levels) = load_ground_truth(args.ground_truth)
gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
for algorithm in algorithm_list: for algorithm in algorithm_list:
i = 0 i = 0
@ -254,7 +258,7 @@ if __name__ == '__main__':
im = stretch_to_8bit(im) im = stretch_to_8bit(im)
(time,angular_err) = evaluate(im, algorithm, gt_illuminants[i], i, range_thresh, (time,angular_err) = evaluate(im, algorithm, gt_illuminants[i], i, range_thresh,
256 if range_thresh > 255 else 64, args.dst_folder) 256 if range_thresh > 255 else 64, args.dst_folder, args.model_folder)
state[algorithm][file] = {"angular_error": angular_err, "time": time} state[algorithm][file] = {"angular_error": angular_err, "time": time}
sys.stdout.write("Algorithm: %-20s Done: [%3d/%3d]\r" % (algorithm, i, sz)), sys.stdout.write("Algorithm: %-20s Done: [%3d/%3d]\r" % (algorithm, i, sz)),
sys.stdout.flush() sys.stdout.flush()

@ -1,62 +0,0 @@
#include "opencv2/xphoto.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/core/utility.hpp"
using namespace cv;
using namespace std;
const char* keys =
{
"{i || input image name}"
"{o || output image name}"
};
int main( int argc, const char** argv )
{
bool printHelp = ( argc == 1 );
printHelp = printHelp || ( argc == 2 && string(argv[1]) == "--help" );
printHelp = printHelp || ( argc == 2 && string(argv[1]) == "-h" );
if ( printHelp )
{
printf("\nThis sample demonstrates the grayworld balance algorithm\n"
"Call:\n"
" simple_color_blance -i=in_image_name [-o=out_image_name]\n\n");
return 0;
}
CommandLineParser parser(argc, argv, keys);
if ( !parser.check() )
{
parser.printErrors();
return -1;
}
string inFilename = parser.get<string>("i");
string outFilename = parser.get<string>("o");
Mat src = imread(inFilename, 1);
if ( src.empty() )
{
printf("Cannot read image file: %s\n", inFilename.c_str());
return -1;
}
Mat res(src.size(), src.type());
xphoto::autowbGrayworld(src, res);
if ( outFilename == "" )
{
namedWindow("after white balance", 1);
imshow("after white balance", res);
waitKey(0);
}
else
imwrite(outFilename, res);
return 0;
}

@ -80,7 +80,7 @@ def get_tree_node_lists(tree, tree_depth):
return (dst_feature_idx, dst_thresh_vals, dst_leaf_vals) return (dst_feature_idx, dst_thresh_vals, dst_leaf_vals)
def generate_code(model, input_params): def generate_code(model, input_params, use_YML, out_file):
feature_idx = [] feature_idx = []
thresh_vals = [] thresh_vals = []
leaf_vals = [] leaf_vals = []
@ -95,31 +95,60 @@ def generate_code(model, input_params):
feature_idx += local_feature_idx feature_idx += local_feature_idx
thresh_vals += local_thresh_vals thresh_vals += local_thresh_vals
leaf_vals += local_leaf_vals leaf_vals += local_leaf_vals
if use_YML:
fs = cv2.FileStorage(out_file, 1)
fs.write("num_trees", len(model))
fs.write("num_tree_nodes", 2**depth)
fs.write("feature_idx", np.array(feature_idx).astype(np.uint8))
fs.write("thresh_vals", np.array(thresh_vals).astype(np.float32))
fs.write("leaf_vals", np.array(leaf_vals).astype(np.float32))
fs.release()
else:
res = "/* This file was automatically generated by learn_color_balance.py script\n" +\ res = "/* This file was automatically generated by learn_color_balance.py script\n" +\
" * using the following parameters:\n" " * using the following parameters:\n"
for key in input_params: for key in input_params:
res += " " + key + " " + input_params[key] res += " " + key + " " + input_params[key]
res += "\n */\n" res += "\n */\n"
res += "const int num_trees = " + str(len(model)) + ";\n"
res += "const int num_features = 4;\n" res += "const int num_features = 4;\n"
res += "const int num_tree_nodes = " + str(2**depth) + ";\n" res += "const int _num_trees = " + str(len(model)) + ";\n"
res += "const int _num_tree_nodes = " + str(2**depth) + ";\n"
res += "unsigned char feature_idx[num_trees*num_features*2*(num_tree_nodes-1)] = {" + str(feature_idx[0]) res += "unsigned char _feature_idx[_num_trees*num_features*2*(_num_tree_nodes-1)] = {" + str(feature_idx[0])
for i in range(1,len(feature_idx)): for i in range(1,len(feature_idx)):
res += "," + str(feature_idx[i]) res += "," + str(feature_idx[i])
res += "};\n" res += "};\n"
res += "float thresh_vals[num_trees*num_features*2*(num_tree_nodes-1)] = {" + ("%.3ff" % thresh_vals[0])[1:] res += "float _thresh_vals[_num_trees*num_features*2*(_num_tree_nodes-1)] = {" + ("%.3ff" % thresh_vals[0])[1:]
for i in range(1,len(thresh_vals)): for i in range(1,len(thresh_vals)):
res += "," + ("%.3ff" % thresh_vals[i])[1:] res += "," + ("%.3ff" % thresh_vals[i])[1:]
res += "};\n" res += "};\n"
res += "float leaf_vals[num_trees*num_features*2*num_tree_nodes] = {" + ("%.3ff" % leaf_vals[0])[1:] res += "float _leaf_vals[_num_trees*num_features*2*_num_tree_nodes] = {" + ("%.3ff" % leaf_vals[0])[1:]
for i in range(1,len(leaf_vals)): for i in range(1,len(leaf_vals)):
res += "," + ("%.3ff" % leaf_vals[i])[1:] res += "," + ("%.3ff" % leaf_vals[i])[1:]
res += "};\n" res += "};\n"
return res f = open(out_file,"w")
f.write(res)
f.close()
def load_ground_truth(gt_path):
gt = scipy.io.loadmat(gt_path)
base_gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
base_gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(base_gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
base_gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(base_gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
return (base_gt_illuminants, black_levels)
if __name__ == '__main__': if __name__ == '__main__':
@ -153,8 +182,9 @@ if __name__ == '__main__':
"-o", "-o",
"--out", "--out",
metavar="OUT", metavar="OUT",
default="learning_based_color_balance_model.hpp", default="color_balance_model.yml",
help="Path to the output learnt model") help="Path to the output learnt model. Either a .yml (for loading during runtime) "
"or .hpp (for compiling with the main code) file ")
parser.add_argument( parser.add_argument(
"--hist_bin_num", "--hist_bin_num",
metavar="HIST_BIN_NUM", metavar="HIST_BIN_NUM",
@ -196,39 +226,37 @@ if __name__ == '__main__':
print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>") print("Error: Please specify the -r parameter in form <first_image_index>,<last_image_index>")
sys.exit(1) sys.exit(1)
use_YML = None
if args.out.endswith(".yml"):
use_YML = True
elif args.out.endswith(".hpp"):
use_YML = False
else:
print("Error: Only .hpp and .yml are supported as output formats")
sys.exit(1)
hist_bin_num = int(args.hist_bin_num) hist_bin_num = int(args.hist_bin_num)
num_trees = int(args.num_trees) num_trees = int(args.num_trees)
max_tree_depth = int(args.max_tree_depth) max_tree_depth = int(args.max_tree_depth)
gt = scipy.io.loadmat(args.ground_truth)
img_files = sorted(os.listdir(args.input_folder)) img_files = sorted(os.listdir(args.input_folder))
(base_gt_illuminants,black_levels) = load_ground_truth(args.ground_truth)
base_gt_illuminants = []
black_levels = []
if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys():
#NUS 8-camera dataset format
base_gt_illuminants = gt["groundtruth_illuminants"]
black_levels = len(gt_illuminants) * [gt["darkness_level"][0][0]]
elif "real_rgb" in gt.keys():
#Gehler-Shi dataset format
base_gt_illuminants = gt["real_rgb"]
black_levels = 87 * [0] + (len(base_gt_illuminants) - 87) * [129]
else:
print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported")
sys.exit(1)
features = [] features = []
gt_illuminants = [] gt_illuminants = []
i=0 i=0
sz = len(img_files) sz = len(img_files)
random.seed(1234) random.seed(1234)
inst = cv2.xphoto.createLearningBasedWB()
inst.setRangeMaxVal(255)
inst.setSaturationThreshold(0.98)
inst.setHistBinNum(hist_bin_num)
for file in img_files: for file in img_files:
if (i>=img_range[0] and i<img_range[1]) or (img_range[0]==img_range[1]==0): if (i>=img_range[0] and i<img_range[1]) or (img_range[0]==img_range[1]==0):
cur_path = os.path.join(args.input_folder,file) cur_path = os.path.join(args.input_folder,file)
im = cv2.imread(cur_path, -1).astype(np.float32) im = cv2.imread(cur_path, -1).astype(np.float32)
im -= black_levels[i] im -= black_levels[i]
im_8bit = convert_to_8bit(im) im_8bit = convert_to_8bit(im)
cur_img_features = cv2.xphoto.extractSimpleFeatures(im_8bit, None, 255, 0.98, hist_bin_num) cur_img_features = inst.extractSimpleFeatures(im_8bit, None)
features.append(cur_img_features.tolist()) features.append(cur_img_features.tolist())
gt_illuminants.append(base_gt_illuminants[i].tolist()) gt_illuminants.append(base_gt_illuminants[i].tolist())
@ -241,7 +269,7 @@ if __name__ == '__main__':
im_8bit[:,:,1] *= G_coef im_8bit[:,:,1] *= G_coef
im_8bit[:,:,2] *= R_coef im_8bit[:,:,2] *= R_coef
im_8bit = convert_to_8bit(im) im_8bit = convert_to_8bit(im)
cur_img_features = cv2.xphoto.extractSimpleFeatures(im_8bit, None, 255, 0.98, hist_bin_num) cur_img_features = inst.extractSimpleFeatures(im_8bit, None)
features.append(cur_img_features.tolist()) features.append(cur_img_features.tolist())
illum = base_gt_illuminants[i] illum = base_gt_illuminants[i]
illum[0] *= R_coef illum[0] *= R_coef
@ -255,10 +283,8 @@ if __name__ == '__main__':
print("\nLearning the model...") print("\nLearning the model...")
model = learn_regression_tree_ensemble(features, gt_illuminants, num_trees, max_tree_depth) model = learn_regression_tree_ensemble(features, gt_illuminants, num_trees, max_tree_depth)
print("Generating code...") print("Writing the model...")
str = generate_code(model,{"-r":args.range, "--hist_bin_num": args.hist_bin_num, "--num_trees": args.num_trees, generate_code(model,{"-r":args.range, "--hist_bin_num": args.hist_bin_num, "--num_trees": args.num_trees,
"--max_tree_depth": args.max_tree_depth, "--num_augmented": args.num_augmented}) "--max_tree_depth": args.max_tree_depth, "--num_augmented": args.num_augmented},
f = open(args.out,"w") use_YML, args.out)
f.write(str)
f.close()
print("Done") print("Done")

@ -1,51 +0,0 @@
#include "opencv2/xphoto.hpp"
#include "opencv2/highgui.hpp"
using namespace cv;
using namespace std;
const char *keys = {"{help h usage ? | | print this message}"
"{i | | input image name }"
"{o | | output image name }"};
int main(int argc, const char **argv)
{
CommandLineParser parser(argc, argv, keys);
parser.about("OpenCV learning-based color balance demonstration sample");
if (parser.has("help") || argc < 2)
{
parser.printMessage();
return 0;
}
string inFilename = parser.get<string>("i");
string outFilename = parser.get<string>("o");
if (!parser.check())
{
parser.printErrors();
return -1;
}
Mat src = imread(inFilename, 1);
if (src.empty())
{
printf("Cannot read image file: %s\n", inFilename.c_str());
return -1;
}
Mat res;
xphoto::autowbLearningBased(src, res);
if (outFilename == "")
{
namedWindow("after white balance", 1);
imshow("after white balance", res);
waitKey(0);
}
else
imwrite(outFilename, res);
return 0;
}

@ -1,60 +0,0 @@
#include "opencv2/xphoto.hpp"
#include "opencv2/imgproc.hpp"
#include "opencv2/highgui.hpp"
#include "opencv2/core/utility.hpp"
#include "opencv2/imgproc/types_c.h"
const char* keys =
{
"{i || input image name}"
"{o || output image name}"
};
int main( int argc, const char** argv )
{
bool printHelp = ( argc == 1 );
printHelp = printHelp || ( argc == 2 && std::string(argv[1]) == "--help" );
printHelp = printHelp || ( argc == 2 && std::string(argv[1]) == "-h" );
if ( printHelp )
{
printf("\nThis sample demonstrates simple color balance algorithm\n"
"Call:\n"
" simple_color_blance -i=in_image_name [-o=out_image_name]\n\n");
return 0;
}
cv::CommandLineParser parser(argc, argv, keys);
if ( !parser.check() )
{
parser.printErrors();
return -1;
}
std::string inFilename = parser.get<std::string>("i");
std::string outFilename = parser.get<std::string>("o");
cv::Mat src = cv::imread(inFilename, 1);
if ( src.empty() )
{
printf("Cannot read image file: %s\n", inFilename.c_str());
return -1;
}
cv::Mat res(src.size(), src.type());
cv::xphoto::balanceWhite(src, res, cv::xphoto::WHITE_BALANCE_SIMPLE);
if ( outFilename == "" )
{
cv::namedWindow("after white balance", 1);
cv::imshow("after white balance", res);
cv::waitKey(0);
}
else
cv::imwrite(outFilename, res);
return 0;
}

@ -49,6 +49,54 @@ namespace xphoto
void calculateChannelSums(uint &sumB, uint &sumG, uint &sumR, uchar *src_data, int src_len, float thresh); void calculateChannelSums(uint &sumB, uint &sumG, uint &sumR, uchar *src_data, int src_len, float thresh);
void calculateChannelSums(uint64 &sumB, uint64 &sumG, uint64 &sumR, ushort *src_data, int src_len, float thresh); void calculateChannelSums(uint64 &sumB, uint64 &sumG, uint64 &sumR, ushort *src_data, int src_len, float thresh);
class GrayworldWBImpl : public GrayworldWB
{
private:
float thresh;
public:
GrayworldWBImpl() { thresh = 0.9f; }
float getSaturationThreshold() const { return thresh; }
void setSaturationThreshold(float val) { thresh = val; }
void balanceWhite(InputArray _src, OutputArray _dst)
{
CV_Assert(!_src.empty());
CV_Assert(_src.isContinuous());
CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3);
Mat src = _src.getMat();
int N = src.cols * src.rows, N3 = N * 3;
double dsumB = 0.0, dsumG = 0.0, dsumR = 0.0;
if (src.type() == CV_8UC3)
{
uint sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<uchar>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
else if (src.type() == CV_16UC3)
{
uint64 sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<ushort>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
// Find inverse of averages
double max_sum = max(dsumB, max(dsumR, dsumG));
const double eps = 0.1;
float dinvB = dsumB < eps ? 0.f : (float)(max_sum / dsumB),
dinvG = dsumG < eps ? 0.f : (float)(max_sum / dsumG),
dinvR = dsumR < eps ? 0.f : (float)(max_sum / dsumR);
// Use the inverse of averages as channel gains:
applyChannelGains(src, _dst, dinvB, dinvG, dinvR);
}
};
/* Computes sums for each channel, while ignoring saturated pixels which are determined by thresh /* Computes sums for each channel, while ignoring saturated pixels which are determined by thresh
* (version for CV_8UC3) * (version for CV_8UC3)
*/ */
@ -297,41 +345,6 @@ void applyChannelGains(InputArray _src, OutputArray _dst, float gainB, float gai
} }
} }
void autowbGrayworld(InputArray _src, OutputArray _dst, float thresh) Ptr<GrayworldWB> createGrayworldWB() { return makePtr<GrayworldWBImpl>(); }
{
Mat src = _src.getMat();
CV_Assert(!src.empty());
CV_Assert(src.isContinuous());
CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3);
int N = src.cols * src.rows, N3 = N * 3;
double dsumB = 0.0, dsumG = 0.0, dsumR = 0.0;
if (src.type() == CV_8UC3)
{
uint sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<uchar>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
else if (src.type() == CV_16UC3)
{
uint64 sumB = 0, sumG = 0, sumR = 0;
calculateChannelSums(sumB, sumG, sumR, src.ptr<ushort>(), N3, thresh);
dsumB = (double)sumB;
dsumG = (double)sumG;
dsumR = (double)sumR;
}
// Find inverse of averages
double max_sum = max(dsumB, max(dsumR, dsumG));
const double eps = 0.1;
float dinvB = dsumB < eps ? 0.f : (float)(max_sum / dsumB), dinvG = dsumG < eps ? 0.f : (float)(max_sum / dsumG),
dinvR = dsumR < eps ? 0.f : (float)(max_sum / dsumR);
// Use the inverse of averages as channel gains:
applyChannelGains(src, _dst, dinvB, dinvG, dinvR);
}
} }
} }

@ -64,56 +64,115 @@ struct hist_elem
hist_elem(float _hist_val, Vec2f chromaticity) : hist_val(_hist_val), r(chromaticity[0]), g(chromaticity[1]) {} hist_elem(float _hist_val, Vec2f chromaticity) : hist_val(_hist_val), r(chromaticity[0]), g(chromaticity[1]) {}
}; };
bool operator<(const hist_elem &a, const hist_elem &b); bool operator<(const hist_elem &a, const hist_elem &b);
void getColorPalleteMode(Vec2f &dst, hist_elem *pallete, int pallete_sz, float bandwidth);
void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val, float saturation_thresh);
void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f &brightest_chromaticity, Mat &src,
Mat &mask);
void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_pallete_mode, Mat &src, Mat &mask,
int hist_bin_num, int max_val);
bool operator<(const hist_elem &a, const hist_elem &b) { return a.hist_val > b.hist_val; } bool operator<(const hist_elem &a, const hist_elem &b) { return a.hist_val > b.hist_val; }
/* Returns the most high-density point (i.e. mode) of the color pallete. class LearningBasedWBImpl : public LearningBasedWB
* Uses a simplistic kernel density estimator with a Epanechnikov kernel and
* fixed bandwidth.
*/
void getColorPalleteMode(Vec2f &dst, hist_elem *pallete, int pallete_sz, float bandwidth)
{ {
float max_density = -1.0f; private:
float denom = bandwidth * bandwidth; int range_max_val, hist_bin_num, palette_size;
for (int i = 0; i < pallete_sz; i++) float saturation_thresh, palette_bandwidth, prediction_thresh;
int num_trees, num_tree_nodes, tree_depth;
uchar *feature_idx;
float *thresh_vals, *leaf_vals;
Mat feature_idx_Mat, thresh_vals_Mat, leaf_vals_Mat;
Mat mask;
int src_max_val;
void preprocessing(Mat &src);
void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f &brightest_chromaticity, Mat &src);
void getColorPaletteMode(Vec2f &dst, hist_elem *palette);
void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_palette_mode, Mat &src);
float regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tree_thresh_vals, float *tree_leaf_vals);
Vec2f predictIlluminant(vector<Vec2f> features);
public:
LearningBasedWBImpl(String path_to_model)
{
range_max_val = 255;
saturation_thresh = 0.98f;
hist_bin_num = 64;
palette_size = 300;
palette_bandwidth = 0.1f;
prediction_thresh = 0.025f;
if (path_to_model.empty())
{
/* use the default model */
num_trees = _num_trees;
num_tree_nodes = _num_tree_nodes;
feature_idx = _feature_idx;
thresh_vals = _thresh_vals;
leaf_vals = _leaf_vals;
}
else
{ {
float cur_density = 0.0f; /* load model from file */
float cur_dist_sq; FileStorage fs(path_to_model, 0);
num_trees = fs["num_trees"];
num_tree_nodes = fs["num_tree_nodes"];
fs["feature_idx"] >> feature_idx_Mat;
fs["thresh_vals"] >> thresh_vals_Mat;
fs["leaf_vals"] >> leaf_vals_Mat;
feature_idx = feature_idx_Mat.ptr<uchar>();
thresh_vals = thresh_vals_Mat.ptr<float>();
leaf_vals = leaf_vals_Mat.ptr<float>();
}
}
int getRangeMaxVal() const { return range_max_val; }
void setRangeMaxVal(int val) { range_max_val = val; }
float getSaturationThreshold() const { return saturation_thresh; }
void setSaturationThreshold(float val) { saturation_thresh = val; }
for (int j = 0; j < pallete_sz; j++) int getHistBinNum() const { return hist_bin_num; }
void setHistBinNum(int val) { hist_bin_num = val; }
void extractSimpleFeatures(InputArray _src, OutputArray _dst)
{ {
cur_dist_sq = (pallete[i].r - pallete[j].r) * (pallete[i].r - pallete[j].r) + CV_Assert(!_src.empty());
(pallete[i].g - pallete[j].g) * (pallete[i].g - pallete[j].g); CV_Assert(_src.isContinuous());
cur_density += max((1.0f - (cur_dist_sq / denom)), 0.0f); CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3);
Mat src = _src.getMat();
vector<Vec2f> dst(num_features);
preprocessing(src);
getAverageAndBrightestColorChromaticity(dst[0], dst[1], src);
getHistogramBasedFeatures(dst[2], dst[3], src);
Mat(dst).convertTo(_dst, CV_32F);
} }
if (cur_density > max_density) void balanceWhite(InputArray _src, OutputArray _dst)
{ {
max_density = cur_density; CV_Assert(!_src.empty());
dst[0] = pallete[i].r; CV_Assert(_src.isContinuous());
dst[1] = pallete[i].g; CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3);
} Mat src = _src.getMat();
vector<Vec2f> features;
extractSimpleFeatures(src, features);
Vec2f illuminant = predictIlluminant(features);
float denom = 1 - illuminant[0] - illuminant[1];
float gainB = 1.0f;
float gainG = denom / illuminant[1];
float gainR = denom / illuminant[0];
applyChannelGains(src, _dst, gainB, gainG, gainR);
} }
} };
/* Computes a mask for non-saturated pixels and maximum pixel value /* Computes a mask for non-saturated pixels and maximum pixel value
* which are then used for feature computation * which are then used for feature computation
*/ */
void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val, float saturation_thresh) void LearningBasedWBImpl::preprocessing(Mat &src)
{ {
dst_mask = Mat(src.size(), CV_8U); mask.create(src.size(), CV_8U);
uchar *mask_ptr = dst_mask.ptr<uchar>(); uchar *mask_ptr = mask.ptr<uchar>();
int src_len = src.rows * src.cols; int src_len = src.rows * src.cols;
int thresh = (int)(saturation_thresh * range_max_val); int thresh = (int)(saturation_thresh * range_max_val);
int i = 0; int i = 0;
int local_max; int local_max;
dst_max_val = -1; src_max_val = -1;
if (src.type() == CV_8UC3) if (src.type() == CV_8UC3)
{ {
@ -133,15 +192,15 @@ void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val,
v_store(global_max, v_global_max); v_store(global_max, v_global_max);
for (int j = 0; j < 16; j++) for (int j = 0; j < 16; j++)
{ {
if (global_max[j] > dst_max_val) if (global_max[j] > src_max_val)
dst_max_val = global_max[j]; src_max_val = global_max[j];
} }
#endif #endif
for (; i < src_len; i++) for (; i < src_len; i++)
{ {
local_max = max(src_ptr[3 * i], max(src_ptr[3 * i + 1], src_ptr[3 * i + 2])); local_max = max(src_ptr[3 * i], max(src_ptr[3 * i + 1], src_ptr[3 * i + 2]));
if (local_max > dst_max_val) if (local_max > src_max_val)
dst_max_val = local_max; src_max_val = local_max;
if (local_max < thresh) if (local_max < thresh)
mask_ptr[i] = 255; mask_ptr[i] = 255;
else else
@ -166,15 +225,15 @@ void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val,
v_store(global_max, v_global_max); v_store(global_max, v_global_max);
for (int j = 0; j < 8; j++) for (int j = 0; j < 8; j++)
{ {
if (global_max[j] > dst_max_val) if (global_max[j] > src_max_val)
dst_max_val = global_max[j]; src_max_val = global_max[j];
} }
#endif #endif
for (; i < src_len; i++) for (; i < src_len; i++)
{ {
local_max = max(src_ptr[3 * i], max(src_ptr[3 * i + 1], src_ptr[3 * i + 2])); local_max = max(src_ptr[3 * i], max(src_ptr[3 * i + 1], src_ptr[3 * i + 2]));
if (local_max > dst_max_val) if (local_max > src_max_val)
dst_max_val = local_max; src_max_val = local_max;
if (local_max < thresh) if (local_max < thresh)
mask_ptr[i] = 255; mask_ptr[i] = 255;
else else
@ -183,8 +242,8 @@ void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val,
} }
} }
void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f &brightest_chromaticity, Mat &src, void LearningBasedWBImpl::getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity,
Mat &mask) Vec2f &brightest_chromaticity, Mat &src)
{ {
int i = 0; int i = 0;
int src_len = src.rows * src.cols; int src_len = src.rows * src.cols;
@ -376,15 +435,42 @@ void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f
} }
} }
void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_pallete_mode, Mat &src, Mat &mask, /* Returns the most high-density point (i.e. mode) of the color palette.
int hist_bin_num, int max_val) * Uses a simplistic kernel density estimator with a Epanechnikov kernel and
* fixed bandwidth.
*/
void LearningBasedWBImpl::getColorPaletteMode(Vec2f &dst, hist_elem *palette)
{
float max_density = -1.0f;
float denom = palette_bandwidth * palette_bandwidth;
for (int i = 0; i < palette_size; i++)
{
float cur_density = 0.0f;
float cur_dist_sq;
for (int j = 0; j < palette_size; j++)
{
cur_dist_sq = (palette[i].r - palette[j].r) * (palette[i].r - palette[j].r) +
(palette[i].g - palette[j].g) * (palette[i].g - palette[j].g);
cur_density += max((1.0f - (cur_dist_sq / denom)), 0.0f);
}
if (cur_density > max_density)
{
max_density = cur_density;
dst[0] = palette[i].r;
dst[1] = palette[i].g;
}
}
}
void LearningBasedWBImpl::getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_palette_mode,
Mat &src)
{ {
const int pallete_size = 300;
const float pallete_bandwidth = 0.1f;
MatND hist; MatND hist;
int channels[] = {0, 1, 2}; int channels[] = {0, 1, 2};
int histSize[] = {hist_bin_num, hist_bin_num, hist_bin_num}; int histSize[] = {hist_bin_num, hist_bin_num, hist_bin_num};
float range[] = {0, (float)max(hist_bin_num, max_val)}; float range[] = {0, (float)max(hist_bin_num, src_max_val)};
const float *ranges[] = {range, range, range}; const float *ranges[] = {range, range, range};
calcHist(&src, 1, channels, mask, hist, 3, histSize, ranges); calcHist(&src, 1, channels, mask, hist, 3, histSize, ranges);
@ -406,10 +492,10 @@ void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity
} }
getChromaticity(dominant_chromaticity, (float)dominant_R, (float)dominant_G, (float)dominant_B); getChromaticity(dominant_chromaticity, (float)dominant_R, (float)dominant_G, (float)dominant_B);
vector<hist_elem> pallete; vector<hist_elem> palette;
pallete.reserve(pallete_size); palette.reserve(palette_size);
hist_ptr = hist.ptr<float>(); hist_ptr = hist.ptr<float>();
// extract top pallete_size most common colors and add them to the pallete: // extract top palette_size most common colors and add them to the palette:
for (int i = 0; i < hist_bin_num; i++) for (int i = 0; i < hist_bin_num; i++)
for (int j = 0; j < hist_bin_num; j++) for (int j = 0; j < hist_bin_num; j++)
for (int k = 0; k < hist_bin_num; k++) for (int k = 0; k < hist_bin_num; k++)
@ -424,45 +510,28 @@ void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity
getChromaticity(chromaticity, (float)k, (float)j, (float)i); getChromaticity(chromaticity, (float)k, (float)j, (float)i);
hist_elem el(bin_count, chromaticity); hist_elem el(bin_count, chromaticity);
if (pallete.size() < pallete_size) if (palette.size() < (uint)palette_size)
{ {
pallete.push_back(el); palette.push_back(el);
if (pallete.size() == pallete_size) if (palette.size() == (uint)palette_size)
make_heap(pallete.begin(), pallete.end()); make_heap(palette.begin(), palette.end());
} }
else if (bin_count > pallete.front().hist_val) else if (bin_count > palette.front().hist_val)
{ {
pop_heap(pallete.begin(), pallete.end()); pop_heap(palette.begin(), palette.end());
pallete.back() = el; palette.back() = el;
push_heap(pallete.begin(), pallete.end()); push_heap(palette.begin(), palette.end());
} }
hist_ptr++; hist_ptr++;
} }
getColorPalleteMode(chromaticity_pallete_mode, (hist_elem *)(&pallete[0]), (int)pallete.size(), pallete_bandwidth); getColorPaletteMode(chromaticity_palette_mode, (hist_elem *)(&palette[0]));
} }
void extractSimpleFeatures(InputArray _src, OutputArray _dst, int range_max_val, float saturation_thresh, float LearningBasedWBImpl::regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tree_thresh_vals,
int hist_bin_num) float *tree_leaf_vals)
{
Mat src = _src.getMat();
CV_Assert(!src.empty());
CV_Assert(src.isContinuous());
CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3);
vector<Vec2f> dst(num_features);
Mat mask;
int max_val = 0;
preprocessing(mask, max_val, src, range_max_val, saturation_thresh);
getAverageAndBrightestColorChromaticity(dst[0], dst[1], src, mask);
getHistogramBasedFeatures(dst[2], dst[3], src, mask, hist_bin_num, max_val);
Mat(dst).convertTo(_dst, CV_32F);
}
inline float regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tree_thresh_vals, float *tree_leaf_vals)
{ {
int node_idx = 0; int node_idx = 0;
int depth = (int)round(log(num_tree_nodes) / log(2)); for (int i = 0; i < tree_depth; i++)
for (int i = 0; i < depth; i++)
{ {
if (src[tree_feature_idx[node_idx]] <= tree_thresh_vals[node_idx]) if (src[tree_feature_idx[node_idx]] <= tree_thresh_vals[node_idx])
node_idx = 2 * node_idx + 1; node_idx = 2 * node_idx + 1;
@ -472,22 +541,14 @@ inline float regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tr
return tree_leaf_vals[node_idx - num_tree_nodes + 1]; return tree_leaf_vals[node_idx - num_tree_nodes + 1];
} }
void autowbLearningBased(InputArray _src, OutputArray _dst, int range_max_val, float saturation_thresh, Vec2f LearningBasedWBImpl::predictIlluminant(vector<Vec2f> features)
int hist_bin_num)
{ {
const float prediction_thresh = 0.025f;
Mat src = _src.getMat();
CV_Assert(!src.empty());
CV_Assert(src.isContinuous());
CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3);
vector<Vec2f> features;
extractSimpleFeatures(src, features, range_max_val, saturation_thresh, hist_bin_num);
int feature_model_size = 2 * (num_tree_nodes - 1); int feature_model_size = 2 * (num_tree_nodes - 1);
int local_model_size = num_features * feature_model_size; int local_model_size = num_features * feature_model_size;
int feature_model_size_leaf = 2 * num_tree_nodes; int feature_model_size_leaf = 2 * num_tree_nodes;
int local_model_size_leaf = num_features * feature_model_size_leaf; int local_model_size_leaf = num_features * feature_model_size_leaf;
tree_depth = (int)round(log(num_tree_nodes) / log(2));
vector<float> consensus_r, consensus_g; vector<float> consensus_r, consensus_g;
vector<float> all_r, all_g; vector<float> all_r, all_g;
for (int i = 0; i < num_trees; i++) for (int i = 0; i < num_trees; i++)
@ -538,12 +599,13 @@ void autowbLearningBased(InputArray _src, OutputArray _dst, int range_max_val, f
nth_element(consensus_g.begin(), consensus_g.begin() + consensus_g.size() / 2, consensus_g.end()); nth_element(consensus_g.begin(), consensus_g.begin() + consensus_g.size() / 2, consensus_g.end());
illuminant_g = consensus_g[consensus_g.size() / 2]; illuminant_g = consensus_g[consensus_g.size() / 2];
} }
return Vec2f(illuminant_r, illuminant_g);
}
float denom = 1 - illuminant_r - illuminant_g; Ptr<LearningBasedWB> createLearningBasedWB(const String& path_to_model)
float gainB = 1.0f; {
float gainG = denom / illuminant_g; Ptr<LearningBasedWB> inst = makePtr<LearningBasedWBImpl>(path_to_model);
float gainR = denom / illuminant_r; return inst;
applyChannelGains(src, _dst, gainB, gainG, gainR);
} }
} }
} }

@ -2,10 +2,10 @@
* using the following parameters: * using the following parameters:
--num_trees 20 --hist_bin_num 64 --max_tree_depth 4 --num_augmented 2 -r 0,0 --num_trees 20 --hist_bin_num 64 --max_tree_depth 4 --num_augmented 2 -r 0,0
*/ */
const int num_trees = 20;
const int num_features = 4; const int num_features = 4;
const int num_tree_nodes = 16; const int _num_trees = 20;
unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] = { const int _num_tree_nodes = 16;
unsigned char _feature_idx[_num_trees * num_features * 2 * (_num_tree_nodes - 1)] = {
0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1,
1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
@ -68,7 +68,7 @@ unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] =
1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1,
1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1};
float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = { float _thresh_vals[_num_trees * num_features * 2 * (_num_tree_nodes - 1)] = {
.193f, .098f, .455f, .040f, .145f, .316f, .571f, .016f, .058f, .137f, .174f, .276f, .356f, .515f, .730f, .606f, .324f, .193f, .098f, .455f, .040f, .145f, .316f, .571f, .016f, .058f, .137f, .174f, .276f, .356f, .515f, .730f, .606f, .324f,
.794f, .230f, .440f, .683f, .878f, .134f, .282f, .406f, .532f, .036f, .747f, .830f, .931f, .196f, .145f, .363f, .047f, .794f, .230f, .440f, .683f, .878f, .134f, .282f, .406f, .532f, .036f, .747f, .830f, .931f, .196f, .145f, .363f, .047f,
.351f, .279f, .519f, .013f, .887f, .191f, .193f, .361f, .316f, .576f, .445f, .524f, .368f, .752f, .271f, .477f, .636f, .351f, .279f, .519f, .013f, .887f, .191f, .193f, .361f, .316f, .576f, .445f, .524f, .368f, .752f, .271f, .477f, .636f,
@ -211,7 +211,7 @@ float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = {
.550f, .000f, .195f, .377f, .500f, .984f, .000f, .479f, .183f, .704f, .082f, .310f, .567f, .875f, .043f, .141f, .271f, .550f, .000f, .195f, .377f, .500f, .984f, .000f, .479f, .183f, .704f, .082f, .310f, .567f, .875f, .043f, .141f, .271f,
.372f, .511f, .630f, .762f, .896f, .325f, .164f, .602f, .086f, .230f, .414f, .761f, .040f, .131f, .197f, .283f, .352f, .372f, .511f, .630f, .762f, .896f, .325f, .164f, .602f, .086f, .230f, .414f, .761f, .040f, .131f, .197f, .283f, .352f,
.516f, .685f, .855f}; .516f, .685f, .855f};
float leaf_vals[num_trees * num_features * 2 * num_tree_nodes] = { float _leaf_vals[_num_trees * num_features * 2 * _num_tree_nodes] = {
.011f, .029f, .047f, .064f, .075f, .102f, .141f, .172f, .212f, .259f, .308f, .364f, .443f, .497f, .592f, .767f, .069f, .011f, .029f, .047f, .064f, .075f, .102f, .141f, .172f, .212f, .259f, .308f, .364f, .443f, .497f, .592f, .767f, .069f,
.165f, .241f, .278f, .357f, .412f, .463f, .540f, .562f, .623f, .676f, .734f, .797f, .838f, .894f, .944f, .014f, .040f, .165f, .241f, .278f, .357f, .412f, .463f, .540f, .562f, .623f, .676f, .734f, .797f, .838f, .894f, .944f, .014f, .040f,
.061f, .033f, .040f, .160f, .181f, .101f, .123f, .047f, .195f, .282f, .374f, .775f, .248f, .068f, .064f, .155f, .177f, .061f, .033f, .040f, .160f, .181f, .101f, .123f, .047f, .195f, .282f, .374f, .775f, .248f, .068f, .064f, .155f, .177f,

@ -37,50 +37,39 @@
// //
//M*/ //M*/
#include <vector>
#include <algorithm> #include <algorithm>
#include <iterator>
#include <iostream> #include <iostream>
#include <iterator>
#include "opencv2/xphoto.hpp" #include <vector>
#include "opencv2/imgproc.hpp"
#include "opencv2/core.hpp" #include "opencv2/core.hpp"
#include "opencv2/core/core_c.h" #include "opencv2/imgproc.hpp"
#include "opencv2/xphoto.hpp"
#include "opencv2/core/types.hpp"
#include "opencv2/core/types_c.h"
namespace cv namespace cv
{ {
namespace xphoto namespace xphoto
{ {
template <typename T> template <typename T>
void balanceWhite(std::vector < Mat_<T> > &src, Mat &dst, void balanceWhiteSimple(std::vector<Mat_<T> > &src, Mat &dst, const float inputMin, const float inputMax,
const float inputMin, const float inputMax, const float outputMin, const float outputMax, const float p)
const float outputMin, const float outputMax, const int algorithmType) {
{
switch ( algorithmType )
{
case WHITE_BALANCE_SIMPLE:
{
/********************* Simple white balance *********************/ /********************* Simple white balance *********************/
float s1 = 2.0f; // low quantile const float s1 = p; // low quantile
float s2 = 2.0f; // high quantile const float s2 = p; // high quantile
int depth = 2; // depth of histogram tree int depth = 2; // depth of histogram tree
if (src[0].depth() != CV_8U) if (src[0].depth() != CV_8U)
++depth; ++depth;
int bins = 16; // number of bins at each histogram level int bins = 16; // number of bins at each histogram level
int nElements = int( pow((float)bins, (float)depth) ); int nElements = int(pow((float)bins, (float)depth));
// number of elements in histogram tree // number of elements in histogram tree
for (size_t i = 0; i < src.size(); ++i) for (size_t i = 0; i < src.size(); ++i)
{ {
std::vector <int> hist(nElements, 0); std::vector<int> hist(nElements, 0);
typename Mat_<T>::iterator beginIt = src[i].begin(); typename Mat_<T>::iterator beginIt = src[i].begin();
typename Mat_<T>::iterator endIt = src[i].end(); typename Mat_<T>::iterator endIt = src[i].end();
@ -97,19 +86,19 @@ namespace xphoto
for (int j = 0; j < depth; ++j) for (int j = 0; j < depth; ++j)
{ {
int currentBin = int( (val - minValue + 1e-4f) / interval ); int currentBin = int((val - minValue + 1e-4f) / interval);
++hist[pos + currentBin]; ++hist[pos + currentBin];
pos = (pos + currentBin)*bins; pos = (pos + currentBin) * bins;
minValue = minValue + currentBin*interval; minValue = minValue + currentBin * interval;
maxValue = minValue + interval; maxValue = minValue + interval;
interval /= bins; interval /= bins;
} }
} }
int total = int( src[i].total() ); int total = int(src[i].total());
int p1 = 0, p2 = bins - 1; int p1 = 0, p2 = bins - 1;
int n1 = 0, n2 = total; int n1 = 0, n2 = total;
@ -134,78 +123,90 @@ namespace xphoto
n2 -= hist[p2--]; n2 -= hist[p2--];
maxValue -= interval; maxValue -= interval;
} }
p2 = p2*bins - 1; p2 = p2 * bins - 1;
interval /= bins; interval /= bins;
} }
src[i] = (outputMax - outputMin) * (src[i] - minValue) src[i] = (outputMax - outputMin) * (src[i] - minValue) / (maxValue - minValue) + outputMin;
/ (maxValue - minValue) + outputMin;
} }
/****************************************************************/ /****************************************************************/
break;
}
default:
CV_Error_( CV_StsNotImplemented,
("Unsupported algorithm type (=%d)", algorithmType) );
}
dst.create(/**/ src[0].size(), CV_MAKETYPE( src[0].depth(), int( src.size() ) ) /**/); dst.create(/**/ src[0].size(), CV_MAKETYPE(src[0].depth(), int(src.size())) /**/);
cv::merge(src, dst); cv::merge(src, dst);
}
class SimpleWBImpl : public SimpleWB
{
private:
float inputMin, inputMax, outputMin, outputMax, p;
public:
SimpleWBImpl()
{
inputMin = 0.0f;
inputMax = 255.0f;
outputMin = 0.0f;
outputMax = 255.0f;
p = 2.0f;
} }
/*! float getInputMin() const { return inputMin; }
* Wrappers over different white balance algorithm void setInputMin(float val) { inputMin = val; }
*
* \param src : source image (RGB) float getInputMax() const { return inputMax; }
* \param dst : destination image void setInputMax(float val) { inputMax = val; }
*
* \param inputMin : minimum input value float getOutputMin() const { return outputMin; }
* \param inputMax : maximum input value void setOutputMin(float val) { outputMin = val; }
* \param outputMin : minimum output value
* \param outputMax : maximum output value float getOutputMax() const { return outputMax; }
* void setOutputMax(float val) { outputMax = val; }
* \param algorithmType : type of the algorithm to use
*/ float getP() const { return p; }
void balanceWhite(const Mat &src, Mat &dst, const int algorithmType, void setP(float val) { p = val; }
const float inputMin, const float inputMax,
const float outputMin, const float outputMax) void balanceWhite(InputArray _src, OutputArray _dst)
{ {
switch ( src.depth() ) CV_Assert(!_src.empty());
CV_Assert(_src.depth() == CV_8U || _src.depth() == CV_16S || _src.depth() == CV_32S || _src.depth() == CV_32F);
Mat src = _src.getMat();
Mat &dst = _dst.getMatRef();
switch (src.depth())
{ {
case CV_8U: case CV_8U:
{ {
std::vector < Mat_<uchar> > mv; std::vector<Mat_<uchar> > mv;
split(src, mv); split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType); balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break; break;
} }
case CV_16S: case CV_16S:
{ {
std::vector < Mat_<short> > mv; std::vector<Mat_<short> > mv;
split(src, mv); split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType); balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break; break;
} }
case CV_32S: case CV_32S:
{ {
std::vector < Mat_<int> > mv; std::vector<Mat_<int> > mv;
split(src, mv); split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType); balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break; break;
} }
case CV_32F: case CV_32F:
{ {
std::vector < Mat_<float> > mv; std::vector<Mat_<float> > mv;
split(src, mv); split(src, mv);
balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType); balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p);
break; break;
} }
default:
CV_Error_( CV_StsNotImplemented,
("Unsupported source image format (=%d)", src.type()) );
break;
} }
} }
};
Ptr<SimpleWB> createSimpleWB() { return makePtr<SimpleWBImpl>(); }
} }
} }

@ -7,6 +7,7 @@ namespace cvtest
cv::String dir = cvtest::TS::ptr()->get_data_path() + "cv/xphoto/simple_white_balance/"; cv::String dir = cvtest::TS::ptr()->get_data_path() + "cv/xphoto/simple_white_balance/";
int nTests = 12; int nTests = 12;
float threshold = 0.005f; float threshold = 0.005f;
cv::Ptr<cv::xphoto::WhiteBalancer> wb = cv::xphoto::createSimpleWB();
for (int i = 0; i < nTests; ++i) for (int i = 0; i < nTests; ++i)
{ {
@ -18,7 +19,7 @@ namespace cvtest
cv::Mat previousResult = cv::imread( previousResultName, 1 ); cv::Mat previousResult = cv::imread( previousResultName, 1 );
cv::Mat currentResult; cv::Mat currentResult;
cv::xphoto::balanceWhite(src, currentResult, cv::xphoto::WHITE_BALANCE_SIMPLE); wb->balanceWhite(src, currentResult);
cv::Mat sqrError = ( currentResult - previousResult ) cv::Mat sqrError = ( currentResult - previousResult )
.mul( currentResult - previousResult ); .mul( currentResult - previousResult );

@ -69,6 +69,8 @@ namespace cvtest {
const int nTests = 14; const int nTests = 14;
const float wb_thresh = 0.5f; const float wb_thresh = 0.5f;
const float acc_thresh = 2.f; const float acc_thresh = 2.f;
Ptr<xphoto::GrayworldWB> wb = xphoto::createGrayworldWB();
wb->setSaturationThreshold(wb_thresh);
for ( int i = 0; i < nTests; ++i ) for ( int i = 0; i < nTests; ++i )
{ {
@ -80,13 +82,13 @@ namespace cvtest {
ref_autowbGrayworld(src, referenceResult, wb_thresh); ref_autowbGrayworld(src, referenceResult, wb_thresh);
Mat currentResult; Mat currentResult;
xphoto::autowbGrayworld(src, currentResult, wb_thresh); wb->balanceWhite(src, currentResult);
ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh);
// test the 16-bit depth: // test the 16-bit depth:
Mat currentResult_16U, src_16U; Mat currentResult_16U, src_16U;
src.convertTo(src_16U, CV_16UC3, 256.0); src.convertTo(src_16U, CV_16UC3, 256.0);
xphoto::autowbGrayworld(src_16U, currentResult_16U, wb_thresh); wb->balanceWhite(src_16U, currentResult_16U);
currentResult_16U.convertTo(currentResult, CV_8UC3, 1/256.0); currentResult_16U.convertTo(currentResult, CV_8UC3, 1/256.0);
ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh);
} }

@ -18,7 +18,11 @@ TEST(xphoto_simplefeatures, regression)
Vec2f ref2(200.0f / (240 + 220 + 200), 220.0f / (240 + 220 + 200)); Vec2f ref2(200.0f / (240 + 220 + 200), 220.0f / (240 + 220 + 200));
vector<Vec2f> dst_features; vector<Vec2f> dst_features;
xphoto::extractSimpleFeatures(test_im, dst_features, 255, 0.98f, 64); Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB();
wb->setRangeMaxVal(255);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(64);
wb->extractSimpleFeatures(test_im, dst_features);
ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh);
@ -26,7 +30,10 @@ TEST(xphoto_simplefeatures, regression)
// check 16 bit depth: // check 16 bit depth:
test_im.convertTo(test_im, CV_16U, 256.0); test_im.convertTo(test_im, CV_16U, 256.0);
xphoto::extractSimpleFeatures(test_im, dst_features, 65535, 0.98f, 64); wb->setRangeMaxVal(65535);
wb->setSaturationThreshold(0.98f);
wb->setHistBinNum(128);
wb->extractSimpleFeatures(test_im, dst_features);
ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh);
ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh);

@ -0,0 +1,42 @@
Training the learning-based white balance algorithm {#tutorial_xphoto_training_white_balance}
===================================================
Introduction
------------
Many traditional white balance algorithms are statistics-based, i.e. they rely on the fact that certain assumptions should hold in properly white-balanced images
like the well-known grey-world assumption. However, better results can often be achieved by leveraging large datasets of images with ground-truth
illuminants in a learning-based framework. This tutorial demonstrates how to train a learning-based white balance algorithm and evaluate the quality of the results.
How to train a model
--------------------
-# Download a dataset for training. In this tutorial we will use the [Gehler-Shi dataset ](http://www.cs.sfu.ca/~colour/data/shi_gehler/). Extract all 568 training images
in one folder. A file containing ground-truth illuminant values (real_illum_568..mat) is downloaded separately.
-# We will be using a [Python script ](https://github.com/opencv/opencv_contrib/tree/master/modules/xphoto/samples/learn_color_balance.py) for training.
Call it with the following parameters:
@code
python learn_color_balance.py -i <path to the folder with training images> -g <path to real_illum_568..mat> -r 0,378 --num_trees 30 --max_tree_depth 6 --num_augmented 0
@endcode
This should start training a model on the first 378 images (2/3 of the whole dataset). We set the size of the model to be 30 regression tree pairs per feature and limit
the tree depth to be no more then 6. By default the resulting model will be saved to color_balance_model.yml
-# Use the trained model by passing its path when constructing an instance of LearningBasedWB:
@code{.cpp}
Ptr<xphoto::LearningBasedWB> wb = xphoto::createLearningBasedWB(modelFilename);
@endcode
How to evaluate a model
----------------------
-# We will use a [benchmarking script ](https://github.com/opencv/opencv_contrib/tree/master/modules/xphoto/samples/color_balance_benchmark.py) to compare
the model that we've trained with the classic grey-world algorithm on the remaining 1/3 of the dataset. Call the script with the following parameters:
@code
python color_balance_benchmark.py -a grayworld,learning_based:color_balance_model.yml -m <full path to folder containing the model> -i <path to the folder with training images> -g <path to real_illum_568..mat> -r 379,567 -d "img"
@endcode
-# The objective evaluation results are stored in white_balance_eval_result.html and the resulting white-balanced images are stored in the img folder for a qualitative
comparison of algorithms. Different algorithms are compared in terms of angular error between the estimated and ground-truth illuminants.
Loading…
Cancel
Save