深度学习-LogSumExp技巧

引言

今天来学习下 LogSumExp(LSE)技巧,主要解决计算 Softmax 或 CrossEntropy时出现的上溢 (overflow) 或下溢 (underflow) 问题。

我们知道编程语言中的数值都有一个表示范围的,如果数值过大,超过最大的范围,就是上溢;如果过小,超过最小的范围,就是下溢。

什么是 LSE

LSE 被定义为参数指数之和的对数:

logSumExp(x1xn)=log(i=1nexi)=log(i=1nexibeb)=log(ebi=1nexib)=log(i=1nexib)+log(eb)=log(i=1nexib)+b\begin{aligned}\operatorname{logSumExp}\left(x_{1} \ldots x_{n}\right) &=\log \left(\sum_{i=1}^{n} e^{x_{i}}\right) \\ &=\log \left(\sum_{i=1}^{n} e^{x_{i}-b} e^{b}\right) \\ &=\log \left(e^{b} \sum_{i=1}^{n} e^{x_{i}-b}\right) \\ &=\log \left(\sum_{i=1}^{n} e^{x_{i}-b}\right)+\log \left(e^{b}\right) \\ &=\log \left(\sum_{i=1}^{n} e^{x_{i}-b}\right)+b \end{aligned}

其中, b=maxi=1nxib = \max_{i=1}^n x_i

输入可以看成是一个 n 维的向量,输出是一个标量。

为什么需要 LSE

假设我们有 NN 个值的数据集 {xn}n=1N\{x_n\}_{n=1}^N,我们想要求 z=logn=1Nexp{xn}z=\log \sum_{n=1}^{N} \exp \left\{x_{n}\right\} 的值,应该如何计算?

看上去这个问题可能比较奇怪,但是实际上我们在神经网络中经常能碰到这个问题。

在神经网络中,假设我们的最后一层是使用 softmax 去得到一个概率分布,它的公式应该很熟悉了吧

Softmax(xi,x1xn)=exij=1nexj(1)\text{Softmax}(x_i\,, x_{1} \ldots x_{n}) = \frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}} \tag{1}

这里的 xjx_j 是其中的一个值。最终 loss 函数如果使用 cross entropy,那么就涉及到需要对该式求 log\log ,也就是

log(exji=1nexi)=log(exj)log(i=1nexi)=xjlog(i=1nexi)\begin{aligned} \log \left(\frac{e^{x_{j}}}{\sum_{i=1}^{n} e^{x_{i}}}\right) &=\log \left(e^{x_{j}}\right)-\log \left(\sum_{i=1}^{n} e^{x_{i}}\right) \\ &=x_{j}-\log \left(\sum_{i=1}^{n} e^{x_{i}}\right) \end{aligned}

这里的减号后面的部分,也就是我们上面所要求的 zz ,即 LogSumExp(之后简称 LSE)。

但是 Softmax 存在上溢和下溢的问题。如果 xix_i 太大,对应的指数函数也非常大,此时很容易就溢出,得到nan结果;如果 xix_i 太小,或者说负的太多,就会导致出现下溢而变成 0,如果分母变成 0,就会出现除 0 的结果。

此时我们经常看到一个常见的做法是 (其实用到的是指数归一化技巧, Exp-normalize),先计算 xx 中的最大值 b=maxi=1nxib = \max_{i=1}^n x_i,然后根据

Softmax(xi,x1xn)=exij=1nexj=exibebj=1n(exjbeb)=exibebebj=1nexjb=exibj=1nexjb=Softmax(xib,x1bxnb)(2)\begin{aligned} \text{Softmax}(x_i\,, x_{1} \ldots x_{n}) &= \frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}} \\ &= \frac{e^{x_i-b} \cdot e^b}{\sum_{j=1}^n \left (e^{x_j-b} \cdot e^b \right)} \\ &= \frac{e^{x_i-b} \cdot e^b}{ e^b \cdot \sum_{j=1}^n e^{x_j-b} } \\ &= \frac{e^{x_i-b}}{\sum_{j=1}^n e^{x_j-b}} \\ &= \text{Softmax}(x_i - b, x_{1}-b \ldots x_{n}-b) \end{aligned}\tag{2}

这种转换是等价的,经过这一变换,就避免了上溢,最大值变成了exp(0)=1\exp(0)=1;同时分母中也会有一个 1,就避免了下溢。

我们通过实例来理解一下。

