u++の備忘録

言語処理100本ノック 2020「77. ミニバッチ化」

問題文

nlp100.github.io

問題の概要

ミニバッチ化の処理を追加しました。

import joblib
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm

X_train = joblib.load('ch08/X_train.joblib')
y_train = joblib.load('ch08/y_train.joblib')
X_train = torch.from_numpy(X_train.astype(np.float32)).clone()
y_train = torch.from_numpy(y_train.astype(np.int64)).clone()

X_valid = joblib.load('ch08/X_valid.joblib')
y_valid = joblib.load('ch08/y_valid.joblib')
X_valid = torch.from_numpy(X_valid.astype(np.float32)).clone()
y_valid = torch.from_numpy(y_valid.astype(np.int64)).clone()

X_test = joblib.load('ch08/X_test.joblib')
y_test = joblib.load('ch08/y_test.joblib')
X_test = torch.from_numpy(X_test.astype(np.float32)).clone()
y_test = torch.from_numpy(y_test.astype(np.int64)).clone()

X = X_train
y = y_train
ds = TensorDataset(X, y)

net = nn.Linear(X.size()[1], 4)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.01)

batchSize = [1, 2, 4, 8]

for bs in batchSize:
    loader = DataLoader(ds, batch_size=bs, shuffle=True)

    train_losses = []
    valid_losses = []
    train_accs = []
    valid_accs = []

    for epoc in tqdm(range(100)):
        train_running_loss = 0.0
        valid_running_loss = 0.0

        for xx, yy in loader:
            y_pred = net(xx)
            loss = loss_fn(y_pred, yy)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            train_running_loss += loss.item()
            valid_running_loss += loss_fn(net(X_valid), y_valid).item()

        joblib.dump(net.state_dict(), f'ch08/state_dict_{epoc}.joblib')

        train_losses.append(train_running_loss)
        valid_losses.append(valid_running_loss)

        _, y_pred_train = torch.max(net(X), 1)
        train_accs.append((y_pred_train == y).sum().item() / len(y))
        _, y_pred_valid = torch.max(net(X_valid), 1)
        valid_accs.append((y_pred_valid == y_valid).sum().item() / len(y_valid))

plt.plot(train_losses, label='train loss')
plt.plot(valid_losses, label='valid loss')
plt.legend()
plt.show()

plt.plot(train_accs, label='train acc')
plt.plot(valid_accs, label='valid acc')
plt.legend()
plt.show()