首页 > 其他 > 详细

Opencv研读笔记:haartraining程序之icvCreateCARTStageClassifier函数详解~

时间:2015-01-04 23:10:03      阅读:731      评论:0      收藏:0      [点我收藏+]

之前介绍了haartraining程序中的cvCreateMTStumpClassifier函数,这个函数的功能是计算最优弱分类器,这篇文章介绍一下自己对haartraining中关于强分类器计算的一些理解,也就是程序中的icvCreateCARTStageClassifier函数。

由于haartraining是基于HAAR特征进行adaboost训练,对于HAAR特征的处理比较繁琐,采用了奇数弱分类器补充针对翻转特征最优弱分类器计算的代码,所以代码看起来较为冗长。此外,其采用了较多的中间结构体变量,例如CvIntHaarClassifier结构体(用于模拟强分类器结构体CvStageHaarClassifier的父类),CvBoostTrainer结构体(用于初始化,更新样本权值等)等等,所以代码看起来比较绕。

强分类器的创建,其中,样本权值的更新,程序中设计了四种经典adaboost算法版本,它们是,Discrete Adaboost、Real Adaboost、Logit Boost、Gentle Adaboost。代码通过函数指针的形式(实际上这也是opencv一直常用的手段)对函数进行回调。

以上说的就是icvCreateCARTStageClassifier中值得注意的几点,下面上代码,是根据自己的理解添加的注释,请各位不吝批评指正哈!

转载请注明:http://blog.csdn.net/wsj998689aa/article/details/42398235

