61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
from agents import Agent, AgentHooks
|
|
from agents import OpenAIChatCompletionsModel, AsyncOpenAI, ModelSettings
|
|
from agents import set_tracing_disabled
|
|
|
|
from core.config import (
|
|
DEFAULT_MODEL, AGENT_TEMPERATURE, LLM_TIMEOUT,
|
|
OLLAMA_BASE_URL, OLLAMA_API_KEY,
|
|
OPENAI_BASE_URL, OPENAI_API_KEY,
|
|
OPENAI_MODELS, OLLAMA_MODELS
|
|
)
|
|
from core.system_prompt import get_system_prompt
|
|
from api.tools import ALL_TOOLS
|
|
|
|
set_tracing_disabled(True)
|
|
|
|
class MyAgentHooks(AgentHooks):
|
|
async def on_start(self, context, agent):
|
|
print(f"\n🏃♂️➡️ [AgentHooks] {agent.name} started.")
|
|
|
|
async def on_end(self, context, agent, output):
|
|
print(f"🏁 [AgentHooks] {agent.name} ended.")
|
|
|
|
def _make_client(model_name: str) -> AsyncOpenAI:
|
|
"""Return the correct AsyncOpenAI client based on model name"""
|
|
|
|
if model_name in OPENAI_MODELS:
|
|
return AsyncOpenAI(
|
|
base_url=OPENAI_BASE_URL,
|
|
api_key=OPENAI_API_KEY,
|
|
timeout=LLM_TIMEOUT,
|
|
max_retries=0,
|
|
)
|
|
|
|
if model_name in OLLAMA_MODELS:
|
|
return AsyncOpenAI(
|
|
base_url=OLLAMA_BASE_URL,
|
|
api_key=OLLAMA_API_KEY,
|
|
timeout=LLM_TIMEOUT,
|
|
max_retries=0,
|
|
)
|
|
|
|
raise ValueError(f"Model {model_name} not supported")
|
|
|
|
def assistant_agent(model_name: str = DEFAULT_MODEL) -> Agent:
|
|
"""Initialize the assistant agent for legal work"""
|
|
|
|
client = _make_client(model_name)
|
|
model = OpenAIChatCompletionsModel(model=model_name, openai_client=client)
|
|
|
|
agent = Agent(
|
|
name="AI Lawyer Assistant",
|
|
instructions=get_system_prompt(model_name),
|
|
model=model,
|
|
model_settings=ModelSettings(temperature=AGENT_TEMPERATURE, tool_choice="auto", parallel_tool_calls=False),
|
|
tools=ALL_TOOLS,
|
|
tool_use_behavior="run_llm_again",
|
|
reset_tool_choice=True,
|
|
hooks=MyAgentHooks(),
|
|
)
|
|
|
|
return agent |