资源简介
本资源采用自组织学习获取中心和有监督学习获取中心两种方式训练RBF神经网络,支持多维函数逼近,支持批量训练,具有较好的封装性,使用非常简便
代码片段和文件信息
import tensorflow as tf
import numpy as np
from sklearn.cluster import KMeans
class RBF:
#初始化学习率、学习步数
def __init__(selflearning_rate=0.002step_num=10001hidden_size=10):
self.learning_rate=learning_rate
self.step_num=step_num
self.hidden_size=hidden_size
#使用 k-means 获取聚类中心、标准差
def getC_S(selfxclass_num):
estimator=KMeans(n_clusters=class_nummax_iter=10000) #构造聚类器
estimator.fit(x) #聚类
c=estimator.cluster_centers_
n=len(c)
s=0;
for i in range(n):
j=i+1
while j t=np.sum((c[i]-c[j])**2)
s=max(st)
j=j+1
s=np.sqrt(s)/np.sqrt(2*n)
return cs
#高斯核函数(c为中心,s为标准差)
def kernel(selfxcs):
x1=tf.tile(x[1self.hidden_size]) #将x水平复制 hidden次
x2=tf.reshape(x1[-1self.hidden_sizeself.feature])
dist=tf.reduce_sum((x2-c)**22)
return tf.exp(-dist/(2*s**2))
#训练RBF神经网络
def train(selfxy):
self.feature=np.shape(x)[1] #输入值的特征数
self.cself.s=self.getC_S(xself.hidden_size) #获取聚类中心、标准差
x_=tf.placeholder(tf.float32[Noneself.feature]) #定义placeholder
y_=tf.placeholder(tf.float32[None1]) #定义placeholder
#定义径向基层
z=self.kernel(x_self.cself.s)
#定义输出层
w=tf.Variable(tf.random_normal([self.hidden_size1]))
b=tf.Variable(tf.zeros([1]))
yf=tf.matmul(zw)+b
loss=tf.reduce_mean(tf.square(y_-yf))#二次代价函数
optimizer=tf.train.AdamOptimizer(self.learning_rate) #Adam优化器
train=optimizer.minimize(loss) #最小化代价函数
init=tf.global_variables_initializer() #变量初始化
with tf.Session() as sess:
sess.run(init)
for epoch in range(self.step_num):
sess.run(trainfeed_dict={x_:xy_:y})
if epoch>0 and epoch%500==0:
mse=sess.run(lossfeed_dict={x_:xy_:y})
print(epochmse)
self.wself.b=sess.run([wb]feed_dict={x_:xy_:y})
def kernel2(selfxcs): #预测时使用
x1=np.tile(x[1self.hidden_size]) #将x水平复制 hidden次
x2=np.reshape(x1[-1self.hidden_sizeself.feature])
dist=np.sum((x2-c)**22)
return np.exp(-dist/(2*s**2))
def predict(selfx):
z=self.kernel2(xself.cself.s)
pre=np.matmul(zself.w)+self.b
return pre
属性 大小 日期 时间 名称
----------- --------- ---------- ----- ----
目录 0 2019-11-17 15:32 RBF案例\
文件 2790 2019-11-28 10:53 RBF案例\RBF_kmeans.py
文件 2314 2019-11-28 10:44 RBF案例\RBF_Supervised.py
文件 831 2019-11-28 10:50 RBF案例\test_kmeans.py
文件 1090 2019-11-28 11:00 RBF案例\test_kmeans2.py
文件 837 2019-11-28 10:59 RBF案例\test_Supervised.py
文件 1094 2019-11-28 11:06 RBF案例\test_Supervised2.py
目录 0 2019-11-28 11:06 RBF案例\__pycache__\
文件 1439 2019-11-17 15:25 RBF案例\__pycache__\RBF.cpython-37.pyc
文件 2611 2019-11-28 10:54 RBF案例\__pycache__\RBF_kmeans.cpython-37.pyc
文件 2175 2019-11-28 11:06 RBF案例\__pycache__\RBF_Supervised.cpython-37.pyc
评论
共有 条评论