首页 > 其他 > 详细

LSTM——长短时记忆网络

时间:2019-12-02 23:19:47      阅读:586      评论:0      收藏:0      [点我收藏+]

  LSTM(Long Short-term Memory),长短时记忆网络是1997年Hochreiter和Schmidhuber为了解决预测位置与相关信息之间的间隔增大或者复杂语言场景中,有用信息间隔有大有小、长短不一,造成循环神经网络性能受到限制而提出的。

  LSTM是RNN的一种特殊类型,它可以学习长期依赖的信息。与单一RNN不同,LSTM网络结构是一种拥有3个”门”结构的特殊网络结构,这个特殊设计可以避免长期依赖问题。

  下面介绍LSTM网络结构:

  技术分享图片

  原始的RNN隐藏层只有一个状态h,它对于短期的输入非常敏感。LSTM网络增加了一个状态c,让它来保存长期的状态。新增的状态c,称为单元状态。将(b)按照时间维度展开,如下图所示:

技术分享图片

  如上图可以看出,在t时刻,LSTM网络的输入有3个,即当前时刻网络状态的输入值xt、上一时刻LSTM网络的输入值ht-1以及上一时刻的单元状态ct-1;LSTM网络的输出有两个,即当前时刻LSTM网络输出值ht和当前时刻的单元状态ct。注意,x、c、h都是向量。

  LSTM网络的关键,就是怎样控制长期状态c。在这里,LSTM的思路是使用三个控制开关:第一个开关,负责控制保存长期状态c;第二个开关,负责控制把即时状态输入到长期状态c;第三个开关,负责控制是否把长期状态c作为当前的LSTM网络的输出。

  这三个开关叫做“门”结构,它们可以让信息有选择性的影响循环神经网络中每一个时刻的状态。所谓“门”,实际上就是一层全连接层,它的输入是一个向量,输出是一个0~1之间的实数向量。假设W是门的权重向量,b是偏置项,那么门可以表示为:g(x)=σ(Wx+b)。其中σ为sigmoid函数,因为其值域为(0,1),所以门的状态都是半开半闭的。

  LSTM网络用两个门来来控制单元状态c的内容,一个是遗忘门,它决定了上一时刻的单元状态ct-1有多少保留到当前时刻的单元状态ct;另一个是输入门,它决定了当前时刻网络的输入xt有多少保存到单元状态ct。LSTM网络用输出门来控制单元状态ct有多少输出到LSTM的当前输出值ht

  (1)遗忘门:ƒt=σ(Wf•[ht-1,xt]+bf)。其中Wf是遗忘门的权重矩阵,[ht-1,xt]表示把两个向量连接成一个更长的向量,bf是遗忘门的偏置项。

  (2)输入门:it=σ(Wi•[ht-1,xt]+bi)

  接下来计算用于描述当前输入的单元状态ct,它是根据上一次的输出和本次输入来计算的:

  ct=tanh(Wc•[ht-1,xt]+bc)

  接着计算当前时刻的单元状态ct。它是由上一次的单元状态ct-1按元素乘以遗忘门ft,再用当前输入的单元状态ct按元素乘以输入门it,再将这两个乘积相加而产生的:

  ct=ft•ct-1+it•ct

  这样就把LSTM网络关于当前的记忆ct和长期的记忆ct-1组合在一起,形成了新的单元状态ct。由于遗忘门的控制,LSTM网络可以保存很久很久以前的信息;又由于输入门的控制,它可以避免当前无关紧要的内容进入记忆。

  (3)输出门:ot=σ(Wo•[ht-1,xt]+bo)

  LSTM网络的最终输出,是由输出门和单元状态共同决定的:

  ht=ot•tanh(ct)

  最终LSTM网络结构示意图如图所示:

技术分享图片

  上面介绍的公式,为LSTM前向计算的全部公式。


下面介绍下LSTM网络的训练算法:

  LSTM网络的训练算法仍为反向传播算法,主要步骤如下:

  (1)前向计算每个神经元的输出值,对于LSTM网络来说,即ft、it、ct、ot、ht五个向量的值。

  (2)反向计算每个神经元的误差项。与RNN一样,LSTM网络误差项的反向传播也包括两个方向:一个是沿时间的反向传播,即从当前时刻t开始,计算每个时刻的误差项;另一个是将误差项向上一层传播。

  (3)根据相应的误差项,计算每个权重的梯度。

  (4)用梯度下降的误差后向传播算法更新权重。

 


LSTM网络程序实现:——tensorflow

import tensorflow as tf
#定义一个基本的LSTM网络结构
lstm=tf.contrib.rnn.BasicLSTMCell(lstm_hidden_size)
#将LSTM中的状态初始化为全零数组。返回的state包含两个张量state.c和state.h
state=lstm.zero_state(batch_size,tf.float32)
#定义损失函数
loss=0.0
#定义训练数据的序列长度num_steps
for i in range(num_steps):
    #声明LSTM中使用的变量,在之后的时刻都需要反复用之前定义好的变量
    if i>0:tf.get_variable_scope.reuuse_variables()
    #将当前输入current_input和前一时刻状态state(h_t-1和c_t-1)传入定义的LSTM结构,
    #可以得到当前LSTM输出和lstm_output(ht)和更新后的状态state(ht和ct)
    lstm_output,state=lstm(current_input,state)
    #将当前时刻LSTM输出传入一个全连接层,得到最后的输出
    final_output=fully_connected(lstm_output)
    #计算当前时刻输出的损失函数
    loss+=calc_loss(final_output,expected_optput)
#使用常规的神经网络训练方法训练模型

 


 


 

LSTM——长短时记忆网络

原文:https://www.cnblogs.com/candyRen/p/11973379.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!