首页 > 其他 > 详细

对tf.nn.softmax的理解

时间:2020-03-22 23:25:21      阅读:117      评论:0      收藏:0      [点我收藏+]

对tf.nn.softmax的理解

Softmax的含义:Softmax简单的说就是把一个N*1的向量归一化为(0,1)之间的值,由于其中采用指数运算,使得向量中数值较大的量特征更加明显。
技术分享图片
如图所示,在等号左边部分就是全连接层做的事。

  1. W是全连接层的参数,我们也称为权值;W是全连接层的参数,是个T*N的矩阵,这个N和X的N对应,T表示类别数,比如你进行手写数字识别,就是10个分类,那么T就是10。
  2. X是全连接层的输入,也就是特征。从图上可以看出特征X是N1的向量,他就是由全连接层前面多个卷积、激活和池化层处理后得到的;
    举一个例子,假设全连接层前面连接的是一个卷积层,这个卷积层的输出是64个特征,每个特征的大小是7X7,那么在将这些特征输入给全连接层之前会将这些特征通过tf.reshape转化为成N
    1的向量(这个时候N就是64X7X7=3136)。
    我们所说的训练一个网络,对于全连接层而言就是寻找最合适的W矩阵。因此全连接层就是执行WX得到一个T1的向量(也就是图中的logits[T1]),这个向量里面的每个数都没有大小限制的,也就是从负无穷大到正无穷大。然后如果你是多分类问题,一般会在全连接层后面接一个softmax层,这个softmax的输入是T1的向量,输出也是T1的向量(也就是图中的prob[T*1],这个向量的每个值表示这个样本属于每个类的概率),只不过输出的向量的每个值的大小范围为0到1。

现在你知道softmax的输出向量是什么意思了,就是概率,该样本属于各个类的概率!

那么softmax执行了什么操作可以得到0到1的概率呢?先来看看softmax的公式(以前自己看这些内容时候对公式也很反感,不过静下心来看就好了):

技术分享图片

公式非常简单,前面说过softmax的输入是WX,假设模型的输入样本是I,讨论一个3分类问题(类别用1,2,3表示),样本I的真实类别是2,那么这个样本I经过网络所有层到达softmax层之前就得到了WX,也就是说WX是一个31的向量,那么上面公式中的aj就表示这个31的向量中的第j个值(最后会得到S1,S2,S3);而分母中的ak则表示31的向量中的3个值,所以会有个求和符号(这里求和是k从1到T,T和上面图中的T是对应相等的,也就是类别数的意思,j的范围也是1到T)。因为e^x恒大于0,所以分子永远是正数,分母又是多个正数的和,所以分母也肯定是正数,因此Sj是正数,而且范围是(0,1)。如果现在不是在训练模型,而是在测试模型,那么当一个样本经过softmax层并输出一个T*1的向量时,就会取这个向量中值最大的那个数的index作为这个样本的预测标签。

因此我们训练全连接层的W的目标就是使得其输出的WX在经过softmax层计算后其对应于真实标签的预测概率要最高。

举个例子:假设你的WX=[1,2,3],那么经过softmax层后就会得到[0.09,0.24,0.67],这三个数字表示这个样本属于第1,2,3类的概率分别是0.09,0.24,0.67。取概率最大的0.67,所以这里得到的预测值就是第三类。

弄懂了softmax,就要来说说softmax loss了。
那softmax loss是什么意思呢?如下:

技术分享图片

首先L是损失。Sj是softmax的输出向量S的第j个值,前面已经介绍过了,表示的是这个样本属于第j个类别的概率。yj前面有个求和符号,j的范围也是1到类别数T,因此y是一个1*T的向量,里面的T个值,而且只有1个值是1,其他T-1个值都是0。那么哪个位置的值是1呢?答案是真实标签对应的位置的那个值是1,其他都是0。所以这个公式其实有一个更简单的形式:

技术分享图片

当然此时要限定j是指向当前样本的真实标签。

