将分布之间的Fisher information matrix (FIM)看成是统计流形上的黎曼度量,然后用流形上的最速下降方向作为搜索方向,就是自然梯度。这是一种概念上比较简洁漂亮的处理方式,但显然不是一种最容易理解的方式,很多人第一次接触到的时候都是有些懵的。
思路上更简单直接的方式可能是从约束优化来理解。考虑一个基本的函数 ,考虑概率分布
上的优化问题
如果我们想找参数化的分布 ,使得E(f)的改进程度最大,最直观的方法自然是直接对E(f)做一步梯度下降。但是由于
和
是概率分布,他们之间的距离不是用参数之间的欧式距离来定义的(简单来说,沿梯度下降一步之后的
可能不满足分布参数的要求,比如正态分布的协方差矩阵变得不正定了),而是用分布之间的KL-divergence来定义的
由于这个KL-div 不对称,它不满足距离的定义。同时由于 比较小,我们可以对此式展开做二阶近似
其中 就是Fisher information matrix 的分量,换句话说,FIM就是KL-div的二阶近似
这可以进一步的写成 , 即Fisher矩阵是log- p的Hessian的期望,与二阶信息密切相关。
回到原来的优化问题,我们面对的问题变成了
将上面的KL-div的二阶近似带入,构造Lagrange 函数,就有
此式可以写成矩阵形式 对此式稍作推导,就得到最速下降方向
这里的 是一个
的无穷小量。这个方向就是所谓的自然梯度方向。
可以看到,这里的推导没有用到任何微分几何和黎曼度量的概念,唯一用到的就是概率分布之间的KL-div 和它的二阶近似,然后套用约束优化的拉格朗日乘子,也就无所谓“自然”了。当然,这里的推导会比黎曼度量-自然梯度 更加技术化一些,技术化的东西相对来说不容易推广。
自然梯度和牛顿法是有关联的,在某些特殊情况下可以认为是Gauss-Newton法的近似。
以上内容最后编辑于 2018.3
更新:多角度理解自然梯度
高票适合对统计,ml,optimization有很高造诣的人理解。我这个回答尝试通过简单的数学推导来给理解natural gradient descent.
在标准的stochastic gradient descent里面, 我们假定用N个可train的parameters, 就像torch.nn里面的parameters, k 代表迭代中的第k步:
假定cost function 为J(w). 在 linear regression中, J(w) 就是mean squared error; 在logistic regression中,J(w) 即为cross entropy。同理所有的regression和classification问题。
这里也同时顺便提一句为什么policy gradient 的基础上搞出了trpo和ppo,原因:
回到正题,现在我们假定 J(w) 在Euclidean coordinates 中,即有 。在Euclidean coordinates,用标准的stochastic gradient descent 是没有问题的。直观来讲,因为坐标轴之间是相互垂直的,在
上的增长,并不会影响到
;
代表的还是最陡峭的方向。
但是一旦到了黎曼空间里面 or Riemannian manifold, 代表的就不是最陡峭的方向了。直观来讲,把地球展开成二维(假设我们有二向箔),它将变成(一幅画)一个地图,在这个地图上的确两点之间线段最短(standard gradients,切线可以理解成无限接近的两点),但是我们一旦回到三维,测量距离的时候,我们要考虑地球的弯曲(the global curvature of the earth‘s surface),切线某种意义上失去了这个curvature。
所以在黎曼几何里面,距离就定义成了 。G(w)这里叫Riemannian metric tensor,黎曼度量。
在统计机器学习里面,G(w) 可以简单地认为是fisher information matrix (the variance of the score function)。
在trpo 里面,如果只泰勒展开近似,结合拉格朗日对偶,并且不backtrack line search,
基本上就是由上面的一套操作(许多套操作)得到的了。
最后回答不易,欢迎点赞,评论。如果有任何错误,欢迎指正。
References:
2. A Natural Policy Gradient
https://papers.nips.cc/paper/2073-a-natural-policy-gradient.pdf3. Natural Gradient Works Efficiently in Learning (Shun-ichi Amari)
4. WHY NATURAL GRADIENT? (Shun-ichi Amari)
其实有另一种理解方式:
natural gradient 实际上是一种测度,而这种测度可以通过线性空间连续变换而来——连续变换怎么理解?连续变换就是原来平直的薄膜在不弄破的情况下进行任意拉伸——这种拉伸是可逆的。
因此,natural gradient decent可以同样应用这种连续变换的逆运算获得gradient,这个逆运算就是你说的fisher information matrix的逆。
原文:https://www.cnblogs.com/cx2016/p/13746220.html