legal-ai-assistant/backend/agent/response.py

73 lines
2.6 KiB
Python

from typing import AsyncGenerator
from agents import Agent, Runner, RunItemStreamEvent
from openai.types.responses import ResponseTextDeltaEvent
import asyncio
import time
from backend.logger import setup_logger
logger = setup_logger(__name__)
def parse_run_item_event(event: RunItemStreamEvent, last_tool_name: str | None) -> tuple[dict | None, str | None]:
"""Parses run item event into SSE payload and updated tool name."""
match event.name:
case "reasoning_item_created":
summary = event.item.raw_item.summary
if summary and summary[0].text:
return {
"type": "reasoning",
"data": summary[0].text
}, last_tool_name
return None, last_tool_name
case "tool_called":
last_tool_name = event.item.raw_item.name
return {
"type": "tool_start",
"tool": last_tool_name,
"input": event.item.raw_item.arguments,
}, last_tool_name
case "tool_output":
return {
"type": "tool_result",
"tool": last_tool_name,
"output": event.item.output,
}, last_tool_name
case _:
return None, last_tool_name
async def stream_response(agent: Agent, prompt: list[dict] | str) -> AsyncGenerator[dict, None]:
"""Stream agent response token by token."""
last_tool_name = None
start_perf_time = time.time()
try:
result = Runner.run_streamed(agent, input=prompt)
async for event in result.stream_events():
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
yield {"type": "text", "data": event.data.delta}
await asyncio.sleep(0.03)
elif event.type == "run_item_stream_event":
payload, last_tool_name = parse_run_item_event(event, last_tool_name)
if payload:
yield payload
pure_elapsed = time.time() - start_perf_time
usage = getattr(result.context_wrapper, "usage", None)
if usage:
yield {
"type": "usage",
"input_tokens": getattr(usage, "input_tokens", 0),
"output_tokens": getattr(usage, "output_tokens", 0),
"pure_duration": pure_elapsed
}
except (asyncio.CancelledError, GeneratorExit):
pass
except Exception as e:
logger.error(f"[RUN STREAMED ERROR] | {str(e)}", exc_info=True)
yield {"type": "error", "data": str(e)}