Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 48 additions & 34 deletions nemo/collections/tts/models/magpietts.py
Original file line number Diff line number Diff line change
Expand Up @@ -2777,6 +2777,17 @@ def infer_batch(
all_predictions = []
end_indices = {}

# Maintain list of decoder embedded inputs that we append to at the end of every iteration
audio_codes_embedded = self.embed_audio_tokens(audio_codes_input)
if context_tensors.additional_decoder_input is not None: # additional_decoder_input is the context
_audio_codes_embedded = torch.cat(
[context_tensors.additional_decoder_input, audio_codes_embedded], dim=1
)
_audio_codes_mask = torch.cat([context_tensors.additional_decoder_mask, audio_codes_mask], dim=1)
else:
_audio_codes_embedded = audio_codes_embedded
_audio_codes_mask = audio_codes_mask

if use_cfg:
dummy_cond, dummy_cond_mask, dummy_additional_decoder_input, dummy_addition_dec_mask, _ = (
self.prepare_dummy_cond_for_cfg(
Expand All @@ -2786,6 +2797,29 @@ def infer_batch(
context_tensors.additional_decoder_mask,
)
)
batch_size = audio_codes_embedded.size(0)
if isinstance(context_tensors.cond, list): # multi_encoder setup
cfg_cond = [
torch.cat([cond_item, dummy_cond_item], dim=0)
for cond_item, dummy_cond_item in zip(context_tensors.cond, dummy_cond)
]
cfg_cond_mask = [
torch.cat([cond_mask_item, dummy_cond_mask_item], dim=0)
for cond_mask_item, dummy_cond_mask_item in zip(context_tensors.cond_mask, dummy_cond_mask)
]
else:
cfg_cond = torch.cat([context_tensors.cond, dummy_cond], dim=0)
cfg_cond_mask = torch.cat([context_tensors.cond_mask, context_tensors.cond_mask], dim=0)
# Maintain list of decoder cfg inputs that we append to at the end of every iteration
cfg_audio_codes_embedded = torch.cat([_audio_codes_embedded, _audio_codes_embedded], dim=0)
cfg_audio_codes_mask = torch.cat([_audio_codes_mask, _audio_codes_mask], dim=0)
if dummy_additional_decoder_input is not None:
cfg_audio_codes_embedded[batch_size:, : dummy_additional_decoder_input.size(1)] = (
dummy_additional_decoder_input
)
cfg_audio_codes_mask[batch_size:, : dummy_additional_decoder_input.size(1)] = (
dummy_addition_dec_mask
)

cross_attention_scores_all_timesteps = []
all_heads_cross_attn_scores_all_timesteps = []
Expand All @@ -2800,17 +2834,8 @@ def infer_batch(
for idx in range(max_decoder_steps // self.frame_stacking_factor):
if idx == 1:
time_to_first_prediction = time.time() - start_time
if idx % 20 == 0:
print(f"Decoding timestep {idx}")
audio_codes_embedded = self.embed_audio_tokens(audio_codes_input)
if context_tensors.additional_decoder_input is not None:
_audio_codes_embedded = torch.cat(
[context_tensors.additional_decoder_input, audio_codes_embedded], dim=1
)
_audio_codes_mask = torch.cat([context_tensors.additional_decoder_mask, audio_codes_mask], dim=1)
else:
_audio_codes_embedded = audio_codes_embedded
_audio_codes_mask = audio_codes_mask
if idx % 50 == 0:
logging.info(f"Decoding timestep {idx}")

if apply_prior_to_layers is not None:
attn_prior = [None for _ in range(self.decoder.n_layers)]
Expand All @@ -2823,29 +2848,6 @@ def infer_batch(
attn_prior = [attn_prior, None]

if use_cfg:
batch_size = audio_codes_embedded.size(0)
if isinstance(context_tensors.cond, list):
cfg_cond = [
torch.cat([cond_item, dummy_cond_item], dim=0)
for cond_item, dummy_cond_item in zip(context_tensors.cond, dummy_cond)
]
cfg_cond_mask = [
torch.cat([cond_mask_item, dummy_cond_mask_item], dim=0)
for cond_mask_item, dummy_cond_mask_item in zip(context_tensors.cond_mask, dummy_cond_mask)
]
else:
cfg_cond = torch.cat([context_tensors.cond, dummy_cond], dim=0)
cfg_cond_mask = torch.cat([context_tensors.cond_mask, dummy_cond_mask], dim=0)
cfg_audio_codes_embedded = torch.cat([_audio_codes_embedded, _audio_codes_embedded], dim=0)
cfg_audio_codes_mask = torch.cat([_audio_codes_mask, _audio_codes_mask], dim=0)
if dummy_additional_decoder_input is not None:
cfg_audio_codes_embedded[batch_size:, : dummy_additional_decoder_input.size(1)] = (
dummy_additional_decoder_input
)
cfg_audio_codes_mask[batch_size:, : dummy_additional_decoder_input.size(1)] = (
dummy_addition_dec_mask
)

combined_logits, attn_probs, dec_out = self.forward(
dec_input_embedded=cfg_audio_codes_embedded,
dec_input_mask=cfg_audio_codes_mask,
Expand Down Expand Up @@ -2990,6 +2992,18 @@ def infer_batch(
# Codec must be of atleast 4 timesteps to be decoded properly
print("All ends reached")
break

# If not all inputs have ended, embed for the next step and append it to the relevant tensors:
# _audio_codes_embedded, _audio_codes_mask
# and their CFG counterparts: cfg_audio_codes_embedded, cfg_audio_codes_mask
audio_codes_embedded = self.embed_audio_tokens(audio_codes_input[:, :, -1].unsqueeze(-1))
_audio_codes_embedded = torch.cat((_audio_codes_embedded, audio_codes_embedded), dim=1)
_audio_codes_mask = torch.cat((_audio_codes_mask, audio_codes_mask[:, -1].unsqueeze(-1)), dim=1)
if use_cfg:
double_embeds = torch.cat([audio_codes_embedded, audio_codes_embedded], dim=0)
cfg_audio_codes_embedded = torch.cat((cfg_audio_codes_embedded, double_embeds), dim=1)
double_mask = torch.cat([_audio_codes_mask[:, -1], _audio_codes_mask[:, -1]], dim=0)
cfg_audio_codes_mask = torch.cat((cfg_audio_codes_mask, double_mask.unsqueeze(-1)), dim=1)
tts_generation_time = time.time() - start_time
tts_generation_time_per_frame = tts_generation_time / (len(all_predictions) * self.frame_stacking_factor)

Expand Down
Loading