RBMの学習アルゴリズムと損失関数について

はじめに

前回の記事でRBMを実装した。

tam5917.hatenablog.com

この実装では、RBMの損失関数は

Contrastive Divergence (CD) アルゴリズムを考慮すると,損失関数はvisible層v0から計算される自由エネルギー(F0)と、再構成されたvisible層vkから計算される自由エネルギー(Fk)との差(F0 - Fk)で計算すればよい。

としていた(上記記事より引用)。

本記事はRBMの学習アルゴリズムを先日の実装に即した形で説明しつつ、その損失関数の説明を補うものである。RBM自体の解説は [1] が、実装のTIPSは [2] が大いに参考となる。

RBMのモデル構造と確率分布

まずRBMのモデル構造と確率分布を復習する。ただし本記事では Bernoulli-Bernoulli 型のモデルのみを扱う。

RBMは可視層  \mathbf{v} と隠れ層  \mathbf{h} を持つ。可視層と隠れ層の各状態は 0 または 1 の値を取る確率変数とする。これらに対して、RBMはエネルギーを定義する関数  E(以降、エネルギー関数)を持つ。


\begin{align*}
E(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\theta}) \triangleq - (\mathbf{b}^\top \mathbf{v} + \mathbf{c}^\top \mathbf{h} + \mathbf{h}^\top \mathbf{W} \mathbf{v})
\end{align*}

ここで、 \mathbf{b} は可視層のバイアス、 \mathbf{c} は隠れ層のバイアス、 \mathbf{W} は可視層と隠れ層間の重み行列である。モデルパラメタを  \boldsymbol{\theta} = \left\{\mathbf{b}, \mathbf{c}, \mathbf{W}\right\} とした。可視層と隠れ層の状態の同時分布  P(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\theta}) は、エネルギー関数を用いて以下のように定義される。


\begin{align*}
P(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\theta}) \triangleq \frac{\exp(-E(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\theta}))} {Z (\boldsymbol{\mathbf{\theta}})}
\end{align*}

ここで、 Z(\boldsymbol{\mathbf{\theta}}) は分配関数と呼ばれ、すべての可能な  (\mathbf{v}, \mathbf{h}) の組み合わせについて  \exp(-E(\mathbf{v}, \mathbf{h}) \mid \boldsymbol{\theta}) の総和をとったものである:


\begin{align*}
Z (\boldsymbol{\mathbf{\theta}}) \triangleq \sum_{\mathbf{v}} \sum_{\mathbf{h}} \exp(-E(\mathbf{v}, \mathbf{h}  \mid \boldsymbol{\mathbf{\theta}}))
\end{align*}

なおRBMにおいては、一方の層における各変数の値を所与としたときに、他方の層における各変数が条件付き独立となる性質を持つ。この性質は学習アルゴリズムの実現に重要な役割を果たす。

周辺分布と自由エネルギー

可視層の周辺分布  P(\mathbf{v} \mid \boldsymbol{\theta}) は、隠れ層の状態について周辺化することで得られる。


\begin{align*}
P(\mathbf{v} \mid \boldsymbol{\theta}) = \sum_{\mathbf{h}} P(\mathbf{v}, \mathbf{h}  \mid \boldsymbol{\mathbf{\theta}}) = \sum_{\mathbf{h}} \frac{\exp(-E(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\mathbf{\theta}}))}{Z (\boldsymbol{\mathbf{\theta}})}
\end{align*}

自由エネルギー  F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}}) P(\mathbf{v} \mid \boldsymbol{\theta}) のうち分配関数を除いたものに対応し、次式で定義される。


\begin{align*}
F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}}) \triangleq -\log \left(\sum_{\mathbf{h}} \exp(-E(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\mathbf{\theta}})) \right)
\end{align*}

 P(\mathbf{v} \mid \boldsymbol{\theta}) はまた自由エネルギーを使って書くこともできる。


\begin{align*}
P(\mathbf{v} \mid \boldsymbol{\theta}) = \frac{\exp(-F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}}))}{Z (\boldsymbol{\mathbf{\theta}})}
\end{align*}

最尤推定

