VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech

Author: smly

smly

Sun, Apr 09, 2023

TTS (Text-to-speech) の代表的な手法である VITS の論文を読んだ。興味のあるところをメモする。

Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech

Several recent end-to-end text-to-speech (TTS) models enabling single-stage training and parallel sampling have been proposed, but their sample quality does not match that of two-stage TTS systems. In this work, we present a parallel end-to-end TTS method that generates more natural sounding audio than current two-stage models. Our method adopts variational inference augmented with normalizing flows and an adversarial training process, which improves the expressive power of generative modeling. We also propose a stochastic duration predictor to synthesize speech with diverse rhythms from input text. With the uncertainty modeling over latent variables and the stochastic duration predictor, our method expresses the natural one-to-many relationship in which a text input can be spoken in multiple ways with different pitches and rhythms. A subjective human evaluation (mean opinion score, or MOS) on the LJ Speech, a single speaker dataset, shows that our method outperforms the best publicly available TTS systems and achieves a MOS comparable to ground truth.

概要

VITS は Variational Inference with adversarial learning for end-to-end Text-to-Speech から文字を取ったもの。 HiFi-GAN V1 の Decoder と GAN、Glow-TTS から Text Encoderや音素継続時間推定、アライメントを組み合わせ改良している。HiFi-GAN V1 の Decoder 入力は潜在変数 zz に変更され、話者の多様性を考慮した形となっている。音素継続時間推定は Glow-TTS や FastSpeech の deterministic な方法ではなく、Flow を使った generative model としてデザインしている。

音素と音声の潜在変数を導入し、VAE により損失関数を介して接続して学習する。最終的な損失関数は以下となる。

