u++の備忘録

scikit-learn-contrib の Metric Learning を試す

Metric Learning について

Metric Learning は、データの教師情報を基にデータ間の距離や類似度などの Metric を学習する手法です。日本語で手軽に読める記事だと、*1, *2 などが詳しいです。

このたび、phalanx さんの tweet *3で、 Metric Learning の基礎的なアルゴリズムのいくつかが scikit-learn-contrib *4に搭載されていると知りました。

本記事では、scikit-learn-contrib の metric-learn パッケージを用いて、簡単にMetric Learning を試します。

インストール

README や PyPI *5 に記載のある通り、次の通りにインストールします。

pip install metric-learn

利用するデータセット

今回は、sklearn に含まれている load_digits データセットを利用します*6。64次元の特徴量・0-9の10種類のラベルを持つ手書き数字のデータセットです。

f:id:upura:20190818182045p:plain
画像は*7より引用。

可視化(Metric Learning 前)

特徴量の可視化に当たっては、T-SNE *8 を用いて2次元への削減を行います。

次のコードは、metric-learn の docs に掲載されていた内容*9を、凡例を出すように一部改変しています。

def plot_tsne(X, y):
    plt.figure(figsize=(8, 6))
    
    # clean the figure
    plt.clf()

    tsne = TSNE()
    X_embedded = tsne.fit_transform(X)

    cmap = plt.get_cmap("tab10")
    for idx in range(10):
        plt.scatter(X_embedded[(y==idx), 0], X_embedded[(y==idx), 1], c=cmap(idx), label=idx)

    plt.legend()
    plt.xticks(())
    plt.yticks(())

    plt.show()

load_digits データセットをそのまま可視化したところ、下図のようになりました。大まかに分かれてはいますが、中央付近など少し煩雑になっていると分かります。

f:id:upura:20190818180214p:plain

可視化(Metric Learning 後)

次に、Metric Learning を実施します。

import metric_learn


# setting up LMNN
lmnn = metric_learn.LMNN(k=6, learn_rate=1e-6)

# fit the data!
lmnn.fit(X, y)

# transform our input space
X_lmnn = lmnn.transform(X)

いくつかのアルゴリズムが実装されていますが、ここでは Large Margin Nearest Neighbor (LMNN) を採用します。

Algorithms

  • Large Margin Nearest Neighbor (LMNN)
  • Information Theoretic Metric Learning (ITML)
  • Sparse Determinant Metric Learning (SDML)
  • Least Squares Metric Learning (LSML)
  • Neighborhood Components Analysis (NCA)
  • Local Fisher Discriminant Analysis (LFDA)
  • Relative Components Analysis (RCA)
  • Metric Learning for Kernel Regression (MLKR)
  • Mahalanobis Metric for Clustering (MMC)

READMEから引用。

Metric Learning 実施後の特徴量を可視化したところ、下図のようになりました。Metric Learning 実施前よりも、各クラスがハッキリと分かれているのが確認できます。

f:id:upura:20190818180230p:plain

今回は全データで学習し、全データに適用しています。Metric Learning はデータ間の距離や類似度などの Metric を学習しているので、学習に用いていないデータセットに適用することが可能です。

例えば Kaggle のような教師あり機械学習の文脈で利用する場合には、train データセットで Metric を学習し、test データセットにも適用することになるでしょう。分離に適した新しい特徴量空間を用いることで、より分類性能が高いモデルの構築が期待されます。

おわりに

本記事では、scikit-learn-contrib の metric-learn パッケージに搭載されている Metric Learning を試しました。なかなか使い所が難しい印象もある技術ではありますが、選択肢の一つとして持っておく価値は多分にあると感じています。

実装は notebook 形式で GitHub にて公開しています*10