Psuedo-LikelihoodをRBMの損失関数として使ったら学習はうまくいくのか

はじめに

Psuedo-Likelihood (PL) はRBMの学習の進行を測定するための指標として使われている。scikit-learnの BernoulliRBM クラスには PLを計算するための score_samples メソッドが実装されている。

scikit-learn.org

さてscikit-learnで計算可能ならば、当然PyTorchでも計算可能である。そこでRBMのPyTorch実装においてPLを損失関数に採用すれば、ある程度は学習がうまく進むはずだろうと見込める。

本記事はPLをRBMの損失関数として使ったときの実験結果を報告するものである。PLの詳しい理論や導出はさておき、まずはやってみようの精神である。

実装

以下に置いた。Enjoy!

使い方:

$ python train_rbm_pl.py --log_dir log_rbm --epochs 300 --learning_rate 0.002 --batch_size 128

PLの具体的な計算箇所を示す。反転(flip)のところは scikit-learn を一部参考にした。

def psuedo_likelihood(self, v: torch.Tensor) -> torch.Tensor:
    """Calculates psuedo-likelihood.

    Args:
        v (torch.Tensor): Input state of the visible layer.

    Returns:
        psuedo_likelihood (torch.Tensor): Psuedo-Likelihood.
    """
    v_flip = v.clone()
    ind = (
        np.arange(v.shape[0]),
        np.random.RandomState().randint(0, v.shape[1], v.shape[0]),
    )
    v_flip[ind] = 1 - v_flip[ind]
    fe = self.free_energy(v)
    fe_flip = self.free_energy(v_flip)
    zeros = torch.zeros(v.shape[0]).to(v.device)
    psuedo_likelihood = -v.shape[1] * torch.logaddexp(zeros, -(fe_flip - fe))
    return psuedo_likelihood

損失関数として用いる場合は、計算したPLの符号を反転して「Negative Psuedo-Likelihood」とする。

PLによる本来の尤度の近似計算に対して、その計算を簡略化する追加の近似を重ねたのが今回のPL実装である。つまり正確なPLすら実は計算していない点に留意する必要がある。

実験結果

MNISTの2値化した画像データを対象に訓練を300エポック回した。学習率は0.002、バッチサイズは128、optimizerはAdamである。訓練後、生成した再構成画像を図1に示す。

図1:テストデータの再構成結果(上段:テストデータ,下段:再構成結果(2値サンプリング))

当初の期待通り、RBMの学習がうまくいくことは確認できた。しかしながらCD法よりも質の低い局所解に収束したことも確認できた。実際、今回の実装とは別にCD法で学習させたときのPLの推移は以下に示す通りである。学習率、バッチサイズ、optimizerは揃えてある。

Epoch: 1, Psuedo-Likelihood = -153.2711
Epoch: 2, Psuedo-Likelihood = -98.9846
Epoch: 3, Psuedo-Likelihood = -84.6411
Epoch: 4, Psuedo-Likelihood = -76.5742
Epoch: 5, Psuedo-Likelihood = -70.7524
Epoch: 6, Psuedo-Likelihood = -68.8260
Epoch: 7, Psuedo-Likelihood = -65.8515
Epoch: 8, Psuedo-Likelihood = -63.2211
Epoch: 9, Psuedo-Likelihood = -60.9840
Epoch: 10, Psuedo-Likelihood = -59.3827

CD法ではわずか10エポックで-60程度に到達したことが分かる。 それに対して今回の実装におけるPLの推移は以下の図2に示す通りであり、300エポック回しても-60には到達しなかった。

図2:PLのエポックごとの推移

考察

PL自体が持つ近似,すなわち可視ユニット間の依存性を無視する近似が、勾配の推定に与えた影響は小さくないと推測できる。PLの勾配の向きは、あくまで「ある1つの可視ユニットを反転させた場合の尤度変化」を最大化する方向であって、全体の同時分布の尤度を最大化する方向とは限らない。特に学習初期においてはデータ分布からモデルが大きく外れており、PLの勾配は不安定である。誤った方向に学習が進んだ結果、局所解に陥った可能性が考えられる。一方、CD法も勾配計算に近似が入っているものの、Gibbsサンプリングを用いてデータ分布を近似し、そこから得られた勾配は全体の尤度最大化により近い方向を向いている。それゆえ、たとえ学習初期であっても勾配の向きは大きく外していない。

おわりに

PLがRBMの損失関数としてもある程度使えることは実験的に確認できた。ただ本記事の結果は、あくまで参考程度に留めておきたい。続く記事にてPLの導出を行う予定である。