Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 147 additions & 0 deletions nemo/collections/asr/metrics/eval_ner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
from typing import List
from collections import defaultdict
import numpy as np
import editdistance


def get_ner_scores(all_gt, all_predictions):
"""
Evalutes per-label and overall (micro and macro) metrics of precision, recall, and fscore

Input:
all_gt/all_predictions:
List of list of tuples: (label, phrase, identifier)
Each list of tuples correspond to a sentence:
label: entity tag
phrase: entity phrase
tuple_identifier: identifier to differentiate repeating (label, phrase) pairs

Returns:
Dictionary of metrics

Example:
List of GT (label, phrase) pairs of a sentence: [(GPE, "eu"), (DATE, "today"), (GPE, "eu")]
all_gt: [(GPE, "eu", 0), (DATE, "today", 0), (GPE, "eu", 1)]
"""
metrics = {}
stats = get_ner_stats(all_gt, all_predictions)
num_correct, num_gt, num_pred = 0, 0, 0
prec_lst, recall_lst, fscore_lst = [], [], []
for tag_name, tag_stats in stats.items():
precision, recall, fscore = get_metrics(
np.sum(tag_stats["tp"]),
np.sum(tag_stats["gt_cnt"]),
np.sum(tag_stats["pred_cnt"]),
)
_ = metrics.setdefault(tag_name, {})
metrics[tag_name]["precision"] = precision
metrics[tag_name]["recall"] = recall
metrics[tag_name]["fscore"] = fscore

num_correct += np.sum(tag_stats["tp"])
num_pred += np.sum(tag_stats["pred_cnt"])
num_gt += np.sum(tag_stats["gt_cnt"])

prec_lst.append(precision)
recall_lst.append(recall)
fscore_lst.append(fscore)

precision, recall, fscore = get_metrics(num_correct, num_gt, num_pred)
metrics["overall_micro"] = {}
metrics["overall_micro"]["precision"] = precision
metrics["overall_micro"]["recall"] = recall
metrics["overall_micro"]["fscore"] = fscore

metrics["overall_macro"] = {}
metrics["overall_macro"]["precision"] = np.mean(prec_lst)
metrics["overall_macro"]["recall"] = np.mean(recall_lst)
metrics["overall_macro"]["fscore"] = np.mean(fscore_lst)

return metrics


def get_ner_stats(all_gt, all_predictions):
stats = {}
cnt = 0
for gt, pred in zip(all_gt, all_predictions):
entities_true = defaultdict(set)
entities_pred = defaultdict(set)
for type_name, entity_info1, entity_info2 in gt:
entities_true[type_name].add((entity_info1, entity_info2))
for type_name, entity_info1, entity_info2 in pred:
entities_pred[type_name].add((entity_info1, entity_info2))
target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
for tag_name in target_names:
_ = stats.setdefault(tag_name, {})
_ = stats[tag_name].setdefault("tp", [])
_ = stats[tag_name].setdefault("gt_cnt", [])
_ = stats[tag_name].setdefault("pred_cnt", [])
entities_true_type = entities_true.get(tag_name, set())
entities_pred_type = entities_pred.get(tag_name, set())
stats[tag_name]["tp"].append(len(entities_true_type & entities_pred_type))
stats[tag_name]["pred_cnt"].append(len(entities_pred_type))
stats[tag_name]["gt_cnt"].append(len(entities_true_type))
return stats


def safe_divide(numerator, denominator):
numerator = np.array(numerator)
denominator = np.array(denominator)
mask = denominator == 0.0
denominator = denominator.copy()
denominator[mask] = 1 # avoid infs/nans
return numerator / denominator


