Recurrent Batch NormalizationをTensorFlowで実装した

要するにLSTMの内部でバッチ正規化を行うということ。

論文と実装は以下の通り。

本実装は以下の先行実装に依拠しており、ここに感謝する次第である。

前者はTensorFlow実装、後者はTheano実装であるが、前者は後者を参考にして実装された。本実装はさらに前者の実装に対して自分用に手を加えたということである。なおTensorFlow 0.10でのみ動作確認をしている。

本実装ではBN_LSTMCellクラスが定義されている。使い方の注意点は、BN_LSTMCellのインスタンスを作成する際、引数に'is_training'を取るということである。これはバッチ正規化に関して、学習時と評価時で振る舞いを変えるためのものである。より具体的に言うと、学習時には各ミニバッチについて統計量(平均と分散)を計算してバッチ正規化を行う必要があるのだが、評価時の各ミニバッチに対しては改めて統計量を計算する必要はなく、学習データ全体から求められる確定した統計量に基づいてバッチ正規化を行うのである。

今後はGRUやSGU, MGUなどにrecurrent batch normalizationを実装する予定である。