u++の備忘録

LightGBMでクリスマスツリーを描く

本記事は、kaggle Advent Calendar 2018 その2の25日目の記事です。意図的にフライングして前日の24日、クリスマスイブに投稿します。

qiita.com

クリスマス用の記事として、LightGBMでクリスマスツリーを描いてみました。

なお「決定境界を用いて絵を描く」というアイディアは、4年前にTJOさんの投稿を見て以来、頭の片隅にありました。今年はKaggleに打ち込んでLightGBMに大変お世話になったので、最後までLightGBMを使い倒して、このアイディアを昇華させようと考えた次第です。

tjo.hatenablog.com

データセットはTJOさんのGitHubからダウンロードしました。

github.com

そのままでは工夫がないので「木の根元」に当たる新しいラベルもデータセットに加えました。

import random
random.seed(100)
x_add = [random.random() *6 - 3 for i in range(100)]
y_add = [-1 * random.random() *1.5 - 2.7 for i in range(100)]
label_add = [2 if abs(i) < 0.6 else 1 for i in x_add]

df_add = pd.DataFrame({
    'x': x_add,
    'y': y_add,
    'label': label_add
})

df = pd.concat([df, df_add])

LightGBMの決定境界を描く際には、mlxtendライブラリのplot_decision_regionsを使いました。学習済モデルとデータを渡すだけで、非常に簡単に利用可能です。

qiita.com

rasbt.github.io

from mlxtend.plotting import plot_decision_regions
import lightgbm as lgb
from sklearn.model_selection import train_test_split

X = df[['x', 'y']]
y = df['label']
X_train, X_valid, y_train, y_valid= train_test_split(X, y, random_state = 0)
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_valid, y_valid, reference=lgb_train)

lgbm_params = {
    'learning_rate': 0.2,
    'num_leaves': 8,
    'boosting_type': 'gbdt',
    'reg_alpha': 1,
    'reg_lambda': 1,
    'objective': 'regression',
    'metric': 'mae',
}

model = lgb.train(
    lgbm_params, lgb_train,
    valid_sets=lgb_eval,
    num_boost_round=1000,
    early_stopping_rounds=10,
)

plt.figure(figsize=(10,10))
plt.xticks(color="None")
plt.yticks(color="None")
plt.tick_params(length=0)
plot_decision_regions(np.array(X), np.array(y), clf=model, res=0.02, legend=2, colors='limegreen,white,brown')

f:id:upura:20181215165835p:plain

特にオチはないです。Merry Christmas!

実装はGitHubで公開しました。

github.com