RBMの学習は、与えられたデータセットに対するモデルパラメータ  \boldsymbol{\theta} の最適化により、モデルがそのデータを生成する確率の最大化を目指す。これを最尤推定と呼ぶ。

データセット  \mathcal{D} = \left\{\mathbf{v}_{1}, \mathbf{v}_{2}, ..., \mathbf{v}_{N}\right\} が与えられたとき、対数尤度  L_{\mathcal{D}} (\boldsymbol{\mathbf{\theta}}) は次式で定義される。


\begin{align*}
L_{\mathcal{D}} (\boldsymbol{\mathbf{\theta}}) \triangleq \log P(\mathcal{D} \mid \boldsymbol{\theta}) = \sum_{i=1}^N \log P(\mathbf{v}_{i}  \mid \boldsymbol{\theta})
\end{align*}

対数尤度  L_{\mathcal{D}} (\boldsymbol{\mathbf{\theta}}) を最大化するモデルパラメータ  \boldsymbol{\theta} を見つけることが最尤推定の目的である。対数尤度に負号をつけた「負の対数尤度」を定義して,最小化問題として定式化してもよい。

RBMの最尤推定を実現する方法の一つにEMアルゴリズムがある [3]。EMアルゴリズムではM-stepにおいてパラメータに関する目的関数( \mathcal{Q} 関数)の最大化が要求される。 \mathcal{Q} 関数を最大化する条件(学習方程式)が導出できるものの、分配関数の場合と同様に、そこに現れる期待値の計算量が膨大となる問題がある [4]。

結局、確率的勾配降下法をはじめとする勾配法に基づいて、対数尤度を最大化するパラメータを数値的に探索しなければならない。そこで対数尤度  L_{\mathcal{D}} (\boldsymbol{\mathbf{\theta}}) のパラメータに関する勾配が具体的に計算できることを確認しておく必要がある。

補足:勾配上昇法(勾配降下法)

勾配上昇法では、対数尤度を最大化すべくモデルパラメタを次式に従って更新する。


\begin{align*}
\boldsymbol{\theta}_{\text{new}} = \boldsymbol{\theta}_{\text{old}} + \eta  \left. \nabla_{\boldsymbol{\theta}} L_{\mathcal{D}}(\boldsymbol{\theta}) \right|_{\boldsymbol{\theta} = \boldsymbol{\theta}_{\text{old}}}
\end{align*}

ここで  \eta は学習率と呼ばれるハイパーパラメータである。対数尤度の値が十分大きくなるまで、勾配の計算とパラメタの更新を繰り返す。また上述の「負の対数尤度」を導入し,勾配降下法としても定式化できる。


\begin{align*}
\boldsymbol{\theta}_{\text{new}} &= \boldsymbol{\theta}_{\text{old}} - \eta  \left. \nabla_{\boldsymbol{\theta}}(-L_{\mathcal{D}}(\boldsymbol{\theta})) \right|_{\boldsymbol{\theta} = \boldsymbol{\theta}_{\text{old}}}
\end{align*}

学習率をあらかじめ  N で割った形にしておくと、後述する確率的勾配降下法との親和性が増す。


\begin{align*}
\boldsymbol{\theta}_{\text{new}} &= \boldsymbol{\theta}_{\text{old}} +  \frac{\eta}{N}  \left. \nabla_{\boldsymbol{\theta}}L_{\mathcal{D}}(\boldsymbol{\theta}) \right|_{\boldsymbol{\theta} = \boldsymbol{\theta}_{\text{old}}} \\ &=  \boldsymbol{\theta}_{\text{old}} + \frac{\eta}{N} \sum_{i=1}^N    \left. \nabla_{\boldsymbol{\theta}} \log P(\mathbf{v}_{i}  \mid \boldsymbol{\theta}) \right|_{\boldsymbol{\theta} = \boldsymbol{\theta}_{\text{old}}}
\end{align*}

ここでデータセット  \mathcal{D} 上の経験分布  q_{\mathcal{D}} (\mathbf{v})を導入しておく。


\begin{align*}
q_{\mathcal{D}}(\mathbf{v}) = \frac{1}{N} \sum_{i=1}^{N} \delta (\mathbf{v}, \mathbf{v}_{i})
\end{align*}

