Skip to content

Commit e840b6c

Browse files
authored
add MMLU benchmark val / test (#1)
* add mmlu dataset configs * add mmlu metric * fix bugs * implement predict for sft model * remove dummy file * clean code * modify prefix for mmlu test * add test.py * add mmlu val/test for gunaco config * use float16 for gunaco * add METAINFO and add logger
1 parent ca15479 commit e840b6c

File tree

11 files changed

+396
-17
lines changed

11 files changed

+396
-17
lines changed

configs/_base_/datasets/mmlu_fs.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from datasets import load_dataset
2+
from mmchat.datasets import process_hf_dataset
3+
from mmengine.dataset import DefaultSampler
4+
5+
6+
data_root = 'data/mmlu/'
7+
8+
mmlu_fs_dataset = dict(
9+
type=load_dataset,
10+
path='json',
11+
data_files=dict(
12+
val=data_root + 'five_shot_mmlu_val.json',
13+
test=data_root + 'five_shot_mmlu_test.json'))
14+
15+
val_mmlu_fs = dict(
16+
type=process_hf_dataset,
17+
dataset=mmlu_fs_dataset,
18+
mode='val')
19+
val_dataloader = dict(
20+
batch_size=1,
21+
num_workers=1,
22+
dataset=val_mmlu_fs,
23+
sampler=dict(type=DefaultSampler, shuffle=False))
24+
25+
test_mmlu_fs = dict(
26+
type=process_hf_dataset,
27+
dataset=mmlu_fs_dataset,
28+
mode='test')
29+
test_dataloader = dict(
30+
batch_size=1,
31+
num_workers=1,
32+
dataset=test_mmlu_fs,
33+
sampler=dict(type=DefaultSampler, shuffle=False))
34+
35+
val_cfg = dict(type='ValLoop')
36+
test_cfg = dict(type='TestLoop')
37+
38+
val_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_fs_val')
39+
test_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_fs_test')

configs/_base_/datasets/mmlu_zs.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from datasets import load_dataset
2+
from mmchat.datasets import process_hf_dataset
3+
from mmengine.dataset import DefaultSampler
4+
5+
6+
data_root = 'data/mmlu/'
7+
8+
mmlu_zs_dataset = dict(
9+
type=load_dataset,
10+
path='json',
11+
data_files=dict(
12+
val=data_root + 'zero_shot_mmlu_val.json',
13+
test=data_root + 'zero_shot_mmlu_test.json'))
14+
15+
val_mmlu_zs = dict(
16+
type=process_hf_dataset,
17+
dataset=mmlu_zs_dataset,
18+
mode='val')
19+
val_dataloader = dict(
20+
batch_size=1,
21+
num_workers=1,
22+
dataset=val_mmlu_zs,
23+
sampler=dict(type=DefaultSampler, shuffle=False))
24+
25+
test_mmlu_zs = dict(
26+
type=process_hf_dataset,
27+
dataset=mmlu_zs_dataset,
28+
mode='test')
29+
test_dataloader = dict(
30+
batch_size=1,
31+
num_workers=1,
32+
dataset=test_mmlu_zs,
33+
sampler=dict(type=DefaultSampler, shuffle=False))
34+
35+
val_cfg = dict(type='ValLoop')
36+
test_cfg = dict(type='TestLoop')
37+
38+
val_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_zs_val')
39+
test_evaluator = dict(type='MMLUMetric', tokenizer=None, prefix='mmlu_zs_test')

configs/guanaco/gunaco_llama_7B.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
with read_base():
99
from .._base_.datasets.oasst1 import *
10+
from .._base_.datasets.mmlu_fs import *
1011
from .._base_.schedules.guanaco import *
1112
from .._base_.default_runtime import *
1213

@@ -20,21 +21,21 @@
2021
use_fast = False,
2122
padding_side="right",
2223
),
23-
source_max_len = 16,
24+
source_max_len = 2048,
2425
target_max_len = 512,
2526
train_on_source = False,
2627
predict_with_generate = False,
2728
),
2829
llm = dict(
2930
type=AutoModelForCausalLM.from_pretrained,
3031
pretrained_model_name_or_path = '/nvme/share_data/llama-7b',
31-
torch_dtype = torch.float32,
32+
torch_dtype = torch.float16,
3233
quantization_config=dict(
3334
type = BitsAndBytesConfig,
3435
load_in_4bit=True,
3536
load_in_8bit=False,
3637
llm_int8_has_fp16_weight=False,
37-
bnb_4bit_compute_dtype=torch.float32,
38+
bnb_4bit_compute_dtype=torch.float16,
3839
bnb_4bit_use_double_quant=True,
3940
bnb_4bit_quant_type = 'nf4'
4041
)
@@ -50,3 +51,14 @@
5051

5152
)
5253

