Add evaluation script
This commit is contained in:
parent
8673270a4a
commit
0a499b7e9a
@ -25,7 +25,7 @@ ChatGLM2-6B 开源模型旨在与开源社区一起推动大模型技术发展
|
||||
尽管模型在训练的各个阶段都尽力确保数据的合规性和准确性,但由于 ChatGLM2-6B 模型规模较小,且模型受概率随机性因素影响,无法保证输出内容的准确性,且模型易被误导。**本项目不承担开源模型和代码导致的数据安全、舆情风险或发生任何模型被误导、滥用、传播、不当利用而产生的风险和责任。**
|
||||
|
||||
## 评测结果
|
||||
我们选取了部分中英文典型数据集进行了评测,以下为 ChatGLM2-6B 模型在 [MMLU](https://github.com/hendrycks/test) (英文)、[C-Eval](https://cevalbenchmark.com/static/leaderboard.html)(中文)、[GSM8K](https://github.com/openai/grade-school-math)(数学)、[BBH](https://github.com/suzgunmirac/BIG-Bench-Hard)(英文) 上的测评结果。
|
||||
我们选取了部分中英文典型数据集进行了评测,以下为 ChatGLM2-6B 模型在 [MMLU](https://github.com/hendrycks/test) (英文)、[C-Eval](https://cevalbenchmark.com/static/leaderboard.html)(中文)、[GSM8K](https://github.com/openai/grade-school-math)(数学)、[BBH](https://github.com/suzgunmirac/BIG-Bench-Hard)(英文) 上的测评结果。在 [evaluation](./evaluation/README.md) 中提供了在 C-Eval 上进行测评的脚本。
|
||||
|
||||
### MMLU
|
||||
|
||||
|
10
evaluation/README.md
Normal file
10
evaluation/README.md
Normal file
@ -0,0 +1,10 @@
|
||||
首先从 [Tsinghua Cloud](https://cloud.tsinghua.edu.cn/f/e84444333b6d434ea7b0) 下载处理好的 C-Eval 数据集,解压到 `evaluation` 目录下。然后运行
|
||||
|
||||
```shell
|
||||
cd evaluation
|
||||
python evaluate_ceval.py
|
||||
```
|
||||
|
||||
这个脚本会在C-Eval的验证集上进行预测并输出准确率。如果想要得到测试集上的结果可以将代码中的 `./CEval/val/**/*.jsonl` 改为 `./CEval/test/**/*.jsonl`,并按照 C-Eval 规定的格式保存结果并在 [官网](https://cevalbenchmark.com/) 上提交。
|
||||
|
||||
汇报的结果使用的是内部的并行测试框架,结果可能会有轻微波动。
|
60
evaluation/evaluate_ceval.py
Normal file
60
evaluation/evaluate_ceval.py
Normal file
@ -0,0 +1,60 @@
|
||||
import os
|
||||
import glob
|
||||
import re
|
||||
import json
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from transformers import AutoTokenizer, AutoModel
|
||||
from tqdm import tqdm
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True)
|
||||
model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).bfloat16().cuda()
|
||||
|
||||
choices = ["A", "B", "C", "D"]
|
||||
choice_tokens = [tokenizer.encode(choice, add_special_tokens=False)[0] for choice in choices]
|
||||
|
||||
|
||||
def build_prompt(text):
|
||||
return "[Round {}]\n\n问:{}\n\n答:".format(1, text)
|
||||
|
||||
|
||||
extraction_prompt = '综上所述,ABCD中正确的选项是:'
|
||||
|
||||
accuracy_dict, count_dict = {}, {}
|
||||
with torch.no_grad():
|
||||
for entry in glob.glob("./CEval/val/**/*.jsonl", recursive=True):
|
||||
dataset = []
|
||||
with open(entry, encoding='utf-8') as file:
|
||||
for line in file:
|
||||
dataset.append(json.loads(line))
|
||||
correct = 0
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
|
||||
for batch in tqdm(dataloader):
|
||||
texts = batch["inputs_pretokenized"]
|
||||
queries = [build_prompt(query) for query in texts]
|
||||
inputs = tokenizer(queries, padding=True, return_tensors="pt", truncation=True, max_length=2048).to('cuda')
|
||||
outputs = model.generate(**inputs, do_sample=False, max_new_tokens=512)
|
||||
intermediate_outputs = []
|
||||
for idx in range(len(outputs)):
|
||||
output = outputs.tolist()[idx][len(inputs["input_ids"][idx]):]
|
||||
response = tokenizer.decode(output)
|
||||
intermediate_outputs.append(response)
|
||||
answer_texts = [text + intermediate + "\n" + extraction_prompt for text, intermediate in
|
||||
zip(texts, intermediate_outputs)]
|
||||
input_tokens = [build_prompt(answer_text) for answer_text in answer_texts]
|
||||
inputs = tokenizer(input_tokens, padding=True, return_tensors="pt", truncation=True, max_length=2048).to('cuda')
|
||||
outputs = model(**inputs, return_last_logit=True)
|
||||
logits = outputs.logits[:, -1]
|
||||
logits = logits[:, choice_tokens]
|
||||
preds = logits.argmax(dim=-1)
|
||||
correct += (preds.cpu() == batch["label"]).sum().item()
|
||||
accuracy = correct / len(dataset)
|
||||
print(entry, accuracy)
|
||||
accuracy_dict[entry] = accuracy
|
||||
count_dict[entry] = len(dataset)
|
||||
|
||||
acc_total, count_total = 0.0, 0
|
||||
for key in accuracy_dict:
|
||||
acc_total += accuracy_dict[key] * count_dict[key]
|
||||
count_total += count_dict[key]
|
||||
print(acc_total / count_total)
|
Loading…
Reference in New Issue
Block a user