前言

本来是想看模型uncertainty的,也不知怎么就回溯到了这里,建议提前看下极大似然估计,最大后验估计,贝叶斯公式

线性回归

这里以一个多维的特征向量举例:假设输入样本为x\bm{x}x,模型的输出为不同参数在该样本上的线性组合f(x)=wTxf(\bm{x})=w^T\bm{x}f(x)=wTx,样本的标签为y=f(x)+ϵ\bm{y}=f(\bm{x}) + \epsilony=f(x)+ϵ,其中ϵ∼N(0,σ2)\epsilon \sim N(0, \sigma^2)ϵN(0,σ2)
首先有一种很朴素的方法就是使用最小二乘法来求解,即对真实值和预测值的函数进行求导,找到极值点最小情况下对应的www,这种方法从贝叶斯的角度考虑就是计算
wMLE=argmaxxlog⁡P(D∣w)w_{MLE}=\mathop{argmax} \limits_x\log P(D|w)wMLE=xargmaxlogP(Dw),相当于极大似然估计。
为了防止过拟合,后面又有人提出了带有正则化的最小二乘估计,这种方式实际上就是最大后验估计:
wMAP=argmaxxlog⁡P(D∣w)P(w)w_{MAP}=\mathop{argmax} \limits_x\log P(D|w)P(w)wMAP=xargmaxlogP(Dw)P(w),即认为www不是可以无限取值的,而是服从一种先验分布,关于最小二乘估计与贝叶斯的关系会单独写一节。我们知道,贝叶斯学派喜欢从已知数据推导参数,即求解P(w∣D)P(w|D)P(wD),并且还不是求解具体的www是多少,而是计算已知数据的情况下模型参数www应该对应什么样的后验分布。

贝叶斯推断

根据贝叶斯公式展开:
P(w∣D)=P(D∣w)P(w)P(D) P(w|D)=\frac{P(D|w)P(w)}{P(D)} P(wD)=P(D)P(Dw)P(w)
其中P(D)=P(Y∣X)=∫P(Y∣w,X)P(w∣W)dwP(D)=P(Y|X)=\int P(Y|w,X)P(w|W) dwP(D)=P(YX)=P(Yw,X)P(wW)dw,这是一个固定值,所以可以得到下面的计算:
P(w∣D)∝P(D∣w)P(w) P(w|D) \varpropto P(D|w)P(w) P(wD)P(Dw)P(w)
我们之前定义的真实值y\bm{y}yx\bm{x}x是一种线性高斯模型,所以得到P(D∣w)P(D|w)P(Dw)的表示为:
P(D∣w)=P(Y∣w,X)=∏i=1NP(yi∣w,xi)=∏i=1NP(yi∣wTxi,σ2) P(\bm{D}|w)=P(\bm{Y}|w,\bm{X})=\prod_{i=1}^N P(\bm{y_i}|w, \bm{x_i})=\prod_{i=1}^N P(\bm{y_i}|w^T \bm{x_i},\sigma^2) P(Dw)=P(Yw,X)=i=1NP(yiw,xi)=i=1NP(yiwTxi,σ2)
解释一下上面的公式:

