54 lines
1.4 KiB
Python
54 lines
1.4 KiB
Python
import os
|
|
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from pydantic import BaseModel
|
|
from fastapi.responses import StreamingResponse
|
|
from backend.core.agent import assistant_agent
|
|
from backend.core.streaming import stream_response
|
|
from backend.core.config import ALL_MODELS, DEFAULT_MODEL
|
|
import json
|
|
|
|
app = FastAPI(title="Legal AI Assistant")
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["http://localhost:3000"],
|
|
allow_methods=["*"],
|
|
allow_headers=["*"],
|
|
)
|
|
|
|
class Message(BaseModel):
|
|
role: str
|
|
content: str
|
|
|
|
class Request(BaseModel):
|
|
messages: list[Message]
|
|
model: str = DEFAULT_MODEL
|
|
|
|
@app.get("/")
|
|
async def health_check():
|
|
return {"status": "ok"}
|
|
|
|
@app.get("/api/models")
|
|
async def get_models():
|
|
return {"models": list(ALL_MODELS)}
|
|
|
|
@app.post("/api/chat")
|
|
async def chat(request: Request):
|
|
agent = assistant_agent(request.model)
|
|
messages = [{"role": ms.role, "content": ms.content} for ms in request.messages]
|
|
|
|
async def stream():
|
|
async for token in stream_response(agent, messages):
|
|
chunk = json.dumps({"type": "text", "delta": token})
|
|
yield f"data: {chunk}\n\n"
|
|
yield "data: [DONE]\n\n"
|
|
|
|
return StreamingResponse(
|
|
stream(),
|
|
media_type="text/event-stream",
|
|
headers={
|
|
"Cache-Control": "no-cache",
|
|
"X-Accel-Buffering": "no",
|
|
}
|
|
) |