|
| 1 | +from typing import Any, List, Optional, Sequence, Union |
| 2 | +from rich.console import Console |
| 3 | +from rich.table import Table |
| 4 | + |
| 5 | +import numpy as np |
| 6 | +import torch |
| 7 | +from mmengine.evaluator import BaseMetric |
| 8 | +from mmengine.logging import MMLogger |
| 9 | + |
| 10 | +from mmchat.registry import METRICS, TOKENIZER |
| 11 | + |
| 12 | + |
| 13 | +@METRICS.register_module() |
| 14 | +class MMLUMetric(BaseMetric): |
| 15 | + METAINFO = { |
| 16 | + 'subcategories': { |
| 17 | + "abstract_algebra": ["math"], |
| 18 | + "anatomy": ["health"], |
| 19 | + "astronomy": ["physics"], |
| 20 | + "business_ethics": ["business"], |
| 21 | + "clinical_knowledge": ["health"], |
| 22 | + "college_biology": ["biology"], |
| 23 | + "college_chemistry": ["chemistry"], |
| 24 | + "college_computer_science": ["computer science"], |
| 25 | + "college_mathematics": ["math"], |
| 26 | + "college_medicine": ["health"], |
| 27 | + "college_physics": ["physics"], |
| 28 | + "computer_security": ["computer science"], |
| 29 | + "conceptual_physics": ["physics"], |
| 30 | + "econometrics": ["economics"], |
| 31 | + "electrical_engineering": ["engineering"], |
| 32 | + "elementary_mathematics": ["math"], |
| 33 | + "formal_logic": ["philosophy"], |
| 34 | + "global_facts": ["other"], |
| 35 | + "high_school_biology": ["biology"], |
| 36 | + "high_school_chemistry": ["chemistry"], |
| 37 | + "high_school_computer_science": ["computer science"], |
| 38 | + "high_school_european_history": ["history"], |
| 39 | + "high_school_geography": ["geography"], |
| 40 | + "high_school_government_and_politics": ["politics"], |
| 41 | + "high_school_macroeconomics": ["economics"], |
| 42 | + "high_school_mathematics": ["math"], |
| 43 | + "high_school_microeconomics": ["economics"], |
| 44 | + "high_school_physics": ["physics"], |
| 45 | + "high_school_psychology": ["psychology"], |
| 46 | + "high_school_statistics": ["math"], |
| 47 | + "high_school_us_history": ["history"], |
| 48 | + "high_school_world_history": ["history"], |
| 49 | + "human_aging": ["health"], |
| 50 | + "human_sexuality": ["culture"], |
| 51 | + "international_law": ["law"], |
| 52 | + "jurisprudence": ["law"], |
| 53 | + "logical_fallacies": ["philosophy"], |
| 54 | + "machine_learning": ["computer science"], |
| 55 | + "management": ["business"], |
| 56 | + "marketing": ["business"], |
| 57 | + "medical_genetics": ["health"], |
| 58 | + "miscellaneous": ["other"], |
| 59 | + "moral_disputes": ["philosophy"], |
| 60 | + "moral_scenarios": ["philosophy"], |
| 61 | + "nutrition": ["health"], |
| 62 | + "philosophy": ["philosophy"], |
| 63 | + "prehistory": ["history"], |
| 64 | + "professional_accounting": ["other"], |
| 65 | + "professional_law": ["law"], |
| 66 | + "professional_medicine": ["health"], |
| 67 | + "professional_psychology": ["psychology"], |
| 68 | + "public_relations": ["politics"], |
| 69 | + "security_studies": ["politics"], |
| 70 | + "sociology": ["culture"], |
| 71 | + "us_foreign_policy": ["politics"], |
| 72 | + "virology": ["health"], |
| 73 | + "world_religions": ["philosophy"], |
| 74 | + }, |
| 75 | + 'categories': { |
| 76 | + "STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"], |
| 77 | + "humanities": ["history", "philosophy", "law"], |
| 78 | + "social sciences": ["politics", "culture", "economics", "geography", "psychology"], |
| 79 | + "other (business, health, misc.)": ["other", "business", "health"], |
| 80 | + }, |
| 81 | + } |
| 82 | + METAINFO['subcategories_list'] = list(set([subcat for subcats in METAINFO['subcategories'].values() |
| 83 | + for subcat in subcats])) |
| 84 | + |
| 85 | + def __init__(self, tokenizer, *args, **kwargs): |
| 86 | + super().__init__(*args, **kwargs) |
| 87 | + self.logger: MMLogger = MMLogger.get_current_instance() |
| 88 | + tokenizer = TOKENIZER.build(tokenizer) |
| 89 | + self.abcd_idx = [ |
| 90 | + tokenizer("A", add_special_tokens=False).input_ids[0], |
| 91 | + tokenizer("B", add_special_tokens=False).input_ids[0], |
| 92 | + tokenizer("C", add_special_tokens=False).input_ids[0], |
| 93 | + tokenizer("D", add_special_tokens=False).input_ids[0], |
| 94 | + ] |
| 95 | + |
| 96 | + @staticmethod |
| 97 | + def ABCD_to_0123(abcd): |
| 98 | + return {'A': 0, 'B': 1, 'C': 2, 'D': 3}[abcd] |
| 99 | + |
| 100 | + @staticmethod |
| 101 | + def accuracy(preds, gts): |
| 102 | + """Computes the accuracy for preds and gts""" |
| 103 | + correct = [1 if pred == gt else 0 for pred, gt in zip(preds, gts)] |
| 104 | + acc = np.mean(correct) * 100 |
| 105 | + return acc |
| 106 | + |
| 107 | + def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None: |
| 108 | + """Process one batch of data samples and predictions. The processed |
| 109 | + results should be stored in ``self.results``, which will be used to |
| 110 | + compute the metrics when all batches have been processed. |
| 111 | +
|
| 112 | + Args: |
| 113 | + data_batch (Any): A batch of data from the dataloader. |
| 114 | + data_samples (Sequence[dict]): A batch of outputs from |
| 115 | + the model. |
| 116 | + """ |
| 117 | + subjects = data_batch['subject'] |
| 118 | + gts = [self.ABCD_to_0123(gt) for gt in data_batch['output']] |
| 119 | + preds = [] |
| 120 | + for sample, subject, gt in zip(data_samples, subjects, gts): |
| 121 | + pred_logits = sample['logits'] |
| 122 | + labels = sample['labels'] |
| 123 | + labels_non_zero_id = (labels != -100).nonzero()[0][0] |
| 124 | + pred_logtis_abcd = pred_logits[labels_non_zero_id-1, self.abcd_idx] |
| 125 | + pred = torch.argmax(pred_logtis_abcd).item() |
| 126 | + preds.append(pred) |
| 127 | + self.results.append((subject, pred, gt)) |
| 128 | + |
| 129 | + def compute_metrics(self, results: list) -> dict: |
| 130 | + """Compute the metrics from processed results. |
| 131 | +
|
| 132 | + Args: |
| 133 | + results (list): The processed results of each batch. |
| 134 | +
|
| 135 | + Returns: |
| 136 | + dict: The computed metrics. The keys are the names of the metrics, |
| 137 | + and the values are corresponding results. |
| 138 | + """ |
| 139 | + subjects_results = {subject: {'preds': [], 'gts': []} for subject in self.METAINFO['subcategories'].keys()} |
| 140 | + subcats_results = {subcat: {'preds': [], 'gts': []} for subcat in self.METAINFO['subcategories_list']} |
| 141 | + cats_results = {cat: {'preds': [], 'gts': []} for cat in self.METAINFO['categories'].keys()} |
| 142 | + for subject, pred, gt in results: |
| 143 | + subjects_results[subject]['preds'].append(pred) |
| 144 | + subjects_results[subject]['gts'].append(gt) |
| 145 | + subcats = self.METAINFO['subcategories'][subject] |
| 146 | + for subcat in subcats: |
| 147 | + subcats_results[subcat]['preds'].append(pred) |
| 148 | + subcats_results[subcat]['gts'].append(gt) |
| 149 | + for cat, subcats in self.METAINFO['categories'].items(): |
| 150 | + for subcat in subcats: |
| 151 | + if subcat in subcats_results: |
| 152 | + cats_results[cat]['preds'].extend(subcats_results[subcat]['preds']) |
| 153 | + cats_results[cat]['gts'].extend(subcats_results[subcat]['gts']) |
| 154 | + |
| 155 | + subjects_metrics = dict() |
| 156 | + subcats_metrics = dict() |
| 157 | + cats_metrics = dict() |
| 158 | + for subject in self.METAINFO['subcategories'].keys(): |
| 159 | + assert len(subjects_results[subject]['preds']) == len(subjects_results[subject]['gts']) |
| 160 | + if len(subjects_results[subject]['preds']) == 0: |
| 161 | + self.logger.info(f'Skip subject {subject} for mmlu') |
| 162 | + else: |
| 163 | + score = self.accuracy(subjects_results[subject]['preds'], subjects_results[subject]['gts']) |
| 164 | + subjects_metrics[f'{subject}'] = score |
| 165 | + for subcat in self.METAINFO['subcategories_list']: |
| 166 | + assert len(subcats_results[subcat]['preds']) == len(subcats_results[subcat]['gts']) |
| 167 | + if len(subcats_results[subcat]['preds']) == 0: |
| 168 | + self.logger.info(f'Skip subcategory {subcat} for mmlu') |
| 169 | + else: |
| 170 | + score = self.accuracy(subcats_results[subcat]['preds'], subcats_results[subcat]['gts']) |
| 171 | + subcats_metrics[f'{subcat}'] = score |
| 172 | + for cat in self.METAINFO['categories'].keys(): |
| 173 | + assert len(cats_results[cat]['preds']) == len(cats_results[cat]['gts']) |
| 174 | + if len(cats_results[cat]['preds']) == 0: |
| 175 | + self.logger.info(f'Skip category {cat} for mmlu') |
| 176 | + else: |
| 177 | + score = self.accuracy(cats_results[cat]['preds'], cats_results[cat]['gts']) |
| 178 | + cats_metrics[f'{cat}'] = score |
| 179 | + |
| 180 | + metrics = dict() |
| 181 | + metrics.update(subjects_metrics) |
| 182 | + metrics.update(subcats_metrics) |
| 183 | + metrics.update(cats_metrics) |
| 184 | + metrics['average'] = np.mean(list(subjects_metrics.values())) |
| 185 | + |
| 186 | + table_metrics = dict() |
| 187 | + table_metrics.update(cats_metrics) |
| 188 | + table_metrics['average'] = np.mean(list(subjects_metrics.values())) |
| 189 | + self._print_results(table_metrics) |
| 190 | + return metrics |
| 191 | + |
| 192 | + def _print_results(self, table_metrics: dict) -> None: |
| 193 | + table_title = ' MMLU Benchmark ' |
| 194 | + table = Table(title=table_title) |
| 195 | + console = Console() |
| 196 | + table.add_column('Categories', justify='left') |
| 197 | + table.add_column('Accuracy (%)', justify='right') |
| 198 | + for cat, acc in table_metrics.items(): |
| 199 | + table.add_row(cat, '{:.1f}'.format(acc)) |
| 200 | + with console.capture() as capture: |
| 201 | + console.print(table, end='') |
| 202 | + self.logger.info('\n' + capture.get()) |
0 commit comments