首页 > 其他 > 详细

对自定义变量优化的模拟实验

时间:2021-01-22 10:04:35      阅读:23      评论:0      收藏:0      [点我收藏+]

To Be Continue~

‘‘‘
@author: feizzhang
date: 21.01.2021
‘‘‘

import torch
from torch import nn
from torch.autograd import Variable


‘‘‘
# --------------------------------
#  Autograd Version
# --------------------------------
‘‘‘
def simulation_v1():
    # --------------------------------
    #  Initialize the variable
    # --------------------------------
    X = Variable(torch.FloatTensor([8.]), requires_grad=True)
    Y = Variable(torch.FloatTensor([6.]), requires_grad=True)
    W = Variable(torch.FloatTensor([6.]), requires_grad=True)
    H = Variable(torch.FloatTensor([4.]), requires_grad=True)

    # --------------------------------
    #  Train for 40 times
    # --------------------------------
    for i in range(40):
        # --------------------------------
        #  1) Define loss function
        # --------------------------------
        loss = torch.pow((W * H + X + Y), 2)

        # --------------------------------
        #  2) Update grad once
        # --------------------------------
        loss.backward()

        # --------------------------------
        #  3) Update the param
        # --------------------------------
        W.data -= 0.01 * W.grad.data
        H.data -= 0.01 * H.grad.data
        X.data -= 0.01 * X.grad.data
        Y.data -= 0.01 * Y.grad.data

        # --------------------------------
        #  4) Clean the grad
        # --------------------------------
        W.grad.data.zero_()
        H.grad.data.zero_()
        X.grad.data.zero_()
        Y.grad.data.zero_()


‘‘‘
# --------------------------------
#  Optimizer Scheduler Version
# --------------------------------
‘‘‘
class Getparam(nn.Module):
    ‘‘‘Get the param of box‘‘‘
    def __init__(self):
        super(Getparam, self).__init__()

        self.param = nn.Parameter(torch.FloatTensor([[8.], [6.], [6.], [4.]]), requires_grad=True)

    def forward(self, init_w):
        return init_w * self.param


# --------------------------------
#  Define the training
# --------------------------------
def simulation_v2():
    # --------------------------------
    #  Initialize the variable
    # --------------------------------
    getter = Getparam()
    var_data = {}

    # --------------------------------
    #  Define optimizer scheduler
    # --------------------------------
    # optimizer = torch.optim.Adam(getter.parameters(), lr=1, weight_decay=0, betas=(0.9, 0.99))
    optimizer = torch.optim.SGD(getter.parameters(), lr=0.01, weight_decay=0)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 40], gamma=0.5)

    # --------------------------------
    #  Train for 40 iter
    # --------------------------------
    for iter in range(40):
        optimizer.zero_grad()

        # --------------------------------
        #  1) Get the variable
        # --------------------------------
        wight = torch.ones([4, 1], dtype=torch.float16)
        var = getter(wight)

        # --------------------------------
        #  2) Save one iter variable
        # --------------------------------
        for param in getter.parameters():
            print(param.data)
            if iter == 5:
                var_data[iter] = param.data

        # --------------------------------
        #  3) Define the loss
        # --------------------------------
        loss_ = torch.pow((var[2] * var[3] + var[0] + var[1]), 2)

        # --------------------------------
        #  4) Update grad once
        # --------------------------------
        loss_.backward()

        # --------------------------------
        #  5) Update param once
        # --------------------------------
        optimizer.step()

        # --------------------------------
        #  6) Update lr once
        # --------------------------------
        scheduler.step()

        return var_data

simulation_v2()

对自定义变量优化的模拟实验

原文:https://www.cnblogs.com/froml77/p/14311637.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!