"Differential Entropic Clustering of Multivariate Gaussians"をNumbaを使って高速化してみた話

はじめに

前回記事で実装した Differential Entropic Clustering をもう少し高速化したいなぁ,という話.

tam5917.hatenablog.com

実装

やり方は簡単で,numbaをインストールして,@jit デコレータをBurg matrix divergence およびMahalanobis距離を計算する関数につけるだけ.

@jit(nopython=True)
def comp_burg_div(mat_x, mat_y):
    """Compute Burg matrix divergence."""
    dim = mat_x.shape[0]
    mat_y_inv = LA.inv(mat_y)
    mat = mat_x @ mat_y_inv
    burg_div = np.trace(mat) - np.log(LA.det(mat)) - dim
    return burg_div

@jit(nopython=True)
def comp_maha_dist(vec_x, vec_y, cov):
    """Compute Mahalanobis distance."""
    delta = vec_x - vec_y
    m = np.dot(np.dot(delta, LA.inv(cov)), delta)
    return np.sqrt(m)

前回記事のdemoスクリプトを上記の関数たちに置き換えて,time moduleで簡単に計測してみたところ,

  • 高速化なし ... 58.50秒
  • Burg matrix divergence を計算する関数のみ高速化 ... 33.59秒
  • Mahalanobis 距離を計算する関数のみ高速化 ... 45.33秒
  • 両方の関数を高速化 ... 12.31秒

となった.

高速化に成功した.行列計算関連で繰り返し呼び出される関数には,jitデコレータをつけておくのが良さそうである.

以下,参考まで,今回の実装.main関数の前後で時間計測する(面倒なので).

A demo script of "Differential Entropic Clustering of Multivariate Gaussians" in Python. This is a faster version thanks to numba and jit. · GitHub

おわりに

Numbaによる高速化の効果は大いにあったので,薦めたい.