Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ class CudaGraphsMode(PrettyStrEnum):
max_symbols: Optional[int]
allow_cuda_graphs: bool
biasing_multi_model: GPUBiasingMultiModelBase | None
fusion_models: list[NGramGPULanguageModel]
fusion_models_alpha: list[float]

def force_cuda_graphs_mode(self, mode: Optional[str | CudaGraphsMode]):
"""
Expand Down Expand Up @@ -138,6 +140,64 @@ def disable_cuda_graphs(self) -> bool:
self.reset_cuda_graphs_state()
return True

# fusion models-related methods
@property
def per_stream_biasing_enabled(self):
return self.biasing_multi_model is not None

def _all_fusion_models(
self, with_multi_model: bool = True
) -> list[NGramGPULanguageModel | GPUBiasingMultiModelBase]:
if with_multi_model and self.per_stream_biasing_enabled:
return self.fusion_models + [self.biasing_multi_model]
return self.fusion_models

def _all_fusion_models_with_params(self, with_multi_model: bool = True) -> list[FusionModelWithParams]:
models_with_params = [
FusionModelWithParams(model=model, alpha=alpha, is_multi_model=False)
for model, alpha in zip(self.fusion_models, self.fusion_models_alpha)
]
if with_multi_model and self.per_stream_biasing_enabled:
models_with_params.append(
FusionModelWithParams(model=self.biasing_multi_model, alpha=None, is_multi_model=True)
)
return models_with_params

def has_fusion_models(self, with_multi_model: bool = True) -> bool:
if len(self.fusion_models) > 0:
return True
return with_multi_model and self.per_stream_biasing_enabled

def _move_fusion_models_to_device(self, device: torch.device):
"""
Move all fusion models to device.
We need to do this since `self` is not nn.Module instance, but owns fusion models (nn.Module instances).
"""
with torch.inference_mode(mode=False):
# NB: we avoid inference mode since otherwise all model params/buffers will be inference tensors,
# which will make further inplace manipulations impossible
# (e.g., `remove_model` for multi-model will throw errors)
for fusion_model in self._all_fusion_models():
fusion_model.to(device) # fusion_models is nn.Module, but self is not; need to move manually

def advance_fusion_models(
self, fusion_states_list: list[torch.Tensor], multi_biasing_ids: torch.Tensor | None, float_dtype: torch.dtype
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
fusion_states_candidates_list = []
fusion_scores_list = []
for fusion_idx, fusion_model_with_params in enumerate(self._all_fusion_models_with_params()):
fusion_scores, fusion_states_candidates = fusion_model_with_params.model.advance(
states=fusion_states_list[fusion_idx],
**({"model_ids": multi_biasing_ids} if fusion_model_with_params.is_multi_model else {}),
)
fusion_scores = fusion_scores.to(dtype=float_dtype)
if not fusion_model_with_params.is_multi_model:
fusion_scores *= fusion_model_with_params.alpha
# save fusion scores and states candidates
fusion_scores_list.append(fusion_scores)
fusion_states_candidates_list.append(fusion_states_candidates)
return fusion_scores_list, fusion_states_candidates_list

@abstractmethod
def torch_impl(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,10 @@
import torch.nn.functional as F
from omegaconf import DictConfig

from nemo.collections.asr.parts.context_biasing.biasing_multi_model import (
GPUBiasingMultiModel,
GPUBiasingMultiModelBase,
)
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import GPUBiasingMultiModel
from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
BatchedLabelLoopingState,
FusionModelWithParams,
GreedyBatchedLabelLoopingComputerBase,
LabelLoopingStateItem,
SeparateGraphsLabelLooping,
Expand Down Expand Up @@ -265,33 +261,6 @@ def __init__(
self.cuda_graphs_allow_fallback = True
self.maybe_enable_cuda_graphs()

@property
def per_stream_biasing_enabled(self):
return self.biasing_multi_model is not None

def _all_fusion_models(
self, with_multi_model: bool = True
) -> list[NGramGPULanguageModel | GPUBiasingMultiModelBase]:
if with_multi_model and self.per_stream_biasing_enabled:
return self.fusion_models + [self.biasing_multi_model]
return self.fusion_models

def _all_fusion_models_with_params(self, with_multi_model: bool = True) -> list[FusionModelWithParams]:
models_with_params = [
FusionModelWithParams(model=model, alpha=alpha, is_multi_model=False)
for model, alpha in zip(self.fusion_models, self.fusion_models_alpha)
]
if with_multi_model and self.per_stream_biasing_enabled:
models_with_params.append(
FusionModelWithParams(model=self.biasing_multi_model, alpha=None, is_multi_model=True)
)
return models_with_params

def has_fusion_models(self, with_multi_model: bool = True) -> bool:
if len(self.fusion_models) > 0:
return True
return with_multi_model and self.per_stream_biasing_enabled

def reset_cuda_graphs_state(self):
"""Reset state to release memory (for CUDA graphs implementations)"""
self.state = None
Expand All @@ -306,18 +275,6 @@ def _get_frame_confidence(self, logits: torch.Tensor) -> Optional[torch.Tensor]:
else None
)

def _move_fusion_models_to_device(self, device: torch.device):
"""
Move all fusion models to device.
We need to do this since `self` is not nn.Module instance, but owns fusion models (nn.Module instances).
"""
with torch.inference_mode(mode=False):
# NB: we avoid inference mode since otherwise all model params/buffers will be inference tensors,
# which will make further inplace manipulations impossible
# (e.g., `remove_model` for multi-model will throw errors)
for fusion_model in self._all_fusion_models():
fusion_model.to(device) # fusion_models is nn.Module, but self is not; need to move manually

def torch_impl(
self,
encoder_output: torch.Tensor,
Expand Down Expand Up @@ -416,21 +373,14 @@ def torch_impl(
scores, labels = logits.max(-1)

if self.has_fusion_models():
fusion_scores_list, fusion_states_candidates_list = [], []
fusion_scores_list, fusion_states_candidates_list = self.advance_fusion_models(
fusion_states_list=fusion_states_list,
multi_biasing_ids=multi_biasing_ids,
float_dtype=float_dtype,
)
logits_with_fusion = logits.clone()
for fusion_idx, fusion_model_with_params in enumerate(self._all_fusion_models_with_params()):
fusion_scores, fusion_states_candidates = fusion_model_with_params.model.advance(
states=fusion_states_list[fusion_idx],
**({"model_ids": multi_biasing_ids} if fusion_model_with_params.is_multi_model else {}),
)
fusion_scores = fusion_scores.to(dtype=float_dtype)
if not fusion_model_with_params.is_multi_model:
fusion_scores *= fusion_model_with_params.alpha
# combine logits with fusion model without blank
for fusion_scores in fusion_scores_list:
logits_with_fusion[:, :-1] += fusion_scores
# save fusion scores and states candidates
fusion_scores_list.append(fusion_scores)
fusion_states_candidates_list.append(fusion_states_candidates)

# get max scores and labels without blank
fusion_scores_max, fusion_labels_max = logits_with_fusion[:, :-1].max(dim=-1)
Expand Down Expand Up @@ -478,7 +428,6 @@ def torch_impl(
if self.has_fusion_models():
logits_with_fusion = logits.clone()
for fusion_scores in fusion_scores_list:
# combined scores with fusion model - without blank
logits_with_fusion[:, :-1] += fusion_scores
# get max scores and labels without blank
more_scores_w_fusion, more_labels_w_fusion = logits_with_fusion[:, :-1].max(dim=-1)
Expand Down Expand Up @@ -1132,18 +1081,19 @@ def _before_inner_loop_get_joint_output(self):
torch.max(logits, dim=-1, out=(self.state.scores, self.state.labels))

if self.has_fusion_models():
for fusion_model_idx, fusion_model_with_params in enumerate(self._all_fusion_models_with_params()):
fusion_scores_list, fusion_states_candidates_list = self.advance_fusion_models(
fusion_states_list=self.state.fusion_states_list,
multi_biasing_ids=self.state.multi_biasing_ids,
float_dtype=self.state.float_dtype,
)
for fusion_model_idx in range(len(fusion_scores_list)):
# get fusion scores/states
fusion_scores, fusion_states_candidates = fusion_model_with_params.model.advance(
states=self.state.fusion_states_list[fusion_model_idx],
**({"model_ids": self.state.multi_biasing_ids} if fusion_model_with_params.is_multi_model else {}),
self.state.fusion_states_candidates_list[fusion_model_idx].copy_(
fusion_states_candidates_list[fusion_model_idx]
)
if not fusion_model_with_params.is_multi_model:
fusion_scores *= fusion_model_with_params.alpha
self.state.fusion_states_candidates_list[fusion_model_idx].copy_(fusion_states_candidates)
self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores.to(dtype=self.state.float_dtype))
self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores_list[fusion_model_idx])
# update logits with fusion scores
logits[:, :-1] += fusion_scores
logits[:, :-1] += fusion_scores_list[fusion_model_idx]
# get labels (greedy) and scores from current logits, replace labels/scores with new
scores_w_fusion, labels_w_fusion = logits[:, :-1].max(dim=-1)
# preserve "blank" / "non-blank" category
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,10 @@
import torch.nn.functional as F
from omegaconf import DictConfig, ListConfig

from nemo.collections.asr.parts.context_biasing.biasing_multi_model import (
GPUBiasingMultiModel,
GPUBiasingMultiModelBase,
)
from nemo.collections.asr.parts.context_biasing.biasing_multi_model import GPUBiasingMultiModel
from nemo.collections.asr.parts.submodules.ngram_lm import NGramGPULanguageModel
from nemo.collections.asr.parts.submodules.transducer_decoding.label_looping_base import (
BatchedLabelLoopingState,
FusionModelWithParams,
GreedyBatchedLabelLoopingComputerBase,
LabelLoopingStateItem,
SeparateGraphsLabelLooping,
Expand Down Expand Up @@ -292,33 +288,6 @@ def __init__(
self.cuda_graphs_allow_fallback = True
self.maybe_enable_cuda_graphs()

@property
def per_stream_biasing_enabled(self):
return self.biasing_multi_model is not None

def _all_fusion_models(
self, with_multi_model: bool = True
) -> list[NGramGPULanguageModel | GPUBiasingMultiModelBase]:
if with_multi_model and self.per_stream_biasing_enabled:
return self.fusion_models + [self.biasing_multi_model]
return self.fusion_models

def _all_fusion_models_with_params(self, with_multi_model: bool = True) -> list[FusionModelWithParams]:
models_with_params = [
FusionModelWithParams(model=model, alpha=alpha, is_multi_model=False)
for model, alpha in zip(self.fusion_models, self.fusion_models_alpha)
]
if with_multi_model and self.per_stream_biasing_enabled:
models_with_params.append(
FusionModelWithParams(model=self.biasing_multi_model, alpha=None, is_multi_model=True)
)
return models_with_params

def has_fusion_models(self, with_multi_model: bool = True) -> bool:
if len(self.fusion_models) > 0:
return True
return with_multi_model and self.per_stream_biasing_enabled

def reset_cuda_graphs_state(self):
"""Reset state to release memory (for CUDA graphs implementations)"""
self.state = None
Expand Down Expand Up @@ -347,18 +316,6 @@ def _get_frame_confidence(self, logits: torch.Tensor, num_durations: int) -> Opt
)
)

def _move_fusion_models_to_device(self, device: torch.device):
"""
Move all fusion models to device.
We need to do this since `self` is not nn.Module instance, but owns fusion models (nn.Module instances).
"""
with torch.inference_mode(mode=False):
# NB: we avoid inference mode since otherwise all model params/buffers will be inference tensors,
# which will make further inplace manipulations impossible
# (e.g., `remove_model` for multi-model will throw errors)
for fusion_model in self._all_fusion_models():
fusion_model.to(device) # fusion_models is nn.Module, but self is not; need to move manually

def torch_impl(
self,
encoder_output: torch.Tensor,
Expand Down Expand Up @@ -467,21 +424,14 @@ def torch_impl(
scores, labels = logits[:, :-num_durations].max(dim=-1)

if self.has_fusion_models():
fusion_scores_list, fusion_states_candidates_list = [], []
fusion_scores_combined, fusion_states_candidates_list = self.advance_fusion_models(
fusion_states_list=fusion_states_list,
multi_biasing_ids=multi_biasing_ids,
float_dtype=float_dtype,
)
logits_with_fusion = logits.clone()
for fusion_idx, fusion_model_with_params in enumerate(self._all_fusion_models_with_params()):
fusion_scores, fusion_states_candidates = fusion_model_with_params.model.advance(
states=fusion_states_list[fusion_idx],
**({"model_ids": multi_biasing_ids} if fusion_model_with_params.is_multi_model else {}),
)
fusion_scores = fusion_scores.to(dtype=float_dtype)
if not fusion_model_with_params.is_multi_model:
fusion_scores *= fusion_model_with_params.alpha
# combine logits with fusion model without blank
for fusion_scores in fusion_scores_list:
logits_with_fusion[:, : -num_durations - 1] += fusion_scores
# save fusion scores and states candidates
fusion_scores_list.append(fusion_scores)
fusion_states_candidates_list.append(fusion_states_candidates)

# get max scores and labels without blank
fusion_scores_max, fusion_labels_max = logits_with_fusion[:, : -num_durations - 1].max(dim=-1)
Expand Down Expand Up @@ -534,7 +484,6 @@ def torch_impl(
if self.has_fusion_models():
logits_with_fusion = logits.clone()
for fusion_scores in fusion_scores_list:
# combined scores with fusion model - without blank
logits_with_fusion[:, : -num_durations - 1] += fusion_scores
# get max scores and labels without blank
more_scores_w_fusion, more_labels_w_fusion = logits_with_fusion[:, : -num_durations - 1].max(
Expand Down Expand Up @@ -1212,18 +1161,19 @@ def _before_inner_loop_get_joint_output(self):
)

if self.has_fusion_models():
for fusion_model_idx, fusion_model_with_params in enumerate(self._all_fusion_models_with_params()):
fusion_scores_list, fusion_states_candidates_list = self.advance_fusion_models(
fusion_states_list=self.state.fusion_states_list,
multi_biasing_ids=self.state.multi_biasing_ids,
float_dtype=self.state.float_dtype,
)
for fusion_model_idx in range(len(fusion_scores_list)):
# get fusion scores/states
fusion_scores, fusion_states_candidates = fusion_model_with_params.model.advance(
states=self.state.fusion_states_list[fusion_model_idx],
**({"model_ids": self.state.multi_biasing_ids} if fusion_model_with_params.is_multi_model else {}),
self.state.fusion_states_candidates_list[fusion_model_idx].copy_(
fusion_states_candidates_list[fusion_model_idx]
)
if not fusion_model_with_params.is_multi_model:
fusion_scores *= fusion_model_with_params.alpha
self.state.fusion_states_candidates_list[fusion_model_idx].copy_(fusion_states_candidates)
self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores.to(dtype=self.state.float_dtype))
self.state.fusion_scores_list[fusion_model_idx].copy_(fusion_scores_list[fusion_model_idx])
# update logits with fusion scores
logits[:, : -self.state.model_durations.shape[0] - 1] += fusion_scores
logits[:, : -self.state.model_durations.shape[0] - 1] += fusion_scores_list[fusion_model_idx]
# get labels (greedy) and scores from current logits, replace labels/scores with new
scores_w_fusion, labels_w_fusion = logits[:, : -self.state.model_durations.shape[0] - 1].max(dim=-1)
# preserve "blank" / "non-blank" category
Expand Down
Loading