1
2
3
4
5
6
def bad_softmax(x):
y = np.exp(x)
return y / y.sum()

x = np.array([1, -10, 1000])
print(bad_softmax(x))
1
2
3
... RuntimeWarning: overflow encountered in exp
... RuntimeWarning: invalid value encountered in true_divide
array([ 0., 0., nan])

接下来进行上面的优化,并进行测试:

1
2
3
4
5
6
def softmax(x):
b = x.max()
y = np.exp(x - b)
return y / y.sum()

print(softmax(x))
1
array([0., 0., 1.])

我们再看下是否会出现下溢:

1
2
3
4
5
x = np.array([-800, -1000, -1000])
print(bad_softmax(x))
# array([nan, nan, nan])
print(softmax(x))
# array([1.00000000e+00, 3.72007598e-44, 3.72007598e-44])

嗯,看来解决了这个两个问题。等等,不是说 LSE 吗,怎么整了个什么归一化技巧。

好吧,回到 LSE。

我们对 Softmax 取对数,得到:

log(Softmax(xi))=logexij=1nexj=xilogj=1nexj(3)\begin{aligned} \log \left( \text{Softmax}(x_i) \right) &= \log \frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}} \\ &= x_i - \log \sum_{j=1}^n e^{x_j} \\ \end{aligned}\tag{3}

因为上面最后一项也有上溢的问题,所以应用同样的技巧,得

logj=1nexj=logj=1nexjbeb=b+logj=1nexjb(4)\log \sum_{j=1}^n e^{x_j} = \log \sum_{j=1}^n e^{x_j-b} e^b = b + \log \sum_{j=1}^n e^{x_j-b} \tag{4}

bb 同样是取 xx 中的最大值。

那么:

log(Softmax(xi))=logexij=1nexj=xilogj=1nexj=xilogj=1nexjbb\begin{aligned} \log \left( \text{Softmax}(x_i) \right) &= \log \frac{e^{x_i}}{\sum_{j=1}^n e^{x_j}} \\ &= x_i - \log \sum_{j=1}^n e^{x_j} \\ &= x_i - \log \sum_{j=1}^n e^{x_j-b} - b \end{aligned}

这样,我们就得到了 LSE 的最终表示:

LSE(x1xn)=b+logj=1nexjb(5)\text{LSE}( x_{1} \ldots x_{n}) = b + \log \sum_{j=1}^n e^{x_j-b} \tag{5}

此时,Softmax 也可以这样表示:

Softmax(xi)=exp(xiblogj=1nexjb)(6)\text{Softmax}(x_i) = \exp \left( x_i - b - \log \sum_{j=1}^n e^{x_j-b} \right) \tag{6}

对 LogSumExp 求导就得到了 exp-normalize(Softmax) 的形式,

(b+logj=1nexjb)xj=exibj=1nexjb(7)\frac{\partial \left (b + \log \sum_{j=1}^n e^{x_j-b} \right )}{\partial x_j} = \frac{e^{x_i - b}}{\sum_{j=1}^n e^{x_j-b}} \tag{7}

那我们是使用 exp-normalize 还是使用 LogSumExp 呢?

如果你需要保留 Log 空间,那么就计算 log(Softmax)\log(\text{Softmax}) ,此时使用 LogSumExp 技巧;如果你只需要计算 Softmax,那么就使用 exp-normalize 技巧。

性质

