使用的是Fashion MNIST数据集
import tensorflow as tf import pandas as pd import numpy as np import matplotlib.pyplot as plt (train_image,train_label),(test_image,test_label) = tf.keras.datasets.fashion_mnist.load_data() train_image = train_image/255 test_image = test_image/255 # model = tf.keras.Sequential() # model.add(tf.keras.layers.Flatten(input_shape=(28,28))) # model.add(tf.keras.layers.Dense(128,activation=‘relu‘)) # model.add(tf.keras.layers.Dense(10,activation=‘softmax‘)) # model.compile(optimizer=‘adam‘,loss=‘sparse_categorical_crossentropy‘,metrics=[‘acc‘]) # model.fit(train_image,train_label,epochs=5) # model.evaluate(test_image,test_label) # 独热编码 train_label_onehot = tf.keras.utils.to_categorical(train_label) test_label_onehot = tf.keras.utils.to_categorical(test_label) model = tf.keras.Sequential() model.add(tf.keras.layers.Flatten(input_shape=(28,28))) model.add(tf.keras.layers.Dense(128,activation=‘relu‘)) model.add(tf.keras.layers.Dense(10,activation=‘softmax‘)) model.compile(optimizer=‘adam‘,loss=‘categorical_crossentropy‘,metrics=[‘acc‘]) model.fit(train_image,train_label_onehot,epochs=5) predict = model.predict(test_image) print(np.argmax(predict[0])) print(test_label[0])
原文:https://www.cnblogs.com/xhj1074376195/p/14299793.html