• 机器学习之变分推断(三)基于平均场假设变分推断与广义EM


    引言

    上一节介绍了基于平均场假设 的变分推断推导过程。本节将介绍平均场假设变分推断与广义EM之间的联系

    回顾:基于平均场假设的变分推断

    首先,平均场理论(Mean Theory)是一个物理学的概念,将隐变量在概率图中的状态变量 划分成 M \mathcal M M个组,将整个关于 隐变量的概率分布看作 M \mathcal M M个独立的子概率分布。数学符号表示如下:
    Q ( Z ) = ∏ i = 1 M Q i ( Z ( i ) ) = Q 1 ( Z ( 1 ) ) ⋅ Q 2 ( Z ( 2 ) ) ⋯ Q M ( Z ( M ) ) Q(Z)=Mi=1Qi(Z(i))=Q1(Z(1))Q2(Z(2))QM(Z(M))

    Q(Z)=i=1MQi(Z(i))=Q1(Z(1))Q2(Z(2))QM(Z(M))
    由于平均场假设 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)内部各子概率分布 Q i ( Z ( i ) ) \mathcal Q_{i}(\mathcal Z^{(i)}) Qi(Z(i))之间相互独立,因此,在求解 任一子概率分布 Q j ( Z ( j ) ) ( j ∈ { 1 , 2 , ⋯   , M } ) \mathcal Q_j(\mathcal Z^{(j)})(j \in \{1,2,\cdots,\mathcal M\}) Qj(Z(j))(j{1,2,,M}) 过程中,可以通过固定剩余的 M − 1 \mathcal M - 1 M1项进行求解。令:
    注意:由于只将 Z ( j ) \mathcal Z^{(j)} Z(j)看作变量,因此该期望基于的分布 ∏ i ≠ j M Q i ( Z ( i ) ) \prod_{i \neq j}^{\mathcal M} \mathcal Q_i(\mathcal Z^{(i)}) i=jMQi(Z(i))是已知分布。同理,隐变量 Z = ( Z ( 1 ) , Z ( 2 ) , ⋯   , Z ( M ) ) \mathcal Z = (\mathcal Z^{(1)},\mathcal Z^{(2)},\cdots,\mathcal Z^{(\mathcal M)}) Z=(Z(1),Z(2),,Z(M))中只有 Z ( j ) \mathcal Z^{(j)} Z(j)是变量,其余均是常数。因此,将该期望视作关于 X , Z ( j ) \mathcal X,\mathcal Z^{(j)} X,Z(j)的函数。
    E ∏ i ≠ j M Q i ( Z ( i ) ) [ log ⁡ P ( X , Z ) ] = log ⁡ ϕ ^ ( X , Z ( j ) ) \mathbb E_{\prod_{i \neq j}^{\mathcal M} \mathcal Q_i(\mathcal Z^{(i)})} \left[ \log P(\mathcal X,\mathcal Z)\right] = \log \hat \phi (\mathcal X ,\mathcal Z^{(j)}) Ei=jMQi(Z(i))[logP(X,Z)]=logϕ^(X,Z(j))

    从而求解最优 Q j ^ ( Z ( j ) ) \hat {\mathcal Q_j}(\mathcal Z^{(j)}) Qj^(Z(j))的值:
    Q j ^ ( Z ( j ) ) = arg ⁡ max ⁡ Q j ( Z ( j ) ) L [ Q ( Z ) ] = arg ⁡ max ⁡ Q j ( Z ( j ) ) { − K L [ ϕ ^ ( X , Z ( j ) ) ∣ ∣ Q j ( Z ( j ) ) ] } ^Qj(Z(j))=argmaxQj(Z(j))L[Q(Z)]=argmaxQj(Z(j)){KL[ˆϕ(X,Z(j))||Qj(Z(j))]}

    Qj^(Z(j))=Qj(Z(j))argmaxL[Q(Z)]=Qj(Z(j))argmax{KL[ϕ^(X,Z(j))∣∣Qj(Z(j))]}
    同理,可以尝试求解其他的子概率分布
    Q 1 ^ ( Z ( 1 ) ) , Q 1 ^ ( Z ( 1 ) ) , ⋯   , Q M ^ ( Z ( M ) ) \hat {\mathcal Q_1}(\mathcal Z^{(1)}),\hat {\mathcal Q_1}(\mathcal Z^{(1)}),\cdots, \hat {\mathcal Q_{\mathcal M}}(\mathcal Z^{(\mathcal M)}) Q1^(Z(1)),Q1^(Z(1)),,QM^(Z(M))
    最终,求得最优解 Q ^ ( Z ) \hat {\mathcal Q}(\mathcal Z) Q^(Z)
    Q ^ ( Z ) = ∏ j = 1 M Q ^ j ( Z ( j ) ) \hat {\mathcal Q}(\mathcal Z) = \prod_{j=1}^{\mathcal M}\hat {\mathcal Q}_j(\mathcal Z^{(j)}) Q^(Z)=j=1MQ^j(Z(j))

    深入认识平均场假设

    观察上式,上述的推导过程看似无懈可击,但实际上 存在漏洞
    并不是说 Q ^ ( Z ) = ∏ j = 1 M Q ^ j ( Z ( j ) ) \hat {\mathcal Q}(\mathcal Z) = \prod_{j=1}^{\mathcal M}\hat {\mathcal Q}_j(\mathcal Z^{(j)}) Q^(Z)=j=1MQ^j(Z(j))是错误的,因为该式子是 平均场假设给我们提供的条件。具体漏洞在什么地方?

    如果我们将 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)看成关于 Q 1 ( Z ( 1 ) ) , Q 2 ( Z ( 2 ) ) , ⋯   , Q M ( Z ( M ) ) \mathcal Q_1(\mathcal Z^{(1)}),\mathcal Q_2(\mathcal Z^{(2)}),\cdots,\mathcal Q_{\mathcal M}(\mathcal Z^{(\mathcal M)}) Q1(Z(1)),Q2(Z(2)),,QM(Z(M))的函数。即令:
    将上式展开即可~
    Q ( Z ) = Q 1 ( Z ( 1 ) ) ⋅ Q 2 ( Z ( 2 ) ) ⋯ Q M ( Z ( M ) ) = J ( Q 1 , Q 2 , ⋯   , Q M ) Q(Z)=Q1(Z(1))Q2(Z(2))QM(Z(M))=J(Q1,Q2,,QM)

    Q(Z)=Q1(Z(1))Q2(Z(2))QM(Z(M))=J(Q1,Q2,,QM)
    每一次都固定 M − 1 \mathcal M - 1 M1的变量,只为求出剩余变量的最优结果。那么如果初始隐变量是随机的,即如果每次求解过程中都对随机结果进行固定并求解,那么我们总是得不到一个最优结果
    因此,如何在各自概率分布分别固定的过程中,使 Q ^ ( Z ) \hat {\mathcal Q}(\mathcal Z) Q^(Z)越来越好,最终达到最优
    依然是坐标上升法(Coordinate Ascent)。

    • 假设当前求解 Q 1 ( Z ( 1 ) ) \mathcal Q_1(\mathcal Z^{(1)}) Q1(Z(1)),同时固定其余所有分布,我们会得到如下结果:
      Q 1 ^ ( Z ( 1 ) ) = arg ⁡ max ⁡ Q 1 ( Z ( 1 ) ) { − K L [ ϕ ^ ( X , Z ( 1 ) ) ∣ ∣ Q 1 ( Z ( 1 ) ) ] } \hat {\mathcal Q_1}(\mathcal Z^{(1)}) = \mathop{\arg\max}\limits_{\mathcal Q_1(\mathcal Z^{(1)})} \left\{- \mathcal K\mathcal L \left[ \hat \phi(\mathcal X,\mathcal Z^{(1)}) || \mathcal Q_1(\mathcal Z^{(1)})\right] \right\} Q1^(Z(1))=Q1(Z(1))argmax{KL[ϕ^(X,Z(1))∣∣Q1(Z(1))]}
      此时的 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)表示如下:
      将第一次迭代产生的最优解 Q 1 ^ ( Z ( 1 ) ) \hat {\mathcal Q_1}(\mathcal Z^{(1)}) Q1^(Z(1))带进 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)中。
      Q ( Z ) = J ( Q 1 ^ , Q 2 , ⋯   , Q M ) \mathcal Q(\mathcal Z) = \mathcal J(\hat {\mathcal Q_1},\mathcal Q_2,\cdots,\mathcal Q_{\mathcal M}) Q(Z)=J(Q1^,Q2,,QM)
    • 第一步的基础上,求解 Q ^ 2 ( Z ( 2 ) ) \hat {\mathcal Q}_2(\mathcal Z^{(2)}) Q^2(Z(2))
      Q 2 ^ ( Z ( 2 ) ) = arg ⁡ max ⁡ Q 2 ( Z ( 2 ) ) { − K L [ ϕ ^ ( X , Z ( 2 ) ) ∣ ∣ Q 2 ( Z ( 2 ) ) ] } \hat {\mathcal Q_2}(\mathcal Z^{(2)}) = \mathop{\arg\max}\limits_{\mathcal Q_2(\mathcal Z^{(2)})} \left\{- \mathcal K\mathcal L \left[ \hat \phi(\mathcal X,\mathcal Z^{(2)}) || \mathcal Q_2(\mathcal Z^{(2)})\right] \right\} Q2^(Z(2))=Q2(Z(2))argmax{KL[ϕ^(X,Z(2))∣∣Q2(Z(2))]}
      此时的 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)表示如下:
      同上~
      Q ( Z ) = J ( Q 1 ^ , Q 2 ^ , ⋯   , Q M ) \mathcal Q(\mathcal Z) = \mathcal J(\hat {\mathcal Q_1},\hat {\mathcal Q_2},\cdots,\mathcal Q_{\mathcal M}) Q(Z)=J(Q1^,Q2^,,QM)
    • 以此类推,直到固定最后一个子概率分布 Q M ( Z ( M ) ) \mathcal Q_{\mathcal M}(\mathcal Z^{(\mathcal M)}) QM(Z(M)),最终得到:
      Q ( Z ) = J ( Q 1 ^ , Q 2 ^ , ⋯   , Q ^ M ) \mathcal Q(\mathcal Z) = \mathcal J(\hat {\mathcal Q_1},\hat {\mathcal Q_2},\cdots,\hat {\mathcal Q}_{\mathcal M}) Q(Z)=J(Q1^,Q2^,,Q^M)
      此时,我们已经将 所有的子概率分布 全部求解一遍,并不是说此时的 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)就是最优分布,而是仅完整地执行了第一次迭代
      后续将继续从第一个 Q 1 ^ ( Z ( 1 ) ) \hat {\mathcal Q_1}(\mathcal Z^{(1)}) Q1^(Z(1))再次进行求解。这种方式就可以逐渐得到最优的 Q ( Z ) \mathcal Q(\mathcal Z) Q(Z)

    经典变分推断与广义EM

    基于平均场假设的变分推断与广义EM存在很多相似之处:

    • 它们的求解方式都是基于 K L \mathcal K\mathcal L KL散度的性质,将求解过程转换为如下形式
      arg ⁡ max ⁡ E L B O \mathop{\arg\max} ELBO argmaxELBO
    • 并且它们均转化为坐标上升法来解决最优解问题
      • 广义EM
        { Q ^ ( Z ) = arg ⁡ max ⁡ Q ( Z ) L [ Q ( Z ) , θ ] θ ^ = arg ⁡ max ⁡ θ L [ Q ( Z ) , θ ] L [ Q ( Z ) , θ ] = ∫ Z Q ( Z ) log ⁡ P ( X , Z ∣ θ ( t ) ) Q ( Z ) d Z {ˆQ(Z)=argmaxQ(Z)L[Q(Z),θ]ˆθ=argmaxθL[Q(Z),θ]
        \\ \mathcal L[\mathcal Q(\mathcal Z),\theta] = \int_{\mathcal Z} \mathcal Q(\mathcal Z) \log \frac{P(\mathcal X,\mathcal Z \mid \theta^{(t)})}{\mathcal Q(\mathcal Z)} d\mathcal Z
        Q^(Z)=Q(Z)argmaxL[Q(Z),θ]θ^=θargmaxL[Q(Z),θ]L[Q(Z),θ]=ZQ(Z)logQ(Z)P(X,Zθ(t))dZ
      • 基于平均场假设的变分推断
        Q j ^ ( Z ( j ) ) = arg ⁡ max ⁡ Q j ( Z ( j ) ) L [ Q ( Z ) ] = arg ⁡ max ⁡ Q j ( Z ( j ) ) { − K L [ ϕ ^ ( X , Z ( j ) ) ∣ ∣ Q j ( Z ( j ) ) ] } L [ Q ( Z ) ] = ∫ Z Q ( Z ) ⋅ log ⁡ [ P ( X , Z ) Q ( Z ) ] d Z Q ( Z ) = J ( Q 1 , ⋯   , Q ^ j , ⋯   , Q M ) ^Qj(Z(j))=argmaxQj(Z(j))L[Q(Z)]=argmaxQj(Z(j)){KL[ˆϕ(X,Z(j))||Qj(Z(j))]}
        \\ \mathcal L[\mathcal Q(\mathcal Z)] = \int_{\mathcal Z} \mathcal Q(\mathcal Z) \cdot \log \left[\frac{P(\mathcal X,\mathcal Z)}{\mathcal Q(\mathcal Z)}\right]d\mathcal Z \\ \mathcal Q(\mathcal Z) = \mathcal J(\mathcal Q_1,\cdots,\hat {\mathcal Q}_j,\cdots,\mathcal Q_{\mathcal M})
        Qj^(Z(j))=Qj(Z(j))argmaxL[Q(Z)]=Qj(Z(j))argmax{KL[ϕ^(X,Z(j))∣∣Qj(Z(j))]}L[Q(Z)]=ZQ(Z)log[Q(Z)P(X,Z)]dZQ(Z)=J(Q1,,Q^j,,QM)

    它们之间的核心区别更在于对于问题的理解角度不同

    • 广义EM算法的核心依然是 频率学派角度的求解逻辑——求解概率模型 P ( X ∣ θ ) P(\mathcal X \mid \theta) P(Xθ)中的最优参数 θ ^ \hat \theta θ^。它的底层逻辑依然是 极大似然估计(Maximum Likelihood Estimate,MLE);
    • 相比于广义EM算法基于平均场假设的变分推断的核心是 贝叶斯学派角度的求解逻辑:针对 P ( X ) P(\mathcal X) P(X)积分难的问题
      P ( X ) = ∫ Z P ( X , Z ) d Z = ∫ Z P ( X ∣ Z ) ⋅ P ( Z ) d Z P(X)=ZP(X,Z)dZ=ZP(XZ)P(Z)dZ
      P(X)=ZP(X,Z)dZ=ZP(XZ)P(Z)dZ

      通过对 P ( θ ∣ X ) P(\theta \mid \mathcal X) P(θX)采用近似手段,将关于参数的后验求解出来。换句话说,对于参数结果 θ \theta θ在贝叶斯学派角度中并不是不存在,而是贝叶斯学派角度并不关心 θ \theta θ的具体值,而是关心 θ \theta θ的后验分布
      因此,在整个变分推断的推导过程中,我们总是有意地弱化模型参数 θ \theta θ的作用,而更加关注后验概率本身。

    相关参考:
    机器学习-变分推断3(再回首)

  • 相关阅读:
    回顾.NET系列:Framework、Net Core、Net 过往
    Python 正则表达式:强大的文本处理工具
    Java8特性,Stream流的使用,收集成为map集合
    运算符、流程控制
    漏洞复现-phpmyadmin_SQL注入 (CVE-2020-5504)
    java计算机毕业设计酒店后厨供应商订单合并系统源码+数据库+lw文档+系统
    代码随想录1.5——数组:35搜索插入位置、34在排序数组中查找元素的第一个和最后一个位置、26.删除排序数组中的重复项、283移动零
    RabbitMQ支持的消息模型
    现在学编程还有出路吗?程序员的出路在哪里?
    JVM内存模型解析
  • 原文地址:https://blog.csdn.net/qq_34758157/article/details/126915662