連載
» 2020年06月26日 05時00分 公開

PyTorchのRNNクラスとRNNCellクラスを再発明しよう作って試そう! ディープラーニング工作室(2/2 ページ)

[かわさきしんじ,Deep Insider編集部]
前のページへ 1|2       

MyRNNクラスを使っての学習と評価

 次に、今定義したMyRNNクラスとNetクラスを使って、学習と評価を実行してみましょう。学習するコードは関数として、この後にも使うことにします。

def train(epocs, net, X_train, y_train, criterion, optimizer):
    losses = []

    for epoch in range(EPOCHS):
        print('epoch:', epoch)
        optimizer.zero_grad()
        hidden = torch.zeros(num_batch, hidden_size)
        output, hidden = net(X_train, hidden)
        loss = criterion(output, y_train)
        loss.backward()
        optimizer.step()

        print(f'loss: {loss.item() / len(X_train):.6f}')
        losses.append(loss.item() / len(X_train))

    return output, losses

学習を行うコード

 前回の学習を行うコードを関数として、エポック数、ニューラルネットワークモデル、訓練データと正解ラベル、損失関数と最適化アルゴリズムを受け取るようにしたところ以外は変更点はほぼありません。が、強調表示で示した「hidden = torch.zeros(num_batch, hidden_size)」という行だけが、少し違っています。PyTorchのRNNクラスは、複数のRNN層を積み重ねることが可能ですが、今回のMyRNNクラスではこれをサポートしていません。そのため、隠れ状態の初期値となるテンソルの初期化が少し変わっています(詳細な説明は省略します)。

 後は、学習に必要な要素(訓練データ、正解ラベル、Netクラスのインスタンス、損失関数、最適化アルゴリズム)を生成して、上記のtrain関数に渡すだけです。

num_div = 100  # 1周期の分割数
cycles = 2  # 周期数
num_batch = 25  # 1つの時系列データのデータ数
X_train, y_train = make_train_data(num_div, cycles, num_batch)

input_size = 1  # 入力サイズ
hidden_size = 32  # 隠れ状態のサイズ
output_size = 1  # 出力層のノード数
net = Net(input_size, hidden_size, output_size)

criterion = torch.nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.05)

学習に必要な要素の生成

 最後にエポック数と共に、これらをtrain関数に渡します。戻り値は最後の学習結果の出力と平均損失を含んだリストになっています。

EPOCHS = 100
output, losses = train(EPOCHS, net, X_train, y_train, criterion, optimizer)

学習の実行

 実行結果を以下に示します。

実行結果 実行結果

 損失がどんな状態をグラフにプロットしてみましょう。

plt.plot(losses)

損失のプロット

 実行結果を以下に示します。

実行結果 実行結果

 前回と同様なグラフとなりました。最後の学習結果を基に、サイン波の推測ができているかも確認してみましょう。

output = output.reshape(len(output)).detach()
sample_data, _ = make_data(num_div, cycles)
plt.plot(range(24, 200), output)
plt.plot(sample_data)
plt.grid()

サイン波と推測結果のプロット

 実行結果を以下に示します。

実行結果 実行結果

 こちらに関しても前回と同様な結果となりました。

MyRNNCellクラス

 ここでRNNクラスを置き換えるMyRNNクラスができたので、ここで終わってもよいのですが、PyTorchが提供するRNNCellクラスを使っている点にはちょっと不満もあります。そこで、最後にRNNCellクラスを置き換えるMyRNNCellクラスも定義してみましょう。

前のページへ 1|2       

Copyright© Digital Advantage Corp. All Rights Reserved.

RSSについて

アイティメディアIDについて

メールマガジン登録

@ITのメールマガジンは、 もちろん、すべて無料です。ぜひメールマガジンをご購読ください。