55
66import re
77import sys
8- from typing import TYPE_CHECKING , Any
8+ from typing import TYPE_CHECKING , Union
99
1010if TYPE_CHECKING :
1111 import numpy as np
@@ -42,7 +42,7 @@ def _f1(precision: float, recall: float) -> float:
4242 return 2 * precision * recall / (precision + recall )
4343
4444
45- def _flatten_result (my_dict : dict , sep : str = ":" ) -> dict [str , Any ]:
45+ def _flatten_result (my_dict : dict , sep : str = ":" ) -> dict [str , Union [ int , str ] ]:
4646 """Flatten two-dimension dictionary.
4747
4848 Use keys in the first dimension as a prefix for keys in the second dimension.
@@ -56,7 +56,7 @@ def _flatten_result(my_dict: dict, sep: str = ":") -> dict[str, Any]:
5656 :param str sep: separator between the two keys (default: ":")
5757
5858 :return: a one-dimension dictionary with keys combined
59- :rtype: dict[str, Any ]
59+ :rtype: dict[str, Union[int, str] ]
6060 """
6161 return {
6262 f"{ k1 } { sep } { k2 } " : v
@@ -133,7 +133,7 @@ def preprocessing(txt: str, remove_space: bool = True) -> str:
133133 return txt
134134
135135
136- def compute_stats (ref_sample : str , raw_sample : str ) -> dict [str , Any ]:
136+ def compute_stats (ref_sample : str , raw_sample : str ) -> dict [str , dict [ str , Union [ int , str ]] ]:
137137 """Compute statistics for tokenization quality
138138
139139 These statistics include:
@@ -150,7 +150,7 @@ def compute_stats(ref_sample: str, raw_sample: str) -> dict[str, Any]:
150150 :param str samples: samples that we want to evaluate
151151
152152 :return: metrics at character- and word-level and indicators of correctly tokenized words
153- :rtype: dict[str, Any ]
153+ :rtype: dict[str, dict[str, Union[int, str]] ]
154154 """
155155 import numpy as np
156156
@@ -166,11 +166,11 @@ def compute_stats(ref_sample: str, raw_sample: str) -> dict[str, Any]:
166166 c_pos_pred = c_pos_pred [c_pos_pred < ref_sample_arr .shape [0 ]]
167167 c_neg_pred = c_neg_pred [c_neg_pred < ref_sample_arr .shape [0 ]]
168168
169- c_tp : np . intp = np .sum (ref_sample_arr [c_pos_pred ] == 1 )
170- c_fp : np . intp = np .sum (ref_sample_arr [c_pos_pred ] == 0 )
169+ c_tp : int = int ( np .sum (ref_sample_arr [c_pos_pred ] == 1 ) )
170+ c_fp : int = int ( np .sum (ref_sample_arr [c_pos_pred ] == 0 ) )
171171
172- c_tn : np . intp = np .sum (ref_sample_arr [c_neg_pred ] == 0 )
173- c_fn : np . intp = np .sum (ref_sample_arr [c_neg_pred ] == 1 )
172+ c_tn : int = int ( np .sum (ref_sample_arr [c_neg_pred ] == 0 ) )
173+ c_fn : int = int ( np .sum (ref_sample_arr [c_neg_pred ] == 1 ) )
174174
175175 # Compute word-level statistics
176176
@@ -183,7 +183,7 @@ def compute_stats(ref_sample: str, raw_sample: str) -> dict[str, Any]:
183183 word_boundaries , ss_boundaries
184184 )
185185
186- correctly_tokenised_words : np . intp = np .sum (tokenization_indicators )
186+ correctly_tokenised_words : int = int ( np .sum (tokenization_indicators ) )
187187
188188 tokenization_indicators_str = list (map (str , tokenization_indicators ))
189189
@@ -196,8 +196,8 @@ def compute_stats(ref_sample: str, raw_sample: str) -> dict[str, Any]:
196196 },
197197 "word_level" : {
198198 "correctly_tokenised_words" : correctly_tokenised_words ,
199- "total_words_in_sample" : np .sum (sample_arr ),
200- "total_words_in_ref_sample" : np .sum (ref_sample_arr ),
199+ "total_words_in_sample" : int ( np .sum (sample_arr ) ),
200+ "total_words_in_ref_sample" : int ( np .sum (ref_sample_arr ) ),
201201 },
202202 "global" : {
203203 "tokenisation_indicators" : "" .join (tokenization_indicators_str )
0 commit comments