diff --git a/app/Dockerfile b/app/Dockerfile index dfd49e4..d274c5b 100644 --- a/app/Dockerfile +++ b/app/Dockerfile @@ -11,10 +11,6 @@ RUN pip install --no-cache-dir -r requirements.txt COPY . /app/ COPY .chainlit /app/.chainlit -# Patch chainlit library to support cognito auth -COPY patches/oauth_providers.py /usr/local/lib/python3.11/site-packages/chainlit/ -COPY patches/base.py /usr/local/lib/python3.11/site-packages/chainlit/client/ - EXPOSE 8080 CMD ["chainlit", "run", "./app.py", "--port", "8080"] \ No newline at end of file diff --git a/app/Dockerfile-api b/app/Dockerfile-api deleted file mode 100644 index 80f2bd9..0000000 --- a/app/Dockerfile-api +++ /dev/null @@ -1,7 +0,0 @@ -FROM python:3.11.6-slim -RUN apt-get update && apt-get install ca-certificates -WORKDIR /app -COPY . /app/ -RUN pip install --no-cache-dir -r requirements.txt -EXPOSE 8001 -CMD ["python", "./api_server.py"] \ No newline at end of file diff --git a/app/app.py b/app/app.py index 6ab9b2e..0cd0f5b 100644 --- a/app/app.py +++ b/app/app.py @@ -3,7 +3,7 @@ from typing import Optional, Dict import chainlit as cl -from chainlit.client.base import ConversationDict +from chainlit.types import ThreadDict from langchain.agents import AgentExecutor from chainlit.server import app @@ -32,7 +32,7 @@ async def check_text(text: str): if os.getenv("DISABLE_AUTH", "").lower() != "true": @cl.oauth_callback def oauth_callback(provider_id: str, token: str, raw_user_data: Dict[str, str], - default_app_user: cl.AppUser) -> Optional[cl.AppUser]: + default_app_user: cl.User) -> Optional[cl.User]: return default_app_user @@ -65,7 +65,8 @@ async def chat_profile(): @cl.on_chat_start async def on_chat_start(): app_user = cl.user_session.get("user") - session_id = '' if app_user is None else f'{app_user.username}:{cl.user_session.get("id")}' + session_id = '' if app_user is None else f'{app_user.id}:{cl.user_session.get("id")}' + print("starting chat: " + session_id) character = cl.user_session.get("chat_profile") agent = get_agent(session_id, personality=character) cl.user_session.set("agent", agent) @@ -84,19 +85,19 @@ async def main(message: cl.Message): # Handler for resuming chat @cl.on_chat_resume -async def on_chat_resume(conversation: ConversationDict): +async def on_chat_resume(thread: ThreadDict): app_user = cl.user_session.get("user") - session_id = f'{app_user.username}:{conversation["id"]}' + session_id = f'{app_user.id}:{cl.user_session.get("id")}' + print("resumeing chat: " + session_id) agent = get_agent(session_id) cl.user_session.set("agent", agent) memory = agent.memory - root_messages = [m for m in conversation["messages"] if m["parentId"] is None] + root_messages = [m for m in thread["steps"] if m["parentId"] == None] for message in root_messages: - if message["authorIsUser"]: - memory.chat_memory.add_user_message(message["content"]) + if message["type"] == "USER_MESSAGE": + memory.chat_memory.add_user_message(message["output"]) else: - memory.chat_memory.add_ai_message(message["content"]) - + memory.chat_memory.add_ai_message(message["output"]) # Endpoint for readiness check used by GAE @app.get('/readiness_check') diff --git a/app/patches/base.py b/app/patches/base.py deleted file mode 100644 index 9ba877a..0000000 --- a/app/patches/base.py +++ /dev/null @@ -1,169 +0,0 @@ -from typing import ( - Any, - Dict, - Generic, - List, - Literal, - Mapping, - Optional, - TypedDict, - TypeVar, -) - -from chainlit.logger import logger -from chainlit.prompt import Prompt -from dataclasses_json import DataClassJsonMixin -from pydantic import BaseModel, Field -from pydantic.dataclasses import dataclass -from python_graphql_client import GraphqlClient - -ElementType = Literal[ - "image", "avatar", "text", "pdf", "tasklist", "audio", "video", "file", "plotly" -] -ElementDisplay = Literal["inline", "side", "page"] -ElementSize = Literal["small", "medium", "large"] - -Role = Literal["USER", "ADMIN", "OWNER", "ANONYMOUS"] -Provider = Literal[ - "credentials", "header", "cognito", "github", "google", "azure-ad", "okta", "auth0", "descope" -] - - -class AppUserDict(TypedDict): - id: str - username: str - - -# Used when logging-in a user -@dataclass -class AppUser(DataClassJsonMixin): - username: str - role: Role = "USER" - tags: List[str] = Field(default_factory=list) - image: Optional[str] = None - provider: Optional[Provider] = None - - -@dataclass -class PersistedAppUserFields: - id: str - createdAt: int - - -@dataclass -class PersistedAppUser(AppUser, PersistedAppUserFields): - pass - - -class MessageDict(TypedDict): - conversationId: Optional[str] - id: str - createdAt: Optional[int] - content: str - author: str - prompt: Optional[Prompt] - language: Optional[str] - parentId: Optional[str] - indent: Optional[int] - authorIsUser: Optional[bool] - waitForAnswer: Optional[bool] - isError: Optional[bool] - humanFeedback: Optional[int] - disableHumanFeedback: Optional[bool] - - -class ElementDict(TypedDict): - id: str - conversationId: Optional[str] - type: ElementType - url: str - objectKey: Optional[str] - name: str - display: ElementDisplay - size: Optional[ElementSize] - language: Optional[str] - forIds: Optional[List[str]] - mime: Optional[str] - - -class ConversationDict(TypedDict): - id: Optional[str] - metadata: Optional[Dict] - createdAt: Optional[int] - appUser: Optional[AppUserDict] - messages: List[MessageDict] - elements: Optional[List[ElementDict]] - - -@dataclass -class PageInfo: - hasNextPage: bool - endCursor: Optional[str] - - -T = TypeVar("T") - - -@dataclass -class PaginatedResponse(DataClassJsonMixin, Generic[T]): - pageInfo: PageInfo - data: List[T] - - -class Pagination(BaseModel): - first: int - cursor: Optional[str] = None - - -class ConversationFilter(BaseModel): - feedback: Optional[Literal[-1, 0, 1]] = None - username: Optional[str] = None - search: Optional[str] = None - - -class ChainlitGraphQLClient: - def __init__(self, api_key: str, chainlit_server: str): - self.headers = {"content-type": "application/json"} - if api_key: - self.headers["x-api-key"] = api_key - else: - raise ValueError("Cannot instantiate Cloud Client without CHAINLIT_API_KEY") - - graphql_endpoint = f"{chainlit_server}/api/graphql" - self.graphql_client = GraphqlClient( - endpoint=graphql_endpoint, headers=self.headers - ) - - async def query(self, query: str, variables: Dict[str, Any] = {}) -> Dict[str, Any]: - """ - Execute a GraphQL query. - - :param query: The GraphQL query string. - :param variables: A dictionary of variables for the query. - :return: The response data as a dictionary. - """ - return await self.graphql_client.execute_async(query=query, variables=variables) - - def check_for_errors(self, response: Dict[str, Any], raise_error: bool = False): - if "errors" in response: - if raise_error: - raise Exception( - f"{response['errors'][0]['message']}. Path: {str(response['errors'][0]['path'])}" - ) - logger.error(response["errors"][0]) - return True - return False - - async def mutation( - self, mutation: str, variables: Mapping[str, Any] = {} - ) -> Dict[str, Any]: - """ - Execute a GraphQL mutation. - - :param mutation: The GraphQL mutation string. - :param variables: A dictionary of variables for the mutation. - :return: The response data as a dictionary. - """ - return await self.graphql_client.execute_async( - query=mutation, variables=variables - ) diff --git a/app/patches/note.md b/app/patches/note.md deleted file mode 100644 index 1a76634..0000000 --- a/app/patches/note.md +++ /dev/null @@ -1,7 +0,0 @@ -# Note - -Authentication with AWS Cognito is not supported yet, at the time of creating this project. -I am patching files from chainlit library to support OAuth2 with AWS Cognito. -These files are used to replace chainlit library files in `Dockerfile` - -I wil soon open PR to officially support Cognito in chainlit \ No newline at end of file diff --git a/app/patches/oauth_providers.py b/app/patches/oauth_providers.py deleted file mode 100644 index 0bfcb62..0000000 --- a/app/patches/oauth_providers.py +++ /dev/null @@ -1,484 +0,0 @@ -import base64 -import os -import urllib.parse -from typing import Dict, List, Optional, Tuple - -import httpx -from chainlit.client.base import AppUser -from fastapi import HTTPException - - -class OAuthProvider: - id: str - env: List[str] - client_id: str - client_secret: str - authorize_url: str - authorize_params: Dict[str, str] - - def is_configured(self): - return all([os.environ.get(env) for env in self.env]) - - async def get_token(self, code: str, url: str) -> str: - raise NotImplementedError() - - async def get_user_info(self, token: str) -> Tuple[Dict[str, str], AppUser]: - raise NotImplementedError() - - -class CognitoOAuthProvider(OAuthProvider): - id = "cognito" - env = ["COGNITO_CLIENT_ID", "COGNITO_CLIENT_SECRET", "COGNITO_DOMAIN", "COGNITO_REDIRECT_URI"] - - def __init__(self): - self.client_id = os.getenv("COGNITO_CLIENT_ID") - self.client_secret = os.getenv("COGNITO_CLIENT_SECRET") - self.redirect_uri = os.getenv("COGNITO_REDIRECT_URI") - self.domain = os.getenv("COGNITO_DOMAIN") - self.authorize_url = f"https://{self.domain}/login" - self.token_url = f"https://{self.domain}/oauth2/token" - self.user_info_url = f"https://{self.domain}/oauth2/userInfo" - self.authorize_params = { - "response_type": "code", - "scope": "openid email", - "redirect_uri": self.redirect_uri - } - - async def get_token(self, code: str, url: str) -> str: - payload = { - "grant_type": "authorization_code", - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "redirect_uri": self.redirect_uri - } - async with httpx.AsyncClient() as client: - response = await client.post(self.token_url, data=payload) - response.raise_for_status() - content = response.json() - token = content.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str) -> Tuple[Dict[str, str], AppUser]: - async with httpx.AsyncClient() as client: - try: - response = await client.get( - self.user_info_url, - headers={"Authorization": f"Bearer {token}"} - ) - response.raise_for_status() - except Exception as err: - print(f"An error occurred: {err}") - else: - user_info = response.json() - - app_user = AppUser( - username=user_info["email"], - image=user_info.get("picture"), - provider="cognito" - ) - return (user_info, app_user) - - -class GithubOAuthProvider(OAuthProvider): - id = "github" - env = ["OAUTH_GITHUB_CLIENT_ID", "OAUTH_GITHUB_CLIENT_SECRET"] - authorize_url = "https://github.com/login/oauth/authorize" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GITHUB_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GITHUB_CLIENT_SECRET") - self.authorize_params = { - "scope": "user:email", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - } - async with httpx.AsyncClient() as client: - response = await client.post( - "https://github.com/login/oauth/access_token", - data=payload, - ) - response.raise_for_status() - content = urllib.parse.parse_qs(response.text) - token = content.get("access_token", [""])[0] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - user_response = await client.get( - "https://api.github.com/user", - headers={"Authorization": f"token {token}"}, - ) - user_response.raise_for_status() - user = user_response.json() - - emails_response = await client.get( - "https://api.github.com/user/emails", - headers={"Authorization": f"token {token}"}, - ) - emails_response.raise_for_status() - emails = emails_response.json() - - user.update({"emails": emails}) - - app_user = AppUser( - username=user["login"], - image=user["avatar_url"], - provider="github", - ) - return (user, app_user) - - -class GoogleOAuthProvider(OAuthProvider): - id = "google" - env = ["OAUTH_GOOGLE_CLIENT_ID", "OAUTH_GOOGLE_CLIENT_SECRET"] - authorize_url = "https://accounts.google.com/o/oauth2/v2/auth" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_GOOGLE_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_GOOGLE_CLIENT_SECRET") - self.authorize_params = { - "scope": "https://www.googleapis.com/auth/userinfo.profile https://www.googleapis.com/auth/userinfo.email", - "response_type": "code", - "access_type": "offline", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - "https://oauth2.googleapis.com/token", - data=payload, - ) - response.raise_for_status() - json = response.json() - token = json.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://www.googleapis.com/userinfo/v2/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - user = response.json() - - app_user = AppUser( - username=user["name"], image=user["picture"], provider="google" - ) - return (user, app_user) - - -class AzureADOAuthProvider(OAuthProvider): - id = "azure-ad" - env = [ - "OAUTH_AZURE_AD_CLIENT_ID", - "OAUTH_AZURE_AD_CLIENT_SECRET", - "OAUTH_AZURE_AD_TENANT_ID", - ] - authorize_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/authorize" - if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" - ) - token_url = ( - f"https://login.microsoftonline.com/{os.environ.get('OAUTH_AZURE_AD_TENANT_ID', '')}/oauth2/v2.0/token" - if os.environ.get("OAUTH_AZURE_AD_ENABLE_SINGLE_TENANT") - else "https://login.microsoftonline.com/common/oauth2/v2.0/token" - ) - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AZURE_AD_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AZURE_AD_CLIENT_SECRET") - self.authorize_params = { - "tenant": os.environ.get("OAUTH_AZURE_AD_TENANT_ID"), - "response_type": "code", - "scope": "https://graph.microsoft.com/User.Read", - "response_mode": "query", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - self.token_url, - data=payload, - ) - response.raise_for_status() - json = response.json() - - token = json["access_token"] - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - "https://graph.microsoft.com/v1.0/me", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - - user = response.json() - - try: - photo_response = await client.get( - "https://graph.microsoft.com/v1.0/me/photos/48x48/$value", - headers={"Authorization": f"Bearer {token}"}, - ) - photo_data = await photo_response.aread() - base64_image = base64.b64encode(photo_data) - user[ - "image" - ] = f"data:{photo_response.headers['Content-Type']};base64,{base64_image.decode('utf-8')}" - except Exception as e: - # Ignore errors getting the photo - pass - - app_user = AppUser( - username=user["userPrincipalName"], - image=user.get("image", ""), - provider="azure-ad", - ) - return (user, app_user) - - -class OktaOAuthProvider(OAuthProvider): - id = "okta" - env = [ - "OAUTH_OKTA_CLIENT_ID", - "OAUTH_OKTA_CLIENT_SECRET", - "OAUTH_OKTA_DOMAIN", - ] - # Avoid trailing slash in domain if supplied - domain = f"https://{os.environ.get('OAUTH_OKTA_DOMAIN', '').rstrip('/')}" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_OKTA_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_OKTA_CLIENT_SECRET") - self.authorization_server_id = os.environ.get( - "OAUTH_OKTA_AUTHORIZATION_SERVER_ID", "" - ) - self.authorize_url = ( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/authorize" - ) - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "response_mode": "query", - } - - def get_authorization_server_path(self): - if not self.authorization_server_id: - return "/default" - if self.authorization_server_id == "false": - return "" - return f"/{self.authorization_server_id}" - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/token", - data=payload, - ) - response.raise_for_status() - json_data = response.json() - - token = json_data.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/oauth2{self.get_authorization_server_path()}/v1/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - user = response.json() - - app_user = AppUser(username=user.get("email"), image="", provider="okta") - return (user, app_user) - - -class Auth0OAuthProvider(OAuthProvider): - id = "auth0" - env = ["OAUTH_AUTH0_CLIENT_ID", "OAUTH_AUTH0_CLIENT_SECRET", "OAUTH_AUTH0_DOMAIN"] - - def __init__(self): - self.client_id = os.environ.get("OAUTH_AUTH0_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_AUTH0_CLIENT_SECRET") - # Ensure that the domain does not have a trailing slash - self.domain = f"https://{os.environ.get('OAUTH_AUTH0_DOMAIN', '').rstrip('/')}" - self.original_domain = ( - f"https://{os.environ.get('OAUTH_AUTH0_ORIGINAL_DOMAIN').rstrip('/')}" - if os.environ.get("OAUTH_AUTH0_ORIGINAL_DOMAIN") - else self.domain - ) - - self.authorize_url = f"{self.domain}/authorize" - - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "audience": f"{self.original_domain}/userinfo", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/oauth/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise HTTPException( - status_code=400, detail="Failed to get the access token" - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.original_domain}/userinfo", - headers={"Authorization": f"Bearer {token}"}, - ) - response.raise_for_status() - user = response.json() - app_user = AppUser( - username=user.get("email"), - image=user.get("picture", ""), - provider="auth0", - ) - return (user, app_user) - - -class DescopeOAuthProvider(OAuthProvider): - id = "descope" - env = ["OAUTH_DESCOPE_CLIENT_ID", "OAUTH_DESCOPE_CLIENT_SECRET"] - # Ensure that the domain does not have a trailing slash - domain = f"https://api.descope.com/oauth2/v1" - - authorize_url = f"{domain}/authorize" - - def __init__(self): - self.client_id = os.environ.get("OAUTH_DESCOPE_CLIENT_ID") - self.client_secret = os.environ.get("OAUTH_DESCOPE_CLIENT_SECRET") - self.authorize_params = { - "response_type": "code", - "scope": "openid profile email", - "audience": f"{self.domain}/userinfo", - } - - async def get_token(self, code: str, url: str): - payload = { - "client_id": self.client_id, - "client_secret": self.client_secret, - "code": code, - "grant_type": "authorization_code", - "redirect_uri": url, - } - async with httpx.AsyncClient() as client: - response = await client.post( - f"{self.domain}/token", - data=payload, - ) - response.raise_for_status() - json_content = response.json() - token = json_content.get("access_token") - if not token: - raise httpx.HTTPStatusError( - "Failed to get the access token", - request=response.request, - response=response, - ) - return token - - async def get_user_info(self, token: str): - async with httpx.AsyncClient() as client: - response = await client.get( - f"{self.domain}/userinfo", headers={"Authorization": f"Bearer {token}"} - ) - response.raise_for_status() # This will raise an exception for 4xx/5xx responses - user = response.json() - - app_user = AppUser(username=user.get("email"), image="", provider="descope") - return (user, app_user) - - -providers = [ - CognitoOAuthProvider(), - GithubOAuthProvider(), - GoogleOAuthProvider(), - AzureADOAuthProvider(), - OktaOAuthProvider(), - Auth0OAuthProvider(), - DescopeOAuthProvider(), -] - - -def get_oauth_provider(provider: str) -> Optional[OAuthProvider]: - for p in providers: - if p.id == provider: - return p - return None - - -def get_configured_oauth_providers(): - return [p.id for p in providers if p.is_configured()] diff --git a/app/requirements.txt b/app/requirements.txt index 9719f27..c23c75e 100644 --- a/app/requirements.txt +++ b/app/requirements.txt @@ -1,10 +1,11 @@ langchain==0.0.348 openai==1.3.8 -chainlit==0.7.700 +chainlit==1.0.101 numexpr==2.8.7 google-search-results==2.4.2 duckduckgo-search==3.8.5 redis==5.0.1 langserve==0.0.34 sse_starlette==1.8.2 -gunicorn==20.0.4 \ No newline at end of file +gunicorn==20.0.4 +literalai==0.0.102 \ No newline at end of file