Skip to content

Commit

Permalink
make format happy
Browse files Browse the repository at this point in the history
  • Loading branch information
Coobiw committed Oct 20, 2024
1 parent b63ff54 commit f4bc35c
Show file tree
Hide file tree
Showing 9 changed files with 342 additions and 189 deletions.
5 changes: 3 additions & 2 deletions aria/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@
from typing import Dict, Iterable, List

import torch
from datasets import DatasetDict, concatenate_datasets, load_dataset
from datasets.features import Features, Sequence, Value

from datasets import DatasetDict, concatenate_datasets, load_dataset


def apply_chat_template_and_tokenize(
messages_batch: List[List[Dict]],
Expand Down Expand Up @@ -96,7 +97,7 @@ def create_target(role, input_id):
if pad_length > 0:
input_ids[i] = input_ids[i] + [tokenizer.pad_token_id] * pad_length
targets[i] = targets[i] + [IGNORE_TOKEN_ID] * pad_length
else: # truncate
else: # truncate
input_ids[i] = input_ids[i][:max_batch_len]
targets[i] = targets[i][:max_batch_len]

Expand Down
18 changes: 14 additions & 4 deletions aria/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,13 @@ def setup_peft(model, model_config):
return model


def collate_fn(examples, tokenizer, processor, split_image: bool = False, max_seq_length: int = 1024):
def collate_fn(
examples,
tokenizer,
processor,
split_image: bool = False,
max_seq_length: int = 1024,
):
images = []
messages = []
for example in examples:
Expand Down Expand Up @@ -172,7 +178,7 @@ def collate_fn(examples, tokenizer, processor, split_image: bool = False, max_se
message["content"].insert(cont_idx + img_i, insert_item)
messages.append(example["messages"])
else:
if example['images']:
if example["images"]:
images.extend(example["images"])
messages.append(example["messages"])

Expand All @@ -192,7 +198,7 @@ def collate_fn(examples, tokenizer, processor, split_image: bool = False, max_se

batch.update(image_inputs)
batch["pixel_values"] = batch["pixel_values"].to(torch.bfloat16)
else: # text-only
else: # text-only
batch = apply_chat_template_and_tokenize(
messages,
tokenizer,
Expand Down Expand Up @@ -225,7 +231,11 @@ def main():
model=model,
args=training_args,
data_collator=lambda examples: collate_fn(
examples, tokenizer, processor, model_config.split_image, training_args.max_seq_length
examples,
tokenizer,
processor,
model_config.split_image,
training_args.max_seq_length,
),
train_dataset=train_dataset,
eval_dataset=eval_dataset,
Expand Down
11 changes: 6 additions & 5 deletions examples/code_sft/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from human_eval.evaluation import evaluate_functional_correctness
import json
import argparse

import os

from human_eval.evaluation import evaluate_functional_correctness

parser = argparse.ArgumentParser()
parser.add_argument("--save_root", required=True, type=str, help="the result directory of humaneval")
parser.add_argument(
"--save_root", required=True, type=str, help="the result directory of humaneval"
)
args = parser.parse_args()

tmp_dir = os.path.join(args.save_root, "tmp")
Expand All @@ -20,4 +21,4 @@
language="python",
)

print(result)
print(result)
33 changes: 19 additions & 14 deletions examples/code_sft/get_data.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,33 @@
import os
import json
import os

from tqdm import trange

from datasets import load_dataset

save_dir = './datasets/code_sft'
save_dir = "./datasets/code_sft"
os.makedirs(save_dir, exist_ok=True)

dataset = load_dataset("ise-uiuc/Magicoder-Evol-Instruct-110K")['train']
dataset = load_dataset("ise-uiuc/Magicoder-Evol-Instruct-110K")["train"]
l_ds = len(dataset)

with open(os.path.join(save_dir, "train.jsonl"),"w") as fo:
with open(os.path.join(save_dir, "train.jsonl"), "w") as fo:
for i in trange(l_ds):
instruction, response = dataset[i]['instruction'], dataset[i]['response']
instruction, response = dataset[i]["instruction"], dataset[i]["response"]
item = {
"messages": [
{
"content": [{"text": instruction, "type": "text"},],
"role": "user"
},
{
"content": [{"text": response, "type": "text"},],
"role": "assistant"
},
],
{
"content": [
{"text": instruction, "type": "text"},
],
"role": "user",
},
{
"content": [
{"text": response, "type": "text"},
],
"role": "assistant",
},
],
}
fo.write(f"{json.dumps(item, ensure_ascii=False)}\n")
15 changes: 7 additions & 8 deletions examples/code_sft/human_eval/data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Iterable, Dict
import gzip
import json
import os

from typing import Dict, Iterable

ROOT = os.path.dirname(os.path.abspath(__file__))
HUMAN_EVAL = os.path.join(ROOT, "..", "data", "HumanEval.jsonl.gz")
Expand All @@ -18,7 +17,7 @@ def stream_jsonl(filename: str) -> Iterable[Dict]:
"""
if filename.endswith(".gz"):
with open(filename, "rb") as gzfp:
with gzip.open(gzfp, 'rt') as fp:
with gzip.open(gzfp, "rt") as fp:
for line in fp:
if any(not x.isspace() for x in line):
yield json.loads(line)
Expand All @@ -34,16 +33,16 @@ def write_jsonl(filename: str, data: Iterable[Dict], append: bool = False):
Writes an iterable of dictionaries to jsonl
"""
if append:
mode = 'ab'
mode = "ab"
else:
mode = 'wb'
mode = "wb"
filename = os.path.expanduser(filename)
if filename.endswith(".gz"):
with open(filename, mode) as fp:
with gzip.GzipFile(fileobj=fp, mode='wb') as gzfp:
with gzip.GzipFile(fileobj=fp, mode="wb") as gzfp:
for x in data:
gzfp.write((json.dumps(x) + "\n").encode('utf-8'))
gzfp.write((json.dumps(x) + "\n").encode("utf-8"))
else:
with open(filename, mode) as fp:
for x in data:
fp.write((json.dumps(x) + "\n").encode('utf-8'))
fp.write((json.dumps(x) + "\n").encode("utf-8"))
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import fire
import sys

from .data import HUMAN_EVAL
import fire

from .evaluation import evaluate_functional_correctness


Expand All @@ -18,7 +18,9 @@ def entry_point(
results to f"{sample_file}_results.jsonl.gz"
"""
k = list(map(int, k.split(",")))
results = evaluate_functional_correctness(sample_file, k, n_workers, timeout, problem_file, is_mbpp)
results = evaluate_functional_correctness(
sample_file, k, n_workers, timeout, problem_file, is_mbpp
)
print(results)


Expand Down
Loading

0 comments on commit f4bc35c

Please sign in to comment.