torch.utils.data.DataLoader()
:构建可迭代的数据装载器, 训练的时候,每一个for循环,每一次iteration,就是从DataLoader中获取一个batch_size大小的数据的。
Dataloader()参数:
torch.utils.data.Dataset()
:Dataset抽象类, 所有自定义的Dataset都需要继承它,并且必须复写__getitem__()
这个类方法。
__getitem__
方法的是Dataset的核心,作用是接收一个索引, 返回一个样本, 看上面的函数,参数里面接收index,然后我们需要编写究竟如何根据这个索引去读取我们的数据部分。
torchvision已经预先实现了常用的Dataset, 其他预先实现的有: torchvision.datasets.CIFAR10
, 可以读取CIFAR-10,以及ImageNet、COCO、MNIST、LSUN等数据集。
ImageFolder假设所有的文件按文件夹保存,每个文件夹下存储同一个类别的图片,文件夹名为类名,其构造函数如下:
ImageFolder(root, transform=None, target_transform=None, loader=default_loader)
参数:
示例:
文件夹格式:
train_path = r‘datasets/myDataSet/train‘
预处理格式:
train_transform = transforms.Compose([
transforms.Resize((40,40)),
transforms.RandomCrop(40,padding=4),
transforms.ToTensor(),
transforms.Normalize([0.485,0.456,0.406],
[0.229,0.224,0.225],)
])
dataset:
trainset = ImageFolder(train_path,transform = train_transform)
# print(trainset[30]) # 元组类型,第30号图片的(像素信息,label)
Data.DataLoader:
train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=False)
for i,(img, target) in enumerate(train_loader):
print(i)
print(img.shape) # (batchsize, channel, H, W)
print(target.shape) # (batch)
print(target) # 一个batch图片对应的label
class myDataset(Data.Dataset):
def __init__(self, path, transform):
self.path = path
self.transform = transform
self.data_info = self.get_img_info(path)
self.label = []
for i in range(len(self.data_info)):
self.label.append(list(self.data_info[i])[1])
def __getitem__(self, idx):
path_img = self.data_info[idx][0]
label = self.label[idx]
img = Image.open(path_img).convert(‘RGB‘) # 0~255
if self.transform is not None:
img = self.transform(img) # 在这里做transform,转为tensor等等
return img, label, idx
def __len__(self):
return len(self.data_info)
@staticmethod
def get_img_info(data_dir):
data_info = list()
for root, dirs, _ in os.walk(data_dir):
# 遍历类别
for sub_dir in dirs:
img_names = os.listdir(os.path.join(root, sub_dir))
img_names = list(filter(lambda x: x.endswith(‘.jpg‘), img_names))
# 遍历图片
for i in range(len(img_names)):
img_name = img_names[i]
path_img = os.path.join(root, sub_dir, img_name)
label = int(sub_dir)
data_info.append((path_img, int(label)))
return data_info
trainset = myDataset(train_path, train_transform)
train_loader = Data.DataLoader(dataset=trainset, batch_size=4,shuffle=True)
for i,(img, target, index) in enumerate(train_loader):
print(i)
print(img.shape) # (batchsize, channel, H, W)
print(target.shape) # (batch)
print(target) # 一个batch的图片对应的label
print(index) # 一个batch的图片在数据集中对应的index
s
Pytorch数据读取机制(DataLoader)与图像预处理模块(transforms)
原文:https://www.cnblogs.com/bin888/p/15036953.html