1 import numpy as np 2 import matplotlib.pyplot as plt 3 from mpl_toolkits.mplot3d.axes3d import Axes3D 4 5 # 公式 f(x,y) = 2x^2+6y^2+6xy+x+4y+8 6 def targetFunc(x,y): 7 return 2*(x**2)+6*y**2+6*x*y+x+4*y+8 8 9 # 偏导 10 # f‘x(x,y)=4x+6y+1 11 # f‘y(x,y)=12y+6x+4 12 def derivativeFunc(x,y): 13 rx = 4*x+6*y+1 14 ry = 12*y+6*x+4 15 return (rx,ry) 16 17 def linerFunc(initPoint:tuple,targetFunc,derivativeFunc,step = 0.01,limitValue = 0.00000001,timeout=1000000,ax:Axes3D = None): 18 count = 1 19 initPoint = np.array(initPoint) 20 ro,do = targetFunc(*initPoint),np.array(derivativeFunc(*initPoint)) 21 if ax !=None: 22 ax.scatter(*initPoint,ro,c=‘r‘) 23 24 newPoint = initPoint-do*step 25 rn,dn = targetFunc(*newPoint),np.array(derivativeFunc(*newPoint)) 26 27 diff = np.abs(np.array(do-dn)) 28 29 while (diff > limitValue).any() and count < timeout: 30 # print(initPoint) 31 initPoint = newPoint 32 ro, do = targetFunc(*initPoint), np.array(derivativeFunc(*initPoint)) 33 34 newPoint = np.where(np.abs(do*step) >= limitValue,initPoint-do*step,initPoint) 35 rn, dn = targetFunc(*newPoint), np.array(derivativeFunc(*newPoint)) 36 diff = np.abs(np.array(do - dn)) 37 38 if ax != None: 39 ax.scatter(*initPoint, ro, c=‘r‘) 40 count+=1 41 pass 42 print("最终运算次数为 : {0}".format(count)) 43 return rn,newPoint 44 pass 45 46 47 if __name__=="__main__": 48 x,y = np.linspace(-200,200,1000),np.linspace(-200,200,1000) 49 x,y = np.meshgrid(x,y) 50 fxy=targetFunc(x,y) 51 52 fig = plt.figure() 53 ax = Axes3D(fig) 54 55 # ax.plot_surface(x, y, fxy) 56 limitValue,limitPoint = linerFunc((20,20),targetFunc,derivativeFunc,ax=ax) 57 print("该函数在({0},{1})处有驻点,值为{2}".format(limitPoint[0],limitPoint[1],limitValue)) 58 ax.legend() 59 60 61 plt.show() 62 63 64 65 pass
原文:https://www.cnblogs.com/dofstar/p/11462941.html