Deep SVDDに基づく外れ値検知をPyTorchで実装した

はじめに

外れ値検知の機械学習モデルの一つとして"Deep SVDD" が知られている。 今回はこれを、異常検知/外れ値検知のためのPythonパッケージPyODの仕様に沿った形で、PyTorchにより実装したということである。

外れ値検知は1クラス分類と捉えることができ、「通常」クラスか「それ以外(=外れ値、つまり異常)」という分類が行われる。 "Deep SVDD"は、外れ値検知の既存手法であるOne-Class SVM / Support Vector Data Description (SVDD) の非線形カーネルニューラルネットワークで置き換えたものである。

準備

PyODはpipでインストール可能である。

pip3 install pyod

ほか、torch, sklearn, numpy, tqdmのインストールを済ませておく。

Deep SVDDについて

概要は以下の記事を参考するのが早いだろう(落合先生のフォーマット!)。 salty-vanilla.github.io

論文は以下から読むことができる(ICML2018)。 proceedings.mlr.press

ほか、参考となる記事とスライドを掲載する。 speakerdeck.com www.skillupai.com

ICLR2021ではDeep One-Class Classificationについて深い考察に基づいた拡張が提案されている。 openreview.net

作成したクラス:DeepSVDD

【コードを表示する】 gist.github.com

なお論文著者によるオリジナルのPyTorch実装は以下から入手できる。 github.com

本記事の実装は上記オリジナル実装のsimplified versionとみなせる。

デモンストレーション

PyODから提供されている各種アルゴリズムのexample用スクリプトを参考に、簡単なデモンストレーションのnotebookを作成した。 今回作成したDeepSVDDクラスをdeep_svdd.pyとして保存した場合のnotebookである。

gist.github.com

ひとまずうまくいっているように見える。

注意点

今回の実装では、ニューラルネット部分は全結合層のみからなっている。 画像データを対象にする場合には畳み込み層やプーリング層へと実装を修正する必要がある(『DeepSVDD_net』クラスの中身)。 この記事の読者ならば修正は容易だろう。

今後の課題

  • 実際の画像データ(MNIST, CIFAR-10)を用いたオリジナル論文の追試
  • 画像データへの対応