Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add llama3 tasks #2556

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions lm_eval/tasks/llama3/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# Task-name

### Paper

Title: ``

Abstract: ``


Homepage: ``


### Citation

```

```

### Groups, Tags, and Tasks

#### Groups



#### Subgroups


### Tasks

* `llama_arc_challenge`: 25-shot multiple-choice ARC challenge.
* `mgsm_chat`: 0-shot mgsm benchmark. Use with chat-template.

### Checklist

For adding novel benchmarks/datasets to the library:
* [x] Is the task an existing benchmark in the literature?
* [x] Have you referenced the original paper that introduced the task?
* [x] If yes, does the original paper provide a reference implementation? If so, have you checked against the reference implementation and documented how to run such a test?


If other tasks on this dataset are already supported:
* [x] Is the "Main" variant of this task clearly denoted?
* [x] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [x] Have you noted which, if any, published evaluation setups are matched by this variant?
24 changes: 24 additions & 0 deletions lm_eval/tasks/llama3/base/arc_challenge.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
tag:
- llama
task: llama_arc_challenge
dataset_path: allenai/ai2_arc
dataset_name: ARC-Challenge
output_type: multiple_choice
training_split: train
validation_split: validation
test_split: test
fewshot_split: train
doc_to_text: "Question: {{question.strip()}}\nA. {{choices.text[0]}}\nB. {{choices.text[1]}}\nC. {{choices.text[2]}}{% if choices.text|length > 3 %}\nD. {{choices.text[3]}}{% endif %}\nAnswer:"
fewshot_delimiter: "\n\n"
doc_to_target: "{{ 'ABCD'[answerKey|int - 1] if answerKey|string in '1234' else answerKey }}"
doc_to_choice: "{{ choices.label|map('replace', '1', 'A')|map('replace', '2', 'B')|map('replace', '3', 'C')|map('replace', '4', 'D')|list if choices.label[0] in '1234' else choices.label }}"
num_fewshot: 25
metric_list:
- metric: acc
aggregation: mean
higher_is_better: true
- metric: acc_norm
aggregation: mean
higher_is_better: true
metadata:
version: 1.0
15 changes: 15 additions & 0 deletions lm_eval/tasks/llama3/base/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import datasets


def process_arc_c_docs(dataset: datasets.Dataset) -> datasets.Dataset:
COLUMNS = dataset.column_names

def map_(doc):
doc["doc_to_text"] = doc["input_final_prompts"][0].strip()[:-2].strip()
doc["doc_to_choice"] = [
x.replace("Answer:", "").strip() for x in doc["output_choice_completions"]
]
doc["doc_to_target"] = doc["input_correct_responses"][0].strip()[-1]
return doc

