Skip to content

Commit 656bbaf

Browse files
author
Huy Vu2
committed
Merge remote-tracking branch 'origin/main' into huvu/nemo_data_designer
2 parents bf38aa9 + 53d943a commit 656bbaf

File tree

4 files changed

+13
-3
lines changed

4 files changed

+13
-3
lines changed

nemo_curator/stages/audio/common.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,11 @@ class LegacySpeechStage(ProcessingStage[Task, Task]):
3232
def process(self, task: AudioBatch) -> list[Task]:
3333
result = []
3434
for entry in task.data:
35-
result.extend(self.process_dataset_entry(entry))
35+
entries = self.process_dataset_entry(entry)
36+
for r in entries:
37+
if r is not task and not r._stage_perf:
38+
r._stage_perf = list(task._stage_perf)
39+
result.extend(entries)
3640
return result
3741

3842
@abstractmethod
@@ -54,6 +58,7 @@ class GetAudioDurationStage(LegacySpeechStage):
5458
All the same fields as in the input manifest plus duration_key
5559
"""
5660

61+
name = "GetAudioDurationStage"
5762
audio_filepath_key: str
5863
duration_key: str
5964

@@ -80,14 +85,14 @@ class PreserveByValueStage(LegacySpeechStage):
8085
8186
"""
8287

88+
name = "PreserveByValueStage"
89+
8390
def __init__(
8491
self,
8592
input_value_key: str,
8693
target_value: int | str,
8794
operator: str = "eq",
88-
**kwargs,
8995
):
90-
super().__init__(**kwargs)
9196
self.input_value_key = input_value_key
9297
self.target_value = target_value
9398
if operator == "lt":

nemo_curator/stages/audio/inference/asr_nemo.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,4 +148,5 @@ def process(self, task: FileGroupTask | DocumentBatch | AudioBatch) -> AudioBatc
148148
dataset_name=f"{self.model_name}_inference",
149149
filepath_key=self.filepath_key,
150150
data=audio_items,
151+
_stage_perf=task._stage_perf,
151152
)

nemo_curator/stages/audio/io/convert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@ class AudioToDocumentStage(ProcessingStage[AudioBatch, DocumentBatch]):
2424
2525
"""
2626

27+
name = "AudioToDocumentStage"
28+
2729
def process(self, task: AudioBatch) -> list[DocumentBatch]:
2830
return [
2931
DocumentBatch(
3032
data=pd.DataFrame(task.data),
3133
task_id=task.task_id,
3234
dataset_name=task.dataset_name,
35+
_stage_perf=task._stage_perf,
3336
)
3437
]

nemo_curator/stages/audio/metrics/get_wer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class GetPairwiseWerStage(LegacySpeechStage):
6262
The same data as in the input manifest with wer_key and corresponding values.
6363
"""
6464

65+
name = "GetPairwiseWerStage"
6566
text_key: str = "text"
6667
pred_text_key: str = "pred_text"
6768
wer_key: str = "wer"

0 commit comments

Comments
 (0)