Tacotron2を調べてみた3
前回はDecoderの概要部分でしたので、次は中身を見ていきたいと思います。
2.6 decode
クラスDecoderのinference中の次の文の中身を見てみます。
mel_output, gate_output, alignment = self.decode(decoder_input)
この文は、クラスDecoderのメソッドdecodeを呼び出しています。メソッドdecodeを次に示します。decoder_inputは前回推定したmel spectrogramをprenetに通したものでで大きさ(1,256)のテンソルです。
def decode(self, decoder_input):
""" Decoder step using stored states, attention and memory
PARAMS
------
decoder_input: previous mel output
RETURNS
-------
mel_output:
gate_output: gate output energies
attention_weights:
"""
cell_input = torch.cat((decoder_input, self.attention_context), -1)
self.attention_hidden, self.attention_cell = self.attention_rnn(
cell_input, (self.attention_hidden, self.attention_cell))
self.attention_hidden = F.dropout(
self.attention_hidden, self.p_attention_dropout, self.training)
attention_weights_cat = torch.cat(
(self.attention_weights.unsqueeze(1),
self.attention_weights_cum.unsqueeze(1)), dim=1)
self.attention_context, self.attention_weights = self.attention_layer(
self.attention_hidden, self.memory, self.processed_memory,
attention_weights_cat, self.mask)
self.attention_weights_cum += self.attention_weights
decoder_input = torch.cat(
(self.attention_hidden, self.attention_context), -1)
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
decoder_input, (self.decoder_hidden, self.decoder_cell))
self.decoder_hidden = F.dropout(
self.decoder_hidden, self.p_decoder_dropout, self.training)
decoder_hidden_attention_context = torch.cat(
(self.decoder_hidden, self.attention_context), dim=1)
decoder_output = self.linear_projection(
decoder_hidden_attention_context)
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
return decoder_output, gate_prediction, self.attention_weights
ソースコードだと全体像がわかりずらいので図にしました。
decodeの入力
・decoder_input 前回のmel spectrogramをprenetに通したもの
・encoder_output encoderの出力
decoderの出力
・attention_weights 位置検出結果、出力は状態確認(表示)で使用しているのみ
・gate_prediction 終了判断に使用
・decoder_output mel spectrogram
構造を見てわかるように、attention_hidden、processed_memory、attention_weights_catの3信号をQueryとして、encoder出力をKey及びValueとしてAttentionを行っています。
ここから先は推測まじりですが、
終了を最後のgate_layerで判定できているので、attention_layerの出力attention_contextにどこまで進んだかを示す信号が含まれていると思われます。
encoderの出力は文字単位なので、attention_contextも基本的には文字単位のデータなのですが、このデータの中に1文字の間に変化する音の情報も含まれていると思われます。この後の処理のどこかで、このデータから今の瞬間の音に関する情報を抽出してmel spectrogramを生成しているのでしょう。ただし、attentionの後ろに複雑なネットワーク構造は無いので、1文字中ではあまり複雑な変化に対応できないかも知れません。
2.7 Location Sensitive Attenation
こちらもわかりやすよう図にしました。
get_alignment_energiesで現在が各文字の位置にある可能性を推測、softmaxで正規化しています。その後、現在の位置に相当する文字に相当する情報をattention_contextとしてエンコーダの出力から抽出します。
2.6 decodeの項に戻り、さらにいくつか処理をして現在のmel_spectrogramと、終了を示すフラグgate_predictionを推測します。
2.8 Post Net
Post Netは5層の1次元Convolution Network(畳み込みニューラルネットワーク)です。
mel_outputs_postnet = self.postnet(mel_outputs)
mel_outputs_postnet mel_outputs_postnet = mel_outputs + mel_outputs_postnet
Postnetの仕様は次の通りです。
・入力 80
・フィルタ数 512 フィルタ長 5
・出力80
mel_outputsの大きさは(1,80,n)、nは時間方向のspectrum数です。
時間方向に畳み込みをしてmel spectrogramの時間方向のなめらかさ?改善のための補正量を推定し、Post Netを通る前のmel_outputsに加算して補正をかけます。
個人的には補正前後で大きな違いは感じられませんでした。
2.9 Tactron2.inferenceの出力
クラスTactoron2のinferenceからparse_output経由で出力されます。
outputs = self.parse_output(
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
return outputs
・mel_outputs、mel_outputs_postnet
mel spectrogramです。テンソルのサイズは(1,80,n)でnは時間軸方向のデータ数です。
・gate_outputs
終了フラグです。サイズは(1,n,1)です。その時々のgateの値を保管したものです。基本的に、音声合成では使いません。
・alignments
位置あわせを示すテンソルです。サイズは(1,n,文字数)でnは時間軸方向のデータ数です。
3. 補足 Tacotron2のinference実行
Tacotron2のプロジェクトフォルダにあるJupitorノートブックでは次の流れで呼び出しています。一応補足まで。
①from train import load_model
load_modelをインポート
②model = load_model(hparams)
load_model呼び出して、modelを生成。
③train.pyのload_model
def load_model(hparams):
model = Tacotron2(hparams).cuda()
・・・
return model
ここで、クラスTacotron2を生成してリターンしているので、②のmodelはTacotron2になる。
④inference呼び出し
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)
model=Tacotron2なので、Tacotron2のメソッドinferenceを呼び出します。ここから今まで見てきた処理が始まります。
4.まとめ
NVIDIA実装をもとにTacotron2の構造を見てみました。次回はJetson Nanoで動かしてみます。