79 lines
2.7 KiB
Python
79 lines
2.7 KiB
Python
import json
|
|
import asyncio
|
|
from typing import AsyncIterator
|
|
from pydantic import BaseModel
|
|
from fastapi import APIRouter
|
|
from fastapi.responses import StreamingResponse
|
|
from backend.agent import build_agent, make_mcp_server, stream_response
|
|
from configs import DEFAULT_MODEL
|
|
from backend.logger import setup_logger
|
|
|
|
logger = setup_logger(__name__)
|
|
|
|
router = APIRouter()
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class Query(BaseModel):
|
|
query: str
|
|
model: str = DEFAULT_MODEL
|
|
history: list[Message] = []
|
|
|
|
def build_messages(query: Query) -> list[dict]:
|
|
"""Converts Query history and current query into messages list."""
|
|
messages = [{"role": m.role, "content": m.content} for m in query.history]
|
|
messages.append({"role": "user", "content": query.query})
|
|
return messages
|
|
|
|
async def run_agent_task(query: Query, queue: asyncio.Queue, messages: list[dict]) -> None:
|
|
"""Connects MCP, builds agent and streams events into queue."""
|
|
mcp_server = make_mcp_server()
|
|
try:
|
|
async with mcp_server:
|
|
logger.info("[MCP CONNECTED]")
|
|
agent = build_agent(mcp_server=mcp_server, model_name=query.model)
|
|
logger.info(f"[AGENT MODEL] | {query.model}")
|
|
async for event in stream_response(agent, messages):
|
|
await queue.put(event)
|
|
except Exception as e:
|
|
logger.error(f"[AGENT TASK ERROR] | {e}", exc_info=True)
|
|
await queue.put({"type": "error", "data": f"\u26A0 {str(e)}"})
|
|
finally:
|
|
await queue.put(None)
|
|
|
|
async def generate_response(queue: Query, task: asyncio.Task[None]) -> AsyncIterator[str]:
|
|
"""Reads events from queue and yields SSE formatted strings."""
|
|
try:
|
|
while True:
|
|
token = await queue.get()
|
|
if token is None:
|
|
yield "data: [DONE]\n\n"
|
|
break
|
|
yield f"data: {json.dumps(token)}\n\n"
|
|
except asyncio.CancelledError:
|
|
task.cancel()
|
|
finally:
|
|
await asyncio.gather(task, return_exceptions=True)
|
|
|
|
@router.post("/api/run")
|
|
async def run_agent(query: Query) -> StreamingResponse:
|
|
"""
|
|
Run Legal AI Assistant and stream response.
|
|
|
|
Args:
|
|
query: User query and model identifier (e.g. 'llama-3.1-8b').
|
|
|
|
Returns:
|
|
SSE stream of response tokens.
|
|
"""
|
|
logger.info(f"[REQUEST] | query={query.query} | model={query.model}")
|
|
messages = build_messages(query)
|
|
queue = asyncio.Queue()
|
|
task = asyncio.create_task(run_agent_task(query, queue, messages))
|
|
return StreamingResponse(
|
|
generate_response(queue, task),
|
|
media_type="text/event-stream",
|
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
) |