首页 > 其他 > 详细

利用感知机做线性分类小demo

时间:2020-10-31 08:35:46      阅读:27      评论:0      收藏:0      [点我收藏+]

如题啦,单个感知机分类

结果展示:

技术分享图片

 

代码:

技术分享图片
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random 

def sgn(x):
    if x>=0:
        return 1;
    else: 
        return -1;

class point:
    x1=0
    x2=0
    y=1

    def __init__(self,x1,x2,y):
        self.x1=x1
        self.x2=x2
        self.y=y

    #def __init__(para):
    #    self.x1=para[0]
    #    self.x2=para[1]
    #    self.y=para[2]

    def cauy(self,para):
        return (self.x1*para[0]+self.x2*para[1]+para[2])


p1=point(1,1,-1)
p2=point(3,3,1)
p3=point(4,3,1)


#画三个点
plt.scatter(1,1,c=r)
plt.scatter(3,3,c=b)
plt.scatter(4,3,c=b)


#参数与训练集
w1=w2=b=0
n=1

trainset=[p1,p2,p3]
parameter=[w1,w2,b]

count=0
#循环训练参数
while(1):
    random.shuffle(trainset)

    for i in trainset:
        #预测值
        y_=sgn(i.cauy(parameter))

        #遇到误分类点
        if(y_!=i.y):
            #参数修正
            w1+=n*(i.y-y_)*i.x1
            w2+=n*(i.y-y_)*i.x2
            b+=n*i.y
            parameter=[w1,w2,b]

            #绘图
            x0=np.array([1,2,3,4,5,6,7,8,9,10])
            if w2!=0:          
                plt.plot(x0,x0*w1/(-w2)-b/w2)
            elif w1!=0:
                 plt.plot([-b/w1]*10,x0)
            else:
                plt.scatter(0,0,c="g")
            
            print(w1,x1 + ,w2,x2 + ,b)

            count+=1
            break

    else:
        break

print("you have try",count,"times")
plt.show()
View Code

ps:下面还有随机产生数据集的版本(但是点多了就很难线性可分,一般5个随机点就很难分了)

技术分享图片
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import random 

def sgn(x):
    if x>=0:
        return 1;
    else: 
        return -1;

class point:
    x1=0
    x2=0
    y=1

    def __init__(self,x1,x2,y):
        self.x1=x1
        self.x2=x2
        self.y=y

    #def __init__(para):
    #    self.x1=para[0]
    #    self.x2=para[1]
    #    self.y=para[2]

    def cauy(self,para):
        return (self.x1*para[0]+self.x2*para[1]+para[2])

#随机生成数据集
trainset=[point(int(random.randrange(1,10,1)),int(random.randrange(1,10,1)),sgn(random.randrange(-5,5,1))) for i in range(4)]


#画点
for p in trainset:
    if p.y==1:
        plt.scatter(p.x1,p.x2,c=b)
    else:
        plt.scatter(p.x1,p.x2,c=r)


#参数与训练集
w1=w2=b=0
n=1

parameter=[w1,w2,b]

count=0
res=False

#循环训练参数
while(1and count<20):
    #理论上不打乱也行
    #random.shuffle(trainset)

    for i in trainset:
        #预测值
        y_=sgn(i.cauy(parameter))

        #遇到误分类点
        if(y_!=i.y):
            #参数修正
            w1+=n*(i.y-y_)*i.x1
            w2+=n*(i.y-y_)*i.x2
            b+=n*i.y
            parameter=[w1,w2,b]

            #绘图
            x0=np.array([1,2,3,4,5,6,7,8,9,10])
            if w2!=0:          
                plt.plot(x0,x0*w1/(-w2)-b/w2)
            elif w1!=0:
                 plt.plot([-b/w1]*10,x0)
            else:
                plt.scatter(0,0,c=g)

            print(w1,x1 + ,w2,x2 + ,b)

            count+=1
            break

    else:
        res=True
        break

if(res):
    print("you have try",count,"times and you are succeed")
else :
    print("it is unclassifiable in 20 times training !!")
plt.show()
View Code

 

利用感知机做线性分类小demo

原文:https://www.cnblogs.com/laozhu1234/p/13904723.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!