static
CvIntHaarClassifier* icvCreateCARTStageClassifier( CvHaarTrainingData* data,        // 全部训练样本
                                                   CvMat* sampleIdx,                // 实际训练样本序列
                                                   CvIntHaarFeatures* haarFeatures, // 全部HAAR特征
                                                   float minhitrate,    // 最小正检率(用于确定强分类器阈值)
                                                   float maxfalsealarm, // 最大误检率(用于确定是否收敛)            
                                                   int   symmetric,     // HAAR是否对称
                                                   float weightfraction,    // 样本剔除比例(用于剔除小权值样本)
                                                   int numsplits,           // 每个弱分类器特征个数(一般为1)
                                                   CvBoostType boosttype,   // adaboost类型
                                                   CvStumpError stumperror, // Discrete AdaBoost中的阈值计算方式
                                                   int maxsplits )          // 弱分类器最大个数
{

#ifdef CV_COL_ARRANGEMENT
    int flags = CV_COL_SAMPLE;
#else
    int flags = CV_ROW_SAMPLE;
#endif

    CvStageHaarClassifier* stage = NULL;                    // 强分类器
    CvBoostTrainer* trainer;                                // 临时训练器,用于更新样本权值
    CvCARTClassifier* cart = NULL;                          // 弱分类器
    CvCARTTrainParams trainParams;                          // 训练参数
    CvMTStumpTrainParams stumpTrainParams;                  // 弱分类器参数
    //CvMat* trainData = NULL;
    //CvMat* sortedIdx = NULL;
    CvMat eval;                                             // 临时矩阵
    int n = 0;                                              // 特征总数
    int m = 0;                                              // 总样本个数
    int numpos = 0;                                         // 正样本个数
    int numneg = 0;                                         // 负样本个数
    int numfalse = 0;                                       // 误检样本个数
    float sum_stage = 0.0F;                                 // 置信度累积和                              
    float threshold = 0.0F;                                 // 强分类器阈值
    float falsealarm = 0.0F;                                // 误检率
    
    //CvMat* sampleIdx = NULL;
    CvMat* trimmedIdx;                                      // 剔除小权值之后的样本序列
    //float* idxdata = NULL;
    //float* tempweights = NULL;
    //int    idxcount = 0;
    CvUserdata userdata;                                    // 训练数据

    int i = 0;
    int j = 0;
    int idx;
    int numsamples;                                         // 实际样本个数
    int numtrimmed;                                         // 剔除小权值之后的样本个数
    
    CvCARTHaarClassifier* classifier;                       // 弱分类器
    CvSeq* seq = NULL;
    CvMemStorage* storage = NULL;
    CvMat* weakTrainVals;                                   // 样本类别,只有logitboost才会用到
    float alpha;
    float sumalpha;
    int num_splits;                                         // 弱分类器个数                                    

#ifdef CV_VERBOSE
    printf( "+----+----+-+---------+---------+---------+---------+\n" );
    printf( "|  N |%%SMP|F|  ST.THR |    HR   |    FA   | EXP. ERR|\n" );
    printf( "+----+----+-+---------+---------+---------+---------+\n" );
#endif /* CV_VERBOSE */
    
    n = haarFeatures->count;
    m = data->sum.rows;
    numsamples = (sampleIdx) ? MAX( sampleIdx->rows, sampleIdx->cols ) : m;

    userdata = cvUserdata( data, haarFeatures );

    /* 弱分类参数设置 */
    stumpTrainParams.type = ( boosttype == CV_DABCLASS )
        ? CV_CLASSIFICATION_CLASS : CV_REGRESSION;                              // 分类或者回归
    stumpTrainParams.error = ( boosttype == CV_LBCLASS || boosttype == CV_GABCLASS )
        ? CV_SQUARE : stumperror;
    stumpTrainParams.portion = CV_STUMP_TRAIN_PORTION;                          // 每组特征个数
    stumpTrainParams.getTrainData = icvGetTrainingDataCallback;                 // 计算样本的haar值
    stumpTrainParams.numcomp = n;                                               // 特征个数            
    stumpTrainParams.userdata = &userdata; 
    stumpTrainParams.sortedIdx = data->idxcache;                                // 特征-样本序号矩阵(排序之后)

    trainParams.count = numsplits;
    trainParams.stumpTrainParams = (CvClassifierTrainParams*) &stumpTrainParams;
    trainParams.stumpConstructor = cvCreateMTStumpClassifier;                   // 筛选最优弱分类器
    trainParams.splitIdx = icvSplitIndicesCallback;                             // 没用到过
    trainParams.userdata = &userdata;                                           

    // 临时向量,用于存放样本haar特征值
    eval = cvMat( 1, m, CV_32FC1, cvAlloc( sizeof( float ) * m ) );
    
    storage = cvCreateMemStorage();

    // 最优弱分类器存储序列
    seq = cvCreateSeq( 0, sizeof( *seq ), sizeof( classifier ), storage );

    // 样本类别,只有logitboost才会用到
    weakTrainVals = cvCreateMat( 1, m, CV_32FC1 );

    // 初始化样本类别与权重,weakTrainVals为{-1, 1},权重都一样
    trainer = cvBoostStartTraining( &data->cls, weakTrainVals, &data->weights,
                                    sampleIdx, boosttype );
    num_splits = 0;
    sumalpha = 0.0F;
    do
    {     

#ifdef CV_VERBOSE
        int v_wt = 0;
        int v_flipped = 0;
#endif /* CV_VERBOSE */

        // 剔除小权值样本
        trimmedIdx = cvTrimWeights( &data->weights, sampleIdx, weightfraction );

        // 实际样本总数
        numtrimmed = (trimmedIdx) ? MAX( trimmedIdx->rows, trimmedIdx->cols ) : m;

#ifdef CV_VERBOSE
        v_wt = 100 * numtrimmed / numsamples;
        v_flipped = 0;

#endif /* CV_VERBOSE */

        // 重要函数,创建CART树的同时,计算出当前最优弱分类器,一般只有根节点
        cart = (CvCARTClassifier*) cvCreateCARTClassifier( data->valcache,
                        flags,
                        weakTrainVals, 0, 0, 0, trimmedIdx,
                        &(data->weights),
                        (CvClassifierTrainParams*) &trainParams );

        // 创建弱分类器
        classifier = (CvCARTHaarClassifier*) icvCreateCARTHaarClassifier( numsplits );

        // 将CART树转化为弱分类器
        icvInitCARTHaarClassifier( classifier, cart, haarFeatures );

        num_splits += classifier->count;

        cart->release( (CvClassifier**) &cart );
        
        // 为何一定要在奇数个弱分类器处计算?
        if( symmetric && (seq->total % 2) )
        {
            float normfactor = 0.0F;
            CvStumpClassifier* stump;
            
            /* 翻转HAAR特征 */
            for( i = 0; i < classifier->count; i++ )
            {
                if( classifier->feature[i].desc[0] == 'h' )
                {
                    for( j = 0; j < CV_HAAR_FEATURE_MAX &&
                                    classifier->feature[i].rect[j].weight != 0.0F; j++ )
                    {
                        classifier->feature[i].rect[j].r.x = data->winsize.width - 
                            classifier->feature[i].rect[j].r.x -
                            classifier->feature[i].rect[j].r.width;                
                    }
                }
                else
                {
                    int tmp = 0;

                    /* (x,y) -> (24-x,y) */
                    /* w -> h; h -> w    */
                    for( j = 0; j < CV_HAAR_FEATURE_MAX &&
                                    classifier->feature[i].rect[j].weight != 0.0F; j++ )
                    {
                        classifier->feature[i].rect[j].r.x = data->winsize.width - 
                            classifier->feature[i].rect[j].r.x;
                        CV_SWAP( classifier->feature[i].rect[j].r.width,
                                 classifier->feature[i].rect[j].r.height, tmp );
                    }
                }
            }

            // 转化为基于积分图计算的特征
            icvConvertToFastHaarFeature( classifier->feature,
                                         classifier->fastfeature,
                                         classifier->count, data->winsize.width + 1 );

            // 为了验证最新翻转特征是否为最优特征
            stumpTrainParams.getTrainData = NULL;
            stumpTrainParams.numcomp = 1;
            stumpTrainParams.userdata = NULL;
            stumpTrainParams.sortedIdx = NULL;

            // 验证是否新生成的特征可作为最优弱分类器
            for( i = 0; i < classifier->count; i++ )
            {
                for( j = 0; j < numtrimmed; j++ )
                {
                    // 获取训练样本
                    idx = icvGetIdxAt( trimmedIdx, j );

                    // 对每个训练样本计算Haar特征
                    eval.data.fl[idx] = cvEvalFastHaarFeature( &classifier->fastfeature[i],
                                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step) ); 

                    // 归一化因子
                    normfactor = data->normfactor.data.fl[idx];

                    // 对Haar特征归一化
                    eval.data.fl[idx] = ( normfactor == 0.0F )
                        ? 0.0F : (eval.data.fl[idx] / normfactor);
                }

                // 计算最优弱分类器
                stump = (CvStumpClassifier*) trainParams.stumpConstructor( &eval,
                    CV_COL_SAMPLE,
                    weakTrainVals, 0, 0, 0, trimmedIdx,
                    &(data->weights),
                    trainParams.stumpTrainParams );
            
                classifier->threshold[i] = stump->threshold;                // 阈值
                if( classifier->left[i] <= 0 )
                {
                    classifier->val[-classifier->left[i]] = stump->left;    // 左分支输出置信度
                }
                if( classifier->right[i] <= 0 )
                {
                    classifier->val[-classifier->right[i]] = stump->right;  // 右分支输出置信度
                }

                stump->release( (CvClassifier**) &stump );        
                
            }

            // 还原参数,参数支持cvCreateCARTClassifier函数
            stumpTrainParams.getTrainData = icvGetTrainingDataCallback;
            stumpTrainParams.numcomp = n;
            stumpTrainParams.userdata = &userdata;
            stumpTrainParams.sortedIdx = data->idxcache;

#ifdef CV_VERBOSE
            v_flipped = 1;
#endif /* CV_VERBOSE */

        } /* if symmetric */
        if( trimmedIdx != sampleIdx )
        {
            cvReleaseMat( &trimmedIdx );
            trimmedIdx = NULL;
        }
        
        // 基于当前最优弱分类器,更新样本特征值
        for( i = 0; i < numsamples; i++ )
        {
            idx = icvGetIdxAt( sampleIdx, i );

            eval.data.fl[idx] = classifier->eval( (CvIntHaarClassifier*) classifier,
                (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                data->normfactor.data.fl[idx] );
        }

        // 更新样本权重,如果是LogitBoost,也会更新weakTrainVals
        alpha = cvBoostNextWeakClassifier( &eval, &data->cls, weakTrainVals,
                                           &data->weights, trainer );
        
        // 这个变量没什么用
        sumalpha += alpha;
        
        for( i = 0; i <= classifier->count; i++ )
        {
            if( boosttype == CV_RABCLASS ) 
            {
                classifier->val[i] = cvLogRatio( classifier->val[i] );
            }
            classifier->val[i] *= alpha;
        }

        // 添加弱分类器
        cvSeqPush( seq, (void*) &classifier );

        // 正样本个数
        numpos = 0;

        // 遍历sampleIdx中所有样本
        for( i = 0; i < numsamples; i++ )
        {
            // 获得样本序号
            idx = icvGetIdxAt( sampleIdx, i );

            // 如果样本为正样本
            if( data->cls.data.fl[idx] == 1.0F )
            {
                // 初始化特征值
                eval.data.fl[numpos] = 0.0F;

                // 遍历seq中所有弱分类器
                for( j = 0; j < seq->total; j++ )
                {
                    // 获取弱分类器
                    classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));

                    // 累积计算当前正样本的弱分类器输出结果
                    eval.data.fl[numpos] += classifier->eval( 
                        (CvIntHaarClassifier*) classifier,
                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                        data->normfactor.data.fl[idx] );
                }
                /* eval.data.fl[numpos] = 2.0F * eval.data.fl[numpos] - seq->total; */
                numpos++;
            }
        }

        // 对输出结果值排序
        icvSort_32f( eval.data.fl, numpos, 0 );

        // 计算阈值,应该是大于threshold则为正类,小于threshold则为负类
        threshold = eval.data.fl[(int) ((1.0F - minhitrate) * numpos)];

        numneg = 0;
        numfalse = 0;

        // 遍历所有样本
        for( i = 0; i < numsamples; i++ )
        {
            idx = icvGetIdxAt( sampleIdx, i );

            // 如果样本为负样本
            if( data->cls.data.fl[idx] == 0.0F )
            {
                numneg++;
                sum_stage = 0.0F;

                // 遍历seq中所有弱分类器
                for( j = 0; j < seq->total; j++ )
                {
                   classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));

                   // 累积当前负样本的分类器输出结果
                   sum_stage += classifier->eval( (CvIntHaarClassifier*) classifier,
                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                        data->normfactor.data.fl[idx] );
                }
                /* sum_stage = 2.0F * sum_stage - seq->total; */

                // 因为小于threshold为负类,所以下面是分类错误的情况
                if( sum_stage >= (threshold - CV_THRESHOLD_EPS) )
                {
                    numfalse++;
                }
            }
        }

        // 计算虚警率
        falsealarm = ((float) numfalse) / ((float) numneg);

