首页 > 其他 > 详细

第五讲 卷积神经网络 VGG16 cifar10

时间:2020-05-09 22:01:24      阅读:99      评论:0      收藏:0      [点我收藏+]
  1 import tensorflow as tf
  2 import os
  3 import numpy as np
  4 from matplotlib import pyplot as plt
  5 from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Dropout, Flatten, Dense
  6 from tensorflow.keras import Model
  7 
  8 
  9 np.set_printoptions(threshold=np.inf)
 10 
 11 cifar10 = tf.keras.datasets.cifar10
 12 (x_train, y_train), (x_test, y_test) = cifar10.load_data()
 13 x_train, x_test = x_train/255.0, x_test/255.0
 14 
 15 
 16 class VGG16(Model):
 17   def __init__(self):
 18     super(VGG16, self).__init__()
 19     self.c1 = Conv2D(filters=64, kernel_size=(3, 3), padding=same)
 20     self.b1 = BatchNormalization()
 21     self.a1 = Activation(relu)
 22     self.c2 = Conv2D(filters=64, kernel_size=(3, 3), padding=same)
 23     self.b2 = BatchNormalization()
 24     self.a2 = Activation(relu)
 25     self.p1 = MaxPool2D(pool_size = (2, 2), strides=2, padding=same)
 26     self.d1 = Dropout(0.2)
 27 
 28     self.c3 = Conv2D(filters=128, kernel_size=(3, 3), padding=same)
 29     self.b3 = BatchNormalization()
 30     self.a3 = Activation(relu)
 31     self.c4 = Conv2D(filters=128, kernel_size=(3, 3), padding=same)
 32     self.b4 = BatchNormalization()
 33     self.a4 = Activation(relu)
 34     self.p2 = MaxPool2D(pool_size=(2,2), strides=2, padding=same)
 35     self.d2 = Dropout(0.2)
 36 
 37     self.c5 = Conv2D(filters=256, kernel_size=(3, 3), padding=same)
 38     self.b5 = BatchNormalization()
 39     self.a5 = Activation(relu)
 40     self.c6 = Conv2D(filters=256, kernel_size=(3, 3), padding=same)
 41     self.b6 = BatchNormalization()
 42     self.a6 = Activation(relu)
 43     self.c7 = Conv2D(filters=256, kernel_size=(3, 3), padding=same)
 44     self.b7 = BatchNormalization()
 45     self.a7 = Activation(relu)
 46     self.p3 = MaxPool2D(pool_size=(2, 2), strides=2, padding=same)
 47     self.d3 = Dropout(0.2)
 48 
 49     self.c8 = Conv2D(filters=512, kernel_size=(3, 3), padding=same)
 50     self.b8 = BatchNormalization()
 51     self.a8 = Activation(relu)
 52     self.c9 = Conv2D(filters=512, kernel_size=(3, 3), padding=same)
 53     self.b9 = BatchNormalization()
 54     self.a9 = Activation(relu)
 55     self.c10 = Conv2D(filters=512, kernel_size=(3, 3), padding=same)
 56     self.b10 = BatchNormalization()
 57     self.a10 = Activation(relu)
 58     self.p4 = MaxPool2D(pool_size=(2, 2), strides=2, padding=same)
 59     self.d4 = Dropout(0.2)
 60 
 61     self.c11 = Conv2D(filters=512, kernel_size=(3, 3), padding=same)
 62     self.b11 = BatchNormalization()
 63     self.a11 = Activation(relu)
 64     self.c12 = Conv2D(filters=512, kernel_size=(3, 3), padding=same)
 65     self.b12 = BatchNormalization()
 66     self.a12 = Activation(relu)
 67     self.c13 = Conv2D(filters=512, kernel_size=(3, 3), padding=same)
 68     self.b13 = BatchNormalization()
 69     self.a13 = Activation(relu)
 70     self.p5 = MaxPool2D(pool_size=(2,2), strides=2, padding=same)
 71     self.d5 = Dropout(0.2)
 72 
 73     self.flatten = Flatten()
 74     self.f1 = Dense(512, activation=relu)
 75     self.d6 = Dropout(0.2)
 76     self.f2 = Dense(512, activation=relu)
 77     self.d7 = Dropout(0.7)
 78     self.f3 = Dense(10, activation=softmax)
 79 
 80   def call(self, x):
 81     x = self.c1(x)
 82     x = self.b1(x)
 83     x = self.a1(x)
 84     x = self.c2(x)
 85     x = self.b2(x)
 86     x = self.a2(x)
 87     x = self.p1(x)
 88     x = self.d1(x)
 89 
 90     x = self.c3(x)
 91     x = self.b3(x)
 92     x = self.a3(x)
 93     x = self.c4(x)
 94     x = self.b4(x)
 95     x = self.a4(x)
 96     x = self.p2(x)
 97     x = self.d2(x)
 98 
 99     x = self.c5(x)
