legal-ai-assistant/backend/routers/run_agent.py
2026-05-29 14:04:54 +02:00

80 lines
2.7 KiB
Python

import json
import asyncio
from typing import AsyncIterator
from pydantic import BaseModel
from fastapi import APIRouter
from fastapi.responses import StreamingResponse
from backend.agent import build_agent, make_mcp_server, stream_response
from configs import DEFAULT_MODEL
from backend.logger import setup_logger
logger = setup_logger(__name__)
router = APIRouter()
class Message(BaseModel):
role: str
content: str
class Query(BaseModel):
query: str
model: str = DEFAULT_MODEL
history: list[Message] = []
def build_messages(query: Query) -> list[dict]:
"""Converts Query history and current query into messages list."""
messages = [{"role": m.role, "content": m.content} for m in query.history]
messages.append({"role": "user", "content": query.query})
return messages
async def run_agent_task(query: Query, queue: asyncio.Queue, messages: list[dict]) -> None:
"""Connects MCP, builds agent and streams events into queue."""
mcp_server = make_mcp_server()
try:
async with mcp_server:
logger.info("[MCP CONNECTED]")
agent = build_agent(mcp_server=mcp_server, model_name=query.model)
logger.info(f"[AGENT MODEL] | {query.model}")
async for event in stream_response(agent, messages):
await queue.put(event)
except Exception as e:
logger.error(f"[AGENT TASK ERROR] | {e}", exc_info=True)
await queue.put({"type": "error", "data": f"\u26A0 {str(e)}"})
finally:
await queue.put(None)
async def generate_response(queue: Query, task: asyncio.Task[None]) -> AsyncIterator[str]:
"""Reads events from queue and yields SSE formatted strings."""
try:
while True:
token = await queue.get()
if token is None:
yield "data: [DONE]\n\n"
break
yield f"data: {json.dumps(token)}\n\n"
except asyncio.CancelledError:
task.cancel()
finally:
await asyncio.gather(task, return_exceptions=True)
@router.post("/api/run")
async def run_agent(query: Query) -> StreamingResponse:
"""
Run Legal AI Assistant and stream response.
Args:
query: User query and model identifier (e.g. 'llama-3.1-8b').
Returns:
SSE stream of response tokens.
"""
logger.info(f"[REQUEST] | query={query.query} | model={query.model}")
queue = asyncio.Queue()
messages = build_messages(query)
agent_coro = run_agent_task(query, queue, messages)
task = asyncio.create_task(agent_coro)
return StreamingResponse(
generate_response(queue, task),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)