NCE损失对应的论文为《A fast and simple algorithm for training neural probabilistic language models》,发表于2012年的ICML会议。
背景
在2012年,语言模型一般采用n-gram的方法,统计单词/上下文间的共现关系,比神经概率语言模型(neural probabilistic language models, NPLMs)效果好。
现在主流的语言模型都是神经概率语言模 型,核心思想是已知上下文 h h h,预测下一个词为 w w w的概率,通过一定的解码方法(例如greedy search、beam search等),对概率做解码,得到下一个词。Greedy search可以理解为选择概率最大的那个词。
2012年神经概率语言模型效果不好的原因是难训练。一方面自然是硬件的制约,那一年英伟达刚发布GTX680,和现在的A100、H100完全没法比。当时老黄不给力,学术界也没办法;另一方面是算法效率不行,难以进行大规模的分类学习,将”已知上下文 h h h,预测下一个词为 w i w_i wi的概率“建模成分类学习任务,目的在于把下一个词分类到词表中的某个词上。
举个例子,已知上下文是“我想去”,需要预测下一个词。词表中有4个词,即['北京','上海','天津','广州']
,需要把下一个词归类到词表的4个词里。如果词表有10万个词呢?训不动啊~
这就是当时面临的困境。NCE对分类算法做了优化,使得对大词表做分类任务成为可能。
原理
通俗的背景讲完了,接下来谈谈公式化的原理部分。
问题建模
已知上下文 h h h,预测下一次词为 w w w的概率为:
P θ h ( w ) = e x p ( s θ ( w , h ) ) ∑ w i e x p ( s θ ( w i , h ) ) (1) P_{\theta}^h(w)=\frac{exp(s_{\theta}(w,h))}{\sum_{w_i}{exp(s_{\theta}(w_i,h))}}\tag{1} Pθh(w)=∑wiexp(sθ(wi,h))exp(sθ(w,h))(1)
其中, s θ ( w , h ) s_{\theta}(w,h) sθ(w,h)表示已知上下文 h h h,下一个词为 w w w的预测得分; ∑ w i \sum_{w_i} ∑wi表示词表内的所有词。
一般情况下, s θ ( w , h ) s_{\theta}(w,h) sθ(w,h)通过对上下文 h h h表征以及词类别 w w w表征添加多个全连接层计算得到。最简单的策略,仅对上下文 h h h表征 f h f_h fh用一个全连接层 W W W做一次映射,再和词类别 w w w表征 f w i f_{w_i} fwi做点积即可。
s θ ( w , h ) = ( f h W ) ⋅ f w s_{\theta}(w,h)=(f_h W) \cdot f_{w} sθ(w,h)=(fhW)⋅fw
难度分析
对公式(1)进行分析,
分子部分 e x p ( s θ ( w , h ) ) exp(s_{\theta}(w,h)) exp(sθ(w,h))是好算的,针对单个 w w w,只需要计算一次。
分母部分KaTeX parse error: \tag works only in display equations不好算,针对单个 w w w,需要计算 e x p ( s θ ( w 1 , h ) ) , e x p ( s θ ( w 2 , h ) ) , . . . e x p ( s θ ( w n , h ) ) exp(s_{\theta}(w_1,h)), exp(s_{\theta}(w_2,h)), ...exp(s_{\theta}(w_n,h)) exp(sθ(w1,h)),exp(sθ(w2,h)),...exp(sθ(wn,h)),如果词表中词很多,计算量不小。
目前学术界、工业界对超大规模分类的优化基本上都聚焦在如何优化分母上,例如InfoNCE仅关注batch内的负类样本、KNN softmax对类别聚类,减少类别数目、partial FC对类别做采样以及显存均分来较少计算量、Inf-CL借助FlashAttention的思想,以空间换时间。
优化策略
既然对词表内n个词的大规模分类任务难做,难办,那就掀桌子不办了!!!
将原多分类任务转换成一个更容易实现的任务——新二分类任务。
除了有正常的真实数据之外,从一个噪声分布里采样噪声数据,对真实数据和噪声数据做二分类,可以证明:随着噪声数据越多,转换后任务的优化目标和转换前任务越接近。
新二分类任务
给定上下文 h h h后,现在有两个数据分布,一个是真实数据分布 P d h ( w ) P_d^h(w) Pdh(w)(实际应该写成 P d ( w ∣ h ) P_d(w|h) Pd(w∣h),简化形式写成 P d h ( w ) P_d^h(w) Pdh(w)),另一个是噪声数据分布 P n ( w ) P_n(w) Pn(w),真实数据和噪声数据的比例是1:k
。所以,训练数据的完整分布是 P h ( w ) = 1 k + 1 P d h ( w ) + k k + 1 P n ( w ) P^h(w)=\frac{1}{k+1}P_d^h(w)+\frac{k}{k+1}P_n(w) Ph(w)=k+11Pdh(w)+k+1kPn(w),训练任务是 D = 1 D=1 D=1(分辨真实数据)和 D = 0 D=0 D=0(分辨噪声数据)。
我们希望优化神经网络参数 θ \theta θ,来拟合真实数据分布 P d h ( w ) = P θ h ( w ) P_d^h(w)=P^h_{\theta}(w) Pdh(w)=Pθh(w),后者就是我们学到的数据分布 P θ h ( w ) P^h_{\theta}(w) Pθh(w),于是,训练数据的完整分布写成 P h ( w , θ ) = 1 k + 1 P θ h ( w ) + k k + 1 P n ( w ) P^h(w,\theta)=\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w) Ph(w,θ)=k+11Pθh(w)+k+1kPn(w)
训练目标一般是最大化后验概率 P h ( D ∣ w , θ ) P^h(D|w,\theta) Ph(D∣w,θ)的对数似然期望 E [ l o g ( P h ( D ∣ w , θ ) ) ] E \left[log(P^h(D|w,\theta))\right] E[log(Ph(D∣w,θ))],需要计算后验概率 P h ( D ∣ w , θ ) P^h(D|w,\theta) Ph(D∣w,θ)。
P h ( D ∣ w , θ ) = P h ( D = 1 ∣ w , θ ) + P h ( D = 0 ∣ w , θ ) (2) P^h(D|w,\theta)=P^h(D=1|w,\theta)+P^h(D=0|w,\theta)\tag{2} Ph(D∣w,θ)=Ph(D=1∣w,θ)+Ph(D=0∣w,θ)(2)
真实数据分布的后验概率为:
P h ( D = 1 ∣ w , θ ) = P h ( w , θ ∣ D = 1 ) P h ( w , θ ) P h ( D = 1 ) = P θ h ( w ) 1 k + 1 P θ h ( w ) + k k + 1 P n ( w ) 1 k + 1 = P θ h ( w ) P θ h ( w ) + k P n ( w ) (3) \begin{equation}\begin{aligned} P^h(D=1|w,\theta) &= \frac{P^h(w,\theta|D=1)}{P^h(w,\theta)}P^h(D=1) \\ &=\frac{P_{\theta}^h(w)}{\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w)}\frac{1}{k+1} \\ &=\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)} \end{aligned} \end{equation} \tag{3} Ph(D=1∣w,θ)=Ph(w,θ)Ph(w,θ∣D=1)Ph(D=1)=k+11Pθh(w)+k+1kPn(w)Pθh(w)k+11=Pθh(w)+kPn(w)Pθh(w)(3)
我们来看看等式为什么成立
- 边缘概率 P h ( w , θ ) = 1 k + 1 P θ h ( w ) + k k + 1 P n ( w ) P^h(w,\theta)=\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w) Ph(w,θ)=k+11Pθh(w)+k+1kPn(w)
- 先验概率 P h ( D = 1 ) = 1 k + 1 P^h(D=1)=\frac{1}{k+1} Ph(D=1)=k+11,原因是真实数据和噪声数据的比例是
1:k
。 - 似然函数 P h ( w , θ ∣ D = 1 ) = P θ h ( w ) P^h(w,\theta|D=1)=P^h_{\theta}(w) Ph(w,θ∣D=1)=Pθh(w),表明在真实数据分布下,从词表里预测下一个词为 w w w的概率是 P θ h ( w ) P^h_{\theta}(w) Pθh(w),这就是我们想拟合的函数。
类似的,噪声数据分布的后验概率为:
P h ( D = 0 ∣ w , θ ) = P h ( w , θ ∣ D = 0 ) P h ( w , θ ) P h ( D = 0 ) = P n ( w ) 1 k + 1 P θ h ( w ) + k k + 1 P n ( w ) k k + 1 = k P n ( w ) P θ h ( w ) + k P n ( w ) (4) \begin{equation}\begin{aligned} P^h(D=0|w,\theta) &= \frac{P^h(w,\theta|D=0)}{P^h(w,\theta)}P^h(D=0) \\ &=\frac{P_n(w)}{\frac{1}{k+1}P^h_{\theta}(w)+\frac{k}{k+1}P_n(w)}\frac{k}{k+1} \\ &=\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)} \end{aligned} \end{equation} \tag{4} Ph(D=0∣w,θ)=Ph(w,θ)Ph(w,θ∣D=0)Ph(D=0)=k+11Pθh(w)+k+1kPn(w)Pn(w)k+1k=Pθh(w)+kPn(w)kPn(w)(4)
后验概率 P h ( D ∣ w i , θ ) P^h(D|w_i,\theta) Ph(D∣wi,θ)的对数似然的期望 E [ l o g ( P h ( D ∣ w i , θ ) ) ] E \left[log(P^h(D|w_i,\theta))\right] E[log(Ph(D∣wi,θ))]为
J h ( θ ) = E [ l o g ( P h ( D ∣ w , θ ) ) ] = E P d h [ l o g P h ( D = 1 ∣ w , θ ) ] + E P n [ l o g P h ( D = 0 ∣ w , θ ) ] = E P d h [ l o g P θ h ( w ) P θ h ( w ) + k P n ( w ) ] + E P n [ l o g k P n ( w ) P θ h ( w ) + k P n ( w ) ] (5) \begin{equation}\begin{aligned} J^h(\theta)&=E \left[log(P^h(D|w,\theta))\right] \\ &= E_{P_d^h}\left[logP^h(D=1|w,\theta)\right] +E_{P_n}\left[logP^h(D=0|w,\theta)\right] \\ &= E_{P_d^h}\left[log\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\right] +E_{P_n}\left[log\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\right] \\ \end{aligned} \end{equation} \tag{5} Jh(θ)=E[log(Ph(D∣w,θ))]=EPdh[logPh(D=1∣w,θ)]+EPn[logPh(D=0∣w,θ)]=EPdh[logPθh(w)+kPn(w)Pθh(w)]+EPn[logPθh(w)+kPn(w)kPn(w)](5)
我们来算一下梯度,等于
∂ ∂ θ J h ( θ ) = E P d h [ k P n ( w ) P θ h ( w ) + k P n ( w ) ∂ ∂ θ l o g P θ h ( w ) ] − k E P n [ P θ h ( w ) P θ h ( w ) + k P n ( w ) ∂ ∂ θ l o g P θ h ( w ) ] (6) \begin{equation} \begin{aligned} \frac{\partial}{\partial{\theta}}{J^h(\theta)}&= E_{P_d^h}\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right] -\\&kE_{P_n}\left[\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right] \end{aligned} \end{equation} \tag{6} ∂θ∂Jh(θ)=EPdh[Pθh(w)+kPn(w)kPn(w)∂θ∂logPθh(w)]−kEPn[Pθh(w)+kPn(w)Pθh(w)∂θ∂logPθh(w)](6)
对(6)式做化简,有
∂ ∂ θ J h ( θ ) = E P d h [ k P n ( w ) P θ h ( w ) + k P n ( w i ) ∂ ∂ θ l o g P θ h ( w ) ] − k E P n [ P θ h ( w ) P θ h ( w ) + k P n ( w ) ∂ ∂ θ l o g P θ h ( w ) ] = ∑ w [ P d h ⋅ k P n ( w ) P θ h ( w ) + k P n ( w ) ∂ ∂ θ l o g P θ h ( w ) − k P n ⋅ P θ h ( w ) P θ h ( w ) + k P n ( w ) ∂ ∂ θ l o g P θ h ( w ) ] = ∑ w [ k P n ( w ) P θ h ( w ) + k P n ( w ) × ( P d h ( w ) − P θ h ( w ) ) ∂ ∂ θ l o g P θ h ( w ) ] (7) \begin{equation} \begin{aligned} \frac{\partial}{\partial{\theta}}{J^h(\theta)}&= E_{P_d^h}\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w_i)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right] -\\&kE_{P_n}\left[\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)\right]\\ &=\sum_w\left[P_d^h\cdot\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w)-\right.\\ &\left. kP_{n}\cdot\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right]\\ &=\sum_w\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\times\right.\\ &\left. (P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right] \end{aligned} \end{equation} \tag{7} ∂θ∂Jh(θ)=EPdh[Pθh(w)+kPn(wi)kPn(w)∂θ∂logPθh(w)]−kEPn[Pθh(w)+kPn(w)Pθh(w)∂θ∂logPθh(w)]=w∑[Pdh⋅Pθh(w)+kPn(w)kPn(w)∂θ∂logPθh(w)−kPn⋅Pθh(w)+kPn(w)Pθh(w)∂θ∂logPθh(w)]=w∑[Pθh(w)+kPn(w)kPn(w)×(Pdh(w)−Pθh(w))∂θ∂logPθh(w)](7)
当噪声数据量级巨大, k → ∞ k\to \infty k→∞ , k P n ( w ) P θ h ( w ) + k P n ( w ) → 1 \frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\to1 Pθh(w)+kPn(w)kPn(w)→1 ,有
∂ ∂ θ J h ( θ ) = ∑ w [ k P n ( w ) P θ h ( w ) + k P n ( w ) × ( P d h ( w ) − P θ h ( w ) ) ∂ ∂ θ l o g P θ h ( w ) ] → ∑ w [ ( P d h ( w ) − P θ h ( w ) ) ∂ ∂ θ l o g P θ h ( w ) ] (8) \begin{equation} \begin{aligned} \frac{\partial}{\partial{\theta}}{J^h(\theta)}&= \sum_w\left[\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\times\right.\\ &\left. (P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right]\\ &\to \sum_w\left[(P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}logP_{\theta}^h(w) \right] \end{aligned} \end{equation} \tag{8} ∂θ∂Jh(θ)=w∑[Pθh(w)+kPn(w)kPn(w)×(Pdh(w)−Pθh(w))∂θ∂logPθh(w)]→w∑[(Pdh(w)−Pθh(w))∂θ∂logPθh(w)](8)
原多分类任务
我们计算下原多分类任务的对数似然期望和梯度,看看 k → ∞ k\to \infty k→∞ 时的新二分类任务和原多分类任务有什么关系。原多分类任务的优化目标为
J h ( θ ) = E P d h [ l o g ( P θ h ( w ) ] = E P d h [ l o g ( e x p ( s θ ( w , h ) ) ∑ w e x p ( s θ ( w , h ) ) ) ] = E P d h [ s θ ( w , h ) ] − E P d h [ l o g ( ∑ w e x p ( s θ ( w , h ) ) ) ] = E P d h [ s θ ( w , h ) ] − l o g ( ∑ w e x p ( s θ ( w , h ) ) ) (9) \begin{equation}\begin{aligned} J^h(\theta)&=E_{P_d^h} \left[log(P_{\theta}^h(w)\right] \\ &= E_{P_d^h} \left[log\left(\frac{exp(s_{\theta}(w,h))}{\sum_w{exp(s_{\theta}(w,h))}}\right)\right]\\ &=E_{P_d^h}\left[s_{\theta}(w,h)\right]-E_{P_d^h}\left[log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right)\right]\\ &=E_{P_d^h}\left[s_{\theta}(w,h)\right]-log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right) \end{aligned} \end{equation} \tag{9} Jh(θ)=EPdh[log(Pθh(w)]=EPdh[log(∑wexp(sθ(w,h))exp(sθ(w,h)))]=EPdh[sθ(w,h)]−EPdh[log(w∑exp(sθ(w,h)))]=EPdh[sθ(w,h)]−log(w∑exp(sθ(w,h)))(9)
等式最后一步成立的原因是 [ l o g ( ∑ w e x p ( s θ ( w , h ) ) ) ] \left[log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right)\right] [log(∑wexp(sθ(w,h)))]仅和模型预测分布 P θ h P_{\theta}^h Pθh有关,和真实数据分布 P d h P_d^h Pdh无关。
对(9)式求梯度,有 ∂ ∂ θ J h ( θ ) = E P d h [ ∂ ∂ θ s θ ( w , h ) ] − ∂ ∂ θ l o g ( ∑ w e x p ( s θ ( w , h ) ) ) = E P d h [ ∂ ∂ θ s θ ( w , h ) ] − 1 ∑ w e x p ( s θ ( w , h ) ) ∂ ∂ θ ∑ w e x p ( s θ ( w , h ) ) = E P d h [ ∂ ∂ θ s θ ( w , h ) ] − 1 ∑ w e x p ( s θ ( w , h ) ) ∑ w ( s θ ( w , h ) ∂ ∂ θ s θ ( w , h ) ) = E P d h [ ∂ ∂ θ s θ ( w , h ) ] − ∑ w s θ ( w , h ) ∑ w e x p ( s θ ( w , h ) ) ∂ ∂ θ s θ ( w , h ) = E P d h [ ∂ ∂ θ s θ ( w , h ) ] − ∑ w P θ h ( w ) ∂ ∂ θ s θ ( w , h ) = E P d h [ ∂ ∂ θ s θ ( w , h ) ] − ∑ w P θ h ( w ) ∂ ∂ θ s θ ( w , h ) = ∑ w P d h ∂ ∂ θ s θ ( w , h ) − ∑ w P θ h ( w ) ∂ ∂ θ s θ ( w , h ) = ∑ w ( P d h ( w ) − P θ h ( w ) ) ∂ ∂ θ s θ ( w , h ) (10) \begin{equation}\begin{aligned} \frac{\partial}{\partial\theta}J^h(\theta)&=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\frac{\partial}{\partial\theta}log\left(\sum_w{exp\left(s_{\theta}(w,h)\right)}\right)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\frac{1}{\sum_w{exp\left(s_{\theta}(w,h)\right)}}\frac{\partial}{\partial\theta}\sum_w{exp\left(s_{\theta}(w,h)\right)}\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\frac{1}{\sum_w{exp\left(s_{\theta}(w,h)\right)}}\sum_w\left(s_{\theta}(w,h)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\sum_w\frac{s_{\theta}(w,h)}{\sum_w{exp\left(s_{\theta}(w,h)\right)}}\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\sum_wP_{\theta}^h(w)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=E_{P_d^h}\left[\frac{\partial}{\partial\theta}s_{\theta}(w,h)\right]-\sum_wP_{\theta}^h(w)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=\sum_wP_d^h\frac{\partial}{\partial\theta}s_{\theta}(w,h)-\sum_wP_{\theta}^h(w)\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ &=\sum_w(P_d^h(w)-P_{\theta}^h(w))\frac{\partial}{\partial\theta}s_{\theta}(w,h)\\ \end{aligned} \end{equation} \tag{10} ∂θ∂Jh(θ)=EPdh[∂θ∂sθ(w,h)]−∂θ∂log(w∑exp(sθ(w,h)))=EPdh[∂θ∂sθ(w,h)]−∑wexp(sθ(w,h))1∂θ∂w∑exp(sθ(w,h))=EPdh[∂θ∂sθ(w,h)]−∑wexp(sθ(w,h))1w∑(sθ(w,h)∂θ∂sθ(w,h))=EPdh[∂θ∂sθ(w,h)]−w∑∑wexp(sθ(w,h))sθ(w,h)∂θ∂sθ(w,h)=EPdh[∂θ∂sθ(w,h)]−w∑Pθh(w)∂θ∂sθ(w,h)=EPdh[∂θ∂sθ(w,h)]−w∑Pθh(w)∂θ∂sθ(w,h)=w∑Pdh∂θ∂sθ(w,h)−w∑Pθh(w)∂θ∂sθ(w,h)=w∑(Pdh(w)−Pθh(w))∂θ∂sθ(w,h)(10)对比公式(8)和公式(10),很像,但不一样。公式(8)最后是 ∂ ∂ θ l o g P θ h ( w ) \frac{\partial}{\partial\theta}logP_{\theta}^h(w) ∂θ∂logPθh(w),公式(10)最后是 ∂ ∂ θ s θ ( w , h ) \frac{\partial}{\partial\theta}s_{\theta}(w,h) ∂θ∂sθ(w,h),咋回事?
不一样就对了,在NCE中,我们可以将 ∑ w e x p ( s θ ( w , h ) ) \sum_w{exp\left(s_{\theta}(w,h)\right)} ∑wexp(sθ(w,h))等价成1,那公式(8)和公式(10)就一样了。那为什么可以等价呢?论文的说辞是: 模型参数较多,把正则项当做常数,公式中其他项,比如 s θ ,能学到正则项。 \textcolor{red}{{模型参数较多,把正则项当做常数,公式中其他项,比如s_{\theta},能学到正则项。}} 模型参数较多,把正则项当做常数,公式中其他项,比如sθ,能学到正则项。(正则项可以理解为 ∑ w e x p ( s θ ( w , h ) ) \sum_w{exp\left(s_{\theta}(w,h)\right)} ∑wexp(sθ(w,h))),那么 ∑ w e x p ( s θ ( w , h ) ) \sum_w{exp\left(s_{\theta}(w,h)\right)} ∑wexp(sθ(w,h))是1也好,100也好,都不会对模型收敛有影响。简单起见,当做1就行。
这段说辞还是太抽象了,有没有形象一点的解释?
两个任务为什么可以等价
原多分类任务
J h ( θ ) = E P d h [ l o g ( P θ h ( w ) ] = E P d h [ l o g ( e x p ( s θ ( w , h ) ) ∑ w e x p ( s θ ( w , h ) ) ) ] (11) \begin{equation}\begin{aligned} J^h(\theta)&=E_{P_d^h} \left[log(P_{\theta}^h(w)\right] \\ &= E_{P_d^h} \left[log\left(\frac{exp(s_{\theta}(w,h))}{\sum_w{exp(s_{\theta}(w,h))}}\right)\right] \end{aligned} \end{equation} \tag{11} Jh(θ)=EPdh[log(Pθh(w)]=EPdh[log(∑wexp(sθ(w,h))exp(sθ(w,h)))](11)
该任务的对数似然期望见公式(11), l o g log log函数曲线如下:
如果 l o g ( P θ h ( w ) = e x p ( s θ ( w , h ) ) ∈ [ 0 , + ∞ ] log(P_{\theta}^h(w)=exp(s_{\theta}(w,h))\in[0,+\infty] log(Pθh(w)=exp(sθ(w,h))∈[0,+∞], J h ( θ ) = E P d h [ l o g ( P θ h ( w ) ] J^h(\theta)=E_{P_d^h} \left[log(P_{\theta}^h(w)\right] Jh(θ)=EPdh[log(Pθh(w)]不存在极值,无法收敛。
如果对 l o g ( P θ h ( w ) = e x p ( s θ ( w , h ) ) ∈ [ 0 , + ∞ ] log(P_{\theta}^h(w)=exp(s_{\theta}(w,h))\in[0,+\infty] log(Pθh(w)=exp(sθ(w,h))∈[0,+∞]进行归一化, l o g ( P θ h ( w ) = [ l o g ( e x p ( s θ ( w , h ) ) ∑ w e x p ( s θ ( w , h ) ) ) ] ∈ ( 0 , 1 ) log(P_{\theta}^h(w)=\left[log\left(\frac{exp(s_{\theta}(w,h))}{\sum_w{exp(s_{\theta}(w,h))}}\right)\right]\in(0,1) log(Pθh(w)=[log(∑wexp(sθ(w,h))exp(sθ(w,h)))]∈(0,1), J h ( θ ) = E P d h [ l o g ( P θ h ( w ) ] J^h(\theta)=E_{P_d^h} \left[log(P_{\theta}^h(w)\right] Jh(θ)=EPdh[log(Pθh(w)]存在极值,具备收敛条件。
现二分类任务
从公式(5)可知,
J h ( θ ) = E [ l o g ( P h ( D ∣ w , θ ) ) ] = E P d h [ l o g P h ( D = 1 ∣ w , θ ) ] + E P n [ l o g P h ( D = 0 ∣ w , θ ) ] = E P d h [ l o g P θ h ( w ) P θ h ( w ) + k P n ( w ) ] + E P n [ l o g k P n ( w ) P θ h ( w ) + k P n ( w ) ] = E P d h [ l o g ( σ ( Δ ) ) ] + k E P n [ l o g ( 1 − σ ( Δ ) ) ] \begin{equation}\begin{aligned} J^h(\theta)&=E \left[log(P^h(D|w,\theta))\right] \\ &= E_{P_d^h}\left[logP^h(D=1|w,\theta)\right] +E_{P_n}\left[logP^h(D=0|w,\theta)\right] \\ &= E_{P_d^h}\left[log\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\right] +E_{P_n}\left[log\frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}\right] \\ &= E_{P_d^h}\left[log(\sigma({\Delta}))\right] +kE_{P_n}\left[log(1-\sigma({\Delta}))\right] \\ \end{aligned} \tag{12}\end{equation} Jh(θ)=E[log(Ph(D∣w,θ))]=EPdh[logPh(D=1∣w,θ)]+EPn[logPh(D=0∣w,θ)]=EPdh[logPθh(w)+kPn(w)Pθh(w)]+EPn[logPθh(w)+kPn(w)kPn(w)]=EPdh[log(σ(Δ))]+kEPn[log(1−σ(Δ))](12)
,其中 Δ = l o g P θ h ( w ) − l o g k P n ( w ) \Delta=logP_{\theta}^h(w)-logkP_n(w) Δ=logPθh(w)−logkPn(w),将公式(5)推导成具备 σ \sigma σ的公式(12),原因在于求导方便, ∂ ∂ x σ ( x ) = σ ( x ) ( 1 − σ ( x ) ) \frac{\partial}{\partial x}\sigma(x)=\sigma(x)(1-\sigma(x)) ∂x∂σ(x)=σ(x)(1−σ(x)),将公式(5)推导成公式(12)的过程是:
P θ h ( w ) P θ h ( w ) + k P n ( w ) = 1 1 + k P n ( w ) P θ h ( w ) = 1 1 + e x p ( l o g ( k P n ( w ) P θ h ( w ) ) ) = 1 1 + e x p ( l o g k P n ( w ) − l o g P θ h ( w ) ) = 1 1 + e x p ( − ( l o g P θ h ( w ) − l o g k P n ( w ) ) ) = σ ( l o g P θ h ( w ) − l o g k P n ( w ) ) \begin{equation}\begin{aligned} \frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}&=\frac{1}{1+\frac{kP_n(w)}{P_{\theta}^h(w)}}\\ &=\frac{1}{1+exp(log(\frac{kP_n(w)}{P_{\theta}^h(w)}))}\\ &=\frac{1}{1+exp(logkP_n(w)-logP_{\theta}^h(w))}\\ &=\frac{1}{1+exp(-(logP_{\theta}^h(w)-logkP_n(w)))}\\ &=\sigma(logP_{\theta}^h(w)-logkP_n(w))\\ \end{aligned} \tag{12}\end{equation} Pθh(w)+kPn(w)Pθh(w)=1+Pθh(w)kPn(w)1=1+exp(log(Pθh(w)kPn(w)))1=1+exp(logkPn(w)−logPθh(w))1=1+exp(−(logPθh(w)−logkPn(w)))1=σ(logPθh(w)−logkPn(w))(12)
k P n ( w ) P θ h ( w ) + k P n ( w ) = 1 − P θ h ( w ) P θ h ( w ) + k P n ( w ) = 1 − σ ( l o g P θ h ( w ) − l o g k P n ( w ) ) \begin{equation}\begin{aligned} \frac{kP_n(w)}{P_{\theta}^h(w)+kP_n(w)}&=1-\frac{P_{\theta}^h(w)}{P_{\theta}^h(w)+kP_n(w)}\\ &=1-\sigma(logP_{\theta}^h(w)-logkP_n(w))\\ \end{aligned} \tag{13}\end{equation} Pθh(w)+kPn(w)kPn(w)=1−Pθh(w)+kPn(w)Pθh(w)=1−σ(logPθh(w)−logkPn(w))(13)
于是,计算对数似然均值(公式(12))对 l o g P θ h ( w ) logP_{\theta}^h(w) logPθh(w)的一阶导,有
∂ J h ( θ ) ∂ l o g P θ h ( w ) = ∂ J h ( θ ) ∂ Δ ∂ Δ ∂ l o g P θ h ( w ) = ∂ J h ( θ ) ∂ Δ = ∂ ∂ Δ { E P d h [ l o g ( σ ( Δ ) ) ] + k E P n [ l o g ( 1 − σ ( Δ ) ) ] } = E P d h [ ∂ ∂ Δ l o g ( σ ( Δ ) ) ] + k E P n [ ∂ ∂ Δ l o g ( 1 − σ ( Δ ) ) ] = E P d h [ 1 − σ ( Δ ) ] + k E P n [ − σ ( Δ ) ] = ∑ w P θ h ( w ) ( 1 − σ ( Δ ) ) − k P n ( w ) σ ( Δ ) \begin{equation}\begin{aligned} \frac{\partial J^h(\theta)}{\partial logP_{\theta}^h(w)} &=\frac{\partial J^h(\theta)}{\partial \Delta}\frac{\partial \Delta}{\partial logP_{\theta}^h(w)}\\ &=\frac{\partial J^h(\theta)}{\partial \Delta}\\ &=\frac{\partial }{\partial \Delta}\left\{E_{P_d^h}\left[log(\sigma({\Delta}))\right] +kE_{P_n}\left[log(1-\sigma({\Delta}))\right]\right\}\\ &=E_{P_d^h}\left[\frac{\partial }{\partial \Delta}log(\sigma({\Delta}))\right] +kE_{P_n}\left[\frac{\partial }{\partial \Delta}log(1-\sigma({\Delta}))\right]\\ &=E_{P_d^h}\left[1-\sigma({\Delta})\right] +kE_{P_n}\left[-\sigma({\Delta})\right]\\ &=\sum_wP_{\theta}^h(w)(1-\sigma({\Delta}))-kP_n(w)\sigma({\Delta})\\ \end{aligned} \tag{14}\end{equation} ∂logPθh(w)∂Jh(θ)=∂Δ∂Jh(θ)∂logPθh(w)∂Δ=∂Δ∂Jh(θ)=∂Δ∂{EPdh[log(σ(Δ))]+kEPn[log(1−σ(Δ))]}=EPdh[∂Δ∂log(σ(Δ))]+kEPn[∂Δ∂log(1−σ(Δ))]=EPdh[1−σ(Δ)]+kEPn[−σ(Δ)]=w∑Pθh(w)(1−σ(Δ))−kPn(w)σ(Δ)(14)
如果 P θ h ( w ) = P d h ( w ) P_{\theta}^h(w)=P_d^h(w) Pθh(w)=Pdh(w),对数似然均值达到极大值(这个是废话,因为训练目标就是希望 P θ h ( w ) → P d h ( w ) P_{\theta}^h(w)\to P_d^h(w) Pθh(w)→Pdh(w),并且在优化策略章节开始部分,我们就让 P θ h ( w ) = P d h ( w ) P_{\theta}^h(w)= P_d^h(w) Pθh(w)=Pdh(w))其中 P d h ( w ) P_d^h(w) Pdh(w)表示真实分布。
我们再计算对数似然均值(公式(12))对 l o g P θ h ( w ) logP_{\theta}^h(w) logPθh(w)的二阶导,有:
∂ 2 J h ( θ ) ∂ l o g 2 P θ h ( w ) = ∂ 2 J ( θ ) ∂ Δ 2 = ∂ ∂ Δ { E P d h [ 1 − σ ( Δ ) ] + k E P n [ − σ ( Δ ) ] } = E P d h ∂ ∂ Δ [ 1 − σ ( Δ ) ] + k E P n ∂ ∂ Δ [ − σ ( Δ ) ] = E P d h [ − σ ( Δ ) ( 1 − σ ( Δ ) ) ] + k E P n [ − σ ( Δ ) ( 1 − σ ( Δ ) ) ] \begin{equation}\begin{aligned} \frac{\partial^2 J^h(\theta)}{\partial log^2P_{\theta}^h(w)} &=\frac{\partial^2J(\theta)}{\partial \Delta^2}\\ &=\frac{\partial}{\partial \Delta} \left\{E_{P_d^h}\left[1-\sigma({\Delta})\right] +kE_{P_n}\left[-\sigma({\Delta})\right] \right\} \\ &= E_{P_d^h}\frac{\partial}{\partial \Delta}\left[1- \sigma({\Delta})\right] +kE_{P_n}\frac{\partial}{\partial \Delta}\left[-\sigma({\Delta})\right] \\ &= E_{P_d^h}[-\sigma(\Delta)(1-\sigma(\Delta))] +kE_{P_n}[-\sigma(\Delta)(1-\sigma(\Delta))] \\ \end{aligned} \tag{14}\end{equation} ∂log2Pθh(w)∂2Jh(θ)=∂Δ2∂2J(θ)=∂Δ∂{EPdh[1−σ(Δ)]+kEPn[−σ(Δ)]}=EPdh∂Δ∂[1−σ(Δ)]+kEPn∂Δ∂[−σ(Δ)]=EPdh[−σ(Δ)(1−σ(Δ))]+kEPn[−σ(Δ)(1−σ(Δ))](14)
因为 [ − σ ( Δ ) ( 1 − σ ( Δ ) ) ] [-\sigma(\Delta)(1-\sigma(\Delta))] [−σ(Δ)(1−σ(Δ))]始终小于0,所以二阶导始终小于0,说明新二分类任务的对数似然均值是关于 l o g P θ h ( w ) logP_{\theta}^h(w) logPθh(w)的凸函数,有唯一极大值。所以极大值一定是 P θ h ( w ) = P h ( w ) P_{\theta}^h(w)=P^h(w) Pθh(w)=Ph(w)。
最重要的是,整个推导过程对是否需要归一化没有要求,既然没有要求,直接让 ∑ w e x p ( s θ ( w , h ) ) = 1 \sum_w{exp\left(s_{\theta}(w,h)\right)}=1 ∑wexp(sθ(w,h))=1
代码实现
从公式(12),我们可以知道: Δ = l o g P θ h ( w ) − l o g k P n ( w ) \Delta=logP_{\theta}^h(w)-logkP_n(w) Δ=logPθh(w)−logkPn(w)
J h ( θ ) = E [ l o g ( P h ( D ∣ w , θ ) ) ] = E P d h [ l o g σ ( Δ ) ] + k E P n [ l o g ( 1 − σ ( Δ ) ) ] = E P d h [ l o g σ ( l o g P θ h ( w ) − l o g k P n ( w ) ) ] + k E P n [ l o g ( 1 − σ ( l o g P θ h ( w ) − l o g k P n ( w ) ) ) ] = ∑ w { P d h [ l o g σ ( l o g P θ h ( w ) − l o g k P n ( w ) ) ] } + k ∑ w { P n [ l o g ( 1 − σ ( l o g P θ h ( w ) − l o g k P n ( w ) ) ) ] } → l o g ( σ ( l o g P θ h ( w 0 ) − l o g k P n ( w 0 ) ) + ∑ i = 1 k [ l o g ( 1 − σ ( l o g P θ h ( w i ) − l o g k P n ( w i ) ) ) ] = l o g ( σ ( s θ ( w 0 , h ) − l o g k P n ( w 0 ) ) + ∑ i = 1 k [ l o g ( 1 − σ ( s θ ( w i , h ) − l o g k P n ( w i ) ) ) ] \begin{equation}\begin{aligned} J^h(\theta)&=E \left[log(P^h(D|w,\theta))\right] \\ &= E_{P_d^h}\left[log\sigma({\Delta})\right] +kE_{P_n}\left[log(1-\sigma({\Delta}))\right] \\ &= E_{P_d^h}\left[log\sigma(logP_{\theta}^h(w)-logkP_n(w))\right] +\\ &\quad\quad\quad\quad\quad\quad kE_{P_n}\left[log(1-\sigma(logP_{\theta}^h(w)-logkP_n(w)))\right] \\ &= \sum_w\left\{P_d^h\left[log\sigma(logP_{\theta}^h(w)-logkP_n(w))\right] \right\}+\\ &\quad\quad\quad\quad\quad\quad k\sum_w\left\{P_n\left[log(1-\sigma(logP_{\theta}^h(w)-logkP_n(w)))\right]\right\} \\ &\to log(\sigma(logP_{\theta}^h(w_0)-logkP_n(w_0)) +\\ &\quad\quad\quad\quad\quad\quad\sum_{i=1}^k\left[log(1-\sigma(logP_{\theta}^h(w_i)-logkP_n(w_i)))\right] \\ &=log(\sigma(s_{\theta}(w_0,h)-logkP_n(w_0)) +\\ &\quad\quad\quad\quad\quad\quad\sum_{i=1}^k\left[log(1-\sigma(s_{\theta}(w_i,h)-logkP_n(w_i)))\right] \\ \end{aligned} \tag{15}\end{equation} Jh(θ)=E[log(Ph(D∣w,θ))]=EPdh[logσ(Δ)]+kEPn[log(1−σ(Δ))]=EPdh[logσ(logPθh(w)−logkPn(w))]+kEPn[log(1−σ(logPθh(w)−logkPn(w)))]=w∑{Pdh[logσ(logPθh(w)−logkPn(w))]}+kw∑{Pn[log(1−σ(logPθh(w)−logkPn(w)))]}→log(σ(logPθh(w0)−logkPn(w0))+i=1∑k[log(1−σ(logPθh(wi)−logkPn(wi)))]=log(σ(sθ(w0,h)−logkPn(w0))+i=1∑k[log(1−σ(sθ(wi,h)−logkPn(wi)))](15)
具体实现时,正样本项仅考虑目标class,负样本项随机选择k个样本,通过蒙特卡洛来模拟抽样。
那最终损失函数代码应该怎么写呢?
l o s s = − J h ( θ ) = − l o g ( σ ( s θ ( w 0 , h ) − l o g k P n ( w 0 ) ) ) − ∑ i = 1 k [ l o g ( 1 − σ ( s θ ( w i , h ) − l o g k P n ( w i ) ) ) ] \begin{equation}\begin{aligned} loss &= -J^h(\theta) \\ &=-log(\sigma(s_{\theta}(w_0,h)-logkP_n(w_0))) - \\ &\quad\quad\quad\quad\quad\quad\sum_{i=1}^k\left[log(1-\sigma(s_{\theta}(w_i,h)-logkP_n(w_i)))\right] \\ \end{aligned} \tag{16}\end{equation} loss=−Jh(θ)=−log(σ(sθ(w0,h)−logkPn(w0)))−i=1∑k[log(1−σ(sθ(wi,h)−logkPn(wi)))](16)
公式(16)中有四个项输入,分别是
- s θ ( w 0 , h ) s_{\theta}(w_0,h) sθ(w0,h),目标class的logit
- P n ( w 0 ) P_n(w_0) Pn(w0),目标class的噪声分布
- s θ ( w i , h ) s_{\theta}(w_i,h) sθ(wi,h),噪声class的logit
- P n ( w i ) P_n(w_i) Pn(wi),噪声class的噪声分布
from torch import randn, tensor, log, multinomial
import torch.nn.functional as F
from einops import repeat
import torch
import mathbs,k=2,8
num_classes=16#构造噪声:按照类别的频率采样
#(噪声分布约等于实际数据分布,两个分布越接近,nce效果越好)
classes=[0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
class_freq=tensor([20,10,30,5,45,56,76,43,23,11,34,5,6,54,23,7])
class_probs=class_freq/class_freq.sum()
noise_classes=multinomial(class_probs, num_classes)#模型预测的logits
logits=randn(bs, num_classes)
#2个样本的标签
labels=tensor([2, 4])#目标class的logit
true_class_logits=logits.take_along_dim(labels[:, None], dim=1)#目标class的噪声分布
true_class_noise=class_probs[labels]
#噪声class的logit
logits_k = repeat(logits, '(b 1) h -> (b k) h', k=k)
noise_class_logits = logits_k.take_along_dim(noise_classes.reshape(bs * k, -1), dim=1)
#噪声class的噪声分布
noise_class_noise=class_probs[noise_classes]#nce loss计算
true_class_loss = -torch.log( F.sigmoid(true_class_logits - torch.log(k*true_class_noise))).mean()
noise_class_loss = -torch.log( 1-F.sigmoid(noise_class_logits - torch.log(k*noise_class_noise))).mean()loss = true_class_loss+noise_class_loss
print("nce loss is {:.4f}".format(loss))