import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import os
device = (‘cuda:1‘ if torch.cuda.is_available() else ‘cpu‘)
# device = (‘cpu‘)
# Training settings
batch_size = 64
root = ‘pytorch-master/mnist_data‘
train_dataset = datasets.MNIST(root=root,
train=True,
transform=transforms.ToTensor(),
download=True)
test_dataset = datasets.MNIST(root=root,
train=False,
transform=transforms.ToTensor(),
download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True,
drop_last=True
)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False,
drop_last=True)
save_path = os.path.join(root, ‘savepath‘)
os.makedirs(save_path, exist_ok=True)
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5)
self.conv2 = nn.Conv2d(10, 20, 5)
self.conv3 = nn.Conv2d(20, 40, 3)
self.mp = nn.MaxPool2d(2)
self.mp1 = nn.MaxPool2d(2)
self.fc1 = nn.Linear(2560, 512)
self.fc2 = nn.Linear(512, 10)
def forward(self, x):
in_size = x.size(0)
x = F.relu(self.mp(self.conv1(x)))
x = F.relu(self.mp(self.conv2(x)))
x = F.relu(self.mp1(self.conv3(x)))
x = x.view(in_size, -1)
x = self.fc1(x)
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net().to(device)
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
def data_enhance(data, batch_idx):
new_data = torch.zeros((data.size(0), data.size(1), 28 * 3, 28 * 3))
noise = torch.rand(new_data.size())
index = batch_idx % 9
if index == 0:
new_data[:, :, 0:28, 0:28] = data
elif index == 1:
new_data[:, :, 28:56, 0:28] = data
elif index == 2:
new_data[:, :, 56:, 0:28] = data
elif index == 3:
new_data[:, :, 0:28, 28:56] = data
elif index == 4:
new_data[:, :, 28:56, 28:56] = data
elif index == 5:
new_data[:, :, 56:, 28:56] = data
elif index == 6:
new_data[:, :, 0:28, 56:] = data
elif index == 7:
new_data[:, :, 28:56, 56:] = data
elif index == 8:
new_data[:, :, 56:, 56:] = data
new_data = noise*0.7 + new_data*0.3
return new_data
def train(epoch):
for batch_idx, (data, target) in enumerate(train_loader):
data = data_enhance(data, batch_idx).to(device)
output = model(data)
loss = F.nll_loss(output, target.to(device))
if batch_idx % 200 == 0:
contest = ‘Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\n‘.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data.item())
print(contest)
with open(os.path.join(root, ‘log.txt‘), ‘a‘) as f:
f.write(contest)
loss.backward()
optimizer.step()
optimizer.zero_grad()
torch.save(model.state_dict(), os.path.join(save_path, str(epoch) + ‘.pth‘))
def test():
test_loss = 0
correct = 0
for index, (data, target) in enumerate(test_loader):
data = data_enhance(data, index).to(device)
output = model(data)
test_loss += F.nll_loss(output, target.to(device), size_average=False).data.item()
pred = output.data.max(1, keepdim=True)[1]
correct += pred.eq(target.to(device).data.view_as(pred)).cpu().sum()
test_loss /= len(test_loader.dataset)
contest = ‘Test set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n\n‘.format(
test_loss, correct, len(test_loader.dataset),
100. * correct / len(test_loader.dataset))
print(contest)
with open(os.path.join(root, ‘log.txt‘), ‘a‘) as f:
f.write(contest)
from torchvision.utils import save_image
feature = []
def get_features_hook(self, input, output):
feature.append(output)
def show(para_path):
print(‘device:{}‘.format(device))
show_path = os.path.join(root, ‘show‘)
os.makedirs(show_path, exist_ok=True)
model = Net()
model.load_state_dict(torch.load(para_path,map_location=‘cpu‘))
model = model.to(device)
for index, (data, target) in enumerate(test_loader):
print(index)
data = data_enhance(data, index).to(device)
save_image(data, os.path.join(show_path, str(index) + ‘_img.jpg‘))
handle = model.mp1.register_forward_hook(get_features_hook)
model(data)
handle.remove()
feat = torch.max(feature[-1], dim=1, keepdim=True)[0]
save_image(feat, os.path.join(show_path, str(index) + ‘_feat.jpg‘))
if index > 3:
break
if __name__ == ‘__main__‘:
act = 2
if act == 1:
print(‘start training...‘)
for epoch in range(1, 100):
train(epoch)
test()
else:
print(‘start show..‘)
show(‘/pytorch-master/mnist_data/savepath/40.pth‘)
输入:(为了增加难度,对mnist数据集的图片进行了平移,加噪音操作)


可视化效果:(可以看出,网络确实学习到了数字特征(至少是位置信息),最终能达到0.96的准确率)


原文:https://www.cnblogs.com/jiangnanyanyuchen/p/13325322.html