75 lines
2.6 KiB
Python
75 lines
2.6 KiB
Python
from typing import AsyncGenerator
|
|
from agents import Agent, Runner, RunItemStreamEvent
|
|
from openai.types.responses import ResponseTextDeltaEvent
|
|
import asyncio
|
|
import time
|
|
from backend.logger import setup_logger
|
|
|
|
logger = setup_logger(__name__)
|
|
|
|
def parse_run_item_event(event: RunItemStreamEvent, last_tool_name: str | None) -> tuple[dict | None, str | None]:
|
|
"""Parses run item event into SSE payload and updated tool name."""
|
|
match event.name:
|
|
case "reasoning_item_created":
|
|
summary = event.item.raw_item.summary
|
|
if summary and summary[0].text:
|
|
return {
|
|
"type": "reasoning",
|
|
"data": summary[0].text
|
|
}, last_tool_name
|
|
return None, last_tool_name
|
|
|
|
case "tool_called":
|
|
last_tool_name = event.item.raw_item.name
|
|
return {
|
|
"type": "tool_start",
|
|
"tool": last_tool_name,
|
|
"input": event.item.raw_item.arguments,
|
|
}, last_tool_name
|
|
|
|
case "tool_output":
|
|
return {
|
|
"type": "tool_result",
|
|
"tool": last_tool_name,
|
|
"output": event.item.output,
|
|
}, last_tool_name
|
|
|
|
case _:
|
|
return None, last_tool_name
|
|
|
|
async def stream_response(agent: Agent, prompt: list[dict] | str) -> AsyncGenerator[dict, None]:
|
|
"""Stream agent response token by token."""
|
|
last_tool_name = None
|
|
start_perf_time = time.time()
|
|
|
|
try:
|
|
result = Runner.run_streamed(agent, input=prompt)
|
|
async for event in result.stream_events():
|
|
|
|
if event.type == "raw_response_event" and isinstance(event.data, ResponseTextDeltaEvent):
|
|
yield {"type": "text", "data": event.data.delta}
|
|
await asyncio.sleep(0.03)
|
|
|
|
elif event.type == "run_item_stream_event":
|
|
payload, last_tool_name = (
|
|
parse_run_item_event(event, last_tool_name)
|
|
)
|
|
if payload:
|
|
yield payload
|
|
|
|
pure_elapsed = time.time() - start_perf_time
|
|
usage = getattr(result.context_wrapper, "usage", None)
|
|
if usage:
|
|
yield {
|
|
"type": "usage",
|
|
"input_tokens": getattr(usage, "input_tokens", 0),
|
|
"output_tokens": getattr(usage, "output_tokens", 0),
|
|
"pure_duration": pure_elapsed
|
|
}
|
|
|
|
except (asyncio.CancelledError, GeneratorExit):
|
|
pass
|
|
except Exception as e:
|
|
logger.error(f"[RUN STREAMED ERROR] | {str(e)}", exc_info=True)
|
|
yield {"type": "error", "data": str(e)}
|
|
|