検索
連載

ディープラーニングで自動筆記 − Kerasを用いた文書生成(後編)ディープラーニング習得、次の一歩(3/3 ページ)

「文書生成」チャレンジの後編。ネットワークにLSTM、ライブラリにKeras+TensorFlowを採用し、さらに精度を改善していく。最後に、全然関係ない入力文章から、江戸川乱歩風文書が生成されるかを試す。

Share
Tweet
LINE
Hatena
前のページへ |       

江戸川乱歩に無関係な文章を入力にしてみる

 江戸川乱歩に無関係の、訓練データでない文章を入力にして、どの程度、江戸川乱歩風の文章が生成されるか試してみる。

 文書生成メイン処理(つまり先ほどのリスト8-2)を、以下のリストに取り換える。それ以外は、先ほどと同じリストの実行順で実行する。

df_input = csv.reader(open('input.csv', 'r'))
data_input = [ v for v in df_input]
mat_input = np.array(data_input)
mat_input = mat_input[:, 0]

print(mat_input.shape)
text_gen = ''                         # 生成テキスト
x_validation = np.zeros((1, maxlen))  # 入力データ
# 入力データx_validationに入力文の単語インデックスを設定
for i in range(0, len(mat_input)) :
  text_gen += mat_input[i]
  #x_validation 1文字シフト
  x_validation[0, 0:maxlen-1] = x_validation[0, 1:maxlen]
  if mat_input[i] in words :
    x_validation[0, maxlen-1] = word_indices[mat_input[i]]
  else :
    x_validation[0, maxlen-1] = word_indices['UNK']

print(text_gen)
print()

row = x_validation.shape[0]            # 評価データ数

print()

flag_1 = flag_2 = flag_3 = 0
# 応答文生成
for k in range (0, 400) :
  # 単語予測
  # 300
  ret_0 = model_classify_freq_0.predict(x_validation, batch_size=batch_size, verbose=0)     # 評価結果
  ret_0 = ret_0.reshape(row, n_sigmoid)
  flag_0 = ret_0[0, 0]
  # 最大値インデックス
  if flag_0 < 0.5 :                     # 300未満
    ret_1 = model_classify_freq_1.predict(x_validation, batch_size=batch_size, verbose=0)   # 評価結果
    ret_1 = ret_1.reshape(row, n_sigmoid)
    flag_1 = ret_1[0, 0]
    if flag_1 < 0.5 :                   # 28未満
      ret_2 = model_classify_freq_2.predict(x_validation, batch_size=batch_size, verbose=0) # 評価結果
      ret_2 = ret_2.reshape(row, n_sigmoid)
      flag_2 = ret_2[0, 0]
      if flag_2< 0.5 :                  # 10未満
        pred_freq = 0
        ret = model_words_0_10.predict(x_validation, batch_size=batch_size, verbose=0)
        ret_w0 = ret.argmax(1)[0]
        ret_word = indices_w0[0][ret_w0]
      else :                            # 10以上28未満
        pred_freq = 1
        ret = model_words_10_28.predict(x_validation, batch_size=batch_size, verbose=0)
        ret_w0 = ret.argmax(1)[0]
        ret_word = indices_w0[1][ret_w0]
    else :                              # 28以上
      ret_3 = model_classify_freq_3.predict(x_validation, batch_size=batch_size, verbose=0) # 評価結果
      ret_3 = ret_3.reshape(row, n_sigmoid)
      flag_3 = ret_3[0, 0]
      if flag_3 <0.5 :                  # 28以上100未満
        pred_freq = 2
        ret = model_words_28_100.predict(x_validation, batch_size=batch_size, verbose=0)
        ret_w0 = ret.argmax(1)[0]
        ret_word = indices_w0[2][ret_w0]
      else :                            # 100以上300未満
        pred_freq = 3
        ret = model_words_100_300.predict(x_validation, batch_size=batch_size, verbose=0)
        ret_w0 = ret.argmax(1)[0]
        ret_word = indices_w0[3][ret_w0]
  else :                                # 300以上
    ret_4 = model_classify_freq_4.predict(x_validation, batch_size=batch_size, verbose=0)   # 評価結果
    ret_4 = ret_4.reshape(row, n_sigmoid)
    flag_4 = ret_4[0, 0]
    if flag_4 <0.5 :                    # 300以上2000未満
      pred_freq = 4
      ret = model_words_300_2000.predict(x_validation, batch_size=batch_size, verbose=0)
      ret_w0 = ret.argmax(1)[0]
      ret_word = indices_w0[4][ret_w0]
    else :                              # 2000以上
      ret_5 = model_classify_freq_5.predict(x_validation, batch_size=batch_size, verbose=0) # 評価結果
      ret_5 = ret_5.reshape(row, n_sigmoid)
      flag_5 = ret_5[0, 0]
      if flag_5 < 0.5 :                 # 2000以上15000未満
        pred_freq = 5
        ret = model_words_2000_15000.predict(x_validation, batch_size=batch_size, verbose=0)
        ret_w0 = ret.argmax(1)[0]
        ret_word = indices_w0[5][ret_w0]
      else :                            # 15000以上
        pred_freq = 6
        ret = model_words_15000_400000.predict(x_validation, batch_size=batch_size, verbose=0)
        ret_w0 = ret.argmax(1)[0]
        ret_word = indices_w0[6][ret_w0]


  print(pred_freq, '\t', indices_word[ret_word])
  text_gen += indices_word[ret_word]    # 生成文字を追加

  x_validation[0, 0:maxlen-1] = x_validation[0, 1:maxlen]
  x_validation[0, maxlen-1] = ret_word  # 1文字シフト

