diff --git a/README.md b/README.md index 0c75ecb..53346e7 100644 --- a/README.md +++ b/README.md @@ -2,6 +2,12 @@ A plugin for the [`llm`](https://llm.datasette.io/en/stable/) CLI that allows you to use the text generation models (LLMs) running on globally on Cloudflare [Workers AI](https://developers.cloudflare.com/workers-ai/models/#text-generation). +`llm-cloudflare` is useful for: + +* Using and building with LLMs that may not efficiently run on your local machine (limited GPU, memory, etc) vs. having Workers AI run it on a GPU near you. +* Validating the performance of and/or comparing multiple models. +* Experimenting without needing to download models ahead-of-time. + ## Usage **Prerequisite**: You'll need the `llm` CLI [installed first](https://llm.datasette.io/en/stable/setup.html). @@ -41,6 +47,8 @@ llm models default This plugin provides access to the [text generation models](https://developers.cloudflare.com/workers-ai/models/#text-generation) (LLMs) provided by Workers AI. +To see what models are available, invoke `llm models`. Models prefixed with `Cloudflare Workers AI` are provided by this plugin. + The supported models are generated by scripts. New models thus rely on this plugin being updated periodically. In the future, this plugin may also add support for Workers AI's [embedding models](https://developers.cloudflare.com/workers-ai/models/#text-embeddings) for use with [`llm embed`](https://llm.datasette.io/en/stable/embeddings/index.html). diff --git a/llm_cloudflare.py b/llm_cloudflare.py index a029547..27a952a 100644 --- a/llm_cloudflare.py +++ b/llm_cloudflare.py @@ -2,7 +2,7 @@ from llm import ModelError import os from openai import OpenAI, APIConnectionError, RateLimitError, APIStatusError -from pydantic import Field, field_validator, model_validator +from pydantic import Field from typing import Optional, List DEFAULT_MODEL = "@cf/meta/llama-3.1-8b-instruct" @@ -11,7 +11,51 @@ @llm.hookimpl def register_models(register): # Workers AI text generation models: https://developers.cloudflare.com/workers-ai/models/#text-generation + # + # Generated via: + # curl \ + # "https://api.cloudflare.com/client/v4/accounts/${CLOUDFLARE_ACCOUNT_ID}/ai/models/search?per_page=1000" \ + # -H "Authorization: Bearer ${WORKERS_AI_TOKEN}" \ + # | jq --raw-output '.result[] | select (.task.name | contains("Text Generation")) | "register(WorkersAI(\"\(.name)\"))"' + register( + WorkersAI("@cf/meta/llama-3.1-8b-instruct"), aliases=("llama3.1-8b-instruct") + ) + register(WorkersAI("@cf/qwen/qwen1.5-0.5b-chat")) + register(WorkersAI("@cf/google/gemma-2b-it-lora")) + register(WorkersAI("@hf/nexusflow/starling-lm-7b-beta")) + register(WorkersAI("@hf/thebloke/llamaguard-7b-awq")) + register(WorkersAI("@hf/thebloke/neural-chat-7b-v3-1-awq")) + register(WorkersAI("@cf/meta/llama-2-7b-chat-fp16")) + register(WorkersAI("@cf/mistral/mistral-7b-instruct-v0.1")) + register(WorkersAI("@cf/mistral/mistral-7b-instruct-v0.2-lora")) + register(WorkersAI("@cf/tinyllama/tinyllama-1.1b-chat-v1.0")) + register(WorkersAI("@hf/mistral/mistral-7b-instruct-v0.2")) + register(WorkersAI("@cf/fblgit/una-cybertron-7b-v2-bf16")) + register(WorkersAI("@cf/thebloke/discolm-german-7b-v1-awq")) + register(WorkersAI("@cf/meta/llama-2-7b-chat-int8")) + register(WorkersAI("@cf/meta/llama-3.1-8b-instruct-fp8")) + register(WorkersAI("@hf/thebloke/mistral-7b-instruct-v0.1-awq")) + register(WorkersAI("@cf/qwen/qwen1.5-7b-chat-awq")) + register(WorkersAI("@hf/thebloke/llama-2-13b-chat-awq")) + register(WorkersAI("@hf/thebloke/deepseek-coder-6.7b-base-awq")) + register(WorkersAI("@cf/meta-llama/llama-2-7b-chat-hf-lora")) + register(WorkersAI("@hf/thebloke/openhermes-2.5-mistral-7b-awq")) + register(WorkersAI("@hf/thebloke/deepseek-coder-6.7b-instruct-awq")) + register(WorkersAI("@cf/deepseek-ai/deepseek-math-7b-instruct")) + register(WorkersAI("@cf/tiiuae/falcon-7b-instruct")) + register(WorkersAI("@hf/nousresearch/hermes-2-pro-mistral-7b")) register(WorkersAI("@cf/meta/llama-3.1-8b-instruct")) + register(WorkersAI("@cf/meta/llama-3.1-8b-instruct-awq")) + register(WorkersAI("@hf/thebloke/zephyr-7b-beta-awq")) + register(WorkersAI("@cf/google/gemma-7b-it-lora")) + register(WorkersAI("@cf/qwen/qwen1.5-1.8b-chat")) + register(WorkersAI("@cf/meta/llama-3-8b-instruct-awq")) + register(WorkersAI("@cf/defog/sqlcoder-7b-2")) + register(WorkersAI("@cf/microsoft/phi-2")) + register(WorkersAI("@hf/meta-llama/meta-llama-3-8b-instruct")) + register(WorkersAI("@hf/google/gemma-7b-it")) + register(WorkersAI("@cf/qwen/qwen1.5-14b-chat-awq")) + register(WorkersAI("@cf/openchat/openchat-3.5-0106")) class WorkersAIOptions(llm.Options): @@ -21,24 +65,23 @@ class WorkersAIOptions(llm.Options): max_tokens: Optional[int] = Field( default=None, - description="TODO", + description="The maximum number of tokens to return in a response.", ) temperature: Optional[float] = Field( default=None, - description="TODO", + description="'temperature' refers to a hyperparameter that controls the randomness and creativity of generated text. Higher temperatures produce more unpredictable and less coherent output, and lower temperatures produce more precise and factual but also more predictable and less imaginative text.", ) top_p: Optional[float] = Field( default=None, - description="TODO", + description="'top_p' is a parameter that determines the probability threshold for generating a word or token, dictating how eagerly the model selects the next word in a sequence based on its predicted probability of being the next word. Higher values give more creative freedom and lower values lead to more coherent but less diverse outputs.", ) top_k: Optional[int] = Field( default=None, - description="TODO", + description="'top-k' is a technique that restricts the model from generating a response by considering only the top-k most likely next tokens from the vocabulary, rather than considering all possible tokens. A higher value is less restrictive.", ) - # TODO: Define options for temperature, top_k and top_p class WorkersAI(llm.Model): diff --git a/pyproject.toml b/pyproject.toml index 2a18270..dae12f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "llm-cloudflare" -version = "0.2.1" +version = "0.3" description = "An LLM CLI plugin for Cloudflare Workers AI models." readme = "README.md" authors = [{name = "elithrar"}] diff --git a/uv.lock b/uv.lock index 3ea1e14..6593860 100644 --- a/uv.lock +++ b/uv.lock @@ -170,7 +170,7 @@ wheels = [ [[package]] name = "llm-cloudflare" -version = "0.1" +version = "0.3" source = { virtual = "." } dependencies = [ { name = "llm" },