首页 > 其他 > 详细

Keras猫狗大战三:加载模型,预测目录中图片,画混淆矩阵

时间:2019-06-22 20:31:09      阅读:201      评论:0      收藏:0      [点我收藏+]

版权声明:本文为博主原创文章,欢迎转载,并请注明出处。联系方式:460356155@qq.com

 一、加载模型,预测测试集

%matplotlib inline
import matplotlib.pyplot as plt

import os
import itertools
import cv2

import numpy as np
from sklearn.metrics import confusion_matrix

from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model

dst_path = rD:\BaiduNetdiskDownload\small
model_file = r"D:\fastai\projects\cats_and_dogs_small_1.h5"
test_dir = os.path.join(dst_path, test)

batch_size = 20

model = load_model(model_file)

test_datagen = ImageDataGenerator(rescale=1. / 255)

test_generator = test_datagen.flow_from_directory(
    test_dir,
    target_size=(150, 150),
    batch_size=batch_size,
    class_mode=binary)

test_loss, test_acc = model.evaluate_generator(test_generator, steps=test_generator.samples / batch_size)
print(test acc: %.3f%% % test_acc)
Found 400 images belonging to 2 classes.
test acc: 0.747%

二、预测测试集,画混淆矩阵
def get_input_xy(src=[]):
    pre_x = []
    true_y = []

    class_indices = {cat: 0, dog: 1}

    for s in src:
        input = cv2.imread(s)
        input = cv2.resize(input, (150, 150))
        input = cv2.cvtColor(input, cv2.COLOR_BGR2RGB)
        pre_x.append(input)

        _, fn = os.path.split(s)
        y = class_indices.get(fn[:3])
        true_y.append(y)

    pre_x = np.array(pre_x) / 255.0

    return pre_x, true_y


def plot_sonfusion_matrix(cm, classes, normalize=False, title=Confusion matrix, cmap=plt.cm.Blues):
    plt.imshow(cm, interpolation=nearest, cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype(float) / cm.sum(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.0
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j], horizontalalignment=center, color=white if cm[i, j] > thresh else black)

    plt.tight_layout()
    plt.ylabel(True label)
    plt.xlabel(Predict label)


test = os.listdir(test_dir)

images = []

# 获取每张图片的地址,并保存在列表images中
for testpath in test:
    for fn in os.listdir(os.path.join(test_dir, testpath)):
        if fn.endswith(jpg):
            fd = os.path.join(test_dir, testpath, fn)
            images.append(fd)

# 得到规范化图片及true label
pre_x, true_y = get_input_xy(images)

# 预测
pred_y = model.predict_classes(pre_x)

# 画混淆矩阵
confusion_mat = confusion_matrix(true_y, pred_y)
plot_sonfusion_matrix(confusion_mat, classes=range(2))

技术分享图片

 

Keras猫狗大战三:加载模型,预测目录中图片,画混淆矩阵

原文:https://www.cnblogs.com/zhengbiqing/p/11070050.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!