求f(x1, x2) = x1^2 + x2^2, -10 <= x1, x2 <= 10
# -*- coding: utf-8 -*-
# author: 'boliang'
# date: 2018/5/22 14:46
import random
import matplotlib.pyplot as plt
class Solution(object):
w = 0.5
c1 = 2
c2 = 2
r1 = random.random()
r2 = random.random()
def __init__(self, N, l, r, x_len, fun):
"""
初始化
:param N:
:param l:
:param r:
:param x_len:
:param fun:
"""
self.l = l
self.r = r
self.p = [{
'v': self.generate_speed(l, r),
'x': [self.generate_value(l, r) for i in range(x_len)]
} for i in range(N)]
self.fun = fun
self.g = {
'f': 2*r**2,
'best':[10, 10]
}
for p in self.p:
p['pBest'] = p['x']
p['f'] = self.fun(p['x'])
if p['f'] < self.g['f']:
self.g['f'] = p['f']
self.g['best'] = p['pBest']
def init_speed(self):
for p in self.p:
p['v'] = self.generate_speed(self.l, self.r)
def update_speedPos(self):
"""
粒子的速度和位置更新
:return:
"""
for p in self.p:
for i in range(len(p['v'])):
p['v'][i] = self.w*p['v'][i] \
+ self.c1*self.r1*(p['pBest'][i]-p['x'][i]) \
+ self.c2*self.r2*(self.g['best'][i]-p['x'][i])
p['x'][i] = self.adjust_bound(p['v'][i]+p['x'][i], self.l, self.r)
def update_eval(self):
"""
评估粒子的适应度函数值
:return:
"""
for p in self.p:
tmp_f = self.fun(p['x'])
if tmp_f < p['f']:
p['f'] = tmp_f
p['pBest'] = p['x']
if p['f'] < self.g['f']:
self.g['f'] = p['f']
self.g['best'] = p['pBest']
def solve(self, step):
"""
主函数
:return:
"""
x = [i for i in range(step)]
y = []
same_cal = 0
while step > 0:
self.update_speedPos()
self.update_eval()
y.append(self.g['f'])
step -= 1
if len(y) >= 2 and y[-1] == y[-2]:
same_cal += 1
if same_cal > len(y)/5:
self.init_speed()
same_cal = 0
plt.plot(x, y, '-o')
plt.xlabel('Step')
plt.ylabel('Eval')
plt.title('Operation')
plt.show()
def adjust_bound(self, val, l, r):
if val > r:
return r
elif val < l:
return l
else:
return val
def generate_speed(self, l, r):
v1 = 0
v2 = 0
while v1 == 0 or v2 == 0:
v1 = random.randint(l, r)
v2 = random.randint(l, r)
return [v1, v2]
def generate_value(self, l, r):
return random.randint(l, r)
def fun(x):
return x[0]**2 + x[1]**2
if __name__ == '__main__':
sol = Solution(3, -10, 10, 2, fun)
sol.solve(30)
迭代30次,算法快速收敛
Comments