- 为什么数据集的后验概率是对应多个样本得到的后验概率的乘积?
贝叶斯线性估计有一个前提:条件独立,即在相同的www下由不同的样本xix_ixi得到的输出yiy_iyi的分布是相互独立的。多元高斯分布的联合概率密度在所有变量互相独立的前提下等于各个变量的概率密度函数的乘积。高斯过程建模取消了这种假设,这里不做讨论。
我们要计算P(w∣D)P(w|D)P(wD)还需要P(w)P(w)P(w),一般假设其服从高斯分布,所以这样后面两项就都可以计算了,原式变为:
P(w∣D)∝∏i=1NP(yi∣wTxi,ϵ)⋅N(0,σ2) P(w|D) \varpropto \prod_{i=1}^N P(\bm{y_i}|w^T \bm{x_i},\epsilon) \cdot N(0, \sigma^2) P(wD)i=1NP(yiwTxi,ϵ)N(0,σ2)
后验概率P(w∣D)P(w|D)P(wD)也是一个高斯分布,这个是通过高斯分布的共轭性质推导的,这里不详细展开,只要明确这一点就行,既然已经知道它是高斯分布,那么我们只需要知道它的期望和方差就获得了整个分布的表达式。对上式进行展开:
∏i=1NP(yi∣w,xi)=∏i=1N12πσe−(yi−wTxi)22σ2=1(2π)N2σNe−12σ2∑i=1N(yi−wTxi)2 \prod_{i=1}^N P(\bm{y_i}|w, \bm{x_i}) \\ = \prod_{i=1}^N\frac{1}{\sqrt{2\pi}\sigma}e^{-\frac{(y_i-w^Tx_i)^2}{2\sigma^2}} \\ =\frac{1}{(2\pi)^{\frac{N}{2}}\sigma^N}e^{-\frac{1}{2\sigma^2}\sum_{i=1}^N(y_i-w^Tx_i)^2} i=1NP(yiw,xi)=i=1N2π σ1e2σ2(yiwTxi)2=(2π)2NσN1e2σ21i=1N(yiwTxi)2
将上式进一步整理,可得:
∑i=1N(yi−wTxi)2=(y1−wTx1,y2−wTx2,...,yN−wTxN)(y1−wTx1y2−wTx2...yN−wTxN)=(YT−wTXT)(Y−wTX)=(Y−Xw)T(Y−Xw) \sum_{i=1}^N(y_i-w^Tx_i)^2=(y_1-w^Tx_1, y_2-w^Tx_2, ..., y_N-w^Tx_N)\left( \begin{array}{cc} y_1-w^Tx_1 \\ y_2-w^Tx_2 \\ ... \\ y_N-w^Tx_N \end{array}\right)\\ =(Y^T-w^TX^T)(Y-w^TX)\\=(Y-Xw)^T(Y-Xw) i=1N(yiwTxi)2=(y1wTx1,y2wTx2,...,yNwTxN)y1wTx1y2wTx2...yNwTxN=(YTwTXT)(YwTX)=(YXw)T(YXw)
此时:
∏i=1NP(yi∣w,xi)=1(2π)N2σNe−12σ2∑i=1N(yi−wTxi)2=1(2π)N2σNe−12σ2(Y−Xw)T(Y−Xw)=1(2π)N2σNe−12(Y−Xw)Tσ−2I(Y−Xw)∼N(Xw,σ−2I) \prod_{i=1}^N P(\bm{y_i}|w, \bm{x_i})\\=\frac{1}{(2\pi)^{\frac{N}{2}}\sigma^N}e^{-\frac{1}{2\sigma^2}\sum_{i=1}^N(y_i-w^Tx_i)^2}\\=\frac{1}{(2\pi)^{\frac{N}{2}}\sigma^N}e^{-\frac{1}{2\sigma^2}(Y-Xw)^T(Y-Xw)}\\=\frac{1}{(2\pi)^{\frac{N}{2}}\sigma^N}e^-\frac{1}{2}(Y-Xw)^T\sigma^{-2}I(Y-Xw)\\\sim N(Xw,\sigma^{-2}I) i=1NP(yiw,xi)=(2π)2NσN1e2σ21i=1N(yiwTxi)2=(2π)2NσN1e2σ21(YXw)T(YXw)=(2π)2NσN1e21(YXw)Tσ2I(YXw)N(Xw,σ2I)
将计算得到的∏i=1NP(yi∣w,xi)\prod_{i=1}^N P(\bm{y_i}|w, \bm{x_i})i=1NP(yiw,xi)P(w)P(w)P(w)带入到目标的求解中:
P(w∣D)∝∏i=1NP(yi∣wTxi,ϵ)⋅N(0,σ2)∝e−12(Y−Xw)Tσ−2I(Y−Xw)⋅e−12wTΣpw∝e−12σ2(YTY−2YTXw+wXTXw)−12wTΣpw P(w|D) \varpropto \prod_{i=1}^N P(\bm{y_i}|w^T \bm{x_i},\epsilon) \cdot N(0, \sigma^2)\\ \varpropto e^{-\frac{1}{2}(Y-Xw)^T\sigma^{-2}I(Y-Xw)} \cdot e^{-\frac{1}{2}w^T\Sigma_{p}w}\\ \varpropto e^{-\frac{1}{2\sigma^2}(Y^TY-2Y^TXw+wX^TXw)-\frac{1}{2}w^T\Sigma_{p}w} P(wD)i=1NP(yiwTxi,ϵ)N(0,σ2)e21(YXw)Tσ2I(YXw)e21wTΣpwe2σ21(YTY2YTXw+wXTXw)21wTΣpw
上式中的∑p\sum_{p}p是一个以σ2\sigma^2σ2为主对角线元素的单位矩阵,反映的是噪声在样本集上的表现,对于多元高斯分布来说是一个协方差矩阵。
现在我们得到了P(w∣D)P(w|D)P(wD)的表示形式,但是不能立即看出这个分布的均值和方差是多少,所以我们需要进一步求解,首先,用一个多元高斯分布的公式展开,写出均值和方差的通用表示形式。一个多元高斯分布的指数部分展开为:
e−12(X−μ)TΣ−1(X−μ)=−12(XTΣ−1X−2μTΣ−1X+μTΣ−1μ) e^{-\frac{1}{2}(X-\mu)^T\Sigma^{-1}(X-\mu)}\\=-\frac{1}{2}(X^T \Sigma^{-1} X-2\mu^T\Sigma^{-1}X+\mu^T\Sigma^{-1}\mu ) e21(Xμ)TΣ1(Xμ)=21(XTΣ1X2μTΣ1X+μTΣ1μ)
多元高斯分布是一个关于XXX的函数,我们的目标函数是一个关于www的函数,所以我们需要把上式和前面的一次项,二次项分别对应起来,即:
−12σ2wXTXw−12wTΣpw=−12wT(σ−2XTX+Σp−1)w⇔−12(XTΣ−1X)−12σ2(−2YTXw)=σ−2YTXw⇔μTΣ−1X -\frac{1}{2\sigma^2}wX^TXw-\frac{1}{2}w^T\Sigma_{p}w \\=-\frac{1}{2}w^T(\sigma^{-2}X^TX+\Sigma_p^{-1})w \Leftrightarrow -\frac{1}{2}(X^T \Sigma^{-1} X) \\ -\frac{1}{2\sigma^2}(-2Y^TXw)\\ =\sigma^{-2}Y^TXw \Leftrightarrow \mu^T\Sigma^{-1}X 2σ21wXTXw21wTΣpw=21wT(σ2XTX+Σp1)w21(XTΣ1X)2σ21(2YTXw)=σ2YTXwμTΣ1X
通过第一个对照可以求解后验分布的协方差为:
Σw−1=σ−2XTX+Σp−1Σw=(σ−2XTX+Σp−1)−1 \Sigma_w^{-1}=\sigma^{-2}X^TX+\Sigma_p^{-1}\\ \Sigma_w = (\sigma^{-2}X^TX+\Sigma_p^{-1})^{-1} Σw1=σ2XTX+Σp1Σw=(σ2XTX+Σp1)1
将计算的协方差带入第二个对照:
σ−2YTX=μTΣw−1 \sigma^{-2}Y^TX=\mu^T\Sigma_w^{-1} σ2YTX=μTΣw1
计算可得均值为:
μw=σ−2ΣwYTX \mu_w = \sigma^{-2}\Sigma_wY^TX μw=σ2ΣwYTX
这样我们就通过现有的已知量得到了后验概率分布的表达式了。