print()
print(text_gen)

リスト8-3 文書生成メイン処理(出力次元削減対応)

 リスト8-2との違いは、最初の方の、入力文字列設定処理である。リスト8-2では、訓練データの中から初期値を選択していたが、リスト8-3では、入力ファイル「input.csv」で与えられる文字列を入力している。input.csvは入力文章をJUMAN++を用いてCSVファイルに整形したものである。

 今回は、川端康成「雪国」の冒頭を使ってみた。その結果は以下のとおりである。少し長めに(200語)、文書生成してみた。

 まず、「お題」となる入力文字列は以下のとおり。

「国境の長いトンネルを抜けると雪国であった。夜の底が白くなった。信号所に汽車が止まった。向側の座席から娘が立って来て、島村の前のガラス窓を落した。」

 これに続く文章の生成結果は以下のとおり。

「そして、中央に通り、見通しの末が記されていた。そこに転がっている。だが、黒い影が見た。そして、さして残してもどうことだ」そこで私は、これがどんな態度で、死んだ大阪の町から鳥異様な見えたでもあるというよりも、雪子さんのそっくりの全身、そっとあの興奮を取出し、つくさえも、まるで忘れてしまうようなのか、表題潤様子を睨みつけていたが、この世に飽きUNKいた。この玉の中には、一体なんのためにも、その勢いをUNKことなぞできない。だが、どちらかといえば、若い女性であったが、ああした行動をUNK終ると、UNKにUNKれた闇の中に、さいぜん老人は、真赤な糸《くだ》を出し、手足を綺麗にUNK、やっと気を沈めて、ソッと註文を示す。「掘り返して、ぼんやりとしたものだ」のであろう、右の方に、三つの廊下から、イライラし実験室にUNK時、その薄暗いボートをUNK、用意のような様子で、黒い覆面の黒布が、こみ上げてくるおかしさにつれ去る×UNKられていると、サッと大阪の腕を掴んで、口から飛びこんで行った。「あんたは、真にせまったと鉄一円札の束をUNK力で、それに属する女を、この大阪のポケットから上をたよりにUNKいた。そして、その一人はだれもない。服そうに髭にUNKと、また躊躇をUNKない場合が、小さく固っていた。彼らは三尺程の写真を取出して、写真で大丈夫UNK別にUNKを待つ始めた。くれるぞ、それは彼の耳の側へ駈けつけて置いて見たが、UNKも哀れな、一通の病気のステッキをUNK。つまりトランクの男によって千円の如何なる水泳ことの外には何の動作も、直ちにUNK行って、」

 なお、正解は当然ないので示していない。

 訓練データでは100発100中だったので、もう少し小説らしい文章になると期待していたが、汎化性能を犠牲にしているので、これはやむを得ないところかもしれない。

おわりに

 今回の記事には入っていないが、「UNK」の復元にも挑戦してみた。以下のリンク先の、筆者のプレゼンの中で言及している。

 今回の方式は、コンセプトが分かりやすく、当初は実装も容易と思われたので、読める文章が生成されることを期待していた。正解率を向上させれば使い物になると考えたが、任意の文字列を入力にしたときにはいまひとつであった。訓練データの規模をさらに大きくすれば、あるいは、より精度の高い文書生成が可能になるかもしれない。

 今回は、正解率を向上させるのに苦労した。恐らく訓練データのばらつきが大きかったためと想像している。分類数を少なくしたり、似たような傾向のデータをグルーピングしたりすると正解率を向上させることができたが、これは他の局面にも応用できると考えている。

「ディープラーニング習得、次の一歩」のインデックス

ディープラーニング習得、次の一歩

Copyright© Digital Advantage Corp. All Rights Reserved.

前のページへ |       
[an error occurred while processing this directive]
ページトップに戻る