オンラインOCSVMに基づく外れ値検知をPythonで実装し、訓練時間削減の効果をベンチマークデータで検証した

はじめに

scikit-learn 1.0の新機能として、OC-SVMのオンライン版が紹介されている。 scikit-learn.org

オンライン化のご利益は訓練時間の大幅な削減である。 そこで本記事では、「結局オンライン化でどれくらい速くなったの?」という疑問に答えるべく、オンライン版OCSVMを使いやすい形となるようPyODのフォーマットで実装し、外れ値検知のベンチマークデータセットの上で性能を評価した。

準備

pip3 install pyod
pip3 install scikit-learn -U  # 最新版1.0系を入れるという意味

作成したクラス

SGD_OCSVMクラスを作成した。以下がソースコードである。

カーネル近似のためのNyström methodを実装したNystroemクラスと、SGDOneClassSVMクラスをそれぞれインスタンス化して実装している。

Nyström methodによりグラム行列を低ランク近似&feature mapを手に入れる。feature mapによりデータを(一般には)高次元空間に写像した後に、OC-SVMの目的関数が最小となるべく、パラメタをSGDで更新する。

つまり内部的には

  • 訓練データに対するNystroemのfit & transform
  • transform後のデータに対して、SGDOneClassSVMのfit

ということであり、sklearn的なパイプラインである。

参考までにそれらクラスの公式ドキュメントを示しておく。

今回は2クラスの引数をそのまま受け継ぐ形でSGD_OCSVMクラスを実装しているので、引数の数が多いのはイケてないが、やむを得ない。

使い方の例

上記のPythonファイルを、例えばsgd_ocsvm.pyとして保存する。使い勝手はPyODに合わせるようにした。

# PyODのOCSVM(バッチ型)の訓練
from pyod.models.ocsvm import OCSVM
clf = OCSVM(gamma="auto")  # デフォルトはRBFカーネル
clf.fit(X)

# SGDによるオンライン型のOCSVMの訓練
from sgd_ocsvm import SGD_OCSVM
clf = SGD_OCSVM(gamma="auto")  # デフォルトはRBFカーネル
clf.fit(X)

実験

PyODのベンチマークセット15種類に対して、バッチ型とオンライン型それぞれでモデルを訓練し、その所要時間および異常検知精度を比較する。 データセットは以下のサイトから入手できる。

これらデータセットの特性を簡単にまとめておく。

データセット サンプル数 特徴量次元 外れ値比率
arrhythmia 452 274 14.6%
cardio 1831 21 9.61%
glass 214 9 4.21%
ionosphere 351 33 35.9%
letter 1600 32 6.25%
mnist 7603 100 9.21%
musk 3062 166 3.17%
optdigits 5216 64 2.88%
pendigits 6870 16 2.27%
pima 768 8 34.90%
satellite 6435 36 31.6%
satimage-2 5803 36 1.22%
vertebral 240 6 12.5%
vowels 1456 12 3.43%
wbc 378 30 5.56%

実験に使用したノートブックを示す。なお訓練時のデータには外れ値は含まれている(訓練データとテストデータを6:4で分割)。

gist.github.com

結果

ノートブックより得られるグラフを図1に示す。左から、「訓練時間」「ROC」「Precision @ n」である。OCSVMはバッチ型、SGD-OCSVMはオンライン型である。

図1: 性能比較(バッチ型とオンライン型; 15データセットの平均)

上図より、訓練時間はデータセット平均でも劇的に削減できていることが分かる。なおかつ、異常検知精度は据え置きであり、今回はカーネル近似の影響をほぼ受けていない結果となった。

各データセットごとの訓練時間を示したのが図2である。

図2: 各データセットの訓練時間

上図より、サンプル数の多いmnistやoptdigits、pendigits、satellite、satimage-2で時間削減の効果が大きいことが読み取れる。特にmnistに関しては99%の時間削減となった (1- 0.053 / 6.1873 = 0.99) 。それ以外のデータセットについてはサンプル数が少なく、時間削減の効果は小さい。

参考までに、各データセットごとのROCスコアとPrecision @ nスコアを図3と図4にそれぞれ示す。ROCはほぼ同じグラフである。Precision @ nがオンライン化でやや劣るケースが見られるのは近似の影響であろう。

図3: 各データセットROCスコア

図4: 各データセットのPrecision @ n スコア

おわりに

今回の要点:

  • SGDを用いた最適化に基づくOC-SVMを、PyODフォーマットで利用可能となるように実装。カーネル近似手法の一つであるNyström methodの組み合わせ(パイプライン)。
  • 15ベンチマーク用データセットを用いて、訓練時間削減の効果を検証。サンプル数が多い場合に時間削減の効果が大きいことを確認。

みんなPyOD使おうぜ!