u++の備忘録

カーネル密度推定による分類方法をPythonで可視化

 『画像認識』(機械学習プロフェッショナルシリーズ)*1の第5章pp.143-144に載っていた、カーネル密度推定による分類方法を可視化。
f:id:upura:20171017202619p:plain

data1(青)は0を中心とした正規分布の乱数、data2(黄)は3を中心とした正規分布の乱数。カーネル密度推定により確率密度が計算されている。例えば x=1のときは青の線が黄の線を上回っているので「data1に分類するのが尤もらしい」という具合に考える。

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style('whitegrid')

data1 = np.random.randn(50)
data2 = 3 + np.random.randn(50)

# data1のプロット
sns.distplot(data1, axlabel="x")
sns.rugplot(data1)

# data2のプロット
sns.distplot(data2, color="y")
sns.rugplot(data2, color="y")

plt.ylabel('p(x)')
plt.show()