首页 > 其他 > 详细

rnn-tf代码详解

时间:2019-03-08 15:41:37      阅读:295      评论:0      收藏:0      [点我收藏+]

手写数字识别经典案例,旨在熟悉RNN结构,掌握tf编写RNN的方法。

# coding:utf-8
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

mnist=input_data.read_data_sets("./data",one_hot=True)

# 常规参数
train_rate=0.001
train_step=10000
batch_size=1280
display_step=100

# rnn参数
frame_size=28 # 输入特征数
sequence_length=28 # 输入个数
hidden_num=100 # 隐层神经元个数
n_classes=10

# 定义输入,输出
# 此处输入格式是样本数*特征数,特征是把图片拉成一维的,当然一维还是二维自己定,改成相应的代码就行了
x=tf.placeholder(dtype=tf.float32,shape=[None,sequence_length*frame_size],name="inputx")
y=tf.placeholder(dtype=tf.float32,shape=[None,n_classes],name="expected_y")

# 定义权值
# 注意权值设定只设定v, u和w无需设定
weights=tf.Variable(tf.truncated_normal(shape=[hidden_num,n_classes])) # 全连接层权重
bias=tf.Variable(tf.zeros(shape=[n_classes]))

def RNN(x,weights,bias):
x=tf.reshape(x,shape=[-1,sequence_length,frame_size]) # 3维
rnn_cell=tf.nn.rnn_cell.BasicRNNCell(hidden_num)
init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size])

# 其实这是一个深度RNN网络,对于每一个长度为n的序列[x1,x2,x3,...,xn]的每一个xi,都会在深度方向跑一遍RNN,跑上hidden_num个隐层单元
output,states=tf.nn.dynamic_rnn(rnn_cell,x,dtype=tf.float32)

return tf.nn.softmax(tf.matmul(output[:,-1,:],weights)+bias,1) # y=softmax(vh+c)

predy=RNN(x,weights,bias)

# 以下所有神经网络大同小异
cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predy,labels=y))
train=tf.train.AdamOptimizer(train_rate).minimize(cost)

correct_pred=tf.equal(tf.argmax(predy,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.to_float(correct_pred))

sess=tf.Session()
sess.run(tf.global_variables_initializer())
step=1
testx,testy=mnist.test.next_batch(batch_size)
while step<train_step:
batch_x,batch_y=mnist.train.next_batch(batch_size)
_loss,__=sess.run([cost,train],feed_dict={x:batch_x,y:batch_y})
if step % display_step ==0:
acc,loss=sess.run([accuracy,cost],feed_dict={x:testx,y:testy})
print(step,acc,loss)

step+=1

 

这是最简单的RNN,后面还有非常非常非常复杂的在等你。

rnn-tf代码详解

原文:https://www.cnblogs.com/yanshw/p/10495745.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!