from agents import Agent, AgentHooks from agents import OpenAIChatCompletionsModel, AsyncOpenAI, ModelSettings from agents import set_tracing_disabled from core.config import ( DEFAULT_MODEL, AGENT_TEMPERATURE, LLM_TIMEOUT, OLLAMA_BASE_URL, OLLAMA_API_KEY, OPENAI_BASE_URL, OPENAI_API_KEY, OPENAI_MODELS, OLLAMA_MODELS ) from core.system_prompt import get_system_prompt from api.tools import ALL_TOOLS set_tracing_disabled(True) class MyAgentHooks(AgentHooks): async def on_start(self, context, agent): print(f"\nšŸƒā€ā™‚ļøā€āž”ļø [AgentHooks] {agent.name} started.") async def on_end(self, context, agent, output): print(f"šŸ [AgentHooks] {agent.name} ended.") def _make_client(model_name: str) -> AsyncOpenAI: """Return the correct AsyncOpenAI client based on model name""" if model_name in OPENAI_MODELS: return AsyncOpenAI( base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY, timeout=LLM_TIMEOUT, max_retries=0, ) if model_name in OLLAMA_MODELS: return AsyncOpenAI( base_url=OLLAMA_BASE_URL, api_key=OLLAMA_API_KEY, timeout=LLM_TIMEOUT, max_retries=0, ) raise ValueError(f"Model {model_name} not supported") def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent: """Initialize the assistant agent for legal work""" client = _make_client(model_name) model = OpenAIChatCompletionsModel(model=model_name, openai_client=client) agent = Agent( name="AI Lawyer Assistant", instructions=get_system_prompt(model_name), model=model, model_settings=ModelSettings(temperature=AGENT_TEMPERATURE, tool_choice="auto", parallel_tool_calls=False), tools=ALL_TOOLS, tool_use_behavior="run_llm_again", reset_tool_choice=True, hooks=MyAgentHooks(), ) return agent