首页 > 其他 > 详细

torch加载参数

时间:2020-03-17 20:54:55      阅读:65      评论:0      收藏:0      [点我收藏+]
 1 from torch.utils.data import DataLoader
 2 from torchvision import datasets
 3 from PIL import Image as img
 4 
 5 dataPath = ./data/imgs/
 6 
 7 dataset = datasets.ImageFolder(./data/, loader=img.open)
 8 dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=0)
 9 
10 # 方式一
11 for epoch in range(100):
12     for i, (img, _)in enumerate(dataloader):
13         # do training
14 
15 # 方式二
16 
17 def data_gen(data_loader):
18     while True:
19         for (images, _) in enumerate(data_loader):
20             yield images
21 
22 gen_img = data_gen(dataloader)
23 
24 for iter in range(100):
25     imgs = gen_img.__next__()
26     # do training

 

torch加载参数

原文:https://www.cnblogs.com/JunzhaoLiang/p/12513245.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!