43 lines
1.7 KiB
Python
43 lines
1.7 KiB
Python
import json
|
|
import aiohttp
|
|
import chainlit as cl
|
|
from typing import AsyncIterator
|
|
from configs import DEFAULT_MODEL, BACKEND_BASE_URL
|
|
|
|
def sync_session_model() -> tuple[str, list]:
|
|
"""Checks for chat profile changes and resets model and history if needed."""
|
|
current_profile = cl.user_session.get("chat_profile") or DEFAULT_MODEL
|
|
current_model = cl.user_session.get("model")
|
|
if current_model != current_profile:
|
|
cl.user_session.set("model", current_profile)
|
|
cl.user_session.set("history", [])
|
|
|
|
model = cl.user_session.get("model")
|
|
history = cl.user_session.get("history")
|
|
|
|
return model, history
|
|
|
|
async def post_agent_request(session: aiohttp.ClientSession, query: str, model: str, history: list) -> aiohttp.ClientResponse:
|
|
"""Makes a POST request, returns a response object."""
|
|
return await session.post(
|
|
BACKEND_BASE_URL,
|
|
json={"query": query, "model": model, "history": history[:-1]})
|
|
|
|
async def parse_sse_stream(response: aiohttp.ClientResponse) -> AsyncIterator[dict]:
|
|
"""Iterates SSE response bytes, decodes and parses lines."""
|
|
async for raw_line in response.content:
|
|
line = raw_line.decode("utf-8").strip()
|
|
if not line.startswith("data: "):
|
|
continue
|
|
data = line[6:]
|
|
if data == "[DONE]":
|
|
break
|
|
parsed = json.loads(data)
|
|
yield parsed
|
|
|
|
async def finalize_message(msg: cl.Message, cancelled: bool, history: list) -> list:
|
|
"""Updates message and appends assistant response to history."""
|
|
await msg.update()
|
|
if msg.content and not cancelled:
|
|
history.append({"role": "assistant", "content": msg.content})
|
|
return history |