让模型接着上次保存好的模型训练,模型加载
#实例化模型、优化器、损失函数 model = MnistModel().to(config.device) optimizer = optim.Adam(model.parameters(),lr=0.01) if os.path.exists("./model/mnist_net.pt"): model.load_state_dict(torch.load("./model/mnist_net.pt")) optimizer.load_state_dict(torch.load("model/mnist_optimizer.pt"))
模型保存
torch.save(model.state_dict(),"model/mnist_net.pt") torch.save(optimizer.state_dict(),"model/mnist_optimizer.pt")
原文:https://www.cnblogs.com/LiuXinyu12378/p/12313880.html