ニューラルネットワークの内部では何が行われている?作って試そう! ディープラーニング工作室(3/3 ページ)

» 2020年04月21日 05時00分 公開
[かわさきしんじDeep Insider編集部]
前のページへ 1|2|3       

重みとバイアスの更新

 重みやバイアスは学習の過程で更新されていきます。そこで、今回は学習を一度だけ行って、重みが更新されている様子を確認するだけとしましょう。実際のコードを以下に示します。

criterion = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.03)

print('before learning')
print('weight')
print(net.fc1.weight)

print('learn once')
outputs = net(X_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()

print('after learning')
print('weight')
print(net.fc1.weight)

学習を一度だけ行い、入力層の重みとバイアスがどう変化したかを確認するコード

 学習を行うコード自体は前回までに見てきたものをさらにシンプルにしただけのものです。Netクラスのインスタンスに訓練データを渡し、その出力から損失関数で損失を計算して、それを基に重みやバイアスを更新していると思ってください。

 実行結果を以下に示します。

実行結果 実行結果

 ほんの少しですが、学習前と学習後で重みとバイアスが更新されているのが分かるはずです。実際の学習では、こうした処理を何度も何度も繰り返して、訓練データを基にした推測結果が正解ラベルに近い値となるまで、重みとバイアスを更新していくことになります。これについては次回に詳しく見ていくことにします。

呼び出し可能オブジェクト

 本稿では「呼び出し可能オブジェクト」という言葉が出てきたので、これについて最後に少し説明しておきましょう。よく知らないし、知りたくもないという方は、読み飛ばしても構いません。ただし、PyTorchでは「ニューラルネットワーククラスのインスタンスは、関数やメソッドのように呼び出せる」ことだけは覚えておきましょう。

 Pythonには「呼び出し可能オブジェクト」というオブジェクトがあります。関数やメソッドはもちろん呼び出し可能オブジェクトの代表的な存在です。

 ですが、上で見たようにLinearクラスのインスタンス(あるいは、上で定義しているNetクラスのインスタンス)もまた呼び出し可能オブジェクトです。Pythonでは、あるオブジェクトが__call__特殊メソッドを持っている場合、それは関数のように呼び出すことができるようになっていて、Linearクラスのインスタンスはまさにこのメソッドを持っているのです。

 実際にあることを確認してみましょう。

linear = nn.Linear(4, 5)
print(linear.__call__)

Linearクラスのインスタンスに__call__メソッドがあるか

 これを実行すると次のようになります。

実行結果 実行結果

 この結果を見ると、Linearクラスのインスタンスの__call_メソッドは、Module.__call__メソッドに束縛されていることが分かります。詳細なことはともかくとして、これでLinearクラスのインスタンスには__call__メソッドがあり、関数のように呼び出せることが分かりました。

 __call__メソッドを持つオブジェクトは、インスタンス名にかっこ「()」を付加して、そこに0個以上の引数を渡すことで呼び出せます。呼び出し可能オブジェクトを使って呼び出しを行うと、Pythonの処理系により、対応する__call__メソッドが呼び出されて、そのパラメーターに呼び出し可能オブジェクトに渡した引数が渡されるようになっています(そのため、__call__メソッドのパラメーターリストと、オブジェクトに渡す引数リストが一致している必要があります)。

 先ほど「Linearクラスのインスタンスの__call_メソッドは、Module.__call__メソッドに束縛されている」と述べたことから分かる通り、PyTorchではニューラルネットワークの基底クラスであるModuleクラスで__call_メソッドが定義されています。そのため、ニューラルネットワーククラスのインスタンスは全て関数のように呼び出せます。

 ちなみに、Module.__call__メソッドの内部ではいろいろと処理を行っていますが、上述したように、ニューラルネットワーククラスのforwardメソッドは__call__メソッドから内部的に呼び出されるような構造になっています。そのため、上で定義したNetクラスのインスタンスをnetとすると、「net(入力データ)」のようにするだけで、そのforwardメソッドが呼び出されます。この構造を利用して、Netクラスでは、forwardメソッドでこのニューラルネットワークで行う計算処理を記述するようになっているのです。


「作って試そう! ディープラーニング工作室」のインデックス

作って試そう! ディープラーニング工作室

前のページへ 1|2|3       

Copyright© Digital Advantage Corp. All Rights Reserved.

アイティメディアからのお知らせ

スポンサーからのお知らせPR

注目のテーマ

Microsoft & Windows最前線2026
人に頼れない今こそ、本音で語るセキュリティ「モダナイズ」
4AI by @IT - AIを作り、動かし、守り、生かす
AI for エンジニアリング
ローコード/ノーコード セントラル by @IT - ITエンジニアがビジネスの中心で活躍する組織へ
Cloud Native Central by @IT - スケーラブルな能力を組織に
システム開発ノウハウ 【発注ナビ】PR
あなたにおすすめの記事PR

RSSについて

アイティメディアIDについて

メールマガジン登録

@ITのメールマガジンは、 もちろん、すべて無料です。ぜひメールマガジンをご購読ください。