return dataset.map(map_, remove_columns=COLUMNS)
21 changes: 21 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
group: mgsm_chat
group_alias: mgsm (llama)
task:
- mgsm_chat_bn
- mgsm_chat_de
- mgsm_chat_en
- mgsm_chat_es
- mgsm_chat_fr
- mgsm_chat_ja
- mgsm_chat_ru
- mgsm_chat_sw
- mgsm_chat_te
- mgsm_chat_th
- mgsm_chat_zh
aggregate_metric_list:
- metric: exact_match
aggregation: mean
weight_by_size: True
filter_list: [flexible-extract]
metadata:
version: 0
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_bn.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: bn # Overridden by language-specific config.
process_docs: !function utils.process_docs_bn
task: mgsm_chat_bn
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_de.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: de # Overridden by language-specific config.
process_docs: !function utils.process_docs_de
task: mgsm_chat_de
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_en.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: en # Overridden by language-specific config.
process_docs: !function utils.process_docs_en
task: mgsm_chat_en
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_es.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: es # Overridden by language-specific config.
process_docs: !function utils.process_docs_es
task: mgsm_chat_es
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_fr.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: fr # Overridden by language-specific config.
process_docs: !function utils.process_docs_fr
task: mgsm_chat_fr
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_ja.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: ja # Overridden by language-specific config.
process_docs: !function utils.process_docs_ja
task: mgsm_chat_ja
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_ru.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: ru # Overridden by language-specific config.
process_docs: !function utils.process_docs_ru
task: mgsm_chat_ru
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_sw.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: sw # Overridden by language-specific config.
process_docs: !function utils.process_docs_sw
task: mgsm_chat_sw
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_te.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: te # Overridden by language-specific config.
process_docs: !function utils.process_docs_te
task: mgsm_chat_te
34 changes: 34 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_template
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# This file will be included in the generated language-specific task configs.
# It doesn't have a yaml file extension as it is not meant to be imported directly
# by the harness.
tag: llama3
dataset_path: juletxara/mgsm
dataset_name: null # Overridden by language-specific config.
output_type: generate_until
training_split: train
test_split: test
target_delimiter: ""
doc_to_text: "{{question}}"
doc_to_target: answers # list
process_results: !function utils.process_results
generation_kwargs:
until: []
do_sample: false
temperature: 0.0
max_gen_toks: 2048
metric_list:
- metric: exact_match
aggregation: mean
higher_is_better: true
ignore_case: true
ignore_punctuation: true # not used here, but set manually in the process_results function
filter_list:
- name: "flexible-extract"
filter:
- function: regex
group_select: -1
regex_pattern: "(?:Answer|Réponse|Antwort|Ответ|Respuesta|答え|Jibu|答案|คำตอบ|సమాధానం|উত্তর)[::] (-?[$0-9.,]{2,})|(-?[0-9]+)"
- function: remove_whitespace
- function: take_first
metadata:
version: 0.0
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_th.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: th # Overridden by language-specific config.
process_docs: !function utils.process_docs_th
task: mgsm_chat_th
4 changes: 4 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/mgsm_chat_zh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
include: mgsm_chat_template
dataset_name: zh # Overridden by language-specific config.
process_docs: !function utils.process_docs_zh
task: mgsm_chat_zh
125 changes: 125 additions & 0 deletions lm_eval/tasks/llama3/instruct/mgsm/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
import re
import string
from functools import partial
from typing import TYPE_CHECKING, Dict, List


if TYPE_CHECKING:
import datasets

from lm_eval.api.metrics import exact_match_fn