LSE 函数是凸函数,且在域内严格单调递增,但是其并非处处严格凸(摘自维基百科

严格凸的 LSE 应该是 LSE0+(x1,,xn)=LSE(0,x1,,xn)L S E_{0}^{+}\left(x_{1}, \ldots, x_{n}\right)=L S E\left(0, x_{1}, \ldots, x_{n}\right)

首先,如果我们使用线性近似 的方法,依据定义 f(x)f(c)+f(c)(xc)f(x)\approx f(c)+f'(c)(x-c)

那么对于 f(x)=log(x+a)f(x)=\log(x+a) 来说,令 c=xac=x-a ,可得

f(x)=log(x+a)f(xa)+f(xa)a=logx+1xaf(x)=\log(x+a)\approx f(x-a)+f'(x-a)\cdot a=\log x+\frac{1}{x}\cdot a

因此 log(x+a)logx+ax\log(x+a)\approx \log x + \frac{a}{x}

那么 LSE(x)=log(iexi)log(exj)+(ijexi)/exj=xj+(ijexi)/exjLSE(x)=\log(\sum_i e^{x_i} )\approx \log \left(e^{x_j} \right)+\left(\sum_{i \neq j} e^{x_i}\right) / e^{x_j}={x_j}+\left(\sum_{i \neq j} e^{x_i}\right) / e^{x_j}

在维基百科中,该式直接 xj=maxixi\approx x_j=\max_i x_i 了。

由于 ixinmaxixi\sum_i x_i\leq n\cdot \max_i x_i ,且对于正数来说 ixixi\sum_i x_i\geq x_i ,因此

max{x1,xn}=log(emaxxi)log(ex1++exn)log(nemaxxi)=max{x1,,xn}+log(n)\begin{aligned} \max \left\{x_{1} \ldots, x_{n}\right\} &=\log (e^{\max x_{i}}) \\ & \leq \log (e^{x_1}+\cdots+e^{x_n}) \\ & \leq \log (n \cdot e^{\max x_{i}}) \\ &=\max \left\{x_{1}, \ldots, x_{n}\right\}+\log (n) \end{aligned}

max{x1,,xn}<LSE(x1,,xn)max{x1,,xn}+log(n)\max \left\{x_{1}, \ldots, x_{n}\right\}<\operatorname{LSE}\left(x_{1}, \ldots, x_{n}\right) \leq \max \left\{x_{1}, \ldots, x_{n}\right\}+\log (n)

因此可以说 LSE(x1,...,xn)maxixiLSE(x_1,...,x_n)\approx \max_i x_i ,所以它实际上是针对 max 函数的一种平滑操作,从字面上理解来说,LSE 函数才是真正意义上的 softmax 函数。而我们在神经网络里所说的 softmax 函数其实是近似于 argmax 函数的,也就是我们平时所用的 softmax 应该叫做 softargmax。

怎么实现 LSE

实现 LSE 就很简单了,我们通过代码实现一下。

1
2
3
4
5
6
def logsumexp(x):
b = x.max()
return b + np.log(np.sum(np.exp(x - b)))

def softmax_lse(x):
return np.exp(x - logsumexp(x))

上面是基于 LSE 实现了 Softmax,下面测试一下:

1
2
3
4
5
6
7
8
9
10
> x1 = np.array([1, -10, 1000])
> x2 = np.array([-900, -1000, -1000])
> softmax_lse(x1)
array([0., 0., 1.])
> softmax(x1)
array([0., 0., 1.])
> softmax_lse(x2)
array([1.00000000e+00, 3.72007598e-44, 3.72007598e-44])
> softmax(x2)
> array([1.00000000e+00, 3.72007598e-44, 3.72007598e-44])

最后我们看一下数值稳定版的 Sigmoid 函数

数值稳定的 Sigmoid 函数

我们知道 Sigmoid 函数公式为:

σ(x)=11+ex(8)\sigma(x) = \frac{1}{1 + e^{-x}} \tag{8}

从上图可以看出,如果 xx 很大, exe^x 会非常大,而很小就没事,变成无限接近 0。

当 Sigmoid 函数中的 xx 负的特别多,那么 exp(x)\exp(-x) 就会变成 \infty ,就出现了上溢;

那么如何解决这个问题呢?σ(x)\sigma(x) 可以表示成两种形式:

σ(x)=11+exp(x)=ex1+ex(9)\sigma(x) = \frac{1}{1 + \exp(-x)} = \frac{e^x}{1 + e^x} \tag{9}

x0x \geq 0 时,我们根据 exe^{x} 的图像,我们取 11+ex\frac{1}{1 + e^{-x}} 的形式;

x<0x < 0 时,我们取 ex1+ex\frac{e^x}{1 + e^x}

1
2
3
4
5
6
7
8
9
10
# 原来的做法
def sigmoid_naive(x):
return 1 / (1 + math.exp(-x))

# 优化后的做法
def sigmoid(x):
if x < 0:
return math.exp(x) / (1 + math.exp(x))
else:
return 1 / (1 + math.exp(-x))

然后用不同的数值进行测试:

1
2
3
4
5
6
7
8
> sigmoid_naive(2000)
1.0
> sigmoid(2000)
1.0
> sigmoid_naive(-2000)
OverflowError: math range error
> sigmoid(-2000)
0.0

References

一文弄懂LogSumExp技巧

https://zhuanlan.zhihu.com/p/153535799