Natural Gradient Decent

Dec 29, 2022 · 11 min · machine learning, math


Natural gradient decent minimizes the loss function in distribution space with KL-divergence as metric, instead of the usual parameter space with Euclidean metric.


Gradient Decent

Assume that L(θ)L(\theta) is a loss function with first-order continuous partial derivatives, where θRn\theta \in R^n. Now want to solve a unconstrained optimization problem:

minθL(θ)\min_\theta L(\theta)

Gradient descent is an iterative optimization algorithm. It works by starting at a random point on the objective function and iteratively updating θ\theta. This process is repeated until the minimum is found or the algorithm converges. Formally, in (k+1)(k + 1)th iteration, we have:

θ(k+1)=θ(k)+ηv(1)\theta^{(k+1)} = \theta^{(k)} + \eta \vec{v} \tag{1}

where η\eta is step size and v\vec{v} is update direction. The first order Taylor series expansion of L(θ(k+1))L(\theta^{(k+1)}) at θ(k)\theta^{(k)} is:

L(θ(k+1))=L(θ(k))+L(θ(k))(θ(k+1)θ(k))+o(θ(k))=L(θ(k))+ηvL(θ(k))+o(θ(k))L(θ(k))+ηvL(θ(k))\begin{aligned} L(\theta^{(k+1)}) &= L(\theta^{(k)}) + \nabla L(\theta^{(k)}) \cdot (\theta^{(k+1)} - \theta^{(k)}) \textcolor{blue}{+ o(\theta^{(k)})}\\ &= L(\theta^{(k)}) + \eta \vec{v} \cdot \nabla L(\theta^{(k)}) \textcolor{blue}{+ o(\theta^{(k)})}\\ &\approx L(\theta^{(k)}) + \eta \vec{v} \cdot \nabla L(\theta^{(k)}) \end{aligned}

We want to find a steepest descent direction around the local neighbourhood of θ\theta in the parameter space, that is to minimize:

L(θ(k+1))L(θ(k))=ηvL(θ(k))=ηvL(θ(k))cos(α)L(\theta^{(k+1)}) - L(\theta^{(k)}) = \eta \vec{v} \cdot \nabla L(\theta^{(k)}) = \eta \cdot \| \vec{v} \| \cdot \| \nabla L(\theta^{(k)}) \| \cdot \cos(\alpha)

where α\alpha is the angle between v\vec{v} and L(θ(k))\nabla L(\theta^{(k)}). Obviously, when α=π\alpha = - \pi, i.e. when v\vec{v} and L(θ(k))\nabla L(\theta^{(k)}) are in the opposite directions, cos(α)=1\cos(\alpha) = -1 and the loss decrease most greatly. This is why gradient decent moves in the direction of the negative gradient:

v=L(θ(k))L(θ(k))\vec{v} = - \frac{\nabla L(\theta^{(k)})}{\| \nabla L(\theta^{(k)}) \|}

Now we can write Eq. 1 as:

θ(k+1)=θ(k)ηL(θ(k))L(θ(k))\theta^{(k+1)} = \theta^{(k)} - \frac{\eta}{\| \nabla L(\theta^{(k)}) \|} \cdot \nabla L(\theta^{(k)})

where ηL(θ(k))\frac{\eta}{\| \nabla L(\theta^{(k)}) \|} is learning rate.

KL Divergence

The Kullback-Leibler divergence, also known as the relative entropy, is a measure of the difference between current probability distribution pp and target distribution qq, which is defined as:

KL(pq)=p(x)logp(x)q(x)dxKL (p \| q) = \int p(x) \log \frac{p(x)}{q(x)} dx

The KL divergence is zero if and only if the two distributions are equal. Note that KL divergence is non-symmetric, meaning that the KL divergence between pp and qq is not necessarily equal to that between qq and pp.

Fisher Information Matrix

I suggest having a look at this great article. To be short, the Fisher Information Matrix is:

  • the second moment of the first-order derivative of the log-likelihood function:

    F=Ep(xθ)[logp(xθ)logp(xθ)T]F = \mathbb{E}_{p(x \mid \theta)} \Big [ \nabla \log p(x \mid \theta) \nabla \log p(x \mid \theta)^T \Big ]
  • the negative expectation of the Hessian matrix of the log-likelihood function (proof):

    F=Ep(xθ)[2θθTlogp(xθ)]F = - \mathbb{E}_{p(x \mid \theta)} \left [ \frac{\partial^2}{\partial_{\theta} \partial_{\theta^T}} \log p(x \mid \theta) \right ]

KL Divergence & Fisher Information

Let d0d \to 0, we have:

KL(p(xθ)p(xθ+d))12dTFdKL \Big ( p(x \mid \theta) \| p(x \mid \theta + d) \Big ) \approx \frac{1}{2} d^T F d

For convenience, we denote p(xθ)p(x \mid \theta) and p(xθ+d)p(x \mid \theta + d) as p(θ)p(\theta) and p(θ+d)p(\theta + d), respectively.

