言語処理100本ノック 2020「77. ミニバッチ化」
問題文
問題の概要
ミニバッチ化の処理を追加しました。
import joblib import matplotlib.pyplot as plt import numpy as np import torch from torch import nn, optim from torch.utils.data import DataLoader, TensorDataset from tqdm import tqdm X_train = joblib.load('ch08/X_train.joblib') y_train = joblib.load('ch08/y_train.joblib') X_train = torch.from_numpy(X_train.astype(np.float32)).clone() y_train = torch.from_numpy(y_train.astype(np.int64)).clone() X_valid = joblib.load('ch08/X_valid.joblib') y_valid = joblib.load('ch08/y_valid.joblib') X_valid = torch.from_numpy(X_valid.astype(np.float32)).clone() y_valid = torch.from_numpy(y_valid.astype(np.int64)).clone() X_test = joblib.load('ch08/X_test.joblib') y_test = joblib.load('ch08/y_test.joblib') X_test = torch.from_numpy(X_test.astype(np.float32)).clone() y_test = torch.from_numpy(y_test.astype(np.int64)).clone() X = X_train y = y_train ds = TensorDataset(X, y) net = nn.Linear(X.size()[1], 4) loss_fn = nn.CrossEntropyLoss() optimizer = optim.SGD(net.parameters(), lr=0.01) batchSize = [1, 2, 4, 8] for bs in batchSize: loader = DataLoader(ds, batch_size=bs, shuffle=True) train_losses = [] valid_losses = [] train_accs = [] valid_accs = [] for epoc in tqdm(range(100)): train_running_loss = 0.0 valid_running_loss = 0.0 for xx, yy in loader: y_pred = net(xx) loss = loss_fn(y_pred, yy) optimizer.zero_grad() loss.backward() optimizer.step() train_running_loss += loss.item() valid_running_loss += loss_fn(net(X_valid), y_valid).item() joblib.dump(net.state_dict(), f'ch08/state_dict_{epoc}.joblib') train_losses.append(train_running_loss) valid_losses.append(valid_running_loss) _, y_pred_train = torch.max(net(X), 1) train_accs.append((y_pred_train == y).sum().item() / len(y)) _, y_pred_valid = torch.max(net(X_valid), 1) valid_accs.append((y_pred_valid == y_valid).sum().item() / len(y_valid)) plt.plot(train_losses, label='train loss') plt.plot(valid_losses, label='valid loss') plt.legend() plt.show() plt.plot(train_accs, label='train acc') plt.plot(valid_accs, label='valid acc') plt.legend() plt.show()