首页 > 其他 > 详细

朴素贝叶斯

时间:2014-06-06 23:33:59      阅读:493      评论:0      收藏:0      [点我收藏+]

参考《机器学习实战》

朴素(naive)贝叶斯遵循以下原则:

设特征为x,y;类别为c。

在已知x、y特征的情况下,判断其类为ci的概率为:

bubuko.com,布布扣

自然, 我们选取概率较大的为对应的分类结果:

bubuko.com,布布扣

朴素贝叶斯就是根据这一原则进行分类器设计的。书中以垃圾邮件分类为例做了详述。

垃圾邮件分类的基本过程:

    数据集:包含50个文本文件,其中25个为正常邮件,另外25个为垃圾邮件。

    测试集:随机地从这50个样本中选取10个做为测试样本。剩下的数据作为训练集

    分类器训练过程:

1、首先设定两个list变量‘fileWords’,‘allwords’用来保存单词,每篇文档对应的类用classSet中,filewords将每篇文档中的单词以列表的形式保存起来,每篇文档成为filewords中的一个元素。allwords将每篇文档的单词全部合并到一起;

2、分析单词:对allwords中的单词去重并去掉太短的单词。

3、生成文档向量矩阵:对于每篇在filewords中的文档,分析该文档的单词是否存在于allwords列表中。以此对每篇文档生成一个与allwords等长的向量,向量中的1代表allwords中的对应位置的单词存在于该文档中否则不存在。

4、计算先验概率:将不同类别的文档向量分别加起来,除以对应类文档向量中‘1’的总个数,得到bubuko.com,布布扣。各类文档的个数除以总的训练集文档个数得到bubuko.com,布布扣。为避免bubuko.com,布布扣等于0时对计算造成的影响,每个概率都增加一定的偏移值,并用log对数进行表示。

5、分类:对于输入文档,首先利用allwords对输入进行向量化,得到该输入文档的文档向量,然后该向量与每个类的bubuko.com,布布扣向量做矢量乘法,然后再乘以bubuko.com,布布扣。得到bubuko.com,布布扣中的分子,由于分母是常数,所以我们只需要比较该贝叶斯公式中的分子大小即可对输入做分类判断。


   实验代码(原理基本与书中给出的一致):

bubuko.com,布布扣
# -*- coding:cp936 -*-
import re
import random
from math import *
from veusz.windows.tutorial import DataStart

hamFiles = ./email/ham/%d.txt
spamFiles = ./email/spam/%d.txt

def bayesClassify(vectIn, p1Vect, p0Vect, pSpam):
    pC1 = sum(vectIn*p1Vect)+log(pSpam)
    pC0 = sum(vectIn*p0Vect)+log(1-pSpam)
    if pC1 > pC0:
        return 1
    else:
        return 0
import numpy as np
def words2Vect(wordsIn, WordsSet):
    setLen = len(WordsSet)
    vectOut = np.zeros(setLen)
    cnt = 0
    for iw in WordsSet:
        if iw in wordsIn:
            vectOut[cnt]=1
        cnt += 1
    return vectOut

def loadDataSet():
    dataSet = []
    classSet = []
    allWords = []
    fileWords = []
    regEx = re.compile(r\W*)
    for i in range(1,26):
        fileTmp = open(hamFiles%i).read()
        if i ==1:
            print fileTmp
        wordSplit = regEx.split(fileTmp)
        if i == 1:
            print wordSplit
#         exit()
        fileWordTmp = []
        for word in wordSplit:
            if len(word)>3:
                fileWordTmp.append(word.lower())
        allWords.extend(fileWordTmp)
        fileWords.append(fileWordTmp)
        fileWordTmp=[]
        classSet.append(0)

        fileTmp = open(spamFiles%i).read()
        
        wordSplit = regEx.split(fileTmp)
        for word in wordSplit:
            if len(word)>3:
                fileWordTmp.append(word.lower())
        allWords.extend(fileWordTmp)
        fileWords.append(fileWordTmp)
        classSet.append(1)

    #wordsSet
    wordsSet = list(set(allWords))
#     print wordsSet
#     exit()
    #convert file words to file vector
    for i in range(0,50):
        vect = words2Vect(fileWords[i], wordsSet)
        dataSet.append(list(vect))
    return dataSet, classSet, wordsSet

def BayesiTrain(dataSet, classSet):
    
    fileNum = dataSet.shape[0]
    p1add = np.zeros(dataSet.shape[1])
    p0add = np.zeros(dataSet.shape[1])
    spamCnt = 0 
    for i in range(fileNum):
        if classSet[i]==1:
            p1add = p1add + dataSet[i,:]
            spamCnt += 1
        else:
            p0add =p0add + dataSet[i,:]
    p1Sum = np.sum(p1add)
    p0Sum = np.sum(p0add)
    p1Vect = (p1add+1)/(p1Sum+2)
    p0Vect = (p0add+1)/(p0Sum+2)
    pSpa = spamCnt/float(fileNum)
    return  p1Vect, p0Vect, pSpa
    
    
def test():
    #对测试数据进行测试: 将测试集输入分类器,检测输出结果
    
    #Load data set and the class set
    dataSet, classSet, wordSet = loadDataSet()
    #get test set
    testSet = []
    testClass = []
    for i in range(10):#only 10 used for test set
        randIndex = int(random.uniform(0,50-i))
        testSet.append(dataSet[randIndex])
        testClass.append(classSet[randIndex])
        del dataSet[randIndex]
        del classSet[randIndex]
    
    #训练,得到p1Vect/p0Vect及pSpam
    p1Vect, p0Vect, pSpam = BayesiTrain(np.array(dataSet), classSet)
    erCnt = 0
    for iTest in range(10):
        class_res = bayesClassify(np.array(testSet[iTest]), np.array(p1Vect), np.array(p0Vect), pSpam)
        if class_res != testClass[iTest]:
            print Error
            erCnt += 1 #出错检测加1
    print Error Rate: %f%(erCnt/float(10))

test()
bubuko.com,布布扣

朴素贝叶斯,布布扣,bubuko.com

朴素贝叶斯

原文:http://www.cnblogs.com/mmhx/p/3765375.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!