import json import asyncio from typing import AsyncIterator from pydantic import BaseModel from fastapi import APIRouter from fastapi.responses import StreamingResponse from backend.agent import build_agent, make_mcp_server, stream_response from configs import DEFAULT_MODEL from backend.logger import setup_logger logger = setup_logger(__name__) router = APIRouter() class Message(BaseModel): role: str content: str class Query(BaseModel): query: str model: str = DEFAULT_MODEL history: list[Message] = [] def build_messages(query: Query) -> list[dict]: """Converts Query history and current query into messages list.""" messages = [{"role": m.role, "content": m.content} for m in query.history] messages.append({"role": "user", "content": query.query}) return messages async def run_agent_task(query: Query, queue: asyncio.Queue, messages: list[dict]) -> None: """Connects MCP, builds agent and streams events into queue.""" mcp_server = make_mcp_server() try: async with mcp_server: logger.info("[MCP CONNECTED]") agent = build_agent(mcp_server=mcp_server, model_name=query.model) logger.info(f"[AGENT MODEL] | {query.model}") async for event in stream_response(agent, messages): await queue.put(event) except Exception as e: logger.error(f"[AGENT TASK ERROR] | {e}", exc_info=True) await queue.put({"type": "error", "data": f"\u26A0 {str(e)}"}) finally: await queue.put(None) async def generate_response(queue: Query, task: asyncio.Task[None]) -> AsyncIterator[str]: """Reads events from queue and yields SSE formatted strings.""" try: while True: token = await queue.get() if token is None: yield "data: [DONE]\n\n" break yield f"data: {json.dumps(token)}\n\n" except asyncio.CancelledError: task.cancel() finally: await asyncio.gather(task, return_exceptions=True) @router.post("/api/run") async def run_agent(query: Query) -> StreamingResponse: """ Run Legal AI Assistant and stream response. Args: query: User query and model identifier (e.g. 'llama-3.1-8b'). Returns: SSE stream of response tokens. """ logger.info(f"[REQUEST] | query={query.query} | model={query.model}") queue = asyncio.Queue() messages = build_messages(query) agent_coro = run_agent_task(query, queue, messages) task = asyncio.create_task(agent_coro) return StreamingResponse( generate_response(queue, task), media_type="text/event-stream", headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, )