するとパラメタの更新式は経験分布に関する期待値を用いて書くこともできる。


\begin{align*}
\boldsymbol{\theta}_{\text{new}} &= \boldsymbol{\theta}_{\text{old}} + \frac{\eta}{N} \sum_{i=1}^N    \left. \nabla_{\boldsymbol{\theta}} \log P(\mathbf{v}_{i}  \mid \boldsymbol{\theta}) \right|_{\boldsymbol{\theta} = \boldsymbol{\theta}_{\text{old}}} \\ &=  \boldsymbol{\theta}_{\text{old}} + \eta \left. \mathbb{E}_{q_{\mathcal{D}}(\mathbf{v})} [ \nabla_{\boldsymbol{\theta}}  \log P(\mathbf{v}  \mid \boldsymbol{\theta})] \right|_{\boldsymbol{\theta} = \boldsymbol{\theta}_{\text{old}}} 
\end{align*}

対数尤度の勾配計算

対数尤度の勾配を計算するために、まず個々のデータ  \mathbf{v} に対する対数尤度  \log P(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}}) の勾配を考える。


\begin{align*}
\nabla_{\boldsymbol{\theta}}
\log P(\mathbf{v} \mid \boldsymbol{\theta}) &= \nabla_{\boldsymbol{\theta}} \log \left( \frac{\exp(-F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}}))}{Z (\boldsymbol{\mathbf{\theta}})} \right) \\
&=  \nabla_{\boldsymbol{\theta}} (-F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}})) -  \nabla_{\boldsymbol{\theta}} \log Z (\boldsymbol{\theta}) \\ 
&=  \nabla_{\boldsymbol{\theta}} (-F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}})) -   \mathbb{E}_{P(\mathbf{v} \mid  \boldsymbol{\mathbf{\theta}})}  \left[ \nabla_{\boldsymbol{\theta}} (-F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}}))  \right]\\
\end{align*}

便宜上、第一項をPositive Phase,第二項をNegative Phaseと呼ぶ。 Positive Phaseは観測データ  \mathbf{v} に対する(負の)自由エネルギーの勾配そのものである。自由エネルギー自体は容易に計算可能である(付録を参照)。したがってPositive Phaseの勾配もまた計算可能である。対照的に、Negative Phaseは  Z (\boldsymbol{\theta}) 自体の計算量が膨大なため(組合せ爆発)、その勾配計算もまた計算が難しい。それゆえに、勾配全体として正確な計算が難しい。

Contrastive Divergence (CD) 法による勾配計算の近似

CD法は上記の勾配を、特にNegative Phaseに注目して近似的に計算するための手法である。

  1. 観測データ  \mathbf{v}^{(0)} = \mathbf{v} を用いて、隠れ層をサンプリングし  \mathbf{h}^{(0)} を得る
  2.  \mathbf{h}^{(0)} を用いて可視層をサンプリングし  \mathbf{v}^{(1)} を得る
  3. このプロセスを  k 回繰り返して \mathbf{v}^{(k)} を得る(通常、 kは小さい値でよく、 k = 1 が一般的)
  4.  \mathbf{v} に関する期待値  \mathbb{E}_{P(\mathbf{v} \mid  \boldsymbol{\mathbf{\theta}})}  \left[  \nabla_{\boldsymbol{\theta}} (-F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}})) \right]を、  \mathbf{v}^{(k)} を用いた期待値評価により近似する

RBMの条件付き独立性により、各層のサンプリングは確率変数ごとのサンプリングに帰着されるので容易である。さらに上記のプロセスは、GPUを用いた並列計算によって効率的に実行できる。つまり、ミニバッチ上の並列なサンプリングの連鎖により複数の  \mathbf{v}^{(k)} を一斉に取得し、それらの経験分布による期待値(勾配の標本平均)をもって真の期待値を近似的に評価できる。ここでミニバッチとは、一般にはランダムサンプリングによりデータセットから取り出した、小規模な擬似データセットのことを指す。

