はじめに
scikit-learn 1.0の新機能として、OC-SVMのオンライン版が紹介されている。 scikit-learn.org
オンライン化のご利益は訓練時間の大幅な削減である。 そこで本記事では、「結局オンライン化でどれくらい速くなったの?」という疑問に答えるべく、オンライン版OCSVMを使いやすい形となるようPyODのフォーマットで実装し、外れ値検知のベンチマークデータセットの上で性能を評価した。
準備
pip3 install pyod pip3 install scikit-learn -U # 最新版1.0系を入れるという意味
作成したクラス
SGD_OCSVMクラスを作成した。以下がソースコードである。
カーネル近似のためのNyström methodを実装したNystroemクラスと、SGDOneClassSVMクラスをそれぞれインスタンス化して実装している。
Nyström methodによりグラム行列を低ランク近似&feature mapを手に入れる。feature mapによりデータを(一般には)高次元空間に写像した後に、OC-SVMの目的関数が最小となるべく、パラメタをSGDで更新する。
つまり内部的には
- 訓練データに対するNystroemのfit & transform
- transform後のデータに対して、SGDOneClassSVMのfit
ということであり、sklearn的なパイプラインである。
参考までにそれらクラスの公式ドキュメントを示しておく。
- sklearn.kernel_approximation.Nystroem — scikit-learn 1.0.2 documentation
- sklearn.linear_model.SGDOneClassSVM — scikit-learn 1.0.2 documentation
今回は2クラスの引数をそのまま受け継ぐ形でSGD_OCSVMクラスを実装しているので、引数の数が多いのはイケてないが、やむを得ない。
使い方の例
上記のPythonファイルを、例えばsgd_ocsvm.pyとして保存する。使い勝手はPyODに合わせるようにした。
# PyODのOCSVM(バッチ型)の訓練 from pyod.models.ocsvm import OCSVM clf = OCSVM(gamma="auto") # デフォルトはRBFカーネル clf.fit(X) # SGDによるオンライン型のOCSVMの訓練 from sgd_ocsvm import SGD_OCSVM clf = SGD_OCSVM(gamma="auto") # デフォルトはRBFカーネル clf.fit(X)
実験
PyODのベンチマークセット15種類に対して、バッチ型とオンライン型それぞれでモデルを訓練し、その所要時間および異常検知精度を比較する。 データセットは以下のサイトから入手できる。
これらデータセットの特性を簡単にまとめておく。
データセット | サンプル数 | 特徴量次元 | 外れ値比率 |
---|---|---|---|
arrhythmia | 452 | 274 | 14.6% |
cardio | 1831 | 21 | 9.61% |
glass | 214 | 9 | 4.21% |
ionosphere | 351 | 33 | 35.9% |
letter | 1600 | 32 | 6.25% |
mnist | 7603 | 100 | 9.21% |
musk | 3062 | 166 | 3.17% |
optdigits | 5216 | 64 | 2.88% |
pendigits | 6870 | 16 | 2.27% |
pima | 768 | 8 | 34.90% |
satellite | 6435 | 36 | 31.6% |
satimage-2 | 5803 | 36 | 1.22% |
vertebral | 240 | 6 | 12.5% |
vowels | 1456 | 12 | 3.43% |
wbc | 378 | 30 | 5.56% |
実験に使用したノートブックを示す。なお訓練時のデータには外れ値は含まれている(訓練データとテストデータを6:4で分割)。
結果
ノートブックより得られるグラフを図1に示す。左から、「訓練時間」「ROC」「Precision @ n」である。OCSVMはバッチ型、SGD-OCSVMはオンライン型である。
上図より、訓練時間はデータセット平均でも劇的に削減できていることが分かる。なおかつ、異常検知精度は据え置きであり、今回はカーネル近似の影響をほぼ受けていない結果となった。
各データセットごとの訓練時間を示したのが図2である。
上図より、サンプル数の多いmnistやoptdigits、pendigits、satellite、satimage-2で時間削減の効果が大きいことが読み取れる。特にmnistに関しては99%の時間削減となった (1- 0.053 / 6.1873 = 0.99) 。それ以外のデータセットについてはサンプル数が少なく、時間削減の効果は小さい。
参考までに、各データセットごとのROCスコアとPrecision @ nスコアを図3と図4にそれぞれ示す。ROCはほぼ同じグラフである。Precision @ nがオンライン化でやや劣るケースが見られるのは近似の影響であろう。
おわりに
今回の要点:
- SGDを用いた最適化に基づくOC-SVMを、PyODフォーマットで利用可能となるように実装。カーネル近似手法の一つであるNyström methodの組み合わせ(パイプライン)。
- 15ベンチマーク用データセットを用いて、訓練時間削減の効果を検証。サンプル数が多い場合に時間削減の効果が大きいことを確認。
みんなPyOD使おうぜ!