959a391334
Some checks failed
publish docs / publish-docs (push) Has been cancelled
release-please / release-please (push) Has been cancelled
tests / setup (push) Has been cancelled
tests / ${{ matrix.quality-command }} (black) (push) Has been cancelled
tests / ${{ matrix.quality-command }} (mypy) (push) Has been cancelled
tests / ${{ matrix.quality-command }} (ruff) (push) Has been cancelled
tests / test (push) Has been cancelled
tests / all_checks_passed (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
126 lines
4.4 KiB
Python
126 lines
4.4 KiB
Python
from typing import TYPE_CHECKING, Literal
|
|
|
|
from injector import inject, singleton
|
|
from llama_index.core.indices import VectorStoreIndex
|
|
from llama_index.core.schema import NodeWithScore
|
|
from llama_index.core.storage import StorageContext
|
|
from pydantic import BaseModel, Field
|
|
|
|
from private_gpt.components.embedding.embedding_component import EmbeddingComponent
|
|
from private_gpt.components.llm.llm_component import LLMComponent
|
|
from private_gpt.components.node_store.node_store_component import NodeStoreComponent
|
|
from private_gpt.components.vector_store.vector_store_component import (
|
|
VectorStoreComponent,
|
|
)
|
|
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
|
from private_gpt.server.ingest.model import IngestedDoc
|
|
|
|
if TYPE_CHECKING:
|
|
from llama_index.core.schema import RelatedNodeInfo
|
|
|
|
|
|
class Chunk(BaseModel):
|
|
object: Literal["context.chunk"]
|
|
score: float = Field(examples=[0.023])
|
|
document: IngestedDoc
|
|
text: str = Field(examples=["Outbound sales increased 20%, driven by new leads."])
|
|
previous_texts: list[str] | None = Field(
|
|
default=None,
|
|
examples=[["SALES REPORT 2023", "Inbound didn't show major changes."]],
|
|
)
|
|
next_texts: list[str] | None = Field(
|
|
default=None,
|
|
examples=[
|
|
[
|
|
"New leads came from Google Ads campaign.",
|
|
"The campaign was run by the Marketing Department",
|
|
]
|
|
],
|
|
)
|
|
|
|
@classmethod
|
|
def from_node(cls: type["Chunk"], node: NodeWithScore) -> "Chunk":
|
|
doc_id = node.node.ref_doc_id if node.node.ref_doc_id is not None else "-"
|
|
return cls(
|
|
object="context.chunk",
|
|
score=node.score or 0.0,
|
|
document=IngestedDoc(
|
|
object="ingest.document",
|
|
doc_id=doc_id,
|
|
doc_metadata=node.metadata,
|
|
),
|
|
text=node.get_content(),
|
|
)
|
|
|
|
|
|
@singleton
|
|
class ChunksService:
|
|
@inject
|
|
def __init__(
|
|
self,
|
|
llm_component: LLMComponent,
|
|
vector_store_component: VectorStoreComponent,
|
|
embedding_component: EmbeddingComponent,
|
|
node_store_component: NodeStoreComponent,
|
|
) -> None:
|
|
self.vector_store_component = vector_store_component
|
|
self.llm_component = llm_component
|
|
self.embedding_component = embedding_component
|
|
self.storage_context = StorageContext.from_defaults(
|
|
vector_store=vector_store_component.vector_store,
|
|
docstore=node_store_component.doc_store,
|
|
index_store=node_store_component.index_store,
|
|
)
|
|
|
|
def _get_sibling_nodes_text(
|
|
self, node_with_score: NodeWithScore, related_number: int, forward: bool = True
|
|
) -> list[str]:
|
|
explored_nodes_texts = []
|
|
current_node = node_with_score.node
|
|
for _ in range(related_number):
|
|
explored_node_info: RelatedNodeInfo | None = (
|
|
current_node.next_node if forward else current_node.prev_node
|
|
)
|
|
if explored_node_info is None:
|
|
break
|
|
|
|
explored_node = self.storage_context.docstore.get_node(
|
|
explored_node_info.node_id
|
|
)
|
|
|
|
explored_nodes_texts.append(explored_node.get_content())
|
|
current_node = explored_node
|
|
|
|
return explored_nodes_texts
|
|
|
|
def retrieve_relevant(
|
|
self,
|
|
text: str,
|
|
context_filter: ContextFilter | None = None,
|
|
limit: int = 10,
|
|
prev_next_chunks: int = 0,
|
|
) -> list[Chunk]:
|
|
index = VectorStoreIndex.from_vector_store(
|
|
self.vector_store_component.vector_store,
|
|
storage_context=self.storage_context,
|
|
llm=self.llm_component.llm,
|
|
embed_model=self.embedding_component.embedding_model,
|
|
show_progress=True,
|
|
)
|
|
vector_index_retriever = self.vector_store_component.get_retriever(
|
|
index=index, context_filter=context_filter, similarity_top_k=limit
|
|
)
|
|
nodes = vector_index_retriever.retrieve(text)
|
|
nodes.sort(key=lambda n: n.score or 0.0, reverse=True)
|
|
|
|
retrieved_nodes = []
|
|
for node in nodes:
|
|
chunk = Chunk.from_node(node)
|
|
chunk.previous_texts = self._get_sibling_nodes_text(
|
|
node, prev_next_chunks, False
|
|
)
|
|
chunk.next_texts = self._get_sibling_nodes_text(node, prev_next_chunks)
|
|
retrieved_nodes.append(chunk)
|
|
|
|
return retrieved_nodes
|