Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support regex as a stopping condition #2699

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/references/sampling_params.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ The `sampling_params` follows this format
max_new_tokens: int = 128,
# Stop when hitting any of the strings in this list
stop: Optional[Union[str, List[str]]] = None,
# Stop when hitting any of the regex patterns in this list.
stop_regex: Optional[Union[str, List[str]]] = None,
# Stop when hitting any of the token_ids in this list
stop_token_ids: Optional[List[int]] = [],
# Sampling temperature
Expand Down
6 changes: 6 additions & 0 deletions python/sglang/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def gen(
max_tokens: Optional[int] = None,
min_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
Expand Down Expand Up @@ -121,6 +122,7 @@ def gen(
max_tokens,
min_tokens,
stop,
stop_regex,
stop_token_ids,
temperature,
top_p,
Expand All @@ -143,6 +145,7 @@ def gen_int(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
Expand All @@ -161,6 +164,7 @@ def gen_int(
max_tokens,
None,
stop,
stop_regex,
stop_token_ids,
temperature,
top_p,
Expand All @@ -182,6 +186,7 @@ def gen_string(
name: Optional[str] = None,
max_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
Expand All @@ -200,6 +205,7 @@ def gen_string(
max_tokens,
None,
stop,
stop_regex,
stop_token_ids,
temperature,
top_p,
Expand Down
1 change: 1 addition & 0 deletions python/sglang/lang/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,7 @@ def _resolve_sampling_params(self, sampling_params):
"max_new_tokens",
"min_new_tokens",
"stop",
"stop_regex",
"stop_token_ids",
"temperature",
"top_p",
Expand Down
13 changes: 13 additions & 0 deletions python/sglang/lang/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class SglSamplingParams:
max_new_tokens: int = 128
min_new_tokens: int = 0
stop: Union[str, List[str]] = ()
stop_regex: Union[str, List[str]] = ()
stop_token_ids: Optional[List[int]] = ()
temperature: float = 1.0
top_p: float = 1.0
Expand All @@ -42,6 +43,7 @@ def clone(self):
self.max_new_tokens,
self.min_new_tokens,
self.stop,
self.stop_regex,
self.stop_token_ids,
self.temperature,
self.top_p,
Expand Down Expand Up @@ -117,6 +119,7 @@ def to_srt_kwargs(self):
"max_new_tokens": self.max_new_tokens,
"min_new_tokens": self.min_new_tokens,
"stop": self.stop,
"stop_regex": self.stop_regex,
"stop_token_ids": self.stop_token_ids,
"temperature": self.temperature,
"top_p": self.top_p,
Expand Down Expand Up @@ -154,6 +157,7 @@ def run(
*args,
max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
Expand All @@ -178,10 +182,13 @@ def run(
stop = []
if stop_token_ids is None:
stop_token_ids = []
if stop_regex is None:
stop_regex = []

default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
stop_regex=stop_regex,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
Expand Down Expand Up @@ -212,6 +219,7 @@ def run_batch(
*,
max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
Expand All @@ -234,6 +242,8 @@ def run_batch(
stop = []
if stop_token_ids is None:
stop_token_ids = []
if stop_regex is None:
stop_regex = []

assert isinstance(batch_kwargs, (list, tuple))
if len(batch_kwargs) == 0:
Expand All @@ -256,6 +266,7 @@ def run_batch(
default_sampling_para = SglSamplingParams(
max_new_tokens=max_new_tokens,
stop=stop,
stop_regex=stop_regex,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
Expand Down Expand Up @@ -438,6 +449,7 @@ def __init__(
max_new_tokens: Optional[int] = None,
min_new_tokens: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
Expand All @@ -461,6 +473,7 @@ def __init__(
max_new_tokens=max_new_tokens,
min_new_tokens=min_new_tokens,
stop=stop,
stop_regex=stop_regex,
stop_token_ids=stop_token_ids,
temperature=temperature,
top_p=top_p,
Expand Down
30 changes: 27 additions & 3 deletions python/sglang/srt/managers/schedule_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import dataclasses
import logging
import re
from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -67,7 +68,6 @@
"enable_ep_moe": ServerArgs.enable_ep_moe,
}


logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -103,6 +103,18 @@ def to_json(self):
}


class FINISH_MATCHED_REGEX(BaseFinishReason):
def __init__(self, matched: str):
super().__init__()
self.matched = matched

def to_json(self):
return {
"type": "stop_regex",
"matched": self.matched,
}


class FINISH_LENGTH(BaseFinishReason):
def __init__(self, length: int):
super().__init__()
Expand Down Expand Up @@ -255,6 +267,7 @@ def __init__(
# 2: read_offset
# 3: last token
self.vid = 0 # version id to sync decode status with in detokenizer_manager
self.decoded_text = ""
self.surr_offset = None # Surrounding offset to defeat the cleanup algorithm
self.read_offset = None
self.decoded_text = ""
Expand Down Expand Up @@ -429,9 +442,20 @@ def check_finished(self):

for stop_str in self.sampling_params.stop_strs:
if stop_str in tail_str or stop_str in self.decoded_text:
self.finished_reason = FINISH_MATCHED_STR(matched=stop_str)
self.finished_reason = FINISH_MATCHED_REGEX(matched=stop_str)
return

# Check stop regex
if self.sampling_params.stop_regex_strs:
decode_res, new_text = self.get_next_inc_detokenization()
if decode_res:
for stop_regex_str in self.sampling_params.stop_regex_strs:
if re.search(stop_regex_str, new_text):
self.finished_reason = FINISH_MATCHED_STR(
matched=stop_regex_str
)
return

def jump_forward_and_retokenize(self, jump_forward_str, next_state):
if self.origin_input_text is None:
# Recovering text can only use unpadded ids
Expand Down Expand Up @@ -463,7 +487,7 @@ def jump_forward_and_retokenize(self, jump_forward_str, next_state):
self.surr_offset = prompt_tokens
self.read_offset = len(all_ids)

# NOTE: A trick to reduce the surrouding tokens decoding overhead
# NOTE: A trick to reduce the surrounding tokens decoding overhead
for i in range(0, INIT_INCREMENTAL_DETOKENIZATION_OFFSET):
surr_text_ = self.tokenizer.decode(
all_ids[self.read_offset - i : self.read_offset]
Expand Down
5 changes: 5 additions & 0 deletions python/sglang/srt/openai_api/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,7 @@ def v1_generate_request(
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": request.stop,
"stop_regex": request.stop_regex,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"top_k": request.top_k,
Expand Down Expand Up @@ -917,6 +918,7 @@ def v1_chat_generate_request(
if assistant_prefix:
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
stop = request.stop
stop_regex = request.stop_regex
image_data = None
modalities = []
else:
Expand All @@ -925,6 +927,7 @@ def v1_chat_generate_request(
image_data = conv.image_data
modalities = conv.modalities
stop = conv.stop_str or []
stop_regex = []
if request.stop:
if isinstance(request.stop, str):
stop.append(request.stop)
Expand All @@ -935,6 +938,7 @@ def v1_chat_generate_request(
# Use the raw prompt and stop strings if the messages is already a string.
prompt_ids = request.messages
stop = request.stop
stop_regex = request.stop_regex
image_data = None
modalities = []
input_ids.append(prompt_ids)
Expand All @@ -948,6 +952,7 @@ def v1_chat_generate_request(
"max_new_tokens": request.max_tokens,
"min_new_tokens": request.min_tokens,
"stop": stop,
"stop_regex": stop_regex,
"stop_token_ids": request.stop_token_ids,
"top_p": request.top_p,
"top_k": request.top_k,
Expand Down
4 changes: 3 additions & 1 deletion python/sglang/srt/openai_api/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,11 +175,12 @@ class CompletionRequest(BaseModel):
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_regex: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = (None,)


class CompletionResponseChoice(BaseModel):
Expand Down Expand Up @@ -317,6 +318,7 @@ class ChatCompletionRequest(BaseModel):
regex: Optional[str] = None
ebnf: Optional[str] = None
repetition_penalty: float = 1.0
stop_regex: Optional[Union[str, List[str]]] = None
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
Expand Down
7 changes: 7 additions & 0 deletions python/sglang/srt/sampling/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def __init__(
self,
max_new_tokens: int = 128,
stop: Optional[Union[str, List[str]]] = None,
stop_regex: Optional[Union[str, List[str]]] = None,
stop_token_ids: Optional[List[int]] = None,
temperature: float = 1.0,
top_p: float = 1.0,
Expand All @@ -57,6 +58,7 @@ def __init__(
self.presence_penalty = presence_penalty
self.repetition_penalty = repetition_penalty
self.stop_strs = stop
self.stop_regex_strs = stop_regex
if stop_token_ids:
self.stop_token_ids = set(stop_token_ids)
else:
Expand Down Expand Up @@ -146,3 +148,8 @@ def normalize(self, tokenizer):
else:
stop_str_max_len = max(stop_str_max_len, len(stop_str))
self.stop_str_max_len = stop_str_max_len

if self.stop_regex_strs is None:
self.stop_regex_strs = []
elif isinstance(self.stop_regex_strs, str):
self.stop_regex_strs = [self.stop_regex_strs]
30 changes: 30 additions & 0 deletions test/srt/test_matched_stop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def run_completions_generation(
prompt=MANY_NEW_TOKENS_PROMPT,
max_tokens=1,
stop=None,
stop_regex=None,
finish_reason=None,
matched_stop=None,
):
Expand All @@ -53,6 +54,9 @@ def run_completions_generation(
if stop is not None:
payload["stop"] = stop

if stop_regex is not None:
payload["stop_regex"] = stop_regex

response_completions = requests.post(
self.base_url + "/v1/completions",
json=payload,
Expand All @@ -70,6 +74,7 @@ def run_chat_completions_generation(
prompt=MANY_NEW_TOKENS_PROMPT,
max_tokens=1,
stop=None,
stop_regex=None,
finish_reason=None,
matched_stop=None,
):
Expand All @@ -87,6 +92,9 @@ def run_chat_completions_generation(
if stop is not None:
chat_payload["stop"] = stop

if stop_regex is not None:
chat_payload["stop_regex"] = stop_regex

response_chat = requests.post(
self.base_url + "/v1/chat/completions",
json=chat_payload,
Expand All @@ -105,6 +113,28 @@ def test_finish_stop_str(self):
max_tokens=1000, stop="\n", finish_reason="stop", matched_stop="\n"
)

def test_finish_stop_regex_str(self):
stop_regex = r"and |or "
self.run_completions_generation(
max_tokens=1000,
stop_regex=stop_regex,
finish_reason="stop",
matched_stop=stop_regex,
)
self.run_chat_completions_generation(
max_tokens=1000,
stop_regex=stop_regex,
finish_reason="stop",
matched_stop=stop_regex,
)
stop_regex = r"[.!?]\s*$"
self.run_chat_completions_generation(
max_tokens=1000,
stop_regex=stop_regex,
finish_reason="stop",
matched_stop=stop_regex,
)

def test_finish_stop_eos(self):
llama_format_prompt = """
<|begin_of_text|><|start_header_id|>system<|end_header_id|>
Expand Down
Loading