首页 > Web开发 > 详细

第五讲 卷积神经网络 AlexNet8 cifar10

时间:2020-05-08 23:59:55      阅读:190      评论:0      收藏:0      [点我收藏+]
import tensorflow as tf
import os
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
from tensorflow.keras import Model


np.set_printoptions(threshold=np.inf)

cifar10 = tf.keras.datasets.cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train, x_test = x_train/25.0, x_test/255.0


class AlexNet8(Model):
  def __init__(self):
    super(AlexNet8, self).__init__()
    self.c1 = Conv2D(filters=96, kernel_size=(3, 3))
    self.b1 = BatchNormalization()
    self.a1 = Activation(relu)
    self.p1 = MaxPool2D(pool_size=(3, 3), strides=2)

    self.c2 = Conv2D(filters=256, kernel_size=(3, 3))
    self.b2 = BatchNormalization()
    self.a2 = Activation(relu)
    self.p2 = MaxPool2D(pool_size=(3, 3), strides=2)

    self.c3 = Conv2D(filters=384, kernel_size=(3, 3), padding=same, activation=relu)
    self.c4 = Conv2D(filters=384, kernel_size=(3, 3), padding=same, activation=relu)
    self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding=same, activation=relu)
    self.p3 = MaxPool2D(pool_size=(3, 3), strides=2)

    self.flatten = Flatten()
    self.f1 = Dense(2048, activation=relu)
    self.d1 = Dropout(0.5)
    self.f2 = Dense(2048, activation=relu)
    self.d2 = Dropout(0.5)
    self.f3 = Dense(10, activation=softmax)

  def call(self, x):
    x = self.c1(x)
    x = self.b1(x)
    x = self.a1(x)
    x = self.p1(x)

    x = self.c2(x)
    x = self.b2(x)
    x = self.a2(x)
    x = self.p2(x)

    x = self.c3(x)
    x = self.c4(x)
    x = self.c5(x)
    x = self.p3(x)

    x = self.flatten(x)
    x = self.f1(x)
    x = self.d1(x)
    x = self.f2(x)
    x = self.d2(x)
    y = self.f3(x)
    return y


model = AlexNet8()

model.compile(optimizer=adam, 
              loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics = [sparse_categorical_accuracy])

checkpoint_save_path = "./checkpoint/Baseline.ckpt"
if os.path.exists(checkpoint_save_path + ".index"):
  print("--------------------load the model-----------------")
  model.load_weights(checkpoint_save_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True)

history = model.fit(x_train, y_train, batch_size=32, epochs=100, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback])

model.summary()



with open(./weights.txt, w) as file:
  for v in model.trainable_variables:
    file.write(str(v.name) + \n)
    file.write(str(v.shape) + \n)
    file.write(str(v.numpy()) + \n)



def plot_acc_loss_curve(history):
    # 显示训练集和验证集的acc和loss曲线
    from matplotlib import pyplot as plt
    acc = history.history[sparse_categorical_accuracy]
    val_acc = history.history[val_sparse_categorical_accuracy]
    loss = history.history[loss]
    val_loss = history.history[val_loss]
    
    plt.figure(figsize=(15, 5))
    plt.subplot(1, 2, 1)
    plt.plot(acc, label=Training Accuracy)
    plt.plot(val_acc, label=Validation Accuracy)
    plt.title(Training and Validation Accuracy)
    plt.legend()
   #plt.grid()
    
    plt.subplot(1, 2, 2)
    plt.plot(loss, label=Training Loss)
    plt.plot(val_loss, label=Validation Loss)
    plt.title(Training and Validation Loss)
    plt.legend()
    #plt.grid()
    plt.show()

plot_acc_loss_curve(history)

 

第五讲 卷积神经网络 AlexNet8 cifar10

原文:https://www.cnblogs.com/wbloger/p/12853766.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!