u++の備忘録

テーブルデータ向けのGAN(TGAN)で、titanicのデータを増やす

はじめに

ynktk さんのツイート*1を見て、テーブルデータ向けの GAN の存在を知りました。本記事では、TGAN を用いて titanic のデータを拡張してみます。

TGANとは

テーブルデータに対応した GAN (Generative Adversarial Network, 敵対的生成ネットワーク) *2 です。数値などの連続変数だけではなく、カテゴリ変数にも対応しています。

Titanic のデータを増やす

今回は、著名なデータセットである Titanic のデータを対象にTGANを試します。

データの読み込み

まずはデータを読み込みます。データは Kaggle からダウンロードしました*3。下図のようなデータが格納されています。

df = pd.read_csv('input/train.csv')
df.head()

f:id:upura:20190820092059p:plain

行数は900弱です。

df.shape
(891, 12)

欠損値の削除

TGANは、欠損値に対応していません。最初に各カラムの欠損値の数を確認しておきます。

df.isnull().sum()
PassengerId      0
Survived         0
Pclass           0
Name             0
Sex              0
Age            177
SibSp            0
Parch            0
Ticket           0
Fare             0
Cabin          687
Embarked         2
dtype: int64

欠損値が大半を占める 'Cabin' は、この段階で削除します。合わせて、GANでの増幅に不適切な ['PassengerId', 'Name', 'Ticket'] も削除しておきます。

df.drop(['Cabin', 'PassengerId', 'Name', 'Ticket'], axis=1, inplace=True)

その他の欠損値を含む2カラムについて、'Age' は平均値を四捨五入して整数にした値、'Embarked' は最頻値である 'S' で埋めました。

df['Age'].fillna(round(df['Age'].mean(), 0), inplace=True)
df['Embarked'].fillna(df['Embarked'].value_counts().index[0], inplace=True)

カラム名の保持

現在 PyPI でインストールできる TGAN (ver 0.1.0) には、実行後に DataFrame のカラム名がインデックス番号に置換されてしまう不具合があります*4。そのため、実行後のために事前にカラム名を変数に入れて保持しておく必要があります。

df_columns = df.columns

連続変数の指定

TGAN の実行時には、連続変数のカラムのインデックス番号一覧をリスト型で渡します。今回は、次のように float 型のカラムを抽出しました。

continuous_columns = [df.columns.get_loc(c) for c in df.select_dtypes(include=['float']).columns]

TGAN の実行

いよいよ TGAN を実行します。

from tgan.model import TGANModel
tgan = TGANModel(continuous_columns, batch_size=50)
tgan.fit(df)

このときdocsには記載がありませんが、小さめのbatch_sizeを引数に指定しないと、tensorpack の assertion error *5で実行が止まってしまいます。

実行時間は、900弱のデータセットで15分程度でした。学習済のモデルは、次のように保存可能です。

model_path = 'output/models/mymodel.pkl'
tgan.save(model_path)

サンプルの抽出

学習済のモデルから、次のようにデータを生成できます。今回は、元のデータセットと同数を指定しました。

num_samples = len(df)
samples = tgan.sample(num_samples)

f:id:upura:20190819161132p:plain

'Age' が小数点以下になっているので丸める処理などは必要かもしれませんが、'Sex' などカテゴリ変数も含めてデータが生成できていると分かります。目的変数である 'Survived' も、問題なく増幅されていました。

samples['Survived'].value_counts()
0    540
1    310

おわりに

本記事では、TGAN を用いて titanic のデータを拡張してみました。Kaggle などの文脈で言うと、学習用データの水増しに利用できる可能性があります。ただし、ynktk さんとも議論した通り*6、GAN でまともなデータを作るにはそもそも十分量のデータセットが必要というジレンマがありそうです。

TGAN で増やしたデータで性能が向上するかはデータセットと課題設定次第ですが、機会があれば試してみても面白いなと思いました。

今回の実装は GitHub *7で公開しています。

*1:

*2:papers.nips.cc

*3:www.kaggle.com

*4:github.com

*5:github.com

*6:

*7:github.com