36 lines
1.1 KiB
Python
36 lines
1.1 KiB
Python
|
from typing import Literal
|
||
|
|
||
|
from fastapi import APIRouter, Depends, Request
|
||
|
from pydantic import BaseModel
|
||
|
|
||
|
from private_gpt.server.embeddings.embeddings_service import (
|
||
|
Embedding,
|
||
|
EmbeddingsService,
|
||
|
)
|
||
|
from private_gpt.server.utils.auth import authenticated
|
||
|
|
||
|
embeddings_router = APIRouter(prefix="/v1", dependencies=[Depends(authenticated)])
|
||
|
|
||
|
|
||
|
class EmbeddingsBody(BaseModel):
|
||
|
input: str | list[str]
|
||
|
|
||
|
|
||
|
class EmbeddingsResponse(BaseModel):
|
||
|
object: Literal["list"]
|
||
|
model: Literal["private-gpt"]
|
||
|
data: list[Embedding]
|
||
|
|
||
|
|
||
|
@embeddings_router.post("/embeddings", tags=["Embeddings"])
|
||
|
def embeddings_generation(request: Request, body: EmbeddingsBody) -> EmbeddingsResponse:
|
||
|
"""Get a vector representation of a given input.
|
||
|
|
||
|
That vector representation can be easily consumed
|
||
|
by machine learning models and algorithms.
|
||
|
"""
|
||
|
service = request.state.injector.get(EmbeddingsService)
|
||
|
input_texts = body.input if isinstance(body.input, list) else [body.input]
|
||
|
embeddings = service.texts_embeddings(input_texts)
|
||
|
return EmbeddingsResponse(object="list", model="private-gpt", data=embeddings)
|