来举个例子吧。假设一个5分类问题,然后一个样本I的标签y=[0,0,0,1,0],也就是说样本I的真实标签是4,假设模型预测的结果概率(softmax的输出)p=[0.1,0.15,0.05,0.6,0.1],可以看出这个预测是对的,那么对应的损失L=-log(0.6),也就是当这个样本经过这样的网络参数产生这样的预测p时,它的损失是-log(0.6)。那么假设p=[0.15,0.2,0.4,0.1,0.15],这个预测结果就很离谱了,因为真实标签是4,而你觉得这个样本是4的概率只有0.1(远不如其他概率高,如果是在测试阶段,那么模型就会预测该样本属于类别3),对应损失L=-log(0.1)。那么假设p=[0.05,0.15,0.4,0.3,0.1],这个预测结果虽然也错了,但是没有前面那个那么离谱,对应的损失L=-log(0.3)。我们知道log函数在输入小于1的时候是个负数,而且log函数是递增函数,所以-log(0.6) < -log(0.3) < -log(0.1)。简单讲就是你预测错比预测对的损失要大,预测错得离谱比预测错得轻微的损失要大。

理清了softmax loss,就可以来看看cross entropy了。
corss entropy是交叉熵的意思,它的公式如下:

技术分享图片

是不是觉得和softmax loss的公式很像。当cross entropy的输入P是softmax的输出时,cross entropy等于softmax loss。Pj是输入的概率向量P的第j个值,所以如果你的概率是通过softmax公式得到的,那么cross entropy就是softmax loss。

 

 

 

 

softmax 虽然简单,但是其实这里面有非常的多细节值得一说。

我们挨个捋一捋。

 

1. 什么是 Softmax?


首先,softmax 的作用是把 一个序列,变成概率。

技术分享图片技术分享图片

他能够保证:

  1. 所有的值都是 [0, 1] 之间的(因为概率必须是 [0, 1])
  2. 所有的值加起来等于 1

从概率的角度解释 softmax 的话,就是

技术分享图片

 

2. 文档里面跟 Softmax 有关的坑

这里穿插一个“小坑”,很多deep learning frameworks的 文档 里面 (PyTorch,TensorFlow)是这样描述 softmax 的,

take logits and produce probabilities

很明显,这里面的 logits 就是 全连接层(经过或者不经过 activation都可以)的输出,probability就是 softmax 的输出结果。 这里 logits 有些地方还称之为 unscaled log probabilities。这个就很意思了,unscaled probability可以理解,那又为什么 全连接层直接出来结果会和 log 有关系呢?


原因有两个:

  1. 因为 全连接层 出来的结果,其实是无界的(有正有负),这个跟概率的定义不一致,但是你如果他看成 概率的 log,就可以理解了。
  2. softmax 的作用,我们都知道是 normalize probability。在 softmax 里面,输入 技术分享图片 都是在指数上的 技术分享图片 ,所有把 技术分享图片 想成 log of probability 也就顺理成章了。

 

3. Softmax 就是 Soft 版本的 Max

好的,我们把话题拉回到 softmax。

softmax,顾名思义就是 soft 版本的 max。我们来看一下为什么?

举个栗子,假如 softmax 的输入是:

技术分享图片

softmax 的结果是:

技术分享图片

我们稍微改变一下输入,把 3 改大一点,变成 5,输入是

技术分享图片

softmax 的结果是:

技术分享图片

可见 softmax 是一种非常明显的 “马太效应”:强(大)的更强(大),弱(小)的更弱(小)。假如你要选一个最大的数出来,这个其实就是叫 hardmax。那么 softmax 呢,其实真的就是 soft 版本的 max。

这种 soft 版本的 max 在很多地方有用的上。因为 hard 版本的 max 好是好,但是有很严重的梯度问题,求最大值这个函数本身的梯度是非常非常稀疏的(比如神经网络中的 max pooling),经过hardmax之后,只有被选中的那个变量上面才有梯度,其他都是没有梯度。这对于一些任务(比如文本生成等)来说几乎是不可接受的。所以要么用 hard max 的变种,比如 Gumbel,

