ai-lawyer-agent/core/init_agent.py
2026-03-23 02:55:42 +01:00

61 lines
1.9 KiB
Python

from agents import Agent, AgentHooks
from agents import OpenAIChatCompletionsModel, AsyncOpenAI, ModelSettings
from agents import set_tracing_disabled
from core.config import (
DEFAULT_MODEL, AGENT_TEMPERATURE, LLM_TIMEOUT,
OLLAMA_BASE_URL, OLLAMA_API_KEY,
OPENAI_BASE_URL, OPENAI_API_KEY,
OPENAI_MODELS, OLLAMA_MODELS
)
from core.system_prompt import get_system_prompt
from api.tools import ALL_TOOLS
set_tracing_disabled(True)
class MyAgentHooks(AgentHooks):
async def on_start(self, context, agent):
print(f"\n🏃‍♂️‍➡️ [AgentHooks] {agent.name} started.")
async def on_end(self, context, agent, output):
print(f"🏁 [AgentHooks] {agent.name} ended.")
def _make_client(model_name: str) -> AsyncOpenAI:
"""Return the correct AsyncOpenAI client based on model name"""
if model_name in OPENAI_MODELS:
return AsyncOpenAI(
base_url=OPENAI_BASE_URL,
api_key=OPENAI_API_KEY,
timeout=LLM_TIMEOUT,
max_retries=0,
)
if model_name in OLLAMA_MODELS:
return AsyncOpenAI(
base_url=OLLAMA_BASE_URL,
api_key=OLLAMA_API_KEY,
timeout=LLM_TIMEOUT,
max_retries=0,
)
raise ValueError(f"Model {model_name} not supported")
def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent:
"""Initialize the assistant agent for legal work"""
client = _make_client(model_name)
model = OpenAIChatCompletionsModel(model=model_name, openai_client=client)
agent = Agent(
name="AI Lawyer Assistant",
instructions=get_system_prompt(model_name),
model=model,
model_settings=ModelSettings(temperature=AGENT_TEMPERATURE, tool_choice="auto", parallel_tool_calls=False),
tools=ALL_TOOLS,
tool_use_behavior="run_llm_again",
reset_tool_choice=True,
hooks=MyAgentHooks(),
)
return agent