def ner_error_analysis(all_gt, all_predictions, gt_text):
"""
Print out predictions and GT
all_gt: [GT] list of tuples of (label, phrase, identifier idx)
all_predictions: [hypothesis] list of tuples of (label, phrase, identifier idx)
gt_text: list of GT text sentences
"""
analysis_examples_dct = {}
analysis_examples_dct["all"] = []
for idx, text in enumerate(gt_text):
if isinstance(text, list):
text = " ".join(text)
gt = all_gt[idx]
pred = all_predictions[idx]
entities_true = defaultdict(set)
entities_pred = defaultdict(set)
for type_name, entity_info1, entity_info2 in gt:
entities_true[type_name].add((entity_info1, entity_info2))
for type_name, entity_info1, entity_info2 in pred:
entities_pred[type_name].add((entity_info1, entity_info2))
target_names = sorted(set(entities_true.keys()) | set(entities_pred.keys()))
analysis_examples_dct["all"].append("\t".join([text, str(gt), str(pred)]))
for tag_name in target_names:
_ = analysis_examples_dct.setdefault(tag_name, [])
new_gt = [(item1, item2) for item1, item2, _ in gt]
new_pred = [(item1, item2) for item1, item2, _ in pred]
analysis_examples_dct[tag_name].append(
"\t".join([text, str(new_gt), str(new_pred)])
)

return analysis_examples_dct


def get_metrics(num_correct, num_gt, num_pred):
precision = safe_divide([num_correct], [num_pred])
recall = safe_divide([num_correct], [num_gt])
fscore = safe_divide([2 * precision * recall], [(precision + recall)])
return precision[0], recall[0], fscore[0][0]


def get_wer(refs: List[str], hyps: List[str]):
"""
args:
refs (list of str): reference texts
hyps (list of str): hypothesis/prediction texts
"""
n_words, n_errors = 0, 0
for ref, hyp in zip(refs, hyps):
ref, hyp = ref.split(), hyp.split()
n_words += len(ref)
n_errors += editdistance.eval(ref, hyp)
return safe_divide(n_errors, n_words)
74 changes: 74 additions & 0 deletions nemo/collections/asr/metrics/multi_wer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
from typing import List, Union
from .wer import word_error_rate
from .eval_ner import eval_utils_ner

def get_clean_transcript(sent: str):
clean_sent = [word for word in sent.split(' ') if not word.isupper()]
return ' '.join(clean_sent)

def make_distinct(label_lst):
"""
Make the label_lst distinct
"""
tag2cnt, new_tag_lst = {}, []
if len(label_lst) > 0:
for tag_item in label_lst:
_ = tag2cnt.setdefault(tag_item, 0)
tag2cnt[tag_item] += 1
tag, wrd = tag_item
new_tag_lst.append((tag, wrd, tag2cnt[tag_item]))
assert len(new_tag_lst) == len(set(new_tag_lst))
return new_tag_lst

def get_entity_format(sents: List[str], tags: dict, score_type: str):
def update_label_lst(lst, phrase, label):
if label in tags['NER']:
if score_type == "label":
lst.append((label, "phrase"))
else:
lst.append((label, phrase))

label_lst, sent_lst = [], []
for sent in sents:
sent_label_lst = []
sent = sent.replace(" ", " ")
wrd_lst = sent.split(" ")
sent_lst.append(sent)
phrase_lst, is_entity, num_illegal_assigments = [], False, 0
for wrd in wrd_lst:
if wrd in tags["NER"]:
if is_entity:
phrase_lst = []
num_illegal_assigments += 1
is_entity = True
entity_tag = wrd
elif wrd in tags["EMOTION"]:
sent_label_lst.append((wrd, "phrase"))
elif wrd in tags["END"]:
if is_entity:
if len(phrase_lst) > 0:
update_label_lst(sent_label_lst, " ".join(phrase_lst), entity_tag)
else:
num_illegal_assigments += 1
else:
num_illegal_assigments += 1
else:
if is_entity:
phrase_lst.append(wrd)
label_lst.append(make_distinct(sent_label_lst))
return label_lst, sent_lst

def multi_word_error_rate(hypotheses: List[str], references: List[str], tags: dict, score_type: str, use_cer=False) -> Union[float, float]:
hypotheses_without_tags = [get_clean_transcript(hypothesis) for hypothesis in hypotheses]
refs_without_tags = [get_clean_transcript(reference) for reference in references]

wer = word_error_rate(hypotheses_without_tags, refs_without_tags, use_cer)

hypo_label_list, _ = get_entity_format(hypotheses, tags, score_type)
ref_label_list, _ = get_entity_format(references, tags, score_type)

metrics = eval_utils_ner.get_ner_scores(ref_label_list, hypo_label_list)


return wer, metrics["overall_micro"]["fscore"]

Loading
Loading