122 lines
3.7 KiB
Python
122 lines
3.7 KiB
Python
import os
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
import aiohttp
|
|
import asyncio
|
|
import chainlit as cl
|
|
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
|
|
from chainlit.types import ThreadDict
|
|
|
|
from frontend.services import (
|
|
sync_session_model,
|
|
post_agent_request,
|
|
parse_sse_stream,
|
|
dispatch_events,
|
|
finalize_message,
|
|
)
|
|
from configs import (
|
|
ALL_MODELS,
|
|
ALL_STARTERS,
|
|
DEFAULT_MODEL,
|
|
AUTH_USER,
|
|
AUTH_PASS,
|
|
CHAINLIT_DATABASE_URL,
|
|
HTTP_TIMEOUT_TOTAL,
|
|
HTTP_TIMEOUT_CONNECT,
|
|
)
|
|
|
|
_CLIENT_TIMEOUT = aiohttp.ClientTimeout(
|
|
total=HTTP_TIMEOUT_TOTAL,
|
|
connect=HTTP_TIMEOUT_CONNECT,
|
|
)
|
|
|
|
@cl.password_auth_callback
|
|
def auth_callback(username: str, password: str) -> cl.User | None:
|
|
"""Checks the user's login and password."""
|
|
if username == AUTH_USER and password == AUTH_PASS:
|
|
return cl.User(
|
|
identifier=username,
|
|
metadata={"role": "admin", "provider": "credentials"}
|
|
)
|
|
else:
|
|
return None
|
|
|
|
@cl.set_starters
|
|
async def set_starters() -> list[cl.Starter]:
|
|
"""Define starter prompts displayed on the welcome screen."""
|
|
return [
|
|
cl.Starter(
|
|
label=starter["label"],
|
|
message=starter["prompt"],
|
|
icon=f"/public/icons/{starter['icon']}",
|
|
)
|
|
for starter in ALL_STARTERS
|
|
]
|
|
|
|
@cl.set_chat_profiles
|
|
async def chat_profile() -> list[cl.ChatProfile]:
|
|
"""Define available AI model profiles for the chat."""
|
|
return [
|
|
cl.ChatProfile(
|
|
name=model["id"],
|
|
markdown_description=model["desc"],
|
|
icon=f"/public/icons/{model['icon']}",
|
|
)
|
|
for model in ALL_MODELS
|
|
]
|
|
|
|
@cl.on_chat_start
|
|
async def init_session() -> None:
|
|
"""Initialize chat session with selected model."""
|
|
model = cl.user_session.get("chat_profile") or DEFAULT_MODEL
|
|
cl.user_session.set("model", model)
|
|
cl.user_session.set("history", [])
|
|
|
|
@cl.data_layer
|
|
def get_data_layer() -> SQLAlchemyDataLayer:
|
|
"""Returns SQLAlchemy data layer for storing chat history"""
|
|
return SQLAlchemyDataLayer(
|
|
conninfo=CHAINLIT_DATABASE_URL
|
|
)
|
|
|
|
@cl.on_chat_resume
|
|
async def chat_resume(thread: ThreadDict) -> None:
|
|
"""Converts thread steps into role-based history list."""
|
|
history = []
|
|
|
|
for step in thread["steps"]:
|
|
if step["type"] == "user_message":
|
|
history.append({"role": "user", "content": step["output"]})
|
|
elif step["type"] == "assistant_message":
|
|
history.append({"role": "assistant", "content": step["output"]})
|
|
|
|
model = cl.user_session.get("chat_profile") or DEFAULT_MODEL
|
|
cl.user_session.set("model", model)
|
|
cl.user_session.set("history", history)
|
|
|
|
@cl.on_message
|
|
async def on_message(message: cl.Message) -> None:
|
|
"""Input AI agent response, connect with FastAPI."""
|
|
model, history = sync_session_model()
|
|
history.append({"role": "user", "content": message.content})
|
|
|
|
msg = cl.Message(content="")
|
|
cancelled = False
|
|
active_steps = {}
|
|
|
|
async with cl.Step(name="thinking 💭", type="run") as steps:
|
|
try:
|
|
async with aiohttp.ClientSession(timeout=_CLIENT_TIMEOUT) as session:
|
|
response = await post_agent_request(session, message.content, model, history)
|
|
async for parsed in parse_sse_stream(response):
|
|
await dispatch_events(parsed, msg, steps, active_steps)
|
|
|
|
except asyncio.CancelledError:
|
|
cancelled = True
|
|
steps.output = "Cancelled"
|
|
except Exception as e:
|
|
steps.output = "Error"
|
|
|
|
|
|
history = await finalize_message(msg, cancelled, history)
|
|
cl.user_session.set("history", history) |