Skip to content

Commit c9cc855

Browse files
committed
bugfix: adpat new changes on codec inference and process_batch
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
1 parent d1dab37 commit c9cc855

File tree

1 file changed

+29
-12
lines changed

1 file changed

+29
-12
lines changed

nemo/collections/tts/models/magpietts.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
import time
1919
from dataclasses import dataclass, field, fields
2020
from functools import partial
21-
from typing import Any, Dict, List, Optional, Tuple, Union
2221
from pathlib import Path
22+
from typing import Any, Dict, List, Optional, Tuple, Union
2323

2424
import numpy as np
2525
import soundfile as sf
@@ -1758,7 +1758,7 @@ def _prepare_audio_examples(
17581758
self,
17591759
logits: torch.Tensor,
17601760
target_audio_codes: torch.Tensor,
1761-
audio_codes_lens_target: torch.Tensor,
1761+
audio_codes_lens: torch.Tensor,
17621762
context_audio_codes: Optional[torch.Tensor] = None,
17631763
context_audio_codes_lens: Optional[torch.Tensor] = None,
17641764
max_examples: int = 3,
@@ -1769,7 +1769,7 @@ def _prepare_audio_examples(
17691769
Args:
17701770
logits: Model output logits to convert to predicted audio.
17711771
target_audio_codes: Ground truth audio codes.
1772-
audio_codes_lens_target: Lengths of target audio codes.
1772+
audio_codes_lens: Lengths of target audio codes.
17731773
context_audio_codes: Optional context audio codes for voice cloning.
17741774
context_audio_codes_lens: Lengths of context audio codes.
17751775
max_examples: Maximum number of examples to process.
@@ -1779,15 +1779,28 @@ def _prepare_audio_examples(
17791779
each containing a list of numpy arrays (or None for context if unavailable).
17801780
"""
17811781
with torch.no_grad():
1782-
# Decode predictions and targets
1783-
pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens_target)
1784-
pred_audio, pred_audio_lens = self.codes_to_audio(pred_audio_codes, audio_codes_lens_target)
1785-
target_audio, target_audio_lens = self.codes_to_audio(target_audio_codes, audio_codes_lens_target)
1782+
# Decode predictions: convert logits to codes, remove EOS token, then decode to audio
1783+
pred_audio_codes = self.logits_to_audio_codes(logits, audio_codes_lens)
1784+
pred_audio_codes, pred_audio_codes_lens = self.remove_eos_token(
1785+
codes=pred_audio_codes, codes_len=audio_codes_lens
1786+
)
1787+
pred_audio, pred_audio_lens, _ = self.codes_to_audio(pred_audio_codes, pred_audio_codes_lens)
1788+
1789+
# Decode targets: remove EOS token, then decode to audio
1790+
target_audio_codes, target_audio_codes_lens = self.remove_eos_token(
1791+
codes=target_audio_codes, codes_len=audio_codes_lens
1792+
)
1793+
target_audio, target_audio_lens, _ = self.codes_to_audio(target_audio_codes, target_audio_codes_lens)
17861794

17871795
# Decode context audio if available (shape check ensures it's not a dummy tensor used in text context)
17881796
context_audio, context_audio_lens = None, None
17891797
if context_audio_codes is not None and context_audio_codes.shape[2] > 3:
1790-
context_audio, context_audio_lens = self.codes_to_audio(context_audio_codes, context_audio_codes_lens)
1798+
context_audio_codes, context_audio_codes_lens = self.remove_special_tokens(
1799+
codes=context_audio_codes, codes_len=context_audio_codes_lens
1800+
)
1801+
context_audio, context_audio_lens, _ = self.codes_to_audio(
1802+
context_audio_codes, context_audio_codes_lens
1803+
)
17911804

17921805
pred_audios = []
17931806
target_audios = []
@@ -1857,8 +1870,12 @@ def _log_media_to_wandb_and_tb(self, media_data: Dict, global_step: int) -> None
18571870
audio_list.append(
18581871
wandb.Audio(context_audio_np, sample_rate=self.output_sample_rate, caption="context")
18591872
)
1860-
audio_list.append(wandb.Audio(pred_audio_np, sample_rate=self.output_sample_rate, caption="prediction"))
1861-
audio_list.append(wandb.Audio(target_audio_np, sample_rate=self.output_sample_rate, caption="target"))
1873+
audio_list.append(
1874+
wandb.Audio(pred_audio_np, sample_rate=self.output_sample_rate, caption="prediction")
1875+
)
1876+
audio_list.append(
1877+
wandb.Audio(target_audio_np, sample_rate=self.output_sample_rate, caption="target")
1878+
)
18621879
wandb_log_dict[f"Audio:{dataset_prefix}/Example_{idx}"] = audio_list
18631880

18641881
if is_tb:
@@ -2702,7 +2719,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
27022719
batch_idx: Batch index
27032720
dataloader_idx: Index of the dataloader (0 for single dataloader)
27042721
"""
2705-
batch_output = self.process_batch(batch, mode="val")
2722+
batch_output = self.process_batch(batch)
27062723
# self.process_batch returns a dict. We currently only log "logits" which come from the parallel prediction
27072724
# head. If we use local_transformer, then the local_transformer returns "local_transformer_logits"
27082725

@@ -2744,7 +2761,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
27442761
audio_data = self._prepare_audio_examples(
27452762
logits=logits,
27462763
target_audio_codes=audio_codes_target,
2747-
audio_codes_lens_target=audio_codes_lens_target,
2764+
audio_codes_lens=audio_codes_lens_target,
27482765
context_audio_codes=context_audio_codes,
27492766
context_audio_codes_lens=context_audio_codes_lens,
27502767
max_examples=3,

0 commit comments

Comments
 (0)