本代码借鉴于知乎SleepyBag
具体实现
import numpy as np
class Perception:
def __init__(self,lr,max_epoch=2000):
self.lr=lr
self.max_epoch=max_epoch
def fit(self,X,Y):# X Y is np array
self.feature_size=X.shape[-1]
self.w=np.random.rand(self.feature_size)
self.b=np.random.rand(1)
Update=1
epoch=0
while Update>0 and epoch<self.max_epoch :
Update=0
perm=np.random.permutation(len(X))
for i in perm:
x,y=X[i],Y[i]
if self.predict_one(x)!=y:
self.w+=self.lr*y*x
self.b+=self.lr*y
Update+=1
epoch+=1
print("第{}epoch,仍有{}个分类错误的点".format(epoch,Update))
def predict_one(self,x):
return 1 if self.w@x+self.b>0 else -1
def predict(self,X):
return np.apply_along_axis(self.predict_one,axis=-1,arr=X)
跑个demo看看效果吧
X = np.array([[0.2, 0], [0, 1], [1, 0], [1, 1]])
Y = np.array([1, 1, -1, -1])
perp = Perception(0.1)
perp.fit(X,Y)
perp.predict(X)
运行结果:
可视化结果:
异或问题可视化结果: