diff --git a/modules/xphoto/include/opencv2/xphoto/white_balance.hpp b/modules/xphoto/include/opencv2/xphoto/white_balance.hpp index 124e55560..1767f1f42 100644 --- a/modules/xphoto/include/opencv2/xphoto/white_balance.hpp +++ b/modules/xphoto/include/opencv2/xphoto/white_balance.hpp @@ -58,129 +58,172 @@ namespace xphoto //! @addtogroup xphoto //! @{ - //! various white balance algorithms - enum WhitebalanceTypes - { - /** perform smart histogram adjustments (ignoring 4% pixels with minimal and maximal - values) for each channel */ - WHITE_BALANCE_SIMPLE = 0, - WHITE_BALANCE_GRAYWORLD = 1 - }; - - /** @brief The function implements different algorithm of automatic white balance, - - i.e. it tries to map image's white color to perceptual white (this can be violated due to - specific illumination or camera settings). - - @param src - @param dst - @param algorithmType see xphoto::WhitebalanceTypes - @param inputMin minimum value in the input image - @param inputMax maximum value in the input image - @param outputMin minimum value in the output image - @param outputMax maximum value in the output image - @sa cvtColor, equalizeHist - */ - CV_EXPORTS_W void balanceWhite(const Mat &src, Mat &dst, const int algorithmType, - const float inputMin = 0.0f, const float inputMax = 255.0f, - const float outputMin = 0.0f, const float outputMax = 255.0f); - - /** @brief Implements a simple grayworld white balance algorithm. - - The function autowbGrayworld scales the values of pixels based on a - gray-world assumption which states that the average of all channels - should result in a gray image. - - This function adds a modification which thresholds pixels based on their - saturation value and only uses pixels below the provided threshold in - finding average pixel values. - - Saturation is calculated using the following for a 3-channel RGB image per - pixel I and is in the range [0, 1]: - - \f[ \texttt{Saturation} [I] = \frac{\textrm{max}(R,G,B) - \textrm{min}(R,G,B) - }{\textrm{max}(R,G,B)} \f] - - A threshold of 1 means that all pixels are used to white-balance, while a - threshold of 0 means no pixels are used. Lower thresholds are useful in - white-balancing saturated images. - - Currently only works on images of type @ref CV_8UC3 and @ref CV_16UC3. - - @param src Input array. - @param dst Output array of the same size and type as src. - @param thresh Maximum saturation for a pixel to be included in the - gray-world assumption. - - @sa balanceWhite - */ - CV_EXPORTS_W void autowbGrayworld(InputArray src, OutputArray dst, - float thresh = 0.5f); - - /** @brief Implements a more sophisticated learning-based automatic color balance algorithm. - - As autowbGrayworld, this function works by applying different gains to the input - image channels, but their computation is a bit more involved compared to the - simple grayworld assumption. More details about the algorithm can be found in - @cite Cheng2015 . - - To mask out saturated pixels this function uses only pixels that satisfy the - following condition: - - \f[ \frac{\textrm{max}(R,G,B)}{\texttt{range_max_val}} < \texttt{saturation_thresh} \f] - - Currently supports images of type @ref CV_8UC3 and @ref CV_16UC3. - - @param src Input three-channel image in the BGR color space. - @param dst Output image of the same size and type as src. - @param range_max_val Maximum possible value of the input image (e.g. 255 for 8 bit images, 4095 for 12 bit images) - @param saturation_thresh Threshold that is used to determine saturated pixels - @param hist_bin_num Defines the size of one dimension of a three-dimensional RGB histogram that is used internally by - the algorithm. It often makes sense to increase the number of bins for images with higher bit depth (e.g. 256 bins - for a 12 bit image) +/** @brief The base class for auto white balance algorithms. + */ +class CV_EXPORTS_W WhiteBalancer : public Algorithm +{ + public: + /** @brief Applies white balancing to the input image - @sa autowbGrayworld + @param src Input image + @param dst White balancing result + @sa cvtColor, equalizeHist */ - CV_EXPORTS_W void autowbLearningBased(InputArray src, OutputArray dst, int range_max_val = 255, - float saturation_thresh = 0.98f, int hist_bin_num = 64); - - /** @brief Implements the feature extraction part of the learning-based color balance algorithm. + CV_WRAP virtual void balanceWhite(InputArray src, OutputArray dst) = 0; +}; + +/** @brief A simple white balance algorithm that works by independently stretching + each of the input image channels to the specified range. For increased robustness + it ignores the top and bottom \f$p\%\f$ of pixel values. + */ +class CV_EXPORTS_W SimpleWB : public WhiteBalancer +{ + public: + /** @brief Input image range minimum value + @see setInputMin */ + CV_WRAP virtual float getInputMin() const = 0; + /** @copybrief getInputMin @see getInputMin */ + CV_WRAP virtual void setInputMin(float val) = 0; + + /** @brief Input image range maximum value + @see setInputMax */ + CV_WRAP virtual float getInputMax() const = 0; + /** @copybrief getInputMax @see getInputMax */ + CV_WRAP virtual void setInputMax(float val) = 0; + + /** @brief Output image range minimum value + @see setOutputMin */ + CV_WRAP virtual float getOutputMin() const = 0; + /** @copybrief getOutputMin @see getOutputMin */ + CV_WRAP virtual void setOutputMin(float val) = 0; + + /** @brief Output image range maximum value + @see setOutputMax */ + CV_WRAP virtual float getOutputMax() const = 0; + /** @copybrief getOutputMax @see getOutputMax */ + CV_WRAP virtual void setOutputMax(float val) = 0; + + /** @brief Percent of top/bottom values to ignore + @see setP */ + CV_WRAP virtual float getP() const = 0; + /** @copybrief getP @see getP */ + CV_WRAP virtual void setP(float val) = 0; +}; + +/** @brief Creates an instance of SimpleWB + */ +CV_EXPORTS_W Ptr createSimpleWB(); + +/** @brief Gray-world white balance algorithm + +This algorithm scales the values of pixels based on a +gray-world assumption which states that the average of all channels +should result in a gray image. + +It adds a modification which thresholds pixels based on their +saturation value and only uses pixels below the provided threshold in +finding average pixel values. + +Saturation is calculated using the following for a 3-channel RGB image per +pixel I and is in the range [0, 1]: + +\f[ \texttt{Saturation} [I] = \frac{\textrm{max}(R,G,B) - \textrm{min}(R,G,B) +}{\textrm{max}(R,G,B)} \f] + +A threshold of 1 means that all pixels are used to white-balance, while a +threshold of 0 means no pixels are used. Lower thresholds are useful in +white-balancing saturated images. + +Currently supports images of type @ref CV_8UC3 and @ref CV_16UC3. + */ +class CV_EXPORTS_W GrayworldWB : public WhiteBalancer +{ + public: + /** @brief Maximum saturation for a pixel to be included in the + gray-world assumption + @see setSaturationThreshold */ + CV_WRAP virtual float getSaturationThreshold() const = 0; + /** @copybrief getSaturationThreshold @see getSaturationThreshold */ + CV_WRAP virtual void setSaturationThreshold(float val) = 0; +}; + +/** @brief Creates an instance of GrayworldWB + */ +CV_EXPORTS_W Ptr createGrayworldWB(); + +/** @brief More sophisticated learning-based automatic white balance algorithm. + +As @ref GrayworldWB, this algorithm works by applying different gains to the input +image channels, but their computation is a bit more involved compared to the +simple gray-world assumption. More details about the algorithm can be found in +@cite Cheng2015 . + +To mask out saturated pixels this function uses only pixels that satisfy the +following condition: + +\f[ \frac{\textrm{max}(R,G,B)}{\texttt{range_max_val}} < \texttt{saturation_thresh} \f] + +Currently supports images of type @ref CV_8UC3 and @ref CV_16UC3. + */ +class CV_EXPORTS_W LearningBasedWB : public WhiteBalancer +{ + public: + /** @brief Implements the feature extraction part of the algorithm. In accordance with @cite Cheng2015 , computes the following features for the input image: 1. Chromaticity of an average (R,G,B) tuple 2. Chromaticity of the brightest (R,G,B) tuple (while ignoring saturated pixels) 3. Chromaticity of the dominant (R,G,B) tuple (the one that has the highest value in the RGB histogram) - 4. Mode of the chromaticity pallete, that is constructed by taking 300 most common colors according to + 4. Mode of the chromaticity palette, that is constructed by taking 300 most common colors according to the RGB histogram and projecting them on the chromaticity plane. Mode is the most high-density point - of the pallete, which is computed by a straightforward fixed-bandwidth kernel density estimator with + of the palette, which is computed by a straightforward fixed-bandwidth kernel density estimator with a Epanechnikov kernel function. - @param src Input three-channel image in the BGR color space. + @param src Input three-channel image (BGR color space is assumed). @param dst An array of four (r,g) chromaticity tuples corresponding to the features listed above. - @param range_max_val Maximum possible value of the input image (e.g. 255 for 8 bit images, 4095 for 12 bit images) - @param saturation_thresh Threshold that is used to determine saturated pixels - @param hist_bin_num Defines the size of one dimension of a three-dimensional RGB histogram that is used internally by - the algorithm. It often makes sense to increase the number of bins for images with higher bit depth (e.g. 256 bins - for a 12 bit image) - - @sa autowbLearningBased */ - CV_EXPORTS_W void extractSimpleFeatures(InputArray src, OutputArray dst, int range_max_val = 255, - float saturation_thresh = 0.98f, int hist_bin_num = 64); - - /** @brief Implements an efficient fixed-point approximation for applying channel gains. - - @param src Input three-channel image in the BGR color space (either CV_8UC3 or CV_16UC3) - @param dst Output image of the same size and type as src. - @param gainB gain for the B channel - @param gainG gain for the G channel - @param gainR gain for the R channel - - @sa autowbGrayworld, autowbLearningBased - */ - CV_EXPORTS_W void applyChannelGains(InputArray src, OutputArray dst, float gainB, float gainG, float gainR); - //! @} - + CV_WRAP virtual void extractSimpleFeatures(InputArray src, OutputArray dst) = 0; + + /** @brief Maximum possible value of the input image (e.g. 255 for 8 bit images, + 4095 for 12 bit images) + @see setRangeMaxVal */ + CV_WRAP virtual int getRangeMaxVal() const = 0; + /** @copybrief getRangeMaxVal @see getRangeMaxVal */ + CV_WRAP virtual void setRangeMaxVal(int val) = 0; + + /** @brief Threshold that is used to determine saturated pixels, i.e. pixels where at least one of the + channels exceeds \f$\texttt{saturation_threshold}\times\texttt{range_max_val}\f$ are ignored. + @see setSaturationThreshold */ + CV_WRAP virtual float getSaturationThreshold() const = 0; + /** @copybrief getSaturationThreshold @see getSaturationThreshold */ + CV_WRAP virtual void setSaturationThreshold(float val) = 0; + + /** @brief Defines the size of one dimension of a three-dimensional RGB histogram that is used internally + by the algorithm. It often makes sense to increase the number of bins for images with higher bit depth + (e.g. 256 bins for a 12 bit image). + @see setHistBinNum */ + CV_WRAP virtual int getHistBinNum() const = 0; + /** @copybrief getHistBinNum @see getHistBinNum */ + CV_WRAP virtual void setHistBinNum(int val) = 0; +}; + +/** @brief Creates an instance of LearningBasedWB + +@param path_to_model Path to a .yml file with the model. If not specified, the default model is used + */ +CV_EXPORTS_W Ptr createLearningBasedWB(const String& path_to_model = String()); + +/** @brief Implements an efficient fixed-point approximation for applying channel gains, which is + the last step of multiple white balance algorithms. + +@param src Input three-channel image in the BGR color space (either CV_8UC3 or CV_16UC3) +@param dst Output image of the same size and type as src. +@param gainB gain for the B channel +@param gainG gain for the G channel +@param gainR gain for the R channel +*/ +CV_EXPORTS_W void applyChannelGains(InputArray src, OutputArray dst, float gainB, float gainG, float gainR); +//! @} } } diff --git a/modules/xphoto/perf/perf_grayworld.cpp b/modules/xphoto/perf/perf_grayworld.cpp index 80c034c58..32eea8f19 100644 --- a/modules/xphoto/perf/perf_grayworld.cpp +++ b/modules/xphoto/perf/perf_grayworld.cpp @@ -21,8 +21,10 @@ PERF_TEST_P( Size_WBThresh, autowbGrayworld, Mat dst(size, CV_8UC3); declare.in(src, WARMUP_RNG).out(dst); + Ptr wb = xphoto::createGrayworldWB(); + wb->setSaturationThreshold(wb_thresh); - TEST_CYCLE() xphoto::autowbGrayworld(src, dst, wb_thresh); + TEST_CYCLE() wb->balanceWhite(src, dst); SANITY_CHECK(dst); } diff --git a/modules/xphoto/perf/perf_learning_based_color_balance.cpp b/modules/xphoto/perf/perf_learning_based_color_balance.cpp index 7570a20eb..dbffbd79d 100644 --- a/modules/xphoto/perf/perf_learning_based_color_balance.cpp +++ b/modules/xphoto/perf/perf_learning_based_color_balance.cpp @@ -65,8 +65,12 @@ PERF_TEST_P(learningBasedWBPerfTest, perf, Combine(SZ_ALL_HD, Values(CV_8UC3, CV RNG rng(1234); rng.fill(src_dscl, RNG::UNIFORM, 0, range_max_val); resize(src_dscl, src, src.size()); + Ptr wb = xphoto::createLearningBasedWB(); + wb->setRangeMaxVal(range_max_val); + wb->setSaturationThreshold(0.98f); + wb->setHistBinNum(hist_bin_num); - TEST_CYCLE() xphoto::autowbLearningBased(src, dst, range_max_val, 0.98f, hist_bin_num); + TEST_CYCLE() wb->balanceWhite(src, dst); SANITY_CHECK_NOTHING(); } diff --git a/modules/xphoto/samples/color_balance.cpp b/modules/xphoto/samples/color_balance.cpp new file mode 100644 index 000000000..b33fe9d7c --- /dev/null +++ b/modules/xphoto/samples/color_balance.cpp @@ -0,0 +1,68 @@ +#include "opencv2/xphoto.hpp" +#include "opencv2/highgui.hpp" + +using namespace cv; +using namespace std; + +const char *keys = { "{help h usage ? | | print this message}" + "{i | | input image name }" + "{o | | output image name }" + "{a |grayworld| color balance algorithm (simple, grayworld or learning_based)}" + "{m | | path to the model for the learning-based algorithm (optional) }" }; + +int main(int argc, const char **argv) +{ + CommandLineParser parser(argc, argv, keys); + parser.about("OpenCV color balance demonstration sample"); + if (parser.has("help") || argc < 2) + { + parser.printMessage(); + return 0; + } + + string inFilename = parser.get("i"); + string outFilename = parser.get("o"); + string algorithm = parser.get("a"); + string modelFilename = parser.get("m"); + + if (!parser.check()) + { + parser.printErrors(); + return -1; + } + + Mat src = imread(inFilename, 1); + if (src.empty()) + { + printf("Cannot read image file: %s\n", inFilename.c_str()); + return -1; + } + + Mat res; + Ptr wb; + if (algorithm == "simple") + wb = xphoto::createSimpleWB(); + else if (algorithm == "grayworld") + wb = xphoto::createGrayworldWB(); + else if (algorithm == "learning_based") + wb = xphoto::createLearningBasedWB(modelFilename); + else + { + printf("Unsupported algorithm: %s\n", algorithm.c_str()); + return -1; + } + + wb->balanceWhite(src, res); + + if (outFilename == "") + { + namedWindow("after white balance", 1); + imshow("after white balance", res); + + waitKey(0); + } + else + imwrite(outFilename, res); + + return 0; +} diff --git a/modules/xphoto/samples/color_balance_benchmark.py b/modules/xphoto/samples/color_balance_benchmark.py index 3cf12a30d..405b8f9af 100644 --- a/modules/xphoto/samples/color_balance_benchmark.py +++ b/modules/xphoto/samples/color_balance_benchmark.py @@ -5,6 +5,7 @@ import numpy as np import scipy.io import cv2 import timeit +from learn_color_balance import load_ground_truth def load_json(path): @@ -39,15 +40,24 @@ def stretch_to_8bit(arr, clip_percentile = 2.5): return arr.astype(np.uint8) -def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder): +def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder, model_folder): new_im = None start_time = timeit.default_timer() if algo=="grayworld": - new_im = cv2.xphoto.autowbGrayworld(im, 0.95) + inst = cv2.xphoto.createGrayworldWB() + inst.setSaturationThreshold(0.95) + new_im = inst.balanceWhite(im) elif algo=="nothing": new_im = im - elif algo=="learning_based": - new_im = cv2.xphoto.autowbLearningBased(im, None, range_thresh, 0.98, bin_num) + elif algo.split(":")[0]=="learning_based": + model_path = "" + if len(algo.split(":"))>1: + model_path = os.path.join(model_folder, algo.split(":")[1]) + inst = cv2.xphoto.createLearningBasedWB(model_path) + inst.setRangeMaxVal(range_thresh) + inst.setSaturationThreshold(0.98) + inst.setHistBinNum(bin_num) + new_im = inst.balanceWhite(im) elif algo=="GT": gains = gt_illuminant / min(gt_illuminant) g1 = float(1.0 / gains[2]) @@ -59,7 +69,7 @@ def evaluate(im, algo, gt_illuminant, i, range_thresh, bin_num, dst_folder): if len(dst_folder)>0: if not os.path.exists(dst_folder): os.makedirs(dst_folder) - im_name = ("%04d_" % i) + algo + ".jpg" + im_name = ("%04d_" % i) + algo.replace(":","_") + ".jpg" cv2.imwrite(os.path.join(dst_folder, im_name), stretch_to_8bit(new_im)) #recover the illuminant from the color balancing result, assuming the standard model: @@ -140,7 +150,9 @@ if __name__ == '__main__': metavar="ALGORITHMS", default="", help=("Comma-separated list of color balance algorithms to evaluate. " - "Currently available: GT,learning_based,grayworld,nothing.")) + "Currently available: GT,learning_based,grayworld,nothing. " + "Use a colon to set a specific model for the learning-based " + "algorithm, e.g. learning_based:model1.yml,learning_based:model2.yml")) parser.add_argument( "-i", "--input_folder", @@ -196,6 +208,12 @@ if __name__ == '__main__': default="0,0", help=("Comma-separated range of images from the dataset to evaluate on (for instance: 0,568). " "All available images are used by default.")) + parser.add_argument( + "-m", + "--model_folder", + metavar="MODEL_FOLDER", + default="", + help=("Path to the folder containing models for the learning-based color balance algorithm (optional)")) args, other_args = parser.parse_known_args() if not os.path.exists(args.input_folder): @@ -218,22 +236,8 @@ if __name__ == '__main__': print("Error: Please specify the -r parameter in form ,") sys.exit(1) - gt = scipy.io.loadmat(args.ground_truth) img_files = sorted(os.listdir(args.input_folder)) - - gt_illuminants = [] - black_levels = [] - if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys(): - #NUS 8-camera dataset format - gt_illuminants = gt["groundtruth_illuminants"] - black_levels = len(gt_illuminants) * [gt["darkness_level"][0][0]] - elif "real_rgb" in gt.keys(): - #Gehler-Shi dataset format - gt_illuminants = gt["real_rgb"] - black_levels = 87 * [0] + (len(gt_illuminants) - 87) * [129] - else: - print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported") - sys.exit(1) + (gt_illuminants,black_levels) = load_ground_truth(args.ground_truth) for algorithm in algorithm_list: i = 0 @@ -254,7 +258,7 @@ if __name__ == '__main__': im = stretch_to_8bit(im) (time,angular_err) = evaluate(im, algorithm, gt_illuminants[i], i, range_thresh, - 256 if range_thresh > 255 else 64, args.dst_folder) + 256 if range_thresh > 255 else 64, args.dst_folder, args.model_folder) state[algorithm][file] = {"angular_error": angular_err, "time": time} sys.stdout.write("Algorithm: %-20s Done: [%3d/%3d]\r" % (algorithm, i, sz)), sys.stdout.flush() diff --git a/modules/xphoto/samples/grayworld_color_balance.cpp b/modules/xphoto/samples/grayworld_color_balance.cpp deleted file mode 100644 index caaa0a454..000000000 --- a/modules/xphoto/samples/grayworld_color_balance.cpp +++ /dev/null @@ -1,62 +0,0 @@ -#include "opencv2/xphoto.hpp" - -#include "opencv2/imgproc.hpp" -#include "opencv2/highgui.hpp" - -#include "opencv2/core/utility.hpp" - -using namespace cv; -using namespace std; - -const char* keys = -{ - "{i || input image name}" - "{o || output image name}" -}; - -int main( int argc, const char** argv ) -{ - bool printHelp = ( argc == 1 ); - printHelp = printHelp || ( argc == 2 && string(argv[1]) == "--help" ); - printHelp = printHelp || ( argc == 2 && string(argv[1]) == "-h" ); - - if ( printHelp ) - { - printf("\nThis sample demonstrates the grayworld balance algorithm\n" - "Call:\n" - " simple_color_blance -i=in_image_name [-o=out_image_name]\n\n"); - return 0; - } - - CommandLineParser parser(argc, argv, keys); - if ( !parser.check() ) - { - parser.printErrors(); - return -1; - } - - string inFilename = parser.get("i"); - string outFilename = parser.get("o"); - - Mat src = imread(inFilename, 1); - if ( src.empty() ) - { - printf("Cannot read image file: %s\n", inFilename.c_str()); - return -1; - } - - Mat res(src.size(), src.type()); - xphoto::autowbGrayworld(src, res); - - if ( outFilename == "" ) - { - namedWindow("after white balance", 1); - imshow("after white balance", res); - - waitKey(0); - } - else - imwrite(outFilename, res); - - return 0; -} diff --git a/modules/xphoto/src/learn_color_balance.py b/modules/xphoto/samples/learn_color_balance.py similarity index 75% rename from modules/xphoto/src/learn_color_balance.py rename to modules/xphoto/samples/learn_color_balance.py index 985af4557..c3b7fbc2f 100644 --- a/modules/xphoto/src/learn_color_balance.py +++ b/modules/xphoto/samples/learn_color_balance.py @@ -80,7 +80,7 @@ def get_tree_node_lists(tree, tree_depth): return (dst_feature_idx, dst_thresh_vals, dst_leaf_vals) -def generate_code(model, input_params): +def generate_code(model, input_params, use_YML, out_file): feature_idx = [] thresh_vals = [] leaf_vals = [] @@ -95,31 +95,60 @@ def generate_code(model, input_params): feature_idx += local_feature_idx thresh_vals += local_thresh_vals leaf_vals += local_leaf_vals + if use_YML: + fs = cv2.FileStorage(out_file, 1) + fs.write("num_trees", len(model)) + fs.write("num_tree_nodes", 2**depth) + fs.write("feature_idx", np.array(feature_idx).astype(np.uint8)) + fs.write("thresh_vals", np.array(thresh_vals).astype(np.float32)) + fs.write("leaf_vals", np.array(leaf_vals).astype(np.float32)) + fs.release() + else: + res = "/* This file was automatically generated by learn_color_balance.py script\n" +\ + " * using the following parameters:\n" + for key in input_params: + res += " " + key + " " + input_params[key] + res += "\n */\n" + res += "const int num_features = 4;\n" + res += "const int _num_trees = " + str(len(model)) + ";\n" + res += "const int _num_tree_nodes = " + str(2**depth) + ";\n" + + res += "unsigned char _feature_idx[_num_trees*num_features*2*(_num_tree_nodes-1)] = {" + str(feature_idx[0]) + for i in range(1,len(feature_idx)): + res += "," + str(feature_idx[i]) + res += "};\n" + + res += "float _thresh_vals[_num_trees*num_features*2*(_num_tree_nodes-1)] = {" + ("%.3ff" % thresh_vals[0])[1:] + for i in range(1,len(thresh_vals)): + res += "," + ("%.3ff" % thresh_vals[i])[1:] + res += "};\n" + + res += "float _leaf_vals[_num_trees*num_features*2*_num_tree_nodes] = {" + ("%.3ff" % leaf_vals[0])[1:] + for i in range(1,len(leaf_vals)): + res += "," + ("%.3ff" % leaf_vals[i])[1:] + res += "};\n" + f = open(out_file,"w") + f.write(res) + f.close() + + +def load_ground_truth(gt_path): + gt = scipy.io.loadmat(gt_path) + base_gt_illuminants = [] + black_levels = [] + if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys(): + #NUS 8-camera dataset format + base_gt_illuminants = gt["groundtruth_illuminants"] + black_levels = len(base_gt_illuminants) * [gt["darkness_level"][0][0]] + elif "real_rgb" in gt.keys(): + #Gehler-Shi dataset format + base_gt_illuminants = gt["real_rgb"] + black_levels = 87 * [0] + (len(base_gt_illuminants) - 87) * [129] + else: + print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported") + sys.exit(1) - res = "/* This file was automatically generated by learn_color_balance.py script\n" +\ - " * using the following parameters:\n" - for key in input_params: - res += " " + key + " " + input_params[key] - res += "\n */\n" - res += "const int num_trees = " + str(len(model)) + ";\n" - res += "const int num_features = 4;\n" - res += "const int num_tree_nodes = " + str(2**depth) + ";\n" - - res += "unsigned char feature_idx[num_trees*num_features*2*(num_tree_nodes-1)] = {" + str(feature_idx[0]) - for i in range(1,len(feature_idx)): - res += "," + str(feature_idx[i]) - res += "};\n" - - res += "float thresh_vals[num_trees*num_features*2*(num_tree_nodes-1)] = {" + ("%.3ff" % thresh_vals[0])[1:] - for i in range(1,len(thresh_vals)): - res += "," + ("%.3ff" % thresh_vals[i])[1:] - res += "};\n" - - res += "float leaf_vals[num_trees*num_features*2*num_tree_nodes] = {" + ("%.3ff" % leaf_vals[0])[1:] - for i in range(1,len(leaf_vals)): - res += "," + ("%.3ff" % leaf_vals[i])[1:] - res += "};\n" - return res + return (base_gt_illuminants, black_levels) if __name__ == '__main__': @@ -153,8 +182,9 @@ if __name__ == '__main__': "-o", "--out", metavar="OUT", - default="learning_based_color_balance_model.hpp", - help="Path to the output learnt model") + default="color_balance_model.yml", + help="Path to the output learnt model. Either a .yml (for loading during runtime) " + "or .hpp (for compiling with the main code) file ") parser.add_argument( "--hist_bin_num", metavar="HIST_BIN_NUM", @@ -196,39 +226,37 @@ if __name__ == '__main__': print("Error: Please specify the -r parameter in form ,") sys.exit(1) + use_YML = None + if args.out.endswith(".yml"): + use_YML = True + elif args.out.endswith(".hpp"): + use_YML = False + else: + print("Error: Only .hpp and .yml are supported as output formats") + sys.exit(1) + hist_bin_num = int(args.hist_bin_num) num_trees = int(args.num_trees) max_tree_depth = int(args.max_tree_depth) - - gt = scipy.io.loadmat(args.ground_truth) img_files = sorted(os.listdir(args.input_folder)) - - base_gt_illuminants = [] - black_levels = [] - if "groundtruth_illuminants" in gt.keys() and "darkness_level" in gt.keys(): - #NUS 8-camera dataset format - base_gt_illuminants = gt["groundtruth_illuminants"] - black_levels = len(gt_illuminants) * [gt["darkness_level"][0][0]] - elif "real_rgb" in gt.keys(): - #Gehler-Shi dataset format - base_gt_illuminants = gt["real_rgb"] - black_levels = 87 * [0] + (len(base_gt_illuminants) - 87) * [129] - else: - print("Error: unknown ground-truth format, only formats of Gehler-Shi and NUS 8-camera datasets are supported") - sys.exit(1) + (base_gt_illuminants,black_levels) = load_ground_truth(args.ground_truth) features = [] gt_illuminants = [] i=0 sz = len(img_files) random.seed(1234) + inst = cv2.xphoto.createLearningBasedWB() + inst.setRangeMaxVal(255) + inst.setSaturationThreshold(0.98) + inst.setHistBinNum(hist_bin_num) for file in img_files: if (i>=img_range[0] and i("i"); - string outFilename = parser.get("o"); - - if (!parser.check()) - { - parser.printErrors(); - return -1; - } - - Mat src = imread(inFilename, 1); - if (src.empty()) - { - printf("Cannot read image file: %s\n", inFilename.c_str()); - return -1; - } - - Mat res; - xphoto::autowbLearningBased(src, res); - - if (outFilename == "") - { - namedWindow("after white balance", 1); - imshow("after white balance", res); - - waitKey(0); - } - else - imwrite(outFilename, res); - - return 0; -} diff --git a/modules/xphoto/samples/simple_color_balance.cpp b/modules/xphoto/samples/simple_color_balance.cpp deleted file mode 100644 index 4159544bd..000000000 --- a/modules/xphoto/samples/simple_color_balance.cpp +++ /dev/null @@ -1,60 +0,0 @@ -#include "opencv2/xphoto.hpp" - -#include "opencv2/imgproc.hpp" -#include "opencv2/highgui.hpp" - -#include "opencv2/core/utility.hpp" -#include "opencv2/imgproc/types_c.h" - -const char* keys = -{ - "{i || input image name}" - "{o || output image name}" -}; - -int main( int argc, const char** argv ) -{ - bool printHelp = ( argc == 1 ); - printHelp = printHelp || ( argc == 2 && std::string(argv[1]) == "--help" ); - printHelp = printHelp || ( argc == 2 && std::string(argv[1]) == "-h" ); - - if ( printHelp ) - { - printf("\nThis sample demonstrates simple color balance algorithm\n" - "Call:\n" - " simple_color_blance -i=in_image_name [-o=out_image_name]\n\n"); - return 0; - } - - cv::CommandLineParser parser(argc, argv, keys); - if ( !parser.check() ) - { - parser.printErrors(); - return -1; - } - - std::string inFilename = parser.get("i"); - std::string outFilename = parser.get("o"); - - cv::Mat src = cv::imread(inFilename, 1); - if ( src.empty() ) - { - printf("Cannot read image file: %s\n", inFilename.c_str()); - return -1; - } - - cv::Mat res(src.size(), src.type()); - cv::xphoto::balanceWhite(src, res, cv::xphoto::WHITE_BALANCE_SIMPLE); - - if ( outFilename == "" ) - { - cv::namedWindow("after white balance", 1); - cv::imshow("after white balance", res); - - cv::waitKey(0); - } - else - cv::imwrite(outFilename, res); - - return 0; -} \ No newline at end of file diff --git a/modules/xphoto/src/grayworld_white_balance.cpp b/modules/xphoto/src/grayworld_white_balance.cpp index 379924fe8..9d38c2675 100644 --- a/modules/xphoto/src/grayworld_white_balance.cpp +++ b/modules/xphoto/src/grayworld_white_balance.cpp @@ -49,6 +49,54 @@ namespace xphoto void calculateChannelSums(uint &sumB, uint &sumG, uint &sumR, uchar *src_data, int src_len, float thresh); void calculateChannelSums(uint64 &sumB, uint64 &sumG, uint64 &sumR, ushort *src_data, int src_len, float thresh); +class GrayworldWBImpl : public GrayworldWB +{ + private: + float thresh; + + public: + GrayworldWBImpl() { thresh = 0.9f; } + float getSaturationThreshold() const { return thresh; } + void setSaturationThreshold(float val) { thresh = val; } + void balanceWhite(InputArray _src, OutputArray _dst) + { + CV_Assert(!_src.empty()); + CV_Assert(_src.isContinuous()); + CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3); + Mat src = _src.getMat(); + + int N = src.cols * src.rows, N3 = N * 3; + + double dsumB = 0.0, dsumG = 0.0, dsumR = 0.0; + if (src.type() == CV_8UC3) + { + uint sumB = 0, sumG = 0, sumR = 0; + calculateChannelSums(sumB, sumG, sumR, src.ptr(), N3, thresh); + dsumB = (double)sumB; + dsumG = (double)sumG; + dsumR = (double)sumR; + } + else if (src.type() == CV_16UC3) + { + uint64 sumB = 0, sumG = 0, sumR = 0; + calculateChannelSums(sumB, sumG, sumR, src.ptr(), N3, thresh); + dsumB = (double)sumB; + dsumG = (double)sumG; + dsumR = (double)sumR; + } + + // Find inverse of averages + double max_sum = max(dsumB, max(dsumR, dsumG)); + const double eps = 0.1; + float dinvB = dsumB < eps ? 0.f : (float)(max_sum / dsumB), + dinvG = dsumG < eps ? 0.f : (float)(max_sum / dsumG), + dinvR = dsumR < eps ? 0.f : (float)(max_sum / dsumR); + + // Use the inverse of averages as channel gains: + applyChannelGains(src, _dst, dinvB, dinvG, dinvR); + } +}; + /* Computes sums for each channel, while ignoring saturated pixels which are determined by thresh * (version for CV_8UC3) */ @@ -297,41 +345,6 @@ void applyChannelGains(InputArray _src, OutputArray _dst, float gainB, float gai } } -void autowbGrayworld(InputArray _src, OutputArray _dst, float thresh) -{ - Mat src = _src.getMat(); - CV_Assert(!src.empty()); - CV_Assert(src.isContinuous()); - CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3); - - int N = src.cols * src.rows, N3 = N * 3; - - double dsumB = 0.0, dsumG = 0.0, dsumR = 0.0; - if (src.type() == CV_8UC3) - { - uint sumB = 0, sumG = 0, sumR = 0; - calculateChannelSums(sumB, sumG, sumR, src.ptr(), N3, thresh); - dsumB = (double)sumB; - dsumG = (double)sumG; - dsumR = (double)sumR; - } - else if (src.type() == CV_16UC3) - { - uint64 sumB = 0, sumG = 0, sumR = 0; - calculateChannelSums(sumB, sumG, sumR, src.ptr(), N3, thresh); - dsumB = (double)sumB; - dsumG = (double)sumG; - dsumR = (double)sumR; - } - - // Find inverse of averages - double max_sum = max(dsumB, max(dsumR, dsumG)); - const double eps = 0.1; - float dinvB = dsumB < eps ? 0.f : (float)(max_sum / dsumB), dinvG = dsumG < eps ? 0.f : (float)(max_sum / dsumG), - dinvR = dsumR < eps ? 0.f : (float)(max_sum / dsumR); - - // Use the inverse of averages as channel gains: - applyChannelGains(src, _dst, dinvB, dinvG, dinvR); -} +Ptr createGrayworldWB() { return makePtr(); } } } diff --git a/modules/xphoto/src/learning_based_color_balance.cpp b/modules/xphoto/src/learning_based_color_balance.cpp index 4ef5ff6f9..49482398b 100644 --- a/modules/xphoto/src/learning_based_color_balance.cpp +++ b/modules/xphoto/src/learning_based_color_balance.cpp @@ -64,56 +64,115 @@ struct hist_elem hist_elem(float _hist_val, Vec2f chromaticity) : hist_val(_hist_val), r(chromaticity[0]), g(chromaticity[1]) {} }; bool operator<(const hist_elem &a, const hist_elem &b); -void getColorPalleteMode(Vec2f &dst, hist_elem *pallete, int pallete_sz, float bandwidth); -void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val, float saturation_thresh); -void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f &brightest_chromaticity, Mat &src, - Mat &mask); -void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_pallete_mode, Mat &src, Mat &mask, - int hist_bin_num, int max_val); - bool operator<(const hist_elem &a, const hist_elem &b) { return a.hist_val > b.hist_val; } -/* Returns the most high-density point (i.e. mode) of the color pallete. - * Uses a simplistic kernel density estimator with a Epanechnikov kernel and - * fixed bandwidth. - */ -void getColorPalleteMode(Vec2f &dst, hist_elem *pallete, int pallete_sz, float bandwidth) +class LearningBasedWBImpl : public LearningBasedWB { - float max_density = -1.0f; - float denom = bandwidth * bandwidth; - for (int i = 0; i < pallete_sz; i++) - { - float cur_density = 0.0f; - float cur_dist_sq; + private: + int range_max_val, hist_bin_num, palette_size; + float saturation_thresh, palette_bandwidth, prediction_thresh; + int num_trees, num_tree_nodes, tree_depth; + uchar *feature_idx; + float *thresh_vals, *leaf_vals; + Mat feature_idx_Mat, thresh_vals_Mat, leaf_vals_Mat; + Mat mask; + int src_max_val; + + void preprocessing(Mat &src); + void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f &brightest_chromaticity, Mat &src); + void getColorPaletteMode(Vec2f &dst, hist_elem *palette); + void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_palette_mode, Mat &src); - for (int j = 0; j < pallete_sz; j++) + float regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tree_thresh_vals, float *tree_leaf_vals); + Vec2f predictIlluminant(vector features); + + public: + LearningBasedWBImpl(String path_to_model) + { + range_max_val = 255; + saturation_thresh = 0.98f; + hist_bin_num = 64; + palette_size = 300; + palette_bandwidth = 0.1f; + prediction_thresh = 0.025f; + if (path_to_model.empty()) { - cur_dist_sq = (pallete[i].r - pallete[j].r) * (pallete[i].r - pallete[j].r) + - (pallete[i].g - pallete[j].g) * (pallete[i].g - pallete[j].g); - cur_density += max((1.0f - (cur_dist_sq / denom)), 0.0f); + /* use the default model */ + num_trees = _num_trees; + num_tree_nodes = _num_tree_nodes; + feature_idx = _feature_idx; + thresh_vals = _thresh_vals; + leaf_vals = _leaf_vals; } - - if (cur_density > max_density) + else { - max_density = cur_density; - dst[0] = pallete[i].r; - dst[1] = pallete[i].g; + /* load model from file */ + FileStorage fs(path_to_model, 0); + num_trees = fs["num_trees"]; + num_tree_nodes = fs["num_tree_nodes"]; + fs["feature_idx"] >> feature_idx_Mat; + fs["thresh_vals"] >> thresh_vals_Mat; + fs["leaf_vals"] >> leaf_vals_Mat; + feature_idx = feature_idx_Mat.ptr(); + thresh_vals = thresh_vals_Mat.ptr(); + leaf_vals = leaf_vals_Mat.ptr(); } } -} + + int getRangeMaxVal() const { return range_max_val; } + void setRangeMaxVal(int val) { range_max_val = val; } + + float getSaturationThreshold() const { return saturation_thresh; } + void setSaturationThreshold(float val) { saturation_thresh = val; } + + int getHistBinNum() const { return hist_bin_num; } + void setHistBinNum(int val) { hist_bin_num = val; } + + void extractSimpleFeatures(InputArray _src, OutputArray _dst) + { + CV_Assert(!_src.empty()); + CV_Assert(_src.isContinuous()); + CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3); + Mat src = _src.getMat(); + vector dst(num_features); + + preprocessing(src); + getAverageAndBrightestColorChromaticity(dst[0], dst[1], src); + getHistogramBasedFeatures(dst[2], dst[3], src); + Mat(dst).convertTo(_dst, CV_32F); + } + + void balanceWhite(InputArray _src, OutputArray _dst) + { + CV_Assert(!_src.empty()); + CV_Assert(_src.isContinuous()); + CV_Assert(_src.type() == CV_8UC3 || _src.type() == CV_16UC3); + Mat src = _src.getMat(); + + vector features; + extractSimpleFeatures(src, features); + Vec2f illuminant = predictIlluminant(features); + + float denom = 1 - illuminant[0] - illuminant[1]; + float gainB = 1.0f; + float gainG = denom / illuminant[1]; + float gainR = denom / illuminant[0]; + applyChannelGains(src, _dst, gainB, gainG, gainR); + } +}; /* Computes a mask for non-saturated pixels and maximum pixel value * which are then used for feature computation */ -void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val, float saturation_thresh) +void LearningBasedWBImpl::preprocessing(Mat &src) { - dst_mask = Mat(src.size(), CV_8U); - uchar *mask_ptr = dst_mask.ptr(); + mask.create(src.size(), CV_8U); + uchar *mask_ptr = mask.ptr(); int src_len = src.rows * src.cols; int thresh = (int)(saturation_thresh * range_max_val); int i = 0; int local_max; - dst_max_val = -1; + src_max_val = -1; if (src.type() == CV_8UC3) { @@ -133,15 +192,15 @@ void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val, v_store(global_max, v_global_max); for (int j = 0; j < 16; j++) { - if (global_max[j] > dst_max_val) - dst_max_val = global_max[j]; + if (global_max[j] > src_max_val) + src_max_val = global_max[j]; } #endif for (; i < src_len; i++) { local_max = max(src_ptr[3 * i], max(src_ptr[3 * i + 1], src_ptr[3 * i + 2])); - if (local_max > dst_max_val) - dst_max_val = local_max; + if (local_max > src_max_val) + src_max_val = local_max; if (local_max < thresh) mask_ptr[i] = 255; else @@ -166,15 +225,15 @@ void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val, v_store(global_max, v_global_max); for (int j = 0; j < 8; j++) { - if (global_max[j] > dst_max_val) - dst_max_val = global_max[j]; + if (global_max[j] > src_max_val) + src_max_val = global_max[j]; } #endif for (; i < src_len; i++) { local_max = max(src_ptr[3 * i], max(src_ptr[3 * i + 1], src_ptr[3 * i + 2])); - if (local_max > dst_max_val) - dst_max_val = local_max; + if (local_max > src_max_val) + src_max_val = local_max; if (local_max < thresh) mask_ptr[i] = 255; else @@ -183,8 +242,8 @@ void preprocessing(Mat &dst_mask, int &dst_max_val, Mat &src, int range_max_val, } } -void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f &brightest_chromaticity, Mat &src, - Mat &mask) +void LearningBasedWBImpl::getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, + Vec2f &brightest_chromaticity, Mat &src) { int i = 0; int src_len = src.rows * src.cols; @@ -376,15 +435,42 @@ void getAverageAndBrightestColorChromaticity(Vec2f &average_chromaticity, Vec2f } } -void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_pallete_mode, Mat &src, Mat &mask, - int hist_bin_num, int max_val) +/* Returns the most high-density point (i.e. mode) of the color palette. + * Uses a simplistic kernel density estimator with a Epanechnikov kernel and + * fixed bandwidth. + */ +void LearningBasedWBImpl::getColorPaletteMode(Vec2f &dst, hist_elem *palette) +{ + float max_density = -1.0f; + float denom = palette_bandwidth * palette_bandwidth; + for (int i = 0; i < palette_size; i++) + { + float cur_density = 0.0f; + float cur_dist_sq; + + for (int j = 0; j < palette_size; j++) + { + cur_dist_sq = (palette[i].r - palette[j].r) * (palette[i].r - palette[j].r) + + (palette[i].g - palette[j].g) * (palette[i].g - palette[j].g); + cur_density += max((1.0f - (cur_dist_sq / denom)), 0.0f); + } + + if (cur_density > max_density) + { + max_density = cur_density; + dst[0] = palette[i].r; + dst[1] = palette[i].g; + } + } +} + +void LearningBasedWBImpl::getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity_palette_mode, + Mat &src) { - const int pallete_size = 300; - const float pallete_bandwidth = 0.1f; MatND hist; int channels[] = {0, 1, 2}; int histSize[] = {hist_bin_num, hist_bin_num, hist_bin_num}; - float range[] = {0, (float)max(hist_bin_num, max_val)}; + float range[] = {0, (float)max(hist_bin_num, src_max_val)}; const float *ranges[] = {range, range, range}; calcHist(&src, 1, channels, mask, hist, 3, histSize, ranges); @@ -406,10 +492,10 @@ void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity } getChromaticity(dominant_chromaticity, (float)dominant_R, (float)dominant_G, (float)dominant_B); - vector pallete; - pallete.reserve(pallete_size); + vector palette; + palette.reserve(palette_size); hist_ptr = hist.ptr(); - // extract top pallete_size most common colors and add them to the pallete: + // extract top palette_size most common colors and add them to the palette: for (int i = 0; i < hist_bin_num; i++) for (int j = 0; j < hist_bin_num; j++) for (int k = 0; k < hist_bin_num; k++) @@ -424,45 +510,28 @@ void getHistogramBasedFeatures(Vec2f &dominant_chromaticity, Vec2f &chromaticity getChromaticity(chromaticity, (float)k, (float)j, (float)i); hist_elem el(bin_count, chromaticity); - if (pallete.size() < pallete_size) + if (palette.size() < (uint)palette_size) { - pallete.push_back(el); - if (pallete.size() == pallete_size) - make_heap(pallete.begin(), pallete.end()); + palette.push_back(el); + if (palette.size() == (uint)palette_size) + make_heap(palette.begin(), palette.end()); } - else if (bin_count > pallete.front().hist_val) + else if (bin_count > palette.front().hist_val) { - pop_heap(pallete.begin(), pallete.end()); - pallete.back() = el; - push_heap(pallete.begin(), pallete.end()); + pop_heap(palette.begin(), palette.end()); + palette.back() = el; + push_heap(palette.begin(), palette.end()); } hist_ptr++; } - getColorPalleteMode(chromaticity_pallete_mode, (hist_elem *)(&pallete[0]), (int)pallete.size(), pallete_bandwidth); -} - -void extractSimpleFeatures(InputArray _src, OutputArray _dst, int range_max_val, float saturation_thresh, - int hist_bin_num) -{ - Mat src = _src.getMat(); - CV_Assert(!src.empty()); - CV_Assert(src.isContinuous()); - CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3); - vector dst(num_features); - - Mat mask; - int max_val = 0; - preprocessing(mask, max_val, src, range_max_val, saturation_thresh); - getAverageAndBrightestColorChromaticity(dst[0], dst[1], src, mask); - getHistogramBasedFeatures(dst[2], dst[3], src, mask, hist_bin_num, max_val); - Mat(dst).convertTo(_dst, CV_32F); + getColorPaletteMode(chromaticity_palette_mode, (hist_elem *)(&palette[0])); } -inline float regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tree_thresh_vals, float *tree_leaf_vals) +float LearningBasedWBImpl::regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tree_thresh_vals, + float *tree_leaf_vals) { int node_idx = 0; - int depth = (int)round(log(num_tree_nodes) / log(2)); - for (int i = 0; i < depth; i++) + for (int i = 0; i < tree_depth; i++) { if (src[tree_feature_idx[node_idx]] <= tree_thresh_vals[node_idx]) node_idx = 2 * node_idx + 1; @@ -472,22 +541,14 @@ inline float regressionTreePredict(Vec2f src, uchar *tree_feature_idx, float *tr return tree_leaf_vals[node_idx - num_tree_nodes + 1]; } -void autowbLearningBased(InputArray _src, OutputArray _dst, int range_max_val, float saturation_thresh, - int hist_bin_num) +Vec2f LearningBasedWBImpl::predictIlluminant(vector features) { - const float prediction_thresh = 0.025f; - Mat src = _src.getMat(); - CV_Assert(!src.empty()); - CV_Assert(src.isContinuous()); - CV_Assert(src.type() == CV_8UC3 || src.type() == CV_16UC3); - - vector features; - extractSimpleFeatures(src, features, range_max_val, saturation_thresh, hist_bin_num); - int feature_model_size = 2 * (num_tree_nodes - 1); int local_model_size = num_features * feature_model_size; int feature_model_size_leaf = 2 * num_tree_nodes; int local_model_size_leaf = num_features * feature_model_size_leaf; + tree_depth = (int)round(log(num_tree_nodes) / log(2)); + vector consensus_r, consensus_g; vector all_r, all_g; for (int i = 0; i < num_trees; i++) @@ -538,12 +599,13 @@ void autowbLearningBased(InputArray _src, OutputArray _dst, int range_max_val, f nth_element(consensus_g.begin(), consensus_g.begin() + consensus_g.size() / 2, consensus_g.end()); illuminant_g = consensus_g[consensus_g.size() / 2]; } + return Vec2f(illuminant_r, illuminant_g); +} - float denom = 1 - illuminant_r - illuminant_g; - float gainB = 1.0f; - float gainG = denom / illuminant_g; - float gainR = denom / illuminant_r; - applyChannelGains(src, _dst, gainB, gainG, gainR); +Ptr createLearningBasedWB(const String& path_to_model) +{ + Ptr inst = makePtr(path_to_model); + return inst; } } } diff --git a/modules/xphoto/src/learning_based_color_balance_model.hpp b/modules/xphoto/src/learning_based_color_balance_model.hpp index a462f6261..e46d8f6dc 100644 --- a/modules/xphoto/src/learning_based_color_balance_model.hpp +++ b/modules/xphoto/src/learning_based_color_balance_model.hpp @@ -2,10 +2,10 @@ * using the following parameters: --num_trees 20 --hist_bin_num 64 --max_tree_depth 4 --num_augmented 2 -r 0,0 */ -const int num_trees = 20; const int num_features = 4; -const int num_tree_nodes = 16; -unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] = { +const int _num_trees = 20; +const int _num_tree_nodes = 16; +unsigned char _feature_idx[_num_trees * num_features * 2 * (_num_tree_nodes - 1)] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, @@ -68,7 +68,7 @@ unsigned char feature_idx[num_trees * num_features * 2 * (num_tree_nodes - 1)] = 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; -float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = { +float _thresh_vals[_num_trees * num_features * 2 * (_num_tree_nodes - 1)] = { .193f, .098f, .455f, .040f, .145f, .316f, .571f, .016f, .058f, .137f, .174f, .276f, .356f, .515f, .730f, .606f, .324f, .794f, .230f, .440f, .683f, .878f, .134f, .282f, .406f, .532f, .036f, .747f, .830f, .931f, .196f, .145f, .363f, .047f, .351f, .279f, .519f, .013f, .887f, .191f, .193f, .361f, .316f, .576f, .445f, .524f, .368f, .752f, .271f, .477f, .636f, @@ -211,7 +211,7 @@ float thresh_vals[num_trees * num_features * 2 * (num_tree_nodes - 1)] = { .550f, .000f, .195f, .377f, .500f, .984f, .000f, .479f, .183f, .704f, .082f, .310f, .567f, .875f, .043f, .141f, .271f, .372f, .511f, .630f, .762f, .896f, .325f, .164f, .602f, .086f, .230f, .414f, .761f, .040f, .131f, .197f, .283f, .352f, .516f, .685f, .855f}; -float leaf_vals[num_trees * num_features * 2 * num_tree_nodes] = { +float _leaf_vals[_num_trees * num_features * 2 * _num_tree_nodes] = { .011f, .029f, .047f, .064f, .075f, .102f, .141f, .172f, .212f, .259f, .308f, .364f, .443f, .497f, .592f, .767f, .069f, .165f, .241f, .278f, .357f, .412f, .463f, .540f, .562f, .623f, .676f, .734f, .797f, .838f, .894f, .944f, .014f, .040f, .061f, .033f, .040f, .160f, .181f, .101f, .123f, .047f, .195f, .282f, .374f, .775f, .248f, .068f, .064f, .155f, .177f, diff --git a/modules/xphoto/src/simple_color_balance.cpp b/modules/xphoto/src/simple_color_balance.cpp index f755e71b1..253a0d0af 100644 --- a/modules/xphoto/src/simple_color_balance.cpp +++ b/modules/xphoto/src/simple_color_balance.cpp @@ -37,175 +37,176 @@ // //M*/ -#include #include -#include #include - -#include "opencv2/xphoto.hpp" - -#include "opencv2/imgproc.hpp" +#include +#include #include "opencv2/core.hpp" -#include "opencv2/core/core_c.h" - -#include "opencv2/core/types.hpp" -#include "opencv2/core/types_c.h" +#include "opencv2/imgproc.hpp" +#include "opencv2/xphoto.hpp" namespace cv { namespace xphoto { - template - void balanceWhite(std::vector < Mat_ > &src, Mat &dst, - const float inputMin, const float inputMax, - const float outputMin, const float outputMax, const int algorithmType) +template +void balanceWhiteSimple(std::vector > &src, Mat &dst, const float inputMin, const float inputMax, + const float outputMin, const float outputMax, const float p) +{ + /********************* Simple white balance *********************/ + const float s1 = p; // low quantile + const float s2 = p; // high quantile + + int depth = 2; // depth of histogram tree + if (src[0].depth() != CV_8U) + ++depth; + int bins = 16; // number of bins at each histogram level + + int nElements = int(pow((float)bins, (float)depth)); + // number of elements in histogram tree + + for (size_t i = 0; i < src.size(); ++i) { - switch ( algorithmType ) + std::vector hist(nElements, 0); + + typename Mat_::iterator beginIt = src[i].begin(); + typename Mat_::iterator endIt = src[i].end(); + + for (typename Mat_::iterator it = beginIt; it != endIt; ++it) + // histogram filling { - case WHITE_BALANCE_SIMPLE: - { - /********************* Simple white balance *********************/ - float s1 = 2.0f; // low quantile - float s2 = 2.0f; // high quantile - - int depth = 2; // depth of histogram tree - if (src[0].depth() != CV_8U) - ++depth; - int bins = 16; // number of bins at each histogram level - - int nElements = int( pow((float)bins, (float)depth) ); - // number of elements in histogram tree - - for (size_t i = 0; i < src.size(); ++i) - { - std::vector hist(nElements, 0); - - typename Mat_::iterator beginIt = src[i].begin(); - typename Mat_::iterator endIt = src[i].end(); - - for (typename Mat_::iterator it = beginIt; it != endIt; ++it) - // histogram filling - { - int pos = 0; - float minValue = inputMin - 0.5f; - float maxValue = inputMax + 0.5f; - T val = *it; - - float interval = float(maxValue - minValue) / bins; - - for (int j = 0; j < depth; ++j) - { - int currentBin = int( (val - minValue + 1e-4f) / interval ); - ++hist[pos + currentBin]; - - pos = (pos + currentBin)*bins; - - minValue = minValue + currentBin*interval; - maxValue = minValue + interval; - - interval /= bins; - } - } - - int total = int( src[i].total() ); - - int p1 = 0, p2 = bins - 1; - int n1 = 0, n2 = total; - - float minValue = inputMin - 0.5f; - float maxValue = inputMax + 0.5f; - - float interval = (maxValue - minValue) / float(bins); - - for (int j = 0; j < depth; ++j) - // searching for s1 and s2 - { - while (n1 + hist[p1] < s1 * total / 100.0f) - { - n1 += hist[p1++]; - minValue += interval; - } - p1 *= bins; - - while (n2 - hist[p2] > (100.0f - s2) * total / 100.0f) - { - n2 -= hist[p2--]; - maxValue -= interval; - } - p2 = p2*bins - 1; - - interval /= bins; - } - - src[i] = (outputMax - outputMin) * (src[i] - minValue) - / (maxValue - minValue) + outputMin; - } - /****************************************************************/ - break; - } - default: - CV_Error_( CV_StsNotImplemented, - ("Unsupported algorithm type (=%d)", algorithmType) ); + int pos = 0; + float minValue = inputMin - 0.5f; + float maxValue = inputMax + 0.5f; + T val = *it; + + float interval = float(maxValue - minValue) / bins; + + for (int j = 0; j < depth; ++j) + { + int currentBin = int((val - minValue + 1e-4f) / interval); + ++hist[pos + currentBin]; + + pos = (pos + currentBin) * bins; + + minValue = minValue + currentBin * interval; + maxValue = minValue + interval; + + interval /= bins; + } } - dst.create(/**/ src[0].size(), CV_MAKETYPE( src[0].depth(), int( src.size() ) ) /**/); - cv::merge(src, dst); + int total = int(src[i].total()); + + int p1 = 0, p2 = bins - 1; + int n1 = 0, n2 = total; + + float minValue = inputMin - 0.5f; + float maxValue = inputMax + 0.5f; + + float interval = (maxValue - minValue) / float(bins); + + for (int j = 0; j < depth; ++j) + // searching for s1 and s2 + { + while (n1 + hist[p1] < s1 * total / 100.0f) + { + n1 += hist[p1++]; + minValue += interval; + } + p1 *= bins; + + while (n2 - hist[p2] > (100.0f - s2) * total / 100.0f) + { + n2 -= hist[p2--]; + maxValue -= interval; + } + p2 = p2 * bins - 1; + + interval /= bins; + } + + src[i] = (outputMax - outputMin) * (src[i] - minValue) / (maxValue - minValue) + outputMin; } + /****************************************************************/ + + dst.create(/**/ src[0].size(), CV_MAKETYPE(src[0].depth(), int(src.size())) /**/); + cv::merge(src, dst); +} + +class SimpleWBImpl : public SimpleWB +{ + private: + float inputMin, inputMax, outputMin, outputMax, p; - /*! - * Wrappers over different white balance algorithm - * - * \param src : source image (RGB) - * \param dst : destination image - * - * \param inputMin : minimum input value - * \param inputMax : maximum input value - * \param outputMin : minimum output value - * \param outputMax : maximum output value - * - * \param algorithmType : type of the algorithm to use - */ - void balanceWhite(const Mat &src, Mat &dst, const int algorithmType, - const float inputMin, const float inputMax, - const float outputMin, const float outputMax) + public: + SimpleWBImpl() { - switch ( src.depth() ) + inputMin = 0.0f; + inputMax = 255.0f; + outputMin = 0.0f; + outputMax = 255.0f; + p = 2.0f; + } + + float getInputMin() const { return inputMin; } + void setInputMin(float val) { inputMin = val; } + + float getInputMax() const { return inputMax; } + void setInputMax(float val) { inputMax = val; } + + float getOutputMin() const { return outputMin; } + void setOutputMin(float val) { outputMin = val; } + + float getOutputMax() const { return outputMax; } + void setOutputMax(float val) { outputMax = val; } + + float getP() const { return p; } + void setP(float val) { p = val; } + + void balanceWhite(InputArray _src, OutputArray _dst) + { + CV_Assert(!_src.empty()); + CV_Assert(_src.depth() == CV_8U || _src.depth() == CV_16S || _src.depth() == CV_32S || _src.depth() == CV_32F); + Mat src = _src.getMat(); + Mat &dst = _dst.getMatRef(); + + switch (src.depth()) + { + case CV_8U: + { + std::vector > mv; + split(src, mv); + balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p); + break; + } + case CV_16S: { - case CV_8U: - { - std::vector < Mat_ > mv; - split(src, mv); - balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType); - break; - } - case CV_16S: - { - std::vector < Mat_ > mv; - split(src, mv); - balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType); - break; - } - case CV_32S: - { - std::vector < Mat_ > mv; - split(src, mv); - balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType); - break; - } - case CV_32F: - { - std::vector < Mat_ > mv; - split(src, mv); - balanceWhite(mv, dst, inputMin, inputMax, outputMin, outputMax, algorithmType); - break; - } - default: - CV_Error_( CV_StsNotImplemented, - ("Unsupported source image format (=%d)", src.type()) ); - break; + std::vector > mv; + split(src, mv); + balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p); + break; + } + case CV_32S: + { + std::vector > mv; + split(src, mv); + balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p); + break; + } + case CV_32F: + { + std::vector > mv; + split(src, mv); + balanceWhiteSimple(mv, dst, inputMin, inputMax, outputMin, outputMax, p); + break; + } } } +}; + +Ptr createSimpleWB() { return makePtr(); } } } diff --git a/modules/xphoto/test/simple_color_balance.cpp b/modules/xphoto/test/simple_color_balance.cpp index bf30271ba..64f9fb17e 100644 --- a/modules/xphoto/test/simple_color_balance.cpp +++ b/modules/xphoto/test/simple_color_balance.cpp @@ -7,6 +7,7 @@ namespace cvtest cv::String dir = cvtest::TS::ptr()->get_data_path() + "cv/xphoto/simple_white_balance/"; int nTests = 12; float threshold = 0.005f; + cv::Ptr wb = cv::xphoto::createSimpleWB(); for (int i = 0; i < nTests; ++i) { @@ -18,7 +19,7 @@ namespace cvtest cv::Mat previousResult = cv::imread( previousResultName, 1 ); cv::Mat currentResult; - cv::xphoto::balanceWhite(src, currentResult, cv::xphoto::WHITE_BALANCE_SIMPLE); + wb->balanceWhite(src, currentResult); cv::Mat sqrError = ( currentResult - previousResult ) .mul( currentResult - previousResult ); diff --git a/modules/xphoto/test/test_grayworld.cpp b/modules/xphoto/test/test_grayworld.cpp index ae494b6e5..a4877003b 100644 --- a/modules/xphoto/test/test_grayworld.cpp +++ b/modules/xphoto/test/test_grayworld.cpp @@ -69,6 +69,8 @@ namespace cvtest { const int nTests = 14; const float wb_thresh = 0.5f; const float acc_thresh = 2.f; + Ptr wb = xphoto::createGrayworldWB(); + wb->setSaturationThreshold(wb_thresh); for ( int i = 0; i < nTests; ++i ) { @@ -80,13 +82,13 @@ namespace cvtest { ref_autowbGrayworld(src, referenceResult, wb_thresh); Mat currentResult; - xphoto::autowbGrayworld(src, currentResult, wb_thresh); + wb->balanceWhite(src, currentResult); ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh); // test the 16-bit depth: Mat currentResult_16U, src_16U; src.convertTo(src_16U, CV_16UC3, 256.0); - xphoto::autowbGrayworld(src_16U, currentResult_16U, wb_thresh); + wb->balanceWhite(src_16U, currentResult_16U); currentResult_16U.convertTo(currentResult, CV_8UC3, 1/256.0); ASSERT_LE(cv::norm(currentResult, referenceResult, NORM_INF), acc_thresh); } diff --git a/modules/xphoto/test/test_learning_based_color_balance.cpp b/modules/xphoto/test/test_learning_based_color_balance.cpp index 66faca4e8..47b12c4fc 100644 --- a/modules/xphoto/test/test_learning_based_color_balance.cpp +++ b/modules/xphoto/test/test_learning_based_color_balance.cpp @@ -18,7 +18,11 @@ TEST(xphoto_simplefeatures, regression) Vec2f ref2(200.0f / (240 + 220 + 200), 220.0f / (240 + 220 + 200)); vector dst_features; - xphoto::extractSimpleFeatures(test_im, dst_features, 255, 0.98f, 64); + Ptr wb = xphoto::createLearningBasedWB(); + wb->setRangeMaxVal(255); + wb->setSaturationThreshold(0.98f); + wb->setHistBinNum(64); + wb->extractSimpleFeatures(test_im, dst_features); ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh); @@ -26,7 +30,10 @@ TEST(xphoto_simplefeatures, regression) // check 16 bit depth: test_im.convertTo(test_im, CV_16U, 256.0); - xphoto::extractSimpleFeatures(test_im, dst_features, 65535, 0.98f, 64); + wb->setRangeMaxVal(65535); + wb->setSaturationThreshold(0.98f); + wb->setHistBinNum(128); + wb->extractSimpleFeatures(test_im, dst_features); ASSERT_LE(cv::norm(dst_features[0], ref1, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[1], ref2, NORM_INF), acc_thresh); ASSERT_LE(cv::norm(dst_features[2], ref1, NORM_INF), acc_thresh); diff --git a/modules/xphoto/tutorials/training_white_balance.markdown b/modules/xphoto/tutorials/training_white_balance.markdown new file mode 100644 index 000000000..eac4c8498 --- /dev/null +++ b/modules/xphoto/tutorials/training_white_balance.markdown @@ -0,0 +1,42 @@ +Training the learning-based white balance algorithm {#tutorial_xphoto_training_white_balance} +=================================================== + +Introduction +------------ + +Many traditional white balance algorithms are statistics-based, i.e. they rely on the fact that certain assumptions should hold in properly white-balanced images +like the well-known grey-world assumption. However, better results can often be achieved by leveraging large datasets of images with ground-truth +illuminants in a learning-based framework. This tutorial demonstrates how to train a learning-based white balance algorithm and evaluate the quality of the results. + + +How to train a model +-------------------- + +-# Download a dataset for training. In this tutorial we will use the [Gehler-Shi dataset ](http://www.cs.sfu.ca/~colour/data/shi_gehler/). Extract all 568 training images + in one folder. A file containing ground-truth illuminant values (real_illum_568..mat) is downloaded separately. + +-# We will be using a [Python script ](https://github.com/opencv/opencv_contrib/tree/master/modules/xphoto/samples/learn_color_balance.py) for training. + Call it with the following parameters: + @code + python learn_color_balance.py -i -g -r 0,378 --num_trees 30 --max_tree_depth 6 --num_augmented 0 + @endcode + This should start training a model on the first 378 images (2/3 of the whole dataset). We set the size of the model to be 30 regression tree pairs per feature and limit + the tree depth to be no more then 6. By default the resulting model will be saved to color_balance_model.yml + +-# Use the trained model by passing its path when constructing an instance of LearningBasedWB: + @code{.cpp} + Ptr wb = xphoto::createLearningBasedWB(modelFilename); + @endcode + + +How to evaluate a model +---------------------- + +-# We will use a [benchmarking script ](https://github.com/opencv/opencv_contrib/tree/master/modules/xphoto/samples/color_balance_benchmark.py) to compare + the model that we've trained with the classic grey-world algorithm on the remaining 1/3 of the dataset. Call the script with the following parameters: + @code + python color_balance_benchmark.py -a grayworld,learning_based:color_balance_model.yml -m -i -g -r 379,567 -d "img" + @endcode + +-# The objective evaluation results are stored in white_balance_eval_result.html and the resulting white-balanced images are stored in the img folder for a qualitative + comparison of algorithms. Different algorithms are compared in terms of angular error between the estimated and ground-truth illuminants. \ No newline at end of file