首页 > 其他 > 详细

第四讲 网络八股拓展--用mnist数据集实现断点续训

时间:2020-05-07 01:53:20      阅读:47      评论:0      收藏:0      [点我收藏+]
 1 import tensorflow as tf
 2 import os
 3 
 4 
 5 mnist = tf.keras.datasets.mnist
 6 (x_train, y_train), (x_test, y_test) = mnist.load_data()
 7 x_train, x_test = x_train/255.0, x_test/255.0
 8 
 9 
10 model = tf.keras.models.Sequential([
11         tf.keras.layers.Flatten(),
12         tf.keras.layers.Dense(128, activation=relu),
13         tf.keras.layers.Dense(10, activation=softmax)
14 ])
15 
16 model.compile(optimizer=adam,
17                loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
18                metrics=[sparse_categorical_accuracy])
19 
20 
21 
22 checkpoint_save_path = "./checkpoint/mnist.ckpt"
23 if os.path.exists(checkpoint_save_path + ".index"):
24   print("-----------------load the model-----------------------")
25   model.load_weights(checkpoint_save_path)
26 
27 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, 
28                                                  save_weights_only=True,
29                                                  save_best_only=True)
30 
31 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), 
32                     validation_freq=1, callbacks=[cp_callback])
33 model.summary()

 

第四讲 网络八股拓展--用mnist数据集实现断点续训

原文:https://www.cnblogs.com/wbloger/p/12839556.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!