机器学习中的矩阵向量求导 - 矩阵向量求导链式法则
本文我们讨论矩阵向量求导链式法则,使用该法则很多时候可以帮我们快速求出导数结果。本文的标量对向量的求导,标量对矩阵的求导使用分母布局, 向量对向量的求导使用分子布局。如果遇到其他资料求导结果不同,请先确认布局是否一样。1. 向量对向量求导的链式法则首先我们来看看向量对向量求导的链式法则。假设多个向量存在依赖关系,比如三个向量x→y→z存在依赖关系,则我们有下面的链式求导法则:∂z∂x=∂z∂y∂y
本文我们讨论矩阵向量求导链式法则,使用该法则很多时候可以帮我们快速求出导数结果。
本文的标量对向量的求导,标量对矩阵的求导使用分母布局, 向量对向量的求导使用分子布局。如果遇到其他资料求导结果不同,请先确认布局是否一样。
1. 向量对向量求导的链式法则
首先我们来看看向量对向量求导的链式法则。假设多个向量存在依赖关系,比如三个向量 x → y → z x→y→z x→y→z存在依赖关系,则我们有下面的链式求导法则:
∂ z ∂ x = ∂ z ∂ y ∂ y ∂ x \frac{∂z}{∂x}=\frac{∂z}{∂y}\frac{∂y}{∂x} ∂x∂z=∂y∂z∂x∂y
该法则也可以推广到更多的向量依赖关系。但是要注意的是要求所有有依赖关系的变量都是向量,如果有一个 Y Y Y是矩阵,比如是 x → Y → z x→Y→z x→Y→z, 则上式并不成立。
从矩阵维度相容的角度也很容易理解上面的链式法则,假设 x , y , z x,y,z x,y,z分别是 m , n , p m,n,p m,n,p维向量,则求导结果 ∂ z ∂ x \frac{∂z}{∂x} ∂x∂z是一个 p × m p×m p×m的雅克比矩阵,而右边 ∂ z ∂ y \frac{∂z}{∂y} ∂y∂z是一个 p × n p×n p×n的雅克比矩阵, ∂ y ∂ x \frac{∂y}{∂x} ∂x∂y是一个n×m的矩阵,两个雅克比矩阵的乘积维度刚好是 p × m p×m p×m,和左边相容。
2. 标量对多个向量的链式求导法则
在我们的机器学习算法中,最终要优化的一般是一个标量损失函数,因此最后求导的目标是标量,无法使用上一节的链式求导法则,比如2向量,最后到1标量的依赖关系:x→y→z,此时很容易发现维度不相容。
假设 x , y x,y x,y分别是 m , n m,n m,n维向量, 那么 ∂ z ∂ x \frac{∂z}{∂x} ∂x∂z的求导结果是一个 m × 1 m×1 m×1的向量, 而 ∂ z ∂ y \frac{∂z}{∂y} ∂y∂z是一个 n × 1 n×1 n×1的向量, ∂ y ∂ x \frac{∂y}{∂x} ∂x∂y是一个 n × m n×m n×m的雅克比矩阵,右边的向量和矩阵是没法直接乘的。
但是假如我们把标量求导的部分都做一个转置,那么维度就可以相容了,也就是:
( ∂ z ∂ x ) T = ( ∂ z ∂ y ) T ∂ y ∂ x (\frac{∂z}{∂x})^T=(\frac{∂z}{∂y})^T\frac{∂y}{∂x} (∂x∂z)T=(∂y∂z)T∂x∂y
但是毕竟我们要求导的是 ∂ z ∂ x \frac{∂z}{∂x} ∂x∂z,而不是它的转置,因此两边转置我们可以得到标量对多个向量求导的链式法则:
∂ z ∂ x = ( ∂ y ∂ x ) T ∂ z ∂ y \frac{∂z}{∂x}=(\frac{∂y}{∂x})^T\frac{∂z}{∂y} ∂x∂z=(∂x∂y)T∂y∂z
如果是标量对更多的向量求导,比如 y 1 → y 2 → . . . → y n → z y1→y2→...→yn→z y1→y2→...→yn→z,则其链式求导表达式可以表示为:
∂ z ∂ y 1 = ( ∂ y n ∂ y n − 1 ∂ y n − 1 ∂ y n − 2 . . . ∂ y 2 ∂ y 1 ) T ∂ z ∂ y n \frac{∂z}{∂y_1}=(\frac{∂y_n}{∂y_{n−1}}\frac{∂y_{n−1}}{∂y_{n−2}}...\frac{∂y_2}{∂y_1})^T\frac{∂z}{∂y_n} ∂y1∂z=(∂yn−1∂yn∂yn−2∂yn−1...∂y1∂y2)T∂yn∂z
这里我们给一个最常见的最小二乘法求导的例子。最小二乘法优化的目标是最小化如下损失函数:
l = ( X θ − y ) T ( X θ − y ) l=(Xθ−y)^T(Xθ−y) l=(Xθ−y)T(Xθ−y)
我们优化的损失函数l是一个标量,而模型参数θ是一个向量,期望L对θ求导,并求出导数等于0时候的极值点。我们假设向量 z = X θ − y z=Xθ−y z=Xθ−y, 则 l = z T z l=z^Tz l=zTz, θ → z → l θ→z→l θ→z→l存在链式求导的关系,因此:
∂ l ∂ θ = ( ∂ z ∂ θ ) T ∂ l ∂ z = X T ( 2 z ) = 2 X T ( X θ − y ) \frac{∂l}{∂θ}=(\frac{∂z}{∂θ})^T\frac{∂l}{∂z}=X^T(2z)=2X^T(Xθ−y) ∂θ∂l=(∂θ∂z)T∂z∂l=XT(2z)=2XT(Xθ−y)
其中最后一步转换使用了如下求导公式:
∂ X θ − y ∂ θ = X ∂ z T z ∂ z = 2 z \begin{aligned} \frac{∂Xθ−y}{∂θ}&=X \\ \frac{∂z^Tz}{∂z}&=2z \end{aligned} ∂θ∂Xθ−y∂z∂zTz=X=2z
这两个式子我们在前几篇里已有求解过,现在可以直接拿来使用了,非常方便。
当然上面的问题使用微分法求导数也是非常简单的,这里只是给出链式求导法的思路。
3. 标量对多个矩阵的链式求导法则
下面我们再来看看标量对多个矩阵的链式求导法则,假设有这样的依赖关系:X→Y→z,那么我们有:
∂ z ∂ x i j = ∑ k , l ∂ z ∂ Y k l ∂ Y k l ∂ X i j = t r ( ( ∂ z ∂ Y ) T ∂ Y ∂ X i j ) \frac{∂z}{∂x_{ij}}=\sum_{k,l}{\frac{∂z}{∂Y_{kl}}\frac{∂Y_{kl}}{∂X_{ij}}}=tr((\frac{∂z}{∂Y})^T\frac{∂Y}{∂X_{ij}}) ∂xij∂z=k,l∑∂Ykl∂z∂Xij∂Ykl=tr((∂Y∂z)T∂Xij∂Y)
这里大家会发现我们没有给出基于矩阵整体的链式求导法则,主要原因是矩阵对矩阵的求导是比较复杂的定义,我们目前也未涉及。因此只能给出对矩阵中一个标量的链式求导方法。这个方法并不实用,因为我们并不想每次都基于定义法来求导最后再去排列求导结果。
虽然我们没有全局的标量对矩阵的链式求导法则,但是对于一些线性关系的链式求导,我们还是可以得到一些有用的结论的。
我们来看这个常见问题: A , X , B , Y A,X,B,Y A,X,B,Y都是矩阵, z z z是标量,其中 z = f ( Y ) , Y = A X + B z=f(Y),Y=AX+B z=f(Y),Y=AX+B,我们要求出 ∂ z ∂ X \frac{∂z}{∂X} ∂X∂z,这个问题在机器学习中是很常见的。此时,我们并不能直接整体使用矩阵的链式求导法则,因为矩阵对矩阵的求导结果不好处理。
这里我们回归初心,使用定义法试一试,先使用上面的标量链式求导公式:
∂ z ∂ x i j = ∑ k , l ∂ z ∂ Y k l ∂ Y k l ∂ X i j \frac{∂z}{∂x_{ij}}=\sum_{k,l}{\frac{∂z}{∂Y_{kl}}\frac{∂Y_{kl}}{∂X_{ij}}} ∂xij∂z=k,l∑∂Ykl∂z∂Xij∂Ykl
我们再来看看后半部分的导数:
∂ Y k l ∂ X i j = ∂ ∑ s ( A k s X s l ) ∂ X i j = ∂ A k i X i l ∂ X i j = A k i δ l j \frac{∂Y_{kl}}{∂X_{ij}}={\frac{∂\sum_{s}{(A_{ks}X_{sl}})}{∂X_{ij}}}=\frac{∂A_{ki}X_{il}}{∂X_{ij}}=A_{ki}δ_{lj} ∂Xij∂Ykl=∂Xij∂∑s(AksXsl)=∂Xij∂AkiXil=Akiδlj
其中 δ l j δ_{lj} δlj在 l = j l=j l=j时为1,否则为0.
那么最终的标签链式求导公式转化为:
∂ z ∂ x i j = ∑ k , l ∂ z ∂ Y k l A k i δ l j = ∑ k ∂ z ∂ Y k j A k i \frac{∂z}{∂x_{ij}}=\sum_{k,l}{\frac{∂z}{∂Y_{kl}}A_{ki}δ_{lj}}=\sum_k{\frac{∂z}{∂Y_{kj}}A_{ki}} ∂xij∂z=k,l∑∂Ykl∂zAkiδlj=k∑∂Ykj∂zAki
即矩阵 A T A^T AT的第 i i i行和 ∂ z ∂ Y \frac{∂z}{∂Y} ∂Y∂z的第 j j j列的内积。排列成矩阵即为:
∂ z ∂ X = A T ∂ z ∂ Y \frac{∂z}{∂X}=A^T\frac{∂z}{∂Y} ∂X∂z=AT∂Y∂z
总结下就是:
z = f ( Y ) , Y = A X + B → ∂ z ∂ X = A T ∂ z ∂ Y z=f(Y),Y=AX+B→\frac{∂z}{∂X}=A^T\frac{∂z}{∂Y} z=f(Y),Y=AX+B→∂X∂z=AT∂Y∂z
这结论在 x x x是一个向量的时候也成立,即:
z = f ( y ) , y = A x + b → ∂ z ∂ x = A T ∂ z ∂ y z=f(y),y=Ax+b→\frac{∂z}{∂x}=A^T\frac{∂z}{∂y} z=f(y),y=Ax+b→∂x∂z=AT∂y∂z
如果要求导的自变量在左边,线性变换在右边,也有类似稍有不同的结论如下,证明方法是类似的,这里直接给出结论:
z = f ( Y ) , Y = X A + B → ∂ z ∂ X = ∂ z ∂ Y A T z = f ( y ) , y = X a + b → ∂ z ∂ X = ∂ z ∂ y a T \begin{aligned} z=f(Y),Y&=XA+B→\frac{∂z}{∂X}=\frac{∂z}{∂Y}A^T \\ z=f(y),y&=Xa+b→\frac{∂z}{∂X}=\frac{∂z}{∂y}a^T \end{aligned} z=f(Y),Yz=f(y),y=XA+B→∂X∂z=∂Y∂zAT=Xa+b→∂X∂z=∂y∂zaT
使用好上述四个结论,对于机器学习尤其是深度学习里的求导问题可以非常快的解决,大家可以试一试。
4. 矩阵向量求导小结
矩阵向量求导在前面我们讨论三种方法,定义法,微分法和链式求导法。在同等情况下,优先考虑链式求导法,尤其是第三节的四个结论。其次选择微分法、在没有好的求导方法的时候使用定义法是最后的保底方案。
基本上大家看了系列里这四篇后对矩阵向量求导就已经很熟悉了,对于机器学习中出现的矩阵向量求导问题已足够。这里还没有讲到的是矩阵对矩阵的求导,还有矩阵对向量,向量对矩阵求导这三种形式,这个我们在系列的下一篇,也是最后一篇简单讨论一下,如果大家只是关注机器学习的优化问题,不涉及其他应用数学问题的,可以不关注。
转载自:

DAMO开发者矩阵,由阿里巴巴达摩院和中国互联网协会联合发起,致力于探讨最前沿的技术趋势与应用成果,搭建高质量的交流与分享平台,推动技术创新与产业应用链接,围绕“人工智能与新型计算”构建开放共享的开发者生态。
更多推荐
所有评论(0)