首页 > 其他 > 详细

3.tensorflow——NN

时间:2019-11-15 21:05:48      阅读:86      评论:0      收藏:0      [点我收藏+]
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

numClasses=10
inputsize=784
numHiddenUnits=50
trainningIterations=50000#total steps
batchSize=64#

#1.dataset
mnist=input_data.read_data_sets(data/,one_hot=True)
############################################################
#2.tarin
X=tf.placeholder(tf.float32,shape=[None,inputsize])
y=tf.placeholder(tf.float32,shape=[None,numClasses])
#2.1 initial paras
#y1=X*W1+B1
W1=tf.Variable(tf.truncated_normal([inputsize,numHiddenUnits],stddev=0.1))
B1=tf.Variable(tf.constant(0.1),[numHiddenUnits])
#y=y1*W2+B2
W2=tf.Variable(tf.truncated_normal([numHiddenUnits,numClasses],stddev=0.1))
B2=tf.Variable(tf.constant(0.1),[numClasses])
#layers
hiddenLayerOutput=tf.nn.relu(tf.matmul(X,W1)+B1)
finalOutput=tf.nn.relu(tf.matmul(hiddenLayerOutput,W2)+B2)

#2.2 tarin set up
loss=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=finalOutput))
opt=tf.train.GradientDescentOptimizer(learning_rate=0.1).minimize(loss)
correct_prediction=tf.equal(tf.argmax(finalOutput,1),tf.argmax(y,1))
accuracy=tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

#2.3 run tarin
sess=tf.Session()
init=tf.global_variables_initializer()
sess.run(init)
for i in range(trainningIterations):
    batch=mnist.train.next_batch(batchSize)
    batchInput=batch[0]
    batchLabels=batch[1]
    sess.run(opt,feed_dict={X:batchInput,y:batchLabels})
    if i%1000 == 0:
         train_accuracy=sess.run(accuracy,feed_dict={X:batchInput,y:batchLabels})
         print("step %d, tarinning accuracy %g" % (i,train_accuracy))

#2.4 run test to accuracy
batch=mnist.test.next_batch(batchSize)
testAccuracy=sess.run(accuracy,feed_dict={X:batch[0],y:batch[1]})
print("test accuracy %g" % (testAccuracy))

输出结果:

step 0, tarinning accuracy 0.171875
step 1000, tarinning accuracy 0.84375
step 2000, tarinning accuracy 0.953125
step 3000, tarinning accuracy 0.84375
step 4000, tarinning accuracy 0.953125
step 5000, tarinning accuracy 1
step 6000, tarinning accuracy 0.984375
step 7000, tarinning accuracy 1
step 8000, tarinning accuracy 0.984375
step 9000, tarinning accuracy 1
step 10000, tarinning accuracy 1
step 11000, tarinning accuracy 0.96875
step 12000, tarinning accuracy 1
step 13000, tarinning accuracy 0.96875
step 14000, tarinning accuracy 1
step 15000, tarinning accuracy 0.984375
step 16000, tarinning accuracy 0.953125
step 17000, tarinning accuracy 1
step 18000, tarinning accuracy 1
step 19000, tarinning accuracy 1
step 20000, tarinning accuracy 1
step 21000, tarinning accuracy 1
step 22000, tarinning accuracy 1
step 23000, tarinning accuracy 1
step 24000, tarinning accuracy 1
step 25000, tarinning accuracy 1
step 26000, tarinning accuracy 1
step 27000, tarinning accuracy 1
step 28000, tarinning accuracy 1
step 29000, tarinning accuracy 1
step 30000, tarinning accuracy 1
step 31000, tarinning accuracy 1
step 32000, tarinning accuracy 1
step 33000, tarinning accuracy 1
step 34000, tarinning accuracy 1
step 35000, tarinning accuracy 1
step 36000, tarinning accuracy 1
step 37000, tarinning accuracy 1
step 38000, tarinning accuracy 1
step 39000, tarinning accuracy 1
step 40000, tarinning accuracy 0.984375
step 41000, tarinning accuracy 1
step 42000, tarinning accuracy 1
step 43000, tarinning accuracy 1
step 44000, tarinning accuracy 1
step 45000, tarinning accuracy 1
step 46000, tarinning accuracy 1
step 47000, tarinning accuracy 1
step 48000, tarinning accuracy 1
step 49000, tarinning accuracy 1
test accuracy 0.984375

 

3.tensorflow——NN

原文:https://www.cnblogs.com/yrm1160029237/p/11868992.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!