54+
val_evaluator['tokenizer'] = dict(
55+
type=AutoTokenizer.from_pretrained,
56+
pretrained_model_name_or_path='/nvme/share_data/llama-7b',
57+
use_fast=False,
58+
padding_side="right")
59+
60+
test_evaluator['tokenizer'] = dict(
61+
type=AutoTokenizer.from_pretrained,
62+
pretrained_model_name_or_path='/nvme/share_data/llama-7b',
63+
use_fast=False,
64+
padding_side="right")

mmchat/datasets/huggingface.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ def _prompt_format(example):
3333
dataset = dataset.rename_column(old, new)
3434

3535
# Remove unused columns.
36-
dataset = dataset.remove_columns(
37-
[col for col in dataset.column_names['train'] if col not in ['input', 'output']]
38-
)
36+
if 'train' in dataset.column_names:
37+
dataset = dataset.remove_columns(
38+
[col for col in dataset.column_names['train'] if col not in ['input', 'output']]
39+
)
3940
return dataset[mode]
4041

4142

mmchat/evaluation/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .metrics import *
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .mmlu_metric import MMLUMetric
2+
3+
__all__ = ['MMLUMetric']
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
from typing import Any, List, Optional, Sequence, Union
2+
from rich.console import Console
3+
from rich.table import Table
4+
5+
import numpy as np
6+
import torch
7+
from mmengine.evaluator import BaseMetric
8+
from mmengine.logging import MMLogger
9+
10+
from mmchat.registry import METRICS, TOKENIZER
11+
12+
13+
@METRICS.register_module()
14+
class MMLUMetric(BaseMetric):
15+
METAINFO = {
16+
'subcategories': {
17+
"abstract_algebra": ["math"],
18+
"anatomy": ["health"],
19+
"astronomy": ["physics"],
20+
"business_ethics": ["business"],
21+
"clinical_knowledge": ["health"],
22+
"college_biology": ["biology"],
23+
"college_chemistry": ["chemistry"],
24+
"college_computer_science": ["computer science"],
25+
"college_mathematics": ["math"],
26+
"college_medicine": ["health"],
27+
"college_physics": ["physics"],
28+
"computer_security": ["computer science"],
29+
"conceptual_physics": ["physics"],
30+
"econometrics": ["economics"],
31+
"electrical_engineering": ["engineering"],
32+
"elementary_mathematics": ["math"],
33+
"formal_logic": ["philosophy"],
34+
"global_facts": ["other"],
35+
"high_school_biology": ["biology"],
36+
"high_school_chemistry": ["chemistry"],
37+
"high_school_computer_science": ["computer science"],
38+
"high_school_european_history": ["history"],
39+
"high_school_geography": ["geography"],
40+
"high_school_government_and_politics": ["politics"],
41+
"high_school_macroeconomics": ["economics"],
42+
"high_school_mathematics": ["math"],
43+
"high_school_microeconomics": ["economics"],
44+
"high_school_physics": ["physics"],
45+
"high_school_psychology": ["psychology"],
46+
"high_school_statistics": ["math"],
47+
"high_school_us_history": ["history"],
48+
"high_school_world_history": ["history"],
49+
"human_aging": ["health"],
50+
"human_sexuality": ["culture"],
51+
"international_law": ["law"],
52+
"jurisprudence": ["law"],
53+
"logical_fallacies": ["philosophy"],
54+
"machine_learning": ["computer science"],
55+
"management": ["business"],
56+
"marketing": ["business"],
57+
"medical_genetics": ["health"],
58+
"miscellaneous": ["other"],
59+
"moral_disputes": ["philosophy"],
60+
"moral_scenarios": ["philosophy"],
61+
"nutrition": ["health"],
62+
"philosophy": ["philosophy"],
63+
"prehistory": ["history"],
64+
"professional_accounting": ["other"],
65+
"professional_law": ["law"],
66+
"professional_medicine": ["health"],
67+
"professional_psychology": ["psychology"],
68+
"public_relations": ["politics"],
69+
"security_studies": ["politics"],
70+
"sociology": ["culture"],
71+
"us_foreign_policy": ["politics"],
72+
"virology": ["health"],
73+
"world_religions": ["philosophy"],
74+
},
75+
'categories': {
76+
"STEM": ["physics", "chemistry", "biology", "computer science", "math", "engineering"],
77+
"humanities": ["history", "philosophy", "law"],
78+
"social sciences": ["politics", "culture", "economics", "geography", "psychology"],
79+
"other (business, health, misc.)": ["other", "business", "health"],
80+
},
81+
}
82+
METAINFO['subcategories_list'] = list(set([subcat for subcats in METAINFO['subcategories'].values()
83+
for subcat in subcats]))
84+
85+
def __init__(self, tokenizer, *args, **kwargs):
86+
super().__init__(*args, **kwargs)
87+
self.logger: MMLogger = MMLogger.get_current_instance()
88+
tokenizer = TOKENIZER.build(tokenizer)
89+
self.abcd_idx = [
90+
tokenizer("A", add_special_tokens=False).input_ids[0],
91+
tokenizer("B", add_special_tokens=False).input_ids[0],
92+
tokenizer("C", add_special_tokens=False).input_ids[0],
93+
tokenizer("D", add_special_tokens=False).input_ids[0],
94+
]
95+
96+
@staticmethod
97+
def ABCD_to_0123(abcd):
98+
return {'A': 0, 'B': 1, 'C': 2, 'D': 3}[abcd]
99+
100+
@staticmethod
101+
def accuracy(preds, gts):
102+
"""Computes the accuracy for preds and gts"""
103+
correct = [1 if pred == gt else 0 for pred, gt in zip(preds, gts)]
104+
acc = np.mean(correct) * 100
105+
return acc
106+
107+
def process(self, data_batch: Any, data_samples: Sequence[dict]) -> None:
108+
"""Process one batch of data samples and predictions. The processed
109+
results should be stored in ``self.results``, which will be used to
110+
compute the metrics when all batches have been processed.
111+
112+
Args:
113+
data_batch (Any): A batch of data from the dataloader.
114+
data_samples (Sequence[dict]): A batch of outputs from
115+
the model.
116+
"""
117+
subjects = data_batch['subject']
118+
gts = [self.ABCD_to_0123(gt) for gt in data_batch['output']]
119+
preds = []
120+
for sample, subject, gt in zip(data_samples, subjects, gts):
121+
pred_logits = sample['logits']
122+
labels = sample['labels']
123+
labels_non_zero_id = (labels != -100).nonzero()[0][0]
124+
pred_logtis_abcd = pred_logits[labels_non_zero_id-1, self.abcd_idx]
125+
pred = torch.argmax(pred_logtis_abcd).item()
126+
preds.append(pred)
127+
self.results.append((subject, pred, gt))
128+
129+
def compute_metrics(self, results: list) -> dict:
130+
"""Compute the metrics from processed results.
131+
132+
Args:
133+
results (list): The processed results of each batch.
134+
135+
Returns:
136+
dict: The computed metrics. The keys are the names of the metrics,
137+
and the values are corresponding results.
138+
"""
139+
subjects_results = {subject: {'preds': [], 'gts': []} for subject in self.METAINFO['subcategories'].keys()}
140+
subcats_results = {subcat: {'preds': [], 'gts': []} for subcat in self.METAINFO['subcategories_list']}
141+
cats_results = {cat: {'preds': [], 'gts': []} for cat in self.METAINFO['categories'].keys()}
142+
for subject, pred, gt in results:
143+
subjects_results[subject]['preds'].append(pred)
144+
subjects_results[subject]['gts'].append(gt)
145+
subcats = self.METAINFO['subcategories'][subject]
146+
for subcat in subcats:
147+
subcats_results[subcat]['preds'].append(pred)
148+
subcats_results[subcat]['gts'].append(gt)
149+
for cat, subcats in self.METAINFO['categories'].items():
150+
for subcat in subcats:
151+
if subcat in subcats_results:
152+
cats_results[cat]['preds'].extend(subcats_results[subcat]['preds'])
153+
cats_results[cat]['gts'].extend(subcats_results[subcat]['gts'])
154+
155+
subjects_metrics = dict()
156+
subcats_metrics = dict()
157+
cats_metrics = dict()
158+
for subject in self.METAINFO['subcategories'].keys():
159+
assert len(subjects_results[subject]['preds']) == len(subjects_results[subject]['gts'])
160+
if len(subjects_results[subject]['preds']) == 0:
161+
self.logger.info(f'Skip subject {subject} for mmlu')
162+
else:
163+
score = self.accuracy(subjects_results[subject]['preds'], subjects_results[subject]['gts'])
164+
subjects_metrics[f'{subject}'] = score
165+
for subcat in self.METAINFO['subcategories_list']:
166+
assert len(subcats_results[subcat]['preds']) == len(subcats_results[subcat]['gts'])
167+
if len(subcats_results[subcat]['preds']) == 0:
168+
self.logger.info(f'Skip subcategory {subcat} for mmlu')
169+
else:
170+
score = self.accuracy(subcats_results[subcat]['preds'], subcats_results[subcat]['gts'])
171+
subcats_metrics[f'{subcat}'] = score
172+
for cat in self.METAINFO['categories'].keys():
173+
assert len(cats_results[cat]['preds']) == len(cats_results[cat]['gts'])
174+
if len(cats_results[cat]['preds']) == 0:
175+
self.logger.info(f'Skip category {cat} for mmlu')
176+
else:
177+
score = self.accuracy(cats_results[cat]['preds'], cats_results[cat]['gts'])
178+
cats_metrics[f'{cat}'] = score
179+
180+
metrics = dict()
181+
metrics.update(subjects_metrics)
182+
metrics.update(subcats_metrics)
183+
metrics.update(cats_metrics)
184+
metrics['average'] = np.mean(list(subjects_metrics.values()))
185+
186+
table_metrics = dict()
187+
table_metrics.update(cats_metrics)
188+
table_metrics['average'] = np.mean(list(subjects_metrics.values()))
189+
self._print_results(table_metrics)
190+
return metrics
191+
192+
def _print_results(self, table_metrics: dict) -> None:
193+
table_title = ' MMLU Benchmark '
194+
table = Table(title=table_title)
195+
console = Console()
196+
table.add_column('Categories', justify='left')
197+
table.add_column('Accuracy (%)', justify='right')
198+
for cat, acc in table_metrics.items():
199+
table.add_row(cat, '{:.1f}'.format(acc))
200+
with console.capture() as capture:
201+
console.print(table, end='')
202+
self.logger.info('\n' + capture.get())

