|
4 | 4 | METRICS = ["accuracy", "balanced_accuracy", "f1", "jaccard", "roc_auc"] |
5 | 5 |
|
6 | 6 | TASK2CATS = { |
7 | | - "adversarial_attack": ["adversarial", "clean"], |
| 7 | + "adversarial_attack": ["adversarial", "clean", "drop"], |
8 | 8 | "simple_shot": ["1", "2", "4", "8", "16"], |
9 | 9 | "image_retrieval": ["1", "3", "5", "10"], |
10 | 10 | } |
11 | 11 |
|
| 12 | +VAL_TYPES = ["metric_score", "ci_low", "ci_high"] |
| 13 | + |
12 | 14 |
|
13 | 15 | def extract_value_from_large_json( |
14 | | - file_path: str, target_keys: list, categories: list = None |
| 16 | + file_path: str, target_keys: list, val_types: list, categories: list = None |
15 | 17 | ) -> dict: |
16 | 18 | """ |
17 | 19 | Extracts values from a large JSON file using ijson. |
18 | 20 | :param file_path: Path to the JSON file. |
19 | 21 | :param target_keys: List of keys to extract values for. |
20 | | - :param categories: List of first-level keys (each pointing to a dict) |
| 22 | + :param val_types: List of types of values to consider. |
| 23 | + :param categories: List of first-level keys (each pointing to a dict). |
| 24 | +
|
| 25 | + :return: Dictionary with keys and their corresponding values. |
| 26 | + """ |
21 | 27 |
|
22 | | - :return: Dictionary with keys and their corresponding values.""" |
23 | 28 | import ijson |
24 | 29 |
|
25 | 30 | if categories is None: |
26 | | - values_dict = {key: None for key in target_keys} |
| 31 | + values_dict = { |
| 32 | + key: {val_type: None for val_type in val_types} for key in target_keys |
| 33 | + } |
27 | 34 | else: |
28 | | - values_dict = {cat: {key: None for key in target_keys} for cat in categories} |
| 35 | + values_dict = { |
| 36 | + cat: { |
| 37 | + key: {val_type: None for val_type in val_types} for key in target_keys |
| 38 | + } |
| 39 | + for cat in categories |
| 40 | + } |
29 | 41 |
|
30 | 42 | with open(file_path, "rb") as f: # Open the file in binary mode |
31 | 43 | parser = ijson.parse(f) |
32 | 44 | current_key = None |
| 45 | + current_val_type = None |
33 | 46 | if categories is not None: |
34 | 47 | current_cat = None |
35 | 48 | for _, event, value in parser: |
36 | 49 | if categories is None: |
37 | 50 | if event == "map_key" and value in target_keys: |
38 | 51 | current_key = value |
39 | | - elif current_key is not None and event in [ |
40 | | - "string", |
41 | | - "number", |
42 | | - "boolean", |
43 | | - "null", |
44 | | - ]: |
45 | | - values_dict[current_key] = value |
46 | | - current_key = None |
47 | | - if all([val is not None for val in values_dict.values()]): |
48 | | - return values_dict |
| 52 | + elif event == "map_key" and value in val_types: |
| 53 | + current_val_type = value |
| 54 | + elif ( |
| 55 | + current_key is not None |
| 56 | + and current_val_type is not None |
| 57 | + and event |
| 58 | + in [ |
| 59 | + "string", |
| 60 | + "number", |
| 61 | + "boolean", |
| 62 | + "null", |
| 63 | + ] |
| 64 | + ): |
| 65 | + values_dict[current_key][current_val_type] = value |
| 66 | + current_val_type = None |
49 | 67 | else: |
50 | 68 | if event == "map_key" and value in categories: |
51 | 69 | current_cat = value |
52 | 70 | elif event == "map_key" and value in target_keys: |
53 | 71 | current_key = value |
54 | | - elif current_key is not None and event in [ |
55 | | - "string", |
56 | | - "number", |
57 | | - "boolean", |
58 | | - "null", |
59 | | - ]: |
60 | | - values_dict[current_cat][current_key] = value |
61 | | - current_key = None |
| 72 | + elif event == "map_key" and value in val_types: |
| 73 | + current_val_type = value |
| 74 | + elif ( |
| 75 | + current_key is not None |
| 76 | + and current_val_type is not None |
| 77 | + and event |
| 78 | + in [ |
| 79 | + "string", |
| 80 | + "number", |
| 81 | + "boolean", |
| 82 | + "null", |
| 83 | + ] |
| 84 | + ): |
| 85 | + values_dict[current_cat][current_key][current_val_type] = value |
| 86 | + current_val_type = None |
62 | 87 |
|
63 | 88 | return values_dict |
64 | 89 |
|
@@ -86,27 +111,37 @@ def gather_results(): |
86 | 111 | # Extracting results |
87 | 112 | if task in TASK2CATS.keys(): |
88 | 113 | categories = TASK2CATS[task] |
89 | | - result_dicts = extract_value_from_large_json(file, METRICS, categories) |
| 114 | + result_dicts = extract_value_from_large_json( |
| 115 | + file, METRICS, VAL_TYPES, categories |
| 116 | + ) |
90 | 117 | else: |
91 | 118 | result_dicts = { |
92 | | - "": extract_value_from_large_json(file, METRICS), |
| 119 | + "": extract_value_from_large_json(file, METRICS, VAL_TYPES), |
93 | 120 | } |
94 | 121 |
|
95 | 122 | # Populating rows of results |
96 | 123 | for setting, result_dict in result_dicts.items(): |
97 | 124 | for metric in METRICS: |
98 | 125 | metric_dict = base_dict.copy() |
99 | 126 | metric_dict["metric"] = metric |
100 | | - value = result_dict[metric] |
101 | | - if value is not None: |
102 | | - value = round(100 * value, 1) |
103 | | - metric_dict["value"] = value |
104 | | - metric_dict["setting"] = setting |
105 | | - res.append(metric_dict) |
| 127 | + value_dict = result_dict[metric] |
| 128 | + write_line = True |
| 129 | + for val_type in value_dict.keys(): |
| 130 | + value = value_dict[val_type] |
| 131 | + if value is not None: |
| 132 | + value = round(100 * value, 1) |
| 133 | + else: |
| 134 | + write_line = False |
| 135 | + metric_dict[val_type] = value |
| 136 | + |
| 137 | + if write_line: |
| 138 | + metric_dict["setting"] = setting |
| 139 | + res.append(metric_dict) |
106 | 140 |
|
107 | 141 | df = pd.DataFrame(res) |
108 | 142 | df.to_csv(os.path.join(results_dir, "results.csv"), index=False) |
109 | | - logging.info("Saved at:", os.path.join(results_dir, "results.csv")) |
| 143 | + logging.info(f"Saved at: {os.path.join(results_dir, 'results.csv')}") |
110 | 144 | logging.info( |
111 | | - "The setting column corresponds to: (i) Adversarial/Clean for adversarial_attack, (ii) nb shots for simple_shot, (iii) k for image_retrieval." |
| 145 | + "The setting column corresponds to: (i) Adversarial/Clean/Drop for adversarial_attack, (ii) nb shots for simple_shot, (iii) k for image_retrieval. " |
| 146 | + "'metric_score' is the metric value for the considered test set, 'ci_low' and 'ci_high' are lower and upper bounds of 95% bootstrap confidence interval." |
112 | 147 | ) |
0 commit comments