CD法を考慮すると、Negative Phaseは  \mathbf{v}^{(k)} を用いて次のように近似しても差し支えない。


\begin{align*}
\mathbb{E}_{P(\mathbf{v} \mid  \boldsymbol{\mathbf{\theta}})}  \left[ \nabla_{\boldsymbol{\theta}} (-F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}}))  \right]
\approx \nabla_{\boldsymbol{\theta}} (-F(\mathbf{v}^{(k)} \mid \boldsymbol{\mathbf{\theta}}))
\end{align*}

したがって対数尤度の勾配もまた次のように近似できる。


\begin{align*}
\nabla_{\boldsymbol{\theta}}
\log P(\mathbf{v}^{(0)} \mid \boldsymbol{\theta})
\approx \nabla_{\boldsymbol{\theta}} (-F(\mathbf{v}^{(0)} \mid \boldsymbol{\mathbf{\theta}})) - \nabla_{\boldsymbol{\theta}} (-F(\mathbf{v}^{(k)} \mid \boldsymbol{\mathbf{\theta}}))
\end{align*}

損失関数の導入

データ単位の損失関数を次式で定義する。


\begin{align*}
\mathcal{L}(\mathbf{v}^{(0)} = \mathbf{v} \mid \boldsymbol{\theta}) &\triangleq F(\mathbf{v}^{(0)} \mid \boldsymbol{\mathbf{\theta}}) - F(\mathbf{v}^{(k)} \mid \boldsymbol{\mathbf{\theta}})
\end{align*}

勾配降下法を考慮すると、「負の対数尤度」を最小化することになるので、損失関数を上式で定義しておくと都合が良い。 データ単位の対数尤度の勾配はまた,損失関数の勾配によって次のように近似できる。


\begin{align*}
\nabla_{\boldsymbol{\theta}}
\log P(\mathbf{v} \mid \boldsymbol{\theta})
\approx -\nabla_{\boldsymbol{\theta}} \mathcal{L}(\mathbf{v} \mid \boldsymbol{\theta}) 
\end{align*}

損失関数の解釈

データ単位の損失関数の勾配から、以下の解釈が可能である。

  • 学習は  F(\mathbf{v}^{(0)} \mid \boldsymbol{\mathbf{\theta}}) が減少する方向に進む。すなわちRBMは観測データに低い自由エネルギーを与え、生成確率(対数尤度)を大きくしたい。

  • 学習は  F(\mathbf{v}^{(k)} \mid \boldsymbol{\mathbf{\theta}}) F(\mathbf{v}^{(0)} \mid \boldsymbol{\mathbf{\theta}}) に対して相対的に増大する方向に進む。すなわちRBMは再構成データに対して相対的に高い自由エネルギーを与え、生成確率を少しでも小さくしたい。CD法に立ち戻って解釈すれば、近似を入れる前の期待値  \mathbb{E}_{P(\mathbf{v} \mid  \boldsymbol{\mathbf{\theta}})}  \left[ \cdot \right] には、観測データ以外の無秩序なデータをRBMが生成する可能性(状態)が全て考慮されている。RBMは観測データ以外の不要な状態に対しては相対的に自由エネルギーを高くして、生成確率を小さく保ちたい。

  • 学習はこれら自由エネルギーの差を0に近づける方向に進む。最終的な損失関数が自由エネルギーの差で表されるためである。

以上を踏まえて、冒頭の記事からモデル訓練における自由エネルギーの推移のログを引用する。Average Loss が損失関数の値に対応している。学習は確かに損失関数を0に近づける方向に進んでいる。損失関数自体は負の対数尤度ではないので、ここでは0に向かって増加傾向を示している。観測データの自由エネルギー Free Energy (v0) の値と再構成データの自由エネルギー Free Energy (vk) の値は共に減少しているが、後者が相対的に大きくなっていることも確認できる。損失関数が負の値を示しているのはそのためである。