KL(p(θ)p(θ+d))=p(θ)logp(θ)p(θ+d)dx=Ep(θ)[logp(θ)p(θ+d)]=Ep(θ)[logp(θ)]Ep(θ)[logp(θ+d)]Ep(θ)[logp(θ)]Ep(θ)[logp(θ)+logp(θ)d+12dT2logp(θ)d]second order Taylor expansion=Ep(θ)[logp(θ)d]Ep(θ)[12dT2logp(θ)d]=[θp(θ)1p(θ)p(θ)dθ]d12dTEp(θ)[2logp(θ)]d=[θp(θ)dθ]d12dTEp(θ)[2logp(θ)]d=12dTEp(θ)[2logp(θ)]=Fd=12dTFd\begin{aligned} KL \Big ( p(\theta) \| p(\theta + d) \Big ) &= \int p(\theta) \log \frac{p(\theta)}{p(\theta + d)} dx = \mathbb{E}_{p(\theta)} \left [ \log \frac{p(\theta)}{p(\theta + d)} \right ] \\ &= \mathbb{E}_{p(\theta)} [ \log p(\theta) ] - \textcolor{blue}{\mathbb{E}_{p(\theta)} [ \log p(\theta + d) ]} \\ &\approx \mathbb{E}_{p(\theta)} [ \log p(\theta) ] - \textcolor{blue}{\underbrace{\mathbb{E}_{p(\theta)} \left [ \log p(\theta) + \nabla \log p(\theta) d + \frac{1}{2} d^T \nabla^2 \log p(\theta) d \right ]}_{\text{second order Taylor expansion}}} \\ & = - \mathbb{E}_{p(\theta)} [ \nabla \log p(\theta) d ] - \mathbb{E}_{p(\theta)} \left [ \frac{1}{2} d^T \nabla^2 \log p(\theta) d \right ] \\ &= - \left [ \int_{\theta} p(\theta) \frac{1}{p(\theta)} \nabla p(\theta) d \theta \right ] d - \frac{1}{2} d^T \mathbb{E}_{p(\theta)} [ \nabla^2 \log p(\theta) ] d \\ &= - \left [ \nabla \int_{\theta} p(\theta) d \theta \right ] d - \frac{1}{2} d^T \mathbb{E}_{p(\theta)} [ \nabla^2 \log p(\theta) ] d \\ & = - \frac{1}{2} d^T \underbrace{\mathbb{E}_{p(\theta)} [ \nabla^2 \log p(\theta) ]}_{= - F} d \\ & = \frac{1}{2} d^T F d \end{aligned}

Which means the Fisher Information Matrix FF defines the local curvature in distribution space for which KL-divergence is the metric.

Riemannian manifold


A manifold is a topological space sharing the local properties of Euclidean spaces. Every point on a manifold has a small neighborhood around it that can be locally approximated by a tangent plane, which means that the curvature of this neighborhood is approximately zero. This allows us to use the Euclidean metric, which is based on the properties of flat space, to measure distance within this small neighborhood. To help your understanding, people on earth may experience their surroundings as being flat, but they are unable to perceive the curvature of the earth due to its large size.

More formally: A manifold MM of dimension nn is a topological space, such that every point xMx \in M has a neighbourhood which is homeomorphic to an open set in Euclidean space Rn\Reals^n. This open set in Euclidean space is called tangent space, refered as TxMT_x M.

Riemannian Metric

The distance between two points in an Euclidean space can be easily determined by taking the modulusof the vector connecting the points. However, manifolds are not linear spaces, so we need to use alternative methods to calculate lengths on a manifold. One possible approach is to consider a continuous, differentiable curve, represented by γ:[a,b]M\gamma: [a, b] \to M, on the manifold and compute the integral of the distance differential at each point along curve γ\gamma:

L(γ)=abγ(x)dxL(\gamma) = \int_a^b \| \gamma'(x) \| dx

Thus for each point γ(x)\gamma(x), we have to define a notion of distance in the tangent space at that point. We then use the notion to calculate the modulus γ(x)\| \gamma'(x) \| of the tangent vector γ(x)=xγ(x)\gamma'(x) = \frac{\partial}{\partial x} \gamma(x) at the point and add up all of γ(x)\| \gamma'(x) \| to obtain the total length of the curve.

To calculate γ(x)\| \gamma'(x) \|:

γ(x)2=γ(x),γ(x)=gγ(x)(γ(x),γ(x))\| \gamma'(x) \|^2 = ⟨ \gamma'(x), \gamma'(x) ⟩ = g_{\gamma(x)} (\gamma'(x), \gamma'(x))

The square of the modulus of a vector, also known as the Euclidean norm or the 2\ell_2 norm, is equal to the inner product of the vector with itself. This means that gx:TxM×TxMRg_x: T_x M \times T_x M \to \Reals defines a metric, which is a way of measuring distance, over a tangent space. A metric that varies smoothly with respect to a point on a manifold is known as a Riemannian metric, and a manifold equipped with a Riemannian metric is called a Riemannian manifold.

As explained here (page 4):

