Transformerの学習率を調整するSchedulerクラスをPyTorchで書いた

はじめに

Attention is All You Needという論文で「warmup & ステップ数の逆平方根で学習率を減衰」させる学習率スケジューリングが提案されたが、そのようなスケジューリングを手軽に行うスケジューラを書いたということである。

ソースコード

from torch.optim.lr_scheduler import _LRScheduler

class TransformerLR(_LRScheduler):
    """TransformerLR class for adjustment of learning rate.

    The scheduling is based on the method proposed in 'Attention is All You Need'.
    """

    def __init__(self, optimizer, warmup_epochs=1000, last_epoch=-1, verbose=False):
        """Initialize class."""
        self.warmup_epochs = warmup_epochs
        self.normalize = self.warmup_epochs**0.5
        super().__init__(optimizer, last_epoch, verbose)

    def get_lr(self):
        """Return adjusted learning rate."""
        step = self.last_epoch + 1
        scale = self.normalize * min(step**-0.5, step * self.warmup_epochs**-1.5)
        return [base_lr * scale for base_lr in self.base_lrs]

学習率のスケジューラをオリジナルで書きたい場合は、_LRSchedulerを継承し、get_lr関数を自作するのが良い。PyTorch本家がそのような実装となっている。

github.com

実装において、min(step**-0.5, step * self.warmup_epochs**-1.5)の部分が学習率調整の本質である。なおself.normalizeは値の範囲を0から1に収めるための正規化定数の役割を果たす。 これらの積によりscaleが計算され、それをoptimizerの構築時に指定した学習率に掛けることで学習率が更新される。

注意すべきは、上記の実装はepoch単位で学習率を調整することが前提になっているという点である (step = self.last_epoch + 1というあたり)。

学習の初期段階では0に近い係数(scale)が学習率に掛けられるが、徐々に増えていき、warmup_epochsまでエポック数が達すると係数は1で最大となる。 その後は係数が1より小さくなるため、学習率が減衰されていく。

参考文献

arxiv.org