如何使用模型做预测

我们得到了参数www的分布,如何进一步来预测未知数据x∗x^*x的label呢?
首先对于数据x∗x^*x,有f(x∗)=wTx∗f(x^*)=w^Tx^*f(x)=wTx,而www服从后验分布N∼(μw,Σw)N\sim (\mu_w, \Sigma_w)N(μw,Σw),根据高斯分布的性质,f(x∗)f(x^*)f(x)应该服从N∼((x∗)Tμw,(x∗)TΣwx∗)N\sim((x^*)^T\mu_w,(x^*)^T\Sigma_wx^*)N((x)Tμw,(x)TΣwx),另外,考虑到数据的噪声ϵ\epsilonϵ,相应的y∗y^*y应该服从的高斯分布的形式为:
P(y∗∣x∗,D)=N((x∗)Tμw,(x∗)TΣwx∗+σ2) P(y^*|x^*,D) = N((x^*)^T\mu_w,(x^*)^T\Sigma_wx^*+\sigma^2) P(yx,D)=N((x)Tμw,(x)TΣwx+σ2)
实际做预测的时候,一般是对上面的分布求期望,也可以理解为求极值对应的横坐标值,因为在高斯分布下极值点对应的横坐标点就是期望值。

另外补充一句

如果上式在预测的时候。后验概率P(y∗∣x∗,D)P(y^*|x^*,D)P(yx,D)的方差比较大的话,我们就可以理解为模型对这个样本的预测把握程度并不大,因为他已经在一定范围内左右摇摆了,这可能是一种模型uncertainty的最初体现形式吧。

Logo

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。

更多推荐