Skip to content

Commit

Permalink
Merge pull request #6190 from JieShenAI/main
Browse files Browse the repository at this point in the history
add vllm_infer script
  • Loading branch information
hiyouga authored Dec 4, 2024
2 parents 263cb82 + 4c61368 commit dc78355
Show file tree
Hide file tree
Showing 2 changed files with 362 additions and 0 deletions.
223 changes: 223 additions & 0 deletions scripts/async_call_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# pip install langchain langchain_openai

import os
import sys
import json
import asyncio


import fire
from tqdm import tqdm
from dataclasses import dataclass
from aiolimiter import AsyncLimiter
from typing import List
import pandas as pd
from langchain_openai import ChatOpenAI
from dotenv import load_dotenv

from llamafactory.hparams import get_train_args
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.data.loader import _get_merged_dataset

load_dotenv()


class AsyncLLM:
def __init__(
self,
model: str = "gpt-3.5-turbo",
base_url: str = "http://localhost:{}/v1/".format(
os.environ.get("API_PORT", 8000)
),
api_key: str = "{}".format(os.environ.get("API_KEY", "0")),
num_per_second: int = 6,
**kwargs,
):
self.model = model
self.base_url = base_url
self.api_key = api_key
self.num_per_second = num_per_second

# 创建限速器,每秒最多发出 5 个请求
self.limiter = AsyncLimiter(self.num_per_second, 1)

self.llm = ChatOpenAI(
model=self.model, base_url=self.base_url, api_key=self.api_key, **kwargs
)

async def __call__(self, text):
# 限速
async with self.limiter:
return await self.llm.ainvoke([text])


llm = AsyncLLM(
base_url="http://localhost:{}/v1/".format(os.environ.get("API_PORT", 8000)),
api_key="{}".format(os.environ.get("API_KEY", "0")),
num_per_second=10,
)
llms = [llm]


@dataclass
class AsyncAPICall:
uid: str = "0"

@staticmethod
async def _run_task_with_progress(task, pbar):
result = await task
pbar.update(1)
return result

@staticmethod
def async_run(
llms: List[AsyncLLM],
data: List[str],
keyword: str = "",
output_dir: str = "output",
chunk_size=500,
) -> List[str]:

async def infer_chunk(llms: List[AsyncLLM], data: List):
"""
逐块进行推理,为避免处理庞大数据时,程序崩溃导致已推理数据丢失
"""
results = [llms[i % len(llms)](text) for i, text in enumerate(data)]

with tqdm(total=len(results)) as pbar:
results = await asyncio.gather(
*[
AsyncAPICall._run_task_with_progress(task, pbar)
for task in results
]
)
return results

idx = 0
all_df = []
file_exist_skip = False
user_confirm = False

while idx < len(data):
file_path = os.path.join(output_dir, "tmp", f"{idx}.csv.temp")

if os.path.exists(file_path):
if not user_confirm:
while True:
user_response = input(
f"Find {file_path} file already exists. Do you want to skip them forever?\ny or Y to skip, n or N to rerun to overwrite: "
)
if user_response.lower() == "y":
user_confirm = True
file_exist_skip = True
break
elif user_response.lower() == "n":
user_confirm = True
file_exist_skip = False
break

if file_exist_skip:
tmp_df = pd.read_csv(file_path)
all_df.append(tmp_df)
idx += chunk_size
continue

tmp_data = data[idx : idx + chunk_size]
loop = asyncio.get_event_loop()
tmp_result = loop.run_until_complete(infer_chunk(llms=llms, data=tmp_data))
tmp_result = [item.content for item in tmp_result]

tmp_df = pd.DataFrame({"infer": tmp_result})

if not os.path.exists(p := os.path.dirname(file_path)):
os.makedirs(p, exist_ok=True)

tmp_df.to_csv(file_path, index=False)
all_df.append(tmp_df)
idx += chunk_size

all_df = pd.concat(all_df)
return all_df["infer"]


def async_api_infer(
model_name_or_path: str = "",
eval_dataset: str = "",
template: str = "",
dataset_dir: str = "data",
do_predict: bool = True,
predict_with_generate: bool = True,
max_samples: int = None,
output_dir: str = "output",
chunk_size=50,
):

