Skip to content

Commit 025a275

Browse files
committed
Improved results summarization
1 parent 4565da8 commit 025a275

File tree

1 file changed

+69
-34
lines changed

1 file changed

+69
-34
lines changed

src/thunder/utils/results.py

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,61 +4,86 @@
44
METRICS = ["accuracy", "balanced_accuracy", "f1", "jaccard", "roc_auc"]
55

66
TASK2CATS = {
7-
"adversarial_attack": ["adversarial", "clean"],
7+
"adversarial_attack": ["adversarial", "clean", "drop"],
88
"simple_shot": ["1", "2", "4", "8", "16"],
99
"image_retrieval": ["1", "3", "5", "10"],
1010
}
1111

12+
VAL_TYPES = ["metric_score", "ci_low", "ci_high"]
13+
1214

1315
def extract_value_from_large_json(
14-
file_path: str, target_keys: list, categories: list = None
16+
file_path: str, target_keys: list, val_types: list, categories: list = None
1517
) -> dict:
1618
"""
1719
Extracts values from a large JSON file using ijson.
1820
:param file_path: Path to the JSON file.
1921
:param target_keys: List of keys to extract values for.
20-
:param categories: List of first-level keys (each pointing to a dict)
22+
:param val_types: List of types of values to consider.
23+
:param categories: List of first-level keys (each pointing to a dict).
24+
25+
:return: Dictionary with keys and their corresponding values.
26+
"""
2127

22-
:return: Dictionary with keys and their corresponding values."""
2328
import ijson
2429

2530
if categories is None:
26-
values_dict = {key: None for key in target_keys}
31+
values_dict = {
32+
key: {val_type: None for val_type in val_types} for key in target_keys
33+
}
2734
else:
28-
values_dict = {cat: {key: None for key in target_keys} for cat in categories}
35+
values_dict = {
36+
cat: {
37+
key: {val_type: None for val_type in val_types} for key in target_keys
38+
}
39+
for cat in categories
40+
}
2941

3042
with open(file_path, "rb") as f: # Open the file in binary mode
3143
parser = ijson.parse(f)
3244
current_key = None
45+
current_val_type = None
3346
if categories is not None:
3447
current_cat = None
3548
for _, event, value in parser:
3649
if categories is None:
3750
if event == "map_key" and value in target_keys:
3851
current_key = value
39-
elif current_key is not None and event in [
40-
"string",
41-
"number",
42-
"boolean",
43-
"null",
44-
]:
45-
values_dict[current_key] = value
46-
current_key = None
47-
if all([val is not None for val in values_dict.values()]):
48-
return values_dict
52+
elif event == "map_key" and value in val_types:
53+
current_val_type = value
54+
elif (
55+
current_key is not None
56+
and current_val_type is not None
57+
and event
58+
in [
59+
"string",
60+
"number",
61+
"boolean",
62+
"null",
63+
]
64+
):
65+
values_dict[current_key][current_val_type] = value
66+
current_val_type = None
4967
else:
5068
if event == "map_key" and value in categories:
5169
current_cat = value
5270
elif event == "map_key" and value in target_keys:
5371
current_key = value
54-
elif current_key is not None and event in [
55-
"string",
56-
"number",
57-
"boolean",
58-
"null",
59-
]:
60-
values_dict[current_cat][current_key] = value
61-
current_key = None
72+
elif event == "map_key" and value in val_types:
73+
current_val_type = value
74+
elif (
75+
current_key is not None
76+
and current_val_type is not None
77+
and event
78+
in [
79+
"string",
80+
"number",
81+
"boolean",
82+
"null",
83+
]
84+
):
85+
values_dict[current_cat][current_key][current_val_type] = value
86+
current_val_type = None
6287

6388
return values_dict
6489

@@ -86,27 +111,37 @@ def gather_results():
86111
# Extracting results
87112
if task in TASK2CATS.keys():
88113
categories = TASK2CATS[task]
89-
result_dicts = extract_value_from_large_json(file, METRICS, categories)
114+
result_dicts = extract_value_from_large_json(
115+
file, METRICS, VAL_TYPES, categories
116+
)
90117
else:
91118
result_dicts = {
92-
"": extract_value_from_large_json(file, METRICS),
119+
"": extract_value_from_large_json(file, METRICS, VAL_TYPES),
93120
}
94121

95122
# Populating rows of results
96123
for setting, result_dict in result_dicts.items():
97124
for metric in METRICS:
98125
metric_dict = base_dict.copy()
99126
metric_dict["metric"] = metric
100-
value = result_dict[metric]
101-
if value is not None:
102-
value = round(100 * value, 1)
103-
metric_dict["value"] = value
104-
metric_dict["setting"] = setting
105-
res.append(metric_dict)
127+
value_dict = result_dict[metric]
128+
write_line = True
129+
for val_type in value_dict.keys():
130+
value = value_dict[val_type]
131+
if value is not None:
132+
value = round(100 * value, 1)
133+
else:
134+
write_line = False
135+
metric_dict[val_type] = value
136+
137+
if write_line:
138+
metric_dict["setting"] = setting
139+
res.append(metric_dict)
106140

107141
df = pd.DataFrame(res)
108142
df.to_csv(os.path.join(results_dir, "results.csv"), index=False)
109-
logging.info("Saved at:", os.path.join(results_dir, "results.csv"))
143+
logging.info(f"Saved at: {os.path.join(results_dir, 'results.csv')}")
110144
logging.info(
111-
"The setting column corresponds to: (i) Adversarial/Clean for adversarial_attack, (ii) nb shots for simple_shot, (iii) k for image_retrieval."
145+
"The setting column corresponds to: (i) Adversarial/Clean/Drop for adversarial_attack, (ii) nb shots for simple_shot, (iii) k for image_retrieval. "
146+
"'metric_score' is the metric value for the considered test set, 'ci_low' and 'ci_high' are lower and upper bounds of 95% bootstrap confidence interval."
112147
)

0 commit comments

Comments
 (0)