Skip to content

Commit

Permalink
Merge pull request #331 from tanweersalah/rate-limiter
Browse files Browse the repository at this point in the history
feat: added rate limiter
  • Loading branch information
kyma-bot authored Jan 23, 2025
2 parents 50a73bc + 9f6c7e0 commit 9c56e33
Show file tree
Hide file tree
Showing 15 changed files with 809 additions and 25 deletions.
4 changes: 4 additions & 0 deletions src/agents/common/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@
K8S_AGENT = "KubernetesAgent"

KYMA_AGENT = "KymaAgent"

SUCCESS_CODE = 200

ERROR_RATE_LIMIT_CODE = 429
36 changes: 36 additions & 0 deletions src/agents/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import hashlib
from collections.abc import Sequence
from datetime import UTC, datetime
from typing import Any, Literal

import tiktoken
Expand Down Expand Up @@ -128,6 +130,40 @@ def create_node_output(
}


def get_current_day_timestamps_utc() -> tuple[str, str]:
"""
Returns the start and end timestamps for the current day in UTC.
Start: 00:00:00
End: 23:59:59
"""
# Get the current date in UTC
now = datetime.now(UTC)

# Start of the day (00:00:00)
start_of_day = datetime(now.year, now.month, now.day, 0, 0, 0)

# End of the day (23:59:59)
end_of_day = datetime(now.year, now.month, now.day, 23, 59, 59)

# Format the timestamps in ISO format
from_timestamp = start_of_day.strftime("%Y-%m-%dT%H:%M:%S.%fZ")
to_timestamp = end_of_day.strftime("%Y-%m-%dT%H:%M:%S.%fZ")

return from_timestamp, to_timestamp


def hash_url(url: str) -> str:
"""
Generate a 32-character MD5 hash of a given URL.
:url (str): The URL string to be hashed.
Returns:
str: A 32-character hexadecimal string representing the MD5 hash.
"""
return hashlib.md5(url.encode()).hexdigest()