if len(sys.argv) == 1:
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args(
dict(
model_name_or_path=model_name_or_path,
dataset_dir=dataset_dir,
eval_dataset=eval_dataset,
template=template,
output_dir=output_dir,
do_predict=True,
predict_with_generate=True,
max_samples=max_samples,
)
)
)
else:
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args()
)

dataset = _get_merged_dataset(
data_args.eval_dataset, model_args, data_args, training_args, "sft"
)

labels = [item[0]["content"] for item in dataset["_response"]]
prompts = [item[0]["content"] for item in dataset["_prompt"]]

infers = AsyncAPICall.async_run(
llms,
prompts,
chunk_size=chunk_size,
output_dir=training_args.output_dir,
)

if not os.path.exists(training_args.output_dir):
os.makedirs(training_args.output_dir, exist_ok=True)

output_prediction_file = os.path.join(
training_args.output_dir, "generated_predictions.jsonl"
)

with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, pred, label in zip(prompts, infers, labels):
res.append(
json.dumps(
{"prompt": text, "predict": pred, "label": label},
ensure_ascii=False,
)
)
writer.write("\n".join(res))


if __name__ == "__main__":
fire.Fire(async_api_infer)
139 changes: 139 additions & 0 deletions scripts/vllm_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2024 the LlamaFactory team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import json
import os
import sys
from typing import List

import fire
from vllm import LLM, SamplingParams
from vllm.lora.request import LoRARequest

from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
from llamafactory.extras.constants import IGNORE_INDEX
from llamafactory.hparams import get_train_args
from llamafactory.model import load_tokenizer


max_tokens = 2048


def vllm_infer(
model_name_or_path: str = None,
adapter_name_or_path: str = None,
dataset_dir: str = "data",
eval_dataset: str = None,
template: str = "default",
max_sample: int = None,
preprocessing_num_workers: int = 16,
predict_with_generate: bool = True,
do_predict: bool = True,
temperature: float = 0.7,
top_p: float = 0.7,
top_k: float = 50,
output_dir: str = "output",
):

if len(sys.argv) == 1:
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args(
dict(
model_name_or_path=model_name_or_path,
adapter_name_or_path=adapter_name_or_path,
dataset_dir=dataset_dir,
eval_dataset=eval_dataset,
template=template,
max_sample=max_sample,
preprocessing_num_workers=preprocessing_num_workers,
predict_with_generate=predict_with_generate,
do_predict=do_predict,
temperature=temperature,
top_p=top_p,
top_k=top_k,
output_dir=output_dir,
)
)
)
else:
model_args, data_args, training_args, finetuning_args, generating_args = (
get_train_args()
)

tokenizer = load_tokenizer(model_args)["tokenizer"]
template = get_template_and_fix_tokenizer(tokenizer, data_args)

eval_dataset = get_dataset(
template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
)["eval_dataset"]

prompts = [item["input_ids"] for item in eval_dataset]
prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)

labels = [
list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
for item in eval_dataset
]
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

sampling_params = SamplingParams(
temperature=generating_args.temperature,
top_k=generating_args.top_k,
top_p=generating_args.top_p,
max_tokens=max_tokens,
)

if model_args.adapter_name_or_path:
if isinstance(model_args.adapter_name_or_path, list):
lora_path = model_args.adapter_name_or_path[0]
else:
lora_path = model_args.adapter_name_or_path

lora_requests = LoRARequest("lora_adapter_0", 0, lora_path=lora_path)
enable_lora = True
else:
lora_requests = None
enable_lora = False

llm = LLM(
model=model_args.model_name_or_path,
trust_remote_code=True,
tokenizer=model_args.model_name_or_path,
enable_lora=enable_lora,
)

outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)

if not os.path.exists(training_args.output_dir):
os.makedirs(training_args.output_dir, exist_ok=True)

output_prediction_file = os.path.join(
training_args.output_dir, "generated_predictions.jsonl"
)

with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = []
for text, pred, label in zip(prompts, outputs, labels):
res.append(
json.dumps(
{"prompt": text, "predict": pred.outputs[0].text, "label": label},
ensure_ascii=False,
)
)
writer.write("\n".join(res))


if __name__ == "__main__":
fire.Fire(vllm_infer)

0 comments on commit dc78355

Please sign in to comment.