Epoch: 1 Average Loss: -22.8360, Free Energy (v0): -161.2850, Free Energy (vk): -138.4490
Epoch: 2 Average Loss: -12.4768, Free Energy (v0): -192.5474, Free Energy (vk): -180.0706
Epoch: 3 Average Loss: -10.4602, Free Energy (v0): -215.8684, Free Energy (vk): -205.4082
Epoch: 4 Average Loss: -8.8693, Free Energy (v0): -234.3609, Free Energy (vk): -225.4916
Epoch: 5 Average Loss: -7.5784, Free Energy (v0): -249.7175, Free Energy (vk): -242.1391
Epoch: 6 Average Loss: -6.4599, Free Energy (v0): -262.7325, Free Energy (vk): -256.2727
Epoch: 7 Average Loss: -5.5244, Free Energy (v0): -273.9464, Free Energy (vk): -268.4220
Epoch: 8 Average Loss: -4.7158, Free Energy (v0): -283.0165, Free Energy (vk): -278.3007
Epoch: 9 Average Loss: -3.9667, Free Energy (v0): -290.9498, Free Energy (vk): -286.9831
Epoch: 10 Average Loss: -3.4139, Free Energy (v0): -296.5954, Free Energy (vk): -293.1815

ミニバッチ版確率的勾配降下法

データセット上の全データを用いる勾配上昇法に基づくRBMの学習は、データセットのサイズ  N に比例した勾配の計算操作が要求されるため、  N が大きな値をとる大規模データセット上では効率的でない。そこでミニバッチ上で勾配上昇法を効率的に行う方法がしばしば採用されている。先日のPyTorch実装でもミニバッチ学習を採用した。 以下ではミニバッチ版確率的勾配降下法の導出を与える(一部で[5]と[6]を参考にした)。

データセット  \mathcal{D} からミニバッチ  \mathcal{M} \subset \mathcal{D} を取り出せたとする。ミニバッチのサイズは  N_{\mathcal{M}} \ll N とする。 ミニバッチ上で対数尤度の不偏推定を行う。


\begin{align*}
L_{\mathcal{M}} (\boldsymbol{\mathbf{\theta}}) &\triangleq \log P(\mathcal{M} \mid \boldsymbol{\theta}) =  \sum_{\mathbf{v} \in \mathcal{M}} \log P(\mathbf{v}  \mid \boldsymbol{\theta})\\
\frac{1}{N} L_{\mathcal{D}} (\boldsymbol{\mathbf{\theta}}) &=  \mathbb{E}_{q_{\mathcal{D}}(\mathbf{v})} (\log P(\mathbf{v}  \mid \boldsymbol{\theta})) \approx  \mathbb{E}_{q_{\mathcal{M}}(\mathbf{v})} (\log P(\mathbf{v}  \mid \boldsymbol{\theta}))  =  \frac{1}{N_{\mathcal{M}}} L_{\mathcal{M}} (\boldsymbol{\mathbf{\theta}}) \\
\end{align*}

ただしミニバッチ  \mathcal{M} 上の経験分布を  q_{\mathcal{M}} (\mathbf{v}) と置いた。

RBMのミニバッチ版確率的勾配降下法は,上記の不偏推定量の勾配に基づいて行われる。すなわち


\begin{align*}
\mathbb{E}_{q_{\mathcal{D}}(\mathbf{v})} (\nabla_{\boldsymbol{\theta}} \log P(\mathbf{v}  \mid \boldsymbol{\theta})) &\approx \mathbb{E}_{q_{\mathcal{M}}(\mathbf{v})} (\nabla_{\boldsymbol{\theta}} \log P(\mathbf{v}  \mid \boldsymbol{\theta}))
\end{align*}

という近似に基づき、更新式を次のように書く(上述の「補足」を参照)。


\begin{align*}
\boldsymbol{\theta}_{\text{new}} &= \boldsymbol{\theta}_{\text{old}} + \eta \left. \mathbb{E}_{q_{\mathcal{M}}(\mathbf{v})} [ \nabla_{\boldsymbol{\theta}}  \log P(\mathbf{v}  \mid \boldsymbol{\theta})] \right|_{\boldsymbol{\theta} = \boldsymbol{\theta}_{\text{old}}}
\end{align*}

CD法に基づいてデータ単位の損失関数を導入すると、更新式の右辺がさらに近似される。そこで改めて右辺を \boldsymbol{\theta}_{\text{new}}と置いて、ミニバッチ版確率的勾配降下法の更新式を得る。


