-
Notifications
You must be signed in to change notification settings - Fork 3.3k
Fix alignment logits when using fusion models in TDT greedy decoder #15311
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
e8ac4c9 to
9a90078
Compare
Signed-off-by: Alexandre Caulier <alexandre.caulier.a@gmail.com>
Signed-off-by: Alexandre Caulier <alexandre.caulier.a@gmail.com>
* Added datasets filtering to the inference script New command line argument: --datasets <dataset1,dataset2,...> where dataset1, dataset2, ... are the names datasets to process in the datasets_json_path file. If not specified, all datasets in the datasets_json_path will be processed. If specified, only the datasets in the list will be processed. Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> * Refined datasets filtering in the inference script * Correctly handle comma-separated list of dataset names in the --datasets argument. * Help text Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> * Enable label to force CI tests Signed-off-by: Charlie Truong <chtruong@nvidia.com> --------- Signed-off-by: Fejgin, Roy <rfejgin@nvidia.com> Signed-off-by: Charlie Truong <chtruong@nvidia.com> Co-authored-by: Fejgin, Roy <rfejgin@nvidia.com> Co-authored-by: Jason <jasoli@nvidia.com> Signed-off-by: Alexandre Caulier <alexandre.caulier.a@gmail.com>
Signed-off-by: Jason <jasoli@nvidia.com> Signed-off-by: Alexandre Caulier <alexandre.caulier.a@gmail.com>
9244671 to
6264954
Compare
|
If you have time for a review @nithinraok. Thanks a lot!! |
artbataev
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First, thanks for this contribution.
Unfortunately, I'm working on a concurrent PR with modifications to the decoding algorithm, so asking to wait before it is merged.
But I would be happy to fix bugs and inconsistencies in this implementation.
Also, it would be great if you could clarify your use-case of alignments: alignments data structure (which stores all logits) requires an enormous amount of memory and makes decoding slower.
Regarding the PR itself:
- I agree about using consistent logits with LM scores in alignments (currently inconsistent, but for both TDT and RNN-T)
- Also, we need to maintain consistency of
tdt_label_loopingandrnnt_label_looping - I disagree with using
-inffor blank, since it will break some of our use cases of alignments.
If you clarify the core use-case, we could discuss how to implement it.
|
Concurrent PR with decoding changes: #15315 |
|
Hello @artbataev ! Thanks a lot for taking the time to review my PR! No worries, I can wait. Regarding my use-case for the
It is very interesting to see, at each prediction step, which tokens are the most likely according to the model and to compare that with how the Fusion model is trained. This really allow to have a deep understanding of the behaviour of the models and to see which changes could improve the results (fine tuning, changing how the fusion model is trained ...) What i noticed is that the token that was predicted at each prediction step was not necessarily the one with the highest score in the logits of the relevant prediction step. My proposed change allows to fix that but indeed I agree that setting -inf for the blank logit is not very satisfying. When the blank is not the most likely, keeping the blank logit to its value and updating the logits with the fusion scores of all the other tokens won't guarantee that the blank token is not the most likely token. WDYT ? How would you fix this issue with another implementation ? Also, thanks a lot for linking the other PR. I am having a look ! Thanks a lot for all the PRs you merged related to TDT, very impressive work 💪 |

What does this PR do ?
Fixes an inconsistency in alignment logits when using preserve_alignments=True with TDT models and fusion models (e.g., LM, boosting tree).
Problem
When fusion models are enabled, the label selection logic correctly applies:
However, the logits stored in alignments were always the pre-fusion logits. This caused a mismatch where the stored label (correctly chosen with the above logic) did not correspond to the argmax of the stored logits.
Solution
Update the logits sent to alignments to be consistent with the label selection:
This ensures
argmax(logits) == labelfor every alignment entry, making the alignments reliable for downstream analysis.Scope
Implementation in
torch_implmethod only. CUDA graphs implementation is not affected.Collection: ASR
Changelog
tdt_label_looping.py: Updatedtorch_implto store alignment-consistent logits when fusion models are used (both outer loop and inner loop)test_rnnt_decoding.py: Addedtest_tdt_greedy_decoding_preserve_alignments_with_fusionto verify label/logits consistency with fusion modelsGitHub Actions CI
The Jenkins CI system has been replaced by GitHub Actions self-hosted runners.
The GitHub Actions CI will run automatically when the "Run CICD" label is added to the PR.
To re-run CI remove and add the label again.
To run CI on an untrusted fork, a NeMo user with write access must first click "Approve and run".
Before your PR is "Ready for review"
Pre checks:
PR Type:
If you haven't finished some of the above items you can still open "Draft" PR.
Who can review?
Anyone in the community is free to review the PR once the checks have passed.
Contributor guidelines contains specific people who can review PRs to various areas.
Additional Information
None