首页 > 其他 > 详细

pytorch搭建网络,保存参数,恢复参数

时间:2018-08-11 00:42:13      阅读:283      评论:0      收藏:0      [点我收藏+]

这是看过莫凡python的学习笔记。

搭建网络,两种方式

(1)建立Sequential对象

import torch
net = torch.nn.Sequential(
            torch.nn.Linear(2,10),
            torch.nn.ReLU(),
            torch.nn.Linear(10,2))

输出网络结构

Sequential(
  (0): Linear(in_features=2, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=2, bias=True)
)

(2)建立网络类,继承torch.nn.module

class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.hidden = torch.nn.Linear(2,10)
        self.predict = torch.nn.Linear(10,2)
    def forward(self,x):
        x = F.relu(self.hidden(x))
        x = self.predict(x)
        return x

输出和上面基本一样,略微不同

Net(
  (hidden): Linear(in_features=2, out_features=10, bias=True)
  (predict): Linear(in_features=10, out_features=2, bias=True)
)

 

保存模型,两种方式

(1)保存整个网络,及网络参数

torch.save(net,net.pkl)

(2)只保存网络参数

torch.save(net.state_dict(),net_params.pkl)

 

恢复模型,两种方式

(1)加载整个网络,及参数

net2 = torch.load(net.pkl)

(2)加载参数,但需实现网络

net3 = torch.nn.Sequential(
            torch.nn.Linear(2,10),
            torch.nn.ReLU(),
            torch.nn.Linear(10,2))
net3.load_state_dict(torch.load(net_params.pkl))

 

pytorch搭建网络,保存参数,恢复参数

原文:https://www.cnblogs.com/wzyuan/p/9458008.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!