1818import time
1919from dataclasses import dataclass , field , fields
2020from functools import partial
21- from typing import Any , Dict , List , Optional , Tuple , Union
2221from pathlib import Path
22+ from typing import Any , Dict , List , Optional , Tuple , Union
2323
2424import numpy as np
2525import soundfile as sf
@@ -1758,7 +1758,7 @@ def _prepare_audio_examples(
17581758 self ,
17591759 logits : torch .Tensor ,
17601760 target_audio_codes : torch .Tensor ,
1761- audio_codes_lens_target : torch .Tensor ,
1761+ audio_codes_lens : torch .Tensor ,
17621762 context_audio_codes : Optional [torch .Tensor ] = None ,
17631763 context_audio_codes_lens : Optional [torch .Tensor ] = None ,
17641764 max_examples : int = 3 ,
@@ -1769,7 +1769,7 @@ def _prepare_audio_examples(
17691769 Args:
17701770 logits: Model output logits to convert to predicted audio.
17711771 target_audio_codes: Ground truth audio codes.
1772- audio_codes_lens_target : Lengths of target audio codes.
1772+ audio_codes_lens : Lengths of target audio codes.
17731773 context_audio_codes: Optional context audio codes for voice cloning.
17741774 context_audio_codes_lens: Lengths of context audio codes.
17751775 max_examples: Maximum number of examples to process.
@@ -1779,15 +1779,28 @@ def _prepare_audio_examples(
17791779 each containing a list of numpy arrays (or None for context if unavailable).
17801780 """
17811781 with torch .no_grad ():
1782- # Decode predictions and targets
1783- pred_audio_codes = self .logits_to_audio_codes (logits , audio_codes_lens_target )
1784- pred_audio , pred_audio_lens = self .codes_to_audio (pred_audio_codes , audio_codes_lens_target )
1785- target_audio , target_audio_lens = self .codes_to_audio (target_audio_codes , audio_codes_lens_target )
1782+ # Decode predictions: convert logits to codes, remove EOS token, then decode to audio
1783+ pred_audio_codes = self .logits_to_audio_codes (logits , audio_codes_lens )
1784+ pred_audio_codes , pred_audio_codes_lens = self .remove_eos_token (
1785+ codes = pred_audio_codes , codes_len = audio_codes_lens
1786+ )
1787+ pred_audio , pred_audio_lens , _ = self .codes_to_audio (pred_audio_codes , pred_audio_codes_lens )
1788+
1789+ # Decode targets: remove EOS token, then decode to audio
1790+ target_audio_codes , target_audio_codes_lens = self .remove_eos_token (
1791+ codes = target_audio_codes , codes_len = audio_codes_lens
1792+ )
1793+ target_audio , target_audio_lens , _ = self .codes_to_audio (target_audio_codes , target_audio_codes_lens )
17861794
17871795 # Decode context audio if available (shape check ensures it's not a dummy tensor used in text context)
17881796 context_audio , context_audio_lens = None , None
17891797 if context_audio_codes is not None and context_audio_codes .shape [2 ] > 3 :
1790- context_audio , context_audio_lens = self .codes_to_audio (context_audio_codes , context_audio_codes_lens )
1798+ context_audio_codes , context_audio_codes_lens = self .remove_special_tokens (
1799+ codes = context_audio_codes , codes_len = context_audio_codes_lens
1800+ )
1801+ context_audio , context_audio_lens , _ = self .codes_to_audio (
1802+ context_audio_codes , context_audio_codes_lens
1803+ )
17911804
17921805 pred_audios = []
17931806 target_audios = []
@@ -1857,8 +1870,12 @@ def _log_media_to_wandb_and_tb(self, media_data: Dict, global_step: int) -> None
18571870 audio_list .append (
18581871 wandb .Audio (context_audio_np , sample_rate = self .output_sample_rate , caption = "context" )
18591872 )
1860- audio_list .append (wandb .Audio (pred_audio_np , sample_rate = self .output_sample_rate , caption = "prediction" ))
1861- audio_list .append (wandb .Audio (target_audio_np , sample_rate = self .output_sample_rate , caption = "target" ))
1873+ audio_list .append (
1874+ wandb .Audio (pred_audio_np , sample_rate = self .output_sample_rate , caption = "prediction" )
1875+ )
1876+ audio_list .append (
1877+ wandb .Audio (target_audio_np , sample_rate = self .output_sample_rate , caption = "target" )
1878+ )
18621879 wandb_log_dict [f"Audio:{ dataset_prefix } /Example_{ idx } " ] = audio_list
18631880
18641881 if is_tb :
@@ -2702,7 +2719,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
27022719 batch_idx: Batch index
27032720 dataloader_idx: Index of the dataloader (0 for single dataloader)
27042721 """
2705- batch_output = self .process_batch (batch , mode = "val" )
2722+ batch_output = self .process_batch (batch )
27062723 # self.process_batch returns a dict. We currently only log "logits" which come from the parallel prediction
27072724 # head. If we use local_transformer, then the local_transformer returns "local_transformer_logits"
27082725
@@ -2744,7 +2761,7 @@ def validation_step(self, batch, batch_idx, dataloader_idx=0):
27442761 audio_data = self ._prepare_audio_examples (
27452762 logits = logits ,
27462763 target_audio_codes = audio_codes_target ,
2747- audio_codes_lens_target = audio_codes_lens_target ,
2764+ audio_codes_lens = audio_codes_lens_target ,
27482765 context_audio_codes = context_audio_codes ,
27492766 context_audio_codes_lens = context_audio_codes_lens ,
27502767 max_examples = 3 ,
0 commit comments