Lvae=  Lrecon+Lkl+Ldur+Ladv(G)+Lfm(G)\begin{aligned} L_{vae} = &\; \color{#2A2}L_{recon} + L_{kl}\color{#000} + \color{#22A}L_{dur}\color{#000} \\ &+ \color{#A22}L_{adv}(G) + L_{fm}(G)\color{#000} \end{aligned}

緑は VAE の損失関数、青は Stochastic Duration Predictor の損失関数、赤は GAN training の損失関数。

訓練時と推論時の手続き

(a) 訓練時と (b) 推論時で手続きが異なる。 訓練時はスペクトログラム xlinx_{lin} を入力して Posterior Encoder によって潜在変数 zz を生成し、それを Decoder に入力する。 Decoder から生成された Raw waveform y^\hat{y} は Descriminator によって Adversarial training を行う。これによって潜在変数 zz を介してスペクトログラムから音声波形を復元できるように学習する。

次の列にある Flow は可逆なニューラルネットを使って変換を行う手法。 VITS では Flow を話者の多様性の正規化を行うために用い、これを Normalizing flow fθ(z)f_{\theta}(z) と呼ぶ。 この関数は話者の特徴を取り除く働きをする。 逆関数 fθ1(z)f^{-1}_{\theta}(z) は特定の話者の声に対応するように変換を行う。これは推論時に用いる。 逆関数を使うことで特定話者への音声変換に応用できる。

テキストは IPA音素列 ctextc_{text} に変換し、これを Text Encoder (prior encoder) の入力として使用する。OSS の phonemizer で変換して Glow-TTS の実装に従って blank token で区切り作成している。 Text Encoder から得られた潜在表現 htexth_{text} に対して Projection を行い、音素ごとに μθ\mu_{\theta}, σθ\sigma_{\theta} を計算する。

音素列と fθ(z)f_{\theta}(z) とアライメント AActext×z|c_{text}| \times |z| dimension の2値行列として表現され、 音素ごとの継続時間を dd と表現する。ddAA が monotonic なアライメントであるため d=jAi,jd = \sum_j A_{i, j} な離散値となる。

アライメント手法の Monotonic Alignment Search (MAS) と 音素継続時間の推定を行う Stochastic Duration Predictor については後述する。

Monotonic Alignment Search (MAS)

テキストから得られた音素ごとの表現とスペクトログラムから得られた音声表現のアライメントをとるには、Monotonic Alignment Search (MAS) を用いる。これは Glow-TTS の論文で提案されているため、Glow-TTS の論文を参照する。

上図ではアライメント AA を text representaiton hh と latent representation zz に対する行列として表現し、対数尤度の累積和 Qi,jQ_{i,j}-\infty で初期化した上で左上から Qi,jQ_{i,j} を埋めていき DP によって解く。最後に右下から左と上を比較していくことでアライメントを得ることができる。 DP の手続きは以下の通り。

vits 公式実装では MAS を Cython で実装している。

Stochastic Duration Predictor

訓練時には音声特徴と音素特徴のアライメントをとることで音素継続時間 dd を計算した。 推論時は音声側は生成しなくてはいけないので音素継続時間 dd を推定する必要がある。

先行研究の Glow-TTS は FastSpeech で提案された Duration Predictor を用いている。 FastSpeech では音素継続時間を推定するために Transformer を用いて MSE Loss によって学習している。 これらの deterministic な duration predictor では人が毎回異なる話す速度で発話する様子を表現することができない。そのため VITS の duration predictor は、音素が与えらたとき duration distribution に従ってサンプリングするように generative model としてデザインしている。

この音素継続時間は音声表現 fθ(z)f_{\theta}(z) の長さ(MAS でアライメントをとった、各音素に対応する連続した音声表現の個数)であるため整数値である。連続値ではない。 そのためノイズを加えて連続値を作り、音素継続時間を確率的に生成する。

Stochastic Duration Predictor では ctextc_{text} が与えられたときの dd の対数尤度を最大化したい。 音素継続時間 dd に対して同じ時間解像度かつ次数のランダム変数 u,νu, \nu を導入する。それぞれvariational dequantization と variational data augmentation のための変数。 近似事後分布 qθ(u,νd,ctext)q_{\theta} (u,\nu|d, c_{text}) から u,νu, \nu をサンプリングする。 このとき変分下限は、

logpθ(dctext)Eqθ(u,νd,ctext)[logpθ(du,νctext)qθ(u,νd,ctext)]\log p_{\theta} (d|c_{text}) \leq \mathbb{E}_{q_{\theta}(u,\nu|d, c_{text})} \left[ \log \frac{p_{\theta} (d - u, \nu|c_{text})}{q_{\theta} (u,\nu|d, c_{text})} \right]

学習ではこの変分下限を最大化するように学習する。これは負の変分下限を最小化することと等価。

Ldur=Eqθ(u,νd,ctext)[logpθ(du,νctext)qθ(u,νd,ctext)]L_{dur} = - \mathbb{E}_{q_{\theta}(u,\nu|d, c_{text})} \left[ \log \frac{p_{\theta} (d - u, \nu|c_{text})}{q_{\theta} (u,\nu|d, c_{text})} \right]

また入力の勾配が逆伝播しないように stop gradient operator を input condition に適用する。 これによって Stochastic Duration Predictor の学習は他のモジュールには影響しない。

訓練と推論の手続きについては Supplementary Material の B.2 を参照する。

訓練時は音素表現 htexth_{text}, 音素継続時間 dd, ノイズ ϵQ\epsilon_{Q} を入力として posterior encoder で u,νu, \nu を生成する。また du,νd-u, \nu を入力として flow (gθg_{\theta}) を通して ϵD\epsilon_{D} を生成する。推論時には flow の逆変換によって du,νd-u, \nu を生成する。

Stochastic Duration Predictor の実装

GitHub - jaywalnut310/vits: VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech

VITS: Conditional Variational Autoencoder with Adversarial Learning for End-to-End Text-to-Speech - GitHub - jaywalnut310/vits: VITS: Conditional Variational Autoencoder with Adversarial Learning f...

vits の公式実装では models.py の StochasticDurationPredictor として実装されている。 forward() メソッドに Text Encoder の出力 htexth_{text} に相当する x を入力として与える。reverse という引数があり、これが True のときは inference のために逆変換を行う。False のときは訓練時の forward を行うため音素継続時間 dd が必要。これは引数 w で与える。推論時には不要。

# https://github.com/jaywalnut310/vits/ よりコメントを追加して引用
  def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
    # Detach して stop gradient する
    x = torch.detach(x)
    x = self.pre(x)

    # 話者の embeddings が与えられている場合は detrach して加算
    if g is not None:
      g = torch.detach(g)
      x = x + self.cond(g)

    # DDSConv
    x = self.convs(x, x_mask)
    x = self.proj(x) * x_mask

訓練時には posterior encoder で u,νu, \nu を生成して Flow へ入力する。

    # 訓練時 (`reverse=False`)
    if not reverse:
      # 順方向の変換
      flows = self.flows
      # 音素継続時間 $d$ が与えられていないときはエラー
      assert w is not None

      # `logq` を計算するための変数
      logdet_tot_q = 0

      # condition encoder で $d$ を encode する
      h_w = self.post_pre(w)
      h_w = self.post_convs(h_w, x_mask)
      h_w = self.post_proj(h_w) * x_mask

      # posterior encoder の入力ノイズをサンプリングする
      e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
      # ノイズ $\epsilon_{Q}$ を初期値として posterior encoder への入力を `z_q` に更新していく
      z_q = e_q
      for flow in self.post_flows:
        z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
        logdet_tot_q += logdet_q

      # z1 は $\nu$ に相当する
      z_u, z1 = torch.split(z_q, [1, 1], 1) 
      u = torch.sigmoid(z_u) * x_mask

      # z0 は $d - u$ に相当する
      z0 = (w - u) * x_mask
      logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1,2])
      logq = torch.sum(-0.5 * (math.log(2*math.pi) + (e_q**2)) * x_mask, [1,2]) - logdet_tot_q

      logdet_tot = 0
      z0, logdet = self.log_flow(z0, x_mask)
      logdet_tot += logdet

      # $d - u, \nu$ を入力として flow ($g_{\theta}$) を通して $\epsilon_{D}$ を生成する
      z = torch.cat([z0, z1], 1)
      for flow in flows:
        z, logdet = flow(z, x_mask, g=x, reverse=reverse)
        logdet_tot = logdet_tot + logdet
      nll = torch.sum(0.5 * (math.log(2*math.pi) + (z**2)) * x_mask, [1,2]) - logdet_tot
      return nll + logq # [b]

reverse=True の推論時は以下のように posterior encoder を適用することなく Flow から続く。

    else:
      flows = list(reversed(self.flows))
      flows = flows[:-2] + [flows[-1]] # remove a useless vflow

      # $\epsilon_{D}$ をサンプリング
      z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale

      # conditional encoder を適用した z と $\epsilon_{D}$ を入力して $d-u, \nu$ を生成する
      for flow in flows:
        z = flow(z, x_mask, g=x, reverse=reverse)
      z0, z1 = torch.split(z, [1, 1], 1)
      logw = z0
      return logw

感想

Flow を使ったモデルにはじめて触れたので興味深く読んだ。Normalizing flow の有無によって評価が大きく改善していると言及されており、実際に検証して確認してみたい。