// 输出内容
#ifdef CV_VERBOSE
        {
            // 正样本检出率
            float v_hitrate    = 0.0F;

            // 负样本误检率
            float v_falsealarm = 0.0F;
            /* expected error of stage classifier regardless threshold */

            // 这是什么?
            float v_experr = 0.0F;

            // 遍历所有样本
            for( i = 0; i < numsamples; i++ )
            {
                idx = icvGetIdxAt( sampleIdx, i );

                sum_stage = 0.0F;

                // 遍历seq中所有弱分类器
                for( j = 0; j < seq->total; j++ )
                {
                    classifier = *((CvCARTHaarClassifier**) cvGetSeqElem( seq, j ));
                    sum_stage += classifier->eval( (CvIntHaarClassifier*) classifier,
                        (sum_type*) (data->sum.data.ptr + idx * data->sum.step),
                        (sum_type*) (data->tilted.data.ptr + idx * data->tilted.step),
                        data->normfactor.data.fl[idx] );
                }
                /* sum_stage = 2.0F * sum_stage - seq->total; */

                // 只需要判断单一分支即可
                if( sum_stage >= (threshold - CV_THRESHOLD_EPS) )
                {
                    if( data->cls.data.fl[idx] == 1.0F )
                    {
                        v_hitrate += 1.0F;
                    }
                    else
                    {
                        v_falsealarm += 1.0F;
                    }
                }

                // 正类样本的sum_stage必须大于0
                if( ( sum_stage >= 0.0F ) != (data->cls.data.fl[idx] == 1.0F) )
                {
                    v_experr += 1.0F;
                }
            }
            v_experr /= numsamples;
            printf( "|%4d|%3d%%|%c|%9f|%9f|%9f|%9f|\n",
                seq->total, v_wt, ( (v_flipped) ? '+' : '-' ),
                threshold, v_hitrate / numpos, v_falsealarm / numneg,
                v_experr );
            printf( "+----+----+-+---------+---------+---------+---------+\n" );
            fflush( stdout );
        }
#endif /* CV_VERBOSE */
        
    // 两种收敛方式,一种是误检率小于规定阈值,另一种是弱分类器个数小于规定阈值
    } while( falsealarm > maxfalsealarm && (!maxsplits || (num_splits < maxsplits) ) );
    cvBoostEndTraining( &trainer );

    if( falsealarm > maxfalsealarm )
    {
        stage = NULL;
    }
    else
    {
        stage = (CvStageHaarClassifier*) icvCreateStageHaarClassifier( seq->total,
                                                                       threshold );
        cvCvtSeqToArray( seq, (CvArr*) stage->classifier );
    }
    
    /* CLEANUP */
    cvReleaseMemStorage( &storage );
    cvReleaseMat( &weakTrainVals );
    cvFree( &(eval.data.ptr) );
    
    return (CvIntHaarClassifier*) stage;
}



Opencv研读笔记:haartraining程序之icvCreateCARTStageClassifier函数详解~

原文:http://blog.csdn.net/wsj998689aa/article/details/42398235

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!