\begin{align*}
\boldsymbol{\theta}_{\text{new}} 
&=  \boldsymbol{\theta}_{\text{old}} -  \eta  \left. \mathbb{E}_{q_{\mathcal{M}}(\mathbf{v})} [ \nabla_{\boldsymbol{\theta}} \mathcal{L}(\mathbf{v} \mid \boldsymbol{\theta})  ]\right|_{\boldsymbol{\theta} = \boldsymbol{\theta}_{\text{old}}}
\end{align*}

ちなみにPyTorchでは、確率的勾配降下法が実装されたoptimizerのクラスは SGD である。件のRBM実装ではoptimizerに Adam を採用したのであるが。

付録には損失関数の期待値  \mathbb{E}_{q_{\mathcal{M}}(\mathbf{v})} [\mathcal{L}(\mathbf{v} \mid \boldsymbol{\theta}) ] の計算方法を、実際のPyTorchコードを示しながら説明した。続く勾配計算は自動微分におまかせすればよい。

おわりに

本記事ではRBMの学習アルゴリズムを簡単に振り返った。実装のうえでは必要なことであったので、良い勉強の機会となった。

本記事で触れられなかったトピックの1つに、Pseudo-Likelihood (PL) の理論がある。scikit-learnのRBM実装でも PL の計算は score_samples メソッドとして与えられており、主に学習の進行をモニタする目的で使われている。今後の記事に期待であるが、むしろエネルギーベースモデルの記事シリーズが続く予感がする(続くとは言ってない)。

RBMは理論と応用の両面で奥が深い生成モデルであり 、RBMの愛好者(RBM推し)がもっと増えてもよいと思う。

付録

自由エネルギーの具体的な計算方法

隠れ層の各状態ユニット  h_i が二値の場合すなわち0か1のどちらかの値を取るとして、 \exp(-E(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\mathbf{\theta}})) を次のように分解できる。


\begin{align*}
\exp(-E(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\mathbf{\theta}})) = \exp(\mathbf{b}^\top \mathbf{v}) \times \prod_i \exp(c_i h_i + \mathbf{v}^\top \mathbf{W}_i h_i)
\end{align*}

ここで、 \mathbf{W}_i は重み行列  \mathbf{W} の 第  i 列ベクトルであり、 c_i はバイアス項  \mathbf{c} の第  i 成分である。また積は隠れ層の全ユニットに渡って取る。このとき、 \sum_{\mathbf{h}} \exp(-E(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\mathbf{\theta}})) は次のように変形できる:


\begin{align*}
\sum_{\mathbf{h}} \exp(-E(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\mathbf{\theta}})) &= \exp(\mathbf{b}^\top \mathbf{v}) \times \prod_{i} \sum_{h_i \in \{0, 1\}} \exp(c_i h_i + \mathbf{v}^\top \mathbf{W}_i h_i) \\
&= \exp(\mathbf{b}^\top \mathbf{v}) \times \prod_{i} \left(1 + \exp(c_{i} + \mathbf{v}^\top \mathbf{W}_{i}) \right)
\end{align*}

したがって、自由エネルギー  F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}}) は次式に基づき具体的に計算できる。


\begin{align*}
F(\mathbf{v} \mid \boldsymbol{\mathbf{\theta}}) &= -\log \left(\sum_{\mathbf{h}} \exp(-E(\mathbf{v}, \mathbf{h} \mid \boldsymbol{\mathbf{\theta}})) \right) \\ &= -\mathbf{b}^\top \mathbf{v} - \sum_i \log \left(1 + \exp(c_i + \mathbf{v}^\top \mathbf{W}_i) \right)
\end{align*}

  \log \left(1 + \exp(c_i + \mathbf{v}^\top \mathbf{W}_i) \right) の部分はいわゆる Softplus関数 の計算である:


\begin{align*}
\text{Softplus(x)} := \log \left(1 + \exp(x) \right)
\end{align*}

先日のPyTorch実装から自由エネルギーの具体的な計算部分を抜粋する(RBM クラスの free_energy メソッド)。上式と実装はきちんと対応している。

