Metric Learning について
Metric Learning は、データの教師情報を基にデータ間の距離や類似度などの Metric を学習する手法です。日本語で手軽に読める記事だと、*1, *2 などが詳しいです。
このたび、phalanx さんの tweet *3で、 Metric Learning の基礎的なアルゴリズムのいくつかが scikit-learn-contrib *4に搭載されていると知りました。
本記事では、scikit-learn-contrib の metric-learn パッケージを用いて、簡単にMetric Learning を試します。
利用するデータセット
今回は、sklearn に含まれている load_digits データセットを利用します*6。64次元の特徴量・0-9の10種類のラベルを持つ手書き数字のデータセットです。
画像は*7より引用。
可視化(Metric Learning 前)
特徴量の可視化に当たっては、T-SNE *8 を用いて2次元への削減を行います。
次のコードは、metric-learn の docs に掲載されていた内容*9を、凡例を出すように一部改変しています。
def plot_tsne(X, y): plt.figure(figsize=(8, 6)) # clean the figure plt.clf() tsne = TSNE() X_embedded = tsne.fit_transform(X) cmap = plt.get_cmap("tab10") for idx in range(10): plt.scatter(X_embedded[(y==idx), 0], X_embedded[(y==idx), 1], c=cmap(idx), label=idx) plt.legend() plt.xticks(()) plt.yticks(()) plt.show()
load_digits データセットをそのまま可視化したところ、下図のようになりました。大まかに分かれてはいますが、中央付近など少し煩雑になっていると分かります。
可視化(Metric Learning 後)
次に、Metric Learning を実施します。
import metric_learn # setting up LMNN lmnn = metric_learn.LMNN(k=6, learn_rate=1e-6) # fit the data! lmnn.fit(X, y) # transform our input space X_lmnn = lmnn.transform(X)
いくつかのアルゴリズムが実装されていますが、ここでは Large Margin Nearest Neighbor (LMNN) を採用します。
Algorithms
- Large Margin Nearest Neighbor (LMNN)
- Information Theoretic Metric Learning (ITML)
- Sparse Determinant Metric Learning (SDML)
- Least Squares Metric Learning (LSML)
- Neighborhood Components Analysis (NCA)
- Local Fisher Discriminant Analysis (LFDA)
- Relative Components Analysis (RCA)
- Metric Learning for Kernel Regression (MLKR)
- Mahalanobis Metric for Clustering (MMC)
READMEから引用。
Metric Learning 実施後の特徴量を可視化したところ、下図のようになりました。Metric Learning 実施前よりも、各クラスがハッキリと分かれているのが確認できます。
今回は全データで学習し、全データに適用しています。Metric Learning はデータ間の距離や類似度などの Metric を学習しているので、学習に用いていないデータセットに適用することが可能です。
例えば Kaggle のような教師あり機械学習の文脈で利用する場合には、train データセットで Metric を学習し、test データセットにも適用することになるでしょう。分離に適した新しい特徴量空間を用いることで、より分類性能が高いモデルの構築が期待されます。
おわりに
本記事では、scikit-learn-contrib の metric-learn パッケージに搭載されている Metric Learning を試しました。なかなか使い所が難しい印象もある技術ではありますが、選択肢の一つとして持っておく価値は多分にあると感じています。
*1:copypaste-ds.hatenablog.com
*3: metric learning is part of sklearn contribhttps://t.co/t8XUirRzBa