首页 > 其他 > 详细

PyTorch实现Softmax代码

时间:2020-02-14 21:59:22      阅读:266      评论:0      收藏:0      [点我收藏+]
 1 # 加载各种包或者模块
 2 import torch
 3 from torch import nn
 4 from torch.nn import init
 5 import numpy as np
 6 import sys
 7 sys.path.append("/home/kesci/input")
 8 import d2lzh1981 as d2l
 9 
10 print(torch.__version__)
1 # 初始化参数和获取数据
2 
3 batch_size = 256
4 train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, root=/home/kesci/input/FashionMNIST2065)
 1 num_inputs = 784
 2 num_outputs = 10
 3 
 4 class LinearNet(nn.Module):
 5     def __init__(self, num_inputs, num_outputs):
 6         super(LinearNet, self).__init__()
 7         self.linear = nn.Linear(num_inputs, num_outputs)
 8     def forward(self, x): # x 的形状: (batch, 1, 28, 28)
 9         y = self.linear(x.view(x.shape[0], -1))
10         return y
11     
12 # net = LinearNet(num_inputs, num_outputs)
13 
14 class FlattenLayer(nn.Module):
15     def __init__(self):
16         super(FlattenLayer, self).__init__()
17     def forward(self, x): # x 的形状: (batch, *, *, ...)
18         return x.view(x.shape[0], -1)
19 
20 from collections import OrderedDict
21 net = nn.Sequential(
22         # FlattenLayer(),
23         # LinearNet(num_inputs, num_outputs) 
24         OrderedDict([
25            (flatten, FlattenLayer()),
26            (linear, nn.Linear(num_inputs, num_outputs))]) # 或者写成我们自己定义的 LinearNet(num_inputs, num_outputs) 也可以
27         )
 1 # 初始化模型参数
 2 init.normal_(net.linear.weight, mean=0, std=0.01)
 3 init.constant_(net.linear.bias, val=0)
 4 
 5 # 定义损失函数
 6 loss = nn.CrossEntropyLoss() # 下面是他的函数原型
 7 # class torch.nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction=‘mean‘)
 8 
 9 # 定义优化函数
10 optimizer = torch.optim.SGD(net.parameters(), lr=0.1) # 下面是函数原型
11 # class torch.optim.SGD(params, lr=, momentum=0, dampening=0, weight_decay=0, nesterov=False)
12 
13 # 训练
14 num_epochs = 5
15 d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, batch_size, None, None, optimizer)

 

PyTorch实现Softmax代码

原文:https://www.cnblogs.com/hahasd/p/12309695.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!