首页 > 其他 > 详细

PyTorch——模型推断——单张推断(二)

时间:2021-03-30 20:55:51      阅读:32      评论:0      收藏:0      [点我收藏+]
 1 #coding= utf-8
 2 import os
 3 import torch
 4 from torchvision import transforms
 5 from data_pipe import get_data
 6 from vgg import VGG_13
 7 import numpy as np
 8 import cv2
 9 from PIL import Image
10 
11 
12 class Infer(object):
13 
14     def __init__(self):
15         self.model = VGG_13()
16         self.model.load_state_dict(torch.load("./models/model_18.pth"))
17         self.model.eval()
18 
19     def _infer(self, img_tensor):
20         with torch.no_grad():
21             result = self.model(img_tensor)
22         return result
23 
24     def predict(self, path):
25         img_path_list = [os.path.join(path ,x) for x in os.listdir(path)]
26         transform = transforms.Compose([
27             transforms.Resize([224, 224]),
28             transforms.ToTensor()])
29         for img_path in img_path_list:
30             print(img_path)
31             img = cv2.imread(img_path)
32             img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
33             img_tensor = transform(img)
34             img_tensor = img_tensor.reshape((1, 3, 224, 224))
35             result = self._infer(img_tensor)
36             print(result)
37 
38 
39 if __name__ == "__main__":
40     path = "./test_images"
41     Infer().predict(path)

 

PyTorch——模型推断——单张推断(二)

原文:https://www.cnblogs.com/timelesszxl/p/14598430.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!