mmchat/evaluation/mmlu.py

Whitespace-only changes.

mmchat/models/algorithms/sft.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -57,11 +57,12 @@ def __init__(self, llm, data_preprocessor):
5757
self.llm = self._build_from_cfg_or_module(llm, LLM)
5858
self.llm.config.use_cache = False
5959
self.llm.config.torch_dtype = torch.float32
60-
smart_tokenizer_and_embedding_resize(
61-
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
62-
tokenizer=self.tokenizer,
63-
model=self.llm,
64-
)
60+
if self.tokenizer._pad_token is None:
61+
smart_tokenizer_and_embedding_resize(
62+
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
63+
tokenizer=self.tokenizer,
64+
model=self.llm,
65+
)
6566
from transformers.models.llama import LlamaTokenizer
6667

6768
if isinstance(self.tokenizer, LlamaTokenizer):
@@ -110,14 +111,12 @@ def _forward(self, data, data_samples=None):
110111
return outputs
111112

112113
def predict(self, data, data_samples=None):
113-
114114
outputs = self.llm(**data)
115-
116-
return outputs
117-
115+
logits_dict = [{'labels': labels, 'logits': logits} \
116+
for labels, logits in zip(data['labels'], outputs.logits)]
117+
return logits_dict
118118

119119
def compute_loss(self, data, data_samples=None):
120-
121120
outputs = self.llm(**data)
122121
# import pdb;pdb.set_trace()
123122
loss_dict = {'loss_llm': outputs.loss}

mmchat/models/utils/data_processor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,5 +70,5 @@ def forward(self,instances: Sequence[Dict], training=True) -> Dict[str, torch.Te
7070
if labels is not None:
7171
data_dict['labels'] = labels
7272

73-
return {'data': data_dict, 'data_samples': None}
73+
return self.cast_data({'data': data_dict, 'data_samples': None})
7474

0 commit comments

Comments
 (0)