PyTorchのRNNクラスとRNNCellクラスを再発明しよう:作って試そう! ディープラーニング工作室(1/2 ページ)
RNNクラスがどのような処理をしているのかを、自分だけのRNNクラスを定義しながら、見ていくことにします。
前回はRNNの概要と、PyTorchが提供するRNNクラスを用いて、サイン波の推測を行ってみました。今回は、RNNクラスを模した簡素なクラスを自分で作りながら、その動作を見ていくことにしましょう。また、PyTorchが提供するRNNクラスやRNNCellクラスはより高度で複雑な処理を行っているので、実際のコードは本稿のものとはかなり異なっていることには注意してください(PyTorchのRNNクラスやRNNCellクラスでは最終的にはPythonではなく、C++で実装されているコードが呼び出されているので、これはあくまでも多分そんな感じというお話です)
PyTorchなどのフレームワークが提供するクラスを模したクラスを自分で作ることを「車輪の再発明」と呼ぶことがあります。車輪の再発明は「既にあるものをまた発明する」ということで「無駄なこと」と考えられがちですが、自分で何かを理解するためには再発明をすることが役に立つこともあります。プロダクションコードでこうしたことをするのは確かに無駄なのですが、今回はRNNが内部でどんなことをしているのかを勉強する目的で再発明をしてみましょう。
自分で作るRNNクラス
前回は、PyTorchが提供するRNNクラスを利用して、次のようなニューラルネットワーククラスを記述しました。
class Net(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.rnn = torch.nn.RNN(input_size, hidden_size)
self.fc = torch.nn.Linear(hidden_size, 1)
def forward(self, x, hidden):
output, h = self.rnn(x, hidden)
output = self.fc(output[:, -1])
return output, h
今回、目指すのは上のプログラムにある「self.rnn = torch.nn.RNN(input_size, hidden_size)」という行を「self.rnn = MyRNN(input_size, hidden_size)」のように置き換えることです。
前回に見た通り、PyTorchが提供しているRNNクラスは、時系列データ(ここではサイン波を形成する25個の値)のシーケンス(つまり、複数の時系列データ)を一気に処理してくれます。
そこで今回は、RNNクラスの動作を解きほぐしながら、2つのクラスを定義することにします。サイン波を構成するデータセットの生成や、学習/評価に関連するコードは基本的には前回と変わりません。そこで、最初にデータセットを生成するコードを以下に示しておきます。なお、今回のコードはこのノートブックで公開しているので、必要に応じて参照してください。
import numpy as np
import matplotlib.pyplot as plt
import torch
from random import uniform
def make_data(num_div, cycles, offset=0):
step = 2 * np.pi / num_div
res0 = [np.sin(step * i + offset) for i in range(num_div * cycles + 1)]
res1 = [np.sin(step * i + offset) + uniform(-0.02, 0.02) for i in range(num_div * cycles + 1)]
return res0, res1
def make_train_data(num_div, cycles, num_batch, offset=0):
x, x_w_noise = make_data(num_div, cycles, offset)
count = len(x) - num_batch
data = [x_w_noise[idx:idx+num_batch] for idx in range(count)]
labels = [x[idx+num_batch] for idx in range(count)]
num_items = len(data)
train_data = torch.tensor(data, dtype=torch.float).reshape(num_items, num_batch, -1)
train_labels = torch.tensor(labels, dtype=torch.float).reshape(num_items, -1)
return train_data, train_labels
学習/評価に使うコードは後ほど、紹介します(が、既に述べた通り、前回のコードとほぼ同じです)。
ここで作成するMyRNNクラスでは、どんな処理をしなければならないかをざっくりと示します。
- 時系列データのシーケンスを、シーケンスごと(ひとまとまりの時系列データごと)にRNNで処理する
- ひとかたまりの時系列データについては、それを構成する個々のデータを順次ニューラルネットワークに入力し、次のサイン波の値を推測(できるように学習)
前回は1周期を100分割した、2周期分のデータを作成しました(計200個+最後に1個足していました)。
ひとかたまりの時系列データは25個の値で構成されていたので(バッチサイズが25)、シーケンスの数は「201−25=176」個です。
要するに、MyRNNクラスは25個の数値で構成される時系列データ(と学習時には対応する正解ラベル)を、176個受け取り、それらをRNNで処理しているということです。
また、隠れ状態のサイズとしては前回は「32」を指定していたことも思い出してください。これに合わせて、RNNクラスのインスタンスには今述べた176個のシーケンスと隠れ状態の初期値を渡してもいました(「self.rnn = torch.nn.RNN(input_size, hidden_size)」でinput_size=1、hidden_size=32)。
以上がおおよその前提となります。それでは、MyRNNクラスの定義を見ていきましょう。
MyRNNクラス
上で述べたように、PyTorchのRNNクラスは複雑な処理をひとまとめに実行してくれますが、その一方で、RNN層で行う「時系列データを1つだけ受け取って、計算を行い、その結果を出力する」というクラスもPyTorchでは提供されています。それが、RNNCellクラス(torch.nn.RNNCellクラス)です。
このクラスでは、RNNクラスと同様に、そのインスタンス生成時に「入力サイズ」と「隠れ状態のサイズ」の2つの引数を指定します(バイアスの有無、活性化関数の種類なども指定できますが、本稿ではデフォルトのままとします)。
これもRNNクラスと同様ですが、このクラスのインスタンスを呼び出すときには、1つの時系列データと隠れ状態の初期値を渡します。ただし、RNNクラスとは異なり、その戻り値は1つの時系列データを処理した結果、すなわち、現在の隠れ状態を表すテンソルのみとなります。
ここで隠れ状態について少し話をしておきましょう。
隠れ状態とは、RNNに入力される時系列データの一つ一つについて存在します。前回そして今回は時系列データは25個の数値で構成されるものとしています。隠れ状態のサイズも同様に、前回と今回は32としています。つまり、25個のデータのそれぞれについて、そのデータを処理した時点での隠れ状態を表す32個のデータ、つまり、合計で25×32というサイズのテンソルが存在するということです(RNNの手法によっては、「隠れ状態のサイズの3倍、4倍」など、より多くのデータで隠れ状態を表現することもあります)。これらを「時刻t-1、t、……における隠れ状態」などと表現することもあるでしょう。
そして、この隠れ状態を、以降のデータを基に計算する際に関与させることで(隠れ状態に対して用いる重みやバイアスも用意されています)、時系列データに含まれるある値と以降の値とが関連を持つ(影響を受ける)ようになっています。
ここで少し気になるのは、上でも述べた通り、RNNクラスではそのインスタンスを呼び出すと、2つの値が戻されていたことです。例えば、「output, hidden = net(X_train, hidden)」のようにNetクラスのインスタンスを介して、RNNクラスのインスタンスを呼び出したとしましょう。このとき、hiddenには、全ての時系列データ(今回ならば176個)の処理を完了した時点での隠れ状態が返されます(RNNCellクラスのインスタンスが戻す値と同様)。対して、outputには全ての時系列データを処理するごとに得られた隠れ状態が全て含まれています。
RNNCellクラスでは、一度に1つの時系列データのみを処理するので、得られる隠れ状態は1つだけですが、RNNクラスは幾つもの時系列データ(シーケンス)を処理するので、その数だけ隠れ状態があり、それらをひとまとめにしたものが最終的な出力となっているということです。
逆にいうと、RNNCellクラスを使って、全ての時系列データを処理するときには、戻り値である隠れ状態を、保存しておいて、最後にそれを返すようにする必要があるということです。
というわけで、RNNCellクラスを使用したMyRNNクラスのコードを以下に示します。
class MyRNN(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.rnncell = torch.nn.RNNCell(input_size, hidden_size)
def forward(self, x, hidden):
count = len(x) # sequence length
output = torch.Tensor()
for idx in range(count):
hidden = self.rnncell(x[idx], hidden)
output = torch.cat((output, hidden))
output = output.reshape(len(x), -1, self.hidden_size)
return output, hidden
__init__メソッドでは、入力サイズ(もちろん、これは今回は1つの数値だけなので、1に決まっていますが)と隠れ状態のサイズをパラメーターに受け取り、それをインスタンス変数に保存した後(ただし、self.input_sizeは未使用)、それらを指定して、RNNCellクラスのインスタンスを生成し、それをインスタンス変数self.rnncellに代入しています。
forwardメソッドでは、時系列データが何個あるのかを数えて(ここでは176個になります)、その数だけfor文でループ処理を行っています。ループの内部では、パラメーターxに受け取った全ての時系列データから個々の時系列データを取り出して、隠れ状態の初期値(パラメーターhiddenに受け取るもの)と共にRNNCellクラスのインスタンスに渡して、計算を行っています。その戻りが1つの時系列データを処理し終わった時点での隠れ状態となります。
このようにして返された隠れ状態は、torch.catメソッドで、全ての隠れ状態を保存するためのテンソルであるoutputへ追加していきます。これにより、今回は全ての時系列データの処理が完了した際には「176×25×32」というサイズのテンソルができあがります(最後に、reshapeメソッドでそうなるようにテンソルの形状も変更しています)。
では、このMyRNNクラスを使って、前回のNetクラスを書き直してみましょう。
class Net(torch.nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super().__init__()
self.rnn = MyRNN(input_size, hidden_size)
self.fc = torch.nn.Linear(hidden_size, output_size)
def forward(self, x, hidden):
output, h = self.rnn(x, hidden)
output = self.fc(output[:, -1])
return output, h
といっても、コードは前回のものとほぼ同じで、PyTorchが提供するRNNクラスの代わりにMyRNNクラスを使っているだけです。
Copyright© Digital Advantage Corp. All Rights Reserved.