首页 > Web开发 > 详细

encode与decode

时间:2018-12-28 14:58:43      阅读:152      评论:0      收藏:0      [点我收藏+]
import torch
from torch import nn
import numpy as np
import matplotlib.pyplot as plt
import torch.utils.data as Data
import torchvision
from mpl_toolkits.mplot3d import Axes3D    #画3D图
from matplotlib import cm
# Hyper Parameters
EPOCH=10
BATCH_SIZE=64
LR = 0.005 # learning rate
DOWNLOAD_MNIST=False
N_TEST_IMG=5

train_data=torchvision.datasets.MNIST(
    root=./mnist/,
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=DOWNLOAD_MNIST
)

train_loader=Data.DataLoader(dataset=train_data,batch_size=BATCH_SIZE,shuffle=True)

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(
            nn.Linear(28 * 28, 128),
            nn.Tanh(),
            nn.Linear(128,64),
            nn.Tanh(),
            nn.Linear(64, 12),
            # nn.Tanh(),
            # nn.Linear(12, 3),
        )
        self.decoder=nn.Sequential(
            # nn.Linear(3,12),
            # nn.Tanh(),
            nn.Linear(12, 64),
            nn.Tanh(),
            nn.Linear(64, 128),
            nn.Tanh(),
            nn.Linear(128, 28*28),
            nn.Sigmoid()

        )

    def forward(self, x ):
       encoder=self.encoder(x)
       decoder=self.decoder(encoder)
       return  encoder,decoder


AutoEncoder = AutoEncoder()
# print(AutoEncoder)

optimizer = torch.optim.Adam(AutoEncoder.parameters(), lr=LR)  # optimize all cnn parameters
loss_func = nn.MSELoss()

f,a=plt.subplots(2,N_TEST_IMG,figsize=(5,2))

plt.ion()  # continuously plot

view_data=train_data.train_data[:N_TEST_IMG].view(-1,28*28).type(torch.FloatTensor)/255

for i in range(N_TEST_IMG):
    a[0][i].imshow(np.reshape(view_data.data.numpy()[i], (28, 28)), cmap=gray)
    a[0][i].set_xticks(())
    a[0][i].set_yticks(())

for epoch in range(EPOCH):
    for step,(x,b_label) in enumerate(train_loader):
        b_x=x.view(-1,28*28)
        b_y=x.view(-1,28*28)
        encoded, decoded = AutoEncoder(b_x)
        loss=loss_func(decoded,b_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        if step%100==0:
            print(Epoch:|,epoch,train loss:%0.4f%loss.data.numpy())
            _,decoded_data=AutoEncoder(view_data)
            for i in range(N_TEST_IMG):
                a[1][i].clear()
                a[1][i].imshow(np.reshape(decoded.data.numpy()[i],(28,28)),cmap=gray)
                a[1][i].set_xticks(())
                a[1][i].set_yticks(())
            plt.draw()
            plt.pause(0.05)
plt.ioff()
plt.show()

view_data=train_data.train_data[:200].view(-1,28*28).type(torch.FloatTensor)/255
encoded_data,_=AutoEncoder(view_data)
fig=plt.figure(2)
ax=Axes3D(fig)
X,Y,Z=encoded_data.data[:, 0].numpy(), encoded_data.data[:, 1].numpy(), encoded_data.data[:, 2].numpy()
values=train_data.train_labels[:200].numpy()
for x,y,z ,s in zip(X,Y,Z,values):
    c=cm.rainbow(int(255*s/9))
    ax.text(x,y,z,s,backgroundcolor=c)
ax.set_xlim(X.min(),X.max())
ax.set_ylim(Y.min(),Y.max())
ax.set_zlim(Z.min(),Z.max())
plt.show()

选出五张图片做测试。

图像分为5*2显示,上面一行是原始图像,下面一行为编码和解码后的图像。

encode与decode

原文:https://www.cnblogs.com/wmy-ncut/p/10190482.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!