Roughly, a Riemannian manifold MM is a smooth set with a smoothly-varying inner product on the tangent spaces.

Formally, for each xMx \in M,a Riemannian metric g=gxg = g_x satisfies:

  • g(u,v)=g(v,u)g(u, v) = g(v, u) for all u,vTxMu, v \in T_x M
  • g(u,u)0g(u, u) \geq 0 for all uTxMu \in T_x M
  • g(u,u)=0g(u, u) = 0 if and only if u=0u = 0

Natural Gradient Decent

Constrained Optimization

Let’s start with looking at natural gradient decent from aspect of constrained optimization. In traditional gradient decent, the constrained optimization problem we want to solve is:

mindL(θ+d)s.t.dϵ\begin{gathered} \min_{d} L(\theta + d) \\ \text{s.t.} \enspace \| d \| \leq \epsilon \end{gathered}

where the distance in parameter space is contrained and can be calculated using Euclidean metric.

As mentioned before, the steepest descent direction is the direction of the negative gradient:

L(θ)L(θ)=limϵ01ϵargmind s.t.dϵL(θ+d)- \frac{\nabla L(\theta)}{\| \nabla L(\theta) \|} = \lim_{\epsilon \to 0} \frac{1}{\epsilon} \arg \min_{d \text{ s.t.} \| d \| \leq \epsilon} L(\theta + d)

When using gradient descent, the distribution of the parameters may change as the optimization process progresses. However, it is also important to ensure that the amount of change in the distribution is controlled, as large changes can lead to instability in the model. The Euclidean distance, which is based on the properties of flat space, may not be an appropriate measure of the amount of change in the distribution. Therefore, natural gradient decent works in distribution space and uses KL divergence to compare the current distribution of parameters p(xθ)p(x \mid \theta) with the target distribution p(xθ+d)p(x \mid \theta + d). Now the constraint becomes:

s.t.KL(p(xθ)p(xθ+d))ϵ\text{s.t.} \enspace KL \Big ( p(x \mid \theta) \| p(x \mid \theta + d) \Big ) \leq \epsilon

We apply the Lagrange multiplier method to it:

d=argmindL(θ+d)+λ(KL(p(xθ)p(xθ+d))ϵ)argmindL(θ)+L(θ)Tdfirst order Taylor expansion+12λdTFdsecond order Taylor expansionλϵ\begin{aligned} d^* &= \arg \min_{d} L(\theta + d) + \lambda \Big ( KL \big ( p(x | \theta) \| p(x | \theta + d) \big ) - \epsilon \Big ) \\ &\approx \arg \min_{d} \underbrace{L(\theta) + \nabla L(\theta)^T d}_{\text{first order Taylor expansion}} + \underbrace{\frac{1}{2} \lambda d^T F d}_{\text{second order Taylor expansion}} - \lambda \epsilon \end{aligned}

To solve this minimization, we set its derivative to zero:

0=d[L(θ)+L(θ)Td+12λdTFdλϵ]=L(θ)+λFdλFd=L(θ)d=1λF1L(θ)\begin{aligned} 0 &= \frac{\partial}{\partial d} \left [ L(\theta) + \nabla L(\theta)^T d + \frac{1}{2} \lambda d^T F d - \lambda \epsilon \right ] \\ &= \nabla L(\theta) + \lambda F d \\ \lambda F d &= -\nabla L(\theta) \\ d &= - \frac{1}{\lambda} F^{-1} \nabla L(\theta) \end{aligned}

Finally, the optimal descent direction is ~L(θ)=F1L(θ)\tilde{\nabla} L(\theta) = F^{-1} \nabla L(\theta) (constant factor can be absorbed into the learning rate), called natural gradient.

Riemannian Manifold

Now we consider optimizing the objective function on a manifold. We can easily recognize that the descent direction depends on how to calculate d\|d\|, i.e., what’s the metric on the manifold.

The tranditional gradient decent works on parameter space (a Euclidean space), of which the metric is Euclidean metric. The natural gradient descent operates in distribution space (a Riemannian manifold), in which a point, represented by p(xθ)p(x \mid \theta), can be thought of as a parameterized probability distribution. When the parameter θ\theta changes to θ+d\theta + d, the distance between p(xθ)p(x \mid \theta) and p(xθ+d)p(x \mid \theta + d) is 12dTFd\frac{1}{2} d^T F d. It can be seen that the Fisher information matrix FF serves as the Riemannian metric on this Riemannian manifold.

The Fisher information matrix reflects the local curvature of the likelihood probability distribution space, which means that it encodes information about how much the probability distribution changes as the parameters vary. By using the Fisher information matrix as a measure, we can obtain the descent direction that takes into account the curvature of the probability distribution space. This is important because the larger the curvature, the smaller the range of parameter values that can be used to maintain a given likelihood. Natural gradient descent ensures that the optimization process respects the underlying geometry of the probability distribution space and avoids making large, unstable changes to the parameters.


Edit this page on GitHubLast updated: 3/7/2023, 2:57:21 PM