はじめに
先日、RBMの記事を見かけた。 zenn.dev
RBM自体は昔からある有名なモデルであり実装の経験もあったのだが、PyTorchでキレイに書いたことはなかったので、書いてみたということ。
実装
以下に置いた。Enjoy!
今回はデモンストレーションということで、MNIST上で訓練・テストを行うものである。train_rbm.py ではMNIST画像の画素値を [0, 1] で正規化してvisible層の入力としており、確率的な観測値と見なしている。他方、train_rbm2.py では画素値を 0 or 1 で二値化している。 これらvisible層の挙動をオプションで切り替えられるよう、ひとつのスクリプトに統合すること自体は実に簡単だが、今回は分けておいた。
勾配まで手計算で導出したうえで実装しているものが多く見つかるが、本実装では損失関数として自由エネルギーまでを計算し、具体的な勾配計算は自動微分にまかせる方針を取った。
実験結果その1
以下では train_rbm.py (入力visible層の値を[0, 1] で正規化バージョン)を動かしたときの結果を示す。エポック数は30とした。
テストデータの再構成結果の例を図1に示す。レイアウトは冒頭のRBM記事をリスペクトした。ここでの再構成結果はvisible層の確率値を可視化している。

参考まで、visible層の確率値ではなく,2値サンプリングで再構成した結果を図2に示す。

ちなみに1観測データあたりの自由エネルギーの推移のログを以下に抜粋する(v0が入力,vkが再構成を意味する):
Epoch: 1 Average Loss: -22.8360, Free Energy (v0): -161.2850, Free Energy (vk): -138.4490 Epoch: 2 Average Loss: -12.4768, Free Energy (v0): -192.5474, Free Energy (vk): -180.0706 Epoch: 3 Average Loss: -10.4602, Free Energy (v0): -215.8684, Free Energy (vk): -205.4082 Epoch: 4 Average Loss: -8.8693, Free Energy (v0): -234.3609, Free Energy (vk): -225.4916 Epoch: 5 Average Loss: -7.5784, Free Energy (v0): -249.7175, Free Energy (vk): -242.1391 Epoch: 6 Average Loss: -6.4599, Free Energy (v0): -262.7325, Free Energy (vk): -256.2727 Epoch: 7 Average Loss: -5.5244, Free Energy (v0): -273.9464, Free Energy (vk): -268.4220 Epoch: 8 Average Loss: -4.7158, Free Energy (v0): -283.0165, Free Energy (vk): -278.3007 Epoch: 9 Average Loss: -3.9667, Free Energy (v0): -290.9498, Free Energy (vk): -286.9831 Epoch: 10 Average Loss: -3.4139, Free Energy (v0): -296.5954, Free Energy (vk): -293.1815
Contrastive Divergence (CD) アルゴリズムを考慮すると,損失関数はvisible層v0から計算される自由エネルギー(F0)と、再構成されたvisible層vkから計算される自由エネルギー(Fk)との差(F0 - Fk)で計算すればよい。訓練が進むにつれて両者の差は小さくなり,訓練は成功裏に進んだ。
RBMの学習アルゴリズムと損失関数については続く記事にて説明を加えておいた。
実験結果その2
以下では train_rbm2.py(入力visible層の値を2値化したバージョン)を動かしたときの結果を示す。エポック数は30とした。
テストデータの再構成結果の例を図3に示す。再構成結果はvisible layerを2値サンプリングしたもので可視化している。図2と比較してドットが抜けるノイズが少ないように見えるのは、visible層の入力と再構成の間で一貫して2値化が用いられ、条件がマッチしているためであろう。

参考まで、visible layerの確率値(連続値)で再構成した結果を図4に示す。

train_rbm2.py も同様に訓練が進むが,損失関数の挙動からその進みは遅いことが分かった。これは2値化に伴うサンプリングのノイズが大きく影響しているためであろう。
おわりに
RBMであれVAEであれ、手を動かして実装すると愛着が湧く。