Merge pull request #18838 from alalek:video_tracking_api
Tracking API: move to video/tracking.hpp * video(tracking): moved code from opencv_contrib/tracking module - Tracker API - MIL, GOTURN trackers - applied clang-format * video(tracking): cleanup unused code * samples: add tracker.py sample * video(tracking): avoid div by zero * static analyzerpull/18825/head
parent
94e8a08d1d
commit
aab6362705
31 changed files with 3843 additions and 4 deletions
@ -1,2 +1,12 @@ |
||||
set(the_description "Video Analysis") |
||||
ocv_define_module(video opencv_imgproc OPTIONAL opencv_calib3d WRAP java objc python js) |
||||
ocv_define_module(video |
||||
opencv_imgproc |
||||
OPTIONAL |
||||
opencv_calib3d |
||||
opencv_dnn |
||||
WRAP |
||||
java |
||||
objc |
||||
python |
||||
js |
||||
) |
||||
|
@ -1,6 +1,44 @@ |
||||
@article{AAM, |
||||
title={Adaptive appearance modeling for video tracking: survey and evaluation}, |
||||
author={Salti, Samuele and Cavallaro, Andrea and Di Stefano, Luigi}, |
||||
journal={Image Processing, IEEE Transactions on}, |
||||
volume={21}, |
||||
number={10}, |
||||
pages={4334--4348}, |
||||
year={2012}, |
||||
publisher={IEEE} |
||||
} |
||||
|
||||
@article{AMVOT, |
||||
title={A survey of appearance models in visual object tracking}, |
||||
author={Li, Xi and Hu, Weiming and Shen, Chunhua and Zhang, Zhongfei and Dick, Anthony and Hengel, Anton Van Den}, |
||||
journal={ACM Transactions on Intelligent Systems and Technology (TIST)}, |
||||
volume={4}, |
||||
number={4}, |
||||
pages={58}, |
||||
year={2013}, |
||||
publisher={ACM} |
||||
} |
||||
|
||||
@inproceedings{GOTURN, |
||||
title={Learning to Track at 100 FPS with Deep Regression Networks}, |
||||
author={Held, David and Thrun, Sebastian and Savarese, Silvio}, |
||||
booktitle={European Conference Computer Vision (ECCV)}, |
||||
year={2016} |
||||
} |
||||
|
||||
@inproceedings{Kroeger2016, |
||||
author={Till Kroeger and Radu Timofte and Dengxin Dai and Luc Van Gool}, |
||||
title={Fast Optical Flow using Dense Inverse Search}, |
||||
booktitle={Proceedings of the European Conference on Computer Vision ({ECCV})}, |
||||
year = {2016} |
||||
year={2016} |
||||
} |
||||
|
||||
@inproceedings{MIL, |
||||
title={Visual tracking with online multiple instance learning}, |
||||
author={Babenko, Boris and Yang, Ming-Hsuan and Belongie, Serge}, |
||||
booktitle={Computer Vision and Pattern Recognition, 2009. CVPR 2009. IEEE Conference on}, |
||||
pages={983--990}, |
||||
year={2009}, |
||||
organization={IEEE} |
||||
} |
||||
|
@ -0,0 +1,406 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_VIDEO_DETAIL_TRACKING_HPP |
||||
#define OPENCV_VIDEO_DETAIL_TRACKING_HPP |
||||
|
||||
/*
|
||||
* Partially based on: |
||||
* ==================================================================================================================== |
||||
* - [AAM] S. Salti, A. Cavallaro, L. Di Stefano, Adaptive Appearance Modeling for Video Tracking: Survey and Evaluation |
||||
* - [AMVOT] X. Li, W. Hu, C. Shen, Z. Zhang, A. Dick, A. van den Hengel, A Survey of Appearance Models in Visual Object Tracking |
||||
* |
||||
* This Tracking API has been designed with PlantUML. If you modify this API please change UML files under modules/tracking/doc/uml |
||||
* |
||||
*/ |
||||
|
||||
#include "opencv2/core.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
/** @addtogroup tracking_detail
|
||||
@{ |
||||
*/ |
||||
|
||||
/************************************ TrackerFeature Base Classes ************************************/ |
||||
|
||||
/** @brief Abstract base class for TrackerFeature that represents the feature.
|
||||
*/ |
||||
class CV_EXPORTS TrackerFeature |
||||
{ |
||||
public: |
||||
virtual ~TrackerFeature(); |
||||
|
||||
/** @brief Compute the features in the images collection
|
||||
@param images The images |
||||
@param response The output response |
||||
*/ |
||||
void compute(const std::vector<Mat>& images, Mat& response); |
||||
|
||||
protected: |
||||
virtual bool computeImpl(const std::vector<Mat>& images, Mat& response) = 0; |
||||
}; |
||||
|
||||
/** @brief Class that manages the extraction and selection of features
|
||||
|
||||
@cite AAM Feature Extraction and Feature Set Refinement (Feature Processing and Feature Selection). |
||||
See table I and section III C @cite AMVOT Appearance modelling -\> Visual representation (Table II, |
||||
section 3.1 - 3.2) |
||||
|
||||
TrackerFeatureSet is an aggregation of TrackerFeature |
||||
|
||||
@sa |
||||
TrackerFeature |
||||
|
||||
*/ |
||||
class CV_EXPORTS TrackerFeatureSet |
||||
{ |
||||
public: |
||||
TrackerFeatureSet(); |
||||
|
||||
~TrackerFeatureSet(); |
||||
|
||||
/** @brief Extract features from the images collection
|
||||
@param images The input images |
||||
*/ |
||||
void extraction(const std::vector<Mat>& images); |
||||
|
||||
/** @brief Add TrackerFeature in the collection. Return true if TrackerFeature is added, false otherwise
|
||||
@param feature The TrackerFeature class
|
||||
*/ |
||||
bool addTrackerFeature(const Ptr<TrackerFeature>& feature); |
||||
|
||||
/** @brief Get the TrackerFeature collection (TrackerFeature name, TrackerFeature pointer)
|
||||
*/ |
||||
const std::vector<Ptr<TrackerFeature>>& getTrackerFeatures() const; |
||||
|
||||
/** @brief Get the responses
|
||||
@note Be sure to call extraction before getResponses Example TrackerFeatureSet::getResponses |
||||
*/ |
||||
const std::vector<Mat>& getResponses() const; |
||||
|
||||
private: |
||||
void clearResponses(); |
||||
bool blockAddTrackerFeature; |
||||
|
||||
std::vector<Ptr<TrackerFeature>> features; // list of features
|
||||
std::vector<Mat> responses; // list of response after compute
|
||||
}; |
||||
|
||||
/************************************ TrackerSampler Base Classes ************************************/ |
||||
|
||||
/** @brief Abstract base class for TrackerSamplerAlgorithm that represents the algorithm for the specific
|
||||
sampler. |
||||
*/ |
||||
class CV_EXPORTS TrackerSamplerAlgorithm |
||||
{ |
||||
public: |
||||
virtual ~TrackerSamplerAlgorithm(); |
||||
|
||||
/** @brief Computes the regions starting from a position in an image.
|
||||
|
||||
Return true if samples are computed, false otherwise |
||||
|
||||
@param image The current frame |
||||
@param boundingBox The bounding box from which regions can be calculated |
||||
|
||||
@param sample The computed samples @cite AAM Fig. 1 variable Sk |
||||
*/ |
||||
virtual bool sampling(const Mat& image, const Rect& boundingBox, std::vector<Mat>& sample) = 0; |
||||
}; |
||||
|
||||
/**
|
||||
* \brief Class that manages the sampler in order to select regions for the update the model of the tracker |
||||
* [AAM] Sampling e Labeling. See table I and section III B |
||||
*/ |
||||
|
||||
/** @brief Class that manages the sampler in order to select regions for the update the model of the tracker
|
||||
|
||||
@cite AAM Sampling e Labeling. See table I and section III B |
||||
|
||||
TrackerSampler is an aggregation of TrackerSamplerAlgorithm |
||||
@sa |
||||
TrackerSamplerAlgorithm |
||||
*/ |
||||
class CV_EXPORTS TrackerSampler |
||||
{ |
||||
public: |
||||
TrackerSampler(); |
||||
|
||||
~TrackerSampler(); |
||||
|
||||
/** @brief Computes the regions starting from a position in an image
|
||||
@param image The current frame |
||||
@param boundingBox The bounding box from which regions can be calculated |
||||
*/ |
||||
void sampling(const Mat& image, Rect boundingBox); |
||||
|
||||
/** @brief Return the collection of the TrackerSamplerAlgorithm
|
||||
*/ |
||||
const std::vector<Ptr<TrackerSamplerAlgorithm>>& getSamplers() const; |
||||
|
||||
/** @brief Return the samples from all TrackerSamplerAlgorithm, @cite AAM Fig. 1 variable Sk
|
||||
*/ |
||||
const std::vector<Mat>& getSamples() const; |
||||
|
||||
/** @brief Add TrackerSamplerAlgorithm in the collection. Return true if sampler is added, false otherwise
|
||||
@param sampler The TrackerSamplerAlgorithm |
||||
*/ |
||||
bool addTrackerSamplerAlgorithm(const Ptr<TrackerSamplerAlgorithm>& sampler); |
||||
|
||||
private: |
||||
std::vector<Ptr<TrackerSamplerAlgorithm>> samplers; |
||||
std::vector<Mat> samples; |
||||
bool blockAddTrackerSampler; |
||||
|
||||
void clearSamples(); |
||||
}; |
||||
|
||||
/************************************ TrackerModel Base Classes ************************************/ |
||||
|
||||
/** @brief Abstract base class for TrackerTargetState that represents a possible state of the target.
|
||||
|
||||
See @cite AAM \f$\hat{x}^{i}_{k}\f$ all the states candidates. |
||||
|
||||
Inherits this class with your Target state, In own implementation you can add scale variation, |
||||
width, height, orientation, etc. |
||||
*/ |
||||
class CV_EXPORTS TrackerTargetState |
||||
{ |
||||
public: |
||||
virtual ~TrackerTargetState() {}; |
||||
/** @brief Get the position
|
||||
* @return The position |
||||
*/ |
||||
Point2f getTargetPosition() const; |
||||
|
||||
/** @brief Set the position
|
||||
* @param position The position |
||||
*/ |
||||
void setTargetPosition(const Point2f& position); |
||||
/** @brief Get the width of the target
|
||||
* @return The width of the target |
||||
*/ |
||||
int getTargetWidth() const; |
||||
|
||||
/** @brief Set the width of the target
|
||||
* @param width The width of the target |
||||
*/ |
||||
void setTargetWidth(int width); |
||||
/** @brief Get the height of the target
|
||||
* @return The height of the target |
||||
*/ |
||||
int getTargetHeight() const; |
||||
|
||||
/** @brief Set the height of the target
|
||||
* @param height The height of the target |
||||
*/ |
||||
void setTargetHeight(int height); |
||||
|
||||
protected: |
||||
Point2f targetPosition; |
||||
int targetWidth; |
||||
int targetHeight; |
||||
}; |
||||
|
||||
/** @brief Represents the model of the target at frame \f$k\f$ (all states and scores)
|
||||
|
||||
See @cite AAM The set of the pair \f$\langle \hat{x}^{i}_{k}, C^{i}_{k} \rangle\f$ |
||||
@sa TrackerTargetState |
||||
*/ |
||||
typedef std::vector<std::pair<Ptr<TrackerTargetState>, float>> ConfidenceMap; |
||||
|
||||
/** @brief Represents the estimate states for all frames
|
||||
|
||||
@cite AAM \f$x_{k}\f$ is the trajectory of the target up to time \f$k\f$ |
||||
|
||||
@sa TrackerTargetState |
||||
*/ |
||||
typedef std::vector<Ptr<TrackerTargetState>> Trajectory; |
||||
|
||||
/** @brief Abstract base class for TrackerStateEstimator that estimates the most likely target state.
|
||||
|
||||
See @cite AAM State estimator |
||||
|
||||
See @cite AMVOT Statistical modeling (Fig. 3), Table III (generative) - IV (discriminative) - V (hybrid) |
||||
*/ |
||||
class CV_EXPORTS TrackerStateEstimator |
||||
{ |
||||
public: |
||||
virtual ~TrackerStateEstimator(); |
||||
|
||||
/** @brief Estimate the most likely target state, return the estimated state
|
||||
@param confidenceMaps The overall appearance model as a list of :cConfidenceMap |
||||
*/ |
||||
Ptr<TrackerTargetState> estimate(const std::vector<ConfidenceMap>& confidenceMaps); |
||||
|
||||
/** @brief Update the ConfidenceMap with the scores
|
||||
@param confidenceMaps The overall appearance model as a list of :cConfidenceMap |
||||
*/ |
||||
void update(std::vector<ConfidenceMap>& confidenceMaps); |
||||
|
||||
/** @brief Create TrackerStateEstimator by tracker state estimator type
|
||||
@param trackeStateEstimatorType The TrackerStateEstimator name |
||||
|
||||
The modes available now: |
||||
|
||||
- "BOOSTING" -- Boosting-based discriminative appearance models. See @cite AMVOT section 4.4 |
||||
|
||||
The modes available soon: |
||||
|
||||
- "SVM" -- SVM-based discriminative appearance models. See @cite AMVOT section 4.5 |
||||
*/ |
||||
static Ptr<TrackerStateEstimator> create(const String& trackeStateEstimatorType); |
||||
|
||||
/** @brief Get the name of the specific TrackerStateEstimator
|
||||
*/ |
||||
String getClassName() const; |
||||
|
||||
protected: |
||||
virtual Ptr<TrackerTargetState> estimateImpl(const std::vector<ConfidenceMap>& confidenceMaps) = 0; |
||||
virtual void updateImpl(std::vector<ConfidenceMap>& confidenceMaps) = 0; |
||||
String className; |
||||
}; |
||||
|
||||
/** @brief Abstract class that represents the model of the target.
|
||||
|
||||
It must be instantiated by specialized tracker |
||||
|
||||
See @cite AAM Ak |
||||
|
||||
Inherits this with your TrackerModel |
||||
*/ |
||||
class CV_EXPORTS TrackerModel |
||||
{ |
||||
public: |
||||
TrackerModel(); |
||||
|
||||
virtual ~TrackerModel(); |
||||
|
||||
/** @brief Set TrackerEstimator, return true if the tracker state estimator is added, false otherwise
|
||||
@param trackerStateEstimator The TrackerStateEstimator |
||||
@note You can add only one TrackerStateEstimator |
||||
*/ |
||||
bool setTrackerStateEstimator(Ptr<TrackerStateEstimator> trackerStateEstimator); |
||||
|
||||
/** @brief Estimate the most likely target location
|
||||
|
||||
@cite AAM ME, Model Estimation table I |
||||
@param responses Features extracted from TrackerFeatureSet |
||||
*/ |
||||
void modelEstimation(const std::vector<Mat>& responses); |
||||
|
||||
/** @brief Update the model
|
||||
|
||||
@cite AAM MU, Model Update table I |
||||
*/ |
||||
void modelUpdate(); |
||||
|
||||
/** @brief Run the TrackerStateEstimator, return true if is possible to estimate a new state, false otherwise
|
||||
*/ |
||||
bool runStateEstimator(); |
||||
|
||||
/** @brief Set the current TrackerTargetState in the Trajectory
|
||||
@param lastTargetState The current TrackerTargetState |
||||
*/ |
||||
void setLastTargetState(const Ptr<TrackerTargetState>& lastTargetState); |
||||
|
||||
/** @brief Get the last TrackerTargetState from Trajectory
|
||||
*/ |
||||
Ptr<TrackerTargetState> getLastTargetState() const; |
||||
|
||||
/** @brief Get the list of the ConfidenceMap
|
||||
*/ |
||||
const std::vector<ConfidenceMap>& getConfidenceMaps() const; |
||||
|
||||
/** @brief Get the last ConfidenceMap for the current frame
|
||||
*/ |
||||
const ConfidenceMap& getLastConfidenceMap() const; |
||||
|
||||
/** @brief Get the TrackerStateEstimator
|
||||
*/ |
||||
Ptr<TrackerStateEstimator> getTrackerStateEstimator() const; |
||||
|
||||
private: |
||||
void clearCurrentConfidenceMap(); |
||||
|
||||
protected: |
||||
std::vector<ConfidenceMap> confidenceMaps; |
||||
Ptr<TrackerStateEstimator> stateEstimator; |
||||
ConfidenceMap currentConfidenceMap; |
||||
Trajectory trajectory; |
||||
int maxCMLength; |
||||
|
||||
virtual void modelEstimationImpl(const std::vector<Mat>& responses) = 0; |
||||
virtual void modelUpdateImpl() = 0; |
||||
}; |
||||
|
||||
/************************************ Specific TrackerStateEstimator Classes ************************************/ |
||||
|
||||
// None
|
||||
|
||||
/************************************ Specific TrackerSamplerAlgorithm Classes ************************************/ |
||||
|
||||
/** @brief TrackerSampler based on CSC (current state centered), used by MIL algorithm TrackerMIL
|
||||
*/ |
||||
class CV_EXPORTS TrackerSamplerCSC : public TrackerSamplerAlgorithm |
||||
{ |
||||
public: |
||||
~TrackerSamplerCSC(); |
||||
|
||||
enum MODE |
||||
{ |
||||
MODE_INIT_POS = 1, //!< mode for init positive samples
|
||||
MODE_INIT_NEG = 2, //!< mode for init negative samples
|
||||
MODE_TRACK_POS = 3, //!< mode for update positive samples
|
||||
MODE_TRACK_NEG = 4, //!< mode for update negative samples
|
||||
MODE_DETECT = 5 //!< mode for detect samples
|
||||
}; |
||||
|
||||
struct CV_EXPORTS Params |
||||
{ |
||||
Params(); |
||||
float initInRad; //!< radius for gathering positive instances during init
|
||||
float trackInPosRad; //!< radius for gathering positive instances during tracking
|
||||
float searchWinSize; //!< size of search window
|
||||
int initMaxNegNum; //!< # negative samples to use during init
|
||||
int trackMaxPosNum; //!< # positive samples to use during training
|
||||
int trackMaxNegNum; //!< # negative samples to use during training
|
||||
}; |
||||
|
||||
/** @brief Constructor
|
||||
@param parameters TrackerSamplerCSC parameters TrackerSamplerCSC::Params |
||||
*/ |
||||
TrackerSamplerCSC(const TrackerSamplerCSC::Params& parameters = TrackerSamplerCSC::Params()); |
||||
|
||||
/** @brief Set the sampling mode of TrackerSamplerCSC
|
||||
@param samplingMode The sampling mode |
||||
|
||||
The modes are: |
||||
|
||||
- "MODE_INIT_POS = 1" -- for the positive sampling in initialization step |
||||
- "MODE_INIT_NEG = 2" -- for the negative sampling in initialization step |
||||
- "MODE_TRACK_POS = 3" -- for the positive sampling in update step |
||||
- "MODE_TRACK_NEG = 4" -- for the negative sampling in update step |
||||
- "MODE_DETECT = 5" -- for the sampling in detection step |
||||
*/ |
||||
void setMode(int samplingMode); |
||||
|
||||
bool sampling(const Mat& image, const Rect& boundingBox, std::vector<Mat>& sample) CV_OVERRIDE; |
||||
|
||||
private: |
||||
Params params; |
||||
int mode; |
||||
RNG rng; |
||||
|
||||
std::vector<Mat> sampleImage(const Mat& img, int x, int y, int w, int h, float inrad, float outrad = 0, int maxnum = 1000000); |
||||
}; |
||||
|
||||
//! @}
|
||||
|
||||
}}} // namespace cv::detail::tracking
|
||||
|
||||
#endif // OPENCV_VIDEO_DETAIL_TRACKING_HPP
|
@ -0,0 +1,168 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_VIDEO_DETAIL_TRACKING_FEATURE_HPP |
||||
#define OPENCV_VIDEO_DETAIL_TRACKING_FEATURE_HPP |
||||
|
||||
#include "opencv2/core.hpp" |
||||
#include "opencv2/imgproc.hpp" |
||||
|
||||
/*
|
||||
* TODO This implementation is based on apps/traincascade/ |
||||
* TODO Changed CvHaarEvaluator based on ADABOOSTING implementation (Grabner et al.) |
||||
*/ |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
//! @addtogroup tracking_detail
|
||||
//! @{
|
||||
|
||||
inline namespace feature { |
||||
|
||||
class CvParams |
||||
{ |
||||
public: |
||||
CvParams(); |
||||
virtual ~CvParams() |
||||
{ |
||||
} |
||||
}; |
||||
|
||||
class CvFeatureParams : public CvParams |
||||
{ |
||||
public: |
||||
enum FeatureType |
||||
{ |
||||
HAAR = 0, |
||||
LBP = 1, |
||||
HOG = 2 |
||||
}; |
||||
|
||||
CvFeatureParams(); |
||||
static Ptr<CvFeatureParams> create(CvFeatureParams::FeatureType featureType); |
||||
int maxCatCount; // 0 in case of numerical features
|
||||
int featSize; // 1 in case of simple features (HAAR, LBP) and N_BINS(9)*N_CELLS(4) in case of Dalal's HOG features
|
||||
int numFeatures; |
||||
}; |
||||
|
||||
class CvFeatureEvaluator |
||||
{ |
||||
public: |
||||
virtual ~CvFeatureEvaluator() |
||||
{ |
||||
} |
||||
virtual void init(const CvFeatureParams* _featureParams, int _maxSampleCount, Size _winSize); |
||||
virtual void setImage(const Mat& img, uchar clsLabel, int idx); |
||||
static Ptr<CvFeatureEvaluator> create(CvFeatureParams::FeatureType type); |
||||
|
||||
int getNumFeatures() const |
||||
{ |
||||
return numFeatures; |
||||
} |
||||
int getMaxCatCount() const |
||||
{ |
||||
return featureParams->maxCatCount; |
||||
} |
||||
int getFeatureSize() const |
||||
{ |
||||
return featureParams->featSize; |
||||
} |
||||
const Mat& getCls() const |
||||
{ |
||||
return cls; |
||||
} |
||||
float getCls(int si) const |
||||
{ |
||||
return cls.at<float>(si, 0); |
||||
} |
||||
|
||||
protected: |
||||
virtual void generateFeatures() = 0; |
||||
|
||||
int npos, nneg; |
||||
int numFeatures; |
||||
Size winSize; |
||||
CvFeatureParams* featureParams; |
||||
Mat cls; |
||||
}; |
||||
|
||||
class CvHaarFeatureParams : public CvFeatureParams |
||||
{ |
||||
public: |
||||
CvHaarFeatureParams(); |
||||
bool isIntegral; |
||||
}; |
||||
|
||||
class CvHaarEvaluator : public CvFeatureEvaluator |
||||
{ |
||||
public: |
||||
class FeatureHaar |
||||
{ |
||||
|
||||
public: |
||||
FeatureHaar(Size patchSize); |
||||
bool eval(const Mat& image, Rect ROI, float* result) const; |
||||
inline int getNumAreas() const { return m_numAreas; } |
||||
inline const std::vector<float>& getWeights() const { return m_weights; } |
||||
inline const std::vector<Rect>& getAreas() const { return m_areas; } |
||||
|
||||
private: |
||||
int m_type; |
||||
int m_numAreas; |
||||
std::vector<float> m_weights; |
||||
float m_initMean; |
||||
float m_initSigma; |
||||
void generateRandomFeature(Size imageSize); |
||||
float getSum(const Mat& image, Rect imgROI) const; |
||||
std::vector<Rect> m_areas; // areas within the patch over which to compute the feature
|
||||
cv::Size m_initSize; // size of the patch used during training
|
||||
cv::Size m_curSize; // size of the patches currently under investigation
|
||||
float m_scaleFactorHeight; // scaling factor in vertical direction
|
||||
float m_scaleFactorWidth; // scaling factor in horizontal direction
|
||||
std::vector<Rect> m_scaleAreas; // areas after scaling
|
||||
std::vector<float> m_scaleWeights; // weights after scaling
|
||||
}; |
||||
|
||||
virtual void init(const CvFeatureParams* _featureParams, int _maxSampleCount, Size _winSize) CV_OVERRIDE; |
||||
virtual void setImage(const Mat& img, uchar clsLabel = 0, int idx = 1) CV_OVERRIDE; |
||||
inline const std::vector<CvHaarEvaluator::FeatureHaar>& getFeatures() const { return features; } |
||||
inline CvHaarEvaluator::FeatureHaar& getFeatures(int idx) |
||||
{ |
||||
return features[idx]; |
||||
} |
||||
inline void setWinSize(Size patchSize) { winSize = patchSize; } |
||||
inline Size getWinSize() const { return winSize; } |
||||
virtual void generateFeatures() CV_OVERRIDE; |
||||
|
||||
/**
|
||||
* \brief Overload the original generateFeatures in order to limit the number of the features |
||||
* @param numFeatures Number of the features |
||||
*/ |
||||
virtual void generateFeatures(int numFeatures); |
||||
|
||||
protected: |
||||
bool isIntegral; |
||||
|
||||
/* TODO Added from MIL implementation */ |
||||
Mat _ii_img; |
||||
void compute_integral(const cv::Mat& img, std::vector<cv::Mat_<float>>& ii_imgs) |
||||
{ |
||||
Mat ii_img; |
||||
integral(img, ii_img, CV_32F); |
||||
split(ii_img, ii_imgs); |
||||
} |
||||
|
||||
std::vector<FeatureHaar> features; |
||||
Mat sum; /* sum images (each row represents image) */ |
||||
}; |
||||
|
||||
} // namespace feature
|
||||
|
||||
//! @}
|
||||
|
||||
}}} // namespace cv::detail::tracking
|
||||
|
||||
#endif |
@ -0,0 +1,32 @@ |
||||
package org.opencv.test.video; |
||||
|
||||
import org.opencv.core.Core; |
||||
import org.opencv.core.CvException; |
||||
import org.opencv.test.OpenCVTestCase; |
||||
|
||||
import org.opencv.video.Tracker; |
||||
import org.opencv.video.TrackerGOTURN; |
||||
import org.opencv.video.TrackerMIL; |
||||
|
||||
public class TrackerCreateTest extends OpenCVTestCase { |
||||
|
||||
@Override |
||||
protected void setUp() throws Exception { |
||||
super.setUp(); |
||||
} |
||||
|
||||
|
||||
public void testCreateTrackerGOTURN() { |
||||
try { |
||||
Tracker tracker = TrackerGOTURN.create(); |
||||
assert(tracker != null); |
||||
} catch (CvException e) { |
||||
// expected, model files may be missing
|
||||
} |
||||
} |
||||
|
||||
public void testCreateTrackerMIL() { |
||||
Tracker tracker = TrackerMIL.create(); |
||||
} |
||||
|
||||
} |
@ -0,0 +1,4 @@ |
||||
#ifdef HAVE_OPENCV_VIDEO |
||||
typedef TrackerMIL::Params TrackerMIL_Params; |
||||
typedef TrackerGOTURN::Params TrackerGOTURN_Params; |
||||
#endif |
@ -0,0 +1,19 @@ |
||||
#!/usr/bin/env python |
||||
import os |
||||
import numpy as np |
||||
import cv2 as cv |
||||
|
||||
from tests_common import NewOpenCVTests, unittest |
||||
|
||||
class tracking_test(NewOpenCVTests): |
||||
|
||||
def test_createTracker(self): |
||||
t = cv.TrackerMIL_create() |
||||
try: |
||||
t = cv.TrackerGOTURN_create() |
||||
except cv.error as e: |
||||
pass # may fail due to missing DL model files |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
NewOpenCVTests.bootstrap() |
@ -0,0 +1,104 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "perf_precomp.hpp" |
||||
|
||||
namespace opencv_test { namespace { |
||||
using namespace perf; |
||||
|
||||
typedef tuple<string, int, Rect> TrackingParams_t; |
||||
|
||||
std::vector<TrackingParams_t> getTrackingParams() |
||||
{ |
||||
std::vector<TrackingParams_t> params { |
||||
TrackingParams_t("david/data/david.webm", 300, Rect(163,62,47,56)), |
||||
TrackingParams_t("dudek/data/dudek.webm", 1, Rect(123,87,132,176)), |
||||
TrackingParams_t("faceocc2/data/faceocc2.webm", 1, Rect(118,57,82,98)) |
||||
}; |
||||
return params; |
||||
} |
||||
|
||||
class Tracking : public perf::TestBaseWithParam<TrackingParams_t> |
||||
{ |
||||
public: |
||||
template<typename ROI_t = Rect2d, typename Tracker> |
||||
void runTrackingTest(const Ptr<Tracker>& tracker, const TrackingParams_t& params); |
||||
}; |
||||
|
||||
template<typename ROI_t, typename Tracker> |
||||
void Tracking::runTrackingTest(const Ptr<Tracker>& tracker, const TrackingParams_t& params) |
||||
{ |
||||
const int N = 10; |
||||
string video = get<0>(params); |
||||
int startFrame = get<1>(params); |
||||
//int endFrame = startFrame + N;
|
||||
Rect boundingBox = get<2>(params); |
||||
|
||||
string videoPath = findDataFile(std::string("cv/tracking/") + video); |
||||
|
||||
VideoCapture c; |
||||
c.open(videoPath); |
||||
if (!c.isOpened()) |
||||
throw SkipTestException("Can't open video file"); |
||||
#if 0 |
||||
// c.set(CAP_PROP_POS_FRAMES, startFrame);
|
||||
#else |
||||
if (startFrame) |
||||
std::cout << "startFrame = " << startFrame << std::endl; |
||||
for (int i = 0; i < startFrame; i++) |
||||
{ |
||||
Mat dummy_frame; |
||||
c >> dummy_frame; |
||||
ASSERT_FALSE(dummy_frame.empty()) << i << ": " << videoPath; |
||||
} |
||||
#endif |
||||
|
||||
// decode frames into memory (don't measure decoding performance)
|
||||
std::vector<Mat> frames; |
||||
for (int i = 0; i < N; ++i) |
||||
{ |
||||
Mat frame; |
||||
c >> frame; |
||||
ASSERT_FALSE(frame.empty()) << "i=" << i; |
||||
frames.push_back(frame); |
||||
} |
||||
|
||||
std::cout << "frame size = " << frames[0].size() << std::endl; |
||||
|
||||
PERF_SAMPLE_BEGIN(); |
||||
{ |
||||
tracker->init(frames[0], (ROI_t)boundingBox); |
||||
for (int i = 1; i < N; ++i) |
||||
{ |
||||
ROI_t rc; |
||||
tracker->update(frames[i], rc); |
||||
ASSERT_FALSE(rc.empty()); |
||||
} |
||||
} |
||||
PERF_SAMPLE_END(); |
||||
|
||||
SANITY_CHECK_NOTHING(); |
||||
} |
||||
|
||||
|
||||
//==================================================================================================
|
||||
|
||||
PERF_TEST_P(Tracking, MIL, testing::ValuesIn(getTrackingParams())) |
||||
{ |
||||
auto tracker = TrackerMIL::create(); |
||||
runTrackingTest<Rect>(tracker, GetParam()); |
||||
} |
||||
|
||||
PERF_TEST_P(Tracking, GOTURN, testing::ValuesIn(getTrackingParams())) |
||||
{ |
||||
std::string model = cvtest::findDataFile("dnn/gsoc2016-goturn/goturn.prototxt"); |
||||
std::string weights = cvtest::findDataFile("dnn/gsoc2016-goturn/goturn.caffemodel", false); |
||||
TrackerGOTURN::Params params; |
||||
params.modelTxt = model; |
||||
params.modelBin = weights; |
||||
auto tracker = TrackerGOTURN::create(params); |
||||
runTrackingTest<Rect>(tracker, GetParam()); |
||||
} |
||||
|
||||
}} // namespace
|
@ -0,0 +1,25 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
TrackerFeature::~TrackerFeature() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
void TrackerFeature::compute(const std::vector<Mat>& images, Mat& response) |
||||
{ |
||||
if (images.empty()) |
||||
return; |
||||
|
||||
computeImpl(images, response); |
||||
} |
||||
|
||||
}}} // namespace cv::detail::tracking
|
@ -0,0 +1,121 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
#include "opencv2/video/detail/tracking_feature.private.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
inline namespace internal { |
||||
|
||||
class TrackerFeatureHAAR : public TrackerFeature |
||||
{ |
||||
public: |
||||
struct Params |
||||
{ |
||||
Params(); |
||||
int numFeatures; //!< # of rects
|
||||
Size rectSize; //!< rect size
|
||||
bool isIntegral; //!< true if input images are integral, false otherwise
|
||||
}; |
||||
|
||||
TrackerFeatureHAAR(const TrackerFeatureHAAR::Params& parameters = TrackerFeatureHAAR::Params()); |
||||
|
||||
virtual ~TrackerFeatureHAAR() CV_OVERRIDE {} |
||||
|
||||
protected: |
||||
bool computeImpl(const std::vector<Mat>& images, Mat& response) CV_OVERRIDE; |
||||
|
||||
private: |
||||
Params params; |
||||
Ptr<CvHaarEvaluator> featureEvaluator; |
||||
}; |
||||
|
||||
/**
|
||||
* Parameters |
||||
*/ |
||||
|
||||
TrackerFeatureHAAR::Params::Params() |
||||
{ |
||||
numFeatures = 250; |
||||
rectSize = Size(100, 100); |
||||
isIntegral = false; |
||||
} |
||||
|
||||
TrackerFeatureHAAR::TrackerFeatureHAAR(const TrackerFeatureHAAR::Params& parameters) |
||||
: params(parameters) |
||||
{ |
||||
CvHaarFeatureParams haarParams; |
||||
haarParams.numFeatures = params.numFeatures; |
||||
haarParams.isIntegral = params.isIntegral; |
||||
featureEvaluator = makePtr<CvHaarEvaluator>(); |
||||
featureEvaluator->init(&haarParams, 1, params.rectSize); |
||||
} |
||||
|
||||
class Parallel_compute : public cv::ParallelLoopBody |
||||
{ |
||||
private: |
||||
Ptr<CvHaarEvaluator> featureEvaluator; |
||||
std::vector<Mat> images; |
||||
Mat response; |
||||
//std::vector<CvHaarEvaluator::FeatureHaar> features;
|
||||
public: |
||||
Parallel_compute(Ptr<CvHaarEvaluator>& fe, const std::vector<Mat>& img, Mat& resp) |
||||
: featureEvaluator(fe) |
||||
, images(img) |
||||
, response(resp) |
||||
{ |
||||
|
||||
//features = featureEvaluator->getFeatures();
|
||||
} |
||||
|
||||
virtual void operator()(const cv::Range& r) const CV_OVERRIDE |
||||
{ |
||||
for (int jf = r.start; jf != r.end; ++jf) |
||||
{ |
||||
int cols = images[jf].cols; |
||||
int rows = images[jf].rows; |
||||
for (int j = 0; j < featureEvaluator->getNumFeatures(); j++) |
||||
{ |
||||
float res = 0; |
||||
featureEvaluator->getFeatures()[j].eval(images[jf], Rect(0, 0, cols, rows), &res); |
||||
(Mat_<float>(response))(j, jf) = res; |
||||
} |
||||
} |
||||
} |
||||
}; |
||||
|
||||
bool TrackerFeatureHAAR::computeImpl(const std::vector<Mat>& images, Mat& response) |
||||
{ |
||||
if (images.empty()) |
||||
{ |
||||
return false; |
||||
} |
||||
|
||||
int numFeatures = featureEvaluator->getNumFeatures(); |
||||
|
||||
response = Mat_<float>(Size((int)images.size(), numFeatures)); |
||||
|
||||
std::vector<CvHaarEvaluator::FeatureHaar> f = featureEvaluator->getFeatures(); |
||||
//for each sample compute #n_feature -> put each feature (n Rect) in response
|
||||
parallel_for_(Range(0, (int)images.size()), Parallel_compute(featureEvaluator, images, response)); |
||||
|
||||
/*for ( size_t i = 0; i < images.size(); i++ )
|
||||
{ |
||||
int c = images[i].cols; |
||||
int r = images[i].rows; |
||||
for ( int j = 0; j < numFeatures; j++ ) |
||||
{ |
||||
float res = 0; |
||||
featureEvaluator->getFeatures( j ).eval( images[i], Rect( 0, 0, c, r ), &res ); |
||||
( Mat_<float>( response ) )( j, i ) = res; |
||||
} |
||||
}*/ |
||||
|
||||
return true; |
||||
} |
||||
|
||||
}}}} // namespace cv::detail::tracking::internal
|
@ -0,0 +1,60 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
TrackerFeatureSet::TrackerFeatureSet() |
||||
{ |
||||
blockAddTrackerFeature = false; |
||||
} |
||||
|
||||
TrackerFeatureSet::~TrackerFeatureSet() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
void TrackerFeatureSet::extraction(const std::vector<Mat>& images) |
||||
{ |
||||
blockAddTrackerFeature = true; |
||||
|
||||
clearResponses(); |
||||
responses.resize(features.size()); |
||||
|
||||
for (size_t i = 0; i < features.size(); i++) |
||||
{ |
||||
CV_DbgAssert(features[i]); |
||||
features[i]->compute(images, responses[i]); |
||||
} |
||||
} |
||||
|
||||
bool TrackerFeatureSet::addTrackerFeature(const Ptr<TrackerFeature>& feature) |
||||
{ |
||||
CV_Assert(!blockAddTrackerFeature); |
||||
CV_Assert(feature); |
||||
|
||||
features.push_back(feature); |
||||
return true; |
||||
} |
||||
|
||||
const std::vector<Ptr<TrackerFeature>>& TrackerFeatureSet::getTrackerFeatures() const |
||||
{ |
||||
return features; |
||||
} |
||||
|
||||
const std::vector<Mat>& TrackerFeatureSet::getResponses() const |
||||
{ |
||||
return responses; |
||||
} |
||||
|
||||
void TrackerFeatureSet::clearResponses() |
||||
{ |
||||
responses.clear(); |
||||
} |
||||
|
||||
}}} // namespace cv::detail::tracking
|
@ -0,0 +1,85 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "tracker_mil_model.hpp" |
||||
|
||||
/**
|
||||
* TrackerMILModel |
||||
*/ |
||||
|
||||
namespace cv { |
||||
inline namespace tracking { |
||||
namespace impl { |
||||
|
||||
TrackerMILModel::TrackerMILModel(const Rect& boundingBox) |
||||
{ |
||||
currentSample.clear(); |
||||
mode = MODE_POSITIVE; |
||||
width = boundingBox.width; |
||||
height = boundingBox.height; |
||||
|
||||
Ptr<TrackerStateEstimatorMILBoosting::TrackerMILTargetState> initState = Ptr<TrackerStateEstimatorMILBoosting::TrackerMILTargetState>( |
||||
new TrackerStateEstimatorMILBoosting::TrackerMILTargetState(Point2f((float)boundingBox.x, (float)boundingBox.y), boundingBox.width, boundingBox.height, |
||||
true, Mat())); |
||||
trajectory.push_back(initState); |
||||
} |
||||
|
||||
void TrackerMILModel::responseToConfidenceMap(const std::vector<Mat>& responses, ConfidenceMap& confidenceMap) |
||||
{ |
||||
if (currentSample.empty()) |
||||
{ |
||||
CV_Error(-1, "The samples in Model estimation are empty"); |
||||
} |
||||
|
||||
for (size_t i = 0; i < responses.size(); i++) |
||||
{ |
||||
//for each column (one sample) there are #num_feature
|
||||
//get informations from currentSample
|
||||
for (int j = 0; j < responses.at(i).cols; j++) |
||||
{ |
||||
|
||||
Size currentSize; |
||||
Point currentOfs; |
||||
currentSample.at(j).locateROI(currentSize, currentOfs); |
||||
bool foreground = false; |
||||
if (mode == MODE_POSITIVE || mode == MODE_ESTIMATON) |
||||
{ |
||||
foreground = true; |
||||
} |
||||
else if (mode == MODE_NEGATIVE) |
||||
{ |
||||
foreground = false; |
||||
} |
||||
|
||||
//get the column of the HAAR responses
|
||||
Mat singleResponse = responses.at(i).col(j); |
||||
|
||||
//create the state
|
||||
Ptr<TrackerStateEstimatorMILBoosting::TrackerMILTargetState> currentState = Ptr<TrackerStateEstimatorMILBoosting::TrackerMILTargetState>( |
||||
new TrackerStateEstimatorMILBoosting::TrackerMILTargetState(currentOfs, width, height, foreground, singleResponse)); |
||||
|
||||
confidenceMap.push_back(std::make_pair(currentState, 0.0f)); |
||||
} |
||||
} |
||||
} |
||||
|
||||
void TrackerMILModel::modelEstimationImpl(const std::vector<Mat>& responses) |
||||
{ |
||||
responseToConfidenceMap(responses, currentConfidenceMap); |
||||
} |
||||
|
||||
void TrackerMILModel::modelUpdateImpl() |
||||
{ |
||||
} |
||||
|
||||
void TrackerMILModel::setMode(int trainingMode, const std::vector<Mat>& samples) |
||||
{ |
||||
currentSample.clear(); |
||||
currentSample = samples; |
||||
|
||||
mode = trainingMode; |
||||
} |
||||
|
||||
}}} // namespace cv::tracking::impl
|
@ -0,0 +1,67 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef __OPENCV_TRACKER_MIL_MODEL_HPP__ |
||||
#define __OPENCV_TRACKER_MIL_MODEL_HPP__ |
||||
|
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
#include "tracker_mil_state.hpp" |
||||
|
||||
namespace cv { |
||||
inline namespace tracking { |
||||
namespace impl { |
||||
|
||||
using namespace cv::detail::tracking; |
||||
|
||||
/**
|
||||
* \brief Implementation of TrackerModel for MIL algorithm |
||||
*/ |
||||
class TrackerMILModel : public detail::TrackerModel |
||||
{ |
||||
public: |
||||
enum
|
||||
{ |
||||
MODE_POSITIVE = 1, // mode for positive features
|
||||
MODE_NEGATIVE = 2, // mode for negative features
|
||||
MODE_ESTIMATON = 3 // mode for estimation step
|
||||
}; |
||||
|
||||
/**
|
||||
* \brief Constructor |
||||
* \param boundingBox The first boundingBox |
||||
*/ |
||||
TrackerMILModel(const Rect& boundingBox); |
||||
|
||||
/**
|
||||
* \brief Destructor |
||||
*/ |
||||
~TrackerMILModel() {}; |
||||
|
||||
/**
|
||||
* \brief Set the mode |
||||
*/ |
||||
void setMode(int trainingMode, const std::vector<Mat>& samples); |
||||
|
||||
/**
|
||||
* \brief Create the ConfidenceMap from a list of responses |
||||
* \param responses The list of the responses |
||||
* \param confidenceMap The output |
||||
*/ |
||||
void responseToConfidenceMap(const std::vector<Mat>& responses, ConfidenceMap& confidenceMap); |
||||
|
||||
protected: |
||||
void modelEstimationImpl(const std::vector<Mat>& responses) CV_OVERRIDE; |
||||
void modelUpdateImpl() CV_OVERRIDE; |
||||
|
||||
private: |
||||
int mode; |
||||
std::vector<Mat> currentSample; |
||||
|
||||
int width; //initial width of the boundingBox
|
||||
int height; //initial height of the boundingBox
|
||||
}; |
||||
|
||||
}}} // namespace cv::tracking::impl
|
||||
|
||||
#endif |
@ -0,0 +1,159 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
#include "tracker_mil_state.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
/**
|
||||
* TrackerStateEstimatorMILBoosting::TrackerMILTargetState |
||||
*/ |
||||
TrackerStateEstimatorMILBoosting::TrackerMILTargetState::TrackerMILTargetState(const Point2f& position, int width, int height, bool foreground, |
||||
const Mat& features) |
||||
{ |
||||
setTargetPosition(position); |
||||
setTargetWidth(width); |
||||
setTargetHeight(height); |
||||
setTargetFg(foreground); |
||||
setFeatures(features); |
||||
} |
||||
|
||||
void TrackerStateEstimatorMILBoosting::TrackerMILTargetState::setTargetFg(bool foreground) |
||||
{ |
||||
isTarget = foreground; |
||||
} |
||||
|
||||
void TrackerStateEstimatorMILBoosting::TrackerMILTargetState::setFeatures(const Mat& features) |
||||
{ |
||||
targetFeatures = features; |
||||
} |
||||
|
||||
bool TrackerStateEstimatorMILBoosting::TrackerMILTargetState::isTargetFg() const |
||||
{ |
||||
return isTarget; |
||||
} |
||||
|
||||
Mat TrackerStateEstimatorMILBoosting::TrackerMILTargetState::getFeatures() const |
||||
{ |
||||
return targetFeatures; |
||||
} |
||||
|
||||
TrackerStateEstimatorMILBoosting::TrackerStateEstimatorMILBoosting(int nFeatures) |
||||
{ |
||||
className = "BOOSTING"; |
||||
trained = false; |
||||
numFeatures = nFeatures; |
||||
} |
||||
|
||||
TrackerStateEstimatorMILBoosting::~TrackerStateEstimatorMILBoosting() |
||||
{ |
||||
} |
||||
|
||||
void TrackerStateEstimatorMILBoosting::setCurrentConfidenceMap(ConfidenceMap& confidenceMap) |
||||
{ |
||||
currentConfidenceMap.clear(); |
||||
currentConfidenceMap = confidenceMap; |
||||
} |
||||
|
||||
uint TrackerStateEstimatorMILBoosting::max_idx(const std::vector<float>& v) |
||||
{ |
||||
const float* findPtr = &(*std::max_element(v.begin(), v.end())); |
||||
const float* beginPtr = &(*v.begin()); |
||||
return (uint)(findPtr - beginPtr); |
||||
} |
||||
|
||||
Ptr<TrackerTargetState> TrackerStateEstimatorMILBoosting::estimateImpl(const std::vector<ConfidenceMap>& /*confidenceMaps*/) |
||||
{ |
||||
//run ClfMilBoost classify in order to compute next location
|
||||
if (currentConfidenceMap.empty()) |
||||
return Ptr<TrackerTargetState>(); |
||||
|
||||
Mat positiveStates; |
||||
Mat negativeStates; |
||||
|
||||
prepareData(currentConfidenceMap, positiveStates, negativeStates); |
||||
|
||||
std::vector<float> prob = boostMILModel.classify(positiveStates); |
||||
|
||||
int bestind = max_idx(prob); |
||||
//float resp = prob[bestind];
|
||||
|
||||
return currentConfidenceMap.at(bestind).first; |
||||
} |
||||
|
||||
void TrackerStateEstimatorMILBoosting::prepareData(const ConfidenceMap& confidenceMap, Mat& positive, Mat& negative) |
||||
{ |
||||
|
||||
int posCounter = 0; |
||||
int negCounter = 0; |
||||
|
||||
for (size_t i = 0; i < confidenceMap.size(); i++) |
||||
{ |
||||
Ptr<TrackerMILTargetState> currentTargetState = confidenceMap.at(i).first.staticCast<TrackerMILTargetState>(); |
||||
CV_DbgAssert(currentTargetState); |
||||
if (currentTargetState->isTargetFg()) |
||||
posCounter++; |
||||
else |
||||
negCounter++; |
||||
} |
||||
|
||||
positive.create(posCounter, numFeatures, CV_32FC1); |
||||
negative.create(negCounter, numFeatures, CV_32FC1); |
||||
|
||||
//TODO change with mat fast access
|
||||
//initialize trainData (positive and negative)
|
||||
|
||||
int pc = 0; |
||||
int nc = 0; |
||||
for (size_t i = 0; i < confidenceMap.size(); i++) |
||||
{ |
||||
Ptr<TrackerMILTargetState> currentTargetState = confidenceMap.at(i).first.staticCast<TrackerMILTargetState>(); |
||||
Mat stateFeatures = currentTargetState->getFeatures(); |
||||
|
||||
if (currentTargetState->isTargetFg()) |
||||
{ |
||||
for (int j = 0; j < stateFeatures.rows; j++) |
||||
{ |
||||
//fill the positive trainData with the value of the feature j for sample i
|
||||
positive.at<float>(pc, j) = stateFeatures.at<float>(j, 0); |
||||
} |
||||
pc++; |
||||
} |
||||
else |
||||
{ |
||||
for (int j = 0; j < stateFeatures.rows; j++) |
||||
{ |
||||
//fill the negative trainData with the value of the feature j for sample i
|
||||
negative.at<float>(nc, j) = stateFeatures.at<float>(j, 0); |
||||
} |
||||
nc++; |
||||
} |
||||
} |
||||
} |
||||
|
||||
void TrackerStateEstimatorMILBoosting::updateImpl(std::vector<ConfidenceMap>& confidenceMaps) |
||||
{ |
||||
|
||||
if (!trained) |
||||
{ |
||||
//this is the first time that the classifier is built
|
||||
//init MIL
|
||||
boostMILModel.init(); |
||||
trained = true; |
||||
} |
||||
|
||||
ConfidenceMap lastConfidenceMap = confidenceMaps.back(); |
||||
Mat positiveStates; |
||||
Mat negativeStates; |
||||
|
||||
prepareData(lastConfidenceMap, positiveStates, negativeStates); |
||||
//update MIL
|
||||
boostMILModel.update(positiveStates, negativeStates); |
||||
} |
||||
|
||||
}}} // namespace cv::detail::tracking
|
@ -0,0 +1,87 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_VIDEO_DETAIL_TRACKING_MIL_STATE_HPP |
||||
#define OPENCV_VIDEO_DETAIL_TRACKING_MIL_STATE_HPP |
||||
|
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
#include "tracking_online_mil.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
/** @brief TrackerStateEstimator based on Boosting
|
||||
*/ |
||||
class CV_EXPORTS TrackerStateEstimatorMILBoosting : public TrackerStateEstimator |
||||
{ |
||||
public: |
||||
/**
|
||||
* Implementation of the target state for TrackerStateEstimatorMILBoosting |
||||
*/ |
||||
class TrackerMILTargetState : public TrackerTargetState |
||||
{ |
||||
|
||||
public: |
||||
/**
|
||||
* \brief Constructor |
||||
* \param position Top left corner of the bounding box |
||||
* \param width Width of the bounding box |
||||
* \param height Height of the bounding box |
||||
* \param foreground label for target or background |
||||
* \param features features extracted |
||||
*/ |
||||
TrackerMILTargetState(const Point2f& position, int width, int height, bool foreground, const Mat& features); |
||||
|
||||
~TrackerMILTargetState() {}; |
||||
|
||||
/** @brief Set label: true for target foreground, false for background
|
||||
@param foreground Label for background/foreground |
||||
*/ |
||||
void setTargetFg(bool foreground); |
||||
/** @brief Set the features extracted from TrackerFeatureSet
|
||||
@param features The features extracted |
||||
*/ |
||||
void setFeatures(const Mat& features); |
||||
/** @brief Get the label. Return true for target foreground, false for background
|
||||
*/ |
||||
bool isTargetFg() const; |
||||
/** @brief Get the features extracted
|
||||
*/ |
||||
Mat getFeatures() const; |
||||
|
||||
private: |
||||
bool isTarget; |
||||
Mat targetFeatures; |
||||
}; |
||||
|
||||
/** @brief Constructor
|
||||
@param nFeatures Number of features for each sample |
||||
*/ |
||||
TrackerStateEstimatorMILBoosting(int nFeatures = 250); |
||||
~TrackerStateEstimatorMILBoosting(); |
||||
|
||||
/** @brief Set the current confidenceMap
|
||||
@param confidenceMap The current :cConfidenceMap |
||||
*/ |
||||
void setCurrentConfidenceMap(ConfidenceMap& confidenceMap); |
||||
|
||||
protected: |
||||
Ptr<TrackerTargetState> estimateImpl(const std::vector<ConfidenceMap>& confidenceMaps) CV_OVERRIDE; |
||||
void updateImpl(std::vector<ConfidenceMap>& confidenceMaps) CV_OVERRIDE; |
||||
|
||||
private: |
||||
uint max_idx(const std::vector<float>& v); |
||||
void prepareData(const ConfidenceMap& confidenceMap, Mat& positive, Mat& negative); |
||||
|
||||
ClfMilBoost boostMILModel; |
||||
bool trained; |
||||
int numFeatures; |
||||
|
||||
ConfidenceMap currentConfidenceMap; |
||||
}; |
||||
|
||||
}}} // namespace cv::detail::tracking
|
||||
|
||||
#endif |
@ -0,0 +1,132 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
TrackerModel::TrackerModel() |
||||
{ |
||||
stateEstimator = Ptr<TrackerStateEstimator>(); |
||||
maxCMLength = 10; |
||||
} |
||||
|
||||
TrackerModel::~TrackerModel() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
bool TrackerModel::setTrackerStateEstimator(Ptr<TrackerStateEstimator> trackerStateEstimator) |
||||
{ |
||||
if (stateEstimator.get()) |
||||
{ |
||||
return false; |
||||
} |
||||
|
||||
stateEstimator = trackerStateEstimator; |
||||
return true; |
||||
} |
||||
|
||||
Ptr<TrackerStateEstimator> TrackerModel::getTrackerStateEstimator() const |
||||
{ |
||||
return stateEstimator; |
||||
} |
||||
|
||||
void TrackerModel::modelEstimation(const std::vector<Mat>& responses) |
||||
{ |
||||
modelEstimationImpl(responses); |
||||
} |
||||
|
||||
void TrackerModel::clearCurrentConfidenceMap() |
||||
{ |
||||
currentConfidenceMap.clear(); |
||||
} |
||||
|
||||
void TrackerModel::modelUpdate() |
||||
{ |
||||
modelUpdateImpl(); |
||||
|
||||
if (maxCMLength != -1 && (int)confidenceMaps.size() >= maxCMLength - 1) |
||||
{ |
||||
int l = maxCMLength / 2; |
||||
confidenceMaps.erase(confidenceMaps.begin(), confidenceMaps.begin() + l); |
||||
} |
||||
if (maxCMLength != -1 && (int)trajectory.size() >= maxCMLength - 1) |
||||
{ |
||||
int l = maxCMLength / 2; |
||||
trajectory.erase(trajectory.begin(), trajectory.begin() + l); |
||||
} |
||||
confidenceMaps.push_back(currentConfidenceMap); |
||||
stateEstimator->update(confidenceMaps); |
||||
|
||||
clearCurrentConfidenceMap(); |
||||
} |
||||
|
||||
bool TrackerModel::runStateEstimator() |
||||
{ |
||||
if (!stateEstimator) |
||||
{ |
||||
CV_Error(-1, "Tracker state estimator is not setted"); |
||||
} |
||||
Ptr<TrackerTargetState> targetState = stateEstimator->estimate(confidenceMaps); |
||||
if (!targetState) |
||||
return false; |
||||
|
||||
setLastTargetState(targetState); |
||||
return true; |
||||
} |
||||
|
||||
void TrackerModel::setLastTargetState(const Ptr<TrackerTargetState>& lastTargetState) |
||||
{ |
||||
trajectory.push_back(lastTargetState); |
||||
} |
||||
|
||||
Ptr<TrackerTargetState> TrackerModel::getLastTargetState() const |
||||
{ |
||||
return trajectory.back(); |
||||
} |
||||
|
||||
const std::vector<ConfidenceMap>& TrackerModel::getConfidenceMaps() const |
||||
{ |
||||
return confidenceMaps; |
||||
} |
||||
|
||||
const ConfidenceMap& TrackerModel::getLastConfidenceMap() const |
||||
{ |
||||
return confidenceMaps.back(); |
||||
} |
||||
|
||||
Point2f TrackerTargetState::getTargetPosition() const |
||||
{ |
||||
return targetPosition; |
||||
} |
||||
|
||||
void TrackerTargetState::setTargetPosition(const Point2f& position) |
||||
{ |
||||
targetPosition = position; |
||||
} |
||||
|
||||
int TrackerTargetState::getTargetWidth() const |
||||
{ |
||||
return targetWidth; |
||||
} |
||||
|
||||
void TrackerTargetState::setTargetWidth(int width) |
||||
{ |
||||
targetWidth = width; |
||||
} |
||||
int TrackerTargetState::getTargetHeight() const |
||||
{ |
||||
return targetHeight; |
||||
} |
||||
|
||||
void TrackerTargetState::setTargetHeight(int height) |
||||
{ |
||||
targetHeight = height; |
||||
} |
||||
|
||||
}}} // namespace cv::detail::tracking
|
@ -0,0 +1,68 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
|
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
TrackerSampler::TrackerSampler() |
||||
{ |
||||
blockAddTrackerSampler = false; |
||||
} |
||||
|
||||
TrackerSampler::~TrackerSampler() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
void TrackerSampler::sampling(const Mat& image, Rect boundingBox) |
||||
{ |
||||
clearSamples(); |
||||
|
||||
for (size_t i = 0; i < samplers.size(); i++) |
||||
{ |
||||
CV_DbgAssert(samplers[i]); |
||||
std::vector<Mat> current_samples; |
||||
samplers[i]->sampling(image, boundingBox, current_samples); |
||||
|
||||
//push in samples all current_samples
|
||||
for (size_t j = 0; j < current_samples.size(); j++) |
||||
{ |
||||
std::vector<Mat>::iterator it = samples.end(); |
||||
samples.insert(it, current_samples.at(j)); |
||||
} |
||||
} |
||||
|
||||
blockAddTrackerSampler = true; |
||||
} |
||||
|
||||
bool TrackerSampler::addTrackerSamplerAlgorithm(const Ptr<TrackerSamplerAlgorithm>& sampler) |
||||
{ |
||||
CV_Assert(!blockAddTrackerSampler); |
||||
CV_Assert(sampler); |
||||
|
||||
samplers.push_back(sampler); |
||||
return true; |
||||
} |
||||
|
||||
const std::vector<Ptr<TrackerSamplerAlgorithm>>& TrackerSampler::getSamplers() const |
||||
{ |
||||
return samplers; |
||||
} |
||||
|
||||
const std::vector<Mat>& TrackerSampler::getSamples() const |
||||
{ |
||||
return samples; |
||||
} |
||||
|
||||
void TrackerSampler::clearSamples() |
||||
{ |
||||
samples.clear(); |
||||
} |
||||
|
||||
}}} // namespace cv::detail::tracking
|
@ -0,0 +1,124 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
TrackerSamplerAlgorithm::~TrackerSamplerAlgorithm() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
TrackerSamplerCSC::Params::Params() |
||||
{ |
||||
initInRad = 3; |
||||
initMaxNegNum = 65; |
||||
searchWinSize = 25; |
||||
trackInPosRad = 4; |
||||
trackMaxNegNum = 65; |
||||
trackMaxPosNum = 100000; |
||||
} |
||||
|
||||
TrackerSamplerCSC::TrackerSamplerCSC(const TrackerSamplerCSC::Params& parameters) |
||||
: params(parameters) |
||||
{ |
||||
mode = MODE_INIT_POS; |
||||
rng = theRNG(); |
||||
} |
||||
|
||||
TrackerSamplerCSC::~TrackerSamplerCSC() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
bool TrackerSamplerCSC::sampling(const Mat& image, const Rect& boundingBox, std::vector<Mat>& sample) |
||||
{ |
||||
CV_Assert(!image.empty()); |
||||
|
||||
float inrad = 0; |
||||
float outrad = 0; |
||||
int maxnum = 0; |
||||
|
||||
switch (mode) |
||||
{ |
||||
case MODE_INIT_POS: |
||||
inrad = params.initInRad; |
||||
sample = sampleImage(image, boundingBox.x, boundingBox.y, boundingBox.width, boundingBox.height, inrad); |
||||
break; |
||||
case MODE_INIT_NEG: |
||||
inrad = 2.0f * params.searchWinSize; |
||||
outrad = 1.5f * params.initInRad; |
||||
maxnum = params.initMaxNegNum; |
||||
sample = sampleImage(image, boundingBox.x, boundingBox.y, boundingBox.width, boundingBox.height, inrad, outrad, maxnum); |
||||
break; |
||||
case MODE_TRACK_POS: |
||||
inrad = params.trackInPosRad; |
||||
outrad = 0; |
||||
maxnum = params.trackMaxPosNum; |
||||
sample = sampleImage(image, boundingBox.x, boundingBox.y, boundingBox.width, boundingBox.height, inrad, outrad, maxnum); |
||||
break; |
||||
case MODE_TRACK_NEG: |
||||
inrad = 1.5f * params.searchWinSize; |
||||
outrad = params.trackInPosRad + 5; |
||||
maxnum = params.trackMaxNegNum; |
||||
sample = sampleImage(image, boundingBox.x, boundingBox.y, boundingBox.width, boundingBox.height, inrad, outrad, maxnum); |
||||
break; |
||||
case MODE_DETECT: |
||||
inrad = params.searchWinSize; |
||||
sample = sampleImage(image, boundingBox.x, boundingBox.y, boundingBox.width, boundingBox.height, inrad); |
||||
break; |
||||
default: |
||||
inrad = params.initInRad; |
||||
sample = sampleImage(image, boundingBox.x, boundingBox.y, boundingBox.width, boundingBox.height, inrad); |
||||
break; |
||||
} |
||||
return false; |
||||
} |
||||
|
||||
void TrackerSamplerCSC::setMode(int samplingMode) |
||||
{ |
||||
mode = samplingMode; |
||||
} |
||||
|
||||
std::vector<Mat> TrackerSamplerCSC::sampleImage(const Mat& img, int x, int y, int w, int h, float inrad, float outrad, int maxnum) |
||||
{ |
||||
int rowsz = img.rows - h - 1; |
||||
int colsz = img.cols - w - 1; |
||||
float inradsq = inrad * inrad; |
||||
float outradsq = outrad * outrad; |
||||
int dist; |
||||
|
||||
uint minrow = max(0, (int)y - (int)inrad); |
||||
uint maxrow = min((int)rowsz - 1, (int)y + (int)inrad); |
||||
uint mincol = max(0, (int)x - (int)inrad); |
||||
uint maxcol = min((int)colsz - 1, (int)x + (int)inrad); |
||||
|
||||
//fprintf(stderr,"inrad=%f minrow=%d maxrow=%d mincol=%d maxcol=%d\n",inrad,minrow,maxrow,mincol,maxcol);
|
||||
|
||||
std::vector<Mat> samples; |
||||
samples.resize((maxrow - minrow + 1) * (maxcol - mincol + 1)); |
||||
int i = 0; |
||||
|
||||
float prob = ((float)(maxnum)) / samples.size(); |
||||
|
||||
for (int r = minrow; r <= int(maxrow); r++) |
||||
for (int c = mincol; c <= int(maxcol); c++) |
||||
{ |
||||
dist = (y - r) * (y - r) + (x - c) * (x - c); |
||||
if (float(rng.uniform(0.f, 1.f)) < prob && dist < inradsq && dist >= outradsq) |
||||
{ |
||||
samples[i] = img(Rect(c, r, w, h)); |
||||
i++; |
||||
} |
||||
} |
||||
|
||||
samples.resize(min(i, maxnum)); |
||||
return samples; |
||||
} |
||||
|
||||
}}} // namespace cv::detail::tracking
|
@ -0,0 +1,37 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
TrackerStateEstimator::~TrackerStateEstimator() |
||||
{ |
||||
} |
||||
|
||||
Ptr<TrackerTargetState> TrackerStateEstimator::estimate(const std::vector<ConfidenceMap>& confidenceMaps) |
||||
{ |
||||
if (confidenceMaps.empty()) |
||||
return Ptr<TrackerTargetState>(); |
||||
|
||||
return estimateImpl(confidenceMaps); |
||||
} |
||||
|
||||
void TrackerStateEstimator::update(std::vector<ConfidenceMap>& confidenceMaps) |
||||
{ |
||||
if (confidenceMaps.empty()) |
||||
return; |
||||
|
||||
return updateImpl(confidenceMaps); |
||||
} |
||||
|
||||
String TrackerStateEstimator::getClassName() const |
||||
{ |
||||
return className; |
||||
} |
||||
|
||||
}}} // namespace cv::detail::tracking
|
@ -0,0 +1,582 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "opencv2/video/detail/tracking.private.hpp" |
||||
#include "opencv2/video/detail/tracking_feature.private.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
/*
|
||||
* TODO This implementation is based on apps/traincascade/ |
||||
* TODO Changed CvHaarEvaluator based on ADABOOSTING implementation (Grabner et al.) |
||||
*/ |
||||
|
||||
CvParams::CvParams() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
//---------------------------- FeatureParams --------------------------------------
|
||||
|
||||
CvFeatureParams::CvFeatureParams() |
||||
: maxCatCount(0) |
||||
, featSize(1) |
||||
, numFeatures(1) |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
//------------------------------------- FeatureEvaluator ---------------------------------------
|
||||
|
||||
void CvFeatureEvaluator::init(const CvFeatureParams* _featureParams, int _maxSampleCount, Size _winSize) |
||||
{ |
||||
CV_Assert(_featureParams); |
||||
CV_Assert(_maxSampleCount > 0); |
||||
featureParams = (CvFeatureParams*)_featureParams; |
||||
winSize = _winSize; |
||||
numFeatures = _featureParams->numFeatures; |
||||
cls.create((int)_maxSampleCount, 1, CV_32FC1); |
||||
generateFeatures(); |
||||
} |
||||
|
||||
void CvFeatureEvaluator::setImage(const Mat& img, uchar clsLabel, int idx) |
||||
{ |
||||
winSize.width = img.cols; |
||||
winSize.height = img.rows; |
||||
//CV_Assert( img.cols == winSize.width );
|
||||
//CV_Assert( img.rows == winSize.height );
|
||||
CV_Assert(idx < cls.rows); |
||||
cls.ptr<float>(idx)[0] = clsLabel; |
||||
} |
||||
|
||||
CvHaarFeatureParams::CvHaarFeatureParams() |
||||
{ |
||||
isIntegral = false; |
||||
} |
||||
|
||||
//--------------------- HaarFeatureEvaluator ----------------
|
||||
|
||||
void CvHaarEvaluator::init(const CvFeatureParams* _featureParams, int /*_maxSampleCount*/, Size _winSize) |
||||
{ |
||||
CV_Assert(_featureParams); |
||||
int cols = (_winSize.width + 1) * (_winSize.height + 1); |
||||
sum.create((int)1, cols, CV_32SC1); |
||||
isIntegral = ((CvHaarFeatureParams*)_featureParams)->isIntegral; |
||||
CvFeatureEvaluator::init(_featureParams, 1, _winSize); |
||||
} |
||||
|
||||
void CvHaarEvaluator::setImage(const Mat& img, uchar /*clsLabel*/, int /*idx*/) |
||||
{ |
||||
CV_DbgAssert(!sum.empty()); |
||||
|
||||
winSize.width = img.cols; |
||||
winSize.height = img.rows; |
||||
|
||||
CvFeatureEvaluator::setImage(img, 1, 0); |
||||
if (!isIntegral) |
||||
{ |
||||
std::vector<Mat_<float>> ii_imgs; |
||||
compute_integral(img, ii_imgs); |
||||
_ii_img = ii_imgs[0]; |
||||
} |
||||
else |
||||
{ |
||||
_ii_img = img; |
||||
} |
||||
} |
||||
|
||||
void CvHaarEvaluator::generateFeatures() |
||||
{ |
||||
generateFeatures(featureParams->numFeatures); |
||||
} |
||||
|
||||
void CvHaarEvaluator::generateFeatures(int nFeatures) |
||||
{ |
||||
for (int i = 0; i < nFeatures; i++) |
||||
{ |
||||
CvHaarEvaluator::FeatureHaar feature(Size(winSize.width, winSize.height)); |
||||
features.push_back(feature); |
||||
} |
||||
} |
||||
|
||||
#define INITSIGMA(numAreas) (static_cast<float>(sqrt(256.0f * 256.0f / 12.0f * (numAreas)))); |
||||
|
||||
CvHaarEvaluator::FeatureHaar::FeatureHaar(Size patchSize) |
||||
{ |
||||
try |
||||
{ |
||||
generateRandomFeature(patchSize); |
||||
} |
||||
catch (...) |
||||
{ |
||||
// FIXIT
|
||||
throw; |
||||
} |
||||
} |
||||
|
||||
void CvHaarEvaluator::FeatureHaar::generateRandomFeature(Size patchSize) |
||||
{ |
||||
cv::Point2i position; |
||||
Size baseDim; |
||||
Size sizeFactor; |
||||
int area; |
||||
|
||||
CV_Assert(!patchSize.empty()); |
||||
|
||||
//Size minSize = Size( 3, 3 );
|
||||
int minArea = 9; |
||||
|
||||
bool valid = false; |
||||
while (!valid) |
||||
{ |
||||
//choose position and scale
|
||||
position.y = rand() % (patchSize.height); |
||||
position.x = rand() % (patchSize.width); |
||||
|
||||
baseDim.width = (int)((1 - sqrt(1 - (float)rand() * (float)(1.0 / RAND_MAX))) * patchSize.width); |
||||
baseDim.height = (int)((1 - sqrt(1 - (float)rand() * (float)(1.0 / RAND_MAX))) * patchSize.height); |
||||
|
||||
//select types
|
||||
//float probType[11] = {0.0909f, 0.0909f, 0.0909f, 0.0909f, 0.0909f, 0.0909f, 0.0909f, 0.0909f, 0.0909f, 0.0909f, 0.0950f};
|
||||
float probType[11] = { 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.2f, 0.0f, 0.0f, 0.0f, 0.0f, 0.0f }; |
||||
float prob = (float)rand() * (float)(1.0 / RAND_MAX); |
||||
|
||||
if (prob < probType[0]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 2; |
||||
sizeFactor.width = 1; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 1; |
||||
m_numAreas = 2; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -1; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = baseDim.height; |
||||
m_areas[0].width = baseDim.width; |
||||
m_areas[1].x = position.x; |
||||
m_areas[1].y = position.y + baseDim.height; |
||||
m_areas[1].height = baseDim.height; |
||||
m_areas[1].width = baseDim.width; |
||||
m_initMean = 0; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
|
||||
valid = true; |
||||
} |
||||
else if (prob < probType[0] + probType[1]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 1; |
||||
sizeFactor.width = 2; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 2; |
||||
m_numAreas = 2; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -1; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = baseDim.height; |
||||
m_areas[0].width = baseDim.width; |
||||
m_areas[1].x = position.x + baseDim.width; |
||||
m_areas[1].y = position.y; |
||||
m_areas[1].height = baseDim.height; |
||||
m_areas[1].width = baseDim.width; |
||||
m_initMean = 0; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
valid = true; |
||||
} |
||||
else if (prob < probType[0] + probType[1] + probType[2]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 4; |
||||
sizeFactor.width = 1; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 3; |
||||
m_numAreas = 3; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -2; |
||||
m_weights[2] = 1; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = baseDim.height; |
||||
m_areas[0].width = baseDim.width; |
||||
m_areas[1].x = position.x; |
||||
m_areas[1].y = position.y + baseDim.height; |
||||
m_areas[1].height = 2 * baseDim.height; |
||||
m_areas[1].width = baseDim.width; |
||||
m_areas[2].y = position.y + 3 * baseDim.height; |
||||
m_areas[2].x = position.x; |
||||
m_areas[2].height = baseDim.height; |
||||
m_areas[2].width = baseDim.width; |
||||
m_initMean = 0; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
valid = true; |
||||
} |
||||
else if (prob < probType[0] + probType[1] + probType[2] + probType[3]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 1; |
||||
sizeFactor.width = 4; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 3; |
||||
m_numAreas = 3; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -2; |
||||
m_weights[2] = 1; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = baseDim.height; |
||||
m_areas[0].width = baseDim.width; |
||||
m_areas[1].x = position.x + baseDim.width; |
||||
m_areas[1].y = position.y; |
||||
m_areas[1].height = baseDim.height; |
||||
m_areas[1].width = 2 * baseDim.width; |
||||
m_areas[2].y = position.y; |
||||
m_areas[2].x = position.x + 3 * baseDim.width; |
||||
m_areas[2].height = baseDim.height; |
||||
m_areas[2].width = baseDim.width; |
||||
m_initMean = 0; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
valid = true; |
||||
} |
||||
else if (prob < probType[0] + probType[1] + probType[2] + probType[3] + probType[4]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 2; |
||||
sizeFactor.width = 2; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 5; |
||||
m_numAreas = 4; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -1; |
||||
m_weights[2] = -1; |
||||
m_weights[3] = 1; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = baseDim.height; |
||||
m_areas[0].width = baseDim.width; |
||||
m_areas[1].x = position.x + baseDim.width; |
||||
m_areas[1].y = position.y; |
||||
m_areas[1].height = baseDim.height; |
||||
m_areas[1].width = baseDim.width; |
||||
m_areas[2].y = position.y + baseDim.height; |
||||
m_areas[2].x = position.x; |
||||
m_areas[2].height = baseDim.height; |
||||
m_areas[2].width = baseDim.width; |
||||
m_areas[3].y = position.y + baseDim.height; |
||||
m_areas[3].x = position.x + baseDim.width; |
||||
m_areas[3].height = baseDim.height; |
||||
m_areas[3].width = baseDim.width; |
||||
m_initMean = 0; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
valid = true; |
||||
} |
||||
else if (prob < probType[0] + probType[1] + probType[2] + probType[3] + probType[4] + probType[5]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 3; |
||||
sizeFactor.width = 3; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 6; |
||||
m_numAreas = 2; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -9; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = 3 * baseDim.height; |
||||
m_areas[0].width = 3 * baseDim.width; |
||||
m_areas[1].x = position.x + baseDim.width; |
||||
m_areas[1].y = position.y + baseDim.height; |
||||
m_areas[1].height = baseDim.height; |
||||
m_areas[1].width = baseDim.width; |
||||
m_initMean = -8 * 128; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
valid = true; |
||||
} |
||||
else if (prob < probType[0] + probType[1] + probType[2] + probType[3] + probType[4] + probType[5] + probType[6]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 3; |
||||
sizeFactor.width = 1; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 7; |
||||
m_numAreas = 3; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -2; |
||||
m_weights[2] = 1; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = baseDim.height; |
||||
m_areas[0].width = baseDim.width; |
||||
m_areas[1].x = position.x; |
||||
m_areas[1].y = position.y + baseDim.height; |
||||
m_areas[1].height = baseDim.height; |
||||
m_areas[1].width = baseDim.width; |
||||
m_areas[2].y = position.y + baseDim.height * 2; |
||||
m_areas[2].x = position.x; |
||||
m_areas[2].height = baseDim.height; |
||||
m_areas[2].width = baseDim.width; |
||||
m_initMean = 0; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
valid = true; |
||||
} |
||||
else if (prob < probType[0] + probType[1] + probType[2] + probType[3] + probType[4] + probType[5] + probType[6] + probType[7]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 1; |
||||
sizeFactor.width = 3; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
|
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
|
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 8; |
||||
m_numAreas = 3; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -2; |
||||
m_weights[2] = 1; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = baseDim.height; |
||||
m_areas[0].width = baseDim.width; |
||||
m_areas[1].x = position.x + baseDim.width; |
||||
m_areas[1].y = position.y; |
||||
m_areas[1].height = baseDim.height; |
||||
m_areas[1].width = baseDim.width; |
||||
m_areas[2].y = position.y; |
||||
m_areas[2].x = position.x + 2 * baseDim.width; |
||||
m_areas[2].height = baseDim.height; |
||||
m_areas[2].width = baseDim.width; |
||||
m_initMean = 0; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
valid = true; |
||||
} |
||||
else if (prob < probType[0] + probType[1] + probType[2] + probType[3] + probType[4] + probType[5] + probType[6] + probType[7] + probType[8]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 3; |
||||
sizeFactor.width = 3; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 9; |
||||
m_numAreas = 2; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -2; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = 3 * baseDim.height; |
||||
m_areas[0].width = 3 * baseDim.width; |
||||
m_areas[1].x = position.x + baseDim.width; |
||||
m_areas[1].y = position.y + baseDim.height; |
||||
m_areas[1].height = baseDim.height; |
||||
m_areas[1].width = baseDim.width; |
||||
m_initMean = 0; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
valid = true; |
||||
} |
||||
else if (prob |
||||
< probType[0] + probType[1] + probType[2] + probType[3] + probType[4] + probType[5] + probType[6] + probType[7] + probType[8] + probType[9]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 3; |
||||
sizeFactor.width = 1; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 10; |
||||
m_numAreas = 3; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -1; |
||||
m_weights[2] = 1; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = baseDim.height; |
||||
m_areas[0].width = baseDim.width; |
||||
m_areas[1].x = position.x; |
||||
m_areas[1].y = position.y + baseDim.height; |
||||
m_areas[1].height = baseDim.height; |
||||
m_areas[1].width = baseDim.width; |
||||
m_areas[2].y = position.y + baseDim.height * 2; |
||||
m_areas[2].x = position.x; |
||||
m_areas[2].height = baseDim.height; |
||||
m_areas[2].width = baseDim.width; |
||||
m_initMean = 128; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
valid = true; |
||||
} |
||||
else if (prob |
||||
< probType[0] + probType[1] + probType[2] + probType[3] + probType[4] + probType[5] + probType[6] + probType[7] + probType[8] + probType[9] |
||||
+ probType[10]) |
||||
{ |
||||
//check if feature is valid
|
||||
sizeFactor.height = 1; |
||||
sizeFactor.width = 3; |
||||
if (position.y + baseDim.height * sizeFactor.height >= patchSize.height || position.x + baseDim.width * sizeFactor.width >= patchSize.width) |
||||
continue; |
||||
area = baseDim.height * sizeFactor.height * baseDim.width * sizeFactor.width; |
||||
if (area < minArea) |
||||
continue; |
||||
|
||||
m_type = 11; |
||||
m_numAreas = 3; |
||||
m_weights.resize(m_numAreas); |
||||
m_weights[0] = 1; |
||||
m_weights[1] = -1; |
||||
m_weights[2] = 1; |
||||
m_areas.resize(m_numAreas); |
||||
m_areas[0].x = position.x; |
||||
m_areas[0].y = position.y; |
||||
m_areas[0].height = baseDim.height; |
||||
m_areas[0].width = baseDim.width; |
||||
m_areas[1].x = position.x + baseDim.width; |
||||
m_areas[1].y = position.y; |
||||
m_areas[1].height = baseDim.height; |
||||
m_areas[1].width = baseDim.width; |
||||
m_areas[2].y = position.y; |
||||
m_areas[2].x = position.x + 2 * baseDim.width; |
||||
m_areas[2].height = baseDim.height; |
||||
m_areas[2].width = baseDim.width; |
||||
m_initMean = 128; |
||||
m_initSigma = INITSIGMA(m_numAreas); |
||||
valid = true; |
||||
} |
||||
else |
||||
CV_Error(Error::StsAssert, ""); |
||||
} |
||||
|
||||
m_initSize = patchSize; |
||||
m_curSize = m_initSize; |
||||
m_scaleFactorWidth = m_scaleFactorHeight = 1.0f; |
||||
m_scaleAreas.resize(m_numAreas); |
||||
m_scaleWeights.resize(m_numAreas); |
||||
for (int curArea = 0; curArea < m_numAreas; curArea++) |
||||
{ |
||||
m_scaleAreas[curArea] = m_areas[curArea]; |
||||
m_scaleWeights[curArea] = (float)m_weights[curArea] / (float)(m_areas[curArea].width * m_areas[curArea].height); |
||||
} |
||||
} |
||||
|
||||
bool CvHaarEvaluator::FeatureHaar::eval(const Mat& image, Rect /*ROI*/, float* result) const |
||||
{ |
||||
|
||||
*result = 0.0f; |
||||
|
||||
for (int curArea = 0; curArea < m_numAreas; curArea++) |
||||
{ |
||||
*result += (float)getSum(image, Rect(m_areas[curArea].x, m_areas[curArea].y, m_areas[curArea].width, m_areas[curArea].height)) |
||||
* m_scaleWeights[curArea]; |
||||
} |
||||
|
||||
/*
|
||||
if( image->getUseVariance() ) |
||||
{ |
||||
float variance = (float) image->getVariance( ROI ); |
||||
*result /= variance; |
||||
} |
||||
*/ |
||||
|
||||
return true; |
||||
} |
||||
|
||||
float CvHaarEvaluator::FeatureHaar::getSum(const Mat& image, Rect imageROI) const |
||||
{ |
||||
// left upper Origin
|
||||
int OriginX = imageROI.x; |
||||
int OriginY = imageROI.y; |
||||
|
||||
// Check and fix width and height
|
||||
int Width = imageROI.width; |
||||
int Height = imageROI.height; |
||||
|
||||
if (OriginX + Width >= image.cols - 1) |
||||
Width = (image.cols - 1) - OriginX; |
||||
if (OriginY + Height >= image.rows - 1) |
||||
Height = (image.rows - 1) - OriginY; |
||||
|
||||
float value = 0; |
||||
int depth = image.depth(); |
||||
|
||||
if (depth == CV_8U || depth == CV_32S) |
||||
value = static_cast<float>(image.at<int>(OriginY + Height, OriginX + Width) + image.at<int>(OriginY, OriginX) - image.at<int>(OriginY, OriginX + Width) |
||||
- image.at<int>(OriginY + Height, OriginX)); |
||||
else if (depth == CV_64F) |
||||
value = static_cast<float>(image.at<double>(OriginY + Height, OriginX + Width) + image.at<double>(OriginY, OriginX) |
||||
- image.at<double>(OriginY, OriginX + Width) - image.at<double>(OriginY + Height, OriginX)); |
||||
else if (depth == CV_32F) |
||||
value = static_cast<float>(image.at<float>(OriginY + Height, OriginX + Width) + image.at<float>(OriginY, OriginX) - image.at<float>(OriginY, OriginX + Width) |
||||
- image.at<float>(OriginY + Height, OriginX)); |
||||
|
||||
return value; |
||||
} |
||||
|
||||
}}} // namespace cv::detail::tracking
|
@ -0,0 +1,356 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../../precomp.hpp" |
||||
#include "tracking_online_mil.hpp" |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
#define sign(s) ((s > 0) ? 1 : ((s < 0) ? -1 : 0)) |
||||
|
||||
template <class T> |
||||
class SortableElementRev |
||||
{ |
||||
public: |
||||
T _val; |
||||
int _ind; |
||||
SortableElementRev() |
||||
: _val(), _ind(0) |
||||
{ |
||||
} |
||||
SortableElementRev(T val, int ind) |
||||
{ |
||||
_val = val; |
||||
_ind = ind; |
||||
} |
||||
bool operator<(SortableElementRev<T>& b) |
||||
{ |
||||
return (_val < b._val); |
||||
}; |
||||
}; |
||||
|
||||
static bool CompareSortableElementRev(const SortableElementRev<float>& i, const SortableElementRev<float>& j) |
||||
{ |
||||
return i._val < j._val; |
||||
} |
||||
|
||||
template <class T> |
||||
void sort_order_des(std::vector<T>& v, std::vector<int>& order) |
||||
{ |
||||
uint n = (uint)v.size(); |
||||
std::vector<SortableElementRev<T>> v2; |
||||
v2.resize(n); |
||||
order.clear(); |
||||
order.resize(n); |
||||
for (uint i = 0; i < n; i++) |
||||
{ |
||||
v2[i]._ind = i; |
||||
v2[i]._val = v[i]; |
||||
} |
||||
//std::sort( v2.begin(), v2.end() );
|
||||
std::sort(v2.begin(), v2.end(), CompareSortableElementRev); |
||||
for (uint i = 0; i < n; i++) |
||||
{ |
||||
order[i] = v2[i]._ind; |
||||
v[i] = v2[i]._val; |
||||
} |
||||
}; |
||||
|
||||
//implementations for strong classifier
|
||||
|
||||
ClfMilBoost::Params::Params() |
||||
{ |
||||
_numSel = 50; |
||||
_numFeat = 250; |
||||
_lRate = 0.85f; |
||||
} |
||||
|
||||
ClfMilBoost::ClfMilBoost() |
||||
: _numsamples(0) |
||||
, _counter(0) |
||||
{ |
||||
_myParams = ClfMilBoost::Params(); |
||||
_numsamples = 0; |
||||
} |
||||
|
||||
ClfMilBoost::~ClfMilBoost() |
||||
{ |
||||
_selectors.clear(); |
||||
for (size_t i = 0; i < _weakclf.size(); i++) |
||||
delete _weakclf.at(i); |
||||
} |
||||
|
||||
void ClfMilBoost::init(const ClfMilBoost::Params& parameters) |
||||
{ |
||||
_myParams = parameters; |
||||
_numsamples = 0; |
||||
|
||||
//_ftrs = Ftr::generate( _myParams->_ftrParams, _myParams->_numFeat );
|
||||
// if( params->_storeFtrHistory )
|
||||
// Ftr::toViz( _ftrs, "haarftrs" );
|
||||
_weakclf.resize(_myParams._numFeat); |
||||
for (int k = 0; k < _myParams._numFeat; k++) |
||||
{ |
||||
_weakclf[k] = new ClfOnlineStump(k); |
||||
_weakclf[k]->_lRate = _myParams._lRate; |
||||
} |
||||
_counter = 0; |
||||
} |
||||
|
||||
void ClfMilBoost::update(const Mat& posx, const Mat& negx) |
||||
{ |
||||
int numneg = negx.rows; |
||||
int numpos = posx.rows; |
||||
|
||||
// compute ftrs
|
||||
//if( !posx.ftrsComputed() )
|
||||
// Ftr::compute( posx, _ftrs );
|
||||
//if( !negx.ftrsComputed() )
|
||||
// Ftr::compute( negx, _ftrs );
|
||||
|
||||
// initialize H
|
||||
static std::vector<float> Hpos, Hneg; |
||||
Hpos.clear(); |
||||
Hneg.clear(); |
||||
Hpos.resize(posx.rows, 0.0f), Hneg.resize(negx.rows, 0.0f); |
||||
|
||||
_selectors.clear(); |
||||
std::vector<float> posw(posx.rows), negw(negx.rows); |
||||
std::vector<std::vector<float>> pospred(_weakclf.size()), negpred(_weakclf.size()); |
||||
|
||||
// train all weak classifiers without weights
|
||||
#ifdef _OPENMP |
||||
#pragma omp parallel for |
||||
#endif |
||||
for (int m = 0; m < _myParams._numFeat; m++) |
||||
{ |
||||
_weakclf[m]->update(posx, negx); |
||||
pospred[m] = _weakclf[m]->classifySetF(posx); |
||||
negpred[m] = _weakclf[m]->classifySetF(negx); |
||||
} |
||||
|
||||
// pick the best features
|
||||
for (int s = 0; s < _myParams._numSel; s++) |
||||
{ |
||||
|
||||
// compute errors/likl for all weak clfs
|
||||
std::vector<float> poslikl(_weakclf.size(), 1.0f), neglikl(_weakclf.size()), likl(_weakclf.size()); |
||||
#ifdef _OPENMP |
||||
#pragma omp parallel for |
||||
#endif |
||||
for (int w = 0; w < (int)_weakclf.size(); w++) |
||||
{ |
||||
float lll = 1.0f; |
||||
for (int j = 0; j < numpos; j++) |
||||
lll *= (1 - sigmoid(Hpos[j] + pospred[w][j])); |
||||
poslikl[w] = (float)-log(1 - lll + 1e-5); |
||||
|
||||
lll = 0.0f; |
||||
for (int j = 0; j < numneg; j++) |
||||
lll += (float)-log(1e-5f + 1 - sigmoid(Hneg[j] + negpred[w][j])); |
||||
neglikl[w] = lll; |
||||
|
||||
likl[w] = poslikl[w] / numpos + neglikl[w] / numneg; |
||||
} |
||||
|
||||
// pick best weak clf
|
||||
std::vector<int> order; |
||||
sort_order_des(likl, order); |
||||
|
||||
// find best weakclf that isn't already included
|
||||
for (uint k = 0; k < order.size(); k++) |
||||
if (std::count(_selectors.begin(), _selectors.end(), order[k]) == 0) |
||||
{ |
||||
_selectors.push_back(order[k]); |
||||
break; |
||||
} |
||||
|
||||
// update H = H + h_m
|
||||
#ifdef _OPENMP |
||||
#pragma omp parallel for |
||||
#endif |
||||
for (int k = 0; k < posx.rows; k++) |
||||
Hpos[k] += pospred[_selectors[s]][k]; |
||||
#ifdef _OPENMP |
||||
#pragma omp parallel for |
||||
#endif |
||||
for (int k = 0; k < negx.rows; k++) |
||||
Hneg[k] += negpred[_selectors[s]][k]; |
||||
} |
||||
|
||||
//if( _myParams->_storeFtrHistory )
|
||||
//for ( uint j = 0; j < _selectors.size(); j++ )
|
||||
// _ftrHist( _selectors[j], _counter ) = 1.0f / ( j + 1 );
|
||||
|
||||
_counter++; |
||||
/* */ |
||||
return; |
||||
} |
||||
|
||||
std::vector<float> ClfMilBoost::classify(const Mat& x, bool logR) |
||||
{ |
||||
int numsamples = x.rows; |
||||
std::vector<float> res(numsamples); |
||||
std::vector<float> tr; |
||||
|
||||
for (uint w = 0; w < _selectors.size(); w++) |
||||
{ |
||||
tr = _weakclf[_selectors[w]]->classifySetF(x); |
||||
#ifdef _OPENMP |
||||
#pragma omp parallel for |
||||
#endif |
||||
for (int j = 0; j < numsamples; j++) |
||||
{ |
||||
res[j] += tr[j]; |
||||
} |
||||
} |
||||
|
||||
// return probabilities or log odds ratio
|
||||
if (!logR) |
||||
{ |
||||
#ifdef _OPENMP |
||||
#pragma omp parallel for |
||||
#endif |
||||
for (int j = 0; j < (int)res.size(); j++) |
||||
{ |
||||
res[j] = sigmoid(res[j]); |
||||
} |
||||
} |
||||
|
||||
return res; |
||||
} |
||||
|
||||
//implementations for weak classifier
|
||||
|
||||
ClfOnlineStump::ClfOnlineStump() |
||||
: _mu0(0), _mu1(0), _sig0(0), _sig1(0) |
||||
, _q(0) |
||||
, _s(0) |
||||
, _log_n1(0), _log_n0(0) |
||||
, _e1(0), _e0(0) |
||||
, _lRate(0) |
||||
{ |
||||
_trained = false; |
||||
_ind = -1; |
||||
init(); |
||||
} |
||||
|
||||
ClfOnlineStump::ClfOnlineStump(int ind) |
||||
: _mu0(0), _mu1(0), _sig0(0), _sig1(0) |
||||
, _q(0) |
||||
, _s(0) |
||||
, _log_n1(0), _log_n0(0) |
||||
, _e1(0), _e0(0) |
||||
, _lRate(0) |
||||
{ |
||||
_trained = false; |
||||
_ind = ind; |
||||
init(); |
||||
} |
||||
void ClfOnlineStump::init() |
||||
{ |
||||
_mu0 = 0; |
||||
_mu1 = 0; |
||||
_sig0 = 1; |
||||
_sig1 = 1; |
||||
_lRate = 0.85f; |
||||
_trained = false; |
||||
} |
||||
|
||||
void ClfOnlineStump::update(const Mat& posx, const Mat& negx, const Mat_<float>& /*posw*/, const Mat_<float>& /*negw*/) |
||||
{ |
||||
//std::cout << " ClfOnlineStump::update" << _ind << std::endl;
|
||||
float posmu = 0.0, negmu = 0.0; |
||||
if (posx.cols > 0) |
||||
posmu = float(mean(posx.col(_ind))[0]); |
||||
if (negx.cols > 0) |
||||
negmu = float(mean(negx.col(_ind))[0]); |
||||
|
||||
if (_trained) |
||||
{ |
||||
if (posx.cols > 0) |
||||
{ |
||||
_mu1 = (_lRate * _mu1 + (1 - _lRate) * posmu); |
||||
cv::Mat diff = posx.col(_ind) - _mu1; |
||||
_sig1 = _lRate * _sig1 + (1 - _lRate) * float(mean(diff.mul(diff))[0]); |
||||
} |
||||
if (negx.cols > 0) |
||||
{ |
||||
_mu0 = (_lRate * _mu0 + (1 - _lRate) * negmu); |
||||
cv::Mat diff = negx.col(_ind) - _mu0; |
||||
_sig0 = _lRate * _sig0 + (1 - _lRate) * float(mean(diff.mul(diff))[0]); |
||||
} |
||||
|
||||
_q = (_mu1 - _mu0) / 2; |
||||
_s = sign(_mu1 - _mu0); |
||||
_log_n0 = std::log(float(1.0f / pow(_sig0, 0.5f))); |
||||
_log_n1 = std::log(float(1.0f / pow(_sig1, 0.5f))); |
||||
//_e1 = -1.0f/(2.0f*_sig1+1e-99f);
|
||||
//_e0 = -1.0f/(2.0f*_sig0+1e-99f);
|
||||
_e1 = -1.0f / (2.0f * _sig1 + std::numeric_limits<float>::min()); |
||||
_e0 = -1.0f / (2.0f * _sig0 + std::numeric_limits<float>::min()); |
||||
} |
||||
else |
||||
{ |
||||
_trained = true; |
||||
if (posx.cols > 0) |
||||
{ |
||||
_mu1 = posmu; |
||||
cv::Scalar scal_mean, scal_std_dev; |
||||
cv::meanStdDev(posx.col(_ind), scal_mean, scal_std_dev); |
||||
_sig1 = float(scal_std_dev[0]) * float(scal_std_dev[0]) + 1e-9f; |
||||
} |
||||
|
||||
if (negx.cols > 0) |
||||
{ |
||||
_mu0 = negmu; |
||||
cv::Scalar scal_mean, scal_std_dev; |
||||
cv::meanStdDev(negx.col(_ind), scal_mean, scal_std_dev); |
||||
_sig0 = float(scal_std_dev[0]) * float(scal_std_dev[0]) + 1e-9f; |
||||
} |
||||
|
||||
_q = (_mu1 - _mu0) / 2; |
||||
_s = sign(_mu1 - _mu0); |
||||
_log_n0 = std::log(float(1.0f / pow(_sig0, 0.5f))); |
||||
_log_n1 = std::log(float(1.0f / pow(_sig1, 0.5f))); |
||||
//_e1 = -1.0f/(2.0f*_sig1+1e-99f);
|
||||
//_e0 = -1.0f/(2.0f*_sig0+1e-99f);
|
||||
_e1 = -1.0f / (2.0f * _sig1 + std::numeric_limits<float>::min()); |
||||
_e0 = -1.0f / (2.0f * _sig0 + std::numeric_limits<float>::min()); |
||||
} |
||||
} |
||||
|
||||
bool ClfOnlineStump::classify(const Mat& x, int i) |
||||
{ |
||||
float xx = x.at<float>(i, _ind); |
||||
double log_p0 = (xx - _mu0) * (xx - _mu0) * _e0 + _log_n0; |
||||
double log_p1 = (xx - _mu1) * (xx - _mu1) * _e1 + _log_n1; |
||||
return log_p1 > log_p0; |
||||
} |
||||
|
||||
float ClfOnlineStump::classifyF(const Mat& x, int i) |
||||
{ |
||||
float xx = x.at<float>(i, _ind); |
||||
double log_p0 = (xx - _mu0) * (xx - _mu0) * _e0 + _log_n0; |
||||
double log_p1 = (xx - _mu1) * (xx - _mu1) * _e1 + _log_n1; |
||||
return float(log_p1 - log_p0); |
||||
} |
||||
|
||||
inline std::vector<float> ClfOnlineStump::classifySetF(const Mat& x) |
||||
{ |
||||
std::vector<float> res(x.rows); |
||||
|
||||
#ifdef _OPENMP |
||||
#pragma omp parallel for |
||||
#endif |
||||
for (int k = 0; k < (int)res.size(); k++) |
||||
{ |
||||
res[k] = classifyF(x, k); |
||||
} |
||||
return res; |
||||
} |
||||
|
||||
}}} // namespace cv::detail::tracking
|
@ -0,0 +1,79 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#ifndef OPENCV_VIDEO_DETAIL_TRACKING_ONLINE_MIL_HPP |
||||
#define OPENCV_VIDEO_DETAIL_TRACKING_ONLINE_MIL_HPP |
||||
|
||||
#include <limits> |
||||
|
||||
namespace cv { |
||||
namespace detail { |
||||
inline namespace tracking { |
||||
|
||||
//! @addtogroup tracking_detail
|
||||
//! @{
|
||||
|
||||
//TODO based on the original implementation
|
||||
//http://vision.ucsd.edu/~bbabenko/project_miltrack.shtml
|
||||
|
||||
class ClfOnlineStump; |
||||
|
||||
class CV_EXPORTS ClfMilBoost |
||||
{ |
||||
public: |
||||
struct CV_EXPORTS Params |
||||
{ |
||||
Params(); |
||||
int _numSel; |
||||
int _numFeat; |
||||
float _lRate; |
||||
}; |
||||
|
||||
ClfMilBoost(); |
||||
~ClfMilBoost(); |
||||
void init(const ClfMilBoost::Params& parameters = ClfMilBoost::Params()); |
||||
void update(const Mat& posx, const Mat& negx); |
||||
std::vector<float> classify(const Mat& x, bool logR = true); |
||||
|
||||
inline float sigmoid(float x) |
||||
{ |
||||
return 1.0f / (1.0f + exp(-x)); |
||||
} |
||||
|
||||
private: |
||||
uint _numsamples; |
||||
ClfMilBoost::Params _myParams; |
||||
std::vector<int> _selectors; |
||||
std::vector<ClfOnlineStump*> _weakclf; |
||||
uint _counter; |
||||
}; |
||||
|
||||
class ClfOnlineStump |
||||
{ |
||||
public: |
||||
float _mu0, _mu1, _sig0, _sig1; |
||||
float _q; |
||||
int _s; |
||||
float _log_n1, _log_n0; |
||||
float _e1, _e0; |
||||
float _lRate; |
||||
|
||||
ClfOnlineStump(); |
||||
ClfOnlineStump(int ind); |
||||
void init(); |
||||
void update(const Mat& posx, const Mat& negx, const cv::Mat_<float>& posw = cv::Mat_<float>(), const cv::Mat_<float>& negw = cv::Mat_<float>()); |
||||
bool classify(const Mat& x, int i); |
||||
float classifyF(const Mat& x, int i); |
||||
std::vector<float> classifySetF(const Mat& x); |
||||
|
||||
private: |
||||
bool _trained; |
||||
int _ind; |
||||
}; |
||||
|
||||
//! @}
|
||||
|
||||
}}} // namespace cv::detail::tracking
|
||||
|
||||
#endif |
@ -0,0 +1,19 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
|
||||
namespace cv { |
||||
|
||||
Tracker::Tracker() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
Tracker::~Tracker() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
} // namespace cv
|
@ -0,0 +1,140 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
|
||||
#ifdef HAVE_OPENCV_DNN |
||||
#include "opencv2/dnn.hpp" |
||||
#endif |
||||
|
||||
namespace cv { |
||||
|
||||
TrackerGOTURN::TrackerGOTURN() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
TrackerGOTURN::~TrackerGOTURN() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
TrackerGOTURN::Params::Params() |
||||
{ |
||||
modelTxt = "goturn.prototxt"; |
||||
modelBin = "goturn.caffemodel"; |
||||
} |
||||
|
||||
#ifdef HAVE_OPENCV_DNN |
||||
|
||||
class TrackerGOTURNImpl : public TrackerGOTURN |
||||
{ |
||||
public: |
||||
TrackerGOTURNImpl(const TrackerGOTURN::Params& parameters) |
||||
: params(parameters) |
||||
{ |
||||
// Load GOTURN architecture from *.prototxt and pretrained weights from *.caffemodel
|
||||
net = dnn::readNetFromCaffe(params.modelTxt, params.modelBin); |
||||
CV_Assert(!net.empty()); |
||||
} |
||||
|
||||
void init(InputArray image, const Rect& boundingBox) CV_OVERRIDE; |
||||
bool update(InputArray image, Rect& boundingBox) CV_OVERRIDE; |
||||
|
||||
void setBoudingBox(Rect boundingBox) |
||||
{ |
||||
if (image_.empty()) |
||||
CV_Error(Error::StsInternal, "Set image first"); |
||||
boundingBox_ = boundingBox & Rect(Point(0, 0), image_.size()); |
||||
} |
||||
|
||||
TrackerGOTURN::Params params; |
||||
|
||||
dnn::Net net; |
||||
Rect boundingBox_; |
||||
Mat image_; |
||||
}; |
||||
|
||||
void TrackerGOTURNImpl::init(InputArray image, const Rect& boundingBox) |
||||
{ |
||||
image_ = image.getMat().clone(); |
||||
setBoudingBox(boundingBox); |
||||
} |
||||
|
||||
bool TrackerGOTURNImpl::update(InputArray image, Rect& boundingBox) |
||||
{ |
||||
int INPUT_SIZE = 227; |
||||
//Using prevFrame & prevBB from model and curFrame GOTURN calculating curBB
|
||||
InputArray curFrame = image; |
||||
Mat prevFrame = image_; |
||||
Rect2d prevBB = boundingBox_; |
||||
Rect curBB; |
||||
|
||||
float padTargetPatch = 2.0; |
||||
Rect2f searchPatchRect, targetPatchRect; |
||||
Point2f currCenter, prevCenter; |
||||
Mat prevFramePadded, curFramePadded; |
||||
Mat searchPatch, targetPatch; |
||||
|
||||
prevCenter.x = (float)(prevBB.x + prevBB.width / 2); |
||||
prevCenter.y = (float)(prevBB.y + prevBB.height / 2); |
||||
|
||||
targetPatchRect.width = (float)(prevBB.width * padTargetPatch); |
||||
targetPatchRect.height = (float)(prevBB.height * padTargetPatch); |
||||
targetPatchRect.x = (float)(prevCenter.x - prevBB.width * padTargetPatch / 2.0 + targetPatchRect.width); |
||||
targetPatchRect.y = (float)(prevCenter.y - prevBB.height * padTargetPatch / 2.0 + targetPatchRect.height); |
||||
|
||||
targetPatchRect.width = std::min(targetPatchRect.width, (float)prevFrame.cols); |
||||
targetPatchRect.height = std::min(targetPatchRect.height, (float)prevFrame.rows); |
||||
targetPatchRect.x = std::max(-prevFrame.cols * 0.5f, std::min(targetPatchRect.x, prevFrame.cols * 1.5f)); |
||||
targetPatchRect.y = std::max(-prevFrame.rows * 0.5f, std::min(targetPatchRect.y, prevFrame.rows * 1.5f)); |
||||
|
||||
copyMakeBorder(prevFrame, prevFramePadded, (int)targetPatchRect.height, (int)targetPatchRect.height, (int)targetPatchRect.width, (int)targetPatchRect.width, BORDER_REPLICATE); |
||||
targetPatch = prevFramePadded(targetPatchRect).clone(); |
||||
|
||||
copyMakeBorder(curFrame, curFramePadded, (int)targetPatchRect.height, (int)targetPatchRect.height, (int)targetPatchRect.width, (int)targetPatchRect.width, BORDER_REPLICATE); |
||||
searchPatch = curFramePadded(targetPatchRect).clone(); |
||||
|
||||
// Preprocess
|
||||
// Resize
|
||||
resize(targetPatch, targetPatch, Size(INPUT_SIZE, INPUT_SIZE), 0, 0, INTER_LINEAR_EXACT); |
||||
resize(searchPatch, searchPatch, Size(INPUT_SIZE, INPUT_SIZE), 0, 0, INTER_LINEAR_EXACT); |
||||
|
||||
// Convert to Float type and subtract mean
|
||||
Mat targetBlob = dnn::blobFromImage(targetPatch, 1.0f, Size(), Scalar::all(128), false); |
||||
Mat searchBlob = dnn::blobFromImage(searchPatch, 1.0f, Size(), Scalar::all(128), false); |
||||
|
||||
net.setInput(targetBlob, "data1"); |
||||
net.setInput(searchBlob, "data2"); |
||||
|
||||
Mat resMat = net.forward("scale").reshape(1, 1); |
||||
|
||||
curBB.x = cvRound(targetPatchRect.x + (resMat.at<float>(0) * targetPatchRect.width / INPUT_SIZE) - targetPatchRect.width); |
||||
curBB.y = cvRound(targetPatchRect.y + (resMat.at<float>(1) * targetPatchRect.height / INPUT_SIZE) - targetPatchRect.height); |
||||
curBB.width = cvRound((resMat.at<float>(2) - resMat.at<float>(0)) * targetPatchRect.width / INPUT_SIZE); |
||||
curBB.height = cvRound((resMat.at<float>(3) - resMat.at<float>(1)) * targetPatchRect.height / INPUT_SIZE); |
||||
|
||||
// Predicted BB
|
||||
boundingBox = curBB & Rect(Point(0, 0), image_.size()); |
||||
|
||||
// Set new model image and BB from current frame
|
||||
image_ = image.getMat().clone(); |
||||
setBoudingBox(curBB); |
||||
return true; |
||||
} |
||||
|
||||
Ptr<TrackerGOTURN> TrackerGOTURN::create(const TrackerGOTURN::Params& parameters) |
||||
{ |
||||
return makePtr<TrackerGOTURNImpl>(parameters); |
||||
} |
||||
|
||||
#else // OPENCV_HAVE_DNN
|
||||
Ptr<TrackerGOTURN> TrackerGOTURN::create(const TrackerGOTURN::Params& parameters) |
||||
{ |
||||
(void)(parameters); |
||||
CV_Error(cv::Error::StsNotImplemented, "to use GOTURN, the tracking module needs to be built with opencv_dnn !"); |
||||
} |
||||
#endif // OPENCV_HAVE_DNN
|
||||
|
||||
} // namespace cv
|
@ -0,0 +1,227 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "../precomp.hpp" |
||||
#include "detail/tracker_mil_model.hpp" |
||||
|
||||
#include "detail/tracker_feature_haar.impl.hpp" |
||||
|
||||
namespace cv { |
||||
inline namespace tracking { |
||||
namespace impl { |
||||
|
||||
using cv::detail::tracking::internal::TrackerFeatureHAAR; |
||||
|
||||
|
||||
class TrackerMILImpl CV_FINAL : public TrackerMIL |
||||
{ |
||||
public: |
||||
TrackerMILImpl(const TrackerMIL::Params& parameters); |
||||
|
||||
virtual void init(InputArray image, const Rect& boundingBox) CV_OVERRIDE; |
||||
virtual bool update(InputArray image, Rect& boundingBox) CV_OVERRIDE; |
||||
|
||||
void compute_integral(const Mat& img, Mat& ii_img); |
||||
|
||||
TrackerMIL::Params params; |
||||
|
||||
Ptr<TrackerMILModel> model; |
||||
Ptr<TrackerSampler> sampler; |
||||
Ptr<TrackerFeatureSet> featureSet; |
||||
}; |
||||
|
||||
TrackerMILImpl::TrackerMILImpl(const TrackerMIL::Params& parameters) |
||||
: params(parameters) |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
void TrackerMILImpl::compute_integral(const Mat& img, Mat& ii_img) |
||||
{ |
||||
Mat ii; |
||||
std::vector<Mat> ii_imgs; |
||||
integral(img, ii, CV_32F); // FIXIT split first
|
||||
split(ii, ii_imgs); |
||||
ii_img = ii_imgs[0]; |
||||
} |
||||
|
||||
void TrackerMILImpl::init(InputArray image, const Rect& boundingBox) |
||||
{ |
||||
sampler = makePtr<TrackerSampler>(); |
||||
featureSet = makePtr<TrackerFeatureSet>(); |
||||
|
||||
Mat intImage; |
||||
compute_integral(image.getMat(), intImage); |
||||
TrackerSamplerCSC::Params CSCparameters; |
||||
CSCparameters.initInRad = params.samplerInitInRadius; |
||||
CSCparameters.searchWinSize = params.samplerSearchWinSize; |
||||
CSCparameters.initMaxNegNum = params.samplerInitMaxNegNum; |
||||
CSCparameters.trackInPosRad = params.samplerTrackInRadius; |
||||
CSCparameters.trackMaxPosNum = params.samplerTrackMaxPosNum; |
||||
CSCparameters.trackMaxNegNum = params.samplerTrackMaxNegNum; |
||||
|
||||
Ptr<TrackerSamplerAlgorithm> CSCSampler = makePtr<TrackerSamplerCSC>(CSCparameters); |
||||
CV_Assert(sampler->addTrackerSamplerAlgorithm(CSCSampler)); |
||||
|
||||
//or add CSC sampler with default parameters
|
||||
//sampler->addTrackerSamplerAlgorithm( "CSC" );
|
||||
|
||||
//Positive sampling
|
||||
CSCSampler.staticCast<TrackerSamplerCSC>()->setMode(TrackerSamplerCSC::MODE_INIT_POS); |
||||
sampler->sampling(intImage, boundingBox); |
||||
std::vector<Mat> posSamples = sampler->getSamples(); |
||||
|
||||
//Negative sampling
|
||||
CSCSampler.staticCast<TrackerSamplerCSC>()->setMode(TrackerSamplerCSC::MODE_INIT_NEG); |
||||
sampler->sampling(intImage, boundingBox); |
||||
std::vector<Mat> negSamples = sampler->getSamples(); |
||||
|
||||
CV_Assert(!posSamples.empty()); |
||||
CV_Assert(!negSamples.empty()); |
||||
|
||||
//compute HAAR features
|
||||
TrackerFeatureHAAR::Params HAARparameters; |
||||
HAARparameters.numFeatures = params.featureSetNumFeatures; |
||||
HAARparameters.rectSize = Size((int)boundingBox.width, (int)boundingBox.height); |
||||
HAARparameters.isIntegral = true; |
||||
Ptr<TrackerFeature> trackerFeature = makePtr<TrackerFeatureHAAR>(HAARparameters); |
||||
featureSet->addTrackerFeature(trackerFeature); |
||||
|
||||
featureSet->extraction(posSamples); |
||||
const std::vector<Mat> posResponse = featureSet->getResponses(); |
||||
|
||||
featureSet->extraction(negSamples); |
||||
const std::vector<Mat> negResponse = featureSet->getResponses(); |
||||
|
||||
model = makePtr<TrackerMILModel>(boundingBox); |
||||
Ptr<TrackerStateEstimatorMILBoosting> stateEstimator = makePtr<TrackerStateEstimatorMILBoosting>(params.featureSetNumFeatures); |
||||
model->setTrackerStateEstimator(stateEstimator); |
||||
|
||||
//Run model estimation and update
|
||||
model.staticCast<TrackerMILModel>()->setMode(TrackerMILModel::MODE_POSITIVE, posSamples); |
||||
model->modelEstimation(posResponse); |
||||
model.staticCast<TrackerMILModel>()->setMode(TrackerMILModel::MODE_NEGATIVE, negSamples); |
||||
model->modelEstimation(negResponse); |
||||
model->modelUpdate(); |
||||
} |
||||
|
||||
bool TrackerMILImpl::update(InputArray image, Rect& boundingBox) |
||||
{ |
||||
Mat intImage; |
||||
compute_integral(image.getMat(), intImage); |
||||
|
||||
//get the last location [AAM] X(k-1)
|
||||
Ptr<TrackerTargetState> lastLocation = model->getLastTargetState(); |
||||
Rect lastBoundingBox((int)lastLocation->getTargetPosition().x, (int)lastLocation->getTargetPosition().y, lastLocation->getTargetWidth(), |
||||
lastLocation->getTargetHeight()); |
||||
|
||||
//sampling new frame based on last location
|
||||
auto& samplers = sampler->getSamplers(); |
||||
CV_Assert(!samplers.empty()); |
||||
CV_Assert(samplers[0]); |
||||
samplers[0].staticCast<TrackerSamplerCSC>()->setMode(TrackerSamplerCSC::MODE_DETECT); |
||||
sampler->sampling(intImage, lastBoundingBox); |
||||
std::vector<Mat> detectSamples = sampler->getSamples(); |
||||
if (detectSamples.empty()) |
||||
return false; |
||||
|
||||
/*//TODO debug samples
|
||||
Mat f; |
||||
image.copyTo(f); |
||||
|
||||
for( size_t i = 0; i < detectSamples.size(); i=i+10 ) |
||||
{ |
||||
Size sz; |
||||
Point off; |
||||
detectSamples.at(i).locateROI(sz, off); |
||||
rectangle(f, Rect(off.x,off.y,detectSamples.at(i).cols,detectSamples.at(i).rows), Scalar(255,0,0), 1); |
||||
}*/ |
||||
|
||||
//extract features from new samples
|
||||
featureSet->extraction(detectSamples); |
||||
std::vector<Mat> response = featureSet->getResponses(); |
||||
|
||||
//predict new location
|
||||
ConfidenceMap cmap; |
||||
model.staticCast<TrackerMILModel>()->setMode(TrackerMILModel::MODE_ESTIMATON, detectSamples); |
||||
model.staticCast<TrackerMILModel>()->responseToConfidenceMap(response, cmap); |
||||
model->getTrackerStateEstimator().staticCast<TrackerStateEstimatorMILBoosting>()->setCurrentConfidenceMap(cmap); |
||||
|
||||
if (!model->runStateEstimator()) |
||||
{ |
||||
return false; |
||||
} |
||||
|
||||
Ptr<TrackerTargetState> currentState = model->getLastTargetState(); |
||||
boundingBox = Rect((int)currentState->getTargetPosition().x, (int)currentState->getTargetPosition().y, currentState->getTargetWidth(), |
||||
currentState->getTargetHeight()); |
||||
|
||||
/*//TODO debug
|
||||
rectangle(f, lastBoundingBox, Scalar(0,255,0), 1); |
||||
rectangle(f, boundingBox, Scalar(0,0,255), 1); |
||||
imshow("f", f); |
||||
//waitKey( 0 );*/
|
||||
|
||||
//sampling new frame based on new location
|
||||
//Positive sampling
|
||||
samplers[0].staticCast<TrackerSamplerCSC>()->setMode(TrackerSamplerCSC::MODE_INIT_POS); |
||||
sampler->sampling(intImage, boundingBox); |
||||
std::vector<Mat> posSamples = sampler->getSamples(); |
||||
|
||||
//Negative sampling
|
||||
samplers[0].staticCast<TrackerSamplerCSC>()->setMode(TrackerSamplerCSC::MODE_INIT_NEG); |
||||
sampler->sampling(intImage, boundingBox); |
||||
std::vector<Mat> negSamples = sampler->getSamples(); |
||||
|
||||
if (posSamples.empty() || negSamples.empty()) |
||||
return false; |
||||
|
||||
//extract features
|
||||
featureSet->extraction(posSamples); |
||||
std::vector<Mat> posResponse = featureSet->getResponses(); |
||||
|
||||
featureSet->extraction(negSamples); |
||||
std::vector<Mat> negResponse = featureSet->getResponses(); |
||||
|
||||
//model estimate
|
||||
model.staticCast<TrackerMILModel>()->setMode(TrackerMILModel::MODE_POSITIVE, posSamples); |
||||
model->modelEstimation(posResponse); |
||||
model.staticCast<TrackerMILModel>()->setMode(TrackerMILModel::MODE_NEGATIVE, negSamples); |
||||
model->modelEstimation(negResponse); |
||||
|
||||
//model update
|
||||
model->modelUpdate(); |
||||
|
||||
return true; |
||||
} |
||||
|
||||
}} // namespace tracking::impl
|
||||
|
||||
TrackerMIL::Params::Params() |
||||
{ |
||||
samplerInitInRadius = 3; |
||||
samplerSearchWinSize = 25; |
||||
samplerInitMaxNegNum = 65; |
||||
samplerTrackInRadius = 4; |
||||
samplerTrackMaxPosNum = 100000; |
||||
samplerTrackMaxNegNum = 65; |
||||
featureSetNumFeatures = 250; |
||||
} |
||||
|
||||
TrackerMIL::TrackerMIL() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
TrackerMIL::~TrackerMIL() |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
Ptr<TrackerMIL> TrackerMIL::create(const TrackerMIL::Params& parameters) |
||||
{ |
||||
return makePtr<tracking::impl::TrackerMILImpl>(parameters); |
||||
} |
||||
|
||||
} // namespace cv
|
@ -0,0 +1,97 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
#include "test_precomp.hpp" |
||||
|
||||
//#define DEBUG_TEST
|
||||
#ifdef DEBUG_TEST |
||||
#include <opencv2/highgui.hpp> |
||||
#endif |
||||
|
||||
namespace opencv_test { namespace { |
||||
//using namespace cv::tracking;
|
||||
|
||||
#define TESTSET_NAMES testing::Values("david", "dudek", "faceocc2") |
||||
|
||||
const string TRACKING_DIR = "tracking"; |
||||
const string FOLDER_IMG = "data"; |
||||
const string FOLDER_OMIT_INIT = "initOmit"; |
||||
|
||||
#include "test_trackers.impl.hpp" |
||||
|
||||
//[TESTDATA]
|
||||
PARAM_TEST_CASE(DistanceAndOverlap, string) |
||||
{ |
||||
string dataset; |
||||
virtual void SetUp() |
||||
{ |
||||
dataset = GET_PARAM(0); |
||||
} |
||||
}; |
||||
|
||||
TEST_P(DistanceAndOverlap, MIL) |
||||
{ |
||||
TrackerTest<Tracker, Rect> test(TrackerMIL::create(), dataset, 30, .65f, NoTransform); |
||||
test.run(); |
||||
} |
||||
|
||||
TEST_P(DistanceAndOverlap, Shifted_Data_MIL) |
||||
{ |
||||
TrackerTest<Tracker, Rect> test(TrackerMIL::create(), dataset, 30, .6f, CenterShiftLeft); |
||||
test.run(); |
||||
} |
||||
|
||||
/***************************************************************************************/ |
||||
//Tests with scaled initial window
|
||||
|
||||
TEST_P(DistanceAndOverlap, Scaled_Data_MIL) |
||||
{ |
||||
TrackerTest<Tracker, Rect> test(TrackerMIL::create(), dataset, 30, .7f, Scale_1_1); |
||||
test.run(); |
||||
} |
||||
|
||||
TEST_P(DistanceAndOverlap, GOTURN) |
||||
{ |
||||
std::string model = cvtest::findDataFile("dnn/gsoc2016-goturn/goturn.prototxt"); |
||||
std::string weights = cvtest::findDataFile("dnn/gsoc2016-goturn/goturn.caffemodel", false); |
||||
cv::TrackerGOTURN::Params params; |
||||
params.modelTxt = model; |
||||
params.modelBin = weights; |
||||
TrackerTest<Tracker, Rect> test(TrackerGOTURN::create(params), dataset, 35, .35f, NoTransform); |
||||
test.run(); |
||||
} |
||||
|
||||
INSTANTIATE_TEST_CASE_P(Tracking, DistanceAndOverlap, TESTSET_NAMES); |
||||
|
||||
TEST(GOTURN, memory_usage) |
||||
{ |
||||
cv::Rect roi(145, 70, 85, 85); |
||||
|
||||
std::string model = cvtest::findDataFile("dnn/gsoc2016-goturn/goturn.prototxt"); |
||||
std::string weights = cvtest::findDataFile("dnn/gsoc2016-goturn/goturn.caffemodel", false); |
||||
cv::TrackerGOTURN::Params params; |
||||
params.modelTxt = model; |
||||
params.modelBin = weights; |
||||
cv::Ptr<Tracker> tracker = TrackerGOTURN::create(params); |
||||
|
||||
string inputVideo = cvtest::findDataFile("tracking/david/data/david.webm"); |
||||
cv::VideoCapture video(inputVideo); |
||||
ASSERT_TRUE(video.isOpened()) << inputVideo; |
||||
|
||||
cv::Mat frame; |
||||
video >> frame; |
||||
ASSERT_FALSE(frame.empty()) << inputVideo; |
||||
tracker->init(frame, roi); |
||||
string ground_truth_bb; |
||||
for (int nframes = 0; nframes < 15; ++nframes) |
||||
{ |
||||
std::cout << "Frame: " << nframes << std::endl; |
||||
video >> frame; |
||||
bool res = tracker->update(frame, roi); |
||||
ASSERT_TRUE(res); |
||||
std::cout << "Predicted ROI: " << roi << std::endl; |
||||
} |
||||
} |
||||
|
||||
}} // namespace opencv_test::
|
@ -0,0 +1,368 @@ |
||||
// This file is part of OpenCV project.
|
||||
// It is subject to the license terms in the LICENSE file found in the top-level directory
|
||||
// of this distribution and at http://opencv.org/license.html.
|
||||
|
||||
/*
|
||||
* The Evaluation Methodologies are partially based on: |
||||
* ==================================================================================================================== |
||||
* [OTB] Y. Wu, J. Lim, and M.-H. Yang, "Online object tracking: A benchmark," in Computer Vision and Pattern Recognition (CVPR), 2013 |
||||
* |
||||
*/ |
||||
|
||||
enum BBTransformations |
||||
{ |
||||
NoTransform = 0, |
||||
CenterShiftLeft = 1, |
||||
CenterShiftRight = 2, |
||||
CenterShiftUp = 3, |
||||
CenterShiftDown = 4, |
||||
CornerShiftTopLeft = 5, |
||||
CornerShiftTopRight = 6, |
||||
CornerShiftBottomLeft = 7, |
||||
CornerShiftBottomRight = 8, |
||||
Scale_0_8 = 9, |
||||
Scale_0_9 = 10, |
||||
Scale_1_1 = 11, |
||||
Scale_1_2 = 12 |
||||
}; |
||||
|
||||
namespace { |
||||
|
||||
std::vector<std::string> splitString(const std::string& s_, const std::string& delimiter) |
||||
{ |
||||
std::string s = s_; |
||||
std::vector<string> token; |
||||
size_t pos = 0; |
||||
while ((pos = s.find(delimiter)) != std::string::npos) |
||||
{ |
||||
token.push_back(s.substr(0, pos)); |
||||
s.erase(0, pos + delimiter.length()); |
||||
} |
||||
token.push_back(s); |
||||
return token; |
||||
} |
||||
|
||||
float calcDistance(const Rect& a, const Rect& b) |
||||
{ |
||||
Point2f p_a((float)(a.x + a.width / 2), (float)(a.y + a.height / 2)); |
||||
Point2f p_b((float)(b.x + b.width / 2), (float)(b.y + b.height / 2)); |
||||
return sqrt(pow(p_a.x - p_b.x, 2) + pow(p_a.y - p_b.y, 2)); |
||||
} |
||||
|
||||
float calcOverlap(const Rect& a, const Rect& b) |
||||
{ |
||||
float rectIntersectionArea = (float)(a & b).area(); |
||||
return rectIntersectionArea / (a.area() + b.area() - rectIntersectionArea); |
||||
} |
||||
|
||||
} // namespace
|
||||
|
||||
template <typename Tracker, typename ROI_t = Rect2d> |
||||
class TrackerTest |
||||
{ |
||||
public: |
||||
TrackerTest(const Ptr<Tracker>& tracker, const string& video, float distanceThreshold, |
||||
float overlapThreshold, int shift = NoTransform, int segmentIdx = 1, int numSegments = 10); |
||||
~TrackerTest() {} |
||||
void run(); |
||||
|
||||
protected: |
||||
void checkDataTest(); |
||||
|
||||
void distanceAndOverlapTest(); |
||||
|
||||
Ptr<Tracker> tracker; |
||||
string video; |
||||
std::vector<Rect> bbs; |
||||
int startFrame; |
||||
string suffix; |
||||
string prefix; |
||||
float overlapThreshold; |
||||
float distanceThreshold; |
||||
int segmentIdx; |
||||
int shift; |
||||
int numSegments; |
||||
|
||||
int gtStartFrame; |
||||
int endFrame; |
||||
vector<int> validSequence; |
||||
|
||||
private: |
||||
Rect applyShift(const Rect& bb); |
||||
}; |
||||
|
||||
template <typename Tracker, typename ROI_t> |
||||
TrackerTest<Tracker, ROI_t>::TrackerTest(const Ptr<Tracker>& _tracker, const string& _video, float _distanceThreshold, |
||||
float _overlapThreshold, int _shift, int _segmentIdx, int _numSegments) |
||||
: tracker(_tracker) |
||||
, video(_video) |
||||
, overlapThreshold(_overlapThreshold) |
||||
, distanceThreshold(_distanceThreshold) |
||||
, segmentIdx(_segmentIdx) |
||||
, shift(_shift) |
||||
, numSegments(_numSegments) |
||||
{ |
||||
// nothing
|
||||
} |
||||
|
||||
template <typename Tracker, typename ROI_t> |
||||
Rect TrackerTest<Tracker, ROI_t>::applyShift(const Rect& bb_) |
||||
{ |
||||
Rect bb = bb_; |
||||
Point center(bb.x + (bb.width / 2), bb.y + (bb.height / 2)); |
||||
|
||||
int xLimit = bb.x + bb.width - 1; |
||||
int yLimit = bb.y + bb.height - 1; |
||||
|
||||
int h = 0; |
||||
int w = 0; |
||||
float ratio = 1.0; |
||||
|
||||
switch (shift) |
||||
{ |
||||
case CenterShiftLeft: |
||||
bb.x = bb.x - (int)ceil(0.1 * bb.width); |
||||
break; |
||||
case CenterShiftRight: |
||||
bb.x = bb.x + (int)ceil(0.1 * bb.width); |
||||
break; |
||||
case CenterShiftUp: |
||||
bb.y = bb.y - (int)ceil(0.1 * bb.height); |
||||
break; |
||||
case CenterShiftDown: |
||||
bb.y = bb.y + (int)ceil(0.1 * bb.height); |
||||
break; |
||||
case CornerShiftTopLeft: |
||||
bb.x = (int)cvRound(bb.x - 0.1 * bb.width); |
||||
bb.y = (int)cvRound(bb.y - 0.1 * bb.height); |
||||
|
||||
bb.width = xLimit - bb.x + 1; |
||||
bb.height = yLimit - bb.y + 1; |
||||
break; |
||||
case CornerShiftTopRight: |
||||
xLimit = (int)cvRound(xLimit + 0.1 * bb.width); |
||||
|
||||
bb.y = (int)cvRound(bb.y - 0.1 * bb.height); |
||||
bb.width = xLimit - bb.x + 1; |
||||
bb.height = yLimit - bb.y + 1; |
||||
break; |
||||
case CornerShiftBottomLeft: |
||||
bb.x = (int)cvRound(bb.x - 0.1 * bb.width); |
||||
yLimit = (int)cvRound(yLimit + 0.1 * bb.height); |
||||
|
||||
bb.width = xLimit - bb.x + 1; |
||||
bb.height = yLimit - bb.y + 1; |
||||
break; |
||||
case CornerShiftBottomRight: |
||||
xLimit = (int)cvRound(xLimit + 0.1 * bb.width); |
||||
yLimit = (int)cvRound(yLimit + 0.1 * bb.height); |
||||
|
||||
bb.width = xLimit - bb.x + 1; |
||||
bb.height = yLimit - bb.y + 1; |
||||
break; |
||||
case Scale_0_8: |
||||
ratio = 0.8f; |
||||
w = (int)(ratio * bb.width); |
||||
h = (int)(ratio * bb.height); |
||||
|
||||
bb = Rect(center.x - (w / 2), center.y - (h / 2), w, h); |
||||
break; |
||||
case Scale_0_9: |
||||
ratio = 0.9f; |
||||
w = (int)(ratio * bb.width); |
||||
h = (int)(ratio * bb.height); |
||||
|
||||
bb = Rect(center.x - (w / 2), center.y - (h / 2), w, h); |
||||
break; |
||||
case 11: |
||||
//scale 1.1
|
||||
ratio = 1.1f; |
||||
w = (int)(ratio * bb.width); |
||||
h = (int)(ratio * bb.height); |
||||
|
||||
bb = Rect(center.x - (w / 2), center.y - (h / 2), w, h); |
||||
break; |
||||
case 12: |
||||
//scale 1.2
|
||||
ratio = 1.2f; |
||||
w = (int)(ratio * bb.width); |
||||
h = (int)(ratio * bb.height); |
||||
|
||||
bb = Rect(center.x - (w / 2), center.y - (h / 2), w, h); |
||||
break; |
||||
default: |
||||
break; |
||||
} |
||||
|
||||
return bb; |
||||
} |
||||
|
||||
template <typename Tracker, typename ROI_t> |
||||
void TrackerTest<Tracker, ROI_t>::distanceAndOverlapTest() |
||||
{ |
||||
bool initialized = false; |
||||
|
||||
int fc = (startFrame - gtStartFrame); |
||||
|
||||
bbs.at(fc) = applyShift(bbs.at(fc)); |
||||
Rect currentBBi = bbs.at(fc); |
||||
ROI_t currentBB(currentBBi); |
||||
float sumDistance = 0; |
||||
float sumOverlap = 0; |
||||
|
||||
string folder = cvtest::TS::ptr()->get_data_path() + "/" + TRACKING_DIR + "/" + video + "/" + FOLDER_IMG; |
||||
string videoPath = folder + "/" + video + ".webm"; |
||||
|
||||
VideoCapture c; |
||||
c.open(videoPath); |
||||
if (!c.isOpened()) |
||||
throw SkipTestException("Can't open video file"); |
||||
#if 0 |
||||
c.set(CAP_PROP_POS_FRAMES, startFrame); |
||||
#else |
||||
if (startFrame) |
||||
std::cout << "startFrame = " << startFrame << std::endl; |
||||
for (int i = 0; i < startFrame; i++) |
||||
{ |
||||
Mat dummy_frame; |
||||
c >> dummy_frame; |
||||
ASSERT_FALSE(dummy_frame.empty()) << i << ": " << videoPath; |
||||
} |
||||
#endif |
||||
|
||||
for (int frameCounter = startFrame; frameCounter < endFrame; frameCounter++) |
||||
{ |
||||
Mat frame; |
||||
c >> frame; |
||||
|
||||
ASSERT_FALSE(frame.empty()) << "frameCounter=" << frameCounter << " video=" << videoPath; |
||||
if (!initialized) |
||||
{ |
||||
tracker->init(frame, currentBB); |
||||
std::cout << "frame size = " << frame.size() << std::endl; |
||||
initialized = true; |
||||
} |
||||
else if (initialized) |
||||
{ |
||||
if (frameCounter >= (int)bbs.size()) |
||||
break; |
||||
tracker->update(frame, currentBB); |
||||
} |
||||
float curDistance = calcDistance(currentBB, bbs.at(fc)); |
||||
float curOverlap = calcOverlap(currentBB, bbs.at(fc)); |
||||
|
||||
#ifdef DEBUG_TEST |
||||
Mat result; |
||||
repeat(frame, 1, 2, result); |
||||
rectangle(result, currentBB, Scalar(0, 255, 0), 1); |
||||
Rect roi2(frame.cols, 0, frame.cols, frame.rows); |
||||
rectangle(result(roi2), bbs.at(fc), Scalar(0, 0, 255), 1); |
||||
imshow("result", result); |
||||
waitKey(1); |
||||
#endif |
||||
|
||||
sumDistance += curDistance; |
||||
sumOverlap += curOverlap; |
||||
fc++; |
||||
} |
||||
|
||||
float meanDistance = sumDistance / (endFrame - startFrame); |
||||
float meanOverlap = sumOverlap / (endFrame - startFrame); |
||||
|
||||
EXPECT_LE(meanDistance, distanceThreshold); |
||||
EXPECT_GE(meanOverlap, overlapThreshold); |
||||
} |
||||
|
||||
template <typename Tracker, typename ROI_t> |
||||
void TrackerTest<Tracker, ROI_t>::checkDataTest() |
||||
{ |
||||
|
||||
FileStorage fs; |
||||
fs.open(cvtest::TS::ptr()->get_data_path() + TRACKING_DIR + "/" + video + "/" + video + ".yml", FileStorage::READ); |
||||
fs["start"] >> startFrame; |
||||
fs["prefix"] >> prefix; |
||||
fs["suffix"] >> suffix; |
||||
fs.release(); |
||||
|
||||
string gtFile = cvtest::TS::ptr()->get_data_path() + TRACKING_DIR + "/" + video + "/gt.txt"; |
||||
std::ifstream gt; |
||||
//open the ground truth
|
||||
gt.open(gtFile.c_str()); |
||||
ASSERT_TRUE(gt.is_open()) << gtFile; |
||||
string line; |
||||
int bbCounter = 0; |
||||
while (getline(gt, line)) |
||||
{ |
||||
bbCounter++; |
||||
} |
||||
gt.close(); |
||||
|
||||
int seqLength = bbCounter; |
||||
for (int i = startFrame; i < seqLength; i++) |
||||
{ |
||||
validSequence.push_back(i); |
||||
} |
||||
|
||||
//exclude from the images sequence, the frames where the target is occluded or out of view
|
||||
string omitFile = cvtest::TS::ptr()->get_data_path() + TRACKING_DIR + "/" + video + "/" + FOLDER_OMIT_INIT + "/" + video + ".txt"; |
||||
std::ifstream omit; |
||||
omit.open(omitFile.c_str()); |
||||
if (omit.is_open()) |
||||
{ |
||||
string omitLine; |
||||
while (getline(omit, omitLine)) |
||||
{ |
||||
vector<string> tokens = splitString(omitLine, " "); |
||||
int s_start = atoi(tokens.at(0).c_str()); |
||||
int s_end = atoi(tokens.at(1).c_str()); |
||||
for (int k = s_start; k <= s_end; k++) |
||||
{ |
||||
std::vector<int>::iterator position = std::find(validSequence.begin(), validSequence.end(), k); |
||||
if (position != validSequence.end()) |
||||
validSequence.erase(position); |
||||
} |
||||
} |
||||
} |
||||
omit.close(); |
||||
gtStartFrame = startFrame; |
||||
//compute the start and the and for each segment
|
||||
int numFrame = (int)(validSequence.size() / numSegments); |
||||
startFrame += (segmentIdx - 1) * numFrame; |
||||
endFrame = startFrame + numFrame; |
||||
|
||||
std::ifstream gt2; |
||||
//open the ground truth
|
||||
gt2.open(gtFile.c_str()); |
||||
ASSERT_TRUE(gt2.is_open()) << gtFile; |
||||
string line2; |
||||
int bbCounter2 = 0; |
||||
while (getline(gt2, line2)) |
||||
{ |
||||
vector<string> tokens = splitString(line2, ","); |
||||
Rect bb(atoi(tokens.at(0).c_str()), atoi(tokens.at(1).c_str()), atoi(tokens.at(2).c_str()), atoi(tokens.at(3).c_str())); |
||||
ASSERT_EQ((size_t)4, tokens.size()) << "Incorrect ground truth file " << gtFile; |
||||
|
||||
bbs.push_back(bb); |
||||
bbCounter2++; |
||||
} |
||||
gt2.close(); |
||||
|
||||
if (segmentIdx == numSegments) |
||||
endFrame = (int)bbs.size(); |
||||
} |
||||
|
||||
template <typename Tracker, typename ROI_t> |
||||
void TrackerTest<Tracker, ROI_t>::run() |
||||
{ |
||||
srand(1); // FIXIT remove that, ensure that there is no "rand()" in implementation
|
||||
|
||||
ASSERT_TRUE(tracker); |
||||
|
||||
checkDataTest(); |
||||
|
||||
//check for failure
|
||||
if (::testing::Test::HasFatalFailure()) |
||||
return; |
||||
|
||||
distanceAndOverlapTest(); |
||||
} |
@ -0,0 +1,80 @@ |
||||
#!/usr/bin/env python |
||||
|
||||
''' |
||||
Tracker demo |
||||
|
||||
USAGE: |
||||
tracker.py [<video_source>] |
||||
''' |
||||
|
||||
# Python 2/3 compatibility |
||||
from __future__ import print_function |
||||
|
||||
import sys |
||||
|
||||
import numpy as np |
||||
import cv2 as cv |
||||
|
||||
from video import create_capture, presets |
||||
|
||||
class App(object): |
||||
|
||||
def initializeTracker(self, image): |
||||
while True: |
||||
print('==> Select object ROI for tracker ...') |
||||
bbox = cv.selectROI('tracking', image) |
||||
print('ROI: {}'.format(bbox)) |
||||
|
||||
tracker = cv.TrackerMIL_create() |
||||
try: |
||||
tracker.init(image, bbox) |
||||
except Exception as e: |
||||
print('Unable to initialize tracker with requested bounding box. Is there any object?') |
||||
print(e) |
||||
print('Try again ...') |
||||
continue |
||||
|
||||
return tracker |
||||
|
||||
def run(self): |
||||
videoPath = sys.argv[1] if len(sys.argv) >= 2 else 'vtest.avi' |
||||
camera = create_capture(videoPath, presets['cube']) |
||||
if not camera.isOpened(): |
||||
sys.exit("Can't open video stream: {}".format(videoPath)) |
||||
|
||||
ok, image = camera.read() |
||||
if not ok: |
||||
sys.exit("Can't read first frame") |
||||
assert image is not None |
||||
|
||||
cv.namedWindow('tracking') |
||||
tracker = self.initializeTracker(image) |
||||
|
||||
print("==> Tracking is started. Press 'SPACE' to re-initialize tracker or 'ESC' for exit...") |
||||
|
||||
while camera.isOpened(): |
||||
ok, image = camera.read() |
||||
if not ok: |
||||
print("Can't read frame") |
||||
break |
||||
|
||||
ok, newbox = tracker.update(image) |
||||
#print(ok, newbox) |
||||
|
||||
if ok: |
||||
cv.rectangle(image, newbox, (200,0,0)) |
||||
|
||||
cv.imshow("tracking", image) |
||||
k = cv.waitKey(1) |
||||
if k == 32: # SPACE |
||||
tracker = self.initializeTracker(image) |
||||
if k == 27: # ESC |
||||
break |
||||
|
||||
print('Done') |
||||
|
||||
|
||||
if __name__ == '__main__': |
||||
print(__doc__) |
||||
App().run() |
||||
cv.destroyAllWindows() |
Loading…
Reference in new issue