100     x = self.b5(x)
101     x = self.a5(x)
102     x = self.c6(x)
103     x = self.b6(x)
104     x = self.a6(x)
105     x = self.c7(x)
106     x = self.b7(x)
107     x = self.a7(x)
108     x = self.p3(x)
109     x = self.d3(x)
110 
111     x = self.c8(x)
112     x = self.b8(x)
113     x = self.a8(x)
114     x = self.c9(x)
115     x = self.b9(x)
116     x = self.a9(x)
117     x = self.c10(x)
118     x = self.b10(x)
119     x = self.a10(x)
120     x = self.p4(x)
121     x = self.d4(x)
122 
123     x = self.c11(x)
124     x = self.b11(x)
125     x = self.a11(x)
126     x = self.c12(x)
127     x = self.b12(x)
128     x = self.a12(x)
129     x = self.c13(x)
130     x = self.b13(x)
131     x = self.a13(x)
132     x = self.p5(x)
133     x = self.d5(x)
134 
135     x = self.flatten(x)
136     x = self.f1(x)
137     x = self.d6(x)
138     x = self.f2(x)
139     x = self.d7(x)
140     y = self.f3(x)
141     return y
142 
143 
144 model = VGG16()
145 
146 model.compile(optimizer=adam, 
147               loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
148               metrics=[sparse_categorical_accuracy])
149 
150 checkpoint_save_path = "./checkpoint/VGG16.ckpt"
151 if os.path.exists(checkpoint_save_path + .index):
152   print(------------------------load the model---------------------)
153   model.load_weights(checkpoint_save_path)
154 
155 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
156                                                  save_weights_only=True,
157                                                  save_best_only=True)
158 
159 
160 history = model.fit(x_train, y_train, batch_size=128, epochs=50, validation_data=(x_test, y_test),validation_freq=1, callbacks=[cp_callback])
161 
162 model.summary()
163 
164 
165 with open(./weights.txt, w) as f:
166   for v in model.trainable_variables:
167     f.write(str(v.name) + \n)
168     f.write(str(v.shape) + \n)
169     f.write(str(v.numpy()) + \n)
170 
171 
172 
173 def plot_acc_loss_curve(history):
174     # 显示训练集和验证集的acc和loss曲线
175     from matplotlib import pyplot as plt
176     acc = history.history[sparse_categorical_accuracy]
177     val_acc = history.history[val_sparse_categorical_accuracy]
178     loss = history.history[loss]
179     val_loss = history.history[val_loss]
180     
181     plt.figure(figsize=(15, 5))
182     plt.subplot(1, 2, 1)
183     plt.plot(acc, label=Training Accuracy)
184     plt.plot(val_acc, label=Validation Accuracy)
185     plt.title(Training and Validation Accuracy)
186     plt.legend()
187     #plt.grid()
188     
189     plt.subplot(1, 2, 2)
190     plt.plot(loss, label=Training Loss)
191     plt.plot(val_loss, label=Validation Loss)
192     plt.title(Training and Validation Loss)
193     plt.legend()
194     #plt.grid()
195     plt.show()
196 
197 plot_acc_loss_curve(history)

 

第五讲 卷积神经网络 VGG16 cifar10

原文:https://www.cnblogs.com/wbloger/p/12860137.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!