Categorical Reparameterization with Gumbel-Softmax?arxiv.org

亦或是 ARSM

ARSM: Augment-REINFORCE-Swap-Merge Estimator for Gradient Backpropagation Through Categorical Variables?proceedings.mlr.press技术分享图片

,要么就直接 softmax。

 

4. Softmax 的实现以及数值稳定性


softmax 的代码实现看似是比较简单的,直接套上面的公式就好

def softmax(x):
    """Compute the softmax of vector x."""
    exps = np.exp(x)
    return exps / np.sum(exps)

但是这种方法非常的不稳定。因为这种方法要算指数,只要你的输入稍微大一点,比如:

技术分享图片

分母上就是

技术分享图片

很明显,在计算上一定会溢出。解决方法也比较简单,就是我们在分子分母上都乘上一个系数,减小数值大小,同时保证整体还是对的

技术分享图片

把常数 C 吸收进指数里面

技术分享图片技术分享图片

这里的D是可以随便选的,一般可以选成

技术分享图片

具体实现可以写成这样

def stablesoftmax(x):
    """Compute the softmax of vector x in a numerically stable way."""
    shiftx = x - np.max(x)
    exps = np.exp(shiftx)
    return exps / np.sum(exps)

这样一种实现数值稳定性已经好了很多,但是仍然会有数值稳定性的问题。比如输入的值差别过大的时候,比如

技术分享图片

这种情况即使用了上面的方法,可能还是报 NaN 的错误。但是这个就是数学本身的问题了,大家使用的时候稍微注意下。

一种可能的替代的方案是使用 LogSoftmax (然后再求 exp),数值稳定性比 softmax 好一些。

技术分享图片

可以看到,LogSoftmax省了一个指数计算,省了一个除法,数值上相对稳定一些。另外,其实 Softmax_Cross_Entropy 里面也是这么实现的


5. Softmax 的梯度

下面我们来看一下 softmax 的梯度问题。整个 softmax 里面的操作都是可微的,所以求梯度就非常简单了,就是基础的求导公式,这里就直接放结果了。

技术分享图片技术分享图片

所以说,如果某个变量做完 softmax 之后很小,比如 技术分享图片 ,那么他的梯度也是非常小的,几乎得不到任何梯度。有些时候,这会造成梯度非常的稀疏,优化不动。


6. Softmax 和 Cross-Entropy 的关系

先说结论,

softmax 和 cross-entropy 本来太大的关系,只是把两个放在一起实现的话,算起来更快,也更数值稳定。

cross-entropy 不是机器学习独有的概念,本质上是用来衡量两个概率分布的相似性的。简单理解(只是简单理解!)就是这样,

如果有两组变量:

技术分享图片

如果你直接求 L2 距离,两个距离就很大了,但是你对这俩做 cross entropy,那么距离就是0。所以 cross-entropy 其实是更“灵活”一些。

 

那么我们知道了,cross entropy 是用来衡量两个概率分布之间的距离的,softmax能把一切转换成概率分布,那么自然二者经常在一起使用。但是你只需要简单推导一下,就会发现,softmax + cross entropy 就好像

“往东走五米,再往西走十米”,

我们为什么不直接

“往西走五米”呢?

cross entropy 的公式是

技术分享图片

这里的 技术分享图片 就是我们前面说的 LogSoftmax。这玩意算起来比 softmax 好算,数值稳定还好一点,为啥不直接算他呢?

所以说,这有了 PyTorch 里面的 torch.nn.CrossEntropyLoss (输入是我们前面讲的 logits,也就是 全连接直接出来的东西)。这个 CrossEntropyLoss 其实就是等于 torch.nn.LogSoftmax + torch.nn.NLLLoss

对tf.nn.softmax的理解

原文:https://www.cnblogs.com/newton001/p/12548828.html

(0)
(0)
   
举报
评论 一句话评论(0
关于我们 - 联系我们 - 留言反馈 - 联系我们:wmxa8@hotmail.com
© 2014 bubuko.com 版权所有
打开技术之扣,分享程序人生!