首页 > Web开发 > 详细

mindspore\lite\examples\train_lenet\model\train_utils.py注解

时间:2021-08-07 15:09:42      阅读:26      评论:0      收藏:0      [点我收藏+]

** "mindspore\lite\examples\train_lenet\model\train_utils.py"**

一、代码用处

这段代码块主要使用于 训练数据的模型

二、代码注释

"""train_utils."""

import mindspore.nn as nn#导入mindspore包
from mindspore.common.parameter 
import ParameterTuple
def TrainWrap(net, loss_fn=None, optimizer=None, weights=None):#定义一个包装函数
    """
    TrainWrap
    """
    if loss_fn is None:#判断是否有损失
        loss_fn = nn.SoftmaxCrossEntropyWithLogits(reduction=‘mean‘, sparse=True)#调用方法使用 Logits 的软最大交叉熵
    loss_net = nn.WithLossCell(net, loss_fn)
    loss_net.set_train()
    if weights is None:
        weights = ParameterTuple(net.trainable_params())
    if optimizer is None:#优化器
        optimizer = nn.Adam(weights, learning_rate=0.003, beta1=0.9, beta2=0.999, eps=1e-5, use_locking=False,
                            use_nesterov=False, weight_decay=4e-5, loss_scale=1.0)#进行优化
    train_net = nn.TrainOneStepCell(loss_net, optimizer)
    return train_net#返回训练数据

mindspore\lite\examples\train_lenet\model\train_utils.py注解

原文:https://www.cnblogs.com/WangLiYuan87/p/15111622.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!