Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is batch streaming possible with the Text Generation functions? #1423

Closed
suhjohn opened this issue Aug 16, 2023 · 5 comments
Closed

Is batch streaming possible with the Text Generation functions? #1423

suhjohn opened this issue Aug 16, 2023 · 5 comments

Comments

@suhjohn
Copy link

suhjohn commented Aug 16, 2023

There doesn't seem to be good documentation on using generate_iterable. From just the name, I get the sense that this could potentially be used to batch stream. I'm looking to improve the performance of tgi / vllm and streaming is a crucial functionality that I would like to support but it's unclear if it's possible with CTranslate2.

@suhjohn suhjohn changed the title Batch generation + Streaming possible? Is batch streaming possible with the Text Generation functions? Aug 16, 2023
@guillaumekln
Copy link
Collaborator

What do you mean by batch streaming exactly?

@SebastianBodza
Copy link
Contributor

SebastianBodza commented Aug 16, 2023

Something like the following:

import transformers
generator = ctranslate2.Generator("./models/ctranslate2/gpt2")
tokenizer = transformers.AutoTokenizer.from_pretrained("gpt2")
print("model loaded")
prompt = "What is the meaning of life?"
input_id = tokenizer.encode(prompt, add_special_tokens=True)
prompt_tokens = tokenizer.convert_ids_to_tokens(input_id)

step_results = generator.generate_tokens(
    [prompt_tokens, prompt_tokens],
    sampling_temperature=0.8,
    sampling_topk=20,
    max_length=100,
)

output_ids = []

for step_result in step_results:
    print(step_result)
    is_new_word = step_result.token.startswith("Ġ")

    if is_new_word and output_ids:
        word = tokenizer.decode(output_ids)
        print(word, end=" ", flush=True)
        output_ids = []

    output_ids.append(step_result.token_id)

if output_ids:
    word = tokenizer.decode(output_ids)
    print(word)

By itself it is definetly not working and currently not implemented as far as i see from the generator.generate_tokens function. In the extensions.py

    298         generator.generate_batch,
    299         [prompt],
    300         repetition_penalty=repetition_penalty,
    301         no_repeat_ngram_size=no_repeat_ngram_size,
    302         disable_unk=disable_unk,
    303         suppress_sequences=suppress_sequences,
    304         end_token=end_token,
    305         max_length=max_length,
    306         min_length=min_length,
    307         sampling_topk=sampling_topk,
    308         sampling_topp=sampling_topp,
    309         sampling_temperature=sampling_temperature,
    310         return_scores=return_log_prob,
    311         static_prompt=static_prompt,
    312         cache_static_prompt=cache_static_prompt,
    313         include_prompt_in_result=False,
    314     )

File [c:\Users\\anaconda3\lib\site-packages\ctranslate2\extensions.py:337](file:///C:/Users//anaconda3/lib/site-packages/ctranslate2/extensions.py:337), in _generate_tokens(process_func, *args, **kwargs)
    327     return generator_closed.is_set()
    329 kwargs.update(
    330     {
    331         "asynchronous": True,
   (...)
    334     }
    335 )
--> 337 async_result = process_func(*args, **kwargs)[0]
    339 def _catch_exception():
    340     try:

TypeError: generate_batch(): incompatible function arguments. The following argument types are supported:
    1. (self: ctranslate2._ext.Generator, start_tokens: List[List[str]], *, max_batch_size: int = 0, batch_type: str = 'examples', asynchronous: bool = False, beam_size: int = 1, patience: float = 1, num_hypotheses: int = 1, length_penalty: float = 1, repetition_penalty: float = 1, no_repeat_ngram_size: int = 0, disable_unk: bool = False, suppress_sequences: Optional[List[List[str]]] = None, end_token: Optional[Union[str, List[str], List[int]]] = None, return_end_token: bool = False, max_length: int = 512, min_length: int = 0, static_prompt: Optional[List[str]] = None, cache_static_prompt: bool = True, include_prompt_in_result: bool = True, return_scores: bool = False, return_alternatives: bool = False, min_alternative_expansion_prob: float = 0, sampling_topk: int = 1, sampling_topp: float = 1, sampling_temperature: float = 1, callback: Callable[[ctranslate2._ext.GenerationStepResult], bool] = None) -> Union[List[ctranslate2._ext.GenerationResult], List[ctranslate2._ext.AsyncGenerationResult]]

The generator gets a list. When providing a list to generate_tokens it crashes. Btw. not sure if the type hint is correct then, i think it should be str.

It should however be possible with limited adjustments to the Code. Adjusting the Code above leads to the following output:

GenerationStepResult(step=0, batch_id=0, token_id=15957, token='ĠTouch', log_prob=None, is_last=False)
GenerationStepResult(step=0, batch_id=1, token_id=15957, token='ĠTouch', log_prob=None, is_last=False)

@guillaumekln
Copy link
Collaborator

The type hint is correct. This method takes a list of tokens, but not a batch of list of tokens.

For batch mode, see the "Tip" note in the related documentation:

https://opennmt.net/CTranslate2/generation.html#token-streaming

@SebastianBodza
Copy link
Contributor

My bad, the tokenizer converts it to a list 😐.

Thanks for the hint of the callback 👍

@suhjohn suhjohn changed the title Is batch streaming possible with the Text Generation functions? TGI / VLLM-like CTranslate2 inference server implementation? Aug 17, 2023
@suhjohn suhjohn changed the title TGI / VLLM-like CTranslate2 inference server implementation? Is batch streaming possible with the Text Generation functions? Aug 17, 2023
@guillaumekln
Copy link
Collaborator

guillaumekln commented Sep 5, 2023

I'm closing this issue, even though the term "batch streaming" was not clarified by OP.

  • If it refers to token streaming in batch mode, there is already a callback parameter as discussed above.
  • If it refers to continuous batching, see the existing issue Continuous batching #1333.
  • If it refers to something else, feel free to reopen the issue with more details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants