首页 > Web开发 > 详细

MXNET:权重衰减-gluon实现

时间:2018-08-22 23:47:15      阅读:274      评论:0      收藏:0      [点我收藏+]

构建数据集

# -*- coding: utf-8 -*-
from mxnet import init
from mxnet import ndarray as nd
from mxnet.gluon import loss as gloss
import gb

n_train = 20
n_test = 100

num_inputs = 200
true_w = nd.ones((num_inputs, 1)) * 0.01
true_b = 0.05
features = nd.random.normal(shape=(n_train+n_test, num_inputs))
labels = nd.dot(features, true_w) + true_b
labels += nd.random.normal(scale=0.01, shape=labels.shape)
train_features, test_features = features[:n_train, :], features[n_train:, :]
train_labels, test_labels = labels[:n_train], labels[n_train:]

数据迭代器

from mxnet import autograd
from mxnet.gluon import data as gdata

batch_size = 1
num_epochs = 10
learning_rate = 0.003

train_iter = gdata.DataLoader(gdata.ArrayDataset(
    train_features, train_labels), batch_size, shuffle=True)
loss = gloss.L2Loss()

训练并展示结果

gb.semilogy函数:绘制训练和测试数据的loss

from mxnet import gluon
from mxnet.gluon import nn

def fit_and_plot(weight_decay):
    net = nn.Sequential()
    net.add(nn.Dense(1))
    net.initialize(init.Normal(sigma=1))
    # 对权重参数做 L2 范数正则化,即权重衰减。
    trainer_w = gluon.Trainer(net.collect_params('.*weight'), 'sgd', {
        'learning_rate': learning_rate, 'wd': weight_decay})
    # 不对偏差参数做 L2 范数正则化。
    trainer_b = gluon.Trainer(net.collect_params('.*bias'), 'sgd', {
        'learning_rate': learning_rate})
    train_ls = []
    test_ls = []
    for _ in range(num_epochs):
        for X, y in train_iter:
            with autograd.record():
                l = loss(net(X), y)
            l.backward()
            # 对两个 Trainer 实例分别调用 step 函数。
            trainer_w.step(batch_size)
            trainer_b.step(batch_size)
        train_ls.append(loss(net(train_features),
                             train_labels).mean().asscalar())
        test_ls.append(loss(net(test_features),
                            test_labels).mean().asscalar())
    gb.semilogy(range(1, num_epochs + 1), train_ls, 'epochs', 'loss',
                range(1, num_epochs + 1), test_ls, ['train', 'test'])
    return 'w[:10]:', net[0].weight.data()[:, :10], 'b:', net[0].bias.data()
print fit_and_plot(5)
  • 使用 Gluon 的 wd 超参数可以使用权重衰减来应对过拟合问题。
  • 我们可以定义多个 Trainer 实例对不同的模型参数使用不同的迭代方法。

MXNET:权重衰减-gluon实现

原文:https://www.cnblogs.com/houkai/p/9521015.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!