PyTorchにはRNN機能を提供するクラスが幾つか用意されています。今回はその中でも基本的なRNNクラス(torch.nn.RNNクラス)を使用します。このクラスと、既におなじみのLinearクラスを組み合わせることにしましょう。
実際のクラス定義のコードは次の通りです。
class Net(torch.nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.rnn = torch.nn.RNN(input_size, hidden_size)
self.fc = torch.nn.Linear(hidden_size, 1)
def forward(self, x, hidden):
output, h = self.rnn(x, hidden)
output = self.fc(output[:, -1])
return output, h
__init__メソッドでは、インスタンス変数self.rnnにRNNクラスのインスタンスを代入しています。RNNクラスのインスタンス生成時には「torch.nn.RNN(input_size, hidden_size)」のように「入力のサイズ」と「隠れ状態のサイズ」の2つの引数を指定しています(これらは__init__メソッドのパラメーターに渡されるようにしてあります)。
ここで覚えておいてほしいのは、先ほど作成した訓練データではバッチサイズ(num_batch)として25を指定しましたが、これはここでいう「入力のサイズ」とは違うということです。このニューラルネットワークでは「一度に1つのデータを入力して、それを処理」しますが、バッチサイズの25というのは時系列のデータとしてそれらを連続的に入力するということです。先ほどの図に合わせると、次のようになります。
25個の時系列データが順次RNNに入力されると、最初に入力されたデータ(x1)はそれ以降のデータの処理に何らかの影響を与え、2番目に入力されたデータは3番目以降のデータの処理に何らかの影響を与えていくことが分かります。__init__メソッドのhidden_sizeパラメーターに値を指定して作成する隠れ状態は、次の層への出力となったり、次のデータの処理時にRNN層への入力として使われたりします。
全結合層についてはこれまでと同様ですが、PyTorchのRNNクラスと結合させる場合、その入力の数は、RNN層の隠れ状態のサイズに指定した値と同じにします。そのため、インスタンス変数self.fcに代入するLinearクラスのインスタンス生成では「torch.nn.Linear(hidden_size, 1)」としています。出力層のノード数が1なのは、これは時系列データを受け取り、そこから次のサイン波の値を1つだけ推測するからです。
なお、PyTorchのRNN層には活性化関数tanhを使った活性化処理がデフォルトで組み込まれているので、このクラス定義では明示的には活性化関数は登場しません。
forwardメソッドでは、まずパラメーターがselfを含めて3つ(実際の呼び出し時には2つを指定)ある点に注意してください。これは隠れ状態の初期値を、「net(入力, 隠れ状態の初期値)」のようにして、このクラスのインスタンスを呼び出す側が用意する必要があるということです。また、このメソッド内で行っている「self.rnn(x, hidden)」呼び出しでは受け取った入力と隠れ状態をRNNクラスのインスタンスに渡しています。その戻り値が2つある点にも注意が必要です。一つは次の層への出力に、もう一つは現在の入力をした時点での隠れ状態を表すオブジェクトになります。
また、全結合層には「self.fc(output[:, -1])」のようにして、計算結果の一部のみを渡しています。詳しくは次回以降に見ますが、これは25個の時系列データごとに出力される25個の計算結果(これは隠れ状態のサイズと同じサイズのテンソルになります)の中で最後のものだけを取り出す操作です。これを全結合層に送り込むと、最終的な計算結果、つまり「25個の時系列データを基にニューラルネットワークモデルが推測した次の値(を要素とするテンソル)」が得られるというわけです。
最後に、PyTorchのRNNクラスでは、RNNを何層にするかをインスタンス生成時に指定できます(名前付き引数num_layersで指定。デフォルト値は1です)。ここではデフォルト値を採用することにします。RNNを2層以上にする例は次回以降に紹介する予定です。
クラスの定義は以上です。次に学習をしてみましょう。
Copyright© Digital Advantage Corp. All Rights Reserved.