def compute_string_token_count(text: str, model_type: ModelType) -> int:
"""Returns the token count of the string."""
return len(tiktoken.encoding_for_model(model_type).encode(text=text))
Expand Down
15 changes: 12 additions & 3 deletions src/agents/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from agents.summarization.summarization import Summarization
from agents.supervisor.agent import SUPERVISOR, SupervisorAgent
from services.k8s import IK8sClient
from utils.langfuse import handler
from utils.logging import get_logger
from utils.models.factory import IModel, ModelType
from utils.settings import (
Expand Down Expand Up @@ -100,10 +99,14 @@ class CompanionGraph:
planner_prompt: ChatPromptTemplate

def __init__(
self, models: dict[str, IModel | Embeddings], memory: BaseCheckpointSaver
self,
models: dict[str, IModel | Embeddings],
memory: BaseCheckpointSaver,
handler: Any = None,
):
self.models = models
self.memory = memory
self.handler = handler

gpt_4o_mini = models[ModelType.GPT4O_MINI]
gpt_4o = models[ModelType.GPT4O]
Expand Down Expand Up @@ -233,6 +236,9 @@ async def astream(
HumanMessage(content=message.query),
]

x_cluster_url = k8s_client.get_api_server()
cluster_id = x_cluster_url.split(".")[1]

async for chunk in self.graph.astream(
input={
"messages": messages,
Expand All @@ -244,7 +250,10 @@ async def astream(
"configurable": {
"thread_id": conversation_id,
},
"callbacks": [handler],
"callbacks": [self.handler],
"tags": [
cluster_id
], # cluster_id as a tag for traceability and rate limiting
},
):
chunk_json = json.dumps(chunk, cls=CustomJSONEncoder)
Expand Down
77 changes: 74 additions & 3 deletions src/routers/conversations.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from datetime import UTC, datetime
from functools import lru_cache
from typing import Annotated
from typing import Annotated, Any

from fastapi import APIRouter, Body, Depends, Header, HTTPException, Path
from fastapi.encoders import jsonable_encoder
from starlette.responses import JSONResponse, StreamingResponse

from agents.common.constants import ERROR_RATE_LIMIT_CODE
from agents.common.data import Message
from agents.common.utils import get_current_day_timestamps_utc
from routers.common import (
API_PREFIX,
SESSION_ID_HEADER,
Expand All @@ -16,14 +19,21 @@
from services.conversation import ConversationService, IService
from services.data_sanitizer import DataSanitizer, IDataSanitizer
from services.k8s import IK8sClient, K8sClient
from services.langfuse import ILangfuseService, LangfuseService
from utils.config import Config, get_config
from utils.logging import get_logger
from utils.response import prepare_chunk_response
from utils.settings import TOKEN_LIMIT_PER_CLUSTER
from utils.utils import create_session_id

logger = get_logger(__name__)


def get_langfuse_service() -> ILangfuseService:
"""Dependency to get the langfuse service instance"""
return LangfuseService()


@lru_cache(maxsize=1)
def init_config() -> Config:
"""Initialize the config object once."""
Expand All @@ -38,10 +48,11 @@ def init_data_sanitizer(


def init_conversation_service(
config: Annotated[Config, Depends(init_config)]
config: Annotated[Config, Depends(init_config)],
langfuse_service: ILangfuseService = Depends(get_langfuse_service), # noqa B008
) -> IService:
"""Initialize the conversation service instance"""
return ConversationService(config=config)
return ConversationService(langfuse_handler=langfuse_service.handler, config=config)


router = APIRouter(
Expand Down Expand Up @@ -143,9 +154,13 @@ async def messages(
x_cluster_certificate_authority_data: Annotated[str, Header()],
conversation_service: Annotated[IService, Depends(init_conversation_service)],
data_sanitizer: Annotated[IDataSanitizer, Depends(init_data_sanitizer)],
langfuse_service: ILangfuseService = Depends(get_langfuse_service), # noqa B008
) -> StreamingResponse:
"""Endpoint to send a message to the Kyma companion"""

# Check rate limitation
await check_token_usage(x_cluster_url, langfuse_service)

# Initialize k8s client for the request.
try:
k8s_client: IK8sClient = K8sClient(
Expand All @@ -169,3 +184,59 @@ async def messages(
),
media_type="text/event-stream",
)


async def check_token_usage(
x_cluster_url: str,
langfuse_service: Any,
token_limit: int = TOKEN_LIMIT_PER_CLUSTER,
) -> None:
"""
Checks the total token usage for a specific cluster within the current day (UTC) and raises an HTTPException
if the usage exceeds the predefined token limit.
:param x_cluster_url: The URL of the cluster, from which the cluster ID is extracted.
:param langfuse_service: An instance of a service that provides access to the
Langfuse API to retrieve token usage data.
:param token_limit: Default TOKEN_LIMIT_PER_CLUSTER
:raises HTTPException: If the total token usage exceeds the daily limit (`TOKEN_LIMIT_PER_CLUSTER`),
an HTTP 429 error is raised
with details about the current usage,
the limit, and the time remaining until the limit resets at midnight UTC.
"""

# Check if any limit is set, if no limit specified do not proceed
if token_limit == -1:
return

from_timestamp, to_timestamp = get_current_day_timestamps_utc()
cluster_id = x_cluster_url.split(".")[1]
total_token_usage = 0
try:

total_token_usage = await langfuse_service.get_total_token_usage(
from_timestamp, to_timestamp, cluster_id
)
except Exception as e:
logger.error(e)
logger.error("failed to connect to the Langfuse API")

if total_token_usage > token_limit:
current_utc = datetime.now(UTC)
midnight_utc = current_utc.replace(hour=23, minute=59, second=59)
time_remaining = midnight_utc - current_utc
seconds_remaining = int(time_remaining.total_seconds())
raise HTTPException(
status_code=ERROR_RATE_LIMIT_CODE,
detail={
"error": "Rate limit exceeded",
"message": f"Daily token limit of {token_limit} exceeded for this cluster",
"current_usage": total_token_usage,
"limit": token_limit,
"time_remaining_seconds": seconds_remaining,
},
headers={"Retry-After": str(seconds_remaining)},
)
6 changes: 4 additions & 2 deletions src/services/conversation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from collections.abc import AsyncGenerator
from typing import Protocol, cast

from langfuse.callback import CallbackHandler

from agents.common.data import Message
from agents.graph import CompanionGraph, IGraph
from agents.memory.async_redis_checkpointer import AsyncRedisSaver
Expand Down Expand Up @@ -58,6 +60,7 @@ def __init__(
initial_questions_handler: IInitialQuestionsHandler | None = None,
model_factory: IModelFactory | None = None,
followup_questions_handler: IFollowUpQuestionsHandler | None = None,
langfuse_handler: CallbackHandler | None = None,
) -> None:
try:
self._model_factory = model_factory or ModelFactory(config=config)
Expand All @@ -82,8 +85,7 @@ def __init__(
host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER
)
self._companion_graph = CompanionGraph(
models,
memory=checkpointer,
models, memory=checkpointer, handler=langfuse_handler
)

def new_conversation(self, k8s_client: IK8sClient, message: Message) -> list[str]:
Expand Down
8 changes: 8 additions & 0 deletions src/services/k8s.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@
class IK8sClient(Protocol):
"""Interface for the K8sClient class."""

def get_api_server(self) -> str:
"""Returns the URL of the Kubernetes cluster."""
...

def model_dump(self) -> None:
"""Dump the model without any confidential data."""
...
Expand Down Expand Up @@ -124,6 +128,10 @@ def __del__(self):
except FileNotFoundError:
return

def get_api_server(self) -> str:
"""Returns the URL of the Kubernetes cluster."""
return self.api_server

def model_dump(self) -> None:
"""Dump the model. It should not return any critical information because it is called by checkpointer
to store the object in database."""
Expand Down
Loading

0 comments on commit 9c56e33

Please sign in to comment.