legal-ai-assistant/frontend/app.py

122 lines
3.7 KiB
Python

import os
from dotenv import load_dotenv
load_dotenv()
import aiohttp
import asyncio
import chainlit as cl
from chainlit.data.sql_alchemy import SQLAlchemyDataLayer
from chainlit.types import ThreadDict
from frontend.services import (
sync_session_model,
post_agent_request,
parse_sse_stream,
dispatch_events,
finalize_message,
)
from configs import (
ALL_MODELS,
ALL_STARTERS,
DEFAULT_MODEL,
AUTH_USER,
AUTH_PASS,
CHAINLIT_DATABASE_URL,
HTTP_TIMEOUT_TOTAL,
HTTP_TIMEOUT_CONNECT,
)
_CLIENT_TIMEOUT = aiohttp.ClientTimeout(
total=HTTP_TIMEOUT_TOTAL,
connect=HTTP_TIMEOUT_CONNECT,
)
@cl.password_auth_callback
def auth_callback(username: str, password: str) -> cl.User | None:
"""Checks the user's login and password."""
if username == AUTH_USER and password == AUTH_PASS:
return cl.User(
identifier=username,
metadata={"role": "admin", "provider": "credentials"}
)
else:
return None
@cl.set_starters
async def set_starters() -> list[cl.Starter]:
"""Define starter prompts displayed on the welcome screen."""
return [
cl.Starter(
label=starter["label"],
message=starter["prompt"],
icon=f"/public/icons/{starter['icon']}",
)
for starter in ALL_STARTERS
]
@cl.set_chat_profiles
async def chat_profile() -> list[cl.ChatProfile]:
"""Define available AI model profiles for the chat."""
return [
cl.ChatProfile(
name=model["id"],
markdown_description=model["desc"],
icon=f"/public/icons/{model['icon']}",
)
for model in ALL_MODELS
]
@cl.on_chat_start
async def init_session() -> None:
"""Initialize chat session with selected model."""
model = cl.user_session.get("chat_profile") or DEFAULT_MODEL
cl.user_session.set("model", model)
cl.user_session.set("history", [])
@cl.data_layer
def get_data_layer() -> SQLAlchemyDataLayer:
"""Returns SQLAlchemy data layer for storing chat history"""
return SQLAlchemyDataLayer(
conninfo=CHAINLIT_DATABASE_URL
)
@cl.on_chat_resume
async def chat_resume(thread: ThreadDict) -> None:
"""Converts thread steps into role-based history list."""
history = []
for step in thread["steps"]:
if step["type"] == "user_message":
history.append({"role": "user", "content": step["output"]})
elif step["type"] == "assistant_message":
history.append({"role": "assistant", "content": step["output"]})
model = cl.user_session.get("chat_profile") or DEFAULT_MODEL
cl.user_session.set("model", model)
cl.user_session.set("history", history)
@cl.on_message
async def on_message(message: cl.Message) -> None:
"""Input AI agent response, connect with FastAPI."""
model, history = sync_session_model()
history.append({"role": "user", "content": message.content})
msg = cl.Message(content="")
cancelled = False
active_steps = {}
async with cl.Step(name="thinking 💭", type="run") as steps:
try:
async with aiohttp.ClientSession(timeout=_CLIENT_TIMEOUT) as session:
response = await post_agent_request(session, message.content, model, history)
async for parsed in parse_sse_stream(response):
await dispatch_events(parsed, msg, steps, active_steps)
except asyncio.CancelledError:
cancelled = True
steps.output = "Cancelled"
except Exception as e:
steps.output = "Error"
history = await finalize_message(msg, cancelled, history)
cl.user_session.set("history", history)