首页 > 其他 > 详细

pytorch批训练数据构造

时间:2018-08-11 15:00:50      阅读:184      评论:0      收藏:0      [点我收藏+]

这是对莫凡python的学习笔记。

1.创建数据

import torch
import torch.utils.data as Data

BATCH_SIZE = 8
x = torch.linspace(1,10,10)
y = torch.linspace(10,1,10)

可以看到创建了两个一维数据,x:1~10,y:10~1

2.构造数据集对象,及数据加载器对象

torch_dataset = Data.TensorDataset(x,y)
loader = Data.DataLoader(
            dataset = torch_dataset,
            batch_size = BATCH_SIZE,
            shuffle = False,
            num_workers = 2)

num_workers应该指的是多线程

3.输出数据集,这一步主要是看一下batch长什么样子

for epoch in range(3):
    for step, (batch_x, batch_y) in  enumerate(loader):
        print(Epoch:,epoch,| Step:, step, | batch x:,
                 batch_x.numpy(), | batch y:, batch_y.numpy())

输出如下

(‘Epoch:‘, 0, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
(‘Epoch:‘, 0, ‘| Step:‘, 1, ‘| batch x:‘, array([ 9., 10.], dtype=float32), ‘| batch y:‘, array([2., 1.], dtype=float32))
(‘Epoch:‘, 1, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
(‘Epoch:‘, 1, ‘| Step:‘, 1, ‘| batch x:‘, array([ 9., 10.], dtype=float32), ‘| batch y:‘, array([2., 1.], dtype=float32))
(‘Epoch:‘, 2, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5., 6., 7., 8.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.,  5.,  4.,  3.], dtype=float32))
(‘Epoch:‘, 2, ‘| Step:‘, 1, ‘| batch x:‘, array([ 9., 10.], dtype=float32), ‘| batch y:‘, array([2., 1.], dtype=float32))

可以看到,batch_size等于8,则第二个bacth的数据只有两个。

将batch_size改为5,输出如下

(‘Epoch:‘, 0, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.], dtype=float32))
(‘Epoch:‘, 0, ‘| Step:‘, 1, ‘| batch x:‘, array([ 6.,  7.,  8.,  9., 10.], dtype=float32), ‘| batch y:‘, array([5., 4., 3., 2., 1.], dtype=float32))
(‘Epoch:‘, 1, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.], dtype=float32))
(‘Epoch:‘, 1, ‘| Step:‘, 1, ‘| batch x:‘, array([ 6.,  7.,  8.,  9., 10.], dtype=float32), ‘| batch y:‘, array([5., 4., 3., 2., 1.], dtype=float32))
(‘Epoch:‘, 2, ‘| Step:‘, 0, ‘| batch x:‘, array([1., 2., 3., 4., 5.], dtype=float32), ‘| batch y:‘, array([10.,  9.,  8.,  7.,  6.], dtype=float32))
(‘Epoch:‘, 2, ‘| Step:‘, 1, ‘| batch x:‘, array([ 6.,  7.,  8.,  9., 10.], dtype=float32), ‘| batch y:‘, array([5., 4., 3., 2., 1.], dtype=float32))

 

pytorch批训练数据构造

原文:https://www.cnblogs.com/wzyuan/p/9459744.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!