1 #coding= utf-8 2 import os 3 import torch 4 from torchvision import transforms 5 from data_pipe import get_data 6 from vgg import VGG_13 7 import numpy as np 8 import cv2 9 from PIL import Image 10 11 12 class Infer(object): 13 14 def __init__(self): 15 self.model = VGG_13() 16 self.model.load_state_dict(torch.load("./models/model_18.pth")) 17 self.model.eval() 18 19 def _infer(self, img_tensor): 20 with torch.no_grad(): 21 result = self.model(img_tensor) 22 return result 23 24 def predict(self, path): 25 img_path_list = [os.path.join(path ,x) for x in os.listdir(path)] 26 transform = transforms.Compose([ 27 transforms.Resize([224, 224]), 28 transforms.ToTensor()]) 29 for img_path in img_path_list: 30 print(img_path) 31 img = cv2.imread(img_path) 32 img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) 33 img_tensor = transform(img) 34 img_tensor = img_tensor.reshape((1, 3, 224, 224)) 35 result = self._infer(img_tensor) 36 print(result) 37 38 39 if __name__ == "__main__": 40 path = "./test_images" 41 Infer().predict(path)
原文:https://www.cnblogs.com/timelesszxl/p/14598430.html