精度向上により、近年利用が広まっている「ニューラル機械翻訳」。その仕組みを、自分で動かしながら学んでみましょう。第2回はユースケースごとに「JoeyNMT」をカスタマイズする方法や、Discordのチャットbotに組み込む方法を解説します。
この記事は会員限定です。会員登録(無料)すると全てご覧いただけます。
ハイデルベルク大学の博士課程に在籍しながら、八楽という会社で「ヤラクゼン」の開発に携わっている太田です。ヤラクゼンは、AI翻訳から翻訳文の編集、ドキュメントの共有、翻訳会社への発注までを1つにする翻訳プラットフォームです。
第1回は、機械翻訳フレームワーク「JoeyNMT」の概要、インストール方法、モデルを訓練する方法を紹介しました。今回は、JoeyNMTをカスタマイズする方法を具体的なユースケースを交えながら紹介します。
JoeyNMTは、他のフレームワークに比べてコードの行数で9〜10分の1、ファイル数でも4〜5分の1(※1)というミニマルな実装が特長で、核となるモジュールはしっかり入っています。機械学習分野における多くのベンチマークでSOTA(State-of-the-Art)に匹敵するベンチマークスコアを出しています。またデバッグ時にstack traceをたどる際、フラットなディレクトリ構造のおかげで迷わずにエラー箇所を探し当てられるのもメリットです。
※1:OpenNMT-py、XNMTとの比較です。詳細は「Joey NMT: A Minimalist NMT Toolkit for Novices」を参照してください。
それでは、ユースケースごとにJoeyNMTをカスタマイズする方法を見ていきましょう。
JoeyNMTはデフォルトで「subword-nmt」「sentencepiece」という2つのサブワードトークナイザーに対応しています。では、別のトークナイザーを利用したい場合はどうすればよいでしょうか。
トークナイザーは「joeynmt/tokenizers.py」で定義できます。例として、「fastBPE」を新しく導入してみましょう。
fastBPEはsubword-nmtをc++で実装したライブラリです。「SubwordNMTTokenizer」クラスを継承することにします。
class FaseBPETokenizer(SubwordNMTTokenizer): def __init__(self, ...): try: # fastBPEライブラリをインポート import fastBPE except ImportError as e: logger.error(e) raise ImportError from e super().__init__(level, lowercase, normalize, [...], **kwargs) assert self.level == "bpe" # codes_path を取得 self.codes: Path = Path(kwargs["codes_path"]) assert self.codes.is_file(), f"codes file {self.codes} not found." # fastBPEオブジェクト self.bpe = fastBPE.fastBPE(self.codes) def __call__(self, raw_input: str, is_train: bool = False) -> List[str]: # fastBPE.apply() tokenized = self.bpe.apply([raw_input]) tokenized = tokenized[0].strip().split() # 系列の長さが指定の範囲内におさまっているか確認 if is_train and self._filter_by_length(len(tokenized)): return None return tokenized
これでfastBPEでのトークナイズができるようになりました。設定ファイルで「tokenizer_type: "fastbpe"」と選択できるようにするため「_build_tokenizer()」で「FaseBPETokenizer」を呼び出せるようにします。
def _build_tokenizer(cfg: Dict) -> BasicTokenizer: [...] if tokenizer_type == "sentencepiece": [...] elif tokenizer_type == "subword-nmt": [...] elif tokenizer_type == "fastbpe": assert "codes_path" in tokenizer_cfg tokenizer = FaseBPETokenizer( level=cfg["level"], lowercase=cfg.get("lowercase", False), normalize=cfg.get("normalize", False), max_length=cfg.get("max_length", -1), min_length=cfg.get("min_length", -1), **tokenizer_cfg, )
fastBPEにはcodesファイルが必要ですので「codes_path」が設定ファイルで指定されていることを確認しましょう。今回導入した「FaseBPETokenizer」オブジェクトを返すようにしています。
トークナイザーの「__call__()」は、データセットからインスタンスを取り出す際に呼び出されます。例えば「PlaintextDataset」では、「get_item()」内で呼び出されています。
def get_item(self, idx: int, lang: str, is_train: bool = None): [...] item = self.tokenizer[lang](line, is_train=is_train) return item
つまり、訓練、予測時の「for batch in data_iterator:」のイテレーションで「__getitem__()」がコールされるたびにトークナイズの関数も呼び出されることになります。これは、BPE dropoutを可能にするための実装です。もし、新しく導入するトークナイザーが重い計算を必要としたり、いつも決まった値を返したりするのであれば、データ読み込み時に呼び出される「pre_process()」でトークナイズすることを検討してください(「BaseTokenizer」にある「MosesTokenizer」を利用した事前分割の実装が参考になります)。
JoeyNMTは「torch.optim.lr_scheduler」に入っている「ReduceLROnPlateau」「StepLR」「ExponentialLR」の他、transformerでよく使われる「noamスケジューラー」を実装しています。別の学習率スケジューラーを使いたい場合はどうしたらよいでしょうか?
学習率スケジューラーは「joeynmt/builders.py」で定義できます。例として、Inverse Square Rootスケジュールを導入してみます。
class BaseScheduler: def step(self, step): """学習率を更新""" self._step = step + 1 rate = self._compute_rate() for p in self.optimizer.param_groups: p["lr"] = rate self._rate = rate def _compute_rate(self): raise NotImplementedError
「BaseScheduler」クラスに、そのステップでの学習率をオプティマイザのパラメーターに渡す部分が実装されています。学習率を計算する「_compute_rate()」関数をオーバーライドします。
Inverse Square Rootスケジュールは、ステップ数の二乗根に反比例するように学習率を減衰させます。加えて、warmupの期間は、学習率が線形に増加するようにし、warmupの終わりで与えられた学習率に到達するよう係数(decay_rate)を調節します。
class WarmupInverseSquareRootScheduler(BaseScheduler): def __init__( self, optimizer: torch.optim.Optimizer, peak_rate: float = 1.0e-3, warmup: int = 10000, min_rate: float = 1.0e-5, ): super().__init__(optimizer) self.warmup = warmup self.min_rate = min_rate self.peak_rate = peak_rate self.decay_rate = peak_rate * (warmup ** 0.5) def _compute_rate(self): if step < self.warmup: # 線形に増加 rate = self._step * self.peak_rate / self.warmup else: # 2乗のルートに反比例 rate = self.decay_rate * (self._step ** -0.5) return max(rate, self.min_rate)
今回導入したInverse Square Rootスケジューラーを設定ファイルから選択できるように「build_scheduler()」を変更します。
def build_scheduler(): [...] if scheduler_name == "plateau": [...] elif scheduler_name == "decaying": [...] elif scheduler_name == "exponential": [...] elif scheduler_name == "noam": [...] elif scheduler_name == "warmupinversesquareroot": scheduler = WarmupInverseSquareRootScheduler( optimizer=optimizer, peak_rate=config.get("learning_rate", 1.0e-3), min_rate=config.get("learning_rate_min", 1.0e-5), warmup=config.get("learning_rate_warmup", 10000), ) scheduler_step_at = "step"
訓練を途中で中断した際、その中断したところから再開できるよう、学習率の変数をチェックポイントに保存しています。スケジューラーで保存すべき変数が異なるため、スケジューラーごとに、どの変数を保存するのかを指定する必要があります。
Inverse Square Rootスケジューラーの場合、デフォルトで保存されるステップ数とそのステップ時の学習率に加えて「warmup」「decay_rate」「peak_rate」「min_rate」を保存します。
class WarmupInverseSquareRootScheduler(BaseScheduler): [...] def state_dict(self): super().state_dict() self._state_dict["warmup"] = self.warmup self._state_dict["peak_rate"] = self.peak_rate self._state_dict["decay_rate"] = self.decay_rate self._state_dict["min_rate"] = self.min_rate return self._state_dict def load_state_dict(self, state_dict): super().load_state_dict(state_dict) self.warmup = state_dict["warmup"] self.decay_rate = state_dict["decay_rate"] self.peak_rate = state_dict["peak_rate"] self.min_rate = state_dict["min_rate"]
機械翻訳では多くの場合、交差エントロピーが損失関数として使われており、JoeyNMTでもデフォルトになっています。損失関数をカスタマイズしたい場合、どうすればよいでしょうか?
損失関数は「jorynmt/loss.py」で定義できます。第3回で予定している音声翻訳で必要となる「CTC Loss」と呼ばれる損失関数を、少し先取りしてここで導入してみましょう。既存の「XentLoss」クラスを継承して新しいクラス「XentCTCLoss」を作り、PyTorchで実装されているCTC Lossを呼び出します。
CTC Lossを計算するには、blankを特殊なトークンとして扱う必要があり、そのblankのためのトークンIDを指定しなければなりません。新しくblankトークンを定義してもよいのですが、今回はBOSトークン「<s>」で代用することにします。
class XentCTCLoss(XentLoss): def __init__(self, pad_index: int, bos_index: int, smoothing: float = 0.0, zero_infinity: bool = True, ctc_weight: float = 0.3 ): super().__init__(pad_index=pad_index, smoothing=smoothing) self.bos_index = bos_index self.ctc_weight = ctc_weight self.ctc = nn.CTCLoss(blank=bos_index, reduction='sum')
「XentCTCLoss」では、すでにある交差エントロピーとCTCの重み付き和を返すようにします。
class XentCTCLoss(XentLoss): def forward(self, log_probs, **kwargs) -> Tuple[Tensor, Tensor, Tensor]: # CTC Loss の計算に必要な情報がkwargsに入っていることを確認 assert "trg" in kwargs assert "trg_length" in kwargs assert "src_mask" in kwargs assert "ctc_log_probs" in kwargs # 交差エントロピーを計算できるように変形 log_probs_flat, targets_flat = self._reshape(log_probs, kwargs["trg"]) # 交差エントロピーを計算 xent_loss = self.criterion(log_probs_flat, targets_flat) # CTC損失を計算 ctc_loss = self.ctc( kwargs["ctc_log_probs"].transpose(0, 1).contiguous(), targets=kwargs["trg"], # (seq_length, batch_size) input_lengths=kwargs["src_mask"].squeeze(1).sum(dim=1), target_lengths=kwargs["trg_length"] ) # 交差エントロピーとCTCの重み付き和を計算 total_loss = (1.0 - self.ctc_weight) * xent_loss + self.ctc_weight * ctc_loss assert total_loss.item() >= 0.0, "loss has to be non-negative." return total_loss, xent_loss, ctc_loss
損失関数は、モデルの「forward()」で呼ばれます。「joeynmt/model.py」の該当部分を変更し「XentCTCLoss」を呼び出せるようにします。
class Model(nn.Module): def forward(self, return_type: str = None, **kwargs): [...] # 通常のデコーダー出力の他、CTCのためのレイヤーからのデコーダー出力も取得 out, ctc_out = self._encode_decode(**kwargs) # デコーダー出力に対し、log_softmax(各トークンの確率)を計算 log_probs = F.log_softmax(out, dim=-1) # バッチごとに損失を計算 if isinstance(self.loss_function, XentCTCLoss): # CTCレイヤーからの出力についても、log_softmaxを計算 kwargs["ctc_log_probs"] = F.log_softmax(ctc_out, dim=-1) # XentCTCLossのforward()を呼び出す total_loss, nll_loss, ctc_loss = self.loss_function(log_probs, **kwargs) [...]
バックプロパゲーションに使われるのは重み付き和である「total_loss」だけですが、それぞれの損失関数の学習曲線をプロットするため、「nll_loss」「ctc_loss」も返すようにしています。
デコーダー(joeynmt/decoders.py)に、CTCLossの計算のためのレイヤーを追加しました。
class TransformerDecoder(Decoder): def __init__(self, ...): [...] self.ctc_output_layer = nn.Linear(encoder_output_size, vocab_size, bias=False) def forward(self, ...): [...] out = self.output_layer(x) ctc_output = self.ctc_output_layer(encoder_output) return out, x, att, None, ctc_output class Model(nn.Module): def _encode_decode(self, ...): [...] out, x, att, _, ctc_out = self._decode(...) return out, ctc_out
機械翻訳の出力結果でよくあるのが、繰り返しです。例えば、配布している英日モデルを用いたwmt20テストセットで、以下のような出力を確認しました。
入力:"He begged me, "grandma, let me stay, don't do this to me, don't send me back,"" Hernandez said.
出力:「おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん、おばあちゃん」
Copyright © ITmedia, Inc. All Rights Reserved.