-
Notifications
You must be signed in to change notification settings - Fork 110
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from StreetLamb/backend/langchain
Setup endpoints and graph module to build and run graph
- Loading branch information
Showing
9 changed files
with
1,771 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,10 @@ | ||
from fastapi import APIRouter | ||
|
||
from app.api.routes import items, login, users, utils | ||
from app.api.routes import items, login, users, utils, teams | ||
|
||
api_router = APIRouter() | ||
api_router.include_router(login.router, tags=["login"]) | ||
api_router.include_router(users.router, prefix="/users", tags=["users"]) | ||
api_router.include_router(utils.router, prefix="/utils", tags=["utils"]) | ||
api_router.include_router(items.router, prefix="/items", tags=["items"]) | ||
api_router.include_router(teams.router, prefix="/teams", tags=["teams"]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,152 @@ | ||
from typing import Any | ||
from sqlmodel import func, select | ||
from fastapi import APIRouter, HTTPException | ||
from fastapi.responses import StreamingResponse | ||
|
||
from app.core.graph.build import generator | ||
from app.api.deps import CurrentUser, SessionDep | ||
from app.models import TeamChat, TeamsOut, TeamCreate, TeamOut, Team, Message | ||
|
||
# TODO: To remove | ||
teams = { | ||
"FoodExpertLeader": { | ||
"name": "FoodExperts", | ||
"members": { | ||
"ChineseFoodExpert": { | ||
"type": "worker", | ||
"name": "ChineseFoodExpert", | ||
"backstory": "Studied culinary school in Singapore. Well-verse in hawker to fine-dining experiences. ISFP.", | ||
"role": "Provide chinese food suggestions in Singapore", | ||
"tools": [] | ||
}, | ||
"MalayFoodExpert": { | ||
"type": "worker", | ||
"name": "MalayFoodExpert", | ||
"backstory": "Studied culinary school in Singapore. Well-verse in hawker to fine-dining experiences. INTP.", | ||
"role": "Provide malay food suggestions in Singapore", | ||
"tools": [] | ||
}, | ||
} | ||
}, | ||
"TravelExpertLeader": { | ||
"name": "TravelKakis", | ||
"members": { | ||
"FoodExpertLeader": { | ||
"type": "leader", | ||
"name": "FoodExpertLeader", | ||
"role": "Gather inputs from your team and provide a diverse food suggestions in Singapore.", | ||
"tools": [] | ||
}, | ||
"HistoryExpert": { | ||
"type": "worker", | ||
"name": "HistoryExpert", | ||
"backstory": "Studied Singapore history. Well-verse in Singapore architecture. INTJ.", | ||
"role": "Provide places to sight-see with a history/architecture angle", | ||
"tools": [] | ||
} | ||
} | ||
} | ||
} | ||
team_leader = "TravelExpertLeader" | ||
|
||
router = APIRouter() | ||
|
||
@router.get("/", response_model=TeamsOut) | ||
def read_teams( | ||
session: SessionDep, current_user: CurrentUser, skip: int = 0, limit: int = 100 | ||
) -> Any: | ||
""" | ||
Retrieve teams | ||
""" | ||
|
||
if current_user.is_superuser: | ||
count_statement = select(func.count()).select_from(Team) | ||
count = session.exec(count_statement).one() | ||
statement = select(Team).offset(skip).limit(limit) | ||
teams = session.exec(statement).all() | ||
else: | ||
count_statement = ( | ||
select(func.count()) | ||
.select_from(Team) | ||
.where(Team.owner_id == current_user.id) | ||
) | ||
count = session.exec(count_statement).one() | ||
statement = ( | ||
select(Team) | ||
.where(Team.owner_id == current_user.id) | ||
.offset(skip) | ||
.limit(limit) | ||
) | ||
teams = session.exec(statement).all() | ||
return TeamsOut(data=teams, count=count) | ||
|
||
@router.get("/{id}", response_model=TeamOut) | ||
def read_team(session: SessionDep, current_user: CurrentUser, id: int) -> Any: | ||
""" | ||
Get team by ID. | ||
""" | ||
team = session.get(Team, id) | ||
if not team: | ||
raise HTTPException(status_code=404, detail="Team not found") | ||
if not current_user.is_superuser and (team.owner_id != current_user.id): | ||
raise HTTPException(status_code=400, detail="Not enough permissions") | ||
return team | ||
|
||
@router.post("/", response_model=TeamOut) | ||
def create_team( | ||
*, session: SessionDep, current_user: CurrentUser, team_in: TeamCreate | ||
) -> Any: | ||
""" | ||
Create new team. | ||
""" | ||
team = Team.model_validate(team_in, update={"owner_id": current_user.id}) | ||
session.add(team) | ||
session.commit() | ||
session.refresh(team) | ||
return team | ||
|
||
@router.put("/{id}", response_model=TeamOut) | ||
def update_team( | ||
*, session: SessionDep, current_user: CurrentUser, id: int, team_in: TeamCreate | ||
) -> Any: | ||
""" | ||
Update a team. | ||
""" | ||
team = session.get(Team, id) | ||
if not team: | ||
raise HTTPException(status_code=404, detail="Team not found") | ||
if not current_user.is_superuser and (team.owner_id != current_user.id): | ||
raise HTTPException(status_code=400, detail="Not enough permissions") | ||
update_dict = team_in.model_dump(exclude_unset=True) | ||
team.sqlmodel_update(update_dict) | ||
session.add(team) | ||
session.commit() | ||
session.refresh(team) | ||
return team | ||
|
||
@router.delete("/{id}") | ||
def delete_team(session: SessionDep, current_user: CurrentUser, id: int) -> Any: | ||
""" | ||
Delete a team. | ||
""" | ||
team = session.get(Team, id) | ||
if not team: | ||
raise HTTPException(status_code=404, detail="Team not found") | ||
if not current_user.is_superuser and (team.owner_id != current_user.id): | ||
raise HTTPException(status_code=400, detail="Not enough permissions") | ||
session.delete(team) | ||
session.commit() | ||
return Message(message="Team deleted successfully") | ||
|
||
@router.post("/{id}/stream") | ||
async def stream(session: SessionDep, current_user: CurrentUser, id: int, team_chat: TeamChat): | ||
""" | ||
Stream a response to a user's input. | ||
""" | ||
team = session.get(Team, id) | ||
if not team: | ||
raise HTTPException(status_code=404, detail="Team not found") | ||
if not current_user.is_superuser and (team.owner_id != current_user.id): | ||
raise HTTPException(status_code=400, detail="Not enough permissions") | ||
|
||
return StreamingResponse(generator(teams, team_leader, team_chat.messages), media_type="text/event-stream") |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
from functools import partial | ||
from typing import Dict, List | ||
from app.models import ChatMessage | ||
from langgraph.graph import StateGraph, END | ||
from langchain_openai import ChatOpenAI | ||
from app.core.graph.members import Leader, LeaderNode, Member, SummariserNode, TeamState, WorkerNode | ||
from langchain_core.messages import HumanMessage, AIMessage | ||
from langchain_core.runnables import RunnableLambda | ||
|
||
model = ChatOpenAI(model="gpt-3.5-turbo") | ||
|
||
# Create the Member/Leader class instance in members | ||
def format_teams(teams: Dict[str, any]): | ||
"""Update the team members to use Member/Leader""" | ||
for team in teams: | ||
members = teams[team]["members"] | ||
for k,v in members.items(): | ||
print(v) | ||
teams[team]["members"][k] = Leader(**v) if v["type"] == "leader" else Member(**v) | ||
return teams | ||
|
||
def router(state: TeamState): | ||
return state["next"] | ||
|
||
def enter_chain(state: TeamState, team: Dict[str, str | List[Member | Leader]]): | ||
""" | ||
Initialise the sub-graph state. | ||
This makes it so that the states of each graph don't get intermixed. | ||
""" | ||
task = state["task"] | ||
team_name = team["name"] | ||
team_members = team["members"] | ||
|
||
results = { | ||
"messages": task, | ||
"team_name": team_name, | ||
"team_members": team_members, | ||
} | ||
return results | ||
|
||
def exit_chain(state: TeamState): | ||
""" | ||
Pass the final response back to the top-level graph's state. | ||
""" | ||
answer = state["messages"][-1] | ||
return {"messages": [answer]} | ||
|
||
def create_graph(teams: Dict[str, Dict[str, str | Dict[str, Member | Leader]]], leader_name: str): | ||
""" | ||
Create the team's graph. | ||
""" | ||
build = StateGraph(TeamState) | ||
# Add the start and end node | ||
build.add_node(leader_name, RunnableLambda(LeaderNode(model).delegate)) | ||
build.add_node("summariser", RunnableLambda(SummariserNode(model).summarise)) | ||
|
||
members = teams[leader_name]["members"] | ||
for name, member in members.items(): | ||
if isinstance(member, Member): | ||
build.add_node(name, RunnableLambda(WorkerNode(model).work)) | ||
elif isinstance(member, Leader): | ||
subgraph = create_graph(teams, leader_name=name) | ||
enter = partial(enter_chain, team=teams[name]) | ||
build.add_node(name, enter | subgraph | exit_chain) | ||
else: | ||
continue | ||
build.add_edge(name, leader_name) | ||
|
||
conditional_mapping = {v:v for v in members} | ||
conditional_mapping["FINISH"] = "summariser" | ||
build.add_conditional_edges(leader_name, router, conditional_mapping) | ||
|
||
build.set_entry_point(leader_name) | ||
build.set_finish_point("summariser") | ||
graph = build.compile() | ||
return graph | ||
|
||
|
||
|
||
async def generator(teams: dict, team_leader: str, messages: List[ChatMessage]): | ||
"""Create the graph and strem the response""" | ||
format_teams(teams) | ||
root = create_graph(teams, leader_name=team_leader) | ||
messages = [HumanMessage(message.content) if message.type == "human" else AIMessage(message.content) for message in messages] | ||
|
||
async for output in root.astream({ | ||
"messages": messages, | ||
"team_name": teams[team_leader]["name"], | ||
"team_members": teams[team_leader]["members"] | ||
}): | ||
for key, value in output.items(): | ||
if key != "__end__": | ||
response = {key :value} | ||
formatted_output = f"data: {response}\n\n" | ||
print(formatted_output) | ||
yield formatted_output | ||
|
||
# teams = { | ||
# "FoodExpertLeader": { | ||
# "name": "FoodExperts", | ||
# "members": { | ||
# "ChineseFoodExpert": { | ||
# "type": "worker", | ||
# "name": "ChineseFoodExpert", | ||
# "backstory": "Studied culinary school in Singapore. Well-verse in hawker to fine-dining experiences. ISFP.", | ||
# "role": "Provide chinese food suggestions in Singapore", | ||
# "tools": [] | ||
# }, | ||
# "MalayFoodExpert": { | ||
# "type": "worker", | ||
# "name": "MalayFoodExpert", | ||
# "backstory": "Studied culinary school in Singapore. Well-verse in hawker to fine-dining experiences. INTP.", | ||
# "role": "Provide malay food suggestions in Singapore", | ||
# "tools": [] | ||
# }, | ||
# } | ||
# }, | ||
# "TravelExpertLeader": { | ||
# "name": "TravelKakis", | ||
# "members": { | ||
# "FoodExpertLeader": { | ||
# "type": "leader", | ||
# "name": "FoodExpertLeader", | ||
# "role": "Gather inputs from your team and provide a diverse food suggestions in Singapore.", | ||
# "tools": [] | ||
# }, | ||
# "HistoryExpert": { | ||
# "type": "worker", | ||
# "name": "HistoryExpert", | ||
# "backstory": "Studied Singapore history. Well-verse in Singapore architecture. INTJ.", | ||
# "role": "Provide places to sight-see with a history/architecture angle", | ||
# "tools": ["search"] | ||
# } | ||
# } | ||
# } | ||
# } | ||
|
||
# format_teams(teams) | ||
|
||
# team_leader = "TravelExpertLeader" | ||
|
||
# root = create_graph(teams, team_leader) | ||
|
||
# messages = [ | ||
# HumanMessage(f"What is the best food in Singapore") | ||
# ] | ||
|
||
# initial_state = { | ||
# "messages": messages, | ||
# "team_name": teams[team_leader]["name"], | ||
# "team_members": teams[team_leader]["members"], | ||
# } | ||
|
||
# async def main(): | ||
# async for s in root.astream(initial_state): | ||
# if "__end__" not in s: | ||
# print(s) | ||
# print("----") | ||
|
||
# import asyncio | ||
|
||
# asyncio.run(main()) |
Oops, something went wrong.