def free_energy(self, v: torch.Tensor) -> torch.Tensor:
    """Calculates free energy.

    Args:
        v (torch.Tensor): Input state of the visible layer.

    Returns:
        free_energy (torch.Tensor): Free energy.
    """
    vbias_term = torch.matmul(v, self.bv)
    wx_b = torch.matmul(v, self.w) + self.bh
    hidden_term = torch.sum(self.softplus(wx_b), dim=1)
    free_energy = -hidden_term - vbias_term
    return free_energy

損失関数の具体的な計算方法

先日のPyTorch実装から、ミニバッチ学習のための具体的な損失関数の計算箇所を抜粋する(cd_loss 関数)。

def cd_loss(
    rbm: RBM, v0: torch.Tensor, config: TrainingConfig
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Performs Contrastive Divergence algorithm and calculates the loss.

    Args:
        rbm (RBM): The RBM model.
        v0 (torch.Tensor): The visible layer state from the training data.
        config (TrainingConfig): The training configuration.

    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
            loss (torch.Tensor): The calculated loss value.
            free_energy_0 (torch.Tensor): free energy of visible layer.
            free_energy_k (torch.Tensor): free energy of reconstructed visible layer.
    """
    # k-step Gibbs sampling
    vk = v0
    for _ in range(config.cd_steps):
        vk, _ = rbm(vk)

    # Calculate Free Energies and the loss
    free_energy_0 = rbm.free_energy(v0).mean()
    free_energy_k = rbm.free_energy(vk).mean()
    loss: torch.Tensor = free_energy_0 - free_energy_k
    return loss, free_energy_0, free_energy_k

勾配の線形性と自動微分を考慮すれば、結局のところ損失関数の期待値  \mathbb{E}_{q_{\mathcal{M}}(\mathbf{v})} [\mathcal{L}(\mathbf{v} \mid \boldsymbol{\theta}) ] すなわち


\begin{align*}
 \mathbb{E}_{q_{\mathcal{M}}(\mathbf{v})} [F(\mathbf{v}^{(0)} \mid \boldsymbol{\mathbf{\theta}}) - F(\mathbf{v}^{(k)} \mid \boldsymbol{\mathbf{\theta}}) ]
= \mathbb{E}_{q_{\mathcal{M}}(\mathbf{v})} [F(\mathbf{v}^{(0)} \mid \boldsymbol{\mathbf{\theta}}) ] - \mathbb{E}_{q_{\mathcal{M}}(\mathbf{v})} [F(\mathbf{v}^{(k)} \mid \boldsymbol{\mathbf{\theta}}) ]
\end{align*}

までが計算できればよい。CD法に基づいて2つの自由エネルギーを計算する。free_energy_0free_energy_k の計算時に mean() を使っている理由は、ミニバッチ上の標本分布  q_{\mathcal{M}}(\mathbf{v}) による自由エネルギーの期待値つまり標本平均を取るためである。

参考文献

[1] Yoshua Bengio, "Learning Deep Architectures for AI", Foundations and Trends® in Machine Learning: Vol. 2: No. 1, pp 1-127, 2009.

[2] Hinton, G.E. (2012). A Practical Guide to Training Restricted Boltzmann Machines. In: Montavon, G., Orr, G.B., Müller, KR. (eds) Neural Networks: Tricks of the Trade. Lecture Notes in Computer Science, vol 7700. Springer, Berlin, Heidelberg.

[3] 長岡 浩司, 小嶋 徹也, 統計的モデルとしてのボルツマンマシン, 計算機統計学, 1995, 8 巻, 1 号, p. 61-81.

[4] 人工知能学会監修, 神嶌敏弘 編, "深層学習", 近代科学社, 2015.

[5] Diederik P. Kingma and Max Welling, "An Introduction to Variational Autoencoders", Foundations and Trends® in Machine Learning: Vol. 12: No. 4, pp 307-392, 2019.

[6] 中山, 二反田, 田村, 井上, and 牛久, "深層学習からマルチモーダル情報処理へ", サイエンス社, 2022.

参考資料