DAGMMに基づく教師なし異常検知をPyTorchで実装した

はじめに

教師なし異常検知の機械学習モデルの一つとして、"Deep Autoencoding Gaussian Mixture Model" (以降DAGMM)が知られている。 今回はこれを、異常検知/外れ値検知のためのPythonパッケージPyODの仕様に沿った形で、PyTorchにより実装したということである。

異常検知について

以下を参考にするとよい。

DAGMMについて

論文へのリンクを示す(ICLR 2018)。 openreview.net

解説スライドへのリンクを示す。

www.slideshare.net

流れは次の通り。

  1. オートエンコーダによる入力特徴量の次元圧縮

  2. 「圧縮された特徴量+α」の潜在特徴量からガウス混合モデル(Gaussian Mixture Model; GMM)の事後確率を推定

  3. 事後確率からGMMの統計量を計算(潜在特徴量を用いる)

  4. 統計量からエネルギー関数(=負の対数尤度に相当)を計算、異常検知のスコアとして利用

エネルギー関数の値が大きいほど、異常度が大きいということになる(負の対数尤度が大きい、つまり対数尤度が小さいから発生頻度も小さい)

準備

PyODはpipでインストール可能である。

pip3 install pyod

ほか、torch, sklearn, numpy, tqdmのインストールを済ませておく。

実装したクラス:DAGMM

DAGMMについて、先人によるPyTorch実装はいくつか見つかる。

しかしながら、次に挙げる点が不満だった。

  • エネルギー関数の計算時にPyTorchのinverse関数やpinv関数を用いている

  • 平方根や対数計算時にフロア値を適用していないため、計算が不安定になる(おそれがある)

そこで今回の実装ではPyODの仕様に従いつつ、上記の点も改善するようにした。 ちなみにPyLintを適用してコードをチェックしてある(スコア9.92; 満点は10.0)。

【DAGMMクラスのコードを表示するにはここをクリック】 gist.github.com

直接のURLはこちらから。

念のため、クラス引数の説明を載せておく。

引数 説明
comp_neurons compression network(オートエンコーダ)のニューロン
comp_activation compression networkの活性化関数
estim_neurons estimation network(事後確率推定)のニューロン
estim_activation estimation networkの活性化関数
lambda_energy 損失関数におけるエネルギー関数にかかる重み
comp_activation 損失関数における「共分散行列の逆数和」にかかる重み
epochs 学習のエポック数
batch_size ミニバッチサイズ
weight_decay 各層の重みに課す正則化の強さ
validation_size バリデーションデータの比率
batch_norm バッチノルムの有無
dropout_rate ドロップアウトの比率
verbose ログメッセージの段階; 0は非表示、1はプログレスバー
2は各エポックごとに損失関数の値を1行ずつ表示
contamination 外れ値の(想定)比率

ほか、注意点としてはoptimizerがAdam固定である。

異常検知デモンストレーションその1

Toshihiro NAKAE 氏による、DAGMMのデモンストレーション用notebookがある。

NAKAE氏による実装はTensorFlow 1系によるが、データ作成方法などを参考にして、 今回のPyTorch実装(dagmm.pyとして保存し利用)をローカルで動かしたものを示す。

gist.github.com

NAKAE氏とほぼ同様の結果が得られたことがわかる。

異常検知デモンストレーションその2

DAGMMの論文にはKDDCup 10% データセットによる異常検知の実験結果が載っている。 この実験結果を再現する試みである。ちなみにこちらも先と同様Toshihiro NAKAE 氏による、 DAGMMのデモンストレーション用notebookがある。

上記のnotebookを参考にしてPyTorch実装をローカルで動かしたものを以下に示す。

gist.github.com

論文に掲載された結果よりも幾分良いスコアが得られたが、 乱数の引きが良かったのだろう。

おわりに

PyODフォーマットで実装するシリーズも4つ目となった。他にも色々と実装する予定である(公開するとは言ってない)。

DAGMM論文に掲載された実験結果が再現できて安心した。ソースコードの可読性やスタイルにはまだ改善の余地はある。

余談

DAGMMクラスのattributeの数を7個以下に制限するために(さもなくばPyLintに怒られる!)、 今回はPythonNamedTupleを利用した。具体的にはクラスのinitメソッドに与えた引数を「学習条件」や「ネットワーク定義」のくくりで束ねることで、むやみにattributeを増やすことを回避することに成功した。

DAGMMにより得られる低次元特徴表現(estimation networkの入力)の可視化機能の実装も今回は見送った。モデル内部の挙動を観察・把握し、各モデルの異常検知性能をさらに深く考察する際にはとても有用な機能であるが、どちらかというと研究用途でありクラスが提供する異常検知機能とは切り分けて考えることにした。本記事の読者ならば、DAGMMクラスに独自のメソッドを追加し、低次元表現を可視化することは容易に実現できるだろう。