首页 > 其他 > 详细

tensorflow学习笔记七----------RNN

时间:2018-08-15 20:07:10      阅读:232      评论:0      收藏:0      [点我收藏+]

和神经网络不同的是,RNN中的数据批次之间是有相互联系的。输入的数据需要是要求序列化的。

1.将数据处理成序列化;

2.将一号数据传入到隐藏层进行处理,在传入到RNN中进行处理,RNN产生两个结果,一个结果产生分类结果,另外一个结果传入到二号数据的RNN中;

3.所有数据都处理完。

 

导入数据

import tensorflow as tf
import from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import matplotlib.pyplot as plt
print ("Packages imported")

mnist = input_data.read_data_sets("data/", one_hot=True)
trainimgs, trainlabels, testimgs, testlabels  = mnist.train.images, mnist.train.labels, mnist.test.images, mnist.test.labels 
ntrain, ntest, dim, nclasses  = trainimgs.shape[0], testimgs.shape[0], trainimgs.shape[1], trainlabels.shape[1]
print ("MNIST loaded")

将28*28像素的数据变成28条数据;隐藏层有128个神经元;定义好权重和偏置;

diminput  = 28
dimhidden = 128
dimoutput = nclasses
nsteps    = 28
weights = {
    hidden: tf.Variable(tf.random_normal([diminput, dimhidden])), 
    out: tf.Variable(tf.random_normal([dimhidden, dimoutput]))
}
biases = {
    hidden: tf.Variable(tf.random_normal([dimhidden])),
    out: tf.Variable(tf.random_normal([dimoutput]))
}

定义RNN函数。将数据转化一下;计算隐藏层;将隐藏层切片;计算RNN产生的两个结果;预测值是最后一个RNN产生的LSTM_O

def _RNN(_X, _W, _b, _nsteps, _name):
    # 1. Permute input from [batchsize, nsteps, diminput] 
    #   => [nsteps, batchsize, diminput]
    _X = tf.transpose(_X, [1, 0, 2])
    # 2. Reshape input to [nsteps*batchsize, diminput] 
    _X = tf.reshape(_X, [-1, diminput])
    # 3. Input layer => Hidden layer
    _H = tf.matmul(_X, _W[hidden]) + _b[hidden]
    # 4. Splite data to ‘nsteps‘ chunks. An i-th chunck indicates i-th batch data 
    _Hsplit = tf.split(0, _nsteps, _H) 
    # 5. Get LSTM‘s final output (_LSTM_O) and state (_LSTM_S)
    #    Both _LSTM_O and _LSTM_S consist of ‘batchsize‘ elements
    #    Only _LSTM_O will be used to predict the output. 
    with tf.variable_scope(_name) as scope:
        
        scope.reuse_variables()
        lstm_cell = tf.nn.rnn_cell.BasicLSTMCell(dimhidden, forget_bias=1.0)
        _LSTM_O, _LSTM_S = tf.nn.rnn(lstm_cell, _Hsplit,dtype=tf.float32)
    # 6. Output
    _O = tf.matmul(_LSTM_O[-1], _W[out]) + _b[out]    
    # Return! 
    return {
        X: _X, H: _H, Hsplit: _Hsplit,
        LSTM_O: _LSTM_O, LSTM_S: _LSTM_S, O: _O 
    }
print ("Network ready")

定义好RNN后,定义损失函数等

learning_rate = 0.001
x      = tf.placeholder("float", [None, nsteps, diminput])
y      = tf.placeholder("float", [None, dimoutput])
myrnn  = _RNN(x, weights, biases, nsteps, basic)
pred   = myrnn[O]
cost   = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y)) 
optm   = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # Adam Optimizer
accr   = tf.reduce_mean(tf.cast(tf.equal(tf.argmax(pred,1), tf.argmax(y,1)), tf.float32))
init   = tf.global_variables_initializer()
print ("Network Ready!")

进行训练

training_epochs = 5
batch_size      = 16
display_step    = 1
sess = tf.Session()
sess.run(init)
print ("Start optimization")
for epoch in range(training_epochs):
    avg_cost = 0.
    total_batch = int(mnist.train.num_examples/batch_size)
 
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        batch_xs = batch_xs.reshape((batch_size, nsteps, diminput))
        # Fit training using batch data
        feeds = {x: batch_xs, y: batch_ys}
        sess.run(optm, feed_dict=feeds)
        # Compute average loss
        avg_cost += sess.run(cost, feed_dict=feeds)/total_batch
    # Display logs per epoch step
    if epoch % display_step == 0: 
        print ("Epoch: %03d/%03d cost: %.9f" % (epoch, training_epochs, avg_cost))
        feeds = {x: batch_xs, y: batch_ys}
        train_acc = sess.run(accr, feed_dict=feeds)
        print (" Training accuracy: %.3f" % (train_acc))
        testimgs = testimgs.reshape((ntest, nsteps, diminput))
        feeds = {x: testimgs, y: testlabels, istate: np.zeros((ntest, 2*dimhidden))}
        test_acc = sess.run(accr, feed_dict=feeds)
        print (" Test accuracy: %.3f" % (test_acc))
print ("Optimization Finished.")

 

tensorflow学习笔记七----------RNN

原文:https://www.cnblogs.com/xxp17457741/p/9483514.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!