TRANSLATE_TABLE = str.maketrans(
"", "", string.punctuation.replace(".", "")
) # decimals are handled by the number_variations function
# extracted from https://huggingface.co/datasets/meta-llama/Llama-3.2-3B-Instruct-evals/viewer/Llama-3.2-3B-Instruct-evals__mgsm__details
PROMPTS = [
{
"rep": 'Solve this math problem. Give the reasoning steps before giving the final answer on the last line by itself in the format of "Answer:". Do not add anything other than the integer answer after "Answer:".',
"subtask_name": "en",
},
{
"rep": 'Решите эту математическую задачу. Объясните шаги рассуждения перед тем, как дать окончательный ответ в последней строке сам по себе в формате "Ответ:". Не добавляйте ничего, кроме целочисленного ответа после "Ответ:".',
"subtask_name": "ru",
},
{
"rep": 'Suluhisha tatizo hili la hesabu. Toa hatua za mantiki kabla ya kutoa jibu la mwisho kwenye mstari wa mwisho peke yake katika muundo wa "Jibu:". Usiongeze chochote kingine isipokuwa jibu la integer baada ya "Jibu:".',
"subtask_name": "sw",
},
{
"rep": 'Résolvez ce problème de mathématiques. Donnez les étapes de raisonnement avant de fournir la réponse finale sur la dernière ligne elle-même dans le format de "Réponse:". N\'ajoutez rien d\'autre que la réponse entière après "Réponse:".',
"subtask_name": "fr",
},
{
"rep": "ఈ గణిత సమస్యను పరిష్కరించండి. చివరి సమాధానాన్ని ఇవ్వదానికి ముందు తర్కాత్మక అదుగులను ఇవ్వండి. చివరి పంక్తిలో మాత్రమే 'సమాధానం:' అనే ఆకారంలో చివరి సమాధానాద్ని ఇవ్వండి సమాధానం: తర్వాత పూర్ణాంక సమాధానానికి తప్పించి ఎదేనా చేర్చవద్దు.",
"subtask_name": "te",
},
{
"rep": 'แก้ปัญหาคณิตศาสตร์นี้ ให้ให้ขั้นตอนการใช้เหตุผลก่อนที่จะให้คำตอบสุดท้ายในบรรทัดสุดท้ายโดยอยู่ในรูปแบบ "คำตอบ:" ไม่ควรเพิ่มอะไรนอกจากคำตอบที่เป็นจำนวนเต็มหลังจาก "คำตอบ:',
"subtask_name": "th",
},
{
"rep": 'の数学の問題を解いてください。最終的な答えを出す前に、解答の推論過程を記述してください。そして最後の行には "答え:" の形式で答えを記述し、その後には整数の答え以外何も追加しないでください。',
"subtask_name": "ja",
},
{
"rep": 'Löse dieses Mathematikproblem. Gib die Schritte zur Begründung an, bevor du die endgültige Antwort in der letzten Zeile alleine im Format "Antwort:" gibst. Füge nichts anderes als die ganzzahlige Antwort nach "Antwort:" hinzu.',
"subtask_name": "de",
},
{
"rep": 'এই গণিতের সমস্যাটি সমাধান করুন। চূড়ান্ত উত্তর দেওয়ার আগে যুক্তিসম্পন্ন পদক্ষেপ প্রদান করুন। চূড়ান্ত উত্তরটি একক সংখ্যা হিসাবে "উত্তর:" এর পরে শেষ লাইনে দিন। "উত্তর:" এর পরে অন্য কিছু যুক্ত করবেন না।.',
"subtask_name": "bn",
},
{
"rep": '解决这个数学问题。在最后一行给出答案前,请提供推理步骤。最后一行应该以 "答案: " 的形式独立给出答案。在 "答案:" 后不要添加除整数答案之外的任何内容。',
"subtask_name": "zh",
},
{
"rep": 'Resuelve este problema matemático. Proporciona los pasos de razonamiento antes de dar la respuesta final en la última línea por sí misma en el formato de "Respuesta:". No añadas nada más que la respuesta entera después de "Respuesta:".',
"subtask_name": "es",
},
]


def number_variations(n: int) -> List[str]:
formats = []
# Generate each pattern twice
for _ in range(2):
# Basic string representation
formats.append(str(n))
formats.append(f"{n}.")

# With one decimal place
formats.append(f"{n}.0")
formats.append(f"{n}.0.")

# With two decimal places
formats.append(f"{n}.00")
formats.append(f"{n}.00.")

return formats


def process_docs(lang: str, df: "datasets.Dataset") -> "datasets.Dataset":
def map_(doc: dict):
suffix = [x for x in PROMPTS if x["subtask_name"] == lang][0]["rep"]

doc["question"] = (
suffix
+ "\n\n"
+ re.split("[:|:]", doc["question"], maxsplit=1)[-1].strip()
)
doc["answers"] = number_variations(doc["answer_number"])
return doc

return df.map(map_)


process_docs_bn = partial(process_docs, "bn")
process_docs_de = partial(process_docs, "de")
process_docs_en = partial(process_docs, "en")
process_docs_es = partial(process_docs, "es")
process_docs_fr = partial(process_docs, "fr")
process_docs_ja = partial(process_docs, "ja")
process_docs_ru = partial(process_docs, "ru")
process_docs_sw = partial(process_docs, "sw")
process_docs_te = partial(process_docs, "te")
process_docs_th = partial(process_docs, "th")
process_docs_zh = partial(process_docs, "zh")


def process_results(doc: dict, prediction: List[str]) -> Dict[str, int]:
gold: List = doc["answers"]
return {
"exact_match": int(
exact_match_fn(
predictions=[x.strip().translate(TRANSLATE_TABLE) for x in prediction]
* len(gold),
references=gold,
ignore_case=True,
ignore_punctuation=False,
)["exact_match"]
> 0
)
}
Loading